mod mul;
use crate::{eval, eval::Eval, expr::Expr, grad, grad::Grad};
macro_rules! unary_op {
($Name:ident, $name:ident, $doc:literal) => {
#[derive(Debug)]
#[doc = $doc]
pub struct $Name<T: ~const Eval>(T);
impl<T: ~const Eval<Evaluated: ~const core::ops::$Name>> const core::ops::$Name for Expr<T> {
type Output = Expr<$Name<T>>;
#[inline(always)]
fn $name(self) -> Self::Output {
Expr($Name(self.0))
}
}
$crate::implement_eval!(
T: Eval<Evaluated: ~const core::ops::$Name> =>
$Name<T> >-> <T::Evaluated as core::ops::$Name>::Output:
|self| where own {
core::ops::$Name::$name(self.0.eval())
} else {
core::ops::$Name::$name((&self.0).eval())
}
);
};
}
macro_rules! unary_grad {
($Name:ident, $Diff:ident) => {
impl<
T: ~const Eval<Evaluated: ~const core::ops::$Name>
+ ~const Grad<Differentiated: ~const Eval<Evaluated: ~const core::ops::$Diff>>,
> const grad::Typed for $Name<T>
{
type Differentiated = $Diff<T::Differentiated>;
}
impl<
T: ~const Eval<Evaluated: ~const core::ops::$Name>
+ ~const Grad<Differentiated: ~const Eval<Evaluated: ~const core::ops::$Diff>>,
> const grad::Own for $Name<T>
{
#[inline(always)]
fn grad<U>(self, x: &U) -> Self::Differentiated {
$Diff(self.0.grad(x))
}
}
impl<
T: ~const Eval<Evaluated: ~const core::ops::$Name>
+ ~const Grad<Differentiated: ~const Eval<Evaluated: ~const core::ops::$Diff>>,
> const grad::Ref for $Name<T>
{
#[inline(always)]
fn grad<U>(&self, x: &U) -> Self::Differentiated {
$Diff((&self.0).grad(x))
}
}
impl<
T: ~const Eval<Evaluated: ~const core::ops::$Name>
+ ~const Grad<Differentiated: ~const Eval<Evaluated: ~const core::ops::$Diff>>,
> const Grad for $Name<T>
{
}
};
}
macro_rules! binary_op {
($Name:ident, $name:ident, $doc:literal) => {
#[derive(Debug)]
#[doc = $doc]
pub struct $Name<L: ~const Eval, R: ~const Eval>(L, R);
impl<
L: ~const Eval<Evaluated: ~const core::ops::$Name<R::Evaluated>>,
R: ~const Eval,
> const core::ops::$Name<Expr<R>> for Expr<L>
{
type Output = Expr<$Name<L, R>>;
#[inline(always)]
fn $name(self, arg: Expr<R>) -> Self::Output {
Expr($Name(self.0, arg.0))
}
}
$crate::implement_eval!(
L: Eval<Evaluated: ~const core::ops::$Name<R::Evaluated>>, R: Eval =>
$Name<L, R> >-> <L::Evaluated as core::ops::$Name<R::Evaluated>>::Output:
|self| where own {
core::ops::$Name::$name(self.0.eval(), self.1.eval())
} else {
core::ops::$Name::$name((&self.0).eval(), (&self.1).eval())
}
);
};
}
macro_rules! binary_grad {
($Name:ident, $Diff:ident) => {
impl<
L: ~const Eval<Evaluated: ~const core::ops::$Name<R::Evaluated>>
+ ~const Grad<
Differentiated: ~const Eval<
Evaluated: ~const core::ops::$Diff<
<R::Differentiated as eval::Typed>::Evaluated,
>,
>,
>,
R: ~const Grad,
> const grad::Typed for $Name<L, R>
{
type Differentiated = $Diff<L::Differentiated, R::Differentiated>;
}
impl<
L: ~const Eval<Evaluated: ~const core::ops::$Name<R::Evaluated>>
+ ~const Grad<
Differentiated: ~const Eval<
Evaluated: ~const core::ops::$Diff<
<R::Differentiated as eval::Typed>::Evaluated,
>,
>,
>,
R: ~const Grad,
> const grad::Own for $Name<L, R>
{
#[inline(always)]
fn grad<U>(self, x: &U) -> Self::Differentiated {
$Diff(self.0.grad(x), self.1.grad(x))
}
}
impl<
L: ~const Eval<Evaluated: ~const core::ops::$Name<R::Evaluated>>
+ ~const Grad<
Differentiated: ~const Eval<
Evaluated: ~const core::ops::$Diff<
<R::Differentiated as eval::Typed>::Evaluated,
>,
>,
>,
R: ~const Grad,
> const grad::Ref for $Name<L, R>
{
#[inline(always)]
fn grad<U>(&self, x: &U) -> Self::Differentiated {
$Diff((&self.0).grad(x), (&self.1).grad(x))
}
}
impl<
L: ~const Eval<Evaluated: ~const core::ops::$Name<R::Evaluated>>
+ ~const Grad<
Differentiated: ~const Eval<
Evaluated: ~const core::ops::$Diff<
<R::Differentiated as eval::Typed>::Evaluated,
>,
>,
>,
R: ~const Grad,
> const Grad for $Name<L, R>
{
}
};
}
unary_op!(Neg, neg, "Arithmetic negation (e.g. `-4`).");
unary_op!(Not, not, "Logical negation (e.g. `!true`).");
unary_grad!(Neg, Neg);
binary_op!(Add, add, "Arithmetic addition (e.g. `a + b`");
binary_op!(BitAnd, bitand, "Bitwise conjunction (e.g. `a & b`)");
binary_op!(BitOr, bitor, "Bitwise inclusive-or (e.g. `a | b`)");
binary_op!(BitXor, bitxor, "Bitwise exclusive-or (e.g. `a ^ b`)");
binary_op!(Div, div, "Arithmetic division (e.g. `a / b`)"); binary_op!(Mul, mul, "Arithmetic multiplication (e.g. `a * b`)"); binary_op!(Rem, rem, "Arithmetic remainder (e.g. `a % b`)"); binary_op!(Shl, shl, "Arithmetic left-shift (e.g. `a << b`)"); binary_op!(Shr, shr, "Arithmetic right-shift (e.g. `a << b`)"); binary_op!(Sub, sub, "Arithmetic subtraction (e.g. `a - b`)");
binary_grad!(Add, Add);
binary_grad!(Sub, Sub);