use super::cache::{DecoderCache, DecoderLayerCache};
use super::helpers::load_layernorm;
use super::layer::WhisperDecoderLayer;
use crate::error::{Error, Result};
use crate::model::config::AudioConfig;
use crate::nn::{Embedding, LayerNorm, VarBuilder};
use numr::autograd::Var;
use numr::dtype::DType;
use numr::ops::{
ActivationOps, BinaryOps, ConditionalOps, MatmulOps, NormalizationOps, ReduceOps, ScalarOps,
ShapeOps, TensorOps, UnaryOps,
};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub struct WhisperDecoder<R: Runtime> {
embed_tokens: Embedding<R>,
embed_positions: Tensor<R>,
layers: Vec<WhisperDecoderLayer<R>>,
layer_norm: LayerNorm<R>,
tied_out_weight: Tensor<R>,
vocab_size: usize,
hidden_size: usize,
}
impl<R: Runtime<DType = DType>> WhisperDecoder<R> {
pub fn from_varbuilder(vb: &mut VarBuilder<'_, R>, config: &AudioConfig) -> Result<Self> {
let hidden = config.hidden_size;
let num_heads = config.num_heads;
let num_layers = config.decoder_layer_count();
let mut tok_vb = vb.pp("embed_tokens");
let tok_weight = tok_vb.take_tensor("weight")?;
drop(tok_vb);
let tied_out_weight = tok_weight.clone();
let embed_tokens = Embedding::new(tok_weight, false);
let mut pos_vb = vb.pp("embed_positions");
let embed_positions = pos_vb.take_tensor("weight")?;
drop(pos_vb);
let mut layers = Vec::with_capacity(num_layers);
for i in 0..num_layers {
let mut layer_vb = vb.pp(&format!("layers.{i}"));
layers.push(WhisperDecoderLayer::from_varbuilder(
&mut layer_vb,
hidden,
num_heads,
)?);
}
let layer_norm = load_layernorm(vb, "layer_norm")?;
Ok(Self {
embed_tokens,
embed_positions,
layers,
layer_norm,
tied_out_weight,
vocab_size: config.vocab_size,
hidden_size: hidden,
})
}
pub fn vocab_size(&self) -> usize {
self.vocab_size
}
pub fn num_layers(&self) -> usize {
self.layers.len()
}
pub fn new_cache(&self) -> DecoderCache<R> {
DecoderCache {
layers: (0..self.layers.len())
.map(|_| DecoderLayerCache::default())
.collect(),
}
}
pub fn forward_with_cache<C>(
&self,
client: &C,
input_ids: &Tensor<R>,
encoder_out: &Tensor<R>,
position_offset: usize,
cache: &mut DecoderCache<R>,
) -> Result<Tensor<R>>
where
C: RuntimeClient<R>
+ TensorOps<R>
+ ScalarOps<R>
+ MatmulOps<R>
+ BinaryOps<R>
+ ActivationOps<R>
+ NormalizationOps<R>
+ ReduceOps<R>
+ ShapeOps<R>
+ UnaryOps<R>
+ ConditionalOps<R>
+ numr::ops::IndexingOps<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
{
assert_eq!(
cache.layers.len(),
self.layers.len(),
"decoder cache layer count mismatch"
);
let shape = input_ids.shape();
let batch = shape[0];
let seq_len = shape[1];
let tok = self.embed_tokens.forward(client, input_ids)?;
let pos = self
.embed_positions
.narrow(0, position_offset, seq_len)
.map_err(Error::Numr)?
.reshape(&[1, seq_len, self.hidden_size])
.map_err(Error::Numr)?
.broadcast_to(&[batch, seq_len, self.hidden_size])
.map_err(Error::Numr)?
.contiguous()?;
let mut hidden = client.add(tok.tensor(), &pos).map_err(Error::Numr)?;
for (layer, layer_cache) in self.layers.iter().zip(cache.layers.iter_mut()) {
hidden = layer.forward_with_cache(client, &hidden, encoder_out, layer_cache)?;
}
let normed = self.layer_norm.forward(client, &Var::new(hidden, false))?;
let hidden = normed.tensor().clone();
let w_t = self
.tied_out_weight
.transpose(0, 1)
.map_err(Error::Numr)?
.contiguous()?;
let logits = client.matmul(&hidden, &w_t).map_err(Error::Numr)?;
Ok(logits)
}
pub fn forward_inference<C>(
&self,
client: &C,
input_ids: &Tensor<R>,
encoder_out: &Tensor<R>,
position_offset: usize,
) -> Result<Tensor<R>>
where
C: RuntimeClient<R>
+ TensorOps<R>
+ ScalarOps<R>
+ MatmulOps<R>
+ BinaryOps<R>
+ ActivationOps<R>
+ NormalizationOps<R>
+ ReduceOps<R>
+ ShapeOps<R>
+ UnaryOps<R>
+ ConditionalOps<R>
+ numr::ops::IndexingOps<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
{
let shape = input_ids.shape();
let batch = shape[0];
let seq_len = shape[1];
let tok = self.embed_tokens.forward(client, input_ids)?;
let pos = self
.embed_positions
.narrow(0, position_offset, seq_len)
.map_err(Error::Numr)?;
let pos = pos
.reshape(&[1, seq_len, self.hidden_size])
.map_err(Error::Numr)?;
let pos_b = pos
.broadcast_to(&[batch, seq_len, self.hidden_size])
.map_err(Error::Numr)?
.contiguous()?;
let mut hidden = client.add(tok.tensor(), &pos_b).map_err(Error::Numr)?;
for layer in &self.layers {
hidden = layer.forward_inference(client, &hidden, encoder_out, true)?;
}
let normed = self.layer_norm.forward(client, &Var::new(hidden, false))?;
let hidden = normed.tensor().clone();
let w_t = self
.tied_out_weight
.transpose(0, 1)
.map_err(Error::Numr)?
.contiguous()?;
let logits = client.matmul(&hidden, &w_t).map_err(Error::Numr)?;
Ok(logits)
}
}