ferrum_quantization/gguf/
loader.rs1use std::path::Path;
21use std::sync::Arc;
22
23use candle_core::Device;
24use ferrum_kernels::backend::Backend;
25use ferrum_types::{FerrumError, Result};
26
27use crate::config::QuantConfig;
28use crate::gguf::file::GgufFile;
29use crate::gguf::linear::GgufLinear;
30use crate::gguf::names::{ferrum_to_gguf, gate_up_split_parts, qkv_split_parts};
31use crate::loader::WeightLoader;
32use crate::traits::Linear;
33
34pub struct GgufLoader<B: Backend> {
40 gguf: Arc<GgufFile>,
41 decode_device: Device,
46 _marker: std::marker::PhantomData<B>,
47}
48
49impl<B: Backend> GgufLoader<B> {
50 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
53 let gguf = GgufFile::open(path).map_err(candle_to_ferrum)?;
54 Ok(Self {
55 gguf: Arc::new(gguf),
56 decode_device: Device::Cpu,
57 _marker: std::marker::PhantomData,
58 })
59 }
60
61 pub fn from_file(gguf: Arc<GgufFile>) -> Self {
64 Self {
65 gguf,
66 decode_device: Device::Cpu,
67 _marker: std::marker::PhantomData,
68 }
69 }
70
71 pub fn gguf(&self) -> &GgufFile {
75 &self.gguf
76 }
77
78 fn locate(&self, ferrum_name: &str) -> Result<String> {
83 let gguf_name = ferrum_to_gguf(ferrum_name).ok_or_else(|| {
84 FerrumError::model(format!(
85 "GgufLoader: unrecognised tensor name '{ferrum_name}' (no GGUF mapping)"
86 ))
87 })?;
88 if !self.gguf.has_tensor(&gguf_name) {
89 return Err(FerrumError::model(format!(
90 "GgufLoader: tensor '{ferrum_name}' (mapped to '{gguf_name}') not present in GGUF"
91 )));
92 }
93 Ok(gguf_name)
94 }
95
96 fn read_dequant(&self, gguf_name: &str) -> Result<Vec<f32>> {
99 let qt = self
100 .gguf
101 .read_tensor(gguf_name, &self.decode_device)
102 .map_err(candle_to_ferrum)?;
103 let dense = qt
104 .dequantize(&self.decode_device)
105 .map_err(candle_to_ferrum)?;
106 let flat = dense.flatten_all().map_err(candle_to_ferrum)?;
107 flat.to_vec1::<f32>().map_err(candle_to_ferrum)
108 }
109
110 fn rows_cols(&self, gguf_name: &str) -> Result<(usize, usize)> {
114 let info = self
115 .gguf
116 .tensor_info(gguf_name)
117 .ok_or_else(|| FerrumError::model(format!("tensor info missing for '{gguf_name}'")))?;
118 let dims = info.shape.dims();
119 if dims.len() != 2 {
120 return Err(FerrumError::model(format!(
121 "expected 2-D tensor for '{gguf_name}', got rank {}",
122 dims.len()
123 )));
124 }
125 Ok((dims[0], dims[1]))
126 }
127
128 fn load_fused(&self, parts: &[String]) -> Result<Box<dyn Linear<B>>> {
146 if let Some(fast) = self.try_load_fused_q4k(parts)? {
147 if std::env::var("FERRUM_GGUF_LOAD_TRACE").is_ok() {
148 eprintln!("[gguf-load] {:?} → fused-Q4 (homogeneous)", parts);
149 }
150 return Ok(fast);
151 }
152 if let Some(multi) = self.try_load_fused_multi_quant(parts)? {
153 if std::env::var("FERRUM_GGUF_LOAD_TRACE").is_ok() {
154 eprintln!("[gguf-load] {:?} → MultiQuant (mixed dtype)", parts);
155 }
156 return Ok(multi);
157 }
158 if std::env::var("FERRUM_GGUF_LOAD_TRACE").is_ok() {
159 eprintln!("[gguf-load] {:?} → eager fp32 fallback ⚠", parts);
160 }
161 self.load_fused_eager(parts)
162 }
163
164 fn try_load_fused_multi_quant(&self, parts: &[String]) -> Result<Option<Box<dyn Linear<B>>>> {
171 let mut spec: Vec<(ferrum_kernels::backend::GgufQuantType, &[u8], usize)> = Vec::new();
172 let mut cols_check: Option<usize> = None;
173
174 for stem in parts {
175 let weight_name = format!("{stem}.weight");
176 let gguf_name = ferrum_to_gguf(&weight_name).ok_or_else(|| {
177 FerrumError::model(format!(
178 "GgufLoader: fusion source '{weight_name}' has no GGUF mapping"
179 ))
180 })?;
181 if !self.gguf.has_tensor(&gguf_name) {
182 return Err(FerrumError::model(format!(
183 "GgufLoader: fusion source '{weight_name}' (gguf '{gguf_name}') missing"
184 )));
185 }
186
187 let has_bias = ferrum_to_gguf(&format!("{stem}.bias"))
190 .map(|n| self.gguf.has_tensor(&n))
191 .unwrap_or(false);
192 if has_bias {
193 return Ok(None);
194 }
195
196 let info = self.gguf.tensor_info(&gguf_name).ok_or_else(|| {
197 FerrumError::model(format!("tensor_info missing for '{gguf_name}'"))
198 })?;
199 let kind = match info.ggml_dtype {
200 candle_core::quantized::GgmlDType::Q4K => {
201 ferrum_kernels::backend::GgufQuantType::Q4K
202 }
203 candle_core::quantized::GgmlDType::Q6K => {
204 ferrum_kernels::backend::GgufQuantType::Q6K
205 }
206 _ => return Ok(None), };
208
209 let dims = info.shape.dims();
210 if dims.len() != 2 {
211 return Ok(None);
212 }
213 let (rows, cols) = (dims[0], dims[1]);
214 if cols % 256 != 0 {
215 return Ok(None);
216 }
217 match cols_check {
218 Some(c) if c != cols => {
219 return Err(FerrumError::model(format!(
220 "GgufLoader: fusion in_features mismatch ({c} vs {cols} for '{stem}')"
221 )))
222 }
223 _ => cols_check = Some(cols),
224 }
225
226 let bytes = self.gguf.tensor_byte_slice(&gguf_name).ok_or_else(|| {
230 FerrumError::model(format!(
231 "GgufLoader: tensor_byte_slice failed for '{gguf_name}'"
232 ))
233 })?;
234 spec.push((kind, bytes, rows));
235 }
236
237 let cols = cols_check.ok_or_else(|| FerrumError::model("fusion: no parts"))?;
238 let parts_view: Vec<(_, &[u8], _)> = spec
239 .iter()
240 .map(|(kind, bytes, rows)| (*kind, *bytes, *rows))
241 .collect();
242 let quant = match crate::QuantLinear::<B>::from_gguf_fused(&parts_view, cols) {
243 Ok(q) => q,
244 Err(_) => return Ok(None), };
246 Ok(Some(Box::new(quant)))
247 }
248
249 fn try_load_fused_q4k(&self, parts: &[String]) -> Result<Option<Box<dyn Linear<B>>>> {
252 let mut fused_bytes: Vec<u8> = Vec::new();
253 let mut total_rows = 0usize;
254 let mut cols_check: Option<usize> = None;
255
256 for stem in parts {
257 let weight_name = format!("{stem}.weight");
258 let gguf_name = ferrum_to_gguf(&weight_name).ok_or_else(|| {
259 FerrumError::model(format!(
260 "GgufLoader: fusion source '{weight_name}' has no GGUF mapping"
261 ))
262 })?;
263 if !self.gguf.has_tensor(&gguf_name) {
264 return Err(FerrumError::model(format!(
265 "GgufLoader: fusion source '{weight_name}' (gguf '{gguf_name}') missing"
266 )));
267 }
268
269 let bias_name = ferrum_to_gguf(&format!("{stem}.bias"))
272 .map(|n| self.gguf.has_tensor(&n))
273 .unwrap_or(false);
274 if bias_name {
275 return Ok(None);
276 }
277
278 let info = self.gguf.tensor_info(&gguf_name).ok_or_else(|| {
279 FerrumError::model(format!("tensor_info missing for '{gguf_name}'"))
280 })?;
281
282 if !matches!(info.ggml_dtype, candle_core::quantized::GgmlDType::Q4K) {
284 return Ok(None);
285 }
286
287 let dims = info.shape.dims();
288 if dims.len() != 2 {
289 return Ok(None);
290 }
291 let (rows, cols) = (dims[0], dims[1]);
292
293 if cols % 256 != 0 {
297 return Ok(None);
298 }
299
300 match cols_check {
301 Some(c) if c != cols => {
302 return Err(FerrumError::model(format!(
303 "GgufLoader: fusion in_features mismatch ({c} vs {cols} for '{stem}')"
304 )))
305 }
306 _ => cols_check = Some(cols),
307 }
308
309 let bytes = self.gguf.tensor_byte_slice(&gguf_name).ok_or_else(|| {
316 FerrumError::model(format!(
317 "GgufLoader: tensor_byte_slice failed for '{gguf_name}'"
318 ))
319 })?;
320 let expected = rows * (cols / 256) * 144;
322 debug_assert_eq!(
323 bytes.len(),
324 expected,
325 "Q4K byte count mismatch for '{gguf_name}': got {} expected {}",
326 bytes.len(),
327 expected
328 );
329
330 fused_bytes.extend_from_slice(bytes);
331 total_rows += rows;
332 }
333
334 let cols = cols_check.ok_or_else(|| FerrumError::model("fusion: no parts"))?;
335 let quant = crate::QuantLinear::<B>::from_gguf_bytes(
336 ferrum_kernels::backend::GgufQuantType::Q4K,
337 &fused_bytes,
338 total_rows,
339 cols,
340 )?;
341 Ok(Some(Box::new(quant)))
342 }
343
344 fn load_fused_eager(&self, parts: &[String]) -> Result<Box<dyn Linear<B>>> {
347 let mut fused: Vec<f32> = Vec::new();
348 let mut total_rows = 0usize;
349 let mut cols_check: Option<usize> = None;
350
351 for stem in parts {
352 let weight_name = format!("{stem}.weight");
353 let gguf_name = ferrum_to_gguf(&weight_name).ok_or_else(|| {
354 FerrumError::model(format!(
355 "GgufLoader: fusion source '{weight_name}' has no GGUF mapping"
356 ))
357 })?;
358 if !self.gguf.has_tensor(&gguf_name) {
359 return Err(FerrumError::model(format!(
360 "GgufLoader: fusion source '{weight_name}' (gguf '{gguf_name}') missing"
361 )));
362 }
363 let (rows, cols) = self.rows_cols(&gguf_name)?;
364 match cols_check {
365 Some(c) if c != cols => {
366 return Err(FerrumError::model(format!(
367 "GgufLoader: fusion in_features mismatch ({c} vs {cols} for '{stem}')"
368 )))
369 }
370 _ => cols_check = Some(cols),
371 }
372 let data = self.read_dequant(&gguf_name)?;
373 debug_assert_eq!(data.len(), rows * cols);
374 fused.extend_from_slice(&data);
375 total_rows += rows;
376 }
377
378 let cols = cols_check.ok_or_else(|| FerrumError::model("fusion: no parts"))?;
379 Ok(Box::new(GgufLinear::<B>::from_dense_rows(
380 &fused, total_rows, cols,
381 )))
382 }
383}
384
385impl<B: Backend> WeightLoader<B> for GgufLoader<B> {
386 fn load_tensor(&self, name: &str) -> Result<B::Buffer> {
387 let gguf_name = self.locate(name)?;
388 let raw = self.read_dequant(&gguf_name)?;
389 Ok(B::from_slice(&raw))
390 }
391
392 fn load_linear(&self, name: &str) -> Result<Box<dyn Linear<B>>> {
393 if let Some(gguf_weight) = ferrum_to_gguf(&format!("{name}.weight")) {
395 if self.gguf.has_tensor(&gguf_weight) {
396 let info = self.gguf.tensor_info(&gguf_weight).ok_or_else(|| {
402 FerrumError::model(format!("tensor_info missing for '{gguf_weight}'"))
403 })?;
404 let dims = info.shape.dims();
405 if dims.len() != 2 {
406 return Err(FerrumError::model(format!(
407 "GgufLoader::load_linear '{name}': expected rank-2 weight, got rank {}",
408 dims.len()
409 )));
410 }
411 let (n_rows, n_cols) = (dims[0], dims[1]);
412
413 let quant_kind = match info.ggml_dtype {
414 candle_core::quantized::GgmlDType::Q4K => {
415 Some(ferrum_kernels::backend::GgufQuantType::Q4K)
416 }
417 candle_core::quantized::GgmlDType::Q6K => {
418 Some(ferrum_kernels::backend::GgufQuantType::Q6K)
419 }
420 _ => None,
421 };
422 if let Some(kind) = quant_kind {
423 let has_bias = ferrum_to_gguf(&format!("{name}.bias"))
432 .map(|n| self.gguf.has_tensor(&n))
433 .unwrap_or(false);
434 if !has_bias {
435 let bytes = self.gguf.tensor_byte_slice(&gguf_weight).ok_or_else(|| {
444 FerrumError::model(format!(
445 "GgufLoader: tensor_byte_slice failed for '{gguf_weight}'"
446 ))
447 })?;
448 let quant =
449 crate::QuantLinear::<B>::from_gguf_bytes(kind, bytes, n_rows, n_cols)?;
450 return Ok(Box::new(quant));
451 }
452 }
454
455 let qt = self
456 .gguf
457 .read_tensor(&gguf_weight, &self.decode_device)
458 .map_err(candle_to_ferrum)?;
459 if let Some(gguf_bias) = ferrum_to_gguf(&format!("{name}.bias")) {
460 if self.gguf.has_tensor(&gguf_bias) {
461 let bqt = self
462 .gguf
463 .read_tensor(&gguf_bias, &self.decode_device)
464 .map_err(candle_to_ferrum)?;
465 let linear = GgufLinear::<B>::from_qtensor_with_bias(&qt, &bqt)
466 .map_err(candle_to_ferrum)?;
467 return Ok(Box::new(linear));
468 }
469 }
470 let linear = GgufLinear::<B>::from_qtensor(&qt).map_err(candle_to_ferrum)?;
471 return Ok(Box::new(linear));
472 }
473 }
474
475 if let Some(layer_prefix) = name.strip_suffix("self_attn.qkv_proj") {
477 let parts = qkv_split_parts(layer_prefix);
478 return self.load_fused(&parts);
479 }
480 if let Some(layer_prefix) = name.strip_suffix("mlp.gate_up_proj") {
482 let parts = gate_up_split_parts(layer_prefix);
483 return self.load_fused(&parts);
484 }
485
486 Err(FerrumError::model(format!(
487 "GgufLoader: could not load Linear '{name}' — no direct weight, no split components"
488 )))
489 }
490
491 fn has_tensor(&self, name: &str) -> bool {
492 match ferrum_to_gguf(name) {
493 Some(g) => self.gguf.has_tensor(&g),
494 None => false,
495 }
496 }
497
498 fn quant_config(&self) -> Option<&QuantConfig> {
499 None
505 }
506}
507
508fn candle_to_ferrum(e: candle_core::Error) -> FerrumError {
509 FerrumError::model(format!("candle: {e}"))
510}