#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use candle_core::{DType, Module, Result, Tensor};
use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
use crate::{
attention::SdpaParams,
layers::{CausalMasker, MatMul, RmsNorm, RotaryEmbedding, Sdpa},
layers_masker::PastKvLenCache,
pipeline::{KvCache, NormalCache},
};
use super::config::WhisperEncoderArgs;
pub(super) struct EncoderAttention {
pub(super) wq: Arc<dyn QuantMethod>,
pub(super) wk: Arc<dyn QuantMethod>,
pub(super) wv: Arc<dyn QuantMethod>,
pub(super) wo: Arc<dyn QuantMethod>,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
rotary_emb: Arc<RotaryEmbedding>,
sdpa_params: SdpaParams,
}
impl EncoderAttention {
fn new(
cfg: &WhisperEncoderArgs,
rotary_emb: Arc<RotaryEmbedding>,
vb: ShardedVarBuilder,
) -> Result<Self> {
let dim = cfg.dim;
let num_heads = cfg.n_heads;
let num_kv_heads = cfg.n_kv_heads;
let head_dim = cfg.head_dim;
let use_bias = cfg.use_biases;
let wq =
mistralrs_quant::linear_b(dim, num_heads * head_dim, use_bias, &None, vb.pp("wq"))?;
let wk =
mistralrs_quant::linear_b(dim, num_kv_heads * head_dim, false, &None, vb.pp("wk"))?;
let wv =
mistralrs_quant::linear_b(dim, num_kv_heads * head_dim, use_bias, &None, vb.pp("wv"))?;
let wo =
mistralrs_quant::linear_b(num_heads * head_dim, dim, use_bias, &None, vb.pp("wo"))?;
Ok(Self {
wq,
wk,
wv,
wo,
num_heads,
num_kv_heads,
head_dim,
rotary_emb,
sdpa_params: SdpaParams {
n_kv_groups: num_heads / num_kv_heads,
softcap: None,
softmax_scale: 1.0 / (head_dim as f32).sqrt(),
sliding_window: cfg.sliding_window,
sinks: None,
},
})
}
fn forward(
&self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offsets: &[usize],
kv_cache: &mut KvCache,
) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;
let q = MatMul.qmethod_matmul(xs, &*self.wq)?;
let k = MatMul.qmethod_matmul(xs, &*self.wk)?;
let v = MatMul.qmethod_matmul(xs, &*self.wv)?;
let (q, k, v) = if q_len != 1 {
let q = q
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
(q, k, v)
} else {
let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
(q, k, v)
};
let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
let (k, v) = kv_cache.append(&k, &v)?;
let attn_output = Sdpa.run_attention(
&q,
&k,
&v,
attention_mask,
None, &self.sdpa_params,
)?;
let attn_output = if attention_mask.is_some() {
attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
} else {
attn_output.reshape((b_sz, q_len, ()))?
};
MatMul.qmethod_matmul(&attn_output, &*self.wo)
}
}
pub(super) struct EncoderMlp {
pub(super) w1: Arc<dyn QuantMethod>, pub(super) w2: Arc<dyn QuantMethod>, pub(super) w3: Arc<dyn QuantMethod>, }
impl EncoderMlp {
fn new(cfg: &WhisperEncoderArgs, vb: ShardedVarBuilder) -> Result<Self> {
let dim = cfg.dim;
let hidden_dim = cfg.hidden_dim;
let use_bias = cfg.use_biases;
let w1 = mistralrs_quant::linear_b(dim, hidden_dim, false, &None, vb.pp("w1"))?;
let w2 = mistralrs_quant::linear_b(hidden_dim, dim, use_bias, &None, vb.pp("w2"))?;
let w3 = mistralrs_quant::linear_b(dim, hidden_dim, false, &None, vb.pp("w3"))?;
Ok(Self { w1, w2, w3 })
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let gate = MatMul.qmethod_matmul(xs, &*self.w1)?;
let gate = candle_nn::ops::silu(&gate)?;
let up = MatMul.qmethod_matmul(xs, &*self.w3)?;
let xs = (gate * up)?;
MatMul.qmethod_matmul(&xs, &*self.w2)
}
}
pub(super) struct EncoderLayer {
pub(super) attention: EncoderAttention,
pub(super) feed_forward: EncoderMlp,
pub(super) attention_norm: RmsNorm,
pub(super) ffn_norm: RmsNorm,
}
impl EncoderLayer {
fn new(
cfg: &WhisperEncoderArgs,
rotary_emb: Arc<RotaryEmbedding>,
vb: ShardedVarBuilder,
) -> Result<Self> {
let attention = EncoderAttention::new(cfg, rotary_emb, vb.pp("attention"))?;
let feed_forward = EncoderMlp::new(cfg, vb.pp("feed_forward"))?;
let attention_norm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("attention_norm"))?;
let ffn_norm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("ffn_norm"))?;
Ok(Self {
attention,
feed_forward,
attention_norm,
ffn_norm,
})
}
fn forward(
&self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offsets: &[usize],
kv_cache: &mut KvCache,
) -> Result<Tensor> {
let residual = xs;
let xs = self.attention_norm.forward(xs)?;
let xs = self
.attention
.forward(&xs, attention_mask, seqlen_offsets, kv_cache)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = self.ffn_norm.forward(&xs)?;
let xs = self.feed_forward.forward(&xs)?;
residual + xs
}
}
pub struct VoxtralEncoder {
pub(super) conv1: candle_nn::Conv1d,
pub(super) conv2: candle_nn::Conv1d,
pub(super) layers: Vec<EncoderLayer>,
pub(super) norm: RmsNorm,
cache: Arc<Mutex<NormalCache>>,
num_heads: usize,
sliding_window: Option<usize>,
n_layers: usize,
model_dtype: DType,
}
impl VoxtralEncoder {
pub fn new(cfg: &WhisperEncoderArgs, vb: ShardedVarBuilder) -> Result<Self> {
let device = vb.device().clone();
let dtype = vb.dtype();
let n_mels = cfg.audio_encoding_args.num_mel_bins;
let vb_c1 = vb.pp("conv_layers").pp("0").pp("conv");
let conv1 = candle_nn::Conv1d::new(
vb_c1
.get((cfg.dim, n_mels, 3), "weight")?
.to_dtype(DType::F32)?,
Some(vb_c1.get(cfg.dim, "bias")?.to_dtype(DType::F32)?),
candle_nn::Conv1dConfig {
padding: 0,
stride: 1,
..Default::default()
},
);
let vb_c2 = vb.pp("conv_layers").pp("1").pp("conv");
let conv2 = candle_nn::Conv1d::new(
vb_c2
.get((cfg.dim, cfg.dim, 3), "weight")?
.to_dtype(DType::F32)?,
Some(vb_c2.get(cfg.dim, "bias")?.to_dtype(DType::F32)?),
candle_nn::Conv1dConfig {
padding: 0,
stride: 2,
..Default::default()
},
);
let mut ropes = HashMap::new();
ropes.insert(
device.location(),
Arc::new(RotaryEmbedding::new(
cfg.rope_theta as f32,
cfg.head_dim,
1_000_000, &device,
false, dtype,
)?),
);
let vb_layers = vb.pp("transformer").pp("layers");
let mut layers = Vec::with_capacity(cfg.n_layers);
for i in 0..cfg.n_layers {
let rotary_emb = ropes
.get(&device.location())
.expect("No RoPE for device location!")
.clone();
layers.push(EncoderLayer::new(cfg, rotary_emb, vb_layers.pp(i))?);
}
let norm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("transformer").pp("norm"))?;
Ok(Self {
conv1,
conv2,
layers,
norm,
cache: NormalCache::new_sliding(cfg.n_layers, 1_000_000, cfg.sliding_window),
num_heads: cfg.n_heads,
sliding_window: cfg.sliding_window,
n_layers: cfg.n_layers,
model_dtype: dtype,
})
}
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (b_sz, _t, _mel) = xs.dims3()?;
let xs = xs.to_dtype(DType::F32)?;
let xs = xs.transpose(1, 2)?;
let xs = xs.pad_with_zeros(2, 2, 0)?;
let xs = self.conv1.forward(&xs)?.gelu_erf()?;
let xs = xs.pad_with_zeros(2, 1, 0)?;
let xs = self.conv2.forward(&xs)?.gelu_erf()?;
let xs = xs.transpose(1, 2)?.contiguous()?;
let xs = xs.to_dtype(self.model_dtype)?;
let seq_len = xs.dim(1)?;
let mut cache = self.cache.lock().expect("Encoder cache lock poisoned");
let seqlen_offsets = vec![0usize; b_sz];
let dummy_toks = Tensor::zeros((b_sz, seq_len), DType::U32, xs.device())?;
let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
&dummy_toks,
&cache.0 as &dyn PastKvLenCache,
self.sliding_window,
xs.dtype(),
self.num_heads,
)?;
let mut hidden = xs;
for (i, layer) in self.layers.iter().enumerate() {
hidden = layer.forward(
&hidden,
attention_mask.as_ref(),
&seqlen_offsets,
&mut cache.0[i],
)?;
}
self.norm.forward(&hidden)
}
pub fn reset_cache(&self) {
let fresh = NormalCache::new_sliding(self.n_layers, 1_000_000, self.sliding_window);
let inner = fresh.lock().expect("New cache lock poisoned").clone();
*self.cache.lock().expect("Encoder cache lock poisoned") = inner;
}
}