Skip to main content

entrenar/prune/pipeline/
sparse_export.rs

1//! Sparse model export with sparsity metadata
2//!
3//! Exports pruned model weights along with a `sparsity_metadata.json` sidecar
4//! containing per-tensor sparsity statistics.
5
6use crate::prune::pipeline::metrics::PruningMetrics;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::{Path, PathBuf};
10
11/// Result of sparse model export
12#[derive(Debug, Clone)]
13pub struct SparseExportResult {
14    /// Path to the exported weight file
15    pub weights_path: PathBuf,
16    /// Path to the sparsity metadata sidecar
17    pub metadata_path: PathBuf,
18    /// Global sparsity ratio
19    pub global_sparsity: f32,
20    /// Number of tensors exported
21    pub num_tensors: usize,
22}
23
24/// Sparsity metadata sidecar (serialized to sparsity_metadata.json)
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct SparsityMetadata {
27    /// Format version
28    pub version: String,
29    /// Global sparsity (fraction of zero parameters)
30    pub global_sparsity: f32,
31    /// Total parameters
32    pub total_parameters: usize,
33    /// Parameters pruned (zero)
34    pub parameters_pruned: usize,
35    /// Per-tensor sparsity information
36    pub tensors: Vec<TensorSparsityInfo>,
37}
38
39/// Per-tensor sparsity statistics
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct TensorSparsityInfo {
42    /// Tensor name
43    pub name: String,
44    /// Sparsity ratio for this tensor
45    pub sparsity: f32,
46    /// Number of zero elements
47    pub zero_count: usize,
48    /// Total elements
49    pub total_count: usize,
50}
51
52/// Export a sparse (pruned) model with sparsity metadata sidecar
53///
54/// Produces:
55/// - Weight file (SafeTensors format via bytemuck)
56/// - `sparsity_metadata.json` with per-tensor sparsity stats
57pub fn export_sparse_model(
58    weights: &HashMap<String, Vec<f32>>,
59    shapes: &HashMap<String, Vec<usize>>,
60    metrics: &PruningMetrics,
61    output_dir: impl AsRef<Path>,
62    filename: &str,
63) -> Result<SparseExportResult, std::io::Error> {
64    let output_dir = output_dir.as_ref();
65    std::fs::create_dir_all(output_dir)?;
66
67    // Compute per-tensor sparsity
68    let mut tensor_infos = Vec::new();
69    let mut total_zeros = 0usize;
70    let mut total_elements = 0usize;
71
72    // Sort keys for deterministic output
73    let mut names: Vec<&String> = weights.keys().collect();
74    names.sort();
75
76    for name in &names {
77        let data = &weights[*name];
78        let zero_count = data.iter().filter(|&&v| v == 0.0).count();
79        let total = data.len();
80
81        tensor_infos.push(TensorSparsityInfo {
82            name: (*name).clone(),
83            sparsity: if total > 0 { zero_count as f32 / total as f32 } else { 0.0 },
84            zero_count,
85            total_count: total,
86        });
87
88        total_zeros += zero_count;
89        total_elements += total;
90    }
91
92    let global_sparsity =
93        if total_elements > 0 { total_zeros as f32 / total_elements as f32 } else { 0.0 };
94
95    // Build metadata
96    let metadata = SparsityMetadata {
97        version: "1.0".to_string(),
98        global_sparsity,
99        total_parameters: metrics.total_parameters,
100        parameters_pruned: metrics.parameters_pruned,
101        tensors: tensor_infos,
102    };
103
104    // Write weight file as SafeTensors
105    let weights_path = output_dir.join(filename);
106    {
107        use safetensors::tensor::{Dtype, TensorView};
108
109        let tensor_data: Vec<(String, Vec<u8>, Vec<usize>)> = names
110            .iter()
111            .map(|name| {
112                let data = &weights[*name];
113                let bytes: Vec<u8> = bytemuck::cast_slice(data).to_vec();
114                let shape = shapes.get(*name).cloned().unwrap_or_else(|| vec![data.len()]);
115                ((*name).clone(), bytes, shape)
116            })
117            .collect();
118
119        let views: Vec<(&str, TensorView<'_>)> = tensor_data
120            .iter()
121            .map(|(name, bytes, shape)| {
122                let view = TensorView::new(Dtype::F32, shape.clone(), bytes)
123                    .expect("TensorView construction must not fail for valid F32 data");
124                (name.as_str(), view)
125            })
126            .collect();
127
128        let safetensor_bytes = safetensors::serialize(views, None)
129            .map_err(|e| std::io::Error::other(e.to_string()))?;
130
131        std::fs::write(&weights_path, safetensor_bytes)?;
132    }
133
134    // Write sparsity metadata
135    let metadata_path = output_dir.join("sparsity_metadata.json");
136    let metadata_json = serde_json::to_string_pretty(&metadata)
137        .map_err(|e| std::io::Error::other(e.to_string()))?;
138    std::fs::write(&metadata_path, metadata_json)?;
139
140    Ok(SparseExportResult {
141        weights_path,
142        metadata_path,
143        global_sparsity,
144        num_tensors: names.len(),
145    })
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use tempfile::TempDir;
152
153    fn make_test_data() -> (HashMap<String, Vec<f32>>, HashMap<String, Vec<usize>>) {
154        let mut weights = HashMap::new();
155        let mut shapes = HashMap::new();
156
157        // 50% sparse tensor
158        let data = vec![1.0, 0.0, 0.0, 2.0, 0.0, 3.0, 0.0, 0.0];
159        weights.insert("layer.0.weight".to_string(), data);
160        shapes.insert("layer.0.weight".to_string(), vec![2, 4]);
161
162        // 0% sparse tensor
163        weights.insert("layer.0.bias".to_string(), vec![0.1, 0.2]);
164        shapes.insert("layer.0.bias".to_string(), vec![2]);
165
166        (weights, shapes)
167    }
168
169    #[test]
170    fn test_export_sparse_creates_files() {
171        let (weights, shapes) = make_test_data();
172        let metrics = PruningMetrics::new(0.5);
173        let tmp = TempDir::new().expect("temp file creation should succeed");
174
175        let result =
176            export_sparse_model(&weights, &shapes, &metrics, tmp.path(), "sparse.safetensors")
177                .expect("operation should succeed");
178
179        assert!(result.weights_path.exists());
180        assert!(result.metadata_path.exists());
181        assert_eq!(result.num_tensors, 2);
182    }
183
184    #[test]
185    fn test_export_sparse_metadata_content() {
186        let (weights, shapes) = make_test_data();
187        let mut metrics = PruningMetrics::new(0.5);
188        metrics.update_sparsity(5, 10);
189        let tmp = TempDir::new().expect("temp file creation should succeed");
190
191        export_sparse_model(&weights, &shapes, &metrics, tmp.path(), "sparse.safetensors")
192            .expect("parsing should succeed");
193
194        let json = std::fs::read_to_string(tmp.path().join("sparsity_metadata.json"))
195            .expect("file read should succeed");
196        let meta: SparsityMetadata =
197            serde_json::from_str(&json).expect("JSON deserialization should succeed");
198
199        assert_eq!(meta.version, "1.0");
200        assert_eq!(meta.total_parameters, 10);
201        assert_eq!(meta.parameters_pruned, 5);
202        assert_eq!(meta.tensors.len(), 2);
203    }
204
205    #[test]
206    fn test_per_tensor_sparsity() {
207        let (weights, shapes) = make_test_data();
208        let metrics = PruningMetrics::new(0.5);
209        let tmp = TempDir::new().expect("temp file creation should succeed");
210
211        export_sparse_model(&weights, &shapes, &metrics, tmp.path(), "sparse.safetensors")
212            .expect("parsing should succeed");
213
214        let json = std::fs::read_to_string(tmp.path().join("sparsity_metadata.json"))
215            .expect("file read should succeed");
216        let meta: SparsityMetadata =
217            serde_json::from_str(&json).expect("JSON deserialization should succeed");
218
219        // layer.0.bias should have 0% sparsity
220        let bias_info = meta
221            .tensors
222            .iter()
223            .find(|t| t.name == "layer.0.bias")
224            .expect("operation should succeed");
225        assert_eq!(bias_info.sparsity, 0.0);
226        assert_eq!(bias_info.zero_count, 0);
227
228        // layer.0.weight should have 5/8 = 62.5% sparsity
229        let weight_info = meta
230            .tensors
231            .iter()
232            .find(|t| t.name == "layer.0.weight")
233            .expect("operation should succeed");
234        assert!(weight_info.sparsity > 0.5);
235        assert_eq!(weight_info.zero_count, 5);
236    }
237
238    #[test]
239    fn test_export_sparse_safetensors_valid() {
240        let (weights, shapes) = make_test_data();
241        let metrics = PruningMetrics::new(0.5);
242        let tmp = TempDir::new().expect("temp file creation should succeed");
243
244        let result =
245            export_sparse_model(&weights, &shapes, &metrics, tmp.path(), "sparse.safetensors")
246                .expect("operation should succeed");
247
248        // Verify safetensors is valid
249        let data = std::fs::read(&result.weights_path).expect("file read should succeed");
250        let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
251        assert_eq!(loaded.len(), 2);
252    }
253
254    #[test]
255    fn test_export_empty_weights() {
256        let weights = HashMap::new();
257        let shapes = HashMap::new();
258        let metrics = PruningMetrics::new(0.0);
259        let tmp = TempDir::new().expect("temp file creation should succeed");
260
261        let result =
262            export_sparse_model(&weights, &shapes, &metrics, tmp.path(), "empty.safetensors")
263                .expect("operation should succeed");
264
265        assert_eq!(result.num_tensors, 0);
266        assert_eq!(result.global_sparsity, 0.0);
267    }
268}