use crate::prune::pipeline::metrics::PruningMetrics;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone)]
pub struct SparseExportResult {
pub weights_path: PathBuf,
pub metadata_path: PathBuf,
pub global_sparsity: f32,
pub num_tensors: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SparsityMetadata {
pub version: String,
pub global_sparsity: f32,
pub total_parameters: usize,
pub parameters_pruned: usize,
pub tensors: Vec<TensorSparsityInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorSparsityInfo {
pub name: String,
pub sparsity: f32,
pub zero_count: usize,
pub total_count: usize,
}
pub fn export_sparse_model(
weights: &HashMap<String, Vec<f32>>,
shapes: &HashMap<String, Vec<usize>>,
metrics: &PruningMetrics,
output_dir: impl AsRef<Path>,
filename: &str,
) -> Result<SparseExportResult, std::io::Error> {
let output_dir = output_dir.as_ref();
std::fs::create_dir_all(output_dir)?;
let mut tensor_infos = Vec::new();
let mut total_zeros = 0usize;
let mut total_elements = 0usize;
let mut names: Vec<&String> = weights.keys().collect();
names.sort();
for name in &names {
let data = &weights[*name];
let zero_count = data.iter().filter(|&&v| v == 0.0).count();
let total = data.len();
tensor_infos.push(TensorSparsityInfo {
name: (*name).clone(),
sparsity: if total > 0 { zero_count as f32 / total as f32 } else { 0.0 },
zero_count,
total_count: total,
});
total_zeros += zero_count;
total_elements += total;
}
let global_sparsity =
if total_elements > 0 { total_zeros as f32 / total_elements as f32 } else { 0.0 };
let metadata = SparsityMetadata {
version: "1.0".to_string(),
global_sparsity,
total_parameters: metrics.total_parameters,
parameters_pruned: metrics.parameters_pruned,
tensors: tensor_infos,
};
let weights_path = output_dir.join(filename);
{
use safetensors::tensor::{Dtype, TensorView};
let tensor_data: Vec<(String, Vec<u8>, Vec<usize>)> = names
.iter()
.map(|name| {
let data = &weights[*name];
let bytes: Vec<u8> = bytemuck::cast_slice(data).to_vec();
let shape = shapes.get(*name).cloned().unwrap_or_else(|| vec![data.len()]);
((*name).clone(), bytes, shape)
})
.collect();
let views: Vec<(&str, TensorView<'_>)> = tensor_data
.iter()
.map(|(name, bytes, shape)| {
let view = TensorView::new(Dtype::F32, shape.clone(), bytes)
.expect("TensorView construction must not fail for valid F32 data");
(name.as_str(), view)
})
.collect();
let safetensor_bytes = safetensors::serialize(views, None)
.map_err(|e| std::io::Error::other(e.to_string()))?;
std::fs::write(&weights_path, safetensor_bytes)?;
}
let metadata_path = output_dir.join("sparsity_metadata.json");
let metadata_json = serde_json::to_string_pretty(&metadata)
.map_err(|e| std::io::Error::other(e.to_string()))?;
std::fs::write(&metadata_path, metadata_json)?;
Ok(SparseExportResult {
weights_path,
metadata_path,
global_sparsity,
num_tensors: names.len(),
})
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn make_test_data() -> (HashMap<String, Vec<f32>>, HashMap<String, Vec<usize>>) {
let mut weights = HashMap::new();
let mut shapes = HashMap::new();
let data = vec![1.0, 0.0, 0.0, 2.0, 0.0, 3.0, 0.0, 0.0];
weights.insert("layer.0.weight".to_string(), data);
shapes.insert("layer.0.weight".to_string(), vec![2, 4]);
weights.insert("layer.0.bias".to_string(), vec![0.1, 0.2]);
shapes.insert("layer.0.bias".to_string(), vec![2]);
(weights, shapes)
}
#[test]
fn test_export_sparse_creates_files() {
let (weights, shapes) = make_test_data();
let metrics = PruningMetrics::new(0.5);
let tmp = TempDir::new().expect("temp file creation should succeed");
let result =
export_sparse_model(&weights, &shapes, &metrics, tmp.path(), "sparse.safetensors")
.expect("operation should succeed");
assert!(result.weights_path.exists());
assert!(result.metadata_path.exists());
assert_eq!(result.num_tensors, 2);
}
#[test]
fn test_export_sparse_metadata_content() {
let (weights, shapes) = make_test_data();
let mut metrics = PruningMetrics::new(0.5);
metrics.update_sparsity(5, 10);
let tmp = TempDir::new().expect("temp file creation should succeed");
export_sparse_model(&weights, &shapes, &metrics, tmp.path(), "sparse.safetensors")
.expect("parsing should succeed");
let json = std::fs::read_to_string(tmp.path().join("sparsity_metadata.json"))
.expect("file read should succeed");
let meta: SparsityMetadata =
serde_json::from_str(&json).expect("JSON deserialization should succeed");
assert_eq!(meta.version, "1.0");
assert_eq!(meta.total_parameters, 10);
assert_eq!(meta.parameters_pruned, 5);
assert_eq!(meta.tensors.len(), 2);
}
#[test]
fn test_per_tensor_sparsity() {
let (weights, shapes) = make_test_data();
let metrics = PruningMetrics::new(0.5);
let tmp = TempDir::new().expect("temp file creation should succeed");
export_sparse_model(&weights, &shapes, &metrics, tmp.path(), "sparse.safetensors")
.expect("parsing should succeed");
let json = std::fs::read_to_string(tmp.path().join("sparsity_metadata.json"))
.expect("file read should succeed");
let meta: SparsityMetadata =
serde_json::from_str(&json).expect("JSON deserialization should succeed");
let bias_info = meta
.tensors
.iter()
.find(|t| t.name == "layer.0.bias")
.expect("operation should succeed");
assert_eq!(bias_info.sparsity, 0.0);
assert_eq!(bias_info.zero_count, 0);
let weight_info = meta
.tensors
.iter()
.find(|t| t.name == "layer.0.weight")
.expect("operation should succeed");
assert!(weight_info.sparsity > 0.5);
assert_eq!(weight_info.zero_count, 5);
}
#[test]
fn test_export_sparse_safetensors_valid() {
let (weights, shapes) = make_test_data();
let metrics = PruningMetrics::new(0.5);
let tmp = TempDir::new().expect("temp file creation should succeed");
let result =
export_sparse_model(&weights, &shapes, &metrics, tmp.path(), "sparse.safetensors")
.expect("operation should succeed");
let data = std::fs::read(&result.weights_path).expect("file read should succeed");
let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
assert_eq!(loaded.len(), 2);
}
#[test]
fn test_export_empty_weights() {
let weights = HashMap::new();
let shapes = HashMap::new();
let metrics = PruningMetrics::new(0.0);
let tmp = TempDir::new().expect("temp file creation should succeed");
let result =
export_sparse_model(&weights, &shapes, &metrics, tmp.path(), "empty.safetensors")
.expect("operation should succeed");
assert_eq!(result.num_tensors, 0);
assert_eq!(result.global_sparsity, 0.0);
}
}