#![allow(clippy::type_complexity)]
use std::collections::{BTreeMap, BTreeSet};
use std::{boxed::Box, vec::Vec};
use super::tensorlike::Tensorlike;
use super::{storage_traits::Storage, unique_id, Tensor, UniqueId};
use crate::shapes::Shape;
#[derive(Clone, Debug)]
pub struct Gradients<E, D: Storage<E>> {
gradient_by_id: BTreeMap<UniqueId, D::Vec>,
leaf_ids: Option<BTreeSet<UniqueId>>,
}
impl<E, D: Storage<E>> Gradients<E, D> {
pub fn leaky() -> Self {
Self {
gradient_by_id: Default::default(),
leaf_ids: None,
}
}
}
impl<E, D: Storage<E>> Gradients<E, D> {
pub fn get_or_alloc_mut<S: Shape>(
&mut self,
t: &impl Tensorlike<S, E, D>,
) -> Result<&mut D::Vec, D::Err> {
self.try_alloc_for(t)?;
Ok(self.get_mut(t))
}
pub fn try_alloc_for<S: Shape>(&mut self, t: &impl Tensorlike<S, E, D>) -> Result<(), D::Err> {
if let std::collections::btree_map::Entry::Vacant(e) = self.gradient_by_id.entry(t.id()) {
e.insert(t.try_alloc_grad()?);
}
Ok(())
}
pub fn retain_leafs(&mut self, ids: &[UniqueId]) {
self.leaf_ids
.get_or_insert_with(Default::default)
.extend(ids);
self.drop_non_leafs();
}
pub fn drop_non_leafs(&mut self) {
if let Some(leafs) = &self.leaf_ids {
self.gradient_by_id.retain(|k, _| leafs.contains(k));
}
}
pub(crate) fn get_ref_checked<S: Shape, T>(&self, t: &Tensor<S, E, D, T>) -> Option<&D::Vec> {
self.gradient_by_id.get(&t.id)
}
pub(crate) fn get_mut<S: Shape>(&mut self, t: &impl Tensorlike<S, E, D>) -> &mut D::Vec {
self.gradient_by_id.get_mut(&t.id()).unwrap()
}
pub(crate) fn get_ref<S: Shape>(&mut self, t: &impl Tensorlike<S, E, D>) -> &D::Vec {
self.gradient_by_id.get(&t.id()).unwrap()
}
pub fn get<S: Shape>(&self, t: &impl Tensorlike<S, E, D>) -> Tensor<S, E, D> {
let buf = self.gradient_by_id.get(&t.id()).unwrap().clone();
Tensor {
id: unique_id(),
data: std::sync::Arc::new(buf),
shape: *t.shape(),
strides: t.strides(),
device: t.dev().clone(),
tape: Default::default(),
}
}
pub(crate) fn mut_and_ref<L: Shape, R: Shape>(
&mut self,
l: &impl Tensorlike<L, E, D>,
r: &impl Tensorlike<R, E, D>,
) -> (&mut D::Vec, &D::Vec) {
assert_ne!(l.id(), r.id());
let l_ptr = self.get_mut(l) as *mut _;
let r_ptr = self.get_ref(r) as *const _;
let l_ref = unsafe { &mut *l_ptr };
let r_ref = unsafe { &*r_ptr };
(l_ref, r_ref)
}
pub(crate) fn muts_and_ref<L1: Shape, L2: Shape, R: Shape>(
&mut self,
l1: &impl Tensorlike<L1, E, D>,
l2: &impl Tensorlike<L2, E, D>,
r: &impl Tensorlike<R, E, D>,
) -> (&mut D::Vec, &mut D::Vec, &D::Vec) {
assert_ne!(l1.id(), l2.id());
assert_ne!(l1.id(), r.id());
assert_ne!(l2.id(), r.id());
let l1_ptr = self.get_mut(l1) as *mut _;
let l2_ptr = self.get_mut(l2) as *mut _;
let r_ptr = self.get_ref(r) as *const _;
let l1_ref = unsafe { &mut *l1_ptr };
let l2_ref = unsafe { &mut *l2_ptr };
let r_ref = unsafe { &*r_ptr };
(l1_ref, l2_ref, r_ref)
}
#[inline]
pub(crate) fn many_and_ref<L: Shape, R: Shape>(
&mut self,
ls: &Vec<impl Tensorlike<L, E, D>>,
r: &impl Tensorlike<R, E, D>,
) -> (Vec<&mut D::Vec>, &D::Vec) {
for i in 0..ls.len() {
assert_ne!(ls[i].id(), r.id());
for j in (i + 1)..ls.len() {
assert_ne!(ls[i].id(), ls[j].id());
}
}
let l_refs: Vec<&mut D::Vec> = ls
.iter()
.map(|l| {
let l_ptr = self.get_mut(l) as *mut D::Vec;
unsafe { &mut *l_ptr }
})
.collect();
let r_ptr = self.get_ref(r) as *const _;
let r_ref = unsafe { &*r_ptr };
(l_refs, r_ref)
}
}
pub struct OwnedTape<E, D: Storage<E>> {
pub(crate) operations: Vec<(UniqueId, BackwardOp<E, D, D::Err>)>,
pub(crate) gradients: Gradients<E, D>,
}
impl<E, D: Storage<E>> Default for OwnedTape<E, D> {
fn default() -> Self {
Self {
operations: Default::default(),
gradients: Gradients::leaky(),
}
}
}
impl<E: std::fmt::Debug, D: Storage<E>> std::fmt::Debug for OwnedTape<E, D> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OwnedTape")
.field("num_operations", &self.operations.len())
.field("gradients", &self.gradients)
.finish()
}
}
impl<E, D: Storage<E>> OwnedTape<E, D> {
pub(crate) fn execute(mut self) -> Result<Gradients<E, D>, D::Err> {
self.operations.sort_by_key(|(k, _)| *k);
for (_, operation) in self.operations.drain(..).rev() {
(operation)(&mut self.gradients)?;
}
Ok(self.gradients)
}
}
type BackwardOp<E, D, Err> = Box<dyn FnOnce(&mut Gradients<E, D>) -> Result<(), Err>>;
#[derive(Default, Debug, Clone, Copy)]
pub struct NoneTape;
pub trait Tape<E, D: Storage<E>>: Default + Merge<Self> + Merge<NoneTape> {
const OWNS_TAPE: bool;
fn add_backward_op<F>(&mut self, operation: F)
where
F: 'static + FnOnce(&mut Gradients<E, D>) -> Result<(), D::Err>;
}
impl<E, D: Storage<E>> Tape<E, D> for OwnedTape<E, D> {
const OWNS_TAPE: bool = true;
fn add_backward_op<F>(&mut self, operation: F)
where
F: 'static + FnOnce(&mut Gradients<E, D>) -> Result<(), D::Err>,
{
self.operations.push((unique_id(), Box::new(operation)));
}
}
impl<E, D: Storage<E>> Tape<E, D> for NoneTape {
const OWNS_TAPE: bool = false;
fn add_backward_op<F>(&mut self, _: F)
where
F: 'static + FnOnce(&mut Gradients<E, D>) -> Result<(), D::Err>,
{
}
}
pub trait Merge<T: ?Sized> {
fn merge(self, other: T) -> Self;
}
impl Merge<NoneTape> for NoneTape {
fn merge(self, _: NoneTape) -> Self {
self
}
}
impl<E, D: Storage<E>> Merge<NoneTape> for OwnedTape<E, D> {
fn merge(self, _: NoneTape) -> Self {
self
}
}
impl<E, D: Storage<E>> Merge<OwnedTape<E, D>> for OwnedTape<E, D> {
fn merge(mut self, mut other: Self) -> Self {
self.gradients
.gradient_by_id
.extend(other.gradients.gradient_by_id);
if let Some(leafs) = other.gradients.leaf_ids {
self.gradients
.leaf_ids
.get_or_insert_with(Default::default)
.extend(leafs);
}
self.operations.append(&mut other.operations);
self
}
}