p3_field/packed/packed_traits.rs
1use core::iter::{Product, Sum};
2use core::mem::MaybeUninit;
3use core::ops::{Div, DivAssign};
4use core::{array, slice};
5
6use crate::field::Field;
7use crate::{Algebra, BasedVectorSpace, ExtensionField, Powers, PrimeCharacteristicRing};
8
9/// A trait to constrain types that can be packed into a packed value.
10///
11/// The `Packable` trait allows us to specify implementations for potentially conflicting types.
12pub trait Packable: 'static + Default + Copy + Send + Sync + PartialEq + Eq {}
13
14/// A trait for array-like structs made up of multiple scalar elements.
15///
16/// # Safety
17/// - If `P` implements `PackedField` then `P` must be castable to/from `[P::Value; P::WIDTH]`
18/// without UB.
19pub unsafe trait PackedValue: 'static + Copy + Send + Sync {
20 /// The scalar type that is packed into this value.
21 type Value: Packable;
22
23 /// Number of scalar values packed together.
24 const WIDTH: usize;
25
26 /// Constructs a packed value using a function to generate each element.
27 ///
28 /// Similar to [`core::array::from_fn`].
29 #[must_use]
30 fn from_fn<F>(f: F) -> Self
31 where
32 F: FnMut(usize) -> Self::Value;
33
34 /// Create a packed value with all lanes set to the same scalar value.
35 #[inline]
36 #[must_use]
37 fn broadcast(value: Self::Value) -> Self {
38 Self::from_fn(|_| value)
39 }
40
41 /// Interprets a slice of scalar values as a packed value reference.
42 ///
43 /// # Panics:
44 /// This function will panic if `slice.len() != Self::WIDTH`
45 #[must_use]
46 fn from_slice(slice: &[Self::Value]) -> &Self;
47
48 /// Interprets a mutable slice of scalar values as a mutable packed value.
49 ///
50 /// # Panics:
51 /// This function will panic if `slice.len() != Self::WIDTH`
52 #[must_use]
53 fn from_slice_mut(slice: &mut [Self::Value]) -> &mut Self;
54
55 /// Returns the underlying scalar values as an immutable slice.
56 #[must_use]
57 fn as_slice(&self) -> &[Self::Value];
58
59 /// Returns the underlying scalar values as a mutable slice.
60 #[must_use]
61 fn as_slice_mut(&mut self) -> &mut [Self::Value];
62
63 /// Extract the scalar value at the given SIMD lane.
64 ///
65 /// This is equivalent to `self.as_slice()[lane]` but more explicit about the
66 /// SIMD extraction semantics.
67 #[inline]
68 #[must_use]
69 fn extract(&self, lane: usize) -> Self::Value {
70 self.as_slice()[lane]
71 }
72
73 /// Packs a slice of scalar values into a slice of packed values.
74 ///
75 /// # Panics
76 /// Panics if the slice length is not divisible by `WIDTH`.
77 #[inline]
78 #[must_use]
79 fn pack_slice(buf: &[Self::Value]) -> &[Self] {
80 // Sources vary, but this should be true on all platforms we care about.
81 const {
82 assert!(align_of::<Self>() <= align_of::<Self::Value>());
83 }
84 assert!(
85 buf.len().is_multiple_of(Self::WIDTH),
86 "Slice length (got {}) must be a multiple of packed field width ({}).",
87 buf.len(),
88 Self::WIDTH
89 );
90 let buf_ptr = buf.as_ptr().cast::<Self>();
91 let n = buf.len() / Self::WIDTH;
92 unsafe { slice::from_raw_parts(buf_ptr, n) }
93 }
94
95 /// Converts a mutable slice of scalar values into a mutable slice of packed values.
96 ///
97 /// # Panics
98 /// Panics if the slice length is not divisible by `WIDTH`.
99 #[inline]
100 #[must_use]
101 fn pack_slice_mut(buf: &mut [Self::Value]) -> &mut [Self] {
102 const {
103 assert!(align_of::<Self>() <= align_of::<Self::Value>());
104 }
105 assert!(
106 buf.len().is_multiple_of(Self::WIDTH),
107 "Slice length (got {}) must be a multiple of packed field width ({}).",
108 buf.len(),
109 Self::WIDTH
110 );
111 let buf_ptr = buf.as_mut_ptr().cast::<Self>();
112 let n = buf.len() / Self::WIDTH;
113 unsafe { slice::from_raw_parts_mut(buf_ptr, n) }
114 }
115
116 /// Converts a mutable slice of possibly uninitialized scalar values into
117 /// a mutable slice of possibly uninitialized packed values.
118 ///
119 /// # Panics
120 /// Panics if the slice length is not divisible by `WIDTH`.
121 #[inline]
122 #[must_use]
123 fn pack_maybe_uninit_slice_mut(
124 buf: &mut [MaybeUninit<Self::Value>],
125 ) -> &mut [MaybeUninit<Self>] {
126 const {
127 assert!(align_of::<Self>() <= align_of::<Self::Value>());
128 }
129 assert!(
130 buf.len().is_multiple_of(Self::WIDTH),
131 "Slice length (got {}) must be a multiple of packed field width ({}).",
132 buf.len(),
133 Self::WIDTH
134 );
135 let buf_ptr = buf.as_mut_ptr().cast::<MaybeUninit<Self>>();
136 let n = buf.len() / Self::WIDTH;
137 unsafe { slice::from_raw_parts_mut(buf_ptr, n) }
138 }
139
140 /// Packs a slice into packed values and returns the packed portion and any remaining suffix.
141 #[inline]
142 #[must_use]
143 fn pack_slice_with_suffix(buf: &[Self::Value]) -> (&[Self], &[Self::Value]) {
144 let (packed, suffix) = buf.split_at(buf.len() - buf.len() % Self::WIDTH);
145 (Self::pack_slice(packed), suffix)
146 }
147
148 /// Converts a mutable slice of scalar values into a pair:
149 /// - a slice of packed values covering the largest aligned portion,
150 /// - and a remainder slice of scalar values that couldn't be packed.
151 #[inline]
152 #[must_use]
153 fn pack_slice_with_suffix_mut(buf: &mut [Self::Value]) -> (&mut [Self], &mut [Self::Value]) {
154 let (packed, suffix) = buf.split_at_mut(buf.len() - buf.len() % Self::WIDTH);
155 (Self::pack_slice_mut(packed), suffix)
156 }
157
158 /// Converts a mutable slice of possibly uninitialized scalar values into a pair:
159 /// - a slice of possibly uninitialized packed values covering the largest aligned portion,
160 /// - and a remainder slice of possibly uninitialized scalar values that couldn't be packed.
161 #[inline]
162 #[must_use]
163 fn pack_maybe_uninit_slice_with_suffix_mut(
164 buf: &mut [MaybeUninit<Self::Value>],
165 ) -> (&mut [MaybeUninit<Self>], &mut [MaybeUninit<Self::Value>]) {
166 let (packed, suffix) = buf.split_at_mut(buf.len() - buf.len() % Self::WIDTH);
167 (Self::pack_maybe_uninit_slice_mut(packed), suffix)
168 }
169
170 /// Reinterprets a slice of packed values as a flat slice of scalar values.
171 ///
172 /// Each packed value contains `Self::WIDTH` scalar values, which are laid out
173 /// contiguously in memory. This function allows direct access to those scalars.
174 #[inline]
175 #[must_use]
176 fn unpack_slice(buf: &[Self]) -> &[Self::Value] {
177 const {
178 assert!(align_of::<Self>() >= align_of::<Self::Value>());
179 }
180 let buf_ptr = buf.as_ptr().cast::<Self::Value>();
181 let n = buf.len() * Self::WIDTH;
182 unsafe { slice::from_raw_parts(buf_ptr, n) }
183 }
184
185 /// Pack columns from `WIDTH` rows of scalar values into `N` packed values.
186 ///
187 /// Given `WIDTH` rows of `N` scalar values, extract each column and pack it
188 /// into a single packed value. This is the inverse of `unpack_into`.
189 ///
190 /// ## Panics
191 /// Panics if `rows.len() != WIDTH`.
192 #[inline]
193 #[must_use]
194 fn pack_columns<const N: usize>(rows: &[[Self::Value; N]]) -> [Self; N] {
195 assert_eq!(rows.len(), Self::WIDTH);
196 array::from_fn(|col| Self::from_fn(|lane| rows[lane][col]))
197 }
198
199 /// Pack columns using a closure that provides each row's data.
200 ///
201 /// Calls `row_fn(lane)` for each lane `0..WIDTH` to get `[Self::Value; N]`,
202 /// then transposes columns into packed values. Useful when rows aren't
203 /// contiguous in memory (e.g., strided access).
204 #[inline]
205 #[must_use]
206 fn pack_columns_fn<const N: usize>(row_fn: impl Fn(usize) -> [Self::Value; N]) -> [Self; N] {
207 array::from_fn(|col| Self::from_fn(|lane| row_fn(lane)[col]))
208 }
209
210 /// Unpack `N` packed values into `WIDTH` rows of `N` scalars.
211 ///
212 /// ## Inputs
213 /// - `packed`: An array of `N` packed values.
214 /// - `rows`: A mutable slice of exactly `WIDTH` arrays to write the unpacked values.
215 ///
216 /// ## Panics
217 /// Panics if `rows.len() != WIDTH`.
218 #[inline]
219 fn unpack_into<const N: usize>(packed: &[Self; N], rows: &mut [[Self::Value; N]]) {
220 assert_eq!(rows.len(), Self::WIDTH);
221 #[allow(clippy::needless_range_loop)]
222 for lane in 0..Self::WIDTH {
223 rows[lane] = array::from_fn(|col| packed[col].extract(lane));
224 }
225 }
226
227 /// Unpack `N` packed values into an iterator of `WIDTH` rows.
228 ///
229 /// This is the iterator equivalent of `unpack_into`, yielding each row
230 /// without requiring a pre-allocated buffer.
231 #[inline]
232 fn unpack_iter<const N: usize>(packed: [Self; N]) -> impl Iterator<Item = [Self::Value; N]> {
233 (0..Self::WIDTH).map(move |lane| array::from_fn(|col| packed[col].extract(lane)))
234 }
235}
236
237unsafe impl<T: Packable, const WIDTH: usize> PackedValue for [T; WIDTH] {
238 type Value = T;
239 const WIDTH: usize = WIDTH;
240
241 #[inline]
242 fn from_slice(slice: &[Self::Value]) -> &Self {
243 assert_eq!(slice.len(), Self::WIDTH);
244 unsafe { &*slice.as_ptr().cast() }
245 }
246
247 #[inline]
248 fn from_slice_mut(slice: &mut [Self::Value]) -> &mut Self {
249 assert_eq!(slice.len(), Self::WIDTH);
250 unsafe { &mut *slice.as_mut_ptr().cast() }
251 }
252
253 #[inline]
254 fn from_fn<Fn>(f: Fn) -> Self
255 where
256 Fn: FnMut(usize) -> Self::Value,
257 {
258 core::array::from_fn(f)
259 }
260
261 #[inline]
262 fn as_slice(&self) -> &[Self::Value] {
263 self
264 }
265
266 #[inline]
267 fn as_slice_mut(&mut self) -> &mut [Self::Value] {
268 self
269 }
270}
271
272/// An array of field elements which can be packed into a vector for SIMD operations.
273///
274/// # Safety
275/// - See `PackedValue` above.
276pub unsafe trait PackedField: Algebra<Self::Scalar>
277 + PackedValue<Value = Self::Scalar>
278 // TODO: Implement packed / packed division
279 + Div<Self::Scalar, Output = Self>
280 + DivAssign<Self::Scalar>
281 + Sum<Self::Scalar>
282 + Product<Self::Scalar>
283{
284 type Scalar: Field;
285
286 /// Construct an iterator which returns powers of `base` packed into packed field elements.
287 ///
288 /// E.g. if `Self::WIDTH = 4`, returns: `[base^0, base^1, base^2, base^3], [base^4, base^5, base^6, base^7], ...`.
289 #[must_use]
290 fn packed_powers(base: Self::Scalar) -> Powers<Self> {
291 Self::packed_shifted_powers(base, Self::Scalar::ONE)
292 }
293
294 /// Construct an iterator which returns powers of `base` multiplied by `start` and packed into packed field elements.
295 ///
296 /// E.g. if `Self::WIDTH = 4`, returns: `[start, start*base, start*base^2, start*base^3], [start*base^4, start*base^5, start*base^6, start*base^7], ...`.
297 #[must_use]
298 fn packed_shifted_powers(base: Self::Scalar, start: Self::Scalar) -> Powers<Self> {
299 let mut current: Self = start.into();
300 let slice = current.as_slice_mut();
301 for i in 1..Self::WIDTH {
302 slice[i] = slice[i - 1] * base;
303 }
304
305 Powers {
306 base: base.exp_u64(Self::WIDTH as u64).into(),
307 current,
308 }
309 }
310
311 /// Compute a linear combination of a slice of base field elements and
312 /// a slice of packed field elements. The slices must have equal length
313 /// and it must be a compile time constant.
314 ///
315 /// # Panics
316 ///
317 /// May panic if the length of either slice is not equal to `N`.
318 #[must_use]
319 fn packed_linear_combination<const N: usize>(coeffs: &[Self::Scalar], vecs: &[Self]) -> Self {
320 assert_eq!(coeffs.len(), N);
321 assert_eq!(vecs.len(), N);
322 let combined: [Self; N] = array::from_fn(|i| vecs[i] * coeffs[i]);
323 Self::sum_array::<N>(&combined)
324 }
325}
326
327/// # Safety
328/// - `WIDTH` is assumed to be a power of 2.
329pub unsafe trait PackedFieldPow2: PackedField {
330 /// Take interpret two vectors as chunks of `block_len` elements. Unpack and interleave those
331 /// chunks. This is best seen with an example. If we have:
332 /// ```text
333 /// A = [x0, y0, x1, y1]
334 /// B = [x2, y2, x3, y3]
335 /// ```
336 ///
337 /// then
338 ///
339 /// ```text
340 /// interleave(A, B, 1) = ([x0, x2, x1, x3], [y0, y2, y1, y3])
341 /// ```
342 ///
343 /// Pairs that were adjacent in the input are at corresponding positions in the output.
344 ///
345 /// `r` lets us set the size of chunks we're interleaving. If we set `block_len = 2`, then for
346 ///
347 /// ```text
348 /// A = [x0, x1, y0, y1]
349 /// B = [x2, x3, y2, y3]
350 /// ```
351 ///
352 /// we obtain
353 ///
354 /// ```text
355 /// interleave(A, B, block_len) = ([x0, x1, x2, x3], [y0, y1, y2, y3])
356 /// ```
357 ///
358 /// We can also think about this as stacking the vectors, dividing them into 2x2 matrices, and
359 /// transposing those matrices.
360 ///
361 /// When `block_len = WIDTH`, this operation is a no-op.
362 ///
363 /// # Panics
364 /// This may panic if `block_len` does not divide `WIDTH`. Since `WIDTH` is specified to be a power of 2,
365 /// `block_len` must also be a power of 2. It cannot be 0 and it cannot exceed `WIDTH`.
366 #[must_use]
367 fn interleave(&self, other: Self, block_len: usize) -> (Self, Self);
368}
369
370/// Fix a field `F` a packing width `W` and an extension field `EF` of `F`.
371///
372/// By choosing a basis `B`, `EF` can be transformed into an array `[F; D]`.
373///
374/// A type should implement PackedFieldExtension if it can be transformed into `[F::Packing; D] ~ [[F; W]; D]`
375///
376/// This is interpreted by taking a transpose to get `[[F; D]; W]` which can then be reinterpreted
377/// as `[EF; W]` by making use of the chosen basis `B` again.
378pub trait PackedFieldExtension<
379 BaseField: Field,
380 ExtField: ExtensionField<BaseField, ExtensionPacking = Self>,
381>: Algebra<ExtField> + Algebra<BaseField::Packing> + BasedVectorSpace<BaseField::Packing>
382{
383 /// Given a slice of extension field `EF` elements of length `W`,
384 /// convert into the array `[[F; D]; W]` transpose to
385 /// `[[F; W]; D]` and then pack to get `[PF; D]`.
386 #[must_use]
387 fn from_ext_slice(ext_slice: &[ExtField]) -> Self;
388
389 /// Extract the extension field element at the given SIMD lane.
390 #[inline]
391 #[must_use]
392 fn extract(&self, lane: usize) -> ExtField {
393 ExtField::from_basis_coefficients_fn(|d| {
394 self.as_basis_coefficients_slice()[d].as_slice()[lane]
395 })
396 }
397
398 /// Convert an iterator of packed extension field elements to an iterator of
399 /// extension field elements.
400 ///
401 /// This performs the inverse transformation to `from_ext_slice`.
402 #[inline]
403 #[must_use]
404 fn to_ext_iter(iter: impl IntoIterator<Item = Self>) -> impl Iterator<Item = ExtField> {
405 iter.into_iter()
406 .flat_map(|x| (0..BaseField::Packing::WIDTH).map(move |lane| x.extract(lane)))
407 }
408
409 /// Similar to `packed_powers`, construct an iterator which returns
410 /// powers of `base` packed into `PackedFieldExtension` elements.
411 #[must_use]
412 fn packed_ext_powers(base: ExtField) -> Powers<Self>;
413
414 /// Similar to `packed_ext_powers` but only returns `unpacked_len` powers of `base`.
415 ///
416 /// Note that the length of the returned iterator will be `unpacked_len / WIDTH` and
417 /// not `len` as the iterator is over packed extension field elements. If `unpacked_len`
418 /// is not divisible by `WIDTH`, `unpacked_len` will be rounded up to the next multiple of `WIDTH`.
419 #[must_use]
420 fn packed_ext_powers_capped(base: ExtField, unpacked_len: usize) -> impl Iterator<Item = Self> {
421 Self::packed_ext_powers(base).take(unpacked_len.div_ceil(BaseField::Packing::WIDTH))
422 }
423}
424
425unsafe impl<T: Packable> PackedValue for T {
426 type Value = Self;
427
428 const WIDTH: usize = 1;
429
430 #[inline]
431 fn from_slice(slice: &[Self::Value]) -> &Self {
432 assert_eq!(slice.len(), Self::WIDTH);
433 &slice[0]
434 }
435
436 #[inline]
437 fn from_slice_mut(slice: &mut [Self::Value]) -> &mut Self {
438 assert_eq!(slice.len(), Self::WIDTH);
439 &mut slice[0]
440 }
441
442 #[inline]
443 fn from_fn<Fn>(mut f: Fn) -> Self
444 where
445 Fn: FnMut(usize) -> Self::Value,
446 {
447 f(0)
448 }
449
450 #[inline]
451 fn as_slice(&self) -> &[Self::Value] {
452 slice::from_ref(self)
453 }
454
455 #[inline]
456 fn as_slice_mut(&mut self) -> &mut [Self::Value] {
457 slice::from_mut(self)
458 }
459}
460
461unsafe impl<F: Field> PackedField for F {
462 type Scalar = Self;
463}
464
465unsafe impl<F: Field> PackedFieldPow2 for F {
466 #[inline]
467 fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) {
468 match block_len {
469 1 => (*self, other),
470 _ => panic!("unsupported block length"),
471 }
472 }
473}
474
475impl<F: Field> PackedFieldExtension<F, F> for F::Packing {
476 #[inline]
477 fn from_ext_slice(ext_slice: &[F]) -> Self {
478 *F::Packing::from_slice(ext_slice)
479 }
480
481 #[inline]
482 fn packed_ext_powers(base: F) -> Powers<Self> {
483 F::Packing::packed_powers(base)
484 }
485}
486
487impl Packable for u8 {}
488
489impl Packable for u16 {}
490
491impl Packable for u32 {}
492
493impl Packable for u64 {}
494
495impl Packable for u128 {}