1use candle_core::{DType, Device, Module, Result, Tensor};
2use candle_nn::{Dropout, Embedding, Linear, VarBuilder, embedding, linear, linear_no_bias};
3
4use crate::mal::{ModelDef, NormPosition, NormType};
5
6fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
7 let shape = on_false.shape();
8 let mask = mask.broadcast_as(shape.dims())?;
9 let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
10 let m = mask.where_cond(&on_true, on_false)?;
11 Ok(m)
12}
13
14#[derive(Debug, Clone)]
15pub struct RMSNorm {
16 weight: Tensor,
17 eps: f64,
18}
19
20impl RMSNorm {
21 pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
22 let weight = vb.get_with_hints(size, "weight", candle_nn::Init::Const(1.0))?;
23 Ok(Self { weight, eps })
24 }
25}
26
27impl Module for RMSNorm {
28 fn forward(&self, x: &Tensor) -> Result<Tensor> {
29 let dtype = x.dtype();
30 let x = x.to_dtype(DType::F32)?;
31 let variance = x.sqr()?.mean_keepdim(candle_core::D::Minus1)?;
32 let x = x.broadcast_div(&(variance + self.eps)?.sqrt()?)?;
33 let x = x.to_dtype(dtype)?;
34 x.broadcast_mul(&self.weight)
35 }
36}
37
38#[derive(Debug, Clone)]
39pub struct LayerNorm {
40 weight: Tensor,
41 bias: Option<Tensor>,
42 eps: f64,
43}
44
45impl LayerNorm {
46 pub fn new(size: usize, eps: f64, use_bias: bool, vb: VarBuilder) -> Result<Self> {
47 let weight = vb.get_with_hints(size, "weight", candle_nn::Init::Const(1.0))?;
48 let bias = if use_bias {
49 Some(vb.get_with_hints(size, "bias", candle_nn::Init::Const(0.0))?)
50 } else {
51 None
52 };
53 Ok(Self { weight, bias, eps })
54 }
55}
56
57impl Module for LayerNorm {
58 fn forward(&self, x: &Tensor) -> Result<Tensor> {
59 let dtype = x.dtype();
60 let x = x.to_dtype(DType::F32)?;
61 let mean = x.mean_keepdim(candle_core::D::Minus1)?;
62 let x = x.broadcast_sub(&mean)?;
63 let variance = x.sqr()?.mean_keepdim(candle_core::D::Minus1)?;
64 let x = x.broadcast_div(&(variance + self.eps)?.sqrt()?)?;
65 let x = x.to_dtype(dtype)?;
66 let x = x.broadcast_mul(&self.weight)?;
67 match &self.bias {
68 Some(bias) => x.broadcast_add(bias),
69 None => Ok(x),
70 }
71 }
72}
73
74#[derive(Debug, Clone)]
76pub enum Norm {
77 RmsNorm(RMSNorm),
78 LayerNorm(LayerNorm),
79}
80
81impl Norm {
82 pub fn new(
83 norm_type: NormType,
84 size: usize,
85 eps: f64,
86 use_bias: bool,
87 vb: VarBuilder,
88 ) -> Result<Self> {
89 match norm_type {
90 NormType::RmsNorm | NormType::None => Ok(Self::RmsNorm(RMSNorm::new(size, eps, vb)?)),
91 NormType::LayerNorm => Ok(Self::LayerNorm(LayerNorm::new(size, eps, use_bias, vb)?)),
92 }
93 }
94}
95
96impl Module for Norm {
97 fn forward(&self, x: &Tensor) -> Result<Tensor> {
98 match self {
99 Self::RmsNorm(n) => n.forward(x),
100 Self::LayerNorm(n) => n.forward(x),
101 }
102 }
103}
104
105pub struct RotaryEmbedding {
106 cos: Tensor,
107 sin: Tensor,
108}
109
110impl RotaryEmbedding {
111 pub fn new(head_dim: usize, max_seq_len: usize, theta: f64, device: &Device) -> Result<Self> {
112 let inv_freq: Vec<f32> = (0..head_dim)
113 .step_by(2)
114 .map(|i| 1.0 / (theta as f32).powf(i as f32 / head_dim as f32))
115 .collect();
116 let inv_freq = Tensor::new(inv_freq.as_slice(), device)?;
117 let positions: Vec<f32> = (0..max_seq_len).map(|p| p as f32).collect();
118 let positions = Tensor::new(positions.as_slice(), device)?.unsqueeze(1)?;
119 let freqs = positions.matmul(&inv_freq.unsqueeze(0)?)?;
120 let cos = freqs.cos()?;
121 let sin = freqs.sin()?;
122 Ok(Self { cos, sin })
123 }
124
125 pub fn apply(&self, q: &Tensor, k: &Tensor, start_pos: usize) -> Result<(Tensor, Tensor)> {
126 let seq_len = q.dim(2)?;
127 let cos = self.cos.narrow(0, start_pos, seq_len)?;
128 let sin = self.sin.narrow(0, start_pos, seq_len)?;
129
130 let q_rot = self.rotate_half(q, &cos, &sin)?;
131 let k_rot = self.rotate_half(k, &cos, &sin)?;
132 Ok((q_rot, k_rot))
133 }
134
135 fn rotate_half(&self, x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
136 let (b, h, seq, d) = x.dims4()?;
137 let x1 = x.narrow(3, 0, d / 2)?;
138 let x2 = x.narrow(3, d / 2, d / 2)?;
139 let rotated = Tensor::cat(&[&x2.neg()?, &x1], 3)?;
140
141 let cos = cos
142 .unsqueeze(0)?
143 .unsqueeze(0)?
144 .broadcast_as((b, h, seq, d / 2))?;
145 let sin = sin
146 .unsqueeze(0)?
147 .unsqueeze(0)?
148 .broadcast_as((b, h, seq, d / 2))?;
149 let cos = Tensor::cat(&[&cos, &cos], 3)?;
150 let sin = Tensor::cat(&[&sin, &sin], 3)?;
151
152 let x_cos = x.broadcast_mul(&cos)?;
153 let rot_sin = rotated.broadcast_mul(&sin)?;
154 let result = x_cos.add(&rot_sin)?;
155 Ok(result)
156 }
157}
158
159pub struct MultiHeadAttention {
160 q_proj: Linear,
161 k_proj: Linear,
162 v_proj: Linear,
163 o_proj: Linear,
164 num_heads: usize,
165 num_kv_heads: usize,
166 head_dim: usize,
167 dropout: Dropout,
168 window_size: Option<usize>,
169 causal: bool,
170}
171
172impl MultiHeadAttention {
173 pub fn new(config: &ModelDef, vb: VarBuilder) -> Result<Self> {
174 let num_heads = config.num_heads();
175 let num_kv_heads = config.num_kv_heads();
176 let head_dim = config.head_dim();
177 let hidden_size = config.hidden_size;
178 let kv_dim = num_kv_heads * head_dim;
179
180 let (q_proj, k_proj, v_proj, o_proj) = if config.use_bias() {
181 (
182 linear(hidden_size, hidden_size, vb.pp("q_proj"))?,
183 linear(hidden_size, kv_dim, vb.pp("k_proj"))?,
184 linear(hidden_size, kv_dim, vb.pp("v_proj"))?,
185 linear(hidden_size, hidden_size, vb.pp("o_proj"))?,
186 )
187 } else {
188 (
189 linear_no_bias(hidden_size, hidden_size, vb.pp("q_proj"))?,
190 linear_no_bias(hidden_size, kv_dim, vb.pp("k_proj"))?,
191 linear_no_bias(hidden_size, kv_dim, vb.pp("v_proj"))?,
192 linear_no_bias(hidden_size, hidden_size, vb.pp("o_proj"))?,
193 )
194 };
195 let dropout = Dropout::new(config.dropout() as f32);
196 Ok(Self {
197 q_proj,
198 k_proj,
199 v_proj,
200 o_proj,
201 num_heads,
202 num_kv_heads,
203 head_dim,
204 dropout,
205 window_size: config.block.attention.window_size,
206 causal: config.block.attention.causal,
207 })
208 }
209
210 pub fn forward(
211 &self,
212 x: &Tensor,
213 mask: Option<&Tensor>,
214 rope: &RotaryEmbedding,
215 start_pos: usize,
216 train: bool,
217 ) -> Result<Tensor> {
218 let (batch_size, seq_len, _) = x.dims3()?;
219
220 let q = self.q_proj.forward(x)?;
221 let k = self.k_proj.forward(x)?;
222 let v = self.v_proj.forward(x)?;
223
224 let q = q.reshape((batch_size, seq_len, self.num_heads, self.head_dim))?;
225 let k = k.reshape((batch_size, seq_len, self.num_kv_heads, self.head_dim))?;
226 let v = v.reshape((batch_size, seq_len, self.num_kv_heads, self.head_dim))?;
227
228 let q = q.transpose(1, 2)?.contiguous()?;
229 let k = k.transpose(1, 2)?.contiguous()?;
230 let v = v.transpose(1, 2)?.contiguous()?;
231
232 let (q, k) = rope.apply(&q, &k, start_pos)?;
233
234 let (k, v) = if self.num_kv_heads != self.num_heads {
236 let n_rep = self.num_heads / self.num_kv_heads;
237 let k = k
238 .unsqueeze(2)?
239 .expand((batch_size, self.num_kv_heads, n_rep, seq_len, self.head_dim))?
240 .reshape((batch_size, self.num_heads, seq_len, self.head_dim))?;
241 let v = v
242 .unsqueeze(2)?
243 .expand((batch_size, self.num_kv_heads, n_rep, seq_len, self.head_dim))?
244 .reshape((batch_size, self.num_heads, seq_len, self.head_dim))?;
245 (k, v)
246 } else {
247 (k, v)
248 };
249
250 #[cfg(feature = "flash-attn")]
252 let attn_output = {
253 let q = q.transpose(1, 2)?;
254 let k = k.transpose(1, 2)?;
255 let v = v.transpose(1, 2)?;
256 let softmax_scale = 1.0 / (self.head_dim as f32).sqrt();
257 let attn = candle_flash_attn::flash_attn(&q, &k, &v, softmax_scale, self.causal)?;
258 attn.reshape((batch_size, seq_len, self.num_heads * self.head_dim))?
259 };
260
261 #[cfg(not(feature = "flash-attn"))]
262 let attn_output = {
263 let scale = (self.head_dim as f64).sqrt();
264 let k_t = k.transpose(2, 3)?.contiguous()?;
265 let attn_weights = q.matmul(&k_t)?.affine(1.0 / scale, 0.0)?;
266
267 let attn_weights = if self.causal {
269 match mask {
270 Some(m) => masked_fill(&attn_weights, m, f32::NEG_INFINITY)?,
271 None => attn_weights,
272 }
273 } else {
274 attn_weights
275 };
276
277 let attn_weights = if let Some(window) = self.window_size {
279 let device = attn_weights.device();
280 let mut window_mask = vec![0u8; seq_len * seq_len];
281 for i in 0..seq_len {
282 for j in 0..seq_len {
283 if (i as isize - j as isize).unsigned_abs() > window {
284 window_mask[i * seq_len + j] = 1;
285 }
286 }
287 }
288 let window_mask = Tensor::from_vec(window_mask, (seq_len, seq_len), device)?
289 .unsqueeze(0)?
290 .unsqueeze(0)?;
291 masked_fill(&attn_weights, &window_mask, f32::NEG_INFINITY)?
292 } else {
293 attn_weights
294 };
295
296 let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
297 let attn_weights = if train {
298 self.dropout.forward(&attn_weights, train)?
299 } else {
300 attn_weights
301 };
302
303 let output = attn_weights.matmul(&v)?;
304 let output = output.transpose(1, 2)?.contiguous()?;
305 output.reshape((batch_size, seq_len, self.num_heads * self.head_dim))?
306 };
307
308 self.o_proj.forward(&attn_output)
309 }
310}
311
312pub struct FeedForward {
313 gate_proj: Option<Linear>,
314 up_proj: Linear,
315 down_proj: Linear,
316 dropout: Dropout,
317 use_swiglu: bool,
318}
319
320impl FeedForward {
321 pub fn new(config: &ModelDef, vb: VarBuilder) -> Result<Self> {
322 let use_swiglu = config.use_swiglu();
323 let use_gate = config.block.ffn.gate;
324 let intermediate_size = config.intermediate_size();
325
326 let gate_proj = if use_gate {
327 Some(if config.use_bias() {
328 linear(config.hidden_size, intermediate_size, vb.pp("gate_proj"))?
329 } else {
330 linear_no_bias(config.hidden_size, intermediate_size, vb.pp("gate_proj"))?
331 })
332 } else {
333 None
334 };
335
336 let (up_proj, down_proj) = if config.use_bias() {
337 (
338 linear(config.hidden_size, intermediate_size, vb.pp("up_proj"))?,
339 linear(intermediate_size, config.hidden_size, vb.pp("down_proj"))?,
340 )
341 } else {
342 (
343 linear_no_bias(config.hidden_size, intermediate_size, vb.pp("up_proj"))?,
344 linear_no_bias(intermediate_size, config.hidden_size, vb.pp("down_proj"))?,
345 )
346 };
347 let dropout = Dropout::new(config.dropout() as f32);
348 Ok(Self {
349 gate_proj,
350 up_proj,
351 down_proj,
352 dropout,
353 use_swiglu,
354 })
355 }
356
357 pub fn forward(&self, x: &Tensor, train: bool) -> Result<Tensor> {
358 let hidden = if let Some(gate_proj) = &self.gate_proj {
359 let gate = gate_proj.forward(x)?;
360 let up = self.up_proj.forward(x)?;
361 if self.use_swiglu {
362 let gate = candle_nn::ops::silu(&gate)?;
363 (gate * up)?
364 } else {
365 let gate = gate.gelu_erf()?;
366 (gate * up)?
367 }
368 } else {
369 let h = self.up_proj.forward(x)?;
371 if self.use_swiglu {
372 candle_nn::ops::silu(&h)?
373 } else {
374 h.gelu_erf()?
375 }
376 };
377
378 let hidden = self.dropout.forward(&hidden, train)?;
379 self.down_proj.forward(&hidden)
380 }
381}
382
383pub struct TransformerBlock {
384 attention: MultiHeadAttention,
385 feed_forward: FeedForward,
386 attn_norm: Norm,
387 ffn_norm: Norm,
388 norm_position: NormPosition,
389 use_residual: bool,
390}
391
392impl TransformerBlock {
393 pub fn new(config: &ModelDef, vb: VarBuilder) -> Result<Self> {
394 let attention = MultiHeadAttention::new(config, vb.pp("attention"))?;
395 let feed_forward = FeedForward::new(config, vb.pp("feed_forward"))?;
396 let norm_type = config.block.norm.norm_type;
397 let attn_norm = Norm::new(
398 norm_type,
399 config.hidden_size,
400 config.norm_eps(),
401 config.use_bias(),
402 vb.pp("attn_norm"),
403 )?;
404 let ffn_norm = Norm::new(
405 norm_type,
406 config.hidden_size,
407 config.norm_eps(),
408 config.use_bias(),
409 vb.pp("ffn_norm"),
410 )?;
411 Ok(Self {
412 attention,
413 feed_forward,
414 attn_norm,
415 ffn_norm,
416 norm_position: config.block.norm_position,
417 use_residual: config.block.residual,
418 })
419 }
420
421 pub fn forward(
422 &self,
423 x: &Tensor,
424 mask: Option<&Tensor>,
425 rope: &RotaryEmbedding,
426 start_pos: usize,
427 train: bool,
428 ) -> Result<Tensor> {
429 let x = match self.norm_position {
431 NormPosition::Pre => {
432 let h = self.attn_norm.forward(x)?;
433 let h = self.attention.forward(&h, mask, rope, start_pos, train)?;
434 if self.use_residual { (x + h)? } else { h }
435 }
436 NormPosition::Post => {
437 let h = self.attention.forward(x, mask, rope, start_pos, train)?;
438 let h = if self.use_residual { (x + h)? } else { h };
439 self.attn_norm.forward(&h)?
440 }
441 };
442
443 match self.norm_position {
445 NormPosition::Pre => {
446 let h = self.ffn_norm.forward(&x)?;
447 let h = self.feed_forward.forward(&h, train)?;
448 if self.use_residual { &x + h } else { Ok(h) }
449 }
450 NormPosition::Post => {
451 let h = self.feed_forward.forward(&x, train)?;
452 let h = if self.use_residual { (&x + h)? } else { h };
453 self.ffn_norm.forward(&h)
454 }
455 }
456 }
457}
458
459pub struct Transformer {
460 embedding: Embedding,
461 layers: Vec<TransformerBlock>,
462 final_norm: Norm,
463 lm_head: Linear,
464 rope: RotaryEmbedding,
465 config: ModelDef,
466}
467
468impl Transformer {
469 pub fn new(config: &ModelDef, vb: VarBuilder) -> Result<Self> {
470 let embedding = embedding(config.vocab_size, config.hidden_size, vb.pp("embedding"))?;
471 let mut layers = Vec::with_capacity(config.num_layers);
472 for i in 0..config.num_layers {
473 layers.push(TransformerBlock::new(
474 config,
475 vb.pp(format!("layers.{}", i)),
476 )?);
477 }
478 let final_norm = Norm::new(
479 config.block.norm.norm_type,
480 config.hidden_size,
481 config.norm_eps(),
482 config.use_bias(),
483 vb.pp("final_norm"),
484 )?;
485 let lm_head = linear_no_bias(config.hidden_size, config.vocab_size, vb.pp("lm_head"))?;
486 let rope = RotaryEmbedding::new(
487 config.head_dim(),
488 config.max_seq_len,
489 config.rope_theta(),
490 vb.device(),
491 )?;
492 Ok(Self {
493 embedding,
494 layers,
495 final_norm,
496 lm_head,
497 rope,
498 config: config.clone(),
499 })
500 }
501
502 pub fn forward(&self, input_ids: &Tensor, start_pos: usize, train: bool) -> Result<Tensor> {
503 let (_, seq_len) = input_ids.dims2()?;
504
505 let x = self.embedding.forward(input_ids)?;
506
507 let mask = if seq_len > 1 {
508 let mut mask_data = vec![0u8; seq_len * seq_len];
509 for i in 0..seq_len {
510 for j in (i + 1)..seq_len {
511 mask_data[i * seq_len + j] = 1;
512 }
513 }
514 let mask = Tensor::from_vec(mask_data, (seq_len, seq_len), input_ids.device())?
515 .unsqueeze(0)?
516 .unsqueeze(0)?;
517 Some(mask)
518 } else {
519 None
520 };
521
522 let mut x = x;
523 for layer in &self.layers {
524 x = layer.forward(&x, mask.as_ref(), &self.rope, start_pos, train)?;
525 }
526
527 let x = self.final_norm.forward(&x)?;
528 self.lm_head.forward(&x)
529 }
530
531 pub fn config(&self) -> &ModelDef {
532 &self.config
533 }
534
535 pub fn num_parameters(&self) -> usize {
536 self.config.estimated_params()
537 }
538}
539
540pub fn cross_entropy_loss(logits: &Tensor, targets: &Tensor) -> Result<Tensor> {
541 let (batch_size, seq_len, vocab_size) = logits.dims3()?;
542 let logits = logits.reshape((batch_size * seq_len, vocab_size))?;
543 let targets = targets.reshape((batch_size * seq_len,))?;
544 candle_nn::loss::cross_entropy(&logits, &targets)
545}