ferrum_quantization/gguf/
loader.rs1use std::path::Path;
21use std::sync::Arc;
22
23use candle_core::Device;
24use ferrum_kernels::backend::{Backend, BackendQuantGguf, BackendQuantMarlin};
25use ferrum_types::{FerrumError, Result};
26
27use crate::config::QuantConfig;
28use crate::gguf::file::GgufFile;
29use crate::gguf::linear::GgufLinear;
30use crate::gguf::names::{ferrum_to_gguf, gate_up_split_parts, qkv_split_parts};
31use crate::loader::WeightLoader;
32use crate::traits::Linear;
33
34const GGUF_LOAD_TRACE_ENV: &str = "FERRUM_GGUF_LOAD_TRACE";
35
36#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
37struct GgufLoaderRuntimeConfig {
38 load_trace: bool,
39}
40
41impl GgufLoaderRuntimeConfig {
42 fn from_env() -> Self {
43 Self::from_env_vars(std::env::vars())
44 }
45
46 fn from_env_vars<I, K, V>(vars: I) -> Self
47 where
48 I: IntoIterator<Item = (K, V)>,
49 K: Into<String>,
50 V: Into<String>,
51 {
52 Self {
53 load_trace: vars
54 .into_iter()
55 .any(|(name, _value)| name.into() == GGUF_LOAD_TRACE_ENV),
56 }
57 }
58}
59
60pub struct GgufLoader<B: Backend + BackendQuantGguf + BackendQuantMarlin> {
66 gguf: Arc<GgufFile>,
67 decode_device: Device,
72 runtime_config: GgufLoaderRuntimeConfig,
73 _marker: std::marker::PhantomData<B>,
74}
75
76impl<B: Backend + BackendQuantGguf + BackendQuantMarlin> GgufLoader<B> {
77 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
80 let gguf = GgufFile::open(path).map_err(candle_to_ferrum)?;
81 Ok(Self {
82 gguf: Arc::new(gguf),
83 decode_device: Device::Cpu,
84 runtime_config: GgufLoaderRuntimeConfig::from_env(),
85 _marker: std::marker::PhantomData,
86 })
87 }
88
89 pub fn from_file(gguf: Arc<GgufFile>) -> Self {
92 Self {
93 gguf,
94 decode_device: Device::Cpu,
95 runtime_config: GgufLoaderRuntimeConfig::from_env(),
96 _marker: std::marker::PhantomData,
97 }
98 }
99
100 pub fn gguf(&self) -> &GgufFile {
104 &self.gguf
105 }
106
107 fn locate(&self, ferrum_name: &str) -> Result<String> {
112 let gguf_name = ferrum_to_gguf(ferrum_name).ok_or_else(|| {
113 FerrumError::model(format!(
114 "GgufLoader: unrecognised tensor name '{ferrum_name}' (no GGUF mapping)"
115 ))
116 })?;
117 if !self.gguf.has_tensor(&gguf_name) {
118 return Err(FerrumError::model(format!(
119 "GgufLoader: tensor '{ferrum_name}' (mapped to '{gguf_name}') not present in GGUF"
120 )));
121 }
122 Ok(gguf_name)
123 }
124
125 fn read_dequant(&self, gguf_name: &str) -> Result<Vec<f32>> {
128 let qt = self
129 .gguf
130 .read_tensor(gguf_name, &self.decode_device)
131 .map_err(candle_to_ferrum)?;
132 let dense = qt
133 .dequantize(&self.decode_device)
134 .map_err(candle_to_ferrum)?;
135 let flat = dense.flatten_all().map_err(candle_to_ferrum)?;
136 flat.to_vec1::<f32>().map_err(candle_to_ferrum)
137 }
138
139 fn rows_cols(&self, gguf_name: &str) -> Result<(usize, usize)> {
143 let info = self
144 .gguf
145 .tensor_info(gguf_name)
146 .ok_or_else(|| FerrumError::model(format!("tensor info missing for '{gguf_name}'")))?;
147 let dims = info.shape.dims();
148 if dims.len() != 2 {
149 return Err(FerrumError::model(format!(
150 "expected 2-D tensor for '{gguf_name}', got rank {}",
151 dims.len()
152 )));
153 }
154 Ok((dims[0], dims[1]))
155 }
156
157 fn load_fused(&self, parts: &[String]) -> Result<Box<dyn Linear<B>>> {
175 if let Some(fast) = self.try_load_fused_q4k(parts)? {
176 if self.runtime_config.load_trace {
177 eprintln!("[gguf-load] {:?} → fused-Q4 (homogeneous)", parts);
178 }
179 return Ok(fast);
180 }
181 if let Some(multi) = self.try_load_fused_multi_quant(parts)? {
182 if self.runtime_config.load_trace {
183 eprintln!("[gguf-load] {:?} → MultiQuant (mixed dtype)", parts);
184 }
185 return Ok(multi);
186 }
187 if self.runtime_config.load_trace {
188 eprintln!("[gguf-load] {:?} → eager fp32 fallback ⚠", parts);
189 }
190 self.load_fused_eager(parts)
191 }
192
193 fn try_load_fused_multi_quant(&self, parts: &[String]) -> Result<Option<Box<dyn Linear<B>>>> {
200 let mut spec: Vec<(ferrum_kernels::backend::GgufQuantType, &[u8], usize)> = Vec::new();
201 let mut cols_check: Option<usize> = None;
202
203 for stem in parts {
204 let weight_name = format!("{stem}.weight");
205 let gguf_name = ferrum_to_gguf(&weight_name).ok_or_else(|| {
206 FerrumError::model(format!(
207 "GgufLoader: fusion source '{weight_name}' has no GGUF mapping"
208 ))
209 })?;
210 if !self.gguf.has_tensor(&gguf_name) {
211 return Err(FerrumError::model(format!(
212 "GgufLoader: fusion source '{weight_name}' (gguf '{gguf_name}') missing"
213 )));
214 }
215
216 let has_bias = ferrum_to_gguf(&format!("{stem}.bias"))
219 .map(|n| self.gguf.has_tensor(&n))
220 .unwrap_or(false);
221 if has_bias {
222 return Ok(None);
223 }
224
225 let info = self.gguf.tensor_info(&gguf_name).ok_or_else(|| {
226 FerrumError::model(format!("tensor_info missing for '{gguf_name}'"))
227 })?;
228 let kind = match info.ggml_dtype {
229 candle_core::quantized::GgmlDType::Q4K => {
230 ferrum_kernels::backend::GgufQuantType::Q4K
231 }
232 candle_core::quantized::GgmlDType::Q6K => {
233 ferrum_kernels::backend::GgufQuantType::Q6K
234 }
235 _ => return Ok(None), };
237
238 let dims = info.shape.dims();
239 if dims.len() != 2 {
240 return Ok(None);
241 }
242 let (rows, cols) = (dims[0], dims[1]);
243 if cols % 256 != 0 {
244 return Ok(None);
245 }
246 match cols_check {
247 Some(c) if c != cols => {
248 return Err(FerrumError::model(format!(
249 "GgufLoader: fusion in_features mismatch ({c} vs {cols} for '{stem}')"
250 )))
251 }
252 _ => cols_check = Some(cols),
253 }
254
255 let bytes = self.gguf.tensor_byte_slice(&gguf_name).ok_or_else(|| {
259 FerrumError::model(format!(
260 "GgufLoader: tensor_byte_slice failed for '{gguf_name}'"
261 ))
262 })?;
263 spec.push((kind, bytes, rows));
264 }
265
266 let cols = cols_check.ok_or_else(|| FerrumError::model("fusion: no parts"))?;
267 let parts_view: Vec<(_, &[u8], _)> = spec
268 .iter()
269 .map(|(kind, bytes, rows)| (*kind, *bytes, *rows))
270 .collect();
271 let quant = match crate::QuantLinear::<B>::from_gguf_fused(&parts_view, cols) {
272 Ok(q) => q,
273 Err(_) => return Ok(None), };
275 Ok(Some(Box::new(quant)))
276 }
277
278 fn try_load_fused_q4k(&self, parts: &[String]) -> Result<Option<Box<dyn Linear<B>>>> {
281 let mut fused_bytes: Vec<u8> = Vec::new();
282 let mut total_rows = 0usize;
283 let mut cols_check: Option<usize> = None;
284
285 for stem in parts {
286 let weight_name = format!("{stem}.weight");
287 let gguf_name = ferrum_to_gguf(&weight_name).ok_or_else(|| {
288 FerrumError::model(format!(
289 "GgufLoader: fusion source '{weight_name}' has no GGUF mapping"
290 ))
291 })?;
292 if !self.gguf.has_tensor(&gguf_name) {
293 return Err(FerrumError::model(format!(
294 "GgufLoader: fusion source '{weight_name}' (gguf '{gguf_name}') missing"
295 )));
296 }
297
298 let bias_name = ferrum_to_gguf(&format!("{stem}.bias"))
301 .map(|n| self.gguf.has_tensor(&n))
302 .unwrap_or(false);
303 if bias_name {
304 return Ok(None);
305 }
306
307 let info = self.gguf.tensor_info(&gguf_name).ok_or_else(|| {
308 FerrumError::model(format!("tensor_info missing for '{gguf_name}'"))
309 })?;
310
311 if !matches!(info.ggml_dtype, candle_core::quantized::GgmlDType::Q4K) {
313 return Ok(None);
314 }
315
316 let dims = info.shape.dims();
317 if dims.len() != 2 {
318 return Ok(None);
319 }
320 let (rows, cols) = (dims[0], dims[1]);
321
322 if cols % 256 != 0 {
326 return Ok(None);
327 }
328
329 match cols_check {
330 Some(c) if c != cols => {
331 return Err(FerrumError::model(format!(
332 "GgufLoader: fusion in_features mismatch ({c} vs {cols} for '{stem}')"
333 )))
334 }
335 _ => cols_check = Some(cols),
336 }
337
338 let bytes = self.gguf.tensor_byte_slice(&gguf_name).ok_or_else(|| {
345 FerrumError::model(format!(
346 "GgufLoader: tensor_byte_slice failed for '{gguf_name}'"
347 ))
348 })?;
349 let expected = rows * (cols / 256) * 144;
351 debug_assert_eq!(
352 bytes.len(),
353 expected,
354 "Q4K byte count mismatch for '{gguf_name}': got {} expected {}",
355 bytes.len(),
356 expected
357 );
358
359 fused_bytes.extend_from_slice(bytes);
360 total_rows += rows;
361 }
362
363 let cols = cols_check.ok_or_else(|| FerrumError::model("fusion: no parts"))?;
364 let quant = crate::QuantLinear::<B>::from_gguf_bytes(
365 ferrum_kernels::backend::GgufQuantType::Q4K,
366 &fused_bytes,
367 total_rows,
368 cols,
369 )?;
370 Ok(Some(Box::new(quant)))
371 }
372
373 fn load_fused_eager(&self, parts: &[String]) -> Result<Box<dyn Linear<B>>> {
376 let mut fused: Vec<f32> = Vec::new();
377 let mut total_rows = 0usize;
378 let mut cols_check: Option<usize> = None;
379
380 for stem in parts {
381 let weight_name = format!("{stem}.weight");
382 let gguf_name = ferrum_to_gguf(&weight_name).ok_or_else(|| {
383 FerrumError::model(format!(
384 "GgufLoader: fusion source '{weight_name}' has no GGUF mapping"
385 ))
386 })?;
387 if !self.gguf.has_tensor(&gguf_name) {
388 return Err(FerrumError::model(format!(
389 "GgufLoader: fusion source '{weight_name}' (gguf '{gguf_name}') missing"
390 )));
391 }
392 let (rows, cols) = self.rows_cols(&gguf_name)?;
393 match cols_check {
394 Some(c) if c != cols => {
395 return Err(FerrumError::model(format!(
396 "GgufLoader: fusion in_features mismatch ({c} vs {cols} for '{stem}')"
397 )))
398 }
399 _ => cols_check = Some(cols),
400 }
401 let data = self.read_dequant(&gguf_name)?;
402 debug_assert_eq!(data.len(), rows * cols);
403 fused.extend_from_slice(&data);
404 total_rows += rows;
405 }
406
407 let cols = cols_check.ok_or_else(|| FerrumError::model("fusion: no parts"))?;
408 Ok(Box::new(GgufLinear::<B>::from_dense_rows(
409 &fused, total_rows, cols,
410 )))
411 }
412}
413
414impl<B: Backend + BackendQuantGguf + BackendQuantMarlin> WeightLoader<B> for GgufLoader<B> {
415 fn load_tensor(&self, name: &str) -> Result<B::Buffer> {
416 let gguf_name = self.locate(name)?;
417 let raw = self.read_dequant(&gguf_name)?;
418 Ok(B::from_slice(&raw))
419 }
420
421 fn load_linear(&self, name: &str) -> Result<Box<dyn Linear<B>>> {
422 if let Some(gguf_weight) = ferrum_to_gguf(&format!("{name}.weight")) {
424 if self.gguf.has_tensor(&gguf_weight) {
425 let info = self.gguf.tensor_info(&gguf_weight).ok_or_else(|| {
431 FerrumError::model(format!("tensor_info missing for '{gguf_weight}'"))
432 })?;
433 let dims = info.shape.dims();
434 if dims.len() != 2 {
435 return Err(FerrumError::model(format!(
436 "GgufLoader::load_linear '{name}': expected rank-2 weight, got rank {}",
437 dims.len()
438 )));
439 }
440 let (n_rows, n_cols) = (dims[0], dims[1]);
441
442 let quant_kind = match info.ggml_dtype {
443 candle_core::quantized::GgmlDType::Q4K => {
444 Some(ferrum_kernels::backend::GgufQuantType::Q4K)
445 }
446 candle_core::quantized::GgmlDType::Q6K => {
447 Some(ferrum_kernels::backend::GgufQuantType::Q6K)
448 }
449 _ => None,
450 };
451 if let Some(kind) = quant_kind {
452 let has_bias = ferrum_to_gguf(&format!("{name}.bias"))
461 .map(|n| self.gguf.has_tensor(&n))
462 .unwrap_or(false);
463 if !has_bias {
464 let bytes = self.gguf.tensor_byte_slice(&gguf_weight).ok_or_else(|| {
473 FerrumError::model(format!(
474 "GgufLoader: tensor_byte_slice failed for '{gguf_weight}'"
475 ))
476 })?;
477 let quant =
478 crate::QuantLinear::<B>::from_gguf_bytes(kind, bytes, n_rows, n_cols)?;
479 return Ok(Box::new(quant));
480 }
481 }
483
484 let qt = self
485 .gguf
486 .read_tensor(&gguf_weight, &self.decode_device)
487 .map_err(candle_to_ferrum)?;
488 if let Some(gguf_bias) = ferrum_to_gguf(&format!("{name}.bias")) {
489 if self.gguf.has_tensor(&gguf_bias) {
490 let bqt = self
491 .gguf
492 .read_tensor(&gguf_bias, &self.decode_device)
493 .map_err(candle_to_ferrum)?;
494 let linear = GgufLinear::<B>::from_qtensor_with_bias(&qt, &bqt)
495 .map_err(candle_to_ferrum)?;
496 return Ok(Box::new(linear));
497 }
498 }
499 let linear = GgufLinear::<B>::from_qtensor(&qt).map_err(candle_to_ferrum)?;
500 return Ok(Box::new(linear));
501 }
502 }
503
504 if let Some(layer_prefix) = name.strip_suffix("self_attn.qkv_proj") {
506 let parts = qkv_split_parts(layer_prefix);
507 return self.load_fused(&parts);
508 }
509 if let Some(layer_prefix) = name.strip_suffix("mlp.gate_up_proj") {
511 let parts = gate_up_split_parts(layer_prefix);
512 return self.load_fused(&parts);
513 }
514
515 Err(FerrumError::model(format!(
516 "GgufLoader: could not load Linear '{name}' — no direct weight, no split components"
517 )))
518 }
519
520 fn has_tensor(&self, name: &str) -> bool {
521 match ferrum_to_gguf(name) {
522 Some(g) => self.gguf.has_tensor(&g),
523 None => false,
524 }
525 }
526
527 fn quant_config(&self) -> Option<&QuantConfig> {
528 None
534 }
535}
536
537fn candle_to_ferrum(e: candle_core::Error) -> FerrumError {
538 FerrumError::model(format!("candle: {e}"))
539}
540
541#[cfg(test)]
542mod tests {
543 use super::*;
544
545 #[test]
546 fn gguf_loader_runtime_config_parses_load_trace_presence() {
547 let cfg =
548 GgufLoaderRuntimeConfig::from_env_vars([(GGUF_LOAD_TRACE_ENV, ""), ("OTHER", "1")]);
549 assert!(cfg.load_trace);
550
551 let cfg = GgufLoaderRuntimeConfig::from_env_vars([("OTHER", "1")]);
552 assert!(!cfg.load_trace);
553 }
554}