Skip to main content

qlora_rs/
native.rs

1//! Candle native quantized format.
2//!
3//! A native Rust serialization format optimized for the Candle framework.
4//! This format is designed for efficient loading and inference without conversion.
5//!
6//! File structure:
7//! ```text
8//! Header (32 bytes)
9//!   - Magic: "QNAT" (4 bytes)
10//!   - Version: u32 (4 bytes)
11//!   - Format flags: u32 (4 bytes)
12//!   - Metadata size: u64 (8 bytes)
13//!   - Tensor count: u32 (4 bytes)
14//!   - Reserved: u32 (4 bytes)
15//!
16//! Metadata section (variable)
17//!   - Model name (length-prefixed string)
18//!   - Model type (length-prefixed string)
19//!   - Compute dtype (1 byte)
20//!
21//! Tensor headers (variable, one per tensor)
22//!   - Name (length-prefixed string)
23//!   - Shape (length-prefixed u64 array)
24//!   - Block size: u32
25//!   - Offset: u64
26//!   - Num blocks: u32
27//!
28//! Tensor data (variable)
29//!   - Quantized values (2 values per byte)
30//!   - Scale factors:
31//!       - Either per-block f32 scales (one f32 per block), or
32//!       - Double-quantized scales stored as `scales_quantized` and `scales_scales`
33//!   - Zero points (optional, f32 per block)
34//! ```
35
36use std::io::Write;
37use std::path::Path;
38
39use crate::error::{QLoraError, Result};
40use crate::quantization::{ComputeDType, QuantizedTensor};
41
42/// Magic bytes for Candle native quantized format.
43const MAGIC: &[u8; 4] = b"QNAT";
44
45/// Version of the Candle native format.
46const VERSION: u32 = 1;
47
48/// Format flags (reserved for future extensions).
49const FORMAT_FLAGS: u32 = 0;
50
51/// Metadata for Candle native format.
52#[derive(Debug, Clone)]
53pub struct NativeMetadata {
54    /// Name of the model.
55    pub model_name: String,
56    /// Type of the model (e.g., "qlora").
57    pub model_type: String,
58    /// Compute data type for the model.
59    pub compute_dtype: ComputeDType,
60}
61
62impl Default for NativeMetadata {
63    fn default() -> Self {
64        Self {
65            model_name: "qlora-model".to_string(),
66            model_type: "qlora".to_string(),
67            compute_dtype: ComputeDType::F32,
68        }
69    }
70}
71
72/// Export quantized tensors to Candle native format.
73///
74/// # Arguments
75/// * `tensors` - Named quantized tensors to export
76/// * `metadata` - Optional metadata for the model
77/// * `output_path` - Path to write the native format file
78///
79/// # Errors
80/// Returns error if file cannot be written
81pub fn export_native<P: AsRef<Path>>(
82    tensors: &[(&str, &QuantizedTensor)],
83    metadata: Option<NativeMetadata>,
84    output_path: P,
85) -> Result<()> {
86    let mut file = std::fs::File::create(output_path)
87        .map_err(|e| QLoraError::NativeExport(format!("Failed to create output file: {e}")))?;
88
89    let metadata = metadata.unwrap_or_default();
90
91    // Write header
92    write_header(&mut file, tensors.len())?;
93
94    // Write metadata
95    write_metadata(&mut file, &metadata)?;
96
97    // Write tensor headers and calculate offsets
98    let _tensor_offsets = write_tensor_headers(&mut file, tensors)?;
99
100    // Write tensor data sequentially in the same order as the headers
101    for (_name, tensor) in tensors {
102        write_tensor_data(&mut file, tensor)?;
103    }
104
105    Ok(())
106}
107
108/// Write file header.
109fn write_header<W: Write>(writer: &mut W, tensor_count: usize) -> Result<()> {
110    // Magic
111    writer
112        .write_all(MAGIC)
113        .map_err(|e| QLoraError::NativeExport(format!("Failed to write magic: {e}")))?;
114
115    // Version
116    writer
117        .write_all(&VERSION.to_le_bytes())
118        .map_err(|e| QLoraError::NativeExport(format!("Failed to write version: {e}")))?;
119
120    // Format flags
121    writer
122        .write_all(&FORMAT_FLAGS.to_le_bytes())
123        .map_err(|e| QLoraError::NativeExport(format!("Failed to write flags: {e}")))?;
124
125    // Metadata size (placeholder, will be updated later if needed)
126    writer
127        .write_all(&0u64.to_le_bytes())
128        .map_err(|e| QLoraError::NativeExport(format!("Failed to write metadata size: {e}")))?;
129
130    // Tensor count
131    let count = u32::try_from(tensor_count)
132        .map_err(|_| QLoraError::NativeExport("Too many tensors".into()))?;
133    writer
134        .write_all(&count.to_le_bytes())
135        .map_err(|e| QLoraError::NativeExport(format!("Failed to write tensor count: {e}")))?;
136
137    // Reserved
138    writer
139        .write_all(&0u32.to_le_bytes())
140        .map_err(|e| QLoraError::NativeExport(format!("Failed to write reserved: {e}")))?;
141
142    Ok(())
143}
144
145/// Write metadata section.
146fn write_metadata<W: Write>(writer: &mut W, metadata: &NativeMetadata) -> Result<()> {
147    // Model name
148    write_string(writer, &metadata.model_name)?;
149
150    // Model type
151    write_string(writer, &metadata.model_type)?;
152
153    // Compute dtype
154    let dtype_byte = match metadata.compute_dtype {
155        ComputeDType::F32 => 0u8,
156        ComputeDType::F16 => 1u8,
157        ComputeDType::BF16 => 2u8,
158    };
159    writer
160        .write_all(&[dtype_byte])
161        .map_err(|e| QLoraError::NativeExport(format!("Failed to write compute dtype: {e}")))?;
162
163    Ok(())
164}
165
166/// Write tensor headers and return offsets.
167fn write_tensor_headers<W: Write>(
168    writer: &mut W,
169    tensors: &[(&str, &QuantizedTensor)],
170) -> Result<Vec<u64>> {
171    let mut offsets = Vec::new();
172    let mut current_offset = calculate_header_size(tensors);
173
174    for (_name, tensor) in tensors {
175        offsets.push(current_offset as u64);
176        current_offset += calculate_tensor_size(tensor);
177    }
178
179    // Write tensor headers
180    for ((name, tensor), offset) in tensors.iter().zip(offsets.iter()) {
181        write_string(writer, name)?;
182
183        // Shape
184        let shape_len = u32::try_from(tensor.shape.len())
185            .map_err(|_| QLoraError::NativeExport("Tensor shape too large".into()))?;
186        writer
187            .write_all(&shape_len.to_le_bytes())
188            .map_err(|e| QLoraError::NativeExport(format!("Failed to write shape length: {e}")))?;
189
190        for &dim in &tensor.shape {
191            let dim = u64::try_from(dim)
192                .map_err(|_| QLoraError::NativeExport("Dimension too large".into()))?;
193            writer
194                .write_all(&dim.to_le_bytes())
195                .map_err(|e| QLoraError::NativeExport(format!("Failed to write dimension: {e}")))?;
196        }
197
198        // Block size
199        let block_size = u32::try_from(tensor.block_size)
200            .map_err(|_| QLoraError::NativeExport("Block size too large".into()))?;
201        writer
202            .write_all(&block_size.to_le_bytes())
203            .map_err(|e| QLoraError::NativeExport(format!("Failed to write block size: {e}")))?;
204
205        // Offset
206        writer
207            .write_all(&offset.to_le_bytes())
208            .map_err(|e| QLoraError::NativeExport(format!("Failed to write offset: {e}")))?;
209
210        // Number of blocks
211        let num_blocks = u32::try_from(tensor.scales.len())
212            .map_err(|_| QLoraError::NativeExport("Too many blocks".into()))?;
213        writer
214            .write_all(&num_blocks.to_le_bytes())
215            .map_err(|e| QLoraError::NativeExport(format!("Failed to write block count: {e}")))?;
216    }
217
218    Ok(offsets)
219}
220
221/// Write tensor data (quantized values, scales, optional zero points).
222fn write_tensor_data<W: Write>(writer: &mut W, tensor: &QuantizedTensor) -> Result<()> {
223    // Quantized values
224    writer
225        .write_all(&tensor.data)
226        .map_err(|e| QLoraError::NativeExport(format!("Failed to write quantized data: {e}")))?;
227
228    // Scale factors
229    for &scale in &tensor.scales {
230        writer
231            .write_all(&scale.to_le_bytes())
232            .map_err(|e| QLoraError::NativeExport(format!("Failed to write scale: {e}")))?;
233    }
234
235    // Zero points (if present)
236    if let Some(ref zp) = tensor.zero_points {
237        for &zp_val in zp {
238            writer.write_all(&zp_val.to_le_bytes()).map_err(|e| {
239                QLoraError::NativeExport(format!("Failed to write zero point: {e}"))
240            })?;
241        }
242    }
243
244    // Double-quantized scale data (if present)
245    if let Some(ref scales_q) = tensor.scales_quantized {
246        writer.write_all(scales_q).map_err(|e| {
247            QLoraError::NativeExport(format!("Failed to write double-quantized scales: {e}"))
248        })?;
249    }
250    if let Some(ref scales_s) = tensor.scales_scales {
251        for &scale_s in scales_s {
252            writer.write_all(&scale_s.to_le_bytes()).map_err(|e| {
253                QLoraError::NativeExport(format!(
254                    "Failed to write double-quantized scale factors: {e}"
255                ))
256            })?;
257        }
258    }
259
260    Ok(())
261}
262
263/// Write a length-prefixed string.
264fn write_string<W: Write>(writer: &mut W, s: &str) -> Result<()> {
265    let bytes = s.as_bytes();
266    let len = u32::try_from(bytes.len())
267        .map_err(|_| QLoraError::NativeExport("String too long".into()))?;
268    writer
269        .write_all(&len.to_le_bytes())
270        .map_err(|e| QLoraError::NativeExport(format!("Failed to write string length: {e}")))?;
271    writer
272        .write_all(bytes)
273        .map_err(|e| QLoraError::NativeExport(format!("Failed to write string: {e}")))?;
274    Ok(())
275}
276
277/// Calculate size of header (before tensor data).
278fn calculate_header_size(tensors: &[(&str, &QuantizedTensor)]) -> usize {
279    let mut size = 32; // Fixed header size
280
281    // Metadata (estimate: names + compute dtype)
282    size += 4 + 11 + 4 + 8 + 1; // "qlora-model" + "qlora" + dtype
283
284    // Tensor headers
285    for (name, tensor) in tensors {
286        size += 4 + name.len(); // name
287        size += 4 + tensor.shape.len() * 8; // shape
288        size += 4; // block size
289        size += 8; // offset
290        size += 4; // num blocks
291    }
292
293    size
294}
295
296/// Calculate size of tensor data.
297fn calculate_tensor_size(tensor: &QuantizedTensor) -> usize {
298    let mut size = tensor.data.len(); // quantized values
299    size += tensor.scales.len() * 4; // scale factors
300    if let Some(ref zp) = tensor.zero_points {
301        size += zp.len() * 4; // zero points
302    }
303    // Double-quantized scale data (if present)
304    if let Some(ref scales_q) = tensor.scales_quantized {
305        size += std::mem::size_of_val(scales_q.as_slice());
306    }
307    if let Some(ref scales_s) = tensor.scales_scales {
308        size += std::mem::size_of_val(scales_s.as_slice());
309    }
310    size
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use crate::quantization::quantize_nf4;
317    use candle_core::{Device, Tensor};
318    use std::io::Read;
319
320    #[test]
321    fn test_export_native_basic() {
322        let device = Device::Cpu;
323        let tensor = Tensor::zeros(&[64, 64], candle_core::DType::F32, &device).unwrap();
324        let quantized = quantize_nf4(&tensor, 64).unwrap();
325
326        let temp_path = std::env::temp_dir().join("test_native.qnat");
327        export_native(&[("test_tensor", &quantized)], None, &temp_path).unwrap();
328
329        // Verify magic bytes
330        let mut file = std::fs::File::open(&temp_path).unwrap();
331        let mut magic = [0u8; 4];
332        file.read_exact(&mut magic).unwrap();
333        assert_eq!(&magic, MAGIC);
334
335        std::fs::remove_file(temp_path).ok();
336    }
337
338    #[test]
339    fn test_export_native_with_metadata() {
340        let device = Device::Cpu;
341        let tensor = Tensor::zeros(&[32, 32], candle_core::DType::F32, &device).unwrap();
342        let quantized = quantize_nf4(&tensor, 64).unwrap();
343
344        let metadata = NativeMetadata {
345            model_name: "test_model".to_string(),
346            model_type: "test".to_string(),
347            compute_dtype: ComputeDType::F32,
348        };
349
350        let temp_path = std::env::temp_dir().join("test_native_meta.qnat");
351        export_native(&[("weights", &quantized)], Some(metadata), &temp_path).unwrap();
352
353        // Verify file was created
354        let file_meta = std::fs::metadata(&temp_path).unwrap();
355        assert!(file_meta.len() > 0);
356
357        std::fs::remove_file(temp_path).ok();
358    }
359
360    #[test]
361    fn test_export_native_multiple_tensors() {
362        let device = Device::Cpu;
363        let t1 = Tensor::zeros(&[64, 64], candle_core::DType::F32, &device).unwrap();
364        let t2 = Tensor::zeros(&[32, 32], candle_core::DType::F32, &device).unwrap();
365        let q1 = quantize_nf4(&t1, 64).unwrap();
366        let q2 = quantize_nf4(&t2, 64).unwrap();
367
368        let temp_path = std::env::temp_dir().join("test_native_multi.qnat");
369        export_native(&[("w1", &q1), ("w2", &q2)], None, &temp_path).unwrap();
370
371        // Verify file structure
372        let mut file = std::fs::File::open(&temp_path).unwrap();
373        let mut magic = [0u8; 4];
374        file.read_exact(&mut magic).unwrap();
375        assert_eq!(&magic, MAGIC);
376
377        std::fs::remove_file(temp_path).ok();
378    }
379}