Skip to main content

candle_transformers/models/
based.rs

1//! Based from the Stanford Hazy Research group.
2//!
3//! See "Simple linear attention language models balance the recall-throughput tradeoff", Arora et al. 2024
4//! - Simple linear attention language models balance the recall-throughput tradeoff. [Arxiv](https://arxiv.org/abs/2402.18668)
5//! - [GitHub Rep](https://github.com/HazyResearch/based)
6//! - [Blogpost](https://hazyresearch.stanford.edu/blog/2024-03-03-based)
7
8use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
9use candle_nn::{
10    conv1d_no_bias, linear, linear_no_bias, ops::softmax_last_dim, rms_norm, Conv1d, Conv1dConfig,
11    Func, Linear, RmsNorm, VarBuilder,
12};
13use std::sync::Arc;
14
15#[derive(Debug, Clone, serde::Deserialize)]
16pub struct LinearAttentionFeatureMapConfig {
17    input_dim: usize,
18}
19
20#[derive(Debug, Clone, serde::Deserialize)]
21pub struct LinearAttentionConfig {
22    num_heads: usize,
23    feature_dim: usize,
24    feature_map: LinearAttentionFeatureMapConfig,
25}
26
27#[derive(Debug, Clone, serde::Deserialize)]
28pub struct SlidingWindowAttentionConfig {
29    num_heads: usize,
30    window_size: usize,
31}
32
33#[derive(Debug, Clone, serde::Deserialize)]
34pub struct Config {
35    vocab_size: usize,
36    #[serde(rename = "n_embd")]
37    hidden_size: usize,
38    #[serde(rename = "n_inner")]
39    intermediate_size: usize,
40    #[serde(rename = "n_layer")]
41    num_hidden_layers: usize,
42    #[serde(rename = "n_head")]
43    num_attention_heads: usize,
44
45    layer_norm_epsilon: f64,
46    #[serde(default = "default_rope", rename = "rotary_emb_base")]
47    rope_theta: f64,
48
49    alt_mixer_layers: Vec<usize>,
50    alt_mixer_2_layers: Vec<usize>,
51    #[serde(rename = "alt_mixer")]
52    la: LinearAttentionConfig,
53    #[serde(rename = "alt_mixer_2")]
54    swa: SlidingWindowAttentionConfig,
55}
56
57fn default_rope() -> f64 {
58    10_000.0
59}
60
61#[derive(Debug, Clone)]
62#[allow(clippy::upper_case_acronyms)]
63struct MLP {
64    fc1: Linear,
65    fc2: Linear,
66}
67
68impl MLP {
69    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
70        let fc1 = linear_no_bias(cfg.hidden_size, cfg.hidden_size * 4, vb.pp("fc1"))?;
71        let fc2 = linear_no_bias(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?;
72        Ok(Self { fc1, fc2 })
73    }
74}
75
76// Swiglu implementation.
77// Not using Activation::Swiglu because this has the gate and y arguments switched compared to the version in candle-nn/src/ops.rs
78fn swiglu(xs: &Tensor) -> Result<Tensor> {
79    let xs = xs.chunk(2, D::Minus1)?;
80    &xs[1].silu()? * &xs[0]
81}
82
83impl Module for MLP {
84    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
85        let xs = xs.apply(&self.fc1)?;
86        let xs = swiglu(&xs)?;
87        let xs = xs.apply(&self.fc2)?;
88        Ok(xs)
89    }
90}
91
92// A gated convolutional block.
93#[derive(Debug, Clone)]
94struct BasedConv {
95    in_proj: Linear,
96    out_proj: Linear,
97    conv: Conv1d,
98    state: Tensor,
99}
100
101impl BasedConv {
102    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
103        let dim = cfg.hidden_size * 2;
104
105        let conv1d_cfg = Conv1dConfig {
106            groups: dim,
107            padding: 2,
108            ..Default::default()
109        };
110
111        let in_proj = linear(cfg.hidden_size, cfg.hidden_size * 4, vb.pp("in_proj"))?;
112        let out_proj = linear(dim, cfg.hidden_size, vb.pp("out_proj"))?;
113        let conv = conv1d_no_bias(dim, dim, 3, conv1d_cfg, vb.pp("conv.conv"))?;
114        let state = Tensor::zeros((1, dim, 3), vb.dtype(), vb.device())?;
115        Ok(Self {
116            in_proj,
117            out_proj,
118            conv,
119            state,
120        })
121    }
122
123    fn step(&mut self, xs: &Tensor) -> Result<Tensor> {
124        self.state = self.state.roll(-1, D::Minus1)?;
125        let (_, _, l) = self.state.dims3()?;
126        self.state = self.state.narrow(D::Minus1, 0, l - 1)?;
127        self.state = Tensor::cat(&[&self.state, &xs.transpose(1, 2)?], 2)?;
128
129        let xs = (&self.state * self.conv.weight().permute((1, 0, 2))?)?
130            .sum_keepdim(0)?
131            .sum(D::Minus1)?;
132
133        let xs = xs.unsqueeze(1)?;
134
135        Ok(xs)
136    }
137
138    fn forward(&mut self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
139        let xs = xs.apply(&self.in_proj)?;
140        let us = xs.chunk(2, D::Minus1)?;
141        let (_b, l, _d) = us[0].dims3()?;
142        let u_conv = if seqlen_offset > 0 {
143            self.step(&us[0])?
144        } else {
145            let k = std::cmp::min(3, l);
146            self.state = self.state.narrow(D::Minus1, 0, 3 - k)?;
147            let xs = us[0].narrow(1, l - k, k)?.transpose(1, 2)?;
148            self.state = Tensor::cat(&[&self.state, &xs], 2)?;
149
150            us[0]
151                .transpose(1, 2)?
152                .apply(&self.conv)?
153                .narrow(D::Minus1, 0, l)?
154                .transpose(1, 2)?
155        };
156
157        let u_conv = u_conv.silu()?;
158        let v = u_conv.broadcast_mul(&us[1])?;
159        let xs = v.apply(&self.out_proj)?;
160
161        Ok(xs)
162    }
163}
164
165// Linear attention approximating softmax using second order Taylor polynomials.
166#[derive(Debug, Clone)]
167struct LinearAttention {
168    proj_q: Linear,
169    proj_k: Linear,
170    proj_v: Linear,
171    out_proj: Linear,
172    feature_dim: usize,
173    num_heads: usize,
174    input_dim: usize,
175    k_state: Tensor,
176    kv_state: Tensor,
177}
178
179impl LinearAttention {
180    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
181        let input_dim = cfg.la.feature_map.input_dim;
182        let out_proj = linear_no_bias(cfg.hidden_size, cfg.hidden_size, vb.pp("out_proj"))?;
183        let proj_k = linear_no_bias(
184            cfg.hidden_size,
185            cfg.la.num_heads * cfg.la.feature_dim,
186            vb.pp("proj_k"),
187        )?;
188        let proj_q = linear_no_bias(
189            cfg.hidden_size,
190            cfg.la.num_heads * cfg.la.feature_dim,
191            vb.pp("proj_q"),
192        )?;
193
194        let proj_v = linear_no_bias(cfg.hidden_size, cfg.hidden_size, vb.pp("proj_v"))?;
195        let expanded_size = cfg.la.feature_dim.pow(2) + cfg.la.feature_dim + 1;
196        let k_state = Tensor::zeros(
197            (1, cfg.la.num_heads, 1, 1, expanded_size),
198            vb.dtype(),
199            vb.device(),
200        )?;
201        let kv_state = Tensor::zeros(
202            (1, cfg.la.num_heads, cfg.la.feature_dim, expanded_size),
203            vb.dtype(),
204            vb.device(),
205        )?;
206
207        Ok(Self {
208            proj_q,
209            proj_k,
210            proj_v,
211            out_proj,
212            feature_dim: cfg.la.feature_dim,
213            num_heads: cfg.la.num_heads,
214            input_dim,
215            k_state,
216            kv_state,
217        })
218    }
219
220    fn taylor_expansion(&self) -> Result<Func<'static>> {
221        let r2 = std::f64::consts::SQRT_2;
222        let rd = (self.input_dim as f64).sqrt();
223        let rrd = rd.sqrt();
224
225        Ok(Func::new(move |xs| {
226            let dims = xs.dims();
227            let mut d = dims.to_vec();
228            if let Some(last) = d.last_mut() {
229                *last = 1;
230            };
231
232            let x = xs
233                .unsqueeze(D::Minus1)?
234                .broadcast_mul(&xs.unsqueeze(D::Minus2)?)?;
235            let x = (x.flatten_from(D::Minus2)? / r2)?;
236            let o = Tensor::ones(d, xs.dtype(), xs.device())?;
237            let x = Tensor::cat(&[o, (xs / rrd)?, (&x / rd)?], D::Minus1)?;
238
239            Ok(x)
240        }))
241    }
242
243    fn forward(&mut self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
244        let eps = 1e-12;
245
246        let feature_map = self.taylor_expansion()?;
247
248        let (b, l, d) = xs.dims3()?;
249        let q = xs.apply(&self.proj_q)?;
250        let k = xs.apply(&self.proj_k)?;
251        let v = xs.apply(&self.proj_v)?;
252
253        let q = q
254            .reshape((b, l, self.num_heads, self.feature_dim))?
255            .transpose(1, 2)?
256            .contiguous()?;
257        let k = k
258            .reshape((b, l, self.num_heads, self.feature_dim))?
259            .transpose(1, 2)?
260            .contiguous()?;
261        let v = v
262            .reshape((b, l, self.num_heads, d / self.num_heads))?
263            .transpose(1, 2)?
264            .contiguous()?;
265
266        let q = feature_map.forward(&q)?;
267        let k = feature_map.forward(&k)?;
268
269        let y = if seqlen_offset > 0 {
270            let (_b, _h, l, _d) = k.dims4()?;
271            let q = q.unsqueeze(D::Minus2)?;
272            let k = k.unsqueeze(D::Minus2)?;
273            let v = v.unsqueeze(D::Minus1)?;
274            let kn = k.narrow(D::Minus1, l - 1, 1)?;
275            let vn = v.narrow(D::Minus1, l - 1, 1)?;
276
277            self.k_state = self.k_state.broadcast_add(&kn)?;
278            self.kv_state = self.kv_state.broadcast_add(&kn.broadcast_mul(&vn)?)?;
279
280            let num = q.broadcast_mul(&self.kv_state)?.sum(D::Minus1)?;
281            let den = (q.broadcast_mul(&self.k_state)?.sum(D::Minus1)? + eps)?;
282            num.broadcast_div(&den)?
283        } else {
284            self.k_state = k.sum(2)?.unsqueeze(2)?.unsqueeze(3)?;
285            self.kv_state = k
286                .transpose(2, 3)?
287                .matmul(&v)?
288                .transpose(2, 3)?
289                .unsqueeze(2)?;
290            let aqk = q.matmul(&k.transpose(D::Minus1, D::Minus2)?)?;
291            let tril = Tensor::tril2(l, aqk.dtype(), aqk.device())?;
292            let aqk = aqk.broadcast_mul(&tril)?.matmul(&v)?;
293
294            let z = (1f64 / (q.mul(&k.cumsum(2)?)?.sum(D::Minus1)? + eps)?)?;
295            aqk.broadcast_mul(&z.unsqueeze(D::Minus1)?)?
296        };
297
298        let (b, h, l, d) = y.dims4()?;
299        let y = y.permute((0, 2, 1, 3))?.reshape((b, l, h * d))?;
300        let y = self.out_proj.forward(&y)?;
301
302        Ok(y)
303    }
304}
305
306// Rotary embeddings used in local attention.
307#[derive(Debug, Clone)]
308struct RotaryEmbedding {
309    sin: Tensor,
310    cos: Tensor,
311}
312
313impl RotaryEmbedding {
314    fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
315        let dim = cfg.hidden_size / cfg.num_attention_heads;
316        let max_seq_len = 2048; // Hardcoded, missing from config.
317        let inv_freq: Vec<_> = (0..dim)
318            .step_by(2)
319            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
320            .collect();
321        let inv_freq_len = inv_freq.len();
322        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
323        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
324            .to_dtype(dtype)?
325            .reshape((max_seq_len, 1))?;
326        let freqs = t.matmul(&inv_freq)?;
327        Ok(Self {
328            sin: freqs.sin()?,
329            cos: freqs.cos()?,
330        })
331    }
332
333    fn apply_rotary_emb_qkv(
334        &self,
335        q: &Tensor,
336        k: &Tensor,
337        seqlen_offset: usize,
338    ) -> Result<(Tensor, Tensor)> {
339        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
340        let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
341        let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
342        let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
343        let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
344        Ok((q_embed, k_embed))
345    }
346}
347
348// Local attention using a small sliding window.
349#[derive(Debug, Clone)]
350struct SlidingWindowAttention {
351    wqkv: Linear,
352    out_proj: Linear,
353    num_heads: usize,
354    head_dim: usize,
355    hidden_size: usize,
356    rotary_emb: Arc<RotaryEmbedding>,
357    kv_cache: Option<(Tensor, Tensor)>,
358}
359
360impl SlidingWindowAttention {
361    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
362        let hidden_size = cfg.hidden_size;
363        let num_heads = cfg.swa.num_heads;
364        let head_dim = hidden_size / num_heads;
365        let out_proj = linear_no_bias(hidden_size, hidden_size, vb.pp("out_proj"))?;
366        let wqkv = linear_no_bias(hidden_size, hidden_size * 3, vb.pp("Wqkv"))?;
367        let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?);
368        Ok(Self {
369            wqkv,
370            out_proj,
371            hidden_size,
372            num_heads,
373            head_dim,
374            rotary_emb,
375            kv_cache: None,
376        })
377    }
378
379    fn forward(
380        &mut self,
381        xs: &Tensor,
382        attention_mask: Option<&Tensor>,
383        seqlen_offset: usize,
384    ) -> Result<Tensor> {
385        let (b_sz, q_len, _) = xs.dims3()?;
386
387        let qkv = xs.apply(&self.wqkv)?;
388        let qkv = qkv.reshape((b_sz, q_len, 3, (), self.head_dim))?;
389
390        let q = qkv.i((.., .., 0))?;
391        let k = qkv.i((.., .., 1))?;
392        let v = qkv.i((.., .., 2))?;
393
394        let q = q
395            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
396            .transpose(1, 2)?;
397        let k = k
398            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
399            .transpose(1, 2)?;
400        let v = v
401            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
402            .transpose(1, 2)?;
403
404        let (q, k) = self
405            .rotary_emb
406            .apply_rotary_emb_qkv(&q, &k, seqlen_offset)?;
407
408        let (k, v) = match &self.kv_cache {
409            None => (k, v),
410            Some((prev_k, prev_v)) => {
411                let k = Tensor::cat(&[prev_k, &k], 2)?;
412                let v = Tensor::cat(&[prev_v, &v], 2)?;
413                (k, v)
414            }
415        };
416        self.kv_cache = Some((k.clone(), v.clone()));
417
418        let scale = 1f64 / f64::sqrt(self.head_dim as f64);
419        let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
420
421        let attn_weights = match attention_mask {
422            None => attn_weights,
423            Some(mask) => attn_weights.broadcast_add(mask)?,
424        };
425        let attn_weights = softmax_last_dim(&attn_weights)?;
426        let attn_output = attn_weights.matmul(&v)?;
427        let out = attn_output
428            .transpose(1, 2)?
429            .reshape((b_sz, q_len, self.hidden_size))?
430            .apply(&self.out_proj)?;
431
432        Ok(out)
433    }
434}
435
436// The model layers use three types of mixers.
437#[derive(Debug, Clone)]
438enum SequenceMixer {
439    Based(BasedConv),
440    Linear(LinearAttention),
441    Sliding(SlidingWindowAttention),
442}
443
444impl SequenceMixer {
445    fn forward(
446        &mut self,
447        xs: &Tensor,
448        attention_mask: Option<&Tensor>,
449        pos: usize,
450    ) -> Result<Tensor> {
451        match self {
452            Self::Based(b) => b.forward(xs, pos),
453            Self::Linear(b) => b.forward(xs, pos),
454            Self::Sliding(b) => b.forward(xs, attention_mask, pos),
455        }
456    }
457}
458
459#[derive(Debug, Clone)]
460struct DecoderLayer {
461    mlp: MLP,
462    norm1: RmsNorm,
463    norm2: RmsNorm,
464    mixer: SequenceMixer,
465}
466
467impl DecoderLayer {
468    fn new(layer_idx: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
469        let mlp = MLP::new(cfg, vb.pp("mlp"))?;
470        let norm1 = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm1"))?;
471        let norm2 = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm2"))?;
472
473        let l_attn = cfg.alt_mixer_layers.contains(&layer_idx);
474        let sw_attn = cfg.alt_mixer_2_layers.contains(&layer_idx);
475
476        let mixer = if l_attn {
477            SequenceMixer::Linear(LinearAttention::new(cfg, vb.pp("mixer"))?)
478        } else if sw_attn {
479            SequenceMixer::Sliding(SlidingWindowAttention::new(cfg, vb.pp("mixer"))?)
480        } else {
481            SequenceMixer::Based(BasedConv::new(cfg, vb.pp("mixer"))?)
482        };
483
484        Ok(Self {
485            mlp,
486            norm1,
487            norm2,
488            mixer,
489        })
490    }
491
492    fn forward(
493        &mut self,
494        xs: &Tensor,
495        attention_mask: Option<&Tensor>,
496        seqlen_offset: usize,
497    ) -> Result<Tensor> {
498        let residual = xs;
499        let xs = self.norm1.forward(xs)?;
500        let xs = self.mixer.forward(&xs, attention_mask, seqlen_offset)?;
501        let xs = (xs + residual)?;
502        let residual = &xs;
503        let xs = xs.apply(&self.norm2)?.apply(&self.mlp)?;
504        residual + xs
505    }
506}
507
508#[derive(Debug, Clone)]
509pub struct Model {
510    embed_tokens: super::with_tracing::Embedding,
511    layers: Vec<DecoderLayer>,
512    norm: RmsNorm,
513    lm_head: Linear,
514    sliding_window: usize,
515    device: Device,
516    dtype: DType,
517}
518
519impl Model {
520    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
521        let vocab_size = cfg.vocab_size + (8 - cfg.vocab_size % 8) % 8;
522        let lm_head = linear_no_bias(cfg.hidden_size, vocab_size, vb.pp("lm_head"))?;
523        let embed_tokens = super::with_tracing::Embedding::from_weights(lm_head.weight().clone())?;
524        let vb_m = vb.pp("transformer");
525        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
526        let vb_l = vb_m.pp("layers");
527        for layer_idx in 0..cfg.num_hidden_layers {
528            let layer = DecoderLayer::new(layer_idx, cfg, vb_l.pp(layer_idx))?;
529            layers.push(layer)
530        }
531        let norm = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb_m.pp("ln_f"))?;
532        Ok(Self {
533            embed_tokens,
534            layers,
535            norm,
536            lm_head,
537            sliding_window: cfg.swa.window_size,
538            device: vb.device().clone(),
539            dtype: vb.dtype(),
540        })
541    }
542
543    fn prepare_decoder_attention_mask(
544        &self,
545        b_size: usize,
546        tgt_len: usize,
547        seqlen_offset: usize,
548    ) -> Result<Tensor> {
549        let sliding_window = self.sliding_window / 2;
550        let mask: Vec<_> = (0..tgt_len)
551            .flat_map(|i| {
552                (0..tgt_len).map(move |j| {
553                    if i < j || j + sliding_window < i {
554                        f32::NEG_INFINITY
555                    } else {
556                        0.
557                    }
558                })
559            })
560            .collect();
561        let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
562        let mask = if seqlen_offset > 0 {
563            let mask0 = Tensor::zeros((tgt_len, seqlen_offset), self.dtype, &self.device)?;
564            Tensor::cat(&[&mask0, &mask], D::Minus1)?
565        } else {
566            mask
567        };
568        mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
569            .to_dtype(self.dtype)
570    }
571
572    pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
573        let (b_size, seq_len) = input_ids.dims2()?;
574        let attention_mask = if seq_len <= 1 {
575            None
576        } else {
577            let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
578            Some(mask)
579        };
580        let mut xs = self.embed_tokens.forward(input_ids)?;
581        for layer in self.layers.iter_mut() {
582            xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
583        }
584        xs.narrow(1, seq_len - 1, 1)?
585            .apply(&self.norm)?
586            .apply(&self.lm_head)
587    }
588}