use std::collections::HashMap;
use std::fs::File;
use std::path::Path;
use ferrum_kernels::backend::{Backend, BackendQuantMarlin, SrcDtype};
use ferrum_types::{FerrumError, Result};
use half::{bf16, f16};
use memmap2::Mmap;
use safetensors::{Dtype, SafeTensors};
fn map_src_dtype(dtype: Dtype) -> Result<SrcDtype> {
match dtype {
Dtype::F32 => Ok(SrcDtype::F32),
Dtype::F16 => Ok(SrcDtype::F16),
Dtype::BF16 => Ok(SrcDtype::BF16),
other => Err(FerrumError::model(format!(
"dtype {other:?} not supported; Dense path expects F32/F16/BF16"
))),
}
}
use crate::config::{QuantConfig, QuantMethod};
use crate::dense::DenseLinear;
use crate::gptq::GptqLinear;
use crate::loader::WeightLoader;
use crate::traits::Linear;
struct TensorMeta {
dtype: Dtype,
shape: Vec<usize>,
data_start: usize,
data_end: usize,
}
struct Shard {
mmap: Mmap,
names: Vec<String>,
meta: HashMap<String, TensorMeta>,
}
impl Shard {
fn open(path: &Path) -> Result<Self> {
let file = File::open(path).map_err(|e| FerrumError::io(format!("open {path:?}: {e}")))?;
let mmap = unsafe {
Mmap::map(&file).map_err(|e| FerrumError::io(format!("mmap {path:?}: {e}")))?
};
let st = SafeTensors::deserialize(&mmap)
.map_err(|e| FerrumError::model(format!("parse {path:?}: {e}")))?;
debug_assert!(mmap.len() >= 8, "safetensors smaller than 8 bytes");
let header_len = u64::from_le_bytes(
mmap[0..8]
.try_into()
.expect("8-byte header len read failed"),
) as usize;
let data_base = 8 + header_len;
let names: Vec<String> = st.names().iter().map(|s| s.to_string()).collect();
let mut meta = HashMap::with_capacity(names.len());
for name in &names {
let view = st.tensor(name).map_err(|e| {
FerrumError::model(format!("tensor '{name}' missing during preindex: {e}"))
})?;
let view_data = view.data();
let start = view_data.as_ptr() as usize - mmap.as_ptr() as usize;
let end = start + view_data.len();
debug_assert!(start >= data_base);
meta.insert(
name.clone(),
TensorMeta {
dtype: view.dtype(),
shape: view.shape().to_vec(),
data_start: start,
data_end: end,
},
);
}
let _ = data_base;
Ok(Self { mmap, names, meta })
}
fn get_cached(&self, name: &str) -> Result<(&[u8], Dtype, &[usize])> {
let m = self
.meta
.get(name)
.ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in shard")))?;
Ok((&self.mmap[m.data_start..m.data_end], m.dtype, &m.shape))
}
}
pub struct NativeSafetensorsLoader<B: Backend + BackendQuantMarlin> {
shards: Vec<Shard>,
index: HashMap<String, usize>,
quant_config: Option<QuantConfig>,
_m: std::marker::PhantomData<B>,
}
impl<B: Backend + BackendQuantMarlin> NativeSafetensorsLoader<B> {
pub fn open(model_dir: impl AsRef<Path>) -> Result<Self> {
let dir = model_dir.as_ref();
let shard_paths = if dir.join("model.safetensors").exists() {
vec![dir.join("model.safetensors")]
} else if dir.join("model.safetensors.index.json").exists() {
Self::parse_sharded_index(&dir.join("model.safetensors.index.json"))?
.into_iter()
.map(|name| dir.join(name))
.collect()
} else {
return Err(FerrumError::model(format!(
"no safetensors files in {dir:?}"
)));
};
let mut shards = Vec::with_capacity(shard_paths.len());
let mut index: HashMap<String, usize> = HashMap::new();
for (i, p) in shard_paths.iter().enumerate() {
let shard = Shard::open(p)?;
for name in &shard.names {
index.insert(name.clone(), i);
}
shards.push(shard);
}
let quant_config = load_quantize_config(dir)?;
Ok(Self {
shards,
index,
quant_config,
_m: std::marker::PhantomData,
})
}
fn parse_sharded_index(index_path: &Path) -> Result<Vec<String>> {
let data = std::fs::read_to_string(index_path)
.map_err(|e| FerrumError::io(format!("read {index_path:?}: {e}")))?;
let json: serde_json::Value = serde_json::from_str(&data)
.map_err(|e| FerrumError::serialization(format!("index json: {e}")))?;
let weight_map = json
.get("weight_map")
.and_then(|v| v.as_object())
.ok_or_else(|| FerrumError::model("index missing weight_map"))?;
let mut files: Vec<String> = weight_map
.values()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect();
files.sort();
files.dedup();
Ok(files)
}
fn read_f32(&self, name: &str) -> Result<(Vec<f32>, Vec<usize>)> {
let shard_idx = *self
.index
.get(name)
.ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in index")))?;
let (data_bytes, dtype, shape) = self.shards[shard_idx].get_cached(name)?;
let data = dtype_to_f32(dtype, data_bytes)?;
Ok((data, shape.to_vec()))
}
fn read_bytes_typed(&self, name: &str) -> Result<(&[u8], SrcDtype, Vec<usize>)> {
let shard_idx = *self
.index
.get(name)
.ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in index")))?;
let (data_bytes, st_dtype, shape) = self.shards[shard_idx].get_cached(name)?;
let dtype = map_src_dtype(st_dtype)?;
Ok((data_bytes, dtype, shape.to_vec()))
}
fn cat_rows_bytes(&self, names: &[String]) -> Result<(Vec<u8>, SrcDtype, (usize, usize))> {
let mut total_rows = 0usize;
let mut cols = 0usize;
let mut dtype: Option<SrcDtype> = None;
let mut bytes: Vec<u8> = Vec::new();
for n in names {
let (raw, d, shape) = self.read_bytes_typed(n)?;
if shape.len() != 2 {
return Err(FerrumError::model(format!(
"cat_rows_bytes: '{n}' is {shape:?}, need 2D"
)));
}
match dtype {
Some(prev) if prev != d => {
return Err(FerrumError::model(format!(
"cat_rows_bytes: dtype mismatch on '{n}'"
)))
}
_ => dtype = Some(d),
}
if cols == 0 {
cols = shape[1];
} else if cols != shape[1] {
return Err(FerrumError::model(format!(
"cat_rows_bytes: col mismatch {cols} vs {}",
shape[1]
)));
}
total_rows += shape[0];
bytes.extend_from_slice(raw);
}
Ok((bytes, dtype.expect("at least one part"), (total_rows, cols)))
}
fn cat_optional_biases(
&self,
weight_names: &[String],
out_features: usize,
) -> Result<Option<Vec<f32>>> {
let bias_names: Vec<String> = weight_names
.iter()
.map(|name| {
name.strip_suffix(".weight")
.map(|stem| format!("{stem}.bias"))
.unwrap_or_else(|| format!("{name}.bias"))
})
.collect();
let any_bias = bias_names.iter().any(|name| self.has(name));
if !any_bias {
return Ok(None);
}
if let Some(missing) = bias_names.iter().find(|name| !self.has(name)) {
return Err(FerrumError::model(format!(
"dense fusion bias mix: '{missing}' missing while another fused part has bias"
)));
}
let mut fused = Vec::new();
for name in &bias_names {
let (bias, shape) = self.read_f32(name)?;
if shape.len() != 1 {
return Err(FerrumError::model(format!(
"dense fusion bias '{name}': expected 1D, got {shape:?}"
)));
}
fused.extend_from_slice(&bias);
}
if fused.len() != out_features {
return Err(FerrumError::model(format!(
"dense fusion bias length {} != out_features {out_features}",
fused.len()
)));
}
Ok(Some(fused))
}
fn read_i32(&self, name: &str) -> Result<(Vec<i32>, Vec<usize>)> {
let shard_idx = *self
.index
.get(name)
.ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in index")))?;
let (bytes, dtype, shape) = self.shards[shard_idx].get_cached(name)?;
if dtype != Dtype::I32 {
return Err(FerrumError::model(format!(
"'{name}': expected I32, got {:?}",
dtype
)));
}
debug_assert_eq!(bytes.len() % 4, 0);
let count = bytes.len() / 4;
let mut out = Vec::<i32>::with_capacity(count);
unsafe {
std::ptr::copy_nonoverlapping(bytes.as_ptr(), out.as_mut_ptr() as *mut u8, bytes.len());
out.set_len(count);
}
Ok((out, shape.to_vec()))
}
fn has(&self, name: &str) -> bool {
self.index.contains_key(name)
}
pub fn read_gptq_raw(
&self,
name: &str,
) -> Result<(Vec<i32>, Vec<f32>, Vec<i32>, Option<Vec<i32>>, usize, usize)> {
let (qweight, qw_shape) = self.read_i32(&format!("{name}.qweight"))?;
let (scales, _) = self.read_f32(&format!("{name}.scales"))?;
let (qzeros, _) = self.read_i32(&format!("{name}.qzeros"))?;
let g_idx = if self.has(&format!("{name}.g_idx")) {
Some(self.read_i32(&format!("{name}.g_idx"))?.0)
} else {
None
};
if qw_shape.len() != 2 {
return Err(FerrumError::model(format!(
"'{name}.qweight' expected 2D, got {qw_shape:?}"
)));
}
let k = qw_shape[0] * 8;
let n = qw_shape[1];
Ok((qweight, scales, qzeros, g_idx, k, n))
}
pub fn quant_config_ref(&self) -> Option<&crate::config::QuantConfig> {
self.quant_config.as_ref()
}
pub fn load_stacked_gptq_experts(
&self,
expert_prefix_fmt: &str,
num_experts: usize,
proj_names: &[&str],
) -> Result<(
std::sync::Arc<dyn ferrum_kernels::MarlinExpertStack<B>>,
usize,
usize,
)> {
let qcfg = self.quant_config.as_ref().ok_or_else(|| {
FerrumError::model(
"load_stacked_gptq_experts requires quantize_config.json".to_string(),
)
})?;
if qcfg.method != QuantMethod::Gptq {
return Err(FerrumError::model(format!(
"stacked GPTQ load but quant_method={:?}",
qcfg.method
)));
}
let mut qw_rows = 0usize;
let mut sc_rows = 0usize;
let mut qz_rows = 0usize;
let mut n_per_expert = 0usize;
let mut n_per_expert_scales = 0usize;
let mut n_per_expert_zeros = 0usize;
let mut k_shared = 0usize;
let mut g_idx_first: Option<Vec<i32>> = None;
let total_pairs = num_experts * proj_names.len();
let mut qw_parts: Vec<(Vec<i32>, usize)> = Vec::with_capacity(total_pairs); let mut sc_parts: Vec<(Vec<f32>, usize)> = Vec::with_capacity(total_pairs);
let mut qz_parts: Vec<(Vec<i32>, usize)> = Vec::with_capacity(total_pairs);
for e in 0..num_experts {
let prefix = expert_prefix_fmt.replace("{e}", &e.to_string());
let mut e_n = 0usize;
let mut e_n_scales = 0usize;
let mut e_n_zeros = 0usize;
for proj in proj_names {
let name = format!("{prefix}{proj}");
let (qw, qw_sh) = self.read_i32(&format!("{name}.qweight"))?;
let (sc, sc_sh) = self.read_f32(&format!("{name}.scales"))?;
let (qz, qz_sh) = self.read_i32(&format!("{name}.qzeros"))?;
if qw_sh.len() != 2 || sc_sh.len() != 2 || qz_sh.len() != 2 {
return Err(FerrumError::model(format!(
"stacked GPTQ '{name}': expected 2D, got qw {qw_sh:?} sc {sc_sh:?} qz {qz_sh:?}"
)));
}
if qw_rows == 0 {
qw_rows = qw_sh[0];
sc_rows = sc_sh[0];
qz_rows = qz_sh[0];
k_shared = qw_sh[0] * 8;
} else if qw_sh[0] != qw_rows || sc_sh[0] != sc_rows || qz_sh[0] != qz_rows {
return Err(FerrumError::model(format!(
"stacked GPTQ '{name}': row mismatch qw {} sc {} qz {} vs ref {qw_rows}/{sc_rows}/{qz_rows}",
qw_sh[0], sc_sh[0], qz_sh[0]
)));
}
e_n += qw_sh[1];
e_n_scales += sc_sh[1];
e_n_zeros += qz_sh[1];
qw_parts.push((qw, qw_sh[1]));
sc_parts.push((sc, sc_sh[1]));
qz_parts.push((qz, qz_sh[1]));
let g_key = format!("{name}.g_idx");
if self.has(&g_key) {
let (gx, _) = self.read_i32(&g_key)?;
match &g_idx_first {
None => g_idx_first = Some(gx),
Some(prev) => {
if prev.len() != gx.len() || prev.iter().zip(&gx).any(|(a, b)| a != b) {
return Err(FerrumError::model(format!(
"stacked GPTQ '{name}': g_idx mismatch with first \
expert — Marlin requires identical act-order across \
experts in the same stacked tile"
)));
}
}
}
}
}
if e == 0 {
n_per_expert = e_n;
n_per_expert_scales = e_n_scales;
n_per_expert_zeros = e_n_zeros;
} else if e_n != n_per_expert
|| e_n_scales != n_per_expert_scales
|| e_n_zeros != n_per_expert_zeros
{
return Err(FerrumError::model(format!(
"stacked GPTQ expert {e} N mismatch: qw {e_n} sc {e_n_scales} qz {e_n_zeros} vs expert 0 {n_per_expert}/{n_per_expert_scales}/{n_per_expert_zeros}"
)));
}
}
let proj_count = proj_names.len();
let pairs_per_expert = proj_count;
debug_assert_eq!(total_pairs, num_experts * pairs_per_expert);
let mut per_expert_qw: Vec<Vec<i32>> = Vec::with_capacity(num_experts);
let mut per_expert_sc: Vec<Vec<f32>> = Vec::with_capacity(num_experts);
let mut per_expert_qz: Vec<Vec<i32>> = Vec::with_capacity(num_experts);
for e in 0..num_experts {
let mut qw: Vec<i32> = Vec::with_capacity(qw_rows * n_per_expert);
let mut sc: Vec<f32> = Vec::with_capacity(sc_rows * n_per_expert_scales);
let mut qz: Vec<i32> = Vec::with_capacity(qz_rows * n_per_expert_zeros);
for r in 0..qw_rows {
for j in 0..pairs_per_expert {
let pair_idx = e * pairs_per_expert + j;
let (data, cols) = &qw_parts[pair_idx];
qw.extend_from_slice(&data[r * cols..(r + 1) * cols]);
}
}
for r in 0..sc_rows {
for j in 0..pairs_per_expert {
let pair_idx = e * pairs_per_expert + j;
let (data, cols) = &sc_parts[pair_idx];
sc.extend_from_slice(&data[r * cols..(r + 1) * cols]);
}
}
for r in 0..qz_rows {
for j in 0..pairs_per_expert {
let pair_idx = e * pairs_per_expert + j;
let (data, cols) = &qz_parts[pair_idx];
qz.extend_from_slice(&data[r * cols..(r + 1) * cols]);
}
}
per_expert_qw.push(qw);
per_expert_sc.push(sc);
per_expert_qz.push(qz);
}
drop(qw_parts);
drop(sc_parts);
drop(qz_parts);
let qw_refs: Vec<&[i32]> = per_expert_qw.iter().map(|v| v.as_slice()).collect();
let sc_refs: Vec<&[f32]> = per_expert_sc.iter().map(|v| v.as_slice()).collect();
let qz_refs: Vec<&[i32]> = per_expert_qz.iter().map(|v| v.as_slice()).collect();
let store = B::load_gptq_stacked(
&qw_refs,
&sc_refs,
&qz_refs,
g_idx_first.as_deref(),
qcfg.bits,
qcfg.group_size,
k_shared,
n_per_expert,
)?;
Ok((store, n_per_expert, k_shared))
}
}
impl<B: Backend + BackendQuantMarlin> WeightLoader<B> for NativeSafetensorsLoader<B> {
fn load_tensor(&self, name: &str) -> Result<B::Buffer> {
let (raw, src_dtype, _) = self.read_bytes_typed(name)?;
Ok(B::from_weight_bytes(raw, src_dtype))
}
fn load_linear(&self, name: &str) -> Result<Box<dyn Linear<B>>> {
let qw_key = format!("{name}.qweight");
if self.has(&qw_key) {
return self.load_gptq_linear(name);
}
if let Some(prefix) = name.strip_suffix("qkv_proj") {
let parts = [
format!("{prefix}q_proj"),
format!("{prefix}k_proj"),
format!("{prefix}v_proj"),
];
if parts.iter().all(|p| self.has(&format!("{p}.qweight"))) {
return self.load_gptq_linear_fused(&parts);
}
}
if let Some(prefix) = name.strip_suffix("gate_up_proj") {
let parts = [format!("{prefix}gate_proj"), format!("{prefix}up_proj")];
if parts.iter().all(|p| self.has(&format!("{p}.qweight"))) {
return self.load_gptq_linear_fused(&parts);
}
}
let direct = format!("{name}.weight");
if self.has(&direct) {
let (raw, src_dtype, shape) = self.read_bytes_typed(&direct)?;
if shape.len() != 2 {
return Err(FerrumError::model(format!(
"linear '{name}': expected 2D weight, got {shape:?}"
)));
}
let weight = B::from_weight_bytes(raw, src_dtype);
return Ok(Box::new(DenseLinear::<B>::from_buffer(
weight, shape[0], shape[1],
)));
}
if let Some(prefix) = name.strip_suffix("qkv_proj") {
let parts = [
format!("{prefix}q_proj.weight"),
format!("{prefix}k_proj.weight"),
format!("{prefix}v_proj.weight"),
];
if parts.iter().all(|p| self.has(p)) {
let (bytes, dtype, (rows, cols)) = self.cat_rows_bytes(&parts)?;
let weight = B::from_weight_bytes(&bytes, dtype);
let mut linear = DenseLinear::<B>::from_buffer(weight, rows, cols);
if let Some(bias) = self.cat_optional_biases(&parts, rows)? {
linear = linear.with_bias(B::from_slice(&bias));
}
return Ok(Box::new(linear));
}
}
if let Some(prefix) = name.strip_suffix("gate_up_proj") {
let parts = [
format!("{prefix}gate_proj.weight"),
format!("{prefix}up_proj.weight"),
];
if parts.iter().all(|p| self.has(p)) {
let (bytes, dtype, (rows, cols)) = self.cat_rows_bytes(&parts)?;
let weight = B::from_weight_bytes(&bytes, dtype);
let mut linear = DenseLinear::<B>::from_buffer(weight, rows, cols);
if let Some(bias) = self.cat_optional_biases(&parts, rows)? {
linear = linear.with_bias(B::from_slice(&bias));
}
return Ok(Box::new(linear));
}
}
Err(FerrumError::model(format!(
"could not load linear '{name}' — no direct `.weight`, no split components"
)))
}
fn has_tensor(&self, name: &str) -> bool {
self.has(name)
}
fn quant_config(&self) -> Option<&QuantConfig> {
self.quant_config.as_ref()
}
}
impl<B: Backend + BackendQuantMarlin> NativeSafetensorsLoader<B> {
fn load_gptq_linear(&self, name: &str) -> Result<Box<dyn Linear<B>>> {
let qcfg = self.quant_config.as_ref().ok_or_else(|| {
FerrumError::model(format!(
"'{name}.qweight' present but no quantize_config.json — \
can't determine bits/group_size"
))
})?;
if qcfg.method != QuantMethod::Gptq {
return Err(FerrumError::model(format!(
"'{name}.qweight' present but quant_method={:?} (expected GPTQ)",
qcfg.method
)));
}
let (qweight, qw_shape) = self.read_i32(&format!("{name}.qweight"))?;
let (scales_f32, sc_shape) = self.read_f32(&format!("{name}.scales"))?;
let (qzeros, _qz_shape) = self.read_i32(&format!("{name}.qzeros"))?;
let g_idx = if self.has(&format!("{name}.g_idx")) {
Some(self.read_i32(&format!("{name}.g_idx"))?.0)
} else {
None
};
if qw_shape.len() != 2 {
return Err(FerrumError::model(format!(
"'{name}.qweight' expected 2D, got {qw_shape:?}"
)));
}
let in_features = qw_shape[0] * 8;
let out_features = qw_shape[1];
let is_desc_act = validate_gptq_g_idx(name, qcfg, g_idx.as_deref(), in_features)?;
#[cfg(not(feature = "cuda"))]
if is_desc_act {
let dequant_f32 = dequantize_gptq_with_g_idx(
&qweight,
&scales_f32,
&qzeros,
g_idx.as_ref().expect("desc_act=true requires g_idx"),
qcfg.group_size,
in_features,
out_features,
);
let mut linear =
crate::dense::DenseLinear::<B>::from_rows(&dequant_f32, out_features, in_features);
let bias_key = format!("{name}.bias");
if self.has(&bias_key) {
let (bias, _) = self.read_f32(&bias_key)?;
linear = linear.with_bias(B::from_slice(&bias));
}
tracing::info!(
"GPTQ load (desc_act dequant→DenseLinear, non-cuda): name={name} K={in_features} N={out_features}"
);
return Ok(Box::new(linear));
}
#[cfg(feature = "cuda")]
let _ = is_desc_act; if sc_shape.len() != 2 || sc_shape[1] != out_features {
return Err(FerrumError::model(format!(
"'{name}.scales' {sc_shape:?} incompatible with qweight {qw_shape:?}"
)));
}
let bias_key = format!("{name}.bias");
let bias_vec = if self.has(&bias_key) {
let (bias, bias_shape) = self.read_f32(&bias_key)?;
if bias_shape != [out_features] {
return Err(FerrumError::model(format!(
"'{bias_key}' {bias_shape:?} != [{out_features}]"
)));
}
Some(bias)
} else {
None
};
let linear = GptqLinear::<B>::from_raw(
&qweight,
&scales_f32,
&qzeros,
g_idx.as_deref(),
bias_vec.as_deref(),
qcfg.bits,
qcfg.group_size,
in_features,
out_features,
)?;
Ok(Box::new(linear))
}
fn load_gptq_linear_fused(&self, parts: &[String]) -> Result<Box<dyn Linear<B>>> {
let qcfg = self.quant_config.as_ref().ok_or_else(|| {
FerrumError::model("GPTQ fusion requires quantize_config.json".to_string())
})?;
if qcfg.method != QuantMethod::Gptq {
return Err(FerrumError::model(format!(
"GPTQ fusion but quant_method={:?}",
qcfg.method
)));
}
let mut qw_acc: Vec<i32> = Vec::new();
let mut sc_acc: Vec<f32> = Vec::new();
let mut qz_acc: Vec<i32> = Vec::new();
let mut qw_rows = 0usize;
let mut sc_rows = 0usize;
let mut qz_rows = 0usize;
let mut total_n = 0usize;
let mut total_n_scales = 0usize;
let mut total_n_zeros = 0usize;
let mut g_idx: Option<Vec<i32>> = None;
let mut g_idx_presence: Vec<(String, bool)> = Vec::with_capacity(parts.len());
let mut qw_parts: Vec<(Vec<i32>, usize, usize)> = Vec::new(); let mut sc_parts: Vec<(Vec<f32>, usize, usize)> = Vec::new();
let mut qz_parts: Vec<(Vec<i32>, usize, usize)> = Vec::new();
for p in parts {
let (qw, qw_sh) = self.read_i32(&format!("{p}.qweight"))?;
let (sc, sc_sh) = self.read_f32(&format!("{p}.scales"))?;
let (qz, qz_sh) = self.read_i32(&format!("{p}.qzeros"))?;
if qw_sh.len() != 2 || sc_sh.len() != 2 || qz_sh.len() != 2 {
return Err(FerrumError::model(format!(
"GPTQ fusion '{p}': expected 2D tensors, got qw {qw_sh:?} sc {sc_sh:?} qz {qz_sh:?}"
)));
}
if qw_rows == 0 {
qw_rows = qw_sh[0];
sc_rows = sc_sh[0];
qz_rows = qz_sh[0];
} else if qw_sh[0] != qw_rows || sc_sh[0] != sc_rows || qz_sh[0] != qz_rows {
return Err(FerrumError::model(format!(
"GPTQ fusion row mismatch on '{p}'"
)));
}
total_n += qw_sh[1];
total_n_scales += sc_sh[1];
total_n_zeros += qz_sh[1];
qw_parts.push((qw, qw_sh[0], qw_sh[1]));
sc_parts.push((sc, sc_sh[0], sc_sh[1]));
qz_parts.push((qz, qz_sh[0], qz_sh[1]));
let g_key = format!("{p}.g_idx");
if self.has(&g_key) {
let (gx, gx_shape) = self.read_i32(&g_key)?;
if gx_shape != [qw_rows * 8] {
return Err(FerrumError::model(format!(
"GPTQ fusion '{p}': g_idx shape {gx_shape:?} incompatible with K={}",
qw_rows * 8
)));
}
match &g_idx {
None => g_idx = Some(gx),
Some(prev) => {
if prev.len() != gx.len() || prev.iter().zip(&gx).any(|(a, b)| a != b) {
return Err(FerrumError::model(format!(
"GPTQ fusion '{p}': g_idx mismatch with first part; \
fused qkv/gate_up requires identical act-order across parts"
)));
}
}
}
g_idx_presence.push((p.clone(), true));
} else {
g_idx_presence.push((p.clone(), false));
}
}
qw_acc.reserve(qw_rows * total_n);
for r in 0..qw_rows {
for (part, _rows, cols) in &qw_parts {
qw_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
}
}
sc_acc.reserve(sc_rows * total_n_scales);
for r in 0..sc_rows {
for (part, _rows, cols) in &sc_parts {
sc_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
}
}
qz_acc.reserve(qz_rows * total_n_zeros);
for r in 0..qz_rows {
for (part, _rows, cols) in &qz_parts {
qz_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
}
}
let in_features = qw_rows * 8;
let out_features = total_n;
if g_idx.is_some() {
let missing = g_idx_presence
.iter()
.filter_map(|(part, present)| (!present).then_some(part.as_str()))
.collect::<Vec<_>>();
if !missing.is_empty() {
return Err(FerrumError::model(format!(
"GPTQ fusion requires all parts to carry g_idx when any part does; \
missing g_idx for {missing:?}"
)));
}
}
let fused_name = format!("GPTQ fusion {}", parts.join("+"));
let is_desc_act = validate_gptq_g_idx(&fused_name, qcfg, g_idx.as_deref(), in_features)?;
#[cfg(not(feature = "cuda"))]
if is_desc_act {
let dequant_f32 = dequantize_gptq_with_g_idx(
&qw_acc,
&sc_acc,
&qz_acc,
g_idx.as_ref().expect("desc_act=true requires g_idx"),
qcfg.group_size,
in_features,
out_features,
);
let mut linear =
crate::dense::DenseLinear::<B>::from_rows(&dequant_f32, out_features, in_features);
let mut bias_acc: Vec<f32> = Vec::new();
let mut any_bias = false;
for p in parts {
let bk = format!("{p}.bias");
if self.has(&bk) {
any_bias = true;
bias_acc.extend_from_slice(&self.read_f32(&bk)?.0);
} else if any_bias {
return Err(FerrumError::model(format!(
"GPTQ fusion bias mix: '{p}' has no bias but earlier part did"
)));
}
}
if any_bias {
linear = linear.with_bias(B::from_slice(&bias_acc));
}
tracing::info!(
"GPTQ fused load (desc_act dequant→DenseLinear, non-cuda): K={in_features} N={out_features} parts={}",
parts.len()
);
return Ok(Box::new(linear));
}
#[cfg(feature = "cuda")]
let _ = is_desc_act;
let bias_keys: Vec<String> = parts.iter().map(|p| format!("{p}.bias")).collect();
let any = bias_keys.iter().any(|k| self.has(k));
let all = bias_keys.iter().all(|k| self.has(k));
if any && !all {
return Err(FerrumError::model(
"GPTQ fusion: inconsistent bias presence across parts".to_string(),
));
}
let fused_bias = if all {
let mut fused: Vec<f32> = Vec::with_capacity(out_features);
for k in &bias_keys {
let (b, _) = self.read_f32(k)?;
fused.extend_from_slice(&b);
}
if fused.len() != out_features {
return Err(FerrumError::model(format!(
"GPTQ fusion bias length {} != out_features {out_features}",
fused.len()
)));
}
Some(fused)
} else {
None
};
let linear = GptqLinear::<B>::from_raw(
&qw_acc,
&sc_acc,
&qz_acc,
g_idx.as_deref(),
fused_bias.as_deref(),
qcfg.bits,
qcfg.group_size,
in_features,
out_features,
)?;
Ok(Box::new(linear))
}
#[allow(dead_code)]
fn cat_rows(&self, names: &[String]) -> Result<(usize, usize, Vec<f32>)> {
let mut total_rows = 0usize;
let mut cols = 0usize;
let mut out: Vec<f32> = Vec::new();
for n in names {
let (data, shape) = self.read_f32(n)?;
if shape.len() != 2 {
return Err(FerrumError::model(format!(
"cat_rows: '{n}' is {shape:?}, need 2D"
)));
}
if cols == 0 {
cols = shape[1];
} else if cols != shape[1] {
return Err(FerrumError::model(format!(
"cat_rows: col mismatch {cols} vs {}",
shape[1]
)));
}
total_rows += shape[0];
out.extend_from_slice(&data);
}
Ok((total_rows, cols, out))
}
}
fn gptq_g_idx_is_desc_act(g_idx: &[i32], group_size: usize) -> bool {
g_idx
.iter()
.enumerate()
.any(|(i, &g)| g != (i as i32) / group_size as i32)
}
fn validate_gptq_g_idx(
name: &str,
qcfg: &QuantConfig,
g_idx: Option<&[i32]>,
in_features: usize,
) -> Result<bool> {
if qcfg.desc_act && g_idx.is_none() {
return Err(FerrumError::model(format!(
"{name}: quantize_config desc_act=true but no g_idx tensor was found"
)));
}
let Some(g_idx) = g_idx else {
return Ok(false);
};
if qcfg.group_size == 0 {
return Err(FerrumError::model(format!(
"{name}: GPTQ g_idx present but group_size is 0"
)));
}
if g_idx.len() != in_features {
return Err(FerrumError::model(format!(
"{name}: g_idx length {} must match K={in_features}",
g_idx.len()
)));
}
let expected_groups = in_features.div_ceil(qcfg.group_size);
for (idx, &group) in g_idx.iter().enumerate() {
if group < 0 || group as usize >= expected_groups {
return Err(FerrumError::model(format!(
"{name}: g_idx[{idx}]={group} outside expected group range 0..{}",
expected_groups.saturating_sub(1)
)));
}
}
Ok(gptq_g_idx_is_desc_act(g_idx, qcfg.group_size))
}
#[cfg(not(feature = "cuda"))]
fn dequantize_gptq_with_g_idx(
qweight: &[i32], scales: &[f32], qzeros: &[i32], g_idx: &[i32], _group_size: usize,
k: usize,
n: usize,
) -> Vec<f32> {
debug_assert_eq!(g_idx.len(), k);
let mut w = vec![0.0f32; n * k];
let packed_rows = k / 8;
for pr in 0..packed_rows {
for col in 0..n {
let packed = qweight[pr * n + col] as u32;
for bi in 0..8 {
let ki = pr * 8 + bi;
let q = ((packed >> (bi * 4)) & 0xF) as i32;
let g = g_idx[ki] as usize;
let scale = scales[g * n + col];
let z_packed = qzeros[g * (n / 8) + (col / 8)] as u32;
let zero = (((z_packed >> ((col % 8) * 4)) & 0xF) as i32) + 1;
w[col * k + ki] = (q - zero) as f32 * scale;
}
}
}
w
}
fn dtype_to_f32(dtype: Dtype, raw: &[u8]) -> Result<Vec<f32>> {
match dtype {
Dtype::F32 => {
debug_assert_eq!(raw.len() % 4, 0);
let n = raw.len() / 4;
let mut out = Vec::<f32>::with_capacity(n);
unsafe {
std::ptr::copy_nonoverlapping(raw.as_ptr(), out.as_mut_ptr() as *mut u8, raw.len());
out.set_len(n);
}
Ok(out)
}
Dtype::F16 => {
debug_assert_eq!(raw.len() % 2, 0);
let n = raw.len() / 2;
let mut tmp = Vec::<f16>::with_capacity(n);
unsafe {
std::ptr::copy_nonoverlapping(raw.as_ptr(), tmp.as_mut_ptr() as *mut u8, raw.len());
tmp.set_len(n);
}
let mut out = Vec::with_capacity(n);
for h in &tmp {
out.push(h.to_f32());
}
Ok(out)
}
Dtype::BF16 => {
debug_assert_eq!(raw.len() % 2, 0);
let n = raw.len() / 2;
let mut tmp = Vec::<bf16>::with_capacity(n);
unsafe {
std::ptr::copy_nonoverlapping(raw.as_ptr(), tmp.as_mut_ptr() as *mut u8, raw.len());
tmp.set_len(n);
}
let mut out = Vec::with_capacity(n);
for h in &tmp {
out.push(h.to_f32());
}
Ok(out)
}
other => Err(FerrumError::model(format!(
"dtype {other:?} not supported by NativeSafetensorsLoader's f32 path; \
use a format-specific loader (GPTQ / AWQ / GGUF)",
))),
}
}
fn load_quantize_config(dir: &Path) -> Result<Option<QuantConfig>> {
let p = dir.join("quantize_config.json");
if p.exists() {
let data =
std::fs::read_to_string(&p).map_err(|e| FerrumError::io(format!("read {p:?}: {e}")))?;
let qc: QuantConfig = serde_json::from_str(&data)
.map_err(|e| FerrumError::serialization(format!("parse quantize_config.json: {e}")))?;
return Ok(Some(qc));
}
let cfg = dir.join("config.json");
if cfg.exists() {
let data = std::fs::read_to_string(&cfg)
.map_err(|e| FerrumError::io(format!("read {cfg:?}: {e}")))?;
let root: serde_json::Value = serde_json::from_str(&data)
.map_err(|e| FerrumError::serialization(format!("parse config.json: {e}")))?;
if let Some(qc_val) = root.get("quantization_config") {
let method = qc_val
.get("quant_method")
.and_then(|v| v.as_str())
.unwrap_or("none");
let method = match method.to_lowercase().as_str() {
"gptq" => QuantMethod::Gptq,
"awq" => QuantMethod::Awq,
"gguf" => QuantMethod::Gguf,
_ => QuantMethod::None,
};
let bits = qc_val.get("bits").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
let group_size = qc_val
.get("group_size")
.and_then(|v| v.as_i64())
.unwrap_or(128)
.max(0) as usize;
let desc_act = qc_val
.get("desc_act")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let sym = qc_val.get("sym").and_then(|v| v.as_bool()).unwrap_or(false);
if method != QuantMethod::None {
return Ok(Some(QuantConfig {
method,
bits,
group_size,
desc_act,
sym,
}));
}
}
}
Ok(None)
}
#[cfg(test)]
mod tests {
use super::*;
fn gptq_config(desc_act: bool) -> QuantConfig {
QuantConfig {
method: QuantMethod::Gptq,
bits: 4,
group_size: 2,
desc_act,
sym: true,
}
}
#[test]
fn validate_gptq_g_idx_requires_tensor_when_desc_act_configured() {
let err = validate_gptq_g_idx("proj", &gptq_config(true), None, 4)
.unwrap_err()
.to_string();
assert!(err.contains("desc_act=true"));
assert!(err.contains("no g_idx"));
}
#[test]
fn validate_gptq_g_idx_accepts_trivial_non_desc_act_order() {
let is_desc_act =
validate_gptq_g_idx("proj", &gptq_config(false), Some(&[0, 0, 1, 1]), 4).unwrap();
assert!(!is_desc_act);
}
#[test]
fn validate_gptq_g_idx_detects_nontrivial_act_order() {
let is_desc_act =
validate_gptq_g_idx("proj", &gptq_config(false), Some(&[1, 1, 0, 0]), 4).unwrap();
assert!(is_desc_act);
}
#[test]
fn validate_gptq_g_idx_rejects_invalid_shape_and_group() {
let short = validate_gptq_g_idx("proj", &gptq_config(false), Some(&[0, 0, 1]), 4)
.unwrap_err()
.to_string();
assert!(short.contains("must match K=4"));
let out_of_range = validate_gptq_g_idx("proj", &gptq_config(false), Some(&[0, 0, 2, 1]), 4)
.unwrap_err()
.to_string();
assert!(out_of_range.contains("outside expected group range"));
}
}