hermes-llm 1.8.21

LLM training from scratch using Candle
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
use candle_core::{DType, Device, Module, Result, Tensor};
use candle_nn::{Dropout, Embedding, Linear, VarBuilder, embedding, linear, linear_no_bias};

use crate::mal::{ModelDef, NormPosition, NormType};

fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
    let shape = on_false.shape();
    let mask = mask.broadcast_as(shape.dims())?;
    let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
    let m = mask.where_cond(&on_true, on_false)?;
    Ok(m)
}

#[derive(Debug, Clone)]
pub struct RMSNorm {
    weight: Tensor,
    eps: f64,
}

impl RMSNorm {
    pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
        let weight = vb.get_with_hints(size, "weight", candle_nn::Init::Const(1.0))?;
        Ok(Self { weight, eps })
    }
}

impl Module for RMSNorm {
    fn forward(&self, x: &Tensor) -> Result<Tensor> {
        let dtype = x.dtype();
        let x = x.to_dtype(DType::F32)?;
        let variance = x.sqr()?.mean_keepdim(candle_core::D::Minus1)?;
        let x = x.broadcast_div(&(variance + self.eps)?.sqrt()?)?;
        let x = x.to_dtype(dtype)?;
        x.broadcast_mul(&self.weight)
    }
}

#[derive(Debug, Clone)]
pub struct LayerNorm {
    weight: Tensor,
    bias: Option<Tensor>,
    eps: f64,
}

impl LayerNorm {
    pub fn new(size: usize, eps: f64, use_bias: bool, vb: VarBuilder) -> Result<Self> {
        let weight = vb.get_with_hints(size, "weight", candle_nn::Init::Const(1.0))?;
        let bias = if use_bias {
            Some(vb.get_with_hints(size, "bias", candle_nn::Init::Const(0.0))?)
        } else {
            None
        };
        Ok(Self { weight, bias, eps })
    }
}

impl Module for LayerNorm {
    fn forward(&self, x: &Tensor) -> Result<Tensor> {
        let dtype = x.dtype();
        let x = x.to_dtype(DType::F32)?;
        let mean = x.mean_keepdim(candle_core::D::Minus1)?;
        let x = x.broadcast_sub(&mean)?;
        let variance = x.sqr()?.mean_keepdim(candle_core::D::Minus1)?;
        let x = x.broadcast_div(&(variance + self.eps)?.sqrt()?)?;
        let x = x.to_dtype(dtype)?;
        let x = x.broadcast_mul(&self.weight)?;
        match &self.bias {
            Some(bias) => x.broadcast_add(bias),
            None => Ok(x),
        }
    }
}

/// Unified normalization layer that can be either RMSNorm or LayerNorm
#[derive(Debug, Clone)]
pub enum Norm {
    RmsNorm(RMSNorm),
    LayerNorm(LayerNorm),
}

impl Norm {
    pub fn new(
        norm_type: NormType,
        size: usize,
        eps: f64,
        use_bias: bool,
        vb: VarBuilder,
    ) -> Result<Self> {
        match norm_type {
            NormType::RmsNorm | NormType::None => Ok(Self::RmsNorm(RMSNorm::new(size, eps, vb)?)),
            NormType::LayerNorm => Ok(Self::LayerNorm(LayerNorm::new(size, eps, use_bias, vb)?)),
        }
    }
}

impl Module for Norm {
    fn forward(&self, x: &Tensor) -> Result<Tensor> {
        match self {
            Self::RmsNorm(n) => n.forward(x),
            Self::LayerNorm(n) => n.forward(x),
        }
    }
}

pub struct RotaryEmbedding {
    cos: Tensor,
    sin: Tensor,
}

impl RotaryEmbedding {
    pub fn new(head_dim: usize, max_seq_len: usize, theta: f64, device: &Device) -> Result<Self> {
        let inv_freq: Vec<f32> = (0..head_dim)
            .step_by(2)
            .map(|i| 1.0 / (theta as f32).powf(i as f32 / head_dim as f32))
            .collect();
        let inv_freq = Tensor::new(inv_freq.as_slice(), device)?;
        let positions: Vec<f32> = (0..max_seq_len).map(|p| p as f32).collect();
        let positions = Tensor::new(positions.as_slice(), device)?.unsqueeze(1)?;
        let freqs = positions.matmul(&inv_freq.unsqueeze(0)?)?;
        let cos = freqs.cos()?;
        let sin = freqs.sin()?;
        Ok(Self { cos, sin })
    }

    pub fn apply(&self, q: &Tensor, k: &Tensor, start_pos: usize) -> Result<(Tensor, Tensor)> {
        let seq_len = q.dim(2)?;
        let cos = self.cos.narrow(0, start_pos, seq_len)?;
        let sin = self.sin.narrow(0, start_pos, seq_len)?;

        let q_rot = self.rotate_half(q, &cos, &sin)?;
        let k_rot = self.rotate_half(k, &cos, &sin)?;
        Ok((q_rot, k_rot))
    }

