use std::sync::Arc;
use crate::tensor::DenseTensor;
use crate::tensor::traits::{TensorOps, TensorBase};
use super::compute_graph::{ComputeGraph, OpId, TensorId, OpType};
#[derive(Debug, Clone)]
pub struct DifferentiableTensor {
data: DenseTensor,
grad: Option<DenseTensor>,
op_id: Option<OpId>,
tensor_id: TensorId,
requires_grad: bool,
#[allow(dead_code)]
graph: Option<Arc<ComputeGraph>>,
}
impl DifferentiableTensor {
pub fn new(data: DenseTensor, requires_grad: bool) -> Self {
Self {
data,
grad: None,
op_id: None,
tensor_id: TensorId(0),
requires_grad,
graph: None,
}
}
pub fn with_graph(data: DenseTensor, requires_grad: bool, graph: &mut ComputeGraph) -> Self {
let tensor_id = graph.next_tensor_id();
graph.store_value(tensor_id, data.clone());
Self {
data,
grad: None,
op_id: None,
tensor_id,
requires_grad,
graph: Some(Arc::new(graph.clone())),
}
}
pub fn data(&self) -> &DenseTensor {
&self.data
}
pub fn data_mut(&mut self) -> &mut DenseTensor {
&mut self.data
}
pub fn grad(&self) -> Option<&DenseTensor> {
self.grad.as_ref()
}
pub fn grad_mut(&mut self) -> Option<&mut DenseTensor> {
self.grad.as_mut()
}
pub fn set_grad(&mut self, grad: DenseTensor) {
self.grad = Some(grad);
}
pub fn zero_grad(&mut self) {
self.grad = None;
}
pub fn requires_grad(&self) -> bool {
self.requires_grad
}
pub fn tensor_id(&self) -> TensorId {
self.tensor_id
}
pub fn op_id(&self) -> Option<OpId> {
self.op_id
}
#[allow(dead_code)]
pub(crate) fn set_op_id(&mut self, op_id: OpId) {
self.op_id = Some(op_id);
}
pub fn shape(&self) -> &[usize] {
self.data.shape()
}
pub fn matmul(&self, other: &DifferentiableTensor, graph: &mut ComputeGraph) -> DifferentiableTensor {
let output_id = graph.next_tensor_id();
let result_data = self.data.matmul(&other.data);
graph.record_op(OpType::MatMul, &[self.tensor_id, other.tensor_id], output_id);
graph.store_value(output_id, result_data.clone());
DifferentiableTensor {
data: result_data,
grad: None,
op_id: None,
tensor_id: output_id,
requires_grad: self.requires_grad || other.requires_grad,
graph: Some(Arc::new(graph.clone())),
}
}
pub fn add(&self, other: &DifferentiableTensor, graph: &mut ComputeGraph) -> DifferentiableTensor {
let output_id = graph.next_tensor_id();
let result_data = self.data.add(&other.data);
graph.record_op(OpType::Add, &[self.tensor_id, other.tensor_id], output_id);
graph.store_value(output_id, result_data.clone());
DifferentiableTensor {
data: result_data,
grad: None,
op_id: None,
tensor_id: output_id,
requires_grad: self.requires_grad || other.requires_grad,
graph: Some(Arc::new(graph.clone())),
}
}
pub fn sub(&self, other: &DifferentiableTensor, graph: &mut ComputeGraph) -> DifferentiableTensor {
let output_id = graph.next_tensor_id();
let result_data = self.data.sub(&other.data);
graph.record_op(OpType::Sub, &[self.tensor_id, other.tensor_id], output_id);
graph.store_value(output_id, result_data.clone());
DifferentiableTensor {
data: result_data,
grad: None,
op_id: None,
tensor_id: output_id,
requires_grad: self.requires_grad || other.requires_grad,
graph: Some(Arc::new(graph.clone())),
}
}
pub fn mul(&self, other: &DifferentiableTensor, graph: &mut ComputeGraph) -> DifferentiableTensor {
let output_id = graph.next_tensor_id();
let result_data = self.data.mul(&other.data);
graph.record_op(OpType::Mul, &[self.tensor_id, other.tensor_id], output_id);
graph.store_value(output_id, result_data.clone());
DifferentiableTensor {
data: result_data,
grad: None,
op_id: None,
tensor_id: output_id,
requires_grad: self.requires_grad || other.requires_grad,
graph: Some(Arc::new(graph.clone())),
}
}
pub fn relu(&self, graph: &mut ComputeGraph) -> DifferentiableTensor {
let output_id = graph.next_tensor_id();
let result_data = self.data.relu();
graph.record_op(OpType::ReLU, &[self.tensor_id], output_id);
graph.store_value(output_id, result_data.clone());
DifferentiableTensor {
data: result_data,
grad: None,
op_id: None,
tensor_id: output_id,
requires_grad: self.requires_grad,
graph: Some(Arc::new(graph.clone())),
}
}
pub fn gelu(&self, graph: &mut ComputeGraph) -> DifferentiableTensor {
let output_id = graph.next_tensor_id();
let result_data = self.data.gelu();
graph.record_op(OpType::GELU, &[self.tensor_id], output_id);
graph.store_value(output_id, result_data.clone());
DifferentiableTensor {
data: result_data,
grad: None,
op_id: None,
tensor_id: output_id,
requires_grad: self.requires_grad,
graph: Some(Arc::new(graph.clone())),
}
}
pub fn softmax(&self, dim: isize, graph: &mut ComputeGraph) -> DifferentiableTensor {
let output_id = graph.next_tensor_id();
let result_data = self.data.softmax(dim);
graph.record_op(OpType::Softmax, &[self.tensor_id], output_id);
graph.store_value(output_id, result_data.clone());
DifferentiableTensor {
data: result_data,
grad: None,
op_id: None,
tensor_id: output_id,
requires_grad: self.requires_grad,
graph: Some(Arc::new(graph.clone())),
}
}
pub fn transpose(&self, graph: &mut ComputeGraph) -> DifferentiableTensor {
let output_id = graph.next_tensor_id();
let result_data = self.data.transpose(None);
graph.record_op(OpType::Transpose, &[self.tensor_id], output_id);
graph.store_value(output_id, result_data.clone());
DifferentiableTensor {
data: result_data,
grad: None,
op_id: None,
tensor_id: output_id,
requires_grad: self.requires_grad,
graph: Some(Arc::new(graph.clone())),
}
}
pub fn detach(&self) -> DifferentiableTensor {
DifferentiableTensor {
data: self.data.clone(),
grad: None,
op_id: None,
tensor_id: TensorId(0),
requires_grad: false,
graph: None,
}
}
pub fn backward(&mut self, graph: &mut ComputeGraph) {
if self.requires_grad {
graph.backward(self.tensor_id);
if let Some(grad) = graph.get_gradient(self.tensor_id).cloned() {
self.grad = Some(grad);
}
}
}
}
impl From<DenseTensor> for DifferentiableTensor {
fn from(tensor: DenseTensor) -> Self {
Self::new(tensor, false)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::DenseTensor;
#[test]
fn test_differentiable_tensor_creation() {
let data = DenseTensor::new(vec![1.0, 2.0, 3.0], vec![1, 3]);
let tensor = DifferentiableTensor::new(data.clone(), true);
assert!(tensor.requires_grad());
assert_eq!(tensor.data(), &data);
assert!(tensor.grad().is_none());
}
#[test]
fn test_differentiable_matmul() {
let mut graph = ComputeGraph::new();
let x = DenseTensor::new(vec![1.0, 2.0], vec![1, 2]);
let w = DenseTensor::new(vec![0.1, 0.2, 0.3, 0.4], vec![2, 2]);
let x_diff = DifferentiableTensor::with_graph(x, true, &mut graph);
let w_diff = DifferentiableTensor::with_graph(w, true, &mut graph);
let out = x_diff.matmul(&w_diff, &mut graph);
assert!(out.requires_grad());
assert_eq!(out.shape(), &[1, 2]);
}
#[test]
fn test_differentiable_relu() {
let mut graph = ComputeGraph::new();
let data = DenseTensor::new(vec![-1.0, 2.0, -3.0, 4.0], vec![1, 4]);
let tensor = DifferentiableTensor::with_graph(data, true, &mut graph);
let out = tensor.relu(&mut graph);
assert_eq!(out.shape(), &[1, 4]);
}
}