Skip to main content

qlora_rs/
formats.rs

1//! Export format selection and unified interface.
2//!
3//! Provides a unified API for exporting quantized models in different formats
4//! with user-selectable backends.
5
6use std::path::Path;
7
8use crate::error::Result;
9use crate::quantization::QuantizedTensor;
10use crate::{export, native};
11
12/// Supported export formats.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
14pub enum ExportFormat {
15    /// GGUF format (compatible with llama.cpp ecosystem).
16    #[default]
17    Gguf,
18    /// Candle native format (optimized for Candle framework).
19    Native,
20}
21
22impl std::fmt::Display for ExportFormat {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        match self {
25            Self::Gguf => write!(f, "GGUF"),
26            Self::Native => write!(f, "Candle Native"),
27        }
28    }
29}
30
31/// Export configuration for quantized models.
32#[derive(Debug, Clone)]
33pub struct ExportConfig {
34    /// Target export format.
35    pub format: ExportFormat,
36    /// Model name for metadata.
37    pub model_name: String,
38    /// Model type for metadata.
39    pub model_type: String,
40}
41
42impl Default for ExportConfig {
43    fn default() -> Self {
44        Self {
45            format: ExportFormat::Gguf,
46            model_name: "qlora-model".to_string(),
47            model_type: "qlora".to_string(),
48        }
49    }
50}
51
52impl ExportConfig {
53    /// Create a new export configuration with GGUF format.
54    #[must_use]
55    pub fn new_gguf() -> Self {
56        Self {
57            format: ExportFormat::Gguf,
58            ..Default::default()
59        }
60    }
61
62    /// Create a new export configuration with native format.
63    #[must_use]
64    pub fn new_native() -> Self {
65        Self {
66            format: ExportFormat::Native,
67            ..Default::default()
68        }
69    }
70
71    /// Set the format for this export configuration.
72    #[must_use]
73    pub fn with_format(mut self, format: ExportFormat) -> Self {
74        self.format = format;
75        self
76    }
77
78    /// Set the model name for metadata.
79    #[must_use]
80    pub fn with_model_name(mut self, name: String) -> Self {
81        self.model_name = name;
82        self
83    }
84
85    /// Set the model type for metadata.
86    #[must_use]
87    pub fn with_model_type(mut self, model_type: String) -> Self {
88        self.model_type = model_type;
89        self
90    }
91}
92
93/// Export quantized tensors using the specified format.
94///
95/// # Arguments
96/// * `tensors` - Named quantized tensors to export
97/// * `config` - Export configuration with format selection
98/// * `output_path` - Path to write the exported file
99///
100/// # Errors
101/// Returns error if export fails
102pub fn export_model<P: AsRef<Path>>(
103    tensors: &[(&str, &QuantizedTensor)],
104    config: ExportConfig,
105    output_path: P,
106) -> Result<()> {
107    match config.format {
108        ExportFormat::Gguf => {
109            let metadata = export::GgufMetadata {
110                model_name: config.model_name,
111                model_type: config.model_type,
112                model_size: tensors.iter().map(|(_, t)| t.numel()).sum(),
113            };
114            export::export_gguf(tensors, Some(metadata), output_path)
115        }
116        ExportFormat::Native => {
117            let metadata = native::NativeMetadata {
118                model_name: config.model_name,
119                model_type: config.model_type,
120                compute_dtype: crate::quantization::ComputeDType::F32,
121            };
122            native::export_native(tensors, Some(metadata), output_path)
123        }
124    }
125}
126
127/// Export quantized tensors with default GGUF format.
128///
129/// # Arguments
130/// * `tensors` - Named quantized tensors to export
131/// * `output_path` - Path to write the GGUF file
132///
133/// # Errors
134/// Returns error if export fails
135pub fn export_gguf<P: AsRef<Path>>(
136    tensors: &[(&str, &QuantizedTensor)],
137    output_path: P,
138) -> Result<()> {
139    export_model(tensors, ExportConfig::new_gguf(), output_path)
140}
141
142/// Export quantized tensors to native Candle format.
143///
144/// # Arguments
145/// * `tensors` - Named quantized tensors to export
146/// * `output_path` - Path to write the native format file
147///
148/// # Errors
149/// Returns error if export fails
150pub fn export_native_format<P: AsRef<Path>>(
151    tensors: &[(&str, &QuantizedTensor)],
152    output_path: P,
153) -> Result<()> {
154    export_model(tensors, ExportConfig::new_native(), output_path)
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use crate::quantization::quantize_nf4;
161    use candle_core::{Device, Tensor};
162
163    #[test]
164    fn test_export_config_builder() {
165        let config = ExportConfig::default()
166            .with_format(ExportFormat::Native)
167            .with_model_name("my_model".to_string());
168
169        assert_eq!(config.format, ExportFormat::Native);
170        assert_eq!(config.model_name, "my_model");
171    }
172
173    #[test]
174    fn test_export_gguf_via_unified_api() {
175        let device = Device::Cpu;
176        let tensor = Tensor::zeros(&[32, 32], candle_core::DType::F32, &device).unwrap();
177        let quantized = quantize_nf4(&tensor, 64).unwrap();
178
179        let temp_path = std::env::temp_dir().join("test_unified_gguf.gguf");
180        export_gguf(&[("weights", &quantized)], &temp_path).unwrap();
181
182        assert!(std::fs::metadata(&temp_path).is_ok());
183        std::fs::remove_file(temp_path).ok();
184    }
185
186    #[test]
187    fn test_export_native_via_unified_api() {
188        let device = Device::Cpu;
189        let tensor = Tensor::zeros(&[32, 32], candle_core::DType::F32, &device).unwrap();
190        let quantized = quantize_nf4(&tensor, 64).unwrap();
191
192        let temp_path = std::env::temp_dir().join("test_unified_native.qnat");
193        export_native_format(&[("weights", &quantized)], &temp_path).unwrap();
194
195        assert!(std::fs::metadata(&temp_path).is_ok());
196        std::fs::remove_file(temp_path).ok();
197    }
198
199    #[test]
200    fn test_export_model_with_config() {
201        let device = Device::Cpu;
202        let tensor = Tensor::zeros(&[32, 32], candle_core::DType::F32, &device).unwrap();
203        let quantized = quantize_nf4(&tensor, 64).unwrap();
204
205        let config = ExportConfig::default()
206            .with_format(ExportFormat::Native)
207            .with_model_name("test_model".to_string())
208            .with_model_type("test".to_string());
209
210        let temp_path = std::env::temp_dir().join("test_config_export.qnat");
211        export_model(&[("weights", &quantized)], config, &temp_path).unwrap();
212
213        assert!(std::fs::metadata(&temp_path).is_ok());
214        std::fs::remove_file(temp_path).ok();
215    }
216}