computation_types/
linalg.rs

1use core::{fmt, ops};
2use std::marker::PhantomData;
3
4use crate::{
5    impl_computation_fn_for_binary, impl_core_ops,
6    math::Mul,
7    peano::{One, Two, Zero},
8    sum::Sum,
9    Computation, ComputationFn, NamedArgs,
10};
11
12/// See [`Computation::identity_matrix`].
13#[derive(Clone, Copy, Debug)]
14pub struct IdentityMatrix<Len, T>
15where
16    Self: Computation,
17{
18    pub len: Len,
19    pub(super) ty: PhantomData<T>,
20}
21
22impl<Len, T> IdentityMatrix<Len, T>
23where
24    Self: Computation,
25{
26    pub fn new(len: Len) -> Self {
27        Self {
28            len,
29            ty: PhantomData,
30        }
31    }
32}
33
34impl<Len, T> Computation for IdentityMatrix<Len, T>
35where
36    Len: Computation<Dim = Zero, Item = usize>,
37{
38    type Dim = Two;
39    type Item = T;
40}
41
42impl<Len, T> ComputationFn for IdentityMatrix<Len, T>
43where
44    Self: Computation,
45    Len: ComputationFn,
46    IdentityMatrix<Len::Filled, T>: Computation,
47{
48    type Filled = IdentityMatrix<Len::Filled, T>;
49
50    fn fill(self, named_args: NamedArgs) -> Self::Filled {
51        IdentityMatrix {
52            len: self.len.fill(named_args),
53            ty: self.ty,
54        }
55    }
56
57    fn arg_names(&self) -> crate::Names {
58        self.len.arg_names()
59    }
60}
61
62impl_core_ops!(IdentityMatrix<Len, T>);
63
64impl<Len, T> fmt::Display for IdentityMatrix<Len, T>
65where
66    Self: Computation,
67    Len: fmt::Display,
68{
69    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
70        write!(f, "identity_matrix({})", self.len)
71    }
72}
73
74/// A computation representing a diagonal matrix
75/// from an element.
76#[derive(Clone, Copy, Debug)]
77pub struct FromDiagElem<Len, Elem>
78where
79    Self: Computation,
80{
81    pub len: Len,
82    pub elem: Elem,
83}
84
85impl<Len, Elem> FromDiagElem<Len, Elem>
86where
87    Self: Computation,
88{
89    #[allow(missing_docs)]
90    pub fn new(len: Len, elem: Elem) -> Self {
91        Self { len, elem }
92    }
93}
94
95impl<Len, Elem> Computation for FromDiagElem<Len, Elem>
96where
97    Len: Computation<Dim = Zero, Item = usize>,
98    Elem: Computation<Dim = Zero>,
99{
100    type Dim = Two;
101    type Item = Elem::Item;
102}
103
104impl<Len, Elem> ComputationFn for FromDiagElem<Len, Elem>
105where
106    Self: Computation,
107    Len: ComputationFn,
108    Elem: ComputationFn,
109    FromDiagElem<Len::Filled, Elem::Filled>: Computation,
110{
111    type Filled = FromDiagElem<Len::Filled, Elem::Filled>;
112
113    fn fill(self, named_args: NamedArgs) -> Self::Filled {
114        let (args_0, args_1) = named_args
115            .partition(&self.len.arg_names(), &self.elem.arg_names())
116            .unwrap_or_else(|e| panic!("{}", e,));
117        FromDiagElem {
118            len: self.len.fill(args_0),
119            elem: self.elem.fill(args_1),
120        }
121    }
122
123    fn arg_names(&self) -> crate::Names {
124        self.len.arg_names().union(self.elem.arg_names())
125    }
126}
127
128impl_core_ops!(FromDiagElem<Len, Elem>);
129
130impl<Len, Elem> fmt::Display for FromDiagElem<Len, Elem>
131where
132    Self: Computation,
133    Len: fmt::Display,
134    Elem: fmt::Display,
135{
136    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
137        write!(f, "from_diag_elem({}, {})", self.len, self.elem)
138    }
139}
140
141/// See [`Computation::scalar_product`].
142pub type ScalarProduct<A, B> = Sum<Mul<A, B>>;
143
144/// See [`Computation::scalar_product`].
145pub fn scalar_product<A, B>(x: A, y: B) -> ScalarProduct<A, B>
146where
147    Mul<A, B>: Computation,
148    ScalarProduct<A, B>: Computation,
149{
150    Sum(Mul(x, y))
151}
152
153// With better support for overlapping trait-implementations
154// we could add the following:
155//
156// ```
157// impl<A, B> fmt::Display for ScalarProduct<A, B>
158// where
159//     Self: Computation,
160//     A: fmt::Display,
161//     B: fmt::Display,
162// {
163//     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
164//         write!(f, "({} . {})", self.0, self.1)
165//     }
166// }
167// ```
168
169/// See [`Computation::mat_mul`].
170#[derive(Clone, Copy, Debug)]
171pub struct MatMul<A, B>(pub A, pub B)
172where
173    Self: Computation;
174
175impl<A, B> Computation for MatMul<A, B>
176where
177    A: Computation<Dim = Two>,
178    B: Computation<Dim = Two>,
179    A::Item: ops::Mul<B::Item>,
180    <A::Item as ops::Mul<B::Item>>::Output: ops::Add,
181{
182    type Dim = Two;
183    type Item = <<A::Item as ops::Mul<B::Item>>::Output as ops::Add>::Output;
184}
185
186impl_computation_fn_for_binary!(MatMul);
187
188impl_core_ops!(MatMul<A, B>);
189
190impl<A, B> fmt::Display for MatMul<A, B>
191where
192    Self: Computation,
193    A: fmt::Display,
194    B: fmt::Display,
195{
196    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
197        write!(f, "({} x {})", self.0, self.1)
198    }
199}
200
201/// See [`Computation::mul_out`].
202#[derive(Clone, Copy, Debug)]
203pub struct MulOut<A, B>(pub A, pub B)
204where
205    Self: Computation;
206
207impl<A, B> Computation for MulOut<A, B>
208where
209    A: Computation<Dim = One>,
210    B: Computation<Dim = One>,
211    A::Item: ops::Mul<B::Item>,
212{
213    type Dim = Two;
214    type Item = <A::Item as ops::Mul<B::Item>>::Output;
215}
216
217impl_computation_fn_for_binary!(MulOut);
218
219impl_core_ops!(MulOut<A, B>);
220
221impl<A, B> fmt::Display for MulOut<A, B>
222where
223    Self: Computation,
224    A: fmt::Display,
225    B: fmt::Display,
226{
227    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
228        write!(f, "(col({}) x row({}))", self.0, self.1)
229    }
230}
231
232/// See [`Computation::mul_col`].
233#[derive(Clone, Copy, Debug)]
234pub struct MulCol<A, B>(pub A, pub B)
235where
236    Self: Computation;
237
238impl<A, B> Computation for MulCol<A, B>
239where
240    A: Computation<Dim = Two>,
241    B: Computation<Dim = One>,
242    A::Item: ops::Mul<B::Item>,
243    <A::Item as ops::Mul<B::Item>>::Output: ops::Add,
244{
245    type Dim = One;
246    type Item = <<A::Item as ops::Mul<B::Item>>::Output as ops::Add>::Output;
247}
248
249impl_computation_fn_for_binary!(MulCol);
250
251impl_core_ops!(MulCol<A, B>);
252
253impl<A, B> fmt::Display for MulCol<A, B>
254where
255    Self: Computation,
256    A: fmt::Display,
257    B: fmt::Display,
258{
259    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
260        write!(f, "({} x col({}))", self.0, self.1)
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use proptest::prelude::*;
267    use test_strategy::proptest;
268
269    use crate::{linalg::FromDiagElem, run::Matrix, val, val1, val2, Computation};
270
271    #[proptest]
272    fn identity_matrix_should_display(x: usize) {
273        let inp = val!(x);
274        prop_assert_eq!(
275            inp.identity_matrix::<i32>().to_string(),
276            format!("identity_matrix({})", inp)
277        );
278    }
279
280    #[proptest]
281    fn from_diag_elem_should_display(len: usize, elem: i32) {
282        let len = val!(len);
283        let elem = val!(elem);
284        prop_assert_eq!(
285            FromDiagElem::new(len, elem).to_string(),
286            format!("from_diag_elem({}, {})", len, elem)
287        );
288    }
289
290    // With better support for overlapping trait-implementations
291    // we could add the following:
292    //
293    // ```
294    // #[proptest]
295    // fn scalar_product_should_display(x: i32, y: i32, z: i32, q: i32) {
296    //     let lhs = val1!([x, y]);
297    //     let rhs = val1!([z, q]);
298    //     prop_assert_eq!(
299    //         lhs.scalar_product(rhs).to_string(),
300    //         format!("({} . {})", lhs, rhs)
301    //     );
302    // }
303    // ```
304
305    #[proptest]
306    fn mat_mul_should_display(x: i32, y: i32, z: i32, q: i32) {
307        let lhs = val2!(Matrix::from_vec((2, 1), vec![x, y]).unwrap());
308        let rhs = val2!(Matrix::from_vec((1, 2), vec![z, q]).unwrap());
309        prop_assert_eq!(
310            lhs.clone().mat_mul(rhs.clone()).to_string(),
311            format!("({} x {})", lhs, rhs)
312        );
313    }
314
315    #[proptest]
316    fn mul_out_should_display(x: i32, y: i32, z: i32, q: i32) {
317        let lhs = val1!([x, y]);
318        let rhs = val1!([z, q]);
319        prop_assert_eq!(
320            lhs.mul_out(rhs).to_string(),
321            format!("(col({}) x row({}))", lhs, rhs)
322        );
323    }
324
325    #[proptest]
326    fn mul_col_should_display(x: i32, y: i32, z: i32, q: i32) {
327        let lhs = val2!(Matrix::from_vec((2, 1), vec![x, y]).unwrap());
328        let rhs = val1!([z, q]);
329        prop_assert_eq!(
330            lhs.clone().mul_col(rhs).to_string(),
331            format!("({} x col({}))", lhs, rhs)
332        );
333    }
334}