poseidon_parameters/
matrix.rs

1use core::convert::TryInto;
2use core::ops::Mul;
3
4use crate::error::PoseidonParameterError;
5use crate::matrix_ops::{dot_product, MatrixOperations, SquareMatrixOperations};
6use decaf377::Fq;
7
8/// Represents a matrix over `PrimeField` elements.
9///
10/// This matrix can be used to represent row or column
11/// vectors.
12#[derive(Clone, Debug, PartialEq, Eq)]
13pub struct Matrix<const N_ROWS: usize, const N_COLS: usize, const N_ELEMENTS: usize> {
14    /// Elements of the matrix, stored in a fixed-size array.
15    ///
16    pub elements: [Fq; N_ELEMENTS],
17}
18
19impl<const N_ROWS: usize, const N_COLS: usize, const N_ELEMENTS: usize>
20    Matrix<N_ROWS, N_COLS, N_ELEMENTS>
21{
22    pub fn transpose(&self) -> Matrix<N_COLS, N_ROWS, N_ELEMENTS> {
23        let mut transposed_elements = [Fq::default(); N_ELEMENTS];
24
25        let mut index = 0;
26        for j in 0..self.n_cols() {
27            for i in 0..self.n_rows() {
28                transposed_elements[index] = self.get_element(i, j);
29                index += 1;
30            }
31        }
32        Matrix::<N_COLS, N_ROWS, N_ELEMENTS>::new(&transposed_elements)
33    }
34
35    /// Create a new matrix from a slice of elements.
36    pub const fn new_from_known(elements: [Fq; N_ELEMENTS]) -> Self {
37        if N_ELEMENTS != N_ROWS * N_COLS {
38            panic!("Matrix has an insufficient number of elements")
39        }
40
41        Self { elements }
42    }
43}
44
45impl<const N_ROWS: usize, const N_COLS: usize, const N_ELEMENTS: usize> MatrixOperations
46    for Matrix<N_ROWS, N_COLS, N_ELEMENTS>
47{
48    fn new(elements: &[Fq]) -> Self {
49        // Note: We use a third const generic to denote the number of elements in the
50        // matrix here due to `generic_const_exprs` being an unstable Rust feature at
51        // the time of writing.
52        if N_ELEMENTS != N_ROWS * N_COLS {
53            panic!("Matrix has an insufficient number of elements")
54        }
55
56        let elements: [Fq; N_ELEMENTS] = elements
57            .try_into()
58            .expect("Matrix has the correct number of elements");
59
60        Self { elements }
61    }
62
63    fn elements(&self) -> &[Fq] {
64        &self.elements
65    }
66
67    fn get_element(&self, i: usize, j: usize) -> Fq {
68        self.elements[i * N_COLS + j]
69    }
70
71    fn set_element(&mut self, i: usize, j: usize, val: Fq) {
72        self.elements[i * N_COLS + j] = val
73    }
74
75    fn n_rows(&self) -> usize {
76        N_ROWS
77    }
78
79    fn n_cols(&self) -> usize {
80        N_COLS
81    }
82
83    fn hadamard_product(&self, rhs: &Self) -> Result<Self, PoseidonParameterError>
84    where
85        Self: Sized,
86    {
87        let mut new_elements = [Fq::default(); N_ELEMENTS];
88        let mut index = 0;
89        for i in 0..self.n_rows() {
90            for j in 0..self.n_cols() {
91                new_elements[index] = self.get_element(i, j) * rhs.get_element(i, j);
92                index += 1;
93            }
94        }
95
96        Ok(Self::new(&new_elements))
97    }
98}
99
100/// Multiply two `Matrix`
101pub fn mat_mul<
102    const LHS_N_ROWS: usize,
103    const LHS_N_COLS: usize,
104    const LHS_N_ELEMENTS: usize,
105    const RHS_N_ROWS: usize,
106    const RHS_N_COLS: usize,
107    const RHS_N_ELEMENTS: usize,
108    const RESULT_N_ELEMENTS: usize,
109>(
110    lhs: &Matrix<LHS_N_ROWS, LHS_N_COLS, LHS_N_ELEMENTS>,
111    rhs: &Matrix<RHS_N_ROWS, RHS_N_COLS, RHS_N_ELEMENTS>,
112) -> Matrix<LHS_N_ROWS, RHS_N_COLS, RESULT_N_ELEMENTS> {
113    let rhs_T = rhs.transpose();
114
115    let mut new_elements = [Fq::default(); RESULT_N_ELEMENTS];
116
117    let mut index = 0;
118    for row in lhs.iter_rows() {
119        // Rows of the transposed matrix are the columns of the original matrix
120        for column in rhs_T.iter_rows() {
121            new_elements[index] = dot_product(row, column);
122            index += 1;
123        }
124    }
125
126    Matrix::<LHS_N_ROWS, RHS_N_COLS, RESULT_N_ELEMENTS>::new(&new_elements)
127}
128
129/// Multiply scalar by Matrix
130impl<const N_ROWS: usize, const N_COLS: usize, const N_ELEMENTS: usize> Mul<Fq>
131    for Matrix<N_ROWS, N_COLS, N_ELEMENTS>
132{
133    type Output = Matrix<N_ROWS, N_COLS, N_ELEMENTS>;
134
135    fn mul(self, rhs: Fq) -> Self::Output {
136        let elements = self.elements();
137        let mut new_elements = [Fq::default(); N_ELEMENTS];
138        for (i, &element) in elements.iter().enumerate() {
139            new_elements[i] = element * rhs;
140        }
141        Self::new(&new_elements)
142    }
143}
144
145impl<const N_ROWS: usize, const N_COLS: usize, const N_ELEMENTS: usize>
146    Matrix<N_ROWS, N_COLS, N_ELEMENTS>
147{
148    /// Get row vector at a specified row index
149    pub fn row_vector(&self, i: usize) -> Matrix<1, N_COLS, N_ELEMENTS> {
150        let mut row_elements = [Fq::default(); N_COLS];
151        for j in 0..N_COLS {
152            row_elements[j] = self.get_element(i, j);
153        }
154        Matrix::new(&row_elements)
155    }
156}
157
158impl<const N_ROWS: usize, const N_ELEMENTS: usize> SquareMatrix<N_ROWS, N_ELEMENTS> {
159    pub fn transpose(&self) -> Self {
160        Self(self.0.transpose())
161    }
162}
163
164/// Represents a square matrix over `PrimeField` elements
165#[derive(Clone, Debug, PartialEq, Eq)]
166pub struct SquareMatrix<const N_ROWS: usize, const N_ELEMENTS: usize>(
167    pub Matrix<N_ROWS, N_ROWS, N_ELEMENTS>,
168);
169
170impl<const N_ROWS: usize, const N_ELEMENTS: usize> MatrixOperations
171    for SquareMatrix<N_ROWS, N_ELEMENTS>
172{
173    fn new(elements: &[Fq]) -> Self {
174        Self(Matrix::new(elements))
175    }
176
177    fn elements(&self) -> &[Fq] {
178        self.0.elements()
179    }
180
181    fn get_element(&self, i: usize, j: usize) -> Fq {
182        self.0.get_element(i, j)
183    }
184
185    fn set_element(&mut self, i: usize, j: usize, val: Fq) {
186        self.0.set_element(i, j, val)
187    }
188
189    fn n_rows(&self) -> usize {
190        N_ROWS
191    }
192
193    fn n_cols(&self) -> usize {
194        // Matrix is square
195        N_ROWS
196    }
197
198    fn hadamard_product(&self, rhs: &Self) -> Result<Self, PoseidonParameterError>
199    where
200        Self: Sized,
201    {
202        Ok(Self(self.0.hadamard_product(&rhs.0)?))
203    }
204}
205
206impl<const N_ROWS: usize, const N_ELEMENTS: usize> SquareMatrixOperations
207    for SquareMatrix<N_ROWS, N_ELEMENTS>
208{
209    /// Compute the inverse of the matrix
210    fn inverse(&self) -> Result<Self, PoseidonParameterError> {
211        let identity = Self::identity();
212
213        if self.n_rows() == 1 {
214            let elements = [self
215                .get_element(0, 0)
216                .inverse()
217                .expect("inverse of single element must exist for 1x1 matrix")];
218            return Ok(Self::new(&elements));
219        }
220
221        let determinant = self.determinant();
222        if determinant == Fq::from(0u64) {
223            return Err(PoseidonParameterError::NoMatrixInverse);
224        }
225
226        let minors = self.minors();
227        let cofactor_matrix = self.cofactors();
228        let signed_minors = minors
229            .hadamard_product(&cofactor_matrix)
230            .expect("minor and cofactor matrix have correct dimensions");
231        let adj = signed_minors.transpose();
232        let matrix_inverse = adj * (Fq::from(1u64) / determinant);
233
234        debug_assert_eq!(square_mat_mul(self, &matrix_inverse), identity);
235        Ok(matrix_inverse)
236    }
237
238    /// Construct an identity matrix
239    fn identity() -> Self {
240        let elements = [Fq::from(0u64); N_ELEMENTS];
241        let mut m = Self::new(&elements);
242
243        // Set diagonals to 1
244        for i in 0..N_ROWS {
245            m.set_element(i, i, Fq::from(1u64));
246        }
247
248        m
249    }
250
251    /// Compute the (unsigned) minors of this matrix
252    fn minors(&self) -> Self {
253        match N_ROWS {
254            0 => panic!("matrix has no elements!"),
255            1 => Self::new(&[self.get_element(0, 0)]),
256            2 => {
257                let a = self.get_element(0, 0);
258                let b = self.get_element(0, 1);
259                let c = self.get_element(1, 0);
260                let d = self.get_element(1, 1);
261                Self::new(&[d, c, b, a])
262            }
263            3 => minor_matrix::<N_ROWS, 2, N_ELEMENTS, 4>(self),
264            4 => minor_matrix::<N_ROWS, 3, N_ELEMENTS, 9>(self),
265            5 => minor_matrix::<N_ROWS, 4, N_ELEMENTS, 16>(self),
266            6 => minor_matrix::<N_ROWS, 5, N_ELEMENTS, 25>(self),
267            7 => minor_matrix::<N_ROWS, 6, N_ELEMENTS, 36>(self),
268            8 => minor_matrix::<N_ROWS, 7, N_ELEMENTS, 49>(self),
269            _ => {
270                unimplemented!("poseidon-parameters only supports square matrices up to 8")
271            }
272        }
273    }
274
275    /// Compute the cofactor matrix, i.e. $C_{ij} = (-1)^{i+j}$
276    fn cofactors(&self) -> Self {
277        let dim = self.n_rows();
278        let mut elements = [Fq::from(0u64); N_ELEMENTS];
279
280        let mut index = 0;
281        for i in 0..dim {
282            for j in 0..dim {
283                elements[index] = (-Fq::from(1u64)).power([(i + j) as u64]);
284                index += 1;
285            }
286        }
287        Self::new(&elements)
288    }
289
290    /// Compute the matrix determinant
291    fn determinant(&self) -> Fq {
292        match N_ROWS {
293            0 => panic!("matrix has no elements!"),
294            1 => self.get_element(0, 0),
295            2 => determinant::<N_ROWS, 1, N_ELEMENTS, 1>(self),
296            3 => determinant::<N_ROWS, 2, N_ELEMENTS, 4>(self),
297            4 => determinant::<N_ROWS, 3, N_ELEMENTS, 9>(self),
298            5 => determinant::<N_ROWS, 4, N_ELEMENTS, 16>(self),
299            6 => determinant::<N_ROWS, 5, N_ELEMENTS, 25>(self),
300            7 => determinant::<N_ROWS, 6, N_ELEMENTS, 36>(self),
301            8 => determinant::<N_ROWS, 7, N_ELEMENTS, 49>(self),
302            _ => {
303                unimplemented!("poseidon-parameters only supports square matrices up to 8")
304            }
305        }
306    }
307}
308
309/// Multiply scalar by SquareMatrix
310impl<const N_ROWS: usize, const N_ELEMENTS: usize> Mul<Fq> for SquareMatrix<N_ROWS, N_ELEMENTS> {
311    type Output = SquareMatrix<N_ROWS, N_ELEMENTS>;
312
313    fn mul(self, rhs: Fq) -> Self::Output {
314        let elements = self.elements();
315        let mut new_elements = [Fq::default(); N_ELEMENTS];
316        for (i, &element) in elements.iter().enumerate() {
317            new_elements[i] = element * rhs;
318        }
319        Self::new(&new_elements)
320    }
321}
322
323impl<const N_ROWS: usize, const N_ELEMENTS: usize> SquareMatrix<N_ROWS, N_ELEMENTS> {
324    /// Get row vector at a specified row index.
325    pub fn row_vector(&self, i: usize) -> Matrix<1, N_ROWS, N_ELEMENTS> {
326        self.0.row_vector(i)
327    }
328
329    /// Create a 2x2 `SquareMatrix` from four elements.
330    pub fn new_2x2(a: Fq, b: Fq, c: Fq, d: Fq) -> SquareMatrix<2, 4> {
331        SquareMatrix::<2, 4>::new(&[a, b, c, d])
332    }
333
334    /// Create a new matrix from a slice of elements.
335    pub const fn new_from_known(elements: [Fq; N_ELEMENTS]) -> Self {
336        Self(Matrix::new_from_known(elements))
337    }
338}
339
340/// Multiply two matrices
341pub fn square_mat_mul<
342    const LHS_N_ROWS: usize,
343    const LHS_N_ELEMENTS: usize,
344    const RHS_N_ROWS: usize,
345    const RHS_N_ELEMENTS: usize,
346    const RESULT_N_ELEMENTS: usize,
347>(
348    lhs: &SquareMatrix<LHS_N_ROWS, LHS_N_ELEMENTS>,
349    rhs: &SquareMatrix<RHS_N_ROWS, RHS_N_ELEMENTS>,
350) -> SquareMatrix<LHS_N_ROWS, RESULT_N_ELEMENTS> {
351    let rhs_T = rhs.transpose();
352
353    let mut new_elements = [Fq::default(); RESULT_N_ELEMENTS];
354
355    let mut index = 0;
356    for row in lhs.iter_rows() {
357        // Rows of the transposed matrix are the columns of the original matrix
358        for column in rhs_T.iter_rows() {
359            new_elements[index] = dot_product(row, column);
360            index += 1;
361        }
362    }
363
364    SquareMatrix::<LHS_N_ROWS, RESULT_N_ELEMENTS>::new(&new_elements)
365}
366
367/// Helper function for computing matrix minors
368fn minor_matrix<
369    const DIM: usize,
370    const DIM_MINUS_1: usize,
371    const N_ELEMENTS: usize,
372    const N_ELEMENTS_DIM_MINUS_1: usize,
373>(
374    matrix: &SquareMatrix<DIM, N_ELEMENTS>,
375) -> SquareMatrix<DIM, N_ELEMENTS> {
376    let mut minor_matrix_elements = [Fq::default(); N_ELEMENTS];
377    let mut outer_index = 0;
378    for i in 0..DIM {
379        for j in 0..DIM {
380            let mut elements = [Fq::default(); N_ELEMENTS_DIM_MINUS_1];
381            let mut index = 0;
382            for k in 0..i {
383                for l in 0..j {
384                    elements[index] = matrix.get_element(k, l);
385                    index += 1;
386                }
387                for l in (j + 1)..DIM {
388                    elements[index] = matrix.get_element(k, l);
389                    index += 1;
390                }
391            }
392            for k in i + 1..DIM {
393                for l in 0..j {
394                    elements[index] = matrix.get_element(k, l);
395                    index += 1;
396                }
397                for l in (j + 1)..DIM {
398                    elements[index] = matrix.get_element(k, l);
399                    index += 1;
400                }
401            }
402            let minor = SquareMatrix::<DIM_MINUS_1, N_ELEMENTS_DIM_MINUS_1>::new(&elements);
403            minor_matrix_elements[outer_index] = minor.determinant();
404            outer_index += 1;
405        }
406    }
407    SquareMatrix::<DIM, N_ELEMENTS>::new(&minor_matrix_elements)
408}
409
410/// Helper function for computing matrix determinant
411fn determinant<
412    const DIM: usize,
413    const DIM_MINUS_1: usize,
414    const N_ELEMENTS: usize,
415    const N_ELEMENTS_DIM_MINUS_1: usize,
416>(
417    matrix: &SquareMatrix<DIM, N_ELEMENTS>,
418) -> Fq {
419    let mut det = Fq::from(0u64);
420    let mut levi_civita = true;
421
422    for i in 0..DIM {
423        let mut elements = [Fq::default(); N_ELEMENTS_DIM_MINUS_1];
424        let mut index = 0;
425        for k in 0..i {
426            for l in 1..DIM {
427                elements[index] = matrix.get_element(k, l);
428                index += 1;
429            }
430        }
431        for k in i + 1..DIM {
432            for l in 1..DIM {
433                elements[index] = matrix.get_element(k, l);
434                index += 1;
435            }
436        }
437        let minor = SquareMatrix::<DIM_MINUS_1, N_ELEMENTS_DIM_MINUS_1>::new(&elements);
438        if levi_civita {
439            det += matrix.get_element(i, 0) * minor.determinant();
440        } else {
441            det -= matrix.get_element(i, 0) * minor.determinant();
442        }
443        levi_civita = !levi_civita;
444    }
445    det
446}