1use core::{fmt, ops};
2use std::marker::PhantomData;
3
4use crate::{
5 impl_computation_fn_for_binary, impl_core_ops,
6 math::Mul,
7 peano::{One, Two, Zero},
8 sum::Sum,
9 Computation, ComputationFn, NamedArgs,
10};
11
12#[derive(Clone, Copy, Debug)]
14pub struct IdentityMatrix<Len, T>
15where
16 Self: Computation,
17{
18 pub len: Len,
19 pub(super) ty: PhantomData<T>,
20}
21
22impl<Len, T> IdentityMatrix<Len, T>
23where
24 Self: Computation,
25{
26 pub fn new(len: Len) -> Self {
27 Self {
28 len,
29 ty: PhantomData,
30 }
31 }
32}
33
34impl<Len, T> Computation for IdentityMatrix<Len, T>
35where
36 Len: Computation<Dim = Zero, Item = usize>,
37{
38 type Dim = Two;
39 type Item = T;
40}
41
42impl<Len, T> ComputationFn for IdentityMatrix<Len, T>
43where
44 Self: Computation,
45 Len: ComputationFn,
46 IdentityMatrix<Len::Filled, T>: Computation,
47{
48 type Filled = IdentityMatrix<Len::Filled, T>;
49
50 fn fill(self, named_args: NamedArgs) -> Self::Filled {
51 IdentityMatrix {
52 len: self.len.fill(named_args),
53 ty: self.ty,
54 }
55 }
56
57 fn arg_names(&self) -> crate::Names {
58 self.len.arg_names()
59 }
60}
61
62impl_core_ops!(IdentityMatrix<Len, T>);
63
64impl<Len, T> fmt::Display for IdentityMatrix<Len, T>
65where
66 Self: Computation,
67 Len: fmt::Display,
68{
69 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
70 write!(f, "identity_matrix({})", self.len)
71 }
72}
73
74#[derive(Clone, Copy, Debug)]
77pub struct FromDiagElem<Len, Elem>
78where
79 Self: Computation,
80{
81 pub len: Len,
82 pub elem: Elem,
83}
84
85impl<Len, Elem> FromDiagElem<Len, Elem>
86where
87 Self: Computation,
88{
89 #[allow(missing_docs)]
90 pub fn new(len: Len, elem: Elem) -> Self {
91 Self { len, elem }
92 }
93}
94
95impl<Len, Elem> Computation for FromDiagElem<Len, Elem>
96where
97 Len: Computation<Dim = Zero, Item = usize>,
98 Elem: Computation<Dim = Zero>,
99{
100 type Dim = Two;
101 type Item = Elem::Item;
102}
103
104impl<Len, Elem> ComputationFn for FromDiagElem<Len, Elem>
105where
106 Self: Computation,
107 Len: ComputationFn,
108 Elem: ComputationFn,
109 FromDiagElem<Len::Filled, Elem::Filled>: Computation,
110{
111 type Filled = FromDiagElem<Len::Filled, Elem::Filled>;
112
113 fn fill(self, named_args: NamedArgs) -> Self::Filled {
114 let (args_0, args_1) = named_args
115 .partition(&self.len.arg_names(), &self.elem.arg_names())
116 .unwrap_or_else(|e| panic!("{}", e,));
117 FromDiagElem {
118 len: self.len.fill(args_0),
119 elem: self.elem.fill(args_1),
120 }
121 }
122
123 fn arg_names(&self) -> crate::Names {
124 self.len.arg_names().union(self.elem.arg_names())
125 }
126}
127
128impl_core_ops!(FromDiagElem<Len, Elem>);
129
130impl<Len, Elem> fmt::Display for FromDiagElem<Len, Elem>
131where
132 Self: Computation,
133 Len: fmt::Display,
134 Elem: fmt::Display,
135{
136 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
137 write!(f, "from_diag_elem({}, {})", self.len, self.elem)
138 }
139}
140
141pub type ScalarProduct<A, B> = Sum<Mul<A, B>>;
143
144pub fn scalar_product<A, B>(x: A, y: B) -> ScalarProduct<A, B>
146where
147 Mul<A, B>: Computation,
148 ScalarProduct<A, B>: Computation,
149{
150 Sum(Mul(x, y))
151}
152
153#[derive(Clone, Copy, Debug)]
171pub struct MatMul<A, B>(pub A, pub B)
172where
173 Self: Computation;
174
175impl<A, B> Computation for MatMul<A, B>
176where
177 A: Computation<Dim = Two>,
178 B: Computation<Dim = Two>,
179 A::Item: ops::Mul<B::Item>,
180 <A::Item as ops::Mul<B::Item>>::Output: ops::Add,
181{
182 type Dim = Two;
183 type Item = <<A::Item as ops::Mul<B::Item>>::Output as ops::Add>::Output;
184}
185
186impl_computation_fn_for_binary!(MatMul);
187
188impl_core_ops!(MatMul<A, B>);
189
190impl<A, B> fmt::Display for MatMul<A, B>
191where
192 Self: Computation,
193 A: fmt::Display,
194 B: fmt::Display,
195{
196 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
197 write!(f, "({} x {})", self.0, self.1)
198 }
199}
200
201#[derive(Clone, Copy, Debug)]
203pub struct MulOut<A, B>(pub A, pub B)
204where
205 Self: Computation;
206
207impl<A, B> Computation for MulOut<A, B>
208where
209 A: Computation<Dim = One>,
210 B: Computation<Dim = One>,
211 A::Item: ops::Mul<B::Item>,
212{
213 type Dim = Two;
214 type Item = <A::Item as ops::Mul<B::Item>>::Output;
215}
216
217impl_computation_fn_for_binary!(MulOut);
218
219impl_core_ops!(MulOut<A, B>);
220
221impl<A, B> fmt::Display for MulOut<A, B>
222where
223 Self: Computation,
224 A: fmt::Display,
225 B: fmt::Display,
226{
227 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
228 write!(f, "(col({}) x row({}))", self.0, self.1)
229 }
230}
231
232#[derive(Clone, Copy, Debug)]
234pub struct MulCol<A, B>(pub A, pub B)
235where
236 Self: Computation;
237
238impl<A, B> Computation for MulCol<A, B>
239where
240 A: Computation<Dim = Two>,
241 B: Computation<Dim = One>,
242 A::Item: ops::Mul<B::Item>,
243 <A::Item as ops::Mul<B::Item>>::Output: ops::Add,
244{
245 type Dim = One;
246 type Item = <<A::Item as ops::Mul<B::Item>>::Output as ops::Add>::Output;
247}
248
249impl_computation_fn_for_binary!(MulCol);
250
251impl_core_ops!(MulCol<A, B>);
252
253impl<A, B> fmt::Display for MulCol<A, B>
254where
255 Self: Computation,
256 A: fmt::Display,
257 B: fmt::Display,
258{
259 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
260 write!(f, "({} x col({}))", self.0, self.1)
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use proptest::prelude::*;
267 use test_strategy::proptest;
268
269 use crate::{linalg::FromDiagElem, run::Matrix, val, val1, val2, Computation};
270
271 #[proptest]
272 fn identity_matrix_should_display(x: usize) {
273 let inp = val!(x);
274 prop_assert_eq!(
275 inp.identity_matrix::<i32>().to_string(),
276 format!("identity_matrix({})", inp)
277 );
278 }
279
280 #[proptest]
281 fn from_diag_elem_should_display(len: usize, elem: i32) {
282 let len = val!(len);
283 let elem = val!(elem);
284 prop_assert_eq!(
285 FromDiagElem::new(len, elem).to_string(),
286 format!("from_diag_elem({}, {})", len, elem)
287 );
288 }
289
290 #[proptest]
306 fn mat_mul_should_display(x: i32, y: i32, z: i32, q: i32) {
307 let lhs = val2!(Matrix::from_vec((2, 1), vec![x, y]).unwrap());
308 let rhs = val2!(Matrix::from_vec((1, 2), vec![z, q]).unwrap());
309 prop_assert_eq!(
310 lhs.clone().mat_mul(rhs.clone()).to_string(),
311 format!("({} x {})", lhs, rhs)
312 );
313 }
314
315 #[proptest]
316 fn mul_out_should_display(x: i32, y: i32, z: i32, q: i32) {
317 let lhs = val1!([x, y]);
318 let rhs = val1!([z, q]);
319 prop_assert_eq!(
320 lhs.mul_out(rhs).to_string(),
321 format!("(col({}) x row({}))", lhs, rhs)
322 );
323 }
324
325 #[proptest]
326 fn mul_col_should_display(x: i32, y: i32, z: i32, q: i32) {
327 let lhs = val2!(Matrix::from_vec((2, 1), vec![x, y]).unwrap());
328 let rhs = val1!([z, q]);
329 prop_assert_eq!(
330 lhs.clone().mul_col(rhs).to_string(),
331 format!("({} x col({}))", lhs, rhs)
332 );
333 }
334}