use std::cmp::Ordering;
use std::convert::TryFrom;
use std::{iter, ops};
use serde::{Deserialize, Serialize};
use crate::elliptic::curves::{Curve, Scalar};
#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub enum PolynomialDegree {
Infinity,
Finite(u16),
}
impl From<u16> for PolynomialDegree {
fn from(deg: u16) -> Self {
PolynomialDegree::Finite(deg)
}
}
impl PartialOrd for PolynomialDegree {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PolynomialDegree {
fn cmp(&self, other: &Self) -> Ordering {
match (self, other) {
(Self::Infinity, Self::Infinity) => Ordering::Equal,
(Self::Infinity, Self::Finite(_)) => Ordering::Greater,
(Self::Finite(_), Self::Infinity) => Ordering::Less,
(Self::Finite(a), Self::Finite(b)) => a.cmp(b),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(bound = "")]
pub struct Polynomial<E: Curve> {
coefficients: Vec<Scalar<E>>,
}
impl<E: Curve> Polynomial<E> {
pub fn from_coefficients(coefficients: Vec<Scalar<E>>) -> Self {
Self { coefficients }
}
pub fn sample_exact(degree: impl Into<PolynomialDegree>) -> Self {
match degree.into() {
PolynomialDegree::Finite(degree) => Self::from_coefficients(
iter::repeat_with(Scalar::random)
.take(usize::from(degree) + 1)
.collect(),
),
PolynomialDegree::Infinity => Self::from_coefficients(vec![]),
}
}
pub fn sample_exact_with_fixed_const_term(n: u16, const_term: Scalar<E>) -> Self {
if n == 0 {
Self::from_coefficients(vec![const_term])
} else {
let random_coefficients = iter::repeat_with(Scalar::random).take(usize::from(n));
Self::from_coefficients(iter::once(const_term).chain(random_coefficients).collect())
}
}
pub fn degree(&self) -> PolynomialDegree {
self.coefficients()
.iter()
.enumerate()
.rev()
.find(|(_, a)| !a.is_zero())
.map(|(i, _)| {
PolynomialDegree::Finite(
u16::try_from(i).expect("polynomial degree guaranteed to fit into u16"),
)
})
.unwrap_or(PolynomialDegree::Infinity)
}
pub fn evaluate(&self, point_x: &Scalar<E>) -> Scalar<E> {
let mut reversed_coefficients = self.coefficients.iter().rev();
let head = reversed_coefficients
.next()
.expect("at least one coefficient is guaranteed to be present");
let tail = reversed_coefficients;
tail.fold(head.clone(), |partial, coef| {
let partial_times_point_x = partial * point_x;
partial_times_point_x + coef
})
}
pub fn evaluate_bigint<B>(&self, point_x: B) -> Scalar<E>
where
Scalar<E>: From<B>,
{
self.evaluate(&Scalar::from(point_x))
}
pub fn evaluate_many<'i, I>(&'i self, points_x: I) -> impl Iterator<Item = Scalar<E>> + 'i
where
I: IntoIterator<Item = &'i Scalar<E>> + 'i,
{
points_x.into_iter().map(move |x| self.evaluate(x))
}
pub fn evaluate_many_bigint<'i, B, I>(
&'i self,
points_x: I,
) -> impl Iterator<Item = Scalar<E>> + 'i
where
I: IntoIterator<Item = B> + 'i,
Scalar<E>: From<B>,
{
points_x.into_iter().map(move |x| self.evaluate_bigint(x))
}
pub fn coefficients(&self) -> &[Scalar<E>] {
&self.coefficients
}
pub fn lagrange_basis(x: &Scalar<E>, j: u16, xs: &[Scalar<E>]) -> Scalar<E> {
let x_j = &xs[usize::from(j)];
let num: Scalar<E> = (0u16..)
.zip(xs)
.filter(|(m, _)| *m != j)
.map(|(_, x_m)| x - x_m)
.product();
let denum: Scalar<E> = (0u16..)
.zip(xs)
.filter(|(m, _)| *m != j)
.map(|(_, x_m)| x_j - x_m)
.product();
let denum = denum
.invert()
.expect("elements in xs are not pairwise distinct");
num * denum
}
}
impl<E: Curve> ops::Mul<&Scalar<E>> for &Polynomial<E> {
type Output = Polynomial<E>;
fn mul(self, scalar: &Scalar<E>) -> Self::Output {
let coefficients = self.coefficients.iter().map(|c| c * scalar).collect();
Polynomial::from_coefficients(coefficients)
}
}
impl<E: Curve> ops::Add for &Polynomial<E> {
type Output = Polynomial<E>;
fn add(self, g: Self) -> Self::Output {
let len1 = self.coefficients.len();
let len2 = g.coefficients.len();
let overlapped = self
.coefficients()
.iter()
.zip(g.coefficients())
.map(|(f_coef, g_coef)| f_coef + g_coef);
let tail = if len1 < len2 {
&g.coefficients()[len1..]
} else {
&self.coefficients()[len2..]
};
Polynomial::from_coefficients(overlapped.chain(tail.iter().cloned()).collect())
}
}
impl<E: Curve> ops::Sub for &Polynomial<E> {
type Output = Polynomial<E>;
fn sub(self, g: Self) -> Self::Output {
let len1 = self.coefficients.len();
let len2 = g.coefficients.len();
let overlapped = self
.coefficients()
.iter()
.zip(g.coefficients())
.map(|(f_coef, g_coef)| f_coef - g_coef);
let tail = if len1 < len2 {
g.coefficients()[len1..].iter().map(|x| -x).collect()
} else {
self.coefficients()[len2..].to_vec()
};
Polynomial::from_coefficients(overlapped.chain(tail.into_iter()).collect())
}
}