use alloc::vec;
use alloc::vec::Vec;
use core::iter::Sum;
use core::mem::MaybeUninit;
use core::ops::Mul;
use num_bigint::BigUint;
use p3_maybe_rayon::prelude::*;
use crate::field::Field;
use crate::{PackedValue, PrimeCharacteristicRing, PrimeField, PrimeField32};
pub fn cyclic_subgroup_known_order<F: Field>(
generator: F,
order: usize,
) -> impl Iterator<Item = F> + Clone {
generator.powers().take(order)
}
pub fn cyclic_subgroup_coset_known_order<F: Field>(
generator: F,
shift: F,
order: usize,
) -> impl Iterator<Item = F> + Clone {
generator.shifted_powers(shift).take(order)
}
pub fn scale_slice_in_place_single_core<F: Field>(slice: &mut [F], s: F) {
let (packed, sfx) = F::Packing::pack_slice_with_suffix_mut(slice);
let packed_s: F::Packing = s.into();
packed.iter_mut().for_each(|x| *x *= packed_s);
sfx.iter_mut().for_each(|x| *x *= s);
}
#[inline]
pub fn par_scale_slice_in_place<F: Field>(slice: &mut [F], s: F) {
let (packed, sfx) = F::Packing::pack_slice_with_suffix_mut(slice);
let packed_s: F::Packing = s.into();
packed.par_iter_mut().for_each(|x| *x *= packed_s);
sfx.iter_mut().for_each(|x| *x *= s);
}
pub fn add_scaled_slice_in_place<F: Field>(slice: &mut [F], other: &[F], s: F) {
debug_assert_eq!(slice.len(), other.len(), "slices must have equal length");
let (slice_packed, slice_sfx) = F::Packing::pack_slice_with_suffix_mut(slice);
let (other_packed, other_sfx) = F::Packing::pack_slice_with_suffix(other);
let packed_s: F::Packing = s.into();
slice_packed
.iter_mut()
.zip(other_packed)
.for_each(|(x, y)| *x += *y * packed_s);
slice_sfx
.iter_mut()
.zip(other_sfx)
.for_each(|(x, y)| *x += *y * s);
}
pub fn par_add_scaled_slice_in_place<F: Field>(slice: &mut [F], other: &[F], s: F) {
debug_assert_eq!(slice.len(), other.len(), "slices must have equal length");
let (slice_packed, slice_sfx) = F::Packing::pack_slice_with_suffix_mut(slice);
let (other_packed, other_sfx) = F::Packing::pack_slice_with_suffix(other);
let packed_s: F::Packing = s.into();
slice_packed
.par_iter_mut()
.zip(other_packed.par_iter())
.for_each(|(x, y)| *x += *y * packed_s);
slice_sfx
.iter_mut()
.zip(other_sfx)
.for_each(|(x, y)| *x += *y * s);
}
#[inline]
#[must_use]
pub const fn field_to_array<R: PrimeCharacteristicRing, const D: usize>(x: R) -> [R; D] {
let mut arr = [const { MaybeUninit::uninit() }; D];
arr[0] = MaybeUninit::new(x);
let mut i = 1;
while i < D {
arr[i] = MaybeUninit::new(R::ZERO);
i += 1;
}
unsafe { core::mem::transmute_copy::<_, [R; D]>(&arr) }
}
#[inline]
#[must_use]
pub const fn halve_u32<const P: u32>(x: u32) -> u32 {
let shift = (P + 1) >> 1;
let half = x >> 1;
if x & 1 == 0 { half } else { half + shift }
}
#[inline]
#[must_use]
pub const fn halve_u64<const P: u64>(x: u64) -> u64 {
let shift = (P + 1) >> 1;
let half = x >> 1;
if x & 1 == 0 { half } else { half + shift }
}
#[must_use]
pub fn reduce_32<SF: PrimeField32, TF: PrimeField>(vals: &[SF]) -> TF {
reduce_packed(vals, 32)
}
#[must_use]
pub fn reduce_packed_shifted<SF: PrimeField32, TF: PrimeField>(vals: &[SF], radix_bits: u32) -> TF {
debug_assert!((radix_bits < 64) && ((SF::ORDER_U32 as u64) < (1u64 << radix_bits)));
let base = TF::from_int(1u64 << radix_bits);
vals.iter().rev().fold(TF::ZERO, |acc, val| {
acc * base + TF::from_int(val.as_canonical_u32() as u64 + 1)
})
}
#[inline]
#[must_use]
pub const fn absorb_radix_bits<F: PrimeField32>() -> u32 {
u32::BITS - (F::ORDER_U32 - 1).leading_zeros()
}
#[must_use]
pub fn reduce_packed<SF: PrimeField32, TF: PrimeField>(vals: &[SF], radix_bits: u32) -> TF {
debug_assert!((absorb_radix_bits::<SF>() <= radix_bits) && (radix_bits < 64));
let base = TF::from_int(1u64 << radix_bits);
vals.iter().rev().fold(TF::ZERO, |acc, val| {
acc * base + TF::from_int(val.as_canonical_u32())
})
}
#[inline]
#[must_use]
pub const fn injective_pack_bits<F: PrimeField32>() -> u32 {
(F::ORDER_U32 - 1).ilog2()
}
#[must_use]
pub fn max_packed_injective_limbs<F: PrimeField32, PF: PrimeField>(radix_bits: u32) -> usize {
max_packed_injective_limbs_with_max_digit::<PF>(radix_bits, F::ORDER_U32 - 1)
}
fn max_packed_injective_limbs_with_max_digit<PF: PrimeField>(
radix_bits: u32,
max_digit: u32,
) -> usize {
debug_assert!((0 < radix_bits) && (radix_bits < 64));
let max_digit = BigUint::from(max_digit);
let base = BigUint::from(1u32) << (radix_bits as usize);
let pf_order = PF::order();
let mut k = 0usize;
let mut max_val = BigUint::ZERO;
let mut power = BigUint::from(1u32);
loop {
let new_max = &max_val + &max_digit * &power;
if new_max >= pf_order {
break k;
}
max_val = new_max;
power *= &base;
k += 1;
}
}
#[must_use]
pub fn max_shifted_packed_injective_limbs<F: PrimeField32, PF: PrimeField>(
radix_bits: u32,
) -> usize {
max_packed_injective_limbs_with_max_digit::<PF>(radix_bits, F::ORDER_U32)
}
#[must_use]
pub fn max_absorb_injective_limbs<F: PrimeField32, PF: PrimeField>() -> usize {
max_packed_injective_limbs::<F, PF>(absorb_radix_bits::<F>())
}
#[must_use]
pub fn max_shifted_absorb_injective_limbs<F: PrimeField32, PF: PrimeField>() -> usize {
max_shifted_packed_injective_limbs::<F, PF>(absorb_radix_bits::<F>())
}
#[must_use]
pub fn pf_packed_limbs_cover_order<SF: PrimeField>(num_limbs: usize, radix_bits: u32) -> bool {
let Some(total_bits) = num_limbs.checked_mul(radix_bits as usize) else {
return false;
};
(BigUint::from(1u32) << total_bits) >= SF::order()
}
#[must_use]
pub fn split_pf_to_packed_limbs<SF: PrimeField, TF: PrimeField32>(
val: SF,
num_limbs: usize,
radix_bits: u32,
) -> Vec<TF> {
debug_assert!((0 < radix_bits) && (radix_bits < 64));
debug_assert!(
radix_bits <= injective_pack_bits::<TF>(),
"radix_bits must be ≤ injective_pack_bits::<TF>() for injective limb embedding"
);
let mask_u32: u32 = (1u32 << radix_bits) - 1;
let mut rem = val.as_canonical_biguint();
let mut out = vec![TF::ZERO; num_limbs];
for item in out.iter_mut() {
let limb = rem.iter_u32_digits().next().unwrap_or(0) & mask_u32;
*item = TF::from_int(limb);
rem >>= radix_bits;
}
out
}
#[must_use]
pub fn squeeze_field_order_num_limbs<PF: PrimeField, TF: PrimeField32>() -> usize {
let p = BigUint::from(TF::ORDER_U32);
let n = PF::order();
let mut count = 0usize;
let mut power = BigUint::from(1u32);
while &power * &p < n {
power *= &p;
count += 1;
}
count.saturating_sub(1)
}
#[must_use]
pub fn split_pf_to_field_order_limbs<SF: PrimeField, TF: PrimeField32>(
val: SF,
num_limbs: usize,
) -> Vec<TF> {
let p_u32 = TF::ORDER_U32;
let mut rem = val.as_canonical_biguint();
let mut out = Vec::with_capacity(num_limbs);
for _ in 0..num_limbs {
let limb = (&rem % p_u32).to_u32_digits().first().copied().unwrap_or(0);
out.push(TF::from_int(limb));
rem /= p_u32;
}
out
}
#[must_use]
pub fn split_32<SF: PrimeField, TF: PrimeField32>(val: SF, n: usize) -> Vec<TF> {
let mut result: Vec<TF> = val
.as_canonical_biguint()
.to_u64_digits()
.iter()
.take(n)
.map(|d| TF::from_u64(*d))
.collect();
result.resize_with(n, || TF::ZERO);
result
}
#[must_use]
pub fn dot_product<S, LI, RI>(li: LI, ri: RI) -> S
where
LI: Iterator,
RI: Iterator,
LI::Item: Mul<RI::Item>,
S: Sum<<LI::Item as Mul<RI::Item>>::Output>,
{
li.zip(ri).map(|(l, r)| l * r).sum()
}