use std::collections::HashMap;
use std::{boxed::Box, vec::Vec};
use crate::arrays::{HasArrayData, HasArrayType};
use crate::devices::{AllocateZeros, HasDevice};
use crate::unique_id::{HasUniqueId, UniqueId};
#[derive(Default)]
#[allow(clippy::type_complexity)]
pub struct GradientTape {
operations: Vec<Box<dyn FnOnce(&mut Gradients)>>,
}
impl std::fmt::Debug for GradientTape {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GradientTape")
.field("num_operations", &self.operations.len())
.finish()
}
}
impl GradientTape {
pub(crate) fn add_backward_op<F: 'static + FnOnce(&mut Gradients)>(&mut self, operation: F) {
self.operations.push(Box::new(operation));
}
pub fn execute(mut self) -> Gradients {
let mut gradients: Gradients = Default::default();
for operation in self.operations.drain(..).rev() {
(operation)(&mut gradients);
}
gradients
}
pub fn append(&mut self, other: &mut Self) {
self.operations.append(&mut other.operations);
}
}
#[derive(Default, Debug)]
pub struct OwnedTape(pub(crate) Box<GradientTape>);
#[derive(Default, Debug, Clone, Copy)]
pub struct NoneTape;
pub trait Tape: Merge<Self> + Merge<NoneTape> + Default {
const OWNS_TAPE: bool;
fn add_backward_op<F: 'static + FnOnce(&mut Gradients)>(&mut self, operation: F);
}
impl Tape for OwnedTape {
const OWNS_TAPE: bool = true;
fn add_backward_op<F: 'static + FnOnce(&mut Gradients)>(&mut self, operation: F) {
self.0.add_backward_op(operation)
}
}
impl Tape for NoneTape {
const OWNS_TAPE: bool = false;
fn add_backward_op<F: 'static + FnOnce(&mut Gradients)>(&mut self, _operation: F) {}
}
pub trait Merge<T: ?Sized> {
fn merge(self, other: T) -> Self;
}
impl Merge<NoneTape> for NoneTape {
fn merge(self, _: NoneTape) -> Self {
self
}
}
impl Merge<NoneTape> for OwnedTape {
fn merge(self, _: NoneTape) -> Self {
self
}
}
impl Merge<OwnedTape> for OwnedTape {
fn merge(mut self, mut other: Self) -> Self {
self.0.append(other.0.as_mut());
self
}
}
#[derive(Debug, Default)]
pub struct Gradients {
gradient_by_id: HashMap<UniqueId, Box<dyn std::any::Any>>,
}
impl Gradients {
pub fn mut_and_ref<L, R>(&mut self, l: &L, r: &R) -> (&mut L::Array, &R::Array)
where
L: HasUniqueId + HasArrayType + HasDevice,
R: HasUniqueId + HasArrayType,
{
assert_ne!(l.id(), r.id());
let l_ptr = self.mut_gradient(l) as *mut L::Array;
let r_ptr = self.ref_gradient(r) as *const R::Array;
let l_ref = unsafe { &mut *l_ptr };
let r_ref = unsafe { &*r_ptr };
(l_ref, r_ref)
}
pub fn muts_and_ref<L1, L2, L3, R>(
&mut self,
l1: &L1,
l2: &L2,
l3: &L3,
r: &R,
) -> (&mut L1::Array, &mut L2::Array, &mut L3::Array, &R::Array)
where
L1: HasUniqueId + HasArrayType + HasDevice,
L2: HasUniqueId + HasArrayType + HasDevice,
L3: HasUniqueId + HasArrayType + HasDevice,
R: HasUniqueId + HasArrayType,
{
let l1_ptr = self.mut_gradient(l1) as *mut L1::Array;
let l2_ptr = self.mut_gradient(l2) as *mut L2::Array;
let l3_ptr = self.mut_gradient(l3) as *mut L3::Array;
let r_ptr = self.ref_gradient(r) as *const R::Array;
let l1_ref = unsafe { &mut *l1_ptr };
let l2_ref = unsafe { &mut *l2_ptr };
let l3_ref = unsafe { &mut *l3_ptr };
let r_ref = unsafe { &*r_ptr };
(l1_ref, l2_ref, l3_ref, r_ref)
}
pub fn remove<T: HasUniqueId + HasArrayType>(&mut self, t: &T) -> Option<Box<T::Array>> {
self.gradient_by_id
.remove_entry(t.id())
.map(|e| e.1.downcast().unwrap())
}
pub fn mut_gradient<T: HasUniqueId + HasArrayType + HasDevice>(
&mut self,
t: &T,
) -> &mut T::Array {
self.gradient_by_id
.entry(*t.id())
.or_insert_with(|| T::Device::zeros::<T::Array>())
.as_mut()
.downcast_mut()
.unwrap()
}
pub fn ref_gradient<T: HasUniqueId + HasArrayType>(&self, t: &T) -> &T::Array {
self.gradient_by_id
.get(t.id())
.unwrap()
.as_ref()
.downcast_ref()
.unwrap()
}
}
pub trait GradientProvider {
fn gradient<P>(&mut self, p: &P) -> Option<Box<P::Array>>
where
P: HasUniqueId + HasArrayType<Dtype = f32> + HasDevice + HasArrayData;
}
pub trait CanUpdateWithGradients {
fn update<G: GradientProvider>(&mut self, grads: &mut G, unused: &mut UnusedTensors);
}
#[derive(Debug, Default)]
pub struct UnusedTensors {
pub ids: Vec<UniqueId>,
}
impl UnusedTensors {
pub fn add<T: HasUniqueId>(&mut self, t: &T) {
self.ids.push(*t.id());
}
pub fn is_empty(&self) -> bool {
self.ids.is_empty()
}
pub fn len(&self) -> usize {
self.ids.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::devices::Cpu;
use crate::unique_id::unique_id;
struct Tensor {
id: UniqueId,
}
impl HasUniqueId for Tensor {
fn id(&self) -> &UniqueId {
&self.id
}
}
impl HasArrayType for Tensor {
type Array = [f32; 5];
type Dtype = f32;
}
impl HasDevice for Tensor {
type Device = Cpu;
}
#[test]
fn test_backward() {
let id = unique_id();
let t1: Tensor = Tensor { id };
let _t1: Tensor = Tensor { id };
let mut tape = GradientTape::default();
tape.add_backward_op(move |g| {
let t_grad = g.mut_gradient(&_t1);
for x in t_grad.iter_mut() {
*x += 1.0;
}
});
let g = tape.execute();
assert_eq!(g.ref_gradient(&t1), &[1.0; 5]);
}
}