use num_traits::Float;
use std::ops::{Add, Div, Mul, Neg, Sub};
pub trait NumericType:
Clone + Default + Send + Sync + 'static + std::fmt::Display + std::fmt::Debug
{
}
impl<T> NumericType for T where
T: Clone + Default + Send + Sync + 'static + std::fmt::Display + std::fmt::Debug
{
}
pub trait FloatType: NumericType + num_traits::Float + Copy + 'static {
type DefaultFloat: FloatType;
fn to_default_float(self) -> Self::DefaultFloat;
}
pub trait IntType: NumericType + Copy + 'static {
type Unsigned: UIntType;
type DefaultFloat: FloatType;
fn to_unsigned(self) -> Option<Self::Unsigned>;
fn to_default_float(self) -> Self::DefaultFloat;
}
pub trait UIntType: NumericType + Copy + 'static {
type Signed: IntType;
type DefaultFloat: FloatType;
fn to_signed(self) -> Option<Self::Signed>;
fn to_default_float(self) -> Self::DefaultFloat;
}
impl FloatType for f32 {
type DefaultFloat = f64;
fn to_default_float(self) -> f64 {
f64::from(self)
}
}
impl FloatType for f64 {
type DefaultFloat = f64;
fn to_default_float(self) -> f64 {
self
}
}
impl IntType for i32 {
type Unsigned = u32;
type DefaultFloat = f64;
fn to_unsigned(self) -> Option<u32> {
u32::try_from(self).ok()
}
fn to_default_float(self) -> f64 {
f64::from(self)
}
}
impl IntType for i64 {
type Unsigned = u64;
type DefaultFloat = f64;
fn to_unsigned(self) -> Option<u64> {
u64::try_from(self).ok()
}
fn to_default_float(self) -> f64 {
self as f64
}
}
impl UIntType for u32 {
type Signed = i32;
type DefaultFloat = f64;
fn to_signed(self) -> Option<i32> {
i32::try_from(self).ok()
}
fn to_default_float(self) -> f64 {
f64::from(self)
}
}
impl UIntType for u64 {
type Signed = i64;
type DefaultFloat = f64;
fn to_signed(self) -> Option<i64> {
i64::try_from(self).ok()
}
fn to_default_float(self) -> f64 {
self as f64
}
}
pub trait PromoteTo<T> {
type Output;
fn promote(self) -> Self::Output;
}
impl PromoteTo<f64> for f32 {
type Output = f64;
fn promote(self) -> f64 {
f64::from(self)
}
}
impl PromoteTo<f64> for i32 {
type Output = f64;
fn promote(self) -> f64 {
f64::from(self)
}
}
impl PromoteTo<f64> for i64 {
type Output = f64;
fn promote(self) -> f64 {
self as f64
}
}
impl PromoteTo<f64> for u32 {
type Output = f64;
fn promote(self) -> f64 {
f64::from(self)
}
}
impl PromoteTo<f64> for u64 {
type Output = f64;
fn promote(self) -> f64 {
self as f64
}
}
pub trait MathExpr {
type Repr<T>;
fn constant<T: NumericType>(value: T) -> Self::Repr<T>;
fn var<T: NumericType>(index: usize) -> Self::Repr<T>;
fn add<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
where
L: NumericType + Add<R, Output = Output>,
R: NumericType,
Output: NumericType;
fn sub<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
where
L: NumericType + Sub<R, Output = Output>,
R: NumericType,
Output: NumericType;
fn mul<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
where
L: NumericType + Mul<R, Output = Output>,
R: NumericType,
Output: NumericType;
fn div<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
where
L: NumericType + Div<R, Output = Output>,
R: NumericType,
Output: NumericType;
fn pow<T: NumericType + Float>(base: Self::Repr<T>, exp: Self::Repr<T>) -> Self::Repr<T>;
fn neg<T: NumericType + Neg<Output = T>>(expr: Self::Repr<T>) -> Self::Repr<T>;
fn ln<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T>;
fn exp<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T>;
fn sqrt<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T>;
fn sin<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T>;
fn cos<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T>;
}
pub trait StatisticalExpr: MathExpr {
fn logistic<T: NumericType + Float>(x: Self::Repr<T>) -> Self::Repr<T> {
let one = Self::constant(T::one());
let neg_x = Self::neg(x);
let exp_neg_x = Self::exp(neg_x);
let denominator = Self::add(one, exp_neg_x);
Self::div(Self::constant(T::one()), denominator)
}
fn softplus<T: NumericType + Float>(x: Self::Repr<T>) -> Self::Repr<T> {
let one = Self::constant(T::one());
let exp_x = Self::exp(x);
let one_plus_exp_x = Self::add(one, exp_x);
Self::ln(one_plus_exp_x)
}
fn sigmoid<T: NumericType + Float>(x: Self::Repr<T>) -> Self::Repr<T> {
Self::logistic(x)
}
}
pub trait RangeType: Clone + Send + Sync + 'static + std::fmt::Debug {
type IndexType: NumericType;
fn start(&self) -> Self::IndexType;
fn end(&self) -> Self::IndexType;
fn contains(&self, value: &Self::IndexType) -> bool;
fn len(&self) -> Self::IndexType;
fn is_empty(&self) -> bool;
}
pub trait SummandFunction<T>: Clone + std::fmt::Debug {
type Body: Clone;
fn index_var(&self) -> &str;
fn body(&self) -> &Self::Body;
fn apply(&self, index: T) -> Self::Body;
fn depends_on_index(&self) -> bool;
fn extract_independent_factors(&self) -> (Vec<Self::Body>, Self::Body);
}
pub trait SummationExpr: MathExpr {
fn sum_finite<T, R, F>(range: Self::Repr<R>, function: Self::Repr<F>) -> Self::Repr<T>
where
T: NumericType,
R: RangeType,
F: SummandFunction<T>,
Self::Repr<T>: Clone;
fn sum_infinite<T, F>(start: Self::Repr<T>, function: Self::Repr<F>) -> Self::Repr<T>
where
T: NumericType,
F: SummandFunction<T>,
Self::Repr<T>: Clone;
fn sum_telescoping<T, F>(
range: Self::Repr<crate::final_tagless::IntRange>,
function: Self::Repr<F>,
) -> Self::Repr<T>
where
T: NumericType,
F: SummandFunction<T>;
fn range_to<T: NumericType>(
start: Self::Repr<T>,
end: Self::Repr<T>,
) -> Self::Repr<crate::final_tagless::IntRange>;
fn function<T: NumericType>(
index_var: &str,
body: Self::Repr<T>,
) -> Self::Repr<crate::final_tagless::ASTFunction<T>>;
}
pub trait ASTMathExpr {
type Repr;
fn constant(value: f64) -> Self::Repr;
fn var(index: usize) -> Self::Repr;
fn add(left: Self::Repr, right: Self::Repr) -> Self::Repr;
fn sub(left: Self::Repr, right: Self::Repr) -> Self::Repr;
fn mul(left: Self::Repr, right: Self::Repr) -> Self::Repr;
fn div(left: Self::Repr, right: Self::Repr) -> Self::Repr;
fn pow(base: Self::Repr, exp: Self::Repr) -> Self::Repr;
fn neg(expr: Self::Repr) -> Self::Repr;
fn ln(expr: Self::Repr) -> Self::Repr;
fn exp(expr: Self::Repr) -> Self::Repr;
fn sqrt(expr: Self::Repr) -> Self::Repr;
fn sin(expr: Self::Repr) -> Self::Repr;
fn cos(expr: Self::Repr) -> Self::Repr;
}