1use std::{error::Error, ops::RangeInclusive, str::FromStr};
2
3use rayon::prelude::*;
4
5use crate::{at, Matrix, MatrixElement};
6
7pub fn swap(lhs: &mut usize, rhs: &mut usize) {
8 let temp = *lhs;
9 *lhs = *rhs;
10 *rhs = temp;
11}
12
13impl<'a, T> Matrix<'a, T>
15where
16 T: MatrixElement + 'a,
17 <T as FromStr>::Err: Error + 'static,
18 Vec<T>: IntoParallelIterator,
19 Vec<&'a T>: IntoParallelRefIterator<'a>,
20{
21 pub fn determinant_helper(&self) -> T {
22 match self.nrows {
23 1 => self.at(0, 0),
24 2 => Self::det_2x2(self),
25 3 => Self::det_3x3(self),
26 n => Self::det_nxn(self.data.clone(), n),
27 }
28 }
29
30 pub fn matmul_helper(&self, other: &Self) -> Self {
32 match (self.shape(), other.shape()) {
33 ((1, 2), (2, 1)) => return self.onetwo_by_twoone(other),
34 ((2, 2), (2, 1)) => return self.twotwo_by_twoone(other),
35 ((1, 2), (2, 2)) => return self.onetwo_by_twotwo(other),
36 ((2, 2), (2, 2)) => return self.twotwo_by_twotwo(other),
37 _ => {}
38 };
39
40 let blck_size = Self::get_block_size(self, other);
47
48 if self.shape() == other.shape() {
51 return Self::blocked_matmul(self, other, blck_size);
55 }
56
57 Self::optimized_blocked_matmul(self, other, blck_size)
58 }
59
60 #[inline(always)]
62 pub fn get_block_size(&self, other: &Self) -> usize {
63 let range = Self::get_range_for_block_size(self, other);
64
65 range
66 .collect::<Vec<_>>()
67 .into_par_iter()
68 .find_last(|b| self.ncols % b == 0 || self.nrows % b == 0 || other.ncols % b == 0)
69 .unwrap()
70 }
71
72 #[inline(always)]
73 pub fn get_range_for_block_size(&self, other: &Self) -> RangeInclusive<usize> {
74 if self.nrows < 30 && self.ncols < 30 || other.nrows < 30 && other.ncols < 30 {
75 2..=10
76 } else if self.nrows < 100 && self.ncols < 100 || other.nrows < 100 && other.ncols < 100 {
77 10..=30
78 } else {
79 30..=50
80 }
81 }
82
83 #[inline(always)]
88 fn det_2x2(&self) -> T {
89 self.at(0, 0) * self.at(1, 1) - self.at(0, 1) * self.at(1, 0)
90 }
91
92 #[inline(always)]
93 fn det_3x3(&self) -> T {
94 let a = self.at(0, 0);
95 let b = self.at(0, 1);
96 let c = self.at(0, 2);
97 let d = self.at(1, 0);
98 let e = self.at(1, 1);
99 let f = self.at(1, 2);
100 let g = self.at(2, 0);
101 let h = self.at(2, 1);
102 let i = self.at(2, 2);
103
104 a * (e * i - f * h) - b * (d * i - f * g) + c * (d * h - e * g)
105 }
106
107 fn det_nxn(matrix: Vec<T>, n: usize) -> T {
108 if n == 1 {
109 return matrix[0];
110 }
111
112 let mut det = T::zero();
113 let mut sign = T::one();
114
115 for col in 0..n {
116 let sub_det = Self::det_nxn(Self::submatrix(matrix.clone(), n, 0, col), n - 1);
117
118 det += sign * matrix[col] * sub_det;
119
120 sign *= -T::one();
121 }
122
123 det
124 }
125
126 fn submatrix(matrix: Vec<T>, n: usize, row_to_remove: usize, col_to_remove: usize) -> Vec<T> {
127 matrix
128 .par_iter()
129 .enumerate()
130 .filter_map(|(i, &value)| {
131 let row = i / n;
132 let col = i % n;
133 if row != row_to_remove && col != col_to_remove {
134 Some(value)
135 } else {
136 None
137 }
138 })
139 .collect()
140 }
141
142 #[inline(always)]
148 fn onetwo_by_twoone(&self, other: &Self) -> Self {
149 let a = self.at(0, 0) * other.at(0, 0) + self.at(0, 1) * other.at(1, 0);
150
151 Self::new(vec![a], (1, 1)).unwrap()
152 }
153
154 #[inline(always)]
156 fn twotwo_by_twoone(&self, other: &Self) -> Self {
157 let a = self.at(0, 0) * other.at(0, 0) + self.at(0, 1) * other.at(1, 0);
158 let b = self.at(1, 0) * other.at(0, 0) + self.at(1, 1) * other.at(1, 0);
159
160 Self::new(vec![a, b], (2, 1)).unwrap()
161 }
162
163 #[inline(always)]
166 fn onetwo_by_twotwo(&self, other: &Self) -> Self {
167 let a = self.at(0, 0) * other.at(0, 0) + self.at(0, 1) * other.at(1, 0);
168 let b = self.at(0, 0) * other.at(1, 0) + self.at(0, 1) * other.at(1, 1);
169
170 Self::new(vec![a, b], (1, 2)).unwrap()
171 }
172
173 #[inline(always)]
175 fn twotwo_by_twotwo(&self, other: &Self) -> Self {
176 let a = self.at(0, 0) * other.at(0, 0) + self.at(1, 0) * other.at(1, 0);
177 let b = self.at(0, 0) * other.at(0, 1) + self.at(0, 1) * other.at(1, 1);
178 let c = self.at(1, 0) * other.at(0, 0) + self.at(1, 1) * other.at(1, 0);
179 let d = self.at(1, 0) * other.at(1, 0) + self.at(1, 1) * other.at(1, 1);
180
181 Self::new(vec![a, b, c, d], (2, 2)).unwrap()
182 }
183
184 fn optimized_blocked_matmul(&self, other: &Self, block_size: usize) -> Self {
194 let M = self.nrows;
195 let N = self.ncols;
196 let P = other.ncols;
197
198 let mut data = vec![T::zero(); M * P];
199
200 for kk in (0..N).step_by(block_size) {
203 for jj in (0..P).step_by(block_size) {
204 for ii in (0..M).step_by(block_size) {
205 let block_end_i = (ii + block_size).min(M);
206 let block_end_j = (jj + block_size).min(P);
207 let block_end_k = (kk + block_size).min(N);
208
209 for i in ii..block_end_i {
211 for j in jj..block_end_j {
212 data[at!(i, j, P)] = (kk..block_end_k)
216 .into_par_iter()
217 .map(|k| self.at(i, k) * other.at(k, j))
218 .sum();
219 }
220 }
221 }
222 }
223 }
224 Self::new(data, (M, P)).unwrap()
225 }
226
227 fn summa(&self, other: &Self, block_size: usize) -> Self {
230 todo!()
231 }
232
233 fn naive(&self, other: &Self) -> Self {
236 let M = self.nrows;
237 let N = self.ncols;
238 let P = other.ncols;
239
240 let mut data = vec![T::zero(); M * P];
241
242 for i in 0..M {
243 for j in 0..P {
244 data[at!(i, j, P)] = (0..N)
245 .into_par_iter()
246 .map(|k| self.at(i, k) * other.at(k, j))
247 .sum();
248 }
249 }
250
251 Self::new(data, (M, P)).unwrap()
252 }
253
254 fn blocked_matmul(&self, other: &Self, block_size: usize) -> Self {
262 let n = self.nrows;
263
264 let en = block_size * (n / block_size);
265
266 let mut data = vec![T::zero(); n * n];
267
268 let t_other = other.transpose_copy();
269
270 for kk in (0..n).step_by(en) {
271 for jj in (0..n).step_by(en) {
272 for i in 0..n {
273 for j in jj..jj + block_size {
274 data[at!(i, j, n)] = (kk..kk + block_size)
275 .into_par_iter()
276 .map(|k| self.at(i, k) * t_other.at(j, k))
277 .sum();
278 }
279 }
280 }
281 }
282 Self::new(data, (n, n)).unwrap()
283 }
284}