p3_field/
packed.rs

1use alloc::vec::Vec;
2use core::mem::MaybeUninit;
3use core::ops::Div;
4use core::{array, slice};
5
6use crate::field::Field;
7use crate::{Algebra, BasedVectorSpace, ExtensionField, Powers, PrimeCharacteristicRing};
8
9/// A trait to constrain types that can be packed into a packed value.
10///
11/// The `Packable` trait allows us to specify implementations for potentially conflicting types.
12pub trait Packable: 'static + Default + Copy + Send + Sync + PartialEq + Eq {}
13
14/// A trait for array-like structs made up of multiple scalar elements.
15///
16/// # Safety
17/// - If `P` implements `PackedField` then `P` must be castable to/from `[P::Value; P::WIDTH]`
18///   without UB.
19pub unsafe trait PackedValue: 'static + Copy + Send + Sync {
20    /// The scalar type that is packed into this value.
21    type Value: Packable;
22
23    /// Number of scalar values packed together.
24    const WIDTH: usize;
25
26    /// Interprets a slice of scalar values as a packed value reference.
27    ///
28    /// # Panics:
29    /// This function will panic if `slice.len() != Self::WIDTH`
30    fn from_slice(slice: &[Self::Value]) -> &Self;
31
32    /// Interprets a mutable slice of scalar values as a mutable packed value.
33    ///
34    /// # Panics:
35    /// This function will panic if `slice.len() != Self::WIDTH`
36    fn from_slice_mut(slice: &mut [Self::Value]) -> &mut Self;
37
38    /// Constructs a packed value using a function to generate each element.
39    ///
40    /// Similar to `core:array::from_fn`.
41    fn from_fn<F>(f: F) -> Self
42    where
43        F: FnMut(usize) -> Self::Value;
44
45    /// Returns the underlying scalar values as an immutable slice.
46    fn as_slice(&self) -> &[Self::Value];
47
48    /// Returns the underlying scalar values as a mutable slice.
49    fn as_slice_mut(&mut self) -> &mut [Self::Value];
50
51    /// Packs a slice of scalar values into a slice of packed values.
52    ///
53    /// # Panics
54    /// Panics if the slice length is not divisible by `WIDTH`.
55    fn pack_slice(buf: &[Self::Value]) -> &[Self] {
56        // Sources vary, but this should be true on all platforms we care about.
57        // This should be a const assert, but trait methods can't access `Self` in a const context,
58        // even with inner struct instantiation. So we will trust LLVM to optimize this out.
59        assert!(align_of::<Self>() <= align_of::<Self::Value>());
60        assert!(
61            buf.len() % Self::WIDTH == 0,
62            "Slice length (got {}) must be a multiple of packed field width ({}).",
63            buf.len(),
64            Self::WIDTH
65        );
66        let buf_ptr = buf.as_ptr().cast::<Self>();
67        let n = buf.len() / Self::WIDTH;
68        unsafe { slice::from_raw_parts(buf_ptr, n) }
69    }
70
71    /// Packs a slice into packed values and returns the packed portion and any remaining suffix.
72    fn pack_slice_with_suffix(buf: &[Self::Value]) -> (&[Self], &[Self::Value]) {
73        let (packed, suffix) = buf.split_at(buf.len() - buf.len() % Self::WIDTH);
74        (Self::pack_slice(packed), suffix)
75    }
76
77    /// Converts a mutable slice of scalar values into a mutable slice of packed values.
78    ///
79    /// # Panics
80    /// Panics if the slice length is not divisible by `WIDTH`.
81    fn pack_slice_mut(buf: &mut [Self::Value]) -> &mut [Self] {
82        assert!(align_of::<Self>() <= align_of::<Self::Value>());
83        assert!(
84            buf.len() % Self::WIDTH == 0,
85            "Slice length (got {}) must be a multiple of packed field width ({}).",
86            buf.len(),
87            Self::WIDTH
88        );
89        let buf_ptr = buf.as_mut_ptr().cast::<Self>();
90        let n = buf.len() / Self::WIDTH;
91        unsafe { slice::from_raw_parts_mut(buf_ptr, n) }
92    }
93
94    /// Converts a mutable slice of possibly uninitialized scalar values into
95    /// a mutable slice of possibly uninitialized packed values.
96    ///
97    /// # Panics
98    /// Panics if the slice length is not divisible by `WIDTH`.
99    fn pack_maybe_uninit_slice_mut(
100        buf: &mut [MaybeUninit<Self::Value>],
101    ) -> &mut [MaybeUninit<Self>] {
102        assert!(align_of::<Self>() <= align_of::<Self::Value>());
103        assert!(
104            buf.len() % Self::WIDTH == 0,
105            "Slice length (got {}) must be a multiple of packed field width ({}).",
106            buf.len(),
107            Self::WIDTH
108        );
109        let buf_ptr = buf.as_mut_ptr().cast::<MaybeUninit<Self>>();
110        let n = buf.len() / Self::WIDTH;
111        unsafe { slice::from_raw_parts_mut(buf_ptr, n) }
112    }
113
114    /// Converts a mutable slice of scalar values into a pair:
115    /// - a slice of packed values covering the largest aligned portion,
116    /// - and a remainder slice of scalar values that couldn't be packed.
117    fn pack_slice_with_suffix_mut(buf: &mut [Self::Value]) -> (&mut [Self], &mut [Self::Value]) {
118        let (packed, suffix) = buf.split_at_mut(buf.len() - buf.len() % Self::WIDTH);
119        (Self::pack_slice_mut(packed), suffix)
120    }
121
122    /// Converts a mutable slice of possibly uninitialized scalar values into a pair:
123    /// - a slice of possibly uninitialized packed values covering the largest aligned portion,
124    /// - and a remainder slice of possibly uninitialized scalar values that couldn't be packed.
125    fn pack_maybe_uninit_slice_with_suffix_mut(
126        buf: &mut [MaybeUninit<Self::Value>],
127    ) -> (&mut [MaybeUninit<Self>], &mut [MaybeUninit<Self::Value>]) {
128        let (packed, suffix) = buf.split_at_mut(buf.len() - buf.len() % Self::WIDTH);
129        (Self::pack_maybe_uninit_slice_mut(packed), suffix)
130    }
131
132    /// Reinterprets a slice of packed values as a flat slice of scalar values.
133    ///
134    /// Each packed value contains `Self::WIDTH` scalar values, which are laid out
135    /// contiguously in memory. This function allows direct access to those scalars.
136    fn unpack_slice(buf: &[Self]) -> &[Self::Value] {
137        assert!(align_of::<Self>() >= align_of::<Self::Value>());
138        let buf_ptr = buf.as_ptr().cast::<Self::Value>();
139        let n = buf.len() * Self::WIDTH;
140        unsafe { slice::from_raw_parts(buf_ptr, n) }
141    }
142}
143
144unsafe impl<T: Packable, const WIDTH: usize> PackedValue for [T; WIDTH] {
145    type Value = T;
146    const WIDTH: usize = WIDTH;
147
148    fn from_slice(slice: &[Self::Value]) -> &Self {
149        assert_eq!(slice.len(), Self::WIDTH);
150        slice.try_into().unwrap()
151    }
152
153    fn from_slice_mut(slice: &mut [Self::Value]) -> &mut Self {
154        assert_eq!(slice.len(), Self::WIDTH);
155        slice.try_into().unwrap()
156    }
157
158    fn from_fn<F>(f: F) -> Self
159    where
160        F: FnMut(usize) -> Self::Value,
161    {
162        core::array::from_fn(f)
163    }
164
165    fn as_slice(&self) -> &[Self::Value] {
166        self
167    }
168
169    fn as_slice_mut(&mut self) -> &mut [Self::Value] {
170        self
171    }
172}
173
174/// An array of field elements which can be packed into a vector for SIMD operations.
175///
176/// # Safety
177/// - See `PackedValue` above.
178pub unsafe trait PackedField: Algebra<Self::Scalar>
179    + PackedValue<Value = Self::Scalar>
180    // TODO: Implement packed / packed division
181    + Div<Self::Scalar, Output = Self>
182{
183    type Scalar: Field;
184
185    /// Construct an iterator which returns powers of `base` packed into packed field elements.
186    ///
187    /// E.g. if `Self::WIDTH = 4`, returns: `[base^0, base^1, base^2, base^3], [base^4, base^5, base^6, base^7], ...`.
188    #[must_use]
189    fn packed_powers(base: Self::Scalar) -> Powers<Self> {
190        Self::packed_shifted_powers(base, Self::Scalar::ONE)
191    }
192
193    /// Construct an iterator which returns powers of `base` multiplied by `start` and packed into packed field elements.
194    ///
195    /// E.g. if `Self::WIDTH = 4`, returns: `[start, start*base, start*base^2, start*base^3], [start*base^4, start*base^5, start*base^6, start*base^7], ...`.
196    #[must_use]
197    fn packed_shifted_powers(base: Self::Scalar, start: Self::Scalar) -> Powers<Self> {
198        let mut current: Self = start.into();
199        let slice = current.as_slice_mut();
200        for i in 1..Self::WIDTH {
201            slice[i] = slice[i - 1] * base;
202        }
203
204        Powers {
205            base: base.exp_u64(Self::WIDTH as u64).into(),
206            current,
207        }
208    }
209
210    /// Compute a linear combination of a slice of base field elements and
211    /// a slice of packed field elements. The slices must have equal length
212    /// and it must be a compile time constant.
213    ///
214    /// # Panics
215    ///
216    /// May panic if the length of either slice is not equal to `N`.
217    fn packed_linear_combination<const N: usize>(coeffs: &[Self::Scalar], vecs: &[Self]) -> Self {
218        assert_eq!(coeffs.len(), N);
219        assert_eq!(vecs.len(), N);
220        let combined: [Self; N] = array::from_fn(|i| vecs[i] * coeffs[i]);
221        Self::sum_array::<N>(&combined)
222    }
223}
224
225/// # Safety
226/// - `WIDTH` is assumed to be a power of 2.
227pub unsafe trait PackedFieldPow2: PackedField {
228    /// Take interpret two vectors as chunks of `block_len` elements. Unpack and interleave those
229    /// chunks. This is best seen with an example. If we have:
230    /// ```text
231    /// A = [x0, y0, x1, y1]
232    /// B = [x2, y2, x3, y3]
233    /// ```
234    ///
235    /// then
236    ///
237    /// ```text
238    /// interleave(A, B, 1) = ([x0, x2, x1, x3], [y0, y2, y1, y3])
239    /// ```
240    ///
241    /// Pairs that were adjacent in the input are at corresponding positions in the output.
242    ///
243    /// `r` lets us set the size of chunks we're interleaving. If we set `block_len = 2`, then for
244    ///
245    /// ```text
246    /// A = [x0, x1, y0, y1]
247    /// B = [x2, x3, y2, y3]
248    /// ```
249    ///
250    /// we obtain
251    ///
252    /// ```text
253    /// interleave(A, B, block_len) = ([x0, x1, x2, x3], [y0, y1, y2, y3])
254    /// ```
255    ///
256    /// We can also think about this as stacking the vectors, dividing them into 2x2 matrices, and
257    /// transposing those matrices.
258    ///
259    /// When `block_len = WIDTH`, this operation is a no-op.
260    ///
261    /// # Panics
262    /// This may panic if `block_len` does not divide `WIDTH`. Since `WIDTH` is specified to be a power of 2,
263    /// `block_len` must also be a power of 2. It cannot be 0 and it cannot exceed `WIDTH`.
264    fn interleave(&self, other: Self, block_len: usize) -> (Self, Self);
265}
266
267/// Fix a field `F` a packing width `W` and an extension field `EF` of `F`.
268///
269/// By choosing a basis `B`, `EF` can be transformed into an array `[F; D]`.
270///
271/// A type should implement PackedFieldExtension if it can be transformed into `[F::Packing; D] ~ [[F; W]; D]`
272///
273/// This is interpreted by taking a transpose to get `[[F; D]; W]` which can then be reinterpreted
274/// as `[EF; W]` by making use of the chosen basis `B` again.
275pub trait PackedFieldExtension<
276    BaseField: Field,
277    ExtField: ExtensionField<BaseField, ExtensionPacking = Self>,
278>: Algebra<ExtField> + Algebra<BaseField::Packing> + BasedVectorSpace<BaseField::Packing>
279{
280    /// Given a slice of extension field `EF` elements of length `W`,
281    /// convert into the array `[[F; D]; W]` transpose to
282    /// `[[F; W]; D]` and then pack to get `[PF; D]`.
283    fn from_ext_slice(ext_slice: &[ExtField]) -> Self;
284
285    /// Given a iterator of packed extension field elements, convert to an iterator of
286    /// extension field elements.
287    ///
288    /// This performs the inverse transformation to `from_ext_slice`.
289    #[inline]
290    fn to_ext_iter(iter: impl IntoIterator<Item = Self>) -> impl Iterator<Item = ExtField> {
291        iter.into_iter().flat_map(|x| {
292            let packed_coeffs = x.as_basis_coefficients_slice();
293            (0..BaseField::Packing::WIDTH)
294                .map(|i| ExtField::from_basis_coefficients_fn(|j| packed_coeffs[j].as_slice()[i]))
295                .collect::<Vec<_>>() // PackedFieldExtension's should reimplement this to avoid this allocation.
296        })
297    }
298
299    /// Similar to `packed_powers`, construct an iterator which returns
300    /// powers of `base` packed into `PackedFieldExtension` elements.
301    fn packed_ext_powers(base: ExtField) -> Powers<Self>;
302
303    /// Similar to `packed_ext_powers` but only returns `unpacked_len` powers of `base`.
304    ///
305    /// Note that the length of the returned iterator will be `unpacked_len / WIDTH` and
306    /// not `len` as the iterator is over packed extension field elements. If `unpacked_len`
307    /// is not divisible by `WIDTH`, `unpacked_len` will be rounded up to the next multiple of `WIDTH`.
308    fn packed_ext_powers_capped(base: ExtField, unpacked_len: usize) -> impl Iterator<Item = Self> {
309        Self::packed_ext_powers(base).take(unpacked_len.div_ceil(BaseField::Packing::WIDTH))
310    }
311}
312
313unsafe impl<T: Packable> PackedValue for T {
314    type Value = Self;
315
316    const WIDTH: usize = 1;
317
318    fn from_slice(slice: &[Self::Value]) -> &Self {
319        &slice[0]
320    }
321
322    fn from_slice_mut(slice: &mut [Self::Value]) -> &mut Self {
323        &mut slice[0]
324    }
325
326    fn from_fn<Fn>(mut f: Fn) -> Self
327    where
328        Fn: FnMut(usize) -> Self::Value,
329    {
330        f(0)
331    }
332
333    fn as_slice(&self) -> &[Self::Value] {
334        slice::from_ref(self)
335    }
336
337    fn as_slice_mut(&mut self) -> &mut [Self::Value] {
338        slice::from_mut(self)
339    }
340}
341
342unsafe impl<F: Field> PackedField for F {
343    type Scalar = Self;
344}
345
346unsafe impl<F: Field> PackedFieldPow2 for F {
347    fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) {
348        match block_len {
349            1 => (*self, other),
350            _ => panic!("unsupported block length"),
351        }
352    }
353}
354
355impl<F: Field> PackedFieldExtension<F, F> for F::Packing {
356    fn from_ext_slice(ext_slice: &[F]) -> Self {
357        *F::Packing::from_slice(ext_slice)
358    }
359
360    fn packed_ext_powers(base: F) -> Powers<Self> {
361        F::Packing::packed_powers(base)
362    }
363}
364
365impl Packable for u8 {}
366
367impl Packable for u16 {}
368
369impl Packable for u32 {}
370
371impl Packable for u64 {}
372
373impl Packable for u128 {}