extern crate alloc;
use alloc::vec::Vec;
#[derive(Debug)]
pub struct GptqWeights<'a> {
pub qweight: &'a [i32],
pub scales: &'a [f32],
pub qzeros: &'a [i32],
pub g_idx: Option<&'a [i32]>,
pub ic: usize,
pub oc: usize,
pub group_size: usize,
}
#[derive(Debug)]
pub struct MarlinWeights {
pub weight_packed: Vec<i32>,
pub scales: Vec<f32>,
}
pub fn repack(g: &GptqWeights<'_>) -> Result<MarlinWeights, &'static str> {
if g.ic == 0 || g.oc == 0 {
return Err("gptq_to_marlin::repack: IC and OC must be positive");
}
if g.group_size == 0 || (g.group_size != 128 && g.group_size as i32 != -1i32) {
return Err(
"gptq_to_marlin::repack: group_size must be 128 (per-group) — \
per-channel (g=-1) support requires a separate _scale_perm_single \
table not yet wired",
);
}
if g.ic % 16 != 0 {
return Err("gptq_to_marlin::repack: IC must be divisible by 16");
}
if g.oc % 8 != 0 {
return Err("gptq_to_marlin::repack: OC must be divisible by 8");
}
if g.ic % g.group_size != 0 {
return Err("gptq_to_marlin::repack: IC must be divisible by group_size");
}
let num_groups = g.ic / g.group_size;
let expected_qweight_len = (g.ic / 8) * g.oc;
if g.qweight.len() != expected_qweight_len {
return Err("gptq_to_marlin::repack: qweight length != (IC/8) * OC");
}
if g.scales.len() != num_groups * g.oc {
return Err("gptq_to_marlin::repack: scales length != num_groups * OC");
}
if g.qzeros.len() != num_groups * (g.oc / 8) {
return Err("gptq_to_marlin::repack: qzeros length != num_groups * (OC/8)");
}
let mut weight_dense = alloc::vec![0u8; g.ic * g.oc];
for ic_byte in 0..(g.ic / 8) {
for oc in 0..g.oc {
let word = g.qweight[ic_byte * g.oc + oc] as u32;
for nib in 0..8usize {
let q = ((word >> (4 * nib)) & 0xF) as u8;
let ic_pos = ic_byte * 8 + nib;
weight_dense[ic_pos * g.oc + oc] = q;
}
}
}
let mut zeros_dense = alloc::vec![0u8; num_groups * g.oc];
for grp in 0..num_groups {
for oc_byte in 0..(g.oc / 8) {
let word = g.qzeros[grp * (g.oc / 8) + oc_byte] as u32;
for nib in 0..8usize {
let z = ((word >> (4 * nib)) & 0xF) as u8;
zeros_dense[grp * g.oc + oc_byte * 8 + nib] = z;
}
}
}
if let Some(idx) = g.g_idx {
if idx.len() != g.ic {
return Err("gptq_to_marlin::repack: g_idx length != IC");
}
let is_monotonic = idx.windows(2).all(|w| w[0] <= w[1]);
if !is_monotonic {
return Err(
"gptq_to_marlin::repack: act_order=True (non-monotonic g_idx) \
not yet implemented — re-quantize the GPTQ checkpoint with \
desc_act=False or wait for a Phase 48 follow-up",
);
}
}
for ic in 0..g.ic {
let grp = ic / g.group_size;
for oc in 0..g.oc {
let q = weight_dense[ic * g.oc + oc] as i32;
let zp = zeros_dense[grp * g.oc + oc] as i32;
let q_marlin = (q - zp + 8).clamp(0, 15) as u8;
weight_dense[ic * g.oc + oc] = q_marlin;
}
}
let k_blocks = g.ic / 16;
let n_words = g.oc * 16 / 8;
let mut packed = alloc::vec![0i32; k_blocks * n_words];
for kb in 0..k_blocks {
for oc in 0..g.oc {
for half in 0..2usize {
let mut word: u32 = 0;
for nib in 0..8usize {
let ic = kb * 16 + half * 8 + nib;
let q = weight_dense[ic * g.oc + oc] as u32;
word |= (q & 0xF) << (4 * nib);
}
let word_idx = kb * n_words + oc * 2 + half;
packed[word_idx] = word as i32;
}
}
}
let scales_out: Vec<f32> = g.scales.to_vec();
Ok(MarlinWeights {
weight_packed: packed,
scales: scales_out,
})
}
pub const MARLIN_PERM_LEN: usize = 64;
pub const MARLIN_SCALE_PERM_LEN: usize = 64;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rejects_unsupported_group_size() {
let g = GptqWeights {
qweight: &[],
scales: &[],
qzeros: &[],
g_idx: None,
ic: 0,
oc: 0,
group_size: 64,
};
assert!(repack(&g).is_err());
}
#[test]
fn rejects_shape_mismatch() {
let g = GptqWeights {
qweight: &[0i32; 4],
scales: &[1.0f32; 8],
qzeros: &[0i32; 1],
g_idx: None,
ic: 128,
oc: 8,
group_size: 128,
};
assert!(repack(&g).is_err());
}
#[test]
fn accepts_minimal_shape() {
let g = GptqWeights {
qweight: &alloc::vec![0i32; 4096],
scales: &alloc::vec![1.0f32; 256],
qzeros: &alloc::vec![0x77777777i32; 32], g_idx: None,
ic: 128,
oc: 256,
group_size: 128,
};
let m = repack(&g).expect("repack should succeed");
assert_eq!(m.weight_packed.len(), 4096);
assert_eq!(m.scales.len(), 256);
}
}