Skip to main content

arcis_compiler/utils/
matrix.rs

1use ff::Field;
2use num_traits::Zero;
3use std::ops::{Add, AddAssign, Index, IndexMut, Mul, Sub, SubAssign};
4
5#[derive(Clone, Debug, PartialEq, Eq, Hash)]
6pub struct Matrix<T: Copy> {
7    data: Vec<T>,
8    pub nrows: usize,
9    pub ncols: usize,
10}
11
12impl<T: Copy> Matrix<T> {
13    /// Builds a Matrix.
14    /// * size should be nrows then ncols
15    /// * item is what will fill the matrix
16    pub fn new(size: (usize, usize), item: T) -> Self {
17        Matrix {
18            data: vec![item; size.0 * size.1],
19            nrows: size.0,
20            ncols: size.1,
21        }
22    }
23    /// Builds a Matrix.
24    /// * size should be nrows then ncols
25    /// * iterator is what will fill the matrix. It will be wholly consumed, and will fail if size
26    ///   is wrong.
27    pub fn new_from_iter<U: Iterator<Item = T>>(size: (usize, usize), iterator: U) -> Self {
28        let data: Vec<T> = iterator.collect();
29        assert_eq!(
30            data.len(),
31            size.0 * size.1,
32            "iterator of size {} for matrix of size {}x{}",
33            data.len(),
34            size.0,
35            size.1
36        );
37        Matrix {
38            data,
39            nrows: size.0,
40            ncols: size.1,
41        }
42    }
43
44    pub fn new_from_column_major_iter<U: Iterator<Item = T>>(
45        size: (usize, usize),
46        iterator: U,
47    ) -> Self {
48        let column_major_data = iterator.collect::<Vec<T>>();
49        assert_eq!(
50            column_major_data.len(),
51            size.0 * size.1,
52            "iterator of size {} for matrix of size {}x{}",
53            column_major_data.len(),
54            size.0,
55            size.1
56        );
57        let row_major_data = (0..size.0)
58            .flat_map(|i| {
59                (0..size.1)
60                    .map(|j| column_major_data[i + j * size.0])
61                    .collect::<Vec<T>>()
62            })
63            .collect();
64
65        Matrix {
66            data: row_major_data,
67            nrows: size.0,
68            ncols: size.1,
69        }
70    }
71
72    fn index(&self, x: usize, y: usize) -> usize {
73        x * self.ncols + y
74    }
75    pub fn get(&self, location: (usize, usize)) -> Option<&T> {
76        let (x, y) = location;
77        let index = self.index(x, y);
78        self.data.get(index)
79    }
80    pub fn get_mut(&mut self, location: (usize, usize)) -> Option<&mut T> {
81        let (x, y) = location;
82        let index = self.index(x, y);
83        self.data.get_mut(index)
84    }
85    pub fn col(&self, index: usize) -> Self {
86        if index >= self.ncols {
87            panic!(
88                "index for column extraction must be less than {} (found {})",
89                self.ncols, index
90            );
91        }
92        let data = (index..self.nrows * self.ncols)
93            .step_by(self.ncols)
94            .map(|i| self.data[i])
95            .collect();
96        Self {
97            data,
98            nrows: self.nrows,
99            ncols: 1,
100        }
101    }
102    /// Applies a function to all items in self, in-place.
103    pub fn map_mut(&mut self, mut f: impl FnMut(T) -> T) {
104        for x in 0..self.nrows {
105            for y in 0..self.ncols {
106                self[(x, y)] = f(self[(x, y)]);
107            }
108        }
109    }
110    /// Matrix Multiplication
111    pub fn mat_mul<U: Copy + Add<Output = U> + Mul<T, Output = U> + Zero>(
112        &self,
113        rhs: &Matrix<U>,
114    ) -> Matrix<U> {
115        assert_eq!(self.ncols, rhs.nrows);
116        let mut mat: Matrix<U> = Matrix::new((self.nrows, rhs.ncols), U::zero());
117        for i in 0..self.nrows {
118            for j in 0..rhs.ncols {
119                let acc = mat.get_mut((i, j)).unwrap();
120                for k in 0..self.ncols {
121                    *acc = *acc + rhs[(k, j)] * self[(i, k)];
122                }
123            }
124        }
125        mat
126    }
127
128    /*
129       return the determinant of a square matrix.
130       Panics if matrix is not square or of size 0x0.
131
132       The determinant is computed via gaus elimination
133    */
134    pub fn det(&self) -> T
135    where
136        T: Field,
137    {
138        // non-empty square matrix
139        assert!(self.nrows == self.ncols && self.ncols != 0);
140        let n = self.ncols;
141
142        let mut det = T::ONE;
143        let mut rows;
144        // we start with the complete matrix, each round will reduce the matrix dimension by one.
145        rows = self
146            .data
147            .chunks(n)
148            .map(|c| c.to_vec())
149            .collect::<Vec<_>>()
150            .clone();
151
152        // Each recursion step removes the pivot's elements row and column and multiplies the pivot
153        // onto the determinant.
154        for _ in 0..n {
155            // we partition into rows that have a leading zero and rows that don't
156            let (lz_rows_vec, nlz_rows_vec): (Vec<_>, Vec<_>) =
157                rows.iter().partition(|row| row.starts_with(&[T::ZERO]));
158
159            let (lz_rows, mut nlz_rows) = (lz_rows_vec.iter(), nlz_rows_vec.iter());
160            // take pivot element
161            let Some(pivot) = nlz_rows.next() else {
162                // no pivot row implies the rank is less than n i.e. the determinant is zero
163                return T::ZERO;
164            };
165
166            // multiply pivot onto the determinant
167            det *= pivot[0];
168
169            // subtract all leading non zero values with the pivot element (forward elimination).
170
171            let pivot_inverse = pivot[0].invert().unwrap();
172            // precomputing pivot row such that the leading value is one. This reduces the number of
173            // multiplications in the forward elimination multiplications by 50%
174            let normalized_pivot: Vec<_> = pivot.iter().map(|f| *f * pivot_inverse).collect();
175            // forward elimination with normalized pivot row
176            let processed_nlz_rows = nlz_rows.map(|row| {
177                let lead = row[0];
178                let row: Vec<_> = row
179                    .iter()
180                    .zip(&normalized_pivot)
181                    .map(move |(f, p)| *f - lead * p)
182                    .collect();
183                row
184            });
185
186            // collect the reamining rows (without pivot row) and remove the pivot column (all first
187            // elements (i.e. zeros) from the remaining rows).
188            rows = processed_nlz_rows
189                .chain(lz_rows.map(|c| c.to_vec()))
190                .map(|mut v| v.drain(1..).collect::<Vec<_>>())
191                .collect::<Vec<_>>();
192        }
193        det
194    }
195    pub fn convert<U: From<T> + Copy>(&self) -> Matrix<U> {
196        Matrix::new_from_iter(
197            (self.nrows, self.ncols),
198            self.into_iter().map(|c| U::from(c)),
199        )
200    }
201}
202
203impl<T: Copy> Index<(usize, usize)> for Matrix<T> {
204    type Output = T;
205
206    fn index(&self, index: (usize, usize)) -> &Self::Output {
207        self.get(index).unwrap()
208    }
209}
210impl<T: Copy> IndexMut<(usize, usize)> for Matrix<T> {
211    fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
212        self.get_mut(index).unwrap()
213    }
214}
215impl<T: Copy> IntoIterator for Matrix<T> {
216    type Item = T;
217    type IntoIter = std::vec::IntoIter<Self::Item>;
218    fn into_iter(self) -> Self::IntoIter {
219        self.data.into_iter()
220    }
221}
222
223impl<T: Copy> IntoIterator for &Matrix<T> {
224    type Item = T;
225    type IntoIter = std::vec::IntoIter<Self::Item>;
226    fn into_iter(self) -> Self::IntoIter {
227        self.data.clone().into_iter()
228    }
229}
230
231impl<'a, T: Copy + Add<Output = T>> AddAssign<&'a Matrix<T>> for Matrix<T> {
232    fn add_assign(&mut self, rhs: &'a Matrix<T>) {
233        assert_eq!(self.nrows, rhs.nrows);
234        assert_eq!(self.ncols, rhs.ncols);
235        for i in 0..self.nrows {
236            for j in 0..self.ncols {
237                self[(i, j)] = self[(i, j)] + rhs[(i, j)];
238            }
239        }
240    }
241}
242
243impl<'a, T: Copy + Sub<Output = T>> SubAssign<&'a Matrix<T>> for Matrix<T> {
244    fn sub_assign(&mut self, rhs: &'a Matrix<T>) {
245        assert_eq!(self.nrows, rhs.nrows);
246        assert_eq!(self.ncols, rhs.ncols);
247        for i in 0..self.nrows {
248            for j in 0..self.ncols {
249                self[(i, j)] = self[(i, j)] - rhs[(i, j)];
250            }
251        }
252    }
253}
254
255impl<T: Copy + Add<Output = T>> Add for Matrix<T> {
256    type Output = Matrix<T>;
257
258    fn add(mut self, rhs: Self) -> Self::Output {
259        self += &rhs;
260        self
261    }
262}
263
264impl<T: Copy + Sub<Output = T>> Sub for Matrix<T> {
265    type Output = Matrix<T>;
266
267    fn sub(mut self, rhs: Self) -> Self::Output {
268        self -= &rhs;
269        self
270    }
271}
272
273impl<T: Copy> From<Vec<T>> for Matrix<T> {
274    fn from(v: Vec<T>) -> Self {
275        let nrows = v.len();
276        Self::new_from_iter((nrows, 1), v.into_iter())
277    }
278}
279impl<'a, T: Copy> From<&'a [T]> for Matrix<T> {
280    fn from(v: &'a [T]) -> Self {
281        let nrows = v.len();
282        Self::new_from_iter((nrows, 1), v.iter().copied())
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289    use crate::utils::field::ScalarField;
290    use ff::Field;
291
292    type F = ScalarField;
293
294    #[test]
295    fn test_det_dim3() {
296        // 4 2 4
297        // 0 0 3
298        // 5 7 7
299        let data = vec![
300            F::from(4),
301            F::from(2),
302            F::from(4),
303            F::ZERO,
304            F::ZERO,
305            F::from(3),
306            F::from(5),
307            F::from(7),
308            F::from(7),
309        ];
310
311        let mat = Matrix::new_from_iter((3, 3), data.into_iter());
312
313        let det = mat.det();
314        assert_eq!(F::from(54), det);
315    }
316
317    #[test]
318    fn test_det_dim4() {
319        let data = vec![
320            F::from(6),
321            F::from(4),
322            F::from(7),
323            F::from(8),
324            F::from(9),
325            F::from(3),
326            F::from(9),
327            F::from(8),
328            F::from(8),
329            F::from(3),
330            F::from(4),
331            F::from(9),
332            F::from(5),
333            F::from(4),
334            F::from(1),
335            F::from(3),
336        ];
337
338        let mat = Matrix::new_from_iter((4, 4), data.into_iter());
339        let det = mat.det();
340        assert_eq!(F::from(-476), det);
341    }
342}