use super::{SparseFormat, SparseOps, SparseTensor};
use crate::error::{RusTorchError, RusTorchResult};
use ndarray::{Array1, Array2, ArrayD};
use num_traits::{Float, FromPrimitive, One, Zero};
#[cfg(feature = "cuda")]
pub struct CudaSparseOps;
#[cfg(feature = "cuda")]
impl CudaSparseOps {
pub fn init() -> RusTorchResult<Self> {
Ok(CudaSparseOps)
}
pub fn spmv<T: Float + Copy + std::ops::AddAssign>(
&self,
sparse_matrix: &SparseTensor<T>,
vector: &Array1<T>,
) -> RusTorchResult<Array1<T>> {
if sparse_matrix.format != SparseFormat::CSR {
return Err(RusTorchError::InvalidOperation {
operation: "cuda_spmv".to_string(),
message: "CUDA SpMV requires CSR format".to_string(),
});
}
if sparse_matrix.format != SparseFormat::CSR {
return Err(RusTorchError::InvalidOperation {
operation: "spmv_gpu".to_string(),
message: "Only CSR format supported".to_string(),
});
}
if sparse_matrix.indices.len() < 2 {
return Err(RusTorchError::InvalidOperation {
operation: "spmv_gpu".to_string(),
message: "CSR format requires row_ptr and col_indices".to_string(),
});
}
let row_ptr = &sparse_matrix.indices[0];
let col_indices = &sparse_matrix.indices[1];
let mut result = Array1::zeros(sparse_matrix.shape[0]);
for i in 0..sparse_matrix.shape[0] {
let start = row_ptr[i];
let end = row_ptr[i + 1];
for j in start..end {
result[i] += sparse_matrix.values[j] * vector[col_indices[j]];
}
}
Ok(result)
}
pub fn spmm<T: Float + Copy + std::ops::AddAssign>(
&self,
sparse_a: &SparseTensor<T>,
dense_b: &Array2<T>,
) -> RusTorchResult<Array2<T>> {
if sparse_a.format != SparseFormat::CSR {
return Err(RusTorchError::InvalidOperation {
operation: "cuda_spmm".to_string(),
message: "CUDA SpMM requires CSR format".to_string(),
});
}
if sparse_a.format != SparseFormat::CSR {
return Err(RusTorchError::InvalidOperation {
operation: "spmm_gpu".to_string(),
message: "Only CSR format supported".to_string(),
});
}
if sparse_a.indices.len() < 2 {
return Err(RusTorchError::InvalidOperation {
operation: "spmm_gpu".to_string(),
message: "CSR format requires row_ptr and col_indices".to_string(),
});
}
let row_ptr = &sparse_a.indices[0];
let col_indices = &sparse_a.indices[1];
let mut result = Array2::zeros((sparse_a.shape[0], dense_b.ncols()));
for i in 0..sparse_a.shape[0] {
let start = row_ptr[i];
let end = row_ptr[i + 1];
for j in start..end {
let col_idx = col_indices[j];
let sparse_val = sparse_a.values[j];
for k in 0..dense_b.ncols() {
result[[i, k]] += sparse_val * dense_b[[col_idx, k]];
}
}
}
Ok(result)
}
pub fn sparse_add<T: Float + Copy + std::ops::AddAssign>(
&self,
a: &SparseTensor<T>,
b: &SparseTensor<T>,
) -> RusTorchResult<SparseTensor<T>> {
if a.format != b.format {
return Err(RusTorchError::InvalidOperation {
operation: "sparse_add_gpu".to_string(),
message: "Both tensors must have same format".to_string(),
});
}
let mut result = a.clone();
for (i, &val) in b.values.iter().enumerate() {
if i < result.values.len() {
result.values[i] += val;
}
}
Ok(result)
}
pub fn convert_format<T: Float + Copy>(
&self,
tensor: &SparseTensor<T>,
target_format: SparseFormat,
) -> RusTorchResult<SparseTensor<T>> {
match (tensor.format, target_format) {
(SparseFormat::COO, SparseFormat::CSR) => {
Err(RusTorchError::NotImplemented {
feature: "COO to CSR conversion".to_string(),
})
}
(SparseFormat::CSR, SparseFormat::COO) => {
Err(RusTorchError::NotImplemented {
feature: "CSR to COO conversion".to_string(),
})
}
_ => Err(RusTorchError::NotImplemented {
feature: format!(
"GPU conversion from {:?} to {:?}",
tensor.format, target_format
),
}),
}
}
}
#[cfg(feature = "metal")]
pub struct MetalSparseOps {
pub device: metal::Device,
pub command_queue: metal::CommandQueue,
}
#[cfg(feature = "metal")]
impl MetalSparseOps {
pub fn new() -> RusTorchResult<Self> {
let device =
metal::Device::system_default().ok_or_else(|| RusTorchError::BackendUnavailable {
backend: "Metal".to_string(),
})?;
let command_queue = device.new_command_queue();
Ok(Self {
device,
command_queue,
})
}
pub fn spmv<T: Float + Copy + std::ops::AddAssign>(
&self,
sparse_matrix: &SparseTensor<T>,
vector: &Array1<T>,
) -> RusTorchResult<Array1<T>> {
if sparse_matrix.format != SparseFormat::CSR {
return Err(RusTorchError::InvalidOperation {
operation: "metal_spmv".to_string(),
message: "Only CSR format supported".to_string(),
});
}
if sparse_matrix.indices.len() < 2 {
return Err(RusTorchError::InvalidOperation {
operation: "spmv_gpu".to_string(),
message: "CSR format requires row_ptr and col_indices".to_string(),
});
}
let row_ptr = &sparse_matrix.indices[0];
let col_indices = &sparse_matrix.indices[1];
let mut result = Array1::zeros(sparse_matrix.shape[0]);
for i in 0..sparse_matrix.shape[0] {
let start = row_ptr[i];
let end = row_ptr[i + 1];
for j in start..end {
result[i] += sparse_matrix.values[j] * vector[col_indices[j]];
}
}
Ok(result)
}
}
#[derive(Debug, Clone)]
pub struct GpuSparseLayout<T: Float> {
pub values: Vec<T>,
pub indices: Vec<Vec<u32>>, pub alignment: usize,
}
impl<T: Float + Copy> GpuSparseLayout<T> {
pub fn from_sparse_tensor(tensor: &SparseTensor<T>) -> Self {
let values = tensor.values.to_vec();
let indices: Vec<Vec<u32>> = tensor
.indices
.iter()
.map(|arr| arr.iter().map(|&x| x as u32).collect())
.collect();
Self {
values,
indices,
alignment: 16, }
}
pub fn memory_usage(&self) -> usize {
let value_bytes = self.values.len() * std::mem::size_of::<T>();
let index_bytes: usize = self
.indices
.iter()
.map(|arr| arr.len() * std::mem::size_of::<u32>())
.sum();
let total_unaligned = value_bytes + index_bytes;
(total_unaligned + self.alignment - 1) & !(self.alignment - 1)
}
pub fn validate_gpu_memory(&self, available_memory: usize) -> RusTorchResult<()> {
let required_memory = self.memory_usage();
if required_memory > available_memory {
return Err(RusTorchError::OutOfMemory {
requested: required_memory,
available: available_memory,
});
}
Ok(())
}
}
pub struct SparseBatchProcessor<T: Float> {
pub max_batch_size: usize,
pub batch: Vec<SparseTensor<T>>,
}
impl<T: Float + Copy + Zero + One + std::ops::AddAssign + PartialOrd + FromPrimitive>
SparseBatchProcessor<T>
where
T: Zero + One + std::ops::AddAssign + FromPrimitive,
{
pub fn new(max_batch_size: usize) -> Self {
Self {
max_batch_size,
batch: Vec::new(),
}
}
pub fn add_to_batch(&mut self, tensor: SparseTensor<T>) -> RusTorchResult<()> {
if self.batch.len() >= self.max_batch_size {
return Err(RusTorchError::InvalidOperation {
operation: "sparse_batch_add".to_string(),
message: "Batch is full".to_string(),
});
}
self.batch.push(tensor);
Ok(())
}
pub fn process_batch(&mut self) -> RusTorchResult<Vec<Array1<T>>> {
let mut results = Vec::new();
for sparse_tensor in &self.batch {
let dummy_vector = Array1::ones(sparse_tensor.shape[1]);
let result = sparse_tensor.spmv(&dummy_vector)?;
results.push(result);
}
self.batch.clear();
Ok(results)
}
pub fn batch_utilization(&self) -> f32 {
self.batch.len() as f32 / self.max_batch_size as f32
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gpu_layout_conversion() {
let indices = vec![
Array1::from_vec(vec![0, 1, 2]),
Array1::from_vec(vec![1, 2, 0]),
];
let values = Array1::from_vec(vec![1.0f32, 2.0, 3.0]);
let shape = vec![3, 3];
let sparse_tensor = SparseTensor::from_coo(indices, values, shape).unwrap();
let gpu_layout = GpuSparseLayout::from_sparse_tensor(&sparse_tensor);
assert_eq!(gpu_layout.values.len(), 3);
assert_eq!(gpu_layout.indices.len(), 2);
assert!(gpu_layout.memory_usage() > 0);
}
#[test]
fn test_sparse_batch_processor() {
let mut processor = SparseBatchProcessor::new(2);
let sparse1 = SparseTensor::from_coo(
vec![Array1::from_vec(vec![0]), Array1::from_vec(vec![0])],
Array1::from_vec(vec![1.0f32]),
vec![2, 2],
)
.unwrap();
let sparse2 = SparseTensor::from_coo(
vec![Array1::from_vec(vec![1]), Array1::from_vec(vec![1])],
Array1::from_vec(vec![2.0f32]),
vec![2, 2],
)
.unwrap();
processor.add_to_batch(sparse1).unwrap();
assert_eq!(processor.batch_utilization(), 0.5);
processor.add_to_batch(sparse2).unwrap();
assert_eq!(processor.batch_utilization(), 1.0);
let results = processor.process_batch().unwrap();
assert_eq!(results.len(), 2);
assert_eq!(processor.batch_utilization(), 0.0);
}
#[cfg(feature = "cuda")]
#[test]
fn test_cuda_sparse_ops() {
let cuda_ops = CudaSparseOps::init().unwrap();
let sparse_matrix = SparseTensor::from_coo(
vec![Array1::from_vec(vec![0, 1]), Array1::from_vec(vec![0, 1])],
Array1::from_vec(vec![1.0f32, 2.0]),
vec![2, 2],
)
.unwrap()
.to_csr()
.unwrap();
let vector = Array1::from_vec(vec![1.0, 2.0]);
let result = cuda_ops.spmv(&sparse_matrix, &vector).unwrap();
assert_eq!(result.len(), 2);
}
}