    fn rotate_half(&self, x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
        let (b, h, seq, d) = x.dims4()?;
        let x1 = x.narrow(3, 0, d / 2)?;
        let x2 = x.narrow(3, d / 2, d / 2)?;
        let rotated = Tensor::cat(&[&x2.neg()?, &x1], 3)?;

        let cos = cos
            .unsqueeze(0)?
            .unsqueeze(0)?
            .broadcast_as((b, h, seq, d / 2))?;
        let sin = sin
            .unsqueeze(0)?
            .unsqueeze(0)?
            .broadcast_as((b, h, seq, d / 2))?;
        let cos = Tensor::cat(&[&cos, &cos], 3)?;
        let sin = Tensor::cat(&[&sin, &sin], 3)?;

        let x_cos = x.broadcast_mul(&cos)?;
        let rot_sin = rotated.broadcast_mul(&sin)?;
        let result = x_cos.add(&rot_sin)?;
        Ok(result)
    }
}

pub struct MultiHeadAttention {
    q_proj: Linear,
    k_proj: Linear,
    v_proj: Linear,
    o_proj: Linear,
    num_heads: usize,
    num_kv_heads: usize,
    head_dim: usize,
    dropout: Dropout,
    window_size: Option<usize>,
    causal: bool,
}

impl MultiHeadAttention {
    pub fn new(config: &ModelDef, vb: VarBuilder) -> Result<Self> {
        let num_heads = config.num_heads();
        let num_kv_heads = config.num_kv_heads();
        let head_dim = config.head_dim();
        let hidden_size = config.hidden_size;
        let kv_dim = num_kv_heads * head_dim;

        let (q_proj, k_proj, v_proj, o_proj) = if config.use_bias() {
            (
                linear(hidden_size, hidden_size, vb.pp("q_proj"))?,
                linear(hidden_size, kv_dim, vb.pp("k_proj"))?,
                linear(hidden_size, kv_dim, vb.pp("v_proj"))?,
                linear(hidden_size, hidden_size, vb.pp("o_proj"))?,
            )
        } else {
            (
                linear_no_bias(hidden_size, hidden_size, vb.pp("q_proj"))?,
                linear_no_bias(hidden_size, kv_dim, vb.pp("k_proj"))?,
                linear_no_bias(hidden_size, kv_dim, vb.pp("v_proj"))?,
                linear_no_bias(hidden_size, hidden_size, vb.pp("o_proj"))?,
            )
        };
        let dropout = Dropout::new(config.dropout() as f32);
        Ok(Self {
            q_proj,
            k_proj,
            v_proj,
            o_proj,
            num_heads,
            num_kv_heads,
            head_dim,
            dropout,
            window_size: config.block.attention.window_size,
            causal: config.block.attention.causal,
        })
    }

