#![warn(missing_docs)]
#[allow(unused_imports)]
#[macro_use]
extern crate aegir_derive;
#[doc(hidden)]
pub use self::aegir_derive::*;
#[allow(unused_imports)]
#[macro_use]
extern crate aegir_compile;
#[doc(hidden)]
pub use self::aegir_compile::*;
#[allow(unused_imports)]
use paste::paste;
#[macro_use]
extern crate itertools;
pub mod errors;
use errors::*;
pub trait Identifier: Copy + PartialEq + Eq + std::fmt::Debug + std::fmt::Display {
fn into_var(self) -> meta::Variable<Self> { meta::Variable(self) }
}
pub mod ids {
ids!(
A::a, B::b, C::c, D::d, E::e, F::f, G::g, H::h, I::i,
J::j, K::k, L::l, M::m, N::n, O::o, P::p, Q::q, R::r,
S::s, T::t, U::u, V::v, W::w, X::x, Y::y, Z::z,
Alpha::"\u{03B1}", Beta::"\u{03B2}", Gamma::"\u{03B3}", Delta::"\u{03B4}",
Epsilon::"\u{03B5}", Zeta::"\u{03B6}", Eta::"\u{03B7}", Theta::"\u{03B8}",
Iota::"\u{03B9}", Kappa::"\u{03BA}", Lambda::"\u{03BB}", Mu::"\u{03BC}",
Nu::"\u{03BD}", Xi::"\u{03BE}", Omicron::"\u{03BF}", Pi::"\u{03C0}",
Rho::"\u{03C1}", Sigma::"\u{03C2}", Tau::"\u{03C3}", Upsilon::"\u{03C4}",
Phi::"\u{03C6}", Chi::"\u{03C7}", Psi::"\u{03C8}", Omega::"\u{03C9}"
);
}
pub trait Context: AsRef<Self> {}
pub trait Read<I: Identifier>: Context {
type Buffer: buffers::Buffer;
fn read(&self, ident: I) -> Option<Self::Buffer>;
fn read_spec(&self, ident: I) -> Option<buffers::Spec<Self::Buffer>> {
self.read(ident).map(buffers::Spec::Raw)
}
fn read_shape(&self, ident: I) -> Option<buffers::shapes::ShapeOf<Self::Buffer>> {
use buffers::shapes::Shaped;
self.read(ident).map(|buf| buf.shape())
}
}
#[macro_export]
macro_rules! ctx_type {
($name:ident { $($buf_name:ident: $buf_ident:ident),+ }) => {
paste! {
#[derive(Context)]
pub struct $name<$([<__ $buf_ident>]),+> {
$(#[id($buf_ident)] pub $buf_name: [<__ $buf_ident>]),+
}
}
}
}
#[macro_export]
macro_rules! ctx {
($($key:ident = $value:expr),+) => {{
paste! {
ctx_type!(Ctx { $([<_ $key:lower>]: $key),+ });
Ctx {
$([<_ $key:lower>]: $value),+
}
}
}}
}
pub trait Node {
fn add<N: Node>(self, other: N) -> ops::Add<Self, N>
where
Self: Sized,
{
ops::Add(self, other)
}
fn sub<N: Node>(self, other: N) -> ops::Sub<Self, N>
where
Self: Sized,
{
ops::Sub(self, other)
}
fn mul<N: Node>(self, other: N) -> ops::Mul<Self, N>
where
Self: Sized,
{
ops::Mul(self, other)
}
fn div<N: Node>(self, other: N) -> ops::Div<Self, N>
where
Self: Sized,
{
ops::Div(self, other)
}
fn dot<N: Node>(self, other: N) -> ops::TensorDot<Self, N>
where
Self: Sized,
{
ops::Contract(self, other)
}
fn abs(self) -> ops::Abs<Self>
where
Self: Sized,
{
ops::Abs(self)
}
fn neg(self) -> ops::Negate<Self>
where
Self: Sized,
{
ops::Negate(self)
}
fn pow<P>(self, power: P) -> ops::Power<Self, P>
where
Self: Sized,
{
ops::Power(self, power)
}
fn ln(self) -> ops::Ln<Self>
where
Self: Sized,
{
ops::Ln(self)
}
fn squared(self) -> ops::Square<Self>
where
Self: Sized,
{
ops::Square(self)
}
fn sum(self) -> ops::Sum<Self>
where
Self: Sized,
{
ops::Sum(self)
}
fn sigmoid(self) -> ops::Sigmoid<Self>
where
Self: Sized,
{
ops::Sigmoid(self)
}
}
pub trait Contains<T: Identifier>: Node {
fn contains(&self, ident: T) -> bool;
}
pub trait Function<C: Context>: Node {
type Value: buffers::Buffer;
type Error: std::error::Error;
fn evaluate<CR: AsRef<C>>(&self, ctx: CR) -> AegirResult<Self, C>;
fn evaluate_spec<CR: AsRef<C>>(
&self,
ctx: CR,
) -> Result<buffers::Spec<Self::Value>, Self::Error> {
self.evaluate(ctx).map(buffers::Spec::Raw)
}
fn evaluate_shape<CR: AsRef<C>>(
&self,
ctx: CR,
) -> Result<buffers::shapes::ShapeOf<Self::Value>, Self::Error> {
self.evaluate(ctx)
.map(|ref buf| buffers::shapes::Shaped::shape(buf))
}
}
pub trait Differentiable<T: Identifier>: Node {
type Adjoint: Node;
fn adjoint(&self, target: T) -> Self::Adjoint;
fn evaluate_adjoint<C: Context, CR: AsRef<C>>(
&self,
target: T,
ctx: CR,
) -> AegirResult<Self::Adjoint, C>
where
Self: Function<C>,
Self::Adjoint: Function<C>,
{
self.adjoint(target).evaluate(ctx)
}
fn evaluate_dual<C: Context, CR: AsRef<C>>(
&self,
target: T,
ctx: CR,
) -> Result<
DualOf<Self, C, T>,
BinaryError<Self::Error, <AdjointOf<Self, T> as Function<C>>::Error, NoError>,
>
where
Self: Function<C>,
Self::Adjoint: Function<C>,
{
let value = self.evaluate(&ctx).map_err(BinaryError::Left)?;
let adjoint = self.evaluate_adjoint(target, ctx).map_err(BinaryError::Right)?;
Ok(dual!(value, adjoint))
}
}
pub type ErrorOf<F, C> = <F as Function<C>>::Error;
pub type ValueOf<F, C> = <F as Function<C>>::Value;
pub type AegirResult<F, C> = Result<ValueOf<F, C>, ErrorOf<F, C>>;
pub type AdjointOf<F, T> = <F as Differentiable<T>>::Adjoint;
pub type DualOf<F, C, T> = Dual<ValueOf<F, C>, ValueOf<AdjointOf<F, T>, C>>;
extern crate self as aegir;
mod dual;
pub use self::dual::Dual;
pub mod fmt;
pub mod buffers;
pub mod meta;
pub mod ops;