Skip to main content

ferrum_models/architectures/
llama.rs

1//! Custom Llama architecture with public fields and per-request KV cache.
2//!
3//! Based on candle's Llama implementation but restructured for:
4//! - Public weight access (CUDA runner weight extraction)
5//! - Per-request KV cache keyed by cache_id (concurrent serving)
6//! - export_kv_cache() for CUDA runner KV migration
7//! - create_decode_runner() for CUDA decode path
8
9use candle_core::{DType, Device as CandleDevice, IndexOp, Result as CandleResult, Tensor};
10use candle_nn::{self, Module, VarBuilder};
11
12// Use candle_nn types directly
13type Linear = candle_nn::Linear;
14type Embedding = candle_nn::Embedding;
15type RmsNorm = candle_nn::RmsNorm;
16
17fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> CandleResult<Linear> {
18    let w = vb.get((out_dim, in_dim), "weight")?;
19    Ok(Linear::new(w, None))
20}
21
22/// RmsNorm with public weight access for CUDA runner extraction.
23pub struct RmsNormWithWeight {
24    norm: RmsNorm,
25    pub weight: Tensor,
26}
27
28impl Module for RmsNormWithWeight {
29    fn forward(&self, xs: &Tensor) -> CandleResult<Tensor> {
30        self.norm.forward(xs)
31    }
32}
33
34fn rms_norm_with_weight(size: usize, eps: f64, vb: VarBuilder) -> CandleResult<RmsNormWithWeight> {
35    let w = vb.get(size, "weight")?;
36    Ok(RmsNormWithWeight {
37        norm: RmsNorm::new(w.clone(), eps),
38        weight: w,
39    })
40}
41use ferrum_types::{FerrumError, Result};
42use std::collections::HashMap;
43use tracing::{debug, info};
44
45// ======================== Pre-allocated KV Cache ========================
46
47/// Pre-allocated KV cache for a single sequence (all layers).
48pub struct PreAllocKvCache {
49    /// Per-layer K cache: [max_len, num_kv_heads, head_dim]
50    pub k_caches: Vec<Tensor>,
51    /// Per-layer V cache: [max_len, num_kv_heads, head_dim]
52    pub v_caches: Vec<Tensor>,
53    pub current_len: usize,
54    pub max_len: usize,
55}
56
57impl PreAllocKvCache {
58    pub fn new(
59        num_layers: usize,
60        max_len: usize,
61        num_kv_heads: usize,
62        head_dim: usize,
63        dtype: DType,
64        device: &CandleDevice,
65    ) -> CandleResult<Self> {
66        let mut k_caches = Vec::with_capacity(num_layers);
67        let mut v_caches = Vec::with_capacity(num_layers);
68        for _ in 0..num_layers {
69            k_caches.push(Tensor::zeros(
70                (max_len, num_kv_heads, head_dim),
71                dtype,
72                device,
73            )?);
74            v_caches.push(Tensor::zeros(
75                (max_len, num_kv_heads, head_dim),
76                dtype,
77                device,
78            )?);
79        }
80        Ok(Self {
81            k_caches,
82            v_caches,
83            current_len: 0,
84            max_len,
85        })
86    }
87}
88
89// ======================== Rotary Embedding ========================
90
91pub struct RotaryEmbedding {
92    pub cos: Tensor,
93    pub sin: Tensor,
94}
95
96impl RotaryEmbedding {
97    pub fn new(cfg: &Config, dtype: DType, device: &CandleDevice) -> CandleResult<Self> {
98        let head_dim = cfg.head_dim;
99        let inv_freq: Vec<f32> = (0..head_dim)
100            .step_by(2)
101            .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
102            .collect();
103        let inv_freq_t = Tensor::new(inv_freq, device)?;
104        let positions = Tensor::arange(0, cfg.max_position_embeddings as u32, device)?
105            .to_dtype(DType::F32)?
106            .reshape((cfg.max_position_embeddings, 1))?;
107        let angles = positions.matmul(&inv_freq_t.reshape((1, inv_freq_t.elem_count()))?)?;
108        let cos = angles.cos()?.to_dtype(dtype)?;
109        let sin = angles.sin()?.to_dtype(dtype)?;
110        Ok(Self { cos, sin })
111    }
112
113    fn apply(&self, x: &Tensor, pos: usize) -> CandleResult<Tensor> {
114        let (_, _, seq_len, _) = x.dims4()?;
115        let cos = self.cos.narrow(0, pos, seq_len)?;
116        let sin = self.sin.narrow(0, pos, seq_len)?;
117        candle_nn::rotary_emb::rope(x, &cos, &sin)
118    }
119}
120
121// ======================== Model Components ========================
122
123#[derive(Debug, Clone)]
124pub struct Config {
125    pub vocab_size: usize,
126    pub hidden_size: usize,
127    pub intermediate_size: usize,
128    pub num_hidden_layers: usize,
129    pub num_attention_heads: usize,
130    pub num_key_value_heads: usize,
131    pub rms_norm_eps: f64,
132    pub rope_theta: f32,
133    pub max_position_embeddings: usize,
134    pub tie_word_embeddings: bool,
135    pub head_dim: usize,
136}
137
138pub struct Attention {
139    pub q_proj: Linear,
140    pub k_proj: Linear,
141    pub v_proj: Linear,
142    pub o_proj: Linear,
143    pub num_attention_heads: usize,
144    pub num_key_value_heads: usize,
145    pub head_dim: usize,
146}
147
148impl Attention {
149    fn forward(
150        &self,
151        x: &Tensor,
152        pos: usize,
153        layer_idx: usize,
154        rotary: &RotaryEmbedding,
155        kv_cache: &mut PreAllocKvCache,
156    ) -> CandleResult<Tensor> {
157        let (b, seq_len, _) = x.dims3()?;
158
159        let q = self.q_proj.forward(x)?;
160        let k = self.k_proj.forward(x)?;
161        let v = self.v_proj.forward(x)?;
162
163        let q = q
164            .reshape((b, seq_len, self.num_attention_heads, self.head_dim))?
165            .transpose(1, 2)?
166            .contiguous()?;
167        let k = k
168            .reshape((b, seq_len, self.num_key_value_heads, self.head_dim))?
169            .transpose(1, 2)?
170            .contiguous()?;
171        let v = v
172            .reshape((b, seq_len, self.num_key_value_heads, self.head_dim))?
173            .transpose(1, 2)?;
174
175        // Apply RoPE
176        let q = rotary.apply(&q, pos)?;
177        let k = rotary.apply(&k, pos)?;
178
179        // Update KV cache: [seq_len, num_kv_heads, head_dim]
180        let k_for_cache = k.transpose(1, 2)?.contiguous()?; // [b, seq, nkv, hd]
181        let v_for_cache = v.transpose(1, 2)?.contiguous()?;
182        let k_squeezed = k_for_cache.squeeze(0)?; // [seq, nkv, hd]
183        let v_squeezed = v_for_cache.squeeze(0)?;
184
185        // Write into pre-allocated cache using slice_set (in-place)
186        let start = kv_cache.current_len;
187        let valid_len = start + seq_len;
188        kv_cache.k_caches[layer_idx].slice_set(&k_squeezed, 0, start)?;
189        kv_cache.v_caches[layer_idx].slice_set(&v_squeezed, 0, start)?;
190
191        // Read back full KV for attention: [b, nkv, valid_len, hd]
192        let k_full = kv_cache.k_caches[layer_idx]
193            .narrow(0, 0, valid_len)?
194            .unsqueeze(0)?
195            .transpose(1, 2)?;
196        let v_full = kv_cache.v_caches[layer_idx]
197            .narrow(0, 0, valid_len)?
198            .unsqueeze(0)?
199            .transpose(1, 2)?;
200
201        // GQA repeat
202        let n_rep = self.num_attention_heads / self.num_key_value_heads;
203        let k_full = crate::architectures::repeat_kv(k_full, n_rep)?;
204        let v_full = crate::architectures::repeat_kv(v_full, n_rep)?;
205
206        // Scaled dot-product attention
207        let scale = (self.head_dim as f64).sqrt();
208        let att = q
209            .to_dtype(DType::F32)?
210            .matmul(&k_full.to_dtype(DType::F32)?.t()?)?;
211        let att = (att / scale)?;
212
213        // Causal mask for prefill (seq_len > 1)
214        let att = if seq_len > 1 {
215            let mask: Vec<u8> = (0..seq_len)
216                .flat_map(|i| (0..valid_len).map(move |j| u8::from(j > i + start)))
217                .collect();
218            let mask = Tensor::from_slice(&mask, (1, 1, seq_len, valid_len), x.device())?
219                .broadcast_as(att.shape())?;
220            let neg_inf = Tensor::new(f32::NEG_INFINITY, x.device())?.broadcast_as(att.shape())?;
221            mask.where_cond(&neg_inf, &att)?
222        } else {
223            att
224        };
225
226        let att = candle_nn::ops::softmax_last_dim(&att)?;
227        let y = att
228            .matmul(&v_full.to_dtype(DType::F32)?.contiguous()?)?
229            .to_dtype(x.dtype())?;
230
231        let y =
232            y.transpose(1, 2)?
233                .reshape((b, seq_len, self.num_attention_heads * self.head_dim))?;
234        self.o_proj.forward(&y)
235    }
236
237    fn load(vb: VarBuilder, cfg: &Config) -> CandleResult<Self> {
238        let q_dim = cfg.num_attention_heads * cfg.head_dim;
239        let kv_dim = cfg.num_key_value_heads * cfg.head_dim;
240        Ok(Self {
241            q_proj: linear_no_bias(cfg.hidden_size, q_dim, vb.pp("q_proj"))?,
242            k_proj: linear_no_bias(cfg.hidden_size, kv_dim, vb.pp("k_proj"))?,
243            v_proj: linear_no_bias(cfg.hidden_size, kv_dim, vb.pp("v_proj"))?,
244            o_proj: linear_no_bias(q_dim, cfg.hidden_size, vb.pp("o_proj"))?,
245            num_attention_heads: cfg.num_attention_heads,
246            num_key_value_heads: cfg.num_key_value_heads,
247            head_dim: cfg.head_dim,
248        })
249    }
250}
251
252pub struct Mlp {
253    pub gate_proj: Linear,
254    pub up_proj: Linear,
255    pub down_proj: Linear,
256}
257
258impl Mlp {
259    fn forward(&self, x: &Tensor) -> CandleResult<Tensor> {
260        let gate = candle_nn::ops::silu(&self.gate_proj.forward(x)?)?;
261        let up = self.up_proj.forward(x)?;
262        self.down_proj.forward(&(gate * up)?)
263    }
264
265    fn load(vb: VarBuilder, cfg: &Config) -> CandleResult<Self> {
266        Ok(Self {
267            gate_proj: linear_no_bias(cfg.hidden_size, cfg.intermediate_size, vb.pp("gate_proj"))?,
268            up_proj: linear_no_bias(cfg.hidden_size, cfg.intermediate_size, vb.pp("up_proj"))?,
269            down_proj: linear_no_bias(cfg.intermediate_size, cfg.hidden_size, vb.pp("down_proj"))?,
270        })
271    }
272}
273
274pub struct DecoderLayer {
275    pub self_attn: Attention,
276    pub mlp: Mlp,
277    pub input_layernorm: RmsNormWithWeight,
278    pub post_attention_layernorm: RmsNormWithWeight,
279}
280
281impl DecoderLayer {
282    fn forward(
283        &self,
284        x: &Tensor,
285        pos: usize,
286        layer_idx: usize,
287        rotary: &RotaryEmbedding,
288        kv_cache: &mut PreAllocKvCache,
289    ) -> CandleResult<Tensor> {
290        let residual = x;
291        let x = self.input_layernorm.forward(x)?;
292        let x = (self
293            .self_attn
294            .forward(&x, pos, layer_idx, rotary, kv_cache)?
295            + residual)?;
296        let residual = &x;
297        let x = (self
298            .mlp
299            .forward(&self.post_attention_layernorm.forward(&x)?)?
300            + residual)?;
301        Ok(x)
302    }
303
304    fn load(vb: VarBuilder, cfg: &Config) -> CandleResult<Self> {
305        Ok(Self {
306            self_attn: Attention::load(vb.pp("self_attn"), cfg)?,
307            mlp: Mlp::load(vb.pp("mlp"), cfg)?,
308            input_layernorm: rms_norm_with_weight(
309                cfg.hidden_size,
310                cfg.rms_norm_eps,
311                vb.pp("input_layernorm"),
312            )?,
313            post_attention_layernorm: rms_norm_with_weight(
314                cfg.hidden_size,
315                cfg.rms_norm_eps,
316                vb.pp("post_attention_layernorm"),
317            )?,
318        })
319    }
320}
321
322// ======================== Full Model ========================
323
324pub struct Model {
325    pub embed_tokens: Embedding,
326    pub layers: Vec<DecoderLayer>,
327    pub norm: RmsNormWithWeight,
328    pub lm_head: Linear,
329    pub rotary_emb: RotaryEmbedding,
330    pub config: Config,
331    /// Per-request KV caches, keyed by cache_id
332    kv_caches: HashMap<String, PreAllocKvCache>,
333}
334
335impl Model {
336    pub fn load(
337        vb: VarBuilder,
338        cfg: &Config,
339        dtype: DType,
340        device: &CandleDevice,
341    ) -> CandleResult<Self> {
342        let embed_tokens =
343            candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
344        let lm_head = if cfg.tie_word_embeddings {
345            Linear::new(embed_tokens.embeddings().clone(), None)
346        } else {
347            linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
348        };
349        let norm = rms_norm_with_weight(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
350        let layers: Vec<DecoderLayer> = (0..cfg.num_hidden_layers)
351            .map(|i| DecoderLayer::load(vb.pp(format!("model.layers.{i}")), cfg))
352            .collect::<CandleResult<_>>()?;
353        let rotary_emb = RotaryEmbedding::new(cfg, dtype, device)?;
354
355        Ok(Self {
356            embed_tokens,
357            layers,
358            norm,
359            lm_head,
360            rotary_emb,
361            config: cfg.clone(),
362            kv_caches: HashMap::new(),
363        })
364    }
365
366    pub fn forward(
367        &mut self,
368        input_ids: &Tensor,
369        pos: usize,
370        cache_key: &str,
371    ) -> CandleResult<Tensor> {
372        let (_, seq_len) = input_ids.dims2()?;
373
374        // Ensure KV cache exists for this request
375        if !self.kv_caches.contains_key(cache_key) {
376            let kv = PreAllocKvCache::new(
377                self.config.num_hidden_layers,
378                self.config.max_position_embeddings,
379                self.config.num_key_value_heads,
380                self.config.head_dim,
381                DType::F16,
382                &input_ids.device(),
383            )?;
384            self.kv_caches.insert(cache_key.to_string(), kv);
385        }
386        let kv_cache = self.kv_caches.get_mut(cache_key).unwrap();
387
388        let mut x = self.embed_tokens.forward(input_ids)?;
389        for (li, layer) in self.layers.iter().enumerate() {
390            x = layer.forward(&x, pos, li, &self.rotary_emb, kv_cache)?;
391        }
392        let x = self.norm.forward(&x)?;
393        let x = x.i((.., seq_len - 1, ..))?.contiguous()?;
394        let logits = self.lm_head.forward(&x)?;
395
396        // Update KV cache length
397        kv_cache.current_len += seq_len;
398
399        logits.to_dtype(DType::F32)
400    }
401
402    pub fn clear_kv_cache_for(&mut self, cache_key: &str) {
403        self.kv_caches.remove(cache_key);
404    }
405
406    /// Export KV cache for CUDA runner migration.
407    /// Returns per-layer (K_tensor, V_tensor, current_len, max_len).
408    pub fn export_kv_cache(&self, cache_key: &str) -> Option<Vec<(Tensor, Tensor, usize, usize)>> {
409        let kv = self.kv_caches.get(cache_key)?;
410        Some(
411            kv.k_caches
412                .iter()
413                .zip(kv.v_caches.iter())
414                .map(|(k, v)| (k.clone(), v.clone(), kv.current_len, kv.max_len))
415                .collect(),
416        )
417    }
418
419    pub fn release_cache(&self, _cache_key: &str) {
420        // Intentionally no-op: kv_caches is &self, can't mutate.
421        // Use clear_kv_cache_for with &mut self instead.
422    }
423}
424
425// ======================== Model Wrapper ========================
426
427pub struct LlamaModelWrapper {
428    pub(crate) model: parking_lot::Mutex<Model>,
429    config: Config,
430    device: CandleDevice,
431    dtype: DType,
432    pub model_dir: Option<std::path::PathBuf>,
433}
434
435impl LlamaModelWrapper {
436    pub fn from_varbuilder(
437        vb: VarBuilder,
438        config: &crate::definition::ModelDefinition,
439        device: CandleDevice,
440        dtype: DType,
441    ) -> Result<Self> {
442        info!("Creating Llama model from weights...");
443
444        let head_dim = config
445            .extra_params
446            .get("head_dim")
447            .and_then(|v| v.as_u64())
448            .map(|v| v as usize)
449            .unwrap_or(config.hidden_size / config.num_attention_heads);
450        let cfg = Config {
451            vocab_size: config.vocab_size,
452            hidden_size: config.hidden_size,
453            intermediate_size: config.intermediate_size,
454            num_hidden_layers: config.num_hidden_layers,
455            num_attention_heads: config.num_attention_heads,
456            num_key_value_heads: config
457                .num_key_value_heads
458                .unwrap_or(config.num_attention_heads),
459            rms_norm_eps: config.norm_eps,
460            rope_theta: config.rope_theta.unwrap_or(10000.0) as f32,
461            max_position_embeddings: config.max_position_embeddings,
462            tie_word_embeddings: config
463                .extra_params
464                .get("tie_word_embeddings")
465                .and_then(|v| v.as_bool())
466                .unwrap_or(false),
467            head_dim,
468        };
469
470        debug!(
471            "Llama config: hidden={}, layers={}, heads={}, kv_heads={}, head_dim={}",
472            cfg.hidden_size,
473            cfg.num_hidden_layers,
474            cfg.num_attention_heads,
475            cfg.num_key_value_heads,
476            cfg.head_dim
477        );
478
479        let model = Model::load(vb, &cfg, dtype, &device)
480            .map_err(|e| FerrumError::model(format!("Failed to load Llama model: {}", e)))?;
481
482        info!("Llama model created successfully");
483
484        Ok(Self {
485            model: parking_lot::Mutex::new(model),
486            config: cfg,
487            device,
488            dtype,
489            model_dir: None,
490        })
491    }
492
493    pub fn forward_prefill(&self, input_ids: &Tensor, cache_key: &str) -> Result<Tensor> {
494        let mut model = self.model.lock();
495        model.clear_kv_cache_for(cache_key);
496        model
497            .forward(input_ids, 0, cache_key)
498            .map_err(|e| FerrumError::model(format!("Prefill failed: {}", e)))
499    }
500
501    pub fn forward_decode(&self, token_id: &Tensor, pos: usize, cache_key: &str) -> Result<Tensor> {
502        let mut model = self.model.lock();
503        model
504            .forward(token_id, pos, cache_key)
505            .map_err(|e| FerrumError::model(format!("Decode failed: {}", e)))
506    }
507
508    pub fn export_kv_cache(&self, cache_key: &str) -> Option<Vec<(Tensor, Tensor, usize, usize)>> {
509        self.model.lock().export_kv_cache(cache_key)
510    }
511
512    pub fn release_cache(&self, cache_key: &str) {
513        self.model.lock().clear_kv_cache_for(cache_key);
514    }
515
516    pub fn config(&self) -> &Config {
517        &self.config
518    }
519
520    pub fn device(&self) -> &CandleDevice {
521        &self.device
522    }
523
524    pub fn candle_device(&self) -> &CandleDevice {
525        &self.device
526    }
527
528    pub fn dtype(&self) -> DType {
529        self.dtype
530    }
531
532    pub fn set_model_dir(&mut self, dir: std::path::PathBuf) {
533        self.model_dir = Some(dir);
534    }
535
536    /// Create CUDA decode runner by extracting weights directly from the model.
537    #[cfg(feature = "cuda")]
538    pub fn create_decode_runner(
539        &self,
540    ) -> Result<ferrum_cuda_kernels::cuda_decode::CudaDecodeRunner> {
541        use ferrum_cuda_kernels::decode_buffers::ModelDims;
542        use ferrum_cuda_kernels::weight_store::{
543            GpuWeight, LayerWeights, LinearWeight, TransformerGpuWeights,
544        };
545
546        let model = self.model.lock();
547        let cfg = &self.config;
548
549        let cuda_device = self
550            .device
551            .as_cuda_device()
552            .map_err(|e| FerrumError::model(format!("not CUDA: {e}")))?;
553        let candle_stream = cuda_device.cuda_stream();
554        candle_stream
555            .synchronize()
556            .map_err(|e| FerrumError::model(format!("sync: {e}")))?;
557        let rs = candle_stream
558            .context()
559            .new_stream()
560            .map_err(|e| FerrumError::model(format!("new_stream: {e}")))?;
561
562        let embed_table = GpuWeight::from_tensor(model.embed_tokens.embeddings(), &rs)
563            .map_err(|e| FerrumError::model(format!("embed: {e}")))?;
564
565        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
566        for (li, layer) in model.layers.iter().enumerate() {
567            // Fuse Q+K+V → QKV
568            let qkv_fused = candle_core::Tensor::cat(
569                &[
570                    layer.self_attn.q_proj.weight(),
571                    layer.self_attn.k_proj.weight(),
572                    layer.self_attn.v_proj.weight(),
573                ],
574                0,
575            )
576            .map_err(|e| FerrumError::model(format!("qkv cat L{li}: {e}")))?;
577
578            // Fuse gate+up
579            let gate_up_fused = candle_core::Tensor::cat(
580                &[layer.mlp.gate_proj.weight(), layer.mlp.up_proj.weight()],
581                0,
582            )
583            .map_err(|e| FerrumError::model(format!("gate_up cat L{li}: {e}")))?;
584
585            layers.push(LayerWeights {
586                input_ln_w: GpuWeight::from_tensor(&layer.input_layernorm.weight, &rs)
587                    .map_err(|e| FerrumError::model(format!("input_ln: {e}")))?,
588                qkv_w: LinearWeight::Fp16(
589                    GpuWeight::from_tensor(&qkv_fused, &rs)
590                        .map_err(|e| FerrumError::model(format!("qkv: {e}")))?,
591                ),
592                q_norm_w: None,
593                k_norm_w: None,
594                o_w: LinearWeight::Fp16(
595                    GpuWeight::from_tensor(layer.self_attn.o_proj.weight(), &rs)
596                        .map_err(|e| FerrumError::model(format!("o: {e}")))?,
597                ),
598                post_ln_w: GpuWeight::from_tensor(&layer.post_attention_layernorm.weight, &rs)
599                    .map_err(|e| FerrumError::model(format!("post_ln: {e}")))?,
600                gate_up_w: LinearWeight::Fp16(
601                    GpuWeight::from_tensor(&gate_up_fused, &rs)
602                        .map_err(|e| FerrumError::model(format!("gate_up: {e}")))?,
603                ),
604                down_w: LinearWeight::Fp16(
605                    GpuWeight::from_tensor(layer.mlp.down_proj.weight(), &rs)
606                        .map_err(|e| FerrumError::model(format!("down: {e}")))?,
607                ),
608            });
609        }
610
611        let final_norm_w = GpuWeight::from_tensor(&model.norm.weight, &rs)
612            .map_err(|e| FerrumError::model(format!("final_norm: {e}")))?;
613        let lm_head_w = LinearWeight::Fp16(
614            GpuWeight::from_tensor(model.lm_head.weight(), &rs)
615                .map_err(|e| FerrumError::model(format!("lm_head: {e}")))?,
616        );
617        let rope_cos = GpuWeight::from_tensor(&model.rotary_emb.cos, &rs)
618            .map_err(|e| FerrumError::model(format!("rope_cos: {e}")))?;
619        let rope_sin = GpuWeight::from_tensor(&model.rotary_emb.sin, &rs)
620            .map_err(|e| FerrumError::model(format!("rope_sin: {e}")))?;
621
622        let weights = TransformerGpuWeights {
623            embed_table,
624            layers,
625            final_norm_w,
626            lm_head_w,
627            rope_cos,
628            rope_sin,
629        };
630
631        let dims = ModelDims {
632            hidden_size: cfg.hidden_size,
633            intermediate_size: cfg.intermediate_size,
634            num_attention_heads: cfg.num_attention_heads,
635            num_kv_heads: cfg.num_key_value_heads,
636            head_dim: cfg.head_dim,
637            vocab_size: cfg.vocab_size,
638            num_layers: cfg.num_hidden_layers,
639            max_seq_len: cfg.max_position_embeddings,
640            quantized: false,
641            max_batch_size: std::env::var("FERRUM_MAX_BATCH")
642                .ok()
643                .and_then(|v| v.parse().ok())
644                .unwrap_or(1),
645        };
646
647        rs.synchronize()
648            .map_err(|e| FerrumError::model(format!("sync: {e}")))?;
649
650        ferrum_cuda_kernels::cuda_decode::CudaDecodeRunner::new(
651            weights,
652            dims,
653            cuda_device.clone(),
654            rs,
655        )
656        .map_err(|e| FerrumError::model(format!("CudaDecodeRunner: {e}")))
657    }
658}