#![allow(unused_variables)]
use crate::errors::{Result, TrustformersError};
use crate::tensor::Tensor;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
pub type NodeId = usize;
#[derive(Debug)]
pub struct ComputationGraph {
nodes: HashMap<NodeId, GraphNode>,
next_id: NodeId,
topological_order: Vec<NodeId>,
dirty: bool,
root_nodes: Vec<NodeId>,
leaf_nodes: Vec<NodeId>,
}
#[derive(Debug, Clone)]
pub struct GraphNode {
pub id: NodeId,
pub value: Tensor,
pub gradient: Option<Tensor>,
pub operation: Option<OperationType>,
pub parents: Vec<NodeId>,
pub children: Vec<NodeId>,
pub requires_grad: bool,
pub is_leaf: bool,
pub name: Option<String>,
pub shape: Vec<usize>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum OperationType {
Add,
Subtract,
Multiply,
Divide,
MatrixMultiply,
Negate,
Reciprocal,
Square,
Sqrt,
Log,
Exp,
Sigmoid,
Tanh,
ReLU,
LeakyReLU(f32),
Softmax,
LogSoftmax,
Reshape(Vec<usize>),
Transpose(Vec<usize>),
Slice(Vec<std::ops::Range<usize>>),
Concat(usize), Split(Vec<usize>),
Sum(Option<Vec<usize>>), Mean(Option<Vec<usize>>), Max(Option<Vec<usize>>), Min(Option<Vec<usize>>),
LayerNorm(f32), Dropout(f32), BatchNorm(f32),
Custom(String),
}
pub trait GradientFunction: Send + Sync {
fn backward(&self, grad_output: &Tensor, inputs: &[&Tensor]) -> Result<Vec<Tensor>>;
fn operation_type(&self) -> OperationType;
}
impl ComputationGraph {
pub fn new() -> Self {
Self {
nodes: HashMap::new(),
next_id: 0,
topological_order: Vec::new(),
dirty: false,
root_nodes: Vec::new(),
leaf_nodes: Vec::new(),
}
}
pub fn add_node(&mut self, value: Tensor, requires_grad: bool, name: Option<String>) -> NodeId {
let id = self.next_id;
self.next_id += 1;
let shape = value.shape();
let node = GraphNode {
id,
value,
gradient: None,
operation: None,
parents: Vec::new(),
children: Vec::new(),
requires_grad,
is_leaf: true,
name,
shape,
};
self.nodes.insert(id, node);
if requires_grad {
self.root_nodes.push(id);
}
self.dirty = true;
id
}
pub fn add_operation_node(
&mut self,
value: Tensor,
operation: OperationType,
parents: Vec<NodeId>,
requires_grad: bool,
name: Option<String>,
) -> Result<NodeId> {
let id = self.next_id;
self.next_id += 1;
for parent_id in &parents {
if let Some(parent) = self.nodes.get_mut(parent_id) {
parent.children.push(id);
} else {
return Err(TrustformersError::tensor_op_error(
&format!("Parent node {} not found", parent_id),
"ComputationGraph::add_operation_node",
));
}
}
let shape = value.shape();
let node = GraphNode {
id,
value,
gradient: None,
operation: Some(operation),
parents,
children: Vec::new(),
requires_grad,
is_leaf: false,
name,
shape,
};
self.nodes.insert(id, node);
self.dirty = true;
Ok(id)
}
pub fn get_node(&self, id: NodeId) -> Option<&GraphNode> {
self.nodes.get(&id)
}
pub fn get_node_mut(&mut self, id: NodeId) -> Option<&mut GraphNode> {
self.nodes.get_mut(&id)
}
pub fn compute_topological_order(&mut self) -> Result<()> {
if !self.dirty {
return Ok(());
}
let mut in_degree = HashMap::new();
let mut queue = VecDeque::new();
let mut result = Vec::new();
for (id, node) in &self.nodes {
in_degree.insert(*id, node.parents.len());
if node.parents.is_empty() {
queue.push_back(*id);
}
}
while let Some(node_id) = queue.pop_front() {
result.push(node_id);
let Some(node) = self.nodes.get(&node_id) else {
continue;
};
for child_id in &node.children {
let Some(degree) = in_degree.get_mut(child_id) else {
continue;
};
*degree -= 1;
if *degree == 0 {
queue.push_back(*child_id);
}
}
}
if result.len() != self.nodes.len() {
return Err(TrustformersError::tensor_op_error(
"Cycle detected in computation graph",
"ComputationGraph::compute_topological_order",
));
}
self.topological_order = result;
self.dirty = false;
Ok(())
}
pub fn backward(&mut self, output_id: NodeId, grad_output: Option<Tensor>) -> Result<()> {
self.compute_topological_order()?;
if let Some(output_node) = self.nodes.get_mut(&output_id) {
output_node.gradient = Some(grad_output.unwrap_or_else(|| {
Tensor::ones(&output_node.shape).expect("Failed to create ones tensor")
}));
} else {
return Err(TrustformersError::tensor_op_error(
&format!("Output node {} not found", output_id),
"ComputationGraph::backward",
));
}
for &node_id in self.topological_order.iter().rev() {
let Some(node) = self.nodes.get(&node_id).cloned() else {
continue;
};
let Some(ref grad) = node.gradient else {
continue;
};
let Some(ref operation) = node.operation else {
continue;
};
let parent_gradients =
self.compute_operation_gradients(operation, grad, &node.parents)?;
for (parent_id, parent_grad) in node.parents.iter().zip(parent_gradients.iter()) {
let Some(parent_node) = self.nodes.get_mut(parent_id) else {
continue;
};
if !parent_node.requires_grad {
continue;
}
if let Some(ref mut existing_grad) = parent_node.gradient {
*existing_grad = existing_grad.add(parent_grad)?;
} else {
parent_node.gradient = Some(parent_grad.clone());
}
}
}
Ok(())
}
fn compute_operation_gradients(
&self,
operation: &OperationType,
grad_output: &Tensor,
parent_ids: &[NodeId],
) -> Result<Vec<Tensor>> {
let parent_values: Vec<&Tensor> =
parent_ids.iter().map(|id| &self.nodes[id].value).collect();
match operation {
OperationType::Add => {
Ok(vec![grad_output.clone(), grad_output.clone()])
},
OperationType::Subtract => {
Ok(vec![grad_output.clone(), grad_output.neg()?])
},
OperationType::Multiply => {
if parent_values.len() != 2 {
return Err(TrustformersError::tensor_op_error(
"Multiply operation requires exactly 2 inputs",
"ComputationGraph::compute_operation_gradients",
));
}
Ok(vec![
grad_output.mul(parent_values[1])?,
grad_output.mul(parent_values[0])?,
])
},
OperationType::Divide => {
if parent_values.len() != 2 {
return Err(TrustformersError::tensor_op_error(
"Divide operation requires exactly 2 inputs",
"ComputationGraph::compute_operation_gradients",
));
}
let a = parent_values[0];
let b = parent_values[1];
Ok(vec![
grad_output.div(b)?,
grad_output.mul(a)?.neg()?.div(&b.mul(b)?)?,
])
},
OperationType::MatrixMultiply => {
if parent_values.len() != 2 {
return Err(TrustformersError::tensor_op_error(
"MatrixMultiply operation requires exactly 2 inputs",
"ComputationGraph::compute_operation_gradients",
));
}
let a = parent_values[0];
let b = parent_values[1];
let a_shape = a.shape();
let b_shape = b.shape();
let grad_a = if a_shape.len() == 2 && b_shape.len() == 2 {
grad_output.matmul(&b.transpose(1, 0)?)?
} else {
let b_transposed = b.transpose(2, 1)?;
grad_output.matmul(&b_transposed)?
};
let grad_b = if a_shape.len() == 2 && b_shape.len() == 2 {
a.transpose(1, 0)?.matmul(grad_output)?
} else {
let a_transposed = a.permute(&[0, 2, 1])?;
a_transposed.matmul(grad_output)?
};
Ok(vec![grad_a, grad_b])
},
OperationType::Sigmoid => {
if parent_values.len() != 1 {
return Err(TrustformersError::tensor_op_error(
"Sigmoid operation requires exactly 1 input",
"ComputationGraph::compute_operation_gradients",
));
}
let sigmoid_out = parent_values[0].sigmoid()?;
let one = Tensor::ones(&sigmoid_out.shape())?;
let grad_input = grad_output.mul(&sigmoid_out)?.mul(&one.sub(&sigmoid_out)?)?;
Ok(vec![grad_input])
},
OperationType::Tanh => {
if parent_values.len() != 1 {
return Err(TrustformersError::tensor_op_error(
"Tanh operation requires exactly 1 input",
"ComputationGraph::compute_operation_gradients",
));
}
let tanh_out = parent_values[0].tanh()?;
let one = Tensor::ones(&tanh_out.shape())?;
let grad_input = grad_output.mul(&one.sub(&tanh_out.mul(&tanh_out)?)?)?;
Ok(vec![grad_input])
},
OperationType::ReLU => {
if parent_values.len() != 1 {
return Err(TrustformersError::tensor_op_error(
"ReLU operation requires exactly 1 input",
"ComputationGraph::compute_operation_gradients",
));
}
let input = parent_values[0];
let zero = Tensor::zeros(&input.shape())?;
let mask = input.greater(&zero)?;
let grad_input = grad_output.mul(&mask)?;
Ok(vec![grad_input])
},
OperationType::LeakyReLU(alpha) => {
if parent_values.len() != 1 {
return Err(TrustformersError::tensor_op_error(
"LeakyReLU operation requires exactly 1 input",
"ComputationGraph::compute_operation_gradients",
));
}
let input = parent_values[0];
let zero = Tensor::zeros(&input.shape())?;
let alpha_tensor = Tensor::scalar(*alpha)?;
let one = Tensor::ones(&input.shape())?;
let positive_mask = input.greater(&zero)?;
let negative_mask = one.sub(&positive_mask)?;
let grad_input =
grad_output.mul(&positive_mask.add(&negative_mask.mul(&alpha_tensor)?)?)?;
Ok(vec![grad_input])
},
OperationType::Sum(axes) => {
if parent_values.len() != 1 {
return Err(TrustformersError::tensor_op_error(
"Sum operation requires exactly 1 input",
"ComputationGraph::compute_operation_gradients",
));
}
let input_shape = parent_values[0].shape();
let grad_input =
self.broadcast_gradient(grad_output, &input_shape, axes.as_ref())?;
Ok(vec![grad_input])
},
OperationType::Mean(axes) => {
if parent_values.len() != 1 {
return Err(TrustformersError::tensor_op_error(
"Mean operation requires exactly 1 input",
"ComputationGraph::compute_operation_gradients",
));
}
let input_shape = parent_values[0].shape();
let grad_broadcasted =
self.broadcast_gradient(grad_output, &input_shape, axes.as_ref())?;
let num_elements = if let Some(axes) = axes {
axes.iter().map(|&axis| input_shape[axis]).product::<usize>()
} else {
input_shape.iter().product::<usize>()
};
let grad_input = grad_broadcasted.scalar_div(num_elements as f32)?;
Ok(vec![grad_input])
},
OperationType::Reshape(target_shape) => {
if parent_values.len() != 1 {
return Err(TrustformersError::tensor_op_error(
"Reshape operation requires exactly 1 input",
"ComputationGraph::compute_operation_gradients",
));
}
let original_shape = parent_values[0].shape();
let grad_input = grad_output.reshape(&original_shape)?;
Ok(vec![grad_input])
},
OperationType::Transpose(permutation) => {
if parent_values.len() != 1 {
return Err(TrustformersError::tensor_op_error(
"Transpose operation requires exactly 1 input",
"ComputationGraph::compute_operation_gradients",
));
}
let inverse_permutation = self.compute_inverse_permutation(permutation)?;
let grad_input = grad_output.permute(&inverse_permutation)?;
Ok(vec![grad_input])
},
_ => {
let zero_grads = parent_values
.iter()
.map(|input| {
Tensor::zeros(&input.shape()).expect("Failed to create zeros tensor")
})
.collect();
Ok(zero_grads)
},
}
}
fn broadcast_gradient(
&self,
grad_output: &Tensor,
original_shape: &[usize],
axes: Option<&Vec<usize>>,
) -> Result<Tensor> {
if let Some(axes) = axes {
let mut result = grad_output.clone();
for &axis in axes {
result = result.unsqueeze(axis)?;
}
result.broadcast_to(original_shape)
} else {
let grad_scalar = grad_output.clone();
grad_scalar.broadcast_to(original_shape)
}
}
fn compute_inverse_permutation(&self, permutation: &[usize]) -> Result<Vec<usize>> {
let mut inverse = vec![0; permutation.len()];
for (i, &p) in permutation.iter().enumerate() {
if p >= permutation.len() {
return Err(TrustformersError::tensor_op_error(
&format!("Invalid permutation index: {}", p),
"ComputationGraph::compute_inverse_permutation",
));
}
inverse[p] = i;
}
Ok(inverse)
}
pub fn zero_grad(&mut self) {
for node in self.nodes.values_mut() {
node.gradient = None;
}
}
pub fn get_gradient(&self, node_id: NodeId) -> Option<&Tensor> {
self.nodes.get(&node_id)?.gradient.as_ref()
}
pub fn get_value(&self, node_id: NodeId) -> Option<&Tensor> {
self.nodes.get(&node_id).map(|node| &node.value)
}
pub fn update_value(&mut self, node_id: NodeId, value: Tensor) -> Result<()> {
if let Some(node) = self.nodes.get_mut(&node_id) {
node.value = value;
node.shape = node.value.shape();
Ok(())
} else {
Err(TrustformersError::tensor_op_error(
&format!("Node {} not found", node_id),
"ComputationGraph::update_value",
))
}
}
pub fn get_root_nodes(&self) -> &[NodeId] {
&self.root_nodes
}
pub fn get_leaf_nodes(&self) -> &[NodeId] {
&self.leaf_nodes
}
pub fn set_leaf_node(&mut self, node_id: NodeId) {
if !self.leaf_nodes.contains(&node_id) {
self.leaf_nodes.push(node_id);
}
}
pub fn num_nodes(&self) -> usize {
self.nodes.len()
}
pub fn get_topological_order(&self) -> &[NodeId] {
&self.topological_order
}
pub fn zero_gradients(&mut self) {
for node in self.nodes.values_mut() {
node.gradient = None;
}
}
pub fn export_graph(&self) -> GraphExport {
let nodes: Vec<_> = self.nodes.values().cloned().collect();
GraphExport {
nodes,
topological_order: self.topological_order.clone(),
}
}
}
#[derive(Debug, Clone)]
pub struct GraphExport {
pub nodes: Vec<GraphNode>,
pub topological_order: Vec<NodeId>,
}
impl Default for ComputationGraph {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::Tensor;
#[test]
fn test_graph_creation() {
let mut graph = ComputationGraph::new();
assert_eq!(graph.num_nodes(), 0);
let tensor = Tensor::ones(&[2, 3]).expect("Failed to create ones tensor");
let node_id = graph.add_node(tensor, true, Some("test".to_string()));
assert_eq!(graph.num_nodes(), 1);
assert_eq!(node_id, 0);
}
#[test]
fn test_topological_order() {
let mut graph = ComputationGraph::new();
let a = Tensor::ones(&[2, 2]).expect("Failed to create ones tensor");
let b = Tensor::ones(&[2, 2]).expect("Failed to create ones tensor");
let c = a.add(&b).expect("Addition failed");
let node_a = graph.add_node(a, true, Some("a".to_string()));
let node_b = graph.add_node(b, true, Some("b".to_string()));
let node_c = graph
.add_operation_node(
c,
OperationType::Add,
vec![node_a, node_b],
true,
Some("c".to_string()),
)
.expect("operation failed in test");
graph.compute_topological_order().expect("operation failed in test");
let order = graph.get_topological_order();
assert_eq!(order.len(), 3);
let a_pos = order.iter().position(|&id| id == node_a).expect("operation failed in test");
let b_pos = order.iter().position(|&id| id == node_b).expect("operation failed in test");
let c_pos = order.iter().position(|&id| id == node_c).expect("operation failed in test");
assert!(a_pos < c_pos);
assert!(b_pos < c_pos);
}
#[test]
fn test_backward_pass() {
let mut graph = ComputationGraph::new();
let a = Tensor::scalar(2.0).expect("tensor operation failed");
let b = Tensor::scalar(3.0).expect("tensor operation failed");
let c = a.mul(&b).expect("Multiplication failed");
let node_a = graph.add_node(a.clone(), true, Some("a".to_string()));
let node_b = graph.add_node(b.clone(), true, Some("b".to_string()));
let node_c = graph
.add_operation_node(
c,
OperationType::Multiply,
vec![node_a, node_b],
true,
Some("c".to_string()),
)
.expect("operation failed in test");
graph.backward(node_c, None).expect("operation failed in test");
let grad_a = graph.get_gradient(node_a).expect("operation failed in test");
let grad_b = graph.get_gradient(node_b).expect("operation failed in test");
assert_eq!(
grad_a.to_vec_f32().expect("operation failed in test")[0],
3.0
);
assert_eq!(
grad_b.to_vec_f32().expect("operation failed in test")[0],
2.0
);
}
#[test]
fn test_gradient_accumulation() {
let mut graph = ComputationGraph::new();
let a = Tensor::scalar(2.0).expect("tensor operation failed");
let d = a.add(&a).expect("Addition failed");
let node_a = graph.add_node(a.clone(), true, Some("a".to_string()));
let node_d = graph
.add_operation_node(
d,
OperationType::Add,
vec![node_a, node_a],
true,
Some("d".to_string()),
)
.expect("operation failed in test");
graph.backward(node_d, None).expect("operation failed in test");
let grad_a = graph.get_gradient(node_a).expect("operation failed in test");
assert_eq!(
grad_a.to_vec_f32().expect("operation failed in test")[0],
2.0
);
}
#[test]
fn test_new_graph_is_empty() {
let graph = ComputationGraph::new();
assert_eq!(graph.num_nodes(), 0);
}
#[test]
fn test_add_single_node() {
let mut graph = ComputationGraph::new();
let tensor = Tensor::scalar(1.0).expect("tensor operation failed");
let id = graph.add_node(tensor, true, Some("x".to_string()));
assert_eq!(id, 0);
assert_eq!(graph.num_nodes(), 1);
}
#[test]
fn test_add_multiple_nodes() {
let mut graph = ComputationGraph::new();
for i in 0..5 {
let tensor = Tensor::scalar(i as f32).expect("tensor operation failed");
let id = graph.add_node(tensor, true, None);
assert_eq!(id, i);
}
assert_eq!(graph.num_nodes(), 5);
}
#[test]
fn test_get_node() {
let mut graph = ComputationGraph::new();
let tensor = Tensor::scalar(42.0).expect("tensor operation failed");
let id = graph.add_node(tensor, true, Some("test".to_string()));
let node = graph.get_node(id).expect("node should exist");
assert_eq!(node.id, id);
assert_eq!(node.name, Some("test".to_string()));
assert!(node.requires_grad);
assert!(node.is_leaf);
}
#[test]
fn test_get_node_nonexistent() {
let graph = ComputationGraph::new();
assert!(graph.get_node(999).is_none());
}
#[test]
fn test_get_node_mut() {
let mut graph = ComputationGraph::new();
let tensor = Tensor::scalar(1.0).expect("tensor operation failed");
let id = graph.add_node(tensor, true, None);
let node = graph.get_node_mut(id).expect("node should exist");
node.name = Some("modified".to_string());
assert_eq!(
graph.get_node(id).expect("node should exist").name,
Some("modified".to_string())
);
}
#[test]
fn test_add_operation_node() {
let mut graph = ComputationGraph::new();
let a = Tensor::scalar(1.0).expect("tensor operation failed");
let b = Tensor::scalar(2.0).expect("tensor operation failed");
let c = a.add(&b).expect("Addition failed");
let node_a = graph.add_node(a, true, None);
let node_b = graph.add_node(b, true, None);
let node_c = graph
.add_operation_node(c, OperationType::Add, vec![node_a, node_b], true, None)
.expect("operation failed in test");
let op_node = graph.get_node(node_c).expect("node should exist");
assert!(!op_node.is_leaf);
assert_eq!(op_node.operation, Some(OperationType::Add));
assert_eq!(op_node.parents, vec![node_a, node_b]);
}
#[test]
fn test_add_operation_node_nonexistent_parent() {
let mut graph = ComputationGraph::new();
let tensor = Tensor::scalar(1.0).expect("tensor operation failed");
let result = graph.add_operation_node(tensor, OperationType::Add, vec![999], true, None);
assert!(result.is_err());
}
#[test]
fn test_topological_order_linear_chain() {
let mut graph = ComputationGraph::new();
let a = Tensor::scalar(1.0).expect("tensor operation failed");
let b = a.add(&a).expect("Addition failed");
let c = b.add(&b).expect("Addition failed");
let node_a = graph.add_node(a.clone(), true, Some("a".to_string()));
let node_b = graph
.add_operation_node(
b,
OperationType::Add,
vec![node_a, node_a],
true,
Some("b".to_string()),
)
.expect("operation failed in test");
let node_c = graph
.add_operation_node(
c,
OperationType::Add,
vec![node_b, node_b],
true,
Some("c".to_string()),
)
.expect("operation failed in test");
graph.compute_topological_order().expect("operation failed in test");
let order = graph.get_topological_order();
let a_pos = order.iter().position(|&id| id == node_a).expect("a not found");
let b_pos = order.iter().position(|&id| id == node_b).expect("b not found");
let c_pos = order.iter().position(|&id| id == node_c).expect("c not found");
assert!(a_pos < b_pos);
assert!(b_pos < c_pos);
}
#[test]
fn test_topological_order_single_node() {
let mut graph = ComputationGraph::new();
let tensor = Tensor::scalar(1.0).expect("tensor operation failed");
graph.add_node(tensor, true, None);
graph.compute_topological_order().expect("operation failed in test");
assert_eq!(graph.get_topological_order().len(), 1);
}
#[test]
fn test_topological_order_idempotent() {
let mut graph = ComputationGraph::new();
let tensor = Tensor::scalar(1.0).expect("tensor operation failed");
graph.add_node(tensor, true, None);
graph.compute_topological_order().expect("operation failed in test");
let order1 = graph.get_topological_order().to_vec();
graph.compute_topological_order().expect("operation failed in test");
let order2 = graph.get_topological_order().to_vec();
assert_eq!(order1, order2);
}
#[test]
fn test_backward_subtraction() {
let mut graph = ComputationGraph::new();
let a = Tensor::scalar(5.0).expect("tensor operation failed");
let b = Tensor::scalar(3.0).expect("tensor operation failed");
let c = a.sub(&b).expect("Subtraction failed");
let node_a = graph.add_node(a.clone(), true, Some("a".to_string()));
let node_b = graph.add_node(b.clone(), true, Some("b".to_string()));
let node_c = graph
.add_operation_node(c, OperationType::Subtract, vec![node_a, node_b], true, None)
.expect("operation failed in test");
graph.backward(node_c, None).expect("operation failed in test");
let grad_a = graph.get_gradient(node_a).expect("gradient should exist");
let grad_b = graph.get_gradient(node_b).expect("gradient should exist");
assert_eq!(
grad_a.to_vec_f32().expect("operation failed in test")[0],
1.0
);
assert_eq!(
grad_b.to_vec_f32().expect("operation failed in test")[0],
-1.0
);
}
#[test]
fn test_backward_nonexistent_output() {
let mut graph = ComputationGraph::new();
let tensor = Tensor::scalar(1.0).expect("tensor operation failed");
graph.add_node(tensor, true, None);
let result = graph.backward(999, None);
assert!(result.is_err());
}
#[test]
fn test_graph_node_clone() {
let tensor = Tensor::scalar(std::f32::consts::PI).expect("tensor operation failed");
let node = GraphNode {
id: 0,
value: tensor,
gradient: None,
operation: Some(OperationType::Add),
parents: vec![1, 2],
children: vec![3],
requires_grad: true,
is_leaf: false,
name: Some("test".to_string()),
shape: vec![1],
};
let cloned = node.clone();
assert_eq!(cloned.id, 0);
assert_eq!(cloned.parents, vec![1, 2]);
assert_eq!(cloned.name, Some("test".to_string()));
}
#[test]
fn test_operation_type_eq() {
assert_eq!(OperationType::Add, OperationType::Add);
assert_ne!(OperationType::Add, OperationType::Subtract);
}
#[test]
fn test_operation_type_variants() {
let _add = OperationType::Add;
let _sub = OperationType::Subtract;
let _mul = OperationType::Multiply;
let _div = OperationType::Divide;
let _mm = OperationType::MatrixMultiply;
let _neg = OperationType::Negate;
let _rec = OperationType::Reciprocal;
let _sq = OperationType::Square;
let _sqrt = OperationType::Sqrt;
let _log = OperationType::Log;
let _exp = OperationType::Exp;
let _sig = OperationType::Sigmoid;
let _tnh = OperationType::Tanh;
let _relu = OperationType::ReLU;
let _lrelu = OperationType::LeakyReLU(0.01);
let _smax = OperationType::Softmax;
let _lsmax = OperationType::LogSoftmax;
let _resh = OperationType::Reshape(vec![2, 3]);
let _sum = OperationType::Sum(None);
let _mean = OperationType::Mean(Some(vec![0]));
let _ln = OperationType::LayerNorm(1e-5);
let _drop = OperationType::Dropout(0.1);
let _cust = OperationType::Custom("test".to_string());
}
#[test]
fn test_operation_type_clone() {
let op = OperationType::LeakyReLU(0.01);
let cloned = op.clone();
assert_eq!(op, cloned);
}
#[test]
fn test_no_grad_node_skipped_in_backward() {
let mut graph = ComputationGraph::new();
let a = Tensor::scalar(2.0).expect("tensor operation failed");
let b = Tensor::scalar(3.0).expect("tensor operation failed");
let c = a.mul(&b).expect("Multiplication failed");
let node_a = graph.add_node(a.clone(), true, Some("a".to_string()));
let node_b = graph.add_node(b.clone(), false, Some("b".to_string())); let node_c = graph
.add_operation_node(c, OperationType::Multiply, vec![node_a, node_b], true, None)
.expect("operation failed in test");
graph.backward(node_c, None).expect("operation failed in test");
let grad_a = graph.get_gradient(node_a).expect("gradient should exist");
assert_eq!(
grad_a.to_vec_f32().expect("operation failed in test")[0],
3.0
);
assert!(graph.get_gradient(node_b).is_none());
}
#[test]
fn test_reset_gradients() {
let mut graph = ComputationGraph::new();
let a = Tensor::scalar(2.0).expect("tensor operation failed");
let b = Tensor::scalar(3.0).expect("tensor operation failed");
let c = a.mul(&b).expect("Multiplication failed");
let node_a = graph.add_node(a.clone(), true, Some("a".to_string()));
let node_b = graph.add_node(b.clone(), true, Some("b".to_string()));
let node_c = graph
.add_operation_node(c, OperationType::Multiply, vec![node_a, node_b], true, None)
.expect("operation failed in test");
graph.backward(node_c, None).expect("operation failed in test");
assert!(graph.get_gradient(node_a).is_some());
graph.zero_gradients();
assert!(graph.get_gradient(node_a).is_none());
}
#[test]
fn test_backward_division() {
let mut graph = ComputationGraph::new();
let a = Tensor::scalar(6.0).expect("tensor operation failed");
let b = Tensor::scalar(3.0).expect("tensor operation failed");
let c = a.div(&b).expect("Division failed");
let node_a = graph.add_node(a.clone(), true, Some("a".to_string()));
let node_b = graph.add_node(b.clone(), true, Some("b".to_string()));
let node_c = graph
.add_operation_node(c, OperationType::Divide, vec![node_a, node_b], true, None)
.expect("operation failed in test");
graph.backward(node_c, None).expect("operation failed in test");
let grad_a = graph.get_gradient(node_a).expect("gradient should exist");
let ga = grad_a.to_vec_f32().expect("operation failed in test")[0];
assert!((ga - 1.0 / 3.0).abs() < 1e-5);
let grad_b = graph.get_gradient(node_b).expect("gradient should exist");
let gb = grad_b.to_vec_f32().expect("operation failed in test")[0];
assert!((gb - (-2.0 / 3.0)).abs() < 1e-5);
}
}