1use crate::backends::q4k::matmul_q4k_f32_dispatch;
7use crate::blis::attention::fused_attention_decode;
8use crate::blis::norms::rms_norm;
9use crate::error::TruenoError;
10use crate::inference::gguf::{GgmlType, GgufFile};
11
12#[derive(Debug, Clone)]
14pub struct ModelConfig {
15 pub hidden_size: usize,
16 pub intermediate_size: usize,
17 pub num_layers: usize,
18 pub num_heads: usize,
19 pub num_kv_heads: usize,
20 pub head_dim: usize,
21 pub vocab_size: usize,
22 pub rms_norm_eps: f32,
23 pub rope_theta: f32,
24 pub max_seq_len: usize,
25 pub arch: String,
26}
27
28impl ModelConfig {
29 pub fn from_gguf(gguf: &GgufFile) -> Result<Self, TruenoError> {
31 let arch = gguf.meta_str("general.architecture").unwrap_or("llama").to_string();
32 let prefix = &arch; let hidden_size = gguf
35 .meta_u32(&format!("{prefix}.embedding_length"))
36 .ok_or_else(|| TruenoError::InvalidInput("Missing embedding_length in GGUF".into()))?
37 as usize;
38
39 let num_heads = gguf
40 .meta_u32(&format!("{prefix}.attention.head_count"))
41 .ok_or_else(|| TruenoError::InvalidInput("Missing head_count in GGUF".into()))?
42 as usize;
43
44 let num_kv_heads = gguf
45 .meta_u32(&format!("{prefix}.attention.head_count_kv"))
46 .unwrap_or(num_heads as u32) as usize;
47
48 let num_layers = gguf
49 .meta_u32(&format!("{prefix}.block_count"))
50 .ok_or_else(|| TruenoError::InvalidInput("Missing block_count in GGUF".into()))?
51 as usize;
52
53 let intermediate_size =
54 gguf.meta_u32(&format!("{prefix}.feed_forward_length")).ok_or_else(|| {
55 TruenoError::InvalidInput("Missing feed_forward_length in GGUF".into())
56 })? as usize;
57
58 let head_dim = hidden_size / num_heads;
59
60 let vocab_size = gguf
61 .meta_u32("tokenizer.ggml.vocab_size")
62 .or_else(|| {
63 gguf.metadata.get("tokenizer.ggml.tokens").and_then(|v| {
65 if let crate::inference::gguf::MetadataValue::Array(arr) = v {
66 Some(arr.len() as u32)
67 } else {
68 None
69 }
70 })
71 })
72 .unwrap_or(32000) as usize;
73
74 let rms_norm_eps =
75 gguf.meta_f32(&format!("{prefix}.attention.layer_norm_rms_epsilon")).unwrap_or(1e-5);
76
77 let rope_theta = gguf.meta_f32(&format!("{prefix}.rope.freq_base")).unwrap_or(10000.0);
78
79 let max_seq_len =
80 gguf.meta_u32(&format!("{prefix}.context_length")).unwrap_or(2048) as usize;
81
82 Ok(Self {
83 hidden_size,
84 intermediate_size,
85 num_layers,
86 num_heads,
87 num_kv_heads,
88 head_dim,
89 vocab_size,
90 rms_norm_eps,
91 rope_theta,
92 max_seq_len,
93 arch,
94 })
95 }
96}
97
98pub enum WeightMatrix {
100 Q4K { data: Vec<u8>, rows: usize },
102 F32 { data: Vec<f32>, rows: usize },
104}
105
106impl WeightMatrix {
107 pub fn rows(&self) -> usize {
108 match self {
109 WeightMatrix::Q4K { rows, .. } => *rows,
110 WeightMatrix::F32 { rows, .. } => *rows,
111 }
112 }
113}
114
115pub struct LayerWeights {
117 pub attn_norm: Vec<f32>,
119 pub q_weight: WeightMatrix,
120 pub k_weight: WeightMatrix,
121 pub v_weight: WeightMatrix,
122 pub o_weight: WeightMatrix,
123 pub q_bias: Option<Vec<f32>>,
125 pub k_bias: Option<Vec<f32>>,
126 pub v_bias: Option<Vec<f32>>,
127
128 pub ffn_norm: Vec<f32>,
130 pub gate_weight: WeightMatrix,
131 pub up_weight: WeightMatrix,
132 pub down_weight: WeightMatrix,
133}
134
135pub struct ModelWeights {
137 pub token_embd: Vec<f32>, pub output_norm: Vec<f32>, pub output_weight: WeightMatrix,
140 pub layers: Vec<LayerWeights>,
141}
142
143pub struct ForwardArena {
147 pub attn_input: Vec<f32>,
148 pub q: Vec<f32>,
149 pub k: Vec<f32>,
150 pub v: Vec<f32>,
151 pub attn_out: Vec<f32>,
152 pub attn_proj: Vec<f32>,
153 pub ffn_input: Vec<f32>,
154 pub gate: Vec<f32>,
155 pub up: Vec<f32>,
156 pub swiglu: Vec<f32>,
157 pub ffn_out: Vec<f32>,
158 pub hidden: Vec<f32>,
159 pub residual: Vec<f32>,
160 pub normed: Vec<f32>,
161 pub k_cache_head: Vec<f32>,
162 pub v_cache_head: Vec<f32>,
163}
164
165impl ForwardArena {
166 pub fn new(config: &ModelConfig) -> Self {
167 let kv_dim = config.num_kv_heads * config.head_dim;
168 let head_cache_size = config.max_seq_len * config.head_dim;
169 Self {
170 attn_input: vec![0.0f32; config.hidden_size],
171 q: vec![0.0f32; config.hidden_size],
172 k: vec![0.0f32; kv_dim],
173 v: vec![0.0f32; kv_dim],
174 attn_out: vec![0.0f32; config.hidden_size],
175 attn_proj: vec![0.0f32; config.hidden_size],
176 ffn_input: vec![0.0f32; config.hidden_size],
177 gate: vec![0.0f32; config.intermediate_size],
178 up: vec![0.0f32; config.intermediate_size],
179 swiglu: vec![0.0f32; config.intermediate_size],
180 ffn_out: vec![0.0f32; config.hidden_size],
181 hidden: vec![0.0f32; config.hidden_size],
182 residual: vec![0.0f32; config.hidden_size],
183 normed: vec![0.0f32; config.hidden_size],
184 k_cache_head: vec![0.0f32; head_cache_size],
185 v_cache_head: vec![0.0f32; head_cache_size],
186 }
187 }
188}
189
190pub struct KvCache {
192 pub k: Vec<Vec<f32>>,
194 pub v: Vec<Vec<f32>>,
196 pub seq_len: usize,
197}
198
199impl KvCache {
200 pub fn new(config: &ModelConfig) -> Self {
201 let kv_dim = config.num_kv_heads * config.head_dim;
202 let layer_size = config.max_seq_len * kv_dim;
203 Self {
204 k: (0..config.num_layers).map(|_| vec![0.0f32; layer_size]).collect(),
205 v: (0..config.num_layers).map(|_| vec![0.0f32; layer_size]).collect(),
206 seq_len: 0,
207 }
208 }
209}
210
211pub struct LlamaModel {
213 pub config: ModelConfig,
214 pub weights: ModelWeights,
215}
216
217impl LlamaModel {
218 pub fn from_gguf(gguf: &GgufFile) -> Result<Self, TruenoError> {
220 let config = ModelConfig::from_gguf(gguf)?;
221
222 eprintln!(
223 "Loading {} model: {}L × {}H ({}h {}kv) × {}I, vocab={}",
224 config.arch,
225 config.num_layers,
226 config.hidden_size,
227 config.num_heads,
228 config.num_kv_heads,
229 config.intermediate_size,
230 config.vocab_size,
231 );
232
233 let weights = load_weights(gguf, &config)?;
234
235 Ok(Self { config, weights })
236 }
237
238 pub fn forward(
242 &self,
243 token_id: u32,
244 pos: usize,
245 kv_cache: &mut KvCache,
246 arena: &mut ForwardArena,
247 ) -> Result<Vec<f32>, TruenoError> {
248 let cfg = &self.config;
249 let w = &self.weights;
250
251 let embd_start = token_id as usize * cfg.hidden_size;
253 let embd_end = embd_start + cfg.hidden_size;
254 if embd_end > w.token_embd.len() {
255 return Err(TruenoError::InvalidInput(format!(
256 "Token ID {token_id} out of range (vocab={})",
257 cfg.vocab_size
258 )));
259 }
260 arena.hidden[..cfg.hidden_size].copy_from_slice(&w.token_embd[embd_start..embd_end]);
261
262 for (layer_idx, lw) in w.layers.iter().enumerate() {
264 self.forward_layer(layer_idx, lw, pos, kv_cache, arena)?;
265 }
266
267 rms_norm(
269 &arena.hidden[..cfg.hidden_size],
270 &w.output_norm,
271 cfg.rms_norm_eps,
272 &mut arena.normed[..cfg.hidden_size],
273 )?;
274
275 let logits =
277 matmul_weight(&w.output_weight, &arena.normed[..cfg.hidden_size], cfg.hidden_size);
278
279 Ok(logits)
280 }
281
282 fn forward_layer(
284 &self,
285 layer_idx: usize,
286 lw: &LayerWeights,
287 pos: usize,
288 kv_cache: &mut KvCache,
289 arena: &mut ForwardArena,
290 ) -> Result<(), TruenoError> {
291 let cfg = &self.config;
292 let kv_dim = cfg.num_kv_heads * cfg.head_dim;
293 let h_sz = cfg.hidden_size;
294
295 rms_norm(
297 &arena.hidden[..h_sz],
298 &lw.attn_norm,
299 cfg.rms_norm_eps,
300 &mut arena.attn_input[..h_sz],
301 )?;
302
303 matmul_weight_into(&lw.q_weight, &arena.attn_input[..h_sz], h_sz, &mut arena.q[..h_sz]);
305 matmul_weight_into(&lw.k_weight, &arena.attn_input[..h_sz], h_sz, &mut arena.k[..kv_dim]);
306 matmul_weight_into(&lw.v_weight, &arena.attn_input[..h_sz], h_sz, &mut arena.v[..kv_dim]);
307
308 if let Some(bias) = &lw.q_bias {
310 for (v, b) in arena.q[..h_sz].iter_mut().zip(bias.iter()) {
311 *v += b;
312 }
313 }
314 if let Some(bias) = &lw.k_bias {
315 for (v, b) in arena.k[..kv_dim].iter_mut().zip(bias.iter()) {
316 *v += b;
317 }
318 }
319 if let Some(bias) = &lw.v_bias {
320 for (v, b) in arena.v[..kv_dim].iter_mut().zip(bias.iter()) {
321 *v += b;
322 }
323 }
324
325 apply_rope(&mut arena.q[..h_sz], cfg.num_heads, cfg.head_dim, pos, cfg.rope_theta);
327 apply_rope(&mut arena.k[..kv_dim], cfg.num_kv_heads, cfg.head_dim, pos, cfg.rope_theta);
328
329 let kv_off = pos * kv_dim;
331 kv_cache.k[layer_idx][kv_off..kv_off + kv_dim].copy_from_slice(&arena.k[..kv_dim]);
332 kv_cache.v[layer_idx][kv_off..kv_off + kv_dim].copy_from_slice(&arena.v[..kv_dim]);
333
334 let seq_len = pos + 1;
335
336 arena.attn_out[..h_sz].fill(0.0);
338 let heads_per_kv = cfg.num_heads / cfg.num_kv_heads;
339
340 for h in 0..cfg.num_heads {
341 let kv_h = h / heads_per_kv;
342 let q_head = &arena.q[h * cfg.head_dim..(h + 1) * cfg.head_dim];
343
344 let view_len = seq_len * cfg.head_dim;
346 for s in 0..seq_len {
347 let src_off = s * kv_dim + kv_h * cfg.head_dim;
348 let dst_off = s * cfg.head_dim;
349 arena.k_cache_head[dst_off..dst_off + cfg.head_dim]
350 .copy_from_slice(&kv_cache.k[layer_idx][src_off..src_off + cfg.head_dim]);
351 arena.v_cache_head[dst_off..dst_off + cfg.head_dim]
352 .copy_from_slice(&kv_cache.v[layer_idx][src_off..src_off + cfg.head_dim]);
353 }
354
355 let out_head = &mut arena.attn_out[h * cfg.head_dim..(h + 1) * cfg.head_dim];
356 fused_attention_decode(
357 q_head,
358 &arena.k_cache_head[..view_len],
359 &arena.v_cache_head[..view_len],
360 cfg.head_dim,
361 seq_len,
362 out_head,
363 );
364 }
365
366 matmul_weight_into(
368 &lw.o_weight,
369 &arena.attn_out[..h_sz],
370 h_sz,
371 &mut arena.attn_proj[..h_sz],
372 );
373
374 for i in 0..h_sz {
376 arena.residual[i] = arena.hidden[i] + arena.attn_proj[i];
377 }
378
379 rms_norm(
381 &arena.residual[..h_sz],
382 &lw.ffn_norm,
383 cfg.rms_norm_eps,
384 &mut arena.ffn_input[..h_sz],
385 )?;
386
387 let i_sz = cfg.intermediate_size;
388 matmul_weight_into(
389 &lw.gate_weight,
390 &arena.ffn_input[..h_sz],
391 h_sz,
392 &mut arena.gate[..i_sz],
393 );
394 matmul_weight_into(&lw.up_weight, &arena.ffn_input[..h_sz], h_sz, &mut arena.up[..i_sz]);
395
396 for i in 0..i_sz {
398 let g = arena.gate[i];
399 let silu_g = g / (1.0 + (-g).exp());
400 arena.swiglu[i] = silu_g * arena.up[i];
401 }
402
403 matmul_weight_into(
405 &lw.down_weight,
406 &arena.swiglu[..i_sz],
407 i_sz,
408 &mut arena.ffn_out[..h_sz],
409 );
410
411 for i in 0..h_sz {
413 arena.hidden[i] = arena.residual[i] + arena.ffn_out[i];
414 }
415
416 Ok(())
417 }
418}
419
420fn apply_rope(x: &mut [f32], num_heads: usize, head_dim: usize, pos: usize, theta: f32) {
422 for h in 0..num_heads {
423 let head = &mut x[h * head_dim..(h + 1) * head_dim];
424 for i in (0..head_dim).step_by(2) {
425 let freq = 1.0 / theta.powf(i as f32 / head_dim as f32);
426 let angle = pos as f32 * freq;
427 let (sin_a, cos_a) = angle.sin_cos();
428 let x0 = head[i];
429 let x1 = head[i + 1];
430 head[i] = x0 * cos_a - x1 * sin_a;
431 head[i + 1] = x0 * sin_a + x1 * cos_a;
432 }
433 }
434}
435
436fn load_weights(gguf: &GgufFile, config: &ModelConfig) -> Result<ModelWeights, TruenoError> {
438 let token_embd = load_f32_or_dequant_tensor(
442 gguf,
443 "token_embd.weight",
444 config.vocab_size * config.hidden_size,
445 )?;
446
447 let output_norm = load_f32_tensor(gguf, "output_norm.weight", config.hidden_size)?;
449
450 let output_weight = if gguf.tensor_info("output.weight").is_some() {
453 load_weight_matrix(gguf, "output.weight", config.hidden_size)?
454 } else {
455 WeightMatrix::F32 { data: token_embd.clone(), rows: config.vocab_size }
457 };
458
459 let mut layers = Vec::with_capacity(config.num_layers);
461 for i in 0..config.num_layers {
462 let prefix = format!("blk.{i}");
463
464 let attn_norm =
465 load_f32_tensor(gguf, &format!("{prefix}.attn_norm.weight"), config.hidden_size)?;
466 let ffn_norm =
467 load_f32_tensor(gguf, &format!("{prefix}.ffn_norm.weight"), config.hidden_size)?;
468
469 let q_weight =
470 load_weight_matrix(gguf, &format!("{prefix}.attn_q.weight"), config.hidden_size)?;
471 let k_weight =
472 load_weight_matrix(gguf, &format!("{prefix}.attn_k.weight"), config.hidden_size)?;
473 let v_weight =
474 load_weight_matrix(gguf, &format!("{prefix}.attn_v.weight"), config.hidden_size)?;
475 let o_weight =
476 load_weight_matrix(gguf, &format!("{prefix}.attn_output.weight"), config.hidden_size)?;
477
478 let kv_dim = config.num_kv_heads * config.head_dim;
480 let q_bias = load_optional_f32(gguf, &format!("{prefix}.attn_q.bias"), config.hidden_size);
481 let k_bias = load_optional_f32(gguf, &format!("{prefix}.attn_k.bias"), kv_dim);
482 let v_bias = load_optional_f32(gguf, &format!("{prefix}.attn_v.bias"), kv_dim);
483
484 let gate_weight =
485 load_weight_matrix(gguf, &format!("{prefix}.ffn_gate.weight"), config.hidden_size)?;
486 let up_weight =
487 load_weight_matrix(gguf, &format!("{prefix}.ffn_up.weight"), config.hidden_size)?;
488 let down_weight = load_weight_matrix(
489 gguf,
490 &format!("{prefix}.ffn_down.weight"),
491 config.intermediate_size,
492 )?;
493
494 if i == 0 {
495 eprintln!(
496 " Layer 0: Q[{}×{}] K[{}×{}] V[{}×{}] Gate[{}×{}]",
497 q_weight.rows(),
498 config.hidden_size,
499 k_weight.rows(),
500 config.hidden_size,
501 v_weight.rows(),
502 config.hidden_size,
503 gate_weight.rows(),
504 config.hidden_size,
505 );
506 }
507
508 layers.push(LayerWeights {
509 attn_norm,
510 q_weight,
511 k_weight,
512 v_weight,
513 o_weight,
514 q_bias,
515 k_bias,
516 v_bias,
517 ffn_norm,
518 gate_weight,
519 up_weight,
520 down_weight,
521 });
522 }
523
524 eprintln!(" Loaded {} layers", layers.len());
525
526 Ok(ModelWeights { token_embd, output_norm, output_weight, layers })
527}
528
529fn load_f32_or_dequant_tensor(
532 gguf: &GgufFile,
533 name: &str,
534 expected_elements: usize,
535) -> Result<Vec<f32>, TruenoError> {
536 let info = gguf
537 .tensor_info(name)
538 .ok_or_else(|| TruenoError::InvalidInput(format!("Missing tensor: {name}")))?;
539 let data = gguf
540 .tensor_data(name)
541 .ok_or_else(|| TruenoError::InvalidInput(format!("Missing tensor data: {name}")))?;
542
543 match info.dtype {
544 GgmlType::F32 | GgmlType::F16 | GgmlType::Bf16 => {
545 Ok(to_f32_from_any(data, info.dtype, expected_elements))
546 }
547 GgmlType::Q4K => {
548 let n_elements = info.n_elements() as usize;
549 Ok(crate::backends::q4k::dequantize_q4k_to_f32(data, n_elements))
550 }
551 GgmlType::Q6K => Ok(dequantize_q6k_to_f32(data, info.n_elements() as usize)),
552 GgmlType::Q5K => Ok(dequantize_q5k_to_f32(data, info.n_elements() as usize)),
553 GgmlType::Q8_0 => Ok(dequantize_q8_0_to_f32(data, info.n_elements() as usize)),
554 GgmlType::Q4_0 => Ok(dequantize_q4_0_to_f32(data, info.n_elements() as usize)),
555 GgmlType::Q4_1 => Ok(dequantize_q4_1_to_f32(data, info.n_elements() as usize)),
556 _ => {
557 eprintln!(
558 " WARNING: tensor '{name}' has unsupported dtype {:?}, using zeros",
559 info.dtype
560 );
561 Ok(vec![0.0f32; expected_elements])
562 }
563 }
564}
565
566fn load_optional_f32(gguf: &GgufFile, name: &str, expected_elements: usize) -> Option<Vec<f32>> {
568 let info = gguf.tensor_info(name)?;
569 let data = gguf.tensor_data(name)?;
570 Some(to_f32_from_any(data, info.dtype, expected_elements))
571}
572
573fn load_f32_tensor(
575 gguf: &GgufFile,
576 name: &str,
577 expected_elements: usize,
578) -> Result<Vec<f32>, TruenoError> {
579 let info = gguf
580 .tensor_info(name)
581 .ok_or_else(|| TruenoError::InvalidInput(format!("Missing tensor: {name}")))?;
582 let data = gguf
583 .tensor_data(name)
584 .ok_or_else(|| TruenoError::InvalidInput(format!("Missing tensor data: {name}")))?;
585
586 Ok(to_f32_from_any(data, info.dtype, expected_elements))
587}
588
589fn load_weight_matrix(
593 gguf: &GgufFile,
594 name: &str,
595 in_dim: usize,
596) -> Result<WeightMatrix, TruenoError> {
597 let info = gguf
598 .tensor_info(name)
599 .ok_or_else(|| TruenoError::InvalidInput(format!("Missing tensor: {name}")))?;
600 let data = gguf
601 .tensor_data(name)
602 .ok_or_else(|| TruenoError::InvalidInput(format!("Missing tensor data: {name}")))?;
603
604 let n_elements = info.n_elements() as usize;
605 let out_dim = n_elements / in_dim;
606
607 match info.dtype {
608 GgmlType::Q4K => Ok(WeightMatrix::Q4K { data: data.to_vec(), rows: out_dim }),
609 GgmlType::F32 | GgmlType::F16 | GgmlType::Bf16 => {
610 let f32_data = to_f32_from_any(data, info.dtype, n_elements);
611 Ok(WeightMatrix::F32 { data: f32_data, rows: out_dim })
612 }
613 GgmlType::Q6K => {
614 let f32_data = dequantize_q6k_to_f32(data, n_elements);
615 Ok(WeightMatrix::F32 { data: f32_data, rows: out_dim })
616 }
617 GgmlType::Q5K => {
618 let f32_data = dequantize_q5k_to_f32(data, n_elements);
619 Ok(WeightMatrix::F32 { data: f32_data, rows: out_dim })
620 }
621 GgmlType::Q8_0 => {
622 let f32_data = dequantize_q8_0_to_f32(data, n_elements);
623 Ok(WeightMatrix::F32 { data: f32_data, rows: out_dim })
624 }
625 GgmlType::Q4_0 => {
626 let f32_data = dequantize_q4_0_to_f32(data, n_elements);
627 Ok(WeightMatrix::F32 { data: f32_data, rows: out_dim })
628 }
629 GgmlType::Q4_1 => {
630 let f32_data = dequantize_q4_1_to_f32(data, n_elements);
631 Ok(WeightMatrix::F32 { data: f32_data, rows: out_dim })
632 }
633 _ => {
634 eprintln!(" WARNING: tensor '{name}' dtype {:?} unsupported, using zeros", info.dtype);
635 Ok(WeightMatrix::F32 { data: vec![0.0f32; n_elements], rows: out_dim })
636 }
637 }
638}
639
640fn matmul_weight(weight: &WeightMatrix, input: &[f32], in_dim: usize) -> Vec<f32> {
643 match weight {
644 WeightMatrix::Q4K { data, rows } => matmul_q4k_f32_dispatch(data, input, *rows, in_dim),
645 WeightMatrix::F32 { data, rows } => {
646 let mut out = vec![0.0f32; *rows];
647 for i in 0..*rows {
648 let row = &data[i * in_dim..(i + 1) * in_dim];
649 out[i] = row.iter().zip(input.iter()).map(|(a, b)| a * b).sum();
650 }
651 out
652 }
653 }
654}
655
656fn matmul_weight_into(weight: &WeightMatrix, input: &[f32], in_dim: usize, out: &mut [f32]) {
658 match weight {
659 WeightMatrix::Q4K { data, rows } => {
660 let result = matmul_q4k_f32_dispatch(data, input, *rows, in_dim);
661 out[..*rows].copy_from_slice(&result);
662 }
663 WeightMatrix::F32 { data, rows } => {
664 for i in 0..*rows {
665 let row = &data[i * in_dim..(i + 1) * in_dim];
666 out[i] = row.iter().zip(input.iter()).map(|(a, b)| a * b).sum();
667 }
668 }
669 }
670}
671
672fn f16_to_f32(bits: u16) -> f32 {
674 let sign = ((bits >> 15) as u32) << 31;
675 let exp = ((bits >> 10) & 0x1F) as u32;
676 let mant = (bits & 0x3FF) as u32;
677
678 if exp == 0 {
679 if mant == 0 {
680 return f32::from_bits(sign); }
682 let mut m = mant;
684 let mut e: i32 = -14;
685 while m & 0x400 == 0 {
686 m <<= 1;
687 e -= 1;
688 }
689 m &= 0x3FF;
690 let f32_exp = ((e + 127) as u32) << 23;
691 return f32::from_bits(sign | f32_exp | (m << 13));
692 }
693 if exp == 31 {
694 return f32::from_bits(sign | 0x7F80_0000 | (mant << 13));
696 }
697 let f32_exp = (exp + 112) << 23; f32::from_bits(sign | f32_exp | (mant << 13))
699}
700
701fn to_f32_from_any(data: &[u8], dtype: GgmlType, n_elements: usize) -> Vec<f32> {
703 match dtype {
704 GgmlType::F32 => {
705 let count = n_elements.min(data.len() / 4);
707 (0..count)
708 .map(|i| {
709 let off = i * 4;
710 f32::from_le_bytes([data[off], data[off + 1], data[off + 2], data[off + 3]])
711 })
712 .collect()
713 }
714 GgmlType::F16 => {
715 let count = n_elements.min(data.len() / 2);
716 (0..count)
717 .map(|i| {
718 let off = i * 2;
719 let bits = u16::from_le_bytes([data[off], data[off + 1]]);
720 f16_to_f32(bits)
721 })
722 .collect()
723 }
724 GgmlType::Bf16 => {
725 let count = n_elements.min(data.len() / 2);
726 (0..count)
727 .map(|i| {
728 let off = i * 2;
729 let bits = u16::from_le_bytes([data[off], data[off + 1]]);
730 f32::from_bits((bits as u32) << 16)
731 })
732 .collect()
733 }
734 _ => {
735 vec![0.0f32; n_elements]
737 }
738 }
739}
740
741fn dequantize_q6k_to_f32(data: &[u8], num_elements: usize) -> Vec<f32> {
751 const BLOCK_SIZE: usize = 256;
752 const BLOCK_BYTES: usize = 210;
753
754 let num_blocks = (num_elements + BLOCK_SIZE - 1) / BLOCK_SIZE;
755 let mut result = vec![0.0f32; num_blocks * BLOCK_SIZE];
756
757 for sb in 0..num_blocks {
758 let sb_start = sb * BLOCK_BYTES;
759 if sb_start + BLOCK_BYTES > data.len() {
760 break;
761 }
762 let block = &data[sb_start..sb_start + BLOCK_BYTES];
763 let ql = &block[0..128];
764 let qh = &block[128..192];
765 let scales = &block[192..208];
766 let d = f16_to_f32(u16::from_le_bytes([block[208], block[209]]));
767
768 let out_base = sb * BLOCK_SIZE;
769 for group in 0..16usize {
770 let scale = (scales[group] as i8) as f32;
771 let group_off = group * 16;
772 for j in 0..16usize {
773 let idx = group_off + j;
774 let ql_byte = ql[idx / 2];
775 let low4 = if idx % 2 == 0 { ql_byte & 0x0F } else { ql_byte >> 4 };
776 let qh_byte = qh[idx / 4];
777 let high2 = (qh_byte >> ((idx % 4) * 2)) & 0x03;
778 let q6 = ((low4 | (high2 << 4)) as i8).wrapping_sub(32) as f32;
779 result[out_base + idx] = d * scale * q6;
780 }
781 }
782 }
783
784 result.truncate(num_elements);
785 result
786}
787
788fn dequantize_q5k_to_f32(data: &[u8], num_elements: usize) -> Vec<f32> {
799 const BLOCK_SIZE: usize = 256;
800 const BLOCK_BYTES: usize = 176;
801
802 let num_blocks = (num_elements + BLOCK_SIZE - 1) / BLOCK_SIZE;
803 let mut result = vec![0.0f32; num_blocks * BLOCK_SIZE];
804
805 for sb in 0..num_blocks {
806 let sb_start = sb * BLOCK_BYTES;
807 if sb_start + BLOCK_BYTES > data.len() {
808 break;
809 }
810 let block = &data[sb_start..sb_start + BLOCK_BYTES];
811 let d = f16_to_f32(u16::from_le_bytes([block[0], block[1]]));
812 let dmin = f16_to_f32(u16::from_le_bytes([block[2], block[3]]));
813
814 let sc = &block[4..16];
816 let mut scales = [0u8; 8];
817 let mut mins = [0u8; 8];
818 for i in 0..4 {
819 scales[i] = sc[i] & 0x3F;
820 mins[i] = sc[i + 4] & 0x3F;
821 scales[i + 4] = (sc[i + 8] & 0x0F) | ((sc[i] >> 6) << 4);
822 mins[i + 4] = (sc[i + 8] >> 4) | ((sc[i + 4] >> 6) << 4);
823 }
824
825 let qh = &block[16..48];
826 let qs = &block[48..176];
827
828 let out_base = sb * BLOCK_SIZE;
829 for sub in 0..8usize {
830 let scale = d * scales[sub] as f32;
831 let min = dmin * mins[sub] as f32;
832 let sub_off = sub * 32;
833 for j in 0..32usize {
834 let idx = sub_off + j;
835 let low4 = (qs[idx / 2] >> ((idx % 2) * 4)) & 0x0F;
836 let high1 = (qh[idx / 8] >> (idx % 8)) & 0x01;
837 let q5 = (low4 | (high1 << 4)) as f32;
838 result[out_base + idx] = scale * q5 - min;
839 }
840 }
841 }
842
843 result.truncate(num_elements);
844 result
845}
846
847fn dequantize_q8_0_to_f32(data: &[u8], num_elements: usize) -> Vec<f32> {
855 const BLOCK_SIZE: usize = 32;
856 const BLOCK_BYTES: usize = 34;
857
858 let num_blocks = (num_elements + BLOCK_SIZE - 1) / BLOCK_SIZE;
859 let mut result = vec![0.0f32; num_blocks * BLOCK_SIZE];
860
861 for b in 0..num_blocks {
862 let b_start = b * BLOCK_BYTES;
863 if b_start + BLOCK_BYTES > data.len() {
864 break;
865 }
866 let block = &data[b_start..b_start + BLOCK_BYTES];
867 let d = f16_to_f32(u16::from_le_bytes([block[0], block[1]]));
868 let out_base = b * BLOCK_SIZE;
869 for j in 0..BLOCK_SIZE {
870 result[out_base + j] = d * (block[2 + j] as i8) as f32;
871 }
872 }
873
874 result.truncate(num_elements);
875 result
876}
877
878fn dequantize_q4_0_to_f32(data: &[u8], num_elements: usize) -> Vec<f32> {
886 const BLOCK_SIZE: usize = 32;
887 const BLOCK_BYTES: usize = 18;
888
889 let num_blocks = (num_elements + BLOCK_SIZE - 1) / BLOCK_SIZE;
890 let mut result = vec![0.0f32; num_blocks * BLOCK_SIZE];
891
892 for b in 0..num_blocks {
893 let b_start = b * BLOCK_BYTES;
894 if b_start + BLOCK_BYTES > data.len() {
895 break;
896 }
897 let block = &data[b_start..b_start + BLOCK_BYTES];
898 let d = f16_to_f32(u16::from_le_bytes([block[0], block[1]]));
899 let out_base = b * BLOCK_SIZE;
900 for j in 0..16 {
901 let byte = block[2 + j];
902 let lo = (byte & 0x0F) as i32 - 8;
903 let hi = ((byte >> 4) & 0x0F) as i32 - 8;
904 result[out_base + j * 2] = d * lo as f32;
905 result[out_base + j * 2 + 1] = d * hi as f32;
906 }
907 }
908
909 result.truncate(num_elements);
910 result
911}
912
913fn dequantize_q4_1_to_f32(data: &[u8], num_elements: usize) -> Vec<f32> {
922 const BLOCK_SIZE: usize = 32;
923 const BLOCK_BYTES: usize = 20;
924
925 let num_blocks = (num_elements + BLOCK_SIZE - 1) / BLOCK_SIZE;
926 let mut result = vec![0.0f32; num_blocks * BLOCK_SIZE];
927
928 for b in 0..num_blocks {
929 let b_start = b * BLOCK_BYTES;
930 if b_start + BLOCK_BYTES > data.len() {
931 break;
932 }
933 let block = &data[b_start..b_start + BLOCK_BYTES];
934 let d = f16_to_f32(u16::from_le_bytes([block[0], block[1]]));
935 let m = f16_to_f32(u16::from_le_bytes([block[2], block[3]]));
936 let out_base = b * BLOCK_SIZE;
937 for j in 0..16 {
938 let byte = block[4 + j];
939 let lo = (byte & 0x0F) as f32;
940 let hi = ((byte >> 4) & 0x0F) as f32;
941 result[out_base + j * 2] = d * lo + m;
942 result[out_base + j * 2 + 1] = d * hi + m;
943 }
944 }
945
946 result.truncate(num_elements);
947 result
948}