use crate::num::Float;
use crate::{definitions::N_UNARYOPS_OF_DEEPEX_ON_STACK, exerr, ExError, ExResult};
use smallvec::{smallvec, SmallVec};
use std::{fmt::Debug, marker::PhantomData};
enum OperatorType {
Bin,
Unary,
}
fn make_op_not_available_error(repr: &str, op_type: OperatorType) -> ExError {
let op_type_str = match op_type {
OperatorType::Bin => "binary",
OperatorType::Unary => "unary",
};
exerr!("{} operator '{}' not available", op_type_str, repr)
}
#[derive(Clone, Debug)]
pub struct Operator<'a, T: Clone> {
repr: &'a str,
bin_op: Option<BinOp<T>>,
unary_op: Option<fn(T) -> T>,
constant: Option<T>,
}
impl<'a, T: Clone> PartialEq for Operator<'a, T> {
fn eq(&self, other: &Self) -> bool {
self.repr.eq(other.repr)
}
}
impl<'a, T: Clone> Eq for Operator<'a, T> {}
impl<'a, T: Clone> PartialOrd for Operator<'a, T> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<'a, T: Clone> Ord for Operator<'a, T> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.repr.cmp(other.repr)
}
}
fn unwrap_operator<'a, O>(
wrapped_op: &'a Option<O>,
repr: &str,
op_type: OperatorType,
) -> ExResult<&'a O> {
wrapped_op
.as_ref()
.ok_or_else(|| make_op_not_available_error(repr, op_type))
}
impl<'a, T: Clone> Operator<'a, T> {
fn new(
repr: &'a str,
bin_op: Option<BinOp<T>>,
unary_op: Option<fn(T) -> T>,
constant: Option<T>,
) -> Operator<'a, T> {
if constant.is_some() {
if bin_op.is_some() {
panic!("Bug! Operators cannot be constant and binary. Check '{repr}'");
}
if unary_op.is_some() {
panic!("Bug! Operators cannot be constant and unary. Check '{repr}'.");
}
}
Operator {
repr,
bin_op,
unary_op,
constant,
}
}
pub fn make_bin(repr: &'a str, bin_op: BinOp<T>) -> Operator<'a, T> {
Operator::new(repr, Some(bin_op), None, None)
}
pub fn make_unary(repr: &'a str, unary_op: fn(T) -> T) -> Operator<'a, T> {
Operator::new(repr, None, Some(unary_op), None)
}
pub fn make_bin_unary(
repr: &'a str,
bin_op: BinOp<T>,
unary_op: fn(T) -> T,
) -> Operator<'a, T> {
Operator::new(repr, Some(bin_op), Some(unary_op), None)
}
pub fn make_constant(repr: &'a str, constant: T) -> Operator<'a, T> {
Operator::new(repr, None, None, Some(constant))
}
pub fn bin(&self) -> ExResult<BinOp<T>> {
let op = unwrap_operator(&self.bin_op, self.repr, OperatorType::Bin)?;
Ok(op.clone())
}
pub fn unary(&self) -> ExResult<fn(T) -> T> {
Ok(*unwrap_operator(
&self.unary_op,
self.repr,
OperatorType::Unary,
)?)
}
pub fn repr(&self) -> &'a str {
self.repr
}
pub fn has_bin(&self) -> bool {
self.bin_op.is_some()
}
pub fn has_unary(&self) -> bool {
self.unary_op.is_some()
}
pub fn constant(&self) -> Option<T> {
self.constant.clone()
}
}
pub trait OperateBinary<T> {
fn apply(&self, x: T, y: T) -> T;
}
#[derive(Clone, Debug)]
pub struct BinOpWithIdx<T>
where
T: Clone,
{
pub op: BinOp<T>,
pub idx: usize,
}
impl<T> OperateBinary<T> for BinOpWithIdx<T>
where
T: Clone,
{
fn apply(&self, arg1: T, arg2: T) -> T {
(self.op.apply)(arg1, arg2)
}
}
impl<T> PartialEq for BinOpWithIdx<T>
where
T: Clone,
{
fn eq(&self, other: &Self) -> bool {
self.idx == other.idx
}
}
impl<T: Clone> Eq for BinOpWithIdx<T> {}
impl<T: Clone> PartialOrd for BinOpWithIdx<T> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<T: Clone> Ord for BinOpWithIdx<T> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.idx.cmp(&other.idx)
}
}
#[derive(Clone, Copy, Debug)]
pub struct UnaryFuncWithIdx<T> {
pub f: fn(T) -> T,
pub idx: usize,
}
impl<T> UnaryFuncWithIdx<T> {
pub fn apply(&self, x: T) -> T {
(self.f)(x)
}
}
impl<T> PartialEq for UnaryFuncWithIdx<T> {
fn eq(&self, other: &Self) -> bool {
self.idx == other.idx
}
}
impl<T> Eq for UnaryFuncWithIdx<T> {}
impl<T> PartialOrd for UnaryFuncWithIdx<T> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<T> Ord for UnaryFuncWithIdx<T> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.idx.cmp(&other.idx)
}
}
pub type VecOfUnaryFuncs<T> = SmallVec<[UnaryFuncWithIdx<T>; N_UNARYOPS_OF_DEEPEX_ON_STACK]>;
#[derive(Clone, Eq, PartialEq, PartialOrd, Ord, Debug)]
pub struct UnaryOp<T> {
funcs_to_be_composed: VecOfUnaryFuncs<T>,
}
impl<T> UnaryOp<T>
where
T: Clone,
{
pub fn apply(&self, x: T) -> T {
let mut result = x;
for uo in self.funcs_to_be_composed.iter().rev() {
result = uo.apply(result);
}
result
}
pub fn append_after(&mut self, other: UnaryOp<T>) {
self.append_after_iter(other.funcs_to_be_composed.into_iter());
}
pub fn remove_latest(&mut self) {
self.funcs_to_be_composed.remove(0);
}
pub fn append_after_iter<I>(&mut self, other_iter: I)
where
I: Iterator<Item = UnaryFuncWithIdx<T>>,
{
self.funcs_to_be_composed = other_iter
.chain(self.funcs_to_be_composed.iter().cloned())
.collect::<SmallVec<_>>();
}
pub fn len(&self) -> usize {
self.funcs_to_be_composed.len()
}
pub fn new() -> Self {
Self {
funcs_to_be_composed: smallvec![],
}
}
pub fn from_vec(v: VecOfUnaryFuncs<T>) -> Self {
Self {
funcs_to_be_composed: v,
}
}
pub fn from_iter<I>(iter: I) -> Self
where
I: Iterator<Item = UnaryFuncWithIdx<T>>,
{
Self {
funcs_to_be_composed: iter.collect(),
}
}
pub fn funcs_to_be_composed(&self) -> &VecOfUnaryFuncs<T> {
&self.funcs_to_be_composed
}
pub fn clear(&mut self) {
self.funcs_to_be_composed.clear();
}
}
impl<T: Clone> Default for UnaryOp<T> {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug)]
pub struct BinOp<T: Clone> {
pub apply: fn(T, T) -> T,
pub prio: i64,
pub is_commutative: bool,
}
impl<T: Clone> OperateBinary<T> for BinOp<T> {
fn apply(&self, x: T, y: T) -> T {
(self.apply)(x, y)
}
}
pub trait MakeOperators<T: Clone>: Clone + Debug {
fn make<'a>() -> Vec<Operator<'a, T>>;
}
#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Debug)]
pub struct FloatOpsFactory<T> {
dummy: PhantomData<T>,
}
impl<T: Debug + Float> MakeOperators<T> for FloatOpsFactory<T> {
fn make<'a>() -> Vec<Operator<'a, T>> {
vec![
Operator::make_bin(
"^",
BinOp {
apply: |a, b| a.powf(b),
prio: 4,
is_commutative: false,
},
),
Operator::make_bin(
"*",
BinOp {
apply: |a, b| a * b,
prio: 2,
is_commutative: true,
},
),
Operator::make_bin(
"/",
BinOp {
apply: |a, b| a / b,
prio: 3,
is_commutative: false,
},
),
Operator::make_bin_unary(
"+",
BinOp {
apply: |a, b| a + b,
prio: 0,
is_commutative: true,
},
|a| a,
),
Operator::make_bin_unary(
"-",
BinOp {
apply: |a, b| a - b,
prio: 1,
is_commutative: false,
},
|a| -a,
),
Operator::make_bin(
"atan2",
BinOp {
apply: |y, x| y.atan2(x),
prio: 0,
is_commutative: false,
},
),
Operator::make_bin(
"min",
BinOp {
apply: |y, x| y.min(x),
prio: 0,
is_commutative: false,
},
),
Operator::make_bin(
"max",
BinOp {
apply: |y, x| y.max(x),
prio: 0,
is_commutative: false,
},
),
Operator::make_unary("abs", |a| a.abs()),
Operator::make_unary("signum", |a| a.signum()),
Operator::make_unary("sin", |a| a.sin()),
Operator::make_unary("cos", |a| a.cos()),
Operator::make_unary("tan", |a| a.tan()),
Operator::make_unary("asin", |a| a.asin()),
Operator::make_unary("acos", |a| a.acos()),
Operator::make_unary("atan", |a| a.atan()),
Operator::make_unary("sinh", |a| a.sinh()),
Operator::make_unary("cosh", |a| a.cosh()),
Operator::make_unary("tanh", |a| a.tanh()),
Operator::make_unary("asinh", |a| a.asinh()),
Operator::make_unary("acosh", |a| a.acosh()),
Operator::make_unary("atanh", |a| a.atanh()),
Operator::make_unary("floor", |a| a.floor()),
Operator::make_unary("round", |a| a.round()),
Operator::make_unary("ceil", |a| a.ceil()),
Operator::make_unary("trunc", |a| a.trunc()),
Operator::make_unary("fract", |a| a.fract()),
Operator::make_unary("exp", |a| a.exp()),
Operator::make_unary("sqrt", |a| a.sqrt()),
Operator::make_unary("cbrt", |a| a.cbrt()),
Operator::make_unary("ln", |a| a.ln()),
Operator::make_unary("log2", |a| a.log2()),
Operator::make_unary("log10", |a| a.log10()),
Operator::make_unary("log", |a| a.ln()),
Operator::make_constant("PI", T::from_f64(std::f64::consts::PI)),
Operator::make_constant("π", T::from_f64(std::f64::consts::PI)),
Operator::make_constant("E", T::from_f64(std::f64::consts::E)),
Operator::make_constant("e", T::from_f64(std::f64::consts::E)),
Operator::make_constant("TAU", T::from_f64(std::f64::consts::TAU)),
Operator::make_constant("τ", T::from_f64(std::f64::consts::TAU)),
]
}
}
#[macro_export]
macro_rules! ops_factory {
($name:ident, $T:ty, $( $ops:expr ),*) => {
#[derive(Clone, Debug)]
pub struct $name;
impl MakeOperators<$T> for $name {
fn make<'a>() -> Vec<Operator<'a, $T>> {
vec![$($ops,)*]
}
}
}
}