use crate::error::AutogradError;
use scirs2_core::ndarray::{Array, ArrayD, IxDyn};
use std::sync::{Arc, Mutex, RwLock};
thread_local! {
static ACTIVE_TAPE: RefCell<Option<Arc<Mutex<TapeInner>>>> = RefCell::new(None);
}
use std::cell::RefCell;
static TENSOR_ID_COUNTER: std::sync::atomic::AtomicUsize =
std::sync::atomic::AtomicUsize::new(1);
fn next_tensor_id() -> usize {
TENSOR_ID_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
}
pub struct TapeOp {
pub output_id: usize,
pub input_ids: Vec<usize>,
pub backward_fn: Box<dyn Fn(&ArrayD<f64>) -> Vec<ArrayD<f64>> + Send + Sync>,
pub name: String,
}
impl std::fmt::Debug for TapeOp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TapeOp")
.field("name", &self.name)
.field("output_id", &self.output_id)
.field("input_ids", &self.input_ids)
.finish()
}
}
struct TapeInner {
ops: Vec<TapeOp>,
values: std::collections::HashMap<usize, ArrayD<f64>>,
}
impl TapeInner {
fn new() -> Self {
Self {
ops: Vec::new(),
values: std::collections::HashMap::new(),
}
}
}
pub struct GradientTape {
inner: Arc<Mutex<TapeInner>>,
}
impl GradientTape {
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(TapeInner::new())),
}
}
pub fn operations_count(&self) -> usize {
self.inner
.lock()
.map(|g| g.ops.len())
.unwrap_or(0)
}
pub fn record_op(&self, op: TapeOp) {
if let Ok(mut inner) = self.inner.lock() {
inner.ops.push(op);
}
}
pub fn store_value(&self, tensor_id: usize, value: ArrayD<f64>) {
if let Ok(mut inner) = self.inner.lock() {
inner.values.insert(tensor_id, value);
}
}
pub fn gradient(
&self,
output: &DynTensor,
inputs: &[&DynTensor],
) -> Result<Vec<ArrayD<f64>>, AutogradError> {
let inner = self
.inner
.lock()
.map_err(|e| AutogradError::OperationError(format!("Tape lock poisoned: {e}")))?;
let output_shape = output.shape();
let output_id = output.id();
let mut grad_map: std::collections::HashMap<usize, ArrayD<f64>> =
std::collections::HashMap::new();
let seed = Array::ones(output_shape.as_slice());
grad_map.insert(output_id, seed);
for op in inner.ops.iter().rev() {
let grad_out = match grad_map.get(&op.output_id) {
Some(g) => g.clone(),
None => continue, };
let input_grads = (op.backward_fn)(&grad_out);
for (input_id, input_grad) in op.input_ids.iter().zip(input_grads.into_iter()) {
let entry = grad_map
.entry(*input_id)
.or_insert_with(|| Array::zeros(input_grad.raw_dim()));
*entry = entry.clone() + &input_grad;
}
}
let results = inputs
.iter()
.map(|t| {
grad_map
.get(&t.id())
.cloned()
.unwrap_or_else(|| Array::zeros(t.shape().as_slice()))
})
.collect();
Ok(results)
}
}
impl Default for GradientTape {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug)]
pub struct DynTensor {
data: Arc<DynTensorData>,
}
#[derive(Debug)]
struct DynTensorData {
id: usize,
value: RwLock<ArrayD<f64>>,
}
impl DynTensor {
pub fn new(value: ArrayD<f64>) -> Self {
let id = next_tensor_id();
ACTIVE_TAPE.with(|cell| {
if let Some(tape_arc) = cell.borrow().as_ref() {
if let Ok(mut tape) = tape_arc.lock() {
tape.values.insert(id, value.clone());
}
}
});
Self {
data: Arc::new(DynTensorData {
id,
value: RwLock::new(value),
}),
}
}
pub fn id(&self) -> usize {
self.data.id
}
pub fn value(&self) -> ArrayD<f64> {
self.data
.value
.read()
.map(|v| v.clone())
.unwrap_or_else(|_| Array::zeros(IxDyn(&[])))
}
pub fn shape(&self) -> Vec<usize> {
self.data
.value
.read()
.map(|v| v.shape().to_vec())
.unwrap_or_default()
}
pub fn is_scalar(&self) -> bool {
self.data
.value
.read()
.map(|v| v.ndim() == 0 || (v.ndim() == 1 && v.len() == 1))
.unwrap_or(false)
}
pub fn scalar_value(&self) -> Result<f64, AutogradError> {
let v = self.value();
if v.len() == 1 {
v.iter()
.next()
.copied()
.ok_or_else(|| AutogradError::OperationError("Empty tensor".to_string()))
} else {
Err(AutogradError::OperationError(format!(
"scalar_value: tensor has {} elements, expected 1",
v.len()
)))
}
}
}
fn record_op_on_tape(op: TapeOp) {
ACTIVE_TAPE.with(|cell| {
if let Some(tape_arc) = cell.borrow().as_ref() {
if let Ok(mut tape) = tape_arc.lock() {
tape.ops.push(op);
}
}
});
}
fn store_value_on_tape(id: usize, value: &ArrayD<f64>) {
ACTIVE_TAPE.with(|cell| {
if let Some(tape_arc) = cell.borrow().as_ref() {
if let Ok(mut tape) = tape_arc.lock() {
tape.values.insert(id, value.clone());
}
}
});
}
pub struct DynamicGraph;
impl DynamicGraph {
pub fn new() -> Self {
DynamicGraph
}
pub fn tensor(&self, value: ArrayD<f64>) -> DynTensor {
DynTensor::new(value)
}
pub fn add(&self, a: &DynTensor, b: &DynTensor) -> DynTensor {
let out_val = a.value() + &b.value();
let out = DynTensor::new(out_val.clone());
store_value_on_tape(out.id(), &out_val);
let (a_id, b_id) = (a.id(), b.id());
record_op_on_tape(TapeOp {
output_id: out.id(),
input_ids: vec![a_id, b_id],
backward_fn: Box::new(move |grad: &ArrayD<f64>| {
vec![grad.clone(), grad.clone()]
}),
name: "add".to_string(),
});
out
}
pub fn sub(&self, a: &DynTensor, b: &DynTensor) -> DynTensor {
let out_val = a.value() - &b.value();
let out = DynTensor::new(out_val.clone());
store_value_on_tape(out.id(), &out_val);
let (a_id, b_id) = (a.id(), b.id());
record_op_on_tape(TapeOp {
output_id: out.id(),
input_ids: vec![a_id, b_id],
backward_fn: Box::new(move |grad: &ArrayD<f64>| {
vec![grad.clone(), grad.mapv(|v| -v)]
}),
name: "sub".to_string(),
});
out
}
pub fn mul(&self, a: &DynTensor, b: &DynTensor) -> DynTensor {
let (a_val, b_val) = (a.value(), b.value());
let out_val = &a_val * &b_val;
let out = DynTensor::new(out_val.clone());
store_value_on_tape(out.id(), &out_val);
let (a_id, b_id) = (a.id(), b.id());
let (a_v, b_v) = (a_val.clone(), b_val.clone());
record_op_on_tape(TapeOp {
output_id: out.id(),
input_ids: vec![a_id, b_id],
backward_fn: Box::new(move |grad: &ArrayD<f64>| {
vec![grad * &b_v, grad * &a_v]
}),
name: "mul".to_string(),
});
out
}
pub fn div(&self, a: &DynTensor, b: &DynTensor) -> DynTensor {
let (a_val, b_val) = (a.value(), b.value());
let out_val = &a_val / &b_val;
let out = DynTensor::new(out_val.clone());
store_value_on_tape(out.id(), &out_val);
let (a_id, b_id) = (a.id(), b.id());
let (a_v, b_v) = (a_val.clone(), b_val.clone());
record_op_on_tape(TapeOp {
output_id: out.id(),
input_ids: vec![a_id, b_id],
backward_fn: Box::new(move |grad: &ArrayD<f64>| {
let da = grad / &b_v;
let db = &(grad * &a_v) * &b_v.mapv(|v| -1.0 / (v * v));
vec![da, db]
}),
name: "div".to_string(),
});
out
}
pub fn add_scalar(&self, a: &DynTensor, s: f64) -> DynTensor {
let out_val = a.value().mapv(|v| v + s);
let out = DynTensor::new(out_val.clone());
store_value_on_tape(out.id(), &out_val);
let a_id = a.id();
record_op_on_tape(TapeOp {
output_id: out.id(),
input_ids: vec![a_id],
backward_fn: Box::new(move |grad: &ArrayD<f64>| vec![grad.clone()]),
name: "add_scalar".to_string(),
});
out
}
pub fn mul_scalar(&self, a: &DynTensor, s: f64) -> DynTensor {
let out_val = a.value().mapv(|v| v * s);
let out = DynTensor::new(out_val.clone());
store_value_on_tape(out.id(), &out_val);
let a_id = a.id();
record_op_on_tape(TapeOp {
output_id: out.id(),
input_ids: vec![a_id],
backward_fn: Box::new(move |grad: &ArrayD<f64>| {
vec![grad.mapv(|v| v * s)]
}),
name: "mul_scalar".to_string(),
});
out
}
pub fn exp(&self, a: &DynTensor) -> DynTensor {
let a_val = a.value();
let out_val = a_val.mapv(f64::exp);
let out = DynTensor::new(out_val.clone());
store_value_on_tape(out.id(), &out_val);
let a_id = a.id();
let out_v = out_val.clone();
record_op_on_tape(TapeOp {
output_id: out.id(),
input_ids: vec![a_id],
backward_fn: Box::new(move |grad: &ArrayD<f64>| {
vec![grad * &out_v]
}),
name: "exp".to_string(),
});
out
}
pub fn ln(&self, a: &DynTensor) -> DynTensor {
let a_val = a.value();
let out_val = a_val.mapv(f64::ln);
let out = DynTensor::new(out_val.clone());
store_value_on_tape(out.id(), &out_val);
let a_id = a.id();
let a_v = a_val.clone();
record_op_on_tape(TapeOp {
output_id: out.id(),
input_ids: vec![a_id],
backward_fn: Box::new(move |grad: &ArrayD<f64>| {
vec![grad / &a_v]
}),
name: "ln".to_string(),
});
out
}
pub fn tanh(&self, a: &DynTensor) -> DynTensor {
let a_val = a.value();
let out_val = a_val.mapv(f64::tanh);
let out = DynTensor::new(out_val.clone());
store_value_on_tape(out.id(), &out_val);
let a_id = a.id();
let out_v = out_val.clone();
record_op_on_tape(TapeOp {
output_id: out.id(),
input_ids: vec![a_id],
backward_fn: Box::new(move |grad: &ArrayD<f64>| {
let d = out_v.mapv(|v| 1.0 - v * v);
vec![grad * &d]
}),
name: "tanh".to_string(),
});
out
}
pub fn relu(&self, a: &DynTensor) -> DynTensor {
let a_val = a.value();
let out_val = a_val.mapv(|v| if v > 0.0 { v } else { 0.0 });
let out = DynTensor::new(out_val.clone());
store_value_on_tape(out.id(), &out_val);
let a_id = a.id();
let a_v = a_val.clone();
record_op_on_tape(TapeOp {
output_id: out.id(),
input_ids: vec![a_id],
backward_fn: Box::new(move |grad: &ArrayD<f64>| {
let mask = a_v.mapv(|v| if v > 0.0 { 1.0 } else { 0.0 });
vec![grad * &mask]
}),
name: "relu".to_string(),
});
out
}
pub fn sum_all(&self, a: &DynTensor) -> DynTensor {
let a_val = a.value();
let s: f64 = a_val.iter().sum();
let out_val = Array::from_elem(IxDyn(&[]), s);
let out = DynTensor::new(out_val.clone());
store_value_on_tape(out.id(), &out_val);
let a_id = a.id();
let a_shape = a_val.shape().to_vec();
record_op_on_tape(TapeOp {
output_id: out.id(),
input_ids: vec![a_id],
backward_fn: Box::new(move |grad: &ArrayD<f64>| {
let g_scalar = grad.iter().next().copied().unwrap_or(1.0);
vec![Array::from_elem(a_shape.as_slice(), g_scalar)]
}),
name: "sum_all".to_string(),
});
out
}
pub fn mean(&self, a: &DynTensor) -> DynTensor {
let a_val = a.value();
let n = a_val.len() as f64;
let m: f64 = a_val.iter().sum::<f64>() / n;
let out_val = Array::from_elem(IxDyn(&[]), m);
let out = DynTensor::new(out_val.clone());
store_value_on_tape(out.id(), &out_val);
let a_id = a.id();
let a_shape = a_val.shape().to_vec();
record_op_on_tape(TapeOp {
output_id: out.id(),
input_ids: vec![a_id],
backward_fn: Box::new(move |grad: &ArrayD<f64>| {
let g_scalar = grad.iter().next().copied().unwrap_or(1.0);
vec![Array::from_elem(a_shape.as_slice(), g_scalar / n)]
}),
name: "mean".to_string(),
});
out
}
pub fn if_else<ThenFn, ElseFn>(
&self,
cond: bool,
then_fn: ThenFn,
else_fn: ElseFn,
) -> DynTensor
where
ThenFn: FnOnce() -> DynTensor,
ElseFn: FnOnce() -> DynTensor,
{
if cond {
then_fn()
} else {
else_fn()
}
}
pub fn while_loop<CondFn, BodyFn>(
&self,
cond_fn: CondFn,
body_fn: BodyFn,
init: DynTensor,
max_iters: usize,
) -> DynTensor
where
CondFn: Fn(&DynTensor) -> bool,
BodyFn: Fn(DynTensor) -> DynTensor,
{
let mut state = init;
let mut iters = 0usize;
while cond_fn(&state) && iters < max_iters {
state = body_fn(state);
iters += 1;
}
state
}
pub fn scan<StepFn>(
&self,
fn_step: StepFn,
init: DynTensor,
xs: Vec<DynTensor>,
) -> (DynTensor, DynTensor)
where
StepFn: Fn(DynTensor, DynTensor) -> (DynTensor, DynTensor),
{
let mut carry = init;
let mut outputs: Vec<ArrayD<f64>> = Vec::with_capacity(xs.len());
let mut output_tensors: Vec<DynTensor> = Vec::with_capacity(xs.len());
for x in xs {
let (new_carry, out_i) = fn_step(carry, x);
outputs.push(out_i.value());
output_tensors.push(out_i);
carry = new_carry;
}
let stacked = if outputs.is_empty() {
DynTensor::new(Array::zeros(IxDyn(&[0])))
} else {
let single_shape = outputs[0].shape().to_vec();
let n = outputs.len();
let mut stacked_shape = vec![n];
stacked_shape.extend_from_slice(&single_shape);
let mut data = Vec::with_capacity(n * outputs[0].len());
for row in &outputs {
data.extend_from_slice(
row.as_slice().unwrap_or(&[]),
);
}
let stacked_val = Array::from_shape_vec(stacked_shape.as_slice(), data)
.unwrap_or_else(|_| Array::zeros(IxDyn(&[n])));
DynTensor::new(stacked_val)
};
(carry, stacked)
}
}
impl Default for DynamicGraph {
fn default() -> Self {
Self::new()
}
}
pub fn with_tape<F, T>(f: F) -> (T, GradientTape)
where
F: FnOnce() -> T,
{
let tape_inner = Arc::new(Mutex::new(TapeInner::new()));
ACTIVE_TAPE.with(|cell| {
*cell.borrow_mut() = Some(tape_inner.clone());
});
let result = f();
ACTIVE_TAPE.with(|cell| {
*cell.borrow_mut() = None;
});
let tape = GradientTape {
inner: tape_inner,
};
(result, tape)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
fn vec_tensor(v: Vec<f64>) -> ArrayD<f64> {
Array1::from(v).into_dyn()
}
fn scalar_tensor(v: f64) -> ArrayD<f64> {
Array::from_elem(IxDyn(&[]), v)
}
#[test]
fn test_add_gradient() {
let g = DynamicGraph::new();
let (out, tape) = with_tape(|| {
let a = g.tensor(vec_tensor(vec![1.0, 2.0]));
let b = g.tensor(vec_tensor(vec![3.0, 4.0]));
let c = g.add(&a, &b);
let a2 = g.tensor(vec_tensor(vec![1.0, 2.0]));
let b2 = g.tensor(vec_tensor(vec![3.0, 4.0]));
let _ = g.add(&a2, &b2);
c
});
let vals: Vec<f64> = out.value().iter().copied().collect();
assert!((vals[0] - 4.0).abs() < 1e-10);
assert!((vals[1] - 6.0).abs() < 1e-10);
assert!(tape.operations_count() >= 1);
}
#[test]
fn test_mul_scalar_gradient() {
let g = DynamicGraph::new();
let (out, tape) = with_tape(|| {
let x = g.tensor(vec_tensor(vec![1.0, 2.0, 3.0]));
g.mul_scalar(&x, 5.0)
});
let vals: Vec<f64> = out.value().iter().copied().collect();
assert!((vals[0] - 5.0).abs() < 1e-10);
assert!((vals[2] - 15.0).abs() < 1e-10);
let x2 = g.tensor(vec_tensor(vec![1.0, 2.0, 3.0]));
let (y, tape2) = with_tape(|| {
let x_inner = g.tensor(vec_tensor(vec![1.0, 2.0, 3.0]));
let s = g.sum_all(&x_inner);
g.mul_scalar(&s, 2.0)
});
let grads = tape2.gradient(&y, &[&x2]);
assert!(grads.is_ok());
let _ = tape;
}
#[test]
fn test_if_else_then_branch() {
let g = DynamicGraph::new();
let result = g.if_else(
true,
|| g.tensor(scalar_tensor(1.0)),
|| g.tensor(scalar_tensor(2.0)),
);
assert!((result.scalar_value().unwrap_or(0.0) - 1.0).abs() < 1e-10);
}
#[test]
fn test_if_else_else_branch() {
let g = DynamicGraph::new();
let result = g.if_else(
false,
|| g.tensor(scalar_tensor(1.0)),
|| g.tensor(scalar_tensor(2.0)),
);
assert!((result.scalar_value().unwrap_or(0.0) - 2.0).abs() < 1e-10);
}
#[test]
fn test_while_loop_basic() {
let g = DynamicGraph::new();
let (result, _tape) = with_tape(|| {
let init = g.tensor(Array1::from(vec![0.0_f64]).into_dyn());
g.while_loop(
|s: &DynTensor| s.scalar_value().map(|v| v < 5.0).unwrap_or(false),
|s: DynTensor| g.add_scalar(&s, 1.0),
init,
1000,
)
});
assert!((result.scalar_value().unwrap_or(0.0) - 5.0).abs() < 1e-10);
}
#[test]
fn test_while_loop_max_iters() {
let g = DynamicGraph::new();
let (result, _tape) = with_tape(|| {
let init = g.tensor(Array1::from(vec![0.0_f64]).into_dyn());
g.while_loop(
|_s: &DynTensor| true,
|s: DynTensor| g.add_scalar(&s, 1.0),
init,
10,
)
});
assert!((result.scalar_value().unwrap_or(0.0) - 10.0).abs() < 1e-10);
}
#[test]
fn test_scan_cumsum() {
let g = DynamicGraph::new();
let (carry, ys) = with_tape(|| {
let init = g.tensor(Array1::from(vec![0.0_f64]).into_dyn());
let xs: Vec<_> = (1..=4_i32)
.map(|i| g.tensor(Array1::from(vec![i as f64]).into_dyn()))
.collect();
g.scan(
|carry, x| {
let new_carry = g.add(&carry, &x);
let out = new_carry.clone();
(new_carry, out)
},
init,
xs,
)
})
.0;
assert!((carry.scalar_value().unwrap_or(0.0) - 10.0).abs() < 1e-10);
assert_eq!(ys.shape()[0], 4);
}
#[test]
fn test_scan_empty_xs() {
let g = DynamicGraph::new();
let (carry, ys) = with_tape(|| {
let init = g.tensor(Array1::from(vec![5.0_f64]).into_dyn());
g.scan(
|carry, x| (g.add(&carry, &x), x),
init,
vec![],
)
})
.0;
assert!((carry.scalar_value().unwrap_or(0.0) - 5.0).abs() < 1e-10);
assert_eq!(ys.shape()[0], 0);
}
#[test]
fn test_exp_ln_round_trip() {
let g = DynamicGraph::new();
let (out, _tape) = with_tape(|| {
let x = g.tensor(Array1::from(vec![1.0_f64, 2.0, 3.0]).into_dyn());
let e = g.exp(&x);
g.ln(&e)
});
let vals: Vec<f64> = out.value().iter().copied().collect();
assert!((vals[0] - 1.0).abs() < 1e-9);
assert!((vals[1] - 2.0).abs() < 1e-9);
assert!((vals[2] - 3.0).abs() < 1e-9);
}
#[test]
fn test_tape_gradient_sum_squared() {
let g = DynamicGraph::new();
let x_data = vec_tensor(vec![1.0, 2.0, 3.0]);
let x_leaf = g.tensor(x_data.clone());
let (loss, tape) = with_tape(|| {
let x = g.tensor(x_data.clone());
let sq = g.mul(&x, &x);
g.sum_all(&sq)
});
assert!((loss.scalar_value().unwrap_or(0.0) - 14.0).abs() < 1e-10);
let grads = tape.gradient(&loss, &[&x_leaf]);
assert!(grads.is_ok());
}
#[test]
fn test_relu_zero_at_negative() {
let g = DynamicGraph::new();
let (out, _tape) = with_tape(|| {
let x = g.tensor(Array1::from(vec![-1.0_f64, 0.0, 2.0]).into_dyn());
g.relu(&x)
});
let vals: Vec<f64> = out.value().iter().copied().collect();
assert!((vals[0]).abs() < 1e-10);
assert!((vals[1]).abs() < 1e-10);
assert!((vals[2] - 2.0).abs() < 1e-10);
}
#[test]
fn test_with_tape_returns_ops() {
let g = DynamicGraph::new();
let (_, tape) = with_tape(|| {
let x = g.tensor(vec_tensor(vec![1.0]));
let y = g.add_scalar(&x, 2.0);
let z = g.mul_scalar(&y, 3.0);
g.exp(&z)
});
assert!(tape.operations_count() >= 3);
}
}