use std::path::Path;
use std::sync::Arc;
use candle_core::Device;
use ferrum_kernels::backend::{Backend, BackendQuantGguf, BackendQuantMarlin};
use ferrum_types::{FerrumError, Result};
use crate::config::QuantConfig;
use crate::gguf::file::GgufFile;
use crate::gguf::linear::GgufLinear;
use crate::gguf::names::{ferrum_to_gguf, gate_up_split_parts, qkv_split_parts};
use crate::loader::WeightLoader;
use crate::traits::Linear;
const GGUF_LOAD_TRACE_ENV: &str = "FERRUM_GGUF_LOAD_TRACE";
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
struct GgufLoaderRuntimeConfig {
load_trace: bool,
}
impl GgufLoaderRuntimeConfig {
fn from_env() -> Self {
Self::from_env_vars(std::env::vars())
}
fn from_env_vars<I, K, V>(vars: I) -> Self
where
I: IntoIterator<Item = (K, V)>,
K: Into<String>,
V: Into<String>,
{
Self {
load_trace: vars
.into_iter()
.any(|(name, _value)| name.into() == GGUF_LOAD_TRACE_ENV),
}
}
}
pub struct GgufLoader<B: Backend + BackendQuantGguf + BackendQuantMarlin> {
gguf: Arc<GgufFile>,
decode_device: Device,
runtime_config: GgufLoaderRuntimeConfig,
_marker: std::marker::PhantomData<B>,
}
impl<B: Backend + BackendQuantGguf + BackendQuantMarlin> GgufLoader<B> {
pub fn open(path: impl AsRef<Path>) -> Result<Self> {
let gguf = GgufFile::open(path).map_err(candle_to_ferrum)?;
Ok(Self {
gguf: Arc::new(gguf),
decode_device: Device::Cpu,
runtime_config: GgufLoaderRuntimeConfig::from_env(),
_marker: std::marker::PhantomData,
})
}
pub fn from_file(gguf: Arc<GgufFile>) -> Self {
Self {
gguf,
decode_device: Device::Cpu,
runtime_config: GgufLoaderRuntimeConfig::from_env(),
_marker: std::marker::PhantomData,
}
}
pub fn gguf(&self) -> &GgufFile {
&self.gguf
}
fn locate(&self, ferrum_name: &str) -> Result<String> {
let gguf_name = ferrum_to_gguf(ferrum_name).ok_or_else(|| {
FerrumError::model(format!(
"GgufLoader: unrecognised tensor name '{ferrum_name}' (no GGUF mapping)"
))
})?;
if !self.gguf.has_tensor(&gguf_name) {
return Err(FerrumError::model(format!(
"GgufLoader: tensor '{ferrum_name}' (mapped to '{gguf_name}') not present in GGUF"
)));
}
Ok(gguf_name)
}
fn read_dequant(&self, gguf_name: &str) -> Result<Vec<f32>> {
let qt = self
.gguf
.read_tensor(gguf_name, &self.decode_device)
.map_err(candle_to_ferrum)?;
let dense = qt
.dequantize(&self.decode_device)
.map_err(candle_to_ferrum)?;
let flat = dense.flatten_all().map_err(candle_to_ferrum)?;
flat.to_vec1::<f32>().map_err(candle_to_ferrum)
}
fn rows_cols(&self, gguf_name: &str) -> Result<(usize, usize)> {
let info = self
.gguf
.tensor_info(gguf_name)
.ok_or_else(|| FerrumError::model(format!("tensor info missing for '{gguf_name}'")))?;
let dims = info.shape.dims();
if dims.len() != 2 {
return Err(FerrumError::model(format!(
"expected 2-D tensor for '{gguf_name}', got rank {}",
dims.len()
)));
}
Ok((dims[0], dims[1]))
}
fn load_fused(&self, parts: &[String]) -> Result<Box<dyn Linear<B>>> {
if let Some(fast) = self.try_load_fused_q4k(parts)? {
if self.runtime_config.load_trace {
eprintln!("[gguf-load] {:?} → fused-Q4 (homogeneous)", parts);
}
return Ok(fast);
}
if let Some(multi) = self.try_load_fused_multi_quant(parts)? {
if self.runtime_config.load_trace {
eprintln!("[gguf-load] {:?} → MultiQuant (mixed dtype)", parts);
}
return Ok(multi);
}
if self.runtime_config.load_trace {
eprintln!("[gguf-load] {:?} → eager fp32 fallback ⚠", parts);
}
self.load_fused_eager(parts)
}
fn try_load_fused_multi_quant(&self, parts: &[String]) -> Result<Option<Box<dyn Linear<B>>>> {
let mut spec: Vec<(ferrum_kernels::backend::GgufQuantType, &[u8], usize)> = Vec::new();
let mut cols_check: Option<usize> = None;
for stem in parts {
let weight_name = format!("{stem}.weight");
let gguf_name = ferrum_to_gguf(&weight_name).ok_or_else(|| {
FerrumError::model(format!(
"GgufLoader: fusion source '{weight_name}' has no GGUF mapping"
))
})?;
if !self.gguf.has_tensor(&gguf_name) {
return Err(FerrumError::model(format!(
"GgufLoader: fusion source '{weight_name}' (gguf '{gguf_name}') missing"
)));
}
let has_bias = ferrum_to_gguf(&format!("{stem}.bias"))
.map(|n| self.gguf.has_tensor(&n))
.unwrap_or(false);
if has_bias {
return Ok(None);
}
let info = self.gguf.tensor_info(&gguf_name).ok_or_else(|| {
FerrumError::model(format!("tensor_info missing for '{gguf_name}'"))
})?;
let kind = match info.ggml_dtype {
candle_core::quantized::GgmlDType::Q4K => {
ferrum_kernels::backend::GgufQuantType::Q4K
}
candle_core::quantized::GgmlDType::Q6K => {
ferrum_kernels::backend::GgufQuantType::Q6K
}
_ => return Ok(None), };
let dims = info.shape.dims();
if dims.len() != 2 {
return Ok(None);
}
let (rows, cols) = (dims[0], dims[1]);
if cols % 256 != 0 {
return Ok(None);
}
match cols_check {
Some(c) if c != cols => {
return Err(FerrumError::model(format!(
"GgufLoader: fusion in_features mismatch ({c} vs {cols} for '{stem}')"
)))
}
_ => cols_check = Some(cols),
}
let bytes = self.gguf.tensor_byte_slice(&gguf_name).ok_or_else(|| {
FerrumError::model(format!(
"GgufLoader: tensor_byte_slice failed for '{gguf_name}'"
))
})?;
spec.push((kind, bytes, rows));
}
let cols = cols_check.ok_or_else(|| FerrumError::model("fusion: no parts"))?;
let parts_view: Vec<(_, &[u8], _)> = spec
.iter()
.map(|(kind, bytes, rows)| (*kind, *bytes, *rows))
.collect();
let quant = match crate::QuantLinear::<B>::from_gguf_fused(&parts_view, cols) {
Ok(q) => q,
Err(_) => return Ok(None), };
Ok(Some(Box::new(quant)))
}
fn try_load_fused_q4k(&self, parts: &[String]) -> Result<Option<Box<dyn Linear<B>>>> {
let mut fused_bytes: Vec<u8> = Vec::new();
let mut total_rows = 0usize;
let mut cols_check: Option<usize> = None;
for stem in parts {
let weight_name = format!("{stem}.weight");
let gguf_name = ferrum_to_gguf(&weight_name).ok_or_else(|| {
FerrumError::model(format!(
"GgufLoader: fusion source '{weight_name}' has no GGUF mapping"
))
})?;
if !self.gguf.has_tensor(&gguf_name) {
return Err(FerrumError::model(format!(
"GgufLoader: fusion source '{weight_name}' (gguf '{gguf_name}') missing"
)));
}
let bias_name = ferrum_to_gguf(&format!("{stem}.bias"))
.map(|n| self.gguf.has_tensor(&n))
.unwrap_or(false);
if bias_name {
return Ok(None);
}
let info = self.gguf.tensor_info(&gguf_name).ok_or_else(|| {
FerrumError::model(format!("tensor_info missing for '{gguf_name}'"))
})?;
if !matches!(info.ggml_dtype, candle_core::quantized::GgmlDType::Q4K) {
return Ok(None);
}
let dims = info.shape.dims();
if dims.len() != 2 {
return Ok(None);
}
let (rows, cols) = (dims[0], dims[1]);
if cols % 256 != 0 {
return Ok(None);
}
match cols_check {
Some(c) if c != cols => {
return Err(FerrumError::model(format!(
"GgufLoader: fusion in_features mismatch ({c} vs {cols} for '{stem}')"
)))
}
_ => cols_check = Some(cols),
}
let bytes = self.gguf.tensor_byte_slice(&gguf_name).ok_or_else(|| {
FerrumError::model(format!(
"GgufLoader: tensor_byte_slice failed for '{gguf_name}'"
))
})?;
let expected = rows * (cols / 256) * 144;
debug_assert_eq!(
bytes.len(),
expected,
"Q4K byte count mismatch for '{gguf_name}': got {} expected {}",
bytes.len(),
expected
);
fused_bytes.extend_from_slice(bytes);
total_rows += rows;
}
let cols = cols_check.ok_or_else(|| FerrumError::model("fusion: no parts"))?;
let quant = crate::QuantLinear::<B>::from_gguf_bytes(
ferrum_kernels::backend::GgufQuantType::Q4K,
&fused_bytes,
total_rows,
cols,
)?;
Ok(Some(Box::new(quant)))
}
fn load_fused_eager(&self, parts: &[String]) -> Result<Box<dyn Linear<B>>> {
let mut fused: Vec<f32> = Vec::new();
let mut total_rows = 0usize;
let mut cols_check: Option<usize> = None;
for stem in parts {
let weight_name = format!("{stem}.weight");
let gguf_name = ferrum_to_gguf(&weight_name).ok_or_else(|| {
FerrumError::model(format!(
"GgufLoader: fusion source '{weight_name}' has no GGUF mapping"
))
})?;
if !self.gguf.has_tensor(&gguf_name) {
return Err(FerrumError::model(format!(
"GgufLoader: fusion source '{weight_name}' (gguf '{gguf_name}') missing"
)));
}
let (rows, cols) = self.rows_cols(&gguf_name)?;
match cols_check {
Some(c) if c != cols => {
return Err(FerrumError::model(format!(
"GgufLoader: fusion in_features mismatch ({c} vs {cols} for '{stem}')"
)))
}
_ => cols_check = Some(cols),
}
let data = self.read_dequant(&gguf_name)?;
debug_assert_eq!(data.len(), rows * cols);
fused.extend_from_slice(&data);
total_rows += rows;
}
let cols = cols_check.ok_or_else(|| FerrumError::model("fusion: no parts"))?;
Ok(Box::new(GgufLinear::<B>::from_dense_rows(
&fused, total_rows, cols,
)))
}
}
impl<B: Backend + BackendQuantGguf + BackendQuantMarlin> WeightLoader<B> for GgufLoader<B> {
fn load_tensor(&self, name: &str) -> Result<B::Buffer> {
let gguf_name = self.locate(name)?;
let raw = self.read_dequant(&gguf_name)?;
Ok(B::from_slice(&raw))
}
fn load_linear(&self, name: &str) -> Result<Box<dyn Linear<B>>> {
if let Some(gguf_weight) = ferrum_to_gguf(&format!("{name}.weight")) {
if self.gguf.has_tensor(&gguf_weight) {
let info = self.gguf.tensor_info(&gguf_weight).ok_or_else(|| {
FerrumError::model(format!("tensor_info missing for '{gguf_weight}'"))
})?;
let dims = info.shape.dims();
if dims.len() != 2 {
return Err(FerrumError::model(format!(
"GgufLoader::load_linear '{name}': expected rank-2 weight, got rank {}",
dims.len()
)));
}
let (n_rows, n_cols) = (dims[0], dims[1]);
let quant_kind = match info.ggml_dtype {
candle_core::quantized::GgmlDType::Q4K => {
Some(ferrum_kernels::backend::GgufQuantType::Q4K)
}
candle_core::quantized::GgmlDType::Q6K => {
Some(ferrum_kernels::backend::GgufQuantType::Q6K)
}
_ => None,
};
if let Some(kind) = quant_kind {
let has_bias = ferrum_to_gguf(&format!("{name}.bias"))
.map(|n| self.gguf.has_tensor(&n))
.unwrap_or(false);
if !has_bias {
let bytes = self.gguf.tensor_byte_slice(&gguf_weight).ok_or_else(|| {
FerrumError::model(format!(
"GgufLoader: tensor_byte_slice failed for '{gguf_weight}'"
))
})?;
let quant =
crate::QuantLinear::<B>::from_gguf_bytes(kind, bytes, n_rows, n_cols)?;
return Ok(Box::new(quant));
}
}
let qt = self
.gguf
.read_tensor(&gguf_weight, &self.decode_device)
.map_err(candle_to_ferrum)?;
if let Some(gguf_bias) = ferrum_to_gguf(&format!("{name}.bias")) {
if self.gguf.has_tensor(&gguf_bias) {
let bqt = self
.gguf
.read_tensor(&gguf_bias, &self.decode_device)
.map_err(candle_to_ferrum)?;
let linear = GgufLinear::<B>::from_qtensor_with_bias(&qt, &bqt)
.map_err(candle_to_ferrum)?;
return Ok(Box::new(linear));
}
}
let linear = GgufLinear::<B>::from_qtensor(&qt).map_err(candle_to_ferrum)?;
return Ok(Box::new(linear));
}
}
if let Some(layer_prefix) = name.strip_suffix("self_attn.qkv_proj") {
let parts = qkv_split_parts(layer_prefix);
return self.load_fused(&parts);
}
if let Some(layer_prefix) = name.strip_suffix("mlp.gate_up_proj") {
let parts = gate_up_split_parts(layer_prefix);
return self.load_fused(&parts);
}
Err(FerrumError::model(format!(
"GgufLoader: could not load Linear '{name}' — no direct weight, no split components"
)))
}
fn has_tensor(&self, name: &str) -> bool {
match ferrum_to_gguf(name) {
Some(g) => self.gguf.has_tensor(&g),
None => false,
}
}
fn quant_config(&self) -> Option<&QuantConfig> {
None
}
}
fn candle_to_ferrum(e: candle_core::Error) -> FerrumError {
FerrumError::model(format!("candle: {e}"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn gguf_loader_runtime_config_parses_load_trace_presence() {
let cfg =
GgufLoaderRuntimeConfig::from_env_vars([(GGUF_LOAD_TRACE_ENV, ""), ("OTHER", "1")]);
assert!(cfg.load_trace);
let cfg = GgufLoaderRuntimeConfig::from_env_vars([("OTHER", "1")]);
assert!(!cfg.load_trace);
}
}