1use crate::quant::{Q4_0, Q8_0};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::{Path, PathBuf};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15#[allow(dead_code)]
16enum PruneQuantFormat {
17 Q4_0,
19 Q8_0,
21}
22
23#[derive(Debug, Clone)]
25#[allow(dead_code)]
26struct PruneQuantConfig {
27 target_sparsity: f32,
29 quant_format: PruneQuantFormat,
31}
32
33#[derive(Debug, Clone)]
35#[allow(dead_code)]
36struct PruneQuantizeResult {
37 output_path: PathBuf,
39 achieved_sparsity: f32,
41 quant_format: PruneQuantFormat,
43 num_tensors: usize,
45 file_size: u64,
47}
48
49#[allow(dead_code)]
51fn magnitude_prune(
52 weights: &mut HashMap<String, Vec<f32>>,
53 target_sparsity: f32,
54) -> (usize, usize) {
55 if target_sparsity <= 0.0 {
56 let total: usize = weights.values().map(Vec::len).sum();
57 return (0, total);
58 }
59
60 let mut all_magnitudes: Vec<f32> =
62 weights.values().flat_map(|data| data.iter().map(|v| v.abs())).collect();
63 all_magnitudes.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
64
65 let total = all_magnitudes.len();
66 let prune_count = ((total as f32 * target_sparsity) as usize).min(total);
67 let threshold = if prune_count < total { all_magnitudes[prune_count] } else { f32::MAX };
68
69 let mut pruned = 0;
71 for data in weights.values_mut() {
72 for val in data.iter_mut() {
73 if val.abs() < threshold {
74 *val = 0.0;
75 pruned += 1;
76 }
77 }
78 }
79
80 (pruned, total)
81}
82
83#[allow(dead_code)]
91fn prune_and_quantize(
92 weights: &HashMap<String, Vec<f32>>,
93 shapes: &HashMap<String, Vec<usize>>,
94 config: &PruneQuantConfig,
95 output_dir: impl AsRef<Path>,
96 filename: &str,
97) -> Result<PruneQuantizeResult, std::io::Error> {
98 let output_dir = output_dir.as_ref();
99 std::fs::create_dir_all(output_dir)?;
100
101 let mut pruned_weights = weights.clone();
103
104 let (pruned_count, total_count) = magnitude_prune(&mut pruned_weights, config.target_sparsity);
106 let achieved_sparsity =
107 if total_count > 0 { pruned_count as f32 / total_count as f32 } else { 0.0 };
108
109 let quantized_weights: HashMap<String, Vec<f32>> = pruned_weights
111 .iter()
112 .map(|(name, data)| {
113 let deq = match config.quant_format {
114 PruneQuantFormat::Q4_0 => Q4_0::quantize(data).dequantize(),
115 PruneQuantFormat::Q8_0 => Q8_0::quantize(data).dequantize(),
116 };
117 (name.clone(), deq)
118 })
119 .collect();
120
121 use safetensors::tensor::{Dtype, TensorView};
123
124 let mut sorted_names: Vec<&String> = quantized_weights.keys().collect();
125 sorted_names.sort();
126
127 let tensor_data: Vec<(String, Vec<u8>, Vec<usize>)> = sorted_names
128 .iter()
129 .map(|name| {
130 let data = &quantized_weights[*name];
131 let bytes: Vec<u8> = bytemuck::cast_slice(data).to_vec();
132 let shape = shapes.get(*name).cloned().unwrap_or_else(|| vec![data.len()]);
133 ((*name).clone(), bytes, shape)
134 })
135 .collect();
136
137 let views: Vec<(&str, TensorView<'_>)> = tensor_data
138 .iter()
139 .map(|(name, bytes, shape)| {
140 let view = TensorView::new(Dtype::F32, shape.clone(), bytes)
141 .expect("TensorView construction must not fail for valid F32 data");
142 (name.as_str(), view)
143 })
144 .collect();
145
146 let mut metadata = std::collections::HashMap::new();
147 metadata.insert("sparsity".to_string(), format!("{achieved_sparsity:.4}"));
148 metadata.insert(
149 "quantization".to_string(),
150 match config.quant_format {
151 PruneQuantFormat::Q4_0 => "Q4_0".to_string(),
152 PruneQuantFormat::Q8_0 => "Q8_0".to_string(),
153 },
154 );
155
156 let safetensor_bytes = safetensors::serialize(views, Some(metadata))
157 .map_err(|e| std::io::Error::other(e.to_string()))?;
158
159 let output_path = output_dir.join(filename);
160 std::fs::write(&output_path, &safetensor_bytes)?;
161
162 Ok(PruneQuantizeResult {
163 output_path,
164 achieved_sparsity,
165 quant_format: config.quant_format,
166 num_tensors: sorted_names.len(),
167 file_size: safetensor_bytes.len() as u64,
168 })
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use tempfile::TempDir;
175
176 fn make_test_weights() -> (HashMap<String, Vec<f32>>, HashMap<String, Vec<usize>>) {
177 let mut weights = HashMap::new();
178 let mut shapes = HashMap::new();
179
180 let data: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.1).collect();
181 weights.insert("layer.0.weight".to_string(), data);
182 shapes.insert("layer.0.weight".to_string(), vec![8, 8]);
183
184 (weights, shapes)
185 }
186
187 #[test]
188 fn test_prune_and_quantize_q4() {
189 let (weights, shapes) = make_test_weights();
190 let config =
191 PruneQuantConfig { target_sparsity: 0.5, quant_format: PruneQuantFormat::Q4_0 };
192 let tmp = TempDir::new().expect("temp file creation should succeed");
193
194 let result =
195 prune_and_quantize(&weights, &shapes, &config, tmp.path(), "pruned.safetensors")
196 .expect("operation should succeed");
197
198 assert!(result.achieved_sparsity >= 0.3);
199 assert_eq!(result.quant_format, PruneQuantFormat::Q4_0);
200 assert!(result.output_path.exists());
201 assert!(result.file_size > 0);
202 }
203
204 #[test]
205 fn test_prune_and_quantize_q8() {
206 let (weights, shapes) = make_test_weights();
207 let config =
208 PruneQuantConfig { target_sparsity: 0.3, quant_format: PruneQuantFormat::Q8_0 };
209 let tmp = TempDir::new().expect("temp file creation should succeed");
210
211 let result =
212 prune_and_quantize(&weights, &shapes, &config, tmp.path(), "pruned-q8.safetensors")
213 .expect("operation should succeed");
214
215 assert_eq!(result.quant_format, PruneQuantFormat::Q8_0);
216 assert!(result.file_size > 0);
217 }
218
219 #[test]
220 fn test_prune_and_quantize_no_sparsity() {
221 let (weights, shapes) = make_test_weights();
222 let config =
223 PruneQuantConfig { target_sparsity: 0.0, quant_format: PruneQuantFormat::Q4_0 };
224 let tmp = TempDir::new().expect("temp file creation should succeed");
225
226 let result =
227 prune_and_quantize(&weights, &shapes, &config, tmp.path(), "unpruned.safetensors")
228 .expect("operation should succeed");
229
230 assert_eq!(result.achieved_sparsity, 0.0);
231 }
232
233 #[test]
234 fn test_magnitude_prune_basic() {
235 let mut weights = HashMap::new();
236 weights.insert("w".to_string(), vec![0.1, 0.5, 0.01, 0.8, 0.02, 0.9]);
237
238 let (pruned, total) = magnitude_prune(&mut weights, 0.5);
239 assert_eq!(total, 6);
240 assert!(pruned >= 2);
241
242 let data = &weights["w"];
243 assert_eq!(data[2], 0.0); assert_eq!(data[4], 0.0); }
246
247 #[test]
248 fn test_output_safetensors_valid() {
249 let (weights, shapes) = make_test_weights();
250 let config =
251 PruneQuantConfig { target_sparsity: 0.5, quant_format: PruneQuantFormat::Q4_0 };
252 let tmp = TempDir::new().expect("temp file creation should succeed");
253
254 let result = prune_and_quantize(&weights, &shapes, &config, tmp.path(), "test.safetensors")
255 .expect("config should be valid");
256
257 let data = std::fs::read(&result.output_path).expect("file read should succeed");
258 let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
259 assert_eq!(loaded.len(), 1);
260
261 let (_, meta) =
263 safetensors::SafeTensors::read_metadata(&data).expect("deserialization should succeed");
264 let md = meta.metadata().as_ref().expect("operation should succeed");
265 assert!(md.contains_key("sparsity"));
266 assert!(md.contains_key("quantization"));
267 }
268}