use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::fmt;
use scirs2_core::array_protocol::{
self, matmul, sum, transpose, ArrayFunction, ArrayProtocol, NdarrayWrapper, NotImplemented,
};
use scirs2_core::ndarray_ext::Array2;
struct SparseArray {
indices: Vec<(usize, usize)>,
values: Vec<f64>,
shape: (usize, usize),
}
impl SparseArray {
#[allow(dead_code)]
fn indices(indices: Vec<(usize, usize)>, values: Vec<f64>, shape: (usize, usize)) -> Self {
assert_eq!(
indices.len(),
values.len(),
"Indices and values must have the same length"
);
Self {
indices,
values,
shape,
}
}
fn array(array: &Array2<f64>) -> Self {
let shape = array.dim();
let mut indices = Vec::new();
let mut values = Vec::new();
for ((i, j), &val) in array.indexed_iter() {
if val != 0.0 {
indices.push((i, j));
values.push(val);
}
}
Self {
indices,
values,
shape,
}
}
fn to_dense(&self) -> Array2<f64> {
let mut result = Array2::<f64>::zeros(self.shape);
for (i, &val) in self.indices.iter().zip(self.values.iter()) {
result[[i.0, i.1]] = val;
}
result
}
fn nnz(&self) -> usize {
self.values.len()
}
fn sparsity(&self) -> f64 {
let total = self.shape.0 * self.shape.1;
self.nnz() as f64 / total as f64
}
}
impl fmt::Debug for SparseArray {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SparseArray")
.field("shape", &self.shape)
.field("nnz", &self.nnz())
.field("sparsity", &self.sparsity())
.finish()
}
}
impl ArrayProtocol for SparseArray {
fn box_clone(&self) -> Box<dyn ArrayProtocol> {
Box::new(SparseArray {
indices: self.indices.clone(),
values: self.values.clone(),
shape: self.shape,
})
}
fn array_function(
&self,
func: &ArrayFunction,
_type_ids: &[TypeId],
args: &[Box<dyn Any>],
kwargs: &HashMap<String, Box<dyn Any>>,
) -> Result<Box<dyn Any>, NotImplemented> {
match func.name {
"scirs2::array_protocol::operations::matmul" => {
if args.len() != 2 {
return Err(NotImplemented);
}
let other = if let Some(sparse) = args[1].downcast_ref::<&dyn ArrayProtocol>() {
if let Some(sparse_array) = sparse.as_any().downcast_ref::<SparseArray>() {
sparse_array
} else if let Some(ndarray_wrapper) = sparse
.as_any()
.downcast_ref::<NdarrayWrapper<f64, scirs2_core::ndarray::Ix2>>()
{
return Ok(Box::new(SparseArray::array(ndarray_wrapper.as_array())));
} else {
return Err(NotImplemented);
}
} else {
return Err(NotImplemented);
};
if self.shape.1 != other.shape.0 {
return Err(NotImplemented);
}
let a_dense = self.to_dense();
let b_dense = other.to_dense();
let result_dense = a_dense.dot(&b_dense);
let result = SparseArray::array(&result_dense);
Ok(Box::new(result))
}
"scirs2::array_protocol::operations::add" => {
if args.len() != 2 {
return Err(NotImplemented);
}
let other = if let Some(sparse) = args[1].downcast_ref::<&dyn ArrayProtocol>() {
if let Some(sparse_array) = sparse.as_any().downcast_ref::<SparseArray>() {
sparse_array
} else {
return Err(NotImplemented);
}
} else {
return Err(NotImplemented);
};
if self.shape != other.shape {
return Err(NotImplemented);
}
let a_dense = self.to_dense();
let b_dense = other.to_dense();
let result_dense = &a_dense + &b_dense;
let result = SparseArray::array(&result_dense);
Ok(Box::new(result))
}
"scirs2::array_protocol::operations::sum" => {
let result: f64 = self.values.iter().sum();
if let Some(axis_box) = kwargs.get("axis") {
if let Some(axis) = axis_box.downcast_ref::<usize>() {
let dense = self.to_dense();
let result = dense.sum_axis(scirs2_core::ndarray::Axis(*axis));
let sparse_result = SparseArray::array(
&result.into_dimensionality().expect("Operation failed"),
);
return Ok(Box::new(sparse_result));
}
}
Ok(Box::new(result))
}
"scirs2::array_protocol::operations::transpose" => {
let mut new_indices = Vec::with_capacity(self.indices.len());
for &(i, j) in &self.indices {
new_indices.push((j, i));
}
let result = SparseArray {
indices: new_indices,
values: self.values.clone(),
shape: (self.shape.1, self.shape.0),
};
Ok(Box::new(result))
}
_ => Err(NotImplemented),
}
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[allow(dead_code)]
fn main() {
array_protocol::init();
println!("Custom Array Protocol Example");
println!("============================");
let mut dense = Array2::<f64>::zeros((5, 5));
dense[[0, 0]] = 1.0;
dense[[1, 2]] = 2.0;
dense[[2, 1]] = 3.0;
dense[[3, 3]] = 4.0;
dense[[4, 4]] = 5.0;
let sparse = SparseArray::array(&dense);
println!("\nOriginal sparse array:");
println!("{:?}", sparse);
println!("Non-zero elements: {}", sparse.nnz());
println!("Sparsity: {:.2}%", sparse.sparsity() * 100.0);
let wrapped_dense = NdarrayWrapper::new(dense.clone());
println!("\n1. Matrix multiplication:");
match matmul(&sparse, &sparse) {
Ok(result) => {
if let Some(sparse_result) = result.as_any().downcast_ref::<SparseArray>() {
println!("Sparse * Sparse result:");
println!("{:?}", sparse_result);
} else {
println!("Result is not a SparseArray type");
}
}
Err(e) => println!("Error in Sparse * Sparse: {}", e),
}
match matmul(&sparse, &wrapped_dense) {
Ok(result) => {
if let Some(sparse_result) = result.as_any().downcast_ref::<SparseArray>() {
println!("Sparse * Dense result:");
println!("{:?}", sparse_result);
} else {
println!("Result is not a SparseArray type");
}
}
Err(e) => println!("Error in Sparse * Dense: {}", e),
}
println!("\n2. Sum operation:");
match sum(&sparse, None) {
Ok(result) => {
if let Some(sum_value) = result.downcast_ref::<f64>() {
println!("Sum of sparse array: {}", sum_value);
} else {
println!("Result is not a f64 type");
}
}
Err(e) => println!("Error in Sum operation: {}", e),
}
println!("\n3. Transpose operation:");
match transpose(&sparse) {
Ok(result) => {
if let Some(sparse_result) = result.as_any().downcast_ref::<SparseArray>() {
println!("Transpose of sparse array:");
println!("{:?}", sparse_result);
} else {
println!("Result is not a SparseArray type");
}
}
Err(e) => println!("Error in Transpose operation: {}", e),
}
println!("\nVerification with dense operations:");
let dense_result = dense.dot(&dense);
match matmul(&sparse, &sparse) {
Ok(sparse_result) => {
if let Some(sparse_array) = sparse_result.as_any().downcast_ref::<SparseArray>() {
let sparse_dense = sparse_array.to_dense();
let is_approx_equal = dense_result
.iter()
.zip(sparse_dense.iter())
.all(|(&a, &b)| (a - b).abs() < 1e-10);
println!("Matrix multiplication matches dense: {}", is_approx_equal);
} else {
println!("Matrix multiplication result is not a SparseArray type");
}
}
Err(e) => println!("Error in matrix multiplication verification: {}", e),
}
let dense_sum = dense.sum();
match sum(&sparse, None) {
Ok(result) => {
if let Some(sparse_sum) = result.downcast_ref::<f64>() {
println!(
"Sum matches dense: {}",
(dense_sum - sparse_sum).abs() < 1e-10
);
} else {
println!("Sum result is not a f64 type");
}
}
Err(e) => println!("Error in sum verification: {}", e),
}
let dense_transpose = dense.t().to_owned();
match transpose(&sparse) {
Ok(sparse_transpose) => {
if let Some(sparse_array) = sparse_transpose.as_any().downcast_ref::<SparseArray>() {
let sparse_dense_transpose = sparse_array.to_dense();
let is_transpose_equal = dense_transpose
.iter()
.zip(sparse_dense_transpose.iter())
.all(|(&a, &b)| (a - b).abs() < 1e-10);
println!("Transpose matches dense: {}", is_transpose_equal);
} else {
println!("Transpose result is not a SparseArray type");
}
}
Err(e) => println!("Error in transpose verification: {}", e),
}
println!("\nAll operations completed successfully!");
}