linalg_rs/matrix/
helper.rs

1use std::{error::Error, ops::RangeInclusive, str::FromStr};
2
3use rayon::prelude::*;
4
5use crate::{at, Matrix, MatrixElement};
6
7pub fn swap(lhs: &mut usize, rhs: &mut usize) {
8    let temp = *lhs;
9    *lhs = *rhs;
10    *rhs = temp;
11}
12
13// simd
14impl<'a, T> Matrix<'a, T>
15where
16    T: MatrixElement + 'a,
17    <T as FromStr>::Err: Error + 'static,
18    Vec<T>: IntoParallelIterator,
19    Vec<&'a T>: IntoParallelRefIterator<'a>,
20{
21    pub fn determinant_helper(&self) -> T {
22        match self.nrows {
23            1 => self.at(0, 0),
24            2 => Self::det_2x2(self),
25            3 => Self::det_3x3(self),
26            n => Self::det_nxn(self.data.clone(), n),
27        }
28    }
29
30    // General helper function calling out to other matmuls based on target architecture
31    pub fn matmul_helper(&self, other: &Self) -> Self {
32        match (self.shape(), other.shape()) {
33            ((1, 2), (2, 1)) => return self.onetwo_by_twoone(other),
34            ((2, 2), (2, 1)) => return self.twotwo_by_twoone(other),
35            ((1, 2), (2, 2)) => return self.onetwo_by_twotwo(other),
36            ((2, 2), (2, 2)) => return self.twotwo_by_twotwo(other),
37            _ => {}
38        };
39
40        // Target Detection
41
42        // if let Some(result) = optim::get_optimized_matmul(self, other) {
43        //     return result;
44        // }
45
46        let blck_size = Self::get_block_size(self, other);
47
48        // println!("BS: {}", blck_size);
49
50        if self.shape() == other.shape() {
51            // Calculated from lowest possible size where
52            // nrows & blck_size == 0.
53            // Block size will never be more than 50
54            return Self::blocked_matmul(self, other, blck_size);
55        }
56
57        Self::optimized_blocked_matmul(self, other, blck_size)
58    }
59
60    // Calculate efficient blocksize
61    #[inline(always)]
62    pub fn get_block_size(&self, other: &Self) -> usize {
63        let range = Self::get_range_for_block_size(self, other);
64
65        range
66            .collect::<Vec<_>>()
67            .into_par_iter()
68            .find_last(|b| self.ncols % b == 0 || self.nrows % b == 0 || other.ncols % b == 0)
69            .unwrap()
70    }
71
72    #[inline(always)]
73    pub fn get_range_for_block_size(&self, other: &Self) -> RangeInclusive<usize> {
74        if self.nrows < 30 && self.ncols < 30 || other.nrows < 30 && other.ncols < 30 {
75            2..=10
76        } else if self.nrows < 100 && self.ncols < 100 || other.nrows < 100 && other.ncols < 100 {
77            10..=30
78        } else {
79            30..=50
80        }
81    }
82
83    // ===================================================
84    //           Determinant
85    // ===================================================
86
87    #[inline(always)]
88    fn det_2x2(&self) -> T {
89        self.at(0, 0) * self.at(1, 1) - self.at(0, 1) * self.at(1, 0)
90    }
91
92    #[inline(always)]
93    fn det_3x3(&self) -> T {
94        let a = self.at(0, 0);
95        let b = self.at(0, 1);
96        let c = self.at(0, 2);
97        let d = self.at(1, 0);
98        let e = self.at(1, 1);
99        let f = self.at(1, 2);
100        let g = self.at(2, 0);
101        let h = self.at(2, 1);
102        let i = self.at(2, 2);
103
104        a * (e * i - f * h) - b * (d * i - f * g) + c * (d * h - e * g)
105    }
106
107    fn det_nxn(matrix: Vec<T>, n: usize) -> T {
108        if n == 1 {
109            return matrix[0];
110        }
111
112        let mut det = T::zero();
113        let mut sign = T::one();
114
115        for col in 0..n {
116            let sub_det = Self::det_nxn(Self::submatrix(matrix.clone(), n, 0, col), n - 1);
117
118            det += sign * matrix[col] * sub_det;
119
120            sign *= -T::one();
121        }
122
123        det
124    }
125
126    fn submatrix(matrix: Vec<T>, n: usize, row_to_remove: usize, col_to_remove: usize) -> Vec<T> {
127        matrix
128            .par_iter()
129            .enumerate()
130            .filter_map(|(i, &value)| {
131                let row = i / n;
132                let col = i % n;
133                if row != row_to_remove && col != col_to_remove {
134                    Some(value)
135                } else {
136                    None
137                }
138            })
139            .collect()
140    }
141
142    // ===================================================
143    //           Extremely specific optimizations
144    // ===================================================
145
146    // 1x2 @ 2x1 matrix
147    #[inline(always)]
148    fn onetwo_by_twoone(&self, other: &Self) -> Self {
149        let a = self.at(0, 0) * other.at(0, 0) + self.at(0, 1) * other.at(1, 0);
150
151        Self::new(vec![a], (1, 1)).unwrap()
152    }
153
154    // 2x2 @ 2x1 matrix
155    #[inline(always)]
156    fn twotwo_by_twoone(&self, other: &Self) -> Self {
157        let a = self.at(0, 0) * other.at(0, 0) + self.at(0, 1) * other.at(1, 0);
158        let b = self.at(1, 0) * other.at(0, 0) + self.at(1, 1) * other.at(1, 0);
159
160        Self::new(vec![a, b], (2, 1)).unwrap()
161    }
162
163    //
164    // 1x2 @ 2x2 matrix
165    #[inline(always)]
166    fn onetwo_by_twotwo(&self, other: &Self) -> Self {
167        let a = self.at(0, 0) * other.at(0, 0) + self.at(0, 1) * other.at(1, 0);
168        let b = self.at(0, 0) * other.at(1, 0) + self.at(0, 1) * other.at(1, 1);
169
170        Self::new(vec![a, b], (1, 2)).unwrap()
171    }
172
173    // 2x2 @ 2x2 matrix
174    #[inline(always)]
175    fn twotwo_by_twotwo(&self, other: &Self) -> Self {
176        let a = self.at(0, 0) * other.at(0, 0) + self.at(1, 0) * other.at(1, 0);
177        let b = self.at(0, 0) * other.at(0, 1) + self.at(0, 1) * other.at(1, 1);
178        let c = self.at(1, 0) * other.at(0, 0) + self.at(1, 1) * other.at(1, 0);
179        let d = self.at(1, 0) * other.at(1, 0) + self.at(1, 1) * other.at(1, 1);
180
181        Self::new(vec![a, b, c, d], (2, 2)).unwrap()
182    }
183
184    // ========================================================================
185    //
186    //    General solutions for matrix multiplication
187    //
188    // ========================================================================
189
190    /// Naive matmul if you don't have any SIMD intrinsincts
191    ///
192    /// Also blocked, but doing different than just N
193    fn optimized_blocked_matmul(&self, other: &Self, block_size: usize) -> Self {
194        let M = self.nrows;
195        let N = self.ncols;
196        let P = other.ncols;
197
198        let mut data = vec![T::zero(); M * P];
199
200        //let t_other = other.transpose_copy();
201
202        for kk in (0..N).step_by(block_size) {
203            for jj in (0..P).step_by(block_size) {
204                for ii in (0..M).step_by(block_size) {
205                    let block_end_i = (ii + block_size).min(M);
206                    let block_end_j = (jj + block_size).min(P);
207                    let block_end_k = (kk + block_size).min(N);
208
209                    // Blocking for L0 memory
210                    for i in ii..block_end_i {
211                        for j in jj..block_end_j {
212                            // for k in kk..block_end_k {
213                            //     data[at!(i, j, P)] += self.at(i, k) * other.at(k, j);
214                            // }
215                            data[at!(i, j, P)] = (kk..block_end_k)
216                                .into_par_iter()
217                                .map(|k| self.at(i, k) * other.at(k, j))
218                                .sum();
219                        }
220                    }
221                }
222            }
223        }
224        Self::new(data, (M, P)).unwrap()
225    }
226
227    // SUMMA Algorithm
228    // https://www.netlib.org/lapack/lawnspdf/lawn96.pdf
229    fn summa(&self, other: &Self, block_size: usize) -> Self {
230        todo!()
231    }
232
233    // The magnum opus of matrix multiply, also known as naive matmul
234    // Only optimization is a parallelized innermost summation
235    fn naive(&self, other: &Self) -> Self {
236        let M = self.nrows;
237        let N = self.ncols;
238        let P = other.ncols;
239
240        let mut data = vec![T::zero(); M * P];
241
242        for i in 0..M {
243            for j in 0..P {
244                data[at!(i, j, P)] = (0..N)
245                    .into_par_iter()
246                    .map(|k| self.at(i, k) * other.at(k, j))
247                    .sum();
248            }
249        }
250
251        Self::new(data, (M, P)).unwrap()
252    }
253
254    // Blocked matmul if you don't have any SIMD intrinsincts
255    // https://csapp.cs.cmu.edu/public/waside/waside-blocking.pdf
256    //
257    // Modification involves transposing the B matrix, at the cost
258    // of increased space complexity, but better cache hit rate
259    //
260    // NOTE: Only works for M N @ N M matrices for now
261    fn blocked_matmul(&self, other: &Self, block_size: usize) -> Self {
262        let n = self.nrows;
263
264        let en = block_size * (n / block_size);
265
266        let mut data = vec![T::zero(); n * n];
267
268        let t_other = other.transpose_copy();
269
270        for kk in (0..n).step_by(en) {
271            for jj in (0..n).step_by(en) {
272                for i in 0..n {
273                    for j in jj..jj + block_size {
274                        data[at!(i, j, n)] = (kk..kk + block_size)
275                            .into_par_iter()
276                            .map(|k| self.at(i, k) * t_other.at(j, k))
277                            .sum();
278                    }
279                }
280            }
281        }
282        Self::new(data, (n, n)).unwrap()
283    }
284}