Skip to main content

p3_field/packed/
packed_traits.rs

1use core::iter::{Product, Sum};
2use core::mem::MaybeUninit;
3use core::ops::{Div, DivAssign};
4use core::{array, slice};
5
6use crate::field::Field;
7use crate::{Algebra, BasedVectorSpace, ExtensionField, Powers, PrimeCharacteristicRing};
8
9/// A trait to constrain types that can be packed into a packed value.
10///
11/// The `Packable` trait allows us to specify implementations for potentially conflicting types.
12pub trait Packable: 'static + Default + Copy + Send + Sync + PartialEq + Eq {}
13
14/// A trait for array-like structs made up of multiple scalar elements.
15///
16/// # Safety
17/// - If `P` implements `PackedField` then `P` must be castable to/from `[P::Value; P::WIDTH]`
18///   without UB.
19pub unsafe trait PackedValue: 'static + Copy + Send + Sync {
20    /// The scalar type that is packed into this value.
21    type Value: Packable;
22
23    /// Number of scalar values packed together.
24    const WIDTH: usize;
25
26    /// Constructs a packed value using a function to generate each element.
27    ///
28    /// Similar to [`core::array::from_fn`].
29    #[must_use]
30    fn from_fn<F>(f: F) -> Self
31    where
32        F: FnMut(usize) -> Self::Value;
33
34    /// Create a packed value with all lanes set to the same scalar value.
35    #[inline]
36    #[must_use]
37    fn broadcast(value: Self::Value) -> Self {
38        Self::from_fn(|_| value)
39    }
40
41    /// Interprets a slice of scalar values as a packed value reference.
42    ///
43    /// # Panics:
44    /// This function will panic if `slice.len() != Self::WIDTH`
45    #[must_use]
46    fn from_slice(slice: &[Self::Value]) -> &Self;
47
48    /// Interprets a mutable slice of scalar values as a mutable packed value.
49    ///
50    /// # Panics:
51    /// This function will panic if `slice.len() != Self::WIDTH`
52    #[must_use]
53    fn from_slice_mut(slice: &mut [Self::Value]) -> &mut Self;
54
55    /// Returns the underlying scalar values as an immutable slice.
56    #[must_use]
57    fn as_slice(&self) -> &[Self::Value];
58
59    /// Returns the underlying scalar values as a mutable slice.
60    #[must_use]
61    fn as_slice_mut(&mut self) -> &mut [Self::Value];
62
63    /// Extract the scalar value at the given SIMD lane.
64    ///
65    /// This is equivalent to `self.as_slice()[lane]` but more explicit about the
66    /// SIMD extraction semantics.
67    #[inline]
68    #[must_use]
69    fn extract(&self, lane: usize) -> Self::Value {
70        self.as_slice()[lane]
71    }
72
73    /// Packs a slice of scalar values into a slice of packed values.
74    ///
75    /// # Panics
76    /// Panics if the slice length is not divisible by `WIDTH`.
77    #[inline]
78    #[must_use]
79    fn pack_slice(buf: &[Self::Value]) -> &[Self] {
80        // Sources vary, but this should be true on all platforms we care about.
81        const {
82            assert!(align_of::<Self>() <= align_of::<Self::Value>());
83        }
84        assert!(
85            buf.len().is_multiple_of(Self::WIDTH),
86            "Slice length (got {}) must be a multiple of packed field width ({}).",
87            buf.len(),
88            Self::WIDTH
89        );
90        let buf_ptr = buf.as_ptr().cast::<Self>();
91        let n = buf.len() / Self::WIDTH;
92        unsafe { slice::from_raw_parts(buf_ptr, n) }
93    }
94
95    /// Converts a mutable slice of scalar values into a mutable slice of packed values.
96    ///
97    /// # Panics
98    /// Panics if the slice length is not divisible by `WIDTH`.
99    #[inline]
100    #[must_use]
101    fn pack_slice_mut(buf: &mut [Self::Value]) -> &mut [Self] {
102        const {
103            assert!(align_of::<Self>() <= align_of::<Self::Value>());
104        }
105        assert!(
106            buf.len().is_multiple_of(Self::WIDTH),
107            "Slice length (got {}) must be a multiple of packed field width ({}).",
108            buf.len(),
109            Self::WIDTH
110        );
111        let buf_ptr = buf.as_mut_ptr().cast::<Self>();
112        let n = buf.len() / Self::WIDTH;
113        unsafe { slice::from_raw_parts_mut(buf_ptr, n) }
114    }
115
116    /// Converts a mutable slice of possibly uninitialized scalar values into
117    /// a mutable slice of possibly uninitialized packed values.
118    ///
119    /// # Panics
120    /// Panics if the slice length is not divisible by `WIDTH`.
121    #[inline]
122    #[must_use]
123    fn pack_maybe_uninit_slice_mut(
124        buf: &mut [MaybeUninit<Self::Value>],
125    ) -> &mut [MaybeUninit<Self>] {
126        const {
127            assert!(align_of::<Self>() <= align_of::<Self::Value>());
128        }
129        assert!(
130            buf.len().is_multiple_of(Self::WIDTH),
131            "Slice length (got {}) must be a multiple of packed field width ({}).",
132            buf.len(),
133            Self::WIDTH
134        );
135        let buf_ptr = buf.as_mut_ptr().cast::<MaybeUninit<Self>>();
136        let n = buf.len() / Self::WIDTH;
137        unsafe { slice::from_raw_parts_mut(buf_ptr, n) }
138    }
139
140    /// Packs a slice into packed values and returns the packed portion and any remaining suffix.
141    #[inline]
142    #[must_use]
143    fn pack_slice_with_suffix(buf: &[Self::Value]) -> (&[Self], &[Self::Value]) {
144        let (packed, suffix) = buf.split_at(buf.len() - buf.len() % Self::WIDTH);
145        (Self::pack_slice(packed), suffix)
146    }
147
148    /// Converts a mutable slice of scalar values into a pair:
149    /// - a slice of packed values covering the largest aligned portion,
150    /// - and a remainder slice of scalar values that couldn't be packed.
151    #[inline]
152    #[must_use]
153    fn pack_slice_with_suffix_mut(buf: &mut [Self::Value]) -> (&mut [Self], &mut [Self::Value]) {
154        let (packed, suffix) = buf.split_at_mut(buf.len() - buf.len() % Self::WIDTH);
155        (Self::pack_slice_mut(packed), suffix)
156    }
157
158    /// Converts a mutable slice of possibly uninitialized scalar values into a pair:
159    /// - a slice of possibly uninitialized packed values covering the largest aligned portion,
160    /// - and a remainder slice of possibly uninitialized scalar values that couldn't be packed.
161    #[inline]
162    #[must_use]
163    fn pack_maybe_uninit_slice_with_suffix_mut(
164        buf: &mut [MaybeUninit<Self::Value>],
165    ) -> (&mut [MaybeUninit<Self>], &mut [MaybeUninit<Self::Value>]) {
166        let (packed, suffix) = buf.split_at_mut(buf.len() - buf.len() % Self::WIDTH);
167        (Self::pack_maybe_uninit_slice_mut(packed), suffix)
168    }
169
170    /// Reinterprets a slice of packed values as a flat slice of scalar values.
171    ///
172    /// Each packed value contains `Self::WIDTH` scalar values, which are laid out
173    /// contiguously in memory. This function allows direct access to those scalars.
174    #[inline]
175    #[must_use]
176    fn unpack_slice(buf: &[Self]) -> &[Self::Value] {
177        const {
178            assert!(align_of::<Self>() >= align_of::<Self::Value>());
179        }
180        let buf_ptr = buf.as_ptr().cast::<Self::Value>();
181        let n = buf.len() * Self::WIDTH;
182        unsafe { slice::from_raw_parts(buf_ptr, n) }
183    }
184
185    /// Pack columns from `WIDTH` rows of scalar values into `N` packed values.
186    ///
187    /// Given `WIDTH` rows of `N` scalar values, extract each column and pack it
188    /// into a single packed value. This is the inverse of `unpack_into`.
189    ///
190    /// ## Panics
191    /// Panics if `rows.len() != WIDTH`.
192    #[inline]
193    #[must_use]
194    fn pack_columns<const N: usize>(rows: &[[Self::Value; N]]) -> [Self; N] {
195        assert_eq!(rows.len(), Self::WIDTH);
196        array::from_fn(|col| Self::from_fn(|lane| rows[lane][col]))
197    }
198
199    /// Pack columns using a closure that provides each row's data.
200    ///
201    /// Calls `row_fn(lane)` for each lane `0..WIDTH` to get `[Self::Value; N]`,
202    /// then transposes columns into packed values. Useful when rows aren't
203    /// contiguous in memory (e.g., strided access).
204    #[inline]
205    #[must_use]
206    fn pack_columns_fn<const N: usize>(row_fn: impl Fn(usize) -> [Self::Value; N]) -> [Self; N] {
207        array::from_fn(|col| Self::from_fn(|lane| row_fn(lane)[col]))
208    }
209
210    /// Unpack `N` packed values into `WIDTH` rows of `N` scalars.
211    ///
212    /// ## Inputs
213    /// - `packed`: An array of `N` packed values.
214    /// - `rows`: A mutable slice of exactly `WIDTH` arrays to write the unpacked values.
215    ///
216    /// ## Panics
217    /// Panics if `rows.len() != WIDTH`.
218    #[inline]
219    fn unpack_into<const N: usize>(packed: &[Self; N], rows: &mut [[Self::Value; N]]) {
220        assert_eq!(rows.len(), Self::WIDTH);
221        for (lane, row) in rows.iter_mut().enumerate() {
222            *row = array::from_fn(|col| packed[col].extract(lane));
223        }
224    }
225
226    /// Unpack `N` packed values into an iterator of `WIDTH` rows.
227    ///
228    /// This is the iterator equivalent of `unpack_into`, yielding each row
229    /// without requiring a pre-allocated buffer.
230    #[inline]
231    fn unpack_iter<const N: usize>(packed: [Self; N]) -> impl Iterator<Item = [Self::Value; N]> {
232        (0..Self::WIDTH).map(move |lane| array::from_fn(|col| packed[col].extract(lane)))
233    }
234}
235
236unsafe impl<T: Packable, const WIDTH: usize> PackedValue for [T; WIDTH] {
237    type Value = T;
238    const WIDTH: usize = WIDTH;
239
240    #[inline]
241    fn from_slice(slice: &[Self::Value]) -> &Self {
242        assert_eq!(slice.len(), Self::WIDTH);
243        unsafe { &*slice.as_ptr().cast() }
244    }
245
246    #[inline]
247    fn from_slice_mut(slice: &mut [Self::Value]) -> &mut Self {
248        assert_eq!(slice.len(), Self::WIDTH);
249        unsafe { &mut *slice.as_mut_ptr().cast() }
250    }
251
252    #[inline]
253    fn from_fn<Fn>(f: Fn) -> Self
254    where
255        Fn: FnMut(usize) -> Self::Value,
256    {
257        core::array::from_fn(f)
258    }
259
260    #[inline]
261    fn as_slice(&self) -> &[Self::Value] {
262        self
263    }
264
265    #[inline]
266    fn as_slice_mut(&mut self) -> &mut [Self::Value] {
267        self
268    }
269}
270
271/// An array of field elements which can be packed into a vector for SIMD operations.
272///
273/// # Safety
274/// - See `PackedValue` above.
275pub unsafe trait PackedField:
276    Algebra<Self::Scalar>
277    + PackedValue<Value = Self::Scalar>
278    + Div<Self, Output = Self>
279    + Div<Self::Scalar, Output = Self>
280    + DivAssign<Self>
281    + DivAssign<Self::Scalar>
282    + Sum<Self::Scalar>
283    + Product<Self::Scalar>
284{
285    type Scalar: Field;
286
287    /// Construct an iterator which returns powers of `base` packed into packed field elements.
288    ///
289    /// E.g. if `Self::WIDTH = 4`, returns: `[base^0, base^1, base^2, base^3], [base^4, base^5, base^6, base^7], ...`.
290    #[must_use]
291    fn packed_powers(base: Self::Scalar) -> Powers<Self> {
292        Self::packed_shifted_powers(base, Self::Scalar::ONE)
293    }
294
295    /// Construct an iterator which returns powers of `base` multiplied by `start` and packed into packed field elements.
296    ///
297    /// E.g. if `Self::WIDTH = 4`, returns: `[start, start*base, start*base^2, start*base^3], [start*base^4, start*base^5, start*base^6, start*base^7], ...`.
298    #[must_use]
299    fn packed_shifted_powers(base: Self::Scalar, start: Self::Scalar) -> Powers<Self> {
300        let mut current: Self = start.into();
301        let slice = current.as_slice_mut();
302        for i in 1..Self::WIDTH {
303            slice[i] = slice[i - 1] * base;
304        }
305
306        Powers {
307            base: base.exp_u64(Self::WIDTH as u64).into(),
308            current,
309        }
310    }
311}
312
313/// # Safety
314/// - `WIDTH` is assumed to be a power of 2.
315pub unsafe trait PackedFieldPow2: PackedField {
316    /// Take interpret two vectors as chunks of `block_len` elements. Unpack and interleave those
317    /// chunks. This is best seen with an example. If we have:
318    /// ```text
319    /// A = [x0, y0, x1, y1]
320    /// B = [x2, y2, x3, y3]
321    /// ```
322    ///
323    /// then
324    ///
325    /// ```text
326    /// interleave(A, B, 1) = ([x0, x2, x1, x3], [y0, y2, y1, y3])
327    /// ```
328    ///
329    /// Pairs that were adjacent in the input are at corresponding positions in the output.
330    ///
331    /// `r` lets us set the size of chunks we're interleaving. If we set `block_len = 2`, then for
332    ///
333    /// ```text
334    /// A = [x0, x1, y0, y1]
335    /// B = [x2, x3, y2, y3]
336    /// ```
337    ///
338    /// we obtain
339    ///
340    /// ```text
341    /// interleave(A, B, block_len) = ([x0, x1, x2, x3], [y0, y1, y2, y3])
342    /// ```
343    ///
344    /// We can also think about this as stacking the vectors, dividing them into 2x2 matrices, and
345    /// transposing those matrices.
346    ///
347    /// When `block_len = WIDTH`, this operation is a no-op.
348    ///
349    /// # Panics
350    /// This may panic if `block_len` does not divide `WIDTH`. Since `WIDTH` is specified to be a power of 2,
351    /// `block_len` must also be a power of 2. It cannot be 0 and it cannot exceed `WIDTH`.
352    #[must_use]
353    fn interleave(&self, other: Self, block_len: usize) -> (Self, Self);
354}
355
356/// Fix a field `F` a packing width `W` and an extension field `EF` of `F`.
357///
358/// By choosing a basis `B`, `EF` can be transformed into an array `[F; D]`.
359///
360/// A type should implement PackedFieldExtension if it can be transformed into `[F::Packing; D] ~ [[F; W]; D]`
361///
362/// This is interpreted by taking a transpose to get `[[F; D]; W]` which can then be reinterpreted
363/// as `[EF; W]` by making use of the chosen basis `B` again.
364pub trait PackedFieldExtension<
365    BaseField: Field,
366    ExtField: ExtensionField<BaseField, ExtensionPacking = Self>,
367>: Algebra<ExtField> + Algebra<BaseField::Packing> + BasedVectorSpace<BaseField::Packing>
368{
369    /// Construct a packed extension by applying `f` to each lane.
370    ///
371    /// This is the extension-field analog of [`PackedValue::from_fn`] and the canonical
372    /// primitive constructor for packed extensions: every other constructor in this
373    /// trait (`from_ext_slice`, `pack_ext_columns`, etc.) routes through it.
374    ///
375    /// `f` is called once per `(basis_coefficient, lane)` pair (`D * W` calls total),
376    /// hence the [`Fn`] bound — closures with side effects are unsuitable.
377    ///
378    /// The default impl uses only the [`BasedVectorSpace`] machinery the trait already
379    /// requires. Concrete impls should override when the extension struct exposes its
380    /// base packings directly, e.g. `Self::new(F::Packing::pack_columns_fn(|l| f(l).value))`.
381    #[inline]
382    #[must_use]
383    fn from_ext_fn(f: impl Fn(usize) -> ExtField) -> Self {
384        Self::from_basis_coefficients_fn(|d| {
385            BaseField::Packing::from_fn(|lane| f(lane).as_basis_coefficients_slice()[d])
386        })
387    }
388
389    /// Pack a length-`WIDTH` slice of extension field elements into one packed extension.
390    ///
391    /// ## Panics
392    /// Panics if `slice.len() != BaseField::Packing::WIDTH`.
393    #[inline]
394    #[must_use]
395    fn from_ext_slice(slice: &[ExtField]) -> Self {
396        assert_eq!(slice.len(), BaseField::Packing::WIDTH);
397        Self::from_ext_fn(|lane| slice[lane])
398    }
399
400    /// Pack `N` columns from `W` rows of extension field elements into `N` packed extensions.
401    ///
402    /// This is the extension-field analog of [`PackedValue::pack_columns`]: given `W` rows
403    /// of `N` extension elements, lane `lane` of output column `col` is `rows[lane][col]`.
404    ///
405    /// ## Panics
406    /// Panics if `rows.len() != BaseField::Packing::WIDTH`.
407    #[inline]
408    #[must_use]
409    fn pack_ext_columns<const N: usize>(rows: &[[ExtField; N]]) -> [Self; N] {
410        assert_eq!(rows.len(), BaseField::Packing::WIDTH);
411        array::from_fn(|col| Self::from_ext_fn(|lane| rows[lane][col]))
412    }
413
414    /// Pack `N` columns using a closure that produces each row.
415    ///
416    /// Analog of [`PackedValue::pack_columns_fn`].
417    #[inline]
418    #[must_use]
419    fn pack_ext_columns_fn<const N: usize>(row_fn: impl Fn(usize) -> [ExtField; N]) -> [Self; N] {
420        array::from_fn(|col| Self::from_ext_fn(|lane| row_fn(lane)[col]))
421    }
422
423    /// Extract the extension field element at the given SIMD lane.
424    #[inline]
425    #[must_use]
426    fn extract(&self, lane: usize) -> ExtField {
427        ExtField::from_basis_coefficients_fn(|d| {
428            self.as_basis_coefficients_slice()[d].as_slice()[lane]
429        })
430    }
431
432    /// Write all `W` lanes into the given slice.
433    ///
434    /// This is the extension-field analog of [`PackedValue::as_slice`], but the lanes of
435    /// a packed extension are not contiguous in memory (the layout is `[[F; W]; D]`,
436    /// indexed first by basis coefficient), so the lanes must be copied rather than
437    /// borrowed.
438    ///
439    /// ## Panics
440    /// Panics if `out.len() != BaseField::Packing::WIDTH`.
441    #[inline]
442    fn to_ext_slice(&self, out: &mut [ExtField]) {
443        assert_eq!(out.len(), BaseField::Packing::WIDTH);
444        for (lane, slot) in out.iter_mut().enumerate() {
445            *slot = self.extract(lane);
446        }
447    }
448
449    /// Unpack `N` packed extensions into `W` rows of `N` extension elements.
450    ///
451    /// Inverse of [`PackedFieldExtension::pack_ext_columns`]. Lane `lane` of input
452    /// column `col` is written to `rows[lane][col]`.
453    ///
454    /// ## Panics
455    /// Panics if `rows.len() != BaseField::Packing::WIDTH`.
456    #[inline]
457    fn unpack_ext_into<const N: usize>(packed: &[Self; N], rows: &mut [[ExtField; N]]) {
458        assert_eq!(rows.len(), BaseField::Packing::WIDTH);
459        for (lane, row) in rows.iter_mut().enumerate() {
460            *row = array::from_fn(|col| {
461                ExtField::from_basis_coefficients_fn(|d| {
462                    packed[col].as_basis_coefficients_slice()[d].as_slice()[lane]
463                })
464            });
465        }
466    }
467
468    /// Iterator equivalent of [`PackedFieldExtension::unpack_ext_into`].
469    ///
470    /// Yields `WIDTH` rows of `N` extension elements without requiring a pre-allocated
471    /// buffer. Analog of [`PackedValue::unpack_iter`].
472    #[inline]
473    fn unpack_ext_iter<const N: usize>(packed: [Self; N]) -> impl Iterator<Item = [ExtField; N]> {
474        (0..BaseField::Packing::WIDTH).map(move |lane| {
475            array::from_fn(|col| {
476                ExtField::from_basis_coefficients_fn(|d| {
477                    packed[col].as_basis_coefficients_slice()[d].as_slice()[lane]
478                })
479            })
480        })
481    }
482
483    /// Convert an iterator of packed extension field elements to an iterator of
484    /// extension field elements (flat — one `ExtField` per lane per packed value).
485    #[inline]
486    #[must_use]
487    fn to_ext_iter(iter: impl IntoIterator<Item = Self>) -> impl Iterator<Item = ExtField> {
488        iter.into_iter()
489            .flat_map(|x| (0..BaseField::Packing::WIDTH).map(move |lane| x.extract(lane)))
490    }
491
492    /// Similar to `packed_powers`, construct an iterator which returns
493    /// powers of `base` packed into `PackedFieldExtension` elements.
494    #[must_use]
495    fn packed_ext_powers(base: ExtField) -> Powers<Self>;
496
497    /// Similar to `packed_ext_powers` but only returns `unpacked_len` powers of `base`.
498    ///
499    /// Note that the length of the returned iterator will be `unpacked_len / WIDTH` and
500    /// not `len` as the iterator is over packed extension field elements. If `unpacked_len`
501    /// is not divisible by `WIDTH`, `unpacked_len` will be rounded up to the next multiple of `WIDTH`.
502    #[must_use]
503    fn packed_ext_powers_capped(base: ExtField, unpacked_len: usize) -> impl Iterator<Item = Self> {
504        Self::packed_ext_powers(base).take(unpacked_len.div_ceil(BaseField::Packing::WIDTH))
505    }
506}
507
508unsafe impl<T: Packable> PackedValue for T {
509    type Value = Self;
510
511    const WIDTH: usize = 1;
512
513    #[inline]
514    fn from_slice(slice: &[Self::Value]) -> &Self {
515        assert_eq!(slice.len(), Self::WIDTH);
516        &slice[0]
517    }
518
519    #[inline]
520    fn from_slice_mut(slice: &mut [Self::Value]) -> &mut Self {
521        assert_eq!(slice.len(), Self::WIDTH);
522        &mut slice[0]
523    }
524
525    #[inline]
526    fn from_fn<Fn>(mut f: Fn) -> Self
527    where
528        Fn: FnMut(usize) -> Self::Value,
529    {
530        f(0)
531    }
532
533    #[inline]
534    fn as_slice(&self) -> &[Self::Value] {
535        slice::from_ref(self)
536    }
537
538    #[inline]
539    fn as_slice_mut(&mut self) -> &mut [Self::Value] {
540        slice::from_mut(self)
541    }
542}
543
544unsafe impl<F: Field> PackedField for F {
545    type Scalar = Self;
546}
547
548unsafe impl<F: Field> PackedFieldPow2 for F {
549    #[inline]
550    fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) {
551        match block_len {
552            1 => (*self, other),
553            _ => panic!("unsupported block length"),
554        }
555    }
556}
557
558impl<F: Field> PackedFieldExtension<F, F> for F::Packing {
559    #[inline]
560    fn from_ext_fn(f: impl Fn(usize) -> F) -> Self {
561        F::Packing::from_fn(f)
562    }
563
564    #[inline]
565    fn from_ext_slice(slice: &[F]) -> Self {
566        *F::Packing::from_slice(slice)
567    }
568
569    #[inline]
570    fn packed_ext_powers(base: F) -> Powers<Self> {
571        F::Packing::packed_powers(base)
572    }
573}
574
575impl Packable for u8 {}
576
577impl Packable for u16 {}
578
579impl Packable for u32 {}
580
581impl Packable for u64 {}
582
583impl Packable for u128 {}