numr 0.5.2

High-performance numerical computing with multi-backend GPU acceleration (CPU/CUDA/WebGPU)
Documentation
//! Gradient storage and accumulation

use crate::error::Result;
use crate::runtime::Runtime;
use crate::tensor::{Tensor, TensorId};
use std::collections::HashMap;

/// Storage for gradients computed during backward pass
///
/// Gradients are stored by tensor ID and accumulated when a tensor
/// is used multiple times in the computation graph.
pub struct GradStore<R: Runtime> {
    grads: HashMap<TensorId, Tensor<R>>,
}

impl<R: Runtime> GradStore<R> {
    /// Create a new empty gradient store
    pub fn new() -> Self {
        Self {
            grads: HashMap::new(),
        }
    }

    /// Get the gradient for a tensor
    pub fn get(&self, id: TensorId) -> Option<&Tensor<R>> {
        self.grads.get(&id)
    }

    /// Insert a gradient (overwrites if exists)
    pub fn insert(&mut self, id: TensorId, grad: Tensor<R>) {
        self.grads.insert(id, grad);
    }

    /// Check if a gradient exists
    pub fn contains(&self, id: TensorId) -> bool {
        self.grads.contains_key(&id)
    }

    /// Remove and return a gradient
    pub fn remove(&mut self, id: TensorId) -> Option<Tensor<R>> {
        self.grads.remove(&id)
    }

    /// Get all tensor IDs with gradients
    pub fn keys(&self) -> impl Iterator<Item = &TensorId> {
        self.grads.keys()
    }

    /// Number of stored gradients
    pub fn len(&self) -> usize {
        self.grads.len()
    }

    /// Check if empty
    pub fn is_empty(&self) -> bool {
        self.grads.is_empty()
    }

    /// Clear all gradients
    pub fn clear(&mut self) {
        self.grads.clear();
    }

    /// Accumulate a gradient for a tensor
    ///
    /// If no gradient exists for this tensor, stores the gradient.
    /// If a gradient already exists, adds the new gradient to the existing one.
    ///
    /// # Arguments
    /// * `id` - The tensor ID to accumulate gradient for
    /// * `grad` - The gradient tensor to accumulate
    /// * `add_fn` - Function to add two tensors: `fn(existing, new) -> sum`
    ///
    /// This is used when a tensor is used multiple times in the computation graph,
    /// requiring its gradients to be summed according to the chain rule.
    pub fn accumulate<F>(&mut self, id: TensorId, grad: Tensor<R>, add_fn: F)
    where
        F: FnOnce(Tensor<R>, Tensor<R>) -> Tensor<R>,
    {
        if let Some(existing) = self.grads.remove(&id) {
            // Accumulate: existing + grad
            let accumulated = add_fn(existing, grad);
            self.grads.insert(id, accumulated);
        } else {
            // First gradient for this tensor
            self.grads.insert(id, grad);
        }
    }

    /// Accumulate a gradient with a fallible addition function
    ///
    /// Like `accumulate`, but the addition function can fail and return a Result.
    /// This is the preferred method for use in backward passes where tensor
    /// operations may fail.
    pub fn try_accumulate<F>(&mut self, id: TensorId, grad: Tensor<R>, add_fn: F) -> Result<()>
    where
        F: FnOnce(Tensor<R>, Tensor<R>) -> Result<Tensor<R>>,
    {
        if let Some(existing) = self.grads.remove(&id) {
            let accumulated = add_fn(existing, grad)?;
            self.grads.insert(id, accumulated);
        } else {
            self.grads.insert(id, grad);
        }
        Ok(())
    }

    /// Insert a gradient, overwriting any existing value
    ///
    /// This is intentionally simpler than `accumulate` - it directly inserts
    /// the gradient without addition. Use this when you don't have access to
    /// a TensorOps client, or when overwriting semantics are desired.
    ///
    /// For proper gradient accumulation (adding to existing gradients),
    /// use `accumulate` with an add function instead.
    pub fn accumulate_or_insert(&mut self, id: TensorId, grad: Tensor<R>) {
        self.grads.insert(id, grad);
    }
}

impl<R: Runtime> Default for GradStore<R> {
    fn default() -> Self {
        Self::new()
    }
}