use super::{SparseFormat, SparseOps, SparseTensor};
use crate::error::{RusTorchError, RusTorchResult};
use ndarray::{Array1, Array2, ArrayD};
use num_traits::{Float, FromPrimitive, One, Zero};
use std::collections::{HashMap, HashSet};
use std::iter::Sum;
pub struct SparseAnalyzer<T: Float> {
_phantom: std::marker::PhantomData<T>,
}
impl<T: Float + Copy + PartialOrd + Sum + std::fmt::Display> SparseAnalyzer<T> {
pub fn analyze_pattern(tensor: &SparseTensor<T>) -> SparsePatternAnalysis<T> {
let mut analysis = SparsePatternAnalysis::new();
analysis.total_elements = tensor.dense_size();
analysis.non_zero_elements = tensor.nnz;
analysis.sparsity_ratio = tensor.sparsity();
analysis.format = tensor.format;
if !tensor.values.is_empty() {
let values_slice = tensor.values.as_slice().unwrap();
analysis.min_value = values_slice
.iter()
.fold(T::infinity(), |a, &b| if a < b { a } else { b });
analysis.max_value =
values_slice
.iter()
.fold(T::neg_infinity(), |a, &b| if a > b { a } else { b });
analysis.mean_abs_value =
values_slice.iter().map(|&x| x.abs()).sum::<T>() / T::from(tensor.nnz).unwrap();
}
analysis.pattern_regularity = Self::compute_pattern_regularity(tensor);
let dense_memory = tensor.dense_size() * std::mem::size_of::<T>();
let sparse_memory = tensor.memory_usage();
analysis.memory_efficiency = 1.0 - (sparse_memory as f64 / dense_memory as f64);
analysis
}
fn compute_pattern_regularity(tensor: &SparseTensor<T>) -> f64 {
if tensor.format != SparseFormat::COO || tensor.shape.len() != 2 {
return 0.0; }
let row_indices = &tensor.indices[0];
let col_indices = &tensor.indices[1];
let mut row_counts = HashMap::new();
for &row in row_indices.iter() {
*row_counts.entry(row).or_insert(0) += 1;
}
let row_count_values: Vec<_> = row_counts.values().collect();
if row_count_values.is_empty() {
return 0.0;
}
let mean = row_count_values.iter().map(|&&x| x as f64).sum::<f64>()
/ row_count_values.len() as f64;
let variance = row_count_values
.iter()
.map(|&&x| (x as f64 - mean).powi(2))
.sum::<f64>()
/ row_count_values.len() as f64;
let cv = variance.sqrt() / mean;
(1.0 / (1.0 + cv)).clamp(0.0, 1.0)
}
pub fn suggest_optimal_format(
tensor: &SparseTensor<T>,
access_pattern: AccessPattern,
) -> SparseFormat {
match access_pattern {
AccessPattern::RowMajor | AccessPattern::MatrixVector => SparseFormat::CSR,
AccessPattern::ColumnMajor => SparseFormat::CSC,
AccessPattern::Random | AccessPattern::Unknown => {
if tensor.sparsity() > 0.95 {
SparseFormat::COO } else {
SparseFormat::CSR }
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum AccessPattern {
RowMajor,
ColumnMajor,
MatrixVector,
Random,
Unknown,
}
#[derive(Debug, Clone)]
pub struct SparsePatternAnalysis<T: Float> {
pub total_elements: usize,
pub non_zero_elements: usize,
pub sparsity_ratio: f64,
pub format: SparseFormat,
pub min_value: T,
pub max_value: T,
pub mean_abs_value: T,
pub pattern_regularity: f64,
pub memory_efficiency: f64,
}
impl<T: Float + std::fmt::Display> SparsePatternAnalysis<T> {
fn new() -> Self {
Self {
total_elements: 0,
non_zero_elements: 0,
sparsity_ratio: 0.0,
format: SparseFormat::COO,
min_value: T::zero(),
max_value: T::zero(),
mean_abs_value: T::zero(),
pattern_regularity: 0.0,
memory_efficiency: 0.0,
}
}
pub fn report(&self) -> String {
format!(
"Sparse Tensor Analysis Report:\n\
================================\n\
Format: {:?}\n\
Total elements: {}\n\
Non-zero elements: {}\n\
Sparsity: {:.2}%\n\
Pattern regularity: {:.2}\n\
Memory efficiency: {:.2}%\n\
Value range: [{:.6}, {:.6}]\n\
Mean |value|: {:.6}",
self.format,
self.total_elements,
self.non_zero_elements,
self.sparsity_ratio * 100.0,
self.pattern_regularity,
self.memory_efficiency * 100.0,
self.min_value,
self.max_value,
self.mean_abs_value
)
}
pub fn optimization_recommendations(&self) -> Vec<String> {
let mut recommendations = Vec::new();
if self.sparsity_ratio > 0.95 {
recommendations
.push("Very high sparsity - consider COO format for memory efficiency".to_string());
} else if self.sparsity_ratio < 0.5 {
recommendations.push("Low sparsity - consider dense representation".to_string());
}
if self.pattern_regularity > 0.8 {
recommendations
.push("High pattern regularity - structured pruning may be beneficial".to_string());
}
if self.memory_efficiency < 0.3 {
recommendations
.push("Low memory efficiency - sparse format may not be optimal".to_string());
}
recommendations
}
}
pub struct SparseValidator;
impl SparseValidator {
pub fn validate<T: Float + Copy + PartialOrd>(tensor: &SparseTensor<T>) -> RusTorchResult<()> {
if tensor.values.len() != tensor.nnz {
return Err(RusTorchError::InvalidParameters {
operation: "sparse_validation".to_string(),
message: "Values length doesn't match nnz count".to_string(),
});
}
match tensor.format {
SparseFormat::COO => {
for (dim, indices) in tensor.indices.iter().enumerate() {
if indices.len() != tensor.nnz {
return Err(RusTorchError::InvalidParameters {
operation: "sparse_validation".to_string(),
message: format!("COO indices dimension {} length mismatch", dim),
});
}
if dim < tensor.shape.len() {
let max_allowed = tensor.shape[dim];
for &idx in indices.iter() {
if idx >= max_allowed {
return Err(RusTorchError::InvalidParameters {
operation: "sparse_validation".to_string(),
message: format!(
"Index {} exceeds dimension {} size {}",
idx, dim, max_allowed
),
});
}
}
}
}
}
SparseFormat::CSR => {
if tensor.indices.len() != 2 {
return Err(RusTorchError::InvalidParameters {
operation: "sparse_validation".to_string(),
message: "CSR format requires exactly 2 index arrays".to_string(),
});
}
if tensor.indices[1].len() != tensor.nnz {
return Err(RusTorchError::InvalidParameters {
operation: "sparse_validation".to_string(),
message: "CSR col_indices length must match nnz".to_string(),
});
}
let max_cols = tensor.shape[1];
for &col_idx in tensor.indices[1].iter() {
if col_idx >= max_cols {
return Err(RusTorchError::InvalidParameters {
operation: "sparse_validation".to_string(),
message: format!(
"Column index {} exceeds matrix width {}",
col_idx, max_cols
),
});
}
}
}
SparseFormat::CSC => {
if tensor.indices.len() != 2 {
return Err(RusTorchError::InvalidParameters {
operation: "sparse_validation".to_string(),
message: "CSC format requires exactly 2 index arrays".to_string(),
});
}
}
}
match tensor.format {
SparseFormat::CSR => Self::validate_csr(tensor)?,
SparseFormat::COO => Self::validate_coo(tensor)?,
SparseFormat::CSC => {
return Err(RusTorchError::NotImplemented {
feature: "CSC format validation".to_string(),
});
}
}
Ok(())
}
fn validate_csr<T: Float>(tensor: &SparseTensor<T>) -> RusTorchResult<()> {
if tensor.shape.len() != 2 {
return Err(RusTorchError::InvalidParameters {
operation: "csr_validation".to_string(),
message: "CSR format requires 2D tensors".to_string(),
});
}
if tensor.indices.len() != 2 {
return Err(RusTorchError::InvalidParameters {
operation: "csr_validation".to_string(),
message: "CSR format requires exactly 2 index arrays".to_string(),
});
}
let row_ptr = &tensor.indices[0];
let col_indices = &tensor.indices[1];
if row_ptr.len() != tensor.shape[0] + 1 {
return Err(RusTorchError::InvalidParameters {
operation: "csr_validation".to_string(),
message: "Row pointer length must be rows + 1".to_string(),
});
}
for i in 1..row_ptr.len() {
if row_ptr[i] < row_ptr[i - 1] {
return Err(RusTorchError::InvalidParameters {
operation: "csr_validation".to_string(),
message: "Row pointer must be non-decreasing".to_string(),
});
}
}
if row_ptr[row_ptr.len() - 1] != tensor.nnz {
return Err(RusTorchError::InvalidParameters {
operation: "csr_validation".to_string(),
message: "Last row pointer must equal nnz".to_string(),
});
}
Ok(())
}
fn validate_coo<T: Float>(tensor: &SparseTensor<T>) -> RusTorchResult<()> {
if tensor.indices.len() != tensor.shape.len() {
return Err(RusTorchError::InvalidParameters {
operation: "coo_validation".to_string(),
message: "COO format requires one index array per dimension".to_string(),
});
}
if tensor.shape.len() == 2 {
let mut coordinate_set = HashSet::new();
for i in 0..tensor.nnz {
let coord = (tensor.indices[0][i], tensor.indices[1][i]);
if coordinate_set.contains(&coord) {
return Err(RusTorchError::InvalidParameters {
operation: "coo_validation".to_string(),
message: "Duplicate coordinates found in COO tensor".to_string(),
});
}
coordinate_set.insert(coord);
}
}
Ok(())
}
}
pub struct SparseConverter;
impl SparseConverter {
pub fn convert<T: Float + Zero + One + Copy + std::ops::AddAssign + FromPrimitive>(
tensor: &SparseTensor<T>,
target_format: SparseFormat,
) -> RusTorchResult<SparseTensor<T>> {
SparseValidator::validate(tensor)?;
let result = match (tensor.format, target_format) {
(SparseFormat::COO, SparseFormat::CSR) => tensor.to_csr()?,
(SparseFormat::CSR, SparseFormat::COO) => tensor.to_coo()?,
(format, target) if format == target => tensor.clone(),
_ => {
return Err(RusTorchError::NotImplemented {
feature: format!("Conversion from {:?} to {:?}", tensor.format, target_format),
});
}
};
SparseValidator::validate(&result)?;
Ok(result)
}
pub fn batch_convert<T: Float + Zero + One + Copy + std::ops::AddAssign + FromPrimitive>(
tensors: &[SparseTensor<T>],
target_format: SparseFormat,
) -> RusTorchResult<Vec<SparseTensor<T>>> {
let mut results = Vec::with_capacity(tensors.len());
for tensor in tensors {
let converted = Self::convert(tensor, target_format)?;
results.push(converted);
}
Ok(results)
}
}
pub struct SparseIO;
impl SparseIO {
pub fn save_binary<T: Float>(
_tensor: &SparseTensor<T>,
_path: &std::path::Path,
) -> RusTorchResult<()> {
Err(RusTorchError::NotImplemented {
feature: "Sparse tensor binary serialization".to_string(),
})
}
pub fn load_binary<T: Float>(_path: &std::path::Path) -> RusTorchResult<SparseTensor<T>> {
Err(RusTorchError::NotImplemented {
feature: "Sparse tensor binary deserialization".to_string(),
})
}
}
pub struct SparseBenchmark<T: Float> {
pub results: HashMap<String, BenchmarkResult>,
_phantom: std::marker::PhantomData<T>,
}
#[derive(Debug, Clone)]
pub struct BenchmarkResult {
pub operation: String,
pub time_ns: u64,
pub memory_bytes: usize,
pub throughput_ops: f64,
}
impl<
T: Float
+ Copy
+ Zero
+ One
+ std::ops::AddAssign
+ PartialOrd
+ Sum
+ num_traits::FromPrimitive
+ 'static,
> SparseBenchmark<T>
{
pub fn new() -> Self {
Self {
results: HashMap::new(),
_phantom: std::marker::PhantomData,
}
}
pub fn benchmark_spmv(
&mut self,
tensor: &SparseTensor<T>,
vector: &Array1<T>,
iterations: usize,
) -> RusTorchResult<()>
where
T: Zero + One + std::ops::AddAssign + num_traits::FromPrimitive,
{
let start_time = std::time::Instant::now();
for _ in 0..iterations {
let _ = tensor.spmv(vector)?;
}
let elapsed = start_time.elapsed();
let time_per_op = elapsed.as_nanos() / iterations as u128;
let result = BenchmarkResult {
operation: "spmv".to_string(),
time_ns: time_per_op as u64,
memory_bytes: tensor.memory_usage(),
throughput_ops: 1_000_000_000.0 / time_per_op as f64,
};
self.results.insert("spmv".to_string(), result);
Ok(())
}
pub fn compare_with_dense(
&mut self,
sparse_tensor: &SparseTensor<T>,
dense_equivalent: &Array2<T>,
vector: &Array1<T>,
) -> RusTorchResult<f64> {
self.benchmark_spmv(sparse_tensor, vector, 100)?;
let sparse_time = self.results["spmv"].time_ns;
let start_time = std::time::Instant::now();
for _ in 0..100 {
let _ = dense_equivalent.dot(vector);
}
let dense_time = start_time.elapsed().as_nanos() / 100;
Ok(dense_time as f64 / sparse_time as f64)
}
pub fn report(&self) -> String {
let mut report = String::from("Sparse Operations Benchmark Report:\n");
report.push_str("=====================================\n");
for (op, result) in &self.results {
report.push_str(&format!(
"{}: {:.2}μs, {:.1}MB/s throughput\n",
op,
result.time_ns as f64 / 1000.0,
result.throughput_ops / 1_000_000.0
));
}
report
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sparse_analyzer() {
let indices = vec![
Array1::from_vec(vec![0, 1, 2]),
Array1::from_vec(vec![0, 1, 2]),
];
let values = Array1::from_vec(vec![1.0f32, 2.0, 3.0]);
let shape = vec![4, 4];
let sparse_tensor = SparseTensor::from_coo(indices, values, shape).unwrap();
let analysis = SparseAnalyzer::analyze_pattern(&sparse_tensor);
assert_eq!(analysis.total_elements, 16);
assert_eq!(analysis.non_zero_elements, 3);
assert!(analysis.sparsity_ratio > 0.8);
}
#[test]
fn test_sparse_validator() {
let indices = vec![Array1::from_vec(vec![0, 1]), Array1::from_vec(vec![0, 1])];
let values = Array1::from_vec(vec![1.0f32, 2.0]);
let shape = vec![2, 2];
let sparse_tensor = SparseTensor::from_coo(indices, values.clone(), shape).unwrap();
assert!(SparseValidator::validate(&sparse_tensor).is_ok());
let invalid_indices = vec![
Array1::from_vec(vec![0, 5]), Array1::from_vec(vec![0, 1]),
];
let invalid_tensor = SparseTensor::from_coo(invalid_indices, values, vec![2, 2]).unwrap();
assert!(SparseValidator::validate(&invalid_tensor).is_err());
}
#[test]
fn test_sparse_converter() {
let indices = vec![Array1::from_vec(vec![0, 1]), Array1::from_vec(vec![0, 1])];
let values = Array1::from_vec(vec![1.0f32, 2.0]);
let shape = vec![2, 2];
let coo_tensor = SparseTensor::from_coo(indices, values, shape).unwrap();
match SparseConverter::convert(&coo_tensor, SparseFormat::CSR) {
Ok(csr_tensor) => {
assert_eq!(csr_tensor.format, SparseFormat::CSR);
assert_eq!(csr_tensor.nnz, coo_tensor.nnz);
}
Err(e) => {
println!("CSR conversion failed: {:?}", e);
panic!("CSR conversion should work");
}
}
}
#[test]
fn test_sparse_benchmark() {
let mut benchmark = SparseBenchmark::new();
let sparse_tensor = SparseTensor::from_coo(
vec![Array1::from_vec(vec![0, 1]), Array1::from_vec(vec![0, 1])],
Array1::from_vec(vec![1.0f32, 2.0]),
vec![2, 2],
)
.unwrap()
.to_csr()
.unwrap();
let vector = Array1::from_vec(vec![1.0, 2.0]);
benchmark
.benchmark_spmv(&sparse_tensor, &vector, 10)
.unwrap();
assert!(benchmark.results.contains_key("spmv"));
let report = benchmark.report();
assert!(report.contains("spmv"));
}
}