use crate::error::{Error, Result};
use crate::model::config::AudioConfig;
use crate::nn::VarBuilder;
use crate::nn::conv1d::Conv1d;
use crate::nn::layernorm::LayerNorm;
use crate::nn::linear::Linear;
use numr::autograd::Var;
use numr::ops::{
ActivationOps, BinaryOps, ConvOps, MatmulOps, NormalizationOps, ReduceOps, ScalarOps, ShapeOps,
TensorOps,
};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub struct WhisperEncoderLayer<R: Runtime> {
ln1: LayerNorm<R>,
q_proj: Linear<R>,
k_proj: Linear<R>,
v_proj: Linear<R>,
out_proj: Linear<R>,
ln2: LayerNorm<R>,
fc1: Linear<R>,
fc2: Linear<R>,
num_heads: usize,
head_dim: usize,
}
impl<R: Runtime> WhisperEncoderLayer<R> {
pub fn from_varbuilder(
vb: &mut VarBuilder<'_, R>,
hidden_size: usize,
num_heads: usize,
) -> Result<Self> {
let head_dim = hidden_size / num_heads;
let mut self_attn = vb.pp("self_attn");
let q_w = self_attn.take_tensor("q_proj.weight")?;
let q_b = self_attn.take_tensor_optional("q_proj.bias")?;
let q_proj = Linear::new(q_w, q_b, false);
let k_w = self_attn.take_tensor("k_proj.weight")?;
let k_b = self_attn.take_tensor_optional("k_proj.bias")?;
let k_proj = Linear::new(k_w, k_b, false);
let v_w = self_attn.take_tensor("v_proj.weight")?;
let v_b = self_attn.take_tensor_optional("v_proj.bias")?;
let v_proj = Linear::new(v_w, v_b, false);
let out_w = self_attn.take_tensor("out_proj.weight")?;
let out_b = self_attn.take_tensor_optional("out_proj.bias")?;
let out_proj = Linear::new(out_w, out_b, false);
drop(self_attn);
let mut ln1_vb = vb.pp("self_attn_layer_norm");
let ln1_w = ln1_vb.take_tensor("weight")?;
let ln1_b = ln1_vb.take_tensor("bias")?;
let ln1 = LayerNorm::new(ln1_w, ln1_b, 1e-5, false);
drop(ln1_vb);
let mut fc_vb = vb.pp("fc1");
let fc1_w = fc_vb.take_tensor("weight")?;
let fc1_b = fc_vb.take_tensor_optional("bias")?;
let fc1 = Linear::new(fc1_w, fc1_b, false);
drop(fc_vb);
let mut fc2_vb = vb.pp("fc2");
let fc2_w = fc2_vb.take_tensor("weight")?;
let fc2_b = fc2_vb.take_tensor_optional("bias")?;
let fc2 = Linear::new(fc2_w, fc2_b, false);
drop(fc2_vb);
let mut ln2_vb = vb.pp("final_layer_norm");
let ln2_w = ln2_vb.take_tensor("weight")?;
let ln2_b = ln2_vb.take_tensor("bias")?;
let ln2 = LayerNorm::new(ln2_w, ln2_b, 1e-5, false);
Ok(Self {
ln1,
q_proj,
k_proj,
v_proj,
out_proj,
ln2,
fc1,
fc2,
num_heads,
head_dim,
})
}
pub fn forward_inference<C>(&self, client: &C, x: &Tensor<R>) -> Result<Tensor<R>>
where
C: RuntimeClient<R>
+ TensorOps<R>
+ ScalarOps<R>
+ MatmulOps<R>
+ BinaryOps<R>
+ ActivationOps<R>
+ NormalizationOps<R>
+ ReduceOps<R>
+ ShapeOps<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
{
let shape = x.shape();
let batch = shape[0];
let seq_len = shape[1];
let hidden = shape[2];
let x_var = Var::new(x.clone(), false);
let normed = self.ln1.forward(client, &x_var)?;
let normed_t = normed.tensor().clone();
let normed_var = Var::new(normed_t.clone(), false);
let q = self.q_proj.forward(client, &normed_var)?;
let k = self.k_proj.forward(client, &normed_var)?;
let v = self.v_proj.forward(client, &normed_var)?;
let q = q
.tensor()
.reshape(&[batch, seq_len, self.num_heads, self.head_dim])
.map_err(Error::Numr)?
.transpose(1, 2)
.map_err(Error::Numr)?
.contiguous();
let k = k
.tensor()
.reshape(&[batch, seq_len, self.num_heads, self.head_dim])
.map_err(Error::Numr)?
.transpose(1, 2)
.map_err(Error::Numr)?
.contiguous();
let v = v
.tensor()
.reshape(&[batch, seq_len, self.num_heads, self.head_dim])
.map_err(Error::Numr)?
.transpose(1, 2)
.map_err(Error::Numr)?
.contiguous();
let k_t = k.transpose(2, 3).map_err(Error::Numr)?.contiguous();
let scale = (self.head_dim as f32).sqrt();
let scores = client.matmul(&q, &k_t).map_err(Error::Numr)?;
let scores = client
.div_scalar(&scores, scale as f64)
.map_err(Error::Numr)?;
let attn_weights = client.softmax(&scores, -1).map_err(Error::Numr)?;
let attn_out = client.matmul(&attn_weights, &v).map_err(Error::Numr)?;
let attn_out = attn_out
.transpose(1, 2)
.map_err(Error::Numr)?
.contiguous()
.reshape(&[batch, seq_len, hidden])
.map_err(Error::Numr)?;
let attn_out_var = Var::new(attn_out, false);
let projected = self.out_proj.forward(client, &attn_out_var)?;
let after_attn = client.add(x, projected.tensor()).map_err(Error::Numr)?;
let after_attn_var = Var::new(after_attn.clone(), false);
let normed2 = self.ln2.forward(client, &after_attn_var)?;
let normed2_var = Var::new(normed2.tensor().clone(), false);
let fc1_out = self.fc1.forward(client, &normed2_var)?;
let activated = client.gelu(fc1_out.tensor()).map_err(Error::Numr)?;
let activated_var = Var::new(activated, false);
let fc2_out = self.fc2.forward(client, &activated_var)?;
client
.add(&after_attn, fc2_out.tensor())
.map_err(Error::Numr)
}
}
pub struct WhisperEncoder<R: Runtime> {
conv1: Conv1d<R>,
conv2: Conv1d<R>,
position_embedding: Var<R>,
layers: Vec<WhisperEncoderLayer<R>>,
ln_post: LayerNorm<R>,
}
impl<R: Runtime> WhisperEncoder<R> {
pub fn from_varbuilder(vb: &mut VarBuilder<'_, R>, config: &AudioConfig) -> Result<Self> {
let hidden = config.hidden_size;
let mut conv1_vb = vb.pp("conv1");
let conv1_w = conv1_vb.take_tensor("weight")?;
let conv1_b = conv1_vb.take_tensor_optional("bias")?;
drop(conv1_vb);
let conv1 = Conv1d::new(
conv1_w,
conv1_b,
1, numr::ops::PaddingMode::conv1d(1, 1), 1, 1, false,
);
let mut conv2_vb = vb.pp("conv2");
let conv2_w = conv2_vb.take_tensor("weight")?;
let conv2_b = conv2_vb.take_tensor_optional("bias")?;
drop(conv2_vb);
let conv2 = Conv1d::new(
conv2_w,
conv2_b,
2, numr::ops::PaddingMode::conv1d(1, 1), 1, 1, false,
);
let pos_emb = vb.take_tensor("embed_positions.weight")?;
let position_embedding = Var::new(pos_emb, false);
let mut layers = Vec::with_capacity(config.num_layers);
for i in 0..config.num_layers {
let mut layer_vb = vb.pp(&format!("layers.{i}"));
layers.push(WhisperEncoderLayer::from_varbuilder(
&mut layer_vb,
hidden,
config.num_heads,
)?);
}
let mut ln_vb = vb.pp("layer_norm");
let ln_w = ln_vb.take_tensor("weight")?;
let ln_b = ln_vb.take_tensor("bias")?;
let ln_post = LayerNorm::new(ln_w, ln_b, 1e-5, false);
Ok(Self {
conv1,
conv2,
position_embedding,
layers,
ln_post,
})
}
pub fn forward_inference<C>(&self, client: &C, mel: &Tensor<R>) -> Result<Tensor<R>>
where
C: RuntimeClient<R>
+ TensorOps<R>
+ ScalarOps<R>
+ MatmulOps<R>
+ BinaryOps<R>
+ ActivationOps<R>
+ NormalizationOps<R>
+ ConvOps<R>
+ ReduceOps<R>
+ ShapeOps<R>,
R::Client:
TensorOps<R> + ScalarOps<R> + ConvOps<R> + ReduceOps<R> + BinaryOps<R> + ScalarOps<R>,
{
let x = self.conv1.forward_inference(client, mel)?;
let x = client.gelu(&x).map_err(Error::Numr)?;
let x = self.conv2.forward_inference(client, &x)?;
let x = client.gelu(&x).map_err(Error::Numr)?;
let x = x.transpose(1, 2).map_err(Error::Numr)?.contiguous();
let seq_len = x.shape()[1];
let pos = self.position_embedding.tensor();
let pos_seq_len = pos.shape()[0]; let pos_slice = if seq_len <= pos_seq_len {
pos.narrow(0, 0, seq_len).map_err(Error::Numr)?
} else {
pos.clone()
};
let x = client.add(&x, &pos_slice).map_err(Error::Numr)?;
let mut hidden = x;
for layer in &self.layers {
hidden = layer.forward_inference(client, &hidden)?;
}
let hidden_var = Var::new(hidden, false);
let output = self.ln_post.forward(client, &hidden_var)?;
Ok(output.tensor().clone())
}
}