Skip to main content

entrenar/prune/pipeline/
prune_quantize.rs

1//! Prune-then-Quantize pipeline
2//!
3//! Combines pruning and quantization into a single operation:
4//! 1. Prune weights (set to zero based on importance)
5//! 2. Quantize remaining weights (Q4_0 or Q8_0)
6//! 3. Export the quantized/pruned weights
7
8use crate::quant::{Q4_0, Q8_0};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::{Path, PathBuf};
12
13/// Quantization format for the prune-quantize pipeline
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15#[allow(dead_code)]
16enum PruneQuantFormat {
17    /// 4-bit quantization
18    Q4_0,
19    /// 8-bit quantization
20    Q8_0,
21}
22
23/// Configuration for prune-then-quantize pipeline
24#[derive(Debug, Clone)]
25#[allow(dead_code)]
26struct PruneQuantConfig {
27    /// Target sparsity (0.0 to 1.0)
28    target_sparsity: f32,
29    /// Quantization format
30    quant_format: PruneQuantFormat,
31}
32
33/// Result of prune-then-quantize pipeline
34#[derive(Debug, Clone)]
35#[allow(dead_code)]
36struct PruneQuantizeResult {
37    /// Path to output file
38    output_path: PathBuf,
39    /// Achieved sparsity before quantization
40    achieved_sparsity: f32,
41    /// Quantization format used
42    quant_format: PruneQuantFormat,
43    /// Number of tensors
44    num_tensors: usize,
45    /// Output file size in bytes
46    file_size: u64,
47}
48
49/// Apply magnitude pruning to weights (set smallest values to zero)
50#[allow(dead_code)]
51fn magnitude_prune(
52    weights: &mut HashMap<String, Vec<f32>>,
53    target_sparsity: f32,
54) -> (usize, usize) {
55    if target_sparsity <= 0.0 {
56        let total: usize = weights.values().map(Vec::len).sum();
57        return (0, total);
58    }
59
60    // Collect all magnitudes to determine threshold
61    let mut all_magnitudes: Vec<f32> =
62        weights.values().flat_map(|data| data.iter().map(|v| v.abs())).collect();
63    all_magnitudes.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
64
65    let total = all_magnitudes.len();
66    let prune_count = ((total as f32 * target_sparsity) as usize).min(total);
67    let threshold = if prune_count < total { all_magnitudes[prune_count] } else { f32::MAX };
68
69    // Apply pruning
70    let mut pruned = 0;
71    for data in weights.values_mut() {
72        for val in data.iter_mut() {
73            if val.abs() < threshold {
74                *val = 0.0;
75                pruned += 1;
76            }
77        }
78    }
79
80    (pruned, total)
81}
82
83/// Prune model weights and quantize, saving as SafeTensors with quantized data
84///
85/// Pipeline:
86/// 1. Apply magnitude pruning to reach target sparsity
87/// 2. Quantize all tensors to Q4_0 or Q8_0
88/// 3. Dequantize back to f32 for SafeTensors storage (preserving quantization effects)
89/// 4. Export as SafeTensors
90#[allow(dead_code)]
91fn prune_and_quantize(
92    weights: &HashMap<String, Vec<f32>>,
93    shapes: &HashMap<String, Vec<usize>>,
94    config: &PruneQuantConfig,
95    output_dir: impl AsRef<Path>,
96    filename: &str,
97) -> Result<PruneQuantizeResult, std::io::Error> {
98    let output_dir = output_dir.as_ref();
99    std::fs::create_dir_all(output_dir)?;
100
101    // Clone weights for pruning
102    let mut pruned_weights = weights.clone();
103
104    // Step 1: Prune
105    let (pruned_count, total_count) = magnitude_prune(&mut pruned_weights, config.target_sparsity);
106    let achieved_sparsity =
107        if total_count > 0 { pruned_count as f32 / total_count as f32 } else { 0.0 };
108
109    // Step 2: Quantize then dequantize (applies quantization rounding)
110    let quantized_weights: HashMap<String, Vec<f32>> = pruned_weights
111        .iter()
112        .map(|(name, data)| {
113            let deq = match config.quant_format {
114                PruneQuantFormat::Q4_0 => Q4_0::quantize(data).dequantize(),
115                PruneQuantFormat::Q8_0 => Q8_0::quantize(data).dequantize(),
116            };
117            (name.clone(), deq)
118        })
119        .collect();
120
121    // Step 3: Export as SafeTensors
122    use safetensors::tensor::{Dtype, TensorView};
123
124    let mut sorted_names: Vec<&String> = quantized_weights.keys().collect();
125    sorted_names.sort();
126
127    let tensor_data: Vec<(String, Vec<u8>, Vec<usize>)> = sorted_names
128        .iter()
129        .map(|name| {
130            let data = &quantized_weights[*name];
131            let bytes: Vec<u8> = bytemuck::cast_slice(data).to_vec();
132            let shape = shapes.get(*name).cloned().unwrap_or_else(|| vec![data.len()]);
133            ((*name).clone(), bytes, shape)
134        })
135        .collect();
136
137    let views: Vec<(&str, TensorView<'_>)> = tensor_data
138        .iter()
139        .map(|(name, bytes, shape)| {
140            let view = TensorView::new(Dtype::F32, shape.clone(), bytes)
141                .expect("TensorView construction must not fail for valid F32 data");
142            (name.as_str(), view)
143        })
144        .collect();
145
146    let mut metadata = std::collections::HashMap::new();
147    metadata.insert("sparsity".to_string(), format!("{achieved_sparsity:.4}"));
148    metadata.insert(
149        "quantization".to_string(),
150        match config.quant_format {
151            PruneQuantFormat::Q4_0 => "Q4_0".to_string(),
152            PruneQuantFormat::Q8_0 => "Q8_0".to_string(),
153        },
154    );
155
156    let safetensor_bytes = safetensors::serialize(views, Some(metadata))
157        .map_err(|e| std::io::Error::other(e.to_string()))?;
158
159    let output_path = output_dir.join(filename);
160    std::fs::write(&output_path, &safetensor_bytes)?;
161
162    Ok(PruneQuantizeResult {
163        output_path,
164        achieved_sparsity,
165        quant_format: config.quant_format,
166        num_tensors: sorted_names.len(),
167        file_size: safetensor_bytes.len() as u64,
168    })
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174    use tempfile::TempDir;
175
176    fn make_test_weights() -> (HashMap<String, Vec<f32>>, HashMap<String, Vec<usize>>) {
177        let mut weights = HashMap::new();
178        let mut shapes = HashMap::new();
179
180        let data: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.1).collect();
181        weights.insert("layer.0.weight".to_string(), data);
182        shapes.insert("layer.0.weight".to_string(), vec![8, 8]);
183
184        (weights, shapes)
185    }
186
187    #[test]
188    fn test_prune_and_quantize_q4() {
189        let (weights, shapes) = make_test_weights();
190        let config =
191            PruneQuantConfig { target_sparsity: 0.5, quant_format: PruneQuantFormat::Q4_0 };
192        let tmp = TempDir::new().expect("temp file creation should succeed");
193
194        let result =
195            prune_and_quantize(&weights, &shapes, &config, tmp.path(), "pruned.safetensors")
196                .expect("operation should succeed");
197
198        assert!(result.achieved_sparsity >= 0.3);
199        assert_eq!(result.quant_format, PruneQuantFormat::Q4_0);
200        assert!(result.output_path.exists());
201        assert!(result.file_size > 0);
202    }
203
204    #[test]
205    fn test_prune_and_quantize_q8() {
206        let (weights, shapes) = make_test_weights();
207        let config =
208            PruneQuantConfig { target_sparsity: 0.3, quant_format: PruneQuantFormat::Q8_0 };
209        let tmp = TempDir::new().expect("temp file creation should succeed");
210
211        let result =
212            prune_and_quantize(&weights, &shapes, &config, tmp.path(), "pruned-q8.safetensors")
213                .expect("operation should succeed");
214
215        assert_eq!(result.quant_format, PruneQuantFormat::Q8_0);
216        assert!(result.file_size > 0);
217    }
218
219    #[test]
220    fn test_prune_and_quantize_no_sparsity() {
221        let (weights, shapes) = make_test_weights();
222        let config =
223            PruneQuantConfig { target_sparsity: 0.0, quant_format: PruneQuantFormat::Q4_0 };
224        let tmp = TempDir::new().expect("temp file creation should succeed");
225
226        let result =
227            prune_and_quantize(&weights, &shapes, &config, tmp.path(), "unpruned.safetensors")
228                .expect("operation should succeed");
229
230        assert_eq!(result.achieved_sparsity, 0.0);
231    }
232
233    #[test]
234    fn test_magnitude_prune_basic() {
235        let mut weights = HashMap::new();
236        weights.insert("w".to_string(), vec![0.1, 0.5, 0.01, 0.8, 0.02, 0.9]);
237
238        let (pruned, total) = magnitude_prune(&mut weights, 0.5);
239        assert_eq!(total, 6);
240        assert!(pruned >= 2);
241
242        let data = &weights["w"];
243        assert_eq!(data[2], 0.0); // 0.01 should be pruned
244        assert_eq!(data[4], 0.0); // 0.02 should be pruned
245    }
246
247    #[test]
248    fn test_output_safetensors_valid() {
249        let (weights, shapes) = make_test_weights();
250        let config =
251            PruneQuantConfig { target_sparsity: 0.5, quant_format: PruneQuantFormat::Q4_0 };
252        let tmp = TempDir::new().expect("temp file creation should succeed");
253
254        let result = prune_and_quantize(&weights, &shapes, &config, tmp.path(), "test.safetensors")
255            .expect("config should be valid");
256
257        let data = std::fs::read(&result.output_path).expect("file read should succeed");
258        let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
259        assert_eq!(loaded.len(), 1);
260
261        // Check metadata
262        let (_, meta) =
263            safetensors::SafeTensors::read_metadata(&data).expect("deserialization should succeed");
264        let md = meta.metadata().as_ref().expect("operation should succeed");
265        assert!(md.contains_key("sparsity"));
266        assert!(md.contains_key("quantization"));
267    }
268}