pub mod name_map;
use std::collections::{BTreeMap, HashMap};
use std::io::BufWriter;
use std::path::{Path, PathBuf};
use anyhow::Context;
use safetensors::{Dtype, SafeTensors};
use serde_json::Value;
use oxibonsai_core::gguf::tensor_info::keys;
use oxibonsai_core::gguf::writer::{GgufWriter, MetadataWriteValue, TensorEntry, TensorType};
use oxibonsai_core::quant_ternary::{BlockTQ2_0_g128, BLOCK_TQ2_0_G128_BYTES};
use crate::convert::name_map::hf_to_gguf_name;
#[derive(Debug, Clone, Default)]
pub struct ConvertStats {
pub n_tensors: usize,
pub n_ternary: usize,
pub n_fp32: usize,
pub output_bytes: usize,
}
pub fn convert_hf_to_gguf(
from_dir: &Path,
to_path: &Path,
quant: &str,
) -> anyhow::Result<ConvertStats> {
if quant != "tq2_0_g128" {
anyhow::bail!(
"unsupported quantisation format '{}'; only 'tq2_0_g128' is supported",
quant
);
}
let config = read_config_json(from_dir)?;
let shard_paths = discover_shard_paths(from_dir)?;
let shard_bytes_list: Vec<Vec<u8>> = shard_paths
.iter()
.map(|p| std::fs::read(p).with_context(|| format!("reading shard {:?}", p)))
.collect::<anyhow::Result<_>>()?;
let parsed_shards: Vec<SafeTensors<'_>> = shard_bytes_list
.iter()
.enumerate()
.map(|(i, bytes)| {
SafeTensors::deserialize(bytes)
.with_context(|| format!("parsing shard {:?}", shard_paths[i]))
})
.collect::<anyhow::Result<_>>()?;
let mut name_to_shard: HashMap<&str, usize> = HashMap::new();
for (shard_idx, shard) in parsed_shards.iter().enumerate() {
for name in shard.names() {
name_to_shard.insert(name, shard_idx);
}
}
let mut writer = GgufWriter::new();
write_metadata(&mut writer, &config, from_dir)?;
let tie_word_embeddings = config
.get("tie_word_embeddings")
.and_then(Value::as_bool)
.unwrap_or(false);
let mut gguf_entries: BTreeMap<String, TensorEntryPending> = BTreeMap::new();
let mut embed_tokens_f32: Option<Vec<f32>> = None;
for (hf_name, &shard_idx) in &name_to_shard {
let mapped = match hf_to_gguf_name(hf_name) {
Some(m) => m,
None => {
tracing::debug!(hf_name, "skipping unmapped tensor");
continue;
}
};
let shard = &parsed_shards[shard_idx];
let view = shard
.tensor(hf_name)
.with_context(|| format!("tensor '{}' not found in shard", hf_name))?;
let f32_data = to_f32_vec(view.dtype(), view.data());
if f32_data.is_empty() && !view.data().is_empty() {
tracing::warn!(hf_name, dtype = ?view.dtype(), "unsupported dtype — skipping tensor");
continue;
}
if mapped.gguf_name == "token_embd.weight" && tie_word_embeddings {
embed_tokens_f32 = Some(f32_data.clone());
}
let shape_hf = view.shape();
let gguf_shape: Vec<u64> = shape_hf.iter().rev().map(|&d| d as u64).collect();
gguf_entries.insert(
mapped.gguf_name.clone(),
TensorEntryPending {
gguf_name: mapped.gguf_name,
is_norm: mapped.is_norm,
gguf_shape,
f32_data,
},
);
}
if tie_word_embeddings && !gguf_entries.contains_key("output.weight") {
if let Some(embed_f32) = embed_tokens_f32 {
let embed_entry = gguf_entries
.get("token_embd.weight")
.with_context(|| "tie_word_embeddings=true but token_embd.weight not found")?;
let shape = embed_entry.gguf_shape.clone();
tracing::info!("tie_word_embeddings=true: duplicating token_embd as output.weight");
gguf_entries.insert(
"output.weight".to_string(),
TensorEntryPending {
gguf_name: "output.weight".to_string(),
is_norm: false,
gguf_shape: shape,
f32_data: embed_f32,
},
);
}
}
let mut stats = ConvertStats::default();
for pending in gguf_entries.values() {
let (raw_bytes, tensor_type) = if pending.is_norm {
let raw: Vec<u8> = pending
.f32_data
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
(raw, TensorType::F32)
} else {
let element_count: usize = pending.gguf_shape.iter().product::<u64>() as usize;
let f32_data = pad_to_multiple_of_128(&pending.f32_data, element_count);
let blocks = BlockTQ2_0_g128::quantize(&f32_data)
.with_context(|| format!("quantizing tensor '{}'", pending.gguf_name))?;
let raw = blocks_to_bytes(&blocks);
(raw, TensorType::TQ2_0_g128)
};
println!(
" converting {} {:?} -> {}",
pending.gguf_name,
pending.gguf_shape,
if pending.is_norm { "F32" } else { "TQ2_0_g128" }
);
writer.add_tensor(TensorEntry {
name: pending.gguf_name.clone(),
shape: pending.gguf_shape.clone(),
tensor_type,
data: raw_bytes,
});
if pending.is_norm {
stats.n_fp32 += 1;
} else {
stats.n_ternary += 1;
}
stats.n_tensors += 1;
}
let out_file = std::fs::File::create(to_path)
.with_context(|| format!("creating output file {:?}", to_path))?;
let mut buf_writer = BufWriter::new(out_file);
let bytes_written = writer
.write(&mut buf_writer)
.map_err(|e| anyhow::anyhow!("GGUF write error: {}", e))?;
stats.output_bytes = bytes_written;
Ok(stats)
}
struct TensorEntryPending {
gguf_name: String,
is_norm: bool,
gguf_shape: Vec<u64>,
f32_data: Vec<f32>,
}
fn read_config_json(from_dir: &Path) -> anyhow::Result<Value> {
let config_path = from_dir.join("config.json");
let raw = std::fs::read_to_string(&config_path)
.with_context(|| format!("reading {:?}", config_path))?;
let value: Value =
serde_json::from_str(&raw).with_context(|| format!("parsing {:?}", config_path))?;
Ok(value)
}
fn discover_shard_paths(from_dir: &Path) -> anyhow::Result<Vec<PathBuf>> {
let single = from_dir.join("model.safetensors");
if single.exists() {
return Ok(vec![single]);
}
let index_path = from_dir.join("model.safetensors.index.json");
if !index_path.exists() {
anyhow::bail!(
"neither model.safetensors nor model.safetensors.index.json found in {:?}",
from_dir
);
}
let raw = std::fs::read_to_string(&index_path)
.with_context(|| format!("reading {:?}", index_path))?;
let index: Value =
serde_json::from_str(&raw).with_context(|| format!("parsing {:?}", index_path))?;
let weight_map = index
.get("weight_map")
.and_then(Value::as_object)
.with_context(|| format!("missing 'weight_map' in {:?}", index_path))?;
let mut shard_names: Vec<String> = Vec::new();
for file_name in weight_map.values() {
if let Some(s) = file_name.as_str() {
if !shard_names.contains(&s.to_string()) {
shard_names.push(s.to_string());
}
}
}
shard_names.sort();
let paths: Vec<PathBuf> = shard_names.iter().map(|name| from_dir.join(name)).collect();
Ok(paths)
}
fn write_metadata(writer: &mut GgufWriter, config: &Value, from_dir: &Path) -> anyhow::Result<()> {
writer.add_metadata(
keys::GENERAL_ARCHITECTURE,
MetadataWriteValue::Str("qwen3".to_string()),
);
let model_name = from_dir
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("unknown")
.to_string();
writer.add_metadata(keys::GENERAL_NAME, MetadataWriteValue::Str(model_name));
writer.add_metadata(
"general.quantization_version",
MetadataWriteValue::Str("TQ2_0_G128".to_string()),
);
let u32_keys = [
(keys::LLM_BLOCK_COUNT, "num_hidden_layers"),
(keys::LLM_EMBEDDING_LENGTH, "hidden_size"),
(keys::LLM_FEED_FORWARD_LENGTH, "intermediate_size"),
(keys::LLM_ATTENTION_HEAD_COUNT, "num_attention_heads"),
(keys::LLM_ATTENTION_HEAD_COUNT_KV, "num_key_value_heads"),
(keys::LLM_CONTEXT_LENGTH, "max_position_embeddings"),
(keys::LLM_VOCAB_SIZE, "vocab_size"),
];
for (gguf_key, json_key) in &u32_keys {
if let Some(val) = config.get(*json_key).and_then(Value::as_u64) {
writer.add_metadata(gguf_key, MetadataWriteValue::U32(val as u32));
} else {
tracing::warn!(json_key, "missing or non-u64 field in config.json");
}
}
if let Some(eps) = config.get("rms_norm_eps").and_then(Value::as_f64) {
writer.add_metadata(
keys::LLM_ATTENTION_LAYER_NORM_RMS_EPSILON,
MetadataWriteValue::F32(eps as f32),
);
}
let rope_theta = config
.get("rope_theta")
.and_then(Value::as_f64)
.unwrap_or(10000.0);
writer.add_metadata(
keys::LLM_ROPE_FREQ_BASE,
MetadataWriteValue::F32(rope_theta as f32),
);
Ok(())
}
fn to_f32_vec(dtype: Dtype, data: &[u8]) -> Vec<f32> {
match dtype {
Dtype::F32 => data
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect(),
Dtype::F16 => data
.chunks_exact(2)
.map(|b| half::f16::from_le_bytes([b[0], b[1]]).to_f32())
.collect(),
Dtype::BF16 => data
.chunks_exact(2)
.map(|b| half::bf16::from_le_bytes([b[0], b[1]]).to_f32())
.collect(),
_ => vec![],
}
}
fn pad_to_multiple_of_128(f32_data: &[f32], _element_count: usize) -> Vec<f32> {
let len = f32_data.len();
let remainder = len % 128;
if remainder == 0 {
f32_data.to_vec()
} else {
let padded_len = len + (128 - remainder);
let mut padded = f32_data.to_vec();
padded.resize(padded_len, 0.0f32);
padded
}
}
fn blocks_to_bytes(blocks: &[BlockTQ2_0_g128]) -> Vec<u8> {
let total = blocks.len() * BLOCK_TQ2_0_G128_BYTES;
let bytes: &[u8] = unsafe { std::slice::from_raw_parts(blocks.as_ptr() as *const u8, total) };
bytes.to_vec()
}