1use core::{fmt, ops};
2
3use paste::paste;
4
5use crate::{
6 impl_computation_fn_for_binary, impl_computation_fn_for_unary, impl_core_ops,
7 impl_display_for_inline_binary, Computation, ComputationFn, NamedArgs,
8};
9
10pub use self::{same_or_zero::*, trig::*};
11
12mod same_or_zero {
13 use crate::peano::{Suc, Zero};
14
15 pub trait SameOrZero<B> {
16 type Max;
17 }
18
19 impl SameOrZero<Zero> for Zero {
20 type Max = Zero;
21 }
22
23 impl<A> SameOrZero<Suc<A>> for Zero {
24 type Max = Suc<A>;
25 }
26
27 impl<A> SameOrZero<Zero> for Suc<A> {
28 type Max = Suc<A>;
29 }
30
31 impl<A> SameOrZero<Suc<A>> for Suc<A> {
32 type Max = Suc<A>;
33 }
34}
35
36macro_rules! impl_binary_op {
37 ( $op:ident ) => {
38 impl_binary_op!($op, ops);
39 };
40 ( $op:ident, $package:ident ) => {
41 impl_binary_op!($op, $package, $op);
42 };
43 ( $op:ident, $package:ident, $bound:ident ) => {
44 paste! {
45 #[derive(Clone, Copy, Debug)]
46 pub struct $op<A, B>(pub A, pub B)
47 where
48 Self: Computation;
49
50 impl<A, B, ADim, AItem> Computation for $op<A, B>
51 where
52 A: Computation<Dim = ADim, Item = AItem>,
53 B: Computation,
54 ADim: SameOrZero<B::Dim>,
55 AItem: $package::$bound<B::Item>,
56 {
57 type Dim = ADim::Max;
58 type Item = AItem::Output;
59 }
60
61 impl_computation_fn_for_binary!($op);
62
63 impl_core_ops!($op<A, B>);
64 }
65 };
66}
67
68macro_rules! impl_unary_op {
69 ( $op:ident ) => {
70 impl_unary_op!($op, ops);
71 };
72 ( $op:ident, $package:ident ) => {
73 impl_unary_op!($op, $package, $op);
74 };
75 ( $op:ident, $package:ident, $bound:ident ) => {
76 impl_unary_op!($op, $package, $bound, Item::Output);
77 };
78 ( $op:ident, $package:ident, $bound:ident, Item $( :: $Output:ident )? ) => {
79 paste! {
80 #[derive(Clone, Copy, Debug)]
81 pub struct $op<A>(pub A)
82 where
83 Self: Computation;
84
85
86 impl<A, Item> Computation for $op<A>
87 where
88 A: Computation<Item = Item>,
89 Item: $package::$bound,
90 {
91 type Dim = A::Dim;
92 type Item = Item $( ::$Output )?;
93 }
94
95 impl_computation_fn_for_unary!($op);
96
97 impl_core_ops!($op<A>);
98 }
99 };
100}
101
102impl_binary_op!(Add);
103impl_binary_op!(Sub);
104impl_binary_op!(Mul);
105impl_binary_op!(Div);
106impl_binary_op!(Pow, num_traits);
107impl_unary_op!(Neg);
108impl_unary_op!(Abs, num_traits, Signed, Item);
109
110impl_display_for_inline_binary!(Add, "+");
111impl_display_for_inline_binary!(Sub, "-");
112impl_display_for_inline_binary!(Mul, "*");
113impl_display_for_inline_binary!(Div, "/");
114impl_display_for_inline_binary!(Pow, "^");
115
116impl<A> fmt::Display for Neg<A>
117where
118 Self: Computation,
119 A: fmt::Display,
120{
121 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122 write!(f, "-{}", self.0)
123 }
124}
125
126impl<A> fmt::Display for Abs<A>
127where
128 Self: Computation,
129 A: fmt::Display,
130{
131 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132 write!(f, "{}.abs()", self.0)
133 }
134}
135
136mod trig {
137 use num_traits::real;
138
139 use super::*;
140
141 impl_unary_op!(Sin, real, Real, Item);
142 impl_unary_op!(Cos, real, Real, Item);
143 impl_unary_op!(Tan, real, Real, Item);
144 impl_unary_op!(Asin, real, Real, Item);
145 impl_unary_op!(Acos, real, Real, Item);
146 impl_unary_op!(Atan, real, Real, Item);
147
148 macro_rules! impl_display {
149 ( $op:ident ) => {
150 paste::paste! {
151 impl<A> fmt::Display for $op<A>
152 where
153 Self: Computation,
154 A: fmt::Display,
155 {
156 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157 write!(f, "{}.{}()", self.0, stringify!([<$op:lower>]))
158 }
159 }
160 }
161 };
162 }
163
164 impl_display!(Sin);
165 impl_display!(Cos);
166 impl_display!(Tan);
167 impl_display!(Asin);
168 impl_display!(Acos);
169 impl_display!(Atan);
170}
171
172#[cfg(test)]
173mod tests {
174 use proptest::prelude::*;
175 use test_strategy::proptest;
176
177 use crate::{val, Computation};
178
179 macro_rules! assert_op_display {
180 ( $x:ident $op:tt $y:ident ) => {
181 prop_assert_eq!((val!($x) $op val!($y)).to_string(), format!("({} {} {})", val!($x), stringify!($op), val!($y)));
182 };
183 ( $x:ident . $op:ident ( $y:ident ) ) => {
184 prop_assert_eq!(val!($x).$op(val!($y)).to_string(), format!("{}.{}({})", val!($x), stringify!($op), val!($y)));
185 };
186 }
187
188 #[proptest]
189 fn add_should_display(x: i32, y: i32) {
190 assert_op_display!(x + y);
191 }
192
193 #[proptest]
194 fn sub_should_display(x: i32, y: i32) {
195 assert_op_display!(x - y);
196 }
197
198 #[proptest]
199 fn mul_should_display(x: i32, y: i32) {
200 assert_op_display!(x * y);
201 }
202
203 #[proptest]
204 fn div_should_display(x: i32, y: i32) {
205 assert_op_display!(x / y);
206 }
207
208 #[proptest]
209 fn pow_should_display(x: i32, y: u32) {
210 prop_assert_eq!(
211 val!(x).pow(val!(y)).to_string(),
212 format!("({} ^ {})", val!(x), val!(y))
213 );
214 }
215
216 #[proptest]
217 fn neg_should_display(x: i32) {
218 prop_assert_eq!((-val!(x)).to_string(), format!("-{}", val!(x)));
219 }
220
221 #[proptest]
222 fn abs_should_display(x: i32) {
223 prop_assert_eq!(val!(x).abs().to_string(), format!("{}.abs()", val!(x)));
224 }
225
226 mod trig {
227 use super::*;
228
229 #[proptest]
230 fn sin_should_display(x: f32) {
231 prop_assert_eq!(val!(x).sin().to_string(), format!("{}.sin()", val!(x)));
232 }
233
234 #[proptest]
235 fn cos_should_display(x: f32) {
236 prop_assert_eq!(val!(x).cos().to_string(), format!("{}.cos()", val!(x)));
237 }
238
239 #[proptest]
240 fn tan_should_display(x: f32) {
241 prop_assert_eq!(val!(x).tan().to_string(), format!("{}.tan()", val!(x)));
242 }
243
244 #[proptest]
245 fn asin_should_display(x: f32) {
246 prop_assert_eq!(val!(x).asin().to_string(), format!("{}.asin()", val!(x)));
247 }
248
249 #[proptest]
250 fn acos_should_display(x: f32) {
251 prop_assert_eq!(val!(x).acos().to_string(), format!("{}.acos()", val!(x)));
252 }
253
254 #[proptest]
255 fn atan_should_display(x: f32) {
256 prop_assert_eq!(val!(x).atan().to_string(), format!("{}.atan()", val!(x)));
257 }
258 }
259}