#![allow(non_snake_case)]
use group::{
ff::{Field, PrimeField},
Group,
};
mod poly;
pub use poly::*;
#[cfg(test)]
pub(crate) mod tests;
pub trait DivisorCurve: Group
where
Self::Scalar: PrimeField,
{
type FieldElement: PrimeField;
fn a() -> Self::FieldElement;
fn b() -> Self::FieldElement;
fn divisor_modulus() -> Poly<Self::FieldElement> {
Poly {
y_coefficients: vec![Self::FieldElement::ZERO, Self::FieldElement::ONE],
yx_coefficients: vec![],
x_coefficients: vec![
-Self::a(),
Self::FieldElement::ZERO,
-Self::FieldElement::ONE,
],
zero_coefficient: -Self::b(),
}
}
fn to_xy(point: Self) -> (Self::FieldElement, Self::FieldElement);
}
fn slope_intercept<C: DivisorCurve>(a: C, b: C) -> (C::FieldElement, C::FieldElement) {
let (ax, ay) = C::to_xy(a);
debug_assert_eq!(C::divisor_modulus().eval(ax, ay), C::FieldElement::ZERO);
let (bx, by) = C::to_xy(b);
debug_assert_eq!(C::divisor_modulus().eval(bx, by), C::FieldElement::ZERO);
let slope = (by - ay) *
Option::<C::FieldElement>::from((bx - ax).invert())
.expect("trying to get slope/intercept of points sharing an x coordinate");
let intercept = by - (slope * bx);
debug_assert!(bool::from((ay - (slope * ax) - intercept).is_zero()));
debug_assert!(bool::from((by - (slope * bx) - intercept).is_zero()));
(slope, intercept)
}
fn line<C: DivisorCurve>(a: C, mut b: C) -> Poly<C::FieldElement> {
if (a + b) == C::identity() {
let (ax, _) = C::to_xy(a);
return Poly {
y_coefficients: vec![],
yx_coefficients: vec![],
x_coefficients: vec![C::FieldElement::ONE],
zero_coefficient: -ax,
};
}
if a == b {
b = -a.double();
}
let (slope, intercept) = slope_intercept::<C>(a, b);
Poly {
y_coefficients: vec![C::FieldElement::ONE],
yx_coefficients: vec![],
x_coefficients: vec![-slope],
zero_coefficient: -intercept,
}
}
#[allow(clippy::new_ret_no_self)]
pub fn new_divisor<C: DivisorCurve>(points: &[C]) -> Option<Poly<C::FieldElement>> {
if points.len() < 2 {
None?;
}
if points.iter().sum::<C>() != C::identity() {
None?;
}
let mut divs = vec![];
let mut iter = points.iter().copied();
while let Some(a) = iter.next() {
if a == C::identity() {
None?;
}
let b = iter.next();
if b == Some(C::identity()) {
None?;
}
divs.push((a + b.unwrap_or(C::identity()), line::<C>(a, b.unwrap_or(-a))));
}
let modulus = C::divisor_modulus();
while divs.len() > 1 {
let mut next_divs = vec![];
if (divs.len() % 2) == 1 {
next_divs.push(divs.pop().unwrap());
}
while let Some((a, a_div)) = divs.pop() {
let (b, b_div) = divs.pop().unwrap();
let numerator = a_div.mul_mod(b_div, &modulus).mul_mod(line::<C>(a, b), &modulus);
let denominator = line::<C>(a, -a).mul_mod(line::<C>(b, -b), &modulus);
let (q, r) = numerator.div_rem(&denominator);
assert_eq!(r, Poly::zero());
next_divs.push((a + b, q));
}
divs = next_divs;
}
Some(divs.remove(0).1)
}