moshi_db/
batched_transformer.rs

1// Copyright (c) Kyutai, all rights reserved.
2// This source code is licensed under the license found in the
3// LICENSE file in the root directory of this source tree.
4use crate::nn::{
5    linear, linear_from, matmul_dtype, MaybeQuantizedLinear, MaybeQuantizedVarBuilder,
6};
7use crate::streaming::{StreamMask, StreamTensor, StreamingModule};
8use candle::{IndexOp, Module, Result, Tensor};
9
10use crate::kv_cache::{
11    IndicesAndMask, ScatteredCacheBuilder as KvCacheBuilder, ScatteredKvCache as KvCache,
12};
13
14use crate::transformer::{
15    CaSrc, Config, LayerScale, PositionalEmbedding, Rope, RotaryEmbedding,
16    StreamingMultiheadCrossAttention,
17};
18
19#[derive(Debug, Clone)]
20pub struct StreamingMultiheadAttention {
21    // Self-attention with KV Cache
22    in_proj: MaybeQuantizedLinear,
23    out_proj: MaybeQuantizedLinear,
24    kv_repeat: usize,
25    num_heads: usize,
26    context: usize,
27    kv_cache: KvCache,
28    span: tracing::Span,
29}
30
31impl StreamingMultiheadAttention {
32    pub fn new(
33        cfg: &Config,
34        builder: &KvCacheBuilder,
35        vb: MaybeQuantizedVarBuilder,
36    ) -> Result<Self> {
37        let embed_dim = cfg.d_model;
38        let head_dim = embed_dim / cfg.num_heads;
39        let num_kv = cfg.num_heads / cfg.kv_repeat;
40        let out_dim = embed_dim + 2 * num_kv * (embed_dim / cfg.num_heads);
41        let in_proj_weight = vb.get((out_dim, embed_dim), "in_proj_weight")?;
42        let in_proj_bias =
43            if cfg.bias_attn { Some(vb.get_unquantized(out_dim, "in_proj_bias")?) } else { None };
44        let in_proj = linear_from(in_proj_weight, in_proj_bias)?;
45        let out_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("out_proj"))?;
46        Ok(Self {
47            in_proj,
48            out_proj,
49            kv_repeat: cfg.kv_repeat,
50            num_heads: cfg.num_heads,
51            context: cfg.context,
52            kv_cache: builder.make_cache(num_kv, head_dim)?,
53            span: tracing::span!(tracing::Level::TRACE, "mha"),
54        })
55    }
56
57    pub fn is_quantized(&self) -> bool {
58        match self.in_proj {
59            MaybeQuantizedLinear::Quantized(_) => true,
60            MaybeQuantizedLinear::Real(_) => false,
61        }
62    }
63
64    pub fn forward(
65        &mut self,
66        xs: &Tensor,
67        rope: Option<&Rope>,
68        iam: &IndicesAndMask,
69    ) -> Result<Tensor> {
70        let _enter = self.span.enter();
71        if self.kv_repeat != 1 {
72            candle::bail!("only kv-repeat = 1 is supported")
73        }
74        let (b, t, hd) = xs.dims3()?;
75        let head_dim = hd / self.num_heads;
76        // time_dim = 1, layout: b,t,h,d
77        let qkv = xs.apply(&self.in_proj)?.reshape((b, t, 3, self.num_heads, head_dim))?;
78        let original_dtype = qkv.dtype();
79        let qkv = if self.is_quantized() { qkv.to_dtype(matmul_dtype(xs.device()))? } else { qkv };
80        let q = qkv.i((.., .., 0))?;
81        let k = qkv.i((.., .., 1))?;
82        let v = qkv.i((.., .., 2))?;
83        // qk_layer_norm = None
84        // kv_repeat = 1, otherwise we would need repeat_kv
85        let mut q = q.transpose(1, 2)?.contiguous()?; // b,h,t,d
86        let mut k = k.transpose(1, 2)?.contiguous()?; // b,h,k,d
87        let v = v.transpose(1, 2)?.contiguous()?; // b,h,k,d
88        if let Some(rope) = rope.as_ref() {
89            q = rope.apply_rotary_emb(&q)?;
90            k = rope.apply_rotary_emb(&k)?;
91        }
92
93        let (k, v) = { self.kv_cache.append(&k.contiguous()?, &v.contiguous()?, iam)? };
94        // The KV cache keeps all the data at the moment, we want to trim
95        // down the part that comes from the cache to at most context to
96        // be coherent with the mask shape we provide.
97        let k_len = k.dim(2)?;
98        let k_target_len = t + usize::min(self.context, k_len - t);
99        let (k, v) = if k_target_len < k_len {
100            let k = k.narrow(2, k_len - k_target_len, k_target_len)?;
101            let v = v.narrow(2, k_len - k_target_len, k_target_len)?;
102            (k, v)
103        } else {
104            (k.clone(), v.clone())
105        };
106
107        let xs = {
108            let pre_ws = q.matmul(&k.t()?)?; // b,h,t,k
109            let pre_ws = (pre_ws * (head_dim as f64).powf(-0.5))?;
110            let pre_ws = pre_ws.broadcast_add(iam.mask())?;
111            let ws = candle_nn::ops::softmax_last_dim(&pre_ws)?; // b,h,t,k
112            ws.matmul(&v)? // b,h,t,d
113        };
114
115        let xs = xs
116            .transpose(1, 2)? // b,t,h,d
117            .reshape((b, t, hd))?
118            .to_dtype(original_dtype)?
119            .apply(&self.out_proj)?;
120        Ok(xs)
121    }
122
123    pub fn set_kv_cache(&mut self, kv_cache: KvCache) {
124        self.kv_cache = kv_cache
125    }
126}
127
128#[derive(Debug, Clone)]
129pub enum Mlp {
130    //Feed Forward layers
131    NoGating {
132        linear1: MaybeQuantizedLinear,
133        linear2: MaybeQuantizedLinear,
134    },
135    Gating {
136        linear_in: MaybeQuantizedLinear,
137        linear_out: MaybeQuantizedLinear,
138        activation: candle_nn::Activation,
139    },
140}
141
142impl Mlp {
143    pub fn new(cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
144        let d_model = cfg.d_model;
145        match cfg.gating {
146            None => {
147                let linear1 = linear(d_model, cfg.dim_feedforward, cfg.bias_ff, vb.pp("linear1"))?;
148                let linear2 = linear(cfg.dim_feedforward, d_model, cfg.bias_ff, vb.pp("linear2"))?;
149                Ok(Self::NoGating { linear1, linear2 })
150            }
151            Some(activation) => {
152                let vb = vb.pp("gating");
153                let hidden = if cfg.dim_feedforward == 4 * d_model {
154                    11 * d_model / 4
155                } else {
156                    2 * cfg.dim_feedforward / 3
157                };
158                let linear_in = linear(d_model, 2 * hidden, cfg.bias_ff, vb.pp("linear_in"))?;
159                let linear_out = linear(hidden, d_model, cfg.bias_ff, vb.pp("linear_out"))?;
160                Ok(Self::Gating { linear_in, linear_out, activation })
161            }
162        }
163    }
164}
165
166impl Module for Mlp {
167    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
168        match self {
169            Self::NoGating { linear1, linear2 } => xs.apply(linear1)?.gelu_erf()?.apply(linear2),
170            Self::Gating { linear_in, linear_out, activation } => {
171                let xs = xs.apply(linear_in)?;
172                let (b, t, _) = xs.dims3()?;
173                let xs = xs.reshape((b, t, 2, ()))?;
174                let xs = (xs.i((.., .., 0))?.apply(activation)? * xs.i((.., .., 1))?)?;
175                xs.apply(linear_out)
176            }
177        }
178    }
179}
180
181#[derive(Debug, Clone)]
182pub struct RmsNorm {
183    pub(crate) alpha: Tensor,
184    pub(crate) eps: f32,
185}
186
187impl RmsNorm {
188    pub fn new(d_model: usize, eps: f32, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
189        let alpha = vb.get_unquantized((1, 1, d_model), "alpha")?.reshape(d_model)?;
190        Ok(Self { alpha, eps })
191    }
192}
193
194impl Module for RmsNorm {
195    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
196        candle_nn::ops::rms_norm(xs, &self.alpha, self.eps)
197    }
198}
199
200#[derive(Debug, Clone)]
201pub struct LayerNorm {
202    inner: candle_nn::LayerNorm,
203}
204
205impl LayerNorm {
206    pub fn new(d_model: usize, eps: f32, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
207        let bias = vb.get_unquantized(d_model, "bias")?;
208        let alpha = if vb.contains_key("alpha") {
209            vb.get_unquantized((1, 1, d_model), "alpha")?.reshape(d_model)?
210        } else {
211            vb.get_unquantized(d_model, "weight")?.reshape(d_model)?
212        };
213        let inner = candle_nn::LayerNorm::new(alpha, bias, eps as f64);
214        Ok(Self { inner })
215    }
216}
217
218impl Module for LayerNorm {
219    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
220        self.inner.forward(xs)
221    }
222}
223
224#[derive(Debug, Clone)]
225pub enum Norm {
226    LayerNorm(LayerNorm),
227    RmsNorm(RmsNorm),
228}
229
230impl Norm {
231    pub fn new(d_model: usize, cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
232        let norm = Self::new_shortcut(d_model, cfg.norm, vb)?;
233        Ok(norm)
234    }
235
236    pub fn new_shortcut(
237        d_model: usize,
238        typ: crate::NormType,
239        vb: MaybeQuantizedVarBuilder,
240    ) -> Result<Self> {
241        let norm = match typ {
242            crate::NormType::LayerNorm => {
243                let norm = LayerNorm::new(d_model, 1e-5, vb)?;
244                Self::LayerNorm(norm)
245            }
246            crate::NormType::RmsNorm => {
247                let norm = RmsNorm::new(d_model, 1e-8, vb)?;
248                Self::RmsNorm(norm)
249            }
250        };
251        Ok(norm)
252    }
253}
254
255impl Module for Norm {
256    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
257        match self {
258            Self::LayerNorm(m) => m.forward(xs),
259            Self::RmsNorm(m) => m.forward(xs),
260        }
261    }
262}
263
264#[derive(Debug, Clone)]
265pub struct StreamingTransformerLayer {
266    self_attn: StreamingMultiheadAttention,
267    mlp: Mlp,
268    norm1: Norm,
269    norm2: Norm,
270    layer_scale_1: Option<LayerScale>,
271    layer_scale_2: Option<LayerScale>,
272    cross_attn: Option<(Norm, StreamingMultiheadCrossAttention)>,
273    norm_first: bool,
274    span: tracing::Span,
275}
276
277impl StreamingTransformerLayer {
278    pub fn new(
279        cfg: &Config,
280        builder: &KvCacheBuilder,
281        vb: MaybeQuantizedVarBuilder,
282        shared_ca_vb: Option<MaybeQuantizedVarBuilder>,
283    ) -> Result<Self> {
284        if cfg.use_conv_block {
285            candle::bail!("conv-block is not supported")
286        }
287        let d_model = cfg.d_model;
288        let mlp = Mlp::new(cfg, vb.clone())?;
289        let norm1 = Norm::new(d_model, cfg, vb.pp("norm1"))?;
290        let norm2 = Norm::new(d_model, cfg, vb.pp("norm2"))?;
291        let layer_scale_1 = match cfg.layer_scale {
292            None => None,
293            Some(ls) => {
294                let ls = LayerScale::new(d_model, ls, vb.pp("layer_scale_1"))?;
295                Some(ls)
296            }
297        };
298        let layer_scale_2 = match cfg.layer_scale {
299            None => None,
300            Some(ls) => {
301                let ls = LayerScale::new(d_model, ls, vb.pp("layer_scale_2"))?;
302                Some(ls)
303            }
304        };
305        let self_attn = StreamingMultiheadAttention::new(cfg, builder, vb.pp("self_attn"))?;
306        let cross_attn = match cfg.cross_attention.map(|v| v.1) {
307            Some(norm_type) => {
308                let norm_cross = Norm::new_shortcut(d_model, norm_type, vb.pp("norm_cross"))?;
309                let cross_attn = match shared_ca_vb {
310                    None => {
311                        StreamingMultiheadCrossAttention::new(cfg, vb.pp("cross_attention"), None)?
312                    }
313                    Some(shared_vb) => StreamingMultiheadCrossAttention::new(
314                        cfg,
315                        shared_vb.pp("cross_attention"),
316                        Some(vb.pp("cross_attention.gate")),
317                    )?,
318                };
319                Some((norm_cross, cross_attn))
320            }
321            None => None,
322        };
323        Ok(Self {
324            self_attn,
325            mlp,
326            norm1,
327            norm2,
328            layer_scale_1,
329            layer_scale_2,
330            cross_attn,
331            norm_first: cfg.norm_first,
332            span: tracing::span!(tracing::Level::TRACE, "transformer-layer"),
333        })
334    }
335
336    pub fn forward(
337        &mut self,
338        xs: &Tensor,
339        rope: Option<&Rope>,
340        ca_src: Option<&CaSrc>,
341        iam: &IndicesAndMask,
342    ) -> Result<Tensor> {
343        let _enter = self.span.enter();
344        if !self.norm_first {
345            candle::bail!("only norm_first = true is supported")
346        }
347        let norm1 = xs.apply(&self.norm1)?;
348        let xs = (xs
349            + self.self_attn.forward(&norm1, rope, iam)?.apply(&self.layer_scale_1.as_ref())?)?;
350
351        let xs = match (self.cross_attn.as_mut(), ca_src) {
352            (Some((norm_cross, cross_attn)), Some(ca_src)) => {
353                let residual = &xs;
354                let xs = xs.apply(norm_cross)?;
355                (residual + cross_attn.forward(&xs, ca_src, None)?)?
356            }
357            _ => xs,
358        };
359
360        let xs =
361            (&xs + xs.apply(&self.norm2)?.apply(&self.mlp)?.apply(&self.layer_scale_2.as_ref()))?;
362        Ok(xs)
363    }
364
365    pub fn set_kv_cache(&mut self, kv_cache: KvCache) {
366        self.self_attn.set_kv_cache(kv_cache);
367    }
368}
369
370#[derive(Debug, Clone)]
371pub struct StreamingTransformer {
372    // Main transformer
373    layers: Vec<StreamingTransformerLayer>,
374    positional_embedding: PositionalEmbedding,
375    causal: bool,
376    builder: KvCacheBuilder,
377    rope: Option<RotaryEmbedding>,
378}
379
380impl StreamingTransformer {
381    pub fn new(batch_size: usize, cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
382        let vb_l = vb.pp("layers");
383        let rope = match cfg.positional_embedding {
384            PositionalEmbedding::Rope => {
385                let rope = RotaryEmbedding::new(
386                    cfg.d_model / cfg.num_heads,
387                    cfg.max_period as f32,
388                    vb.device(),
389                )?;
390                Some(rope)
391            }
392            PositionalEmbedding::None | PositionalEmbedding::Sin => None,
393        };
394        let mut layers = Vec::with_capacity(cfg.num_layers);
395        let builder = KvCacheBuilder::new(batch_size, cfg.context, vb.dtype(), vb.device())?;
396        for layer_idx in 0..cfg.num_layers {
397            // Also send weights of first layer as only it contains the KQV proj weights
398            // for shared cross-attention layers
399            let shared_vb = if cfg.shared_cross_attn { Some(vb_l.pp(0)) } else { None };
400            let layer =
401                StreamingTransformerLayer::new(cfg, &builder, vb_l.pp(layer_idx), shared_vb)?;
402            layers.push(layer)
403        }
404        Ok(Self {
405            layers,
406            positional_embedding: cfg.positional_embedding,
407            causal: cfg.causal,
408            builder,
409            rope,
410        })
411    }
412
413    pub fn forward(&mut self, xs: &Tensor, m: &StreamMask) -> Result<Tensor> {
414        self.forward_ca(xs, None, m)
415    }
416
417    pub fn batch_size(&self) -> usize {
418        self.builder.batch_size()
419    }
420
421    fn positions(&self) -> &[usize] {
422        self.builder.positions()
423    }
424
425    pub fn forward_ca(
426        &mut self,
427        xs: &Tensor,
428        ca_src: Option<&CaSrc>,
429        m: &StreamMask,
430    ) -> Result<Tensor> {
431        let (b, t, _c) = xs.dims3()?;
432        if b != self.batch_size() {
433            candle::bail!("unexpected batch size {b} != {}", self.batch_size())
434        }
435        if !self.causal {
436            candle::bail!("only causal mode is supported")
437        }
438        let iam = match m.cpu() {
439            None => candle::bail!("batched-transformer expects a mask"),
440            Some(m) => self.builder.indices_and_mask(t, m)?,
441        };
442        let rope = match self.rope {
443            Some(ref rope) => {
444                let pos = self
445                    .positions()
446                    .iter()
447                    .map(|&v| (0..t).map(|i| (v + i) as u32).collect::<Vec<_>>())
448                    .collect::<Vec<_>>();
449                let pos = Tensor::new(pos, xs.device())?;
450                Some(rope.rope(&pos)?)
451            }
452            None => None,
453        };
454        let mut xs = match self.positional_embedding {
455            PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(),
456            PositionalEmbedding::Sin => candle::bail!("sin positional embedding is not supported"),
457        };
458        for layer in self.layers.iter_mut() {
459            xs = layer.forward(&xs, rope.as_ref(), ca_src, &iam)?
460        }
461        Ok(xs)
462    }
463
464    pub fn maybe_precompute_ca_kv(&self, ca_src: Option<CaSrc>) -> Result<Option<CaSrc>> {
465        let ca_src = match ca_src {
466            None => None,
467            Some(CaSrc::KeysValues(_)) => ca_src,
468            Some(tokens) => {
469                if self.layers.is_empty() {
470                    Some(tokens)
471                } else {
472                    match &self.layers[0].cross_attn {
473                        None => Some(tokens),
474                        Some((_, ca_module)) => {
475                            let (k, v) = ca_module.compute_kv(&tokens)?;
476                            Some(CaSrc::KeysValues((k, v)))
477                        }
478                    }
479                }
480            }
481        };
482        Ok(ca_src)
483    }
484
485    pub fn copy_state(&mut self, from: &Self) -> Result<()> {
486        if self.layers.len() != from.layers.len() {
487            candle::bail!("cannot copy kv-caches as the transformers have different depths")
488        }
489        self.layers
490            .iter_mut()
491            .zip(from.layers.iter())
492            .for_each(|(v, w)| v.set_kv_cache(w.self_attn.kv_cache.clone()));
493        Ok(())
494    }
495
496    pub fn reset_batch_idx(&mut self, batch_idx: usize) -> Result<()> {
497        if batch_idx >= self.batch_size() {
498            candle::bail!("batch_idx {batch_idx} is out of bounds for last_reset_pos")
499        }
500        self.builder.reset_batch_index(batch_idx);
501        Ok(())
502    }
503}
504
505impl StreamingModule for StreamingTransformer {
506    fn reset_state(&mut self) {
507        self.builder.reset();
508    }
509
510    fn step(&mut self, xs: &StreamTensor, m: &StreamMask) -> Result<StreamTensor> {
511        match xs.as_option() {
512            None => Ok(StreamTensor::empty()),
513            Some(xs) => Ok(StreamTensor::from_tensor(self.forward(xs, m)?)),
514        }
515    }
516}
517
518#[derive(Debug, Clone)]
519pub struct ProjectedTransformer {
520    // Projected transformer with unquantized projection
521    transformer: StreamingTransformer,
522    input_proj: Option<MaybeQuantizedLinear>,
523    output_projs: Vec<Option<MaybeQuantizedLinear>>,
524    conv_layout: bool,
525    span: tracing::Span,
526}
527
528impl ProjectedTransformer {
529    pub fn new(
530        input_dim: usize,
531        output_dims: &[usize],
532        batch_size: usize,
533        cfg: &Config,
534        vb: MaybeQuantizedVarBuilder,
535    ) -> Result<Self> {
536        let transformer = StreamingTransformer::new(batch_size, cfg, vb.pp("transformer"))?;
537        let input_proj = if input_dim == cfg.d_model {
538            None
539        } else {
540            let l = linear(input_dim, cfg.d_model, false, vb.pp("input_proj"))?;
541            Some(l)
542        };
543        let mut output_projs = Vec::with_capacity(output_dims.len());
544        let vb_o = vb.pp("output_projs");
545        for (i, &output_dim) in output_dims.iter().enumerate() {
546            let output_proj = if output_dim == cfg.d_model {
547                None
548            } else {
549                let l = linear(cfg.d_model, output_dim, false, vb_o.pp(i))?;
550                Some(l)
551            };
552            output_projs.push(output_proj)
553        }
554        Ok(Self {
555            transformer,
556            input_proj,
557            output_projs,
558            conv_layout: cfg.conv_layout,
559            span: tracing::span!(tracing::Level::TRACE, "proj-transformer"),
560        })
561    }
562
563    pub fn forward(&mut self, xs: &Tensor, m: &StreamMask) -> Result<Vec<Tensor>> {
564        let _enter = self.span.enter();
565        let xs = if self.conv_layout { xs.transpose(1, 2)? } else { xs.clone() };
566        let xs = xs.apply(&self.input_proj.as_ref())?;
567        let xs = self.transformer.forward(&xs, m)?;
568        let mut ys = Vec::with_capacity(self.output_projs.len());
569        for output_proj in self.output_projs.iter() {
570            let ys_ = xs.apply(&output_proj.as_ref())?;
571            let ys_ = if self.conv_layout { ys_.transpose(1, 2)? } else { ys_ };
572            ys.push(ys_)
573        }
574        Ok(ys)
575    }
576
577    pub fn reset_batch_idx(&mut self, batch_idx: usize) -> Result<()> {
578        self.transformer.reset_batch_idx(batch_idx)
579    }
580}
581
582impl StreamingModule for ProjectedTransformer {
583    fn reset_state(&mut self) {
584        self.transformer.reset_state()
585    }
586
587    fn step(&mut self, xs: &StreamTensor, m: &StreamMask) -> Result<StreamTensor> {
588        let xs = xs.apply(&|x: &Tensor| {
589            if self.conv_layout {
590                x.transpose(1, 2)
591            } else {
592                Ok(x.clone())
593            }
594        })?;
595        let xs = xs.apply(&self.input_proj.as_ref())?;
596        let xs = self.transformer.step(&xs, m)?;
597        let ys = xs.apply(&self.output_projs[0].as_ref())?;
598        ys.apply(&|y: &Tensor| {
599            if self.conv_layout {
600                y.transpose(1, 2)
601            } else {
602                Ok(y.clone())
603            }
604        })
605    }
606}