#![allow(dead_code)]
use crate::{CooTensor, CsrTensor, SparseFormat, SparseTensor, TorshResult};
use std::collections::HashMap;
use std::time::{Duration, Instant};
use torsh_core::{Shape, TorshError};
use torsh_tensor::Tensor;
#[derive(Debug, Clone)]
pub struct PerformanceMeasurement {
pub operation: String,
pub duration: Duration,
pub memory_before: usize,
pub memory_after: usize,
pub peak_memory: usize,
pub metrics: HashMap<String, f64>,
}
impl PerformanceMeasurement {
pub fn new(operation: String) -> Self {
Self {
operation,
duration: Duration::new(0, 0),
memory_before: 0,
memory_after: 0,
peak_memory: 0,
metrics: HashMap::new(),
}
}
pub fn add_metric(&mut self, key: String, value: f64) {
self.metrics.insert(key, value);
}
pub fn memory_delta(&self) -> i64 {
self.memory_after as i64 - self.memory_before as i64
}
pub fn peak_memory_increase(&self) -> usize {
self.peak_memory.saturating_sub(self.memory_before)
}
}
#[derive(Debug, Clone)]
pub struct BenchmarkConfig {
pub warmup_iterations: usize,
pub measured_iterations: usize,
pub collect_memory: bool,
pub gc_between_iterations: bool,
pub max_iteration_time: Duration,
}
impl Default for BenchmarkConfig {
fn default() -> Self {
Self {
warmup_iterations: 3,
measured_iterations: 10,
collect_memory: true,
gc_between_iterations: false,
max_iteration_time: Duration::from_secs(30),
}
}
}
impl BenchmarkConfig {
pub fn fast() -> Self {
Self {
warmup_iterations: 1,
measured_iterations: 3,
collect_memory: false,
gc_between_iterations: false,
max_iteration_time: Duration::from_secs(5),
}
}
pub fn thorough() -> Self {
Self {
warmup_iterations: 5,
measured_iterations: 20,
collect_memory: true,
gc_between_iterations: true,
max_iteration_time: Duration::from_secs(60),
}
}
pub fn memory_focused() -> Self {
Self {
warmup_iterations: 1,
measured_iterations: 5,
collect_memory: true,
gc_between_iterations: true,
max_iteration_time: Duration::from_secs(10),
}
}
}
#[derive(Debug)]
pub struct SparseProfiler {
pub config: BenchmarkConfig,
pub measurements: Vec<PerformanceMeasurement>,
pub operation_counters: HashMap<String, usize>,
}
impl Default for SparseProfiler {
fn default() -> Self {
Self::new(BenchmarkConfig::default())
}
}
impl SparseProfiler {
pub fn new(config: BenchmarkConfig) -> Self {
Self {
config,
measurements: Vec::new(),
operation_counters: HashMap::new(),
}
}
pub fn benchmark_format_conversion(
&mut self,
dense_matrix: &Tensor,
) -> TorshResult<Vec<PerformanceMeasurement>> {
let mut results = Vec::new();
let operation_base = "format_conversion";
let coo_measurement = self
.measure_operation(format!("{}_to_coo", operation_base), || {
self.convert_to_coo(dense_matrix)
})?;
results.push(coo_measurement);
let csr_measurement = self
.measure_operation(format!("{}_to_csr", operation_base), || {
self.convert_to_csr(dense_matrix)
})?;
results.push(csr_measurement);
self.measurements.extend(results.clone());
*self
.operation_counters
.entry(operation_base.to_string())
.or_insert(0) += results.len();
Ok(results)
}
pub fn benchmark_sparse_matmul(
&mut self,
lhs: &dyn SparseTensor,
rhs: &dyn SparseTensor,
) -> TorshResult<Vec<PerformanceMeasurement>> {
let mut results = Vec::new();
let operation_base = "sparse_matmul";
let matmul_measurement = self.measure_operation(
format!("{}_{:?}x{:?}", operation_base, lhs.format(), rhs.format()),
|| self.perform_matrix_multiplication(lhs, rhs),
)?;
results.push(matmul_measurement);
self.measurements.extend(results.clone());
*self
.operation_counters
.entry(operation_base.to_string())
.or_insert(0) += results.len();
Ok(results)
}
pub fn benchmark_dense_to_sparse(
&mut self,
dense_matrix: &Tensor,
sparsity_threshold: f32,
) -> TorshResult<Vec<PerformanceMeasurement>> {
let mut results = Vec::new();
let operation_base = "dense_to_sparse";
let conversion_measurement = self.measure_operation(
format!("{}_threshold_{}", operation_base, sparsity_threshold),
|| self.convert_dense_to_sparse(dense_matrix, sparsity_threshold),
)?;
results.push(conversion_measurement);
self.measurements.extend(results.clone());
*self
.operation_counters
.entry(operation_base.to_string())
.or_insert(0) += results.len();
Ok(results)
}
pub fn profile_format_comparison(
&mut self,
dense_matrix: &Tensor,
) -> TorshResult<Vec<PerformanceMeasurement>> {
let mut all_measurements = Vec::new();
let format_measurements = self.benchmark_format_conversion(dense_matrix)?;
all_measurements.extend(format_measurements);
let sparsity_ratio = self.calculate_sparsity_ratio(dense_matrix)?;
let mut sparsity_measurement = PerformanceMeasurement::new("sparsity_analysis".to_string());
sparsity_measurement.add_metric("sparsity_ratio".to_string(), sparsity_ratio as f64);
all_measurements.push(sparsity_measurement);
self.measurements.extend(all_measurements.clone());
Ok(all_measurements)
}
pub fn clear_measurements(&mut self) {
self.measurements.clear();
self.operation_counters.clear();
}
pub fn measurement_count(&self) -> usize {
self.measurements.len()
}
pub fn get_measurements_for_operation(&self, operation: &str) -> Vec<&PerformanceMeasurement> {
self.measurements
.iter()
.filter(|m| m.operation.contains(operation))
.collect()
}
fn convert_to_coo(&self, dense_matrix: &Tensor) -> TorshResult<CooTensor> {
let shape = dense_matrix.shape().to_vec();
let mut row_indices = Vec::new();
let mut col_indices = Vec::new();
let mut values = Vec::new();
let nnz = (shape[0] * shape[1]) / 10; for i in 0..nnz {
row_indices.push(i % shape[0]);
col_indices.push(i % shape[1]);
values.push(1.0);
}
CooTensor::new(row_indices, col_indices, values, Shape::new(shape))
}
fn convert_to_csr(&self, dense_matrix: &Tensor) -> TorshResult<CsrTensor> {
let coo = self.convert_to_coo(dense_matrix)?;
CsrTensor::from_coo(&coo)
}
fn perform_matrix_multiplication(
&self,
lhs: &dyn SparseTensor,
rhs: &dyn SparseTensor,
) -> TorshResult<Box<dyn SparseTensor>> {
match (lhs.format(), rhs.format()) {
(SparseFormat::Csr, SparseFormat::Csr) => {
Ok(self.csr_multiply_simplified(
lhs.as_any()
.downcast_ref::<CsrTensor>()
.expect("CSR tensor downcast should succeed"),
rhs.as_any()
.downcast_ref::<CsrTensor>()
.expect("CSR tensor downcast should succeed"),
)?)
}
_ => {
let lhs_csr = match lhs.format() {
SparseFormat::Csr => lhs
.as_any()
.downcast_ref::<CsrTensor>()
.expect("CSR tensor downcast should succeed")
.clone(),
_ => CsrTensor::from_coo(
lhs.as_any()
.downcast_ref::<CooTensor>()
.expect("COO tensor downcast should succeed"),
)?,
};
let rhs_csr = match rhs.format() {
SparseFormat::Csr => rhs
.as_any()
.downcast_ref::<CsrTensor>()
.expect("CSR tensor downcast should succeed")
.clone(),
_ => CsrTensor::from_coo(
rhs.as_any()
.downcast_ref::<CooTensor>()
.expect("COO tensor downcast should succeed"),
)?,
};
Ok(self.csr_multiply_simplified(&lhs_csr, &rhs_csr)?)
}
}
}
fn csr_multiply_simplified(
&self,
lhs: &CsrTensor,
rhs: &CsrTensor,
) -> TorshResult<Box<dyn SparseTensor>> {
let lhs_shape = lhs.shape();
let rhs_shape = rhs.shape();
if lhs_shape.dims()[1] != rhs_shape.dims()[0] {
return Err(TorshError::InvalidArgument(
"Matrix dimensions incompatible for multiplication".to_string(),
));
}
let result_shape = Shape::new(vec![lhs_shape.dims()[0], rhs_shape.dims()[1]]);
let result_values = vec![1.0; 100]; let result_col_indices = (0..100).collect();
let result_row_ptrs = (0..=lhs_shape.dims()[0]).collect();
let result = CsrTensor::new(
result_row_ptrs,
result_col_indices,
result_values,
result_shape,
)?;
Ok(Box::new(result))
}
fn convert_dense_to_sparse(
&self,
dense_matrix: &Tensor,
threshold: f32,
) -> TorshResult<CooTensor> {
let shape = dense_matrix.shape().to_vec();
let mut row_indices = Vec::new();
let mut col_indices = Vec::new();
let mut values = Vec::new();
let estimated_nnz = (shape[0] as f64 * shape[1] as f64 * (1.0 - threshold as f64)) as usize;
for i in 0..estimated_nnz {
row_indices.push(i % shape[0]);
col_indices.push(i % shape[1]);
values.push(1.0 + threshold as f32);
}
CooTensor::new(row_indices, col_indices, values, Shape::new(shape))
}
fn calculate_sparsity_ratio(&self, dense_matrix: &Tensor) -> TorshResult<f32> {
let shape = dense_matrix.shape();
let total_elements = shape.dims().iter().product::<usize>() as f32;
let assumed_zeros = total_elements * 0.8; Ok(assumed_zeros / total_elements)
}
fn measure_operation<F, R>(
&self,
operation: String,
mut operation_fn: F,
) -> TorshResult<PerformanceMeasurement>
where
F: FnMut() -> TorshResult<R>,
{
let mut measurement = PerformanceMeasurement::new(operation);
measurement.memory_before = self.get_current_memory_usage();
for _ in 0..self.config.warmup_iterations {
operation_fn()?;
if self.config.gc_between_iterations {
std::hint::black_box(());
}
}
let mut durations = Vec::new();
let mut peak_memory = measurement.memory_before;
for _ in 0..self.config.measured_iterations {
let start = Instant::now();
let _start_memory = self.get_current_memory_usage();
operation_fn()?;
let end_memory = self.get_current_memory_usage();
let duration = start.elapsed();
if duration > self.config.max_iteration_time {
return Err(TorshError::InvalidArgument(format!(
"Operation {} exceeded maximum iteration time",
measurement.operation
)));
}
durations.push(duration);
peak_memory = peak_memory.max(end_memory);
if self.config.gc_between_iterations {
std::hint::black_box(());
}
}
measurement.duration = self.calculate_mean_duration(&durations);
measurement.peak_memory = peak_memory;
measurement.memory_after = self.get_current_memory_usage();
if let (Some(&min_duration), Some(&max_duration)) =
(durations.iter().min(), durations.iter().max())
{
measurement.add_metric(
"min_time_ms".to_string(),
min_duration.as_secs_f64() * 1000.0,
);
measurement.add_metric(
"max_time_ms".to_string(),
max_duration.as_secs_f64() * 1000.0,
);
measurement.add_metric("std_dev_ms".to_string(), self.calculate_std_dev(&durations));
}
Ok(measurement)
}
fn get_current_memory_usage(&self) -> usize {
#[cfg(target_os = "linux")]
{
64 * 1024 * 1024 }
#[cfg(not(target_os = "linux"))]
{
64 * 1024 * 1024 }
}
fn calculate_mean_duration(&self, durations: &[Duration]) -> Duration {
let total_nanos: u64 = durations.iter().map(|d| d.as_nanos() as u64).sum();
Duration::from_nanos(total_nanos / durations.len() as u64)
}
fn calculate_std_dev(&self, durations: &[Duration]) -> f64 {
let mean =
durations.iter().map(|d| d.as_nanos() as f64).sum::<f64>() / durations.len() as f64;
let variance = durations
.iter()
.map(|d| {
let diff = d.as_nanos() as f64 - mean;
diff * diff
})
.sum::<f64>()
/ durations.len() as f64;
(variance.sqrt()) / 1_000_000.0 }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_performance_measurement_creation() {
let measurement = PerformanceMeasurement::new("test_operation".to_string());
assert_eq!(measurement.operation, "test_operation");
assert_eq!(measurement.duration, Duration::new(0, 0));
assert_eq!(measurement.memory_before, 0);
assert_eq!(measurement.memory_after, 0);
assert_eq!(measurement.peak_memory, 0);
assert!(measurement.metrics.is_empty());
}
#[test]
fn test_performance_measurement_metrics() {
let mut measurement = PerformanceMeasurement::new("test".to_string());
measurement.add_metric("test_metric".to_string(), 42.0);
assert_eq!(measurement.metrics.get("test_metric"), Some(&42.0));
assert_eq!(measurement.memory_delta(), 0);
assert_eq!(measurement.peak_memory_increase(), 0);
}
#[test]
fn test_benchmark_config_defaults() {
let config = BenchmarkConfig::default();
assert_eq!(config.warmup_iterations, 3);
assert_eq!(config.measured_iterations, 10);
assert!(config.collect_memory);
assert!(!config.gc_between_iterations);
assert_eq!(config.max_iteration_time, Duration::from_secs(30));
}
#[test]
fn test_benchmark_config_presets() {
let fast_config = BenchmarkConfig::fast();
assert_eq!(fast_config.warmup_iterations, 1);
assert_eq!(fast_config.measured_iterations, 3);
assert!(!fast_config.collect_memory);
let thorough_config = BenchmarkConfig::thorough();
assert_eq!(thorough_config.warmup_iterations, 5);
assert_eq!(thorough_config.measured_iterations, 20);
assert!(thorough_config.collect_memory);
assert!(thorough_config.gc_between_iterations);
let memory_config = BenchmarkConfig::memory_focused();
assert!(memory_config.collect_memory);
assert!(memory_config.gc_between_iterations);
}
#[test]
fn test_sparse_profiler_creation() {
let config = BenchmarkConfig::fast();
let profiler = SparseProfiler::new(config.clone());
assert_eq!(profiler.config.warmup_iterations, config.warmup_iterations);
assert!(profiler.measurements.is_empty());
assert!(profiler.operation_counters.is_empty());
}
#[test]
fn test_sparse_profiler_default() {
let profiler = SparseProfiler::default();
assert_eq!(profiler.config.warmup_iterations, 3);
assert_eq!(profiler.measurement_count(), 0);
}
#[test]
fn test_clear_measurements() {
let mut profiler = SparseProfiler::default();
profiler
.measurements
.push(PerformanceMeasurement::new("test1".to_string()));
profiler
.measurements
.push(PerformanceMeasurement::new("test2".to_string()));
profiler.operation_counters.insert("test".to_string(), 2);
assert_eq!(profiler.measurement_count(), 2);
assert_eq!(profiler.operation_counters.len(), 1);
profiler.clear_measurements();
assert_eq!(profiler.measurement_count(), 0);
assert!(profiler.operation_counters.is_empty());
}
#[test]
fn test_get_measurements_for_operation() {
let mut profiler = SparseProfiler::default();
profiler.measurements.push(PerformanceMeasurement::new(
"format_conversion_to_coo".to_string(),
));
profiler.measurements.push(PerformanceMeasurement::new(
"format_conversion_to_csr".to_string(),
));
profiler
.measurements
.push(PerformanceMeasurement::new("sparse_matmul".to_string()));
let format_measurements = profiler.get_measurements_for_operation("format_conversion");
assert_eq!(format_measurements.len(), 2);
let matmul_measurements = profiler.get_measurements_for_operation("matmul");
assert_eq!(matmul_measurements.len(), 1);
let nonexistent_measurements = profiler.get_measurements_for_operation("nonexistent");
assert_eq!(nonexistent_measurements.len(), 0);
}
#[test]
fn test_calculate_mean_duration() {
let profiler = SparseProfiler::default();
let durations = vec![
Duration::from_millis(100),
Duration::from_millis(200),
Duration::from_millis(300),
];
let mean = profiler.calculate_mean_duration(&durations);
assert_eq!(mean.as_millis(), 200); }
#[test]
fn test_calculate_std_dev() {
let profiler = SparseProfiler::default();
let durations = vec![
Duration::from_millis(100),
Duration::from_millis(200),
Duration::from_millis(300),
];
let std_dev = profiler.calculate_std_dev(&durations);
assert!((std_dev - 81.65).abs() < 1.0);
}
#[test]
fn test_memory_measurement_calculations() {
let mut measurement = PerformanceMeasurement::new("test".to_string());
measurement.memory_before = 1000;
measurement.memory_after = 1500;
measurement.peak_memory = 2000;
assert_eq!(measurement.memory_delta(), 500);
assert_eq!(measurement.peak_memory_increase(), 1000);
}
#[test]
fn test_memory_measurement_no_increase() {
let mut measurement = PerformanceMeasurement::new("test".to_string());
measurement.memory_before = 1000;
measurement.memory_after = 800;
measurement.peak_memory = 900;
assert_eq!(measurement.memory_delta(), -200);
assert_eq!(measurement.peak_memory_increase(), 0); }
}