1use crate::errors;
19use crate::tensor::TensorProduct;
20use crate::trace::TraceCompatibleField;
21use alloc::vec;
22use alloc::vec::Vec;
23use core::fmt;
24use hekate_math::{Bit, Block8, Block16, Block32, Block64, Block128, Flat};
25use zeroize::Zeroize;
26
27#[derive(Clone, Copy, Debug, Eq, PartialEq)]
29pub enum Error {
30 DomainTooLarge { num_vars: usize },
32
33 DomainSizeMismatch { expected_len: usize, got_len: usize },
35
36 PointDimensionMismatch { expected_len: usize, got_len: usize },
39
40 UnsupportedFold { kind: &'static str },
42}
43
44impl fmt::Display for Error {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 match self {
47 Self::DomainTooLarge { num_vars } => {
48 write!(
49 f,
50 "Virtual polynomial domain too large: num_vars={num_vars}"
51 )
52 }
53 Self::DomainSizeMismatch {
54 expected_len,
55 got_len,
56 } => write!(
57 f,
58 "Virtual polynomial domain size mismatch: expected {expected_len}, got {got_len}",
59 ),
60 Self::PointDimensionMismatch {
61 expected_len,
62 got_len,
63 } => write!(
64 f,
65 "Virtual polynomial point dimension mismatch: expected {expected_len}, got {got_len}",
66 ),
67 Self::UnsupportedFold { kind } => {
68 write!(
69 f,
70 "Virtual polynomial cannot be folded lazily for kind: {kind}"
71 )
72 }
73 }
74 }
75}
76
77#[derive(Clone, Debug, Zeroize)]
81pub enum PolyVariant<'a, F>
82where
83 F: TraceCompatibleField,
84{
85 #[zeroize(skip)]
87 Dense(&'a [Flat<F>]),
88 #[zeroize(skip)]
89 Shifted(&'a [Flat<F>]),
90
91 Eq(TensorProduct<F>),
94
95 #[zeroize(skip)]
97 PackedBitB8 {
98 data: &'a [Flat<Block8>],
99 bit_idx: usize,
100 },
101
102 #[zeroize(skip)]
104 PackedBitB16 {
105 data: &'a [Flat<Block16>],
106 bit_idx: usize,
107 },
108
109 #[zeroize(skip)]
111 PackedBitB32 {
112 data: &'a [Flat<Block32>],
113 bit_idx: usize,
114 },
115
116 #[zeroize(skip)]
118 PackedBitB64 {
119 data: &'a [Flat<Block64>],
120 bit_idx: usize,
121 },
122
123 #[zeroize(skip)]
126 CompositeSelector(Vec<&'a [Bit]>),
127
128 #[zeroize(skip)]
133 TransitionMask {
134 num_vars: usize,
135 product_of_challenges: F,
136 },
137
138 #[zeroize(skip)]
144 IndirectBit {
145 data: &'a [Bit],
146 indices: &'a [usize],
147 },
148 #[zeroize(skip)]
149 IndirectB8 {
150 data: &'a [Flat<Block8>],
151 indices: &'a [usize],
152 },
153 #[zeroize(skip)]
154 IndirectB16 {
155 data: &'a [Flat<Block16>],
156 indices: &'a [usize],
157 },
158 #[zeroize(skip)]
159 IndirectB32 {
160 data: &'a [Flat<Block32>],
161 indices: &'a [usize],
162 },
163 #[zeroize(skip)]
164 IndirectB64 {
165 data: &'a [Flat<Block64>],
166 indices: &'a [usize],
167 },
168 #[zeroize(skip)]
169 IndirectB128 {
170 data: &'a [Flat<Block128>],
171 indices: &'a [usize],
172 },
173
174 #[zeroize(skip)]
179 StrideBit {
180 data: &'a [Bit],
181 start: usize,
182 step: usize,
183 len: usize,
184 },
185 #[zeroize(skip)]
186 StrideB8 {
187 data: &'a [Flat<Block8>],
188 start: usize,
189 step: usize,
190 len: usize,
191 },
192 #[zeroize(skip)]
193 StrideB16 {
194 data: &'a [Flat<Block16>],
195 start: usize,
196 step: usize,
197 len: usize,
198 },
199 #[zeroize(skip)]
200 StrideB32 {
201 data: &'a [Flat<Block32>],
202 start: usize,
203 step: usize,
204 len: usize,
205 },
206 #[zeroize(skip)]
207 StrideB64 {
208 data: &'a [Flat<Block64>],
209 start: usize,
210 step: usize,
211 len: usize,
212 },
213 #[zeroize(skip)]
214 StrideB128 {
215 data: &'a [Flat<Block128>],
216 start: usize,
217 step: usize,
218 len: usize,
219 },
220
221 #[zeroize(skip)]
227 RotationBit { data: &'a [Bit], rotation: usize },
228 #[zeroize(skip)]
229 RotationB8 {
230 data: &'a [Flat<Block8>],
231 rotation: usize,
232 },
233 #[zeroize(skip)]
234 RotationB16 {
235 data: &'a [Flat<Block16>],
236 rotation: usize,
237 },
238 #[zeroize(skip)]
239 RotationB32 {
240 data: &'a [Flat<Block32>],
241 rotation: usize,
242 },
243 #[zeroize(skip)]
244 RotationB64 {
245 data: &'a [Flat<Block64>],
246 rotation: usize,
247 },
248 #[zeroize(skip)]
249 RotationB128 {
250 data: &'a [Flat<Block128>],
251 rotation: usize,
252 },
253
254 #[zeroize(skip)]
258 BitSlice(&'a [Bit]),
259 #[zeroize(skip)]
260 B8Slice(&'a [Flat<Block8>]),
261 #[zeroize(skip)]
262 B16Slice(&'a [Flat<Block16>]),
263 #[zeroize(skip)]
264 B32Slice(&'a [Flat<Block32>]),
265 #[zeroize(skip)]
266 B64Slice(&'a [Flat<Block64>]),
267 #[zeroize(skip)]
268 B128Slice(&'a [Flat<Block128>]),
269
270 #[zeroize(skip)]
274 ShiftedBitSlice(&'a [Bit]),
275 #[zeroize(skip)]
276 ShiftedB8Slice(&'a [Flat<Block8>]),
277 #[zeroize(skip)]
278 ShiftedB16Slice(&'a [Flat<Block16>]),
279 #[zeroize(skip)]
280 ShiftedB32Slice(&'a [Flat<Block32>]),
281 #[zeroize(skip)]
282 ShiftedB64Slice(&'a [Flat<Block64>]),
283 #[zeroize(skip)]
284 ShiftedB128Slice(&'a [Flat<Block128>]),
285
286 #[zeroize(skip)]
287 ShiftedPackedBitB8 {
288 data: &'a [Flat<Block8>],
289 bit_idx: usize,
290 },
291 #[zeroize(skip)]
292 ShiftedPackedBitB16 {
293 data: &'a [Flat<Block16>],
294 bit_idx: usize,
295 },
296 #[zeroize(skip)]
297 ShiftedPackedBitB32 {
298 data: &'a [Flat<Block32>],
299 bit_idx: usize,
300 },
301 #[zeroize(skip)]
302 ShiftedPackedBitB64 {
303 data: &'a [Flat<Block64>],
304 bit_idx: usize,
305 },
306}
307
308impl<'a, F> PolyVariant<'a, F>
309where
310 F: TraceCompatibleField,
311{
312 pub fn len(&self) -> usize {
314 match self {
315 Self::Dense(h) => h.len(),
316 Self::Shifted(h) => h.len(),
317 Self::Eq(t) => 1 << t.num_vars(),
318 Self::PackedBitB8 { data, .. } => data.len(),
319 Self::PackedBitB16 { data, .. } => data.len(),
320 Self::PackedBitB32 { data, .. } => data.len(),
321 Self::PackedBitB64 { data, .. } => data.len(),
322 Self::TransitionMask { num_vars, .. } => 1 << num_vars,
323 Self::CompositeSelector(cols) => {
324 if cols.is_empty() {
325 0
326 } else {
327 cols[0].len()
328 }
329 }
330 Self::IndirectBit { indices, .. } => indices.len(),
331 Self::IndirectB8 { indices, .. } => indices.len(),
332 Self::IndirectB16 { indices, .. } => indices.len(),
333 Self::IndirectB32 { indices, .. } => indices.len(),
334 Self::IndirectB64 { indices, .. } => indices.len(),
335 Self::IndirectB128 { indices, .. } => indices.len(),
336 Self::StrideBit { len, .. } => *len,
337 Self::StrideB8 { len, .. } => *len,
338 Self::StrideB16 { len, .. } => *len,
339 Self::StrideB32 { len, .. } => *len,
340 Self::StrideB64 { len, .. } => *len,
341 Self::StrideB128 { len, .. } => *len,
342 Self::RotationBit { data, .. } => data.len(),
343 Self::RotationB8 { data, .. } => data.len(),
344 Self::RotationB16 { data, .. } => data.len(),
345 Self::RotationB32 { data, .. } => data.len(),
346 Self::RotationB64 { data, .. } => data.len(),
347 Self::RotationB128 { data, .. } => data.len(),
348 Self::BitSlice(h) => h.len(),
349 Self::B8Slice(h) => h.len(),
350 Self::B16Slice(h) => h.len(),
351 Self::B32Slice(h) => h.len(),
352 Self::B64Slice(h) => h.len(),
353 Self::B128Slice(h) => h.len(),
354 Self::ShiftedBitSlice(h) => h.len(),
355 Self::ShiftedB8Slice(h) => h.len(),
356 Self::ShiftedB16Slice(h) => h.len(),
357 Self::ShiftedB32Slice(h) => h.len(),
358 Self::ShiftedB64Slice(h) => h.len(),
359 Self::ShiftedB128Slice(h) => h.len(),
360 Self::ShiftedPackedBitB8 { data, .. } => data.len(),
361 Self::ShiftedPackedBitB16 { data, .. } => data.len(),
362 Self::ShiftedPackedBitB32 { data, .. } => data.len(),
363 Self::ShiftedPackedBitB64 { data, .. } => data.len(),
364 }
365 }
366
367 pub fn is_empty(&self) -> bool {
368 self.len() == 0
369 }
370
371 #[inline(always)]
375 pub fn get_at(&self, index: usize) -> Flat<F> {
376 match self {
377 Self::Dense(h) => h[index],
378 Self::Shifted(h) => {
379 let len = h.len();
380 let next_idx = if index + 1 == len { 0 } else { index + 1 };
381 h[next_idx]
382 }
383 Self::Eq(t) => t.evaluate_at_index(index),
384 Self::PackedBitB8 { data, bit_idx } => {
385 let bit = data[index].tower_bit(*bit_idx);
386 if bit == 1 {
387 Flat::from_raw(F::ONE)
388 } else {
389 Flat::from_raw(F::ZERO)
390 }
391 }
392 Self::PackedBitB16 { data, bit_idx } => {
393 let bit = data[index].tower_bit(*bit_idx);
394 if bit == 1 {
395 Flat::from_raw(F::ONE)
396 } else {
397 Flat::from_raw(F::ZERO)
398 }
399 }
400 Self::PackedBitB32 { data, bit_idx } => {
401 let bit = data[index].tower_bit(*bit_idx);
402 if bit == 1 {
403 Flat::from_raw(F::ONE)
404 } else {
405 Flat::from_raw(F::ZERO)
406 }
407 }
408 Self::PackedBitB64 { data, bit_idx } => {
409 let bit = data[index].tower_bit(*bit_idx);
410 if bit == 1 {
411 Flat::from_raw(F::ONE)
412 } else {
413 Flat::from_raw(F::ZERO)
414 }
415 }
416 Self::TransitionMask {
417 num_vars,
418 product_of_challenges,
419 } => {
420 let last_idx: usize = (1 << num_vars) - 1;
421 if index == last_idx {
422 Flat::from_raw(F::ONE - *product_of_challenges)
423 } else {
424 Flat::from_raw(F::ONE)
425 }
426 }
427 Self::CompositeSelector(cols) => {
428 let mut sum = Flat::from_raw(F::default());
429 for col in cols {
430 if col[index].0 == 1 {
431 sum += Flat::from_raw(F::ONE);
432 }
433 }
434
435 sum
436 }
437
438 Self::IndirectBit { data, indices } => Flat::from_raw(F::from(data[indices[index]])),
439 Self::IndirectB8 { data, indices } => F::promote_flat(data[indices[index]]),
440 Self::IndirectB16 { data, indices } => F::promote_flat(data[indices[index]]),
441 Self::IndirectB32 { data, indices } => F::promote_flat(data[indices[index]]),
442 Self::IndirectB64 { data, indices } => F::promote_flat(data[indices[index]]),
443 Self::IndirectB128 { data, indices } => F::promote_flat(data[indices[index]]),
444
445 Self::StrideBit {
446 data, start, step, ..
447 } => Flat::from_raw(F::from(data[start + index * step])),
448 Self::StrideB8 {
449 data, start, step, ..
450 } => F::promote_flat(data[start + index * step]),
451 Self::StrideB16 {
452 data, start, step, ..
453 } => F::promote_flat(data[start + index * step]),
454 Self::StrideB32 {
455 data, start, step, ..
456 } => F::promote_flat(data[start + index * step]),
457 Self::StrideB64 {
458 data, start, step, ..
459 } => F::promote_flat(data[start + index * step]),
460 Self::StrideB128 {
461 data, start, step, ..
462 } => F::promote_flat(data[start + index * step]),
463
464 Self::RotationBit { data, rotation } => {
465 Flat::from_raw(F::from(data[(index + rotation) & (data.len() - 1)]))
466 }
467 Self::RotationB8 { data, rotation } => {
468 F::promote_flat(data[(index + rotation) & (data.len() - 1)])
469 }
470 Self::RotationB16 { data, rotation } => {
471 F::promote_flat(data[(index + rotation) & (data.len() - 1)])
472 }
473 Self::RotationB32 { data, rotation } => {
474 F::promote_flat(data[(index + rotation) & (data.len() - 1)])
475 }
476 Self::RotationB64 { data, rotation } => {
477 F::promote_flat(data[(index + rotation) & (data.len() - 1)])
478 }
479 Self::RotationB128 { data, rotation } => {
480 F::promote_flat(data[(index + rotation) & (data.len() - 1)])
481 }
482
483 Self::BitSlice(s) => Flat::from_raw(F::from(s[index])),
484 Self::B8Slice(s) => F::promote_flat(s[index]),
485 Self::B16Slice(s) => F::promote_flat(s[index]),
486 Self::B32Slice(s) => F::promote_flat(s[index]),
487 Self::B64Slice(s) => F::promote_flat(s[index]),
488 Self::B128Slice(s) => F::promote_flat(s[index]),
489
490 Self::ShiftedBitSlice(s) => {
491 let len = s.len();
492 let next_idx = if index + 1 == len { 0 } else { index + 1 };
493
494 Flat::from_raw(F::from(s[next_idx]))
495 }
496 Self::ShiftedB8Slice(s) => {
497 let len = s.len();
498 let next_idx = if index + 1 == len { 0 } else { index + 1 };
499
500 F::promote_flat(s[next_idx])
501 }
502 Self::ShiftedB16Slice(s) => {
503 let len = s.len();
504 let next_idx = if index + 1 == len { 0 } else { index + 1 };
505
506 F::promote_flat(s[next_idx])
507 }
508 Self::ShiftedB32Slice(s) => {
509 let len = s.len();
510 let next_idx = if index + 1 == len { 0 } else { index + 1 };
511
512 F::promote_flat(s[next_idx])
513 }
514 Self::ShiftedB64Slice(s) => {
515 let len = s.len();
516 let next_idx = if index + 1 == len { 0 } else { index + 1 };
517
518 F::promote_flat(s[next_idx])
519 }
520 Self::ShiftedB128Slice(s) => {
521 let len = s.len();
522 let next_idx = if index + 1 == len { 0 } else { index + 1 };
523
524 F::promote_flat(s[next_idx])
525 }
526 Self::ShiftedPackedBitB8 { data, bit_idx } => {
527 let len = data.len();
528 let next_idx = if index + 1 == len { 0 } else { index + 1 };
529 let bit = data[next_idx].tower_bit(*bit_idx);
530
531 if bit == 1 {
532 Flat::from_raw(F::ONE)
533 } else {
534 Flat::from_raw(F::ZERO)
535 }
536 }
537 Self::ShiftedPackedBitB16 { data, bit_idx } => {
538 let len = data.len();
539 let next_idx = if index + 1 == len { 0 } else { index + 1 };
540 let bit = data[next_idx].tower_bit(*bit_idx);
541
542 if bit == 1 {
543 Flat::from_raw(F::ONE)
544 } else {
545 Flat::from_raw(F::ZERO)
546 }
547 }
548 Self::ShiftedPackedBitB32 { data, bit_idx } => {
549 let len = data.len();
550 let next_idx = if index + 1 == len { 0 } else { index + 1 };
551 let bit = data[next_idx].tower_bit(*bit_idx);
552
553 if bit == 1 {
554 Flat::from_raw(F::ONE)
555 } else {
556 Flat::from_raw(F::ZERO)
557 }
558 }
559 Self::ShiftedPackedBitB64 { data, bit_idx } => {
560 let len = data.len();
561 let next_idx = if index + 1 == len { 0 } else { index + 1 };
562 let bit = data[next_idx].tower_bit(*bit_idx);
563
564 if bit == 1 {
565 Flat::from_raw(F::ONE)
566 } else {
567 Flat::from_raw(F::ZERO)
568 }
569 }
570 }
571 }
572
573 #[inline(always)]
578 pub fn evaluate(&self, point: &[Flat<F>]) -> errors::Result<Flat<F>> {
579 match self {
580 Self::Eq(t) => Ok(t.evaluate_extension(point)?),
581 _ => {
582 let num_vars = point.len();
583 let got_len = self.len();
584
585 let Some(expected_len) = 1usize.checked_shl(num_vars as u32) else {
586 return Err(Error::DomainTooLarge { num_vars }.into());
587 };
588
589 if got_len != expected_len {
590 return Err(Error::DomainSizeMismatch {
591 expected_len,
592 got_len,
593 }
594 .into());
595 }
596
597 if num_vars == 0 {
598 return Ok(self.get_at(0));
599 }
600
601 let weights = Self::expand_mle_weights(point);
602 let mut total = Flat::from_raw(F::ZERO);
603
604 for (i, w) in weights.iter().enumerate() {
605 total += self.get_at(i) * *w;
606 }
607
608 Ok(total)
609 }
610 }
611 }
612
613 pub fn expand_mle_weights(r: &[Flat<F>]) -> Vec<Flat<F>> {
614 let num_vars = r.len();
615 let size = 1 << num_vars;
616
617 let mut weights = vec![Flat::from_raw(F::ZERO); size];
618 weights[0] = Flat::from_raw(F::ONE);
619
620 for (i, &rk) in r.iter().enumerate() {
621 let one_minus_rk = Flat::from_raw(F::ONE) - rk;
622 let current_len = 1 << i;
623
624 for i in 0..current_len {
625 let w = weights[i];
626 weights[i] = w * one_minus_rk;
627 weights[current_len + i] = w * rk;
628 }
629 }
630
631 weights
632 }
633}
634
635#[cfg(test)]
636mod tests {
637 use super::*;
638 use hekate_math::{Bit, Block8, Block128, FlatPromote, HardwareField, TowerField};
639
640 type F = Block128;
641
642 #[test]
643 fn get_at_dense_returns_correct_values() {
644 let data: Vec<Flat<F>> = (0..4u128).map(|i| F::from(i * 10).to_hardware()).collect();
645 let v = PolyVariant::<F>::Dense(&data);
646
647 assert_eq!(v.get_at(0), F::from(0u128).to_hardware());
648 assert_eq!(v.get_at(1), F::from(10u128).to_hardware());
649 assert_eq!(v.get_at(2), F::from(20u128).to_hardware());
650 assert_eq!(v.get_at(3), F::from(30u128).to_hardware());
651 }
652
653 #[test]
654 fn get_at_bit_slice_promotes_to_field() {
655 let bits = vec![
656 Bit::from(0u32),
657 Bit::from(1u32),
658 Bit::from(1u32),
659 Bit::from(0u32),
660 ];
661 let v = PolyVariant::<F>::BitSlice(&bits);
662
663 assert_eq!(v.get_at(0), Flat::from_raw(F::ZERO));
664 assert_eq!(v.get_at(1), Flat::from_raw(F::ONE));
665 assert_eq!(v.get_at(2), Flat::from_raw(F::ONE));
666 assert_eq!(v.get_at(3), Flat::from_raw(F::ZERO));
667 }
668
669 #[test]
670 fn get_at_b8_slice_promotes() {
671 let data = vec![
672 Block8::from(0u8).to_hardware(),
673 Block8::from(0xFFu8).to_hardware(),
674 ];
675 let v = PolyVariant::<F>::B8Slice(&data);
676
677 assert_eq!(
678 v.get_at(0),
679 F::promote_flat(Block8::from(0u8).to_hardware())
680 );
681 assert_eq!(
682 v.get_at(1),
683 F::promote_flat(Block8::from(0xFFu8).to_hardware())
684 );
685 }
686
687 #[test]
688 fn len_matches_data_size() {
689 let data = vec![Flat::from_raw(F::ZERO); 16];
690 assert_eq!(PolyVariant::<F>::Dense(&data).len(), 16);
691
692 let bits = vec![Bit::from(0u32); 8];
693 assert_eq!(PolyVariant::<F>::BitSlice(&bits).len(), 8);
694
695 let eq = TensorProduct::new(vec![Flat::from_raw(F::ONE); 5]);
696 assert_eq!(PolyVariant::<F>::Eq(eq).len(), 32);
697
698 let empty: Vec<Flat<F>> = vec![];
699 assert!(PolyVariant::<F>::Dense(&empty).is_empty());
700 }
701
702 #[test]
703 fn evaluate_constant_polynomial() {
704 let num_vars = 3;
705 let data = vec![F::from(42u128).to_hardware(); 1 << num_vars];
706 let v = PolyVariant::<F>::Dense(&data);
707
708 let point: Vec<Flat<F>> = vec![
709 Flat::from_raw(F::from(1u128).to_hardware().into_raw()),
710 Flat::from_raw(F::from(2u128).to_hardware().into_raw()),
711 Flat::from_raw(F::from(3u128).to_hardware().into_raw()),
712 ];
713
714 let val = v.evaluate(&point).unwrap();
715 assert_eq!(val.into_raw(), F::from(42u128).to_hardware().into_raw());
716 }
717
718 #[test]
719 fn evaluate_linear_polynomial() {
720 let data = vec![F::ZERO.to_hardware(), F::from(10u128).to_hardware()];
721 let v = PolyVariant::<F>::Dense(&data);
722
723 let point = vec![Flat::from_raw(F::from(2u128).to_hardware().into_raw())];
724 let val = v.evaluate(&point).unwrap();
725 assert_eq!(val.into_raw(), F::from(20u128).to_hardware().into_raw());
726 }
727
728 #[test]
729 fn evaluate_single_row() {
730 let data = vec![F::from(99u128).to_hardware()];
731 let v = PolyVariant::<F>::Dense(&data);
732
733 let val = v.evaluate(&[]).unwrap();
734 assert_eq!(val.into_raw(), F::from(99u128).to_hardware().into_raw());
735 }
736
737 #[test]
738 fn evaluate_domain_mismatch_rejected() {
739 let data = vec![F::ZERO.to_hardware(); 4];
740 let v = PolyVariant::<F>::Dense(&data);
741
742 let point = vec![Flat::from_raw(F::ONE); 3];
743 assert!(v.evaluate(&point).is_err());
744 }
745
746 #[test]
747 fn evaluate_eq_polynomial() {
748 let r = vec![Flat::from_raw(F::ONE), Flat::from_raw(F::ZERO)];
749 let eq = PolyVariant::<F>::Eq(TensorProduct::new(r.clone()));
750
751 let val = eq.evaluate(&r).unwrap();
752 assert_eq!(val.into_raw(), F::ONE.to_hardware().into_raw());
753 }
754
755 #[test]
756 fn expand_mle_weights_single_var() {
757 let r = vec![Flat::from_raw(F::from(7u128).to_hardware().into_raw())];
758 let w = PolyVariant::<F>::expand_mle_weights(&r);
759
760 assert_eq!(w.len(), 2);
761 let r0 = r[0];
762 let one = Flat::from_raw(F::ONE);
763 assert_eq!(w[0], one - r0);
764 assert_eq!(w[1], r0);
765 }
766
767 #[test]
768 fn expand_mle_weights_zero_vars() {
769 let w = PolyVariant::<F>::expand_mle_weights(&[]);
770 assert_eq!(w.len(), 1);
771 assert_eq!(w[0], Flat::from_raw(F::ONE));
772 }
773
774 #[test]
775 fn shifted_get_at_wraps_cyclically() {
776 let data: Vec<Flat<F>> = (1..=4u128).map(|i| F::from(i).to_hardware()).collect();
777 let v = PolyVariant::<F>::Shifted(&data);
778
779 assert_eq!(v.get_at(0), F::from(2u128).to_hardware());
780 assert_eq!(v.get_at(1), F::from(3u128).to_hardware());
781 assert_eq!(v.get_at(2), F::from(4u128).to_hardware());
782 assert_eq!(v.get_at(3), F::from(1u128).to_hardware());
783 }
784
785 #[test]
786 fn composite_selector_sums_columns() {
787 let a = vec![
788 Bit::from(1u32),
789 Bit::from(0u32),
790 Bit::from(1u32),
791 Bit::from(0u32),
792 ];
793 let b = vec![
794 Bit::from(0u32),
795 Bit::from(1u32),
796 Bit::from(1u32),
797 Bit::from(0u32),
798 ];
799 let v = PolyVariant::<F>::CompositeSelector(vec![&a, &b]);
800
801 assert_eq!(v.get_at(0), Flat::from_raw(F::ONE));
802 assert_eq!(v.get_at(1), Flat::from_raw(F::ONE));
803
804 let two = F::ONE + F::ONE;
805 assert_eq!(v.get_at(2), Flat::from_raw(two));
806 assert_eq!(v.get_at(3), Flat::from_raw(F::ZERO));
807 }
808}