faest/
aes.rs

1use std::{array, mem::size_of};
2
3use generic_array::{
4    typenum::{Unsigned, U3, U4},
5    GenericArray,
6};
7use itertools::iproduct;
8
9use crate::{
10    fields::{ByteCombine, ByteCombineConstants, Field as _, SumPoly},
11    internal_keys::PublicKey,
12    parameter::{BaseParameters, OWFParameters, QSProof, TauParameters},
13    rijndael_32::{
14        bitslice, convert_from_batchblocks, inv_bitslice, mix_columns_0, rijndael_add_round_key,
15        rijndael_key_schedule, rijndael_shift_rows_1, sub_bytes, sub_bytes_nots, State, RCON_TABLE,
16    },
17    universal_hashing::{ZKHasherInit, ZKProofHasher, ZKVerifyHasher},
18    utils::{bit_combine_with_delta, contains_zeros, convert_gq, transpose_and_into_field, Field},
19};
20
21type KeyCstrnts<O> = (
22    Box<GenericArray<u8, <O as OWFParameters>::PRODRUN128Bytes>>,
23    Box<GenericArray<Field<O>, <O as OWFParameters>::PRODRUN128>>,
24);
25
26type CstrntsVal<'a, O> = &'a GenericArray<
27    GenericArray<u8, <O as OWFParameters>::LAMBDALBYTES>,
28    <O as OWFParameters>::LAMBDA,
29>;
30
31const fn inverse_rotate_word(r: usize, rotate: bool) -> usize {
32    if rotate {
33        // equivalent to (r - 3) % 4
34        (r + 1) % 4
35    } else {
36        r
37    }
38}
39
40pub(crate) fn aes_extendedwitness<O>(
41    owf_key: &GenericArray<u8, O::LAMBDABYTES>,
42    owf_input: &GenericArray<u8, O::InputSize>,
43) -> Option<Box<GenericArray<u8, O::LBYTES>>>
44where
45    O: OWFParameters,
46{
47    let mut input = [0u8; 32];
48    // Step 0
49    input[..O::InputSize::USIZE].clone_from_slice(owf_input);
50    let mut witness = GenericArray::default_boxed();
51    let mut index = 0;
52    // Step 3
53    let (kb, mut zeros) = rijndael_key_schedule::<U4, O::NK, O::R>(owf_key, O::SKE::USIZE);
54    // Step 4
55    for i in convert_from_batchblocks(inv_bitslice(&kb[..8])).take(4) {
56        witness[index..index + size_of::<u32>()].copy_from_slice(&i);
57        index += size_of::<u32>();
58    }
59    for i in convert_from_batchblocks(inv_bitslice(&kb[8..16]))
60        .take(O::NK::USIZE / 2 - (4 - (O::NK::USIZE / 2)))
61    {
62        witness[index..index + size_of::<u32>()].copy_from_slice(&i);
63        index += size_of::<u32>();
64    }
65    for j in 1 + (O::NK::USIZE / 8)
66        ..1 + (O::NK::USIZE / 8)
67            + (O::SKE::USIZE * ((2 - (O::NK::USIZE % 4)) * 2 + (O::NK::USIZE % 4) * 3)) / 16
68    {
69        let inside = GenericArray::<_, U3>::from_iter(
70            convert_from_batchblocks(inv_bitslice(&kb[8 * j..8 * (j + 1)])).take(3),
71        );
72        if O::NK::USIZE == 6 {
73            if j % 3 == 1 {
74                witness[index..index + size_of::<u32>()].copy_from_slice(&inside[2]);
75                index += size_of::<u32>();
76            } else if j % 3 == 0 {
77                witness[index..index + size_of::<u32>()].copy_from_slice(&inside[0]);
78                index += size_of::<u32>();
79            }
80        } else {
81            witness[index..index + size_of::<u32>()].copy_from_slice(&inside[0]);
82            index += size_of::<u32>();
83        }
84    }
85    // Step 5
86    zeros |= round_with_save(&input[..16], &kb, O::R::U8, &mut witness, &mut index);
87    if O::LAMBDA::USIZE > 128 {
88        zeros |= round_with_save(&input[16..], &kb, O::R::U8, &mut witness, &mut index);
89    }
90    if zeros {
91        None
92    } else {
93        Some(witness)
94    }
95}
96
97#[allow(clippy::too_many_arguments)]
98fn round_with_save(
99    input1: &[u8],
100    kb: &[u32],
101    r: u8,
102    witness: &mut [u8],
103    index: &mut usize,
104) -> bool {
105    let mut zeros = false;
106    let mut state = State::default();
107    bitslice(&mut state, input1, &[]);
108    rijndael_add_round_key(&mut state, &kb[..8]);
109    for j in 1..r as usize {
110        zeros |= contains_zeros(&inv_bitslice(&state)[0]);
111        sub_bytes(&mut state);
112        sub_bytes_nots(&mut state);
113        rijndael_shift_rows_1::<U4>(&mut state);
114        for i in convert_from_batchblocks(inv_bitslice(&state)).take(4) {
115            witness[*index..*index + size_of::<u32>()].copy_from_slice(&i);
116            *index += size_of::<u32>();
117        }
118        mix_columns_0(&mut state);
119        rijndael_add_round_key(&mut state, &kb[8 * j..8 * (j + 1)]);
120    }
121    zeros | contains_zeros(&inv_bitslice(&state)[0])
122}
123
124fn aes_key_exp_fwd_1<O>(
125    x: &GenericArray<u8, O::LKEBytes>,
126) -> Box<GenericArray<u8, O::PRODRUN128Bytes>>
127where
128    O: OWFParameters,
129{
130    let mut out = GenericArray::default_boxed();
131    out[..O::LAMBDABYTES::USIZE].copy_from_slice(&x[..O::LAMBDABYTES::USIZE]);
132    let mut index = O::LAMBDABYTES::USIZE;
133    let mut x_index = O::LAMBDABYTES::USIZE;
134    for j in O::NK::USIZE..(4 * (O::R::USIZE + 1)) {
135        if (j % O::NK::USIZE == 0) || ((O::NK::USIZE > 6) && (j % O::NK::USIZE == 4)) {
136            out[index..index + 32 / 8].copy_from_slice(&x[x_index..x_index + 32 / 8]);
137            index += 32 / 8;
138            x_index += 32 / 8;
139        } else {
140            for i in 0..4 {
141                out[index] = out[(32 * (j - O::NK::USIZE)) / 8 + i] ^ out[(32 * (j - 1)) / 8 + i];
142                index += 1;
143            }
144        }
145    }
146    out
147}
148
149fn aes_key_exp_fwd<O>(
150    x: &GenericArray<Field<O>, O::LKE>,
151) -> Box<GenericArray<Field<O>, O::PRODRUN128>>
152where
153    O: OWFParameters,
154{
155    // Step 1 is ok by construction
156    let mut out = GenericArray::default_boxed();
157    out[..O::LAMBDA::USIZE].copy_from_slice(&x[..O::LAMBDA::USIZE]);
158    let mut index = O::LAMBDA::USIZE;
159    let mut x_index = O::LAMBDA::USIZE;
160    for j in O::NK::USIZE..(4 * (O::R::USIZE + 1)) {
161        if (j % O::NK::USIZE == 0) || ((O::NK::USIZE > 6) && (j % O::NK::USIZE == 4)) {
162            out[index..index + 32].copy_from_slice(&x[x_index..x_index + 32]);
163            index += 32;
164            x_index += 32;
165        } else {
166            for i in 0..32 {
167                out[index] = out[(32 * (j - O::NK::USIZE)) + i] + out[(32 * (j - 1)) + i];
168                index += 1;
169            }
170        }
171    }
172    out
173}
174
175fn aes_key_exp_bwd_mtag0_mkey0<'a, O>(
176    x: &'a GenericArray<u8, O::LKEBytes>,
177    xk: &'a GenericArray<u8, O::PRODRUN128Bytes>,
178) -> impl Iterator<Item = Field<O>> + 'a
179where
180    O: OWFParameters,
181{
182    let mut indice = 0;
183    let mut c = 0;
184    let mut rmvrcon = true;
185    let mut ircon = 0;
186    // Step 6
187    (0..O::SKE::USIZE).map(move |j| {
188        // Step 7
189        let mut x_tilde = xk[indice + c] ^ x[j + O::LAMBDABYTES::USIZE];
190        // Step 8
191        if rmvrcon && (c == 0) {
192            let rcon = RCON_TABLE[ircon];
193            ircon += 1;
194            // Step 11
195            x_tilde ^= rcon;
196        }
197
198        c += 1;
199        // Step 21
200        if c == 4 {
201            c = 0;
202            if O::LAMBDA::USIZE == 192 {
203                indice += 192 / 8;
204            } else {
205                indice += 128 / 8;
206                if O::LAMBDA::USIZE == 256 {
207                    rmvrcon = !rmvrcon;
208                }
209            }
210        }
211
212        Field::<O>::byte_combine_bits(
213            x_tilde.rotate_right(7) ^ x_tilde.rotate_right(5) ^ x_tilde.rotate_right(2) ^ 0x5,
214        )
215    })
216}
217
218fn aes_key_exp_bwd_mtag1_mkey0<'a, O>(
219    x: &'a [Field<O>],
220    xk: &'a GenericArray<Field<O>, O::PRODRUN128>,
221) -> impl Iterator<Item = Field<O>> + 'a
222where
223    O: OWFParameters,
224{
225    let mut indice = 0;
226    let mut c = 0;
227    let mut rmvrcon = true;
228    // Step 6
229    (0..O::SKE::USIZE).map(move |j| {
230        // Step 7
231        let x_tilde: [_; 8] = array::from_fn(|i| x[8 * j + i] + xk[indice + 8 * c + i]);
232        // Step 15
233        let y_tilde =
234            array::from_fn(|i| x_tilde[(i + 7) % 8] + x_tilde[(i + 5) % 8] + x_tilde[(i + 2) % 8]);
235        c += 1;
236        // Step 21
237        if c == 4 {
238            c = 0;
239            if O::LAMBDA::USIZE == 192 {
240                indice += 192;
241            } else {
242                indice += 128;
243                if O::LAMBDA::USIZE == 256 {
244                    rmvrcon = !rmvrcon;
245                }
246            }
247        }
248        Field::<O>::byte_combine(&y_tilde)
249    })
250}
251
252fn aes_key_exp_bwd_mtag0_mkey1<'a, O>(
253    x: &'a GenericArray<Field<O>, O::LKE>,
254    xk: &'a GenericArray<Field<O>, O::PRODRUN128>,
255    delta: &'a Field<O>,
256) -> impl Iterator<Item = Field<O>> + 'a
257where
258    O: OWFParameters,
259{
260    let mut indice = 0;
261    let mut c = 0;
262    let mut rmvrcon = true;
263    let mut ircon = 0;
264    // Step 6
265    (0..O::SKE::USIZE).map(move |j| {
266        // Step 7
267        let mut x_tilde: [_; 8] =
268            array::from_fn(|i| x[8 * j + i + O::LAMBDA::USIZE] + xk[indice + 8 * c + i]);
269        // Step 8
270        if rmvrcon && (c == 0) {
271            let rcon = RCON_TABLE[ircon];
272            ircon += 1;
273            // Step 11
274            for (i, x) in x_tilde.iter_mut().enumerate() {
275                *x += *delta * ((rcon >> i) & 1);
276            }
277        }
278        // Step 15
279        let mut y_tilde =
280            array::from_fn(|i| x_tilde[(i + 7) % 8] + x_tilde[(i + 5) % 8] + x_tilde[(i + 2) % 8]);
281        y_tilde[0] += delta;
282        y_tilde[2] += delta;
283        c += 1;
284        // Step 21
285        if c == 4 {
286            c = 0;
287            if O::LAMBDA::USIZE == 192 {
288                indice += 192;
289            } else {
290                indice += 128;
291                if O::LAMBDA::USIZE == 256 {
292                    rmvrcon = !rmvrcon;
293                }
294            }
295        }
296        Field::<O>::byte_combine(&y_tilde)
297    })
298}
299
300fn aes_key_exp_cstrnts_mkey0<O>(
301    zk_hasher: &mut ZKProofHasher<Field<O>>,
302    w: &GenericArray<u8, O::LKEBytes>,
303    v: &GenericArray<Field<O>, O::LKE>,
304) -> KeyCstrnts<O>
305where
306    O: OWFParameters,
307{
308    let k = aes_key_exp_fwd_1::<O>(w);
309    let vk = aes_key_exp_fwd::<O>(v);
310    let w_b = aes_key_exp_bwd_mtag0_mkey0::<O>(w, &k);
311    let v_w_b = aes_key_exp_bwd_mtag1_mkey0::<O>(&v[O::LAMBDA::USIZE..], &vk);
312
313    zk_hasher.process(
314        iproduct!(0..O::SKE::USIZE / 4, 0..4).map(|(j, r)| {
315            let iwd = 32 * (O::NK::USIZE - 1) + j * if O::LAMBDA::USIZE == 192 { 192 } else { 128 };
316            let dorotword = !(O::LAMBDA::USIZE == 256 && j % 2 == 1);
317            Field::<O>::byte_combine_bits(k[iwd / 8 + inverse_rotate_word(r, dorotword)])
318        }),
319        iproduct!(0..O::SKE::USIZE / 4, 0..4).map(|(j, r)| {
320            let iwd = 32 * (O::NK::USIZE - 1) + j * if O::LAMBDA::USIZE == 192 { 192 } else { 128 };
321            let dorotword = !(O::LAMBDA::USIZE == 256 && j % 2 == 1);
322            let r = inverse_rotate_word(r, dorotword);
323            Field::<O>::byte_combine_slice(&vk[iwd + (8 * r)..iwd + (8 * r) + 8])
324        }),
325        w_b,
326        v_w_b,
327    );
328
329    (k, vk)
330}
331
332fn aes_key_exp_cstrnts_mkey1<O>(
333    zk_hasher: &mut ZKVerifyHasher<Field<O>>,
334    q: &GenericArray<Field<O>, O::LKE>,
335    delta: &Field<O>,
336) -> Box<GenericArray<Field<O>, <O as OWFParameters>::PRODRUN128>>
337where
338    O: OWFParameters,
339{
340    let q_k = aes_key_exp_fwd::<O>(q);
341    let q_w_b = aes_key_exp_bwd_mtag0_mkey1::<O>(q, &q_k, delta);
342
343    zk_hasher.process(
344        q_w_b,
345        iproduct!(0..O::SKE::USIZE / 4, 0..4).map(|(j, r)| {
346            let iwd = 32 * (O::NK::USIZE - 1) + j * if O::LAMBDA::USIZE == 192 { 192 } else { 128 };
347            let dorotword = !(O::LAMBDA::USIZE == 256 && j % 2 == 1);
348            let rotated_r = inverse_rotate_word(r, dorotword);
349            Field::<O>::byte_combine_slice(&q_k[iwd + (8 * rotated_r)..iwd + (8 * rotated_r) + 8])
350        }),
351    );
352
353    q_k
354}
355
356fn aes_enc_fwd_mkey0_mtag0<'a, O>(
357    x: &'a GenericArray<u8, O::QUOTLENC8>,
358    xk: &'a GenericArray<u8, O::PRODRUN128Bytes>,
359    input: &'a [u8; 16],
360) -> impl Iterator<Item = Field<O>> + 'a
361where
362    O: OWFParameters,
363{
364    (0..16)
365        .map(|i| {
366            // Step 2-5
367            Field::<O>::byte_combine_bits(input[i]) + Field::<O>::byte_combine_bits(xk[i])
368        })
369        .chain(
370            iproduct!(1..O::R::USIZE, 0..4)
371                .map(move |(j, c)| {
372                    // Step 6
373                    let ix: usize = 128 * (j - 1) + 32 * c;
374                    let ik: usize = 128 * j + 32 * c;
375                    let x_hat: [_; 4] =
376                        array::from_fn(|r| Field::<O>::byte_combine_bits(x[ix / 8 + r]));
377                    let mut res: [_; 4] =
378                        array::from_fn(|r| Field::<O>::byte_combine_bits(xk[ik / 8 + r]));
379
380                    // Step 16
381                    res[0] += x_hat[0] * Field::<O>::BYTE_COMBINE_2
382                        + x_hat[1] * Field::<O>::BYTE_COMBINE_3
383                        + x_hat[2]
384                        + x_hat[3];
385                    res[1] += x_hat[0]
386                        + x_hat[1] * Field::<O>::BYTE_COMBINE_2
387                        + x_hat[2] * Field::<O>::BYTE_COMBINE_3
388                        + x_hat[3];
389                    res[2] += x_hat[0]
390                        + x_hat[1]
391                        + x_hat[2] * Field::<O>::BYTE_COMBINE_2
392                        + x_hat[3] * Field::<O>::BYTE_COMBINE_3;
393                    res[3] += x_hat[0] * Field::<O>::BYTE_COMBINE_3
394                        + x_hat[1]
395                        + x_hat[2]
396                        + x_hat[3] * Field::<O>::BYTE_COMBINE_2;
397                    res
398                })
399                .flatten(),
400        )
401}
402
403fn aes_enc_fwd_mkey1_mtag0<'a, O>(
404    x: &'a GenericArray<Field<O>, O::LENC>,
405    xk: &'a GenericArray<Field<O>, O::PRODRUN128>,
406    input: &'a [u8; 16],
407    delta: &'a Field<O>,
408) -> impl Iterator<Item = Field<O>> + 'a
409where
410    O: OWFParameters,
411{
412    (0..16)
413        .map(|i| {
414            // Step 2-5
415            bit_combine_with_delta::<O>(input[i], delta)
416                + Field::<O>::byte_combine_slice(&xk[8 * i..(8 * i) + 8])
417        })
418        .chain(
419            iproduct!(1..O::R::USIZE, 0..4)
420                .map(move |(j, c)| {
421                    // Step 6
422                    let ix: usize = 128 * (j - 1) + 32 * c;
423                    let ik: usize = 128 * j + 32 * c;
424                    let x_hat: [_; 4] = array::from_fn(|r| {
425                        Field::<O>::byte_combine_slice(&x[ix + 8 * r..ix + 8 * r + 8])
426                    });
427                    let mut res: [_; 4] = array::from_fn(|r| {
428                        Field::<O>::byte_combine_slice(&xk[ik + 8 * r..ik + 8 * r + 8])
429                    });
430
431                    // Step 16
432                    res[0] += x_hat[0] * Field::<O>::BYTE_COMBINE_2
433                        + x_hat[1] * Field::<O>::BYTE_COMBINE_3
434                        + x_hat[2]
435                        + x_hat[3];
436                    res[1] += x_hat[0]
437                        + x_hat[1] * Field::<O>::BYTE_COMBINE_2
438                        + x_hat[2] * Field::<O>::BYTE_COMBINE_3
439                        + x_hat[3];
440                    res[2] += x_hat[0]
441                        + x_hat[1]
442                        + x_hat[2] * Field::<O>::BYTE_COMBINE_2
443                        + x_hat[3] * Field::<O>::BYTE_COMBINE_3;
444                    res[3] += x_hat[0] * Field::<O>::BYTE_COMBINE_3
445                        + x_hat[1]
446                        + x_hat[2]
447                        + x_hat[3] * Field::<O>::BYTE_COMBINE_2;
448                    res
449                })
450                .flatten(),
451        )
452}
453
454fn aes_enc_fwd_mkey0_mtag1<'a, O>(
455    x: &'a GenericArray<Field<O>, O::LENC>,
456    xk: &'a GenericArray<Field<O>, O::PRODRUN128>,
457) -> impl Iterator<Item = Field<O>> + 'a
458where
459    O: OWFParameters,
460{
461    (0..16)
462        .map(|i| {
463            // Step 2-5
464            Field::<O>::byte_combine_slice(&xk[8 * i..(8 * i) + 8])
465        })
466        .chain(
467            iproduct!(1..O::R::USIZE, 0..4)
468                .map(move |(j, c)| {
469                    // Step 6
470                    let ix: usize = 128 * (j - 1) + 32 * c;
471                    let ik: usize = 128 * j + 32 * c;
472                    let x_hat: [_; 4] = array::from_fn(|r| {
473                        Field::<O>::byte_combine_slice(&x[ix + 8 * r..ix + 8 * r + 8])
474                    });
475                    let mut res: [_; 4] = array::from_fn(|r| {
476                        Field::<O>::byte_combine_slice(&xk[ik + 8 * r..ik + 8 * r + 8])
477                    });
478
479                    // Step 16
480                    res[0] += x_hat[0] * Field::<O>::BYTE_COMBINE_2
481                        + x_hat[1] * Field::<O>::BYTE_COMBINE_3
482                        + x_hat[2]
483                        + x_hat[3];
484                    res[1] += x_hat[0]
485                        + x_hat[1] * Field::<O>::BYTE_COMBINE_2
486                        + x_hat[2] * Field::<O>::BYTE_COMBINE_3
487                        + x_hat[3];
488                    res[2] += x_hat[0]
489                        + x_hat[1]
490                        + x_hat[2] * Field::<O>::BYTE_COMBINE_2
491                        + x_hat[3] * Field::<O>::BYTE_COMBINE_3;
492                    res[3] += x_hat[0] * Field::<O>::BYTE_COMBINE_3
493                        + x_hat[1]
494                        + x_hat[2]
495                        + x_hat[3] * Field::<O>::BYTE_COMBINE_2;
496                    res
497                })
498                .flatten(),
499        )
500}
501
502fn aes_enc_bkwd_mkey0_mtag0<'a, O>(
503    x: &'a GenericArray<u8, O::QUOTLENC8>,
504    xk: &'a GenericArray<u8, O::PRODRUN128Bytes>,
505    out: &'a [u8; 16],
506) -> impl Iterator<Item = Field<O>> + 'a
507where
508    O: OWFParameters,
509{
510    // Step 2
511    iproduct!(0..O::R::USIZE, 0..4, 0..4).map(move |(j, c, k)| {
512        // Step 4
513        let ird = 128 * j + 32 * ((c + 4 - k) % 4) + 8 * k;
514        let x_t = if j < O::R::USIZE - 1 {
515            x[ird / 8]
516        } else {
517            let x_out = out[(ird - 128 * j) / 8];
518            x_out ^ xk[(128 + ird) / 8]
519        };
520        let y_t = x_t.rotate_right(7) ^ x_t.rotate_right(5) ^ x_t.rotate_right(2) ^ 0x5;
521        Field::<O>::byte_combine_bits(y_t)
522    })
523}
524
525fn aes_enc_bkwd_mkey1_mtag0<'a, O>(
526    x: &'a GenericArray<Field<O>, O::LENC>,
527    xk: &'a GenericArray<Field<O>, O::PRODRUN128>,
528    out: &'a [u8; 16],
529    delta: &'a Field<O>,
530) -> impl Iterator<Item = Field<O>> + 'a
531where
532    O: OWFParameters,
533{
534    // Step 2
535    iproduct!(0..O::R::USIZE, 0..4, 0..4).map(move |(j, c, k)| {
536        // Step 4
537        let ird = 128 * j + 32 * ((c + 4 - k) % 4) + 8 * k;
538        let x_t: [_; 8] = if j < O::R::USIZE - 1 {
539            array::from_fn(|i| x[ird + i])
540        } else {
541            array::from_fn(|i| {
542                *delta * ((out[(ird - 128 * j + i) / 8] >> ((ird - 128 * j + i) % 8)) & 1)
543                    + xk[128 + ird + i]
544            })
545        };
546        let mut y_t = array::from_fn(|i| x_t[(i + 7) % 8] + x_t[(i + 5) % 8] + x_t[(i + 2) % 8]);
547        y_t[0] += delta;
548        y_t[2] += delta;
549        Field::<O>::byte_combine(&y_t)
550    })
551}
552
553fn aes_enc_bkwd_mkey0_mtag1<'a, O>(
554    x: &'a GenericArray<Field<O>, O::LENC>,
555    xk: &'a GenericArray<Field<O>, O::PRODRUN128>,
556) -> impl Iterator<Item = Field<O>> + 'a
557where
558    O: OWFParameters,
559{
560    // Step 2
561    iproduct!(0..O::R::USIZE, 0..4, 0..4).map(move |(j, c, k)| {
562        // Step 4
563        let ird = 128 * j + 32 * ((c + 4 - k) % 4) + 8 * k;
564        let x_t = if j < O::R::USIZE - 1 {
565            &x[ird..ird + 8]
566        } else {
567            &xk[128 + ird..136 + ird]
568        };
569        let y_t = array::from_fn(|i| x_t[(i + 7) % 8] + x_t[(i + 5) % 8] + x_t[(i + 2) % 8]);
570        Field::<O>::byte_combine(&y_t)
571    })
572}
573
574fn aes_enc_cstrnts_mkey0<O>(
575    zk_hasher: &mut ZKProofHasher<Field<O>>,
576    input: &[u8; 16],
577    output: &[u8; 16],
578    w: &GenericArray<u8, O::QUOTLENC8>,
579    v: &GenericArray<Field<O>, O::LENC>,
580    k: &GenericArray<u8, O::PRODRUN128Bytes>,
581    vk: &GenericArray<Field<O>, O::PRODRUN128>,
582) where
583    O: OWFParameters,
584{
585    let s = aes_enc_fwd_mkey0_mtag0::<O>(w, k, input);
586    let vs = aes_enc_fwd_mkey0_mtag1::<O>(v, vk);
587    let s_b = aes_enc_bkwd_mkey0_mtag0::<O>(w, k, output);
588    let v_s_b = aes_enc_bkwd_mkey0_mtag1::<O>(v, vk);
589    zk_hasher.process(s, vs, s_b, v_s_b);
590}
591
592fn aes_enc_cstrnts_mkey1<O>(
593    zk_hasher: &mut ZKVerifyHasher<Field<O>>,
594    input: &[u8; 16],
595    output: &[u8; 16],
596    q: &GenericArray<Field<O>, O::LENC>,
597    qk: &GenericArray<Field<O>, O::PRODRUN128>,
598    delta: &Field<O>,
599) where
600    O: OWFParameters,
601{
602    let qs = aes_enc_fwd_mkey1_mtag0::<O>(q, qk, input, delta);
603    let q_s_b = aes_enc_bkwd_mkey1_mtag0::<O>(q, qk, output, delta);
604    zk_hasher.process(qs, q_s_b);
605}
606
607// Bits are represented as bytes : each times we manipulate bit data, we divide length by 8
608pub(crate) fn aes_prove<O>(
609    w: &GenericArray<u8, O::LBYTES>,
610    u: &GenericArray<u8, O::LAMBDALBYTES>,
611    gv: CstrntsVal<O>,
612    pk: &PublicKey<O>,
613    chall: &GenericArray<u8, <<O as OWFParameters>::BaseParams as BaseParameters>::Chall>,
614) -> QSProof<O>
615where
616    O: OWFParameters,
617{
618    let new_v = transpose_and_into_field::<O>(gv);
619
620    let mut zk_hasher =
621        <<O as OWFParameters>::BaseParams as BaseParameters>::ZKHasher::new_zk_proof_hasher(chall);
622
623    let (k, vk) = aes_key_exp_cstrnts_mkey0::<O>(
624        &mut zk_hasher,
625        GenericArray::from_slice(&w[..O::LKE::USIZE / 8]),
626        GenericArray::from_slice(&new_v[..O::LKE::USIZE]),
627    );
628
629    aes_enc_cstrnts_mkey0::<O>(
630        &mut zk_hasher,
631        pk.owf_input[..16].try_into().unwrap(),
632        pk.owf_output[..16].try_into().unwrap(),
633        GenericArray::from_slice(&w[O::LKE::USIZE / 8..(O::LKE::USIZE + O::LENC::USIZE) / 8]),
634        GenericArray::from_slice(&new_v[O::LKE::USIZE..O::LKE::USIZE + O::LENC::USIZE]),
635        &k,
636        &vk,
637    );
638
639    if O::LAMBDA::USIZE > 128 {
640        aes_enc_cstrnts_mkey0::<O>(
641            &mut zk_hasher,
642            pk.owf_input[16..].try_into().unwrap(),
643            pk.owf_output[16..].try_into().unwrap(),
644            GenericArray::from_slice(&w[(O::LKE::USIZE + O::LENC::USIZE) / 8..O::LBYTES::USIZE]),
645            GenericArray::from_slice(&new_v[(O::LKE::USIZE + O::LENC::USIZE)..O::L::USIZE]),
646            &k,
647            &vk,
648        );
649    }
650
651    let u_s = Field::<O>::from(&u[O::LBYTES::USIZE..]);
652    let v_s = Field::<O>::sum_poly(&new_v[O::L::USIZE..O::L::USIZE + O::LAMBDA::USIZE]);
653    let (a_t, b_t) = zk_hasher.finalize(&u_s, &v_s);
654
655    (a_t.as_bytes(), b_t.as_bytes())
656}
657
658// Bits are represented as bytes : each times we manipulate bit data, we divide length by 8
659#[allow(clippy::too_many_arguments)]
660pub(crate) fn aes_verify<O, Tau>(
661    d: &GenericArray<u8, O::LBYTES>,
662    gq: Box<GenericArray<GenericArray<u8, O::LAMBDALBYTES>, O::LAMBDA>>,
663    a_t: &GenericArray<u8, O::LAMBDABYTES>,
664    chall2: &GenericArray<u8, <<O as OWFParameters>::BaseParams as BaseParameters>::Chall>,
665    chall3: &GenericArray<u8, O::LAMBDABYTES>,
666    pk: &PublicKey<O>,
667) -> GenericArray<u8, O::LAMBDABYTES>
668where
669    O: OWFParameters,
670    Tau: TauParameters,
671{
672    let delta = Field::<O>::from(chall3);
673    let new_q = convert_gq::<O, Tau>(d, gq, chall3);
674    let mut zk_hasher =
675        <<O as OWFParameters>::BaseParams as BaseParameters>::ZKHasher::new_zk_verify_hasher(
676            chall2, delta,
677        );
678
679    let qk = aes_key_exp_cstrnts_mkey1::<O>(
680        &mut zk_hasher,
681        GenericArray::from_slice(&new_q[..O::LKE::USIZE]),
682        &delta,
683    );
684
685    aes_enc_cstrnts_mkey1::<O>(
686        &mut zk_hasher,
687        pk.owf_input[..16].try_into().unwrap(),
688        pk.owf_output[..16].try_into().unwrap(),
689        GenericArray::from_slice(&new_q[O::LKE::USIZE..(O::LKE::USIZE + O::LENC::USIZE)]),
690        &qk,
691        &delta,
692    );
693    if O::LAMBDA::USIZE > 128 {
694        aes_enc_cstrnts_mkey1::<O>(
695            &mut zk_hasher,
696            pk.owf_input[16..].try_into().unwrap(),
697            pk.owf_output[16..].try_into().unwrap(),
698            GenericArray::from_slice(&new_q[O::LKE::USIZE + O::LENC::USIZE..O::L::USIZE]),
699            &qk,
700            &delta,
701        );
702    }
703
704    let q_s = Field::<O>::sum_poly(&new_q[O::L::USIZE..O::L::USIZE + O::LAMBDA::USIZE]);
705    (zk_hasher.finalize(&q_s) + Field::<O>::from(a_t) * delta).as_bytes()
706}
707
708#[cfg(test)]
709mod test {
710    #![allow(clippy::needless_range_loop)]
711
712    use super::*;
713
714    use crate::{
715        fields::{GF128, GF192, GF256},
716        parameter::{
717            FAEST128sParameters, FAEST192sParameters, FAEST256sParameters, FAESTParameters,
718            OWFParameters, OWF128, OWF192, OWF256,
719        },
720        utils::test::read_test_data,
721    };
722
723    use generic_array::{sequence::GenericSequence, ArrayLength, GenericArray};
724    use serde::Deserialize;
725
726    #[derive(Debug, Deserialize)]
727    #[serde(rename_all = "camelCase")]
728    struct AesExtendedWitness {
729        lambda: u16,
730        key: Vec<u8>,
731        input: Vec<u8>,
732        w: Vec<u8>,
733    }
734
735    #[test]
736    fn aes_extended_witness_test() {
737        let database: Vec<AesExtendedWitness> = read_test_data("AesExtendedWitness.json");
738        for data in database {
739            if data.lambda == 128 {
740                let res = aes_extendedwitness::<OWF128>(
741                    GenericArray::from_slice(&data.key),
742                    GenericArray::from_slice(
743                        &data.input[..<OWF128 as OWFParameters>::InputSize::USIZE],
744                    ),
745                );
746                assert_eq!(res.unwrap().as_slice(), &data.w);
747            } else if data.lambda == 192 {
748                let res = aes_extendedwitness::<OWF192>(
749                    GenericArray::from_slice(&data.key),
750                    GenericArray::from_slice(
751                        &data.input[..<OWF192 as OWFParameters>::InputSize::USIZE],
752                    ),
753                );
754                assert_eq!(res.unwrap().as_slice(), &data.w);
755            } else {
756                let res = aes_extendedwitness::<OWF256>(
757                    GenericArray::from_slice(&data.key),
758                    GenericArray::from_slice(
759                        &data.input[..<OWF256 as OWFParameters>::InputSize::USIZE],
760                    ),
761                );
762                assert_eq!(res.unwrap().as_slice(), &data.w);
763            }
764        }
765    }
766
767    #[derive(Debug, Deserialize)]
768    #[serde(rename_all = "camelCase")]
769    struct AesProve {
770        lambda: u16,
771        w: Vec<u8>,
772        input: Vec<u8>,
773        output: Vec<u8>,
774        at: Vec<u8>,
775        bt: Vec<u8>,
776    }
777
778    impl AesProve {
779        fn as_pk<O>(&self) -> PublicKey<O>
780        where
781            O: OWFParameters,
782        {
783            PublicKey {
784                owf_input: GenericArray::from_slice(&self.input).clone(),
785                owf_output: GenericArray::from_slice(&self.output).clone(),
786            }
787        }
788    }
789
790    #[test]
791    fn aes_prove_test() {
792        let database: Vec<AesProve> = read_test_data("AesProve.json");
793        for data in database {
794            if data.lambda == 128 {
795                let res = aes_prove::<OWF128>(
796                    GenericArray::from_slice(&data.w),
797                    &GenericArray::generate(|_| 19),
798                    &GenericArray::generate(|_| GenericArray::generate(|_| 55)),
799                    &data.as_pk(),
800                    &GenericArray::generate(|_| 47),
801                );
802                assert_eq!(res.0.as_slice(), &data.at);
803                assert_eq!(res.1.as_slice(), &data.bt);
804            } else if data.lambda == 192 {
805                let res = aes_prove::<OWF192>(
806                    GenericArray::from_slice(&data.w),
807                    &GenericArray::generate(|_| 19),
808                    &GenericArray::generate(|_| GenericArray::generate(|_| 55)),
809                    &data.as_pk(),
810                    &GenericArray::generate(|_| 47),
811                );
812                assert_eq!(res.0.as_slice(), &data.at);
813                assert_eq!(res.1.as_slice(), &data.bt);
814            } else {
815                let res = aes_prove::<OWF256>(
816                    GenericArray::from_slice(&data.w),
817                    &GenericArray::generate(|_| 19),
818                    &GenericArray::generate(|_| GenericArray::generate(|_| 55)),
819                    &data.as_pk(),
820                    &GenericArray::generate(|_| 47),
821                );
822                assert_eq!(res.0.as_slice(), &data.at);
823                assert_eq!(res.1.as_slice(), &data.bt);
824            }
825        }
826    }
827
828    #[derive(Debug, Deserialize)]
829    #[serde(rename_all = "camelCase")]
830    struct AesVerify {
831        lambda: u16,
832        gq: Vec<Vec<u8>>,
833        d: Vec<u8>,
834        chall2: Vec<u8>,
835        chall3: Vec<u8>,
836        at: Vec<u8>,
837        input: Vec<u8>,
838        output: Vec<u8>,
839        res: Vec<u64>,
840    }
841
842    impl AesVerify {
843        fn res_as_u8(&self) -> Vec<u8> {
844            self.res.iter().flat_map(|x| x.to_le_bytes()).collect()
845        }
846
847        fn as_pk<O>(&self) -> PublicKey<O>
848        where
849            O: OWFParameters,
850        {
851            PublicKey {
852                owf_input: GenericArray::from_slice(&self.input).clone(),
853                owf_output: GenericArray::from_slice(&self.output).clone(),
854            }
855        }
856
857        fn as_gq<LHI, LHO>(&self) -> GenericArray<GenericArray<u8, LHI>, LHO>
858        where
859            LHI: ArrayLength,
860            LHO: ArrayLength,
861        {
862            self.gq
863                .iter()
864                .map(|x| GenericArray::from_slice(x).clone())
865                .collect()
866        }
867    }
868
869    fn aes_verify<O, Tau>(
870        d: &GenericArray<u8, O::LBYTES>,
871        gq: &GenericArray<GenericArray<u8, O::LAMBDALBYTES>, O::LAMBDA>,
872        a_t: &GenericArray<u8, O::LAMBDABYTES>,
873        chall2: &GenericArray<u8, <<O as OWFParameters>::BaseParams as BaseParameters>::Chall>,
874        chall3: &GenericArray<u8, O::LAMBDABYTES>,
875        pk: &PublicKey<O>,
876    ) -> GenericArray<u8, O::LAMBDABYTES>
877    where
878        O: OWFParameters,
879        Tau: TauParameters,
880    {
881        super::aes_verify::<O, Tau>(
882            d,
883            Box::<GenericArray<_, _>>::from_iter(gq.iter().cloned()),
884            a_t,
885            chall2,
886            chall3,
887            pk,
888        )
889    }
890
891    #[test]
892    fn aes_verify_test() {
893        let database: Vec<AesVerify> = read_test_data("AesVerify.json");
894        for data in database {
895            if data.lambda == 128 {
896                let out = aes_verify::<OWF128, <FAEST128sParameters as FAESTParameters>::Tau>(
897                    GenericArray::from_slice(&data.d[..]),
898                    &data.as_gq(),
899                    GenericArray::from_slice(&data.at),
900                    GenericArray::from_slice(&data.chall2[..]),
901                    GenericArray::from_slice(&data.chall3[..]),
902                    &data.as_pk(),
903                );
904                assert_eq!(
905                    GF128::from(&data.res_as_u8()[..16]),
906                    GF128::from(out.as_slice())
907                );
908            } else if data.lambda == 192 {
909                let out = aes_verify::<OWF192, <FAEST192sParameters as FAESTParameters>::Tau>(
910                    GenericArray::from_slice(&data.d[..]),
911                    &data.as_gq(),
912                    GenericArray::from_slice(&data.at),
913                    GenericArray::from_slice(&data.chall2[..]),
914                    GenericArray::from_slice(&data.chall3[..]),
915                    &data.as_pk(),
916                );
917                assert_eq!(
918                    GF192::from(&data.res_as_u8()[..24]),
919                    GF192::from(out.as_slice())
920                );
921            } else {
922                let out = aes_verify::<OWF256, <FAEST256sParameters as FAESTParameters>::Tau>(
923                    GenericArray::from_slice(&data.d[..]),
924                    &data.as_gq(),
925                    GenericArray::from_slice(&data.at),
926                    GenericArray::from_slice(&data.chall2[..]),
927                    GenericArray::from_slice(&data.chall3[..]),
928                    &data.as_pk(),
929                );
930                assert_eq!(
931                    GF256::from(&data.res_as_u8()[..32]),
932                    GF256::from(out.as_slice())
933                );
934            }
935        }
936    }
937}