Skip to main content

oxibonsai_model/convert/
common.rs

1//! Shared helpers for HuggingFace / ONNX → GGUF conversion pipelines.
2//!
3//! Both the safetensors path (`convert::convert_hf_to_gguf`) and the ONNX path
4//! (`convert::onnx::convert_onnx_to_gguf`) emit the same GGUF layout and share:
5//!
6//! * Parsing of the sibling `config.json`.
7//! * Writing Qwen3 metadata (architecture, dimensions, norm epsilon, rope base).
8//! * Padding `f32` weights to a multiple of the TQ2_0_g128 block size.
9//! * Serialising `BlockTQ2_0_g128` blocks into raw GGUF tensor bytes.
10//! * A single `ConvertStats` result struct so callers can report progress
11//!   uniformly.
12
13use std::path::Path;
14
15use anyhow::Context;
16use serde_json::Value;
17
18use oxibonsai_core::gguf::tensor_info::keys;
19use oxibonsai_core::gguf::writer::{GgufWriter, MetadataWriteValue};
20use oxibonsai_core::quant_ternary::{BlockTQ2_0_g128, BLOCK_TQ2_0_G128_BYTES};
21
22/// Statistics returned after a successful conversion.
23///
24/// Shared between the safetensors (`convert_hf_to_gguf`) and ONNX
25/// (`convert_onnx_to_gguf`) pipelines so CLI callers can report a single
26/// struct regardless of source format.
27#[derive(Debug, Clone, Default)]
28pub struct ConvertStats {
29    /// Total number of tensors written to the GGUF file.
30    pub n_tensors: usize,
31    /// Number of tensors quantized to TQ2_0_g128.
32    pub n_ternary: usize,
33    /// Number of tensors stored as FP32.
34    pub n_fp32: usize,
35    /// Total size of the output GGUF file in bytes.
36    pub output_bytes: usize,
37}
38
39/// Read and parse `config.json` at the given path.
40///
41/// Callers are responsible for locating the file (the safetensors path uses
42/// `from_dir.join("config.json")`; the ONNX path may need to search the ONNX
43/// parent and grandparent directories).
44pub fn read_config_json(config_path: &Path) -> anyhow::Result<Value> {
45    let raw = std::fs::read_to_string(config_path)
46        .with_context(|| format!("reading {:?}", config_path))?;
47    let value: Value =
48        serde_json::from_str(&raw).with_context(|| format!("parsing {:?}", config_path))?;
49    Ok(value)
50}
51
52/// Write Qwen3 metadata from `config.json` into a GGUF writer.
53///
54/// The caller provides the human-readable model name; for HF this is usually
55/// the directory basename, and for ONNX the `.onnx` file stem or repository
56/// identifier.
57pub fn write_metadata(
58    writer: &mut GgufWriter,
59    config: &Value,
60    model_name: &str,
61) -> anyhow::Result<()> {
62    // Architecture constant
63    writer.add_metadata(
64        keys::GENERAL_ARCHITECTURE,
65        MetadataWriteValue::Str("qwen3".to_string()),
66    );
67
68    // Human-readable model name
69    writer.add_metadata(
70        keys::GENERAL_NAME,
71        MetadataWriteValue::Str(model_name.to_string()),
72    );
73
74    // Quantisation version string
75    writer.add_metadata(
76        "general.quantization_version",
77        MetadataWriteValue::Str("TQ2_0_G128".to_string()),
78    );
79
80    // Integer keys (u32)
81    let u32_keys = [
82        (keys::LLM_BLOCK_COUNT, "num_hidden_layers"),
83        (keys::LLM_EMBEDDING_LENGTH, "hidden_size"),
84        (keys::LLM_FEED_FORWARD_LENGTH, "intermediate_size"),
85        (keys::LLM_ATTENTION_HEAD_COUNT, "num_attention_heads"),
86        (keys::LLM_ATTENTION_HEAD_COUNT_KV, "num_key_value_heads"),
87        (keys::LLM_CONTEXT_LENGTH, "max_position_embeddings"),
88        (keys::LLM_VOCAB_SIZE, "vocab_size"),
89    ];
90    for (gguf_key, json_key) in &u32_keys {
91        if let Some(val) = config.get(*json_key).and_then(Value::as_u64) {
92            writer.add_metadata(gguf_key, MetadataWriteValue::U32(val as u32));
93        } else {
94            tracing::warn!(json_key, "missing or non-u64 field in config.json");
95        }
96    }
97
98    // rms_norm_eps → F32
99    if let Some(eps) = config.get("rms_norm_eps").and_then(Value::as_f64) {
100        writer.add_metadata(
101            keys::LLM_ATTENTION_LAYER_NORM_RMS_EPSILON,
102            MetadataWriteValue::F32(eps as f32),
103        );
104    }
105
106    // rope_theta → F32 (default 10000.0 if absent)
107    //
108    // Resolution order:
109    //   1. `config["rope_theta"]`                    (top-level, legacy Qwen2 layout)
110    //   2. `config["rope_parameters"]["rope_theta"]` (nested, Qwen3 ONNX/newer layout)
111    //   3. 10000.0 fallback (with `tracing::warn!`)
112    let rope_theta = resolve_rope_theta(config);
113    writer.add_metadata(
114        keys::LLM_ROPE_FREQ_BASE,
115        MetadataWriteValue::F32(rope_theta as f32),
116    );
117
118    // If the nested `rope_parameters` block indicates YARN scaling, note it.
119    //
120    // The existing native `Ternary-Bonsai-1.7B.gguf` only carries
121    // `llm.rope.freq_base` — no `llm.rope.scaling.*` keys — so for now we log
122    // the YARN parameters at info level and rely on architecture defaults
123    // rather than inventing new metadata keys.
124    if let Some(rp) = config.get("rope_parameters").and_then(Value::as_object) {
125        let rope_type = rp.get("rope_type").and_then(Value::as_str).unwrap_or("");
126        if rope_type.eq_ignore_ascii_case("yarn") {
127            let factor = rp.get("factor").and_then(Value::as_f64);
128            let original_max_pos = rp
129                .get("original_max_position_embeddings")
130                .and_then(Value::as_u64);
131            tracing::info!(
132                ?factor,
133                ?original_max_pos,
134                "YARN rope_parameters detected; GGUF YARN metadata not plumbed, relying on architecture defaults"
135            );
136        }
137    }
138
139    Ok(())
140}
141
142/// Resolve `rope_theta` from a HuggingFace `config.json` value.
143///
144/// Looks in this order:
145///   1. Top-level `rope_theta` (legacy Qwen2 layout).
146///   2. Nested `rope_parameters.rope_theta` (Qwen3 ONNX/newer layout).
147///   3. Fallback `10000.0` with a `tracing::warn!` describing the absence.
148fn resolve_rope_theta(config: &Value) -> f64 {
149    if let Some(v) = config.get("rope_theta").and_then(Value::as_f64) {
150        return v;
151    }
152    if let Some(v) = config
153        .get("rope_parameters")
154        .and_then(|rp| rp.get("rope_theta"))
155        .and_then(Value::as_f64)
156    {
157        return v;
158    }
159    tracing::warn!(
160        "config.json missing both `rope_theta` and `rope_parameters.rope_theta`; \
161         falling back to default 10000.0"
162    );
163    10000.0
164}
165
166/// Pad an f32 slice to a multiple of 128 elements for TQ2_0_g128 quantisation.
167///
168/// If the length is already block-aligned, the slice is copied verbatim.
169/// Otherwise the tail is zero-padded up to the next multiple of 128.
170pub fn pad_to_multiple_of_128(f32_data: &[f32]) -> Vec<f32> {
171    let len = f32_data.len();
172    let remainder = len % 128;
173    if remainder == 0 {
174        f32_data.to_vec()
175    } else {
176        let padded_len = len + (128 - remainder);
177        let mut padded = f32_data.to_vec();
178        padded.resize(padded_len, 0.0_f32);
179        padded
180    }
181}
182
183/// Serialise a slice of `BlockTQ2_0_g128` blocks into raw bytes.
184///
185/// Each block is 34 bytes: 32 bytes of packed `qs` + 2 bytes of FP16 `d`.
186///
187/// # Safety
188///
189/// `BlockTQ2_0_g128` is `#[repr(C)]` with a compile-time size assertion of
190/// exactly 34 bytes. The cast is safe because we size the output slice using
191/// `blocks.len() * BLOCK_TQ2_0_G128_BYTES`.
192pub fn blocks_to_bytes(blocks: &[BlockTQ2_0_g128]) -> Vec<u8> {
193    let total = blocks.len() * BLOCK_TQ2_0_G128_BYTES;
194    // SAFETY: repr(C) layout with compile-time size check; byte length verified.
195    let bytes: &[u8] = unsafe { std::slice::from_raw_parts(blocks.as_ptr() as *const u8, total) };
196    bytes.to_vec()
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202    use serde_json::json;
203
204    #[test]
205    fn pad_aligned_is_identity() {
206        let v = vec![1.0_f32; 128];
207        assert_eq!(pad_to_multiple_of_128(&v), v);
208    }
209
210    #[test]
211    fn pad_extends_to_next_block() {
212        let v = vec![1.0_f32; 130];
213        let padded = pad_to_multiple_of_128(&v);
214        assert_eq!(padded.len(), 256);
215        assert_eq!(&padded[..130], &v[..]);
216        assert!(padded[130..].iter().all(|&x| x == 0.0));
217    }
218
219    #[test]
220    fn empty_input_stays_empty() {
221        let v: Vec<f32> = Vec::new();
222        assert!(pad_to_multiple_of_128(&v).is_empty());
223    }
224
225    #[test]
226    fn rope_theta_top_level_wins() {
227        // Legacy Qwen2 layout: `rope_theta` at the top level.
228        let cfg = json!({
229            "rope_theta": 500000.0,
230        });
231        assert_eq!(resolve_rope_theta(&cfg), 500000.0);
232    }
233
234    #[test]
235    fn rope_theta_nested_under_rope_parameters() {
236        // Qwen3 ONNX layout: nested under `rope_parameters`.
237        let cfg = json!({
238            "rope_parameters": {
239                "factor": 4.0,
240                "original_max_position_embeddings": 8192,
241                "rope_theta": 1_000_000.0,
242                "rope_type": "yarn",
243            },
244        });
245        assert_eq!(resolve_rope_theta(&cfg), 1_000_000.0);
246    }
247
248    #[test]
249    fn rope_theta_top_level_takes_precedence_over_nested() {
250        // If both are present, the top-level value wins.
251        let cfg = json!({
252            "rope_theta": 250000.0,
253            "rope_parameters": {
254                "rope_theta": 1_000_000.0,
255                "rope_type": "yarn",
256            },
257        });
258        assert_eq!(resolve_rope_theta(&cfg), 250000.0);
259    }
260
261    #[test]
262    fn rope_theta_fallback_when_missing() {
263        // Neither key present: fall back to the default 10000.0.
264        let cfg = json!({
265            "hidden_size": 2048,
266        });
267        assert_eq!(resolve_rope_theta(&cfg), 10000.0);
268    }
269}