1use ff::Field;
2use num_traits::Zero;
3use std::ops::{Add, AddAssign, Index, IndexMut, Mul, Sub, SubAssign};
4
5#[derive(Clone, Debug, PartialEq, Eq, Hash)]
6pub struct Matrix<T: Copy> {
7 data: Vec<T>,
8 pub nrows: usize,
9 pub ncols: usize,
10}
11
12impl<T: Copy> Matrix<T> {
13 pub fn new(size: (usize, usize), item: T) -> Self {
17 Matrix {
18 data: vec![item; size.0 * size.1],
19 nrows: size.0,
20 ncols: size.1,
21 }
22 }
23 pub fn new_from_iter<U: Iterator<Item = T>>(size: (usize, usize), iterator: U) -> Self {
28 let data: Vec<T> = iterator.collect();
29 assert_eq!(
30 data.len(),
31 size.0 * size.1,
32 "iterator of size {} for matrix of size {}x{}",
33 data.len(),
34 size.0,
35 size.1
36 );
37 Matrix {
38 data,
39 nrows: size.0,
40 ncols: size.1,
41 }
42 }
43
44 pub fn new_from_column_major_iter<U: Iterator<Item = T>>(
45 size: (usize, usize),
46 iterator: U,
47 ) -> Self {
48 let column_major_data = iterator.collect::<Vec<T>>();
49 assert_eq!(
50 column_major_data.len(),
51 size.0 * size.1,
52 "iterator of size {} for matrix of size {}x{}",
53 column_major_data.len(),
54 size.0,
55 size.1
56 );
57 let row_major_data = (0..size.0)
58 .flat_map(|i| {
59 (0..size.1)
60 .map(|j| column_major_data[i + j * size.0])
61 .collect::<Vec<T>>()
62 })
63 .collect();
64
65 Matrix {
66 data: row_major_data,
67 nrows: size.0,
68 ncols: size.1,
69 }
70 }
71
72 fn index(&self, x: usize, y: usize) -> usize {
73 x * self.ncols + y
74 }
75 pub fn get(&self, location: (usize, usize)) -> Option<&T> {
76 let (x, y) = location;
77 let index = self.index(x, y);
78 self.data.get(index)
79 }
80 pub fn get_mut(&mut self, location: (usize, usize)) -> Option<&mut T> {
81 let (x, y) = location;
82 let index = self.index(x, y);
83 self.data.get_mut(index)
84 }
85 pub fn col(&self, index: usize) -> Self {
86 if index >= self.ncols {
87 panic!(
88 "index for column extraction must be less than {} (found {})",
89 self.ncols, index
90 );
91 }
92 let data = (index..self.nrows * self.ncols)
93 .step_by(self.ncols)
94 .map(|i| self.data[i])
95 .collect();
96 Self {
97 data,
98 nrows: self.nrows,
99 ncols: 1,
100 }
101 }
102 pub fn map_mut(&mut self, mut f: impl FnMut(T) -> T) {
104 for x in 0..self.nrows {
105 for y in 0..self.ncols {
106 self[(x, y)] = f(self[(x, y)]);
107 }
108 }
109 }
110 pub fn mat_mul<U: Copy + Add<Output = U> + Mul<T, Output = U> + Zero>(
112 &self,
113 rhs: &Matrix<U>,
114 ) -> Matrix<U> {
115 assert_eq!(self.ncols, rhs.nrows);
116 let mut mat: Matrix<U> = Matrix::new((self.nrows, rhs.ncols), U::zero());
117 for i in 0..self.nrows {
118 for j in 0..rhs.ncols {
119 let acc = mat.get_mut((i, j)).unwrap();
120 for k in 0..self.ncols {
121 *acc = *acc + rhs[(k, j)] * self[(i, k)];
122 }
123 }
124 }
125 mat
126 }
127
128 pub fn det(&self) -> T
135 where
136 T: Field,
137 {
138 assert!(self.nrows == self.ncols && self.ncols != 0);
140 let n = self.ncols;
141
142 let mut det = T::ONE;
143 let mut rows;
144 rows = self
146 .data
147 .chunks(n)
148 .map(|c| c.to_vec())
149 .collect::<Vec<_>>()
150 .clone();
151
152 for _ in 0..n {
155 let (lz_rows_vec, nlz_rows_vec): (Vec<_>, Vec<_>) =
157 rows.iter().partition(|row| row.starts_with(&[T::ZERO]));
158
159 let (lz_rows, mut nlz_rows) = (lz_rows_vec.iter(), nlz_rows_vec.iter());
160 let Some(pivot) = nlz_rows.next() else {
162 return T::ZERO;
164 };
165
166 det *= pivot[0];
168
169 let pivot_inverse = pivot[0].invert().unwrap();
172 let normalized_pivot: Vec<_> = pivot.iter().map(|f| *f * pivot_inverse).collect();
175 let processed_nlz_rows = nlz_rows.map(|row| {
177 let lead = row[0];
178 let row: Vec<_> = row
179 .iter()
180 .zip(&normalized_pivot)
181 .map(move |(f, p)| *f - lead * p)
182 .collect();
183 row
184 });
185
186 rows = processed_nlz_rows
189 .chain(lz_rows.map(|c| c.to_vec()))
190 .map(|mut v| v.drain(1..).collect::<Vec<_>>())
191 .collect::<Vec<_>>();
192 }
193 det
194 }
195 pub fn convert<U: From<T> + Copy>(&self) -> Matrix<U> {
196 Matrix::new_from_iter(
197 (self.nrows, self.ncols),
198 self.into_iter().map(|c| U::from(c)),
199 )
200 }
201}
202
203impl<T: Copy> Index<(usize, usize)> for Matrix<T> {
204 type Output = T;
205
206 fn index(&self, index: (usize, usize)) -> &Self::Output {
207 self.get(index).unwrap()
208 }
209}
210impl<T: Copy> IndexMut<(usize, usize)> for Matrix<T> {
211 fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
212 self.get_mut(index).unwrap()
213 }
214}
215impl<T: Copy> IntoIterator for Matrix<T> {
216 type Item = T;
217 type IntoIter = std::vec::IntoIter<Self::Item>;
218 fn into_iter(self) -> Self::IntoIter {
219 self.data.into_iter()
220 }
221}
222
223impl<T: Copy> IntoIterator for &Matrix<T> {
224 type Item = T;
225 type IntoIter = std::vec::IntoIter<Self::Item>;
226 fn into_iter(self) -> Self::IntoIter {
227 self.data.clone().into_iter()
228 }
229}
230
231impl<'a, T: Copy + Add<Output = T>> AddAssign<&'a Matrix<T>> for Matrix<T> {
232 fn add_assign(&mut self, rhs: &'a Matrix<T>) {
233 assert_eq!(self.nrows, rhs.nrows);
234 assert_eq!(self.ncols, rhs.ncols);
235 for i in 0..self.nrows {
236 for j in 0..self.ncols {
237 self[(i, j)] = self[(i, j)] + rhs[(i, j)];
238 }
239 }
240 }
241}
242
243impl<'a, T: Copy + Sub<Output = T>> SubAssign<&'a Matrix<T>> for Matrix<T> {
244 fn sub_assign(&mut self, rhs: &'a Matrix<T>) {
245 assert_eq!(self.nrows, rhs.nrows);
246 assert_eq!(self.ncols, rhs.ncols);
247 for i in 0..self.nrows {
248 for j in 0..self.ncols {
249 self[(i, j)] = self[(i, j)] - rhs[(i, j)];
250 }
251 }
252 }
253}
254
255impl<T: Copy + Add<Output = T>> Add for Matrix<T> {
256 type Output = Matrix<T>;
257
258 fn add(mut self, rhs: Self) -> Self::Output {
259 self += &rhs;
260 self
261 }
262}
263
264impl<T: Copy + Sub<Output = T>> Sub for Matrix<T> {
265 type Output = Matrix<T>;
266
267 fn sub(mut self, rhs: Self) -> Self::Output {
268 self -= &rhs;
269 self
270 }
271}
272
273impl<T: Copy> From<Vec<T>> for Matrix<T> {
274 fn from(v: Vec<T>) -> Self {
275 let nrows = v.len();
276 Self::new_from_iter((nrows, 1), v.into_iter())
277 }
278}
279impl<'a, T: Copy> From<&'a [T]> for Matrix<T> {
280 fn from(v: &'a [T]) -> Self {
281 let nrows = v.len();
282 Self::new_from_iter((nrows, 1), v.iter().copied())
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289 use crate::utils::field::ScalarField;
290 use ff::Field;
291
292 type F = ScalarField;
293
294 #[test]
295 fn test_det_dim3() {
296 let data = vec![
300 F::from(4),
301 F::from(2),
302 F::from(4),
303 F::ZERO,
304 F::ZERO,
305 F::from(3),
306 F::from(5),
307 F::from(7),
308 F::from(7),
309 ];
310
311 let mat = Matrix::new_from_iter((3, 3), data.into_iter());
312
313 let det = mat.det();
314 assert_eq!(F::from(54), det);
315 }
316
317 #[test]
318 fn test_det_dim4() {
319 let data = vec![
320 F::from(6),
321 F::from(4),
322 F::from(7),
323 F::from(8),
324 F::from(9),
325 F::from(3),
326 F::from(9),
327 F::from(8),
328 F::from(8),
329 F::from(3),
330 F::from(4),
331 F::from(9),
332 F::from(5),
333 F::from(4),
334 F::from(1),
335 F::from(3),
336 ];
337
338 let mat = Matrix::new_from_iter((4, 4), data.into_iter());
339 let det = mat.det();
340 assert_eq!(F::from(-476), det);
341 }
342}