use alloc::vec;
use alloc::vec::Vec;
use core::fmt::{Debug, Display};
use core::hash::Hash;
use core::iter::{Product, Sum, zip};
use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use core::{array, slice};
use num_bigint::BigUint;
use p3_maybe_rayon::prelude::*;
use p3_util::{flatten_to_base, iter_array_chunks_padded};
use serde::Serialize;
use serde::de::DeserializeOwned;
use crate::exponentiation::bits_u64;
use crate::integers::{QuotientMap, from_integer_types};
use crate::packed::PackedField;
use crate::{Dup, Packable, PackedFieldExtension, PackedValue};
pub trait PrimeCharacteristicRing:
Sized
+ Default
+ Dup
+ Add<Output = Self>
+ AddAssign
+ Sub<Output = Self>
+ SubAssign
+ Neg<Output = Self>
+ Mul<Output = Self>
+ MulAssign
+ Sum
+ Product
+ Debug
{
type PrimeSubfield: PrimeField;
const ZERO: Self;
const ONE: Self;
const TWO: Self;
const NEG_ONE: Self;
#[must_use]
fn from_prime_subfield(f: Self::PrimeSubfield) -> Self;
#[must_use]
#[inline(always)]
fn from_bool(b: bool) -> Self {
if b { Self::ONE } else { Self::ZERO }
}
from_integer_types!(
u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize
);
#[must_use]
#[inline(always)]
fn double(&self) -> Self {
self.dup() + self.dup()
}
#[must_use]
#[inline]
fn halve(&self) -> Self {
let half = Self::from_prime_subfield(Self::PrimeSubfield::ONE.halve());
self.dup() * half
}
#[must_use]
#[inline(always)]
fn square(&self) -> Self {
self.dup() * self.dup()
}
#[must_use]
#[inline(always)]
fn cube(&self) -> Self {
self.square() * self.dup()
}
#[must_use]
#[inline(always)]
fn xor(&self, y: &Self) -> Self {
self.dup() + y.dup() - self.dup() * y.dup().double()
}
#[must_use]
#[inline(always)]
fn xor3(&self, y: &Self, z: &Self) -> Self {
self.xor(y).xor(z)
}
#[must_use]
#[inline(always)]
fn andn(&self, y: &Self) -> Self {
(Self::ONE - self.dup()) * y.dup()
}
#[must_use]
#[inline(always)]
fn bool_check(&self) -> Self {
self.dup() * (self.dup() - Self::ONE)
}
#[must_use]
#[inline]
fn exp_u64(&self, power: u64) -> Self {
let mut current = self.dup();
let mut product = Self::ONE;
for j in 0..bits_u64(power) {
if (power >> j) & 1 != 0 {
product *= current.dup();
}
current = current.square();
}
product
}
#[must_use]
#[inline(always)]
fn exp_const_u64<const POWER: u64>(&self) -> Self {
match POWER {
0 => Self::ONE,
1 => self.dup(),
2 => self.square(),
3 => self.cube(),
4 => self.square().square(),
5 => self.square().square() * self.dup(),
6 => self.square().cube(),
7 => {
let x2 = self.square();
let x3 = x2.dup() * self.dup();
let x4 = x2.square();
x3 * x4
}
_ => self.exp_u64(POWER),
}
}
#[must_use]
#[inline]
fn exp_power_of_2(&self, power_log: usize) -> Self {
let mut res = self.dup();
for _ in 0..power_log {
res = res.square();
}
res
}
#[must_use]
#[inline]
fn mul_2exp_u64(&self, exp: u64) -> Self {
self.dup() * Self::TWO.exp_u64(exp)
}
#[must_use]
#[inline]
fn div_2exp_u64(&self, exp: u64) -> Self {
self.dup() * Self::from_prime_subfield(Self::PrimeSubfield::ONE.halve().exp_u64(exp))
}
#[must_use]
#[inline]
fn powers(&self) -> Powers<Self> {
self.shifted_powers(Self::ONE)
}
#[must_use]
#[inline]
fn shifted_powers(&self, start: Self) -> Powers<Self> {
Powers {
base: self.dup(),
current: start,
}
}
#[must_use]
#[inline]
fn dot_product<const N: usize>(u: &[Self; N], v: &[Self; N]) -> Self {
u.iter().zip(v).map(|(x, y)| x.dup() * y.dup()).sum()
}
#[must_use]
#[inline]
fn sum_array<const N: usize>(input: &[Self]) -> Self {
assert_eq!(N, input.len());
match N {
0 => Self::ZERO,
1 => input[0].dup(),
2 => input[0].dup() + input[1].dup(),
3 => input[0].dup() + input[1].dup() + input[2].dup(),
4 => (input[0].dup() + input[1].dup()) + (input[2].dup() + input[3].dup()),
5 => Self::sum_array::<4>(&input[..4]) + Self::sum_array::<1>(&input[4..]),
6 => Self::sum_array::<4>(&input[..4]) + Self::sum_array::<2>(&input[4..]),
7 => Self::sum_array::<4>(&input[..4]) + Self::sum_array::<3>(&input[4..]),
8 => Self::sum_array::<4>(&input[..4]) + Self::sum_array::<4>(&input[4..]),
_ => {
let mut acc = Self::sum_array::<8>(&input[..8]);
for i in (16..=N).step_by(8) {
acc += Self::sum_array::<8>(&input[(i - 8)..i]);
}
match N & 7 {
0 => acc,
1 => acc + Self::sum_array::<1>(&input[(8 * (N / 8))..]),
2 => acc + Self::sum_array::<2>(&input[(8 * (N / 8))..]),
3 => acc + Self::sum_array::<3>(&input[(8 * (N / 8))..]),
4 => acc + Self::sum_array::<4>(&input[(8 * (N / 8))..]),
5 => acc + Self::sum_array::<5>(&input[(8 * (N / 8))..]),
6 => acc + Self::sum_array::<6>(&input[(8 * (N / 8))..]),
7 => acc + Self::sum_array::<7>(&input[(8 * (N / 8))..]),
_ => unreachable!(),
}
}
}
}
#[must_use]
#[inline]
fn zero_vec(len: usize) -> Vec<Self> {
vec![Self::ZERO; len]
}
}
pub trait BasedVectorSpace<F: PrimeCharacteristicRing>: Sized {
const DIMENSION: usize;
#[must_use]
fn as_basis_coefficients_slice(&self) -> &[F];
#[must_use]
#[inline]
fn from_basis_coefficients_slice(slice: &[F]) -> Option<Self> {
Self::from_basis_coefficients_iter(slice.iter().cloned())
}
#[must_use]
fn from_basis_coefficients_fn<Fn: FnMut(usize) -> F>(f: Fn) -> Self;
#[must_use]
fn from_basis_coefficients_iter<I: ExactSizeIterator<Item = F>>(iter: I) -> Option<Self>;
#[must_use]
#[inline]
fn ith_basis_element(i: usize) -> Option<Self> {
(i < Self::DIMENSION).then(|| Self::from_basis_coefficients_fn(|j| F::from_bool(i == j)))
}
#[must_use]
#[inline]
fn flatten_to_base(vec: Vec<Self>) -> Vec<F> {
vec.into_iter()
.flat_map(|x| x.as_basis_coefficients_slice().to_vec())
.collect()
}
#[must_use]
#[inline]
fn reconstitute_from_base(vec: Vec<F>) -> Vec<Self>
where
F: Sync,
Self: Send,
{
assert_eq!(vec.len() % Self::DIMENSION, 0);
vec.par_chunks_exact(Self::DIMENSION)
.map(|chunk| {
Self::from_basis_coefficients_slice(chunk)
.expect("Chunk length not equal to dimension")
})
.collect()
}
}
impl<F: PrimeCharacteristicRing> BasedVectorSpace<F> for F {
const DIMENSION: usize = 1;
#[inline]
fn as_basis_coefficients_slice(&self) -> &[F] {
slice::from_ref(self)
}
#[inline]
fn from_basis_coefficients_fn<Fn: FnMut(usize) -> F>(mut f: Fn) -> Self {
f(0)
}
#[inline]
fn from_basis_coefficients_iter<I: ExactSizeIterator<Item = F>>(mut iter: I) -> Option<Self> {
(iter.len() == 1).then(|| iter.next().unwrap()) }
#[inline]
fn flatten_to_base(vec: Vec<Self>) -> Vec<F> {
vec
}
#[inline]
fn reconstitute_from_base(vec: Vec<F>) -> Vec<Self> {
vec
}
}
pub trait InjectiveMonomial<const N: u64>: PrimeCharacteristicRing {
#[must_use]
#[inline]
fn injective_exp_n(&self) -> Self {
self.exp_const_u64::<N>()
}
}
pub trait PermutationMonomial<const N: u64>: InjectiveMonomial<N> {
#[must_use]
fn injective_exp_root_n(&self) -> Self;
}
pub trait Algebra<F>:
PrimeCharacteristicRing
+ From<F>
+ Add<F, Output = Self>
+ AddAssign<F>
+ Sub<F, Output = Self>
+ SubAssign<F>
+ Mul<F, Output = Self>
+ MulAssign<F>
{
#[must_use]
#[inline]
fn mixed_dot_product<const N: usize>(a: &[Self; N], f: &[F; N]) -> Self
where
F: Dup,
{
let products: [Self; N] = core::array::from_fn(|i| a[i].dup() * f[i].dup());
Self::sum_array::<N>(&products)
}
const BATCHED_LC_CHUNK: usize = 8;
#[must_use]
#[inline]
fn batched_linear_combination(values: &[Self], coeffs: &[F]) -> Self
where
F: Dup,
{
const {
assert!(
matches!(Self::BATCHED_LC_CHUNK, 1 | 2 | 4 | 8 | 16 | 32 | 64),
"BATCHED_LC_CHUNK must be one of 1, 2, 4, 8, 16, 32, or 64"
);
}
match Self::BATCHED_LC_CHUNK {
1 => chunked_linear_combination::<1, Self, F>(values, coeffs),
2 => chunked_linear_combination::<2, Self, F>(values, coeffs),
4 => chunked_linear_combination::<4, Self, F>(values, coeffs),
8 => chunked_linear_combination::<8, Self, F>(values, coeffs),
16 => chunked_linear_combination::<16, Self, F>(values, coeffs),
32 => chunked_linear_combination::<32, Self, F>(values, coeffs),
64 => chunked_linear_combination::<64, Self, F>(values, coeffs),
_ => unreachable!(),
}
}
}
#[must_use]
#[inline]
pub fn chunked_linear_combination<const CHUNK: usize, A: Algebra<F> + Dup, F: Dup>(
values: &[A],
coeffs: &[F],
) -> A {
const { assert!(CHUNK != 0, "chunked_linear_combination requires CHUNK > 0") }
assert_eq!(values.len(), coeffs.len());
let (val_chunks, val_rem) = values.as_chunks::<CHUNK>();
let (coeff_chunks, coeff_rem) = coeffs.as_chunks::<CHUNK>();
debug_assert_eq!(val_chunks.len(), coeff_chunks.len());
let mut acc = A::ZERO;
for (vc, cc) in zip(val_chunks, coeff_chunks) {
acc += A::mixed_dot_product::<CHUNK>(vc, cc);
}
debug_assert_eq!(val_rem.len(), coeff_rem.len());
for (v, c) in zip(val_rem, coeff_rem) {
acc += v.dup() * c.dup();
}
acc
}
impl<R: PrimeCharacteristicRing> Algebra<R> for R {}
pub trait RawDataSerializable: Sized {
const NUM_BYTES: usize;
#[must_use]
fn into_bytes(self) -> impl IntoIterator<Item = u8>;
#[must_use]
fn into_byte_stream(input: impl IntoIterator<Item = Self>) -> impl IntoIterator<Item = u8> {
input.into_iter().flat_map(|elem| elem.into_bytes())
}
#[must_use]
fn into_u32_stream(input: impl IntoIterator<Item = Self>) -> impl IntoIterator<Item = u32> {
let bytes = Self::into_byte_stream(input);
iter_array_chunks_padded(bytes, 0).map(u32::from_le_bytes)
}
#[must_use]
fn into_u64_stream(input: impl IntoIterator<Item = Self>) -> impl IntoIterator<Item = u64> {
let bytes = Self::into_byte_stream(input);
iter_array_chunks_padded(bytes, 0).map(u64::from_le_bytes)
}
#[must_use]
fn into_parallel_byte_streams<const N: usize>(
input: impl IntoIterator<Item = [Self; N]>,
) -> impl IntoIterator<Item = [u8; N]> {
input.into_iter().flat_map(|vector| {
let bytes = vector.map(|elem| elem.into_bytes().into_iter().collect::<Vec<_>>());
(0..Self::NUM_BYTES).map(move |i| array::from_fn(|j| bytes[j][i]))
})
}
#[must_use]
fn into_parallel_u32_streams<const N: usize>(
input: impl IntoIterator<Item = [Self; N]>,
) -> impl IntoIterator<Item = [u32; N]> {
let bytes = Self::into_parallel_byte_streams(input);
iter_array_chunks_padded(bytes, [0; N]).map(|byte_array: [[u8; N]; 4]| {
array::from_fn(|i| u32::from_le_bytes(array::from_fn(|j| byte_array[j][i])))
})
}
#[must_use]
fn into_parallel_u64_streams<const N: usize>(
input: impl IntoIterator<Item = [Self; N]>,
) -> impl IntoIterator<Item = [u64; N]> {
let bytes = Self::into_parallel_byte_streams(input);
iter_array_chunks_padded(bytes, [0; N]).map(|byte_array: [[u8; N]; 8]| {
array::from_fn(|i| u64::from_le_bytes(array::from_fn(|j| byte_array[j][i])))
})
}
}
pub trait Field:
Algebra<Self>
+ RawDataSerializable
+ Packable
+ 'static
+ Copy
+ Div<Self, Output = Self>
+ DivAssign
+ Add<Self::Packing, Output = Self::Packing>
+ Sub<Self::Packing, Output = Self::Packing>
+ Mul<Self::Packing, Output = Self::Packing>
+ Eq
+ Hash
+ Send
+ Sync
+ Display
+ Serialize
+ DeserializeOwned
{
type Packing: PackedField<Scalar = Self>;
const GENERATOR: Self;
#[must_use]
#[inline]
fn is_zero(&self) -> bool {
*self == Self::ZERO
}
#[must_use]
#[inline]
fn is_one(&self) -> bool {
*self == Self::ONE
}
#[must_use]
fn try_inverse(&self) -> Option<Self>;
#[must_use]
fn inverse(&self) -> Self {
self.try_inverse().expect("Tried to invert zero")
}
#[inline]
fn add_slices(slice_1: &mut [Self], slice_2: &[Self]) {
let (shorts_1, suffix_1) = Self::Packing::pack_slice_with_suffix_mut(slice_1);
let (shorts_2, suffix_2) = Self::Packing::pack_slice_with_suffix(slice_2);
debug_assert_eq!(shorts_1.len(), shorts_2.len());
debug_assert_eq!(suffix_1.len(), suffix_2.len());
for (x_1, &x_2) in shorts_1.iter_mut().zip(shorts_2) {
*x_1 += x_2;
}
for (x_1, &x_2) in suffix_1.iter_mut().zip(suffix_2) {
*x_1 += x_2;
}
}
#[must_use]
fn order() -> BigUint;
#[must_use]
#[inline]
fn bits() -> usize {
Self::order().bits() as usize
}
}
pub trait PrimeField:
Field
+ Ord
+ QuotientMap<u8>
+ QuotientMap<u16>
+ QuotientMap<u32>
+ QuotientMap<u64>
+ QuotientMap<u128>
+ QuotientMap<usize>
+ QuotientMap<i8>
+ QuotientMap<i16>
+ QuotientMap<i32>
+ QuotientMap<i64>
+ QuotientMap<i128>
+ QuotientMap<isize>
{
#[must_use]
fn as_canonical_biguint(&self) -> BigUint;
}
pub trait PrimeField64: PrimeField {
const ORDER_U64: u64;
#[must_use]
fn as_canonical_u64(&self) -> u64;
#[must_use]
#[inline(always)]
fn to_unique_u64(&self) -> u64 {
self.as_canonical_u64()
}
}
pub trait PrimeField32: PrimeField64 {
const ORDER_U32: u32;
#[must_use]
fn as_canonical_u32(&self) -> u32;
#[must_use]
#[inline(always)]
fn to_unique_u32(&self) -> u32 {
self.as_canonical_u32()
}
}
pub trait ExtensionField<Base: Field>: Field + Algebra<Base> + BasedVectorSpace<Base> {
type ExtensionPacking: PackedFieldExtension<Base, Self> + 'static + Copy + Send + Sync;
#[must_use]
fn is_in_basefield(&self) -> bool;
#[must_use]
fn as_base(&self) -> Option<Base>;
}
impl<F: Field> ExtensionField<F> for F {
type ExtensionPacking = F::Packing;
#[inline]
fn is_in_basefield(&self) -> bool {
true
}
#[inline]
fn as_base(&self) -> Option<F> {
Some(*self)
}
}
pub trait TwoAdicField: Field {
const TWO_ADICITY: usize;
#[must_use]
fn two_adic_generator(bits: usize) -> Self;
}
#[derive(Clone, Debug)]
pub struct Powers<R: PrimeCharacteristicRing> {
pub base: R,
pub current: R,
}
impl<R: PrimeCharacteristicRing> Iterator for Powers<R> {
type Item = R;
fn next(&mut self) -> Option<R> {
let result = self.current.dup();
self.current *= self.base.dup();
Some(result)
}
}
impl<R: PrimeCharacteristicRing> Powers<R> {
#[inline]
#[must_use]
pub const fn take(self, n: usize) -> BoundedPowers<R> {
BoundedPowers { iter: self, n }
}
#[inline]
pub fn fill(self, slice: &mut [R]) {
slice
.iter_mut()
.zip(self)
.for_each(|(out, next)| *out = next);
}
#[inline]
#[must_use]
pub fn collect_n(self, n: usize) -> Vec<R> {
self.take(n).collect()
}
}
impl<F: Field> BoundedPowers<F> {
#[must_use]
pub fn collect(self) -> Vec<F> {
let num_powers = self.n;
if num_powers < 16 {
return self.take(num_powers).collect();
}
let width = F::Packing::WIDTH;
let num_packed = num_powers.div_ceil(width);
let mut points_packed = F::Packing::zero_vec(num_packed);
let num_threads = current_num_threads().max(1);
let chunk_size = num_packed.div_ceil(num_threads);
let base = self.iter.base;
let chunk_base = base.exp_u64((chunk_size * width) as u64);
let shift = self.iter.current;
points_packed
.par_chunks_mut(chunk_size)
.enumerate()
.for_each(|(chunk_idx, chunk_slice)| {
let chunk_start = shift * chunk_base.exp_u64(chunk_idx as u64);
F::Packing::packed_shifted_powers(base, chunk_start).fill(chunk_slice);
});
let mut points = unsafe { flatten_to_base(points_packed) };
points.truncate(num_powers);
points
}
}
#[derive(Clone, Debug)]
pub struct BoundedPowers<R: PrimeCharacteristicRing> {
iter: Powers<R>,
n: usize,
}
impl<R: PrimeCharacteristicRing> Iterator for BoundedPowers<R> {
type Item = R;
fn next(&mut self) -> Option<R> {
(self.n != 0).then(|| {
self.n -= 1;
self.iter.next().unwrap()
})
}
}