use torsh_core::TorshError;
use torsh_sparse::*;
use torsh_tensor::creation::zeros;
fn main() -> Result<(), TorshError> {
println!("ToRSh-Sparse Basic Usage Example");
println!("================================");
println!("1. Creating sparse matrix from triplets...");
let triplets = vec![
(0, 0, 1.0),
(0, 2, 2.0),
(1, 1, 3.0),
(2, 0, 4.0),
(2, 2, 5.0),
];
let coo_matrix = CooTensor::from_triplets(triplets, (3, 3))?;
println!("COO matrix created with {} non-zeros", coo_matrix.nnz());
println!("\n2. Converting to different formats...");
let csr_matrix = CsrTensor::from_coo(&coo_matrix)?;
let csc_matrix = CscTensor::from_coo(&coo_matrix)?;
println!("CSR matrix: {} non-zeros", csr_matrix.nnz());
println!("CSC matrix: {} non-zeros", csc_matrix.nnz());
println!("\n3. Element access...");
let value = csr_matrix.get(0, 0).unwrap_or(0.0);
println!("Element at (0, 0): {value}");
let value = csr_matrix.get(1, 1).unwrap_or(0.0);
println!("Element at (1, 1): {value}");
let value = csr_matrix.get(0, 1).unwrap_or(0.0); println!("Element at (0, 1): {value}");
println!("\n4. Matrix properties...");
println!("Shape: {:?}", csr_matrix.shape());
println!("Non-zeros: {}", csr_matrix.nnz());
println!("Density: {:.2}%", csr_matrix.density() * 100.0);
println!("\n5. Dense representation:");
let dense = csr_matrix.to_dense()?;
println!("Dense matrix:");
for i in 0..3 {
for j in 0..3 {
print!("{:.1} ", dense.get(&[i, j]).unwrap());
}
println!();
}
println!("\n6. Matrix-vector multiplication...");
let vector = [1.0, 2.0, 3.0];
let vector_tensor = zeros::<f32>(&[3])?;
for (i, &v) in vector.iter().enumerate() {
vector_tensor.set(&[i], v)?;
}
let result = csr_matrix.matvec(&vector_tensor)?;
println!("Matrix * vector = {:?}", result.to_vec()?);
println!("\n7. Matrix transpose...");
let transposed = csr_matrix.transpose()?;
println!("Transposed matrix has {} non-zeros", transposed.nnz());
println!("\n8. Scalar operations...");
let scaled = csr_matrix.scale(2.0)?;
println!("Scaled matrix (2x) has {} non-zeros", scaled.nnz());
println!("\n9. Matrix addition...");
let sum = csr_matrix.add(&scaled)?;
println!("Sum matrix has {} non-zeros", sum.nnz());
println!("\n10. Reduction operations...");
let total_sum = csr_matrix.sum()?;
println!("Sum of all elements: {total_sum}");
let l2_norm = csr_matrix.norm(2.0)?;
println!("L2 norm: {l2_norm:.4}");
let diagonal = csr_matrix.diagonal()?;
println!("Diagonal elements: {diagonal:?}");
println!("\nBasic usage example completed successfully!");
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_operations() {
let result = main();
assert!(result.is_ok());
}
#[test]
fn test_matrix_properties() -> Result<(), TorshError> {
let triplets = vec![(0, 0, 1.0), (1, 1, 2.0), (2, 2, 3.0)];
let coo = CooTensor::from_triplets(triplets, (3, 3))?;
let csr = CsrTensor::from_coo(&coo)?;
assert_eq!(csr.nnz(), 3);
assert_eq!(csr.shape(), (3, 3));
assert!((csr.density() - 3.0 / 9.0).abs() < 1e-10);
Ok(())
}
#[test]
fn test_format_conversions() -> Result<(), TorshError> {
let triplets = vec![(0, 0, 1.0), (1, 1, 2.0), (2, 2, 3.0)];
let coo = CooTensor::from_triplets(triplets, (3, 3))?;
let csr = CsrTensor::from_coo(&coo)?;
let csc = CscTensor::from_coo(&coo)?;
assert_eq!(coo.nnz(), csr.nnz());
assert_eq!(coo.nnz(), csc.nnz());
assert_eq!(csr.get(0, 0)?, 1.0);
assert_eq!(csr.get(1, 1)?, 2.0);
assert_eq!(csr.get(2, 2)?, 3.0);
Ok(())
}
}