moshi_db/
transformer.rs

1// Copyright (c) Kyutai, all rights reserved.
2// This source code is licensed under the license found in the
3// LICENSE file in the root directory of this source tree.
4
5// Implements various modules for transformers with support for both quantized and unquantized forwards
6// Main differences between quantized and unquantized execution:
7// 1. For quantized models' attention `matmul_dtype`` converts intermediate activations to BF16 for
8// more efficient matmuls
9// 2. Quantized tensors cannot be easily split (regarding cross attention and QKV proj weights)
10// 3. Linear and Quantized linear layers are two different types
11use crate::nn::{
12    linear, linear_from, matmul_dtype, MaybeQuantizedLinear, MaybeQuantizedVarBuilder,
13};
14use crate::streaming::{StreamMask, StreamTensor, StreamingModule};
15use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
16
17use crate::kv_cache::KvCache;
18use candle::Context;
19
20#[derive(Debug, Clone, serde::Deserialize)]
21pub struct Config {
22    pub d_model: usize,
23    pub num_heads: usize,
24    pub num_layers: usize,
25    pub causal: bool,
26    pub norm_first: bool,
27    pub bias_ff: bool,
28    pub bias_attn: bool,
29    pub layer_scale: Option<f64>,
30    pub positional_embedding: PositionalEmbedding,
31    pub use_conv_block: bool,
32    pub cross_attention: Option<(CrossAttentionGating, crate::NormType, Option<usize>)>,
33    pub conv_kernel_size: usize,
34    pub use_conv_bias: bool,
35    pub gating: Option<candle_nn::Activation>,
36    pub norm: crate::NormType,
37    pub context: usize,
38    pub max_period: usize,
39    pub max_seq_len: usize,
40
41    pub kv_repeat: usize,
42    pub dim_feedforward: usize,
43    pub conv_layout: bool,
44
45    #[serde(default)]
46    pub shared_cross_attn: bool,
47}
48
49#[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
50pub enum PositionalEmbedding {
51    Rope,
52    Sin,
53    None,
54}
55
56#[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
57pub enum CrossAttentionGating {
58    // Configure Type of gating used at the output of vision cross-attention layers
59    Normal,
60    ConstantGatedTanh,
61    ConstantGatedSigmoid,
62    ConditionalGatedTanh,
63    ConditionalGatedSigmoid,
64    ConditionalGatedSigmoidLearnableBias,
65    ConditionalGatedTanhLearnableBias,
66}
67
68#[derive(Debug, Clone)]
69pub enum CaSrc {
70    // Input to cross-attention to handle cases where the
71    // cross-attention source can be shared across timesteps and/or layers
72    // either a single tensor (has yet to be projected)
73    // or pre-computed K,V projections;
74    Tokens(Tensor),
75    KeysValues((Tensor, Tensor)),
76}
77
78#[derive(Debug, Clone)]
79pub struct LayerScale {
80    scale: Tensor,
81}
82
83impl LayerScale {
84    pub fn new(d_model: usize, _init: f64, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
85        let scale = vb.get_unquantized(d_model, "scale")?;
86        Ok(Self { scale })
87    }
88}
89
90impl Module for LayerScale {
91    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
92        xs.broadcast_mul(&self.scale)
93    }
94}
95
96#[derive(Debug, Clone)]
97pub enum XaGate {
98    // Optional gating at the output of a cross-attention layer
99    // Normal: No gating | Identity
100    Normal,
101    // ConstantGated: Multiply by a scalar
102    ConstantGated {
103        alpha: Tensor,
104    },
105    // ConditionalGated: Pass the input x through a small MLP;
106    // The output yields a vector of scales (one for each channel)
107    // that x is then multiplied by
108    ConditionalGated {
109        in_proj: MaybeQuantizedLinear,
110        out_proj: MaybeQuantizedLinear,
111        activation: candle_nn::init::NonLinearity,
112        learnable_bias: bool,
113    },
114}
115
116impl XaGate {
117    pub fn new(cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
118        let gating_cfg =
119            cfg.cross_attention.map(|v| v.0).context("no cross-attention specified")?;
120        match gating_cfg {
121            // no gating
122            CrossAttentionGating::Normal => Ok(Self::Normal),
123            // constant (per-layer parameter) with tanh activation
124            CrossAttentionGating::ConstantGatedTanh => {
125                let alpha = vb.get_unquantized((1, 1, 1), "alpha")?.tanh()?;
126                Ok(Self::ConstantGated { alpha })
127            }
128            // constant (per-layer parameter) with sigmoid activation
129            CrossAttentionGating::ConstantGatedSigmoid => {
130                let alpha =
131                    candle_nn::ops::sigmoid(&(vb.get_unquantized((1, 1, 1), "alpha")? - 4.0)?)?;
132                Ok(Self::ConstantGated { alpha })
133            }
134            // input conditional (small MLP) with tanh or sigmoid act
135            CrossAttentionGating::ConditionalGatedTanh
136            | CrossAttentionGating::ConditionalGatedSigmoid
137            | CrossAttentionGating::ConditionalGatedSigmoidLearnableBias
138            | CrossAttentionGating::ConditionalGatedTanhLearnableBias => {
139                let dim = cfg.d_model;
140                let hidden_dims = (0.125 * dim as f32).floor() as usize;
141                let learnable_bias = matches!(
142                    gating_cfg,
143                    CrossAttentionGating::ConditionalGatedSigmoidLearnableBias
144                        | CrossAttentionGating::ConditionalGatedTanhLearnableBias
145                );
146                let in_proj = linear(dim, hidden_dims, false, vb.pp("alpha.0"))?;
147                let out_proj = linear(hidden_dims, dim, learnable_bias, vb.pp("alpha.2"))?;
148                let activation = match gating_cfg {
149                    CrossAttentionGating::ConditionalGatedTanh
150                    | CrossAttentionGating::ConditionalGatedTanhLearnableBias => {
151                        candle_nn::init::NonLinearity::Tanh
152                    }
153                    CrossAttentionGating::ConditionalGatedSigmoid
154                    | CrossAttentionGating::ConditionalGatedSigmoidLearnableBias => {
155                        candle_nn::init::NonLinearity::Sigmoid
156                    }
157                    _ => candle::bail!("Invalid cross-attention config specified."),
158                };
159                Ok(Self::ConditionalGated { in_proj, out_proj, activation, learnable_bias })
160            }
161        }
162    }
163}
164
165impl Module for XaGate {
166    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
167        match self {
168            Self::Normal => Ok(xs.clone()),
169            Self::ConstantGated { alpha } => xs.broadcast_mul(alpha),
170            Self::ConditionalGated { in_proj, out_proj, activation, learnable_bias } => {
171                let alpha = xs.apply(in_proj)?.relu()?.apply(out_proj)?;
172                let alpha = match (activation, learnable_bias) {
173                    (candle_nn::init::NonLinearity::Tanh, _) => alpha.tanh(),
174                    (candle_nn::init::NonLinearity::Sigmoid, true) => {
175                        candle_nn::ops::sigmoid(&alpha)
176                    }
177                    (candle_nn::init::NonLinearity::Sigmoid, false) => {
178                        candle_nn::ops::sigmoid(&(alpha - 4.0)?)
179                    }
180                    _ => candle::bail!("Invalid non-linearity specified in cross-attention gating"),
181                };
182                xs * alpha?
183            }
184        }
185    }
186}
187
188#[derive(Debug, Clone)]
189pub struct StreamingMultiheadCrossAttention {
190    //Cross-attention modules. Q and KV projections are separate
191    // because x (speech tokens) and ca_src (cross-attention source) can have
192    // different dimensions
193    in_proj_q: MaybeQuantizedLinear,
194    in_proj_kv: MaybeQuantizedLinear,
195    out_proj: MaybeQuantizedLinear,
196    kv_repeat: usize,
197    num_heads: usize,
198    gate: XaGate,
199    span: tracing::Span,
200}
201
202impl StreamingMultiheadCrossAttention {
203    pub fn new(
204        cfg: &Config,
205        vb: MaybeQuantizedVarBuilder,
206        gate_vb: Option<MaybeQuantizedVarBuilder>,
207    ) -> Result<Self> {
208        let embed_dim = cfg.d_model;
209        let num_kv = cfg.num_heads / cfg.kv_repeat;
210        let out_kv_dim = num_kv * (embed_dim / cfg.num_heads);
211        let out_dim = embed_dim + 2 * out_kv_dim;
212        // Case 1 (legacy): A  single in_proj; i.e., both x and ca_src *must* have
213        // the same number of dims this is only possible for non-quantized tensors though
214        // as we will need to split Q/KV weights down the line even when they have the same
215        // shape since they take different inputs
216        let (in_proj_q, in_proj_kv) = if vb.contains_key("in_proj_weight") {
217            match &vb {
218                MaybeQuantizedVarBuilder::Quantized(_) => candle::bail!("Quantized cross-attention layers require a separate in_proj_weight_q and in_proj_weight_kv"),
219                MaybeQuantizedVarBuilder::Real(weights) => {
220                    let in_proj_weight = weights.get((out_dim, embed_dim), "in_proj_weight")?;
221                    let in_proj_weight_q = in_proj_weight.narrow(0, 0, embed_dim)?;
222                    let in_proj_weight_kv = in_proj_weight.narrow(0, embed_dim, 2 * out_kv_dim)?;
223                    let (in_proj_bias_q, in_proj_bias_kv) = if cfg.bias_attn {
224                        let b = weights.get(out_dim, "in_proj_bias")?;
225                        let in_proj_bias_q = b.narrow(0, 0, embed_dim)?;
226                        let in_proj_bias_kv = b.narrow(0, embed_dim, 2 * out_kv_dim)?;
227                        (Some(in_proj_bias_q), Some(in_proj_bias_kv))
228                    } else {
229                        (None, None)
230                    };
231                    (MaybeQuantizedLinear::Real(candle_nn::Linear::new(in_proj_weight_q, in_proj_bias_q)),
232                    MaybeQuantizedLinear::Real(candle_nn::Linear::new(in_proj_weight_kv, in_proj_bias_kv)))
233
234            }
235        }
236        } else {
237            // Case 2: Separate projections for query (x) and kv (ca_src)
238            let kv_in_dim = match cfg.cross_attention.map(|v| v.2) {
239                None => candle::bail!("cfg.cross_attention is None in cross_attention module"),
240                Some(d) => match d {
241                    None | Some(0) => embed_dim,
242                    Some(dd) => dd,
243                },
244            };
245            let in_proj_weight_q = vb.get((embed_dim, embed_dim), "in_proj_weight_q")?;
246            let in_proj_weight_kv = vb.get((2 * out_kv_dim, kv_in_dim), "in_proj_weight_kv")?;
247
248            // Biases are always unquantized
249            let (in_proj_bias_q, in_proj_bias_kv) = if cfg.bias_attn {
250                (
251                    Some(vb.get_unquantized(embed_dim, "in_proj_bias_q")?),
252                    Some(vb.get_unquantized(2 * out_kv_dim, "in_proj_bias_kv")?),
253                )
254            } else {
255                (None, None)
256            };
257
258            // Finally, we can build the actual linear layers
259            let in_proj_q = linear_from(in_proj_weight_q, in_proj_bias_q)?;
260            let in_proj_kv = linear_from(in_proj_weight_kv, in_proj_bias_kv)?;
261            (in_proj_q, in_proj_kv)
262        };
263
264        let out_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("out_proj"))?;
265        let gate = match gate_vb {
266            None => XaGate::new(cfg, vb.pp("gate"))?,
267            Some(layer_gate_vb) => XaGate::new(cfg, layer_gate_vb)?,
268        };
269        Ok(Self {
270            in_proj_q,
271            in_proj_kv,
272            out_proj,
273            kv_repeat: cfg.kv_repeat,
274            num_heads: cfg.num_heads,
275            gate,
276            span: tracing::span!(tracing::Level::TRACE, "mhca"),
277        })
278    }
279
280    pub fn is_quantized(&self) -> bool {
281        match self.in_proj_q {
282            MaybeQuantizedLinear::Quantized(_) => true,
283            MaybeQuantizedLinear::Real(_) => false,
284        }
285    }
286
287    pub fn compute_kv(&self, ca_src: &CaSrc) -> Result<(Tensor, Tensor)> {
288        // this is used twice:
289        // in the standard forward pass of the cross-attention
290        // for vision models, after loading an image we can precompute its KV projections
291        // as the image is constant across multiple timesteps
292        match ca_src {
293            CaSrc::KeysValues(cakv) => Ok(cakv.clone()),
294            CaSrc::Tokens(xs) => {
295                let kv = xs.apply(&self.in_proj_kv)?;
296                let (ca_b, ca_t, ca_dim) = kv.dims3()?;
297                let head_dim = ca_dim / (2 * self.num_heads);
298                let kv = kv.reshape((ca_b, ca_t, 2, (), head_dim))?;
299                // convert to correct float point type for quantized models
300                let kv =
301                    if self.is_quantized() { kv.to_dtype(matmul_dtype(xs.device()))? } else { kv };
302                let k = kv.i((.., .., 0))?;
303                let v = kv.i((.., .., 1))?;
304                let k = k.transpose(1, 2)?.contiguous()?; // b,h,k,d
305                let v = v.transpose(1, 2)?.contiguous()?; // b,h,k,d
306                Ok((k, v))
307            }
308        }
309    }
310
311    pub fn forward(&self, xs: &Tensor, ca_src: &CaSrc, mask: Option<&Tensor>) -> Result<Tensor> {
312        let _enter = self.span.enter();
313        if self.kv_repeat != 1 {
314            candle::bail!("only kv-repeat = 1 is supported")
315        }
316        let (b, t, hd) = xs.dims3()?;
317        let head_dim = hd / self.num_heads;
318        // time_dim = 1, layout: b,t,h,d
319        let q = xs.apply(&self.in_proj_q)?;
320        let original_dtype = q.dtype();
321        let q = q.reshape((b, t, self.num_heads, head_dim))?;
322        let q = if self.is_quantized() { q.to_dtype(matmul_dtype(xs.device()))? } else { q };
323        let (k, v) = self.compute_kv(ca_src)?;
324        // qk_layer_norm = None
325        // kv_repeat = 1, otherwise we would need repeat_kv
326        let q = q.transpose(1, 2)?.contiguous()?; // b,h,t,d
327
328        let pre_ws = q.matmul(&k.t()?)?; // b,h,t,k
329        let pre_ws = (pre_ws * (head_dim as f64).powf(-0.5))?;
330
331        let pre_ws = match mask {
332            None => pre_ws,
333            Some(mask) => pre_ws.broadcast_add(mask)?,
334        };
335
336        let ws = candle_nn::ops::softmax_last_dim(&pre_ws)?; // b,h,t,k
337        let xs = ws.matmul(&v)?; // b,h,t,d
338        let xs = xs
339            .transpose(1, 2)? // b,t,h,d
340            .reshape((b, t, hd))?
341            .to_dtype(original_dtype)?
342            .apply(&self.out_proj)?
343            .apply(&self.gate)?;
344        Ok(xs)
345    }
346}
347
348#[derive(Debug, Clone)]
349pub struct Rope {
350    sin: Tensor,
351    cos: Tensor,
352}
353
354impl Rope {
355    pub fn apply_rotary_emb(&self, qk: &Tensor) -> Result<Tensor> {
356        let qk_dtype = qk.dtype();
357        candle_nn::rotary_emb::rope_i(&qk.to_dtype(DType::F32)?, &self.cos, &self.sin)?
358            .to_dtype(qk_dtype)
359    }
360}
361
362#[derive(Debug, Clone)]
363pub struct RotaryEmbedding {
364    inv_freq: Tensor,
365}
366
367impl RotaryEmbedding {
368    pub fn new(dim: usize, theta: f32, dev: &Device) -> Result<Self> {
369        let inv_freq: Vec<_> =
370            (0..dim).step_by(2).map(|i| 1f32 / theta.powf(i as f32 / dim as f32)).collect();
371        let inv_freq_len = inv_freq.len();
372        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
373        Ok(Self { inv_freq })
374    }
375
376    pub fn rope(&self, pos: &Tensor) -> Result<Rope> {
377        let t = pos.to_dtype(DType::F32)?;
378        let freqs = match *t.dims() {
379            [d] => t.reshape((d, 1))?.matmul(&self.inv_freq)?,
380            [b, d] => t.reshape((b * d, 1))?.matmul(&self.inv_freq)?.reshape((b, d, ()))?,
381            _ => candle::bail!("Invalid shape for rotary embedding {pos:?}"),
382        };
383        Ok(Rope { sin: freqs.sin()?, cos: freqs.cos()? })
384    }
385}
386
387#[cfg(feature = "flash-attn")]
388fn flash_attn(
389    q: &Tensor,
390    k: &Tensor,
391    v: &Tensor,
392    softmax_scale: f32,
393    causal: bool,
394) -> Result<Tensor> {
395    candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
396}
397
398#[cfg(not(feature = "flash-attn"))]
399fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
400    unimplemented!("compile with '--features flash-attn'")
401}
402
403#[derive(Debug, Clone)]
404pub struct StreamingMultiheadAttention {
405    // Self-attention with KV Cache
406    in_proj: MaybeQuantizedLinear,
407    out_proj: MaybeQuantizedLinear,
408    kv_repeat: usize,
409    num_heads: usize,
410    context: usize,
411    kv_cache: KvCache,
412    use_flash_attn: bool,
413    span: tracing::Span,
414}
415
416impl StreamingMultiheadAttention {
417    pub fn new(cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
418        let embed_dim = cfg.d_model;
419        let num_kv = cfg.num_heads / cfg.kv_repeat;
420        let out_dim = embed_dim + 2 * num_kv * (embed_dim / cfg.num_heads);
421        let in_proj_weight = vb.get((out_dim, embed_dim), "in_proj_weight")?;
422        let in_proj_bias =
423            if cfg.bias_attn { Some(vb.get_unquantized(out_dim, "in_proj_bias")?) } else { None };
424        let in_proj = linear_from(in_proj_weight, in_proj_bias)?;
425        let out_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("out_proj"))?;
426        Ok(Self {
427            in_proj,
428            out_proj,
429            kv_repeat: cfg.kv_repeat,
430            num_heads: cfg.num_heads,
431            context: cfg.context,
432            kv_cache: KvCache::new(2, cfg.context),
433            use_flash_attn: false,
434            span: tracing::span!(tracing::Level::TRACE, "mha"),
435        })
436    }
437
438    pub fn is_quantized(&self) -> bool {
439        match self.in_proj {
440            MaybeQuantizedLinear::Quantized(_) => true,
441            MaybeQuantizedLinear::Real(_) => false,
442        }
443    }
444
445    pub fn forward(
446        &mut self,
447        xs: &Tensor,
448        rope: Option<&Rope>,
449        mask: Option<&Tensor>,
450    ) -> Result<Tensor> {
451        let _enter = self.span.enter();
452        if self.kv_repeat != 1 {
453            candle::bail!("only kv-repeat = 1 is supported")
454        }
455        let (b, t, hd) = xs.dims3()?;
456        let head_dim = hd / self.num_heads;
457        // time_dim = 1, layout: b,t,h,d
458        let qkv = xs.apply(&self.in_proj)?.reshape((b, t, 3, self.num_heads, head_dim))?;
459        let original_dtype = qkv.dtype();
460        let qkv = if self.is_quantized() { qkv.to_dtype(matmul_dtype(xs.device()))? } else { qkv };
461        let q = qkv.i((.., .., 0))?;
462        let k = qkv.i((.., .., 1))?;
463        let v = qkv.i((.., .., 2))?;
464        // qk_layer_norm = None
465        // kv_repeat = 1, otherwise we would need repeat_kv
466        let mut q = q.transpose(1, 2)?.contiguous()?; // b,h,t,d
467        let mut k = k.transpose(1, 2)?.contiguous()?; // b,h,k,d
468        let v = v.transpose(1, 2)?.contiguous()?; // b,h,k,d
469        if let Some(rope) = rope.as_ref() {
470            q = rope.apply_rotary_emb(&q)?;
471            k = rope.apply_rotary_emb(&k)?;
472        }
473
474        let (k, v) = { self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)? };
475        // The KV cache keeps all the data at the moment, we want to trim
476        // down the part that comes from the cache to at most context to
477        // be coherent with the mask shape we provide.
478        let k_len = k.dim(2)?;
479        let k_target_len = t + usize::min(self.context, k_len - t);
480        let (k, v) = if k_target_len < k_len {
481            let k = k.narrow(2, k_len - k_target_len, k_target_len)?;
482            let v = v.narrow(2, k_len - k_target_len, k_target_len)?;
483            (k, v)
484        } else {
485            (k.clone(), v.clone())
486        };
487
488        let xs = if q.dtype() == DType::BF16 && self.use_flash_attn {
489            let q = q.transpose(1, 2)?;
490            let k = k.transpose(1, 2)?;
491            let v = v.transpose(1, 2)?;
492            let softmax_scale = 1f32 / (head_dim as f32).sqrt();
493            flash_attn(&q, &k, &v, softmax_scale, mask.is_some())?.transpose(1, 2)?
494        } else {
495            let pre_ws = q.matmul(&k.t()?)?; // b,h,t,k
496            let pre_ws = (pre_ws * (head_dim as f64).powf(-0.5))?;
497
498            let pre_ws = match mask {
499                None => pre_ws,
500                Some(mask) => pre_ws.broadcast_add(mask)?,
501            };
502
503            let ws = candle_nn::ops::softmax_last_dim(&pre_ws)?; // b,h,t,k
504            ws.matmul(&v)? // b,h,t,d
505        };
506
507        let xs = xs
508            .transpose(1, 2)? // b,t,h,d
509            .reshape((b, t, hd))?
510            .to_dtype(original_dtype)?
511            .apply(&self.out_proj)?;
512        Ok(xs)
513    }
514
515    pub fn reset_kv_cache(&mut self) {
516        self.kv_cache.reset()
517    }
518
519    pub fn set_kv_cache(&mut self, kv_cache: KvCache) {
520        self.kv_cache = kv_cache
521    }
522}
523
524#[derive(Debug, Clone)]
525pub enum Mlp {
526    //Feed Forward layers
527    NoGating {
528        linear1: MaybeQuantizedLinear,
529        linear2: MaybeQuantizedLinear,
530    },
531    Gating {
532        linear_in: MaybeQuantizedLinear,
533        linear_out: MaybeQuantizedLinear,
534        activation: candle_nn::Activation,
535    },
536}
537
538impl Mlp {
539    pub fn new(cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
540        let d_model = cfg.d_model;
541        match cfg.gating {
542            None => {
543                let linear1 = linear(d_model, cfg.dim_feedforward, cfg.bias_ff, vb.pp("linear1"))?;
544                let linear2 = linear(cfg.dim_feedforward, d_model, cfg.bias_ff, vb.pp("linear2"))?;
545                Ok(Self::NoGating { linear1, linear2 })
546            }
547            Some(activation) => {
548                let vb = vb.pp("gating");
549                let hidden = if cfg.dim_feedforward == 4 * d_model {
550                    11 * d_model / 4
551                } else {
552                    2 * cfg.dim_feedforward / 3
553                };
554                let linear_in = linear(d_model, 2 * hidden, cfg.bias_ff, vb.pp("linear_in"))?;
555                let linear_out = linear(hidden, d_model, cfg.bias_ff, vb.pp("linear_out"))?;
556                Ok(Self::Gating { linear_in, linear_out, activation })
557            }
558        }
559    }
560}
561
562impl Module for Mlp {
563    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
564        match self {
565            Self::NoGating { linear1, linear2 } => xs.apply(linear1)?.gelu_erf()?.apply(linear2),
566            Self::Gating { linear_in, linear_out, activation } => {
567                let xs = xs.apply(linear_in)?;
568                let (b, t, _) = xs.dims3()?;
569                let xs = xs.reshape((b, t, 2, ()))?;
570                let xs = (xs.i((.., .., 0))?.apply(activation)? * xs.i((.., .., 1))?)?;
571                xs.apply(linear_out)
572            }
573        }
574    }
575}
576
577#[derive(Debug, Clone)]
578pub struct RmsNorm {
579    pub(crate) alpha: Tensor,
580    pub(crate) eps: f32,
581}
582
583impl RmsNorm {
584    pub fn new(d_model: usize, eps: f32, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
585        let alpha = vb.get_unquantized((1, 1, d_model), "alpha")?.reshape(d_model)?;
586        Ok(Self { alpha, eps })
587    }
588}
589
590impl Module for RmsNorm {
591    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
592        candle_nn::ops::rms_norm(xs, &self.alpha, self.eps)
593    }
594}
595
596#[derive(Debug, Clone)]
597pub struct LayerNorm {
598    inner: candle_nn::LayerNorm,
599}
600
601impl LayerNorm {
602    pub fn new(d_model: usize, eps: f32, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
603        let bias = vb.get_unquantized(d_model, "bias")?;
604        let alpha = if vb.contains_key("alpha") {
605            vb.get_unquantized((1, 1, d_model), "alpha")?.reshape(d_model)?
606        } else {
607            vb.get_unquantized(d_model, "weight")?.reshape(d_model)?
608        };
609        let inner = candle_nn::LayerNorm::new(alpha, bias, eps as f64);
610        Ok(Self { inner })
611    }
612}
613
614impl Module for LayerNorm {
615    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
616        self.inner.forward(xs)
617    }
618}
619
620#[derive(Debug, Clone)]
621pub enum Norm {
622    LayerNorm(LayerNorm),
623    RmsNorm(RmsNorm),
624}
625
626impl Norm {
627    pub fn new(d_model: usize, cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
628        let norm = Self::new_shortcut(d_model, cfg.norm, vb)?;
629        Ok(norm)
630    }
631
632    pub fn new_shortcut(
633        d_model: usize,
634        typ: crate::NormType,
635        vb: MaybeQuantizedVarBuilder,
636    ) -> Result<Self> {
637        let norm = match typ {
638            crate::NormType::LayerNorm => {
639                let norm = LayerNorm::new(d_model, 1e-5, vb)?;
640                Self::LayerNorm(norm)
641            }
642            crate::NormType::RmsNorm => {
643                let norm = RmsNorm::new(d_model, 1e-8, vb)?;
644                Self::RmsNorm(norm)
645            }
646        };
647        Ok(norm)
648    }
649}
650
651impl Module for Norm {
652    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
653        match self {
654            Self::LayerNorm(m) => m.forward(xs),
655            Self::RmsNorm(m) => m.forward(xs),
656        }
657    }
658}
659
660#[derive(Debug, Clone)]
661pub struct StreamingTransformerLayer {
662    self_attn: StreamingMultiheadAttention,
663    mlp: Mlp,
664    norm1: Norm,
665    norm2: Norm,
666    layer_scale_1: Option<LayerScale>,
667    layer_scale_2: Option<LayerScale>,
668    cross_attn: Option<(Norm, StreamingMultiheadCrossAttention)>,
669    norm_first: bool,
670    span: tracing::Span,
671}
672
673impl StreamingTransformerLayer {
674    pub fn new(
675        cfg: &Config,
676        vb: MaybeQuantizedVarBuilder,
677        shared_ca_vb: Option<MaybeQuantizedVarBuilder>,
678    ) -> Result<Self> {
679        if cfg.use_conv_block {
680            candle::bail!("conv-block is not supported")
681        }
682        let d_model = cfg.d_model;
683        let mlp = Mlp::new(cfg, vb.clone())?;
684        let norm1 = Norm::new(d_model, cfg, vb.pp("norm1"))?;
685        let norm2 = Norm::new(d_model, cfg, vb.pp("norm2"))?;
686        let layer_scale_1 = match cfg.layer_scale {
687            None => None,
688            Some(ls) => {
689                let ls = LayerScale::new(d_model, ls, vb.pp("layer_scale_1"))?;
690                Some(ls)
691            }
692        };
693        let layer_scale_2 = match cfg.layer_scale {
694            None => None,
695            Some(ls) => {
696                let ls = LayerScale::new(d_model, ls, vb.pp("layer_scale_2"))?;
697                Some(ls)
698            }
699        };
700        let self_attn = StreamingMultiheadAttention::new(cfg, vb.pp("self_attn"))?;
701        let cross_attn = match cfg.cross_attention.map(|v| v.1) {
702            Some(norm_type) => {
703                let norm_cross = Norm::new_shortcut(d_model, norm_type, vb.pp("norm_cross"))?;
704                let cross_attn = match shared_ca_vb {
705                    None => {
706                        StreamingMultiheadCrossAttention::new(cfg, vb.pp("cross_attention"), None)?
707                    }
708                    Some(shared_vb) => StreamingMultiheadCrossAttention::new(
709                        cfg,
710                        shared_vb.pp("cross_attention"),
711                        Some(vb.pp("cross_attention.gate")),
712                    )?,
713                };
714                Some((norm_cross, cross_attn))
715            }
716            None => None,
717        };
718        Ok(Self {
719            self_attn,
720            mlp,
721            norm1,
722            norm2,
723            layer_scale_1,
724            layer_scale_2,
725            cross_attn,
726            norm_first: cfg.norm_first,
727            span: tracing::span!(tracing::Level::TRACE, "transformer-layer"),
728        })
729    }
730
731    pub fn forward(
732        &mut self,
733        xs: &Tensor,
734        rope: Option<&Rope>,
735        ca_src: Option<&CaSrc>,
736        mask: Option<&Tensor>,
737    ) -> Result<Tensor> {
738        let _enter = self.span.enter();
739        if !self.norm_first {
740            candle::bail!("only norm_first = true is supported")
741        }
742        let norm1 = xs.apply(&self.norm1)?;
743        let xs = (xs
744            + self.self_attn.forward(&norm1, rope, mask)?.apply(&self.layer_scale_1.as_ref())?)?;
745
746        let xs = match (self.cross_attn.as_mut(), ca_src) {
747            (Some((norm_cross, cross_attn)), Some(ca_src)) => {
748                let residual = &xs;
749                let xs = xs.apply(norm_cross)?;
750                (residual + cross_attn.forward(&xs, ca_src, None)?)?
751            }
752            _ => xs,
753        };
754
755        let xs =
756            (&xs + xs.apply(&self.norm2)?.apply(&self.mlp)?.apply(&self.layer_scale_2.as_ref()))?;
757        Ok(xs)
758    }
759
760    pub fn reset_kv_cache(&mut self) {
761        self.self_attn.reset_kv_cache();
762    }
763
764    pub fn set_kv_cache(&mut self, kv_cache: KvCache) {
765        self.self_attn.set_kv_cache(kv_cache);
766    }
767}
768
769#[derive(Debug, Clone)]
770pub struct StreamingTransformer {
771    // Main transformer
772    layers: Vec<StreamingTransformerLayer>,
773    positional_embedding: PositionalEmbedding,
774    max_period: usize,
775    causal: bool,
776    num_heads: usize,
777    context: usize,
778    last_reset_pos: Vec<usize>,
779    rope: Option<RotaryEmbedding>,
780}
781
782impl StreamingTransformer {
783    pub fn new(cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
784        let vb_l = vb.pp("layers");
785        let rope = match cfg.positional_embedding {
786            PositionalEmbedding::Rope => {
787                let rope = RotaryEmbedding::new(
788                    cfg.d_model / cfg.num_heads,
789                    cfg.max_period as f32,
790                    vb.device(),
791                )?;
792                Some(rope)
793            }
794            PositionalEmbedding::None | PositionalEmbedding::Sin => None,
795        };
796        let mut layers = Vec::with_capacity(cfg.num_layers);
797        for layer_idx in 0..cfg.num_layers {
798            // Also send weights of first layer as only it contains the KQV proj weights
799            // for shared cross-attention layers
800            let shared_vb = if cfg.shared_cross_attn { Some(vb_l.pp(0)) } else { None };
801            let layer = StreamingTransformerLayer::new(cfg, vb_l.pp(layer_idx), shared_vb)?;
802            layers.push(layer)
803        }
804        Ok(Self {
805            layers,
806            positional_embedding: cfg.positional_embedding,
807            max_period: cfg.max_period,
808            causal: cfg.causal,
809            num_heads: cfg.num_heads,
810            context: cfg.context,
811            last_reset_pos: vec![],
812            rope,
813        })
814    }
815
816    pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
817        self.forward_ca(xs, None)
818    }
819
820    fn current_seq_len(&self) -> usize {
821        self.layers[0].self_attn.kv_cache.current_seq_len()
822    }
823
824    pub fn forward_ca(&mut self, xs: &Tensor, ca_src: Option<&CaSrc>) -> Result<Tensor> {
825        let (b, t, c) = xs.dims3()?;
826        if !self.causal {
827            candle::bail!("only causal mode is supported")
828        }
829        if self.last_reset_pos.is_empty() {
830            self.last_reset_pos.resize(b, 0);
831        }
832        let current_seq_len = self.current_seq_len();
833        // We will extract at most "context" from the kv_cache.
834        // Note that the mask still discards the values that are before context as this can happen
835        // when t > context.
836        let mask = {
837            // mask shape should be b, h, t, k
838            // self.layers[0].self_attn.kv_cache.attn_mask(t, xs.device())?;
839            // let mask = mask.broadcast_left((b, self.num_heads))?;
840            let ks = self.layers[0].self_attn.kv_cache.positions(t);
841            let min_ks = ks.iter().min().context("no positions, is t == 0?")?;
842            if t == 1 && self.last_reset_pos.iter().all(|v| v <= min_ks) {
843                // No need for a mask here.
844                None
845            } else {
846                let mut mask = Vec::with_capacity(b * self.num_heads * t * ks.len());
847                for &last_reset_pos in self.last_reset_pos.iter() {
848                    for t_pos in 0..t {
849                        let t_pos = t_pos + current_seq_len;
850                        for &k_pos in ks.iter() {
851                            let m = if last_reset_pos <= k_pos
852                                && k_pos <= t_pos
853                                && t_pos <= k_pos + self.context
854                            {
855                                0f32
856                            } else {
857                                f32::NEG_INFINITY
858                            };
859                            mask.push(m);
860                        }
861                    }
862                }
863                let mask = Tensor::from_vec(mask, (b, 1, t, ks.len()), xs.device())?
864                    .to_dtype(xs.dtype())?
865                    .expand((b, self.num_heads, t, ks.len()))?;
866                Some(mask)
867            }
868        };
869        // pos is used for the rotary embeddings, as these are relative embeddings there is no need
870        // to adjust them for the actual position using last_reset_pos.
871        let pos =
872            Tensor::arange(current_seq_len as u32, (current_seq_len + t) as u32, xs.device())?;
873        let rope = match self.rope {
874            Some(ref rope) => Some(rope.rope(&pos)?),
875            None => None,
876        };
877        let mut xs = match self.positional_embedding {
878            PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(),
879            PositionalEmbedding::Sin => {
880                let dev = xs.device();
881                let theta = self.max_period as f32;
882                let half_dim = c / 2;
883                let positions = pos.unsqueeze(1)?.to_dtype(DType::F32)?;
884                let inv_freq: Vec<_> = (0..half_dim)
885                    .map(|i| 1f32 / theta.powf(i as f32 / (half_dim - 1) as f32))
886                    .collect();
887                let inv_freq_len = inv_freq.len();
888                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
889                let freqs = positions.broadcast_mul(&inv_freq)?;
890                let pos_emb = Tensor::cat(&[freqs.cos()?, freqs.sin()?], D::Minus1)?;
891                xs.broadcast_add(&pos_emb)?
892            }
893        };
894        for layer in self.layers.iter_mut() {
895            xs = layer.forward(&xs, rope.as_ref(), ca_src, mask.as_ref())?
896        }
897        Ok(xs)
898    }
899
900    pub fn maybe_precompute_ca_kv(&self, ca_src: Option<CaSrc>) -> Result<Option<CaSrc>> {
901        let ca_src = match ca_src {
902            None => None,
903            Some(CaSrc::KeysValues(_)) => ca_src,
904            Some(tokens) => {
905                if self.layers.is_empty() {
906                    Some(tokens)
907                } else {
908                    match &self.layers[0].cross_attn {
909                        None => Some(tokens),
910                        Some((_, ca_module)) => {
911                            let (k, v) = ca_module.compute_kv(&tokens)?;
912                            Some(CaSrc::KeysValues((k, v)))
913                        }
914                    }
915                }
916            }
917        };
918        Ok(ca_src)
919    }
920
921    pub fn copy_state(&mut self, from: &Self) -> Result<()> {
922        if self.layers.len() != from.layers.len() {
923            candle::bail!("cannot copy kv-caches as the transformers have different depths")
924        }
925        self.last_reset_pos = from.last_reset_pos.clone();
926        self.layers
927            .iter_mut()
928            .zip(from.layers.iter())
929            .for_each(|(v, w)| v.set_kv_cache(w.self_attn.kv_cache.clone()));
930        Ok(())
931    }
932
933    pub fn reset_batch_idx(&mut self, batch_idx: usize, batch_size: usize) -> Result<()> {
934        if self.last_reset_pos.is_empty() {
935            self.last_reset_pos.resize(batch_size, 0);
936        }
937        if batch_idx >= self.last_reset_pos.len() {
938            candle::bail!("batch_idx {} is out of bounds for last_reset_pos", batch_idx)
939        }
940        self.last_reset_pos[batch_idx] = self.current_seq_len();
941        Ok(())
942    }
943}
944
945impl StreamingModule for StreamingTransformer {
946    fn reset_state(&mut self) {
947        self.last_reset_pos.clear();
948        self.layers.iter_mut().for_each(|v| v.reset_kv_cache())
949    }
950
951    fn step(&mut self, xs: &StreamTensor, _: &StreamMask) -> Result<StreamTensor> {
952        // TODO: Use the StreamMask
953        match xs.as_option() {
954            None => Ok(StreamTensor::empty()),
955            Some(xs) => Ok(StreamTensor::from_tensor(self.forward(xs)?)),
956        }
957    }
958}
959
960#[derive(Debug, Clone)]
961pub struct ProjectedTransformer {
962    // Projected transformer with unquantized projection
963    transformer: StreamingTransformer,
964    input_proj: Option<MaybeQuantizedLinear>,
965    output_projs: Vec<Option<MaybeQuantizedLinear>>,
966    conv_layout: bool,
967    span: tracing::Span,
968}
969
970impl ProjectedTransformer {
971    pub fn new(
972        input_dim: usize,
973        output_dims: &[usize],
974        cfg: &Config,
975        vb: MaybeQuantizedVarBuilder,
976    ) -> Result<Self> {
977        let transformer = StreamingTransformer::new(cfg, vb.pp("transformer"))?;
978        let input_proj = if input_dim == cfg.d_model {
979            None
980        } else {
981            let l = linear(input_dim, cfg.d_model, false, vb.pp("input_proj"))?;
982            Some(l)
983        };
984        let mut output_projs = Vec::with_capacity(output_dims.len());
985        let vb_o = vb.pp("output_projs");
986        for (i, &output_dim) in output_dims.iter().enumerate() {
987            let output_proj = if output_dim == cfg.d_model {
988                None
989            } else {
990                let l = linear(cfg.d_model, output_dim, false, vb_o.pp(i))?;
991                Some(l)
992            };
993            output_projs.push(output_proj)
994        }
995        Ok(Self {
996            transformer,
997            input_proj,
998            output_projs,
999            conv_layout: cfg.conv_layout,
1000            span: tracing::span!(tracing::Level::TRACE, "proj-transformer"),
1001        })
1002    }
1003
1004    pub fn forward(&mut self, xs: &Tensor) -> Result<Vec<Tensor>> {
1005        let _enter = self.span.enter();
1006        let xs = if self.conv_layout { xs.transpose(1, 2)? } else { xs.clone() };
1007        let xs = xs.apply(&self.input_proj.as_ref())?;
1008        let xs = self.transformer.forward(&xs)?;
1009        let mut ys = Vec::with_capacity(self.output_projs.len());
1010        for output_proj in self.output_projs.iter() {
1011            let ys_ = xs.apply(&output_proj.as_ref())?;
1012            let ys_ = if self.conv_layout { ys_.transpose(1, 2)? } else { ys_ };
1013            ys.push(ys_)
1014        }
1015        Ok(ys)
1016    }
1017
1018    pub fn reset_batch_idx(&mut self, batch_idx: usize, batch_size: usize) -> Result<()> {
1019        self.transformer.reset_batch_idx(batch_idx, batch_size)
1020    }
1021}
1022
1023impl StreamingModule for ProjectedTransformer {
1024    fn reset_state(&mut self) {
1025        self.transformer.reset_state()
1026    }
1027
1028    fn step(&mut self, xs: &StreamTensor, m: &StreamMask) -> Result<StreamTensor> {
1029        let xs = xs.apply(&|x: &Tensor| {
1030            if self.conv_layout {
1031                x.transpose(1, 2)
1032            } else {
1033                Ok(x.clone())
1034            }
1035        })?;
1036        let xs = xs.apply(&self.input_proj.as_ref())?;
1037        let xs = self.transformer.step(&xs, m)?;
1038        let ys = xs.apply(&self.output_projs[0].as_ref())?;
1039        ys.apply(&|y: &Tensor| {
1040            if self.conv_layout {
1041                y.transpose(1, 2)
1042            } else {
1043                Ok(y.clone())
1044            }
1045        })
1046    }
1047}
1048
1049#[derive(Debug, Clone)]
1050pub enum Transformer {
1051    Standard(ProjectedTransformer),
1052    Batched(crate::batched_transformer::ProjectedTransformer),
1053}
1054
1055impl StreamingModule for Transformer {
1056    fn reset_state(&mut self) {
1057        match self {
1058            Transformer::Standard(t) => t.reset_state(),
1059            Transformer::Batched(t) => t.reset_state(),
1060        }
1061    }
1062
1063    fn step(&mut self, xs: &StreamTensor, m: &StreamMask) -> Result<StreamTensor> {
1064        match self {
1065            Transformer::Standard(t) => t.step(xs, m),
1066            Transformer::Batched(t) => t.step(xs, m),
1067        }
1068    }
1069}
1070
1071impl Transformer {
1072    pub fn new(
1073        batch_size: Option<usize>,
1074        dim: usize,
1075        cfg: &Config,
1076        vb: candle_nn::VarBuilder,
1077    ) -> Result<Self> {
1078        let transformer = match batch_size {
1079            Some(batch_size) => {
1080                let transformer = crate::batched_transformer::ProjectedTransformer::new(
1081                    dim,
1082                    &[dim],
1083                    batch_size,
1084                    cfg,
1085                    MaybeQuantizedVarBuilder::Real(vb),
1086                )?;
1087                Transformer::Batched(transformer)
1088            }
1089            None => {
1090                let transformer = ProjectedTransformer::new(
1091                    dim,
1092                    &[dim],
1093                    cfg,
1094                    MaybeQuantizedVarBuilder::Real(vb),
1095                )?;
1096                Transformer::Standard(transformer)
1097            }
1098        };
1099        Ok(transformer)
1100    }
1101
1102    pub fn forward(&mut self, xs: &Tensor) -> Result<Vec<Tensor>> {
1103        match self {
1104            Transformer::Standard(t) => t.forward(xs),
1105            Transformer::Batched(t) => t.forward(xs, &().into()),
1106        }
1107    }
1108
1109    pub fn reset_batch_idx(&mut self, batch_idx: usize, batch_size: usize) -> Result<()> {
1110        match self {
1111            Transformer::Standard(t) => t.reset_batch_idx(batch_idx, batch_size),
1112            Transformer::Batched(t) => t.reset_batch_idx(batch_idx),
1113        }
1114    }
1115}