Skip to main content

nc_polynomial/
params.rs

1// MIT License
2//
3// Copyright (c) 2026 Raja Lehtihet & Wael El Oraiby
4//
5// Permission is hereby granted, free of charge, to any person obtaining a copy
6// of this software and associated documentation files (the "Software"), to deal
7// in the Software without restriction, including without limitation the rights
8// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9// copies of the Software, and to permit persons to whom the Software is
10// furnished to do so, subject to the following conditions:
11//
12// The above copyright notice and this permission notice shall be included in all
13// copies or substantial portions of the Software.
14//
15// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21// SOFTWARE.
22use core::fmt;
23use std::sync::Arc;
24
25use crate::polynomial::{Polynomial, PolynomialError};
26
27/// Fixed ring settings used by cryptographic constructions over `R_q = Z_q[x] / (f(x))`.
28///
29/// Construction validates:
30/// - polynomial metadata and modulus,
31/// - modulus polynomial properties,
32/// - NTT suitability for multiplying degree-bounded polynomials,
33/// - primitive root compatibility with the chosen NTT length.
34#[derive(Debug, Clone, PartialEq, Eq)]
35pub struct Params {
36    max_degree: usize,
37    modulus: u64,
38    modulus_poly: Polynomial,
39    primitive_root: u64,
40    ntt_length: usize,
41}
42
43/// Errors returned while validating [`Params`].
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum ParamsError {
46    /// Underlying polynomial validation failed.
47    InvalidPolynomial(PolynomialError),
48    /// Modulus polynomial cannot be zero.
49    ZeroModulusPolynomial,
50    /// Modulus polynomial must be non-constant.
51    ConstantModulusPolynomial,
52    /// `degree(f)` must match `max_degree` for a consistent fixed ring shape.
53    ModulusPolynomialDegreeMismatch { degree: usize, max_degree: usize },
54    /// Leading coefficient of `f(x)` must be invertible in `Z_q`.
55    NonInvertibleModulusPolynomialLead { coefficient: u64, modulus: u64 },
56    /// Derived transform size overflowed for the selected degree bound.
57    NttLengthOverflow { max_degree: usize },
58    /// Derived NTT length is unsupported by the modulus (`length` must divide `q - 1`).
59    NttLengthUnsupported { length: usize, modulus: u64 },
60    /// Primitive root does not induce a valid root of unity for the derived NTT length.
61    InvalidPrimitiveRoot {
62        primitive_root: u64,
63        length: usize,
64        modulus: u64,
65    },
66}
67
68impl fmt::Display for ParamsError {
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70        match self {
71            Self::InvalidPolynomial(err) => write!(f, "invalid polynomial parameters: {err}"),
72            Self::ZeroModulusPolynomial => write!(f, "modulus polynomial cannot be zero"),
73            Self::ConstantModulusPolynomial => {
74                write!(f, "modulus polynomial must have degree at least 1")
75            }
76            Self::ModulusPolynomialDegreeMismatch { degree, max_degree } => write!(
77                f,
78                "modulus polynomial degree {degree} must match max_degree {max_degree}"
79            ),
80            Self::NonInvertibleModulusPolynomialLead {
81                coefficient,
82                modulus,
83            } => write!(
84                f,
85                "modulus polynomial leading coefficient {coefficient} is not invertible modulo {modulus}"
86            ),
87            Self::NttLengthOverflow { max_degree } => write!(
88                f,
89                "failed to derive NTT length for max_degree {max_degree} due to overflow"
90            ),
91            Self::NttLengthUnsupported { length, modulus } => write!(
92                f,
93                "derived NTT length {length} is unsupported by modulus {modulus} (length must divide q-1)"
94            ),
95            Self::InvalidPrimitiveRoot {
96                primitive_root,
97                length,
98                modulus,
99            } => write!(
100                f,
101                "primitive root {primitive_root} is invalid for NTT length {length} under modulus {modulus}"
102            ),
103        }
104    }
105}
106
107impl std::error::Error for ParamsError {}
108
109impl Params {
110    /// Builds validated ring parameters.
111    ///
112    /// `modulus_poly_coeffs` must define a non-constant polynomial whose degree equals
113    /// `max_degree`, with invertible leading coefficient in `Z_q`.
114    pub fn new(
115        max_degree: usize,
116        modulus: u64,
117        modulus_poly_coeffs: &[u64],
118        primitive_root: u64,
119    ) -> Result<Self, ParamsError> {
120        let modulus_poly = Polynomial::new(max_degree, modulus, modulus_poly_coeffs)
121            .map_err(ParamsError::InvalidPolynomial)?;
122
123        let Some(modulus_degree) = modulus_poly.degree() else {
124            return Err(ParamsError::ZeroModulusPolynomial);
125        };
126
127        if modulus_degree == 0 {
128            return Err(ParamsError::ConstantModulusPolynomial);
129        }
130
131        if modulus_degree != max_degree {
132            return Err(ParamsError::ModulusPolynomialDegreeMismatch {
133                degree: modulus_degree,
134                max_degree,
135            });
136        }
137
138        let leading = modulus_poly
139            .coeff(modulus_degree)
140            .expect("degree lookup must have a leading coefficient");
141        if mod_inverse(leading, modulus).is_none() {
142            return Err(ParamsError::NonInvertibleModulusPolynomialLead {
143                coefficient: leading,
144                modulus,
145            });
146        }
147
148        let out_len = max_degree
149            .checked_mul(2)
150            .and_then(|value| value.checked_add(1))
151            .ok_or(ParamsError::NttLengthOverflow { max_degree })?;
152        let ntt_length = out_len
153            .checked_next_power_of_two()
154            .ok_or(ParamsError::NttLengthOverflow { max_degree })?;
155        let ntt_length_u64 =
156            u64::try_from(ntt_length).map_err(|_| ParamsError::NttLengthOverflow { max_degree })?;
157
158        if (modulus - 1) % ntt_length_u64 != 0 {
159            return Err(ParamsError::NttLengthUnsupported {
160                length: ntt_length,
161                modulus,
162            });
163        }
164
165        let root = mod_pow(primitive_root, (modulus - 1) / ntt_length_u64, modulus);
166        let is_valid_root = mod_pow(root, ntt_length_u64, modulus) == 1
167            && (ntt_length == 1 || mod_pow(root, (ntt_length / 2) as u64, modulus) != 1);
168        if !is_valid_root {
169            return Err(ParamsError::InvalidPrimitiveRoot {
170                primitive_root,
171                length: ntt_length,
172                modulus,
173            });
174        }
175
176        Ok(Self {
177            max_degree,
178            modulus,
179            modulus_poly,
180            primitive_root,
181            ntt_length,
182        })
183    }
184
185    /// Maximum supported polynomial degree.
186    pub fn max_degree(&self) -> usize {
187        self.max_degree
188    }
189
190    /// Coefficient modulus `q`.
191    pub fn modulus(&self) -> u64 {
192        self.modulus
193    }
194
195    /// Ring modulus polynomial `f(x)`.
196    pub fn modulus_poly(&self) -> &Polynomial {
197        &self.modulus_poly
198    }
199
200    /// Primitive root used to derive NTT roots of unity.
201    pub fn primitive_root(&self) -> u64 {
202        self.primitive_root
203    }
204
205    /// NTT length derived from `max_degree` for exact product convolutions.
206    pub fn ntt_length(&self) -> usize {
207        self.ntt_length
208    }
209}
210
211/// Runtime ring context that carries validated [`Params`] for operations in
212/// `R_q = Z_q[x] / (f(x))`.
213#[derive(Debug, Clone, PartialEq, Eq)]
214pub struct RingContext {
215    params: Arc<Params>,
216}
217
218impl RingContext {
219    /// Creates a context from already-validated [`Params`].
220    pub fn new(params: Params) -> Self {
221        Self {
222            params: Arc::new(params),
223        }
224    }
225
226    /// Builds and validates parameters, then returns a context.
227    pub fn from_parts(
228        max_degree: usize,
229        modulus: u64,
230        modulus_poly_coeffs: &[u64],
231        primitive_root: u64,
232    ) -> Result<Self, ParamsError> {
233        Ok(Self::new(Params::new(
234            max_degree,
235            modulus,
236            modulus_poly_coeffs,
237            primitive_root,
238        )?))
239    }
240
241    /// Returns validated parameters.
242    pub fn params(&self) -> &Params {
243        self.params.as_ref()
244    }
245
246    /// Maximum supported polynomial degree.
247    pub fn max_degree(&self) -> usize {
248        self.params.max_degree()
249    }
250
251    /// Coefficient modulus `q`.
252    pub fn modulus(&self) -> u64 {
253        self.params.modulus()
254    }
255
256    /// Ring modulus polynomial `f(x)`.
257    pub fn modulus_poly(&self) -> &Polynomial {
258        self.params.modulus_poly()
259    }
260
261    /// Primitive root used by NTT-based operations.
262    pub fn primitive_root(&self) -> u64 {
263        self.params.primitive_root()
264    }
265
266    /// Derived NTT transform length.
267    pub fn ntt_length(&self) -> usize {
268        self.params.ntt_length()
269    }
270
271    /// Builds a polynomial compatible with this context.
272    pub fn polynomial(&self, coeffs: &[u64]) -> Result<Polynomial, PolynomialError> {
273        Polynomial::new(self.max_degree(), self.modulus(), coeffs)
274    }
275
276    /// Returns additive identity in this ring.
277    pub fn zero(&self) -> Polynomial {
278        Polynomial::zero(self.max_degree(), self.modulus())
279            .expect("validated params guarantee a valid modulus")
280    }
281
282    /// Returns multiplicative identity in this ring.
283    pub fn one(&self) -> Polynomial {
284        Polynomial::one(self.max_degree(), self.modulus())
285            .expect("validated params guarantee a valid modulus")
286    }
287
288    /// Builds a canonical ring element from coefficients by reducing modulo `f(x)`.
289    pub fn element(&self, coeffs: &[u64]) -> Result<RingElem, PolynomialError> {
290        let poly = self.polynomial(coeffs)?;
291        self.element_from_polynomial(poly)
292    }
293
294    /// Converts a compatible polynomial into a canonical ring element.
295    pub fn element_from_polynomial(&self, poly: Polynomial) -> Result<RingElem, PolynomialError> {
296        self.ensure_member(&poly)?;
297        RingElem::new(Arc::clone(&self.params), poly)
298    }
299
300    /// Returns additive identity as a canonical ring element.
301    pub fn zero_element(&self) -> RingElem {
302        self.element_from_polynomial(self.zero())
303            .expect("validated params guarantee ring element creation")
304    }
305
306    /// Returns multiplicative identity as a canonical ring element.
307    pub fn one_element(&self) -> RingElem {
308        self.element_from_polynomial(self.one())
309            .expect("validated params guarantee ring element creation")
310    }
311
312    /// Reduces polynomial modulo the configured ring polynomial.
313    pub fn reduce(&self, poly: &Polynomial) -> Result<Polynomial, PolynomialError> {
314        self.ensure_member(poly)?;
315        poly.rem_mod_poly(self.modulus_poly())
316    }
317
318    /// Adds two ring elements and reduces modulo `f(x)`.
319    pub fn add(&self, lhs: &Polynomial, rhs: &Polynomial) -> Result<Polynomial, PolynomialError> {
320        self.ensure_member(lhs)?;
321        self.ensure_member(rhs)?;
322        lhs.add_mod_poly(rhs, self.modulus_poly())
323    }
324
325    /// Subtracts two ring elements and reduces modulo `f(x)`.
326    pub fn sub(&self, lhs: &Polynomial, rhs: &Polynomial) -> Result<Polynomial, PolynomialError> {
327        self.ensure_member(lhs)?;
328        self.ensure_member(rhs)?;
329        lhs.sub_mod_poly(rhs, self.modulus_poly())
330    }
331
332    /// Multiplies two ring elements and reduces modulo `f(x)`.
333    pub fn mul(&self, lhs: &Polynomial, rhs: &Polynomial) -> Result<Polynomial, PolynomialError> {
334        self.ensure_member(lhs)?;
335        self.ensure_member(rhs)?;
336        lhs.mul_mod_poly(rhs, self.modulus_poly())
337    }
338
339    fn ensure_member(&self, poly: &Polynomial) -> Result<(), PolynomialError> {
340        if poly.max_degree() != self.max_degree() || poly.modulus() != self.modulus() {
341            return Err(PolynomialError::IncompatiblePolynomials);
342        }
343        Ok(())
344    }
345}
346
347/// Canonical element of `R_q = Z_q[x] / (f(x))`.
348///
349/// Each operation applies the corresponding polynomial arithmetic and then reduces modulo `f(x)`.
350#[derive(Debug, Clone, PartialEq, Eq)]
351pub struct RingElem {
352    poly: Polynomial,
353    params: Arc<Params>,
354}
355
356impl RingElem {
357    fn new(params: Arc<Params>, poly: Polynomial) -> Result<Self, PolynomialError> {
358        if poly.max_degree() != params.max_degree() || poly.modulus() != params.modulus() {
359            return Err(PolynomialError::IncompatiblePolynomials);
360        }
361
362        let reduced = poly.rem_mod_poly(params.modulus_poly())?;
363        Ok(Self {
364            poly: reduced,
365            params,
366        })
367    }
368
369    /// Returns the validated parameter set that defines this element's ring.
370    pub fn params(&self) -> &Params {
371        self.params.as_ref()
372    }
373
374    /// Returns the underlying canonical polynomial representative.
375    pub fn polynomial(&self) -> &Polynomial {
376        &self.poly
377    }
378
379    /// Consumes and returns the underlying polynomial representative.
380    pub fn into_polynomial(self) -> Polynomial {
381        self.poly
382    }
383
384    /// Returns the dense coefficient slice of the canonical representative.
385    pub fn coefficients(&self) -> &[u64] {
386        self.poly.coefficients()
387    }
388
389    /// Returns coefficients trimmed to the actual degree.
390    pub fn trimmed_coefficients(&self) -> Vec<u64> {
391        self.poly.trimmed_coefficients()
392    }
393
394    /// Returns the representative's polynomial degree.
395    pub fn degree(&self) -> Option<usize> {
396        self.poly.degree()
397    }
398
399    /// Returns whether the element is additive identity.
400    pub fn is_zero(&self) -> bool {
401        self.poly.is_zero()
402    }
403
404    /// Adds two ring elements.
405    pub fn add(&self, rhs: &Self) -> Result<Self, PolynomialError> {
406        self.ensure_compatible(rhs)?;
407        let poly = self
408            .poly
409            .add_mod_poly(&rhs.poly, self.params.modulus_poly())?;
410        Ok(Self {
411            poly,
412            params: Arc::clone(&self.params),
413        })
414    }
415
416    /// Subtracts two ring elements.
417    pub fn sub(&self, rhs: &Self) -> Result<Self, PolynomialError> {
418        self.ensure_compatible(rhs)?;
419        let poly = self
420            .poly
421            .sub_mod_poly(&rhs.poly, self.params.modulus_poly())?;
422        Ok(Self {
423            poly,
424            params: Arc::clone(&self.params),
425        })
426    }
427
428    /// Multiplies two ring elements.
429    pub fn mul(&self, rhs: &Self) -> Result<Self, PolynomialError> {
430        self.ensure_compatible(rhs)?;
431        let poly = self
432            .poly
433            .mul_mod_poly(&rhs.poly, self.params.modulus_poly())?;
434        Ok(Self {
435            poly,
436            params: Arc::clone(&self.params),
437        })
438    }
439
440    /// Returns additive inverse.
441    pub fn neg(&self) -> Self {
442        let poly = self
443            .poly
444            .neg()
445            .rem_mod_poly(self.params.modulus_poly())
446            .expect("validated params guarantee successful reduction");
447        Self {
448            poly,
449            params: Arc::clone(&self.params),
450        }
451    }
452
453    /// Multiplies by a scalar in `Z_q`.
454    pub fn scalar_mul(&self, scalar: u64) -> Self {
455        let poly = self
456            .poly
457            .scalar_mul(scalar)
458            .rem_mod_poly(self.params.modulus_poly())
459            .expect("validated params guarantee successful reduction");
460        Self {
461            poly,
462            params: Arc::clone(&self.params),
463        }
464    }
465
466    fn ensure_compatible(&self, rhs: &Self) -> Result<(), PolynomialError> {
467        if Arc::ptr_eq(&self.params, &rhs.params) {
468            return Ok(());
469        }
470
471        if self.params.as_ref() != rhs.params.as_ref() {
472            return Err(PolynomialError::IncompatiblePolynomials);
473        }
474
475        Ok(())
476    }
477}
478
479fn mod_mul(a: u64, b: u64, modulus: u64) -> u64 {
480    ((a as u128 * b as u128) % modulus as u128) as u64
481}
482
483fn mod_pow(mut base: u64, mut exponent: u64, modulus: u64) -> u64 {
484    let mut acc = 1_u64 % modulus;
485    base %= modulus;
486
487    while exponent > 0 {
488        if exponent & 1 == 1 {
489            acc = mod_mul(acc, base, modulus);
490        }
491        base = mod_mul(base, base, modulus);
492        exponent >>= 1;
493    }
494
495    acc
496}
497
498fn mod_inverse(a: u64, modulus: u64) -> Option<u64> {
499    let mut t = 0_i128;
500    let mut new_t = 1_i128;
501    let mut r = modulus as i128;
502    let mut new_r = (a % modulus) as i128;
503
504    while new_r != 0 {
505        let quotient = r / new_r;
506        (t, new_t) = (new_t, t - quotient * new_t);
507        (r, new_r) = (new_r, r - quotient * new_r);
508    }
509
510    if r != 1 {
511        return None;
512    }
513
514    if t < 0 {
515        t += modulus as i128;
516    }
517    Some(t as u64)
518}
519
520#[cfg(test)]
521mod tests {
522    use super::{Params, ParamsError, RingContext};
523    use crate::polynomial::PolynomialError;
524
525    fn enumerate_ring_elements(
526        ctx: &RingContext,
527        values: &[u64],
528        used_len: usize,
529    ) -> Vec<super::RingElem> {
530        assert!(
531            used_len <= ctx.max_degree() + 1,
532            "used_len exceeds ring capacity"
533        );
534        assert!(!values.is_empty(), "values cannot be empty");
535
536        let total = values.len().pow(used_len as u32);
537        let mut out = Vec::with_capacity(total);
538        for mut state in 0..total {
539            let mut coeffs = vec![0_u64; used_len];
540            for slot in &mut coeffs {
541                let digit = state % values.len();
542                *slot = values[digit];
543                state /= values.len();
544            }
545            out.push(
546                ctx.element(&coeffs)
547                    .expect("enumerated coefficients should build"),
548            );
549        }
550        out
551    }
552
553    #[test]
554    fn params_build_and_expose_validated_fields() {
555        let params = Params::new(4, 17, &[1, 0, 0, 0, 1], 3).expect("params should validate");
556
557        assert_eq!(params.max_degree(), 4);
558        assert_eq!(params.modulus(), 17);
559        assert_eq!(
560            params.modulus_poly().trimmed_coefficients(),
561            vec![1, 0, 0, 0, 1]
562        );
563        assert_eq!(params.primitive_root(), 3);
564        assert_eq!(params.ntt_length(), 16);
565    }
566
567    #[test]
568    fn params_reject_zero_modulus_polynomial() {
569        let err = Params::new(4, 17, &[], 3).expect_err("expected error");
570        assert_eq!(err, ParamsError::ZeroModulusPolynomial);
571    }
572
573    #[test]
574    fn params_reject_constant_modulus_polynomial() {
575        let err = Params::new(4, 17, &[3], 3).expect_err("expected error");
576        assert_eq!(err, ParamsError::ConstantModulusPolynomial);
577    }
578
579    #[test]
580    fn params_reject_degree_mismatch() {
581        let err = Params::new(4, 17, &[1, 0, 1], 3).expect_err("expected error");
582        assert_eq!(
583            err,
584            ParamsError::ModulusPolynomialDegreeMismatch {
585                degree: 2,
586                max_degree: 4,
587            }
588        );
589    }
590
591    #[test]
592    fn params_reject_unsupported_ntt_length() {
593        let err = Params::new(4, 19, &[1, 0, 0, 0, 1], 2).expect_err("expected error");
594        assert_eq!(
595            err,
596            ParamsError::NttLengthUnsupported {
597                length: 16,
598                modulus: 19,
599            }
600        );
601    }
602
603    #[test]
604    fn params_reject_invalid_primitive_root() {
605        let err = Params::new(4, 17, &[1, 0, 0, 0, 1], 1).expect_err("expected error");
606        assert_eq!(
607            err,
608            ParamsError::InvalidPrimitiveRoot {
609                primitive_root: 1,
610                length: 16,
611                modulus: 17,
612            }
613        );
614    }
615
616    #[test]
617    fn ring_context_applies_ring_reduction_ops() {
618        let ctx =
619            RingContext::from_parts(4, 17, &[1, 0, 0, 0, 1], 3).expect("context should build");
620        let a = ctx
621            .polynomial(&[16, 2, 0, 1, 0])
622            .expect("poly should build");
623        let b = ctx.polynomial(&[3, 0, 1, 0, 0]).expect("poly should build");
624
625        let sum = ctx.add(&a, &b).expect("add should succeed");
626        let diff = ctx.sub(&a, &b).expect("sub should succeed");
627        let prod = ctx.mul(&a, &b).expect("mul should succeed");
628
629        assert_eq!(sum.trimmed_coefficients(), vec![2, 2, 1, 1]);
630        assert_eq!(diff.trimmed_coefficients(), vec![13, 2, 16, 1]);
631        assert_eq!(prod.trimmed_coefficients(), vec![14, 5, 16, 5]);
632    }
633
634    #[test]
635    fn ring_context_rejects_incompatible_members() {
636        let ctx =
637            RingContext::from_parts(4, 17, &[1, 0, 0, 0, 1], 3).expect("context should build");
638        let foreign = RingContext::from_parts(3, 17, &[1, 0, 0, 1], 3)
639            .expect("context should build")
640            .one();
641
642        let err = ctx
643            .reduce(&foreign)
644            .expect_err("expected incompatibility error");
645        assert_eq!(err, PolynomialError::IncompatiblePolynomials);
646    }
647
648    #[test]
649    fn ring_elem_creation_canonicalizes_representation() {
650        let ctx =
651            RingContext::from_parts(4, 17, &[1, 0, 0, 0, 1], 3).expect("context should build");
652
653        let elem = ctx
654            .element(&[0, 0, 0, 0, 16])
655            .expect("ring element should build");
656
657        assert_eq!(elem.trimmed_coefficients(), vec![1]);
658    }
659
660    #[test]
661    fn ring_elem_add_sub_mul_without_explicit_modulus_poly() {
662        let ctx =
663            RingContext::from_parts(4, 17, &[1, 0, 0, 0, 1], 3).expect("context should build");
664        let a = ctx
665            .element(&[16, 2, 0, 1, 0])
666            .expect("ring element should build");
667        let b = ctx
668            .element(&[3, 0, 1, 0, 0])
669            .expect("ring element should build");
670
671        let sum = a.add(&b).expect("add should succeed");
672        let diff = a.sub(&b).expect("sub should succeed");
673        let prod = a.mul(&b).expect("mul should succeed");
674
675        assert_eq!(sum.trimmed_coefficients(), vec![2, 2, 1, 1]);
676        assert_eq!(diff.trimmed_coefficients(), vec![13, 2, 16, 1]);
677        assert_eq!(prod.trimmed_coefficients(), vec![14, 5, 16, 5]);
678    }
679
680    #[test]
681    fn ring_elem_rejects_mixed_contexts() {
682        let ctx_a =
683            RingContext::from_parts(4, 17, &[1, 0, 0, 0, 1], 3).expect("context should build");
684        let ctx_b = RingContext::from_parts(3, 17, &[1, 0, 0, 1], 3).expect("context should build");
685
686        let a = ctx_a.one_element();
687        let b = ctx_b.one_element();
688
689        let err = a.add(&b).expect_err("expected incompatibility error");
690        assert_eq!(err, PolynomialError::IncompatiblePolynomials);
691    }
692
693    #[test]
694    fn ring_elem_identity_inverse_and_scalar_laws_hold_exhaustively() {
695        let ctx =
696            RingContext::from_parts(4, 17, &[1, 0, 0, 0, 1], 3).expect("context should build");
697        let elems = enumerate_ring_elements(&ctx, &[0, 1, 2, 3], 3);
698        let zero = ctx.zero_element();
699        let one = ctx.one_element();
700
701        for elem in &elems {
702            assert_eq!(elem.add(&zero).expect("e+0 should work"), *elem);
703            assert_eq!(zero.add(elem).expect("0+e should work"), *elem);
704            assert_eq!(elem.sub(&zero).expect("e-0 should work"), *elem);
705            assert_eq!(elem.mul(&one).expect("e*1 should work"), *elem);
706            assert_eq!(one.mul(elem).expect("1*e should work"), *elem);
707            assert!(
708                elem.add(&elem.neg()).expect("e+(-e) should work").is_zero(),
709                "additive inverse should cancel"
710            );
711            assert_eq!(elem.scalar_mul(ctx.modulus() + 7), elem.scalar_mul(7));
712            assert!(elem.scalar_mul(0).is_zero());
713        }
714    }
715
716    #[test]
717    fn ring_elem_ops_match_context_ops_exhaustive_small_domain() {
718        let ctx =
719            RingContext::from_parts(4, 17, &[1, 0, 0, 0, 1], 3).expect("context should build");
720        let elems = enumerate_ring_elements(&ctx, &[0, 1, 2, 3], 3);
721
722        for a in &elems {
723            for b in &elems {
724                let add_ring = a.add(b).expect("ring add should work").into_polynomial();
725                let add_ctx = ctx
726                    .add(a.polynomial(), b.polynomial())
727                    .expect("context add should work");
728                assert_eq!(add_ring, add_ctx);
729
730                let sub_ring = a.sub(b).expect("ring sub should work").into_polynomial();
731                let sub_ctx = ctx
732                    .sub(a.polynomial(), b.polynomial())
733                    .expect("context sub should work");
734                assert_eq!(sub_ring, sub_ctx);
735
736                let mul_ring = a.mul(b).expect("ring mul should work").into_polynomial();
737                let mul_ctx = ctx
738                    .mul(a.polynomial(), b.polynomial())
739                    .expect("context mul should work");
740                assert_eq!(mul_ring, mul_ctx);
741            }
742        }
743    }
744
745    #[test]
746    fn ring_elem_associativity_and_distributivity_hold_on_dense_sample() {
747        let ctx =
748            RingContext::from_parts(4, 17, &[1, 0, 0, 0, 1], 3).expect("context should build");
749        let elems = enumerate_ring_elements(&ctx, &[0, 1, 2, 3], 3);
750        let sample: Vec<_> = elems.iter().take(16).cloned().collect();
751
752        for a in &sample {
753            for b in &sample {
754                assert_eq!(
755                    a.add(b).expect("a+b should work"),
756                    b.add(a).expect("b+a should work"),
757                    "addition should commute"
758                );
759                assert_eq!(
760                    a.mul(b).expect("a*b should work"),
761                    b.mul(a).expect("b*a should work"),
762                    "multiplication should commute in this ring"
763                );
764                for c in &sample {
765                    assert_eq!(
766                        a.add(&b.add(c).expect("b+c should work"))
767                            .expect("a+(b+c) should work"),
768                        a.add(b)
769                            .expect("a+b should work")
770                            .add(c)
771                            .expect("(a+b)+c should work"),
772                        "addition should associate"
773                    );
774                    assert_eq!(
775                        a.mul(&b.mul(c).expect("b*c should work"))
776                            .expect("a*(b*c) should work"),
777                        a.mul(b)
778                            .expect("a*b should work")
779                            .mul(c)
780                            .expect("(a*b)*c should work"),
781                        "multiplication should associate"
782                    );
783                    assert_eq!(
784                        a.mul(&b.add(c).expect("b+c should work"))
785                            .expect("a*(b+c) should work"),
786                        a.mul(b)
787                            .expect("a*b should work")
788                            .add(&a.mul(c).expect("a*c should work"))
789                            .expect("ab+ac should work"),
790                        "left distributivity should hold"
791                    );
792                }
793            }
794        }
795    }
796
797    #[test]
798    fn ring_elem_round_trips_through_polynomial_conversion() {
799        let ctx =
800            RingContext::from_parts(4, 17, &[1, 0, 0, 0, 1], 3).expect("context should build");
801        let elems = enumerate_ring_elements(&ctx, &[0, 1, 2, 3], 3);
802
803        for elem in elems {
804            let reconstructed = ctx
805                .element_from_polynomial(elem.clone().into_polynomial())
806                .expect("reconstruction should work");
807            assert_eq!(elem, reconstructed);
808        }
809    }
810}