use crate::graph::{Graph, NodeId};
pub struct WhisperConfig {
pub d_model: usize,
pub n_heads: u32,
pub n_layers: usize,
pub ffn_dim: usize,
pub n_mels: usize,
pub max_source_positions: usize,
pub layer_norm_eps: f32,
}
impl WhisperConfig {
pub fn whisper_tiny() -> Self {
Self {
d_model: 384,
n_heads: 6,
n_layers: 4,
ffn_dim: 1536,
n_mels: 80,
max_source_positions: 1500,
layer_norm_eps: 1e-5,
}
}
pub fn head_dim(&self) -> u32 {
self.d_model as u32 / self.n_heads
}
}
pub fn build_encoder(g: &mut Graph, config: &WhisperConfig, batch: u32, mel_len: u32) -> NodeId {
let d = config.d_model;
let prefix = "model.encoder";
let mel = g.input("mel", &[(batch * config.n_mels as u32 * mel_len) as usize]);
let conv1_w = g.parameter(&format!("{prefix}.conv1.weight"), &[d * config.n_mels * 3]);
let conv1_b = g.parameter(
&format!("{prefix}.conv1.fused_bias"),
&[(batch as usize * d * mel_len as usize)],
);
let x = g.conv2d_hw(
mel,
conv1_w,
batch,
config.n_mels as u32,
mel_len,
1,
d as u32,
3,
1,
1,
1,
0,
);
let x = g.add(x, conv1_b);
let x = g.gelu(x);
let seq_len = (mel_len + 2 - 3) / 2 + 1;
let conv2_w = g.parameter(&format!("{prefix}.conv2.weight"), &[d * d * 3]);
let conv2_b = g.parameter(
&format!("{prefix}.conv2.fused_bias"),
&[(batch as usize * d * seq_len as usize)],
);
let x = g.conv2d_hw(
x, conv2_w, batch, d as u32, mel_len, 1, d as u32, 3, 1, 2, 1, 0,
);
let x = g.add(x, conv2_b);
let x = g.gelu(x);
assert_eq!(batch, 1, "Whisper encoder currently supports batch=1 only");
let x = g.reshape(x, &[d, seq_len as usize]);
let x = g.transpose(x);
let pos_embed = g.parameter(
&format!("{prefix}.embed_positions.weight"),
&[seq_len as usize, d],
);
let mut x = g.add(x, pos_embed);
for i in 0..config.n_layers {
let lname = format!("{prefix}.layers.{i}");
let ln1_w = g.parameter(&format!("{lname}.self_attn_layer_norm.weight"), &[d]);
let ln1_b = g.parameter(&format!("{lname}.self_attn_layer_norm.bias"), &[d]);
let h = g.layer_norm(x, ln1_w, ln1_b, config.layer_norm_eps);
let wq = g.parameter(&format!("{lname}.self_attn.q_proj.weight"), &[d, d]);
let wk = g.parameter(&format!("{lname}.self_attn.k_proj.weight"), &[d, d]);
let wv = g.parameter(&format!("{lname}.self_attn.v_proj.weight"), &[d, d]);
let q_b = g.parameter(&format!("{lname}.self_attn.q_proj.bias"), &[d]);
let v_b = g.parameter(&format!("{lname}.self_attn.v_proj.bias"), &[d]);
let q = g.matmul(h, wq);
let q = g.bias_add(q, q_b);
let k = g.matmul(h, wk);
let v = g.matmul(h, wv);
let v = g.bias_add(v, v_b);
let attn = g.full_attention(
q,
k,
v,
config.n_heads,
config.n_heads, config.head_dim(),
);
let wo = g.parameter(&format!("{lname}.self_attn.out_proj.weight"), &[d, d]);
let wo_b = g.parameter(&format!("{lname}.self_attn.out_proj.bias"), &[d]);
let attn_out = g.matmul(attn, wo);
let attn_out = g.bias_add(attn_out, wo_b);
x = g.add(x, attn_out);
let ln2_w = g.parameter(&format!("{lname}.final_layer_norm.weight"), &[d]);
let ln2_b = g.parameter(&format!("{lname}.final_layer_norm.bias"), &[d]);
let h = g.layer_norm(x, ln2_w, ln2_b, config.layer_norm_eps);
let ff1_w = g.parameter(&format!("{lname}.fc1.weight"), &[d, config.ffn_dim]);
let ff1_b = g.parameter(&format!("{lname}.fc1.bias"), &[config.ffn_dim]);
let h = g.matmul(h, ff1_w);
let h = g.bias_add(h, ff1_b);
let h = g.gelu(h);
let ff2_w = g.parameter(&format!("{lname}.fc2.weight"), &[config.ffn_dim, d]);
let ff2_b = g.parameter(&format!("{lname}.fc2.bias"), &[d]);
let h = g.matmul(h, ff2_w);
let h = g.bias_add(h, ff2_b);
x = g.add(x, h);
}
let final_ln_w = g.parameter(&format!("{prefix}.layer_norm.weight"), &[d]);
let final_ln_b = g.parameter(&format!("{prefix}.layer_norm.bias"), &[d]);
g.layer_norm(x, final_ln_w, final_ln_b, config.layer_norm_eps)
}
pub fn build_training_graph(config: &WhisperConfig, batch: u32, mel_len: u32) -> Graph {
let mut g = Graph::new();
let encoder_out = build_encoder(&mut g, config, batch, mel_len);
let seq_len = (mel_len / 2) as usize;
let num_classes = 64;
let proj_w = g.parameter("train_proj.weight", &[config.d_model, num_classes]);
let logits = g.matmul(encoder_out, proj_w); let labels = g.input("labels", &[seq_len, num_classes]);
let loss = g.cross_entropy_loss(logits, labels);
g.set_outputs(vec![loss]);
g
}
pub fn transposed_weight_names(config: &WhisperConfig) -> Vec<String> {
let prefix = "model.encoder";
let mut names = Vec::new();
for i in 0..config.n_layers {
let l = format!("{prefix}.layers.{i}");
for proj in ["q_proj", "k_proj", "v_proj", "out_proj"] {
names.push(format!("{l}.self_attn.{proj}.weight"));
}
names.push(format!("{l}.fc1.weight"));
names.push(format!("{l}.fc2.weight"));
}
names
}