use crate::{
arith::{self, trim_zeros},
modulo::PolynomialModulus,
};
use num::{One, Zero};
use std::ops::{Add, Mul, Neg, Sub};
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Polynomial<T, const N: usize> {
pub(crate) coeffs: Vec<T>,
}
impl<T, const N: usize> Polynomial<T, N> {
pub fn new<K>(coeffs: Vec<K>) -> Self
where
T: Zero + From<K>,
{
assert_eq!(N.count_ones(), 1, "N must be a power of two");
assert!(
coeffs.len() <= N,
"The degree of the polynomial must be less than or equal to N"
);
let mut coeffs = coeffs.into_iter().map(T::from).collect();
trim_zeros(&mut coeffs);
Polynomial { coeffs }
}
pub fn from_coeffs(coeffs: Vec<T>) -> Self
where
T: Zero + One + Clone,
for<'a> &'a T: Mul<Output = T> + Sub<Output = T> + Add<Output = T>,
{
assert_eq!(N.count_ones(), 1, "N must be a power of two");
let mut coeffs = coeffs;
trim_zeros(&mut coeffs);
arith::modulo(Polynomial { coeffs }, PolynomialModulus::<T>::new(N))
}
pub fn deg(&self) -> usize {
self.coeffs.len() - 1
}
pub fn leading_coefficient(&self) -> T
where
T: Zero + Clone,
{
self.coeffs.last().cloned().unwrap_or_else(T::zero)
}
pub fn coefficient(&self, idx: usize) -> T
where
T: Zero + Clone,
{
self.coeffs.get(idx).cloned().unwrap_or_else(T::zero)
}
pub fn mapv<U, F>(&self, f: F) -> Polynomial<U, N>
where
U: Zero,
F: Fn(&T) -> U,
{
let mut coeffs = self.coeffs.iter().map(f).collect();
trim_zeros(&mut coeffs);
Polynomial { coeffs }
}
pub fn coeffs_mut<F>(&mut self, f: F)
where
T: Zero,
F: FnMut(&mut T),
{
self.coeffs.iter_mut().for_each(f);
trim_zeros(&mut self.coeffs);
}
pub fn iter(&self) -> impl Iterator<Item = &T> {
self.coeffs.iter()
}
}
impl<T, const N: usize> One for Polynomial<T, N>
where
T: Zero + One + Clone,
for<'a> &'a T: Mul<Output = T> + Sub<Output = T> + Add<Output = T>,
{
fn one() -> Self {
Polynomial {
coeffs: vec![T::one()],
}
}
}
impl<T, const N: usize> Zero for Polynomial<T, N>
where
T: Zero + One + Clone,
for<'a> &'a T: Mul<Output = T> + Sub<Output = T> + Add<Output = T>,
{
fn zero() -> Self {
Polynomial { coeffs: vec![] }
}
fn is_zero(&self) -> bool {
self.coeffs.is_empty() || self.coeffs.iter().all(|c| c.is_zero())
}
}
impl<T, const N: usize> Add for Polynomial<T, N>
where
T: Zero + One + Clone,
for<'a> &'a T: Mul<Output = T> + Sub<Output = T> + Add<Output = T>,
{
type Output = Self;
fn add(self, other: Self) -> Self {
let ret = arith::add(self, other);
let modulo = PolynomialModulus::<T>::new(N);
arith::modulo(ret, modulo)
}
}
impl<T, const N: usize> Neg for Polynomial<T, N>
where
T: Zero + Clone,
for<'a> &'a T: Neg<Output = T>,
{
type Output = Self;
fn neg(self) -> Self {
let mut result = self;
result.coeffs.iter_mut().for_each(|c| *c = -&*c);
result
}
}
impl<T, const N: usize> Sub for Polynomial<T, N>
where
T: Zero + One + Clone,
for<'a> &'a T: Neg<Output = T> + Mul<Output = T> + Sub<Output = T> + Add<Output = T>,
{
type Output = Self;
fn sub(self, other: Self) -> Self {
let ret = arith::add(self, -other);
let modulo = PolynomialModulus::<T>::new(N);
arith::modulo(ret, modulo)
}
}
impl<T, const N: usize> Mul for Polynomial<T, N>
where
T: Zero + One + Clone,
for<'a> &'a T: Mul<Output = T> + Add<Output = T> + Sub<Output = T>,
{
type Output = Self;
#[allow(clippy::needless_borrow)]
fn mul(self, other: Self) -> Self {
if self.is_zero() || other.is_zero() {
return Polynomial::zero();
}
arith::cyclic_mul(&self, &other)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[should_panic]
fn test_new() {
const INVALID_N: usize = 5;
Polynomial::<i32, INVALID_N>::new(vec![1, 2, 3]);
}
#[test]
fn test_zero() {
let p = Polynomial::<i32, 4>::zero();
assert!(p.is_zero());
}
#[cfg(feature = "zq")]
#[test]
fn test_poly_over_zq() {
use crate::zq::{ZqI32, ZqI64, ZqU32, ZqU64};
let p = Polynomial::<ZqI32<7>, 4>::new(vec![1i32, 2, 3]);
assert_eq!(p.deg(), 2);
assert_eq!(p.coefficient(0), ZqI32::new(1));
assert_eq!(p.coefficient(1), ZqI32::new(2));
assert_eq!(p.coefficient(2), ZqI32::new(3));
let p = Polynomial::<ZqI64<7>, 4>::new(vec![1i64, 2, 3]);
assert_eq!(p.deg(), 2);
assert_eq!(p.coefficient(0), ZqI64::new(1));
assert_eq!(p.coefficient(1), ZqI64::new(2));
assert_eq!(p.coefficient(2), ZqI64::new(3));
let p = Polynomial::<ZqU32<7>, 4>::new(vec![1u32, 2, 3]);
assert_eq!(p.deg(), 2);
assert_eq!(p.coefficient(0), ZqU32::new(1));
assert_eq!(p.coefficient(1), ZqU32::new(2));
assert_eq!(p.coefficient(2), ZqU32::new(3));
let p = Polynomial::<ZqU64<7>, 4>::new(vec![1u64, 2, 3]);
assert_eq!(p.deg(), 2);
assert_eq!(p.coefficient(0), ZqU64::new(1));
assert_eq!(p.coefficient(1), ZqU64::new(2));
assert_eq!(p.coefficient(2), ZqU64::new(3));
}
}