use crate::graph::{node::Node, tape::Tape};
use crate::tensor::{constant::Constant, Tensor};
use arrayfire::Array;
use std::rc::Rc;
pub struct Variable<const B: u64, const L: u64, const R: u64, const C: u64> {
tape: Tape,
node: Rc<Node>,
}
impl<const B: u64, const L: u64, const R: u64, const C: u64> Variable<B, L, R, C> {
pub(crate) fn new(mut tape: Tape, node: Node) -> Self {
let node = Rc::new(node);
tape.push(node.clone());
Self { tape, node }
}
pub(crate) const fn tape(&self) -> &Tape {
&self.tape
}
pub fn backward(&self) {
self.node.ones_grad();
for node in self.tape.nodes().rev() {
node.reverse();
}
}
pub fn reset(&self) {
for node in self.tape.nodes().rev() {
node.zero_grad();
}
}
pub fn grad(&self) -> Self {
Self::new(Tape::default(), Node::declaration(self.node.grad().clone()))
}
pub fn freeze(self) -> Constant<B, L, R, C> {
Constant::new(self.data())
}
}
impl<const B: u64, const L: u64, const R: u64, const C: u64> From<&Variable<B, L, R, C>>
for Rc<Node>
{
#[inline]
fn from(tensor: &Variable<B, L, R, C>) -> Self {
tensor.node.clone()
}
}
impl<const B: u64, const L: u64, const R: u64, const C: u64> Tensor<B, L, R, C>
for Variable<B, L, R, C>
{
fn data(&self) -> Array<f32> {
self.node.data().clone()
}
}
#[cfg(test)]
mod tests {
use super::Variable;
use crate::graph::{node::Node, tape::Tape};
use crate::tensor::Tensor;
use crate::tests::equal_arrays;
#[test]
fn new_variable() {
let tensor = Variable::<3, 4, 2, 1>::new(
Tape::default(),
Node::declaration(arrayfire::constant!(5.0; 2,1,4,3)),
);
assert!(equal_arrays(
tensor.data(),
arrayfire::constant!(5.0; 2,1,4,3)
));
assert_eq!(tensor.node.id(), 0);
}
#[test]
fn variable_freeze() {
let tensor = Variable::<3, 4, 2, 1>::new(
Tape::default(),
Node::declaration(arrayfire::constant!(5.0; 2,1,4,3)),
)
.freeze();
assert!(equal_arrays(
(&tensor).into(),
arrayfire::constant!(5.0; 2,1,4,3)
));
}
#[test]
fn variable_backward() {
let tensor = Variable::<3, 4, 2, 1>::new(
Tape::default(),
Node::declaration(arrayfire::constant!(5.0; 2,1,4,3)),
);
tensor.backward();
assert!(equal_arrays(
tensor.grad().data(),
arrayfire::constant!(1.0; 2,1,4,3)
));
}
}