pub mod dequant;
pub mod error;
pub mod reader;
pub mod role_map;
use std::collections::BTreeMap;
use std::io::BufWriter;
use std::path::Path;
use serde_json::Value;
use oxibonsai_core::gguf::writer::{GgufWriter, TensorEntry, TensorType};
use oxibonsai_core::quant_ternary::BlockTQ2_0_g128;
use oxionnx_proto::types::{NodeProto, TensorProto};
use crate::convert::common::{
blocks_to_bytes, pad_to_multiple_of_128, read_config_json, write_metadata, ConvertStats,
};
pub use self::error::{DequantError, OnnxImportError};
pub use self::role_map::OnnxRole;
pub fn convert_onnx_to_gguf(
onnx_path: &Path,
to_path: &Path,
quant: &str,
) -> Result<ConvertStats, OnnxImportError> {
if quant != "tq2_0_g128" {
return Err(OnnxImportError::Other(format!(
"unsupported quantisation format '{quant}'; only 'tq2_0_g128' is supported"
)));
}
let mut reader = reader::OnnxReader::open(onnx_path)?;
let config_path = reader::locate_config_json(onnx_path)?;
let config = read_config_json(&config_path).map_err(|e| {
OnnxImportError::Other(format!("reading {:?}: {e}", config_path))
})?;
let mut writer = GgufWriter::new();
let model_name = onnx_path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("unknown");
write_metadata(&mut writer, &config, model_name)
.map_err(|e| OnnxImportError::Other(format!("writing metadata: {e}")))?;
let tie_word_embeddings = config
.get("tie_word_embeddings")
.and_then(Value::as_bool)
.unwrap_or(false);
let num_hidden_layers = config
.get("num_hidden_layers")
.and_then(Value::as_u64)
.ok_or_else(|| {
OnnxImportError::Other(
"config.json is missing required field 'num_hidden_layers'".to_string(),
)
})? as usize;
let vocab_size = config
.get("vocab_size")
.and_then(Value::as_u64)
.ok_or_else(|| {
OnnxImportError::Other("config.json is missing required field 'vocab_size'".to_string())
})? as usize;
let hidden_size = config
.get("hidden_size")
.and_then(Value::as_u64)
.ok_or_else(|| {
OnnxImportError::Other(
"config.json is missing required field 'hidden_size'".to_string(),
)
})? as usize;
let mut gguf_entries: BTreeMap<String, PendingTensor> = BTreeMap::new();
let init_names: Vec<String> = reader
.model
.graph
.initializers
.iter()
.map(|t| t.name.clone())
.collect();
for name in &init_names {
let Some(role) = role_map::classify_initializer(name, num_hidden_layers) else {
continue;
};
match role {
OnnxRole::NormFp { gguf_name } => {
let (f32_data, shape_onnx) = read_fp_initializer(&mut reader, name)?;
let gguf_shape: Vec<u64> = shape_onnx.iter().rev().map(|&d| d as u64).collect();
gguf_entries.insert(
gguf_name.clone(),
PendingTensor {
gguf_name,
kind: TensorKind::Norm,
gguf_shape,
f32_data,
},
);
}
OnnxRole::EmbeddingFp => {
let (f32_data, shape_onnx) = read_fp_initializer(&mut reader, name)?;
let gguf_shape: Vec<u64> = shape_onnx.iter().rev().map(|&d| d as u64).collect();
gguf_entries.insert(
"token_embd.weight".to_string(),
PendingTensor {
gguf_name: "token_embd.weight".to_string(),
kind: TensorKind::Weight,
gguf_shape,
f32_data,
},
);
}
OnnxRole::LmHeadFp => {
let (f32_data, shape_onnx) = read_fp_initializer(&mut reader, name)?;
let gguf_shape: Vec<u64> = shape_onnx.iter().rev().map(|&d| d as u64).collect();
gguf_entries.insert(
"output.weight".to_string(),
PendingTensor {
gguf_name: "output.weight".to_string(),
kind: TensorKind::Weight,
gguf_shape,
f32_data,
},
);
}
OnnxRole::MatMulPacked { .. }
| OnnxRole::MatMulScales { .. }
| OnnxRole::MatMulZeroPoints { .. } => {
continue;
}
}
}
let matmul_snapshot: Vec<MatMulNbitsMeta> = reader
.model
.graph
.nodes
.iter()
.filter(|n| n.op_type == "MatMulNBits")
.map(collect_matmul_meta)
.collect::<Result<Vec<_>, _>>()?;
for meta in &matmul_snapshot {
let gguf_name = role_map::matmul_node_to_gguf(&meta.node_name)?;
let packed_tensor = resolve_matmul_input_tensor(&reader, &meta.packed_name)?;
let scales_tensor = resolve_matmul_input_tensor(&reader, &meta.scales_name)?;
let zp_tensor: Option<TensorProto> = match meta.zero_points_name.as_ref() {
Some(name) => Some(resolve_matmul_input_tensor(&reader, name)?),
None => None,
};
let packed_bytes: Vec<u8> = reader.initializer_bytes(&packed_tensor)?.to_vec();
let scales_bytes: Vec<u8> = reader.initializer_bytes(&scales_tensor)?.to_vec();
let zp_bytes_opt: Option<Vec<u8>> = if let Some(zp) = zp_tensor.as_ref() {
Some(reader.initializer_bytes(zp)?.to_vec())
} else {
None
};
let scales_f32 =
reader::bytes_to_f32(&scales_bytes, scales_tensor.data_type, &meta.scales_name)?;
let f32_row_major = dequant::dequantize_matmul_nbits(
&packed_bytes,
&scales_f32,
zp_bytes_opt.as_deref(),
meta.n,
meta.k,
meta.bits,
meta.block_size,
)
.map_err(|e| OnnxImportError::Dequant {
node: meta.node_name.clone(),
source: e,
})?;
let gguf_shape = vec![meta.k as u64, meta.n as u64];
gguf_entries.insert(
gguf_name.clone(),
PendingTensor {
gguf_name,
kind: TensorKind::Weight,
gguf_shape,
f32_data: f32_row_major,
},
);
}
let mut token_embd_emitted = false;
if let Some(quant_tensor) = reader
.find_initializer("model_embed_tokens_weight_quant")
.cloned()
{
let scales_tensor = reader
.find_initializer("model_embed_tokens_weight_scales")
.cloned()
.ok_or_else(|| OnnxImportError::MissingNamedInitializer {
name: "model_embed_tokens_weight_scales".to_string(),
})?;
let zp_tensor = reader
.find_initializer("model_embed_tokens_weight_zp_4b")
.cloned()
.ok_or_else(|| OnnxImportError::MissingNamedInitializer {
name: "model_embed_tokens_weight_zp_4b".to_string(),
})?;
let n = vocab_size;
let k = hidden_size;
let block_size = dequant::EXPECTED_BLOCK_SIZE; let n_blocks = k.div_ceil(block_size);
let expected_quant_dims = [n as i64, (k / 4) as i64];
let expected_scales_dims = [n as i64, n_blocks as i64];
let expected_zp_dims = [n as i64, (n_blocks / 2) as i64];
if quant_tensor.dims.as_slice() != expected_quant_dims {
return Err(OnnxImportError::Other(format!(
"GBQ embed 'model_embed_tokens_weight_quant' has dims {:?}, expected {:?} \
(N=vocab_size={}, K/4={})",
quant_tensor.dims,
expected_quant_dims,
n,
k / 4
)));
}
if scales_tensor.dims.as_slice() != expected_scales_dims {
return Err(OnnxImportError::Other(format!(
"GBQ embed 'model_embed_tokens_weight_scales' has dims {:?}, expected {:?} \
(N=vocab_size={}, n_blocks={})",
scales_tensor.dims, expected_scales_dims, n, n_blocks
)));
}
if zp_tensor.dims.as_slice() != expected_zp_dims {
return Err(OnnxImportError::Other(format!(
"GBQ embed 'model_embed_tokens_weight_zp_4b' has dims {:?}, expected {:?} \
(N=vocab_size={}, n_blocks/2={})",
zp_tensor.dims,
expected_zp_dims,
n,
n_blocks / 2
)));
}
let quant_bytes: Vec<u8> = reader.initializer_bytes(&quant_tensor)?.to_vec();
let scales_bytes: Vec<u8> = reader.initializer_bytes(&scales_tensor)?.to_vec();
let zp_bytes: Vec<u8> = reader.initializer_bytes(&zp_tensor)?.to_vec();
let scales_f32 =
reader::bytes_to_f32(&scales_bytes, scales_tensor.data_type, &scales_tensor.name)?;
let zp_repacked =
dequant::repack_4bit_zp_to_2bit(&zp_bytes, n * n_blocks).map_err(|e| {
OnnxImportError::Dequant {
node: "GatherBlockQuantized(embed_tokens)".to_string(),
source: e,
}
})?;
let f32_row_major = dequant::dequantize_matmul_nbits(
&quant_bytes,
&scales_f32,
Some(&zp_repacked),
n,
k,
2,
block_size,
)
.map_err(|e| OnnxImportError::Dequant {
node: "GatherBlockQuantized(embed_tokens)".to_string(),
source: e,
})?;
let gguf_shape = vec![k as u64, n as u64];
gguf_entries.insert(
"token_embd.weight".to_string(),
PendingTensor {
gguf_name: "token_embd.weight".to_string(),
kind: TensorKind::Weight,
gguf_shape,
f32_data: f32_row_major,
},
);
token_embd_emitted = true;
tracing::info!(
"GBQ embed detected: N={}, K={}, emitted token_embd.weight via 2-bit re-pack (bits=4 attribute overridden)",
n,
k
);
}
if tie_word_embeddings && !token_embd_emitted {
match (
gguf_entries.contains_key("token_embd.weight"),
gguf_entries.contains_key("output.weight"),
) {
(false, true) => {
if let Some(source) = gguf_entries.get("output.weight") {
let cloned = PendingTensor {
gguf_name: "token_embd.weight".to_string(),
kind: source.kind,
gguf_shape: source.gguf_shape.clone(),
f32_data: source.f32_data.clone(),
};
tracing::info!(
"tie_word_embeddings=true: duplicating output.weight as token_embd.weight"
);
gguf_entries.insert("token_embd.weight".to_string(), cloned);
}
}
(true, false) => {
if let Some(source) = gguf_entries.get("token_embd.weight") {
let cloned = PendingTensor {
gguf_name: "output.weight".to_string(),
kind: source.kind,
gguf_shape: source.gguf_shape.clone(),
f32_data: source.f32_data.clone(),
};
tracing::info!(
"tie_word_embeddings=true: duplicating token_embd.weight as output.weight"
);
gguf_entries.insert("output.weight".to_string(), cloned);
}
}
_ => {}
}
}
let mut stats = ConvertStats::default();
for pending in gguf_entries.values() {
let (raw_bytes, tensor_type) = match pending.kind {
TensorKind::Norm => {
let raw: Vec<u8> = pending
.f32_data
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
(raw, TensorType::F32)
}
TensorKind::Weight => {
let padded = pad_to_multiple_of_128(&pending.f32_data);
let blocks = BlockTQ2_0_g128::quantize(&padded).map_err(|e| {
OnnxImportError::Requantize {
tensor: pending.gguf_name.clone(),
msg: format!("{e}"),
}
})?;
let raw = blocks_to_bytes(&blocks);
(raw, TensorType::TQ2_0_g128)
}
};
println!(
" converting {} {:?} -> {}",
pending.gguf_name,
pending.gguf_shape,
match pending.kind {
TensorKind::Norm => "F32",
TensorKind::Weight => "TQ2_0_g128",
}
);
writer.add_tensor(TensorEntry {
name: pending.gguf_name.clone(),
shape: pending.gguf_shape.clone(),
tensor_type,
data: raw_bytes,
});
match pending.kind {
TensorKind::Norm => stats.n_fp32 += 1,
TensorKind::Weight => stats.n_ternary += 1,
}
stats.n_tensors += 1;
}
let out_file = std::fs::File::create(to_path).map_err(|e| OnnxImportError::Io {
path: to_path.to_path_buf(),
source: e,
})?;
let mut buf_writer = BufWriter::new(out_file);
let bytes_written = writer
.write(&mut buf_writer)
.map_err(|e| OnnxImportError::GgufWrite(format!("{e}")))?;
stats.output_bytes = bytes_written;
Ok(stats)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum TensorKind {
Norm,
Weight,
}
struct PendingTensor {
gguf_name: String,
kind: TensorKind,
gguf_shape: Vec<u64>,
f32_data: Vec<f32>,
}
fn read_fp_initializer(
reader: &mut reader::OnnxReader,
name: &str,
) -> Result<(Vec<f32>, Vec<i64>), OnnxImportError> {
let tensor = reader.find_initializer(name).cloned().ok_or_else(|| {
OnnxImportError::MissingNamedInitializer {
name: name.to_string(),
}
})?;
let bytes = reader.initializer_bytes(&tensor)?.to_vec();
let f32_data = reader::bytes_to_f32(&bytes, tensor.data_type, &tensor.name)?;
Ok((f32_data, tensor.dims.clone()))
}
struct MatMulNbitsMeta {
node_name: String,
packed_name: String,
scales_name: String,
zero_points_name: Option<String>,
n: usize,
k: usize,
bits: u32,
block_size: usize,
}
fn collect_matmul_meta(node: &NodeProto) -> Result<MatMulNbitsMeta, OnnxImportError> {
let node_name = if node.name.is_empty() {
"<anon>".to_string()
} else {
node.name.clone()
};
let packed_name = node
.inputs
.get(1)
.ok_or_else(|| OnnxImportError::MissingInitializer {
node: node_name.clone(),
index: 1,
name: "<missing>".to_string(),
})?
.clone();
let scales_name = node
.inputs
.get(2)
.ok_or_else(|| OnnxImportError::MissingInitializer {
node: node_name.clone(),
index: 2,
name: "<missing>".to_string(),
})?
.clone();
let zero_points_name = node.inputs.get(3).cloned().filter(|s| !s.is_empty());
let bits_i =
reader::attr_int(&node.attributes, "bits").ok_or(OnnxImportError::MissingAttribute {
node: node_name.clone(),
attr: "bits",
})?;
let block_size_i = reader::attr_int(&node.attributes, "block_size").ok_or(
OnnxImportError::MissingAttribute {
node: node_name.clone(),
attr: "block_size",
},
)?;
let n_i = reader::attr_int(&node.attributes, "N").ok_or(OnnxImportError::MissingAttribute {
node: node_name.clone(),
attr: "N",
})?;
let k_i = reader::attr_int(&node.attributes, "K").ok_or(OnnxImportError::MissingAttribute {
node: node_name.clone(),
attr: "K",
})?;
if bits_i <= 0 || block_size_i <= 0 || n_i <= 0 || k_i <= 0 {
return Err(OnnxImportError::Other(format!(
"MatMulNBits node '{node_name}' has non-positive attribute(s): bits={bits_i} block_size={block_size_i} N={n_i} K={k_i}"
)));
}
Ok(MatMulNbitsMeta {
node_name,
packed_name,
scales_name,
zero_points_name,
n: n_i as usize,
k: k_i as usize,
bits: bits_i as u32,
block_size: block_size_i as usize,
})
}
fn resolve_matmul_input_tensor(
reader: &reader::OnnxReader,
name: &str,
) -> Result<TensorProto, OnnxImportError> {
if let Some(t) = reader.find_initializer(name) {
return Ok(t.clone());
}
let producer = reader
.model
.graph
.nodes
.iter()
.find(|n| n.outputs.iter().any(|o| o == name));
if let Some(node) = producer {
if node.op_type == "Reshape" {
if let Some(src) = node.inputs.first() {
if let Some(t) = reader.find_initializer(src) {
return Ok(t.clone());
}
}
}
return Err(OnnxImportError::Other(format!(
"MatMulNBits input '{name}' is produced by node '{}' (op '{}') whose inputs are not a resolvable initializer",
node.name, node.op_type
)));
}
Err(OnnxImportError::MissingNamedInitializer {
name: name.to_string(),
})
}