1use std::path::Path;
7
8use crate::error::Result;
9use crate::quantization::QuantizedTensor;
10use crate::{export, native};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
14pub enum ExportFormat {
15 #[default]
17 Gguf,
18 Native,
20}
21
22impl std::fmt::Display for ExportFormat {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 match self {
25 Self::Gguf => write!(f, "GGUF"),
26 Self::Native => write!(f, "Candle Native"),
27 }
28 }
29}
30
31#[derive(Debug, Clone)]
33pub struct ExportConfig {
34 pub format: ExportFormat,
36 pub model_name: String,
38 pub model_type: String,
40}
41
42impl Default for ExportConfig {
43 fn default() -> Self {
44 Self {
45 format: ExportFormat::Gguf,
46 model_name: "qlora-model".to_string(),
47 model_type: "qlora".to_string(),
48 }
49 }
50}
51
52impl ExportConfig {
53 #[must_use]
55 pub fn new_gguf() -> Self {
56 Self {
57 format: ExportFormat::Gguf,
58 ..Default::default()
59 }
60 }
61
62 #[must_use]
64 pub fn new_native() -> Self {
65 Self {
66 format: ExportFormat::Native,
67 ..Default::default()
68 }
69 }
70
71 #[must_use]
73 pub fn with_format(mut self, format: ExportFormat) -> Self {
74 self.format = format;
75 self
76 }
77
78 #[must_use]
80 pub fn with_model_name(mut self, name: String) -> Self {
81 self.model_name = name;
82 self
83 }
84
85 #[must_use]
87 pub fn with_model_type(mut self, model_type: String) -> Self {
88 self.model_type = model_type;
89 self
90 }
91}
92
93pub fn export_model<P: AsRef<Path>>(
103 tensors: &[(&str, &QuantizedTensor)],
104 config: ExportConfig,
105 output_path: P,
106) -> Result<()> {
107 match config.format {
108 ExportFormat::Gguf => {
109 let metadata = export::GgufMetadata {
110 model_name: config.model_name,
111 model_type: config.model_type,
112 model_size: tensors.iter().map(|(_, t)| t.numel()).sum(),
113 };
114 export::export_gguf(tensors, Some(metadata), output_path)
115 }
116 ExportFormat::Native => {
117 let metadata = native::NativeMetadata {
118 model_name: config.model_name,
119 model_type: config.model_type,
120 compute_dtype: crate::quantization::ComputeDType::F32,
121 };
122 native::export_native(tensors, Some(metadata), output_path)
123 }
124 }
125}
126
127pub fn export_gguf<P: AsRef<Path>>(
136 tensors: &[(&str, &QuantizedTensor)],
137 output_path: P,
138) -> Result<()> {
139 export_model(tensors, ExportConfig::new_gguf(), output_path)
140}
141
142pub fn export_native_format<P: AsRef<Path>>(
151 tensors: &[(&str, &QuantizedTensor)],
152 output_path: P,
153) -> Result<()> {
154 export_model(tensors, ExportConfig::new_native(), output_path)
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use crate::quantization::quantize_nf4;
161 use candle_core::{Device, Tensor};
162
163 #[test]
164 fn test_export_config_builder() {
165 let config = ExportConfig::default()
166 .with_format(ExportFormat::Native)
167 .with_model_name("my_model".to_string());
168
169 assert_eq!(config.format, ExportFormat::Native);
170 assert_eq!(config.model_name, "my_model");
171 }
172
173 #[test]
174 fn test_export_gguf_via_unified_api() {
175 let device = Device::Cpu;
176 let tensor = Tensor::zeros(&[32, 32], candle_core::DType::F32, &device).unwrap();
177 let quantized = quantize_nf4(&tensor, 64).unwrap();
178
179 let temp_path = std::env::temp_dir().join("test_unified_gguf.gguf");
180 export_gguf(&[("weights", &quantized)], &temp_path).unwrap();
181
182 assert!(std::fs::metadata(&temp_path).is_ok());
183 std::fs::remove_file(temp_path).ok();
184 }
185
186 #[test]
187 fn test_export_native_via_unified_api() {
188 let device = Device::Cpu;
189 let tensor = Tensor::zeros(&[32, 32], candle_core::DType::F32, &device).unwrap();
190 let quantized = quantize_nf4(&tensor, 64).unwrap();
191
192 let temp_path = std::env::temp_dir().join("test_unified_native.qnat");
193 export_native_format(&[("weights", &quantized)], &temp_path).unwrap();
194
195 assert!(std::fs::metadata(&temp_path).is_ok());
196 std::fs::remove_file(temp_path).ok();
197 }
198
199 #[test]
200 fn test_export_model_with_config() {
201 let device = Device::Cpu;
202 let tensor = Tensor::zeros(&[32, 32], candle_core::DType::F32, &device).unwrap();
203 let quantized = quantize_nf4(&tensor, 64).unwrap();
204
205 let config = ExportConfig::default()
206 .with_format(ExportFormat::Native)
207 .with_model_name("test_model".to_string())
208 .with_model_type("test".to_string());
209
210 let temp_path = std::env::temp_dir().join("test_config_export.qnat");
211 export_model(&[("weights", &quantized)], config, &temp_path).unwrap();
212
213 assert!(std::fs::metadata(&temp_path).is_ok());
214 std::fs::remove_file(temp_path).ok();
215 }
216}