mod node;
mod var;
mod vardiff;
use ndarray::{ArrayViewMutD, Dimension, Ix, RawArrayViewMut};
use std::{
cell::{Ref, RefCell},
collections::{BTreeMap, HashSet},
hash::{Hash, Hasher},
rc::Rc,
};
pub use var::Var;
pub use vardiff::VarDiff;
pub(crate) use node::*;
pub use node::{
Backward, Constant, Convolve, ConvolveWithGroups, Data, Eval, Forward, Gradient, Input,
InputBackward, Overwrite, PaddingMode, Reflective, Replicative, Zero,
};
pub(crate) struct OperationsCounter {
count: usize,
}
impl OperationsCounter {
pub fn next(&mut self) -> usize {
self.count += 1;
self.count
}
}
pub(crate) static mut OPERATIONS_COUNTER: OperationsCounter = OperationsCounter { count: 0 };
#[derive(Clone)]
pub struct VarHistory {
path: BTreeMap<usize, Rc<dyn Forward>>,
buffer: RefCell<Vec<Rc<dyn Forward>>>,
changeables: HashSet<Changeable>,
}
impl VarHistory {
pub(crate) fn new() -> Self {
Self {
path: BTreeMap::new(),
buffer: RefCell::new(Vec::new()),
changeables: HashSet::new(),
}
}
pub(crate) fn merge(&mut self, mut other: VarHistory) {
self.path.append(&mut other.path);
}
pub(crate) fn append_forward(&mut self, id: usize, next: Rc<dyn Forward>) {
self.path.insert(id, next);
self.buffer.borrow_mut().truncate(0);
}
pub(crate) fn append_changeable(&mut self, next: Changeable) {
self.changeables.insert(next);
}
pub(crate) fn len(&self) -> usize {
self.path.len()
}
pub(crate) fn is_empty(&self) -> bool {
self.path.is_empty()
}
pub(crate) fn prepare_buffer(&self) {
if self.buffer.borrow().is_empty() {
*self.buffer.borrow_mut() = self.path.values().cloned().collect();
}
}
pub(crate) fn buffer(&self) -> Ref<[Rc<dyn Forward>]> {
Ref::map(self.buffer.borrow(), |vec| &vec[..])
}
}
#[derive(Clone)]
pub struct VarDiffHistory {
path: BTreeMap<usize, Rc<dyn Backward>>,
buffer: RefCell<Vec<Rc<dyn Backward>>>,
parameters: HashSet<RawParam>,
}
impl VarDiffHistory {
pub(crate) fn new(parameters: HashSet<RawParam>) -> Self {
Self {
path: BTreeMap::new(),
buffer: RefCell::new(Vec::new()),
parameters,
}
}
pub(crate) fn merge(&mut self, mut other: VarDiffHistory) {
self.path.append(&mut other.path);
self.parameters.extend(other.parameters);
}
pub(crate) fn append_backward(&mut self, id: usize, next: Rc<dyn Backward>) {
self.path.insert(id, next);
self.buffer.borrow_mut().truncate(0);
}
pub(crate) fn len(&self) -> usize {
self.path.len()
}
pub(crate) fn is_empty(&self) -> bool {
self.path.is_empty()
}
pub(crate) fn prepare_buffer(&self) {
if self.buffer.borrow().is_empty() {
*self.buffer.borrow_mut() = self.path.values().cloned().collect();
}
}
pub(crate) fn buffer(&self) -> Ref<[Rc<dyn Backward>]> {
Ref::map(self.buffer.borrow(), |vec| &vec[..])
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct RawParam {
data: *mut f32,
grad: *mut f32,
shape: Vec<Ix>,
}
impl RawParam {
pub(crate) fn new(data: *mut f32, grad: *mut f32, shape: Vec<Ix>) -> Self {
Self { data, grad, shape }
}
pub(crate) fn into_param<'a>(self) -> Param<'a> {
let shape = self.shape;
unsafe {
let raw_data = RawArrayViewMut::from_shape_ptr(shape.clone(), self.data);
let raw_grad = RawArrayViewMut::from_shape_ptr(shape, self.grad);
let data = raw_data.deref_into_view_mut();
let grad = raw_grad.deref_into_view_mut();
Param { data, grad }
}
}
}
#[derive(Debug)]
pub struct Param<'a> {
pub data: ArrayViewMutD<'a, f32>,
pub grad: ArrayViewMutD<'a, f32>,
}
#[derive(Clone)]
pub(super) struct Changeable {
id: usize,
node: Rc<dyn Eval>,
}
impl PartialEq for Changeable {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
}
}
impl Eq for Changeable {}
impl Hash for Changeable {
fn hash<H: Hasher>(&self, state: &mut H) {
self.id.hash(state);
}
}
pub trait MatMatMul<Rhs> {
type Output;
fn mm(self, other: Rhs) -> Self::Output;
}
pub trait MatMatMulT<Rhs> {
type Output;
fn mm_t(self, other: Rhs) -> Self::Output;
}
pub trait MatVecMul<Rhs> {
type Output;
fn mv(self, other: Rhs) -> Self::Output;
}
pub trait VecMatMul<Rhs> {
type Output;
fn vm(self, other: Rhs) -> Self::Output;
}
pub trait VecVecMul<Rhs> {
type Output;
fn vv(self, other: Rhs) -> Self::Output;
}
pub trait Cat<Rhs> {
type Output;
fn cat(self, other: Rhs, axis: usize) -> Self::Output;
}
pub trait Stack<Rhs> {
type Output;
fn stack(self, other: Rhs, axis: usize) -> Self::Output;
}
pub trait Variable<D: Dimension> {
fn get_node(&self) -> Rc<dyn Data<Dim = D>>;
fn get_past(&self) -> VarHistory;
}
impl<T: Data<Dim = D>, D: Dimension> Variable<D> for Var<T> {
fn get_node(&self) -> Rc<dyn Data<Dim = D>> {
self.node.clone()
}
fn get_past(&self) -> VarHistory {
self.past.clone()
}
}
pub trait DifferentiableVariable<D: Dimension> {
fn get_var(&self) -> Box<dyn Variable<D>>;
fn get_node(&self) -> Rc<dyn GradientOverwrite<D>>;
fn get_past(&self) -> VarDiffHistory;
}
impl<T: Data<Dim = D>, U: GradientOverwrite<D>, D: Dimension> DifferentiableVariable<D>
for VarDiff<T, U>
{
fn get_var(&self) -> Box<dyn Variable<D>> {
Box::new(self.var.clone())
}
fn get_node(&self) -> Rc<dyn GradientOverwrite<D>> {
self.node.clone()
}
fn get_past(&self) -> VarDiffHistory {
self.past.clone()
}
}
#[cfg(test)]
mod test;