    pub fn forward(
        &self,
        x: &Tensor,
        mask: Option<&Tensor>,
        rope: &RotaryEmbedding,
        start_pos: usize,
        train: bool,
    ) -> Result<Tensor> {
        let (batch_size, seq_len, _) = x.dims3()?;

        let q = self.q_proj.forward(x)?;
        let k = self.k_proj.forward(x)?;
        let v = self.v_proj.forward(x)?;

        let q = q.reshape((batch_size, seq_len, self.num_heads, self.head_dim))?;
        let k = k.reshape((batch_size, seq_len, self.num_kv_heads, self.head_dim))?;
        let v = v.reshape((batch_size, seq_len, self.num_kv_heads, self.head_dim))?;

        let q = q.transpose(1, 2)?.contiguous()?;
        let k = k.transpose(1, 2)?.contiguous()?;
        let v = v.transpose(1, 2)?.contiguous()?;

        let (q, k) = rope.apply(&q, &k, start_pos)?;

        // Repeat KV heads for GQA (Grouped Query Attention)
        let (k, v) = if self.num_kv_heads != self.num_heads {
            let n_rep = self.num_heads / self.num_kv_heads;
            let k = k
                .unsqueeze(2)?
                .expand((batch_size, self.num_kv_heads, n_rep, seq_len, self.head_dim))?
                .reshape((batch_size, self.num_heads, seq_len, self.head_dim))?;
            let v = v
                .unsqueeze(2)?
                .expand((batch_size, self.num_kv_heads, n_rep, seq_len, self.head_dim))?
                .reshape((batch_size, self.num_heads, seq_len, self.head_dim))?;
            (k, v)
        } else {
            (k, v)
        };

        // Use Flash Attention if available (CUDA only), otherwise standard attention
        #[cfg(feature = "flash-attn")]
        let attn_output = {
            let q = q.transpose(1, 2)?;
            let k = k.transpose(1, 2)?;
            let v = v.transpose(1, 2)?;
            let softmax_scale = 1.0 / (self.head_dim as f32).sqrt();
            let attn = candle_flash_attn::flash_attn(&q, &k, &v, softmax_scale, self.causal)?;
            attn.reshape((batch_size, seq_len, self.num_heads * self.head_dim))?
        };

        #[cfg(not(feature = "flash-attn"))]
        let attn_output = {
            let scale = (self.head_dim as f64).sqrt();
            let k_t = k.transpose(2, 3)?.contiguous()?;
            let attn_weights = q.matmul(&k_t)?.affine(1.0 / scale, 0.0)?;

            // Apply causal mask if needed
            let attn_weights = if self.causal {
                match mask {
                    Some(m) => masked_fill(&attn_weights, m, f32::NEG_INFINITY)?,
                    None => attn_weights,
                }
            } else {
                attn_weights
            };

            // Apply sliding window mask if configured
            let attn_weights = if let Some(window) = self.window_size {
                let device = attn_weights.device();
                let mut window_mask = vec![0u8; seq_len * seq_len];
                for i in 0..seq_len {
                    for j in 0..seq_len {
                        if (i as isize - j as isize).unsigned_abs() > window {
                            window_mask[i * seq_len + j] = 1;
                        }
                    }
                }
                let window_mask = Tensor::from_vec(window_mask, (seq_len, seq_len), device)?
                    .unsqueeze(0)?
                    .unsqueeze(0)?;
                masked_fill(&attn_weights, &window_mask, f32::NEG_INFINITY)?
            } else {
                attn_weights
            };

            let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
            let attn_weights = if train {
                self.dropout.forward(&attn_weights, train)?
            } else {
                attn_weights
            };

            let output = attn_weights.matmul(&v)?;
            let output = output.transpose(1, 2)?.contiguous()?;
            output.reshape((batch_size, seq_len, self.num_heads * self.head_dim))?
        };

        self.o_proj.forward(&attn_output)
    }
}

pub struct FeedForward {
    gate_proj: Option<Linear>,
    up_proj: Linear,
    down_proj: Linear,
    dropout: Dropout,
    use_swiglu: bool,
}

impl FeedForward {
    pub fn new(config: &ModelDef, vb: VarBuilder) -> Result<Self> {
        let use_swiglu = config.use_swiglu();
        let use_gate = config.block.ffn.gate;
        let intermediate_size = config.intermediate_size();

        let gate_proj = if use_gate {
            Some(if config.use_bias() {
                linear(config.hidden_size, intermediate_size, vb.pp("gate_proj"))?
            } else {
                linear_no_bias(config.hidden_size, intermediate_size, vb.pp("gate_proj"))?
            })
        } else {
            None
        };

        let (up_proj, down_proj) = if config.use_bias() {
            (
                linear(config.hidden_size, intermediate_size, vb.pp("up_proj"))?,
                linear(intermediate_size, config.hidden_size, vb.pp("down_proj"))?,
            )
        } else {
            (
                linear_no_bias(config.hidden_size, intermediate_size, vb.pp("up_proj"))?,
                linear_no_bias(intermediate_size, config.hidden_size, vb.pp("down_proj"))?,
            )
        };
        let dropout = Dropout::new(config.dropout() as f32);
        Ok(Self {
            gate_proj,
            up_proj,
            down_proj,
            dropout,
            use_swiglu,
        })
    }

    pub fn forward(&self, x: &Tensor, train: bool) -> Result<Tensor> {
        let hidden = if let Some(gate_proj) = &self.gate_proj {
            let gate = gate_proj.forward(x)?;
            let up = self.up_proj.forward(x)?;
            if self.use_swiglu {
                let gate = candle_nn::ops::silu(&gate)?;
                (gate * up)?
            } else {
                let gate = gate.gelu_erf()?;
                (gate * up)?
            }
        } else {
            // No gating - simple MLP
            let h = self.up_proj.forward(x)?;
            if self.use_swiglu {
                candle_nn::ops::silu(&h)?
            } else {
                h.gelu_erf()?
            }
        };

        let hidden = self.dropout.forward(&hidden, train)?;
        self.down_proj.forward(&hidden)
    }
}

pub struct TransformerBlock {
    attention: MultiHeadAttention,
    feed_forward: FeedForward,
    attn_norm: Norm,
    ffn_norm: Norm,
    norm_position: NormPosition,
    use_residual: bool,
}

