use alloc::vec::Vec;
use core::ops::{Add, AddAssign, Mul, Neg, Sub};
use p3_field::extension::ComplexExtendable;
use p3_field::{ExtensionField, Field, batch_multiplicative_inverse};
#[allow(clippy::manual_non_exhaustive)]
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
pub struct Point<F> {
pub x: F,
pub y: F,
_private: (),
}
impl<F: Field> Point<F> {
#[inline]
pub fn new(x: F, y: F) -> Self {
debug_assert_eq!(x.square() + y.square(), F::ONE);
Self { x, y, _private: () }
}
const ZERO: Self = Self {
x: F::ONE,
y: F::ZERO,
_private: (),
};
pub fn from_projective_line(t: F) -> Self {
let t2 = t.square();
let inv_denom = (F::ONE + t2).try_inverse().expect("t^2 = -1");
Self::new((F::ONE - t2) * inv_denom, t.double() * inv_denom)
}
pub fn to_projective_line(self) -> Option<F> {
(self.x + F::ONE).try_inverse().map(|x| x * self.y)
}
pub fn double(self) -> Self {
Self::new(self.x.square().double() - F::ONE, self.x.double() * self.y)
}
pub fn v_n(mut self, log_n: usize) -> F {
for _ in 0..(log_n - 1) {
self.x = self.x.square().double() - F::ONE; }
self.x
}
pub fn v_n_prod(mut self, log_n: usize) -> F {
let mut output = self.x;
for _ in 0..(log_n - 2) {
self.x = self.x.square().double() - F::ONE; output *= self.x;
}
output
}
pub fn v_tilde_p<EF: ExtensionField<F>>(self, at: Point<EF>) -> EF {
(at - self).to_projective_line().unwrap()
}
pub fn s_p_at_p(self, log_n: usize) -> F {
-self.v_n_prod(log_n).mul_2exp_u64((2 * log_n - 1) as u64) * self.y
}
pub fn v_p<EF: ExtensionField<F>>(self, at: Point<EF>) -> (EF, EF) {
let diff = -at + self;
(EF::ONE - diff.x, -diff.y)
}
}
pub fn compute_lagrange_den_batched<F: Field, EF: ExtensionField<F>>(
points: &[Point<F>],
at: Point<EF>,
log_n: usize,
) -> Vec<EF> {
let (numer, denom): (Vec<_>, Vec<_>) = points
.iter()
.map(|&pt| {
let diff = at - pt;
let numer = diff.x + F::ONE;
let denom = diff.y * pt.s_p_at_p(log_n);
(numer, denom)
})
.unzip();
let inv_d = batch_multiplicative_inverse(&denom);
numer
.iter()
.zip(inv_d.iter())
.map(|(&num, &inv_d)| num * inv_d)
.collect()
}
impl<F: ComplexExtendable> Point<F> {
pub fn generator(log_n: usize) -> Self {
let g = F::circle_two_adic_generator(log_n);
Self::new(g.real(), g.imag())
}
}
impl<F: Field> Neg for Point<F> {
type Output = Self;
fn neg(mut self) -> Self::Output {
self.y = -self.y;
self
}
}
impl<F: Field, EF: ExtensionField<F>> Add<Point<F>> for Point<EF> {
type Output = Self;
fn add(self, rhs: Point<F>) -> Self::Output {
Self::new(
self.x * rhs.x - self.y * rhs.y,
self.x * rhs.y + self.y * rhs.x,
)
}
}
impl<F: Field> AddAssign for Point<F> {
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs;
}
}
impl<F: Field, EF: ExtensionField<F>> Sub<Point<F>> for Point<EF> {
type Output = Self;
fn sub(self, rhs: Point<F>) -> Self::Output {
Self::new(
self.x * rhs.x + self.y * rhs.y,
self.y * rhs.x - self.x * rhs.y,
)
}
}
impl<F: Field> Mul<usize> for Point<F> {
type Output = Self;
fn mul(mut self, mut rhs: usize) -> Self::Output {
let mut res = Self::ZERO;
while rhs != 0 {
if rhs & 1 == 1 {
res += self;
}
rhs >>= 1;
self = self.double();
}
res
}
}
#[cfg(test)]
mod tests {
use p3_mersenne_31::Mersenne31;
use super::*;
type F = Mersenne31;
type Pt = Point<F>;
#[test]
fn test_arithmetic() {
let one = Pt::generator(3);
assert_eq!(one - one, Pt::ZERO);
assert_eq!(one + one, one * 2);
assert_eq!(one + one + one, one * 3);
assert_eq!(one * 7, -one);
assert_eq!(one * 8, Pt::ZERO);
let generator = Pt::generator(10);
let log_n = 10;
let vn_prod_gen = (1..log_n).map(|i| generator.v_n(i)).product();
assert_eq!(generator.v_n_prod(log_n), vn_prod_gen);
}
}