use crate::{
core::{CoreAlgebra, HasDims},
error::{check_equal_dimensions, Error, Result},
store::{
GenericGradientMap1, GenericGradientMapN, GradientId, GradientStore, GraphArenaBehavior, Id,
},
};
use std::{collections::BinaryHeap, sync::Arc};
#[cfg(doc)]
use crate::prelude::*;
pub struct Graph<C: Config> {
nodes: id_arena::Arena<Node<C>, GraphArenaBehavior>,
eval: C::EvalAlgebra,
}
pub trait Config {
type EvalAlgebra: Default + Clone;
type GradientAlgebra;
type GradientStore;
}
#[derive(Clone, Debug, PartialEq, Default)]
pub struct Value<D> {
data: D,
id: Option<GradientId<D>>,
}
pub struct Node<C: Config> {
inputs: Vec<Option<Id>>,
update_func: Option<GradientUpdateFunc<C>>,
}
type GradientUpdateFunc<C> = Arc<
dyn Fn(
&mut <C as Config>::GradientAlgebra,
&mut <C as Config>::GradientStore,
Id,
) -> Result<()>
+ Send
+ Sync,
>;
impl<C: Config> Node<C> {
fn clear(&mut self) {
self.inputs.clear();
self.update_func = None;
}
}
impl<C: Config> Default for Graph<C> {
fn default() -> Self {
Self::new()
}
}
impl<C: Config> Graph<C> {
pub fn new() -> Self {
Self {
nodes: id_arena::Arena::new(),
eval: C::EvalAlgebra::default(),
}
}
#[inline]
pub fn eval(&mut self) -> &mut C::EvalAlgebra {
&mut self.eval
}
}
impl<C: Config> Graph<C> {
#[inline]
pub(crate) fn make_variable<D>(&mut self, data: D) -> Value<D> {
let node = Node {
inputs: Vec::new(),
update_func: None,
};
let id = Some(GradientId::new(self.nodes.alloc(node)));
Value { id, data }
}
pub fn make_node<D, G, F, Dims>(
&mut self,
data: D,
inputs: Vec<Option<Id>>,
update_func: F,
) -> Value<D>
where
C::GradientAlgebra: CoreAlgebra<D, Value = G>,
C::GradientStore: GradientStore<GradientId<D>, G>,
D: HasDims<Dims = Dims>,
G: HasDims<Dims = Dims> + Clone + 'static,
Dims: PartialEq + std::fmt::Debug + Clone + 'static + Send + Sync,
F: Fn(&mut C::GradientAlgebra, &mut C::GradientStore, G) -> Result<()>
+ 'static
+ Send
+ Sync,
{
self.make_generic_node::<D, D, G, G, F, Dims>(data, inputs, update_func)
}
pub fn make_generic_node<S, D, GS, GD, F, Dims>(
&mut self,
data: D,
inputs: Vec<Option<Id>>,
update_func: F,
) -> Value<D>
where
C::GradientAlgebra: CoreAlgebra<S, Value = GS>,
C::GradientAlgebra: CoreAlgebra<D, Value = GD>,
C::GradientStore: GradientStore<GradientId<D>, GD>,
C::GradientStore: GradientStore<GradientId<S>, GS>,
D: HasDims<Dims = Dims>,
GD: HasDims<Dims = Dims> + Clone + 'static,
Dims: PartialEq + std::fmt::Debug + Clone + 'static + Send + Sync,
F: Fn(&mut C::GradientAlgebra, &mut C::GradientStore, GD) -> Result<()>
+ 'static
+ Send
+ Sync,
{
if inputs.iter().all(|id| id.is_none()) {
return Value::constant(data);
}
let dims = data.dims();
let update_func: GradientUpdateFunc<C> =
Arc::new(move |algebra, store, index| -> Result<()> {
let value: GD = store
.get(GradientId::<D>::new(index))
.ok_or_else(|| Error::missing_gradient(func_name!()))?
.clone();
check_equal_dimensions(func_name!(), &[&value.dims(), &dims])?;
update_func(algebra, store, value)
});
let node = Node {
inputs,
update_func: Some(update_func),
};
let id = Some(GradientId::new(self.nodes.alloc(node)));
Value { id, data }
}
}
impl<C: Config> Graph<C> {
#[inline]
fn do_compute_gradients<D, G>(
&self,
graph: &mut C::GradientAlgebra,
gid: GradientId<D>,
gradient: G,
) -> Result<C::GradientStore>
where
C::GradientAlgebra: CoreAlgebra<D, Value = G>,
C::GradientStore: GradientStore<GradientId<D>, G> + Default,
{
let mut store = C::GradientStore::default();
store.insert(gid, gradient);
let mut heap = BinaryHeap::with_capacity(self.nodes.len());
heap.push(gid.inner);
let mut guard = gid.inner.next_id();
while let Some(id) = heap.pop() {
if id < guard {
guard = id;
let node = self
.nodes
.get(id)
.ok_or_else(|| Error::missing_node(func_name!()))?;
if let Some(update_func) = &node.update_func {
update_func(graph, &mut store, id)?;
}
for input in &node.inputs {
if let Some(id) = input {
heap.push(*id);
}
}
}
}
Ok(store)
}
#[inline]
fn do_compute_gradients_once<D, G>(
mut self,
graph: &mut C::GradientAlgebra,
gid: GradientId<D>,
gradient: G,
) -> Result<C::GradientStore>
where
C::GradientAlgebra: CoreAlgebra<D, Value = G>,
C::GradientStore: GradientStore<GradientId<D>, G> + Default,
{
let mut store = C::GradientStore::default();
store.insert(gid, gradient);
let mut heap = BinaryHeap::with_capacity(self.nodes.len());
heap.push(gid.inner);
let mut guard = gid.inner.next_id();
while let Some(id) = heap.pop() {
if id < guard {
guard = id;
let node = self
.nodes
.get_mut(id)
.ok_or_else(|| Error::missing_node(func_name!()))?;
if let Some(update_func) = &node.update_func {
update_func(graph, &mut store, id)?;
}
for input in &node.inputs {
if let Some(id) = input {
heap.push(*id);
}
}
node.clear();
}
}
Ok(store)
}
}
pub struct Config1<E>(std::marker::PhantomData<E>);
impl<E: Default + Clone> Config for Config1<E> {
type EvalAlgebra = E;
type GradientAlgebra = E;
type GradientStore = GenericGradientMap1;
}
impl<E: Default + Clone> Graph<Config1<E>> {
pub fn evaluate_gradients<T>(
&self,
id: GradientId<T>,
gradient: T,
) -> Result<GenericGradientMap1>
where
E: CoreAlgebra<T, Value = T>,
T: 'static,
{
let mut eval = self.eval.clone();
self.do_compute_gradients(&mut eval, id, gradient)
}
pub fn evaluate_gradients_once<T>(
self,
id: GradientId<T>,
gradient: T,
) -> Result<GenericGradientMap1>
where
E: CoreAlgebra<T, Value = T>,
T: 'static,
{
let mut eval = self.eval.clone();
self.do_compute_gradients_once(&mut eval, id, gradient)
}
}
pub struct ConfigN<E>(std::marker::PhantomData<E>);
impl<E: Default + Clone> Config for ConfigN<E> {
type EvalAlgebra = E;
type GradientAlgebra = Graph<ConfigN<E>>;
type GradientStore = GenericGradientMapN;
}
impl<E: Default + Clone> Graph<ConfigN<E>> {
pub fn compute_gradients<D>(
&mut self,
id: GradientId<D>,
gradient: Value<D>,
) -> Result<GenericGradientMapN>
where
Self: CoreAlgebra<D, Value = Value<D>>,
D: 'static,
{
let current = self.clone();
current.do_compute_gradients_once(self, id, gradient)
}
}
impl<D> Value<D> {
pub fn constant(data: D) -> Self {
Value { data, id: None }
}
pub fn data(&self) -> &D {
&self.data
}
pub fn id(&self) -> Option<GradientId<D>> {
self.id
}
pub fn input(&self) -> Option<Id> {
self.id.map(|id| id.inner)
}
}
impl<C: Config> Clone for Node<C> {
fn clone(&self) -> Self {
Self {
inputs: self.inputs.clone(),
update_func: self.update_func.clone(),
}
}
}
impl<C: Config> Clone for Graph<C> {
fn clone(&self) -> Self {
Self {
nodes: self.nodes.clone(),
eval: self.eval.clone(),
}
}
}
impl<C: Config> std::fmt::Debug for Node<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
f.debug_struct("Node")
.field("inputs", &self.inputs)
.finish()
}
}
impl<C: Config> std::fmt::Debug for Graph<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
for (id, node) in self.nodes.iter() {
write!(f, "{:?} <- {:?}; ", id, node.inputs)?;
}
Ok(())
}
}