Skip to main content

ferrum_models/loader/
runner_weights.rs

1//! Generic CUDA decode runner weight loader.
2//!
3//! Loads transformer weights from safetensors, fuses separate Q/K/V → QKV
4//! and gate/up → gate_up, then uploads to a CUDA stream. Architecture-agnostic:
5//! works for Llama, Qwen2, Mistral, and any model with the standard naming.
6
7#[cfg(feature = "cuda")]
8use candle_core::{DType, Device as CandleDevice, Tensor};
9#[cfg(feature = "cuda")]
10use candle_nn::VarBuilder;
11#[cfg(feature = "cuda")]
12use ferrum_cuda_kernels::{
13    decode_buffers::ModelDims,
14    weight_store::{GpuWeight, LayerWeights, LinearWeight, TransformerGpuWeights},
15};
16#[cfg(feature = "cuda")]
17use ferrum_types::{FerrumError, Result};
18#[cfg(feature = "cuda")]
19use std::sync::Arc;
20
21/// Architecture-specific weight naming and structure.
22#[cfg(feature = "cuda")]
23pub struct WeightConfig {
24    pub num_hidden_layers: usize,
25    pub hidden_size: usize,
26    pub intermediate_size: usize,
27    pub num_attention_heads: usize,
28    pub num_kv_heads: usize,
29    pub head_dim: usize,
30    pub vocab_size: usize,
31    pub max_seq_len: usize,
32    pub rope_theta: f64,
33    /// Whether the model has Q/K normalization (Qwen3 yes, Llama/Qwen2 no)
34    pub has_qk_norm: bool,
35    /// Whether QKV projection is already fused into one weight
36    pub qkv_fused: bool,
37    /// Whether gate+up MLP projection is already fused
38    pub gate_up_fused: bool,
39}
40
41/// Load transformer weights from safetensors for the CUDA decode runner.
42///
43/// Handles both fused (Qwen3) and separate (Llama/Qwen2) weight formats.
44/// Returns `TransformerGpuWeights` ready for `CudaDecodeRunner`.
45#[cfg(feature = "cuda")]
46pub fn load_runner_weights(
47    vb: &VarBuilder,
48    cfg: &WeightConfig,
49    device: &CandleDevice,
50) -> Result<(
51    TransformerGpuWeights,
52    ModelDims,
53    Arc<candle_core::cuda_backend::cudarc::driver::CudaStream>,
54)> {
55    use candle_core::cuda_backend::CudaDevice;
56
57    let cuda_device = device
58        .as_cuda_device()
59        .map_err(|e| FerrumError::model(format!("not CUDA: {e}")))?;
60
61    // Sync candle stream, create runner stream
62    let candle_stream = cuda_device.cuda_stream();
63    candle_stream
64        .synchronize()
65        .map_err(|e| FerrumError::model(format!("candle stream sync: {e}")))?;
66    let rs = candle_stream
67        .context()
68        .new_stream()
69        .map_err(|e| FerrumError::model(format!("new_stream: {e}")))?;
70
71    // Embed tokens
72    let embed_t = vb
73        .get(
74            (cfg.vocab_size, cfg.hidden_size),
75            "model.embed_tokens.weight",
76        )
77        .map_err(|e| FerrumError::model(format!("embed: {e}")))?;
78    let embed_table = GpuWeight::from_tensor(&embed_t, &rs)
79        .map_err(|e| FerrumError::model(format!("embed: {e}")))?;
80
81    // Per-layer weights
82    let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
83    let q_dim = cfg.num_attention_heads * cfg.head_dim;
84    let kv_dim = cfg.num_kv_heads * cfg.head_dim;
85
86    for li in 0..cfg.num_hidden_layers {
87        let prefix = format!("model.layers.{li}");
88
89        // Input layer norm
90        let ln_w = vb
91            .get(cfg.hidden_size, &format!("{prefix}.input_layernorm.weight"))
92            .map_err(|e| FerrumError::model(format!("input_ln L{li}: {e}")))?;
93        let input_ln_w = GpuWeight::from_tensor(&ln_w, &rs)
94            .map_err(|e| FerrumError::model(format!("input_ln: {e}")))?;
95
96        // QKV projection — fuse if separate
97        let qkv_tensor = if cfg.qkv_fused {
98            vb.get(
99                (q_dim + 2 * kv_dim, cfg.hidden_size),
100                &format!("{prefix}.self_attn.qkv_proj.weight"),
101            )
102            .map_err(|e| FerrumError::model(format!("qkv L{li}: {e}")))?
103        } else {
104            let q = vb
105                .get(
106                    (q_dim, cfg.hidden_size),
107                    &format!("{prefix}.self_attn.q_proj.weight"),
108                )
109                .map_err(|e| FerrumError::model(format!("q L{li}: {e}")))?;
110            let k = vb
111                .get(
112                    (kv_dim, cfg.hidden_size),
113                    &format!("{prefix}.self_attn.k_proj.weight"),
114                )
115                .map_err(|e| FerrumError::model(format!("k L{li}: {e}")))?;
116            let v = vb
117                .get(
118                    (kv_dim, cfg.hidden_size),
119                    &format!("{prefix}.self_attn.v_proj.weight"),
120                )
121                .map_err(|e| FerrumError::model(format!("v L{li}: {e}")))?;
122            Tensor::cat(&[&q, &k, &v], 0)
123                .map_err(|e| FerrumError::model(format!("qkv cat L{li}: {e}")))?
124        };
125        let qkv_w = LinearWeight::Fp16(
126            GpuWeight::from_tensor(&qkv_tensor, &rs)
127                .map_err(|e| FerrumError::model(format!("qkv: {e}")))?,
128        );
129
130        // Q/K norms (Qwen3 only)
131        let q_norm_w = if cfg.has_qk_norm {
132            let t = vb
133                .get(cfg.head_dim, &format!("{prefix}.self_attn.q_norm.weight"))
134                .map_err(|e| FerrumError::model(format!("q_norm L{li}: {e}")))?;
135            Some(
136                GpuWeight::from_tensor(&t, &rs)
137                    .map_err(|e| FerrumError::model(format!("q_norm: {e}")))?,
138            )
139        } else {
140            None
141        };
142        let k_norm_w = if cfg.has_qk_norm {
143            let t = vb
144                .get(cfg.head_dim, &format!("{prefix}.self_attn.k_norm.weight"))
145                .map_err(|e| FerrumError::model(format!("k_norm L{li}: {e}")))?;
146            Some(
147                GpuWeight::from_tensor(&t, &rs)
148                    .map_err(|e| FerrumError::model(format!("k_norm: {e}")))?,
149            )
150        } else {
151            None
152        };
153
154        // O projection
155        let o_t = vb
156            .get(
157                (cfg.hidden_size, q_dim),
158                &format!("{prefix}.self_attn.o_proj.weight"),
159            )
160            .map_err(|e| FerrumError::model(format!("o L{li}: {e}")))?;
161        let o_w = LinearWeight::Fp16(
162            GpuWeight::from_tensor(&o_t, &rs).map_err(|e| FerrumError::model(format!("o: {e}")))?,
163        );
164
165        // Post-attention layer norm
166        let pln_t = vb
167            .get(
168                cfg.hidden_size,
169                &format!("{prefix}.post_attention_layernorm.weight"),
170            )
171            .map_err(|e| FerrumError::model(format!("post_ln L{li}: {e}")))?;
172        let post_ln_w = GpuWeight::from_tensor(&pln_t, &rs)
173            .map_err(|e| FerrumError::model(format!("post_ln: {e}")))?;
174
175        // MLP gate+up — fuse if separate
176        let gate_up_tensor = if cfg.gate_up_fused {
177            vb.get(
178                (2 * cfg.intermediate_size, cfg.hidden_size),
179                &format!("{prefix}.mlp.gate_up_proj.weight"),
180            )
181            .map_err(|e| FerrumError::model(format!("gate_up L{li}: {e}")))?
182        } else {
183            let gate = vb
184                .get(
185                    (cfg.intermediate_size, cfg.hidden_size),
186                    &format!("{prefix}.mlp.gate_proj.weight"),
187                )
188                .map_err(|e| FerrumError::model(format!("gate L{li}: {e}")))?;
189            let up = vb
190                .get(
191                    (cfg.intermediate_size, cfg.hidden_size),
192                    &format!("{prefix}.mlp.up_proj.weight"),
193                )
194                .map_err(|e| FerrumError::model(format!("up L{li}: {e}")))?;
195            Tensor::cat(&[&gate, &up], 0)
196                .map_err(|e| FerrumError::model(format!("gate_up cat L{li}: {e}")))?
197        };
198        let gate_up_w = LinearWeight::Fp16(
199            GpuWeight::from_tensor(&gate_up_tensor, &rs)
200                .map_err(|e| FerrumError::model(format!("gate_up: {e}")))?,
201        );
202
203        // Down projection
204        let down_t = vb
205            .get(
206                (cfg.hidden_size, cfg.intermediate_size),
207                &format!("{prefix}.mlp.down_proj.weight"),
208            )
209            .map_err(|e| FerrumError::model(format!("down L{li}: {e}")))?;
210        let down_w = LinearWeight::Fp16(
211            GpuWeight::from_tensor(&down_t, &rs)
212                .map_err(|e| FerrumError::model(format!("down: {e}")))?,
213        );
214
215        layers.push(LayerWeights {
216            input_ln_w,
217            qkv_w,
218            q_norm_w,
219            k_norm_w,
220            o_w,
221            post_ln_w,
222            gate_up_w,
223            down_w,
224        });
225    }
226
227    // Final norm
228    let fn_t = vb
229        .get(cfg.hidden_size, "model.norm.weight")
230        .map_err(|e| FerrumError::model(format!("final_norm: {e}")))?;
231    let final_norm_w = GpuWeight::from_tensor(&fn_t, &rs)
232        .map_err(|e| FerrumError::model(format!("final_norm: {e}")))?;
233
234    // LM head (or tied to embed_tokens)
235    let lm_t = vb
236        .get((cfg.vocab_size, cfg.hidden_size), "lm_head.weight")
237        .or_else(|_| {
238            vb.get(
239                (cfg.vocab_size, cfg.hidden_size),
240                "model.embed_tokens.weight",
241            )
242        })
243        .map_err(|e| FerrumError::model(format!("lm_head: {e}")))?;
244    let lm_head_w = LinearWeight::Fp16(
245        GpuWeight::from_tensor(&lm_t, &rs)
246            .map_err(|e| FerrumError::model(format!("lm_head: {e}")))?,
247    );
248
249    // RoPE cos/sin tables — compute from config
250    let (rope_cos, rope_sin) = compute_rope_tables(cfg, device, &rs)?;
251
252    let weights = TransformerGpuWeights {
253        embed_table,
254        layers,
255        final_norm_w,
256        lm_head_w,
257        rope_cos,
258        rope_sin,
259    };
260
261    let dims = ModelDims {
262        hidden_size: cfg.hidden_size,
263        intermediate_size: cfg.intermediate_size,
264        num_attention_heads: cfg.num_attention_heads,
265        num_kv_heads: cfg.num_kv_heads,
266        head_dim: cfg.head_dim,
267        vocab_size: cfg.vocab_size,
268        num_layers: cfg.num_hidden_layers,
269        max_seq_len: cfg.max_seq_len,
270        quantized: false,
271        max_batch_size: std::env::var("FERRUM_MAX_BATCH")
272            .ok()
273            .and_then(|v| v.parse().ok())
274            .unwrap_or(1),
275    };
276
277    rs.synchronize()
278        .map_err(|e| FerrumError::model(format!("stream sync: {e}")))?;
279
280    Ok((weights, dims, rs))
281}
282
283/// Compute RoPE tables for TP (public wrapper).
284#[cfg(feature = "cuda")]
285pub fn compute_rope_tables_for_tp(
286    cfg: &super::tp_weight_loader::TpWeightConfig,
287    device: &CandleDevice,
288    stream: &Arc<candle_core::cuda_backend::cudarc::driver::CudaStream>,
289) -> Result<(GpuWeight, GpuWeight)> {
290    let w = WeightConfig {
291        num_hidden_layers: cfg.num_hidden_layers,
292        hidden_size: cfg.hidden_size,
293        intermediate_size: cfg.intermediate_size,
294        num_attention_heads: cfg.num_attention_heads,
295        num_kv_heads: cfg.num_kv_heads,
296        head_dim: cfg.head_dim,
297        vocab_size: cfg.vocab_size,
298        max_seq_len: cfg.max_seq_len,
299        rope_theta: cfg.rope_theta,
300        has_qk_norm: cfg.has_qk_norm,
301        qkv_fused: false,
302        gate_up_fused: false,
303    };
304    compute_rope_tables(&w, device, stream)
305}
306
307/// Compute RoPE cos/sin tables from config parameters.
308#[cfg(feature = "cuda")]
309fn compute_rope_tables(
310    cfg: &WeightConfig,
311    device: &CandleDevice,
312    stream: &Arc<candle_core::cuda_backend::cudarc::driver::CudaStream>,
313) -> Result<(GpuWeight, GpuWeight)> {
314    let half_dim = cfg.head_dim / 2;
315    let max_len = cfg.max_seq_len;
316
317    // inv_freq = 1.0 / (theta ^ (2i / head_dim)) for i in 0..head_dim/2
318    let mut inv_freq = vec![0f32; half_dim];
319    for i in 0..half_dim {
320        inv_freq[i] = 1.0 / (cfg.rope_theta as f32).powf(2.0 * i as f32 / cfg.head_dim as f32);
321    }
322
323    // cos/sin table: [max_len, half_dim] as f16
324    let total = max_len * half_dim;
325    let mut cos_data = vec![half::f16::ZERO; total];
326    let mut sin_data = vec![half::f16::ZERO; total];
327
328    for pos in 0..max_len {
329        for i in 0..half_dim {
330            let angle = pos as f32 * inv_freq[i];
331            cos_data[pos * half_dim + i] = half::f16::from_f32(angle.cos());
332            sin_data[pos * half_dim + i] = half::f16::from_f32(angle.sin());
333        }
334    }
335
336    // Upload to GPU
337    let cos_slice = stream
338        .clone_htod(&cos_data)
339        .map_err(|e| FerrumError::model(format!("rope cos upload: {e}")))?;
340    let sin_slice = stream
341        .clone_htod(&sin_data)
342        .map_err(|e| FerrumError::model(format!("rope sin upload: {e}")))?;
343
344    Ok((
345        GpuWeight {
346            slice: cos_slice,
347            len: total,
348        },
349        GpuWeight {
350            slice: sin_slice,
351            len: total,
352        },
353    ))
354}