1use std::io::Write;
37use std::path::Path;
38
39use crate::error::{QLoraError, Result};
40use crate::quantization::{ComputeDType, QuantizedTensor};
41
42const MAGIC: &[u8; 4] = b"QNAT";
44
45const VERSION: u32 = 1;
47
48const FORMAT_FLAGS: u32 = 0;
50
51#[derive(Debug, Clone)]
53pub struct NativeMetadata {
54 pub model_name: String,
56 pub model_type: String,
58 pub compute_dtype: ComputeDType,
60}
61
62impl Default for NativeMetadata {
63 fn default() -> Self {
64 Self {
65 model_name: "qlora-model".to_string(),
66 model_type: "qlora".to_string(),
67 compute_dtype: ComputeDType::F32,
68 }
69 }
70}
71
72pub fn export_native<P: AsRef<Path>>(
82 tensors: &[(&str, &QuantizedTensor)],
83 metadata: Option<NativeMetadata>,
84 output_path: P,
85) -> Result<()> {
86 let mut file = std::fs::File::create(output_path)
87 .map_err(|e| QLoraError::NativeExport(format!("Failed to create output file: {e}")))?;
88
89 let metadata = metadata.unwrap_or_default();
90
91 write_header(&mut file, tensors.len())?;
93
94 write_metadata(&mut file, &metadata)?;
96
97 let _tensor_offsets = write_tensor_headers(&mut file, tensors)?;
99
100 for (_name, tensor) in tensors {
102 write_tensor_data(&mut file, tensor)?;
103 }
104
105 Ok(())
106}
107
108fn write_header<W: Write>(writer: &mut W, tensor_count: usize) -> Result<()> {
110 writer
112 .write_all(MAGIC)
113 .map_err(|e| QLoraError::NativeExport(format!("Failed to write magic: {e}")))?;
114
115 writer
117 .write_all(&VERSION.to_le_bytes())
118 .map_err(|e| QLoraError::NativeExport(format!("Failed to write version: {e}")))?;
119
120 writer
122 .write_all(&FORMAT_FLAGS.to_le_bytes())
123 .map_err(|e| QLoraError::NativeExport(format!("Failed to write flags: {e}")))?;
124
125 writer
127 .write_all(&0u64.to_le_bytes())
128 .map_err(|e| QLoraError::NativeExport(format!("Failed to write metadata size: {e}")))?;
129
130 let count = u32::try_from(tensor_count)
132 .map_err(|_| QLoraError::NativeExport("Too many tensors".into()))?;
133 writer
134 .write_all(&count.to_le_bytes())
135 .map_err(|e| QLoraError::NativeExport(format!("Failed to write tensor count: {e}")))?;
136
137 writer
139 .write_all(&0u32.to_le_bytes())
140 .map_err(|e| QLoraError::NativeExport(format!("Failed to write reserved: {e}")))?;
141
142 Ok(())
143}
144
145fn write_metadata<W: Write>(writer: &mut W, metadata: &NativeMetadata) -> Result<()> {
147 write_string(writer, &metadata.model_name)?;
149
150 write_string(writer, &metadata.model_type)?;
152
153 let dtype_byte = match metadata.compute_dtype {
155 ComputeDType::F32 => 0u8,
156 ComputeDType::F16 => 1u8,
157 ComputeDType::BF16 => 2u8,
158 };
159 writer
160 .write_all(&[dtype_byte])
161 .map_err(|e| QLoraError::NativeExport(format!("Failed to write compute dtype: {e}")))?;
162
163 Ok(())
164}
165
166fn write_tensor_headers<W: Write>(
168 writer: &mut W,
169 tensors: &[(&str, &QuantizedTensor)],
170) -> Result<Vec<u64>> {
171 let mut offsets = Vec::new();
172 let mut current_offset = calculate_header_size(tensors);
173
174 for (_name, tensor) in tensors {
175 offsets.push(current_offset as u64);
176 current_offset += calculate_tensor_size(tensor);
177 }
178
179 for ((name, tensor), offset) in tensors.iter().zip(offsets.iter()) {
181 write_string(writer, name)?;
182
183 let shape_len = u32::try_from(tensor.shape.len())
185 .map_err(|_| QLoraError::NativeExport("Tensor shape too large".into()))?;
186 writer
187 .write_all(&shape_len.to_le_bytes())
188 .map_err(|e| QLoraError::NativeExport(format!("Failed to write shape length: {e}")))?;
189
190 for &dim in &tensor.shape {
191 let dim = u64::try_from(dim)
192 .map_err(|_| QLoraError::NativeExport("Dimension too large".into()))?;
193 writer
194 .write_all(&dim.to_le_bytes())
195 .map_err(|e| QLoraError::NativeExport(format!("Failed to write dimension: {e}")))?;
196 }
197
198 let block_size = u32::try_from(tensor.block_size)
200 .map_err(|_| QLoraError::NativeExport("Block size too large".into()))?;
201 writer
202 .write_all(&block_size.to_le_bytes())
203 .map_err(|e| QLoraError::NativeExport(format!("Failed to write block size: {e}")))?;
204
205 writer
207 .write_all(&offset.to_le_bytes())
208 .map_err(|e| QLoraError::NativeExport(format!("Failed to write offset: {e}")))?;
209
210 let num_blocks = u32::try_from(tensor.scales.len())
212 .map_err(|_| QLoraError::NativeExport("Too many blocks".into()))?;
213 writer
214 .write_all(&num_blocks.to_le_bytes())
215 .map_err(|e| QLoraError::NativeExport(format!("Failed to write block count: {e}")))?;
216 }
217
218 Ok(offsets)
219}
220
221fn write_tensor_data<W: Write>(writer: &mut W, tensor: &QuantizedTensor) -> Result<()> {
223 writer
225 .write_all(&tensor.data)
226 .map_err(|e| QLoraError::NativeExport(format!("Failed to write quantized data: {e}")))?;
227
228 for &scale in &tensor.scales {
230 writer
231 .write_all(&scale.to_le_bytes())
232 .map_err(|e| QLoraError::NativeExport(format!("Failed to write scale: {e}")))?;
233 }
234
235 if let Some(ref zp) = tensor.zero_points {
237 for &zp_val in zp {
238 writer.write_all(&zp_val.to_le_bytes()).map_err(|e| {
239 QLoraError::NativeExport(format!("Failed to write zero point: {e}"))
240 })?;
241 }
242 }
243
244 if let Some(ref scales_q) = tensor.scales_quantized {
246 writer.write_all(scales_q).map_err(|e| {
247 QLoraError::NativeExport(format!("Failed to write double-quantized scales: {e}"))
248 })?;
249 }
250 if let Some(ref scales_s) = tensor.scales_scales {
251 for &scale_s in scales_s {
252 writer.write_all(&scale_s.to_le_bytes()).map_err(|e| {
253 QLoraError::NativeExport(format!(
254 "Failed to write double-quantized scale factors: {e}"
255 ))
256 })?;
257 }
258 }
259
260 Ok(())
261}
262
263fn write_string<W: Write>(writer: &mut W, s: &str) -> Result<()> {
265 let bytes = s.as_bytes();
266 let len = u32::try_from(bytes.len())
267 .map_err(|_| QLoraError::NativeExport("String too long".into()))?;
268 writer
269 .write_all(&len.to_le_bytes())
270 .map_err(|e| QLoraError::NativeExport(format!("Failed to write string length: {e}")))?;
271 writer
272 .write_all(bytes)
273 .map_err(|e| QLoraError::NativeExport(format!("Failed to write string: {e}")))?;
274 Ok(())
275}
276
277fn calculate_header_size(tensors: &[(&str, &QuantizedTensor)]) -> usize {
279 let mut size = 32; size += 4 + 11 + 4 + 8 + 1; for (name, tensor) in tensors {
286 size += 4 + name.len(); size += 4 + tensor.shape.len() * 8; size += 4; size += 8; size += 4; }
292
293 size
294}
295
296fn calculate_tensor_size(tensor: &QuantizedTensor) -> usize {
298 let mut size = tensor.data.len(); size += tensor.scales.len() * 4; if let Some(ref zp) = tensor.zero_points {
301 size += zp.len() * 4; }
303 if let Some(ref scales_q) = tensor.scales_quantized {
305 size += std::mem::size_of_val(scales_q.as_slice());
306 }
307 if let Some(ref scales_s) = tensor.scales_scales {
308 size += std::mem::size_of_val(scales_s.as_slice());
309 }
310 size
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316 use crate::quantization::quantize_nf4;
317 use candle_core::{Device, Tensor};
318 use std::io::Read;
319
320 #[test]
321 fn test_export_native_basic() {
322 let device = Device::Cpu;
323 let tensor = Tensor::zeros(&[64, 64], candle_core::DType::F32, &device).unwrap();
324 let quantized = quantize_nf4(&tensor, 64).unwrap();
325
326 let temp_path = std::env::temp_dir().join("test_native.qnat");
327 export_native(&[("test_tensor", &quantized)], None, &temp_path).unwrap();
328
329 let mut file = std::fs::File::open(&temp_path).unwrap();
331 let mut magic = [0u8; 4];
332 file.read_exact(&mut magic).unwrap();
333 assert_eq!(&magic, MAGIC);
334
335 std::fs::remove_file(temp_path).ok();
336 }
337
338 #[test]
339 fn test_export_native_with_metadata() {
340 let device = Device::Cpu;
341 let tensor = Tensor::zeros(&[32, 32], candle_core::DType::F32, &device).unwrap();
342 let quantized = quantize_nf4(&tensor, 64).unwrap();
343
344 let metadata = NativeMetadata {
345 model_name: "test_model".to_string(),
346 model_type: "test".to_string(),
347 compute_dtype: ComputeDType::F32,
348 };
349
350 let temp_path = std::env::temp_dir().join("test_native_meta.qnat");
351 export_native(&[("weights", &quantized)], Some(metadata), &temp_path).unwrap();
352
353 let file_meta = std::fs::metadata(&temp_path).unwrap();
355 assert!(file_meta.len() > 0);
356
357 std::fs::remove_file(temp_path).ok();
358 }
359
360 #[test]
361 fn test_export_native_multiple_tensors() {
362 let device = Device::Cpu;
363 let t1 = Tensor::zeros(&[64, 64], candle_core::DType::F32, &device).unwrap();
364 let t2 = Tensor::zeros(&[32, 32], candle_core::DType::F32, &device).unwrap();
365 let q1 = quantize_nf4(&t1, 64).unwrap();
366 let q2 = quantize_nf4(&t2, 64).unwrap();
367
368 let temp_path = std::env::temp_dir().join("test_native_multi.qnat");
369 export_native(&[("w1", &q1), ("w2", &q2)], None, &temp_path).unwrap();
370
371 let mut file = std::fs::File::open(&temp_path).unwrap();
373 let mut magic = [0u8; 4];
374 file.read_exact(&mut magic).unwrap();
375 assert_eq!(&magic, MAGIC);
376
377 std::fs::remove_file(temp_path).ok();
378 }
379}