1use crate::prune::pipeline::metrics::PruningMetrics;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::{Path, PathBuf};
10
11#[derive(Debug, Clone)]
13pub struct SparseExportResult {
14 pub weights_path: PathBuf,
16 pub metadata_path: PathBuf,
18 pub global_sparsity: f32,
20 pub num_tensors: usize,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct SparsityMetadata {
27 pub version: String,
29 pub global_sparsity: f32,
31 pub total_parameters: usize,
33 pub parameters_pruned: usize,
35 pub tensors: Vec<TensorSparsityInfo>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct TensorSparsityInfo {
42 pub name: String,
44 pub sparsity: f32,
46 pub zero_count: usize,
48 pub total_count: usize,
50}
51
52pub fn export_sparse_model(
58 weights: &HashMap<String, Vec<f32>>,
59 shapes: &HashMap<String, Vec<usize>>,
60 metrics: &PruningMetrics,
61 output_dir: impl AsRef<Path>,
62 filename: &str,
63) -> Result<SparseExportResult, std::io::Error> {
64 let output_dir = output_dir.as_ref();
65 std::fs::create_dir_all(output_dir)?;
66
67 let mut tensor_infos = Vec::new();
69 let mut total_zeros = 0usize;
70 let mut total_elements = 0usize;
71
72 let mut names: Vec<&String> = weights.keys().collect();
74 names.sort();
75
76 for name in &names {
77 let data = &weights[*name];
78 let zero_count = data.iter().filter(|&&v| v == 0.0).count();
79 let total = data.len();
80
81 tensor_infos.push(TensorSparsityInfo {
82 name: (*name).clone(),
83 sparsity: if total > 0 { zero_count as f32 / total as f32 } else { 0.0 },
84 zero_count,
85 total_count: total,
86 });
87
88 total_zeros += zero_count;
89 total_elements += total;
90 }
91
92 let global_sparsity =
93 if total_elements > 0 { total_zeros as f32 / total_elements as f32 } else { 0.0 };
94
95 let metadata = SparsityMetadata {
97 version: "1.0".to_string(),
98 global_sparsity,
99 total_parameters: metrics.total_parameters,
100 parameters_pruned: metrics.parameters_pruned,
101 tensors: tensor_infos,
102 };
103
104 let weights_path = output_dir.join(filename);
106 {
107 use safetensors::tensor::{Dtype, TensorView};
108
109 let tensor_data: Vec<(String, Vec<u8>, Vec<usize>)> = names
110 .iter()
111 .map(|name| {
112 let data = &weights[*name];
113 let bytes: Vec<u8> = bytemuck::cast_slice(data).to_vec();
114 let shape = shapes.get(*name).cloned().unwrap_or_else(|| vec![data.len()]);
115 ((*name).clone(), bytes, shape)
116 })
117 .collect();
118
119 let views: Vec<(&str, TensorView<'_>)> = tensor_data
120 .iter()
121 .map(|(name, bytes, shape)| {
122 let view = TensorView::new(Dtype::F32, shape.clone(), bytes)
123 .expect("TensorView construction must not fail for valid F32 data");
124 (name.as_str(), view)
125 })
126 .collect();
127
128 let safetensor_bytes = safetensors::serialize(views, None)
129 .map_err(|e| std::io::Error::other(e.to_string()))?;
130
131 std::fs::write(&weights_path, safetensor_bytes)?;
132 }
133
134 let metadata_path = output_dir.join("sparsity_metadata.json");
136 let metadata_json = serde_json::to_string_pretty(&metadata)
137 .map_err(|e| std::io::Error::other(e.to_string()))?;
138 std::fs::write(&metadata_path, metadata_json)?;
139
140 Ok(SparseExportResult {
141 weights_path,
142 metadata_path,
143 global_sparsity,
144 num_tensors: names.len(),
145 })
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use tempfile::TempDir;
152
153 fn make_test_data() -> (HashMap<String, Vec<f32>>, HashMap<String, Vec<usize>>) {
154 let mut weights = HashMap::new();
155 let mut shapes = HashMap::new();
156
157 let data = vec![1.0, 0.0, 0.0, 2.0, 0.0, 3.0, 0.0, 0.0];
159 weights.insert("layer.0.weight".to_string(), data);
160 shapes.insert("layer.0.weight".to_string(), vec![2, 4]);
161
162 weights.insert("layer.0.bias".to_string(), vec![0.1, 0.2]);
164 shapes.insert("layer.0.bias".to_string(), vec![2]);
165
166 (weights, shapes)
167 }
168
169 #[test]
170 fn test_export_sparse_creates_files() {
171 let (weights, shapes) = make_test_data();
172 let metrics = PruningMetrics::new(0.5);
173 let tmp = TempDir::new().expect("temp file creation should succeed");
174
175 let result =
176 export_sparse_model(&weights, &shapes, &metrics, tmp.path(), "sparse.safetensors")
177 .expect("operation should succeed");
178
179 assert!(result.weights_path.exists());
180 assert!(result.metadata_path.exists());
181 assert_eq!(result.num_tensors, 2);
182 }
183
184 #[test]
185 fn test_export_sparse_metadata_content() {
186 let (weights, shapes) = make_test_data();
187 let mut metrics = PruningMetrics::new(0.5);
188 metrics.update_sparsity(5, 10);
189 let tmp = TempDir::new().expect("temp file creation should succeed");
190
191 export_sparse_model(&weights, &shapes, &metrics, tmp.path(), "sparse.safetensors")
192 .expect("parsing should succeed");
193
194 let json = std::fs::read_to_string(tmp.path().join("sparsity_metadata.json"))
195 .expect("file read should succeed");
196 let meta: SparsityMetadata =
197 serde_json::from_str(&json).expect("JSON deserialization should succeed");
198
199 assert_eq!(meta.version, "1.0");
200 assert_eq!(meta.total_parameters, 10);
201 assert_eq!(meta.parameters_pruned, 5);
202 assert_eq!(meta.tensors.len(), 2);
203 }
204
205 #[test]
206 fn test_per_tensor_sparsity() {
207 let (weights, shapes) = make_test_data();
208 let metrics = PruningMetrics::new(0.5);
209 let tmp = TempDir::new().expect("temp file creation should succeed");
210
211 export_sparse_model(&weights, &shapes, &metrics, tmp.path(), "sparse.safetensors")
212 .expect("parsing should succeed");
213
214 let json = std::fs::read_to_string(tmp.path().join("sparsity_metadata.json"))
215 .expect("file read should succeed");
216 let meta: SparsityMetadata =
217 serde_json::from_str(&json).expect("JSON deserialization should succeed");
218
219 let bias_info = meta
221 .tensors
222 .iter()
223 .find(|t| t.name == "layer.0.bias")
224 .expect("operation should succeed");
225 assert_eq!(bias_info.sparsity, 0.0);
226 assert_eq!(bias_info.zero_count, 0);
227
228 let weight_info = meta
230 .tensors
231 .iter()
232 .find(|t| t.name == "layer.0.weight")
233 .expect("operation should succeed");
234 assert!(weight_info.sparsity > 0.5);
235 assert_eq!(weight_info.zero_count, 5);
236 }
237
238 #[test]
239 fn test_export_sparse_safetensors_valid() {
240 let (weights, shapes) = make_test_data();
241 let metrics = PruningMetrics::new(0.5);
242 let tmp = TempDir::new().expect("temp file creation should succeed");
243
244 let result =
245 export_sparse_model(&weights, &shapes, &metrics, tmp.path(), "sparse.safetensors")
246 .expect("operation should succeed");
247
248 let data = std::fs::read(&result.weights_path).expect("file read should succeed");
250 let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
251 assert_eq!(loaded.len(), 2);
252 }
253
254 #[test]
255 fn test_export_empty_weights() {
256 let weights = HashMap::new();
257 let shapes = HashMap::new();
258 let metrics = PruningMetrics::new(0.0);
259 let tmp = TempDir::new().expect("temp file creation should succeed");
260
261 let result =
262 export_sparse_model(&weights, &shapes, &metrics, tmp.path(), "empty.safetensors")
263 .expect("operation should succeed");
264
265 assert_eq!(result.num_tensors, 0);
266 assert_eq!(result.global_sparsity, 0.0);
267 }
268}