p3_field/packed/
packed_traits.rs

1use core::iter::{Product, Sum};
2use core::mem::MaybeUninit;
3use core::ops::{Div, DivAssign};
4use core::{array, slice};
5
6use crate::field::Field;
7use crate::{Algebra, BasedVectorSpace, ExtensionField, Powers, PrimeCharacteristicRing};
8
9/// A trait to constrain types that can be packed into a packed value.
10///
11/// The `Packable` trait allows us to specify implementations for potentially conflicting types.
12pub trait Packable: 'static + Default + Copy + Send + Sync + PartialEq + Eq {}
13
14/// A trait for array-like structs made up of multiple scalar elements.
15///
16/// # Safety
17/// - If `P` implements `PackedField` then `P` must be castable to/from `[P::Value; P::WIDTH]`
18///   without UB.
19pub unsafe trait PackedValue: 'static + Copy + Send + Sync {
20    /// The scalar type that is packed into this value.
21    type Value: Packable;
22
23    /// Number of scalar values packed together.
24    const WIDTH: usize;
25
26    /// Interprets a slice of scalar values as a packed value reference.
27    ///
28    /// # Panics:
29    /// This function will panic if `slice.len() != Self::WIDTH`
30    #[must_use]
31    fn from_slice(slice: &[Self::Value]) -> &Self;
32
33    /// Interprets a mutable slice of scalar values as a mutable packed value.
34    ///
35    /// # Panics:
36    /// This function will panic if `slice.len() != Self::WIDTH`
37    #[must_use]
38    fn from_slice_mut(slice: &mut [Self::Value]) -> &mut Self;
39
40    /// Constructs a packed value using a function to generate each element.
41    ///
42    /// Similar to `core:array::from_fn`.
43    #[must_use]
44    fn from_fn<F>(f: F) -> Self
45    where
46        F: FnMut(usize) -> Self::Value;
47
48    /// Returns the underlying scalar values as an immutable slice.
49    #[must_use]
50    fn as_slice(&self) -> &[Self::Value];
51
52    /// Returns the underlying scalar values as a mutable slice.
53    #[must_use]
54    fn as_slice_mut(&mut self) -> &mut [Self::Value];
55
56    /// Packs a slice of scalar values into a slice of packed values.
57    ///
58    /// # Panics
59    /// Panics if the slice length is not divisible by `WIDTH`.
60    #[inline]
61    #[must_use]
62    fn pack_slice(buf: &[Self::Value]) -> &[Self] {
63        // Sources vary, but this should be true on all platforms we care about.
64        const {
65            assert!(align_of::<Self>() <= align_of::<Self::Value>());
66        }
67        assert!(
68            buf.len().is_multiple_of(Self::WIDTH),
69            "Slice length (got {}) must be a multiple of packed field width ({}).",
70            buf.len(),
71            Self::WIDTH
72        );
73        let buf_ptr = buf.as_ptr().cast::<Self>();
74        let n = buf.len() / Self::WIDTH;
75        unsafe { slice::from_raw_parts(buf_ptr, n) }
76    }
77
78    /// Packs a slice into packed values and returns the packed portion and any remaining suffix.
79    #[inline]
80    #[must_use]
81    fn pack_slice_with_suffix(buf: &[Self::Value]) -> (&[Self], &[Self::Value]) {
82        let (packed, suffix) = buf.split_at(buf.len() - buf.len() % Self::WIDTH);
83        (Self::pack_slice(packed), suffix)
84    }
85
86    /// Converts a mutable slice of scalar values into a mutable slice of packed values.
87    ///
88    /// # Panics
89    /// Panics if the slice length is not divisible by `WIDTH`.
90    #[inline]
91    #[must_use]
92    fn pack_slice_mut(buf: &mut [Self::Value]) -> &mut [Self] {
93        const {
94            assert!(align_of::<Self>() <= align_of::<Self::Value>());
95        }
96        assert!(
97            buf.len().is_multiple_of(Self::WIDTH),
98            "Slice length (got {}) must be a multiple of packed field width ({}).",
99            buf.len(),
100            Self::WIDTH
101        );
102        let buf_ptr = buf.as_mut_ptr().cast::<Self>();
103        let n = buf.len() / Self::WIDTH;
104        unsafe { slice::from_raw_parts_mut(buf_ptr, n) }
105    }
106
107    /// Converts a mutable slice of possibly uninitialized scalar values into
108    /// a mutable slice of possibly uninitialized packed values.
109    ///
110    /// # Panics
111    /// Panics if the slice length is not divisible by `WIDTH`.
112    #[inline]
113    #[must_use]
114    fn pack_maybe_uninit_slice_mut(
115        buf: &mut [MaybeUninit<Self::Value>],
116    ) -> &mut [MaybeUninit<Self>] {
117        const {
118            assert!(align_of::<Self>() <= align_of::<Self::Value>());
119        }
120        assert!(
121            buf.len().is_multiple_of(Self::WIDTH),
122            "Slice length (got {}) must be a multiple of packed field width ({}).",
123            buf.len(),
124            Self::WIDTH
125        );
126        let buf_ptr = buf.as_mut_ptr().cast::<MaybeUninit<Self>>();
127        let n = buf.len() / Self::WIDTH;
128        unsafe { slice::from_raw_parts_mut(buf_ptr, n) }
129    }
130
131    /// Converts a mutable slice of scalar values into a pair:
132    /// - a slice of packed values covering the largest aligned portion,
133    /// - and a remainder slice of scalar values that couldn't be packed.
134    #[inline]
135    #[must_use]
136    fn pack_slice_with_suffix_mut(buf: &mut [Self::Value]) -> (&mut [Self], &mut [Self::Value]) {
137        let (packed, suffix) = buf.split_at_mut(buf.len() - buf.len() % Self::WIDTH);
138        (Self::pack_slice_mut(packed), suffix)
139    }
140
141    /// Converts a mutable slice of possibly uninitialized scalar values into a pair:
142    /// - a slice of possibly uninitialized packed values covering the largest aligned portion,
143    /// - and a remainder slice of possibly uninitialized scalar values that couldn't be packed.
144    #[inline]
145    #[must_use]
146    fn pack_maybe_uninit_slice_with_suffix_mut(
147        buf: &mut [MaybeUninit<Self::Value>],
148    ) -> (&mut [MaybeUninit<Self>], &mut [MaybeUninit<Self::Value>]) {
149        let (packed, suffix) = buf.split_at_mut(buf.len() - buf.len() % Self::WIDTH);
150        (Self::pack_maybe_uninit_slice_mut(packed), suffix)
151    }
152
153    /// Reinterprets a slice of packed values as a flat slice of scalar values.
154    ///
155    /// Each packed value contains `Self::WIDTH` scalar values, which are laid out
156    /// contiguously in memory. This function allows direct access to those scalars.
157    #[inline]
158    #[must_use]
159    fn unpack_slice(buf: &[Self]) -> &[Self::Value] {
160        const {
161            assert!(align_of::<Self>() >= align_of::<Self::Value>());
162        }
163        let buf_ptr = buf.as_ptr().cast::<Self::Value>();
164        let n = buf.len() * Self::WIDTH;
165        unsafe { slice::from_raw_parts(buf_ptr, n) }
166    }
167
168    /// Extract the scalar value at the given SIMD lane.
169    ///
170    /// This is equivalent to `self.as_slice()[lane]` but more explicit about the
171    /// SIMD extraction semantics.
172    #[inline]
173    #[must_use]
174    fn extract(&self, lane: usize) -> Self::Value {
175        self.as_slice()[lane]
176    }
177
178    /// Unpack `N` packed values into `WIDTH` rows of `N` scalars.
179    ///
180    /// ## Inputs
181    /// - `packed`: An array of `N` packed values.
182    /// - `rows`: A mutable slice of exactly `WIDTH` arrays to write the unpacked values.
183    ///
184    /// ## Panics
185    /// Panics if `rows.len() != WIDTH`.
186    #[inline]
187    fn unpack_into<const N: usize>(packed: &[Self; N], rows: &mut [[Self::Value; N]]) {
188        assert_eq!(rows.len(), Self::WIDTH);
189        #[allow(clippy::needless_range_loop)]
190        for lane in 0..Self::WIDTH {
191            rows[lane] = array::from_fn(|col| packed[col].extract(lane));
192        }
193    }
194}
195
196unsafe impl<T: Packable, const WIDTH: usize> PackedValue for [T; WIDTH] {
197    type Value = T;
198    const WIDTH: usize = WIDTH;
199
200    #[inline]
201    fn from_slice(slice: &[Self::Value]) -> &Self {
202        assert_eq!(slice.len(), Self::WIDTH);
203        unsafe { &*slice.as_ptr().cast() }
204    }
205
206    #[inline]
207    fn from_slice_mut(slice: &mut [Self::Value]) -> &mut Self {
208        assert_eq!(slice.len(), Self::WIDTH);
209        unsafe { &mut *slice.as_mut_ptr().cast() }
210    }
211
212    #[inline]
213    fn from_fn<Fn>(f: Fn) -> Self
214    where
215        Fn: FnMut(usize) -> Self::Value,
216    {
217        core::array::from_fn(f)
218    }
219
220    #[inline]
221    fn as_slice(&self) -> &[Self::Value] {
222        self
223    }
224
225    #[inline]
226    fn as_slice_mut(&mut self) -> &mut [Self::Value] {
227        self
228    }
229}
230
231/// An array of field elements which can be packed into a vector for SIMD operations.
232///
233/// # Safety
234/// - See `PackedValue` above.
235pub unsafe trait PackedField: Algebra<Self::Scalar>
236    + PackedValue<Value = Self::Scalar>
237    // TODO: Implement packed / packed division
238    + Div<Self::Scalar, Output = Self>
239    + DivAssign<Self::Scalar>
240    + Sum<Self::Scalar>
241    + Product<Self::Scalar>
242{
243    type Scalar: Field;
244
245    /// Construct an iterator which returns powers of `base` packed into packed field elements.
246    ///
247    /// E.g. if `Self::WIDTH = 4`, returns: `[base^0, base^1, base^2, base^3], [base^4, base^5, base^6, base^7], ...`.
248    #[must_use]
249    fn packed_powers(base: Self::Scalar) -> Powers<Self> {
250        Self::packed_shifted_powers(base, Self::Scalar::ONE)
251    }
252
253    /// Construct an iterator which returns powers of `base` multiplied by `start` and packed into packed field elements.
254    ///
255    /// E.g. if `Self::WIDTH = 4`, returns: `[start, start*base, start*base^2, start*base^3], [start*base^4, start*base^5, start*base^6, start*base^7], ...`.
256    #[must_use]
257    fn packed_shifted_powers(base: Self::Scalar, start: Self::Scalar) -> Powers<Self> {
258        let mut current: Self = start.into();
259        let slice = current.as_slice_mut();
260        for i in 1..Self::WIDTH {
261            slice[i] = slice[i - 1] * base;
262        }
263
264        Powers {
265            base: base.exp_u64(Self::WIDTH as u64).into(),
266            current,
267        }
268    }
269
270    /// Compute a linear combination of a slice of base field elements and
271    /// a slice of packed field elements. The slices must have equal length
272    /// and it must be a compile time constant.
273    ///
274    /// # Panics
275    ///
276    /// May panic if the length of either slice is not equal to `N`.
277    #[must_use]
278    fn packed_linear_combination<const N: usize>(coeffs: &[Self::Scalar], vecs: &[Self]) -> Self {
279        assert_eq!(coeffs.len(), N);
280        assert_eq!(vecs.len(), N);
281        let combined: [Self; N] = array::from_fn(|i| vecs[i] * coeffs[i]);
282        Self::sum_array::<N>(&combined)
283    }
284}
285
286/// # Safety
287/// - `WIDTH` is assumed to be a power of 2.
288pub unsafe trait PackedFieldPow2: PackedField {
289    /// Take interpret two vectors as chunks of `block_len` elements. Unpack and interleave those
290    /// chunks. This is best seen with an example. If we have:
291    /// ```text
292    /// A = [x0, y0, x1, y1]
293    /// B = [x2, y2, x3, y3]
294    /// ```
295    ///
296    /// then
297    ///
298    /// ```text
299    /// interleave(A, B, 1) = ([x0, x2, x1, x3], [y0, y2, y1, y3])
300    /// ```
301    ///
302    /// Pairs that were adjacent in the input are at corresponding positions in the output.
303    ///
304    /// `r` lets us set the size of chunks we're interleaving. If we set `block_len = 2`, then for
305    ///
306    /// ```text
307    /// A = [x0, x1, y0, y1]
308    /// B = [x2, x3, y2, y3]
309    /// ```
310    ///
311    /// we obtain
312    ///
313    /// ```text
314    /// interleave(A, B, block_len) = ([x0, x1, x2, x3], [y0, y1, y2, y3])
315    /// ```
316    ///
317    /// We can also think about this as stacking the vectors, dividing them into 2x2 matrices, and
318    /// transposing those matrices.
319    ///
320    /// When `block_len = WIDTH`, this operation is a no-op.
321    ///
322    /// # Panics
323    /// This may panic if `block_len` does not divide `WIDTH`. Since `WIDTH` is specified to be a power of 2,
324    /// `block_len` must also be a power of 2. It cannot be 0 and it cannot exceed `WIDTH`.
325    #[must_use]
326    fn interleave(&self, other: Self, block_len: usize) -> (Self, Self);
327}
328
329/// Fix a field `F` a packing width `W` and an extension field `EF` of `F`.
330///
331/// By choosing a basis `B`, `EF` can be transformed into an array `[F; D]`.
332///
333/// A type should implement PackedFieldExtension if it can be transformed into `[F::Packing; D] ~ [[F; W]; D]`
334///
335/// This is interpreted by taking a transpose to get `[[F; D]; W]` which can then be reinterpreted
336/// as `[EF; W]` by making use of the chosen basis `B` again.
337pub trait PackedFieldExtension<
338    BaseField: Field,
339    ExtField: ExtensionField<BaseField, ExtensionPacking = Self>,
340>: Algebra<ExtField> + Algebra<BaseField::Packing> + BasedVectorSpace<BaseField::Packing>
341{
342    /// Given a slice of extension field `EF` elements of length `W`,
343    /// convert into the array `[[F; D]; W]` transpose to
344    /// `[[F; W]; D]` and then pack to get `[PF; D]`.
345    #[must_use]
346    fn from_ext_slice(ext_slice: &[ExtField]) -> Self;
347
348    /// Extract the extension field element at the given SIMD lane.
349    #[inline]
350    #[must_use]
351    fn extract(&self, lane: usize) -> ExtField {
352        ExtField::from_basis_coefficients_fn(|d| {
353            self.as_basis_coefficients_slice()[d].as_slice()[lane]
354        })
355    }
356
357    /// Convert an iterator of packed extension field elements to an iterator of
358    /// extension field elements.
359    ///
360    /// This performs the inverse transformation to `from_ext_slice`.
361    #[inline]
362    #[must_use]
363    fn to_ext_iter(iter: impl IntoIterator<Item = Self>) -> impl Iterator<Item = ExtField> {
364        iter.into_iter()
365            .flat_map(|x| (0..BaseField::Packing::WIDTH).map(move |i| x.extract(i)))
366    }
367
368    /// Similar to `packed_powers`, construct an iterator which returns
369    /// powers of `base` packed into `PackedFieldExtension` elements.
370    #[must_use]
371    fn packed_ext_powers(base: ExtField) -> Powers<Self>;
372
373    /// Similar to `packed_ext_powers` but only returns `unpacked_len` powers of `base`.
374    ///
375    /// Note that the length of the returned iterator will be `unpacked_len / WIDTH` and
376    /// not `len` as the iterator is over packed extension field elements. If `unpacked_len`
377    /// is not divisible by `WIDTH`, `unpacked_len` will be rounded up to the next multiple of `WIDTH`.
378    #[must_use]
379    fn packed_ext_powers_capped(base: ExtField, unpacked_len: usize) -> impl Iterator<Item = Self> {
380        Self::packed_ext_powers(base).take(unpacked_len.div_ceil(BaseField::Packing::WIDTH))
381    }
382}
383
384unsafe impl<T: Packable> PackedValue for T {
385    type Value = Self;
386
387    const WIDTH: usize = 1;
388
389    #[inline]
390    fn from_slice(slice: &[Self::Value]) -> &Self {
391        assert_eq!(slice.len(), Self::WIDTH);
392        &slice[0]
393    }
394
395    #[inline]
396    fn from_slice_mut(slice: &mut [Self::Value]) -> &mut Self {
397        assert_eq!(slice.len(), Self::WIDTH);
398        &mut slice[0]
399    }
400
401    #[inline]
402    fn from_fn<Fn>(mut f: Fn) -> Self
403    where
404        Fn: FnMut(usize) -> Self::Value,
405    {
406        f(0)
407    }
408
409    #[inline]
410    fn as_slice(&self) -> &[Self::Value] {
411        slice::from_ref(self)
412    }
413
414    #[inline]
415    fn as_slice_mut(&mut self) -> &mut [Self::Value] {
416        slice::from_mut(self)
417    }
418}
419
420unsafe impl<F: Field> PackedField for F {
421    type Scalar = Self;
422}
423
424unsafe impl<F: Field> PackedFieldPow2 for F {
425    #[inline]
426    fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) {
427        match block_len {
428            1 => (*self, other),
429            _ => panic!("unsupported block length"),
430        }
431    }
432}
433
434impl<F: Field> PackedFieldExtension<F, F> for F::Packing {
435    #[inline]
436    fn from_ext_slice(ext_slice: &[F]) -> Self {
437        *F::Packing::from_slice(ext_slice)
438    }
439
440    #[inline]
441    fn packed_ext_powers(base: F) -> Powers<Self> {
442        F::Packing::packed_powers(base)
443    }
444}
445
446impl Packable for u8 {}
447
448impl Packable for u16 {}
449
450impl Packable for u32 {}
451
452impl Packable for u64 {}
453
454impl Packable for u128 {}