ferrite/autograd/tensor/
base.rsuse std::rc::Rc;
use std::cell::RefCell;
use crate::tensor_storage::*;
use super::function::*;
use std::collections::HashSet;
#[derive(Clone)]
pub struct Tensor {
tensor: TensorStorage,
requires_grad: bool,
grad_fn: Option<Rc<dyn GradientFunction>>,
grad: Option<Rc<RefCell<TensorStorage>>>,
}
impl Tensor {
pub fn new(tensor: TensorStorage, requires_grad: bool) -> Self {
let grad = if requires_grad {
Some(Rc::new(RefCell::new(TensorStorage::zeros(tensor.shape().clone(), None))))
} else {
None
};
Tensor {
tensor,
requires_grad,
grad_fn: None,
grad,
}
}
pub fn tensor(&self) -> &TensorStorage {
&self.tensor
}
pub fn tensor_mut(&mut self) -> &mut TensorStorage {
&mut self.tensor
}
pub fn requires_grad(&self) -> &bool {
&self.requires_grad
}
pub fn grad_fn(&self) -> Option<Rc<dyn GradientFunction>> {
self.grad_fn.clone()
}
pub fn set_grad_fn(&mut self, grad_fn: Option<Rc<dyn GradientFunction>>) {
self.grad_fn = grad_fn;
}
pub fn grad(&self) -> Option<Rc<RefCell<TensorStorage>>> {
self.grad.clone()
}
pub fn backward(&mut self) {
if self.tensor().shape().len() != 1 || self.tensor().shape()[0] != 1 {
panic!("backward() can only be called on scalar tensors");
}
if let Some(grad) = &self.grad {
grad.borrow_mut().set_data(vec![1.0]);
} else {
panic!("Called backward on tensor that doesn't require grad");
}
let mut topo = Vec::new();
let mut visited = HashSet::new();
fn build_topo(
node: &Tensor,
topo: &mut Vec<Rc<dyn GradientFunction>>,
visited: &mut HashSet<*const dyn GradientFunction>
) {
if let Some(grad_fn) = &node.grad_fn {
let ptr = Rc::as_ptr(grad_fn) as *const dyn GradientFunction;
if !visited.contains(&ptr) {
visited.insert(ptr);
for parent in grad_fn.prev() {
build_topo(parent, topo, visited);
}
topo.push(grad_fn.clone());
}
}
}
build_topo(self, &mut topo, &mut visited);
for grad_fn in topo.iter().rev() {
grad_fn.backward();
}
}
}