use crate::{CooTensor, CsrTensor, SparseFormat, SparseTensor, TorshResult};
use std::collections::HashMap;
use std::sync::Arc;
use torsh_core::{DType, TorshError};
use torsh_tensor::Tensor;
#[derive(Debug, Clone)]
pub struct SparseAutogradTensor {
data: SparseData,
requires_grad: bool,
grad: Option<Arc<SparseAutogradTensor>>,
grad_fn: Option<Arc<dyn SparseGradFn>>,
inputs: Vec<Arc<SparseAutogradTensor>>,
id: u64,
is_leaf: bool,
}
#[derive(Debug, Clone)]
pub enum SparseData {
Coo(CooTensor),
Csr(CsrTensor),
}
impl SparseData {
pub fn dtype(&self) -> DType {
match self {
SparseData::Coo(tensor) => tensor.dtype(),
SparseData::Csr(tensor) => tensor.dtype(),
}
}
pub fn shape(&self) -> &torsh_core::Shape {
match self {
SparseData::Coo(tensor) => tensor.shape(),
SparseData::Csr(tensor) => tensor.shape(),
}
}
pub fn nnz(&self) -> usize {
match self {
SparseData::Coo(tensor) => tensor.nnz(),
SparseData::Csr(tensor) => tensor.nnz(),
}
}
pub fn format(&self) -> SparseFormat {
match self {
SparseData::Coo(_) => SparseFormat::Coo,
SparseData::Csr(_) => SparseFormat::Csr,
}
}
}
pub trait SparseGradFn: Send + Sync + std::fmt::Debug {
fn backward(
&self,
grad_output: &SparseAutogradTensor,
) -> TorshResult<Vec<Option<SparseAutogradTensor>>>;
fn num_inputs(&self) -> usize;
fn name(&self) -> &str;
}
impl SparseAutogradTensor {
pub fn new(data: SparseData, requires_grad: bool) -> Self {
static COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
let id = COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Self {
data,
requires_grad,
grad: None,
grad_fn: None,
inputs: Vec::new(),
id,
is_leaf: true,
}
}
pub fn from_coo(coo: CooTensor, requires_grad: bool) -> Self {
Self::new(SparseData::Coo(coo), requires_grad)
}
pub fn from_csr(csr: CsrTensor, requires_grad: bool) -> Self {
Self::new(SparseData::Csr(csr), requires_grad)
}
pub fn data(&self) -> &SparseData {
&self.data
}
pub fn requires_grad(&self) -> bool {
self.requires_grad
}
pub fn grad(&self) -> Option<&SparseAutogradTensor> {
self.grad.as_ref().map(|g| g.as_ref())
}
pub fn set_grad(&mut self, grad: Option<SparseAutogradTensor>) {
self.grad = grad.map(Arc::new);
}
pub fn id(&self) -> u64 {
self.id
}
pub fn is_leaf(&self) -> bool {
self.is_leaf
}
fn accumulate_grad(
&self,
target_tensor: &SparseAutogradTensor,
_new_grad: &SparseAutogradTensor,
) -> TorshResult<()> {
if target_tensor.requires_grad() {
println!(
"Accumulating gradient for tensor ID: {}",
target_tensor.id()
);
}
Ok(())
}
pub fn backward(&self, retain_graph: bool) -> TorshResult<()> {
if !self.requires_grad {
return Err(TorshError::AutogradError(
"Tensor does not require gradients".to_string(),
));
}
let unit_grad = self.create_unit_grad()?;
self.backward_impl(&unit_grad, retain_graph)
}
fn backward_impl(
&self,
grad_output: &SparseAutogradTensor,
_retain_graph: bool,
) -> TorshResult<()> {
if let Some(grad_fn) = &self.grad_fn {
let input_grads = grad_fn.backward(grad_output)?;
for (i, grad) in input_grads.into_iter().enumerate() {
if let Some(input_grad) = grad {
if i < self.inputs.len() {
let input_tensor = &self.inputs[i];
self.accumulate_grad(input_tensor, &input_grad)?;
if input_tensor.grad_fn.is_some() && input_tensor.requires_grad() {
input_tensor.backward_impl(&input_grad, _retain_graph)?;
}
}
}
}
}
Ok(())
}
fn create_unit_grad(&self) -> TorshResult<SparseAutogradTensor> {
match &self.data {
SparseData::Coo(coo) => {
let values = vec![1.0; coo.nnz()];
let unit_coo = CooTensor::new(
coo.row_indices().to_vec(),
coo.col_indices().to_vec(),
values,
coo.shape().clone(),
)?;
Ok(SparseAutogradTensor::from_coo(unit_coo, false))
}
SparseData::Csr(csr) => {
let values = vec![1.0; csr.nnz()];
let unit_csr = CsrTensor::new(
csr.row_ptr().to_vec(),
csr.col_indices().to_vec(),
values,
csr.shape().clone(),
)?;
Ok(SparseAutogradTensor::from_csr(unit_csr, false))
}
}
}
pub fn sparse_mm(&self, other: &SparseAutogradTensor) -> TorshResult<SparseAutogradTensor> {
let result_data = match (&self.data, &other.data) {
(SparseData::Csr(a), SparseData::Csr(b)) => {
let result = a.multiply_csr(b)?;
SparseData::Csr(result)
}
(SparseData::Coo(a), SparseData::Coo(b)) => {
let result = a.multiply_coo(b)?;
SparseData::Coo(result)
}
_ => {
return Err(TorshError::ComputeError(
"Mixed format sparse multiplication not supported".to_string(),
))
}
};
let requires_grad = self.requires_grad || other.requires_grad;
let mut result = SparseAutogradTensor::new(result_data, requires_grad);
if requires_grad {
let self_arc = Arc::new(self.clone());
let other_arc = Arc::new(other.clone());
let grad_fn = Arc::new(SparseMmGradFn {
input_shapes: [self.data().shape().clone(), other.data().shape().clone()],
input_a: Some(Arc::downgrade(&self_arc)),
input_b: Some(Arc::downgrade(&other_arc)),
});
result.grad_fn = Some(grad_fn);
result.inputs = vec![self_arc, other_arc];
result.is_leaf = false;
}
Ok(result)
}
pub fn add(&self, other: &SparseAutogradTensor) -> TorshResult<SparseAutogradTensor> {
let result_data = match (&self.data, &other.data) {
(SparseData::Coo(a), SparseData::Coo(b)) => {
let result = a.add_coo(b)?;
SparseData::Coo(result)
}
(SparseData::Csr(a), SparseData::Csr(b)) => {
let result = a.add_csr(b)?;
SparseData::Csr(result)
}
_ => {
return Err(TorshError::ComputeError(
"Mixed format sparse addition not supported".to_string(),
))
}
};
let requires_grad = self.requires_grad || other.requires_grad;
let mut result = SparseAutogradTensor::new(result_data, requires_grad);
if requires_grad {
let grad_fn = Arc::new(SparseAddGradFn);
result.grad_fn = Some(grad_fn);
result.inputs = vec![Arc::new(self.clone()), Arc::new(other.clone())];
result.is_leaf = false;
}
Ok(result)
}
pub fn to_dense(&self) -> TorshResult<Tensor> {
match &self.data {
SparseData::Coo(coo) => coo.to_dense(),
SparseData::Csr(csr) => csr.to_dense(),
}
}
}
#[derive(Debug)]
struct SparseMmGradFn {
input_shapes: [torsh_core::Shape; 2],
input_a: Option<std::sync::Weak<SparseAutogradTensor>>,
input_b: Option<std::sync::Weak<SparseAutogradTensor>>,
}
impl SparseGradFn for SparseMmGradFn {
fn backward(
&self,
grad_output: &SparseAutogradTensor,
) -> TorshResult<Vec<Option<SparseAutogradTensor>>> {
let input_a = self.input_a.as_ref().and_then(|weak| weak.upgrade());
let input_b = self.input_b.as_ref().and_then(|weak| weak.upgrade());
match (input_a, input_b) {
(Some(a), Some(b)) => {
let grad_a = self.compute_grad_a(grad_output, &b)?;
let grad_b = self.compute_grad_b(grad_output, &a)?;
Ok(vec![grad_a, grad_b])
}
_ => {
let grad_a_shape = &self.input_shapes[0];
let grad_b_shape = &self.input_shapes[1];
let grad_a = self.create_zero_grad(grad_a_shape, grad_output)?;
let grad_b = self.create_zero_grad(grad_b_shape, grad_output)?;
Ok(vec![grad_a, grad_b])
}
}
}
fn num_inputs(&self) -> usize {
2
}
fn name(&self) -> &str {
"SparseMm"
}
}
impl SparseMmGradFn {
fn compute_grad_a(
&self,
grad_output: &SparseAutogradTensor,
input_b: &SparseAutogradTensor,
) -> TorshResult<Option<SparseAutogradTensor>> {
match (grad_output.data(), input_b.data()) {
(SparseData::Coo(_), SparseData::Coo(_)) => {
let grad_a_shape = &self.input_shapes[0];
let unit_coo = CooTensor::new(
vec![0], vec![0],
vec![1.0],
grad_a_shape.clone(),
)?;
Ok(Some(SparseAutogradTensor::from_coo(unit_coo, false)))
}
(SparseData::Csr(_), SparseData::Csr(_)) => {
let grad_a_shape = &self.input_shapes[0];
let rows = grad_a_shape.dims()[0];
if rows > 0 && grad_a_shape.dims()[1] > 0 {
let mut row_ptr = vec![0; rows + 1];
row_ptr[1] = 1; let unit_csr = CsrTensor::new(
row_ptr,
vec![0], vec![1.0],
grad_a_shape.clone(),
)?;
Ok(Some(SparseAutogradTensor::from_csr(unit_csr, false)))
} else {
self.create_zero_grad(grad_a_shape, grad_output)
}
}
_ => {
let grad_a_shape = &self.input_shapes[0];
self.create_zero_grad(grad_a_shape, grad_output)
}
}
}
fn compute_grad_b(
&self,
grad_output: &SparseAutogradTensor,
input_a: &SparseAutogradTensor,
) -> TorshResult<Option<SparseAutogradTensor>> {
match (grad_output.data(), input_a.data()) {
(SparseData::Coo(_), SparseData::Coo(_)) => {
let grad_b_shape = &self.input_shapes[1];
let unit_coo = CooTensor::new(
vec![0], vec![0],
vec![1.0],
grad_b_shape.clone(),
)?;
Ok(Some(SparseAutogradTensor::from_coo(unit_coo, false)))
}
(SparseData::Csr(_), SparseData::Csr(_)) => {
let grad_b_shape = &self.input_shapes[1];
let rows = grad_b_shape.dims()[0];
if rows > 0 && grad_b_shape.dims()[1] > 0 {
let mut row_ptr = vec![0; rows + 1];
row_ptr[1] = 1; let unit_csr = CsrTensor::new(
row_ptr,
vec![0], vec![1.0],
grad_b_shape.clone(),
)?;
Ok(Some(SparseAutogradTensor::from_csr(unit_csr, false)))
} else {
self.create_zero_grad(grad_b_shape, grad_output)
}
}
_ => {
let grad_b_shape = &self.input_shapes[1];
self.create_zero_grad(grad_b_shape, grad_output)
}
}
}
fn create_zero_grad(
&self,
shape: &torsh_core::Shape,
grad_output: &SparseAutogradTensor,
) -> TorshResult<Option<SparseAutogradTensor>> {
match grad_output.data() {
SparseData::Coo(_) => {
let zero_coo = CooTensor::new(vec![], vec![], vec![], shape.clone())?;
Ok(Some(SparseAutogradTensor::from_coo(zero_coo, false)))
}
SparseData::Csr(_) => {
let rows = shape.dims()[0];
let zero_csr = CsrTensor::new(vec![0; rows + 1], vec![], vec![], shape.clone())?;
Ok(Some(SparseAutogradTensor::from_csr(zero_csr, false)))
}
}
}
}
#[derive(Debug)]
struct SparseAddGradFn;
impl SparseGradFn for SparseAddGradFn {
fn backward(
&self,
grad_output: &SparseAutogradTensor,
) -> TorshResult<Vec<Option<SparseAutogradTensor>>> {
Ok(vec![Some(grad_output.clone()), Some(grad_output.clone())])
}
fn num_inputs(&self) -> usize {
2
}
fn name(&self) -> &str {
"SparseAdd"
}
}
impl CooTensor {
pub fn add_coo(&self, other: &CooTensor) -> TorshResult<CooTensor> {
if self.shape() != other.shape() {
return Err(TorshError::ComputeError(
"Shape mismatch for sparse addition".to_string(),
));
}
Err(TorshError::ComputeError(
"COO addition not yet implemented".to_string(),
))
}
pub fn multiply_coo(&self, _other: &CooTensor) -> TorshResult<CooTensor> {
Err(TorshError::ComputeError(
"COO multiplication not yet implemented".to_string(),
))
}
}
impl CsrTensor {
pub fn add_csr(&self, other: &CsrTensor) -> TorshResult<CsrTensor> {
if self.shape() != other.shape() {
return Err(TorshError::ComputeError(
"Shape mismatch for sparse addition".to_string(),
));
}
Err(TorshError::ComputeError(
"CSR addition not yet implemented".to_string(),
))
}
pub fn multiply_csr(&self, _other: &CsrTensor) -> TorshResult<CsrTensor> {
Err(TorshError::ComputeError(
"CSR multiplication not yet implemented".to_string(),
))
}
}
pub struct SparseGradientAccumulator {
gradients: HashMap<u64, SparseAutogradTensor>,
}
impl SparseGradientAccumulator {
pub fn new() -> Self {
Self {
gradients: HashMap::new(),
}
}
pub fn accumulate(&mut self, tensor_id: u64, grad: SparseAutogradTensor) -> TorshResult<()> {
if let Some(existing_grad) = self.gradients.get(&tensor_id) {
let accumulated = existing_grad.add(&grad)?;
self.gradients.insert(tensor_id, accumulated);
} else {
self.gradients.insert(tensor_id, grad);
}
Ok(())
}
pub fn get_grad(&self, tensor_id: u64) -> Option<&SparseAutogradTensor> {
self.gradients.get(&tensor_id)
}
pub fn clear(&mut self) {
self.gradients.clear();
}
}
impl Default for SparseGradientAccumulator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_core::Shape;
#[test]
fn test_sparse_autograd_tensor_creation() {
let coo = CooTensor::new(
vec![0, 1, 2],
vec![0, 1, 2],
vec![1.0, 2.0, 3.0],
Shape::new(vec![3, 3]),
)
.unwrap();
let autograd_tensor = SparseAutogradTensor::from_coo(coo, true);
assert!(autograd_tensor.requires_grad());
assert!(autograd_tensor.is_leaf());
assert_eq!(autograd_tensor.data().shape().dims(), &[3, 3]);
assert_eq!(autograd_tensor.data().nnz(), 3);
}
#[test]
fn test_gradient_accumulator() {
let mut accumulator = SparseGradientAccumulator::new();
let coo = CooTensor::new(
vec![0, 1],
vec![0, 1],
vec![1.0, 2.0],
Shape::new(vec![2, 2]),
)
.unwrap();
let grad = SparseAutogradTensor::from_coo(coo, false);
let tensor_id = 123;
accumulator.accumulate(tensor_id, grad).unwrap();
assert!(accumulator.get_grad(tensor_id).is_some());
accumulator.clear();
assert!(accumulator.get_grad(tensor_id).is_none());
}
}