1use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
9use candle_nn::{
10 conv1d_no_bias, linear, linear_no_bias, ops::softmax_last_dim, rms_norm, Conv1d, Conv1dConfig,
11 Func, Linear, RmsNorm, VarBuilder,
12};
13use std::sync::Arc;
14
15#[derive(Debug, Clone, serde::Deserialize)]
16pub struct LinearAttentionFeatureMapConfig {
17 input_dim: usize,
18}
19
20#[derive(Debug, Clone, serde::Deserialize)]
21pub struct LinearAttentionConfig {
22 num_heads: usize,
23 feature_dim: usize,
24 feature_map: LinearAttentionFeatureMapConfig,
25}
26
27#[derive(Debug, Clone, serde::Deserialize)]
28pub struct SlidingWindowAttentionConfig {
29 num_heads: usize,
30 window_size: usize,
31}
32
33#[derive(Debug, Clone, serde::Deserialize)]
34pub struct Config {
35 vocab_size: usize,
36 #[serde(rename = "n_embd")]
37 hidden_size: usize,
38 #[serde(rename = "n_inner")]
39 intermediate_size: usize,
40 #[serde(rename = "n_layer")]
41 num_hidden_layers: usize,
42 #[serde(rename = "n_head")]
43 num_attention_heads: usize,
44
45 layer_norm_epsilon: f64,
46 #[serde(default = "default_rope", rename = "rotary_emb_base")]
47 rope_theta: f64,
48
49 alt_mixer_layers: Vec<usize>,
50 alt_mixer_2_layers: Vec<usize>,
51 #[serde(rename = "alt_mixer")]
52 la: LinearAttentionConfig,
53 #[serde(rename = "alt_mixer_2")]
54 swa: SlidingWindowAttentionConfig,
55}
56
57fn default_rope() -> f64 {
58 10_000.0
59}
60
61#[derive(Debug, Clone)]
62#[allow(clippy::upper_case_acronyms)]
63struct MLP {
64 fc1: Linear,
65 fc2: Linear,
66}
67
68impl MLP {
69 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
70 let fc1 = linear_no_bias(cfg.hidden_size, cfg.hidden_size * 4, vb.pp("fc1"))?;
71 let fc2 = linear_no_bias(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?;
72 Ok(Self { fc1, fc2 })
73 }
74}
75
76fn swiglu(xs: &Tensor) -> Result<Tensor> {
79 let xs = xs.chunk(2, D::Minus1)?;
80 &xs[1].silu()? * &xs[0]
81}
82
83impl Module for MLP {
84 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
85 let xs = xs.apply(&self.fc1)?;
86 let xs = swiglu(&xs)?;
87 let xs = xs.apply(&self.fc2)?;
88 Ok(xs)
89 }
90}
91
92#[derive(Debug, Clone)]
94struct BasedConv {
95 in_proj: Linear,
96 out_proj: Linear,
97 conv: Conv1d,
98 state: Tensor,
99}
100
101impl BasedConv {
102 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
103 let dim = cfg.hidden_size * 2;
104
105 let conv1d_cfg = Conv1dConfig {
106 groups: dim,
107 padding: 2,
108 ..Default::default()
109 };
110
111 let in_proj = linear(cfg.hidden_size, cfg.hidden_size * 4, vb.pp("in_proj"))?;
112 let out_proj = linear(dim, cfg.hidden_size, vb.pp("out_proj"))?;
113 let conv = conv1d_no_bias(dim, dim, 3, conv1d_cfg, vb.pp("conv.conv"))?;
114 let state = Tensor::zeros((1, dim, 3), vb.dtype(), vb.device())?;
115 Ok(Self {
116 in_proj,
117 out_proj,
118 conv,
119 state,
120 })
121 }
122
123 fn step(&mut self, xs: &Tensor) -> Result<Tensor> {
124 self.state = self.state.roll(-1, D::Minus1)?;
125 let (_, _, l) = self.state.dims3()?;
126 self.state = self.state.narrow(D::Minus1, 0, l - 1)?;
127 self.state = Tensor::cat(&[&self.state, &xs.transpose(1, 2)?], 2)?;
128
129 let xs = (&self.state * self.conv.weight().permute((1, 0, 2))?)?
130 .sum_keepdim(0)?
131 .sum(D::Minus1)?;
132
133 let xs = xs.unsqueeze(1)?;
134
135 Ok(xs)
136 }
137
138 fn forward(&mut self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
139 let xs = xs.apply(&self.in_proj)?;
140 let us = xs.chunk(2, D::Minus1)?;
141 let (_b, l, _d) = us[0].dims3()?;
142 let u_conv = if seqlen_offset > 0 {
143 self.step(&us[0])?
144 } else {
145 let k = std::cmp::min(3, l);
146 self.state = self.state.narrow(D::Minus1, 0, 3 - k)?;
147 let xs = us[0].narrow(1, l - k, k)?.transpose(1, 2)?;
148 self.state = Tensor::cat(&[&self.state, &xs], 2)?;
149
150 us[0]
151 .transpose(1, 2)?
152 .apply(&self.conv)?
153 .narrow(D::Minus1, 0, l)?
154 .transpose(1, 2)?
155 };
156
157 let u_conv = u_conv.silu()?;
158 let v = u_conv.broadcast_mul(&us[1])?;
159 let xs = v.apply(&self.out_proj)?;
160
161 Ok(xs)
162 }
163}
164
165#[derive(Debug, Clone)]
167struct LinearAttention {
168 proj_q: Linear,
169 proj_k: Linear,
170 proj_v: Linear,
171 out_proj: Linear,
172 feature_dim: usize,
173 num_heads: usize,
174 input_dim: usize,
175 k_state: Tensor,
176 kv_state: Tensor,
177}
178
179impl LinearAttention {
180 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
181 let input_dim = cfg.la.feature_map.input_dim;
182 let out_proj = linear_no_bias(cfg.hidden_size, cfg.hidden_size, vb.pp("out_proj"))?;
183 let proj_k = linear_no_bias(
184 cfg.hidden_size,
185 cfg.la.num_heads * cfg.la.feature_dim,
186 vb.pp("proj_k"),
187 )?;
188 let proj_q = linear_no_bias(
189 cfg.hidden_size,
190 cfg.la.num_heads * cfg.la.feature_dim,
191 vb.pp("proj_q"),
192 )?;
193
194 let proj_v = linear_no_bias(cfg.hidden_size, cfg.hidden_size, vb.pp("proj_v"))?;
195 let expanded_size = cfg.la.feature_dim.pow(2) + cfg.la.feature_dim + 1;
196 let k_state = Tensor::zeros(
197 (1, cfg.la.num_heads, 1, 1, expanded_size),
198 vb.dtype(),
199 vb.device(),
200 )?;
201 let kv_state = Tensor::zeros(
202 (1, cfg.la.num_heads, cfg.la.feature_dim, expanded_size),
203 vb.dtype(),
204 vb.device(),
205 )?;
206
207 Ok(Self {
208 proj_q,
209 proj_k,
210 proj_v,
211 out_proj,
212 feature_dim: cfg.la.feature_dim,
213 num_heads: cfg.la.num_heads,
214 input_dim,
215 k_state,
216 kv_state,
217 })
218 }
219
220 fn taylor_expansion(&self) -> Result<Func<'static>> {
221 let r2 = std::f64::consts::SQRT_2;
222 let rd = (self.input_dim as f64).sqrt();
223 let rrd = rd.sqrt();
224
225 Ok(Func::new(move |xs| {
226 let dims = xs.dims();
227 let mut d = dims.to_vec();
228 if let Some(last) = d.last_mut() {
229 *last = 1;
230 };
231
232 let x = xs
233 .unsqueeze(D::Minus1)?
234 .broadcast_mul(&xs.unsqueeze(D::Minus2)?)?;
235 let x = (x.flatten_from(D::Minus2)? / r2)?;
236 let o = Tensor::ones(d, xs.dtype(), xs.device())?;
237 let x = Tensor::cat(&[o, (xs / rrd)?, (&x / rd)?], D::Minus1)?;
238
239 Ok(x)
240 }))
241 }
242
243 fn forward(&mut self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
244 let eps = 1e-12;
245
246 let feature_map = self.taylor_expansion()?;
247
248 let (b, l, d) = xs.dims3()?;
249 let q = xs.apply(&self.proj_q)?;
250 let k = xs.apply(&self.proj_k)?;
251 let v = xs.apply(&self.proj_v)?;
252
253 let q = q
254 .reshape((b, l, self.num_heads, self.feature_dim))?
255 .transpose(1, 2)?
256 .contiguous()?;
257 let k = k
258 .reshape((b, l, self.num_heads, self.feature_dim))?
259 .transpose(1, 2)?
260 .contiguous()?;
261 let v = v
262 .reshape((b, l, self.num_heads, d / self.num_heads))?
263 .transpose(1, 2)?
264 .contiguous()?;
265
266 let q = feature_map.forward(&q)?;
267 let k = feature_map.forward(&k)?;
268
269 let y = if seqlen_offset > 0 {
270 let (_b, _h, l, _d) = k.dims4()?;
271 let q = q.unsqueeze(D::Minus2)?;
272 let k = k.unsqueeze(D::Minus2)?;
273 let v = v.unsqueeze(D::Minus1)?;
274 let kn = k.narrow(D::Minus1, l - 1, 1)?;
275 let vn = v.narrow(D::Minus1, l - 1, 1)?;
276
277 self.k_state = self.k_state.broadcast_add(&kn)?;
278 self.kv_state = self.kv_state.broadcast_add(&kn.broadcast_mul(&vn)?)?;
279
280 let num = q.broadcast_mul(&self.kv_state)?.sum(D::Minus1)?;
281 let den = (q.broadcast_mul(&self.k_state)?.sum(D::Minus1)? + eps)?;
282 num.broadcast_div(&den)?
283 } else {
284 self.k_state = k.sum(2)?.unsqueeze(2)?.unsqueeze(3)?;
285 self.kv_state = k
286 .transpose(2, 3)?
287 .matmul(&v)?
288 .transpose(2, 3)?
289 .unsqueeze(2)?;
290 let aqk = q.matmul(&k.transpose(D::Minus1, D::Minus2)?)?;
291 let tril = Tensor::tril2(l, aqk.dtype(), aqk.device())?;
292 let aqk = aqk.broadcast_mul(&tril)?.matmul(&v)?;
293
294 let z = (1f64 / (q.mul(&k.cumsum(2)?)?.sum(D::Minus1)? + eps)?)?;
295 aqk.broadcast_mul(&z.unsqueeze(D::Minus1)?)?
296 };
297
298 let (b, h, l, d) = y.dims4()?;
299 let y = y.permute((0, 2, 1, 3))?.reshape((b, l, h * d))?;
300 let y = self.out_proj.forward(&y)?;
301
302 Ok(y)
303 }
304}
305
306#[derive(Debug, Clone)]
308struct RotaryEmbedding {
309 sin: Tensor,
310 cos: Tensor,
311}
312
313impl RotaryEmbedding {
314 fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
315 let dim = cfg.hidden_size / cfg.num_attention_heads;
316 let max_seq_len = 2048; let inv_freq: Vec<_> = (0..dim)
318 .step_by(2)
319 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
320 .collect();
321 let inv_freq_len = inv_freq.len();
322 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
323 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
324 .to_dtype(dtype)?
325 .reshape((max_seq_len, 1))?;
326 let freqs = t.matmul(&inv_freq)?;
327 Ok(Self {
328 sin: freqs.sin()?,
329 cos: freqs.cos()?,
330 })
331 }
332
333 fn apply_rotary_emb_qkv(
334 &self,
335 q: &Tensor,
336 k: &Tensor,
337 seqlen_offset: usize,
338 ) -> Result<(Tensor, Tensor)> {
339 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
340 let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
341 let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
342 let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
343 let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
344 Ok((q_embed, k_embed))
345 }
346}
347
348#[derive(Debug, Clone)]
350struct SlidingWindowAttention {
351 wqkv: Linear,
352 out_proj: Linear,
353 num_heads: usize,
354 head_dim: usize,
355 hidden_size: usize,
356 rotary_emb: Arc<RotaryEmbedding>,
357 kv_cache: Option<(Tensor, Tensor)>,
358}
359
360impl SlidingWindowAttention {
361 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
362 let hidden_size = cfg.hidden_size;
363 let num_heads = cfg.swa.num_heads;
364 let head_dim = hidden_size / num_heads;
365 let out_proj = linear_no_bias(hidden_size, hidden_size, vb.pp("out_proj"))?;
366 let wqkv = linear_no_bias(hidden_size, hidden_size * 3, vb.pp("Wqkv"))?;
367 let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?);
368 Ok(Self {
369 wqkv,
370 out_proj,
371 hidden_size,
372 num_heads,
373 head_dim,
374 rotary_emb,
375 kv_cache: None,
376 })
377 }
378
379 fn forward(
380 &mut self,
381 xs: &Tensor,
382 attention_mask: Option<&Tensor>,
383 seqlen_offset: usize,
384 ) -> Result<Tensor> {
385 let (b_sz, q_len, _) = xs.dims3()?;
386
387 let qkv = xs.apply(&self.wqkv)?;
388 let qkv = qkv.reshape((b_sz, q_len, 3, (), self.head_dim))?;
389
390 let q = qkv.i((.., .., 0))?;
391 let k = qkv.i((.., .., 1))?;
392 let v = qkv.i((.., .., 2))?;
393
394 let q = q
395 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
396 .transpose(1, 2)?;
397 let k = k
398 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
399 .transpose(1, 2)?;
400 let v = v
401 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
402 .transpose(1, 2)?;
403
404 let (q, k) = self
405 .rotary_emb
406 .apply_rotary_emb_qkv(&q, &k, seqlen_offset)?;
407
408 let (k, v) = match &self.kv_cache {
409 None => (k, v),
410 Some((prev_k, prev_v)) => {
411 let k = Tensor::cat(&[prev_k, &k], 2)?;
412 let v = Tensor::cat(&[prev_v, &v], 2)?;
413 (k, v)
414 }
415 };
416 self.kv_cache = Some((k.clone(), v.clone()));
417
418 let scale = 1f64 / f64::sqrt(self.head_dim as f64);
419 let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
420
421 let attn_weights = match attention_mask {
422 None => attn_weights,
423 Some(mask) => attn_weights.broadcast_add(mask)?,
424 };
425 let attn_weights = softmax_last_dim(&attn_weights)?;
426 let attn_output = attn_weights.matmul(&v)?;
427 let out = attn_output
428 .transpose(1, 2)?
429 .reshape((b_sz, q_len, self.hidden_size))?
430 .apply(&self.out_proj)?;
431
432 Ok(out)
433 }
434}
435
436#[derive(Debug, Clone)]
438enum SequenceMixer {
439 Based(BasedConv),
440 Linear(LinearAttention),
441 Sliding(SlidingWindowAttention),
442}
443
444impl SequenceMixer {
445 fn forward(
446 &mut self,
447 xs: &Tensor,
448 attention_mask: Option<&Tensor>,
449 pos: usize,
450 ) -> Result<Tensor> {
451 match self {
452 Self::Based(b) => b.forward(xs, pos),
453 Self::Linear(b) => b.forward(xs, pos),
454 Self::Sliding(b) => b.forward(xs, attention_mask, pos),
455 }
456 }
457}
458
459#[derive(Debug, Clone)]
460struct DecoderLayer {
461 mlp: MLP,
462 norm1: RmsNorm,
463 norm2: RmsNorm,
464 mixer: SequenceMixer,
465}
466
467impl DecoderLayer {
468 fn new(layer_idx: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
469 let mlp = MLP::new(cfg, vb.pp("mlp"))?;
470 let norm1 = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm1"))?;
471 let norm2 = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm2"))?;
472
473 let l_attn = cfg.alt_mixer_layers.contains(&layer_idx);
474 let sw_attn = cfg.alt_mixer_2_layers.contains(&layer_idx);
475
476 let mixer = if l_attn {
477 SequenceMixer::Linear(LinearAttention::new(cfg, vb.pp("mixer"))?)
478 } else if sw_attn {
479 SequenceMixer::Sliding(SlidingWindowAttention::new(cfg, vb.pp("mixer"))?)
480 } else {
481 SequenceMixer::Based(BasedConv::new(cfg, vb.pp("mixer"))?)
482 };
483
484 Ok(Self {
485 mlp,
486 norm1,
487 norm2,
488 mixer,
489 })
490 }
491
492 fn forward(
493 &mut self,
494 xs: &Tensor,
495 attention_mask: Option<&Tensor>,
496 seqlen_offset: usize,
497 ) -> Result<Tensor> {
498 let residual = xs;
499 let xs = self.norm1.forward(xs)?;
500 let xs = self.mixer.forward(&xs, attention_mask, seqlen_offset)?;
501 let xs = (xs + residual)?;
502 let residual = &xs;
503 let xs = xs.apply(&self.norm2)?.apply(&self.mlp)?;
504 residual + xs
505 }
506}
507
508#[derive(Debug, Clone)]
509pub struct Model {
510 embed_tokens: super::with_tracing::Embedding,
511 layers: Vec<DecoderLayer>,
512 norm: RmsNorm,
513 lm_head: Linear,
514 sliding_window: usize,
515 device: Device,
516 dtype: DType,
517}
518
519impl Model {
520 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
521 let vocab_size = cfg.vocab_size + (8 - cfg.vocab_size % 8) % 8;
522 let lm_head = linear_no_bias(cfg.hidden_size, vocab_size, vb.pp("lm_head"))?;
523 let embed_tokens = super::with_tracing::Embedding::from_weights(lm_head.weight().clone())?;
524 let vb_m = vb.pp("transformer");
525 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
526 let vb_l = vb_m.pp("layers");
527 for layer_idx in 0..cfg.num_hidden_layers {
528 let layer = DecoderLayer::new(layer_idx, cfg, vb_l.pp(layer_idx))?;
529 layers.push(layer)
530 }
531 let norm = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb_m.pp("ln_f"))?;
532 Ok(Self {
533 embed_tokens,
534 layers,
535 norm,
536 lm_head,
537 sliding_window: cfg.swa.window_size,
538 device: vb.device().clone(),
539 dtype: vb.dtype(),
540 })
541 }
542
543 fn prepare_decoder_attention_mask(
544 &self,
545 b_size: usize,
546 tgt_len: usize,
547 seqlen_offset: usize,
548 ) -> Result<Tensor> {
549 let sliding_window = self.sliding_window / 2;
550 let mask: Vec<_> = (0..tgt_len)
551 .flat_map(|i| {
552 (0..tgt_len).map(move |j| {
553 if i < j || j + sliding_window < i {
554 f32::NEG_INFINITY
555 } else {
556 0.
557 }
558 })
559 })
560 .collect();
561 let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
562 let mask = if seqlen_offset > 0 {
563 let mask0 = Tensor::zeros((tgt_len, seqlen_offset), self.dtype, &self.device)?;
564 Tensor::cat(&[&mask0, &mask], D::Minus1)?
565 } else {
566 mask
567 };
568 mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
569 .to_dtype(self.dtype)
570 }
571
572 pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
573 let (b_size, seq_len) = input_ids.dims2()?;
574 let attention_mask = if seq_len <= 1 {
575 None
576 } else {
577 let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
578 Some(mask)
579 };
580 let mut xs = self.embed_tokens.forward(input_ids)?;
581 for layer in self.layers.iter_mut() {
582 xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
583 }
584 xs.narrow(1, seq_len - 1, 1)?
585 .apply(&self.norm)?
586 .apply(&self.lm_head)
587 }
588}