use super::function::Function;
use crate::tensor::Tensor;
use num_traits::Float;
use std::collections::HashMap;
use std::sync::{Arc, Weak};
pub struct GraphNode<T: Float + Send + Sync + 'static> {
pub function: Option<Arc<dyn Function<T>>>,
pub inputs: Vec<Weak<GraphNode<T>>>,
pub grad: Option<Tensor<T>>,
pub input_tensors: Vec<Tensor<T>>,
pub requires_grad: bool,
}
impl<T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive>
GraphNode<T>
{
pub fn new_leaf(requires_grad: bool) -> Arc<Self> {
Arc::new(GraphNode {
function: None,
inputs: Vec::new(),
grad: None,
input_tensors: Vec::new(),
requires_grad,
})
}
pub fn new_function(
function: Arc<dyn Function<T>>,
inputs: Vec<Weak<GraphNode<T>>>,
input_tensors: Vec<Tensor<T>>,
requires_grad: bool,
) -> Arc<Self> {
Arc::new(GraphNode {
function: Some(function),
inputs,
grad: None,
input_tensors,
requires_grad,
})
}
pub fn accumulate_grad(&mut self, grad: Tensor<T>) {
match &mut self.grad {
Some(existing_grad) => {
*existing_grad = &*existing_grad + &grad;
}
None => {
self.grad = Some(grad);
}
}
}
pub fn zero_grad(&mut self) {
self.grad = None;
}
}
#[derive(Default)]
pub struct ComputationGraph<T: Float + Send + Sync + 'static> {
nodes: HashMap<usize, Arc<GraphNode<T>>>,
next_id: usize,
}
impl<T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive>
ComputationGraph<T>
{
pub fn new() -> Self {
ComputationGraph {
nodes: HashMap::new(),
next_id: 0,
}
}
pub fn add_leaf(&mut self, requires_grad: bool) -> usize {
let id = self.next_id;
self.next_id += 1;
let node = GraphNode::new_leaf(requires_grad);
self.nodes.insert(id, node);
id
}
pub fn add_function(
&mut self,
function: Arc<dyn Function<T>>,
input_ids: Vec<usize>,
input_tensors: Vec<Tensor<T>>,
requires_grad: bool,
) -> usize {
let id = self.next_id;
self.next_id += 1;
let input_nodes: Vec<Weak<GraphNode<T>>> = input_ids
.iter()
.filter_map(|&input_id| self.nodes.get(&input_id).map(Arc::downgrade))
.collect();
let node = GraphNode::new_function(function, input_nodes, input_tensors, requires_grad);
self.nodes.insert(id, node);
id
}
pub fn get_node(&self, id: usize) -> Option<&Arc<GraphNode<T>>> {
self.nodes.get(&id)
}
pub fn get_node_mut(&mut self, id: usize) -> Option<&mut Arc<GraphNode<T>>> {
self.nodes.get_mut(&id)
}
pub fn backward(&mut self, root_id: usize, grad_output: Option<Tensor<T>>) {
if let Some(_root_node_arc) = self.nodes.get(&root_id) {
let initial_grad = grad_output.unwrap_or_else(|| Tensor::ones(&[]));
let mut visited = std::collections::HashSet::new();
let mut stack = Vec::new();
self.topological_sort(root_id, &mut visited, &mut stack);
if let Some(root_node_arc) = self.nodes.get_mut(&root_id) {
if let Some(root_node) = Arc::get_mut(root_node_arc) {
root_node.accumulate_grad(initial_grad);
}
}
for &node_id in stack.iter().rev() {
if let Some(node_arc) = self.nodes.get(&node_id).cloned() {
if let Some(function) = &node_arc.function {
if let Some(grad) = &node_arc.grad {
let grad_inputs = function
.backward(grad, &node_arc.input_tensors.iter().collect::<Vec<_>>());
for (i, input_weak) in node_arc.inputs.iter().enumerate() {
if let Some(input_arc) = input_weak.upgrade() {
if let Some(grad_input) = &grad_inputs[i] {
for (&input_id, input_node_arc) in &self.nodes {
if Arc::ptr_eq(input_node_arc, &input_arc) {
if let Some(input_node_arc_mut) =
self.nodes.get_mut(&input_id)
{
if let Some(input_node) =
Arc::get_mut(input_node_arc_mut)
{
if input_node.requires_grad {
input_node.accumulate_grad(
grad_input.clone(),
);
}
}
}
break;
}
}
}
}
}
}
}
}
}
}
}
fn topological_sort(
&self,
node_id: usize,
visited: &mut std::collections::HashSet<usize>,
stack: &mut Vec<usize>,
) {
if visited.contains(&node_id) {
return;
}
visited.insert(node_id);
if let Some(node_arc) = self.nodes.get(&node_id) {
for input_weak in &node_arc.inputs {
if let Some(input_arc) = input_weak.upgrade() {
for (&input_id, input_node_arc) in &self.nodes {
if Arc::ptr_eq(input_node_arc, &input_arc) {
self.topological_sort(input_id, visited, stack);
break;
}
}
}
}
}
stack.push(node_id);
}
}