Skip to main content

entrenar/cli/commands/
quantize.rs

1//! Quantize command implementation
2
3use crate::cli::logging::log;
4use crate::cli::LogLevel;
5use crate::config::{QuantMethod, QuantizeArgs};
6use crate::quant::{quantize_tensor, QuantGranularity, QuantMode, QuantizedTensor};
7use safetensors::SafeTensors;
8use std::collections::HashMap;
9
10/// Load and deserialize a SafeTensors model from disk.
11fn load_safetensors(args: &QuantizeArgs) -> Result<Vec<u8>, String> {
12    std::fs::read(&args.model).map_err(|e| format!("Failed to read model file: {e}"))
13}
14
15/// Serialize quantized tensors and write to output path (JSON format).
16fn save_quantized_json(
17    quantized_tensors: &HashMap<String, QuantizedTensor>,
18    args: &QuantizeArgs,
19) -> Result<(), String> {
20    let output_data = serde_json::to_vec_pretty(quantized_tensors)
21        .map_err(|e| format!("Failed to serialize: {e}"))?;
22
23    std::fs::write(&args.output, &output_data)
24        .map_err(|e| format!("Failed to write output: {e}"))?;
25
26    Ok(())
27}
28
29/// Serialize quantized tensors to SafeTensors format with I8 dtype + scale tensors.
30///
31/// For each tensor `name`, outputs:
32/// - `name` → I8 data (the quantized weights)
33/// - `name.__scale` → F32 scale factors (per-tensor or per-channel)
34fn save_quantized_safetensors(
35    quantized_tensors: &HashMap<String, QuantizedTensor>,
36    args: &QuantizeArgs,
37) -> Result<(), String> {
38    use safetensors::tensor::{Dtype, TensorView};
39
40    // Collect data buffers that live long enough for TensorView references
41    let mut i8_buffers: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
42    let mut scale_buffers: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
43
44    for (name, qt) in quantized_tensors {
45        // I8 data: reinterpret i8 as u8 for safetensors byte storage
46        let i8_bytes: Vec<u8> = qt.data.iter().map(|&v| v as u8).collect();
47        i8_buffers.push((name.clone(), i8_bytes, qt.shape.clone()));
48
49        // Scale factors as F32
50        let scale_name = format!("{name}.__scale");
51        let scale_bytes: Vec<u8> = qt.params.scales.iter().flat_map(|s| s.to_le_bytes()).collect();
52        let scale_shape = vec![qt.params.scales.len()];
53        scale_buffers.push((scale_name, scale_bytes, scale_shape));
54    }
55
56    // Build TensorViews
57    let mut views: Vec<(&str, TensorView<'_>)> = Vec::new();
58
59    for (name, bytes, shape) in &i8_buffers {
60        let view = TensorView::new(Dtype::I8, shape.clone(), bytes)
61            .map_err(|e| format!("Failed to create I8 TensorView for {name}: {e}"))?;
62        views.push((name.as_str(), view));
63    }
64
65    for (name, bytes, shape) in &scale_buffers {
66        let view = TensorView::new(Dtype::F32, shape.clone(), bytes)
67            .map_err(|e| format!("Failed to create F32 TensorView for {name}: {e}"))?;
68        views.push((name.as_str(), view));
69    }
70
71    // Metadata
72    let mut metadata = HashMap::new();
73    metadata.insert("quantization".to_string(), format!("int{}", args.bits));
74    metadata.insert("method".to_string(), format!("{:?}", args.method).to_lowercase());
75    metadata.insert("num_tensors".to_string(), quantized_tensors.len().to_string());
76
77    let safetensor_bytes = safetensors::serialize(views, Some(metadata))
78        .map_err(|e| format!("SafeTensors serialization failed: {e}"))?;
79
80    std::fs::write(&args.output, safetensor_bytes)
81        .map_err(|e| format!("Failed to write output: {e}"))?;
82
83    Ok(())
84}
85
86/// Validate and convert CLI arguments to quant module types.
87fn resolve_quant_params(args: &QuantizeArgs) -> Result<(QuantMode, QuantGranularity), String> {
88    if args.bits != 4 && args.bits != 8 {
89        return Err(format!("Unsupported bit width: {}. Use 4 or 8.", args.bits));
90    }
91
92    let mode = match args.method {
93        QuantMethod::Symmetric => QuantMode::Symmetric,
94        QuantMethod::Asymmetric => QuantMode::Asymmetric,
95    };
96
97    let granularity =
98        if args.per_channel { QuantGranularity::PerChannel } else { QuantGranularity::PerTensor };
99
100    Ok((mode, granularity))
101}
102
103/// Byte-size tracking for compression ratio computation.
104struct ByteAccumulator {
105    original: usize,
106    quantized: usize,
107}
108
109impl ByteAccumulator {
110    fn new() -> Self {
111        Self { original: 0, quantized: 0 }
112    }
113
114    fn compression_ratio(&self) -> f64 {
115        if self.quantized > 0 {
116            self.original as f64 / self.quantized as f64
117        } else {
118            1.0
119        }
120    }
121}
122
123/// Quantize a single F32 tensor and return the result with byte accounting.
124fn quantize_single_tensor(
125    tensor: &safetensors::tensor::TensorView<'_>,
126    granularity: QuantGranularity,
127    mode: QuantMode,
128    bits: u8,
129) -> (QuantizedTensor, usize) {
130    let shape: Vec<usize> = tensor.shape().to_vec();
131    let num_elements: usize = shape.iter().product();
132    let original_bytes = num_elements * 4; // 4 bytes per f32
133
134    let bytes = tensor.data();
135    let values: Vec<f32> = bytes
136        .chunks_exact(4)
137        .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
138        .collect();
139
140    let quantized = quantize_tensor(&values, &shape, granularity, mode, bits);
141    (quantized, original_bytes)
142}
143
144/// Log verbose details about quantization arguments.
145fn log_quant_args(args: &QuantizeArgs, level: LogLevel) {
146    log(level, LogLevel::Verbose, &format!("  Method: {:?}", args.method));
147    log(level, LogLevel::Verbose, &format!("  Per-channel: {}", args.per_channel));
148    log(level, LogLevel::Verbose, &format!("  Output: {}", args.output.display()));
149}
150
151pub fn run_quantize(args: QuantizeArgs, level: LogLevel) -> Result<(), String> {
152    log(
153        level,
154        LogLevel::Normal,
155        &format!("Quantizing {} to {}-bit", args.model.display(), args.bits),
156    );
157
158    log_quant_args(&args, level);
159
160    let (mode, granularity) = resolve_quant_params(&args)?;
161
162    let data = load_safetensors(&args)?;
163    let tensors =
164        SafeTensors::deserialize(&data).map_err(|e| format!("Failed to parse safetensors: {e}"))?;
165
166    let mut quantized_tensors: HashMap<String, QuantizedTensor> = HashMap::new();
167    let mut bytes = ByteAccumulator::new();
168
169    for name in tensors.names() {
170        let tensor =
171            tensors.tensor(name).map_err(|e| format!("Failed to get tensor {name}: {e}"))?;
172
173        if tensor.dtype() != safetensors::tensor::Dtype::F32 {
174            log(level, LogLevel::Verbose, &format!("  Skipping {name} (not F32)"));
175            continue;
176        }
177
178        let (quantized, original_bytes) =
179            quantize_single_tensor(&tensor, granularity, mode, args.bits);
180        bytes.original += original_bytes;
181        bytes.quantized += quantized.memory_bytes();
182
183        log(
184            level,
185            LogLevel::Verbose,
186            &format!(
187                "  Quantized {}: {:?} -> {} bytes",
188                name,
189                tensor.shape(),
190                quantized.memory_bytes()
191            ),
192        );
193
194        quantized_tensors.insert((*name).to_string(), quantized);
195    }
196
197    if args.safetensors {
198        save_quantized_safetensors(&quantized_tensors, &args)?;
199    } else {
200        save_quantized_json(&quantized_tensors, &args)?;
201    }
202
203    log(
204        level,
205        LogLevel::Normal,
206        &format!(
207            "Quantization complete: {} tensors, {:.1}x compression",
208            quantized_tensors.len(),
209            bytes.compression_ratio()
210        ),
211    );
212    log(level, LogLevel::Normal, &format!("  Output: {}", args.output.display()));
213
214    Ok(())
215}
216
217#[cfg(test)]
218#[allow(clippy::unwrap_used)]
219mod tests {
220    use super::*;
221    use safetensors::tensor::{Dtype, TensorView};
222
223    /// Create a minimal safetensors file with known F32 data for testing.
224    fn create_test_safetensors(path: &std::path::Path) {
225        let data: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.1).collect();
226        let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
227        let view = TensorView::new(Dtype::F32, vec![8, 8], &bytes).unwrap();
228        let views = vec![("test_weight", view)];
229        let serialized = safetensors::serialize(views, None::<HashMap<String, String>>).unwrap();
230        std::fs::write(path, serialized).unwrap();
231    }
232
233    #[test]
234    fn test_wasm002_quantize_safetensors_int8_output() {
235        let dir = tempfile::tempdir().unwrap();
236        let input_path = dir.path().join("model.safetensors");
237        let output_path = dir.path().join("model_int8.safetensors");
238
239        create_test_safetensors(&input_path);
240
241        let args = QuantizeArgs {
242            model: input_path,
243            output: output_path.clone(),
244            bits: 8,
245            method: crate::config::QuantMethod::Symmetric,
246            per_channel: false,
247            calibration_data: None,
248            safetensors: true,
249        };
250
251        run_quantize(args, crate::cli::LogLevel::Quiet).expect("quantize should succeed");
252
253        // Verify output is valid safetensors with I8 tensors
254        let data = std::fs::read(&output_path).unwrap();
255        let tensors = SafeTensors::deserialize(&data).unwrap();
256
257        // Should have weight tensor (I8) and scale tensor (F32)
258        let names: Vec<&str> = tensors.names().into_iter().collect();
259        assert!(names.contains(&"test_weight"), "Must contain weight tensor");
260        assert!(names.contains(&"test_weight.__scale"), "Must contain scale tensor");
261
262        // Verify dtype
263        let weight = tensors.tensor("test_weight").unwrap();
264        assert_eq!(weight.dtype(), Dtype::I8);
265        assert_eq!(weight.shape(), &[8, 8]);
266        assert_eq!(weight.data().len(), 64); // 64 i8 values = 64 bytes
267
268        let scale = tensors.tensor("test_weight.__scale").unwrap();
269        assert_eq!(scale.dtype(), Dtype::F32);
270    }
271
272    #[test]
273    fn test_wasm002_quantize_safetensors_compression() {
274        let dir = tempfile::tempdir().unwrap();
275        let input_path = dir.path().join("model.safetensors");
276        let output_path = dir.path().join("model_int8.safetensors");
277
278        create_test_safetensors(&input_path);
279
280        let args = QuantizeArgs {
281            model: input_path.clone(),
282            output: output_path.clone(),
283            bits: 8,
284            method: crate::config::QuantMethod::Symmetric,
285            per_channel: false,
286            calibration_data: None,
287            safetensors: true,
288        };
289
290        run_quantize(args, crate::cli::LogLevel::Quiet).expect("quantize");
291
292        let input_size = std::fs::metadata(&input_path).unwrap().len();
293        let output_size = std::fs::metadata(&output_path).unwrap().len();
294
295        // Int8 should be significantly smaller than F32 (roughly 4x)
296        assert!(
297            output_size < input_size,
298            "Int8 output ({output_size}) must be smaller than F32 input ({input_size})"
299        );
300    }
301
302    #[test]
303    fn test_wasm002_quantize_json_still_works() {
304        let dir = tempfile::tempdir().unwrap();
305        let input_path = dir.path().join("model.safetensors");
306        let output_path = dir.path().join("model_int8.json");
307
308        create_test_safetensors(&input_path);
309
310        let args = QuantizeArgs {
311            model: input_path,
312            output: output_path.clone(),
313            bits: 8,
314            method: crate::config::QuantMethod::Symmetric,
315            per_channel: false,
316            calibration_data: None,
317            safetensors: false,
318        };
319
320        run_quantize(args, crate::cli::LogLevel::Quiet).expect("quantize");
321
322        // Verify JSON output still works
323        let json = std::fs::read_to_string(&output_path).unwrap();
324        let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
325        assert!(parsed.is_object());
326        assert!(parsed.get("test_weight").is_some());
327    }
328}