use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use super::grad_fn::GradFn;
use super::tensor::{Tensor, TensorId};
#[derive(Clone)]
pub(crate) struct TapeEntry {
pub output_id: TensorId,
pub grad_fn: Arc<dyn GradFn>,
pub input_ids: Vec<TensorId>,
}
#[allow(missing_debug_implementations)]
pub struct ComputationGraph {
tape: Vec<TapeEntry>,
tensors: HashMap<TensorId, Tensor>,
requires_grad: HashSet<TensorId>,
}
impl ComputationGraph {
#[must_use]
pub fn new() -> Self {
Self {
tape: Vec::new(),
tensors: HashMap::new(),
requires_grad: HashSet::new(),
}
}
pub fn clear(&mut self) {
self.tape.clear();
self.tensors.clear();
self.requires_grad.clear();
}
pub fn register_tensor(&mut self, tensor: Tensor) {
if tensor.requires_grad_enabled() {
self.requires_grad.insert(tensor.id());
}
self.tensors.insert(tensor.id(), tensor);
}
pub fn record(
&mut self,
output_id: TensorId,
grad_fn: Arc<dyn GradFn>,
input_ids: Vec<TensorId>,
) {
self.tape.push(TapeEntry {
output_id,
grad_fn,
input_ids,
});
}
#[must_use]
pub fn get_tensor(&self, id: TensorId) -> Option<&Tensor> {
self.tensors.get(&id)
}
pub fn get_tensor_mut(&mut self, id: TensorId) -> Option<&mut Tensor> {
self.tensors.get_mut(&id)
}
pub fn backward(&mut self, output_id: TensorId, grad_output: Tensor) {
let mut grads: HashMap<TensorId, Tensor> = HashMap::new();
grads.insert(output_id, grad_output);
for entry in self.tape.iter().rev() {
let grad_out = match grads.get(&entry.output_id) {
Some(g) => g.clone(),
None => continue,
};
let input_grads = entry.grad_fn.backward(&grad_out);
for (input_id, input_grad) in entry.input_ids.iter().zip(input_grads) {
grads
.entry(*input_id)
.and_modify(|existing| {
let new_data: Vec<f32> = existing
.data()
.iter()
.zip(input_grad.data().iter())
.map(|(a, b)| a + b)
.collect();
*existing = Tensor::new(&new_data, existing.shape());
})
.or_insert(input_grad);
}
}
for (id, grad) in grads {
if let Some(tensor) = self.tensors.get_mut(&id) {
if tensor.requires_grad_enabled() && tensor.is_leaf() {
tensor.accumulate_grad(grad);
}
}
}
}
#[must_use]
pub fn len(&self) -> usize {
self.tape.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.tape.is_empty()
}
#[must_use]
pub fn get_grad(&self, id: TensorId) -> Option<Tensor> {
self.tensors.get(&id).and_then(|t| t.grad().cloned())
}
pub fn clear_grad(&mut self, id: TensorId) {
if let Some(tensor) = self.tensors.get_mut(&id) {
tensor.clear_grad();
}
}
}
impl Default for ComputationGraph {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_graph_creation() {
let graph = ComputationGraph::new();
assert!(graph.is_empty());
assert_eq!(graph.len(), 0);
}
#[test]
fn test_graph_clear() {
let mut graph = ComputationGraph::new();
let t = Tensor::from_slice(&[1.0, 2.0]).requires_grad();
graph.register_tensor(t);
assert!(!graph.tensors.is_empty());
graph.clear();
assert!(graph.is_empty());
assert!(graph.tensors.is_empty());
}
#[test]
fn test_tensor_registration() {
let mut graph = ComputationGraph::new();
let t1 = Tensor::from_slice(&[1.0]).requires_grad();
let t2 = Tensor::from_slice(&[2.0]);
let id1 = t1.id();
let id2 = t2.id();
graph.register_tensor(t1);
graph.register_tensor(t2);
assert!(graph.get_tensor(id1).is_some());
assert!(graph.get_tensor(id2).is_some());
assert!(graph.requires_grad.contains(&id1));
assert!(!graph.requires_grad.contains(&id2));
}
#[test]
fn test_graph_default() {
let graph = ComputationGraph::default();
assert!(graph.is_empty());
}
#[test]
fn test_get_tensor_mut() {
let mut graph = ComputationGraph::new();
let t = Tensor::from_slice(&[1.0, 2.0]);
let id = t.id();
graph.register_tensor(t);
if let Some(tensor) = graph.get_tensor_mut(id) {
assert_eq!(tensor.data(), &[1.0, 2.0]);
}
let other = Tensor::from_slice(&[3.0]);
assert!(graph.get_tensor_mut(other.id()).is_none());
}
#[test]
fn test_record_operation() {
use crate::autograd::grad_fn::NegBackward;
let mut graph = ComputationGraph::new();
let t1 = Tensor::from_slice(&[1.0, 2.0]);
let output = Tensor::from_slice(&[-1.0, -2.0]);
let output_id = output.id();
graph.record(output_id, Arc::new(NegBackward), vec![t1.id()]);
assert_eq!(graph.len(), 1);
assert!(!graph.is_empty());
}
#[test]
fn test_get_grad_and_clear_grad() {
let mut graph = ComputationGraph::new();
let t = Tensor::from_slice(&[1.0, 2.0]).requires_grad();
let id = t.id();
graph.register_tensor(t);
assert!(graph.get_grad(id).is_none());
let other = Tensor::from_slice(&[3.0]);
assert!(graph.get_grad(other.id()).is_none());
graph.clear_grad(other.id());
}
#[test]
fn test_graph_len_empty() {
let graph = ComputationGraph::new();
assert_eq!(graph.len(), 0);
}
#[test]
fn test_graph_multiple_register() {
let mut graph = ComputationGraph::new();
let t1 = Tensor::from_slice(&[1.0]).requires_grad();
let t2 = Tensor::from_slice(&[2.0]).requires_grad();
let t3 = Tensor::from_slice(&[3.0]).requires_grad();
graph.register_tensor(t1);
graph.register_tensor(t2);
graph.register_tensor(t3);
assert_eq!(graph.tensors.len(), 3);
}
#[test]
fn test_graph_register_same_tensor_twice() {
let mut graph = ComputationGraph::new();
let t = Tensor::from_slice(&[1.0]).requires_grad();
let id = t.id();
graph.register_tensor(t.clone());
graph.register_tensor(t);
assert!(graph.get_tensor(id).is_some());
}
#[test]
fn test_backward_simple() {
use crate::autograd::grad_fn::NegBackward;
let mut graph = ComputationGraph::new();
let input = Tensor::from_slice(&[1.0, 2.0]).requires_grad();
let input_id = input.id();
graph.register_tensor(input);
let output = Tensor::from_slice(&[-1.0, -2.0]);
let output_id = output.id();
graph.register_tensor(output);
graph.record(output_id, Arc::new(NegBackward), vec![input_id]);
let grad_output = Tensor::from_slice(&[1.0, 1.0]);
graph.backward(output_id, grad_output);
let grad = graph.get_grad(input_id);
assert!(grad.is_some());
}
#[test]
fn test_backward_no_matching_output() {
let mut graph = ComputationGraph::new();
let output_id = Tensor::from_slice(&[1.0]).id();
let grad_output = Tensor::from_slice(&[1.0]);
graph.backward(output_id, grad_output);
assert!(graph.is_empty());
}
#[test]
fn test_backward_empty_tape() {
let mut graph = ComputationGraph::new();
let t = Tensor::from_slice(&[1.0]).requires_grad();
let id = t.id();
graph.register_tensor(t);
let grad_output = Tensor::from_slice(&[1.0]);
graph.backward(id, grad_output);
assert!(graph.is_empty());
}
#[test]
fn test_clear_grad_existing_tensor() {
let mut graph = ComputationGraph::new();
let t = Tensor::from_slice(&[1.0, 2.0]).requires_grad();
let id = t.id();
graph.register_tensor(t);
graph.clear_grad(id);
assert!(graph.get_grad(id).is_none());
}
#[test]
fn test_tape_entry_clone() {
use crate::autograd::grad_fn::NegBackward;
let entry = TapeEntry {
output_id: TensorId::new(),
grad_fn: Arc::new(NegBackward),
input_ids: vec![TensorId::new(), TensorId::new()],
};
let cloned = entry.clone();
assert_eq!(cloned.input_ids.len(), 2);
}
#[test]
fn test_graph_record_multiple_operations() {
use crate::autograd::grad_fn::NegBackward;
let mut graph = ComputationGraph::new();
let t1 = Tensor::from_slice(&[1.0]);
let t2 = Tensor::from_slice(&[-1.0]);
let t3 = Tensor::from_slice(&[1.0]);
graph.record(t2.id(), Arc::new(NegBackward), vec![t1.id()]);
graph.record(t3.id(), Arc::new(NegBackward), vec![t2.id()]);
assert_eq!(graph.len(), 2);
}
#[test]
fn test_backward_skips_unrelated_operations() {
use crate::autograd::grad_fn::NegBackward;
let mut graph = ComputationGraph::new();
let t1 = Tensor::from_slice(&[1.0]).requires_grad();
let t1_id = t1.id();
let t2 = Tensor::from_slice(&[-1.0]);
let t2_id = t2.id();
let t3 = Tensor::from_slice(&[5.0]); let t3_id = t3.id();
graph.register_tensor(t1);
graph.register_tensor(t2);
graph.record(t2_id, Arc::new(NegBackward), vec![t1_id]);
graph.record(TensorId::new(), Arc::new(NegBackward), vec![t3_id]);
let grad_output = Tensor::from_slice(&[1.0]);
graph.backward(t2_id, grad_output);
assert!(graph.get_grad(t1_id).is_some());
}
#[test]
fn test_graph_get_tensor_nonexistent() {
let graph = ComputationGraph::new();
let fake_id = TensorId::new();
assert!(graph.get_tensor(fake_id).is_none());
}
#[test]
fn test_requires_grad_set_tracking() {
let mut graph = ComputationGraph::new();
let t1 = Tensor::from_slice(&[1.0]).requires_grad();
let t2 = Tensor::from_slice(&[2.0]); let t3 = Tensor::from_slice(&[3.0]).requires_grad();
let id1 = t1.id();
let id2 = t2.id();
let id3 = t3.id();
graph.register_tensor(t1);
graph.register_tensor(t2);
graph.register_tensor(t3);
assert!(graph.requires_grad.contains(&id1));
assert!(!graph.requires_grad.contains(&id2));
assert!(graph.requires_grad.contains(&id3));
}
}