1use std::{array, mem::size_of};
2
3use generic_array::{
4 typenum::{Unsigned, U3, U4},
5 GenericArray,
6};
7use itertools::iproduct;
8
9use crate::{
10 fields::{ByteCombine, ByteCombineConstants, Field as _, SumPoly},
11 internal_keys::PublicKey,
12 parameter::{BaseParameters, OWFParameters, QSProof, TauParameters},
13 rijndael_32::{
14 bitslice, convert_from_batchblocks, inv_bitslice, mix_columns_0, rijndael_add_round_key,
15 rijndael_key_schedule, rijndael_shift_rows_1, sub_bytes, sub_bytes_nots, State, RCON_TABLE,
16 },
17 universal_hashing::{ZKHasherInit, ZKProofHasher, ZKVerifyHasher},
18 utils::{bit_combine_with_delta, contains_zeros, convert_gq, transpose_and_into_field, Field},
19};
20
21type KeyCstrnts<O> = (
22 Box<GenericArray<u8, <O as OWFParameters>::PRODRUN128Bytes>>,
23 Box<GenericArray<Field<O>, <O as OWFParameters>::PRODRUN128>>,
24);
25
26type CstrntsVal<'a, O> = &'a GenericArray<
27 GenericArray<u8, <O as OWFParameters>::LAMBDALBYTES>,
28 <O as OWFParameters>::LAMBDA,
29>;
30
31const fn inverse_rotate_word(r: usize, rotate: bool) -> usize {
32 if rotate {
33 (r + 1) % 4
35 } else {
36 r
37 }
38}
39
40pub(crate) fn aes_extendedwitness<O>(
41 owf_key: &GenericArray<u8, O::LAMBDABYTES>,
42 owf_input: &GenericArray<u8, O::InputSize>,
43) -> Option<Box<GenericArray<u8, O::LBYTES>>>
44where
45 O: OWFParameters,
46{
47 let mut input = [0u8; 32];
48 input[..O::InputSize::USIZE].clone_from_slice(owf_input);
50 let mut witness = GenericArray::default_boxed();
51 let mut index = 0;
52 let (kb, mut zeros) = rijndael_key_schedule::<U4, O::NK, O::R>(owf_key, O::SKE::USIZE);
54 for i in convert_from_batchblocks(inv_bitslice(&kb[..8])).take(4) {
56 witness[index..index + size_of::<u32>()].copy_from_slice(&i);
57 index += size_of::<u32>();
58 }
59 for i in convert_from_batchblocks(inv_bitslice(&kb[8..16]))
60 .take(O::NK::USIZE / 2 - (4 - (O::NK::USIZE / 2)))
61 {
62 witness[index..index + size_of::<u32>()].copy_from_slice(&i);
63 index += size_of::<u32>();
64 }
65 for j in 1 + (O::NK::USIZE / 8)
66 ..1 + (O::NK::USIZE / 8)
67 + (O::SKE::USIZE * ((2 - (O::NK::USIZE % 4)) * 2 + (O::NK::USIZE % 4) * 3)) / 16
68 {
69 let inside = GenericArray::<_, U3>::from_iter(
70 convert_from_batchblocks(inv_bitslice(&kb[8 * j..8 * (j + 1)])).take(3),
71 );
72 if O::NK::USIZE == 6 {
73 if j % 3 == 1 {
74 witness[index..index + size_of::<u32>()].copy_from_slice(&inside[2]);
75 index += size_of::<u32>();
76 } else if j % 3 == 0 {
77 witness[index..index + size_of::<u32>()].copy_from_slice(&inside[0]);
78 index += size_of::<u32>();
79 }
80 } else {
81 witness[index..index + size_of::<u32>()].copy_from_slice(&inside[0]);
82 index += size_of::<u32>();
83 }
84 }
85 zeros |= round_with_save(&input[..16], &kb, O::R::U8, &mut witness, &mut index);
87 if O::LAMBDA::USIZE > 128 {
88 zeros |= round_with_save(&input[16..], &kb, O::R::U8, &mut witness, &mut index);
89 }
90 if zeros {
91 None
92 } else {
93 Some(witness)
94 }
95}
96
97#[allow(clippy::too_many_arguments)]
98fn round_with_save(
99 input1: &[u8],
100 kb: &[u32],
101 r: u8,
102 witness: &mut [u8],
103 index: &mut usize,
104) -> bool {
105 let mut zeros = false;
106 let mut state = State::default();
107 bitslice(&mut state, input1, &[]);
108 rijndael_add_round_key(&mut state, &kb[..8]);
109 for j in 1..r as usize {
110 zeros |= contains_zeros(&inv_bitslice(&state)[0]);
111 sub_bytes(&mut state);
112 sub_bytes_nots(&mut state);
113 rijndael_shift_rows_1::<U4>(&mut state);
114 for i in convert_from_batchblocks(inv_bitslice(&state)).take(4) {
115 witness[*index..*index + size_of::<u32>()].copy_from_slice(&i);
116 *index += size_of::<u32>();
117 }
118 mix_columns_0(&mut state);
119 rijndael_add_round_key(&mut state, &kb[8 * j..8 * (j + 1)]);
120 }
121 zeros | contains_zeros(&inv_bitslice(&state)[0])
122}
123
124fn aes_key_exp_fwd_1<O>(
125 x: &GenericArray<u8, O::LKEBytes>,
126) -> Box<GenericArray<u8, O::PRODRUN128Bytes>>
127where
128 O: OWFParameters,
129{
130 let mut out = GenericArray::default_boxed();
131 out[..O::LAMBDABYTES::USIZE].copy_from_slice(&x[..O::LAMBDABYTES::USIZE]);
132 let mut index = O::LAMBDABYTES::USIZE;
133 let mut x_index = O::LAMBDABYTES::USIZE;
134 for j in O::NK::USIZE..(4 * (O::R::USIZE + 1)) {
135 if (j % O::NK::USIZE == 0) || ((O::NK::USIZE > 6) && (j % O::NK::USIZE == 4)) {
136 out[index..index + 32 / 8].copy_from_slice(&x[x_index..x_index + 32 / 8]);
137 index += 32 / 8;
138 x_index += 32 / 8;
139 } else {
140 for i in 0..4 {
141 out[index] = out[(32 * (j - O::NK::USIZE)) / 8 + i] ^ out[(32 * (j - 1)) / 8 + i];
142 index += 1;
143 }
144 }
145 }
146 out
147}
148
149fn aes_key_exp_fwd<O>(
150 x: &GenericArray<Field<O>, O::LKE>,
151) -> Box<GenericArray<Field<O>, O::PRODRUN128>>
152where
153 O: OWFParameters,
154{
155 let mut out = GenericArray::default_boxed();
157 out[..O::LAMBDA::USIZE].copy_from_slice(&x[..O::LAMBDA::USIZE]);
158 let mut index = O::LAMBDA::USIZE;
159 let mut x_index = O::LAMBDA::USIZE;
160 for j in O::NK::USIZE..(4 * (O::R::USIZE + 1)) {
161 if (j % O::NK::USIZE == 0) || ((O::NK::USIZE > 6) && (j % O::NK::USIZE == 4)) {
162 out[index..index + 32].copy_from_slice(&x[x_index..x_index + 32]);
163 index += 32;
164 x_index += 32;
165 } else {
166 for i in 0..32 {
167 out[index] = out[(32 * (j - O::NK::USIZE)) + i] + out[(32 * (j - 1)) + i];
168 index += 1;
169 }
170 }
171 }
172 out
173}
174
175fn aes_key_exp_bwd_mtag0_mkey0<'a, O>(
176 x: &'a GenericArray<u8, O::LKEBytes>,
177 xk: &'a GenericArray<u8, O::PRODRUN128Bytes>,
178) -> impl Iterator<Item = Field<O>> + 'a
179where
180 O: OWFParameters,
181{
182 let mut indice = 0;
183 let mut c = 0;
184 let mut rmvrcon = true;
185 let mut ircon = 0;
186 (0..O::SKE::USIZE).map(move |j| {
188 let mut x_tilde = xk[indice + c] ^ x[j + O::LAMBDABYTES::USIZE];
190 if rmvrcon && (c == 0) {
192 let rcon = RCON_TABLE[ircon];
193 ircon += 1;
194 x_tilde ^= rcon;
196 }
197
198 c += 1;
199 if c == 4 {
201 c = 0;
202 if O::LAMBDA::USIZE == 192 {
203 indice += 192 / 8;
204 } else {
205 indice += 128 / 8;
206 if O::LAMBDA::USIZE == 256 {
207 rmvrcon = !rmvrcon;
208 }
209 }
210 }
211
212 Field::<O>::byte_combine_bits(
213 x_tilde.rotate_right(7) ^ x_tilde.rotate_right(5) ^ x_tilde.rotate_right(2) ^ 0x5,
214 )
215 })
216}
217
218fn aes_key_exp_bwd_mtag1_mkey0<'a, O>(
219 x: &'a [Field<O>],
220 xk: &'a GenericArray<Field<O>, O::PRODRUN128>,
221) -> impl Iterator<Item = Field<O>> + 'a
222where
223 O: OWFParameters,
224{
225 let mut indice = 0;
226 let mut c = 0;
227 let mut rmvrcon = true;
228 (0..O::SKE::USIZE).map(move |j| {
230 let x_tilde: [_; 8] = array::from_fn(|i| x[8 * j + i] + xk[indice + 8 * c + i]);
232 let y_tilde =
234 array::from_fn(|i| x_tilde[(i + 7) % 8] + x_tilde[(i + 5) % 8] + x_tilde[(i + 2) % 8]);
235 c += 1;
236 if c == 4 {
238 c = 0;
239 if O::LAMBDA::USIZE == 192 {
240 indice += 192;
241 } else {
242 indice += 128;
243 if O::LAMBDA::USIZE == 256 {
244 rmvrcon = !rmvrcon;
245 }
246 }
247 }
248 Field::<O>::byte_combine(&y_tilde)
249 })
250}
251
252fn aes_key_exp_bwd_mtag0_mkey1<'a, O>(
253 x: &'a GenericArray<Field<O>, O::LKE>,
254 xk: &'a GenericArray<Field<O>, O::PRODRUN128>,
255 delta: &'a Field<O>,
256) -> impl Iterator<Item = Field<O>> + 'a
257where
258 O: OWFParameters,
259{
260 let mut indice = 0;
261 let mut c = 0;
262 let mut rmvrcon = true;
263 let mut ircon = 0;
264 (0..O::SKE::USIZE).map(move |j| {
266 let mut x_tilde: [_; 8] =
268 array::from_fn(|i| x[8 * j + i + O::LAMBDA::USIZE] + xk[indice + 8 * c + i]);
269 if rmvrcon && (c == 0) {
271 let rcon = RCON_TABLE[ircon];
272 ircon += 1;
273 for (i, x) in x_tilde.iter_mut().enumerate() {
275 *x += *delta * ((rcon >> i) & 1);
276 }
277 }
278 let mut y_tilde =
280 array::from_fn(|i| x_tilde[(i + 7) % 8] + x_tilde[(i + 5) % 8] + x_tilde[(i + 2) % 8]);
281 y_tilde[0] += delta;
282 y_tilde[2] += delta;
283 c += 1;
284 if c == 4 {
286 c = 0;
287 if O::LAMBDA::USIZE == 192 {
288 indice += 192;
289 } else {
290 indice += 128;
291 if O::LAMBDA::USIZE == 256 {
292 rmvrcon = !rmvrcon;
293 }
294 }
295 }
296 Field::<O>::byte_combine(&y_tilde)
297 })
298}
299
300fn aes_key_exp_cstrnts_mkey0<O>(
301 zk_hasher: &mut ZKProofHasher<Field<O>>,
302 w: &GenericArray<u8, O::LKEBytes>,
303 v: &GenericArray<Field<O>, O::LKE>,
304) -> KeyCstrnts<O>
305where
306 O: OWFParameters,
307{
308 let k = aes_key_exp_fwd_1::<O>(w);
309 let vk = aes_key_exp_fwd::<O>(v);
310 let w_b = aes_key_exp_bwd_mtag0_mkey0::<O>(w, &k);
311 let v_w_b = aes_key_exp_bwd_mtag1_mkey0::<O>(&v[O::LAMBDA::USIZE..], &vk);
312
313 zk_hasher.process(
314 iproduct!(0..O::SKE::USIZE / 4, 0..4).map(|(j, r)| {
315 let iwd = 32 * (O::NK::USIZE - 1) + j * if O::LAMBDA::USIZE == 192 { 192 } else { 128 };
316 let dorotword = !(O::LAMBDA::USIZE == 256 && j % 2 == 1);
317 Field::<O>::byte_combine_bits(k[iwd / 8 + inverse_rotate_word(r, dorotword)])
318 }),
319 iproduct!(0..O::SKE::USIZE / 4, 0..4).map(|(j, r)| {
320 let iwd = 32 * (O::NK::USIZE - 1) + j * if O::LAMBDA::USIZE == 192 { 192 } else { 128 };
321 let dorotword = !(O::LAMBDA::USIZE == 256 && j % 2 == 1);
322 let r = inverse_rotate_word(r, dorotword);
323 Field::<O>::byte_combine_slice(&vk[iwd + (8 * r)..iwd + (8 * r) + 8])
324 }),
325 w_b,
326 v_w_b,
327 );
328
329 (k, vk)
330}
331
332fn aes_key_exp_cstrnts_mkey1<O>(
333 zk_hasher: &mut ZKVerifyHasher<Field<O>>,
334 q: &GenericArray<Field<O>, O::LKE>,
335 delta: &Field<O>,
336) -> Box<GenericArray<Field<O>, <O as OWFParameters>::PRODRUN128>>
337where
338 O: OWFParameters,
339{
340 let q_k = aes_key_exp_fwd::<O>(q);
341 let q_w_b = aes_key_exp_bwd_mtag0_mkey1::<O>(q, &q_k, delta);
342
343 zk_hasher.process(
344 q_w_b,
345 iproduct!(0..O::SKE::USIZE / 4, 0..4).map(|(j, r)| {
346 let iwd = 32 * (O::NK::USIZE - 1) + j * if O::LAMBDA::USIZE == 192 { 192 } else { 128 };
347 let dorotword = !(O::LAMBDA::USIZE == 256 && j % 2 == 1);
348 let rotated_r = inverse_rotate_word(r, dorotword);
349 Field::<O>::byte_combine_slice(&q_k[iwd + (8 * rotated_r)..iwd + (8 * rotated_r) + 8])
350 }),
351 );
352
353 q_k
354}
355
356fn aes_enc_fwd_mkey0_mtag0<'a, O>(
357 x: &'a GenericArray<u8, O::QUOTLENC8>,
358 xk: &'a GenericArray<u8, O::PRODRUN128Bytes>,
359 input: &'a [u8; 16],
360) -> impl Iterator<Item = Field<O>> + 'a
361where
362 O: OWFParameters,
363{
364 (0..16)
365 .map(|i| {
366 Field::<O>::byte_combine_bits(input[i]) + Field::<O>::byte_combine_bits(xk[i])
368 })
369 .chain(
370 iproduct!(1..O::R::USIZE, 0..4)
371 .map(move |(j, c)| {
372 let ix: usize = 128 * (j - 1) + 32 * c;
374 let ik: usize = 128 * j + 32 * c;
375 let x_hat: [_; 4] =
376 array::from_fn(|r| Field::<O>::byte_combine_bits(x[ix / 8 + r]));
377 let mut res: [_; 4] =
378 array::from_fn(|r| Field::<O>::byte_combine_bits(xk[ik / 8 + r]));
379
380 res[0] += x_hat[0] * Field::<O>::BYTE_COMBINE_2
382 + x_hat[1] * Field::<O>::BYTE_COMBINE_3
383 + x_hat[2]
384 + x_hat[3];
385 res[1] += x_hat[0]
386 + x_hat[1] * Field::<O>::BYTE_COMBINE_2
387 + x_hat[2] * Field::<O>::BYTE_COMBINE_3
388 + x_hat[3];
389 res[2] += x_hat[0]
390 + x_hat[1]
391 + x_hat[2] * Field::<O>::BYTE_COMBINE_2
392 + x_hat[3] * Field::<O>::BYTE_COMBINE_3;
393 res[3] += x_hat[0] * Field::<O>::BYTE_COMBINE_3
394 + x_hat[1]
395 + x_hat[2]
396 + x_hat[3] * Field::<O>::BYTE_COMBINE_2;
397 res
398 })
399 .flatten(),
400 )
401}
402
403fn aes_enc_fwd_mkey1_mtag0<'a, O>(
404 x: &'a GenericArray<Field<O>, O::LENC>,
405 xk: &'a GenericArray<Field<O>, O::PRODRUN128>,
406 input: &'a [u8; 16],
407 delta: &'a Field<O>,
408) -> impl Iterator<Item = Field<O>> + 'a
409where
410 O: OWFParameters,
411{
412 (0..16)
413 .map(|i| {
414 bit_combine_with_delta::<O>(input[i], delta)
416 + Field::<O>::byte_combine_slice(&xk[8 * i..(8 * i) + 8])
417 })
418 .chain(
419 iproduct!(1..O::R::USIZE, 0..4)
420 .map(move |(j, c)| {
421 let ix: usize = 128 * (j - 1) + 32 * c;
423 let ik: usize = 128 * j + 32 * c;
424 let x_hat: [_; 4] = array::from_fn(|r| {
425 Field::<O>::byte_combine_slice(&x[ix + 8 * r..ix + 8 * r + 8])
426 });
427 let mut res: [_; 4] = array::from_fn(|r| {
428 Field::<O>::byte_combine_slice(&xk[ik + 8 * r..ik + 8 * r + 8])
429 });
430
431 res[0] += x_hat[0] * Field::<O>::BYTE_COMBINE_2
433 + x_hat[1] * Field::<O>::BYTE_COMBINE_3
434 + x_hat[2]
435 + x_hat[3];
436 res[1] += x_hat[0]
437 + x_hat[1] * Field::<O>::BYTE_COMBINE_2
438 + x_hat[2] * Field::<O>::BYTE_COMBINE_3
439 + x_hat[3];
440 res[2] += x_hat[0]
441 + x_hat[1]
442 + x_hat[2] * Field::<O>::BYTE_COMBINE_2
443 + x_hat[3] * Field::<O>::BYTE_COMBINE_3;
444 res[3] += x_hat[0] * Field::<O>::BYTE_COMBINE_3
445 + x_hat[1]
446 + x_hat[2]
447 + x_hat[3] * Field::<O>::BYTE_COMBINE_2;
448 res
449 })
450 .flatten(),
451 )
452}
453
454fn aes_enc_fwd_mkey0_mtag1<'a, O>(
455 x: &'a GenericArray<Field<O>, O::LENC>,
456 xk: &'a GenericArray<Field<O>, O::PRODRUN128>,
457) -> impl Iterator<Item = Field<O>> + 'a
458where
459 O: OWFParameters,
460{
461 (0..16)
462 .map(|i| {
463 Field::<O>::byte_combine_slice(&xk[8 * i..(8 * i) + 8])
465 })
466 .chain(
467 iproduct!(1..O::R::USIZE, 0..4)
468 .map(move |(j, c)| {
469 let ix: usize = 128 * (j - 1) + 32 * c;
471 let ik: usize = 128 * j + 32 * c;
472 let x_hat: [_; 4] = array::from_fn(|r| {
473 Field::<O>::byte_combine_slice(&x[ix + 8 * r..ix + 8 * r + 8])
474 });
475 let mut res: [_; 4] = array::from_fn(|r| {
476 Field::<O>::byte_combine_slice(&xk[ik + 8 * r..ik + 8 * r + 8])
477 });
478
479 res[0] += x_hat[0] * Field::<O>::BYTE_COMBINE_2
481 + x_hat[1] * Field::<O>::BYTE_COMBINE_3
482 + x_hat[2]
483 + x_hat[3];
484 res[1] += x_hat[0]
485 + x_hat[1] * Field::<O>::BYTE_COMBINE_2
486 + x_hat[2] * Field::<O>::BYTE_COMBINE_3
487 + x_hat[3];
488 res[2] += x_hat[0]
489 + x_hat[1]
490 + x_hat[2] * Field::<O>::BYTE_COMBINE_2
491 + x_hat[3] * Field::<O>::BYTE_COMBINE_3;
492 res[3] += x_hat[0] * Field::<O>::BYTE_COMBINE_3
493 + x_hat[1]
494 + x_hat[2]
495 + x_hat[3] * Field::<O>::BYTE_COMBINE_2;
496 res
497 })
498 .flatten(),
499 )
500}
501
502fn aes_enc_bkwd_mkey0_mtag0<'a, O>(
503 x: &'a GenericArray<u8, O::QUOTLENC8>,
504 xk: &'a GenericArray<u8, O::PRODRUN128Bytes>,
505 out: &'a [u8; 16],
506) -> impl Iterator<Item = Field<O>> + 'a
507where
508 O: OWFParameters,
509{
510 iproduct!(0..O::R::USIZE, 0..4, 0..4).map(move |(j, c, k)| {
512 let ird = 128 * j + 32 * ((c + 4 - k) % 4) + 8 * k;
514 let x_t = if j < O::R::USIZE - 1 {
515 x[ird / 8]
516 } else {
517 let x_out = out[(ird - 128 * j) / 8];
518 x_out ^ xk[(128 + ird) / 8]
519 };
520 let y_t = x_t.rotate_right(7) ^ x_t.rotate_right(5) ^ x_t.rotate_right(2) ^ 0x5;
521 Field::<O>::byte_combine_bits(y_t)
522 })
523}
524
525fn aes_enc_bkwd_mkey1_mtag0<'a, O>(
526 x: &'a GenericArray<Field<O>, O::LENC>,
527 xk: &'a GenericArray<Field<O>, O::PRODRUN128>,
528 out: &'a [u8; 16],
529 delta: &'a Field<O>,
530) -> impl Iterator<Item = Field<O>> + 'a
531where
532 O: OWFParameters,
533{
534 iproduct!(0..O::R::USIZE, 0..4, 0..4).map(move |(j, c, k)| {
536 let ird = 128 * j + 32 * ((c + 4 - k) % 4) + 8 * k;
538 let x_t: [_; 8] = if j < O::R::USIZE - 1 {
539 array::from_fn(|i| x[ird + i])
540 } else {
541 array::from_fn(|i| {
542 *delta * ((out[(ird - 128 * j + i) / 8] >> ((ird - 128 * j + i) % 8)) & 1)
543 + xk[128 + ird + i]
544 })
545 };
546 let mut y_t = array::from_fn(|i| x_t[(i + 7) % 8] + x_t[(i + 5) % 8] + x_t[(i + 2) % 8]);
547 y_t[0] += delta;
548 y_t[2] += delta;
549 Field::<O>::byte_combine(&y_t)
550 })
551}
552
553fn aes_enc_bkwd_mkey0_mtag1<'a, O>(
554 x: &'a GenericArray<Field<O>, O::LENC>,
555 xk: &'a GenericArray<Field<O>, O::PRODRUN128>,
556) -> impl Iterator<Item = Field<O>> + 'a
557where
558 O: OWFParameters,
559{
560 iproduct!(0..O::R::USIZE, 0..4, 0..4).map(move |(j, c, k)| {
562 let ird = 128 * j + 32 * ((c + 4 - k) % 4) + 8 * k;
564 let x_t = if j < O::R::USIZE - 1 {
565 &x[ird..ird + 8]
566 } else {
567 &xk[128 + ird..136 + ird]
568 };
569 let y_t = array::from_fn(|i| x_t[(i + 7) % 8] + x_t[(i + 5) % 8] + x_t[(i + 2) % 8]);
570 Field::<O>::byte_combine(&y_t)
571 })
572}
573
574fn aes_enc_cstrnts_mkey0<O>(
575 zk_hasher: &mut ZKProofHasher<Field<O>>,
576 input: &[u8; 16],
577 output: &[u8; 16],
578 w: &GenericArray<u8, O::QUOTLENC8>,
579 v: &GenericArray<Field<O>, O::LENC>,
580 k: &GenericArray<u8, O::PRODRUN128Bytes>,
581 vk: &GenericArray<Field<O>, O::PRODRUN128>,
582) where
583 O: OWFParameters,
584{
585 let s = aes_enc_fwd_mkey0_mtag0::<O>(w, k, input);
586 let vs = aes_enc_fwd_mkey0_mtag1::<O>(v, vk);
587 let s_b = aes_enc_bkwd_mkey0_mtag0::<O>(w, k, output);
588 let v_s_b = aes_enc_bkwd_mkey0_mtag1::<O>(v, vk);
589 zk_hasher.process(s, vs, s_b, v_s_b);
590}
591
592fn aes_enc_cstrnts_mkey1<O>(
593 zk_hasher: &mut ZKVerifyHasher<Field<O>>,
594 input: &[u8; 16],
595 output: &[u8; 16],
596 q: &GenericArray<Field<O>, O::LENC>,
597 qk: &GenericArray<Field<O>, O::PRODRUN128>,
598 delta: &Field<O>,
599) where
600 O: OWFParameters,
601{
602 let qs = aes_enc_fwd_mkey1_mtag0::<O>(q, qk, input, delta);
603 let q_s_b = aes_enc_bkwd_mkey1_mtag0::<O>(q, qk, output, delta);
604 zk_hasher.process(qs, q_s_b);
605}
606
607pub(crate) fn aes_prove<O>(
609 w: &GenericArray<u8, O::LBYTES>,
610 u: &GenericArray<u8, O::LAMBDALBYTES>,
611 gv: CstrntsVal<O>,
612 pk: &PublicKey<O>,
613 chall: &GenericArray<u8, <<O as OWFParameters>::BaseParams as BaseParameters>::Chall>,
614) -> QSProof<O>
615where
616 O: OWFParameters,
617{
618 let new_v = transpose_and_into_field::<O>(gv);
619
620 let mut zk_hasher =
621 <<O as OWFParameters>::BaseParams as BaseParameters>::ZKHasher::new_zk_proof_hasher(chall);
622
623 let (k, vk) = aes_key_exp_cstrnts_mkey0::<O>(
624 &mut zk_hasher,
625 GenericArray::from_slice(&w[..O::LKE::USIZE / 8]),
626 GenericArray::from_slice(&new_v[..O::LKE::USIZE]),
627 );
628
629 aes_enc_cstrnts_mkey0::<O>(
630 &mut zk_hasher,
631 pk.owf_input[..16].try_into().unwrap(),
632 pk.owf_output[..16].try_into().unwrap(),
633 GenericArray::from_slice(&w[O::LKE::USIZE / 8..(O::LKE::USIZE + O::LENC::USIZE) / 8]),
634 GenericArray::from_slice(&new_v[O::LKE::USIZE..O::LKE::USIZE + O::LENC::USIZE]),
635 &k,
636 &vk,
637 );
638
639 if O::LAMBDA::USIZE > 128 {
640 aes_enc_cstrnts_mkey0::<O>(
641 &mut zk_hasher,
642 pk.owf_input[16..].try_into().unwrap(),
643 pk.owf_output[16..].try_into().unwrap(),
644 GenericArray::from_slice(&w[(O::LKE::USIZE + O::LENC::USIZE) / 8..O::LBYTES::USIZE]),
645 GenericArray::from_slice(&new_v[(O::LKE::USIZE + O::LENC::USIZE)..O::L::USIZE]),
646 &k,
647 &vk,
648 );
649 }
650
651 let u_s = Field::<O>::from(&u[O::LBYTES::USIZE..]);
652 let v_s = Field::<O>::sum_poly(&new_v[O::L::USIZE..O::L::USIZE + O::LAMBDA::USIZE]);
653 let (a_t, b_t) = zk_hasher.finalize(&u_s, &v_s);
654
655 (a_t.as_bytes(), b_t.as_bytes())
656}
657
658#[allow(clippy::too_many_arguments)]
660pub(crate) fn aes_verify<O, Tau>(
661 d: &GenericArray<u8, O::LBYTES>,
662 gq: Box<GenericArray<GenericArray<u8, O::LAMBDALBYTES>, O::LAMBDA>>,
663 a_t: &GenericArray<u8, O::LAMBDABYTES>,
664 chall2: &GenericArray<u8, <<O as OWFParameters>::BaseParams as BaseParameters>::Chall>,
665 chall3: &GenericArray<u8, O::LAMBDABYTES>,
666 pk: &PublicKey<O>,
667) -> GenericArray<u8, O::LAMBDABYTES>
668where
669 O: OWFParameters,
670 Tau: TauParameters,
671{
672 let delta = Field::<O>::from(chall3);
673 let new_q = convert_gq::<O, Tau>(d, gq, chall3);
674 let mut zk_hasher =
675 <<O as OWFParameters>::BaseParams as BaseParameters>::ZKHasher::new_zk_verify_hasher(
676 chall2, delta,
677 );
678
679 let qk = aes_key_exp_cstrnts_mkey1::<O>(
680 &mut zk_hasher,
681 GenericArray::from_slice(&new_q[..O::LKE::USIZE]),
682 &delta,
683 );
684
685 aes_enc_cstrnts_mkey1::<O>(
686 &mut zk_hasher,
687 pk.owf_input[..16].try_into().unwrap(),
688 pk.owf_output[..16].try_into().unwrap(),
689 GenericArray::from_slice(&new_q[O::LKE::USIZE..(O::LKE::USIZE + O::LENC::USIZE)]),
690 &qk,
691 &delta,
692 );
693 if O::LAMBDA::USIZE > 128 {
694 aes_enc_cstrnts_mkey1::<O>(
695 &mut zk_hasher,
696 pk.owf_input[16..].try_into().unwrap(),
697 pk.owf_output[16..].try_into().unwrap(),
698 GenericArray::from_slice(&new_q[O::LKE::USIZE + O::LENC::USIZE..O::L::USIZE]),
699 &qk,
700 &delta,
701 );
702 }
703
704 let q_s = Field::<O>::sum_poly(&new_q[O::L::USIZE..O::L::USIZE + O::LAMBDA::USIZE]);
705 (zk_hasher.finalize(&q_s) + Field::<O>::from(a_t) * delta).as_bytes()
706}
707
708#[cfg(test)]
709mod test {
710 #![allow(clippy::needless_range_loop)]
711
712 use super::*;
713
714 use crate::{
715 fields::{GF128, GF192, GF256},
716 parameter::{
717 FAEST128sParameters, FAEST192sParameters, FAEST256sParameters, FAESTParameters,
718 OWFParameters, OWF128, OWF192, OWF256,
719 },
720 utils::test::read_test_data,
721 };
722
723 use generic_array::{sequence::GenericSequence, ArrayLength, GenericArray};
724 use serde::Deserialize;
725
726 #[derive(Debug, Deserialize)]
727 #[serde(rename_all = "camelCase")]
728 struct AesExtendedWitness {
729 lambda: u16,
730 key: Vec<u8>,
731 input: Vec<u8>,
732 w: Vec<u8>,
733 }
734
735 #[test]
736 fn aes_extended_witness_test() {
737 let database: Vec<AesExtendedWitness> = read_test_data("AesExtendedWitness.json");
738 for data in database {
739 if data.lambda == 128 {
740 let res = aes_extendedwitness::<OWF128>(
741 GenericArray::from_slice(&data.key),
742 GenericArray::from_slice(
743 &data.input[..<OWF128 as OWFParameters>::InputSize::USIZE],
744 ),
745 );
746 assert_eq!(res.unwrap().as_slice(), &data.w);
747 } else if data.lambda == 192 {
748 let res = aes_extendedwitness::<OWF192>(
749 GenericArray::from_slice(&data.key),
750 GenericArray::from_slice(
751 &data.input[..<OWF192 as OWFParameters>::InputSize::USIZE],
752 ),
753 );
754 assert_eq!(res.unwrap().as_slice(), &data.w);
755 } else {
756 let res = aes_extendedwitness::<OWF256>(
757 GenericArray::from_slice(&data.key),
758 GenericArray::from_slice(
759 &data.input[..<OWF256 as OWFParameters>::InputSize::USIZE],
760 ),
761 );
762 assert_eq!(res.unwrap().as_slice(), &data.w);
763 }
764 }
765 }
766
767 #[derive(Debug, Deserialize)]
768 #[serde(rename_all = "camelCase")]
769 struct AesProve {
770 lambda: u16,
771 w: Vec<u8>,
772 input: Vec<u8>,
773 output: Vec<u8>,
774 at: Vec<u8>,
775 bt: Vec<u8>,
776 }
777
778 impl AesProve {
779 fn as_pk<O>(&self) -> PublicKey<O>
780 where
781 O: OWFParameters,
782 {
783 PublicKey {
784 owf_input: GenericArray::from_slice(&self.input).clone(),
785 owf_output: GenericArray::from_slice(&self.output).clone(),
786 }
787 }
788 }
789
790 #[test]
791 fn aes_prove_test() {
792 let database: Vec<AesProve> = read_test_data("AesProve.json");
793 for data in database {
794 if data.lambda == 128 {
795 let res = aes_prove::<OWF128>(
796 GenericArray::from_slice(&data.w),
797 &GenericArray::generate(|_| 19),
798 &GenericArray::generate(|_| GenericArray::generate(|_| 55)),
799 &data.as_pk(),
800 &GenericArray::generate(|_| 47),
801 );
802 assert_eq!(res.0.as_slice(), &data.at);
803 assert_eq!(res.1.as_slice(), &data.bt);
804 } else if data.lambda == 192 {
805 let res = aes_prove::<OWF192>(
806 GenericArray::from_slice(&data.w),
807 &GenericArray::generate(|_| 19),
808 &GenericArray::generate(|_| GenericArray::generate(|_| 55)),
809 &data.as_pk(),
810 &GenericArray::generate(|_| 47),
811 );
812 assert_eq!(res.0.as_slice(), &data.at);
813 assert_eq!(res.1.as_slice(), &data.bt);
814 } else {
815 let res = aes_prove::<OWF256>(
816 GenericArray::from_slice(&data.w),
817 &GenericArray::generate(|_| 19),
818 &GenericArray::generate(|_| GenericArray::generate(|_| 55)),
819 &data.as_pk(),
820 &GenericArray::generate(|_| 47),
821 );
822 assert_eq!(res.0.as_slice(), &data.at);
823 assert_eq!(res.1.as_slice(), &data.bt);
824 }
825 }
826 }
827
828 #[derive(Debug, Deserialize)]
829 #[serde(rename_all = "camelCase")]
830 struct AesVerify {
831 lambda: u16,
832 gq: Vec<Vec<u8>>,
833 d: Vec<u8>,
834 chall2: Vec<u8>,
835 chall3: Vec<u8>,
836 at: Vec<u8>,
837 input: Vec<u8>,
838 output: Vec<u8>,
839 res: Vec<u64>,
840 }
841
842 impl AesVerify {
843 fn res_as_u8(&self) -> Vec<u8> {
844 self.res.iter().flat_map(|x| x.to_le_bytes()).collect()
845 }
846
847 fn as_pk<O>(&self) -> PublicKey<O>
848 where
849 O: OWFParameters,
850 {
851 PublicKey {
852 owf_input: GenericArray::from_slice(&self.input).clone(),
853 owf_output: GenericArray::from_slice(&self.output).clone(),
854 }
855 }
856
857 fn as_gq<LHI, LHO>(&self) -> GenericArray<GenericArray<u8, LHI>, LHO>
858 where
859 LHI: ArrayLength,
860 LHO: ArrayLength,
861 {
862 self.gq
863 .iter()
864 .map(|x| GenericArray::from_slice(x).clone())
865 .collect()
866 }
867 }
868
869 fn aes_verify<O, Tau>(
870 d: &GenericArray<u8, O::LBYTES>,
871 gq: &GenericArray<GenericArray<u8, O::LAMBDALBYTES>, O::LAMBDA>,
872 a_t: &GenericArray<u8, O::LAMBDABYTES>,
873 chall2: &GenericArray<u8, <<O as OWFParameters>::BaseParams as BaseParameters>::Chall>,
874 chall3: &GenericArray<u8, O::LAMBDABYTES>,
875 pk: &PublicKey<O>,
876 ) -> GenericArray<u8, O::LAMBDABYTES>
877 where
878 O: OWFParameters,
879 Tau: TauParameters,
880 {
881 super::aes_verify::<O, Tau>(
882 d,
883 Box::<GenericArray<_, _>>::from_iter(gq.iter().cloned()),
884 a_t,
885 chall2,
886 chall3,
887 pk,
888 )
889 }
890
891 #[test]
892 fn aes_verify_test() {
893 let database: Vec<AesVerify> = read_test_data("AesVerify.json");
894 for data in database {
895 if data.lambda == 128 {
896 let out = aes_verify::<OWF128, <FAEST128sParameters as FAESTParameters>::Tau>(
897 GenericArray::from_slice(&data.d[..]),
898 &data.as_gq(),
899 GenericArray::from_slice(&data.at),
900 GenericArray::from_slice(&data.chall2[..]),
901 GenericArray::from_slice(&data.chall3[..]),
902 &data.as_pk(),
903 );
904 assert_eq!(
905 GF128::from(&data.res_as_u8()[..16]),
906 GF128::from(out.as_slice())
907 );
908 } else if data.lambda == 192 {
909 let out = aes_verify::<OWF192, <FAEST192sParameters as FAESTParameters>::Tau>(
910 GenericArray::from_slice(&data.d[..]),
911 &data.as_gq(),
912 GenericArray::from_slice(&data.at),
913 GenericArray::from_slice(&data.chall2[..]),
914 GenericArray::from_slice(&data.chall3[..]),
915 &data.as_pk(),
916 );
917 assert_eq!(
918 GF192::from(&data.res_as_u8()[..24]),
919 GF192::from(out.as_slice())
920 );
921 } else {
922 let out = aes_verify::<OWF256, <FAEST256sParameters as FAESTParameters>::Tau>(
923 GenericArray::from_slice(&data.d[..]),
924 &data.as_gq(),
925 GenericArray::from_slice(&data.at),
926 GenericArray::from_slice(&data.chall2[..]),
927 GenericArray::from_slice(&data.chall3[..]),
928 &data.as_pk(),
929 );
930 assert_eq!(
931 GF256::from(&data.res_as_u8()[..32]),
932 GF256::from(out.as_slice())
933 );
934 }
935 }
936 }
937}