oxibonsai_model/convert/
common.rs1use std::path::Path;
14
15use anyhow::Context;
16use serde_json::Value;
17
18use oxibonsai_core::gguf::tensor_info::keys;
19use oxibonsai_core::gguf::writer::{GgufWriter, MetadataWriteValue};
20use oxibonsai_core::quant_ternary::{BlockTQ2_0_g128, BLOCK_TQ2_0_G128_BYTES};
21
22#[derive(Debug, Clone, Default)]
28pub struct ConvertStats {
29 pub n_tensors: usize,
31 pub n_ternary: usize,
33 pub n_fp32: usize,
35 pub output_bytes: usize,
37}
38
39pub fn read_config_json(config_path: &Path) -> anyhow::Result<Value> {
45 let raw = std::fs::read_to_string(config_path)
46 .with_context(|| format!("reading {:?}", config_path))?;
47 let value: Value =
48 serde_json::from_str(&raw).with_context(|| format!("parsing {:?}", config_path))?;
49 Ok(value)
50}
51
52pub fn write_metadata(
58 writer: &mut GgufWriter,
59 config: &Value,
60 model_name: &str,
61) -> anyhow::Result<()> {
62 writer.add_metadata(
64 keys::GENERAL_ARCHITECTURE,
65 MetadataWriteValue::Str("qwen3".to_string()),
66 );
67
68 writer.add_metadata(
70 keys::GENERAL_NAME,
71 MetadataWriteValue::Str(model_name.to_string()),
72 );
73
74 writer.add_metadata(
76 "general.quantization_version",
77 MetadataWriteValue::Str("TQ2_0_G128".to_string()),
78 );
79
80 let u32_keys = [
82 (keys::LLM_BLOCK_COUNT, "num_hidden_layers"),
83 (keys::LLM_EMBEDDING_LENGTH, "hidden_size"),
84 (keys::LLM_FEED_FORWARD_LENGTH, "intermediate_size"),
85 (keys::LLM_ATTENTION_HEAD_COUNT, "num_attention_heads"),
86 (keys::LLM_ATTENTION_HEAD_COUNT_KV, "num_key_value_heads"),
87 (keys::LLM_CONTEXT_LENGTH, "max_position_embeddings"),
88 (keys::LLM_VOCAB_SIZE, "vocab_size"),
89 ];
90 for (gguf_key, json_key) in &u32_keys {
91 if let Some(val) = config.get(*json_key).and_then(Value::as_u64) {
92 writer.add_metadata(gguf_key, MetadataWriteValue::U32(val as u32));
93 } else {
94 tracing::warn!(json_key, "missing or non-u64 field in config.json");
95 }
96 }
97
98 if let Some(eps) = config.get("rms_norm_eps").and_then(Value::as_f64) {
100 writer.add_metadata(
101 keys::LLM_ATTENTION_LAYER_NORM_RMS_EPSILON,
102 MetadataWriteValue::F32(eps as f32),
103 );
104 }
105
106 let rope_theta = resolve_rope_theta(config);
113 writer.add_metadata(
114 keys::LLM_ROPE_FREQ_BASE,
115 MetadataWriteValue::F32(rope_theta as f32),
116 );
117
118 if let Some(rp) = config.get("rope_parameters").and_then(Value::as_object) {
125 let rope_type = rp.get("rope_type").and_then(Value::as_str).unwrap_or("");
126 if rope_type.eq_ignore_ascii_case("yarn") {
127 let factor = rp.get("factor").and_then(Value::as_f64);
128 let original_max_pos = rp
129 .get("original_max_position_embeddings")
130 .and_then(Value::as_u64);
131 tracing::info!(
132 ?factor,
133 ?original_max_pos,
134 "YARN rope_parameters detected; GGUF YARN metadata not plumbed, relying on architecture defaults"
135 );
136 }
137 }
138
139 Ok(())
140}
141
142fn resolve_rope_theta(config: &Value) -> f64 {
149 if let Some(v) = config.get("rope_theta").and_then(Value::as_f64) {
150 return v;
151 }
152 if let Some(v) = config
153 .get("rope_parameters")
154 .and_then(|rp| rp.get("rope_theta"))
155 .and_then(Value::as_f64)
156 {
157 return v;
158 }
159 tracing::warn!(
160 "config.json missing both `rope_theta` and `rope_parameters.rope_theta`; \
161 falling back to default 10000.0"
162 );
163 10000.0
164}
165
166pub fn pad_to_multiple_of_128(f32_data: &[f32]) -> Vec<f32> {
171 let len = f32_data.len();
172 let remainder = len % 128;
173 if remainder == 0 {
174 f32_data.to_vec()
175 } else {
176 let padded_len = len + (128 - remainder);
177 let mut padded = f32_data.to_vec();
178 padded.resize(padded_len, 0.0_f32);
179 padded
180 }
181}
182
183pub fn blocks_to_bytes(blocks: &[BlockTQ2_0_g128]) -> Vec<u8> {
193 let total = blocks.len() * BLOCK_TQ2_0_G128_BYTES;
194 let bytes: &[u8] = unsafe { std::slice::from_raw_parts(blocks.as_ptr() as *const u8, total) };
196 bytes.to_vec()
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202 use serde_json::json;
203
204 #[test]
205 fn pad_aligned_is_identity() {
206 let v = vec![1.0_f32; 128];
207 assert_eq!(pad_to_multiple_of_128(&v), v);
208 }
209
210 #[test]
211 fn pad_extends_to_next_block() {
212 let v = vec![1.0_f32; 130];
213 let padded = pad_to_multiple_of_128(&v);
214 assert_eq!(padded.len(), 256);
215 assert_eq!(&padded[..130], &v[..]);
216 assert!(padded[130..].iter().all(|&x| x == 0.0));
217 }
218
219 #[test]
220 fn empty_input_stays_empty() {
221 let v: Vec<f32> = Vec::new();
222 assert!(pad_to_multiple_of_128(&v).is_empty());
223 }
224
225 #[test]
226 fn rope_theta_top_level_wins() {
227 let cfg = json!({
229 "rope_theta": 500000.0,
230 });
231 assert_eq!(resolve_rope_theta(&cfg), 500000.0);
232 }
233
234 #[test]
235 fn rope_theta_nested_under_rope_parameters() {
236 let cfg = json!({
238 "rope_parameters": {
239 "factor": 4.0,
240 "original_max_position_embeddings": 8192,
241 "rope_theta": 1_000_000.0,
242 "rope_type": "yarn",
243 },
244 });
245 assert_eq!(resolve_rope_theta(&cfg), 1_000_000.0);
246 }
247
248 #[test]
249 fn rope_theta_top_level_takes_precedence_over_nested() {
250 let cfg = json!({
252 "rope_theta": 250000.0,
253 "rope_parameters": {
254 "rope_theta": 1_000_000.0,
255 "rope_type": "yarn",
256 },
257 });
258 assert_eq!(resolve_rope_theta(&cfg), 250000.0);
259 }
260
261 #[test]
262 fn rope_theta_fallback_when_missing() {
263 let cfg = json!({
265 "hidden_size": 2048,
266 });
267 assert_eq!(resolve_rope_theta(&cfg), 10000.0);
268 }
269}