use super::graph::{ComputationGraph, NodeId, OperationType};
use crate::errors::{Result, TrustformersError};
use crate::tensor::Tensor;
use std::sync::{Arc, Mutex};
pub type GraphRef = Arc<Mutex<ComputationGraph>>;
#[derive(Debug, Clone)]
pub struct Variable {
graph: GraphRef,
node_id: NodeId,
requires_grad: bool,
}
pub type VariableRef = Arc<Variable>;
impl Variable {
pub fn new(tensor: Tensor, requires_grad: bool) -> Self {
let graph = Arc::new(Mutex::new(ComputationGraph::new()));
let node_id = {
let mut graph_guard = graph.lock().expect("lock should not be poisoned");
graph_guard.add_node(tensor, requires_grad, None)
};
Self {
graph,
node_id,
requires_grad,
}
}
pub fn new_with_name(tensor: Tensor, requires_grad: bool, name: String) -> Self {
let graph = Arc::new(Mutex::new(ComputationGraph::new()));
let node_id = {
let mut graph_guard = graph.lock().expect("lock should not be poisoned");
graph_guard.add_node(tensor, requires_grad, Some(name))
};
Self {
graph,
node_id,
requires_grad,
}
}
pub fn from_graph(graph: GraphRef, node_id: NodeId, requires_grad: bool) -> Self {
Self {
graph,
node_id,
requires_grad,
}
}
pub fn data(&self) -> Result<Tensor> {
let graph = self.graph.lock().expect("lock should not be poisoned");
graph.get_value(self.node_id).cloned().ok_or_else(|| {
TrustformersError::tensor_op_error(
&format!("Node {} not found in graph", self.node_id),
"Variable::data",
)
})
}
pub fn grad(&self) -> Result<Option<Tensor>> {
let graph = self.graph.lock().expect("lock should not be poisoned");
Ok(graph.get_gradient(self.node_id).cloned())
}
pub fn node_id(&self) -> NodeId {
self.node_id
}
pub fn requires_grad(&self) -> bool {
self.requires_grad
}
pub fn graph(&self) -> GraphRef {
self.graph.clone()
}
pub fn shape(&self) -> Result<Vec<usize>> {
let graph = self.graph.lock().expect("lock should not be poisoned");
graph.get_value(self.node_id).map(|tensor| tensor.shape()).ok_or_else(|| {
TrustformersError::tensor_op_error(
&format!("Node {} not found in graph", self.node_id),
"Variable::shape",
)
})
}
pub fn item(&self) -> Result<f32> {
let tensor = self.data()?;
tensor.to_scalar()
}
pub fn backward(&self) -> Result<()> {
let mut graph = self.graph.lock().expect("lock should not be poisoned");
graph.backward(self.node_id, None)
}
pub fn backward_with_grad(&self, grad: Tensor) -> Result<()> {
let mut graph = self.graph.lock().expect("lock should not be poisoned");
graph.backward(self.node_id, Some(grad))
}
pub fn zero_grad(&self) {
let mut graph = self.graph.lock().expect("lock should not be poisoned");
graph.zero_grad();
}
pub fn detach(&self) -> Result<Variable> {
let tensor = self.data()?;
Ok(Variable::new(tensor, false))
}
pub fn requires_grad_(&self) -> Result<Variable> {
let tensor = self.data()?;
Ok(Variable::new(tensor, true))
}
pub fn set_data(&self, tensor: Tensor) -> Result<()> {
let mut graph = self.graph.lock().expect("lock should not be poisoned");
graph.update_value(self.node_id, tensor)
}
pub fn add(&self, other: &Variable) -> Result<Variable> {
self.binary_op(other, OperationType::Add)
}
pub fn sub(&self, other: &Variable) -> Result<Variable> {
self.binary_op(other, OperationType::Subtract)
}
pub fn mul(&self, other: &Variable) -> Result<Variable> {
self.binary_op(other, OperationType::Multiply)
}
pub fn div(&self, other: &Variable) -> Result<Variable> {
self.binary_op(other, OperationType::Divide)
}
pub fn matmul(&self, other: &Variable) -> Result<Variable> {
self.binary_op(other, OperationType::MatrixMultiply)
}
pub fn neg(&self) -> Result<Variable> {
self.unary_op(OperationType::Negate)
}
pub fn square(&self) -> Result<Variable> {
self.unary_op(OperationType::Square)
}
pub fn sqrt(&self) -> Result<Variable> {
self.unary_op(OperationType::Sqrt)
}
pub fn log(&self) -> Result<Variable> {
self.unary_op(OperationType::Log)
}
pub fn exp(&self) -> Result<Variable> {
self.unary_op(OperationType::Exp)
}
pub fn sigmoid(&self) -> Result<Variable> {
self.unary_op(OperationType::Sigmoid)
}
pub fn tanh(&self) -> Result<Variable> {
self.unary_op(OperationType::Tanh)
}
pub fn relu(&self) -> Result<Variable> {
self.unary_op(OperationType::ReLU)
}
pub fn leaky_relu(&self, alpha: f32) -> Result<Variable> {
self.unary_op(OperationType::LeakyReLU(alpha))
}
pub fn softmax(&self) -> Result<Variable> {
self.unary_op(OperationType::Softmax)
}
pub fn reshape(&self, shape: Vec<usize>) -> Result<Variable> {
self.unary_op(OperationType::Reshape(shape))
}
pub fn transpose(&self, permutation: Vec<usize>) -> Result<Variable> {
self.unary_op(OperationType::Transpose(permutation))
}
pub fn sum(&self, axes: Option<Vec<usize>>) -> Result<Variable> {
self.unary_op(OperationType::Sum(axes))
}
pub fn mean(&self, axes: Option<Vec<usize>>) -> Result<Variable> {
self.unary_op(OperationType::Mean(axes))
}
pub fn max(&self, axes: Option<Vec<usize>>) -> Result<Variable> {
self.unary_op(OperationType::Max(axes))
}
pub fn min(&self, axes: Option<Vec<usize>>) -> Result<Variable> {
self.unary_op(OperationType::Min(axes))
}
pub fn add_scalar(&self, scalar: f32) -> Result<Variable> {
let scalar_tensor = Tensor::scalar(scalar)?;
let scalar_var = Variable::new(scalar_tensor, false);
self.add(&scalar_var)
}
pub fn sub_scalar(&self, scalar: f32) -> Result<Variable> {
let scalar_tensor = Tensor::scalar(scalar)?;
let scalar_var = Variable::new(scalar_tensor, false);
self.sub(&scalar_var)
}
pub fn mul_scalar(&self, scalar: f32) -> Result<Variable> {
let scalar_tensor = Tensor::scalar(scalar)?;
let scalar_var = Variable::new(scalar_tensor, false);
self.mul(&scalar_var)
}
pub fn div_scalar(&self, scalar: f32) -> Result<Variable> {
let scalar_tensor = Tensor::scalar(scalar)?;
let scalar_var = Variable::new(scalar_tensor, false);
self.div(&scalar_var)
}
fn binary_op(&self, other: &Variable, op: OperationType) -> Result<Variable> {
if !Arc::ptr_eq(&self.graph, &other.graph) {
return Err(TrustformersError::tensor_op_error(
"Variables must be from the same computation graph",
"Variable::binary_op",
));
}
let result_tensor = self.compute_binary_tensor_op(&other.data()?, &op)?;
let requires_grad = self.requires_grad || other.requires_grad;
let node_id = {
let mut graph = self.graph.lock().expect("lock should not be poisoned");
graph.add_operation_node(
result_tensor,
op,
vec![self.node_id, other.node_id],
requires_grad,
None,
)?
};
Ok(Variable::from_graph(
self.graph.clone(),
node_id,
requires_grad,
))
}
fn unary_op(&self, op: OperationType) -> Result<Variable> {
let result_tensor = self.compute_unary_tensor_op(&op)?;
let node_id = {
let mut graph = self.graph.lock().expect("lock should not be poisoned");
graph.add_operation_node(
result_tensor,
op,
vec![self.node_id],
self.requires_grad,
None,
)?
};
Ok(Variable::from_graph(
self.graph.clone(),
node_id,
self.requires_grad,
))
}
fn compute_binary_tensor_op(&self, other: &Tensor, op: &OperationType) -> Result<Tensor> {
let self_tensor = self.data()?;
match op {
OperationType::Add => Tensor::add(&self_tensor, other),
OperationType::Subtract => Tensor::sub(&self_tensor, other),
OperationType::Multiply => self_tensor.mul(other),
OperationType::Divide => Tensor::div(&self_tensor, other),
OperationType::MatrixMultiply => self_tensor.matmul(other),
_ => Err(TrustformersError::tensor_op_error(
&format!("Unsupported binary operation: {:?}", op),
"Variable::compute_binary_tensor_op",
)),
}
}
fn compute_unary_tensor_op(&self, op: &OperationType) -> Result<Tensor> {
let self_tensor = self.data()?;
match op {
OperationType::Negate => self_tensor.neg(),
OperationType::Square => self_tensor.clone().mul(&self_tensor),
OperationType::Sqrt => self_tensor.sqrt(),
OperationType::Log => self_tensor.log(),
OperationType::Exp => self_tensor.exp(),
OperationType::Sigmoid => self_tensor.sigmoid(),
OperationType::Tanh => self_tensor.tanh(),
OperationType::ReLU => self_tensor.relu(),
OperationType::LeakyReLU(alpha) => self_tensor.leaky_relu(*alpha),
OperationType::Softmax => self_tensor.softmax(-1),
OperationType::Reshape(shape) => self_tensor.reshape(shape),
OperationType::Transpose(permutation) => {
if permutation.len() >= 2 {
self_tensor.transpose(permutation[0], permutation[1])
} else {
self_tensor.transpose(0, 1)
}
},
OperationType::Sum(axes) => {
match axes {
Some(axes_vec) => self_tensor.sum_axes(axes_vec),
None => {
let shape = self_tensor.shape();
let all_axes: Vec<usize> = (0..shape.len()).collect();
self_tensor.sum_axes(&all_axes)
},
}
},
OperationType::Mean(_axes) => {
self_tensor.mean()
},
_ => Err(TrustformersError::tensor_op_error(
&format!("Unsupported unary operation: {:?}", op),
"Variable::compute_unary_tensor_op",
)),
}
}
pub fn set_requires_grad(&mut self, requires_grad: bool) {
self.requires_grad = requires_grad;
if let Ok(mut graph) = self.graph.lock() {
if let Some(node) = graph.get_node_mut(self.node_id) {
node.requires_grad = requires_grad;
}
}
}
pub fn from_tensor(tensor: Tensor) -> Self {
Variable::new(tensor, false)
}
}
impl Variable {
pub fn scalar(value: f32, requires_grad: bool) -> Result<Self> {
let tensor = Tensor::scalar(value)?;
Ok(Variable::new(tensor, requires_grad))
}
pub fn zeros(shape: &[usize], requires_grad: bool) -> Result<Self> {
let tensor = Tensor::zeros(shape)?;
Ok(Variable::new(tensor, requires_grad))
}
pub fn ones(shape: &[usize], requires_grad: bool) -> Result<Self> {
let tensor = Tensor::ones(shape)?;
Ok(Variable::new(tensor, requires_grad))
}
pub fn randn(shape: &[usize], requires_grad: bool) -> Result<Self> {
let tensor = Tensor::randn(shape)?;
Ok(Variable::new(tensor, requires_grad))
}
pub fn rand(shape: &[usize], requires_grad: bool) -> Result<Self> {
let tensor = Tensor::randn(shape)?;
Ok(Variable::new(tensor, requires_grad))
}
}
use std::ops::{Add, Div, Mul, Neg, Sub};
impl Add for &Variable {
type Output = Result<Variable>;
fn add(self, rhs: Self) -> Self::Output {
self.add(rhs)
}
}
impl Sub for &Variable {
type Output = Result<Variable>;
fn sub(self, rhs: Self) -> Self::Output {
self.sub(rhs)
}
}
impl Mul for &Variable {
type Output = Result<Variable>;
fn mul(self, rhs: Self) -> Self::Output {
self.mul(rhs)
}
}
impl Div for &Variable {
type Output = Result<Variable>;
fn div(self, rhs: Self) -> Self::Output {
self.div(rhs)
}
}
impl Neg for &Variable {
type Output = Result<Variable>;
fn neg(self) -> Self::Output {
self.neg()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::Tensor;
#[test]
fn test_variable_creation() {
let tensor = Tensor::ones(&[2, 3]).expect("Failed to create ones tensor");
let var = Variable::new(tensor, true);
assert!(var.requires_grad());
assert_eq!(var.shape().expect("operation failed in test"), vec![2, 3]);
}
#[test]
fn test_variable_operations() {
use super::super::AutodiffEngine;
use std::sync::Arc;
let engine = Arc::new(AutodiffEngine::default());
let a = engine.variable(Tensor::scalar(2.0).expect("tensor operation failed"), true);
let b = engine.variable(Tensor::scalar(3.0).expect("tensor operation failed"), true);
let c = a.add(&b).expect("Addition failed");
assert_eq!(c.item().expect("operation failed in test"), 5.0);
let d = a.mul(&b).expect("Multiplication failed");
assert_eq!(d.item().expect("operation failed in test"), 6.0);
}
#[test]
fn test_gradient_computation() {
use super::super::AutodiffEngine;
use std::sync::Arc;
let engine = Arc::new(AutodiffEngine::default());
let a = engine.variable(Tensor::scalar(2.0).expect("tensor operation failed"), true);
let b = engine.variable(Tensor::scalar(3.0).expect("tensor operation failed"), true);
let c = a.mul(&b).expect("Multiplication failed");
engine.backward(&c, None).expect("operation failed in test");
let grad_a = engine
.get_grad(&a)
.expect("operation failed in test")
.expect("operation failed in test");
let grad_b = engine
.get_grad(&b)
.expect("operation failed in test")
.expect("operation failed in test");
assert_eq!(grad_a.to_scalar().expect("operation failed in test"), 3.0);
assert_eq!(grad_b.to_scalar().expect("operation failed in test"), 2.0);
}
#[test]
fn test_activation_functions() {
let x = Variable::scalar(0.0, true).expect("operation failed in test");
let sigmoid_x = x.sigmoid().expect("Sigmoid failed");
assert_eq!(sigmoid_x.item().expect("operation failed in test"), 0.5);
let tanh_x = x.tanh().expect("Tanh failed");
assert_eq!(tanh_x.item().expect("operation failed in test"), 0.0);
}
#[test]
fn test_tensor_operations() {
let x = Variable::ones(&[2, 3], true).expect("operation failed in test");
let sum_x = x.sum(None).expect("operation failed in test");
assert_eq!(sum_x.item().expect("operation failed in test"), 6.0);
let mean_x = x.mean(None).expect("Mean calculation failed");
assert_eq!(mean_x.item().expect("operation failed in test"), 1.0);
}
#[test]
fn test_reshape_operation() {
let x = Variable::ones(&[2, 3], true).expect("operation failed in test");
let reshaped = x.reshape(vec![3, 2]).expect("Reshape failed");
assert_eq!(
reshaped.shape().expect("operation failed in test"),
vec![3, 2]
);
}
#[test]
fn test_detach_operation() {
let x = Variable::scalar(2.0, true).expect("operation failed in test");
let y = x.detach().expect("operation failed in test");
assert!(x.requires_grad());
assert!(!y.requires_grad());
}
}