1use crate::nn::{
12 linear, linear_from, matmul_dtype, MaybeQuantizedLinear, MaybeQuantizedVarBuilder,
13};
14use crate::streaming::{StreamMask, StreamTensor, StreamingModule};
15use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
16
17use crate::kv_cache::KvCache;
18use candle::Context;
19
20#[derive(Debug, Clone, serde::Deserialize)]
21pub struct Config {
22 pub d_model: usize,
23 pub num_heads: usize,
24 pub num_layers: usize,
25 pub causal: bool,
26 pub norm_first: bool,
27 pub bias_ff: bool,
28 pub bias_attn: bool,
29 pub layer_scale: Option<f64>,
30 pub positional_embedding: PositionalEmbedding,
31 pub use_conv_block: bool,
32 pub cross_attention: Option<(CrossAttentionGating, crate::NormType, Option<usize>)>,
33 pub conv_kernel_size: usize,
34 pub use_conv_bias: bool,
35 pub gating: Option<candle_nn::Activation>,
36 pub norm: crate::NormType,
37 pub context: usize,
38 pub max_period: usize,
39 pub max_seq_len: usize,
40
41 pub kv_repeat: usize,
42 pub dim_feedforward: usize,
43 pub conv_layout: bool,
44
45 #[serde(default)]
46 pub shared_cross_attn: bool,
47}
48
49#[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
50pub enum PositionalEmbedding {
51 Rope,
52 Sin,
53 None,
54}
55
56#[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
57pub enum CrossAttentionGating {
58 Normal,
60 ConstantGatedTanh,
61 ConstantGatedSigmoid,
62 ConditionalGatedTanh,
63 ConditionalGatedSigmoid,
64 ConditionalGatedSigmoidLearnableBias,
65 ConditionalGatedTanhLearnableBias,
66}
67
68#[derive(Debug, Clone)]
69pub enum CaSrc {
70 Tokens(Tensor),
75 KeysValues((Tensor, Tensor)),
76}
77
78#[derive(Debug, Clone)]
79pub struct LayerScale {
80 scale: Tensor,
81}
82
83impl LayerScale {
84 pub fn new(d_model: usize, _init: f64, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
85 let scale = vb.get_unquantized(d_model, "scale")?;
86 Ok(Self { scale })
87 }
88}
89
90impl Module for LayerScale {
91 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
92 xs.broadcast_mul(&self.scale)
93 }
94}
95
96#[derive(Debug, Clone)]
97pub enum XaGate {
98 Normal,
101 ConstantGated {
103 alpha: Tensor,
104 },
105 ConditionalGated {
109 in_proj: MaybeQuantizedLinear,
110 out_proj: MaybeQuantizedLinear,
111 activation: candle_nn::init::NonLinearity,
112 learnable_bias: bool,
113 },
114}
115
116impl XaGate {
117 pub fn new(cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
118 let gating_cfg =
119 cfg.cross_attention.map(|v| v.0).context("no cross-attention specified")?;
120 match gating_cfg {
121 CrossAttentionGating::Normal => Ok(Self::Normal),
123 CrossAttentionGating::ConstantGatedTanh => {
125 let alpha = vb.get_unquantized((1, 1, 1), "alpha")?.tanh()?;
126 Ok(Self::ConstantGated { alpha })
127 }
128 CrossAttentionGating::ConstantGatedSigmoid => {
130 let alpha =
131 candle_nn::ops::sigmoid(&(vb.get_unquantized((1, 1, 1), "alpha")? - 4.0)?)?;
132 Ok(Self::ConstantGated { alpha })
133 }
134 CrossAttentionGating::ConditionalGatedTanh
136 | CrossAttentionGating::ConditionalGatedSigmoid
137 | CrossAttentionGating::ConditionalGatedSigmoidLearnableBias
138 | CrossAttentionGating::ConditionalGatedTanhLearnableBias => {
139 let dim = cfg.d_model;
140 let hidden_dims = (0.125 * dim as f32).floor() as usize;
141 let learnable_bias = matches!(
142 gating_cfg,
143 CrossAttentionGating::ConditionalGatedSigmoidLearnableBias
144 | CrossAttentionGating::ConditionalGatedTanhLearnableBias
145 );
146 let in_proj = linear(dim, hidden_dims, false, vb.pp("alpha.0"))?;
147 let out_proj = linear(hidden_dims, dim, learnable_bias, vb.pp("alpha.2"))?;
148 let activation = match gating_cfg {
149 CrossAttentionGating::ConditionalGatedTanh
150 | CrossAttentionGating::ConditionalGatedTanhLearnableBias => {
151 candle_nn::init::NonLinearity::Tanh
152 }
153 CrossAttentionGating::ConditionalGatedSigmoid
154 | CrossAttentionGating::ConditionalGatedSigmoidLearnableBias => {
155 candle_nn::init::NonLinearity::Sigmoid
156 }
157 _ => candle::bail!("Invalid cross-attention config specified."),
158 };
159 Ok(Self::ConditionalGated { in_proj, out_proj, activation, learnable_bias })
160 }
161 }
162 }
163}
164
165impl Module for XaGate {
166 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
167 match self {
168 Self::Normal => Ok(xs.clone()),
169 Self::ConstantGated { alpha } => xs.broadcast_mul(alpha),
170 Self::ConditionalGated { in_proj, out_proj, activation, learnable_bias } => {
171 let alpha = xs.apply(in_proj)?.relu()?.apply(out_proj)?;
172 let alpha = match (activation, learnable_bias) {
173 (candle_nn::init::NonLinearity::Tanh, _) => alpha.tanh(),
174 (candle_nn::init::NonLinearity::Sigmoid, true) => {
175 candle_nn::ops::sigmoid(&alpha)
176 }
177 (candle_nn::init::NonLinearity::Sigmoid, false) => {
178 candle_nn::ops::sigmoid(&(alpha - 4.0)?)
179 }
180 _ => candle::bail!("Invalid non-linearity specified in cross-attention gating"),
181 };
182 xs * alpha?
183 }
184 }
185 }
186}
187
188#[derive(Debug, Clone)]
189pub struct StreamingMultiheadCrossAttention {
190 in_proj_q: MaybeQuantizedLinear,
194 in_proj_kv: MaybeQuantizedLinear,
195 out_proj: MaybeQuantizedLinear,
196 kv_repeat: usize,
197 num_heads: usize,
198 gate: XaGate,
199 span: tracing::Span,
200}
201
202impl StreamingMultiheadCrossAttention {
203 pub fn new(
204 cfg: &Config,
205 vb: MaybeQuantizedVarBuilder,
206 gate_vb: Option<MaybeQuantizedVarBuilder>,
207 ) -> Result<Self> {
208 let embed_dim = cfg.d_model;
209 let num_kv = cfg.num_heads / cfg.kv_repeat;
210 let out_kv_dim = num_kv * (embed_dim / cfg.num_heads);
211 let out_dim = embed_dim + 2 * out_kv_dim;
212 let (in_proj_q, in_proj_kv) = if vb.contains_key("in_proj_weight") {
217 match &vb {
218 MaybeQuantizedVarBuilder::Quantized(_) => candle::bail!("Quantized cross-attention layers require a separate in_proj_weight_q and in_proj_weight_kv"),
219 MaybeQuantizedVarBuilder::Real(weights) => {
220 let in_proj_weight = weights.get((out_dim, embed_dim), "in_proj_weight")?;
221 let in_proj_weight_q = in_proj_weight.narrow(0, 0, embed_dim)?;
222 let in_proj_weight_kv = in_proj_weight.narrow(0, embed_dim, 2 * out_kv_dim)?;
223 let (in_proj_bias_q, in_proj_bias_kv) = if cfg.bias_attn {
224 let b = weights.get(out_dim, "in_proj_bias")?;
225 let in_proj_bias_q = b.narrow(0, 0, embed_dim)?;
226 let in_proj_bias_kv = b.narrow(0, embed_dim, 2 * out_kv_dim)?;
227 (Some(in_proj_bias_q), Some(in_proj_bias_kv))
228 } else {
229 (None, None)
230 };
231 (MaybeQuantizedLinear::Real(candle_nn::Linear::new(in_proj_weight_q, in_proj_bias_q)),
232 MaybeQuantizedLinear::Real(candle_nn::Linear::new(in_proj_weight_kv, in_proj_bias_kv)))
233
234 }
235 }
236 } else {
237 let kv_in_dim = match cfg.cross_attention.map(|v| v.2) {
239 None => candle::bail!("cfg.cross_attention is None in cross_attention module"),
240 Some(d) => match d {
241 None | Some(0) => embed_dim,
242 Some(dd) => dd,
243 },
244 };
245 let in_proj_weight_q = vb.get((embed_dim, embed_dim), "in_proj_weight_q")?;
246 let in_proj_weight_kv = vb.get((2 * out_kv_dim, kv_in_dim), "in_proj_weight_kv")?;
247
248 let (in_proj_bias_q, in_proj_bias_kv) = if cfg.bias_attn {
250 (
251 Some(vb.get_unquantized(embed_dim, "in_proj_bias_q")?),
252 Some(vb.get_unquantized(2 * out_kv_dim, "in_proj_bias_kv")?),
253 )
254 } else {
255 (None, None)
256 };
257
258 let in_proj_q = linear_from(in_proj_weight_q, in_proj_bias_q)?;
260 let in_proj_kv = linear_from(in_proj_weight_kv, in_proj_bias_kv)?;
261 (in_proj_q, in_proj_kv)
262 };
263
264 let out_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("out_proj"))?;
265 let gate = match gate_vb {
266 None => XaGate::new(cfg, vb.pp("gate"))?,
267 Some(layer_gate_vb) => XaGate::new(cfg, layer_gate_vb)?,
268 };
269 Ok(Self {
270 in_proj_q,
271 in_proj_kv,
272 out_proj,
273 kv_repeat: cfg.kv_repeat,
274 num_heads: cfg.num_heads,
275 gate,
276 span: tracing::span!(tracing::Level::TRACE, "mhca"),
277 })
278 }
279
280 pub fn is_quantized(&self) -> bool {
281 match self.in_proj_q {
282 MaybeQuantizedLinear::Quantized(_) => true,
283 MaybeQuantizedLinear::Real(_) => false,
284 }
285 }
286
287 pub fn compute_kv(&self, ca_src: &CaSrc) -> Result<(Tensor, Tensor)> {
288 match ca_src {
293 CaSrc::KeysValues(cakv) => Ok(cakv.clone()),
294 CaSrc::Tokens(xs) => {
295 let kv = xs.apply(&self.in_proj_kv)?;
296 let (ca_b, ca_t, ca_dim) = kv.dims3()?;
297 let head_dim = ca_dim / (2 * self.num_heads);
298 let kv = kv.reshape((ca_b, ca_t, 2, (), head_dim))?;
299 let kv =
301 if self.is_quantized() { kv.to_dtype(matmul_dtype(xs.device()))? } else { kv };
302 let k = kv.i((.., .., 0))?;
303 let v = kv.i((.., .., 1))?;
304 let k = k.transpose(1, 2)?.contiguous()?; let v = v.transpose(1, 2)?.contiguous()?; Ok((k, v))
307 }
308 }
309 }
310
311 pub fn forward(&self, xs: &Tensor, ca_src: &CaSrc, mask: Option<&Tensor>) -> Result<Tensor> {
312 let _enter = self.span.enter();
313 if self.kv_repeat != 1 {
314 candle::bail!("only kv-repeat = 1 is supported")
315 }
316 let (b, t, hd) = xs.dims3()?;
317 let head_dim = hd / self.num_heads;
318 let q = xs.apply(&self.in_proj_q)?;
320 let original_dtype = q.dtype();
321 let q = q.reshape((b, t, self.num_heads, head_dim))?;
322 let q = if self.is_quantized() { q.to_dtype(matmul_dtype(xs.device()))? } else { q };
323 let (k, v) = self.compute_kv(ca_src)?;
324 let q = q.transpose(1, 2)?.contiguous()?; let pre_ws = q.matmul(&k.t()?)?; let pre_ws = (pre_ws * (head_dim as f64).powf(-0.5))?;
330
331 let pre_ws = match mask {
332 None => pre_ws,
333 Some(mask) => pre_ws.broadcast_add(mask)?,
334 };
335
336 let ws = candle_nn::ops::softmax_last_dim(&pre_ws)?; let xs = ws.matmul(&v)?; let xs = xs
339 .transpose(1, 2)? .reshape((b, t, hd))?
341 .to_dtype(original_dtype)?
342 .apply(&self.out_proj)?
343 .apply(&self.gate)?;
344 Ok(xs)
345 }
346}
347
348#[derive(Debug, Clone)]
349pub struct Rope {
350 sin: Tensor,
351 cos: Tensor,
352}
353
354impl Rope {
355 pub fn apply_rotary_emb(&self, qk: &Tensor) -> Result<Tensor> {
356 let qk_dtype = qk.dtype();
357 candle_nn::rotary_emb::rope_i(&qk.to_dtype(DType::F32)?, &self.cos, &self.sin)?
358 .to_dtype(qk_dtype)
359 }
360}
361
362#[derive(Debug, Clone)]
363pub struct RotaryEmbedding {
364 inv_freq: Tensor,
365}
366
367impl RotaryEmbedding {
368 pub fn new(dim: usize, theta: f32, dev: &Device) -> Result<Self> {
369 let inv_freq: Vec<_> =
370 (0..dim).step_by(2).map(|i| 1f32 / theta.powf(i as f32 / dim as f32)).collect();
371 let inv_freq_len = inv_freq.len();
372 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
373 Ok(Self { inv_freq })
374 }
375
376 pub fn rope(&self, pos: &Tensor) -> Result<Rope> {
377 let t = pos.to_dtype(DType::F32)?;
378 let freqs = match *t.dims() {
379 [d] => t.reshape((d, 1))?.matmul(&self.inv_freq)?,
380 [b, d] => t.reshape((b * d, 1))?.matmul(&self.inv_freq)?.reshape((b, d, ()))?,
381 _ => candle::bail!("Invalid shape for rotary embedding {pos:?}"),
382 };
383 Ok(Rope { sin: freqs.sin()?, cos: freqs.cos()? })
384 }
385}
386
387#[cfg(feature = "flash-attn")]
388fn flash_attn(
389 q: &Tensor,
390 k: &Tensor,
391 v: &Tensor,
392 softmax_scale: f32,
393 causal: bool,
394) -> Result<Tensor> {
395 candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
396}
397
398#[cfg(not(feature = "flash-attn"))]
399fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
400 unimplemented!("compile with '--features flash-attn'")
401}
402
403#[derive(Debug, Clone)]
404pub struct StreamingMultiheadAttention {
405 in_proj: MaybeQuantizedLinear,
407 out_proj: MaybeQuantizedLinear,
408 kv_repeat: usize,
409 num_heads: usize,
410 context: usize,
411 kv_cache: KvCache,
412 use_flash_attn: bool,
413 span: tracing::Span,
414}
415
416impl StreamingMultiheadAttention {
417 pub fn new(cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
418 let embed_dim = cfg.d_model;
419 let num_kv = cfg.num_heads / cfg.kv_repeat;
420 let out_dim = embed_dim + 2 * num_kv * (embed_dim / cfg.num_heads);
421 let in_proj_weight = vb.get((out_dim, embed_dim), "in_proj_weight")?;
422 let in_proj_bias =
423 if cfg.bias_attn { Some(vb.get_unquantized(out_dim, "in_proj_bias")?) } else { None };
424 let in_proj = linear_from(in_proj_weight, in_proj_bias)?;
425 let out_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("out_proj"))?;
426 Ok(Self {
427 in_proj,
428 out_proj,
429 kv_repeat: cfg.kv_repeat,
430 num_heads: cfg.num_heads,
431 context: cfg.context,
432 kv_cache: KvCache::new(2, cfg.context),
433 use_flash_attn: false,
434 span: tracing::span!(tracing::Level::TRACE, "mha"),
435 })
436 }
437
438 pub fn is_quantized(&self) -> bool {
439 match self.in_proj {
440 MaybeQuantizedLinear::Quantized(_) => true,
441 MaybeQuantizedLinear::Real(_) => false,
442 }
443 }
444
445 pub fn forward(
446 &mut self,
447 xs: &Tensor,
448 rope: Option<&Rope>,
449 mask: Option<&Tensor>,
450 ) -> Result<Tensor> {
451 let _enter = self.span.enter();
452 if self.kv_repeat != 1 {
453 candle::bail!("only kv-repeat = 1 is supported")
454 }
455 let (b, t, hd) = xs.dims3()?;
456 let head_dim = hd / self.num_heads;
457 let qkv = xs.apply(&self.in_proj)?.reshape((b, t, 3, self.num_heads, head_dim))?;
459 let original_dtype = qkv.dtype();
460 let qkv = if self.is_quantized() { qkv.to_dtype(matmul_dtype(xs.device()))? } else { qkv };
461 let q = qkv.i((.., .., 0))?;
462 let k = qkv.i((.., .., 1))?;
463 let v = qkv.i((.., .., 2))?;
464 let mut q = q.transpose(1, 2)?.contiguous()?; let mut k = k.transpose(1, 2)?.contiguous()?; let v = v.transpose(1, 2)?.contiguous()?; if let Some(rope) = rope.as_ref() {
470 q = rope.apply_rotary_emb(&q)?;
471 k = rope.apply_rotary_emb(&k)?;
472 }
473
474 let (k, v) = { self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)? };
475 let k_len = k.dim(2)?;
479 let k_target_len = t + usize::min(self.context, k_len - t);
480 let (k, v) = if k_target_len < k_len {
481 let k = k.narrow(2, k_len - k_target_len, k_target_len)?;
482 let v = v.narrow(2, k_len - k_target_len, k_target_len)?;
483 (k, v)
484 } else {
485 (k.clone(), v.clone())
486 };
487
488 let xs = if q.dtype() == DType::BF16 && self.use_flash_attn {
489 let q = q.transpose(1, 2)?;
490 let k = k.transpose(1, 2)?;
491 let v = v.transpose(1, 2)?;
492 let softmax_scale = 1f32 / (head_dim as f32).sqrt();
493 flash_attn(&q, &k, &v, softmax_scale, mask.is_some())?.transpose(1, 2)?
494 } else {
495 let pre_ws = q.matmul(&k.t()?)?; let pre_ws = (pre_ws * (head_dim as f64).powf(-0.5))?;
497
498 let pre_ws = match mask {
499 None => pre_ws,
500 Some(mask) => pre_ws.broadcast_add(mask)?,
501 };
502
503 let ws = candle_nn::ops::softmax_last_dim(&pre_ws)?; ws.matmul(&v)? };
506
507 let xs = xs
508 .transpose(1, 2)? .reshape((b, t, hd))?
510 .to_dtype(original_dtype)?
511 .apply(&self.out_proj)?;
512 Ok(xs)
513 }
514
515 pub fn reset_kv_cache(&mut self) {
516 self.kv_cache.reset()
517 }
518
519 pub fn set_kv_cache(&mut self, kv_cache: KvCache) {
520 self.kv_cache = kv_cache
521 }
522}
523
524#[derive(Debug, Clone)]
525pub enum Mlp {
526 NoGating {
528 linear1: MaybeQuantizedLinear,
529 linear2: MaybeQuantizedLinear,
530 },
531 Gating {
532 linear_in: MaybeQuantizedLinear,
533 linear_out: MaybeQuantizedLinear,
534 activation: candle_nn::Activation,
535 },
536}
537
538impl Mlp {
539 pub fn new(cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
540 let d_model = cfg.d_model;
541 match cfg.gating {
542 None => {
543 let linear1 = linear(d_model, cfg.dim_feedforward, cfg.bias_ff, vb.pp("linear1"))?;
544 let linear2 = linear(cfg.dim_feedforward, d_model, cfg.bias_ff, vb.pp("linear2"))?;
545 Ok(Self::NoGating { linear1, linear2 })
546 }
547 Some(activation) => {
548 let vb = vb.pp("gating");
549 let hidden = if cfg.dim_feedforward == 4 * d_model {
550 11 * d_model / 4
551 } else {
552 2 * cfg.dim_feedforward / 3
553 };
554 let linear_in = linear(d_model, 2 * hidden, cfg.bias_ff, vb.pp("linear_in"))?;
555 let linear_out = linear(hidden, d_model, cfg.bias_ff, vb.pp("linear_out"))?;
556 Ok(Self::Gating { linear_in, linear_out, activation })
557 }
558 }
559 }
560}
561
562impl Module for Mlp {
563 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
564 match self {
565 Self::NoGating { linear1, linear2 } => xs.apply(linear1)?.gelu_erf()?.apply(linear2),
566 Self::Gating { linear_in, linear_out, activation } => {
567 let xs = xs.apply(linear_in)?;
568 let (b, t, _) = xs.dims3()?;
569 let xs = xs.reshape((b, t, 2, ()))?;
570 let xs = (xs.i((.., .., 0))?.apply(activation)? * xs.i((.., .., 1))?)?;
571 xs.apply(linear_out)
572 }
573 }
574 }
575}
576
577#[derive(Debug, Clone)]
578pub struct RmsNorm {
579 pub(crate) alpha: Tensor,
580 pub(crate) eps: f32,
581}
582
583impl RmsNorm {
584 pub fn new(d_model: usize, eps: f32, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
585 let alpha = vb.get_unquantized((1, 1, d_model), "alpha")?.reshape(d_model)?;
586 Ok(Self { alpha, eps })
587 }
588}
589
590impl Module for RmsNorm {
591 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
592 candle_nn::ops::rms_norm(xs, &self.alpha, self.eps)
593 }
594}
595
596#[derive(Debug, Clone)]
597pub struct LayerNorm {
598 inner: candle_nn::LayerNorm,
599}
600
601impl LayerNorm {
602 pub fn new(d_model: usize, eps: f32, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
603 let bias = vb.get_unquantized(d_model, "bias")?;
604 let alpha = if vb.contains_key("alpha") {
605 vb.get_unquantized((1, 1, d_model), "alpha")?.reshape(d_model)?
606 } else {
607 vb.get_unquantized(d_model, "weight")?.reshape(d_model)?
608 };
609 let inner = candle_nn::LayerNorm::new(alpha, bias, eps as f64);
610 Ok(Self { inner })
611 }
612}
613
614impl Module for LayerNorm {
615 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
616 self.inner.forward(xs)
617 }
618}
619
620#[derive(Debug, Clone)]
621pub enum Norm {
622 LayerNorm(LayerNorm),
623 RmsNorm(RmsNorm),
624}
625
626impl Norm {
627 pub fn new(d_model: usize, cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
628 let norm = Self::new_shortcut(d_model, cfg.norm, vb)?;
629 Ok(norm)
630 }
631
632 pub fn new_shortcut(
633 d_model: usize,
634 typ: crate::NormType,
635 vb: MaybeQuantizedVarBuilder,
636 ) -> Result<Self> {
637 let norm = match typ {
638 crate::NormType::LayerNorm => {
639 let norm = LayerNorm::new(d_model, 1e-5, vb)?;
640 Self::LayerNorm(norm)
641 }
642 crate::NormType::RmsNorm => {
643 let norm = RmsNorm::new(d_model, 1e-8, vb)?;
644 Self::RmsNorm(norm)
645 }
646 };
647 Ok(norm)
648 }
649}
650
651impl Module for Norm {
652 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
653 match self {
654 Self::LayerNorm(m) => m.forward(xs),
655 Self::RmsNorm(m) => m.forward(xs),
656 }
657 }
658}
659
660#[derive(Debug, Clone)]
661pub struct StreamingTransformerLayer {
662 self_attn: StreamingMultiheadAttention,
663 mlp: Mlp,
664 norm1: Norm,
665 norm2: Norm,
666 layer_scale_1: Option<LayerScale>,
667 layer_scale_2: Option<LayerScale>,
668 cross_attn: Option<(Norm, StreamingMultiheadCrossAttention)>,
669 norm_first: bool,
670 span: tracing::Span,
671}
672
673impl StreamingTransformerLayer {
674 pub fn new(
675 cfg: &Config,
676 vb: MaybeQuantizedVarBuilder,
677 shared_ca_vb: Option<MaybeQuantizedVarBuilder>,
678 ) -> Result<Self> {
679 if cfg.use_conv_block {
680 candle::bail!("conv-block is not supported")
681 }
682 let d_model = cfg.d_model;
683 let mlp = Mlp::new(cfg, vb.clone())?;
684 let norm1 = Norm::new(d_model, cfg, vb.pp("norm1"))?;
685 let norm2 = Norm::new(d_model, cfg, vb.pp("norm2"))?;
686 let layer_scale_1 = match cfg.layer_scale {
687 None => None,
688 Some(ls) => {
689 let ls = LayerScale::new(d_model, ls, vb.pp("layer_scale_1"))?;
690 Some(ls)
691 }
692 };
693 let layer_scale_2 = match cfg.layer_scale {
694 None => None,
695 Some(ls) => {
696 let ls = LayerScale::new(d_model, ls, vb.pp("layer_scale_2"))?;
697 Some(ls)
698 }
699 };
700 let self_attn = StreamingMultiheadAttention::new(cfg, vb.pp("self_attn"))?;
701 let cross_attn = match cfg.cross_attention.map(|v| v.1) {
702 Some(norm_type) => {
703 let norm_cross = Norm::new_shortcut(d_model, norm_type, vb.pp("norm_cross"))?;
704 let cross_attn = match shared_ca_vb {
705 None => {
706 StreamingMultiheadCrossAttention::new(cfg, vb.pp("cross_attention"), None)?
707 }
708 Some(shared_vb) => StreamingMultiheadCrossAttention::new(
709 cfg,
710 shared_vb.pp("cross_attention"),
711 Some(vb.pp("cross_attention.gate")),
712 )?,
713 };
714 Some((norm_cross, cross_attn))
715 }
716 None => None,
717 };
718 Ok(Self {
719 self_attn,
720 mlp,
721 norm1,
722 norm2,
723 layer_scale_1,
724 layer_scale_2,
725 cross_attn,
726 norm_first: cfg.norm_first,
727 span: tracing::span!(tracing::Level::TRACE, "transformer-layer"),
728 })
729 }
730
731 pub fn forward(
732 &mut self,
733 xs: &Tensor,
734 rope: Option<&Rope>,
735 ca_src: Option<&CaSrc>,
736 mask: Option<&Tensor>,
737 ) -> Result<Tensor> {
738 let _enter = self.span.enter();
739 if !self.norm_first {
740 candle::bail!("only norm_first = true is supported")
741 }
742 let norm1 = xs.apply(&self.norm1)?;
743 let xs = (xs
744 + self.self_attn.forward(&norm1, rope, mask)?.apply(&self.layer_scale_1.as_ref())?)?;
745
746 let xs = match (self.cross_attn.as_mut(), ca_src) {
747 (Some((norm_cross, cross_attn)), Some(ca_src)) => {
748 let residual = &xs;
749 let xs = xs.apply(norm_cross)?;
750 (residual + cross_attn.forward(&xs, ca_src, None)?)?
751 }
752 _ => xs,
753 };
754
755 let xs =
756 (&xs + xs.apply(&self.norm2)?.apply(&self.mlp)?.apply(&self.layer_scale_2.as_ref()))?;
757 Ok(xs)
758 }
759
760 pub fn reset_kv_cache(&mut self) {
761 self.self_attn.reset_kv_cache();
762 }
763
764 pub fn set_kv_cache(&mut self, kv_cache: KvCache) {
765 self.self_attn.set_kv_cache(kv_cache);
766 }
767}
768
769#[derive(Debug, Clone)]
770pub struct StreamingTransformer {
771 layers: Vec<StreamingTransformerLayer>,
773 positional_embedding: PositionalEmbedding,
774 max_period: usize,
775 causal: bool,
776 num_heads: usize,
777 context: usize,
778 last_reset_pos: Vec<usize>,
779 rope: Option<RotaryEmbedding>,
780}
781
782impl StreamingTransformer {
783 pub fn new(cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
784 let vb_l = vb.pp("layers");
785 let rope = match cfg.positional_embedding {
786 PositionalEmbedding::Rope => {
787 let rope = RotaryEmbedding::new(
788 cfg.d_model / cfg.num_heads,
789 cfg.max_period as f32,
790 vb.device(),
791 )?;
792 Some(rope)
793 }
794 PositionalEmbedding::None | PositionalEmbedding::Sin => None,
795 };
796 let mut layers = Vec::with_capacity(cfg.num_layers);
797 for layer_idx in 0..cfg.num_layers {
798 let shared_vb = if cfg.shared_cross_attn { Some(vb_l.pp(0)) } else { None };
801 let layer = StreamingTransformerLayer::new(cfg, vb_l.pp(layer_idx), shared_vb)?;
802 layers.push(layer)
803 }
804 Ok(Self {
805 layers,
806 positional_embedding: cfg.positional_embedding,
807 max_period: cfg.max_period,
808 causal: cfg.causal,
809 num_heads: cfg.num_heads,
810 context: cfg.context,
811 last_reset_pos: vec![],
812 rope,
813 })
814 }
815
816 pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
817 self.forward_ca(xs, None)
818 }
819
820 fn current_seq_len(&self) -> usize {
821 self.layers[0].self_attn.kv_cache.current_seq_len()
822 }
823
824 pub fn forward_ca(&mut self, xs: &Tensor, ca_src: Option<&CaSrc>) -> Result<Tensor> {
825 let (b, t, c) = xs.dims3()?;
826 if !self.causal {
827 candle::bail!("only causal mode is supported")
828 }
829 if self.last_reset_pos.is_empty() {
830 self.last_reset_pos.resize(b, 0);
831 }
832 let current_seq_len = self.current_seq_len();
833 let mask = {
837 let ks = self.layers[0].self_attn.kv_cache.positions(t);
841 let min_ks = ks.iter().min().context("no positions, is t == 0?")?;
842 if t == 1 && self.last_reset_pos.iter().all(|v| v <= min_ks) {
843 None
845 } else {
846 let mut mask = Vec::with_capacity(b * self.num_heads * t * ks.len());
847 for &last_reset_pos in self.last_reset_pos.iter() {
848 for t_pos in 0..t {
849 let t_pos = t_pos + current_seq_len;
850 for &k_pos in ks.iter() {
851 let m = if last_reset_pos <= k_pos
852 && k_pos <= t_pos
853 && t_pos <= k_pos + self.context
854 {
855 0f32
856 } else {
857 f32::NEG_INFINITY
858 };
859 mask.push(m);
860 }
861 }
862 }
863 let mask = Tensor::from_vec(mask, (b, 1, t, ks.len()), xs.device())?
864 .to_dtype(xs.dtype())?
865 .expand((b, self.num_heads, t, ks.len()))?;
866 Some(mask)
867 }
868 };
869 let pos =
872 Tensor::arange(current_seq_len as u32, (current_seq_len + t) as u32, xs.device())?;
873 let rope = match self.rope {
874 Some(ref rope) => Some(rope.rope(&pos)?),
875 None => None,
876 };
877 let mut xs = match self.positional_embedding {
878 PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(),
879 PositionalEmbedding::Sin => {
880 let dev = xs.device();
881 let theta = self.max_period as f32;
882 let half_dim = c / 2;
883 let positions = pos.unsqueeze(1)?.to_dtype(DType::F32)?;
884 let inv_freq: Vec<_> = (0..half_dim)
885 .map(|i| 1f32 / theta.powf(i as f32 / (half_dim - 1) as f32))
886 .collect();
887 let inv_freq_len = inv_freq.len();
888 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
889 let freqs = positions.broadcast_mul(&inv_freq)?;
890 let pos_emb = Tensor::cat(&[freqs.cos()?, freqs.sin()?], D::Minus1)?;
891 xs.broadcast_add(&pos_emb)?
892 }
893 };
894 for layer in self.layers.iter_mut() {
895 xs = layer.forward(&xs, rope.as_ref(), ca_src, mask.as_ref())?
896 }
897 Ok(xs)
898 }
899
900 pub fn maybe_precompute_ca_kv(&self, ca_src: Option<CaSrc>) -> Result<Option<CaSrc>> {
901 let ca_src = match ca_src {
902 None => None,
903 Some(CaSrc::KeysValues(_)) => ca_src,
904 Some(tokens) => {
905 if self.layers.is_empty() {
906 Some(tokens)
907 } else {
908 match &self.layers[0].cross_attn {
909 None => Some(tokens),
910 Some((_, ca_module)) => {
911 let (k, v) = ca_module.compute_kv(&tokens)?;
912 Some(CaSrc::KeysValues((k, v)))
913 }
914 }
915 }
916 }
917 };
918 Ok(ca_src)
919 }
920
921 pub fn copy_state(&mut self, from: &Self) -> Result<()> {
922 if self.layers.len() != from.layers.len() {
923 candle::bail!("cannot copy kv-caches as the transformers have different depths")
924 }
925 self.last_reset_pos = from.last_reset_pos.clone();
926 self.layers
927 .iter_mut()
928 .zip(from.layers.iter())
929 .for_each(|(v, w)| v.set_kv_cache(w.self_attn.kv_cache.clone()));
930 Ok(())
931 }
932
933 pub fn reset_batch_idx(&mut self, batch_idx: usize, batch_size: usize) -> Result<()> {
934 if self.last_reset_pos.is_empty() {
935 self.last_reset_pos.resize(batch_size, 0);
936 }
937 if batch_idx >= self.last_reset_pos.len() {
938 candle::bail!("batch_idx {} is out of bounds for last_reset_pos", batch_idx)
939 }
940 self.last_reset_pos[batch_idx] = self.current_seq_len();
941 Ok(())
942 }
943}
944
945impl StreamingModule for StreamingTransformer {
946 fn reset_state(&mut self) {
947 self.last_reset_pos.clear();
948 self.layers.iter_mut().for_each(|v| v.reset_kv_cache())
949 }
950
951 fn step(&mut self, xs: &StreamTensor, _: &StreamMask) -> Result<StreamTensor> {
952 match xs.as_option() {
954 None => Ok(StreamTensor::empty()),
955 Some(xs) => Ok(StreamTensor::from_tensor(self.forward(xs)?)),
956 }
957 }
958}
959
960#[derive(Debug, Clone)]
961pub struct ProjectedTransformer {
962 transformer: StreamingTransformer,
964 input_proj: Option<MaybeQuantizedLinear>,
965 output_projs: Vec<Option<MaybeQuantizedLinear>>,
966 conv_layout: bool,
967 span: tracing::Span,
968}
969
970impl ProjectedTransformer {
971 pub fn new(
972 input_dim: usize,
973 output_dims: &[usize],
974 cfg: &Config,
975 vb: MaybeQuantizedVarBuilder,
976 ) -> Result<Self> {
977 let transformer = StreamingTransformer::new(cfg, vb.pp("transformer"))?;
978 let input_proj = if input_dim == cfg.d_model {
979 None
980 } else {
981 let l = linear(input_dim, cfg.d_model, false, vb.pp("input_proj"))?;
982 Some(l)
983 };
984 let mut output_projs = Vec::with_capacity(output_dims.len());
985 let vb_o = vb.pp("output_projs");
986 for (i, &output_dim) in output_dims.iter().enumerate() {
987 let output_proj = if output_dim == cfg.d_model {
988 None
989 } else {
990 let l = linear(cfg.d_model, output_dim, false, vb_o.pp(i))?;
991 Some(l)
992 };
993 output_projs.push(output_proj)
994 }
995 Ok(Self {
996 transformer,
997 input_proj,
998 output_projs,
999 conv_layout: cfg.conv_layout,
1000 span: tracing::span!(tracing::Level::TRACE, "proj-transformer"),
1001 })
1002 }
1003
1004 pub fn forward(&mut self, xs: &Tensor) -> Result<Vec<Tensor>> {
1005 let _enter = self.span.enter();
1006 let xs = if self.conv_layout { xs.transpose(1, 2)? } else { xs.clone() };
1007 let xs = xs.apply(&self.input_proj.as_ref())?;
1008 let xs = self.transformer.forward(&xs)?;
1009 let mut ys = Vec::with_capacity(self.output_projs.len());
1010 for output_proj in self.output_projs.iter() {
1011 let ys_ = xs.apply(&output_proj.as_ref())?;
1012 let ys_ = if self.conv_layout { ys_.transpose(1, 2)? } else { ys_ };
1013 ys.push(ys_)
1014 }
1015 Ok(ys)
1016 }
1017
1018 pub fn reset_batch_idx(&mut self, batch_idx: usize, batch_size: usize) -> Result<()> {
1019 self.transformer.reset_batch_idx(batch_idx, batch_size)
1020 }
1021}
1022
1023impl StreamingModule for ProjectedTransformer {
1024 fn reset_state(&mut self) {
1025 self.transformer.reset_state()
1026 }
1027
1028 fn step(&mut self, xs: &StreamTensor, m: &StreamMask) -> Result<StreamTensor> {
1029 let xs = xs.apply(&|x: &Tensor| {
1030 if self.conv_layout {
1031 x.transpose(1, 2)
1032 } else {
1033 Ok(x.clone())
1034 }
1035 })?;
1036 let xs = xs.apply(&self.input_proj.as_ref())?;
1037 let xs = self.transformer.step(&xs, m)?;
1038 let ys = xs.apply(&self.output_projs[0].as_ref())?;
1039 ys.apply(&|y: &Tensor| {
1040 if self.conv_layout {
1041 y.transpose(1, 2)
1042 } else {
1043 Ok(y.clone())
1044 }
1045 })
1046 }
1047}
1048
1049#[derive(Debug, Clone)]
1050pub enum Transformer {
1051 Standard(ProjectedTransformer),
1052 Batched(crate::batched_transformer::ProjectedTransformer),
1053}
1054
1055impl StreamingModule for Transformer {
1056 fn reset_state(&mut self) {
1057 match self {
1058 Transformer::Standard(t) => t.reset_state(),
1059 Transformer::Batched(t) => t.reset_state(),
1060 }
1061 }
1062
1063 fn step(&mut self, xs: &StreamTensor, m: &StreamMask) -> Result<StreamTensor> {
1064 match self {
1065 Transformer::Standard(t) => t.step(xs, m),
1066 Transformer::Batched(t) => t.step(xs, m),
1067 }
1068 }
1069}
1070
1071impl Transformer {
1072 pub fn new(
1073 batch_size: Option<usize>,
1074 dim: usize,
1075 cfg: &Config,
1076 vb: candle_nn::VarBuilder,
1077 ) -> Result<Self> {
1078 let transformer = match batch_size {
1079 Some(batch_size) => {
1080 let transformer = crate::batched_transformer::ProjectedTransformer::new(
1081 dim,
1082 &[dim],
1083 batch_size,
1084 cfg,
1085 MaybeQuantizedVarBuilder::Real(vb),
1086 )?;
1087 Transformer::Batched(transformer)
1088 }
1089 None => {
1090 let transformer = ProjectedTransformer::new(
1091 dim,
1092 &[dim],
1093 cfg,
1094 MaybeQuantizedVarBuilder::Real(vb),
1095 )?;
1096 Transformer::Standard(transformer)
1097 }
1098 };
1099 Ok(transformer)
1100 }
1101
1102 pub fn forward(&mut self, xs: &Tensor) -> Result<Vec<Tensor>> {
1103 match self {
1104 Transformer::Standard(t) => t.forward(xs),
1105 Transformer::Batched(t) => t.forward(xs, &().into()),
1106 }
1107 }
1108
1109 pub fn reset_batch_idx(&mut self, batch_idx: usize, batch_size: usize) -> Result<()> {
1110 match self {
1111 Transformer::Standard(t) => t.reset_batch_idx(batch_idx, batch_size),
1112 Transformer::Batched(t) => t.reset_batch_idx(batch_idx),
1113 }
1114 }
1115}