Skip to main content

arcis_compiler/utils/crypto/
rescue_desc.rs

1use crate::{
2    core::{
3        actually_used_field::ActuallyUsedField,
4        global_value::{field_array::FieldArray, value::FieldValue},
5    },
6    traits::{Pow, Random},
7    utils::{matrix::Matrix, number::Number, used_field::UsedField},
8};
9use num_bigint::BigInt;
10use num_traits::{ToPrimitive, Zero};
11use sha3::{
12    digest::{ExtendableOutput, Update, XofReader},
13    Shake256,
14};
15use std::{
16    fmt::Debug,
17    iter::successors,
18    ops::{Add, Mul, Sub},
19};
20
21pub trait RescueArg<F: UsedField>:
22    Copy
23    + Debug
24    + Add<Self, Output = Self>
25    + Sub<Self, Output = Self>
26    + Mul<Self, Output = Self>
27    + Mul<F, Output = Self>
28    + Zero
29    + Pow
30    + Random
31    + From<F>
32{
33}
34
35impl<F: ActuallyUsedField> RescueArg<F> for F {}
36
37impl<F: ActuallyUsedField> RescueArg<F> for FieldValue<F> {}
38
39impl<const N: usize, F: ActuallyUsedField> RescueArg<F> for FieldArray<N, F> {}
40
41#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
42pub enum RescueMode {
43    BlockCipher,
44    HashFunction { capacity: usize },
45}
46
47// Security level for the block cipher.
48const SECURITY_LEVEL_BLOCK_CIPHER: usize = 128;
49
50// Security level for the hash function.
51const SECURITY_LEVEL_HASH_FUNCTION: usize = 256;
52
53/// see <https://tosc.iacr.org/index.php/ToSC/article/view/8695/8287> for everything
54/// We used different MDS matrix: Cauchy-based matrix.
55/// It's easier to compute and easier to prove it is MDS.
56
57#[derive(Clone, Debug, PartialEq, Eq, Hash)]
58pub struct RescueDesc<F: UsedField, T: RescueArg<F>> {
59    pub mode: RescueMode,
60    alpha: Number,
61    alpha_inverse: Number,
62    n_rounds: usize,
63    pub m: usize,
64    mds_mat: Matrix<F>,
65    mds_mat_inverse: Matrix<F>,
66    round_keys: Vec<Matrix<T>>,
67}
68
69impl<F: UsedField, T: RescueArg<F>> RescueDesc<F, T> {
70    pub fn new_cipher_desc(key: Matrix<T>) -> Self {
71        if key.nrows == 1 || key.ncols > 1 {
72            panic!(
73                "key must be a column vector with at least 2 rows (found nrows: {}, ncols: {})",
74                key.nrows, key.ncols
75            );
76        }
77        let m = key.nrows;
78        let alpha = F::get_alpha();
79        let alpha_inverse = F::get_alpha_inverse();
80        let n_rounds = Self::get_n_rounds(RescueMode::BlockCipher, &alpha, m);
81        let (mds_mat, mds_mat_inverse) = F::mds_matrix_and_inverse(m);
82        // generate the round constants using SHAKE256 hash
83        let round_constants = Self::sample_constants(RescueMode::BlockCipher, n_rounds, m);
84
85        // do the key schedule
86        let round_keys = rescue_permutation(
87            RescueMode::BlockCipher,
88            &alpha,
89            &alpha_inverse,
90            &mds_mat,
91            &round_constants
92                .into_iter()
93                .map(|c| c.convert())
94                .collect::<Vec<Matrix<T>>>(),
95            &key,
96        );
97
98        RescueDesc {
99            mode: RescueMode::BlockCipher,
100            alpha,
101            alpha_inverse,
102            n_rounds,
103            m,
104            mds_mat,
105            mds_mat_inverse,
106            round_keys,
107        }
108    }
109
110    pub fn new_hash_desc(m: usize, capacity: usize) -> Self {
111        let alpha = F::get_alpha();
112        let alpha_inverse = F::get_alpha_inverse();
113        let n_rounds = Self::get_n_rounds(RescueMode::HashFunction { capacity }, &alpha, m);
114        let (mds_mat, mds_mat_inverse) = F::mds_matrix_and_inverse(m);
115        // generate the round constants using SHAKE256 hash
116        let round_constants =
117            Self::sample_constants(RescueMode::HashFunction { capacity }, n_rounds, m);
118
119        RescueDesc {
120            mode: RescueMode::HashFunction { capacity },
121            alpha,
122            alpha_inverse,
123            n_rounds,
124            m,
125            mds_mat,
126            mds_mat_inverse,
127            round_keys: round_constants
128                .into_iter()
129                .map(|c| c.convert())
130                .collect::<Vec<Matrix<T>>>(),
131        }
132    }
133
134    fn get_n_rounds(mode: RescueMode, alpha: &Number, m: usize) -> usize {
135        match mode {
136            RescueMode::BlockCipher => {
137                // see https://tosc.iacr.org/index.php/ToSC/article/view/8695/8287
138                let l_0 = (SECURITY_LEVEL_BLOCK_CIPHER as f64 * 2.0
139                    / ((m + 1) as f64 * (F::modulus().log2() - (alpha - 1).log2())))
140                .ceil() as usize;
141                let l_1 = if *alpha == 3 {
142                    ((SECURITY_LEVEL_BLOCK_CIPHER + 2) as f64 / (4 * m) as f64).ceil() as usize
143                } else {
144                    ((SECURITY_LEVEL_BLOCK_CIPHER + 3) as f64 / (m as f64 * 5.5)).ceil() as usize
145                };
146                2 * ([l_0, l_1, 5].into_iter().max().unwrap())
147            }
148            RescueMode::HashFunction { capacity } => {
149                // get number of rounds for Groebner basis attack
150                let rate = m - capacity;
151                fn dcon(n: usize, alpha: &Number, m: usize) -> usize {
152                    (0.5 * ((alpha.to_usize().unwrap() - 1) * m * (n - 1)) as f64 + 2.0).floor()
153                        as usize
154                }
155                fn v(n: usize, rate: usize, m: usize) -> usize {
156                    m * (n - 1) + rate
157                }
158                fn binomial(n: usize, k: usize) -> Number {
159                    fn factorial(m: Number) -> Number {
160                        if m == 0 || m == 1 {
161                            Number::from(1)
162                        } else {
163                            m.clone() * factorial(m - 1)
164                        }
165                    }
166                    factorial(Number::from(n))
167                        / (factorial(Number::from(n - k)) * factorial(Number::from(k)))
168                }
169
170                let target = Number::power_of_two(SECURITY_LEVEL_HASH_FUNCTION);
171                let mut l1 = 1;
172                let mut tmp = binomial(v(l1, rate, m) + dcon(l1, alpha, m), v(l1, rate, m));
173                while tmp.clone() * tmp <= target && l1 <= 23 {
174                    l1 += 1;
175                    tmp = binomial(v(l1, rate, m) + dcon(l1, alpha, m), v(l1, rate, m));
176                }
177
178                // set a minimum value for sanity and add 50%
179                (1.5 * [5, l1].into_iter().max().unwrap() as f64).ceil() as usize
180            }
181        }
182    }
183
184    fn sample_constants(mode: RescueMode, n_rounds: usize, m: usize) -> Vec<Matrix<F>> {
185        // setup randomness
186        let mut hasher = Shake256::default();
187        // buffer to create `FieldElements` from bytes (via `Number`)
188        // we add 16 bytes to get a distribution statistically close to uniform
189        let buffer_len = F::NUM_BITS.div_ceil(8) as usize + 16;
190        match mode {
191            RescueMode::BlockCipher => {
192                hasher.update(b"encrypt everything, compute anything");
193                let mut reader = hasher.finalize_xof();
194
195                let mut f_iter = (0..m * m + 2 * m).map(|_| {
196                    // create field element from the shake hash
197                    let randomness = reader.read_boxed(buffer_len);
198                    // we set the sign to plus to essentially read unsigned BigInts (BigUInts),
199                    // matching the noble curves TypeScript implementation used
200                    // in the client.
201                    let b = BigInt::from_bytes_le(num_bigint::Sign::Plus, &randomness);
202                    // we need not check whether the obtained field element f is in any subgroup,
203                    // because we use only prime fields (i.e. there are no subgroups)
204                    F::from(Number::from(b))
205                });
206
207                // create matrix and vectors
208                let mut round_constant_mat =
209                    Matrix::new_from_iter((m, m), (&mut f_iter).take(m * m));
210                let initial_round_constant = Matrix::new_from_iter((m, 1), (&mut f_iter).take(m));
211                let round_constant_affine_term =
212                    Matrix::new_from_iter((m, 1), (&mut f_iter).take(m));
213
214                // check for inversability
215                while round_constant_mat.det() == F::ZERO {
216                    //resample the matrix
217                    let data = vec![F::ZERO; m * m].into_iter().map(|_| {
218                        let randomness = reader.read_boxed(buffer_len);
219                        let b = BigInt::from_bytes_le(num_bigint::Sign::Plus, &randomness);
220                        F::from(Number::from(b))
221                    });
222                    round_constant_mat = Matrix::new_from_iter((m, m), data);
223                }
224
225                let mut iter = 0..2 * n_rounds;
226                successors(Some(initial_round_constant), |c| {
227                    iter.next().map(|_| {
228                        round_constant_mat.clone().mat_mul(c) + round_constant_affine_term.clone()
229                    })
230                })
231                .collect::<Vec<Matrix<F>>>()
232            }
233            RescueMode::HashFunction { capacity } => {
234                let seed = format!(
235                    "Rescue-XLIX({},{},{},{})",
236                    F::modulus(),
237                    m,
238                    capacity,
239                    SECURITY_LEVEL_HASH_FUNCTION
240                );
241                hasher.update(seed.as_bytes());
242                let mut reader = hasher.finalize_xof();
243
244                let mut round_constants = (0..2 * m * n_rounds)
245                    .map(|_| {
246                        // create field element from the shake hash
247                        let randomness = reader.read_boxed(buffer_len);
248                        // we set the sign to plus to essentially read unsigned BigInts (BigUInts),
249                        // matching the noble curves TypeScript implementation used
250                        // in the client.
251                        let b = BigInt::from_bytes_le(num_bigint::Sign::Plus, &randomness);
252                        // we need not check whether the obtained field element f is in any
253                        // subgroup, because we use only prime fields (i.e.
254                        // there are no subgroups)
255                        F::from(Number::from(b))
256                    })
257                    .collect::<Vec<F>>()
258                    .chunks(m)
259                    .map(|c| Matrix::new_from_iter((m, 1), c.iter().copied()))
260                    .collect::<Vec<Matrix<F>>>();
261                // Self::permute requires an odd number of round keys
262                // prepending a 0 matrix makes it equivalent to Algorithm 3 from https://eprint.iacr.org/2020/1143.pdf
263                round_constants.insert(0, Matrix::new((m, 1), F::ZERO));
264                round_constants
265            }
266        }
267    }
268
269    pub fn permute(&self, state: &Matrix<T>) -> Matrix<T> {
270        rescue_permutation(
271            self.mode,
272            &self.alpha,
273            &self.alpha_inverse,
274            &self.mds_mat,
275            &self.round_keys,
276            state,
277        )
278        .last()
279        .unwrap()
280        .clone()
281    }
282
283    pub fn permute_inverse(&self, state: &Matrix<T>) -> Matrix<T> {
284        rescue_permutation_inverse(
285            self.mode,
286            &self.alpha,
287            &self.alpha_inverse,
288            &self.mds_mat_inverse,
289            &self.round_keys,
290            state,
291        )
292        .last()
293        .unwrap()
294        .clone()
295    }
296}
297
298fn exponent_for_even(mode: RescueMode, alpha: Number, alpha_inverse: Number) -> Number {
299    match mode {
300        RescueMode::BlockCipher => alpha_inverse,
301        RescueMode::HashFunction { capacity: _ } => alpha,
302    }
303}
304
305fn exponent_for_odd(mode: RescueMode, alpha: Number, alpha_inverse: Number) -> Number {
306    match mode {
307        RescueMode::BlockCipher => alpha,
308        RescueMode::HashFunction { capacity: _ } => alpha_inverse,
309    }
310}
311
312fn rescue_permutation<T: RescueArg<F>, F: UsedField>(
313    mode: RescueMode,
314    alpha: &Number,
315    alpha_inverse: &Number,
316    mds_mat: &Matrix<F>,
317    subkeys: &[Matrix<T>],
318    state: &Matrix<T>,
319) -> Vec<Matrix<T>> {
320    let exponent_even = exponent_for_even(mode, alpha.clone(), alpha_inverse.clone());
321    let exponent_odd = exponent_for_odd(mode, alpha.clone(), alpha_inverse.clone());
322    let initial_key = &subkeys[0];
323    let mut iter = subkeys[1..].iter().enumerate();
324    successors(Some(state.clone() + initial_key.clone()), |s| {
325        iter.next().map(|(r, key)| {
326            let mut s = s.clone();
327            if r % 2 == 0 {
328                // we can expect x to be non-zero
329                s.map_mut(|x| x.pow(&exponent_even, true));
330            } else {
331                // we can expect x to be non-zero
332                s.map_mut(|x| x.pow(&exponent_odd, true));
333            }
334            s = mds_mat.mat_mul(&s);
335            s += key;
336            s
337        })
338    })
339    .collect::<Vec<Matrix<T>>>()
340}
341
342fn rescue_permutation_inverse<T: RescueArg<F>, F: UsedField>(
343    mode: RescueMode,
344    alpha: &Number,
345    alpha_inverse: &Number,
346    mds_mat_inverse: &Matrix<F>,
347    subkeys: &[Matrix<T>],
348    state: &Matrix<T>,
349) -> Vec<Matrix<T>> {
350    let exponent_even = exponent_for_even(mode, alpha.clone(), alpha_inverse.clone());
351    let exponent_odd = exponent_for_odd(mode, alpha.clone(), alpha_inverse.clone());
352    let initial_key = &subkeys[0];
353    let mut states = subkeys[1..]
354        .iter()
355        .rev()
356        .enumerate()
357        .scan(state.clone(), |s, (r, key)| {
358            *s -= key;
359            *s = mds_mat_inverse.mat_mul(s);
360            if r % 2 == 0 {
361                // we can expect x to be non-zero
362                s.map_mut(|x| x.pow(&exponent_even, true));
363            } else {
364                // we can expect x to be non-zero
365                s.map_mut(|x| x.pow(&exponent_odd, true));
366            }
367            Some(s.clone())
368        })
369        .collect::<Vec<Matrix<T>>>();
370    states.push(states.last().unwrap().clone() - initial_key.clone());
371    states
372}
373#[cfg(test)]
374mod tests {
375    use super::*;
376    use crate::utils::field::BaseField;
377    use ff::Field;
378    use rand::Rng;
379
380    fn test_rescue_desc<R: Rng + ?Sized>(rng: &mut R, rescue: RescueDesc<BaseField, BaseField>) {
381        let alpha_prod = (&rescue.alpha * &rescue.alpha_inverse) % (BaseField::modulus() - 1);
382        assert_eq!(alpha_prod, Number::from(1));
383        fn test_is_identity(mat_prod: Matrix<BaseField>) {
384            for i in 0..mat_prod.nrows {
385                for j in 0..mat_prod.ncols {
386                    let expected = if i == j {
387                        BaseField::ONE
388                    } else {
389                        BaseField::ZERO
390                    };
391                    assert_eq!(*mat_prod.get((i, j)).unwrap(), expected);
392                }
393            }
394        }
395        let mat_prod = rescue.mds_mat.mat_mul(&rescue.mds_mat_inverse);
396        test_is_identity(mat_prod);
397        let mat_prod = rescue.mds_mat_inverse.mat_mul(&rescue.mds_mat);
398        test_is_identity(mat_prod);
399        for _ in 0..2 {
400            let state = Matrix::from(gen_random_fp(rng, rescue.m));
401            let permuted = rescue.permute(&state);
402            let unpermuted = rescue.permute_inverse(&permuted);
403            assert_eq!(unpermuted, state);
404        }
405    }
406    fn gen_random_fp<R: Rng + ?Sized>(rng: &mut R, size: usize) -> Vec<BaseField> {
407        (0..size)
408            .map(|_| <BaseField as ff::Field>::random(&mut *rng))
409            .collect()
410    }
411    #[test]
412    fn rescue_desc() {
413        let rng = &mut crate::utils::test_rng::get();
414
415        let mut m = 2;
416        while rng.gen_bool(0.5) {
417            m += 1;
418        }
419        let rescue_cipher = RescueDesc::new_cipher_desc(Matrix::from(gen_random_fp(rng, m)));
420        test_rescue_desc(rng, rescue_cipher);
421
422        let mut capacity = 1;
423        while rng.gen_bool(0.5) {
424            capacity += 1;
425        }
426        capacity = capacity.min(m - 1);
427        let rescue_hash = RescueDesc::<BaseField, BaseField>::new_hash_desc(m, capacity);
428        test_rescue_desc(rng, rescue_hash);
429    }
430}