cake_core/models/common/
transformer.rs1use anyhow::Result;
2use candle_core::Tensor;
3use candle_nn::{Module, RmsNorm};
4
5use crate::cake::{Context, Forwarder};
6use async_trait::async_trait;
7
8use super::{CausalSelfAttention, MLP};
9
10#[derive(Debug, Clone)]
12pub struct Transformer {
13 name: String,
14 rms_1: RmsNorm,
15 attn: CausalSelfAttention,
16 rms_2: RmsNorm,
17 mlp: MLP,
18}
19
20impl std::fmt::Display for Transformer {
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 write!(f, "{} (local)", &self.name)
23 }
24}
25
26#[async_trait]
27impl Forwarder for Transformer {
28 fn load(name: String, ctx: &Context) -> Result<Box<Self>> {
29 let vb = ctx
30 .var_builder
31 .as_ref()
32 .expect("No var_builder specified")
33 .pp(&name);
34 let cfg = ctx.config.as_ref().expect("No config specified");
35
36 let attn = super::CausalSelfAttention::load(vb.pp("self_attn"), cfg)?;
37 let mlp = super::MLP::load(vb.pp("mlp"), cfg)?;
38 let rms_1 =
39 candle_nn::rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
40 let rms_2 = candle_nn::rms_norm(
41 cfg.hidden_size,
42 cfg.rms_norm_eps,
43 vb.pp("post_attention_layernorm"),
44 )?;
45 Ok(Box::new(Self {
46 name,
47 rms_1,
48 attn,
49 rms_2,
50 mlp,
51 }))
52 }
53
54 async fn forward(
55 &self,
56 x: &Tensor,
57 index_pos: usize,
58 block_idx: usize,
59 ctx: &mut Context,
60 ) -> Result<Tensor> {
61 let residual = x;
62
63 let x = self.rms_1.forward(x).map_err(|e| anyhow!("rms_1: {e}"))?;
64 let x = (self
65 .attn
66 .forward(
67 &x,
68 index_pos,
69 block_idx,
70 ctx.cache.as_mut().expect("No cache specified"),
71 )
72 .map_err(|e| anyhow!("attention: {e}"))?
73 + residual)
74 .map_err(|e| anyhow!("residual: {e}"))?;
75 let residual = &x;
76 let x = self.rms_2.forward(&x).map_err(|e| anyhow!("rms_2: {e}"))?;
77 let x = (self.mlp.forward(&x).map_err(|e| anyhow!("mlp: {e}"))? + residual)
78 .map_err(|e| anyhow!("mlp residual: {e}"))?;
79
80 Ok(x)
81 }
82
83 async fn forward_mut(
84 &mut self,
85 x: &Tensor,
86 index_pos: usize,
87 block_idx: usize,
88 ctx: &mut Context,
89 ) -> Result<Tensor> {
90 self.forward(x, index_pos, block_idx, ctx).await
91 }
92
93 fn layer_name(&self) -> &str {
94 &self.name
95 }
96}