Skip to main content

p3_field/packed/
packed_traits.rs

1use core::iter::{Product, Sum};
2use core::mem::MaybeUninit;
3use core::ops::{Div, DivAssign};
4use core::{array, slice};
5
6use crate::field::Field;
7use crate::{Algebra, BasedVectorSpace, ExtensionField, Powers, PrimeCharacteristicRing};
8
9/// A trait to constrain types that can be packed into a packed value.
10///
11/// The `Packable` trait allows us to specify implementations for potentially conflicting types.
12pub trait Packable: 'static + Default + Copy + Send + Sync + PartialEq + Eq {}
13
14/// A trait for array-like structs made up of multiple scalar elements.
15///
16/// # Safety
17/// - If `P` implements `PackedField` then `P` must be castable to/from `[P::Value; P::WIDTH]`
18///   without UB.
19pub unsafe trait PackedValue: 'static + Copy + Send + Sync {
20    /// The scalar type that is packed into this value.
21    type Value: Packable;
22
23    /// Number of scalar values packed together.
24    const WIDTH: usize;
25
26    /// Constructs a packed value using a function to generate each element.
27    ///
28    /// Similar to [`core::array::from_fn`].
29    #[must_use]
30    fn from_fn<F>(f: F) -> Self
31    where
32        F: FnMut(usize) -> Self::Value;
33
34    /// Create a packed value with all lanes set to the same scalar value.
35    #[inline]
36    #[must_use]
37    fn broadcast(value: Self::Value) -> Self {
38        Self::from_fn(|_| value)
39    }
40
41    /// Interprets a slice of scalar values as a packed value reference.
42    ///
43    /// # Panics:
44    /// This function will panic if `slice.len() != Self::WIDTH`
45    #[must_use]
46    fn from_slice(slice: &[Self::Value]) -> &Self;
47
48    /// Interprets a mutable slice of scalar values as a mutable packed value.
49    ///
50    /// # Panics:
51    /// This function will panic if `slice.len() != Self::WIDTH`
52    #[must_use]
53    fn from_slice_mut(slice: &mut [Self::Value]) -> &mut Self;
54
55    /// Returns the underlying scalar values as an immutable slice.
56    #[must_use]
57    fn as_slice(&self) -> &[Self::Value];
58
59    /// Returns the underlying scalar values as a mutable slice.
60    #[must_use]
61    fn as_slice_mut(&mut self) -> &mut [Self::Value];
62
63    /// Extract the scalar value at the given SIMD lane.
64    ///
65    /// This is equivalent to `self.as_slice()[lane]` but more explicit about the
66    /// SIMD extraction semantics.
67    #[inline]
68    #[must_use]
69    fn extract(&self, lane: usize) -> Self::Value {
70        self.as_slice()[lane]
71    }
72
73    /// Packs a slice of scalar values into a slice of packed values.
74    ///
75    /// # Panics
76    /// Panics if the slice length is not divisible by `WIDTH`.
77    #[inline]
78    #[must_use]
79    fn pack_slice(buf: &[Self::Value]) -> &[Self] {
80        // Sources vary, but this should be true on all platforms we care about.
81        const {
82            assert!(align_of::<Self>() <= align_of::<Self::Value>());
83        }
84        assert!(
85            buf.len().is_multiple_of(Self::WIDTH),
86            "Slice length (got {}) must be a multiple of packed field width ({}).",
87            buf.len(),
88            Self::WIDTH
89        );
90        let buf_ptr = buf.as_ptr().cast::<Self>();
91        let n = buf.len() / Self::WIDTH;
92        unsafe { slice::from_raw_parts(buf_ptr, n) }
93    }
94
95    /// Converts a mutable slice of scalar values into a mutable slice of packed values.
96    ///
97    /// # Panics
98    /// Panics if the slice length is not divisible by `WIDTH`.
99    #[inline]
100    #[must_use]
101    fn pack_slice_mut(buf: &mut [Self::Value]) -> &mut [Self] {
102        const {
103            assert!(align_of::<Self>() <= align_of::<Self::Value>());
104        }
105        assert!(
106            buf.len().is_multiple_of(Self::WIDTH),
107            "Slice length (got {}) must be a multiple of packed field width ({}).",
108            buf.len(),
109            Self::WIDTH
110        );
111        let buf_ptr = buf.as_mut_ptr().cast::<Self>();
112        let n = buf.len() / Self::WIDTH;
113        unsafe { slice::from_raw_parts_mut(buf_ptr, n) }
114    }
115
116    /// Converts a mutable slice of possibly uninitialized scalar values into
117    /// a mutable slice of possibly uninitialized packed values.
118    ///
119    /// # Panics
120    /// Panics if the slice length is not divisible by `WIDTH`.
121    #[inline]
122    #[must_use]
123    fn pack_maybe_uninit_slice_mut(
124        buf: &mut [MaybeUninit<Self::Value>],
125    ) -> &mut [MaybeUninit<Self>] {
126        const {
127            assert!(align_of::<Self>() <= align_of::<Self::Value>());
128        }
129        assert!(
130            buf.len().is_multiple_of(Self::WIDTH),
131            "Slice length (got {}) must be a multiple of packed field width ({}).",
132            buf.len(),
133            Self::WIDTH
134        );
135        let buf_ptr = buf.as_mut_ptr().cast::<MaybeUninit<Self>>();
136        let n = buf.len() / Self::WIDTH;
137        unsafe { slice::from_raw_parts_mut(buf_ptr, n) }
138    }
139
140    /// Packs a slice into packed values and returns the packed portion and any remaining suffix.
141    #[inline]
142    #[must_use]
143    fn pack_slice_with_suffix(buf: &[Self::Value]) -> (&[Self], &[Self::Value]) {
144        let (packed, suffix) = buf.split_at(buf.len() - buf.len() % Self::WIDTH);
145        (Self::pack_slice(packed), suffix)
146    }
147
148    /// Converts a mutable slice of scalar values into a pair:
149    /// - a slice of packed values covering the largest aligned portion,
150    /// - and a remainder slice of scalar values that couldn't be packed.
151    #[inline]
152    #[must_use]
153    fn pack_slice_with_suffix_mut(buf: &mut [Self::Value]) -> (&mut [Self], &mut [Self::Value]) {
154        let (packed, suffix) = buf.split_at_mut(buf.len() - buf.len() % Self::WIDTH);
155        (Self::pack_slice_mut(packed), suffix)
156    }
157
158    /// Converts a mutable slice of possibly uninitialized scalar values into a pair:
159    /// - a slice of possibly uninitialized packed values covering the largest aligned portion,
160    /// - and a remainder slice of possibly uninitialized scalar values that couldn't be packed.
161    #[inline]
162    #[must_use]
163    fn pack_maybe_uninit_slice_with_suffix_mut(
164        buf: &mut [MaybeUninit<Self::Value>],
165    ) -> (&mut [MaybeUninit<Self>], &mut [MaybeUninit<Self::Value>]) {
166        let (packed, suffix) = buf.split_at_mut(buf.len() - buf.len() % Self::WIDTH);
167        (Self::pack_maybe_uninit_slice_mut(packed), suffix)
168    }
169
170    /// Reinterprets a slice of packed values as a flat slice of scalar values.
171    ///
172    /// Each packed value contains `Self::WIDTH` scalar values, which are laid out
173    /// contiguously in memory. This function allows direct access to those scalars.
174    #[inline]
175    #[must_use]
176    fn unpack_slice(buf: &[Self]) -> &[Self::Value] {
177        const {
178            assert!(align_of::<Self>() >= align_of::<Self::Value>());
179        }
180        let buf_ptr = buf.as_ptr().cast::<Self::Value>();
181        let n = buf.len() * Self::WIDTH;
182        unsafe { slice::from_raw_parts(buf_ptr, n) }
183    }
184
185    /// Pack columns from `WIDTH` rows of scalar values into `N` packed values.
186    ///
187    /// Given `WIDTH` rows of `N` scalar values, extract each column and pack it
188    /// into a single packed value. This is the inverse of `unpack_into`.
189    ///
190    /// ## Panics
191    /// Panics if `rows.len() != WIDTH`.
192    #[inline]
193    #[must_use]
194    fn pack_columns<const N: usize>(rows: &[[Self::Value; N]]) -> [Self; N] {
195        assert_eq!(rows.len(), Self::WIDTH);
196        array::from_fn(|col| Self::from_fn(|lane| rows[lane][col]))
197    }
198
199    /// Pack columns using a closure that provides each row's data.
200    ///
201    /// Calls `row_fn(lane)` for each lane `0..WIDTH` to get `[Self::Value; N]`,
202    /// then transposes columns into packed values. Useful when rows aren't
203    /// contiguous in memory (e.g., strided access).
204    #[inline]
205    #[must_use]
206    fn pack_columns_fn<const N: usize>(row_fn: impl Fn(usize) -> [Self::Value; N]) -> [Self; N] {
207        array::from_fn(|col| Self::from_fn(|lane| row_fn(lane)[col]))
208    }
209
210    /// Unpack `N` packed values into `WIDTH` rows of `N` scalars.
211    ///
212    /// ## Inputs
213    /// - `packed`: An array of `N` packed values.
214    /// - `rows`: A mutable slice of exactly `WIDTH` arrays to write the unpacked values.
215    ///
216    /// ## Panics
217    /// Panics if `rows.len() != WIDTH`.
218    #[inline]
219    fn unpack_into<const N: usize>(packed: &[Self; N], rows: &mut [[Self::Value; N]]) {
220        assert_eq!(rows.len(), Self::WIDTH);
221        #[allow(clippy::needless_range_loop)]
222        for lane in 0..Self::WIDTH {
223            rows[lane] = array::from_fn(|col| packed[col].extract(lane));
224        }
225    }
226
227    /// Unpack `N` packed values into an iterator of `WIDTH` rows.
228    ///
229    /// This is the iterator equivalent of `unpack_into`, yielding each row
230    /// without requiring a pre-allocated buffer.
231    #[inline]
232    fn unpack_iter<const N: usize>(packed: [Self; N]) -> impl Iterator<Item = [Self::Value; N]> {
233        (0..Self::WIDTH).map(move |lane| array::from_fn(|col| packed[col].extract(lane)))
234    }
235}
236
237unsafe impl<T: Packable, const WIDTH: usize> PackedValue for [T; WIDTH] {
238    type Value = T;
239    const WIDTH: usize = WIDTH;
240
241    #[inline]
242    fn from_slice(slice: &[Self::Value]) -> &Self {
243        assert_eq!(slice.len(), Self::WIDTH);
244        unsafe { &*slice.as_ptr().cast() }
245    }
246
247    #[inline]
248    fn from_slice_mut(slice: &mut [Self::Value]) -> &mut Self {
249        assert_eq!(slice.len(), Self::WIDTH);
250        unsafe { &mut *slice.as_mut_ptr().cast() }
251    }
252
253    #[inline]
254    fn from_fn<Fn>(f: Fn) -> Self
255    where
256        Fn: FnMut(usize) -> Self::Value,
257    {
258        core::array::from_fn(f)
259    }
260
261    #[inline]
262    fn as_slice(&self) -> &[Self::Value] {
263        self
264    }
265
266    #[inline]
267    fn as_slice_mut(&mut self) -> &mut [Self::Value] {
268        self
269    }
270}
271
272/// An array of field elements which can be packed into a vector for SIMD operations.
273///
274/// # Safety
275/// - See `PackedValue` above.
276pub unsafe trait PackedField: Algebra<Self::Scalar>
277    + PackedValue<Value = Self::Scalar>
278    // TODO: Implement packed / packed division
279    + Div<Self::Scalar, Output = Self>
280    + DivAssign<Self::Scalar>
281    + Sum<Self::Scalar>
282    + Product<Self::Scalar>
283{
284    type Scalar: Field;
285
286    /// Construct an iterator which returns powers of `base` packed into packed field elements.
287    ///
288    /// E.g. if `Self::WIDTH = 4`, returns: `[base^0, base^1, base^2, base^3], [base^4, base^5, base^6, base^7], ...`.
289    #[must_use]
290    fn packed_powers(base: Self::Scalar) -> Powers<Self> {
291        Self::packed_shifted_powers(base, Self::Scalar::ONE)
292    }
293
294    /// Construct an iterator which returns powers of `base` multiplied by `start` and packed into packed field elements.
295    ///
296    /// E.g. if `Self::WIDTH = 4`, returns: `[start, start*base, start*base^2, start*base^3], [start*base^4, start*base^5, start*base^6, start*base^7], ...`.
297    #[must_use]
298    fn packed_shifted_powers(base: Self::Scalar, start: Self::Scalar) -> Powers<Self> {
299        let mut current: Self = start.into();
300        let slice = current.as_slice_mut();
301        for i in 1..Self::WIDTH {
302            slice[i] = slice[i - 1] * base;
303        }
304
305        Powers {
306            base: base.exp_u64(Self::WIDTH as u64).into(),
307            current,
308        }
309    }
310
311    /// Compute a linear combination of a slice of base field elements and
312    /// a slice of packed field elements. The slices must have equal length
313    /// and it must be a compile time constant.
314    ///
315    /// # Panics
316    ///
317    /// May panic if the length of either slice is not equal to `N`.
318    #[must_use]
319    fn packed_linear_combination<const N: usize>(coeffs: &[Self::Scalar], vecs: &[Self]) -> Self {
320        assert_eq!(coeffs.len(), N);
321        assert_eq!(vecs.len(), N);
322        let combined: [Self; N] = array::from_fn(|i| vecs[i] * coeffs[i]);
323        Self::sum_array::<N>(&combined)
324    }
325}
326
327/// # Safety
328/// - `WIDTH` is assumed to be a power of 2.
329pub unsafe trait PackedFieldPow2: PackedField {
330    /// Take interpret two vectors as chunks of `block_len` elements. Unpack and interleave those
331    /// chunks. This is best seen with an example. If we have:
332    /// ```text
333    /// A = [x0, y0, x1, y1]
334    /// B = [x2, y2, x3, y3]
335    /// ```
336    ///
337    /// then
338    ///
339    /// ```text
340    /// interleave(A, B, 1) = ([x0, x2, x1, x3], [y0, y2, y1, y3])
341    /// ```
342    ///
343    /// Pairs that were adjacent in the input are at corresponding positions in the output.
344    ///
345    /// `r` lets us set the size of chunks we're interleaving. If we set `block_len = 2`, then for
346    ///
347    /// ```text
348    /// A = [x0, x1, y0, y1]
349    /// B = [x2, x3, y2, y3]
350    /// ```
351    ///
352    /// we obtain
353    ///
354    /// ```text
355    /// interleave(A, B, block_len) = ([x0, x1, x2, x3], [y0, y1, y2, y3])
356    /// ```
357    ///
358    /// We can also think about this as stacking the vectors, dividing them into 2x2 matrices, and
359    /// transposing those matrices.
360    ///
361    /// When `block_len = WIDTH`, this operation is a no-op.
362    ///
363    /// # Panics
364    /// This may panic if `block_len` does not divide `WIDTH`. Since `WIDTH` is specified to be a power of 2,
365    /// `block_len` must also be a power of 2. It cannot be 0 and it cannot exceed `WIDTH`.
366    #[must_use]
367    fn interleave(&self, other: Self, block_len: usize) -> (Self, Self);
368}
369
370/// Fix a field `F` a packing width `W` and an extension field `EF` of `F`.
371///
372/// By choosing a basis `B`, `EF` can be transformed into an array `[F; D]`.
373///
374/// A type should implement PackedFieldExtension if it can be transformed into `[F::Packing; D] ~ [[F; W]; D]`
375///
376/// This is interpreted by taking a transpose to get `[[F; D]; W]` which can then be reinterpreted
377/// as `[EF; W]` by making use of the chosen basis `B` again.
378pub trait PackedFieldExtension<
379    BaseField: Field,
380    ExtField: ExtensionField<BaseField, ExtensionPacking = Self>,
381>: Algebra<ExtField> + Algebra<BaseField::Packing> + BasedVectorSpace<BaseField::Packing>
382{
383    /// Given a slice of extension field `EF` elements of length `W`,
384    /// convert into the array `[[F; D]; W]` transpose to
385    /// `[[F; W]; D]` and then pack to get `[PF; D]`.
386    #[must_use]
387    fn from_ext_slice(ext_slice: &[ExtField]) -> Self;
388
389    /// Extract the extension field element at the given SIMD lane.
390    #[inline]
391    #[must_use]
392    fn extract(&self, lane: usize) -> ExtField {
393        ExtField::from_basis_coefficients_fn(|d| {
394            self.as_basis_coefficients_slice()[d].as_slice()[lane]
395        })
396    }
397
398    /// Convert an iterator of packed extension field elements to an iterator of
399    /// extension field elements.
400    ///
401    /// This performs the inverse transformation to `from_ext_slice`.
402    #[inline]
403    #[must_use]
404    fn to_ext_iter(iter: impl IntoIterator<Item = Self>) -> impl Iterator<Item = ExtField> {
405        iter.into_iter()
406            .flat_map(|x| (0..BaseField::Packing::WIDTH).map(move |lane| x.extract(lane)))
407    }
408
409    /// Similar to `packed_powers`, construct an iterator which returns
410    /// powers of `base` packed into `PackedFieldExtension` elements.
411    #[must_use]
412    fn packed_ext_powers(base: ExtField) -> Powers<Self>;
413
414    /// Similar to `packed_ext_powers` but only returns `unpacked_len` powers of `base`.
415    ///
416    /// Note that the length of the returned iterator will be `unpacked_len / WIDTH` and
417    /// not `len` as the iterator is over packed extension field elements. If `unpacked_len`
418    /// is not divisible by `WIDTH`, `unpacked_len` will be rounded up to the next multiple of `WIDTH`.
419    #[must_use]
420    fn packed_ext_powers_capped(base: ExtField, unpacked_len: usize) -> impl Iterator<Item = Self> {
421        Self::packed_ext_powers(base).take(unpacked_len.div_ceil(BaseField::Packing::WIDTH))
422    }
423}
424
425unsafe impl<T: Packable> PackedValue for T {
426    type Value = Self;
427
428    const WIDTH: usize = 1;
429
430    #[inline]
431    fn from_slice(slice: &[Self::Value]) -> &Self {
432        assert_eq!(slice.len(), Self::WIDTH);
433        &slice[0]
434    }
435
436    #[inline]
437    fn from_slice_mut(slice: &mut [Self::Value]) -> &mut Self {
438        assert_eq!(slice.len(), Self::WIDTH);
439        &mut slice[0]
440    }
441
442    #[inline]
443    fn from_fn<Fn>(mut f: Fn) -> Self
444    where
445        Fn: FnMut(usize) -> Self::Value,
446    {
447        f(0)
448    }
449
450    #[inline]
451    fn as_slice(&self) -> &[Self::Value] {
452        slice::from_ref(self)
453    }
454
455    #[inline]
456    fn as_slice_mut(&mut self) -> &mut [Self::Value] {
457        slice::from_mut(self)
458    }
459}
460
461unsafe impl<F: Field> PackedField for F {
462    type Scalar = Self;
463}
464
465unsafe impl<F: Field> PackedFieldPow2 for F {
466    #[inline]
467    fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) {
468        match block_len {
469            1 => (*self, other),
470            _ => panic!("unsupported block length"),
471        }
472    }
473}
474
475impl<F: Field> PackedFieldExtension<F, F> for F::Packing {
476    #[inline]
477    fn from_ext_slice(ext_slice: &[F]) -> Self {
478        *F::Packing::from_slice(ext_slice)
479    }
480
481    #[inline]
482    fn packed_ext_powers(base: F) -> Powers<Self> {
483        F::Packing::packed_powers(base)
484    }
485}
486
487impl Packable for u8 {}
488
489impl Packable for u16 {}
490
491impl Packable for u32 {}
492
493impl Packable for u64 {}
494
495impl Packable for u128 {}