1use candle_core::{DType, Device as CandleDevice, IndexOp, Result as CandleResult, Tensor};
10use candle_nn::{self, Module, VarBuilder};
11
12type Linear = candle_nn::Linear;
14type Embedding = candle_nn::Embedding;
15type RmsNorm = candle_nn::RmsNorm;
16
17fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> CandleResult<Linear> {
18 let w = vb.get((out_dim, in_dim), "weight")?;
19 Ok(Linear::new(w, None))
20}
21
22pub struct RmsNormWithWeight {
24 norm: RmsNorm,
25 pub weight: Tensor,
26}
27
28impl Module for RmsNormWithWeight {
29 fn forward(&self, xs: &Tensor) -> CandleResult<Tensor> {
30 self.norm.forward(xs)
31 }
32}
33
34fn rms_norm_with_weight(size: usize, eps: f64, vb: VarBuilder) -> CandleResult<RmsNormWithWeight> {
35 let w = vb.get(size, "weight")?;
36 Ok(RmsNormWithWeight {
37 norm: RmsNorm::new(w.clone(), eps),
38 weight: w,
39 })
40}
41use ferrum_types::{FerrumError, Result};
42use std::collections::HashMap;
43use tracing::{debug, info};
44
45pub struct PreAllocKvCache {
49 pub k_caches: Vec<Tensor>,
51 pub v_caches: Vec<Tensor>,
53 pub current_len: usize,
54 pub max_len: usize,
55}
56
57impl PreAllocKvCache {
58 pub fn new(
59 num_layers: usize,
60 max_len: usize,
61 num_kv_heads: usize,
62 head_dim: usize,
63 dtype: DType,
64 device: &CandleDevice,
65 ) -> CandleResult<Self> {
66 let mut k_caches = Vec::with_capacity(num_layers);
67 let mut v_caches = Vec::with_capacity(num_layers);
68 for _ in 0..num_layers {
69 k_caches.push(Tensor::zeros(
70 (max_len, num_kv_heads, head_dim),
71 dtype,
72 device,
73 )?);
74 v_caches.push(Tensor::zeros(
75 (max_len, num_kv_heads, head_dim),
76 dtype,
77 device,
78 )?);
79 }
80 Ok(Self {
81 k_caches,
82 v_caches,
83 current_len: 0,
84 max_len,
85 })
86 }
87}
88
89pub struct RotaryEmbedding {
92 pub cos: Tensor,
93 pub sin: Tensor,
94}
95
96impl RotaryEmbedding {
97 pub fn new(cfg: &Config, dtype: DType, device: &CandleDevice) -> CandleResult<Self> {
98 let head_dim = cfg.head_dim;
99 let inv_freq: Vec<f32> = (0..head_dim)
100 .step_by(2)
101 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
102 .collect();
103 let inv_freq_t = Tensor::new(inv_freq, device)?;
104 let positions = Tensor::arange(0, cfg.max_position_embeddings as u32, device)?
105 .to_dtype(DType::F32)?
106 .reshape((cfg.max_position_embeddings, 1))?;
107 let angles = positions.matmul(&inv_freq_t.reshape((1, inv_freq_t.elem_count()))?)?;
108 let cos = angles.cos()?.to_dtype(dtype)?;
109 let sin = angles.sin()?.to_dtype(dtype)?;
110 Ok(Self { cos, sin })
111 }
112
113 fn apply(&self, x: &Tensor, pos: usize) -> CandleResult<Tensor> {
114 let (_, _, seq_len, _) = x.dims4()?;
115 let cos = self.cos.narrow(0, pos, seq_len)?;
116 let sin = self.sin.narrow(0, pos, seq_len)?;
117 candle_nn::rotary_emb::rope(x, &cos, &sin)
118 }
119}
120
121#[derive(Debug, Clone)]
124pub struct Config {
125 pub vocab_size: usize,
126 pub hidden_size: usize,
127 pub intermediate_size: usize,
128 pub num_hidden_layers: usize,
129 pub num_attention_heads: usize,
130 pub num_key_value_heads: usize,
131 pub rms_norm_eps: f64,
132 pub rope_theta: f32,
133 pub max_position_embeddings: usize,
134 pub tie_word_embeddings: bool,
135 pub head_dim: usize,
136}
137
138pub struct Attention {
139 pub q_proj: Linear,
140 pub k_proj: Linear,
141 pub v_proj: Linear,
142 pub o_proj: Linear,
143 pub num_attention_heads: usize,
144 pub num_key_value_heads: usize,
145 pub head_dim: usize,
146}
147
148impl Attention {
149 fn forward(
150 &self,
151 x: &Tensor,
152 pos: usize,
153 layer_idx: usize,
154 rotary: &RotaryEmbedding,
155 kv_cache: &mut PreAllocKvCache,
156 ) -> CandleResult<Tensor> {
157 let (b, seq_len, _) = x.dims3()?;
158
159 let q = self.q_proj.forward(x)?;
160 let k = self.k_proj.forward(x)?;
161 let v = self.v_proj.forward(x)?;
162
163 let q = q
164 .reshape((b, seq_len, self.num_attention_heads, self.head_dim))?
165 .transpose(1, 2)?
166 .contiguous()?;
167 let k = k
168 .reshape((b, seq_len, self.num_key_value_heads, self.head_dim))?
169 .transpose(1, 2)?
170 .contiguous()?;
171 let v = v
172 .reshape((b, seq_len, self.num_key_value_heads, self.head_dim))?
173 .transpose(1, 2)?;
174
175 let q = rotary.apply(&q, pos)?;
177 let k = rotary.apply(&k, pos)?;
178
179 let k_for_cache = k.transpose(1, 2)?.contiguous()?; let v_for_cache = v.transpose(1, 2)?.contiguous()?;
182 let k_squeezed = k_for_cache.squeeze(0)?; let v_squeezed = v_for_cache.squeeze(0)?;
184
185 let start = kv_cache.current_len;
187 let valid_len = start + seq_len;
188 kv_cache.k_caches[layer_idx].slice_set(&k_squeezed, 0, start)?;
189 kv_cache.v_caches[layer_idx].slice_set(&v_squeezed, 0, start)?;
190
191 let k_full = kv_cache.k_caches[layer_idx]
193 .narrow(0, 0, valid_len)?
194 .unsqueeze(0)?
195 .transpose(1, 2)?;
196 let v_full = kv_cache.v_caches[layer_idx]
197 .narrow(0, 0, valid_len)?
198 .unsqueeze(0)?
199 .transpose(1, 2)?;
200
201 let n_rep = self.num_attention_heads / self.num_key_value_heads;
203 let k_full = crate::architectures::repeat_kv(k_full, n_rep)?;
204 let v_full = crate::architectures::repeat_kv(v_full, n_rep)?;
205
206 let scale = (self.head_dim as f64).sqrt();
208 let att = q
209 .to_dtype(DType::F32)?
210 .matmul(&k_full.to_dtype(DType::F32)?.t()?)?;
211 let att = (att / scale)?;
212
213 let att = if seq_len > 1 {
215 let mask: Vec<u8> = (0..seq_len)
216 .flat_map(|i| (0..valid_len).map(move |j| u8::from(j > i + start)))
217 .collect();
218 let mask = Tensor::from_slice(&mask, (1, 1, seq_len, valid_len), x.device())?
219 .broadcast_as(att.shape())?;
220 let neg_inf = Tensor::new(f32::NEG_INFINITY, x.device())?.broadcast_as(att.shape())?;
221 mask.where_cond(&neg_inf, &att)?
222 } else {
223 att
224 };
225
226 let att = candle_nn::ops::softmax_last_dim(&att)?;
227 let y = att
228 .matmul(&v_full.to_dtype(DType::F32)?.contiguous()?)?
229 .to_dtype(x.dtype())?;
230
231 let y =
232 y.transpose(1, 2)?
233 .reshape((b, seq_len, self.num_attention_heads * self.head_dim))?;
234 self.o_proj.forward(&y)
235 }
236
237 fn load(vb: VarBuilder, cfg: &Config) -> CandleResult<Self> {
238 let q_dim = cfg.num_attention_heads * cfg.head_dim;
239 let kv_dim = cfg.num_key_value_heads * cfg.head_dim;
240 Ok(Self {
241 q_proj: linear_no_bias(cfg.hidden_size, q_dim, vb.pp("q_proj"))?,
242 k_proj: linear_no_bias(cfg.hidden_size, kv_dim, vb.pp("k_proj"))?,
243 v_proj: linear_no_bias(cfg.hidden_size, kv_dim, vb.pp("v_proj"))?,
244 o_proj: linear_no_bias(q_dim, cfg.hidden_size, vb.pp("o_proj"))?,
245 num_attention_heads: cfg.num_attention_heads,
246 num_key_value_heads: cfg.num_key_value_heads,
247 head_dim: cfg.head_dim,
248 })
249 }
250}
251
252pub struct Mlp {
253 pub gate_proj: Linear,
254 pub up_proj: Linear,
255 pub down_proj: Linear,
256}
257
258impl Mlp {
259 fn forward(&self, x: &Tensor) -> CandleResult<Tensor> {
260 let gate = candle_nn::ops::silu(&self.gate_proj.forward(x)?)?;
261 let up = self.up_proj.forward(x)?;
262 self.down_proj.forward(&(gate * up)?)
263 }
264
265 fn load(vb: VarBuilder, cfg: &Config) -> CandleResult<Self> {
266 Ok(Self {
267 gate_proj: linear_no_bias(cfg.hidden_size, cfg.intermediate_size, vb.pp("gate_proj"))?,
268 up_proj: linear_no_bias(cfg.hidden_size, cfg.intermediate_size, vb.pp("up_proj"))?,
269 down_proj: linear_no_bias(cfg.intermediate_size, cfg.hidden_size, vb.pp("down_proj"))?,
270 })
271 }
272}
273
274pub struct DecoderLayer {
275 pub self_attn: Attention,
276 pub mlp: Mlp,
277 pub input_layernorm: RmsNormWithWeight,
278 pub post_attention_layernorm: RmsNormWithWeight,
279}
280
281impl DecoderLayer {
282 fn forward(
283 &self,
284 x: &Tensor,
285 pos: usize,
286 layer_idx: usize,
287 rotary: &RotaryEmbedding,
288 kv_cache: &mut PreAllocKvCache,
289 ) -> CandleResult<Tensor> {
290 let residual = x;
291 let x = self.input_layernorm.forward(x)?;
292 let x = (self
293 .self_attn
294 .forward(&x, pos, layer_idx, rotary, kv_cache)?
295 + residual)?;
296 let residual = &x;
297 let x = (self
298 .mlp
299 .forward(&self.post_attention_layernorm.forward(&x)?)?
300 + residual)?;
301 Ok(x)
302 }
303
304 fn load(vb: VarBuilder, cfg: &Config) -> CandleResult<Self> {
305 Ok(Self {
306 self_attn: Attention::load(vb.pp("self_attn"), cfg)?,
307 mlp: Mlp::load(vb.pp("mlp"), cfg)?,
308 input_layernorm: rms_norm_with_weight(
309 cfg.hidden_size,
310 cfg.rms_norm_eps,
311 vb.pp("input_layernorm"),
312 )?,
313 post_attention_layernorm: rms_norm_with_weight(
314 cfg.hidden_size,
315 cfg.rms_norm_eps,
316 vb.pp("post_attention_layernorm"),
317 )?,
318 })
319 }
320}
321
322pub struct Model {
325 pub embed_tokens: Embedding,
326 pub layers: Vec<DecoderLayer>,
327 pub norm: RmsNormWithWeight,
328 pub lm_head: Linear,
329 pub rotary_emb: RotaryEmbedding,
330 pub config: Config,
331 kv_caches: HashMap<String, PreAllocKvCache>,
333}
334
335impl Model {
336 pub fn load(
337 vb: VarBuilder,
338 cfg: &Config,
339 dtype: DType,
340 device: &CandleDevice,
341 ) -> CandleResult<Self> {
342 let embed_tokens =
343 candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
344 let lm_head = if cfg.tie_word_embeddings {
345 Linear::new(embed_tokens.embeddings().clone(), None)
346 } else {
347 linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
348 };
349 let norm = rms_norm_with_weight(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
350 let layers: Vec<DecoderLayer> = (0..cfg.num_hidden_layers)
351 .map(|i| DecoderLayer::load(vb.pp(format!("model.layers.{i}")), cfg))
352 .collect::<CandleResult<_>>()?;
353 let rotary_emb = RotaryEmbedding::new(cfg, dtype, device)?;
354
355 Ok(Self {
356 embed_tokens,
357 layers,
358 norm,
359 lm_head,
360 rotary_emb,
361 config: cfg.clone(),
362 kv_caches: HashMap::new(),
363 })
364 }
365
366 pub fn forward(
367 &mut self,
368 input_ids: &Tensor,
369 pos: usize,
370 cache_key: &str,
371 ) -> CandleResult<Tensor> {
372 let (_, seq_len) = input_ids.dims2()?;
373
374 if !self.kv_caches.contains_key(cache_key) {
376 let kv = PreAllocKvCache::new(
377 self.config.num_hidden_layers,
378 self.config.max_position_embeddings,
379 self.config.num_key_value_heads,
380 self.config.head_dim,
381 DType::F16,
382 &input_ids.device(),
383 )?;
384 self.kv_caches.insert(cache_key.to_string(), kv);
385 }
386 let kv_cache = self.kv_caches.get_mut(cache_key).unwrap();
387
388 let mut x = self.embed_tokens.forward(input_ids)?;
389 for (li, layer) in self.layers.iter().enumerate() {
390 x = layer.forward(&x, pos, li, &self.rotary_emb, kv_cache)?;
391 }
392 let x = self.norm.forward(&x)?;
393 let x = x.i((.., seq_len - 1, ..))?.contiguous()?;
394 let logits = self.lm_head.forward(&x)?;
395
396 kv_cache.current_len += seq_len;
398
399 logits.to_dtype(DType::F32)
400 }
401
402 pub fn clear_kv_cache_for(&mut self, cache_key: &str) {
403 self.kv_caches.remove(cache_key);
404 }
405
406 pub fn export_kv_cache(&self, cache_key: &str) -> Option<Vec<(Tensor, Tensor, usize, usize)>> {
409 let kv = self.kv_caches.get(cache_key)?;
410 Some(
411 kv.k_caches
412 .iter()
413 .zip(kv.v_caches.iter())
414 .map(|(k, v)| (k.clone(), v.clone(), kv.current_len, kv.max_len))
415 .collect(),
416 )
417 }
418
419 pub fn release_cache(&self, _cache_key: &str) {
420 }
423}
424
425pub struct LlamaModelWrapper {
428 pub(crate) model: parking_lot::Mutex<Model>,
429 config: Config,
430 device: CandleDevice,
431 dtype: DType,
432 pub model_dir: Option<std::path::PathBuf>,
433}
434
435impl LlamaModelWrapper {
436 pub fn from_varbuilder(
437 vb: VarBuilder,
438 config: &crate::definition::ModelDefinition,
439 device: CandleDevice,
440 dtype: DType,
441 ) -> Result<Self> {
442 info!("Creating Llama model from weights...");
443
444 let head_dim = config
445 .extra_params
446 .get("head_dim")
447 .and_then(|v| v.as_u64())
448 .map(|v| v as usize)
449 .unwrap_or(config.hidden_size / config.num_attention_heads);
450 let cfg = Config {
451 vocab_size: config.vocab_size,
452 hidden_size: config.hidden_size,
453 intermediate_size: config.intermediate_size,
454 num_hidden_layers: config.num_hidden_layers,
455 num_attention_heads: config.num_attention_heads,
456 num_key_value_heads: config
457 .num_key_value_heads
458 .unwrap_or(config.num_attention_heads),
459 rms_norm_eps: config.norm_eps,
460 rope_theta: config.rope_theta.unwrap_or(10000.0) as f32,
461 max_position_embeddings: config.max_position_embeddings,
462 tie_word_embeddings: config
463 .extra_params
464 .get("tie_word_embeddings")
465 .and_then(|v| v.as_bool())
466 .unwrap_or(false),
467 head_dim,
468 };
469
470 debug!(
471 "Llama config: hidden={}, layers={}, heads={}, kv_heads={}, head_dim={}",
472 cfg.hidden_size,
473 cfg.num_hidden_layers,
474 cfg.num_attention_heads,
475 cfg.num_key_value_heads,
476 cfg.head_dim
477 );
478
479 let model = Model::load(vb, &cfg, dtype, &device)
480 .map_err(|e| FerrumError::model(format!("Failed to load Llama model: {}", e)))?;
481
482 info!("Llama model created successfully");
483
484 Ok(Self {
485 model: parking_lot::Mutex::new(model),
486 config: cfg,
487 device,
488 dtype,
489 model_dir: None,
490 })
491 }
492
493 pub fn forward_prefill(&self, input_ids: &Tensor, cache_key: &str) -> Result<Tensor> {
494 let mut model = self.model.lock();
495 model.clear_kv_cache_for(cache_key);
496 model
497 .forward(input_ids, 0, cache_key)
498 .map_err(|e| FerrumError::model(format!("Prefill failed: {}", e)))
499 }
500
501 pub fn forward_decode(&self, token_id: &Tensor, pos: usize, cache_key: &str) -> Result<Tensor> {
502 let mut model = self.model.lock();
503 model
504 .forward(token_id, pos, cache_key)
505 .map_err(|e| FerrumError::model(format!("Decode failed: {}", e)))
506 }
507
508 pub fn export_kv_cache(&self, cache_key: &str) -> Option<Vec<(Tensor, Tensor, usize, usize)>> {
509 self.model.lock().export_kv_cache(cache_key)
510 }
511
512 pub fn release_cache(&self, cache_key: &str) {
513 self.model.lock().clear_kv_cache_for(cache_key);
514 }
515
516 pub fn config(&self) -> &Config {
517 &self.config
518 }
519
520 pub fn device(&self) -> &CandleDevice {
521 &self.device
522 }
523
524 pub fn candle_device(&self) -> &CandleDevice {
525 &self.device
526 }
527
528 pub fn dtype(&self) -> DType {
529 self.dtype
530 }
531
532 pub fn set_model_dir(&mut self, dir: std::path::PathBuf) {
533 self.model_dir = Some(dir);
534 }
535
536 #[cfg(feature = "cuda")]
538 pub fn create_decode_runner(
539 &self,
540 ) -> Result<ferrum_cuda_kernels::cuda_decode::CudaDecodeRunner> {
541 use ferrum_cuda_kernels::decode_buffers::ModelDims;
542 use ferrum_cuda_kernels::weight_store::{
543 GpuWeight, LayerWeights, LinearWeight, TransformerGpuWeights,
544 };
545
546 let model = self.model.lock();
547 let cfg = &self.config;
548
549 let cuda_device = self
550 .device
551 .as_cuda_device()
552 .map_err(|e| FerrumError::model(format!("not CUDA: {e}")))?;
553 let candle_stream = cuda_device.cuda_stream();
554 candle_stream
555 .synchronize()
556 .map_err(|e| FerrumError::model(format!("sync: {e}")))?;
557 let rs = candle_stream
558 .context()
559 .new_stream()
560 .map_err(|e| FerrumError::model(format!("new_stream: {e}")))?;
561
562 let embed_table = GpuWeight::from_tensor(model.embed_tokens.embeddings(), &rs)
563 .map_err(|e| FerrumError::model(format!("embed: {e}")))?;
564
565 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
566 for (li, layer) in model.layers.iter().enumerate() {
567 let qkv_fused = candle_core::Tensor::cat(
569 &[
570 layer.self_attn.q_proj.weight(),
571 layer.self_attn.k_proj.weight(),
572 layer.self_attn.v_proj.weight(),
573 ],
574 0,
575 )
576 .map_err(|e| FerrumError::model(format!("qkv cat L{li}: {e}")))?;
577
578 let gate_up_fused = candle_core::Tensor::cat(
580 &[layer.mlp.gate_proj.weight(), layer.mlp.up_proj.weight()],
581 0,
582 )
583 .map_err(|e| FerrumError::model(format!("gate_up cat L{li}: {e}")))?;
584
585 layers.push(LayerWeights {
586 input_ln_w: GpuWeight::from_tensor(&layer.input_layernorm.weight, &rs)
587 .map_err(|e| FerrumError::model(format!("input_ln: {e}")))?,
588 qkv_w: LinearWeight::Fp16(
589 GpuWeight::from_tensor(&qkv_fused, &rs)
590 .map_err(|e| FerrumError::model(format!("qkv: {e}")))?,
591 ),
592 q_norm_w: None,
593 k_norm_w: None,
594 o_w: LinearWeight::Fp16(
595 GpuWeight::from_tensor(layer.self_attn.o_proj.weight(), &rs)
596 .map_err(|e| FerrumError::model(format!("o: {e}")))?,
597 ),
598 post_ln_w: GpuWeight::from_tensor(&layer.post_attention_layernorm.weight, &rs)
599 .map_err(|e| FerrumError::model(format!("post_ln: {e}")))?,
600 gate_up_w: LinearWeight::Fp16(
601 GpuWeight::from_tensor(&gate_up_fused, &rs)
602 .map_err(|e| FerrumError::model(format!("gate_up: {e}")))?,
603 ),
604 down_w: LinearWeight::Fp16(
605 GpuWeight::from_tensor(layer.mlp.down_proj.weight(), &rs)
606 .map_err(|e| FerrumError::model(format!("down: {e}")))?,
607 ),
608 });
609 }
610
611 let final_norm_w = GpuWeight::from_tensor(&model.norm.weight, &rs)
612 .map_err(|e| FerrumError::model(format!("final_norm: {e}")))?;
613 let lm_head_w = LinearWeight::Fp16(
614 GpuWeight::from_tensor(model.lm_head.weight(), &rs)
615 .map_err(|e| FerrumError::model(format!("lm_head: {e}")))?,
616 );
617 let rope_cos = GpuWeight::from_tensor(&model.rotary_emb.cos, &rs)
618 .map_err(|e| FerrumError::model(format!("rope_cos: {e}")))?;
619 let rope_sin = GpuWeight::from_tensor(&model.rotary_emb.sin, &rs)
620 .map_err(|e| FerrumError::model(format!("rope_sin: {e}")))?;
621
622 let weights = TransformerGpuWeights {
623 embed_table,
624 layers,
625 final_norm_w,
626 lm_head_w,
627 rope_cos,
628 rope_sin,
629 };
630
631 let dims = ModelDims {
632 hidden_size: cfg.hidden_size,
633 intermediate_size: cfg.intermediate_size,
634 num_attention_heads: cfg.num_attention_heads,
635 num_kv_heads: cfg.num_key_value_heads,
636 head_dim: cfg.head_dim,
637 vocab_size: cfg.vocab_size,
638 num_layers: cfg.num_hidden_layers,
639 max_seq_len: cfg.max_position_embeddings,
640 quantized: false,
641 max_batch_size: std::env::var("FERRUM_MAX_BATCH")
642 .ok()
643 .and_then(|v| v.parse().ok())
644 .unwrap_or(1),
645 };
646
647 rs.synchronize()
648 .map_err(|e| FerrumError::model(format!("sync: {e}")))?;
649
650 ferrum_cuda_kernels::cuda_decode::CudaDecodeRunner::new(
651 weights,
652 dims,
653 cuda_device.clone(),
654 rs,
655 )
656 .map_err(|e| FerrumError::model(format!("CudaDecodeRunner: {e}")))
657 }
658}