1use alloc::string::String;
2use alloc::vec;
3use alloc::vec::Vec;
4
5use p3_field::{
6 BasedVectorSpace, PrimeField, PrimeField32, absorb_radix_bits, max_absorb_injective_limbs,
7 reduce_packed, split_pf_to_field_order_limbs, squeeze_field_order_num_limbs,
8};
9use p3_symmetric::{CryptographicPermutation, Hash, MerkleCap};
10
11use crate::{
12 CanFinalizeDigest, CanObserve, CanSample, CanSampleBits, DuplexChallenger, FieldChallenger,
13};
14
15#[derive(Clone, Debug)]
34pub struct MultiField32Challenger<F, PF, P, const WIDTH: usize, const RATE: usize>
35where
36 F: PrimeField32,
37 PF: PrimeField,
38 P: CryptographicPermutation<[PF; WIDTH]>,
39{
40 inner: DuplexChallenger<PF, P, WIDTH, RATE>,
42 f_buffer: Vec<F>,
43 f_squeeze_buffer: Vec<F>,
45}
46
47impl<F, PF, P, const WIDTH: usize, const RATE: usize> MultiField32Challenger<F, PF, P, WIDTH, RATE>
48where
49 F: PrimeField32,
50 PF: PrimeField,
51 P: CryptographicPermutation<[PF; WIDTH]>,
52{
53 #[inline]
56 #[must_use]
57 pub const fn absorb_radix_bits(&self) -> u32 {
58 absorb_radix_bits::<F>()
59 }
60
61 #[inline]
65 #[must_use]
66 pub fn absorb_num_f_elms(&self) -> usize {
67 max_absorb_injective_limbs::<F, PF>()
68 }
69
70 #[inline]
75 #[must_use]
76 pub fn squeeze_num_f_elms(&self) -> usize {
77 squeeze_field_order_num_limbs::<PF, F>()
78 }
79
80 #[inline]
82 #[must_use]
83 pub const fn pending_f_squeeze_len(&self) -> usize {
84 self.f_squeeze_buffer.len()
85 }
86
87 pub fn new(permutation: P) -> Result<Self, String> {
88 if F::order() >= PF::order() {
89 return Err(String::from("F::order() must be less than PF::order()"));
90 }
91 if RATE >= WIDTH {
92 return Err(String::from("RATE must be less than WIDTH"));
93 }
94
95 Ok(Self {
96 inner: DuplexChallenger::new(permutation),
97 f_buffer: vec![],
98 f_squeeze_buffer: vec![],
99 })
100 }
101
102 fn flush_f_if_non_empty(&mut self) {
103 if self.f_buffer.is_empty() {
104 return;
105 }
106 let n_in = self.f_buffer.len();
107 let absorb_n = self.absorb_num_f_elms();
108 assert!(n_in <= absorb_n * RATE);
109 let rb = self.absorb_radix_bits();
110 let packed: Vec<PF> = self
111 .f_buffer
112 .chunks(absorb_n)
113 .map(|chunk| reduce_packed(chunk, rb))
114 .collect();
115 self.inner.absorb_rate_padded_with_tag(&packed, n_in as u8);
116 self.f_buffer.clear();
117 self.f_squeeze_buffer.clear();
118 }
119
120 fn refill_f_squeeze_from_inner(&mut self) {
121 self.f_squeeze_buffer.clear();
122 let squeeze_n = self.squeeze_num_f_elms();
123 for &pf in &self.inner.output_buffer {
124 self.f_squeeze_buffer
125 .extend(split_pf_to_field_order_limbs::<PF, F>(pf, squeeze_n));
126 }
127 self.inner.output_buffer.clear();
131 }
132}
133
134impl<F, PF, P, const WIDTH: usize, const RATE: usize> FieldChallenger<F>
135 for MultiField32Challenger<F, PF, P, WIDTH, RATE>
136where
137 F: PrimeField32,
138 PF: PrimeField,
139 P: CryptographicPermutation<[PF; WIDTH]>,
140{
141}
142
143impl<F, PF, P, const WIDTH: usize, const RATE: usize> CanObserve<F>
144 for MultiField32Challenger<F, PF, P, WIDTH, RATE>
145where
146 F: PrimeField32,
147 PF: PrimeField,
148 P: CryptographicPermutation<[PF; WIDTH]>,
149{
150 fn observe(&mut self, value: F) {
151 self.inner.output_buffer.clear();
152 self.f_squeeze_buffer.clear();
153 self.f_buffer.push(value);
154 if self.f_buffer.len() == self.absorb_num_f_elms() * RATE {
155 self.flush_f_if_non_empty();
156 }
157 }
158}
159
160impl<F, PF, const N: usize, P, const WIDTH: usize, const RATE: usize> CanObserve<[F; N]>
161 for MultiField32Challenger<F, PF, P, WIDTH, RATE>
162where
163 F: PrimeField32,
164 PF: PrimeField,
165 P: CryptographicPermutation<[PF; WIDTH]>,
166{
167 fn observe(&mut self, values: [F; N]) {
168 for value in values {
169 self.observe(value);
170 }
171 }
172}
173
174impl<F, PF, const N: usize, P, const WIDTH: usize, const RATE: usize> CanObserve<Hash<F, PF, N>>
175 for MultiField32Challenger<F, PF, P, WIDTH, RATE>
176where
177 F: PrimeField32,
178 PF: PrimeField,
179 P: CryptographicPermutation<[PF; WIDTH]>,
180{
181 fn observe(&mut self, values: Hash<F, PF, N>) {
182 self.inner.output_buffer.clear();
183 self.f_squeeze_buffer.clear();
184 self.flush_f_if_non_empty();
185
186 let words: &[PF; N] = values.as_ref();
187
188 for chunk in words.chunks(RATE) {
189 self.inner
190 .absorb_rate_padded_with_tag(chunk, chunk.len() as u8);
191 self.f_squeeze_buffer.clear();
192 }
193 }
194}
195
196impl<F, PF, const N: usize, P, const WIDTH: usize, const RATE: usize>
197 CanObserve<&MerkleCap<F, [PF; N]>> for MultiField32Challenger<F, PF, P, WIDTH, RATE>
198where
199 F: PrimeField32,
200 PF: PrimeField,
201 P: CryptographicPermutation<[PF; WIDTH]>,
202{
203 fn observe(&mut self, cap: &MerkleCap<F, [PF; N]>) {
204 for digest in cap.roots() {
205 self.observe(Hash::<F, PF, N>::from(*digest));
206 }
207 }
208}
209
210impl<F, PF, const N: usize, P, const WIDTH: usize, const RATE: usize>
211 CanObserve<MerkleCap<F, [PF; N]>> for MultiField32Challenger<F, PF, P, WIDTH, RATE>
212where
213 F: PrimeField32,
214 PF: PrimeField,
215 P: CryptographicPermutation<[PF; WIDTH]>,
216{
217 fn observe(&mut self, cap: MerkleCap<F, [PF; N]>) {
218 self.observe(&cap);
219 }
220}
221
222impl<F, PF, P, const WIDTH: usize, const RATE: usize> CanObserve<Vec<Vec<F>>>
223 for MultiField32Challenger<F, PF, P, WIDTH, RATE>
224where
225 F: PrimeField32,
226 PF: PrimeField,
227 P: CryptographicPermutation<[PF; WIDTH]>,
228{
229 fn observe(&mut self, valuess: Vec<Vec<F>>) {
230 for values in valuess {
231 for value in values {
232 self.observe(value);
233 }
234 }
235 }
236}
237
238impl<F, EF, PF, P, const WIDTH: usize, const RATE: usize> CanSample<EF>
239 for MultiField32Challenger<F, PF, P, WIDTH, RATE>
240where
241 F: PrimeField32,
242 EF: BasedVectorSpace<F>,
243 PF: PrimeField,
244 P: CryptographicPermutation<[PF; WIDTH]>,
245{
246 fn sample(&mut self) -> EF {
247 EF::from_basis_coefficients_fn(|_| {
248 self.flush_f_if_non_empty();
249 if self.f_squeeze_buffer.is_empty() {
250 if !self.inner.input_buffer.is_empty() || self.inner.output_buffer.is_empty() {
251 self.inner.duplexing();
252 }
253 self.refill_f_squeeze_from_inner();
254 }
255 self.f_squeeze_buffer
256 .pop()
257 .expect("Output buffer should be non-empty")
258 })
259 }
260}
261
262impl<F, PF, P, const WIDTH: usize, const RATE: usize> CanSampleBits<usize>
263 for MultiField32Challenger<F, PF, P, WIDTH, RATE>
264where
265 F: PrimeField32,
266 PF: PrimeField,
267 P: CryptographicPermutation<[PF; WIDTH]>,
268{
269 fn sample_bits(&mut self, bits: usize) -> usize {
278 assert!(bits < (usize::BITS as usize));
279 assert!((1 << bits) < F::ORDER_U32);
280 let rand_f: F = self.sample();
281 let rand_usize = rand_f.as_canonical_u32() as usize;
282 rand_usize & ((1 << bits) - 1)
283 }
284}
285
286impl<F, PF, P, const WIDTH: usize, const RATE: usize> CanFinalizeDigest
287 for MultiField32Challenger<F, PF, P, WIDTH, RATE>
288where
289 F: PrimeField32,
290 PF: PrimeField,
291 P: CryptographicPermutation<[PF; WIDTH]>,
292{
293 type Digest = [PF; RATE];
294
295 fn finalize(mut self) -> [PF; RATE] {
296 let had_pending_f = !self.f_buffer.is_empty();
300 self.flush_f_if_non_empty();
301 if !had_pending_f {
302 self.inner.duplexing();
303 }
304 self.inner.sponge_state[..RATE].try_into().unwrap()
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use p3_baby_bear::BabyBear;
311 use p3_field::{
312 Field, PrimeCharacteristicRing, PrimeField, injective_pack_bits, split_pf_to_packed_limbs,
313 squeeze_field_order_num_limbs,
314 };
315 use p3_goldilocks::Goldilocks;
316 use p3_symmetric::Permutation;
317
318 use super::*;
319
320 const WIDTH: usize = 8;
321 const RATE: usize = 4;
322
323 type F = BabyBear;
324 type PF = Goldilocks;
325
326 #[derive(Clone)]
327 struct TestPermutation;
328
329 impl Permutation<[PF; WIDTH]> for TestPermutation {
330 fn permute_mut(&self, input: &mut [PF; WIDTH]) {
331 for (i, val) in input.iter_mut().enumerate() {
332 *val = PF::from_u8((i + 1) as u8);
333 }
334 }
335 }
336
337 impl CryptographicPermutation<[PF; WIDTH]> for TestPermutation {}
338
339 #[derive(Clone)]
342 struct MixingPermutation;
343
344 impl Permutation<[PF; WIDTH]> for MixingPermutation {
345 fn permute_mut(&self, input: &mut [PF; WIDTH]) {
346 let sum: PF = input.iter().copied().sum();
347 for (i, val) in input.iter_mut().enumerate() {
348 *val = sum + PF::from_u8((i + 1) as u8);
349 }
350 }
351 }
352
353 impl CryptographicPermutation<[PF; WIDTH]> for MixingPermutation {}
354
355 #[test]
356 fn test_packing() {
357 let c = MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
358 assert_eq!(c.absorb_radix_bits(), 31);
359 assert_eq!(c.absorb_num_f_elms(), 2);
360 assert_eq!(c.squeeze_num_f_elms(), 1);
361 assert_eq!(squeeze_field_order_num_limbs::<PF, F>(), 1);
362 }
363
364 #[test]
365 fn test_output_buffer_excludes_capacity() {
366 let permutation = TestPermutation;
367 let mut challenger =
368 MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(permutation).unwrap();
369
370 let squeeze_n = challenger.squeeze_num_f_elms();
371
372 let _: F = challenger.sample();
373
374 let expected_output_size = RATE * squeeze_n;
375
376 assert_eq!(
377 challenger.pending_f_squeeze_len(),
378 expected_output_size - 1,
379 "Pending F squeeze should be RATE * squeeze_num_f_elms minus one sample"
380 );
381 assert_eq!(
382 challenger.inner.output_buffer.len(),
383 0,
384 "After refill, inner PF output buffer is drained like popped F outputs"
385 );
386 }
387
388 #[test]
389 fn test_finalize() {
390 let new_chal =
391 || MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
392
393 let mut c1 = new_chal();
395 let mut c2 = new_chal();
396 for i in 0..5u8 {
397 c1.observe(F::from_u8(i));
398 c2.observe(F::from_u8(i));
399 }
400 assert_eq!(c1.finalize(), c2.finalize());
401
402 let mut c1 = new_chal();
404 let mut c2 = new_chal();
405 for i in 0..5u8 {
406 c1.observe(F::from_u8(i));
407 c2.observe(F::from_u8(i + 1));
408 }
409 assert_ne!(c1.finalize(), c2.finalize());
410 }
411
412 #[test]
421 fn test_finalize_sample_interaction() {
422 let batch_size = {
423 let c =
424 MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
425 c.squeeze_num_f_elms() * RATE
426 };
427
428 let digest = |n_samples: usize| {
429 let mut c =
430 MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
431 for i in 0..3u8 {
432 c.observe(F::from_u8(i));
433 }
434 for _ in 0..n_samples {
435 let _: F = c.sample();
436 }
437 c.finalize()
438 };
439
440 assert_ne!(digest(0), digest(1));
443
444 assert_eq!(digest(1), digest(2));
446 assert_eq!(digest(1), digest(batch_size));
447
448 assert_ne!(digest(batch_size), digest(batch_size + 1));
450
451 assert_eq!(digest(batch_size + 1), digest(batch_size + 2));
453 }
454
455 #[test]
456 fn test_partial_absorb_length_distinct_from_padded_equivalent() {
457 let ne = {
458 let c =
459 MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
460 c.absorb_num_f_elms()
461 };
462 assert_eq!(ne, 2);
463
464 let mut a =
465 MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
466 a.observe(F::ONE);
467
468 let mut b =
469 MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
470 b.observe(F::ONE);
471 for _ in 1..ne {
472 b.observe(F::ZERO);
473 }
474
475 assert_ne!(a.finalize(), b.finalize());
476 }
477
478 #[test]
479 fn test_absorb_no_radix_overflow_collision() {
480 let mut a =
481 MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
482 a.observe(F::from_u32(1 << 30));
483 a.observe(F::ZERO);
484
485 let mut b =
486 MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
487 b.observe(F::ZERO);
488 b.observe(F::ONE);
489
490 assert_ne!(a.finalize(), b.finalize());
491 }
492
493 #[test]
494 fn test_duplexing_respects_rate() {
495 let permutation = TestPermutation;
496 let mut challenger =
497 MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(permutation).unwrap();
498
499 let absorb_n = challenger.absorb_num_f_elms();
500
501 for i in 0..(absorb_n * RATE) {
502 challenger.observe(F::from_u8(i as u8));
503 }
504
505 assert_eq!(
506 challenger.inner.output_buffer.len(),
507 RATE,
508 "After a full F batch flush, inner holds one rate row of PF elements"
509 );
510 assert_eq!(
511 challenger.pending_f_squeeze_len(),
512 0,
513 "F limbs are produced on sample() via split_pf_to_packed_limbs, not on observe"
514 );
515 }
516
517 #[test]
518 fn test_squeeze_covers_full_f_range() {
519 use p3_field::split_pf_to_field_order_limbs;
526 let pack_bits = injective_pack_bits::<F>();
527 let threshold = 1u32 << pack_bits; let v_raw = F::ORDER_U32 as u64 + threshold as u64 + 1;
533 let pf_val = PF::from_u64(v_raw);
534 let limbs = split_pf_to_field_order_limbs::<PF, F>(pf_val, 1);
535 assert_eq!(limbs[0].as_canonical_u32(), threshold + 1);
537 assert!(
538 limbs[0].as_canonical_u32() > threshold,
539 "c0 must exceed the old base-2^30 ceiling"
540 );
541 }
542
543 #[test]
544 fn test_observe_hash_native_pf_high_bits_distinct() {
545 use num_bigint::BigUint;
546 use p3_bn254::Bn254;
547 use p3_field::split_pf_to_packed_limbs;
548 use p3_symmetric::Hash;
549
550 type PF254 = Bn254;
551
552 #[derive(Clone)]
553 struct Bn254MixingPermutation;
554
555 impl Permutation<[PF254; WIDTH]> for Bn254MixingPermutation {
556 fn permute_mut(&self, input: &mut [PF254; WIDTH]) {
557 let sum: PF254 = input.iter().copied().sum();
558 for (i, val) in input.iter_mut().enumerate() {
559 *val = sum + PF254::from_u8((i + 1) as u8);
560 }
561 }
562 }
563
564 impl CryptographicPermutation<[PF254; WIDTH]> for Bn254MixingPermutation {}
565
566 let pack_bits = injective_pack_bits::<F>();
567 let observe_n = PF254::bits().div_ceil(pack_bits as usize);
568
569 let a = PF254::from_biguint(BigUint::from(1u32)).unwrap();
570 let b = PF254::from_biguint(BigUint::from(1u32) + (BigUint::from(1u32) << 200)).unwrap();
571 assert_ne!(a, b);
572
573 let digest = |h: PF254| {
574 let mut c =
575 MultiField32Challenger::<F, PF254, _, WIDTH, RATE>::new(Bn254MixingPermutation)
576 .unwrap();
577 c.observe(Hash::<F, PF254, 1>::from([h]));
578 c.finalize()
579 };
580
581 assert_ne!(digest(a), digest(b));
582
583 let limbs_a = split_pf_to_packed_limbs::<PF254, F>(a, observe_n, pack_bits);
584 let limbs_b = split_pf_to_packed_limbs::<PF254, F>(b, observe_n, pack_bits);
585 assert_ne!(limbs_a, limbs_b);
586
587 let d_a = a.as_canonical_biguint().to_u64_digits();
588 let d_b = b.as_canonical_biguint().to_u64_digits();
589 let take3 = |d: &[u64]| {
590 let mut v = [0u64; 3];
591 for (i, x) in d.iter().take(3).enumerate() {
592 v[i] = *x;
593 }
594 v
595 };
596 assert_eq!(take3(&d_a), take3(&d_b));
597 }
598
599 #[test]
600 fn test_observe_hash_native_vs_expanded_f_not_equal() {
601 use p3_symmetric::Hash;
602
603 let g = PF::from_u64(123456789);
604 let mut native =
605 MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
606 native.observe(Hash::<F, PF, 1>::from([g]));
607
608 let mut via_f =
609 MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
610 let pb = injective_pack_bits::<F>();
611 let n = PF::bits().div_ceil(pb as usize);
612 for f in split_pf_to_packed_limbs::<PF, F>(g, n, pb) {
613 via_f.observe(f);
614 }
615
616 assert_ne!(native.finalize(), via_f.finalize());
617 }
618
619 #[test]
620 fn test_inner_sponge_matches_manual_absorb_chain() {
621 let mut m =
622 MultiField32Challenger::<F, PF, _, WIDTH, RATE>::new(MixingPermutation).unwrap();
623 for i in 0..8u8 {
624 m.observe(F::from_u8(i));
625 }
626 let d_m = m.inner.sponge_state;
627
628 let mut inner = DuplexChallenger::<PF, _, WIDTH, RATE>::new(MixingPermutation);
629 let packed: Vec<PF> = (0..8)
630 .step_by(2)
631 .map(|j| {
632 reduce_packed::<F, PF>(
633 &[F::from_u8(j), F::from_u8(j + 1)],
634 absorb_radix_bits::<F>(),
635 )
636 })
637 .collect();
638 inner.absorb_rate_padded_with_tag(&packed, 8);
639 assert_eq!(d_m, inner.sponge_state);
640 }
641}