computation_types/
lib.rs

1#![feature(min_specialization)]
2#![feature(unsized_fn_params)]
3#![allow(internal_features)]
4#![warn(missing_debug_implementations)]
5
6//! Types for abstract mathematical computation.
7//!
8//! Note,
9//! Documentation is currently lacking.
10//! The best way to learn about this framework
11//! is to read the tests
12//! and see how it is used to implement algorithms
13//! in Optimal.
14//!
15//! # Examples
16//!
17//! ```
18//! use computation_types::{named_args, val, Run};
19//!
20//! let one_plus_one = val!(1) + val!(1);
21//! assert_eq!(one_plus_one.to_string(), "(1 + 1)");
22//! assert_eq!(one_plus_one.run(), 2);
23//! ```
24
25mod function;
26pub mod macros;
27mod named_args;
28mod names;
29pub mod peano;
30pub mod run;
31
32pub mod arg;
33pub mod black_box;
34pub mod cmp;
35pub mod control_flow;
36pub mod enumerate;
37pub mod len;
38pub mod linalg;
39pub mod math;
40pub mod rand;
41pub mod sum;
42pub mod val;
43pub mod zip;
44
45pub use crate::{arg::*, function::*, named_args::*, names::*, run::Run, val::*};
46
47/// A type representing a computation.
48///
49/// This trait does little on its own.
50/// Additional traits,
51/// such as [`Run`],
52/// must be implemented
53/// to use a computation.
54#[allow(clippy::len_without_is_empty)]
55pub trait Computation {
56    type Dim;
57    type Item;
58
59    // `math`
60
61    fn add<Rhs>(self, rhs: Rhs) -> math::Add<Self, Rhs>
62    where
63        Self: Sized,
64        math::Add<Self, Rhs>: Computation,
65    {
66        math::Add(self, rhs)
67    }
68
69    fn sub<Rhs>(self, rhs: Rhs) -> math::Sub<Self, Rhs>
70    where
71        Self: Sized,
72        math::Sub<Self, Rhs>: Computation,
73    {
74        math::Sub(self, rhs)
75    }
76
77    fn mul<Rhs>(self, rhs: Rhs) -> math::Mul<Self, Rhs>
78    where
79        Self: Sized,
80        math::Mul<Self, Rhs>: Computation,
81    {
82        math::Mul(self, rhs)
83    }
84
85    fn div<Rhs>(self, rhs: Rhs) -> math::Div<Self, Rhs>
86    where
87        Self: Sized,
88        math::Div<Self, Rhs>: Computation,
89    {
90        math::Div(self, rhs)
91    }
92
93    fn pow<Rhs>(self, rhs: Rhs) -> math::Pow<Self, Rhs>
94    where
95        Self: Sized,
96        math::Pow<Self, Rhs>: Computation,
97    {
98        math::Pow(self, rhs)
99    }
100
101    fn neg(self) -> math::Neg<Self>
102    where
103        Self: Sized,
104        math::Neg<Self>: Computation,
105    {
106        math::Neg(self)
107    }
108
109    fn abs(self) -> math::Abs<Self>
110    where
111        Self: Sized,
112        math::Abs<Self>: Computation,
113    {
114        math::Abs(self)
115    }
116
117    // `math::trig`
118
119    fn sin(self) -> math::Sin<Self>
120    where
121        Self: Sized,
122        math::Sin<Self>: Computation,
123    {
124        math::Sin(self)
125    }
126
127    fn cos(self) -> math::Cos<Self>
128    where
129        Self: Sized,
130        math::Cos<Self>: Computation,
131    {
132        math::Cos(self)
133    }
134
135    fn tan(self) -> math::Tan<Self>
136    where
137        Self: Sized,
138        math::Tan<Self>: Computation,
139    {
140        math::Tan(self)
141    }
142
143    fn asin(self) -> math::Asin<Self>
144    where
145        Self: Sized,
146        math::Asin<Self>: Computation,
147    {
148        math::Asin(self)
149    }
150
151    fn acos(self) -> math::Acos<Self>
152    where
153        Self: Sized,
154        math::Acos<Self>: Computation,
155    {
156        math::Acos(self)
157    }
158
159    fn atan(self) -> math::Atan<Self>
160    where
161        Self: Sized,
162        math::Atan<Self>: Computation,
163    {
164        math::Atan(self)
165    }
166
167    // `cmp`
168
169    fn eq<Rhs>(self, rhs: Rhs) -> cmp::Eq<Self, Rhs>
170    where
171        Self: Sized,
172        cmp::Eq<Self, Rhs>: Computation,
173    {
174        cmp::Eq(self, rhs)
175    }
176
177    fn ne<Rhs>(self, rhs: Rhs) -> cmp::Ne<Self, Rhs>
178    where
179        Self: Sized,
180        cmp::Ne<Self, Rhs>: Computation,
181    {
182        cmp::Ne(self, rhs)
183    }
184
185    fn lt<Rhs>(self, rhs: Rhs) -> cmp::Lt<Self, Rhs>
186    where
187        Self: Sized,
188        cmp::Lt<Self, Rhs>: Computation,
189    {
190        cmp::Lt(self, rhs)
191    }
192
193    fn le<Rhs>(self, rhs: Rhs) -> cmp::Le<Self, Rhs>
194    where
195        Self: Sized,
196        cmp::Le<Self, Rhs>: Computation,
197    {
198        cmp::Le(self, rhs)
199    }
200
201    fn gt<Rhs>(self, rhs: Rhs) -> cmp::Gt<Self, Rhs>
202    where
203        Self: Sized,
204        cmp::Gt<Self, Rhs>: Computation,
205    {
206        cmp::Gt(self, rhs)
207    }
208
209    fn ge<Rhs>(self, rhs: Rhs) -> cmp::Ge<Self, Rhs>
210    where
211        Self: Sized,
212        cmp::Ge<Self, Rhs>: Computation,
213    {
214        cmp::Ge(self, rhs)
215    }
216
217    fn max(self) -> cmp::Max<Self>
218    where
219        Self: Sized,
220        cmp::Max<Self>: Computation,
221    {
222        cmp::Max(self)
223    }
224
225    fn not(self) -> cmp::Not<Self>
226    where
227        Self: Sized,
228        cmp::Not<Self>: Computation,
229    {
230        cmp::Not(self)
231    }
232
233    // `enumerate`
234
235    fn enumerate<F>(self, f: Function<(Name, Name), F>) -> enumerate::Enumerate<Self, F>
236    where
237        Self: Sized,
238        enumerate::Enumerate<Self, F>: Computation,
239    {
240        enumerate::Enumerate { child: self, f }
241    }
242
243    // `sum`
244
245    fn sum(self) -> sum::Sum<Self>
246    where
247        Self: Sized,
248        sum::Sum<Self>: Computation,
249    {
250        sum::Sum(self)
251    }
252
253    // `zip`
254
255    fn zip<Rhs>(self, rhs: Rhs) -> zip::Zip<Self, Rhs>
256    where
257        Self: Sized,
258        zip::Zip<Self, Rhs>: Computation,
259    {
260        zip::Zip(self, rhs)
261    }
262
263    fn fst(self) -> zip::Fst<Self>
264    where
265        Self: Sized,
266        zip::Fst<Self>: Computation,
267    {
268        zip::Fst(self)
269    }
270
271    fn snd(self) -> zip::Snd<Self>
272    where
273        Self: Sized,
274        zip::Snd<Self>: Computation,
275    {
276        zip::Snd(self)
277    }
278
279    // `black_box`
280
281    /// Run the given regular function `F`.
282    ///
283    /// This acts as an escape-hatch to allow regular Rust-code in a computation,
284    /// but the computation may lose features or efficiency if it is used.
285    fn black_box<F, FDim, FItem>(self, f: F) -> black_box::BlackBox<Self, F, FDim, FItem>
286    where
287        Self: Sized,
288        black_box::BlackBox<Self, F, FDim, FItem>: Computation,
289    {
290        black_box::BlackBox::new(self, f)
291    }
292
293    // `control_flow`
294
295    fn if_<ArgNames, P, FTrue, FFalse>(
296        self,
297        arg_names: ArgNames,
298        predicate: P,
299        f_true: FTrue,
300        f_false: FFalse,
301    ) -> control_flow::If<Self, ArgNames, P, FTrue, FFalse>
302    where
303        Self: Sized,
304        control_flow::If<Self, ArgNames, P, FTrue, FFalse>: Computation,
305    {
306        control_flow::If {
307            child: self,
308            arg_names,
309            predicate,
310            f_true,
311            f_false,
312        }
313    }
314
315    fn loop_while<ArgNames, F, P>(
316        self,
317        arg_names: ArgNames,
318        f: F,
319        predicate: P,
320    ) -> control_flow::LoopWhile<Self, ArgNames, F, P>
321    where
322        Self: Sized,
323        control_flow::LoopWhile<Self, ArgNames, F, P>: Computation,
324    {
325        control_flow::LoopWhile {
326            child: self,
327            arg_names,
328            f,
329            predicate,
330        }
331    }
332
333    fn then<ArgNames, F>(
334        self,
335        f: function::Function<ArgNames, F>,
336    ) -> control_flow::Then<Self, ArgNames, F>
337    where
338        Self: Sized,
339        control_flow::Then<Self, ArgNames, F>: Computation,
340    {
341        control_flow::Then { child: self, f }
342    }
343
344    // `linalg`
345
346    /// Return a `self` by `self` identity-matrix.
347    ///
348    /// Diagonal elements have a value of `1`,
349    /// and non-diagonal elements have a value of `0`.
350    ///
351    /// The type of elements,
352    /// `T`,
353    /// may need to be specified.
354    fn identity_matrix<T>(self) -> linalg::IdentityMatrix<Self, T>
355    where
356        Self: Sized,
357        linalg::IdentityMatrix<Self, T>: Computation,
358    {
359        linalg::IdentityMatrix::new(self)
360    }
361
362    /// Multiply and sum the elements of two vectors.
363    ///
364    /// This is sometimes known as the "inner product"
365    /// or "dot product".
366    fn scalar_product<Rhs>(self, rhs: Rhs) -> linalg::ScalarProduct<Self, Rhs>
367    where
368        Self: Sized,
369        math::Mul<Self, Rhs>: Computation,
370        linalg::ScalarProduct<Self, Rhs>: Computation,
371    {
372        linalg::scalar_product(self, rhs)
373    }
374
375    /// Perform matrix-multiplication.
376    fn mat_mul<Rhs>(self, rhs: Rhs) -> linalg::MatMul<Self, Rhs>
377    where
378        Self: Sized,
379        linalg::MatMul<Self, Rhs>: Computation,
380    {
381        linalg::MatMul(self, rhs)
382    }
383
384    /// Multiply elements from the Cartesian product of two vectors.
385    ///
386    /// This is sometimes known as "outer product",
387    /// and it is equivalent to matrix-multiplying a column-matrix by a row-matrix.
388    fn mul_out<Rhs>(self, rhs: Rhs) -> linalg::MulOut<Self, Rhs>
389    where
390        Self: Sized,
391        linalg::MulOut<Self, Rhs>: Computation,
392    {
393        linalg::MulOut(self, rhs)
394    }
395
396    /// Matrix-multiply a matrix by a column-matrix,
397    /// returning a vector.
398    fn mul_col<Rhs>(self, rhs: Rhs) -> linalg::MulCol<Self, Rhs>
399    where
400        Self: Sized,
401        linalg::MulCol<Self, Rhs>: Computation,
402    {
403        linalg::MulCol(self, rhs)
404    }
405
406    // Other
407
408    fn len(self) -> len::Len<Self>
409    where
410        Self: Sized,
411        len::Len<Self>: Computation,
412    {
413        len::Len(self)
414    }
415}
416
417impl<T> Computation for &T
418where
419    T: Computation + ?Sized,
420{
421    type Dim = T::Dim;
422    type Item = T::Item;
423}
424
425impl<T> Computation for &mut T
426where
427    T: Computation + ?Sized,
428{
429    type Dim = T::Dim;
430    type Item = T::Item;
431}
432
433impl<T> Computation for Box<T>
434where
435    T: Computation + ?Sized,
436{
437    type Dim = T::Dim;
438    type Item = T::Item;
439}
440
441impl<T> Computation for std::rc::Rc<T>
442where
443    T: Computation + ?Sized,
444{
445    type Dim = T::Dim;
446    type Item = T::Item;
447}
448
449impl<T> Computation for std::sync::Arc<T>
450where
451    T: Computation + ?Sized,
452{
453    type Dim = T::Dim;
454    type Item = T::Item;
455}
456
457impl<T> Computation for std::borrow::Cow<'_, T>
458where
459    T: Computation + ToOwned + ?Sized,
460{
461    type Dim = T::Dim;
462    type Item = T::Item;
463}
464
465/// A type representing a function-like computation.
466///
467/// Most computations should implement this,
468/// even if they represent a function with zero arguments.
469pub trait ComputationFn: Computation {
470    type Filled;
471
472    /// Fill arguments will values,
473    /// replacing `Arg`s with `Val`s.
474    fn fill(self, named_args: NamedArgs) -> Self::Filled;
475
476    fn arg_names(&self) -> Names;
477}
478
479impl<T> ComputationFn for &T
480where
481    T: ComputationFn + ToOwned + ?Sized,
482    T::Owned: ComputationFn,
483{
484    type Filled = <T::Owned as ComputationFn>::Filled;
485
486    fn fill(self, named_args: NamedArgs) -> Self::Filled {
487        self.to_owned().fill(named_args)
488    }
489
490    fn arg_names(&self) -> Names {
491        (*(*self)).arg_names()
492    }
493}
494
495impl<T> ComputationFn for &mut T
496where
497    T: ComputationFn + ToOwned + ?Sized,
498    T::Owned: ComputationFn,
499{
500    type Filled = <T::Owned as ComputationFn>::Filled;
501
502    fn fill(self, named_args: NamedArgs) -> Self::Filled {
503        self.to_owned().fill(named_args)
504    }
505
506    fn arg_names(&self) -> Names {
507        (*(*self)).arg_names()
508    }
509}
510
511impl<T> ComputationFn for Box<T>
512where
513    T: ComputationFn + ?Sized,
514{
515    type Filled = T::Filled;
516
517    fn fill(self, named_args: NamedArgs) -> Self::Filled {
518        (*self).fill(named_args)
519    }
520
521    fn arg_names(&self) -> Names {
522        (*(*self)).arg_names()
523    }
524}
525
526impl<T> ComputationFn for std::rc::Rc<T>
527where
528    T: ComputationFn + ToOwned + ?Sized,
529    T::Owned: ComputationFn,
530{
531    type Filled = <T::Owned as ComputationFn>::Filled;
532
533    fn fill(self, named_args: NamedArgs) -> Self::Filled {
534        self.as_ref().to_owned().fill(named_args)
535    }
536
537    fn arg_names(&self) -> Names {
538        (*(*self)).arg_names()
539    }
540}
541
542impl<T> ComputationFn for std::sync::Arc<T>
543where
544    T: ComputationFn + ToOwned + ?Sized,
545    T::Owned: ComputationFn,
546{
547    type Filled = <T::Owned as ComputationFn>::Filled;
548
549    fn fill(self, named_args: NamedArgs) -> Self::Filled {
550        self.as_ref().to_owned().fill(named_args)
551    }
552
553    fn arg_names(&self) -> Names {
554        (*(*self)).arg_names()
555    }
556}
557
558impl<T> ComputationFn for std::borrow::Cow<'_, T>
559where
560    T: ComputationFn + ToOwned + ?Sized,
561    T::Owned: ComputationFn,
562{
563    type Filled = <T::Owned as ComputationFn>::Filled;
564
565    fn fill(self, named_args: NamedArgs) -> Self::Filled {
566        self.into_owned().fill(named_args)
567    }
568
569    fn arg_names(&self) -> Names {
570        (*(*self)).arg_names()
571    }
572}
573
574#[cfg(test)]
575mod tests {
576    use super::*;
577
578    // The following test requires `Eq` for computation-types:
579    // ```
580    // #[proptest]
581    // fn args_should_propagate_correctly(
582    //     #[strategy(-1000..1000)] x: i32,
583    //     #[strategy(-1000..1000)] y: i32,
584    //     #[strategy(-1000..1000)] z: i32,
585    //     #[strategy(-1000..1000)] in_x: i32,
586    //     #[strategy(-1000..1000)] in_y: i32,
587    //     #[strategy(-1000..1000)] in_z: i32,
588    // ) {
589    //     prop_assume!((x - in_y) != 0);
590    //     prop_assume!(z != 0);
591    //     prop_assert_eq!(
592    //         (arg!("foo", i32) / (val!(x) - arg!("bar", i32))
593    //             + -(val!(z) * val!(y) + arg!("baz", i32)))
594    //         .fill(named_args![("foo", in_x), ("bar", in_y), ("baz", in_z)]),
595    //         val!(in_x) / (val!(x) - val!(in_y)) + -(val!(z) * val!(y) + val!(in_z))
596    //     );
597    //     prop_assert_eq!(
598    //         (arg!("foo", i32)
599    //             + (((val!(x) + val!(y) - arg!("bar", i32)) / -val!(z)) * arg!("baz", i32)))
600    //         .fill(named_args![("foo", in_x), ("bar", in_y), ("baz", in_z)]),
601    //         val!(in_x) + (((val!(x) + val!(y) - val!(in_y)) / -val!(z)) * val!(in_z))
602    //     );
603    //     prop_assert_eq!(
604    //         -(-arg!("foo", i32)).fill(named_args![("foo", x)]),
605    //         -(-val!(x))
606    //     );
607    // }
608    // ```
609
610    mod dynamic {
611        use ::rand::distributions::Uniform;
612        use peano::Zero;
613        use run::RunCore;
614        use zip::{Zip, Zip3, Zip4};
615
616        use self::rand::Rand;
617
618        use super::*;
619
620        #[test]
621        fn the_framework_should_support_dynamic_objective_functions() {
622            trait ObjFunc:
623                ComputationFn<Dim = Zero, Item = f64, Filled = Box<dyn FilledObjFunc>>
624            {
625                fn boxed_clone(&self) -> Box<dyn ObjFunc>;
626            }
627            impl<T> ObjFunc for T
628            where
629                T: 'static
630                    + Clone
631                    + ComputationFn<Dim = Zero, Item = f64, Filled = Box<dyn FilledObjFunc>>,
632            {
633                fn boxed_clone(&self) -> Box<dyn ObjFunc> {
634                    Box::new(self.clone())
635                }
636            }
637            impl Clone for Box<dyn ObjFunc> {
638                fn clone(&self) -> Self {
639                    self.as_ref().boxed_clone()
640                }
641            }
642
643            trait FilledObjFunc: Computation<Dim = Zero, Item = f64> + RunCore<Output = f64> {
644                fn boxed_clone(&self) -> Box<dyn FilledObjFunc>;
645            }
646            impl<T> FilledObjFunc for T
647            where
648                T: 'static + Clone + Computation<Dim = Zero, Item = f64> + RunCore<Output = f64>,
649            {
650                fn boxed_clone(&self) -> Box<dyn FilledObjFunc> {
651                    Box::new(self.clone())
652                }
653            }
654            impl Clone for Box<dyn FilledObjFunc> {
655                fn clone(&self) -> Self {
656                    self.as_ref().boxed_clone()
657                }
658            }
659
660            fn random_optimizer(
661                len: usize,
662                samples: usize,
663                obj_func: Box<dyn ObjFunc>,
664            ) -> impl Run<Output = Vec<f64>> {
665                let distr = std::iter::repeat(Uniform::new(0.0, 1.0))
666                    .take(len)
667                    .collect::<Vec<_>>();
668                Zip(
669                    val!(1_usize),
670                    Rand::<Val1<Vec<Uniform<f64>>>, f64>::new(val1!(distr.clone())).then(
671                        Function::anonymous("point", Zip(arg1!("point", f64), obj_func.clone())),
672                    ),
673                )
674                .loop_while(
675                    ("i", ("best_point", "best_value")),
676                    Zip(
677                        arg!("i", usize) + val!(1_usize),
678                        Zip3(
679                            arg1!("best_point", f64),
680                            arg!("best_value", f64),
681                            Rand::<Val1<Vec<Uniform<f64>>>, f64>::new(val1!(distr)),
682                        )
683                        .then(Function::anonymous(
684                            ("best_point", "best_value", "point"),
685                            Zip4(
686                                arg1!("best_point", f64),
687                                arg!("best_value", f64),
688                                arg1!("point", f64),
689                                obj_func,
690                            )
691                            .then(Function::anonymous(
692                                ("best_point", "best_value", "point", "value"),
693                                Zip4(
694                                    arg1!("best_point", f64),
695                                    arg!("best_value", f64),
696                                    arg1!("point", f64),
697                                    arg!("value", f64),
698                                )
699                                .if_(
700                                    ("best_point", "best_value", "point", "value"),
701                                    arg!("value", f64).lt(arg!("best_value", f64)),
702                                    Zip(arg1!("point", f64), arg!("value", f64)),
703                                    Zip(arg1!("best_point", f64), arg!("best_value", f64)),
704                                ),
705                            )),
706                        )),
707                    ),
708                    arg!("i", usize).lt(val!(samples)),
709                )
710                .then(Function::anonymous(
711                    ("i", ("best_point", "best_value")),
712                    arg1!("best_point", f64),
713                ))
714            }
715
716            #[derive(Clone, Copy, Debug)]
717            struct BoxFillObjFunc<A>(A);
718
719            impl<A> Computation for BoxFillObjFunc<A>
720            where
721                A: Computation,
722            {
723                type Dim = A::Dim;
724                type Item = A::Item;
725            }
726
727            impl<A> ComputationFn for BoxFillObjFunc<A>
728            where
729                A: ComputationFn,
730                A::Filled: 'static + FilledObjFunc,
731            {
732                type Filled = Box<dyn FilledObjFunc>;
733
734                fn fill(self, named_args: NamedArgs) -> Self::Filled {
735                    Box::new(self.0.fill(named_args))
736                }
737
738                fn arg_names(&self) -> Names {
739                    self.0.arg_names()
740                }
741            }
742
743            random_optimizer(2, 10, Box::new(BoxFillObjFunc(arg1!("point", f64).sum()))).run();
744        }
745    }
746}