use core::iter::{Product, Sum};
use core::mem::MaybeUninit;
use core::ops::{Div, DivAssign};
use core::{array, slice};
use crate::field::Field;
use crate::{Algebra, BasedVectorSpace, ExtensionField, Powers, PrimeCharacteristicRing};
pub trait Packable: 'static + Default + Copy + Send + Sync + PartialEq + Eq {}
pub unsafe trait PackedValue: 'static + Copy + Send + Sync {
type Value: Packable;
const WIDTH: usize;
#[must_use]
fn from_slice(slice: &[Self::Value]) -> &Self;
#[must_use]
fn from_slice_mut(slice: &mut [Self::Value]) -> &mut Self;
#[must_use]
fn from_fn<F>(f: F) -> Self
where
F: FnMut(usize) -> Self::Value;
#[must_use]
fn as_slice(&self) -> &[Self::Value];
#[must_use]
fn as_slice_mut(&mut self) -> &mut [Self::Value];
#[inline]
#[must_use]
fn pack_slice(buf: &[Self::Value]) -> &[Self] {
const {
assert!(align_of::<Self>() <= align_of::<Self::Value>());
}
assert!(
buf.len().is_multiple_of(Self::WIDTH),
"Slice length (got {}) must be a multiple of packed field width ({}).",
buf.len(),
Self::WIDTH
);
let buf_ptr = buf.as_ptr().cast::<Self>();
let n = buf.len() / Self::WIDTH;
unsafe { slice::from_raw_parts(buf_ptr, n) }
}
#[inline]
#[must_use]
fn pack_slice_with_suffix(buf: &[Self::Value]) -> (&[Self], &[Self::Value]) {
let (packed, suffix) = buf.split_at(buf.len() - buf.len() % Self::WIDTH);
(Self::pack_slice(packed), suffix)
}
#[inline]
#[must_use]
fn pack_slice_mut(buf: &mut [Self::Value]) -> &mut [Self] {
const {
assert!(align_of::<Self>() <= align_of::<Self::Value>());
}
assert!(
buf.len().is_multiple_of(Self::WIDTH),
"Slice length (got {}) must be a multiple of packed field width ({}).",
buf.len(),
Self::WIDTH
);
let buf_ptr = buf.as_mut_ptr().cast::<Self>();
let n = buf.len() / Self::WIDTH;
unsafe { slice::from_raw_parts_mut(buf_ptr, n) }
}
#[inline]
#[must_use]
fn pack_maybe_uninit_slice_mut(
buf: &mut [MaybeUninit<Self::Value>],
) -> &mut [MaybeUninit<Self>] {
const {
assert!(align_of::<Self>() <= align_of::<Self::Value>());
}
assert!(
buf.len().is_multiple_of(Self::WIDTH),
"Slice length (got {}) must be a multiple of packed field width ({}).",
buf.len(),
Self::WIDTH
);
let buf_ptr = buf.as_mut_ptr().cast::<MaybeUninit<Self>>();
let n = buf.len() / Self::WIDTH;
unsafe { slice::from_raw_parts_mut(buf_ptr, n) }
}
#[inline]
#[must_use]
fn pack_slice_with_suffix_mut(buf: &mut [Self::Value]) -> (&mut [Self], &mut [Self::Value]) {
let (packed, suffix) = buf.split_at_mut(buf.len() - buf.len() % Self::WIDTH);
(Self::pack_slice_mut(packed), suffix)
}
#[inline]
#[must_use]
fn pack_maybe_uninit_slice_with_suffix_mut(
buf: &mut [MaybeUninit<Self::Value>],
) -> (&mut [MaybeUninit<Self>], &mut [MaybeUninit<Self::Value>]) {
let (packed, suffix) = buf.split_at_mut(buf.len() - buf.len() % Self::WIDTH);
(Self::pack_maybe_uninit_slice_mut(packed), suffix)
}
#[inline]
#[must_use]
fn unpack_slice(buf: &[Self]) -> &[Self::Value] {
const {
assert!(align_of::<Self>() >= align_of::<Self::Value>());
}
let buf_ptr = buf.as_ptr().cast::<Self::Value>();
let n = buf.len() * Self::WIDTH;
unsafe { slice::from_raw_parts(buf_ptr, n) }
}
#[inline]
#[must_use]
fn extract(&self, lane: usize) -> Self::Value {
self.as_slice()[lane]
}
#[inline]
fn unpack_into<const N: usize>(packed: &[Self; N], rows: &mut [[Self::Value; N]]) {
assert_eq!(rows.len(), Self::WIDTH);
#[allow(clippy::needless_range_loop)]
for lane in 0..Self::WIDTH {
rows[lane] = array::from_fn(|col| packed[col].extract(lane));
}
}
}
unsafe impl<T: Packable, const WIDTH: usize> PackedValue for [T; WIDTH] {
type Value = T;
const WIDTH: usize = WIDTH;
#[inline]
fn from_slice(slice: &[Self::Value]) -> &Self {
assert_eq!(slice.len(), Self::WIDTH);
unsafe { &*slice.as_ptr().cast() }
}
#[inline]
fn from_slice_mut(slice: &mut [Self::Value]) -> &mut Self {
assert_eq!(slice.len(), Self::WIDTH);
unsafe { &mut *slice.as_mut_ptr().cast() }
}
#[inline]
fn from_fn<Fn>(f: Fn) -> Self
where
Fn: FnMut(usize) -> Self::Value,
{
core::array::from_fn(f)
}
#[inline]
fn as_slice(&self) -> &[Self::Value] {
self
}
#[inline]
fn as_slice_mut(&mut self) -> &mut [Self::Value] {
self
}
}
pub unsafe trait PackedField: Algebra<Self::Scalar>
+ PackedValue<Value = Self::Scalar>
+ Div<Self::Scalar, Output = Self>
+ DivAssign<Self::Scalar>
+ Sum<Self::Scalar>
+ Product<Self::Scalar>
{
type Scalar: Field;
#[must_use]
fn packed_powers(base: Self::Scalar) -> Powers<Self> {
Self::packed_shifted_powers(base, Self::Scalar::ONE)
}
#[must_use]
fn packed_shifted_powers(base: Self::Scalar, start: Self::Scalar) -> Powers<Self> {
let mut current: Self = start.into();
let slice = current.as_slice_mut();
for i in 1..Self::WIDTH {
slice[i] = slice[i - 1] * base;
}
Powers {
base: base.exp_u64(Self::WIDTH as u64).into(),
current,
}
}
#[must_use]
fn packed_linear_combination<const N: usize>(coeffs: &[Self::Scalar], vecs: &[Self]) -> Self {
assert_eq!(coeffs.len(), N);
assert_eq!(vecs.len(), N);
let combined: [Self; N] = array::from_fn(|i| vecs[i] * coeffs[i]);
Self::sum_array::<N>(&combined)
}
}
pub unsafe trait PackedFieldPow2: PackedField {
#[must_use]
fn interleave(&self, other: Self, block_len: usize) -> (Self, Self);
}
pub trait PackedFieldExtension<
BaseField: Field,
ExtField: ExtensionField<BaseField, ExtensionPacking = Self>,
>: Algebra<ExtField> + Algebra<BaseField::Packing> + BasedVectorSpace<BaseField::Packing>
{
#[must_use]
fn from_ext_slice(ext_slice: &[ExtField]) -> Self;
#[inline]
#[must_use]
fn extract(&self, lane: usize) -> ExtField {
ExtField::from_basis_coefficients_fn(|d| {
self.as_basis_coefficients_slice()[d].as_slice()[lane]
})
}
#[inline]
#[must_use]
fn to_ext_iter(iter: impl IntoIterator<Item = Self>) -> impl Iterator<Item = ExtField> {
iter.into_iter()
.flat_map(|x| (0..BaseField::Packing::WIDTH).map(move |i| x.extract(i)))
}
#[must_use]
fn packed_ext_powers(base: ExtField) -> Powers<Self>;
#[must_use]
fn packed_ext_powers_capped(base: ExtField, unpacked_len: usize) -> impl Iterator<Item = Self> {
Self::packed_ext_powers(base).take(unpacked_len.div_ceil(BaseField::Packing::WIDTH))
}
}
unsafe impl<T: Packable> PackedValue for T {
type Value = Self;
const WIDTH: usize = 1;
#[inline]
fn from_slice(slice: &[Self::Value]) -> &Self {
assert_eq!(slice.len(), Self::WIDTH);
&slice[0]
}
#[inline]
fn from_slice_mut(slice: &mut [Self::Value]) -> &mut Self {
assert_eq!(slice.len(), Self::WIDTH);
&mut slice[0]
}
#[inline]
fn from_fn<Fn>(mut f: Fn) -> Self
where
Fn: FnMut(usize) -> Self::Value,
{
f(0)
}
#[inline]
fn as_slice(&self) -> &[Self::Value] {
slice::from_ref(self)
}
#[inline]
fn as_slice_mut(&mut self) -> &mut [Self::Value] {
slice::from_mut(self)
}
}
unsafe impl<F: Field> PackedField for F {
type Scalar = Self;
}
unsafe impl<F: Field> PackedFieldPow2 for F {
#[inline]
fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) {
match block_len {
1 => (*self, other),
_ => panic!("unsupported block length"),
}
}
}
impl<F: Field> PackedFieldExtension<F, F> for F::Packing {
#[inline]
fn from_ext_slice(ext_slice: &[F]) -> Self {
*F::Packing::from_slice(ext_slice)
}
#[inline]
fn packed_ext_powers(base: F) -> Powers<Self> {
F::Packing::packed_powers(base)
}
}
impl Packable for u8 {}
impl Packable for u16 {}
impl Packable for u32 {}
impl Packable for u64 {}
impl Packable for u128 {}