poseidon_paramgen/
poseidon_build.rs

1use std::fmt::Display;
2
3use ark_ff::PrimeField;
4use ark_std::vec::Vec;
5use num::BigUint;
6
7use poseidon_parameters::v1::{
8    Alpha, ArcMatrix, Matrix, MatrixOperations, MdsMatrix, OptimizedArcMatrix,
9    OptimizedMdsMatrices, PoseidonParameters, SquareMatrix,
10};
11use poseidon_parameters::v2::PoseidonParameters as V2PoseidonParameters;
12
13// use crate::generate;
14use crate::{v1, v2};
15
16/// Create v1 parameter code.
17pub fn v1_compile<F: PrimeField>(
18    M: usize,
19    t_values: Vec<usize>,
20    p: F::BigInt,
21    allow_inverse: bool,
22) -> String {
23    let mut params_code = "use ark_ff::PrimeField;\n
24use poseidon_parameters::v1::{Alpha, ArcMatrix, RoundNumbers, SquareMatrix, Matrix, MdsMatrix, OptimizedArcMatrix, OptimizedMdsMatrices, PoseidonParameters, MatrixOperations};\n\n"
25        .to_string();
26
27    for t in t_values {
28        let params = v1::generate::<F>(M, t, p, allow_inverse);
29        params_code.push_str(&format!("{}", DisplayablePoseidonParameters(&params))[..]);
30    }
31
32    params_code
33}
34
35/// Create v2 parameter code.poseidon-parameters/src/v2.rs
36pub fn v2_compile<F: PrimeField>(
37    M: usize,
38    t_values: Vec<usize>,
39    p: F::BigInt,
40    allow_inverse: bool,
41) -> String {
42    let mut params_code = "use ark_ff::PrimeField;\n
43use poseidon_parameters::v2::{Alpha, ArcMatrix, RoundNumbers, SquareMatrix, Matrix, PoseidonParameters, MatrixOperations};\n\n"
44        .to_string();
45
46    for t in t_values {
47        let params = v2::generate::<F>(M, t, p, allow_inverse);
48        params_code.push_str(&format!("{}", DisplayableV2PoseidonParameters(&params))[..]);
49    }
50
51    params_code
52}
53
54struct DisplayableV2PoseidonParameters<'a, F: PrimeField>(&'a V2PoseidonParameters<F>);
55impl<F: PrimeField> Display for DisplayableV2PoseidonParameters<'_, F> {
56    fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        todo!()
58    }
59}
60
61struct DisplayablePoseidonParameters<'a, F: PrimeField>(&'a PoseidonParameters<F>);
62impl<F: PrimeField> Display for DisplayablePoseidonParameters<'_, F> {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        let this = self.0;
65
66        let capacity = 1;
67        let rate = this.t - capacity;
68
69        let rounds = this.rounds;
70
71        let r_P = rounds.partial();
72        let r_F = rounds.full();
73
74        let arc = &this.arc;
75        let mds = &this.mds;
76        let alpha = this.alpha;
77        let optimized_mds = &this.optimized_mds;
78        let optimized_arc = &this.optimized_arc;
79
80        write!(
81            f,
82            r"/// Parameters for the rate-{rate} instance of Poseidon.
83pub fn rate_{rate}<F: PrimeField>() -> PoseidonParameters<F> {{
84    PoseidonParameters {{
85        M: {},
86        t: {},
87        arc: {},
88        mds: {},
89        alpha: {},
90        rounds: RoundNumbers {{r_P: {r_P}, r_F: {r_F}}},
91        optimized_mds: {},
92        optimized_arc: {},
93    }}
94}}
95",
96            this.M,
97            this.t,
98            DisplayableArcMatrix(arc),
99            DisplayableMdsMatrix(mds),
100            DisplayableAlpha(alpha),
101            DisplayableOptimizedMdsMatrices(optimized_mds),
102            DisplayableOptimizedArcMatrix(optimized_arc),
103        )
104    }
105}
106
107struct DisplayableAlpha(Alpha);
108impl Display for DisplayableAlpha {
109    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110        match self.0 {
111            Alpha::Exponent(exp) => write!(f, "Alpha::Exponent({exp})"),
112            Alpha::Inverse => write!(f, "Alpha::Inverse"),
113        }
114    }
115}
116
117struct DisplayableMatrix<'a, F: PrimeField>(&'a Matrix<F>);
118impl<F: PrimeField> Display for DisplayableMatrix<'_, F> {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        let n_rows = self.0.n_rows();
121        let n_cols = self.0.n_cols();
122        let elements = serialize_slice_f(&self.0.elements().to_vec());
123        write!(f, r"Matrix::new({n_rows}, {n_cols}, {elements})",)
124    }
125}
126
127struct DisplayableSquareMatrix<'a, F: PrimeField>(&'a SquareMatrix<F>);
128impl<F: PrimeField> Display for DisplayableSquareMatrix<'_, F> {
129    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130        let n_rows = self.0.n_rows();
131        let n_cols = self.0.n_cols();
132        let elements = serialize_slice_f(&self.0.elements().to_vec());
133        write!(f, r"SquareMatrix::new({n_rows}, {n_cols}, {elements})",)
134    }
135}
136
137fn serialize_slice_matrix_f<F: PrimeField>(elements: &[Matrix<F>]) -> String {
138    let mut new_str = "vec![".to_string();
139    for elem in elements {
140        new_str.push_str(&format!("{}, ", DisplayableMatrix(elem)).to_string());
141    }
142    // Remove the trailing ", "
143    new_str.pop();
144    new_str.pop();
145    new_str.push(']');
146    new_str
147}
148
149fn serialize_slice_f<F: PrimeField>(elements: &[F]) -> String {
150    let mut new_str = "vec![".to_string();
151    for elem in elements {
152        // We use the BigUint type here since the Display of the field element
153        // is not in decimal: see https://github.com/arkworks-rs/algebra/issues/320
154        let elem_bigint: BigUint = (*elem).into();
155        new_str.push_str("F::from_str(\"");
156        new_str.push_str(&format!("{}", elem_bigint).to_string());
157        new_str.push_str("\").map_err(|_| ()).unwrap(), ");
158    }
159    // Remove the trailing ", "
160    new_str.pop();
161    new_str.pop();
162    new_str.push(']');
163    new_str
164}
165
166fn serialize_slice_of_vecs_f<F: PrimeField>(elements: &[Vec<F>]) -> String {
167    let mut new_str = "vec![".to_string();
168    for r in elements {
169        for c in r {
170            let elem_bigint: BigUint = (*c).into();
171            new_str.push_str("F::from_str(\"");
172            new_str.push_str(&format!("{}", elem_bigint).to_string());
173            new_str.push_str("\").map_err(|_| ()).unwrap(), ");
174        }
175    }
176    // Remove the trailing ", "
177    new_str.pop();
178    new_str.pop();
179    new_str.push(']');
180    new_str
181}
182
183fn serialize_f<F: PrimeField>(single_element: &F) -> String {
184    let mut new_str = "F::from_str(\"".to_string();
185    let elem_bigint: BigUint = (*single_element).into();
186    new_str.push_str(&format!("{}", elem_bigint));
187    new_str.push_str("\").map_err(|_| ()).unwrap()");
188    new_str
189}
190
191struct DisplayableOptimizedMdsMatrices<'a, F: PrimeField>(&'a OptimizedMdsMatrices<F>);
192impl<F: PrimeField> Display for DisplayableOptimizedMdsMatrices<'_, F> {
193    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194        let this = self.0;
195
196        let M_hat = &this.M_hat;
197        let v = &this.v;
198        let w = &this.w;
199        let M_prime = &this.M_prime;
200        let M_doubleprime = &this.M_doubleprime;
201        let M_inverse = &this.M_inverse;
202        let M_hat_inverse = &this.M_hat_inverse;
203        let M_i = &this.M_i;
204        let v_collection = &this.v_collection;
205        let w_hat_collection = &this.w_hat_collection;
206
207        write!(
208            f,
209            r"OptimizedMdsMatrices {{
210                M_hat: {},
211                v: {},
212                w: {},
213                M_prime: {},
214                M_doubleprime: {},
215                M_inverse: {},
216                M_hat_inverse: {},
217                M_00: {},
218                M_i: {},
219                v_collection: {},
220                w_hat_collection: {},
221            }}",
222            DisplayableSquareMatrix(M_hat),
223            DisplayableMatrix(v),
224            DisplayableMatrix(w),
225            DisplayableSquareMatrix(M_prime),
226            DisplayableSquareMatrix(M_doubleprime),
227            DisplayableSquareMatrix(M_inverse),
228            DisplayableSquareMatrix(M_hat_inverse),
229            serialize_f(&this.M_00),
230            DisplayableMatrix(M_i),
231            serialize_slice_matrix_f(v_collection),
232            serialize_slice_matrix_f(w_hat_collection),
233        )
234    }
235}
236
237struct DisplayableMdsMatrix<'a, F: PrimeField>(&'a MdsMatrix<F>);
238impl<F: PrimeField> Display for DisplayableMdsMatrix<'_, F> {
239    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240        let mds_elements: Vec<F> = self.0.elements().to_vec();
241
242        let mut mds_str = "MdsMatrix::from_elements(".to_string();
243        mds_str.push_str(&serialize_slice_f(&mds_elements));
244        mds_str.push(')');
245        write!(f, "{}", &mds_str[..])
246    }
247}
248
249struct DisplayableArcMatrix<'a, F: PrimeField>(&'a ArcMatrix<F>);
250impl<F: PrimeField> Display for DisplayableArcMatrix<'_, F> {
251    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252        let n_rows = self.0.n_rows();
253        let n_cols = self.0.n_cols();
254
255        let arc = self.0.clone();
256        let elements: Vec<Vec<F>> = arc.into();
257
258        let mut arc_str = "ArcMatrix::new(".to_string();
259        arc_str.push_str(&n_rows.to_string());
260        arc_str.push_str(", ");
261        arc_str.push_str(&n_cols.to_string());
262        arc_str.push_str(r", ");
263        arc_str.push_str(&serialize_slice_of_vecs_f(&elements));
264        arc_str.push(')');
265        write!(f, "{}", &arc_str[..])
266    }
267}
268
269fn optimized_arc_matrix_to_vec_vec<F: PrimeField>(matrix: &OptimizedArcMatrix<F>) -> Vec<Vec<F>> {
270    let mut rows = Vec::<Vec<F>>::new();
271    let arc: &ArcMatrix<F> = &matrix.0;
272
273    for i in 0..arc.n_rows() {
274        let mut row = Vec::new();
275        for j in 0..arc.n_cols() {
276            row.push(arc.get_element(i, j));
277        }
278        rows.push(row);
279    }
280    rows
281}
282
283struct DisplayableOptimizedArcMatrix<'a, F: PrimeField>(&'a OptimizedArcMatrix<F>);
284impl<F: PrimeField> Display for DisplayableOptimizedArcMatrix<'_, F> {
285    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286        let n_rows = self.0 .0.n_rows();
287        let n_cols = self.0 .0.n_cols();
288        let elements = optimized_arc_matrix_to_vec_vec(self.0);
289
290        let mut arc_str = "OptimizedArcMatrix::new(".to_string();
291        arc_str.push_str(&n_rows.to_string());
292        arc_str.push_str(", ");
293        arc_str.push_str(&n_cols.to_string());
294        arc_str.push_str(r", ");
295        arc_str.push_str(&serialize_slice_of_vecs_f(&elements));
296        arc_str.push(')');
297        write!(f, "{}", &arc_str[..])
298    }
299}