use std::path::Path;
use anyhow::Context;
use serde_json::Value;
use oxibonsai_core::gguf::tensor_info::keys;
use oxibonsai_core::gguf::writer::{GgufWriter, MetadataWriteValue};
use oxibonsai_core::quant_ternary::{BlockTQ2_0_g128, BLOCK_TQ2_0_G128_BYTES};
#[derive(Debug, Clone, Default)]
pub struct ConvertStats {
pub n_tensors: usize,
pub n_ternary: usize,
pub n_fp32: usize,
pub output_bytes: usize,
}
pub fn read_config_json(config_path: &Path) -> anyhow::Result<Value> {
let raw = std::fs::read_to_string(config_path)
.with_context(|| format!("reading {:?}", config_path))?;
let value: Value =
serde_json::from_str(&raw).with_context(|| format!("parsing {:?}", config_path))?;
Ok(value)
}
pub fn write_metadata(
writer: &mut GgufWriter,
config: &Value,
model_name: &str,
) -> anyhow::Result<()> {
writer.add_metadata(
keys::GENERAL_ARCHITECTURE,
MetadataWriteValue::Str("qwen3".to_string()),
);
writer.add_metadata(
keys::GENERAL_NAME,
MetadataWriteValue::Str(model_name.to_string()),
);
writer.add_metadata(
"general.quantization_version",
MetadataWriteValue::Str("TQ2_0_G128".to_string()),
);
let u32_keys = [
(keys::LLM_BLOCK_COUNT, "num_hidden_layers"),
(keys::LLM_EMBEDDING_LENGTH, "hidden_size"),
(keys::LLM_FEED_FORWARD_LENGTH, "intermediate_size"),
(keys::LLM_ATTENTION_HEAD_COUNT, "num_attention_heads"),
(keys::LLM_ATTENTION_HEAD_COUNT_KV, "num_key_value_heads"),
(keys::LLM_CONTEXT_LENGTH, "max_position_embeddings"),
(keys::LLM_VOCAB_SIZE, "vocab_size"),
];
for (gguf_key, json_key) in &u32_keys {
if let Some(val) = config.get(*json_key).and_then(Value::as_u64) {
writer.add_metadata(gguf_key, MetadataWriteValue::U32(val as u32));
} else {
tracing::warn!(json_key, "missing or non-u64 field in config.json");
}
}
if let Some(eps) = config.get("rms_norm_eps").and_then(Value::as_f64) {
writer.add_metadata(
keys::LLM_ATTENTION_LAYER_NORM_RMS_EPSILON,
MetadataWriteValue::F32(eps as f32),
);
}
let rope_theta = resolve_rope_theta(config);
writer.add_metadata(
keys::LLM_ROPE_FREQ_BASE,
MetadataWriteValue::F32(rope_theta as f32),
);
if let Some(rp) = config.get("rope_parameters").and_then(Value::as_object) {
let rope_type = rp.get("rope_type").and_then(Value::as_str).unwrap_or("");
if rope_type.eq_ignore_ascii_case("yarn") {
let factor = rp.get("factor").and_then(Value::as_f64);
let original_max_pos = rp
.get("original_max_position_embeddings")
.and_then(Value::as_u64);
tracing::info!(
?factor,
?original_max_pos,
"YARN rope_parameters detected; GGUF YARN metadata not plumbed, relying on architecture defaults"
);
}
}
Ok(())
}
fn resolve_rope_theta(config: &Value) -> f64 {
if let Some(v) = config.get("rope_theta").and_then(Value::as_f64) {
return v;
}
if let Some(v) = config
.get("rope_parameters")
.and_then(|rp| rp.get("rope_theta"))
.and_then(Value::as_f64)
{
return v;
}
tracing::warn!(
"config.json missing both `rope_theta` and `rope_parameters.rope_theta`; \
falling back to default 10000.0"
);
10000.0
}
pub fn pad_to_multiple_of_128(f32_data: &[f32]) -> Vec<f32> {
let len = f32_data.len();
let remainder = len % 128;
if remainder == 0 {
f32_data.to_vec()
} else {
let padded_len = len + (128 - remainder);
let mut padded = f32_data.to_vec();
padded.resize(padded_len, 0.0_f32);
padded
}
}
pub fn blocks_to_bytes(blocks: &[BlockTQ2_0_g128]) -> Vec<u8> {
let total = blocks.len() * BLOCK_TQ2_0_G128_BYTES;
let bytes: &[u8] = unsafe { std::slice::from_raw_parts(blocks.as_ptr() as *const u8, total) };
bytes.to_vec()
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn pad_aligned_is_identity() {
let v = vec![1.0_f32; 128];
assert_eq!(pad_to_multiple_of_128(&v), v);
}
#[test]
fn pad_extends_to_next_block() {
let v = vec![1.0_f32; 130];
let padded = pad_to_multiple_of_128(&v);
assert_eq!(padded.len(), 256);
assert_eq!(&padded[..130], &v[..]);
assert!(padded[130..].iter().all(|&x| x == 0.0));
}
#[test]
fn empty_input_stays_empty() {
let v: Vec<f32> = Vec::new();
assert!(pad_to_multiple_of_128(&v).is_empty());
}
#[test]
fn rope_theta_top_level_wins() {
let cfg = json!({
"rope_theta": 500000.0,
});
assert_eq!(resolve_rope_theta(&cfg), 500000.0);
}
#[test]
fn rope_theta_nested_under_rope_parameters() {
let cfg = json!({
"rope_parameters": {
"factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_theta": 1_000_000.0,
"rope_type": "yarn",
},
});
assert_eq!(resolve_rope_theta(&cfg), 1_000_000.0);
}
#[test]
fn rope_theta_top_level_takes_precedence_over_nested() {
let cfg = json!({
"rope_theta": 250000.0,
"rope_parameters": {
"rope_theta": 1_000_000.0,
"rope_type": "yarn",
},
});
assert_eq!(resolve_rope_theta(&cfg), 250000.0);
}
#[test]
fn rope_theta_fallback_when_missing() {
let cfg = json!({
"hidden_size": 2048,
});
assert_eq!(resolve_rope_theta(&cfg), 10000.0);
}
}