impl TransformerBlock {
    pub fn new(config: &ModelDef, vb: VarBuilder) -> Result<Self> {
        let attention = MultiHeadAttention::new(config, vb.pp("attention"))?;
        let feed_forward = FeedForward::new(config, vb.pp("feed_forward"))?;
        let norm_type = config.block.norm.norm_type;
        let attn_norm = Norm::new(
            norm_type,
            config.hidden_size,
            config.norm_eps(),
            config.use_bias(),
            vb.pp("attn_norm"),
        )?;
        let ffn_norm = Norm::new(
            norm_type,
            config.hidden_size,
            config.norm_eps(),
            config.use_bias(),
            vb.pp("ffn_norm"),
        )?;
        Ok(Self {
            attention,
            feed_forward,
            attn_norm,
            ffn_norm,
            norm_position: config.block.norm_position,
            use_residual: config.block.residual,
        })
    }

    pub fn forward(
        &self,
        x: &Tensor,
        mask: Option<&Tensor>,
        rope: &RotaryEmbedding,
        start_pos: usize,
        train: bool,
    ) -> Result<Tensor> {
        // Attention
        let x = match self.norm_position {
            NormPosition::Pre => {
                let h = self.attn_norm.forward(x)?;
                let h = self.attention.forward(&h, mask, rope, start_pos, train)?;
                if self.use_residual { (x + h)? } else { h }
            }
            NormPosition::Post => {
                let h = self.attention.forward(x, mask, rope, start_pos, train)?;
                let h = if self.use_residual { (x + h)? } else { h };
                self.attn_norm.forward(&h)?
            }
        };

        // FFN
        match self.norm_position {
            NormPosition::Pre => {
                let h = self.ffn_norm.forward(&x)?;
                let h = self.feed_forward.forward(&h, train)?;
                if self.use_residual { &x + h } else { Ok(h) }
            }
            NormPosition::Post => {
                let h = self.feed_forward.forward(&x, train)?;
                let h = if self.use_residual { (&x + h)? } else { h };
                self.ffn_norm.forward(&h)
            }
        }
    }
}

pub struct Transformer {
    embedding: Embedding,
    layers: Vec<TransformerBlock>,
    final_norm: Norm,
    lm_head: Linear,
    rope: RotaryEmbedding,
    config: ModelDef,
}

impl Transformer {
    pub fn new(config: &ModelDef, vb: VarBuilder) -> Result<Self> {
        let embedding = embedding(config.vocab_size, config.hidden_size, vb.pp("embedding"))?;
        let mut layers = Vec::with_capacity(config.num_layers);
        for i in 0..config.num_layers {
            layers.push(TransformerBlock::new(
                config,
                vb.pp(format!("layers.{}", i)),
            )?);
        }
        let final_norm = Norm::new(
            config.block.norm.norm_type,
            config.hidden_size,
            config.norm_eps(),
            config.use_bias(),
            vb.pp("final_norm"),
        )?;
        let lm_head = linear_no_bias(config.hidden_size, config.vocab_size, vb.pp("lm_head"))?;
        let rope = RotaryEmbedding::new(
            config.head_dim(),
            config.max_seq_len,
            config.rope_theta(),
            vb.device(),
        )?;
        Ok(Self {
            embedding,
            layers,
            final_norm,
            lm_head,
            rope,
            config: config.clone(),
        })
    }

    pub fn forward(&self, input_ids: &Tensor, start_pos: usize, train: bool) -> Result<Tensor> {
        let (_, seq_len) = input_ids.dims2()?;

        let x = self.embedding.forward(input_ids)?;

        let mask = if seq_len > 1 {
            let mut mask_data = vec![0u8; seq_len * seq_len];
            for i in 0..seq_len {
                for j in (i + 1)..seq_len {
                    mask_data[i * seq_len + j] = 1;
                }
            }
            let mask = Tensor::from_vec(mask_data, (seq_len, seq_len), input_ids.device())?
                .unsqueeze(0)?
                .unsqueeze(0)?;
            Some(mask)
        } else {
            None
        };

        let mut x = x;
        for layer in &self.layers {
            x = layer.forward(&x, mask.as_ref(), &self.rope, start_pos, train)?;
        }

        let x = self.final_norm.forward(&x)?;
        self.lm_head.forward(&x)
    }

    pub fn config(&self) -> &ModelDef {
        &self.config
    }

    pub fn num_parameters(&self) -> usize {
        self.config.estimated_params()
    }
}

pub fn cross_entropy_loss(logits: &Tensor, targets: &Tensor) -> Result<Tensor> {
    let (batch_size, seq_len, vocab_size) = logits.dims3()?;
    let logits = logits.reshape((batch_size * seq_len, vocab_size))?;
    let targets = targets.reshape((batch_size * seq_len,))?;
    candle_nn::loss::cross_entropy(&logits, &targets)
}