computation_types/
run.rs

1mod into_vec;
2mod matrix;
3mod run_core;
4
5use crate::{Computation, NamedArgs};
6
7pub use self::{collect::*, into_vec::*, matrix::*, run_core::*};
8
9/// A computation that can be run
10/// without additional compilation.
11///
12/// This trait is automatically implemented
13/// for types implementing [`RunCore`].
14pub trait Run {
15    type Output;
16
17    fn run(self) -> Self::Output;
18}
19
20impl<T, Collected> Run for T
21where
22    T: Computation + RunCore,
23    T::Output: Collect<T::Dim, Collected = Collected>,
24{
25    type Output = Collected;
26
27    fn run(self) -> Self::Output {
28        self.run_core().collect()
29    }
30}
31
32mod function {
33    use crate::{function::Function, ComputationFn};
34
35    use super::{NamedArgs, Run, RunCore};
36
37    impl<ArgNames, Body> Function<ArgNames, Body> {
38        pub fn call<Args>(self, args: Args) -> <Body::Filled as Run>::Output
39        where
40            (ArgNames, Args): Into<NamedArgs>,
41            Body: ComputationFn,
42            Body::Filled: Run,
43        {
44            self.fill(args).run()
45        }
46
47        pub fn call_core<Args>(self, args: Args) -> <Body::Filled as RunCore>::Output
48        where
49            (ArgNames, Args): Into<NamedArgs>,
50            Body: ComputationFn,
51            Body::Filled: RunCore,
52        {
53            self.fill(args).run_core()
54        }
55    }
56}
57
58mod collect {
59    use paste::paste;
60
61    use crate::peano::{One, Two, Zero};
62
63    use super::{IntoVec, Matrix};
64
65    pub trait Collect<OutDims> {
66        type Collected;
67
68        fn collect(self) -> Self::Collected;
69    }
70
71    impl<T> Collect<Zero> for T {
72        type Collected = T;
73
74        fn collect(self) -> Self::Collected {
75            self
76        }
77    }
78
79    impl<T> Collect<One> for T
80    where
81        T: IntoVec,
82    {
83        type Collected = Vec<T::Item>;
84
85        fn collect(self) -> Self::Collected {
86            self.into_vec()
87        }
88    }
89
90    impl<V> Collect<Two> for Matrix<V>
91    where
92        V: IntoVec,
93    {
94        type Collected = Matrix<Vec<V::Item>>;
95
96        fn collect(self) -> Self::Collected {
97            // Neither shape nor the length of `inner` will change,
98            // so they should still be fine.
99            unsafe { Matrix::new_unchecked(self.shape(), self.into_inner().into_vec()) }
100        }
101    }
102
103    macro_rules! impl_collect_for_n_tuple {
104        ( $n:expr, $( $i:expr ),* ) => {
105            paste! {
106                impl< $( [<T $i>] ),* , $( [<DimT $i>] ),* > Collect<( $( [<DimT $i>] ),* )> for ( $( [<T $i>] ),* )
107                where
108                    $( [<T $i>]: Collect< [<DimT $i>] > ),*
109                {
110                    type Collected = ( $( [<T $i>]::Collected ),* );
111
112                    fn collect(self) -> Self::Collected {
113                        ( $( self.$i.collect() ),* )
114                    }
115                }
116            }
117        };
118    }
119
120    impl_collect_for_n_tuple!(2, 0, 1);
121    impl_collect_for_n_tuple!(3, 0, 1, 2);
122    impl_collect_for_n_tuple!(4, 0, 1, 2, 3);
123    impl_collect_for_n_tuple!(5, 0, 1, 2, 3, 4);
124    impl_collect_for_n_tuple!(6, 0, 1, 2, 3, 4, 5);
125    impl_collect_for_n_tuple!(7, 0, 1, 2, 3, 4, 5, 6);
126    impl_collect_for_n_tuple!(8, 0, 1, 2, 3, 4, 5, 6, 7);
127    impl_collect_for_n_tuple!(9, 0, 1, 2, 3, 4, 5, 6, 7, 8);
128    impl_collect_for_n_tuple!(10, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
129    impl_collect_for_n_tuple!(11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
130    impl_collect_for_n_tuple!(12, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);
131    impl_collect_for_n_tuple!(13, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12);
132    impl_collect_for_n_tuple!(14, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13);
133    impl_collect_for_n_tuple!(15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14);
134    impl_collect_for_n_tuple!(16, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
135}
136
137#[cfg(test)]
138mod tests {
139    use proptest::prelude::*;
140    use test_strategy::proptest;
141
142    use crate::{val, val1};
143
144    use super::*;
145
146    #[proptest]
147    fn operations_should_combine(
148        #[strategy(-1000..1000)] x: i32,
149        #[strategy(-1000..1000)] y: i32,
150        #[strategy(-1000..1000)] z: i32,
151    ) {
152        prop_assume!((y - z) != 0);
153        prop_assume!(z != 0);
154        prop_assert_eq!(
155            (val!(x) / (val!(y) - val!(z)) + -(val!(z) * val!(y))).run(),
156            x / (y - z) + -(z * y)
157        );
158        prop_assert_eq!(
159            (-(((val!(x) + val!(y) - val!(z)) / val!(z)) * val!(y))).run(),
160            -(((x + y - z) / z) * y)
161        );
162        prop_assert_eq!(-(-val!(x)).run(), -(-x));
163        prop_assert_eq!(
164            (val1!([x, y]) / (val!(y) - val!(z)) + -(val!(z) * val!(y))).run(),
165            [x / (y - z) + -(z * y), y / (y - z) + -(z * y)]
166        );
167    }
168}