Skip to main content

cake_core/models/common/
transformer.rs

1use anyhow::Result;
2use candle_core::Tensor;
3use candle_nn::{Module, RmsNorm};
4
5use crate::cake::{Context, Forwarder};
6use async_trait::async_trait;
7
8use super::{CausalSelfAttention, MLP};
9
10/// Transformer block with causal self attention and several caching strategies.
11#[derive(Debug, Clone)]
12pub struct Transformer {
13    name: String,
14    rms_1: RmsNorm,
15    attn: CausalSelfAttention,
16    rms_2: RmsNorm,
17    mlp: MLP,
18}
19
20impl std::fmt::Display for Transformer {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        write!(f, "{} (local)", &self.name)
23    }
24}
25
26#[async_trait]
27impl Forwarder for Transformer {
28    fn load(name: String, ctx: &Context) -> Result<Box<Self>> {
29        let vb = ctx
30            .var_builder
31            .as_ref()
32            .expect("No var_builder specified")
33            .pp(&name);
34        let cfg = ctx.config.as_ref().expect("No config specified");
35
36        let attn = super::CausalSelfAttention::load(vb.pp("self_attn"), cfg)?;
37        let mlp = super::MLP::load(vb.pp("mlp"), cfg)?;
38        let rms_1 =
39            candle_nn::rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
40        let rms_2 = candle_nn::rms_norm(
41            cfg.hidden_size,
42            cfg.rms_norm_eps,
43            vb.pp("post_attention_layernorm"),
44        )?;
45        Ok(Box::new(Self {
46            name,
47            rms_1,
48            attn,
49            rms_2,
50            mlp,
51        }))
52    }
53
54    async fn forward(
55        &self,
56        x: &Tensor,
57        index_pos: usize,
58        block_idx: usize,
59        ctx: &mut Context,
60    ) -> Result<Tensor> {
61        let residual = x;
62
63        let x = self.rms_1.forward(x).map_err(|e| anyhow!("rms_1: {e}"))?;
64        let x = (self
65            .attn
66            .forward(
67                &x,
68                index_pos,
69                block_idx,
70                ctx.cache.as_mut().expect("No cache specified"),
71            )
72            .map_err(|e| anyhow!("attention: {e}"))?
73            + residual)
74            .map_err(|e| anyhow!("residual: {e}"))?;
75        let residual = &x;
76        let x = self.rms_2.forward(&x).map_err(|e| anyhow!("rms_2: {e}"))?;
77        let x = (self.mlp.forward(&x).map_err(|e| anyhow!("mlp: {e}"))? + residual)
78            .map_err(|e| anyhow!("mlp residual: {e}"))?;
79
80        Ok(x)
81    }
82
83    async fn forward_mut(
84        &mut self,
85        x: &Tensor,
86        index_pos: usize,
87        block_idx: usize,
88        ctx: &mut Context,
89    ) -> Result<Tensor> {
90        self.forward(x, index_pos, block_idx, ctx).await
91    }
92
93    fn layer_name(&self) -> &str {
94        &self.name
95    }
96}