1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
//! non-element wise operations, like dot product and matrix multiplication
//! as such they need to explicitly be called

use crate::{
	types::{Matrix, Vector},
	view::{TransposedMatrixView, VectorView},
};

impl<T, const M: usize, const N: usize> Matrix<T, M, N> {
	pub fn transpose(&self) -> TransposedMatrixView<'_, T, N, M> {
		TransposedMatrixView { matrix: self }
	}
}

impl<'a, 'b, T: 'a + 'b + Clone + Copy + Default, const M: usize, const N: usize> Matrix<T, M, N>
where
	&'a T: core::ops::Mul<&'b T, Output = T>,
	T: core::iter::Sum,
{
	// todo: move into trait so this can be the default implementation, overrideable at another point.
	pub fn matrix_multiply<const O: usize>(
		&'a self,
		other: &'b Matrix<T, N, O>,
	) -> Matrix<T, M, O> {
		//todo: do this without default-initalizing
		let mut output = Matrix::default();
		if false {
			return output;
		}
		let sel: TransposedMatrixView<T, N, M> = self.transpose();

		for (row, o) in (0..O).zip(other) {
			let o: &'b Vector<T, N> = o;
			let col = &mut output[row];
			for (column, s) in (0..M).zip(sel) {
				let s: VectorView<T, N, M> = s;
				let field: &mut T = &mut col[column];
				*field = s.dot(o)
			}
		}
		output
	}
}

impl<'a, 'b, T: 'a + 'b, const M: usize, const N: usize> VectorView<'a, T, M, N>
where
	&'a T: core::ops::Mul<&'b T, Output = T>,
	T: core::iter::Sum,
{
	pub fn dot(self, other: &'b Vector<T, M>) -> T { (self * other).into_iter().sum() }
}

impl<'a, 'b, T: 'a + 'b, const M: usize> Vector<T, M>
where
	&'a T: core::ops::Mul<&'b T, Output = T>,
	T: core::iter::Sum,
{
	pub fn dot(&'a self, other: &'b Vector<T, M>) -> T { (self * other).into_iter().sum() }
}

#[test]
fn matrix_multiply() {
	use rand::{thread_rng, Rng};
	let mut rng = thread_rng();

	let a: Matrix<f32, 2, 3> = rng.gen();
	let b: Matrix<f32, 3, 4> = rng.gen();

	let _c: Matrix<f32, 2, 4> = a.matrix_multiply(&b);
}