1use crate::cli::logging::log;
4use crate::cli::LogLevel;
5use crate::config::{QuantMethod, QuantizeArgs};
6use crate::quant::{quantize_tensor, QuantGranularity, QuantMode, QuantizedTensor};
7use safetensors::SafeTensors;
8use std::collections::HashMap;
9
10fn load_safetensors(args: &QuantizeArgs) -> Result<Vec<u8>, String> {
12 std::fs::read(&args.model).map_err(|e| format!("Failed to read model file: {e}"))
13}
14
15fn save_quantized_json(
17 quantized_tensors: &HashMap<String, QuantizedTensor>,
18 args: &QuantizeArgs,
19) -> Result<(), String> {
20 let output_data = serde_json::to_vec_pretty(quantized_tensors)
21 .map_err(|e| format!("Failed to serialize: {e}"))?;
22
23 std::fs::write(&args.output, &output_data)
24 .map_err(|e| format!("Failed to write output: {e}"))?;
25
26 Ok(())
27}
28
29fn save_quantized_safetensors(
35 quantized_tensors: &HashMap<String, QuantizedTensor>,
36 args: &QuantizeArgs,
37) -> Result<(), String> {
38 use safetensors::tensor::{Dtype, TensorView};
39
40 let mut i8_buffers: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
42 let mut scale_buffers: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
43
44 for (name, qt) in quantized_tensors {
45 let i8_bytes: Vec<u8> = qt.data.iter().map(|&v| v as u8).collect();
47 i8_buffers.push((name.clone(), i8_bytes, qt.shape.clone()));
48
49 let scale_name = format!("{name}.__scale");
51 let scale_bytes: Vec<u8> = qt.params.scales.iter().flat_map(|s| s.to_le_bytes()).collect();
52 let scale_shape = vec![qt.params.scales.len()];
53 scale_buffers.push((scale_name, scale_bytes, scale_shape));
54 }
55
56 let mut views: Vec<(&str, TensorView<'_>)> = Vec::new();
58
59 for (name, bytes, shape) in &i8_buffers {
60 let view = TensorView::new(Dtype::I8, shape.clone(), bytes)
61 .map_err(|e| format!("Failed to create I8 TensorView for {name}: {e}"))?;
62 views.push((name.as_str(), view));
63 }
64
65 for (name, bytes, shape) in &scale_buffers {
66 let view = TensorView::new(Dtype::F32, shape.clone(), bytes)
67 .map_err(|e| format!("Failed to create F32 TensorView for {name}: {e}"))?;
68 views.push((name.as_str(), view));
69 }
70
71 let mut metadata = HashMap::new();
73 metadata.insert("quantization".to_string(), format!("int{}", args.bits));
74 metadata.insert("method".to_string(), format!("{:?}", args.method).to_lowercase());
75 metadata.insert("num_tensors".to_string(), quantized_tensors.len().to_string());
76
77 let safetensor_bytes = safetensors::serialize(views, Some(metadata))
78 .map_err(|e| format!("SafeTensors serialization failed: {e}"))?;
79
80 std::fs::write(&args.output, safetensor_bytes)
81 .map_err(|e| format!("Failed to write output: {e}"))?;
82
83 Ok(())
84}
85
86fn resolve_quant_params(args: &QuantizeArgs) -> Result<(QuantMode, QuantGranularity), String> {
88 if args.bits != 4 && args.bits != 8 {
89 return Err(format!("Unsupported bit width: {}. Use 4 or 8.", args.bits));
90 }
91
92 let mode = match args.method {
93 QuantMethod::Symmetric => QuantMode::Symmetric,
94 QuantMethod::Asymmetric => QuantMode::Asymmetric,
95 };
96
97 let granularity =
98 if args.per_channel { QuantGranularity::PerChannel } else { QuantGranularity::PerTensor };
99
100 Ok((mode, granularity))
101}
102
103struct ByteAccumulator {
105 original: usize,
106 quantized: usize,
107}
108
109impl ByteAccumulator {
110 fn new() -> Self {
111 Self { original: 0, quantized: 0 }
112 }
113
114 fn compression_ratio(&self) -> f64 {
115 if self.quantized > 0 {
116 self.original as f64 / self.quantized as f64
117 } else {
118 1.0
119 }
120 }
121}
122
123fn quantize_single_tensor(
125 tensor: &safetensors::tensor::TensorView<'_>,
126 granularity: QuantGranularity,
127 mode: QuantMode,
128 bits: u8,
129) -> (QuantizedTensor, usize) {
130 let shape: Vec<usize> = tensor.shape().to_vec();
131 let num_elements: usize = shape.iter().product();
132 let original_bytes = num_elements * 4; let bytes = tensor.data();
135 let values: Vec<f32> = bytes
136 .chunks_exact(4)
137 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
138 .collect();
139
140 let quantized = quantize_tensor(&values, &shape, granularity, mode, bits);
141 (quantized, original_bytes)
142}
143
144fn log_quant_args(args: &QuantizeArgs, level: LogLevel) {
146 log(level, LogLevel::Verbose, &format!(" Method: {:?}", args.method));
147 log(level, LogLevel::Verbose, &format!(" Per-channel: {}", args.per_channel));
148 log(level, LogLevel::Verbose, &format!(" Output: {}", args.output.display()));
149}
150
151pub fn run_quantize(args: QuantizeArgs, level: LogLevel) -> Result<(), String> {
152 log(
153 level,
154 LogLevel::Normal,
155 &format!("Quantizing {} to {}-bit", args.model.display(), args.bits),
156 );
157
158 log_quant_args(&args, level);
159
160 let (mode, granularity) = resolve_quant_params(&args)?;
161
162 let data = load_safetensors(&args)?;
163 let tensors =
164 SafeTensors::deserialize(&data).map_err(|e| format!("Failed to parse safetensors: {e}"))?;
165
166 let mut quantized_tensors: HashMap<String, QuantizedTensor> = HashMap::new();
167 let mut bytes = ByteAccumulator::new();
168
169 for name in tensors.names() {
170 let tensor =
171 tensors.tensor(name).map_err(|e| format!("Failed to get tensor {name}: {e}"))?;
172
173 if tensor.dtype() != safetensors::tensor::Dtype::F32 {
174 log(level, LogLevel::Verbose, &format!(" Skipping {name} (not F32)"));
175 continue;
176 }
177
178 let (quantized, original_bytes) =
179 quantize_single_tensor(&tensor, granularity, mode, args.bits);
180 bytes.original += original_bytes;
181 bytes.quantized += quantized.memory_bytes();
182
183 log(
184 level,
185 LogLevel::Verbose,
186 &format!(
187 " Quantized {}: {:?} -> {} bytes",
188 name,
189 tensor.shape(),
190 quantized.memory_bytes()
191 ),
192 );
193
194 quantized_tensors.insert((*name).to_string(), quantized);
195 }
196
197 if args.safetensors {
198 save_quantized_safetensors(&quantized_tensors, &args)?;
199 } else {
200 save_quantized_json(&quantized_tensors, &args)?;
201 }
202
203 log(
204 level,
205 LogLevel::Normal,
206 &format!(
207 "Quantization complete: {} tensors, {:.1}x compression",
208 quantized_tensors.len(),
209 bytes.compression_ratio()
210 ),
211 );
212 log(level, LogLevel::Normal, &format!(" Output: {}", args.output.display()));
213
214 Ok(())
215}
216
217#[cfg(test)]
218#[allow(clippy::unwrap_used)]
219mod tests {
220 use super::*;
221 use safetensors::tensor::{Dtype, TensorView};
222
223 fn create_test_safetensors(path: &std::path::Path) {
225 let data: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.1).collect();
226 let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
227 let view = TensorView::new(Dtype::F32, vec![8, 8], &bytes).unwrap();
228 let views = vec![("test_weight", view)];
229 let serialized = safetensors::serialize(views, None::<HashMap<String, String>>).unwrap();
230 std::fs::write(path, serialized).unwrap();
231 }
232
233 #[test]
234 fn test_wasm002_quantize_safetensors_int8_output() {
235 let dir = tempfile::tempdir().unwrap();
236 let input_path = dir.path().join("model.safetensors");
237 let output_path = dir.path().join("model_int8.safetensors");
238
239 create_test_safetensors(&input_path);
240
241 let args = QuantizeArgs {
242 model: input_path,
243 output: output_path.clone(),
244 bits: 8,
245 method: crate::config::QuantMethod::Symmetric,
246 per_channel: false,
247 calibration_data: None,
248 safetensors: true,
249 };
250
251 run_quantize(args, crate::cli::LogLevel::Quiet).expect("quantize should succeed");
252
253 let data = std::fs::read(&output_path).unwrap();
255 let tensors = SafeTensors::deserialize(&data).unwrap();
256
257 let names: Vec<&str> = tensors.names().into_iter().collect();
259 assert!(names.contains(&"test_weight"), "Must contain weight tensor");
260 assert!(names.contains(&"test_weight.__scale"), "Must contain scale tensor");
261
262 let weight = tensors.tensor("test_weight").unwrap();
264 assert_eq!(weight.dtype(), Dtype::I8);
265 assert_eq!(weight.shape(), &[8, 8]);
266 assert_eq!(weight.data().len(), 64); let scale = tensors.tensor("test_weight.__scale").unwrap();
269 assert_eq!(scale.dtype(), Dtype::F32);
270 }
271
272 #[test]
273 fn test_wasm002_quantize_safetensors_compression() {
274 let dir = tempfile::tempdir().unwrap();
275 let input_path = dir.path().join("model.safetensors");
276 let output_path = dir.path().join("model_int8.safetensors");
277
278 create_test_safetensors(&input_path);
279
280 let args = QuantizeArgs {
281 model: input_path.clone(),
282 output: output_path.clone(),
283 bits: 8,
284 method: crate::config::QuantMethod::Symmetric,
285 per_channel: false,
286 calibration_data: None,
287 safetensors: true,
288 };
289
290 run_quantize(args, crate::cli::LogLevel::Quiet).expect("quantize");
291
292 let input_size = std::fs::metadata(&input_path).unwrap().len();
293 let output_size = std::fs::metadata(&output_path).unwrap().len();
294
295 assert!(
297 output_size < input_size,
298 "Int8 output ({output_size}) must be smaller than F32 input ({input_size})"
299 );
300 }
301
302 #[test]
303 fn test_wasm002_quantize_json_still_works() {
304 let dir = tempfile::tempdir().unwrap();
305 let input_path = dir.path().join("model.safetensors");
306 let output_path = dir.path().join("model_int8.json");
307
308 create_test_safetensors(&input_path);
309
310 let args = QuantizeArgs {
311 model: input_path,
312 output: output_path.clone(),
313 bits: 8,
314 method: crate::config::QuantMethod::Symmetric,
315 per_channel: false,
316 calibration_data: None,
317 safetensors: false,
318 };
319
320 run_quantize(args, crate::cli::LogLevel::Quiet).expect("quantize");
321
322 let json = std::fs::read_to_string(&output_path).unwrap();
324 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
325 assert!(parsed.is_object());
326 assert!(parsed.get("test_weight").is_some());
327 }
328}