use arrayfire::{constant, Array};
use std::cell::{Ref, RefCell, RefMut};
use std::rc::Rc;
use std::sync::atomic::{AtomicUsize, Ordering};
static COUNTER: AtomicUsize = AtomicUsize::new(0);
#[allow(clippy::module_name_repetitions)]
pub type NodeId = usize;
enum Origin {
Declaration,
Unary(UnaryOp),
Binary(BinaryOp),
}
pub struct Node {
data: RefCell<Array<f32>>,
grad: RefCell<Array<f32>>,
origin: Origin,
id: NodeId,
}
impl Node {
fn new(data: Array<f32>, origin: Origin) -> Self {
let dims = data.dims();
Self {
data: RefCell::new(data),
grad: RefCell::new(constant(0.0, dims)),
origin,
id: COUNTER.fetch_add(1, Ordering::Relaxed),
}
}
pub(crate) fn declaration(data: Array<f32>) -> Self {
Self::new(data, Origin::Declaration)
}
pub(crate) fn unary(data: Array<f32>, param: Rc<Self>, reverse: UnaryReverseFn) -> Self {
Self::new(data, Origin::Unary(UnaryOp { param, reverse }))
}
pub(crate) fn binary_varvar(
data: Array<f32>,
params: (Rc<Self>, Rc<Self>),
reverse: BinaryReverseFn,
) -> Self {
Self::new(
data,
Origin::Binary(BinaryOp {
params: BinaryParams::VarVar(params.0, params.1),
reverse,
}),
)
}
pub(crate) fn binary_varconst(
data: Array<f32>,
params: (Rc<Self>, Array<f32>),
reverse: BinaryReverseFn,
) -> Self {
Self::new(
data,
Origin::Binary(BinaryOp {
params: BinaryParams::VarConst(params.0, params.1),
reverse,
}),
)
}
pub(crate) fn binary_constvar(
data: Array<f32>,
params: (Array<f32>, Rc<Self>),
reverse: BinaryReverseFn,
) -> Self {
Self::new(
data,
Origin::Binary(BinaryOp {
params: BinaryParams::ConstVar(params.0, params.1),
reverse,
}),
)
}
pub(crate) fn data(&self) -> Ref<Array<f32>> {
self.data.borrow()
}
pub(crate) fn data_mut(&self) -> RefMut<Array<f32>> {
self.data.borrow_mut()
}
pub(crate) fn grad(&self) -> Ref<Array<f32>> {
self.grad.borrow()
}
pub(crate) fn grad_mut(&self) -> RefMut<Array<f32>> {
self.grad.borrow_mut()
}
pub(crate) fn reverse(&self) {
match self.origin {
Origin::Unary(ref op) => {
op.reverse(&self.grad());
}
Origin::Binary(ref op) => {
op.reverse(&self.grad());
}
Origin::Declaration => {}
}
}
pub(crate) fn ones_grad(&self) {
let dims = self.grad().dims();
*self.grad_mut() = constant(1.0, dims);
}
pub(crate) fn zero_grad(&self) {
let dims = self.grad().dims();
*self.grad_mut() = constant(0.0, dims);
}
pub(crate) const fn id(&self) -> NodeId {
self.id
}
pub(crate) const fn is_declaration(&self) -> bool {
matches!(self.origin, Origin::Declaration)
}
}
impl Drop for Node {
fn drop(&mut self) {
COUNTER.fetch_sub(1, Ordering::Relaxed);
}
}
enum BinaryParams {
VarVar(Rc<Node>, Rc<Node>),
VarConst(Rc<Node>, Array<f32>),
ConstVar(Array<f32>, Rc<Node>),
}
pub type UnaryReverseFn = fn(&Array<f32>, &Array<f32>) -> Array<f32>;
pub type BinaryReverseFn = fn(&Array<f32>, &Array<f32>, &Array<f32>) -> (Array<f32>, Array<f32>);
struct UnaryOp {
param: Rc<Node>,
reverse: UnaryReverseFn,
}
impl UnaryOp {
fn reverse(&self, df: &Array<f32>) {
let partial = (self.reverse)(df, &self.param.data());
*self.param.grad_mut() += partial;
}
}
struct BinaryOp {
params: BinaryParams,
reverse: BinaryReverseFn,
}
impl BinaryOp {
fn reverse(&self, df: &Array<f32>) {
match self.params {
BinaryParams::VarVar(ref param_a, ref param_b) => {
let (partial_a, partial_b) = (self.reverse)(df, ¶m_a.data(), ¶m_b.data());
*param_a.grad_mut() += partial_a;
*param_b.grad_mut() += partial_b;
}
BinaryParams::VarConst(ref param_a, ref param_b) => {
let (partial, _) = (self.reverse)(df, ¶m_a.data(), param_b);
*param_a.grad_mut() += partial;
}
BinaryParams::ConstVar(ref param_a, ref param_b) => {
let (_, partial) = (self.reverse)(df, param_a, ¶m_b.data());
*param_b.grad_mut() += partial;
}
}
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::{Node, Origin};
use crate::tests::equal_arrays;
#[test]
fn new_node() {
let node = Node::new(arrayfire::constant!(2.0; 1,2,3,4), Origin::Declaration);
assert!(equal_arrays(
node.data().clone(),
arrayfire::constant!(2.0; 1,2,3,4)
));
assert!(equal_arrays(
node.grad().clone(),
arrayfire::constant!(0.0; 1,2,3,4)
));
assert!(matches!(node.origin, Origin::Declaration));
assert_eq!(node.id(), 0);
}
#[test]
fn node_sequentially_reused_unique_ids() {
let node = Node::new(arrayfire::constant!(2.0; 1,2,3,4), Origin::Declaration);
assert_eq!(node.id(), 0);
let node = Node::new(arrayfire::constant!(2.0; 1,2,3,4), Origin::Declaration);
assert_eq!(node.id(), 1);
{
let node = Node::new(arrayfire::constant!(2.0; 1,2,3,4), Origin::Declaration);
assert_eq!(node.id(), 2);
}
let node = Node::new(arrayfire::constant!(2.0; 1,2,3,4), Origin::Declaration);
assert_eq!(node.id(), 2);
let node = Node::new(arrayfire::constant!(2.0; 1,2,3,4), Origin::Declaration);
assert_eq!(node.id(), 3);
}
#[test]
fn ones_grad() {
let node = Node::new(arrayfire::constant!(2.0; 1,2,3,4), Origin::Declaration);
node.ones_grad();
assert!(equal_arrays(
node.grad().clone(),
arrayfire::constant!(1.0; 1,2,3,4)
));
}
#[test]
fn zero_grad() {
let node = Node::new(arrayfire::constant!(2.0; 1,2,3,4), Origin::Declaration);
node.zero_grad();
assert!(equal_arrays(
node.grad().clone(),
arrayfire::constant!(0.0; 1,2,3,4)
));
}
}