poseidon_parameters/
mds_matrix.rs1use crate::{
2 error::PoseidonParameterError,
3 matrix::{Matrix, SquareMatrix},
4 matrix_ops::{MatrixOperations, SquareMatrixOperations},
5};
6use decaf377::Fq;
7
8#[derive(Clone, Debug, PartialEq, Eq)]
10pub struct MdsMatrix<
11 const STATE_SIZE: usize,
12 const STATE_SIZE_MINUS_1: usize,
13 const NUM_ELEMENTS: usize,
14 const NUM_ELEMENTS_STATE_SIZE_MINUS_1_2: usize,
15>(pub SquareMatrix<STATE_SIZE, NUM_ELEMENTS>);
16
17impl<
18 const STATE_SIZE: usize,
19 const STATE_SIZE_MINUS_1: usize,
20 const NUM_ELEMENTS: usize,
21 const NUM_ELEMENTS_STATE_SIZE_MINUS_1_2: usize,
22 > MatrixOperations
23 for MdsMatrix<STATE_SIZE, STATE_SIZE_MINUS_1, NUM_ELEMENTS, NUM_ELEMENTS_STATE_SIZE_MINUS_1_2>
24{
25 fn new(elements: &[Fq]) -> Self {
26 assert!(STATE_SIZE == STATE_SIZE_MINUS_1 + 1);
27 assert!(STATE_SIZE * STATE_SIZE == NUM_ELEMENTS);
28 assert!(STATE_SIZE_MINUS_1 * STATE_SIZE_MINUS_1 == NUM_ELEMENTS_STATE_SIZE_MINUS_1_2);
29 Self(SquareMatrix::new(elements))
30 }
31
32 fn elements(&self) -> &[Fq] {
33 self.0.elements()
34 }
35
36 fn get_element(&self, i: usize, j: usize) -> Fq {
37 self.0.get_element(i, j)
38 }
39
40 fn set_element(&mut self, i: usize, j: usize, val: Fq) {
41 self.0.set_element(i, j, val)
42 }
43
44 fn n_rows(&self) -> usize {
45 self.0.n_rows()
46 }
47
48 fn n_cols(&self) -> usize {
49 self.0.n_cols()
50 }
51
52 fn hadamard_product(&self, rhs: &Self) -> Result<Self, PoseidonParameterError>
53 where
54 Self: Sized,
55 {
56 Ok(Self(self.0.hadamard_product(&rhs.0)?))
57 }
58}
59
60impl<
61 const STATE_SIZE: usize,
62 const STATE_SIZE_MINUS_1: usize,
63 const NUM_ELEMENTS: usize,
64 const NUM_ELEMENTS_STATE_SIZE_MINUS_1_2: usize,
65 > MdsMatrix<STATE_SIZE, STATE_SIZE_MINUS_1, NUM_ELEMENTS, NUM_ELEMENTS_STATE_SIZE_MINUS_1_2>
66{
67 pub fn from_elements(elements: &[Fq]) -> Self {
76 Self(SquareMatrix::new(elements))
77 }
78
79 pub fn transpose(&self) -> Self {
80 Self(self.0.transpose())
81 }
82
83 pub fn inverse(&self) -> SquareMatrix<STATE_SIZE, NUM_ELEMENTS> {
85 self.0
86 .inverse()
87 .expect("all well-formed MDS matrices should have inverses")
88 }
89
90 pub fn v(&self) -> Matrix<1, STATE_SIZE_MINUS_1, STATE_SIZE_MINUS_1> {
94 let elements = &self.0 .0.elements()[1..self.0 .0.n_rows()];
95 Matrix::new(elements)
96 }
97
98 pub fn w(&self) -> Matrix<STATE_SIZE_MINUS_1, 1, STATE_SIZE_MINUS_1> {
102 let mut elements = [Fq::from(0u64); STATE_SIZE_MINUS_1];
103 for i in 1..self.n_rows() {
104 elements[i - 1] = self.get_element(i, 0);
105 }
106 Matrix::new(&elements)
107 }
108
109 pub fn hat(&self) -> SquareMatrix<STATE_SIZE_MINUS_1, NUM_ELEMENTS_STATE_SIZE_MINUS_1_2> {
115 let dim = self.n_rows();
116 let mut mhat_elements = [Fq::from(0u64); NUM_ELEMENTS_STATE_SIZE_MINUS_1_2];
117 let mut index = 0;
118 for i in 1..dim {
119 for j in 1..dim {
120 mhat_elements[index] = self.0.get_element(i, j);
121 index += 1;
122 }
123 }
124 SquareMatrix::new(&mhat_elements)
125 }
126
127 pub const fn new_from_known(elements: [Fq; NUM_ELEMENTS]) -> Self {
136 Self(SquareMatrix::new_from_known(elements))
137 }
138}
139
140#[derive(Clone, Debug, PartialEq, Eq)]
142pub struct OptimizedMdsMatrices<
143 const N_ROUNDS: usize,
144 const N_PARTIAL_ROUNDS: usize,
145 const STATE_SIZE: usize,
146 const STATE_SIZE_MINUS_1: usize,
147 const NUM_MDS_ELEMENTS: usize,
148 const NUM_STATE_SIZE_MINUS_1_ELEMENTS: usize,
149> {
150 pub M_hat: SquareMatrix<STATE_SIZE_MINUS_1, NUM_STATE_SIZE_MINUS_1_ELEMENTS>,
152 pub v: Matrix<1, STATE_SIZE_MINUS_1, STATE_SIZE_MINUS_1>,
154 pub w: Matrix<STATE_SIZE_MINUS_1, 1, STATE_SIZE_MINUS_1>,
156 pub M_prime: SquareMatrix<STATE_SIZE, NUM_MDS_ELEMENTS>,
158 pub M_doubleprime: SquareMatrix<STATE_SIZE, NUM_MDS_ELEMENTS>,
160 pub M_inverse: SquareMatrix<STATE_SIZE, NUM_MDS_ELEMENTS>,
162 pub M_hat_inverse: SquareMatrix<STATE_SIZE_MINUS_1, NUM_STATE_SIZE_MINUS_1_ELEMENTS>,
164 pub M_00: Fq,
166 pub M_i: Matrix<STATE_SIZE, STATE_SIZE, NUM_MDS_ELEMENTS>,
168 pub v_collection: [Matrix<1, STATE_SIZE_MINUS_1, STATE_SIZE_MINUS_1>; N_PARTIAL_ROUNDS],
170 pub w_hat_collection: [Matrix<STATE_SIZE_MINUS_1, 1, STATE_SIZE_MINUS_1>; N_PARTIAL_ROUNDS],
172}