computation-types 0.0.0

Types for abstract mathematical computation
Documentation
use core::{fmt, ops};

use paste::paste;

use crate::{
    impl_computation_fn_for_binary, impl_computation_fn_for_unary, impl_core_ops,
    impl_display_for_inline_binary, Computation, ComputationFn, NamedArgs,
};

pub use self::{same_or_zero::*, trig::*};

mod same_or_zero {
    use crate::peano::{Suc, Zero};

    pub trait SameOrZero<B> {
        type Max;
    }

    impl SameOrZero<Zero> for Zero {
        type Max = Zero;
    }

    impl<A> SameOrZero<Suc<A>> for Zero {
        type Max = Suc<A>;
    }

    impl<A> SameOrZero<Zero> for Suc<A> {
        type Max = Suc<A>;
    }

    impl<A> SameOrZero<Suc<A>> for Suc<A> {
        type Max = Suc<A>;
    }
}

macro_rules! impl_binary_op {
    ( $op:ident ) => {
        impl_binary_op!($op, ops);
    };
    ( $op:ident, $package:ident ) => {
        impl_binary_op!($op, $package, $op);
    };
    ( $op:ident, $package:ident, $bound:ident ) => {
        paste! {
            #[derive(Clone, Copy, Debug)]
            pub struct $op<A, B>(pub A, pub B)
            where
                Self: Computation;

            impl<A, B, ADim, AItem> Computation for $op<A, B>
            where
                A: Computation<Dim = ADim, Item = AItem>,
                B: Computation,
                ADim: SameOrZero<B::Dim>,
                AItem: $package::$bound<B::Item>,
            {
                type Dim = ADim::Max;
                type Item = AItem::Output;
            }

            impl_computation_fn_for_binary!($op);

            impl_core_ops!($op<A, B>);
        }
    };
}

macro_rules! impl_unary_op {
    ( $op:ident ) => {
        impl_unary_op!($op, ops);
    };
    ( $op:ident, $package:ident ) => {
        impl_unary_op!($op, $package, $op);
    };
    ( $op:ident, $package:ident, $bound:ident ) => {
        impl_unary_op!($op, $package, $bound, Item::Output);
    };
    ( $op:ident, $package:ident, $bound:ident, Item $( :: $Output:ident )? ) => {
        paste! {
            #[derive(Clone, Copy, Debug)]
            pub struct $op<A>(pub A)
            where
                Self: Computation;


            impl<A, Item> Computation for $op<A>
            where
                A: Computation<Item = Item>,
                Item: $package::$bound,
            {
                type Dim = A::Dim;
                type Item = Item $( ::$Output )?;
            }

            impl_computation_fn_for_unary!($op);

            impl_core_ops!($op<A>);
        }
    };
}

impl_binary_op!(Add);
impl_binary_op!(Sub);
impl_binary_op!(Mul);
impl_binary_op!(Div);
impl_binary_op!(Pow, num_traits);
impl_unary_op!(Neg);
impl_unary_op!(Abs, num_traits, Signed, Item);

impl_display_for_inline_binary!(Add, "+");
impl_display_for_inline_binary!(Sub, "-");
impl_display_for_inline_binary!(Mul, "*");
impl_display_for_inline_binary!(Div, "/");
impl_display_for_inline_binary!(Pow, "^");

impl<A> fmt::Display for Neg<A>
where
    Self: Computation,
    A: fmt::Display,
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "-{}", self.0)
    }
}

impl<A> fmt::Display for Abs<A>
where
    Self: Computation,
    A: fmt::Display,
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{}.abs()", self.0)
    }
}

mod trig {
    use num_traits::real;

    use super::*;

    impl_unary_op!(Sin, real, Real, Item);
    impl_unary_op!(Cos, real, Real, Item);
    impl_unary_op!(Tan, real, Real, Item);
    impl_unary_op!(Asin, real, Real, Item);
    impl_unary_op!(Acos, real, Real, Item);
    impl_unary_op!(Atan, real, Real, Item);

    macro_rules! impl_display {
        ( $op:ident ) => {
            paste::paste! {
                impl<A> fmt::Display for $op<A>
                where
                    Self: Computation,
                    A: fmt::Display,
                {
                    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
                        write!(f, "{}.{}()", self.0, stringify!([<$op:lower>]))
                    }
                }
            }
        };
    }

    impl_display!(Sin);
    impl_display!(Cos);
    impl_display!(Tan);
    impl_display!(Asin);
    impl_display!(Acos);
    impl_display!(Atan);
}

#[cfg(test)]
mod tests {
    use proptest::prelude::*;
    use test_strategy::proptest;

    use crate::{val, Computation};

    macro_rules! assert_op_display {
        ( $x:ident $op:tt $y:ident ) => {
            prop_assert_eq!((val!($x) $op val!($y)).to_string(), format!("({} {} {})", val!($x), stringify!($op), val!($y)));
        };
        ( $x:ident . $op:ident ( $y:ident ) ) => {
            prop_assert_eq!(val!($x).$op(val!($y)).to_string(), format!("{}.{}({})", val!($x), stringify!($op), val!($y)));
        };
    }

    #[proptest]
    fn add_should_display(x: i32, y: i32) {
        assert_op_display!(x + y);
    }

    #[proptest]
    fn sub_should_display(x: i32, y: i32) {
        assert_op_display!(x - y);
    }

    #[proptest]
    fn mul_should_display(x: i32, y: i32) {
        assert_op_display!(x * y);
    }

    #[proptest]
    fn div_should_display(x: i32, y: i32) {
        assert_op_display!(x / y);
    }

    #[proptest]
    fn pow_should_display(x: i32, y: u32) {
        prop_assert_eq!(
            val!(x).pow(val!(y)).to_string(),
            format!("({} ^ {})", val!(x), val!(y))
        );
    }

    #[proptest]
    fn neg_should_display(x: i32) {
        prop_assert_eq!((-val!(x)).to_string(), format!("-{}", val!(x)));
    }

    #[proptest]
    fn abs_should_display(x: i32) {
        prop_assert_eq!(val!(x).abs().to_string(), format!("{}.abs()", val!(x)));
    }

    mod trig {
        use super::*;

        #[proptest]
        fn sin_should_display(x: f32) {
            prop_assert_eq!(val!(x).sin().to_string(), format!("{}.sin()", val!(x)));
        }

        #[proptest]
        fn cos_should_display(x: f32) {
            prop_assert_eq!(val!(x).cos().to_string(), format!("{}.cos()", val!(x)));
        }

        #[proptest]
        fn tan_should_display(x: f32) {
            prop_assert_eq!(val!(x).tan().to_string(), format!("{}.tan()", val!(x)));
        }

        #[proptest]
        fn asin_should_display(x: f32) {
            prop_assert_eq!(val!(x).asin().to_string(), format!("{}.asin()", val!(x)));
        }

        #[proptest]
        fn acos_should_display(x: f32) {
            prop_assert_eq!(val!(x).acos().to_string(), format!("{}.acos()", val!(x)));
        }

        #[proptest]
        fn atan_should_display(x: f32) {
            prop_assert_eq!(val!(x).atan().to_string(), format!("{}.atan()", val!(x)));
        }
    }
}