computation_types/
math.rs

1use core::{fmt, ops};
2
3use paste::paste;
4
5use crate::{
6    impl_computation_fn_for_binary, impl_computation_fn_for_unary, impl_core_ops,
7    impl_display_for_inline_binary, Computation, ComputationFn, NamedArgs,
8};
9
10pub use self::{same_or_zero::*, trig::*};
11
12mod same_or_zero {
13    use crate::peano::{Suc, Zero};
14
15    pub trait SameOrZero<B> {
16        type Max;
17    }
18
19    impl SameOrZero<Zero> for Zero {
20        type Max = Zero;
21    }
22
23    impl<A> SameOrZero<Suc<A>> for Zero {
24        type Max = Suc<A>;
25    }
26
27    impl<A> SameOrZero<Zero> for Suc<A> {
28        type Max = Suc<A>;
29    }
30
31    impl<A> SameOrZero<Suc<A>> for Suc<A> {
32        type Max = Suc<A>;
33    }
34}
35
36macro_rules! impl_binary_op {
37    ( $op:ident ) => {
38        impl_binary_op!($op, ops);
39    };
40    ( $op:ident, $package:ident ) => {
41        impl_binary_op!($op, $package, $op);
42    };
43    ( $op:ident, $package:ident, $bound:ident ) => {
44        paste! {
45            #[derive(Clone, Copy, Debug)]
46            pub struct $op<A, B>(pub A, pub B)
47            where
48                Self: Computation;
49
50            impl<A, B, ADim, AItem> Computation for $op<A, B>
51            where
52                A: Computation<Dim = ADim, Item = AItem>,
53                B: Computation,
54                ADim: SameOrZero<B::Dim>,
55                AItem: $package::$bound<B::Item>,
56            {
57                type Dim = ADim::Max;
58                type Item = AItem::Output;
59            }
60
61            impl_computation_fn_for_binary!($op);
62
63            impl_core_ops!($op<A, B>);
64        }
65    };
66}
67
68macro_rules! impl_unary_op {
69    ( $op:ident ) => {
70        impl_unary_op!($op, ops);
71    };
72    ( $op:ident, $package:ident ) => {
73        impl_unary_op!($op, $package, $op);
74    };
75    ( $op:ident, $package:ident, $bound:ident ) => {
76        impl_unary_op!($op, $package, $bound, Item::Output);
77    };
78    ( $op:ident, $package:ident, $bound:ident, Item $( :: $Output:ident )? ) => {
79        paste! {
80            #[derive(Clone, Copy, Debug)]
81            pub struct $op<A>(pub A)
82            where
83                Self: Computation;
84
85
86            impl<A, Item> Computation for $op<A>
87            where
88                A: Computation<Item = Item>,
89                Item: $package::$bound,
90            {
91                type Dim = A::Dim;
92                type Item = Item $( ::$Output )?;
93            }
94
95            impl_computation_fn_for_unary!($op);
96
97            impl_core_ops!($op<A>);
98        }
99    };
100}
101
102impl_binary_op!(Add);
103impl_binary_op!(Sub);
104impl_binary_op!(Mul);
105impl_binary_op!(Div);
106impl_binary_op!(Pow, num_traits);
107impl_unary_op!(Neg);
108impl_unary_op!(Abs, num_traits, Signed, Item);
109
110impl_display_for_inline_binary!(Add, "+");
111impl_display_for_inline_binary!(Sub, "-");
112impl_display_for_inline_binary!(Mul, "*");
113impl_display_for_inline_binary!(Div, "/");
114impl_display_for_inline_binary!(Pow, "^");
115
116impl<A> fmt::Display for Neg<A>
117where
118    Self: Computation,
119    A: fmt::Display,
120{
121    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122        write!(f, "-{}", self.0)
123    }
124}
125
126impl<A> fmt::Display for Abs<A>
127where
128    Self: Computation,
129    A: fmt::Display,
130{
131    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132        write!(f, "{}.abs()", self.0)
133    }
134}
135
136mod trig {
137    use num_traits::real;
138
139    use super::*;
140
141    impl_unary_op!(Sin, real, Real, Item);
142    impl_unary_op!(Cos, real, Real, Item);
143    impl_unary_op!(Tan, real, Real, Item);
144    impl_unary_op!(Asin, real, Real, Item);
145    impl_unary_op!(Acos, real, Real, Item);
146    impl_unary_op!(Atan, real, Real, Item);
147
148    macro_rules! impl_display {
149        ( $op:ident ) => {
150            paste::paste! {
151                impl<A> fmt::Display for $op<A>
152                where
153                    Self: Computation,
154                    A: fmt::Display,
155                {
156                    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157                        write!(f, "{}.{}()", self.0, stringify!([<$op:lower>]))
158                    }
159                }
160            }
161        };
162    }
163
164    impl_display!(Sin);
165    impl_display!(Cos);
166    impl_display!(Tan);
167    impl_display!(Asin);
168    impl_display!(Acos);
169    impl_display!(Atan);
170}
171
172#[cfg(test)]
173mod tests {
174    use proptest::prelude::*;
175    use test_strategy::proptest;
176
177    use crate::{val, Computation};
178
179    macro_rules! assert_op_display {
180        ( $x:ident $op:tt $y:ident ) => {
181            prop_assert_eq!((val!($x) $op val!($y)).to_string(), format!("({} {} {})", val!($x), stringify!($op), val!($y)));
182        };
183        ( $x:ident . $op:ident ( $y:ident ) ) => {
184            prop_assert_eq!(val!($x).$op(val!($y)).to_string(), format!("{}.{}({})", val!($x), stringify!($op), val!($y)));
185        };
186    }
187
188    #[proptest]
189    fn add_should_display(x: i32, y: i32) {
190        assert_op_display!(x + y);
191    }
192
193    #[proptest]
194    fn sub_should_display(x: i32, y: i32) {
195        assert_op_display!(x - y);
196    }
197
198    #[proptest]
199    fn mul_should_display(x: i32, y: i32) {
200        assert_op_display!(x * y);
201    }
202
203    #[proptest]
204    fn div_should_display(x: i32, y: i32) {
205        assert_op_display!(x / y);
206    }
207
208    #[proptest]
209    fn pow_should_display(x: i32, y: u32) {
210        prop_assert_eq!(
211            val!(x).pow(val!(y)).to_string(),
212            format!("({} ^ {})", val!(x), val!(y))
213        );
214    }
215
216    #[proptest]
217    fn neg_should_display(x: i32) {
218        prop_assert_eq!((-val!(x)).to_string(), format!("-{}", val!(x)));
219    }
220
221    #[proptest]
222    fn abs_should_display(x: i32) {
223        prop_assert_eq!(val!(x).abs().to_string(), format!("{}.abs()", val!(x)));
224    }
225
226    mod trig {
227        use super::*;
228
229        #[proptest]
230        fn sin_should_display(x: f32) {
231            prop_assert_eq!(val!(x).sin().to_string(), format!("{}.sin()", val!(x)));
232        }
233
234        #[proptest]
235        fn cos_should_display(x: f32) {
236            prop_assert_eq!(val!(x).cos().to_string(), format!("{}.cos()", val!(x)));
237        }
238
239        #[proptest]
240        fn tan_should_display(x: f32) {
241            prop_assert_eq!(val!(x).tan().to_string(), format!("{}.tan()", val!(x)));
242        }
243
244        #[proptest]
245        fn asin_should_display(x: f32) {
246            prop_assert_eq!(val!(x).asin().to_string(), format!("{}.asin()", val!(x)));
247        }
248
249        #[proptest]
250        fn acos_should_display(x: f32) {
251            prop_assert_eq!(val!(x).acos().to_string(), format!("{}.acos()", val!(x)));
252        }
253
254        #[proptest]
255        fn atan_should_display(x: f32) {
256            prop_assert_eq!(val!(x).atan().to_string(), format!("{}.atan()", val!(x)));
257        }
258    }
259}