Skip to main content

hermes_llm/
model.rs

1use candle_core::{DType, Device, Module, Result, Tensor};
2use candle_nn::{Dropout, Embedding, Linear, VarBuilder, embedding, linear, linear_no_bias};
3
4use crate::mal::{ModelDef, NormPosition, NormType};
5
6fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
7    let shape = on_false.shape();
8    let mask = mask.broadcast_as(shape.dims())?;
9    let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
10    let m = mask.where_cond(&on_true, on_false)?;
11    Ok(m)
12}
13
14#[derive(Debug, Clone)]
15pub struct RMSNorm {
16    weight: Tensor,
17    eps: f64,
18}
19
20impl RMSNorm {
21    pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
22        let weight = vb.get_with_hints(size, "weight", candle_nn::Init::Const(1.0))?;
23        Ok(Self { weight, eps })
24    }
25}
26
27impl Module for RMSNorm {
28    fn forward(&self, x: &Tensor) -> Result<Tensor> {
29        let dtype = x.dtype();
30        let x = x.to_dtype(DType::F32)?;
31        let variance = x.sqr()?.mean_keepdim(candle_core::D::Minus1)?;
32        let x = x.broadcast_div(&(variance + self.eps)?.sqrt()?)?;
33        let x = x.to_dtype(dtype)?;
34        x.broadcast_mul(&self.weight)
35    }
36}
37
38#[derive(Debug, Clone)]
39pub struct LayerNorm {
40    weight: Tensor,
41    bias: Option<Tensor>,
42    eps: f64,
43}
44
45impl LayerNorm {
46    pub fn new(size: usize, eps: f64, use_bias: bool, vb: VarBuilder) -> Result<Self> {
47        let weight = vb.get_with_hints(size, "weight", candle_nn::Init::Const(1.0))?;
48        let bias = if use_bias {
49            Some(vb.get_with_hints(size, "bias", candle_nn::Init::Const(0.0))?)
50        } else {
51            None
52        };
53        Ok(Self { weight, bias, eps })
54    }
55}
56
57impl Module for LayerNorm {
58    fn forward(&self, x: &Tensor) -> Result<Tensor> {
59        let dtype = x.dtype();
60        let x = x.to_dtype(DType::F32)?;
61        let mean = x.mean_keepdim(candle_core::D::Minus1)?;
62        let x = x.broadcast_sub(&mean)?;
63        let variance = x.sqr()?.mean_keepdim(candle_core::D::Minus1)?;
64        let x = x.broadcast_div(&(variance + self.eps)?.sqrt()?)?;
65        let x = x.to_dtype(dtype)?;
66        let x = x.broadcast_mul(&self.weight)?;
67        match &self.bias {
68            Some(bias) => x.broadcast_add(bias),
69            None => Ok(x),
70        }
71    }
72}
73
74/// Unified normalization layer that can be either RMSNorm or LayerNorm
75#[derive(Debug, Clone)]
76pub enum Norm {
77    RmsNorm(RMSNorm),
78    LayerNorm(LayerNorm),
79}
80
81impl Norm {
82    pub fn new(
83        norm_type: NormType,
84        size: usize,
85        eps: f64,
86        use_bias: bool,
87        vb: VarBuilder,
88    ) -> Result<Self> {
89        match norm_type {
90            NormType::RmsNorm | NormType::None => Ok(Self::RmsNorm(RMSNorm::new(size, eps, vb)?)),
91            NormType::LayerNorm => Ok(Self::LayerNorm(LayerNorm::new(size, eps, use_bias, vb)?)),
92        }
93    }
94}
95
96impl Module for Norm {
97    fn forward(&self, x: &Tensor) -> Result<Tensor> {
98        match self {
99            Self::RmsNorm(n) => n.forward(x),
100            Self::LayerNorm(n) => n.forward(x),
101        }
102    }
103}
104
105pub struct RotaryEmbedding {
106    cos: Tensor,
107    sin: Tensor,
108}
109
110impl RotaryEmbedding {
111    pub fn new(head_dim: usize, max_seq_len: usize, theta: f64, device: &Device) -> Result<Self> {
112        let inv_freq: Vec<f32> = (0..head_dim)
113            .step_by(2)
114            .map(|i| 1.0 / (theta as f32).powf(i as f32 / head_dim as f32))
115            .collect();
116        let inv_freq = Tensor::new(inv_freq.as_slice(), device)?;
117        let positions: Vec<f32> = (0..max_seq_len).map(|p| p as f32).collect();
118        let positions = Tensor::new(positions.as_slice(), device)?.unsqueeze(1)?;
119        let freqs = positions.matmul(&inv_freq.unsqueeze(0)?)?;
120        let cos = freqs.cos()?;
121        let sin = freqs.sin()?;
122        Ok(Self { cos, sin })
123    }
124
125    pub fn apply(&self, q: &Tensor, k: &Tensor, start_pos: usize) -> Result<(Tensor, Tensor)> {
126        let seq_len = q.dim(2)?;
127        let cos = self.cos.narrow(0, start_pos, seq_len)?;
128        let sin = self.sin.narrow(0, start_pos, seq_len)?;
129
130        let q_rot = self.rotate_half(q, &cos, &sin)?;
131        let k_rot = self.rotate_half(k, &cos, &sin)?;
132        Ok((q_rot, k_rot))
133    }
134
135    fn rotate_half(&self, x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
136        let (b, h, seq, d) = x.dims4()?;
137        let x1 = x.narrow(3, 0, d / 2)?;
138        let x2 = x.narrow(3, d / 2, d / 2)?;
139        let rotated = Tensor::cat(&[&x2.neg()?, &x1], 3)?;
140
141        let cos = cos
142            .unsqueeze(0)?
143            .unsqueeze(0)?
144            .broadcast_as((b, h, seq, d / 2))?;
145        let sin = sin
146            .unsqueeze(0)?
147            .unsqueeze(0)?
148            .broadcast_as((b, h, seq, d / 2))?;
149        let cos = Tensor::cat(&[&cos, &cos], 3)?;
150        let sin = Tensor::cat(&[&sin, &sin], 3)?;
151
152        let x_cos = x.broadcast_mul(&cos)?;
153        let rot_sin = rotated.broadcast_mul(&sin)?;
154        let result = x_cos.add(&rot_sin)?;
155        Ok(result)
156    }
157}
158
159pub struct MultiHeadAttention {
160    q_proj: Linear,
161    k_proj: Linear,
162    v_proj: Linear,
163    o_proj: Linear,
164    num_heads: usize,
165    num_kv_heads: usize,
166    head_dim: usize,
167    dropout: Dropout,
168    window_size: Option<usize>,
169    causal: bool,
170}
171
172impl MultiHeadAttention {
173    pub fn new(config: &ModelDef, vb: VarBuilder) -> Result<Self> {
174        let num_heads = config.num_heads();
175        let num_kv_heads = config.num_kv_heads();
176        let head_dim = config.head_dim();
177        let hidden_size = config.hidden_size;
178        let kv_dim = num_kv_heads * head_dim;
179
180        let (q_proj, k_proj, v_proj, o_proj) = if config.use_bias() {
181            (
182                linear(hidden_size, hidden_size, vb.pp("q_proj"))?,
183                linear(hidden_size, kv_dim, vb.pp("k_proj"))?,
184                linear(hidden_size, kv_dim, vb.pp("v_proj"))?,
185                linear(hidden_size, hidden_size, vb.pp("o_proj"))?,
186            )
187        } else {
188            (
189                linear_no_bias(hidden_size, hidden_size, vb.pp("q_proj"))?,
190                linear_no_bias(hidden_size, kv_dim, vb.pp("k_proj"))?,
191                linear_no_bias(hidden_size, kv_dim, vb.pp("v_proj"))?,
192                linear_no_bias(hidden_size, hidden_size, vb.pp("o_proj"))?,
193            )
194        };
195        let dropout = Dropout::new(config.dropout() as f32);
196        Ok(Self {
197            q_proj,
198            k_proj,
199            v_proj,
200            o_proj,
201            num_heads,
202            num_kv_heads,
203            head_dim,
204            dropout,
205            window_size: config.block.attention.window_size,
206            causal: config.block.attention.causal,
207        })
208    }
209
210    pub fn forward(
211        &self,
212        x: &Tensor,
213        mask: Option<&Tensor>,
214        rope: &RotaryEmbedding,
215        start_pos: usize,
216        train: bool,
217    ) -> Result<Tensor> {
218        let (batch_size, seq_len, _) = x.dims3()?;
219
220        let q = self.q_proj.forward(x)?;
221        let k = self.k_proj.forward(x)?;
222        let v = self.v_proj.forward(x)?;
223
224        let q = q.reshape((batch_size, seq_len, self.num_heads, self.head_dim))?;
225        let k = k.reshape((batch_size, seq_len, self.num_kv_heads, self.head_dim))?;
226        let v = v.reshape((batch_size, seq_len, self.num_kv_heads, self.head_dim))?;
227
228        let q = q.transpose(1, 2)?.contiguous()?;
229        let k = k.transpose(1, 2)?.contiguous()?;
230        let v = v.transpose(1, 2)?.contiguous()?;
231
232        let (q, k) = rope.apply(&q, &k, start_pos)?;
233
234        // Repeat KV heads for GQA (Grouped Query Attention)
235        let (k, v) = if self.num_kv_heads != self.num_heads {
236            let n_rep = self.num_heads / self.num_kv_heads;
237            let k = k
238                .unsqueeze(2)?
239                .expand((batch_size, self.num_kv_heads, n_rep, seq_len, self.head_dim))?
240                .reshape((batch_size, self.num_heads, seq_len, self.head_dim))?;
241            let v = v
242                .unsqueeze(2)?
243                .expand((batch_size, self.num_kv_heads, n_rep, seq_len, self.head_dim))?
244                .reshape((batch_size, self.num_heads, seq_len, self.head_dim))?;
245            (k, v)
246        } else {
247            (k, v)
248        };
249
250        // Use Flash Attention if available (CUDA only), otherwise standard attention
251        #[cfg(feature = "flash-attn")]
252        let attn_output = {
253            let q = q.transpose(1, 2)?;
254            let k = k.transpose(1, 2)?;
255            let v = v.transpose(1, 2)?;
256            let softmax_scale = 1.0 / (self.head_dim as f32).sqrt();
257            let attn = candle_flash_attn::flash_attn(&q, &k, &v, softmax_scale, self.causal)?;
258            attn.reshape((batch_size, seq_len, self.num_heads * self.head_dim))?
259        };
260
261        #[cfg(not(feature = "flash-attn"))]
262        let attn_output = {
263            let scale = (self.head_dim as f64).sqrt();
264            let k_t = k.transpose(2, 3)?.contiguous()?;
265            let attn_weights = q.matmul(&k_t)?.affine(1.0 / scale, 0.0)?;
266
267            // Apply causal mask if needed
268            let attn_weights = if self.causal {
269                match mask {
270                    Some(m) => masked_fill(&attn_weights, m, f32::NEG_INFINITY)?,
271                    None => attn_weights,
272                }
273            } else {
274                attn_weights
275            };
276
277            // Apply sliding window mask if configured
278            let attn_weights = if let Some(window) = self.window_size {
279                let device = attn_weights.device();
280                let mut window_mask = vec![0u8; seq_len * seq_len];
281                for i in 0..seq_len {
282                    for j in 0..seq_len {
283                        if (i as isize - j as isize).unsigned_abs() > window {
284                            window_mask[i * seq_len + j] = 1;
285                        }
286                    }
287                }
288                let window_mask = Tensor::from_vec(window_mask, (seq_len, seq_len), device)?
289                    .unsqueeze(0)?
290                    .unsqueeze(0)?;
291                masked_fill(&attn_weights, &window_mask, f32::NEG_INFINITY)?
292            } else {
293                attn_weights
294            };
295
296            let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
297            let attn_weights = if train {
298                self.dropout.forward(&attn_weights, train)?
299            } else {
300                attn_weights
301            };
302
303            let output = attn_weights.matmul(&v)?;
304            let output = output.transpose(1, 2)?.contiguous()?;
305            output.reshape((batch_size, seq_len, self.num_heads * self.head_dim))?
306        };
307
308        self.o_proj.forward(&attn_output)
309    }
310}
311
312pub struct FeedForward {
313    gate_proj: Option<Linear>,
314    up_proj: Linear,
315    down_proj: Linear,
316    dropout: Dropout,
317    use_swiglu: bool,
318}
319
320impl FeedForward {
321    pub fn new(config: &ModelDef, vb: VarBuilder) -> Result<Self> {
322        let use_swiglu = config.use_swiglu();
323        let use_gate = config.block.ffn.gate;
324        let intermediate_size = config.intermediate_size();
325
326        let gate_proj = if use_gate {
327            Some(if config.use_bias() {
328                linear(config.hidden_size, intermediate_size, vb.pp("gate_proj"))?
329            } else {
330                linear_no_bias(config.hidden_size, intermediate_size, vb.pp("gate_proj"))?
331            })
332        } else {
333            None
334        };
335
336        let (up_proj, down_proj) = if config.use_bias() {
337            (
338                linear(config.hidden_size, intermediate_size, vb.pp("up_proj"))?,
339                linear(intermediate_size, config.hidden_size, vb.pp("down_proj"))?,
340            )
341        } else {
342            (
343                linear_no_bias(config.hidden_size, intermediate_size, vb.pp("up_proj"))?,
344                linear_no_bias(intermediate_size, config.hidden_size, vb.pp("down_proj"))?,
345            )
346        };
347        let dropout = Dropout::new(config.dropout() as f32);
348        Ok(Self {
349            gate_proj,
350            up_proj,
351            down_proj,
352            dropout,
353            use_swiglu,
354        })
355    }
356
357    pub fn forward(&self, x: &Tensor, train: bool) -> Result<Tensor> {
358        let hidden = if let Some(gate_proj) = &self.gate_proj {
359            let gate = gate_proj.forward(x)?;
360            let up = self.up_proj.forward(x)?;
361            if self.use_swiglu {
362                let gate = candle_nn::ops::silu(&gate)?;
363                (gate * up)?
364            } else {
365                let gate = gate.gelu_erf()?;
366                (gate * up)?
367            }
368        } else {
369            // No gating - simple MLP
370            let h = self.up_proj.forward(x)?;
371            if self.use_swiglu {
372                candle_nn::ops::silu(&h)?
373            } else {
374                h.gelu_erf()?
375            }
376        };
377
378        let hidden = self.dropout.forward(&hidden, train)?;
379        self.down_proj.forward(&hidden)
380    }
381}
382
383pub struct TransformerBlock {
384    attention: MultiHeadAttention,
385    feed_forward: FeedForward,
386    attn_norm: Norm,
387    ffn_norm: Norm,
388    norm_position: NormPosition,
389    use_residual: bool,
390}
391
392impl TransformerBlock {
393    pub fn new(config: &ModelDef, vb: VarBuilder) -> Result<Self> {
394        let attention = MultiHeadAttention::new(config, vb.pp("attention"))?;
395        let feed_forward = FeedForward::new(config, vb.pp("feed_forward"))?;
396        let norm_type = config.block.norm.norm_type;
397        let attn_norm = Norm::new(
398            norm_type,
399            config.hidden_size,
400            config.norm_eps(),
401            config.use_bias(),
402            vb.pp("attn_norm"),
403        )?;
404        let ffn_norm = Norm::new(
405            norm_type,
406            config.hidden_size,
407            config.norm_eps(),
408            config.use_bias(),
409            vb.pp("ffn_norm"),
410        )?;
411        Ok(Self {
412            attention,
413            feed_forward,
414            attn_norm,
415            ffn_norm,
416            norm_position: config.block.norm_position,
417            use_residual: config.block.residual,
418        })
419    }
420
421    pub fn forward(
422        &self,
423        x: &Tensor,
424        mask: Option<&Tensor>,
425        rope: &RotaryEmbedding,
426        start_pos: usize,
427        train: bool,
428    ) -> Result<Tensor> {
429        // Attention
430        let x = match self.norm_position {
431            NormPosition::Pre => {
432                let h = self.attn_norm.forward(x)?;
433                let h = self.attention.forward(&h, mask, rope, start_pos, train)?;
434                if self.use_residual { (x + h)? } else { h }
435            }
436            NormPosition::Post => {
437                let h = self.attention.forward(x, mask, rope, start_pos, train)?;
438                let h = if self.use_residual { (x + h)? } else { h };
439                self.attn_norm.forward(&h)?
440            }
441        };
442
443        // FFN
444        match self.norm_position {
445            NormPosition::Pre => {
446                let h = self.ffn_norm.forward(&x)?;
447                let h = self.feed_forward.forward(&h, train)?;
448                if self.use_residual { &x + h } else { Ok(h) }
449            }
450            NormPosition::Post => {
451                let h = self.feed_forward.forward(&x, train)?;
452                let h = if self.use_residual { (&x + h)? } else { h };
453                self.ffn_norm.forward(&h)
454            }
455        }
456    }
457}
458
459pub struct Transformer {
460    embedding: Embedding,
461    layers: Vec<TransformerBlock>,
462    final_norm: Norm,
463    lm_head: Linear,
464    rope: RotaryEmbedding,
465    config: ModelDef,
466}
467
468impl Transformer {
469    pub fn new(config: &ModelDef, vb: VarBuilder) -> Result<Self> {
470        let embedding = embedding(config.vocab_size, config.hidden_size, vb.pp("embedding"))?;
471        let mut layers = Vec::with_capacity(config.num_layers);
472        for i in 0..config.num_layers {
473            layers.push(TransformerBlock::new(
474                config,
475                vb.pp(format!("layers.{}", i)),
476            )?);
477        }
478        let final_norm = Norm::new(
479            config.block.norm.norm_type,
480            config.hidden_size,
481            config.norm_eps(),
482            config.use_bias(),
483            vb.pp("final_norm"),
484        )?;
485        let lm_head = linear_no_bias(config.hidden_size, config.vocab_size, vb.pp("lm_head"))?;
486        let rope = RotaryEmbedding::new(
487            config.head_dim(),
488            config.max_seq_len,
489            config.rope_theta(),
490            vb.device(),
491        )?;
492        Ok(Self {
493            embedding,
494            layers,
495            final_norm,
496            lm_head,
497            rope,
498            config: config.clone(),
499        })
500    }
501
502    pub fn forward(&self, input_ids: &Tensor, start_pos: usize, train: bool) -> Result<Tensor> {
503        let (_, seq_len) = input_ids.dims2()?;
504
505        let x = self.embedding.forward(input_ids)?;
506
507        let mask = if seq_len > 1 {
508            let mut mask_data = vec![0u8; seq_len * seq_len];
509            for i in 0..seq_len {
510                for j in (i + 1)..seq_len {
511                    mask_data[i * seq_len + j] = 1;
512                }
513            }
514            let mask = Tensor::from_vec(mask_data, (seq_len, seq_len), input_ids.device())?
515                .unsqueeze(0)?
516                .unsqueeze(0)?;
517            Some(mask)
518        } else {
519            None
520        };
521
522        let mut x = x;
523        for layer in &self.layers {
524            x = layer.forward(&x, mask.as_ref(), &self.rope, start_pos, train)?;
525        }
526
527        let x = self.final_norm.forward(&x)?;
528        self.lm_head.forward(&x)
529    }
530
531    pub fn config(&self) -> &ModelDef {
532        &self.config
533    }
534
535    pub fn num_parameters(&self) -> usize {
536        self.config.estimated_params()
537    }
538}
539
540pub fn cross_entropy_loss(logits: &Tensor, targets: &Tensor) -> Result<Tensor> {
541    let (batch_size, seq_len, vocab_size) = logits.dims3()?;
542    let logits = logits.reshape((batch_size * seq_len, vocab_size))?;
543    let targets = targets.reshape((batch_size * seq_len,))?;
544    candle_nn::loss::cross_entropy(&logits, &targets)
545}