1use core::convert::TryInto;
2use core::ops::Mul;
3
4use crate::error::PoseidonParameterError;
5use crate::matrix_ops::{dot_product, MatrixOperations, SquareMatrixOperations};
6use decaf377::Fq;
7
8#[derive(Clone, Debug, PartialEq, Eq)]
13pub struct Matrix<const N_ROWS: usize, const N_COLS: usize, const N_ELEMENTS: usize> {
14 pub elements: [Fq; N_ELEMENTS],
17}
18
19impl<const N_ROWS: usize, const N_COLS: usize, const N_ELEMENTS: usize>
20 Matrix<N_ROWS, N_COLS, N_ELEMENTS>
21{
22 pub fn transpose(&self) -> Matrix<N_COLS, N_ROWS, N_ELEMENTS> {
23 let mut transposed_elements = [Fq::default(); N_ELEMENTS];
24
25 let mut index = 0;
26 for j in 0..self.n_cols() {
27 for i in 0..self.n_rows() {
28 transposed_elements[index] = self.get_element(i, j);
29 index += 1;
30 }
31 }
32 Matrix::<N_COLS, N_ROWS, N_ELEMENTS>::new(&transposed_elements)
33 }
34
35 pub const fn new_from_known(elements: [Fq; N_ELEMENTS]) -> Self {
37 if N_ELEMENTS != N_ROWS * N_COLS {
38 panic!("Matrix has an insufficient number of elements")
39 }
40
41 Self { elements }
42 }
43}
44
45impl<const N_ROWS: usize, const N_COLS: usize, const N_ELEMENTS: usize> MatrixOperations
46 for Matrix<N_ROWS, N_COLS, N_ELEMENTS>
47{
48 fn new(elements: &[Fq]) -> Self {
49 if N_ELEMENTS != N_ROWS * N_COLS {
53 panic!("Matrix has an insufficient number of elements")
54 }
55
56 let elements: [Fq; N_ELEMENTS] = elements
57 .try_into()
58 .expect("Matrix has the correct number of elements");
59
60 Self { elements }
61 }
62
63 fn elements(&self) -> &[Fq] {
64 &self.elements
65 }
66
67 fn get_element(&self, i: usize, j: usize) -> Fq {
68 self.elements[i * N_COLS + j]
69 }
70
71 fn set_element(&mut self, i: usize, j: usize, val: Fq) {
72 self.elements[i * N_COLS + j] = val
73 }
74
75 fn n_rows(&self) -> usize {
76 N_ROWS
77 }
78
79 fn n_cols(&self) -> usize {
80 N_COLS
81 }
82
83 fn hadamard_product(&self, rhs: &Self) -> Result<Self, PoseidonParameterError>
84 where
85 Self: Sized,
86 {
87 let mut new_elements = [Fq::default(); N_ELEMENTS];
88 let mut index = 0;
89 for i in 0..self.n_rows() {
90 for j in 0..self.n_cols() {
91 new_elements[index] = self.get_element(i, j) * rhs.get_element(i, j);
92 index += 1;
93 }
94 }
95
96 Ok(Self::new(&new_elements))
97 }
98}
99
100pub fn mat_mul<
102 const LHS_N_ROWS: usize,
103 const LHS_N_COLS: usize,
104 const LHS_N_ELEMENTS: usize,
105 const RHS_N_ROWS: usize,
106 const RHS_N_COLS: usize,
107 const RHS_N_ELEMENTS: usize,
108 const RESULT_N_ELEMENTS: usize,
109>(
110 lhs: &Matrix<LHS_N_ROWS, LHS_N_COLS, LHS_N_ELEMENTS>,
111 rhs: &Matrix<RHS_N_ROWS, RHS_N_COLS, RHS_N_ELEMENTS>,
112) -> Matrix<LHS_N_ROWS, RHS_N_COLS, RESULT_N_ELEMENTS> {
113 let rhs_T = rhs.transpose();
114
115 let mut new_elements = [Fq::default(); RESULT_N_ELEMENTS];
116
117 let mut index = 0;
118 for row in lhs.iter_rows() {
119 for column in rhs_T.iter_rows() {
121 new_elements[index] = dot_product(row, column);
122 index += 1;
123 }
124 }
125
126 Matrix::<LHS_N_ROWS, RHS_N_COLS, RESULT_N_ELEMENTS>::new(&new_elements)
127}
128
129impl<const N_ROWS: usize, const N_COLS: usize, const N_ELEMENTS: usize> Mul<Fq>
131 for Matrix<N_ROWS, N_COLS, N_ELEMENTS>
132{
133 type Output = Matrix<N_ROWS, N_COLS, N_ELEMENTS>;
134
135 fn mul(self, rhs: Fq) -> Self::Output {
136 let elements = self.elements();
137 let mut new_elements = [Fq::default(); N_ELEMENTS];
138 for (i, &element) in elements.iter().enumerate() {
139 new_elements[i] = element * rhs;
140 }
141 Self::new(&new_elements)
142 }
143}
144
145impl<const N_ROWS: usize, const N_COLS: usize, const N_ELEMENTS: usize>
146 Matrix<N_ROWS, N_COLS, N_ELEMENTS>
147{
148 pub fn row_vector(&self, i: usize) -> Matrix<1, N_COLS, N_ELEMENTS> {
150 let mut row_elements = [Fq::default(); N_COLS];
151 for j in 0..N_COLS {
152 row_elements[j] = self.get_element(i, j);
153 }
154 Matrix::new(&row_elements)
155 }
156}
157
158impl<const N_ROWS: usize, const N_ELEMENTS: usize> SquareMatrix<N_ROWS, N_ELEMENTS> {
159 pub fn transpose(&self) -> Self {
160 Self(self.0.transpose())
161 }
162}
163
164#[derive(Clone, Debug, PartialEq, Eq)]
166pub struct SquareMatrix<const N_ROWS: usize, const N_ELEMENTS: usize>(
167 pub Matrix<N_ROWS, N_ROWS, N_ELEMENTS>,
168);
169
170impl<const N_ROWS: usize, const N_ELEMENTS: usize> MatrixOperations
171 for SquareMatrix<N_ROWS, N_ELEMENTS>
172{
173 fn new(elements: &[Fq]) -> Self {
174 Self(Matrix::new(elements))
175 }
176
177 fn elements(&self) -> &[Fq] {
178 self.0.elements()
179 }
180
181 fn get_element(&self, i: usize, j: usize) -> Fq {
182 self.0.get_element(i, j)
183 }
184
185 fn set_element(&mut self, i: usize, j: usize, val: Fq) {
186 self.0.set_element(i, j, val)
187 }
188
189 fn n_rows(&self) -> usize {
190 N_ROWS
191 }
192
193 fn n_cols(&self) -> usize {
194 N_ROWS
196 }
197
198 fn hadamard_product(&self, rhs: &Self) -> Result<Self, PoseidonParameterError>
199 where
200 Self: Sized,
201 {
202 Ok(Self(self.0.hadamard_product(&rhs.0)?))
203 }
204}
205
206impl<const N_ROWS: usize, const N_ELEMENTS: usize> SquareMatrixOperations
207 for SquareMatrix<N_ROWS, N_ELEMENTS>
208{
209 fn inverse(&self) -> Result<Self, PoseidonParameterError> {
211 let identity = Self::identity();
212
213 if self.n_rows() == 1 {
214 let elements = [self
215 .get_element(0, 0)
216 .inverse()
217 .expect("inverse of single element must exist for 1x1 matrix")];
218 return Ok(Self::new(&elements));
219 }
220
221 let determinant = self.determinant();
222 if determinant == Fq::from(0u64) {
223 return Err(PoseidonParameterError::NoMatrixInverse);
224 }
225
226 let minors = self.minors();
227 let cofactor_matrix = self.cofactors();
228 let signed_minors = minors
229 .hadamard_product(&cofactor_matrix)
230 .expect("minor and cofactor matrix have correct dimensions");
231 let adj = signed_minors.transpose();
232 let matrix_inverse = adj * (Fq::from(1u64) / determinant);
233
234 debug_assert_eq!(square_mat_mul(self, &matrix_inverse), identity);
235 Ok(matrix_inverse)
236 }
237
238 fn identity() -> Self {
240 let elements = [Fq::from(0u64); N_ELEMENTS];
241 let mut m = Self::new(&elements);
242
243 for i in 0..N_ROWS {
245 m.set_element(i, i, Fq::from(1u64));
246 }
247
248 m
249 }
250
251 fn minors(&self) -> Self {
253 match N_ROWS {
254 0 => panic!("matrix has no elements!"),
255 1 => Self::new(&[self.get_element(0, 0)]),
256 2 => {
257 let a = self.get_element(0, 0);
258 let b = self.get_element(0, 1);
259 let c = self.get_element(1, 0);
260 let d = self.get_element(1, 1);
261 Self::new(&[d, c, b, a])
262 }
263 3 => minor_matrix::<N_ROWS, 2, N_ELEMENTS, 4>(self),
264 4 => minor_matrix::<N_ROWS, 3, N_ELEMENTS, 9>(self),
265 5 => minor_matrix::<N_ROWS, 4, N_ELEMENTS, 16>(self),
266 6 => minor_matrix::<N_ROWS, 5, N_ELEMENTS, 25>(self),
267 7 => minor_matrix::<N_ROWS, 6, N_ELEMENTS, 36>(self),
268 8 => minor_matrix::<N_ROWS, 7, N_ELEMENTS, 49>(self),
269 _ => {
270 unimplemented!("poseidon-parameters only supports square matrices up to 8")
271 }
272 }
273 }
274
275 fn cofactors(&self) -> Self {
277 let dim = self.n_rows();
278 let mut elements = [Fq::from(0u64); N_ELEMENTS];
279
280 let mut index = 0;
281 for i in 0..dim {
282 for j in 0..dim {
283 elements[index] = (-Fq::from(1u64)).power([(i + j) as u64]);
284 index += 1;
285 }
286 }
287 Self::new(&elements)
288 }
289
290 fn determinant(&self) -> Fq {
292 match N_ROWS {
293 0 => panic!("matrix has no elements!"),
294 1 => self.get_element(0, 0),
295 2 => determinant::<N_ROWS, 1, N_ELEMENTS, 1>(self),
296 3 => determinant::<N_ROWS, 2, N_ELEMENTS, 4>(self),
297 4 => determinant::<N_ROWS, 3, N_ELEMENTS, 9>(self),
298 5 => determinant::<N_ROWS, 4, N_ELEMENTS, 16>(self),
299 6 => determinant::<N_ROWS, 5, N_ELEMENTS, 25>(self),
300 7 => determinant::<N_ROWS, 6, N_ELEMENTS, 36>(self),
301 8 => determinant::<N_ROWS, 7, N_ELEMENTS, 49>(self),
302 _ => {
303 unimplemented!("poseidon-parameters only supports square matrices up to 8")
304 }
305 }
306 }
307}
308
309impl<const N_ROWS: usize, const N_ELEMENTS: usize> Mul<Fq> for SquareMatrix<N_ROWS, N_ELEMENTS> {
311 type Output = SquareMatrix<N_ROWS, N_ELEMENTS>;
312
313 fn mul(self, rhs: Fq) -> Self::Output {
314 let elements = self.elements();
315 let mut new_elements = [Fq::default(); N_ELEMENTS];
316 for (i, &element) in elements.iter().enumerate() {
317 new_elements[i] = element * rhs;
318 }
319 Self::new(&new_elements)
320 }
321}
322
323impl<const N_ROWS: usize, const N_ELEMENTS: usize> SquareMatrix<N_ROWS, N_ELEMENTS> {
324 pub fn row_vector(&self, i: usize) -> Matrix<1, N_ROWS, N_ELEMENTS> {
326 self.0.row_vector(i)
327 }
328
329 pub fn new_2x2(a: Fq, b: Fq, c: Fq, d: Fq) -> SquareMatrix<2, 4> {
331 SquareMatrix::<2, 4>::new(&[a, b, c, d])
332 }
333
334 pub const fn new_from_known(elements: [Fq; N_ELEMENTS]) -> Self {
336 Self(Matrix::new_from_known(elements))
337 }
338}
339
340pub fn square_mat_mul<
342 const LHS_N_ROWS: usize,
343 const LHS_N_ELEMENTS: usize,
344 const RHS_N_ROWS: usize,
345 const RHS_N_ELEMENTS: usize,
346 const RESULT_N_ELEMENTS: usize,
347>(
348 lhs: &SquareMatrix<LHS_N_ROWS, LHS_N_ELEMENTS>,
349 rhs: &SquareMatrix<RHS_N_ROWS, RHS_N_ELEMENTS>,
350) -> SquareMatrix<LHS_N_ROWS, RESULT_N_ELEMENTS> {
351 let rhs_T = rhs.transpose();
352
353 let mut new_elements = [Fq::default(); RESULT_N_ELEMENTS];
354
355 let mut index = 0;
356 for row in lhs.iter_rows() {
357 for column in rhs_T.iter_rows() {
359 new_elements[index] = dot_product(row, column);
360 index += 1;
361 }
362 }
363
364 SquareMatrix::<LHS_N_ROWS, RESULT_N_ELEMENTS>::new(&new_elements)
365}
366
367fn minor_matrix<
369 const DIM: usize,
370 const DIM_MINUS_1: usize,
371 const N_ELEMENTS: usize,
372 const N_ELEMENTS_DIM_MINUS_1: usize,
373>(
374 matrix: &SquareMatrix<DIM, N_ELEMENTS>,
375) -> SquareMatrix<DIM, N_ELEMENTS> {
376 let mut minor_matrix_elements = [Fq::default(); N_ELEMENTS];
377 let mut outer_index = 0;
378 for i in 0..DIM {
379 for j in 0..DIM {
380 let mut elements = [Fq::default(); N_ELEMENTS_DIM_MINUS_1];
381 let mut index = 0;
382 for k in 0..i {
383 for l in 0..j {
384 elements[index] = matrix.get_element(k, l);
385 index += 1;
386 }
387 for l in (j + 1)..DIM {
388 elements[index] = matrix.get_element(k, l);
389 index += 1;
390 }
391 }
392 for k in i + 1..DIM {
393 for l in 0..j {
394 elements[index] = matrix.get_element(k, l);
395 index += 1;
396 }
397 for l in (j + 1)..DIM {
398 elements[index] = matrix.get_element(k, l);
399 index += 1;
400 }
401 }
402 let minor = SquareMatrix::<DIM_MINUS_1, N_ELEMENTS_DIM_MINUS_1>::new(&elements);
403 minor_matrix_elements[outer_index] = minor.determinant();
404 outer_index += 1;
405 }
406 }
407 SquareMatrix::<DIM, N_ELEMENTS>::new(&minor_matrix_elements)
408}
409
410fn determinant<
412 const DIM: usize,
413 const DIM_MINUS_1: usize,
414 const N_ELEMENTS: usize,
415 const N_ELEMENTS_DIM_MINUS_1: usize,
416>(
417 matrix: &SquareMatrix<DIM, N_ELEMENTS>,
418) -> Fq {
419 let mut det = Fq::from(0u64);
420 let mut levi_civita = true;
421
422 for i in 0..DIM {
423 let mut elements = [Fq::default(); N_ELEMENTS_DIM_MINUS_1];
424 let mut index = 0;
425 for k in 0..i {
426 for l in 1..DIM {
427 elements[index] = matrix.get_element(k, l);
428 index += 1;
429 }
430 }
431 for k in i + 1..DIM {
432 for l in 1..DIM {
433 elements[index] = matrix.get_element(k, l);
434 index += 1;
435 }
436 }
437 let minor = SquareMatrix::<DIM_MINUS_1, N_ELEMENTS_DIM_MINUS_1>::new(&elements);
438 if levi_civita {
439 det += matrix.get_element(i, 0) * minor.determinant();
440 } else {
441 det -= matrix.get_element(i, 0) * minor.determinant();
442 }
443 levi_civita = !levi_civita;
444 }
445 det
446}