use std::collections::HashMap;
use std::fs::File;
use std::path::Path;
use ferrum_kernels::backend::Backend;
use ferrum_types::{FerrumError, Result};
use half::{bf16, f16};
use memmap2::Mmap;
use safetensors::{Dtype, SafeTensors};
use crate::config::{QuantConfig, QuantMethod};
use crate::dense::DenseLinear;
use crate::gptq::GptqLinear;
use crate::loader::WeightLoader;
use crate::traits::Linear;
struct Shard {
mmap: Mmap,
names: Vec<String>,
}
impl Shard {
fn open(path: &Path) -> Result<Self> {
let file = File::open(path).map_err(|e| FerrumError::io(format!("open {path:?}: {e}")))?;
let mmap = unsafe {
Mmap::map(&file).map_err(|e| FerrumError::io(format!("mmap {path:?}: {e}")))?
};
let st = SafeTensors::deserialize(&mmap)
.map_err(|e| FerrumError::model(format!("parse {path:?}: {e}")))?;
let names = st.names().iter().map(|s| s.to_string()).collect();
Ok(Self { mmap, names })
}
fn get<'a>(&'a self, name: &str) -> Result<safetensors::tensor::TensorView<'a>> {
let st = SafeTensors::deserialize(&self.mmap)
.map_err(|e| FerrumError::model(format!("reparse: {e}")))?;
st.tensor(name)
.map_err(|e| FerrumError::model(format!("tensor '{name}': {e}")))
}
}
pub struct NativeSafetensorsLoader<B: Backend> {
shards: Vec<Shard>,
index: HashMap<String, usize>,
quant_config: Option<QuantConfig>,
_m: std::marker::PhantomData<B>,
}
impl<B: Backend> NativeSafetensorsLoader<B> {
pub fn open(model_dir: impl AsRef<Path>) -> Result<Self> {
let dir = model_dir.as_ref();
let shard_paths = if dir.join("model.safetensors").exists() {
vec![dir.join("model.safetensors")]
} else if dir.join("model.safetensors.index.json").exists() {
Self::parse_sharded_index(&dir.join("model.safetensors.index.json"))?
.into_iter()
.map(|name| dir.join(name))
.collect()
} else {
return Err(FerrumError::model(format!(
"no safetensors files in {dir:?}"
)));
};
let mut shards = Vec::with_capacity(shard_paths.len());
let mut index: HashMap<String, usize> = HashMap::new();
for (i, p) in shard_paths.iter().enumerate() {
let shard = Shard::open(p)?;
for name in &shard.names {
index.insert(name.clone(), i);
}
shards.push(shard);
}
let quant_config = load_quantize_config(dir)?;
Ok(Self {
shards,
index,
quant_config,
_m: std::marker::PhantomData,
})
}
fn parse_sharded_index(index_path: &Path) -> Result<Vec<String>> {
let data = std::fs::read_to_string(index_path)
.map_err(|e| FerrumError::io(format!("read {index_path:?}: {e}")))?;
let json: serde_json::Value = serde_json::from_str(&data)
.map_err(|e| FerrumError::serialization(format!("index json: {e}")))?;
let weight_map = json
.get("weight_map")
.and_then(|v| v.as_object())
.ok_or_else(|| FerrumError::model("index missing weight_map"))?;
let mut files: Vec<String> = weight_map
.values()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect();
files.sort();
files.dedup();
Ok(files)
}
fn read_f32(&self, name: &str) -> Result<(Vec<f32>, Vec<usize>)> {
let shard_idx = *self
.index
.get(name)
.ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in index")))?;
let view = self.shards[shard_idx].get(name)?;
let shape = view.shape().to_vec();
let data = dtype_to_f32(view.dtype(), view.data())?;
Ok((data, shape))
}
fn read_i32(&self, name: &str) -> Result<(Vec<i32>, Vec<usize>)> {
let shard_idx = *self
.index
.get(name)
.ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in index")))?;
let view = self.shards[shard_idx].get(name)?;
let shape = view.shape().to_vec();
if view.dtype() != Dtype::I32 {
return Err(FerrumError::model(format!(
"'{name}': expected I32, got {:?}",
view.dtype()
)));
}
let bytes = view.data();
debug_assert_eq!(bytes.len() % 4, 0);
let mut out = vec![0i32; bytes.len() / 4];
out.as_mut_slice()
.iter_mut()
.zip(bytes.chunks_exact(4))
.for_each(|(d, chunk)| {
*d = i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])
});
Ok((out, shape))
}
fn has(&self, name: &str) -> bool {
self.index.contains_key(name)
}
}
impl<B: Backend> WeightLoader<B> for NativeSafetensorsLoader<B> {
fn load_tensor(&self, name: &str) -> Result<B::Buffer> {
let (data, _) = self.read_f32(name)?;
Ok(B::from_slice(&data))
}
fn load_linear(&self, name: &str) -> Result<Box<dyn Linear<B>>> {
let qw_key = format!("{name}.qweight");
if self.has(&qw_key) {
return self.load_gptq_linear(name);
}
if name.ends_with("qkv_proj") {
let prefix = &name[..name.len() - "qkv_proj".len()];
let parts = [
format!("{prefix}q_proj"),
format!("{prefix}k_proj"),
format!("{prefix}v_proj"),
];
if parts.iter().all(|p| self.has(&format!("{p}.qweight"))) {
return self.load_gptq_linear_fused(&parts);
}
}
if name.ends_with("gate_up_proj") {
let prefix = &name[..name.len() - "gate_up_proj".len()];
let parts = [format!("{prefix}gate_proj"), format!("{prefix}up_proj")];
if parts.iter().all(|p| self.has(&format!("{p}.qweight"))) {
return self.load_gptq_linear_fused(&parts);
}
}
let direct = format!("{name}.weight");
if self.has(&direct) {
let (data, shape) = self.read_f32(&direct)?;
if shape.len() != 2 {
return Err(FerrumError::model(format!(
"linear '{name}': expected 2D weight, got {shape:?}"
)));
}
return Ok(Box::new(DenseLinear::<B>::from_rows(
&data, shape[0], shape[1],
)));
}
if name.ends_with("qkv_proj") {
let prefix = &name[..name.len() - "qkv_proj".len()];
let parts = [
format!("{prefix}q_proj.weight"),
format!("{prefix}k_proj.weight"),
format!("{prefix}v_proj.weight"),
];
if parts.iter().all(|p| self.has(p)) {
let (rows, cols, data) = self.cat_rows(&parts)?;
return Ok(Box::new(DenseLinear::<B>::from_rows(&data, rows, cols)));
}
}
if name.ends_with("gate_up_proj") {
let prefix = &name[..name.len() - "gate_up_proj".len()];
let parts = [
format!("{prefix}gate_proj.weight"),
format!("{prefix}up_proj.weight"),
];
if parts.iter().all(|p| self.has(p)) {
let (rows, cols, data) = self.cat_rows(&parts)?;
return Ok(Box::new(DenseLinear::<B>::from_rows(&data, rows, cols)));
}
}
Err(FerrumError::model(format!(
"could not load linear '{name}' — no direct `.weight`, no split components"
)))
}
fn has_tensor(&self, name: &str) -> bool {
self.has(name)
}
fn quant_config(&self) -> Option<&QuantConfig> {
self.quant_config.as_ref()
}
}
impl<B: Backend> NativeSafetensorsLoader<B> {
fn load_gptq_linear(&self, name: &str) -> Result<Box<dyn Linear<B>>> {
let qcfg = self.quant_config.as_ref().ok_or_else(|| {
FerrumError::model(format!(
"'{name}.qweight' present but no quantize_config.json — \
can't determine bits/group_size"
))
})?;
if qcfg.method != QuantMethod::Gptq {
return Err(FerrumError::model(format!(
"'{name}.qweight' present but quant_method={:?} (expected GPTQ)",
qcfg.method
)));
}
let (qweight, qw_shape) = self.read_i32(&format!("{name}.qweight"))?;
let (scales_f32, sc_shape) = self.read_f32(&format!("{name}.scales"))?;
let (qzeros, _qz_shape) = self.read_i32(&format!("{name}.qzeros"))?;
let g_idx = if self.has(&format!("{name}.g_idx")) {
Some(self.read_i32(&format!("{name}.g_idx"))?.0)
} else {
None
};
if qw_shape.len() != 2 {
return Err(FerrumError::model(format!(
"'{name}.qweight' expected 2D, got {qw_shape:?}"
)));
}
let in_features = qw_shape[0] * 8;
let out_features = qw_shape[1];
if sc_shape.len() != 2 || sc_shape[1] != out_features {
return Err(FerrumError::model(format!(
"'{name}.scales' {sc_shape:?} incompatible with qweight {qw_shape:?}"
)));
}
let mut linear = GptqLinear::<B>::from_raw(
&qweight,
&scales_f32,
&qzeros,
g_idx.as_deref(),
qcfg.bits,
qcfg.group_size,
in_features,
out_features,
)?;
let bias_key = format!("{name}.bias");
if self.has(&bias_key) {
let (bias, bias_shape) = self.read_f32(&bias_key)?;
if bias_shape != [out_features] {
return Err(FerrumError::model(format!(
"'{bias_key}' {bias_shape:?} != [{out_features}]"
)));
}
linear = linear.with_bias(&bias);
}
Ok(Box::new(linear))
}
fn load_gptq_linear_fused(&self, parts: &[String]) -> Result<Box<dyn Linear<B>>> {
let qcfg = self.quant_config.as_ref().ok_or_else(|| {
FerrumError::model("GPTQ fusion requires quantize_config.json".to_string())
})?;
if qcfg.method != QuantMethod::Gptq {
return Err(FerrumError::model(format!(
"GPTQ fusion but quant_method={:?}",
qcfg.method
)));
}
let mut qw_acc: Vec<i32> = Vec::new();
let mut sc_acc: Vec<f32> = Vec::new();
let mut qz_acc: Vec<i32> = Vec::new();
let mut qw_rows = 0usize;
let mut sc_rows = 0usize;
let mut qz_rows = 0usize;
let mut total_n = 0usize;
let mut total_n_scales = 0usize;
let mut total_n_zeros = 0usize;
let mut g_idx: Option<Vec<i32>> = None;
let mut qw_parts: Vec<(Vec<i32>, usize, usize)> = Vec::new(); let mut sc_parts: Vec<(Vec<f32>, usize, usize)> = Vec::new();
let mut qz_parts: Vec<(Vec<i32>, usize, usize)> = Vec::new();
for p in parts {
let (qw, qw_sh) = self.read_i32(&format!("{p}.qweight"))?;
let (sc, sc_sh) = self.read_f32(&format!("{p}.scales"))?;
let (qz, qz_sh) = self.read_i32(&format!("{p}.qzeros"))?;
if qw_sh.len() != 2 || sc_sh.len() != 2 || qz_sh.len() != 2 {
return Err(FerrumError::model(format!(
"GPTQ fusion '{p}': expected 2D tensors, got qw {qw_sh:?} sc {sc_sh:?} qz {qz_sh:?}"
)));
}
if qw_rows == 0 {
qw_rows = qw_sh[0];
sc_rows = sc_sh[0];
qz_rows = qz_sh[0];
} else if qw_sh[0] != qw_rows || sc_sh[0] != sc_rows || qz_sh[0] != qz_rows {
return Err(FerrumError::model(format!(
"GPTQ fusion row mismatch on '{p}'"
)));
}
total_n += qw_sh[1];
total_n_scales += sc_sh[1];
total_n_zeros += qz_sh[1];
qw_parts.push((qw, qw_sh[0], qw_sh[1]));
sc_parts.push((sc, sc_sh[0], sc_sh[1]));
qz_parts.push((qz, qz_sh[0], qz_sh[1]));
if g_idx.is_none() {
if self.has(&format!("{p}.g_idx")) {
g_idx = Some(self.read_i32(&format!("{p}.g_idx"))?.0);
}
}
}
qw_acc.reserve(qw_rows * total_n);
for r in 0..qw_rows {
for (part, _rows, cols) in &qw_parts {
qw_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
}
}
sc_acc.reserve(sc_rows * total_n_scales);
for r in 0..sc_rows {
for (part, _rows, cols) in &sc_parts {
sc_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
}
}
qz_acc.reserve(qz_rows * total_n_zeros);
for r in 0..qz_rows {
for (part, _rows, cols) in &qz_parts {
qz_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
}
}
let in_features = qw_rows * 8;
let out_features = total_n;
let mut linear = GptqLinear::<B>::from_raw(
&qw_acc,
&sc_acc,
&qz_acc,
g_idx.as_deref(),
qcfg.bits,
qcfg.group_size,
in_features,
out_features,
)?;
let bias_keys: Vec<String> = parts.iter().map(|p| format!("{p}.bias")).collect();
let any = bias_keys.iter().any(|k| self.has(k));
let all = bias_keys.iter().all(|k| self.has(k));
if any && !all {
return Err(FerrumError::model(
"GPTQ fusion: inconsistent bias presence across parts".to_string(),
));
}
if all {
let mut fused: Vec<f32> = Vec::with_capacity(out_features);
for k in &bias_keys {
let (b, _) = self.read_f32(k)?;
fused.extend_from_slice(&b);
}
if fused.len() != out_features {
return Err(FerrumError::model(format!(
"GPTQ fusion bias length {} != out_features {out_features}",
fused.len()
)));
}
linear = linear.with_bias(&fused);
}
Ok(Box::new(linear))
}
fn cat_rows(&self, names: &[String]) -> Result<(usize, usize, Vec<f32>)> {
let mut total_rows = 0usize;
let mut cols = 0usize;
let mut out: Vec<f32> = Vec::new();
for n in names {
let (data, shape) = self.read_f32(n)?;
if shape.len() != 2 {
return Err(FerrumError::model(format!(
"cat_rows: '{n}' is {shape:?}, need 2D"
)));
}
if cols == 0 {
cols = shape[1];
} else if cols != shape[1] {
return Err(FerrumError::model(format!(
"cat_rows: col mismatch {cols} vs {}",
shape[1]
)));
}
total_rows += shape[0];
out.extend_from_slice(&data);
}
Ok((total_rows, cols, out))
}
}
fn dtype_to_f32(dtype: Dtype, raw: &[u8]) -> Result<Vec<f32>> {
match dtype {
Dtype::F32 => {
debug_assert_eq!(raw.len() % 4, 0);
let n = raw.len() / 4;
let mut out = vec![0.0f32; n];
for i in 0..n {
let bytes = [raw[i * 4], raw[i * 4 + 1], raw[i * 4 + 2], raw[i * 4 + 3]];
out[i] = f32::from_le_bytes(bytes);
}
Ok(out)
}
Dtype::F16 => {
debug_assert_eq!(raw.len() % 2, 0);
let n = raw.len() / 2;
let mut out = vec![0.0f32; n];
for i in 0..n {
let bytes = [raw[i * 2], raw[i * 2 + 1]];
out[i] = f16::from_le_bytes(bytes).to_f32();
}
Ok(out)
}
Dtype::BF16 => {
debug_assert_eq!(raw.len() % 2, 0);
let n = raw.len() / 2;
let mut out = vec![0.0f32; n];
for i in 0..n {
let bytes = [raw[i * 2], raw[i * 2 + 1]];
out[i] = bf16::from_le_bytes(bytes).to_f32();
}
Ok(out)
}
other => Err(FerrumError::model(format!(
"dtype {other:?} not supported by NativeSafetensorsLoader's f32 path; \
use a format-specific loader (GPTQ / AWQ / GGUF)",
))),
}
}
fn load_quantize_config(dir: &Path) -> Result<Option<QuantConfig>> {
let p = dir.join("quantize_config.json");
if p.exists() {
let data =
std::fs::read_to_string(&p).map_err(|e| FerrumError::io(format!("read {p:?}: {e}")))?;
let qc: QuantConfig = serde_json::from_str(&data)
.map_err(|e| FerrumError::serialization(format!("parse quantize_config.json: {e}")))?;
return Ok(Some(qc));
}
let cfg = dir.join("config.json");
if cfg.exists() {
let data = std::fs::read_to_string(&cfg)
.map_err(|e| FerrumError::io(format!("read {cfg:?}: {e}")))?;
let root: serde_json::Value = serde_json::from_str(&data)
.map_err(|e| FerrumError::serialization(format!("parse config.json: {e}")))?;
if let Some(qc_val) = root.get("quantization_config") {
let method = qc_val
.get("quant_method")
.and_then(|v| v.as_str())
.unwrap_or("none");
let method = match method.to_lowercase().as_str() {
"gptq" => QuantMethod::Gptq,
"awq" => QuantMethod::Awq,
"gguf" => QuantMethod::Gguf,
_ => QuantMethod::None,
};
let bits = qc_val.get("bits").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
let group_size = qc_val
.get("group_size")
.and_then(|v| v.as_i64())
.unwrap_or(128)
.max(0) as usize;
let desc_act = qc_val
.get("desc_act")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let sym = qc_val.get("sym").and_then(|v| v.as_bool()).unwrap_or(false);
if method != QuantMethod::None {
return Ok(Some(QuantConfig {
method,
bits,
group_size,
desc_act,
sym,
}));
}
}
}
Ok(None)
}