poseidon_parameters/
mds_matrix.rs

1use crate::{
2    error::PoseidonParameterError,
3    matrix::{Matrix, SquareMatrix},
4    matrix_ops::{MatrixOperations, SquareMatrixOperations},
5};
6use decaf377::Fq;
7
8/// Represents an MDS (maximum distance separable) matrix.
9#[derive(Clone, Debug, PartialEq, Eq)]
10pub struct MdsMatrix<
11    const STATE_SIZE: usize,
12    const STATE_SIZE_MINUS_1: usize,
13    const NUM_ELEMENTS: usize,
14    const NUM_ELEMENTS_STATE_SIZE_MINUS_1_2: usize,
15>(pub SquareMatrix<STATE_SIZE, NUM_ELEMENTS>);
16
17impl<
18        const STATE_SIZE: usize,
19        const STATE_SIZE_MINUS_1: usize,
20        const NUM_ELEMENTS: usize,
21        const NUM_ELEMENTS_STATE_SIZE_MINUS_1_2: usize,
22    > MatrixOperations
23    for MdsMatrix<STATE_SIZE, STATE_SIZE_MINUS_1, NUM_ELEMENTS, NUM_ELEMENTS_STATE_SIZE_MINUS_1_2>
24{
25    fn new(elements: &[Fq]) -> Self {
26        assert!(STATE_SIZE == STATE_SIZE_MINUS_1 + 1);
27        assert!(STATE_SIZE * STATE_SIZE == NUM_ELEMENTS);
28        assert!(STATE_SIZE_MINUS_1 * STATE_SIZE_MINUS_1 == NUM_ELEMENTS_STATE_SIZE_MINUS_1_2);
29        Self(SquareMatrix::new(elements))
30    }
31
32    fn elements(&self) -> &[Fq] {
33        self.0.elements()
34    }
35
36    fn get_element(&self, i: usize, j: usize) -> Fq {
37        self.0.get_element(i, j)
38    }
39
40    fn set_element(&mut self, i: usize, j: usize, val: Fq) {
41        self.0.set_element(i, j, val)
42    }
43
44    fn n_rows(&self) -> usize {
45        self.0.n_rows()
46    }
47
48    fn n_cols(&self) -> usize {
49        self.0.n_cols()
50    }
51
52    fn hadamard_product(&self, rhs: &Self) -> Result<Self, PoseidonParameterError>
53    where
54        Self: Sized,
55    {
56        Ok(Self(self.0.hadamard_product(&rhs.0)?))
57    }
58}
59
60impl<
61        const STATE_SIZE: usize,
62        const STATE_SIZE_MINUS_1: usize,
63        const NUM_ELEMENTS: usize,
64        const NUM_ELEMENTS_STATE_SIZE_MINUS_1_2: usize,
65    > MdsMatrix<STATE_SIZE, STATE_SIZE_MINUS_1, NUM_ELEMENTS, NUM_ELEMENTS_STATE_SIZE_MINUS_1_2>
66{
67    /// Instantiate an MDS matrix from a list of elements.
68    ///
69    /// # Security
70    ///
71    /// You must ensure this matrix was generated securely,
72    /// using the Cauchy method in `fixed_cauchy_matrix` or
73    /// using the random subsampling method described in the original
74    /// paper.
75    pub fn from_elements(elements: &[Fq]) -> Self {
76        Self(SquareMatrix::new(elements))
77    }
78
79    pub fn transpose(&self) -> Self {
80        Self(self.0.transpose())
81    }
82
83    /// Compute inverse of MDS matrix
84    pub fn inverse(&self) -> SquareMatrix<STATE_SIZE, NUM_ELEMENTS> {
85        self.0
86            .inverse()
87            .expect("all well-formed MDS matrices should have inverses")
88    }
89
90    /// Return the elements M_{0,1} .. M_{0,t} from the first row
91    ///
92    /// Ref: p.20 of the Poseidon paper
93    pub fn v(&self) -> Matrix<1, STATE_SIZE_MINUS_1, STATE_SIZE_MINUS_1> {
94        let elements = &self.0 .0.elements()[1..self.0 .0.n_rows()];
95        Matrix::new(elements)
96    }
97
98    /// Return the elements M_{1,0} .. M_{t,0} from the first column
99    ///
100    /// Ref: p.20 of the Poseidon paper
101    pub fn w(&self) -> Matrix<STATE_SIZE_MINUS_1, 1, STATE_SIZE_MINUS_1> {
102        let mut elements = [Fq::from(0u64); STATE_SIZE_MINUS_1];
103        for i in 1..self.n_rows() {
104            elements[i - 1] = self.get_element(i, 0);
105        }
106        Matrix::new(&elements)
107    }
108
109    /// Compute the (t - 1) x (t - 1) Mhat matrix from the MDS matrix
110    ///
111    /// This is simply the MDS matrix with the first row and column removed
112    ///
113    /// Ref: p.20 of the Poseidon paper
114    pub fn hat(&self) -> SquareMatrix<STATE_SIZE_MINUS_1, NUM_ELEMENTS_STATE_SIZE_MINUS_1_2> {
115        let dim = self.n_rows();
116        let mut mhat_elements = [Fq::from(0u64); NUM_ELEMENTS_STATE_SIZE_MINUS_1_2];
117        let mut index = 0;
118        for i in 1..dim {
119            for j in 1..dim {
120                mhat_elements[index] = self.0.get_element(i, j);
121                index += 1;
122            }
123        }
124        SquareMatrix::new(&mhat_elements)
125    }
126
127    /// Create a new matrix from a slice of elements.
128    ///
129    /// # Security
130    ///
131    /// You must ensure this matrix was generated securely,
132    /// using the Cauchy method in `fixed_cauchy_matrix` or
133    /// using the random subsampling method described in the original
134    /// paper.
135    pub const fn new_from_known(elements: [Fq; NUM_ELEMENTS]) -> Self {
136        Self(SquareMatrix::new_from_known(elements))
137    }
138}
139
140/// Represents an optimized MDS (maximum distance separable) matrix.
141#[derive(Clone, Debug, PartialEq, Eq)]
142pub struct OptimizedMdsMatrices<
143    const N_ROUNDS: usize,
144    const N_PARTIAL_ROUNDS: usize,
145    const STATE_SIZE: usize,
146    const STATE_SIZE_MINUS_1: usize,
147    const NUM_MDS_ELEMENTS: usize,
148    const NUM_STATE_SIZE_MINUS_1_ELEMENTS: usize,
149> {
150    /// A (t - 1) x (t - 1) MDS submatrix derived from the MDS matrix.
151    pub M_hat: SquareMatrix<STATE_SIZE_MINUS_1, NUM_STATE_SIZE_MINUS_1_ELEMENTS>,
152    /// A 1 x (t - 1) (row) vector derived from the MDS matrix.
153    pub v: Matrix<1, STATE_SIZE_MINUS_1, STATE_SIZE_MINUS_1>,
154    /// A (t - 1) x 1 (column) vector derived from the MDS matrix.
155    pub w: Matrix<STATE_SIZE_MINUS_1, 1, STATE_SIZE_MINUS_1>,
156    /// A matrix formed from Mhat (an MDS submatrix of the MDS matrix).
157    pub M_prime: SquareMatrix<STATE_SIZE, NUM_MDS_ELEMENTS>,
158    /// A sparse matrix formed from M,
159    pub M_doubleprime: SquareMatrix<STATE_SIZE, NUM_MDS_ELEMENTS>,
160    /// The inverse of the t x t MDS matrix (needed to compute round constants).
161    pub M_inverse: SquareMatrix<STATE_SIZE, NUM_MDS_ELEMENTS>,
162    /// The inverse of the (t - 1) x (t - 1) Mhat matrix.
163    pub M_hat_inverse: SquareMatrix<STATE_SIZE_MINUS_1, NUM_STATE_SIZE_MINUS_1_ELEMENTS>,
164    /// Element at M00
165    pub M_00: Fq,
166    /// M_i
167    pub M_i: Matrix<STATE_SIZE, STATE_SIZE, NUM_MDS_ELEMENTS>,
168    /// v_collection: one per partial round.
169    pub v_collection: [Matrix<1, STATE_SIZE_MINUS_1, STATE_SIZE_MINUS_1>; N_PARTIAL_ROUNDS],
170    /// w_hat_collection: one per round
171    pub w_hat_collection: [Matrix<STATE_SIZE_MINUS_1, 1, STATE_SIZE_MINUS_1>; N_PARTIAL_ROUNDS],
172}