1#[cfg(feature = "cuda")]
8use candle_core::{DType, Device as CandleDevice, Tensor};
9#[cfg(feature = "cuda")]
10use candle_nn::VarBuilder;
11#[cfg(feature = "cuda")]
12use ferrum_cuda_kernels::{
13 decode_buffers::ModelDims,
14 weight_store::{GpuWeight, LayerWeights, LinearWeight, TransformerGpuWeights},
15};
16#[cfg(feature = "cuda")]
17use ferrum_types::{FerrumError, Result};
18#[cfg(feature = "cuda")]
19use std::sync::Arc;
20
21#[cfg(feature = "cuda")]
23pub struct WeightConfig {
24 pub num_hidden_layers: usize,
25 pub hidden_size: usize,
26 pub intermediate_size: usize,
27 pub num_attention_heads: usize,
28 pub num_kv_heads: usize,
29 pub head_dim: usize,
30 pub vocab_size: usize,
31 pub max_seq_len: usize,
32 pub rope_theta: f64,
33 pub has_qk_norm: bool,
35 pub qkv_fused: bool,
37 pub gate_up_fused: bool,
39}
40
41#[cfg(feature = "cuda")]
46pub fn load_runner_weights(
47 vb: &VarBuilder,
48 cfg: &WeightConfig,
49 device: &CandleDevice,
50) -> Result<(
51 TransformerGpuWeights,
52 ModelDims,
53 Arc<candle_core::cuda_backend::cudarc::driver::CudaStream>,
54)> {
55 use candle_core::cuda_backend::CudaDevice;
56
57 let cuda_device = device
58 .as_cuda_device()
59 .map_err(|e| FerrumError::model(format!("not CUDA: {e}")))?;
60
61 let candle_stream = cuda_device.cuda_stream();
63 candle_stream
64 .synchronize()
65 .map_err(|e| FerrumError::model(format!("candle stream sync: {e}")))?;
66 let rs = candle_stream
67 .context()
68 .new_stream()
69 .map_err(|e| FerrumError::model(format!("new_stream: {e}")))?;
70
71 let embed_t = vb
73 .get(
74 (cfg.vocab_size, cfg.hidden_size),
75 "model.embed_tokens.weight",
76 )
77 .map_err(|e| FerrumError::model(format!("embed: {e}")))?;
78 let embed_table = GpuWeight::from_tensor(&embed_t, &rs)
79 .map_err(|e| FerrumError::model(format!("embed: {e}")))?;
80
81 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
83 let q_dim = cfg.num_attention_heads * cfg.head_dim;
84 let kv_dim = cfg.num_kv_heads * cfg.head_dim;
85
86 for li in 0..cfg.num_hidden_layers {
87 let prefix = format!("model.layers.{li}");
88
89 let ln_w = vb
91 .get(cfg.hidden_size, &format!("{prefix}.input_layernorm.weight"))
92 .map_err(|e| FerrumError::model(format!("input_ln L{li}: {e}")))?;
93 let input_ln_w = GpuWeight::from_tensor(&ln_w, &rs)
94 .map_err(|e| FerrumError::model(format!("input_ln: {e}")))?;
95
96 let qkv_tensor = if cfg.qkv_fused {
98 vb.get(
99 (q_dim + 2 * kv_dim, cfg.hidden_size),
100 &format!("{prefix}.self_attn.qkv_proj.weight"),
101 )
102 .map_err(|e| FerrumError::model(format!("qkv L{li}: {e}")))?
103 } else {
104 let q = vb
105 .get(
106 (q_dim, cfg.hidden_size),
107 &format!("{prefix}.self_attn.q_proj.weight"),
108 )
109 .map_err(|e| FerrumError::model(format!("q L{li}: {e}")))?;
110 let k = vb
111 .get(
112 (kv_dim, cfg.hidden_size),
113 &format!("{prefix}.self_attn.k_proj.weight"),
114 )
115 .map_err(|e| FerrumError::model(format!("k L{li}: {e}")))?;
116 let v = vb
117 .get(
118 (kv_dim, cfg.hidden_size),
119 &format!("{prefix}.self_attn.v_proj.weight"),
120 )
121 .map_err(|e| FerrumError::model(format!("v L{li}: {e}")))?;
122 Tensor::cat(&[&q, &k, &v], 0)
123 .map_err(|e| FerrumError::model(format!("qkv cat L{li}: {e}")))?
124 };
125 let qkv_w = LinearWeight::Fp16(
126 GpuWeight::from_tensor(&qkv_tensor, &rs)
127 .map_err(|e| FerrumError::model(format!("qkv: {e}")))?,
128 );
129
130 let q_norm_w = if cfg.has_qk_norm {
132 let t = vb
133 .get(cfg.head_dim, &format!("{prefix}.self_attn.q_norm.weight"))
134 .map_err(|e| FerrumError::model(format!("q_norm L{li}: {e}")))?;
135 Some(
136 GpuWeight::from_tensor(&t, &rs)
137 .map_err(|e| FerrumError::model(format!("q_norm: {e}")))?,
138 )
139 } else {
140 None
141 };
142 let k_norm_w = if cfg.has_qk_norm {
143 let t = vb
144 .get(cfg.head_dim, &format!("{prefix}.self_attn.k_norm.weight"))
145 .map_err(|e| FerrumError::model(format!("k_norm L{li}: {e}")))?;
146 Some(
147 GpuWeight::from_tensor(&t, &rs)
148 .map_err(|e| FerrumError::model(format!("k_norm: {e}")))?,
149 )
150 } else {
151 None
152 };
153
154 let o_t = vb
156 .get(
157 (cfg.hidden_size, q_dim),
158 &format!("{prefix}.self_attn.o_proj.weight"),
159 )
160 .map_err(|e| FerrumError::model(format!("o L{li}: {e}")))?;
161 let o_w = LinearWeight::Fp16(
162 GpuWeight::from_tensor(&o_t, &rs).map_err(|e| FerrumError::model(format!("o: {e}")))?,
163 );
164
165 let pln_t = vb
167 .get(
168 cfg.hidden_size,
169 &format!("{prefix}.post_attention_layernorm.weight"),
170 )
171 .map_err(|e| FerrumError::model(format!("post_ln L{li}: {e}")))?;
172 let post_ln_w = GpuWeight::from_tensor(&pln_t, &rs)
173 .map_err(|e| FerrumError::model(format!("post_ln: {e}")))?;
174
175 let gate_up_tensor = if cfg.gate_up_fused {
177 vb.get(
178 (2 * cfg.intermediate_size, cfg.hidden_size),
179 &format!("{prefix}.mlp.gate_up_proj.weight"),
180 )
181 .map_err(|e| FerrumError::model(format!("gate_up L{li}: {e}")))?
182 } else {
183 let gate = vb
184 .get(
185 (cfg.intermediate_size, cfg.hidden_size),
186 &format!("{prefix}.mlp.gate_proj.weight"),
187 )
188 .map_err(|e| FerrumError::model(format!("gate L{li}: {e}")))?;
189 let up = vb
190 .get(
191 (cfg.intermediate_size, cfg.hidden_size),
192 &format!("{prefix}.mlp.up_proj.weight"),
193 )
194 .map_err(|e| FerrumError::model(format!("up L{li}: {e}")))?;
195 Tensor::cat(&[&gate, &up], 0)
196 .map_err(|e| FerrumError::model(format!("gate_up cat L{li}: {e}")))?
197 };
198 let gate_up_w = LinearWeight::Fp16(
199 GpuWeight::from_tensor(&gate_up_tensor, &rs)
200 .map_err(|e| FerrumError::model(format!("gate_up: {e}")))?,
201 );
202
203 let down_t = vb
205 .get(
206 (cfg.hidden_size, cfg.intermediate_size),
207 &format!("{prefix}.mlp.down_proj.weight"),
208 )
209 .map_err(|e| FerrumError::model(format!("down L{li}: {e}")))?;
210 let down_w = LinearWeight::Fp16(
211 GpuWeight::from_tensor(&down_t, &rs)
212 .map_err(|e| FerrumError::model(format!("down: {e}")))?,
213 );
214
215 layers.push(LayerWeights {
216 input_ln_w,
217 qkv_w,
218 q_norm_w,
219 k_norm_w,
220 o_w,
221 post_ln_w,
222 gate_up_w,
223 down_w,
224 });
225 }
226
227 let fn_t = vb
229 .get(cfg.hidden_size, "model.norm.weight")
230 .map_err(|e| FerrumError::model(format!("final_norm: {e}")))?;
231 let final_norm_w = GpuWeight::from_tensor(&fn_t, &rs)
232 .map_err(|e| FerrumError::model(format!("final_norm: {e}")))?;
233
234 let lm_t = vb
236 .get((cfg.vocab_size, cfg.hidden_size), "lm_head.weight")
237 .map_err(|e| FerrumError::model(format!("lm_head: {e}")))?;
238 let lm_head_w = LinearWeight::Fp16(
239 GpuWeight::from_tensor(&lm_t, &rs)
240 .map_err(|e| FerrumError::model(format!("lm_head: {e}")))?,
241 );
242
243 let (rope_cos, rope_sin) = compute_rope_tables(cfg, device, &rs)?;
245
246 let weights = TransformerGpuWeights {
247 embed_table,
248 layers,
249 final_norm_w,
250 lm_head_w,
251 rope_cos,
252 rope_sin,
253 };
254
255 let dims = ModelDims {
256 hidden_size: cfg.hidden_size,
257 intermediate_size: cfg.intermediate_size,
258 num_attention_heads: cfg.num_attention_heads,
259 num_kv_heads: cfg.num_kv_heads,
260 head_dim: cfg.head_dim,
261 vocab_size: cfg.vocab_size,
262 num_layers: cfg.num_hidden_layers,
263 max_seq_len: cfg.max_seq_len,
264 quantized: false,
265 max_batch_size: std::env::var("FERRUM_MAX_BATCH")
266 .ok()
267 .and_then(|v| v.parse().ok())
268 .unwrap_or(1),
269 };
270
271 rs.synchronize()
272 .map_err(|e| FerrumError::model(format!("stream sync: {e}")))?;
273
274 Ok((weights, dims, rs))
275}
276
277#[cfg(feature = "cuda")]
279fn compute_rope_tables(
280 cfg: &WeightConfig,
281 device: &CandleDevice,
282 stream: &Arc<candle_core::cuda_backend::cudarc::driver::CudaStream>,
283) -> Result<(GpuWeight, GpuWeight)> {
284 let half_dim = cfg.head_dim / 2;
285 let max_len = cfg.max_seq_len;
286
287 let mut inv_freq = vec![0f32; half_dim];
289 for i in 0..half_dim {
290 inv_freq[i] = 1.0 / (cfg.rope_theta as f32).powf(2.0 * i as f32 / cfg.head_dim as f32);
291 }
292
293 let total = max_len * half_dim;
295 let mut cos_data = vec![half::f16::ZERO; total];
296 let mut sin_data = vec![half::f16::ZERO; total];
297
298 for pos in 0..max_len {
299 for i in 0..half_dim {
300 let angle = pos as f32 * inv_freq[i];
301 cos_data[pos * half_dim + i] = half::f16::from_f32(angle.cos());
302 sin_data[pos * half_dim + i] = half::f16::from_f32(angle.sin());
303 }
304 }
305
306 let cos_slice = stream
308 .clone_htod(&cos_data)
309 .map_err(|e| FerrumError::model(format!("rope cos upload: {e}")))?;
310 let sin_slice = stream
311 .clone_htod(&sin_data)
312 .map_err(|e| FerrumError::model(format!("rope sin upload: {e}")))?;
313
314 Ok((
315 GpuWeight {
316 slice: cos_slice,
317 len: total,
318 },
319 GpuWeight {
320 slice: sin_slice,
321 len: total,
322 },
323 ))
324}