use super::modulo_ring::ModuloRingDouble;
use super::modulo_ring::{ModuloRingLarge, ModuloRingSingle};
use crate::{
arch::word::{DoubleWord, Word},
buffer::Buffer,
error::panic_different_rings,
};
use alloc::boxed::Box;
pub struct Modulo<'a>(ModuloRepr<'a>);
pub(crate) enum ModuloRepr<'a> {
Single(ModuloSingleRaw, &'a ModuloRingSingle),
Double(ModuloDoubleRaw, &'a ModuloRingDouble),
Large(ModuloLargeRaw, &'a ModuloRingLarge),
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) struct ModuloSingleRaw(pub(crate) Word);
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) struct ModuloDoubleRaw(pub(crate) DoubleWord);
#[derive(Clone, PartialEq, Eq)]
pub(crate) struct ModuloLargeRaw(pub(crate) Box<[Word]>);
impl<'a> Modulo<'a> {
#[inline]
pub(crate) fn repr(&self) -> &ModuloRepr<'a> {
&self.0
}
#[inline]
pub(crate) fn repr_mut(&mut self) -> &mut ModuloRepr<'a> {
&mut self.0
}
#[inline]
pub(crate) fn into_repr(self) -> ModuloRepr<'a> {
self.0
}
#[inline]
pub(crate) const fn from_single(raw: ModuloSingleRaw, ring: &'a ModuloRingSingle) -> Self {
debug_assert!(ring.is_valid(raw));
Modulo(ModuloRepr::Single(raw, ring))
}
#[inline]
pub(crate) const fn from_double(raw: ModuloDoubleRaw, ring: &'a ModuloRingDouble) -> Self {
debug_assert!(ring.is_valid(raw));
Modulo(ModuloRepr::Double(raw, ring))
}
#[inline]
pub(crate) fn from_large(raw: ModuloLargeRaw, ring: &'a ModuloRingLarge) -> Self {
debug_assert!(ring.is_valid(&raw));
Modulo(ModuloRepr::Large(raw, ring))
}
#[inline]
pub(crate) fn check_same_ring_single(lhs: &ModuloRingSingle, rhs: &ModuloRingSingle) {
if lhs != rhs {
panic_different_rings();
}
}
#[inline]
pub(crate) fn check_same_ring_double(lhs: &ModuloRingDouble, rhs: &ModuloRingDouble) {
if lhs != rhs {
panic_different_rings();
}
}
#[inline]
pub(crate) fn check_same_ring_large(lhs: &ModuloRingLarge, rhs: &ModuloRingLarge) {
if lhs != rhs {
panic_different_rings();
}
}
}
impl ModuloSingleRaw {
pub const fn one(ring: &ModuloRingSingle) -> Self {
let modulo = Self(1 << ring.shift());
debug_assert!(ring.is_valid(modulo));
modulo
}
}
impl ModuloDoubleRaw {
pub const fn one(ring: &ModuloRingDouble) -> Self {
let modulo = Self(1 << ring.shift());
debug_assert!(ring.is_valid(modulo));
modulo
}
}
impl ModuloLargeRaw {
pub fn one(ring: &ModuloRingLarge) -> Self {
let modulus = ring.normalized_modulus();
let mut buf = Buffer::allocate_exact(modulus.len());
buf.push(1 << ring.shift());
buf.push_zeros(modulus.len() - 1);
let modulo = Self(buf.into_boxed_slice());
debug_assert!(ring.is_valid(&modulo));
modulo
}
}
impl Clone for Modulo<'_> {
#[inline]
fn clone(&self) -> Self {
Modulo(self.0.clone())
}
#[inline]
fn clone_from(&mut self, source: &Self) {
self.0.clone_from(&source.0);
}
}
impl Clone for ModuloRepr<'_> {
#[inline]
fn clone(&self) -> Self {
match self {
ModuloRepr::Single(modulo, ring) => ModuloRepr::Single(*modulo, ring),
ModuloRepr::Double(modulo, ring) => ModuloRepr::Double(*modulo, ring),
ModuloRepr::Large(modulo, ring) => ModuloRepr::Large(modulo.clone(), ring),
}
}
#[inline]
fn clone_from(&mut self, source: &Self) {
if let (ModuloRepr::Large(raw, ring), ModuloRepr::Large(src_raw, src_ring)) =
(&mut *self, source)
{
*ring = src_ring;
raw.0.clone_from(&src_raw.0);
} else {
*self = source.clone();
}
}
}