Skip to main content

hekate_core/poly/
variant.rs

1// SPDX-License-Identifier: Apache-2.0
2// This file is part of the hekate project.
3// Copyright (C) 2026 Andrei Kochergin <andrei@oumuamua.dev>
4// Copyright (C) 2026 Oumuamua Labs <info@oumuamua.dev>. All rights reserved.
5//
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10//     http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18use crate::errors;
19use crate::tensor::TensorProduct;
20use crate::trace::TraceCompatibleField;
21use alloc::vec;
22use alloc::vec::Vec;
23use core::fmt;
24use hekate_math::{Bit, Block8, Block16, Block32, Block64, Block128, Flat};
25use zeroize::Zeroize;
26
27/// Failures raised by `PolyVariant` operations.
28#[derive(Clone, Copy, Debug, Eq, PartialEq)]
29pub enum Error {
30    /// `1 << num_vars` overflowed `usize`.
31    DomainTooLarge { num_vars: usize },
32
33    /// Polynomial length disagrees with `2^num_vars`.
34    DomainSizeMismatch { expected_len: usize, got_len: usize },
35
36    /// Evaluation point has the
37    /// wrong number of coordinates.
38    PointDimensionMismatch { expected_len: usize, got_len: usize },
39
40    /// Variant is read-only at fold time.
41    UnsupportedFold { kind: &'static str },
42}
43
44impl fmt::Display for Error {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46        match self {
47            Self::DomainTooLarge { num_vars } => {
48                write!(
49                    f,
50                    "Virtual polynomial domain too large: num_vars={num_vars}"
51                )
52            }
53            Self::DomainSizeMismatch {
54                expected_len,
55                got_len,
56            } => write!(
57                f,
58                "Virtual polynomial domain size mismatch: expected {expected_len}, got {got_len}",
59            ),
60            Self::PointDimensionMismatch {
61                expected_len,
62                got_len,
63            } => write!(
64                f,
65                "Virtual polynomial point dimension mismatch: expected {expected_len}, got {got_len}",
66            ),
67            Self::UnsupportedFold { kind } => {
68                write!(
69                    f,
70                    "Virtual polynomial cannot be folded lazily for kind: {kind}"
71                )
72            }
73        }
74    }
75}
76
77/// Zero-copy MLE view over a physical trace column.
78/// Every variant must deliver `get_at(i)` without
79/// heap allocation, the hot path inside Sumcheck.
80#[derive(Clone, Debug, Zeroize)]
81pub enum PolyVariant<'a, F>
82where
83    F: TraceCompatibleField,
84{
85    /// Fully materialized hypercube.
86    #[zeroize(skip)]
87    Dense(&'a [Flat<F>]),
88    #[zeroize(skip)]
89    Shifted(&'a [Flat<F>]),
90
91    /// `Eq(x, r)` held lazily as a `TensorProduct`:
92    /// `O(num_vars)` memory, `O(1)` fold.
93    Eq(TensorProduct<F>),
94
95    /// `(data[i] >> bit_idx) & 1` on a `B8` column.
96    #[zeroize(skip)]
97    PackedBitB8 {
98        data: &'a [Flat<Block8>],
99        bit_idx: usize,
100    },
101
102    /// `(data[i] >> bit_idx) & 1` on a `B16` column.
103    #[zeroize(skip)]
104    PackedBitB16 {
105        data: &'a [Flat<Block16>],
106        bit_idx: usize,
107    },
108
109    /// `(data[i] >> bit_idx) & 1` on a `B32` column.
110    #[zeroize(skip)]
111    PackedBitB32 {
112        data: &'a [Flat<Block32>],
113        bit_idx: usize,
114    },
115
116    /// `(data[i] >> bit_idx) & 1` on a `B64` column.
117    #[zeroize(skip)]
118    PackedBitB64 {
119        data: &'a [Flat<Block64>],
120        bit_idx: usize,
121    },
122
123    /// `S(i) = Σ_k cols[k][i]` over boolean
124    /// columns (one-hot selector groups).
125    #[zeroize(skip)]
126    CompositeSelector(Vec<&'a [Bit]>),
127
128    /// Mask that is `1` everywhere and
129    /// `1 - product_of_challenges` at
130    /// index `2^N - 1`. Used to kill
131    /// cross-row wrap in Sumcheck.
132    #[zeroize(skip)]
133    TransitionMask {
134        num_vars: usize,
135        product_of_challenges: F,
136    },
137
138    // ==============================================================
139    // Indirect:
140    // P(x) = data[indices[x]].
141    // Zero-copy permutations and arbitrary wiring.
142    // ==============================================================
143    #[zeroize(skip)]
144    IndirectBit {
145        data: &'a [Bit],
146        indices: &'a [usize],
147    },
148    #[zeroize(skip)]
149    IndirectB8 {
150        data: &'a [Flat<Block8>],
151        indices: &'a [usize],
152    },
153    #[zeroize(skip)]
154    IndirectB16 {
155        data: &'a [Flat<Block16>],
156        indices: &'a [usize],
157    },
158    #[zeroize(skip)]
159    IndirectB32 {
160        data: &'a [Flat<Block32>],
161        indices: &'a [usize],
162    },
163    #[zeroize(skip)]
164    IndirectB64 {
165        data: &'a [Flat<Block64>],
166        indices: &'a [usize],
167    },
168    #[zeroize(skip)]
169    IndirectB128 {
170        data: &'a [Flat<Block128>],
171        indices: &'a [usize],
172    },
173
174    // ====================================================
175    // Stride Access:
176    // P(i) = data[start + i * step].
177    // ====================================================
178    #[zeroize(skip)]
179    StrideBit {
180        data: &'a [Bit],
181        start: usize,
182        step: usize,
183        len: usize,
184    },
185    #[zeroize(skip)]
186    StrideB8 {
187        data: &'a [Flat<Block8>],
188        start: usize,
189        step: usize,
190        len: usize,
191    },
192    #[zeroize(skip)]
193    StrideB16 {
194        data: &'a [Flat<Block16>],
195        start: usize,
196        step: usize,
197        len: usize,
198    },
199    #[zeroize(skip)]
200    StrideB32 {
201        data: &'a [Flat<Block32>],
202        start: usize,
203        step: usize,
204        len: usize,
205    },
206    #[zeroize(skip)]
207    StrideB64 {
208        data: &'a [Flat<Block64>],
209        start: usize,
210        step: usize,
211        len: usize,
212    },
213    #[zeroize(skip)]
214    StrideB128 {
215        data: &'a [Flat<Block128>],
216        start: usize,
217        step: usize,
218        len: usize,
219    },
220
221    // ====================================================
222    // Cyclic Rotation:
223    // P(i) = data[(i + rotation) % len].
224    // Uses bitwise masking for modulo (len must be power of 2).
225    // ====================================================
226    #[zeroize(skip)]
227    RotationBit { data: &'a [Bit], rotation: usize },
228    #[zeroize(skip)]
229    RotationB8 {
230        data: &'a [Flat<Block8>],
231        rotation: usize,
232    },
233    #[zeroize(skip)]
234    RotationB16 {
235        data: &'a [Flat<Block16>],
236        rotation: usize,
237    },
238    #[zeroize(skip)]
239    RotationB32 {
240        data: &'a [Flat<Block32>],
241        rotation: usize,
242    },
243    #[zeroize(skip)]
244    RotationB64 {
245        data: &'a [Flat<Block64>],
246        rotation: usize,
247    },
248    #[zeroize(skip)]
249    RotationB128 {
250        data: &'a [Flat<Block128>],
251        rotation: usize,
252    },
253
254    // ====================================================
255    // Compressed slice views (JIT-promoted to F)
256    // ====================================================
257    #[zeroize(skip)]
258    BitSlice(&'a [Bit]),
259    #[zeroize(skip)]
260    B8Slice(&'a [Flat<Block8>]),
261    #[zeroize(skip)]
262    B16Slice(&'a [Flat<Block16>]),
263    #[zeroize(skip)]
264    B32Slice(&'a [Flat<Block32>]),
265    #[zeroize(skip)]
266    B64Slice(&'a [Flat<Block64>]),
267    #[zeroize(skip)]
268    B128Slice(&'a [Flat<Block128>]),
269
270    // ====================================================
271    // Same-width views shifted by one row (cyclic)
272    // ====================================================
273    #[zeroize(skip)]
274    ShiftedBitSlice(&'a [Bit]),
275    #[zeroize(skip)]
276    ShiftedB8Slice(&'a [Flat<Block8>]),
277    #[zeroize(skip)]
278    ShiftedB16Slice(&'a [Flat<Block16>]),
279    #[zeroize(skip)]
280    ShiftedB32Slice(&'a [Flat<Block32>]),
281    #[zeroize(skip)]
282    ShiftedB64Slice(&'a [Flat<Block64>]),
283    #[zeroize(skip)]
284    ShiftedB128Slice(&'a [Flat<Block128>]),
285
286    #[zeroize(skip)]
287    ShiftedPackedBitB8 {
288        data: &'a [Flat<Block8>],
289        bit_idx: usize,
290    },
291    #[zeroize(skip)]
292    ShiftedPackedBitB16 {
293        data: &'a [Flat<Block16>],
294        bit_idx: usize,
295    },
296    #[zeroize(skip)]
297    ShiftedPackedBitB32 {
298        data: &'a [Flat<Block32>],
299        bit_idx: usize,
300    },
301    #[zeroize(skip)]
302    ShiftedPackedBitB64 {
303        data: &'a [Flat<Block64>],
304        bit_idx: usize,
305    },
306}
307
308impl<'a, F> PolyVariant<'a, F>
309where
310    F: TraceCompatibleField,
311{
312    /// Number of hypercube points (polynomial length).
313    pub fn len(&self) -> usize {
314        match self {
315            Self::Dense(h) => h.len(),
316            Self::Shifted(h) => h.len(),
317            Self::Eq(t) => 1 << t.num_vars(),
318            Self::PackedBitB8 { data, .. } => data.len(),
319            Self::PackedBitB16 { data, .. } => data.len(),
320            Self::PackedBitB32 { data, .. } => data.len(),
321            Self::PackedBitB64 { data, .. } => data.len(),
322            Self::TransitionMask { num_vars, .. } => 1 << num_vars,
323            Self::CompositeSelector(cols) => {
324                if cols.is_empty() {
325                    0
326                } else {
327                    cols[0].len()
328                }
329            }
330            Self::IndirectBit { indices, .. } => indices.len(),
331            Self::IndirectB8 { indices, .. } => indices.len(),
332            Self::IndirectB16 { indices, .. } => indices.len(),
333            Self::IndirectB32 { indices, .. } => indices.len(),
334            Self::IndirectB64 { indices, .. } => indices.len(),
335            Self::IndirectB128 { indices, .. } => indices.len(),
336            Self::StrideBit { len, .. } => *len,
337            Self::StrideB8 { len, .. } => *len,
338            Self::StrideB16 { len, .. } => *len,
339            Self::StrideB32 { len, .. } => *len,
340            Self::StrideB64 { len, .. } => *len,
341            Self::StrideB128 { len, .. } => *len,
342            Self::RotationBit { data, .. } => data.len(),
343            Self::RotationB8 { data, .. } => data.len(),
344            Self::RotationB16 { data, .. } => data.len(),
345            Self::RotationB32 { data, .. } => data.len(),
346            Self::RotationB64 { data, .. } => data.len(),
347            Self::RotationB128 { data, .. } => data.len(),
348            Self::BitSlice(h) => h.len(),
349            Self::B8Slice(h) => h.len(),
350            Self::B16Slice(h) => h.len(),
351            Self::B32Slice(h) => h.len(),
352            Self::B64Slice(h) => h.len(),
353            Self::B128Slice(h) => h.len(),
354            Self::ShiftedBitSlice(h) => h.len(),
355            Self::ShiftedB8Slice(h) => h.len(),
356            Self::ShiftedB16Slice(h) => h.len(),
357            Self::ShiftedB32Slice(h) => h.len(),
358            Self::ShiftedB64Slice(h) => h.len(),
359            Self::ShiftedB128Slice(h) => h.len(),
360            Self::ShiftedPackedBitB8 { data, .. } => data.len(),
361            Self::ShiftedPackedBitB16 { data, .. } => data.len(),
362            Self::ShiftedPackedBitB32 { data, .. } => data.len(),
363            Self::ShiftedPackedBitB64 { data, .. } => data.len(),
364        }
365    }
366
367    pub fn is_empty(&self) -> bool {
368        self.len() == 0
369    }
370
371    /// Read the hypercube value at `index`,
372    /// lifted into `Flat<F>`. `O(1)` for
373    /// slice variants, `O(num_vars)` for `Eq`.
374    #[inline(always)]
375    pub fn get_at(&self, index: usize) -> Flat<F> {
376        match self {
377            Self::Dense(h) => h[index],
378            Self::Shifted(h) => {
379                let len = h.len();
380                let next_idx = if index + 1 == len { 0 } else { index + 1 };
381                h[next_idx]
382            }
383            Self::Eq(t) => t.evaluate_at_index(index),
384            Self::PackedBitB8 { data, bit_idx } => {
385                let bit = data[index].tower_bit(*bit_idx);
386                if bit == 1 {
387                    Flat::from_raw(F::ONE)
388                } else {
389                    Flat::from_raw(F::ZERO)
390                }
391            }
392            Self::PackedBitB16 { data, bit_idx } => {
393                let bit = data[index].tower_bit(*bit_idx);
394                if bit == 1 {
395                    Flat::from_raw(F::ONE)
396                } else {
397                    Flat::from_raw(F::ZERO)
398                }
399            }
400            Self::PackedBitB32 { data, bit_idx } => {
401                let bit = data[index].tower_bit(*bit_idx);
402                if bit == 1 {
403                    Flat::from_raw(F::ONE)
404                } else {
405                    Flat::from_raw(F::ZERO)
406                }
407            }
408            Self::PackedBitB64 { data, bit_idx } => {
409                let bit = data[index].tower_bit(*bit_idx);
410                if bit == 1 {
411                    Flat::from_raw(F::ONE)
412                } else {
413                    Flat::from_raw(F::ZERO)
414                }
415            }
416            Self::TransitionMask {
417                num_vars,
418                product_of_challenges,
419            } => {
420                let last_idx: usize = (1 << num_vars) - 1;
421                if index == last_idx {
422                    Flat::from_raw(F::ONE - *product_of_challenges)
423                } else {
424                    Flat::from_raw(F::ONE)
425                }
426            }
427            Self::CompositeSelector(cols) => {
428                let mut sum = Flat::from_raw(F::default());
429                for col in cols {
430                    if col[index].0 == 1 {
431                        sum += Flat::from_raw(F::ONE);
432                    }
433                }
434
435                sum
436            }
437
438            Self::IndirectBit { data, indices } => Flat::from_raw(F::from(data[indices[index]])),
439            Self::IndirectB8 { data, indices } => F::promote_flat(data[indices[index]]),
440            Self::IndirectB16 { data, indices } => F::promote_flat(data[indices[index]]),
441            Self::IndirectB32 { data, indices } => F::promote_flat(data[indices[index]]),
442            Self::IndirectB64 { data, indices } => F::promote_flat(data[indices[index]]),
443            Self::IndirectB128 { data, indices } => F::promote_flat(data[indices[index]]),
444
445            Self::StrideBit {
446                data, start, step, ..
447            } => Flat::from_raw(F::from(data[start + index * step])),
448            Self::StrideB8 {
449                data, start, step, ..
450            } => F::promote_flat(data[start + index * step]),
451            Self::StrideB16 {
452                data, start, step, ..
453            } => F::promote_flat(data[start + index * step]),
454            Self::StrideB32 {
455                data, start, step, ..
456            } => F::promote_flat(data[start + index * step]),
457            Self::StrideB64 {
458                data, start, step, ..
459            } => F::promote_flat(data[start + index * step]),
460            Self::StrideB128 {
461                data, start, step, ..
462            } => F::promote_flat(data[start + index * step]),
463
464            Self::RotationBit { data, rotation } => {
465                Flat::from_raw(F::from(data[(index + rotation) & (data.len() - 1)]))
466            }
467            Self::RotationB8 { data, rotation } => {
468                F::promote_flat(data[(index + rotation) & (data.len() - 1)])
469            }
470            Self::RotationB16 { data, rotation } => {
471                F::promote_flat(data[(index + rotation) & (data.len() - 1)])
472            }
473            Self::RotationB32 { data, rotation } => {
474                F::promote_flat(data[(index + rotation) & (data.len() - 1)])
475            }
476            Self::RotationB64 { data, rotation } => {
477                F::promote_flat(data[(index + rotation) & (data.len() - 1)])
478            }
479            Self::RotationB128 { data, rotation } => {
480                F::promote_flat(data[(index + rotation) & (data.len() - 1)])
481            }
482
483            Self::BitSlice(s) => Flat::from_raw(F::from(s[index])),
484            Self::B8Slice(s) => F::promote_flat(s[index]),
485            Self::B16Slice(s) => F::promote_flat(s[index]),
486            Self::B32Slice(s) => F::promote_flat(s[index]),
487            Self::B64Slice(s) => F::promote_flat(s[index]),
488            Self::B128Slice(s) => F::promote_flat(s[index]),
489
490            Self::ShiftedBitSlice(s) => {
491                let len = s.len();
492                let next_idx = if index + 1 == len { 0 } else { index + 1 };
493
494                Flat::from_raw(F::from(s[next_idx]))
495            }
496            Self::ShiftedB8Slice(s) => {
497                let len = s.len();
498                let next_idx = if index + 1 == len { 0 } else { index + 1 };
499
500                F::promote_flat(s[next_idx])
501            }
502            Self::ShiftedB16Slice(s) => {
503                let len = s.len();
504                let next_idx = if index + 1 == len { 0 } else { index + 1 };
505
506                F::promote_flat(s[next_idx])
507            }
508            Self::ShiftedB32Slice(s) => {
509                let len = s.len();
510                let next_idx = if index + 1 == len { 0 } else { index + 1 };
511
512                F::promote_flat(s[next_idx])
513            }
514            Self::ShiftedB64Slice(s) => {
515                let len = s.len();
516                let next_idx = if index + 1 == len { 0 } else { index + 1 };
517
518                F::promote_flat(s[next_idx])
519            }
520            Self::ShiftedB128Slice(s) => {
521                let len = s.len();
522                let next_idx = if index + 1 == len { 0 } else { index + 1 };
523
524                F::promote_flat(s[next_idx])
525            }
526            Self::ShiftedPackedBitB8 { data, bit_idx } => {
527                let len = data.len();
528                let next_idx = if index + 1 == len { 0 } else { index + 1 };
529                let bit = data[next_idx].tower_bit(*bit_idx);
530
531                if bit == 1 {
532                    Flat::from_raw(F::ONE)
533                } else {
534                    Flat::from_raw(F::ZERO)
535                }
536            }
537            Self::ShiftedPackedBitB16 { data, bit_idx } => {
538                let len = data.len();
539                let next_idx = if index + 1 == len { 0 } else { index + 1 };
540                let bit = data[next_idx].tower_bit(*bit_idx);
541
542                if bit == 1 {
543                    Flat::from_raw(F::ONE)
544                } else {
545                    Flat::from_raw(F::ZERO)
546                }
547            }
548            Self::ShiftedPackedBitB32 { data, bit_idx } => {
549                let len = data.len();
550                let next_idx = if index + 1 == len { 0 } else { index + 1 };
551                let bit = data[next_idx].tower_bit(*bit_idx);
552
553                if bit == 1 {
554                    Flat::from_raw(F::ONE)
555                } else {
556                    Flat::from_raw(F::ZERO)
557                }
558            }
559            Self::ShiftedPackedBitB64 { data, bit_idx } => {
560                let len = data.len();
561                let next_idx = if index + 1 == len { 0 } else { index + 1 };
562                let bit = data[next_idx].tower_bit(*bit_idx);
563
564                if bit == 1 {
565                    Flat::from_raw(F::ONE)
566                } else {
567                    Flat::from_raw(F::ZERO)
568                }
569            }
570        }
571    }
572
573    /// Evaluate the MLE at an arbitrary point.
574    /// `Eq` delegates to `TensorProduct`;
575    /// every other variant expands MLE weights
576    /// once and does a single `get_at` sweep.
577    #[inline(always)]
578    pub fn evaluate(&self, point: &[Flat<F>]) -> errors::Result<Flat<F>> {
579        match self {
580            Self::Eq(t) => Ok(t.evaluate_extension(point)?),
581            _ => {
582                let num_vars = point.len();
583                let got_len = self.len();
584
585                let Some(expected_len) = 1usize.checked_shl(num_vars as u32) else {
586                    return Err(Error::DomainTooLarge { num_vars }.into());
587                };
588
589                if got_len != expected_len {
590                    return Err(Error::DomainSizeMismatch {
591                        expected_len,
592                        got_len,
593                    }
594                    .into());
595                }
596
597                if num_vars == 0 {
598                    return Ok(self.get_at(0));
599                }
600
601                let weights = Self::expand_mle_weights(point);
602                let mut total = Flat::from_raw(F::ZERO);
603
604                for (i, w) in weights.iter().enumerate() {
605                    total += self.get_at(i) * *w;
606                }
607
608                Ok(total)
609            }
610        }
611    }
612
613    pub fn expand_mle_weights(r: &[Flat<F>]) -> Vec<Flat<F>> {
614        let num_vars = r.len();
615        let size = 1 << num_vars;
616
617        let mut weights = vec![Flat::from_raw(F::ZERO); size];
618        weights[0] = Flat::from_raw(F::ONE);
619
620        for (i, &rk) in r.iter().enumerate() {
621            let one_minus_rk = Flat::from_raw(F::ONE) - rk;
622            let current_len = 1 << i;
623
624            for i in 0..current_len {
625                let w = weights[i];
626                weights[i] = w * one_minus_rk;
627                weights[current_len + i] = w * rk;
628            }
629        }
630
631        weights
632    }
633}
634
635#[cfg(test)]
636mod tests {
637    use super::*;
638    use hekate_math::{Bit, Block8, Block128, FlatPromote, HardwareField, TowerField};
639
640    type F = Block128;
641
642    #[test]
643    fn get_at_dense_returns_correct_values() {
644        let data: Vec<Flat<F>> = (0..4u128).map(|i| F::from(i * 10).to_hardware()).collect();
645        let v = PolyVariant::<F>::Dense(&data);
646
647        assert_eq!(v.get_at(0), F::from(0u128).to_hardware());
648        assert_eq!(v.get_at(1), F::from(10u128).to_hardware());
649        assert_eq!(v.get_at(2), F::from(20u128).to_hardware());
650        assert_eq!(v.get_at(3), F::from(30u128).to_hardware());
651    }
652
653    #[test]
654    fn get_at_bit_slice_promotes_to_field() {
655        let bits = vec![
656            Bit::from(0u32),
657            Bit::from(1u32),
658            Bit::from(1u32),
659            Bit::from(0u32),
660        ];
661        let v = PolyVariant::<F>::BitSlice(&bits);
662
663        assert_eq!(v.get_at(0), Flat::from_raw(F::ZERO));
664        assert_eq!(v.get_at(1), Flat::from_raw(F::ONE));
665        assert_eq!(v.get_at(2), Flat::from_raw(F::ONE));
666        assert_eq!(v.get_at(3), Flat::from_raw(F::ZERO));
667    }
668
669    #[test]
670    fn get_at_b8_slice_promotes() {
671        let data = vec![
672            Block8::from(0u8).to_hardware(),
673            Block8::from(0xFFu8).to_hardware(),
674        ];
675        let v = PolyVariant::<F>::B8Slice(&data);
676
677        assert_eq!(
678            v.get_at(0),
679            F::promote_flat(Block8::from(0u8).to_hardware())
680        );
681        assert_eq!(
682            v.get_at(1),
683            F::promote_flat(Block8::from(0xFFu8).to_hardware())
684        );
685    }
686
687    #[test]
688    fn len_matches_data_size() {
689        let data = vec![Flat::from_raw(F::ZERO); 16];
690        assert_eq!(PolyVariant::<F>::Dense(&data).len(), 16);
691
692        let bits = vec![Bit::from(0u32); 8];
693        assert_eq!(PolyVariant::<F>::BitSlice(&bits).len(), 8);
694
695        let eq = TensorProduct::new(vec![Flat::from_raw(F::ONE); 5]);
696        assert_eq!(PolyVariant::<F>::Eq(eq).len(), 32);
697
698        let empty: Vec<Flat<F>> = vec![];
699        assert!(PolyVariant::<F>::Dense(&empty).is_empty());
700    }
701
702    #[test]
703    fn evaluate_constant_polynomial() {
704        let num_vars = 3;
705        let data = vec![F::from(42u128).to_hardware(); 1 << num_vars];
706        let v = PolyVariant::<F>::Dense(&data);
707
708        let point: Vec<Flat<F>> = vec![
709            Flat::from_raw(F::from(1u128).to_hardware().into_raw()),
710            Flat::from_raw(F::from(2u128).to_hardware().into_raw()),
711            Flat::from_raw(F::from(3u128).to_hardware().into_raw()),
712        ];
713
714        let val = v.evaluate(&point).unwrap();
715        assert_eq!(val.into_raw(), F::from(42u128).to_hardware().into_raw());
716    }
717
718    #[test]
719    fn evaluate_linear_polynomial() {
720        let data = vec![F::ZERO.to_hardware(), F::from(10u128).to_hardware()];
721        let v = PolyVariant::<F>::Dense(&data);
722
723        let point = vec![Flat::from_raw(F::from(2u128).to_hardware().into_raw())];
724        let val = v.evaluate(&point).unwrap();
725        assert_eq!(val.into_raw(), F::from(20u128).to_hardware().into_raw());
726    }
727
728    #[test]
729    fn evaluate_single_row() {
730        let data = vec![F::from(99u128).to_hardware()];
731        let v = PolyVariant::<F>::Dense(&data);
732
733        let val = v.evaluate(&[]).unwrap();
734        assert_eq!(val.into_raw(), F::from(99u128).to_hardware().into_raw());
735    }
736
737    #[test]
738    fn evaluate_domain_mismatch_rejected() {
739        let data = vec![F::ZERO.to_hardware(); 4];
740        let v = PolyVariant::<F>::Dense(&data);
741
742        let point = vec![Flat::from_raw(F::ONE); 3];
743        assert!(v.evaluate(&point).is_err());
744    }
745
746    #[test]
747    fn evaluate_eq_polynomial() {
748        let r = vec![Flat::from_raw(F::ONE), Flat::from_raw(F::ZERO)];
749        let eq = PolyVariant::<F>::Eq(TensorProduct::new(r.clone()));
750
751        let val = eq.evaluate(&r).unwrap();
752        assert_eq!(val.into_raw(), F::ONE.to_hardware().into_raw());
753    }
754
755    #[test]
756    fn expand_mle_weights_single_var() {
757        let r = vec![Flat::from_raw(F::from(7u128).to_hardware().into_raw())];
758        let w = PolyVariant::<F>::expand_mle_weights(&r);
759
760        assert_eq!(w.len(), 2);
761        let r0 = r[0];
762        let one = Flat::from_raw(F::ONE);
763        assert_eq!(w[0], one - r0);
764        assert_eq!(w[1], r0);
765    }
766
767    #[test]
768    fn expand_mle_weights_zero_vars() {
769        let w = PolyVariant::<F>::expand_mle_weights(&[]);
770        assert_eq!(w.len(), 1);
771        assert_eq!(w[0], Flat::from_raw(F::ONE));
772    }
773
774    #[test]
775    fn shifted_get_at_wraps_cyclically() {
776        let data: Vec<Flat<F>> = (1..=4u128).map(|i| F::from(i).to_hardware()).collect();
777        let v = PolyVariant::<F>::Shifted(&data);
778
779        assert_eq!(v.get_at(0), F::from(2u128).to_hardware());
780        assert_eq!(v.get_at(1), F::from(3u128).to_hardware());
781        assert_eq!(v.get_at(2), F::from(4u128).to_hardware());
782        assert_eq!(v.get_at(3), F::from(1u128).to_hardware());
783    }
784
785    #[test]
786    fn composite_selector_sums_columns() {
787        let a = vec![
788            Bit::from(1u32),
789            Bit::from(0u32),
790            Bit::from(1u32),
791            Bit::from(0u32),
792        ];
793        let b = vec![
794            Bit::from(0u32),
795            Bit::from(1u32),
796            Bit::from(1u32),
797            Bit::from(0u32),
798        ];
799        let v = PolyVariant::<F>::CompositeSelector(vec![&a, &b]);
800
801        assert_eq!(v.get_at(0), Flat::from_raw(F::ONE));
802        assert_eq!(v.get_at(1), Flat::from_raw(F::ONE));
803
804        let two = F::ONE + F::ONE;
805        assert_eq!(v.get_at(2), Flat::from_raw(two));
806        assert_eq!(v.get_at(3), Flat::from_raw(F::ZERO));
807    }
808}