use ferrum_types::{FerrumError, Result};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, serde::Deserialize)]
pub struct QuantizeConfig {
pub bits: usize,
pub group_size: i64,
#[serde(default)]
pub sym: bool,
#[serde(default)]
pub desc_act: bool,
#[serde(default)]
pub quant_method: String,
}
impl QuantizeConfig {
pub fn from_model_dir(model_dir: &Path) -> Result<Option<Self>> {
let path = model_dir.join("quantize_config.json");
if !path.exists() {
let config_path = model_dir.join("config.json");
if config_path.exists() {
if let Ok(content) = std::fs::read_to_string(&config_path) {
if let Ok(config) = serde_json::from_str::<serde_json::Value>(&content) {
if let Some(qc) = config.get("quantization_config") {
if let Ok(qconfig) =
serde_json::from_value::<QuantizeConfig>(qc.clone())
{
tracing::info!("GPTQ config found in config.json: {:?}", qconfig);
return Ok(Some(qconfig));
}
}
}
}
}
return Ok(None);
}
let content = std::fs::read_to_string(&path)
.map_err(|e| FerrumError::model(format!("read quantize_config.json: {e}")))?;
let config: QuantizeConfig = serde_json::from_str(&content)
.map_err(|e| FerrumError::model(format!("parse quantize_config.json: {e}")))?;
tracing::info!("GPTQ config: {:?}", config);
Ok(Some(config))
}
pub fn effective_group_size(&self, k: usize) -> usize {
if self.group_size <= 0 {
k } else {
self.group_size as usize
}
}
}
#[derive(Debug)]
pub struct GptqLayerWeights {
pub qweight: Vec<i32>,
pub scales: Vec<half::f16>,
pub qzeros: Option<Vec<i32>>,
pub k: usize,
pub n: usize,
pub group_size: usize,
pub symmetric: bool,
}
impl GptqLayerWeights {
pub fn dequantize_cpu(&self) -> Vec<half::f16> {
let mut output = vec![half::f16::ZERO; self.k * self.n];
let packed_rows = self.k / 8;
for packed_row in 0..packed_rows {
for col in 0..self.n {
let packed = self.qweight[packed_row * self.n + col];
let base_k = packed_row * 8;
let group = base_k / self.group_size;
let scale = self.scales[group * self.n + col].to_f32();
let zero = if self.symmetric {
8
} else if let Some(ref qz) = self.qzeros {
let zp_packed = qz[group * (self.n / 8) + col / 8];
let zp_shift = (col % 8) * 4;
(zp_packed >> zp_shift) & 0xF
} else {
8
};
for i in 0..8 {
let val = (packed >> (i * 4)) & 0xF;
let dequantized = (val - zero) as f32 * scale;
output[(base_k + i as usize) * self.n + col] = half::f16::from_f32(dequantized);
}
}
}
output
}
}
pub fn load_gptq_weights(
model_dir: &Path,
qconfig: &QuantizeConfig,
) -> Result<HashMap<String, GptqLayerWeights>> {
use safetensors::SafeTensors;
let safetensor_files = find_safetensor_files(model_dir)?;
if safetensor_files.is_empty() {
return Err(FerrumError::model("No safetensor files found"));
}
let mut result = HashMap::new();
for path in &safetensor_files {
let data = std::fs::read(path)
.map_err(|e| FerrumError::model(format!("read {}: {e}", path.display())))?;
let st = SafeTensors::deserialize(&data)
.map_err(|e| FerrumError::model(format!("parse {}: {e}", path.display())))?;
for (name, _) in st.tensors() {
if !name.ends_with(".qweight") {
continue;
}
let prefix = name.strip_suffix(".qweight").unwrap().to_string();
let qw_tensor = st
.tensor(&format!("{prefix}.qweight"))
.map_err(|e| FerrumError::model(format!("{prefix}.qweight: {e}")))?;
let qweight: Vec<i32> = bytemuck::cast_slice(qw_tensor.data()).to_vec();
let qw_shape = qw_tensor.shape();
let packed_k = qw_shape[0]; let n = qw_shape[1];
let k = packed_k * 8;
let sc_tensor = st
.tensor(&format!("{prefix}.scales"))
.map_err(|e| FerrumError::model(format!("{prefix}.scales: {e}")))?;
let scales: Vec<half::f16> = bytemuck::cast_slice(sc_tensor.data()).to_vec();
let qzeros = if !qconfig.sym {
let qz_tensor = st
.tensor(&format!("{prefix}.qzeros"))
.map_err(|e| FerrumError::model(format!("{prefix}.qzeros: {e}")))?;
Some(bytemuck::cast_slice(qz_tensor.data()).to_vec())
} else {
None
};
let gs = qconfig.effective_group_size(k);
tracing::debug!(
"GPTQ layer: {prefix} K={k} N={n} group_size={gs} sym={}",
qconfig.sym
);
result.insert(
prefix,
GptqLayerWeights {
qweight,
scales,
qzeros,
k,
n,
group_size: gs,
symmetric: qconfig.sym,
},
);
}
}
tracing::info!("Loaded {} GPTQ quantized layers (raw)", result.len());
fuse_qkv_and_gate_up(&mut result);
tracing::info!(
"After fusion: {} GPTQ layers (includes fused qkv_proj, gate_up_proj)",
result.len()
);
Ok(result)
}
fn fuse_qkv_and_gate_up(weights: &mut HashMap<String, GptqLayerWeights>) {
let prefixes: Vec<String> = weights
.keys()
.filter(|k| k.ends_with(".self_attn.q_proj"))
.map(|k| k.strip_suffix(".self_attn.q_proj").unwrap().to_string())
.collect();
for layer_prefix in &prefixes {
let q_key = format!("{layer_prefix}.self_attn.q_proj");
let k_key = format!("{layer_prefix}.self_attn.k_proj");
let v_key = format!("{layer_prefix}.self_attn.v_proj");
if let (Some(q), Some(k), Some(v)) = (
weights.get(&q_key),
weights.get(&k_key),
weights.get(&v_key),
) {
if q.k == k.k && q.k == v.k {
let fused = fuse_columns(&[q, k, v]);
let fused_key = format!("{layer_prefix}.self_attn.qkv_proj");
tracing::info!(
"Fused {q_key}+{k_key}+{v_key} → {fused_key} K={} N={}",
fused.k,
fused.n
);
weights.insert(fused_key, fused);
}
}
let gate_key = format!("{layer_prefix}.mlp.gate_proj");
let up_key = format!("{layer_prefix}.mlp.up_proj");
if let (Some(gate), Some(up)) = (weights.get(&gate_key), weights.get(&up_key)) {
if gate.k == up.k {
let fused = fuse_columns(&[gate, up]);
let fused_key = format!("{layer_prefix}.mlp.gate_up_proj");
tracing::info!(
"Fused {gate_key}+{up_key} → {fused_key} K={} N={}",
fused.k,
fused.n
);
weights.insert(fused_key, fused);
}
}
}
}
fn fuse_columns(parts: &[&GptqLayerWeights]) -> GptqLayerWeights {
let k = parts[0].k;
let gs = parts[0].group_size;
let sym = parts[0].symmetric;
let total_n: usize = parts.iter().map(|p| p.n).sum();
let packed_k = k / 8;
let num_groups = k / gs;
let mut qweight = vec![0i32; packed_k * total_n];
let mut col_offset = 0;
for part in parts {
for row in 0..packed_k {
for col in 0..part.n {
qweight[row * total_n + col_offset + col] = part.qweight[row * part.n + col];
}
}
col_offset += part.n;
}
let mut scales = vec![half::f16::ZERO; num_groups * total_n];
col_offset = 0;
for part in parts {
for row in 0..num_groups {
for col in 0..part.n {
scales[row * total_n + col_offset + col] = part.scales[row * part.n + col];
}
}
col_offset += part.n;
}
let qzeros = if !sym {
let mut all_zeros = vec![0u8; num_groups * total_n];
let mut col_off = 0usize;
for part in parts {
if let Some(ref qz) = part.qzeros {
let part_n8 = part.n / 8;
for row in 0..num_groups {
for col in 0..part.n {
let packed = qz[row * part_n8 + col / 8];
let val = ((packed >> ((col % 8) * 4)) & 0xF) as u8;
all_zeros[row * total_n + col_off + col] = val;
}
}
}
col_off += part.n;
}
let total_n8 = total_n / 8;
let mut packed_zeros = vec![0i32; num_groups * total_n8];
for row in 0..num_groups {
for col in 0..total_n {
let val = all_zeros[row * total_n + col] as i32;
packed_zeros[row * total_n8 + col / 8] |= val << ((col % 8) * 4);
}
}
Some(packed_zeros)
} else {
None
};
GptqLayerWeights {
qweight,
scales,
qzeros,
k,
n: total_n,
group_size: gs,
symmetric: sym,
}
}
fn find_safetensor_files(model_dir: &Path) -> Result<Vec<PathBuf>> {
let mut files = Vec::new();
let single = model_dir.join("model.safetensors");
if single.exists() {
files.push(single);
return Ok(files);
}
let index_path = model_dir.join("model.safetensors.index.json");
if index_path.exists() {
let content = std::fs::read_to_string(&index_path)
.map_err(|e| FerrumError::model(format!("read index: {e}")))?;
let index: serde_json::Value = serde_json::from_str(&content)
.map_err(|e| FerrumError::model(format!("parse index: {e}")))?;
if let Some(weight_map) = index.get("weight_map").and_then(|v| v.as_object()) {
let mut seen = std::collections::HashSet::new();
for filename in weight_map.values().filter_map(|v| v.as_str()) {
if seen.insert(filename.to_string()) {
let path = model_dir.join(filename);
if path.exists() {
files.push(path);
}
}
}
}
}
Ok(files)
}