p3_field/packed/packed_traits.rs
1use core::iter::{Product, Sum};
2use core::mem::MaybeUninit;
3use core::ops::{Div, DivAssign};
4use core::{array, slice};
5
6use crate::field::Field;
7use crate::{Algebra, BasedVectorSpace, ExtensionField, Powers, PrimeCharacteristicRing};
8
9/// A trait to constrain types that can be packed into a packed value.
10///
11/// The `Packable` trait allows us to specify implementations for potentially conflicting types.
12pub trait Packable: 'static + Default + Copy + Send + Sync + PartialEq + Eq {}
13
14/// A trait for array-like structs made up of multiple scalar elements.
15///
16/// # Safety
17/// - If `P` implements `PackedField` then `P` must be castable to/from `[P::Value; P::WIDTH]`
18/// without UB.
19pub unsafe trait PackedValue: 'static + Copy + Send + Sync {
20 /// The scalar type that is packed into this value.
21 type Value: Packable;
22
23 /// Number of scalar values packed together.
24 const WIDTH: usize;
25
26 /// Constructs a packed value using a function to generate each element.
27 ///
28 /// Similar to [`core::array::from_fn`].
29 #[must_use]
30 fn from_fn<F>(f: F) -> Self
31 where
32 F: FnMut(usize) -> Self::Value;
33
34 /// Create a packed value with all lanes set to the same scalar value.
35 #[inline]
36 #[must_use]
37 fn broadcast(value: Self::Value) -> Self {
38 Self::from_fn(|_| value)
39 }
40
41 /// Interprets a slice of scalar values as a packed value reference.
42 ///
43 /// # Panics:
44 /// This function will panic if `slice.len() != Self::WIDTH`
45 #[must_use]
46 fn from_slice(slice: &[Self::Value]) -> &Self;
47
48 /// Interprets a mutable slice of scalar values as a mutable packed value.
49 ///
50 /// # Panics:
51 /// This function will panic if `slice.len() != Self::WIDTH`
52 #[must_use]
53 fn from_slice_mut(slice: &mut [Self::Value]) -> &mut Self;
54
55 /// Returns the underlying scalar values as an immutable slice.
56 #[must_use]
57 fn as_slice(&self) -> &[Self::Value];
58
59 /// Returns the underlying scalar values as a mutable slice.
60 #[must_use]
61 fn as_slice_mut(&mut self) -> &mut [Self::Value];
62
63 /// Extract the scalar value at the given SIMD lane.
64 ///
65 /// This is equivalent to `self.as_slice()[lane]` but more explicit about the
66 /// SIMD extraction semantics.
67 #[inline]
68 #[must_use]
69 fn extract(&self, lane: usize) -> Self::Value {
70 self.as_slice()[lane]
71 }
72
73 /// Packs a slice of scalar values into a slice of packed values.
74 ///
75 /// # Panics
76 /// Panics if the slice length is not divisible by `WIDTH`.
77 #[inline]
78 #[must_use]
79 fn pack_slice(buf: &[Self::Value]) -> &[Self] {
80 // Sources vary, but this should be true on all platforms we care about.
81 const {
82 assert!(align_of::<Self>() <= align_of::<Self::Value>());
83 }
84 assert!(
85 buf.len().is_multiple_of(Self::WIDTH),
86 "Slice length (got {}) must be a multiple of packed field width ({}).",
87 buf.len(),
88 Self::WIDTH
89 );
90 let buf_ptr = buf.as_ptr().cast::<Self>();
91 let n = buf.len() / Self::WIDTH;
92 unsafe { slice::from_raw_parts(buf_ptr, n) }
93 }
94
95 /// Converts a mutable slice of scalar values into a mutable slice of packed values.
96 ///
97 /// # Panics
98 /// Panics if the slice length is not divisible by `WIDTH`.
99 #[inline]
100 #[must_use]
101 fn pack_slice_mut(buf: &mut [Self::Value]) -> &mut [Self] {
102 const {
103 assert!(align_of::<Self>() <= align_of::<Self::Value>());
104 }
105 assert!(
106 buf.len().is_multiple_of(Self::WIDTH),
107 "Slice length (got {}) must be a multiple of packed field width ({}).",
108 buf.len(),
109 Self::WIDTH
110 );
111 let buf_ptr = buf.as_mut_ptr().cast::<Self>();
112 let n = buf.len() / Self::WIDTH;
113 unsafe { slice::from_raw_parts_mut(buf_ptr, n) }
114 }
115
116 /// Converts a mutable slice of possibly uninitialized scalar values into
117 /// a mutable slice of possibly uninitialized packed values.
118 ///
119 /// # Panics
120 /// Panics if the slice length is not divisible by `WIDTH`.
121 #[inline]
122 #[must_use]
123 fn pack_maybe_uninit_slice_mut(
124 buf: &mut [MaybeUninit<Self::Value>],
125 ) -> &mut [MaybeUninit<Self>] {
126 const {
127 assert!(align_of::<Self>() <= align_of::<Self::Value>());
128 }
129 assert!(
130 buf.len().is_multiple_of(Self::WIDTH),
131 "Slice length (got {}) must be a multiple of packed field width ({}).",
132 buf.len(),
133 Self::WIDTH
134 );
135 let buf_ptr = buf.as_mut_ptr().cast::<MaybeUninit<Self>>();
136 let n = buf.len() / Self::WIDTH;
137 unsafe { slice::from_raw_parts_mut(buf_ptr, n) }
138 }
139
140 /// Packs a slice into packed values and returns the packed portion and any remaining suffix.
141 #[inline]
142 #[must_use]
143 fn pack_slice_with_suffix(buf: &[Self::Value]) -> (&[Self], &[Self::Value]) {
144 let (packed, suffix) = buf.split_at(buf.len() - buf.len() % Self::WIDTH);
145 (Self::pack_slice(packed), suffix)
146 }
147
148 /// Converts a mutable slice of scalar values into a pair:
149 /// - a slice of packed values covering the largest aligned portion,
150 /// - and a remainder slice of scalar values that couldn't be packed.
151 #[inline]
152 #[must_use]
153 fn pack_slice_with_suffix_mut(buf: &mut [Self::Value]) -> (&mut [Self], &mut [Self::Value]) {
154 let (packed, suffix) = buf.split_at_mut(buf.len() - buf.len() % Self::WIDTH);
155 (Self::pack_slice_mut(packed), suffix)
156 }
157
158 /// Converts a mutable slice of possibly uninitialized scalar values into a pair:
159 /// - a slice of possibly uninitialized packed values covering the largest aligned portion,
160 /// - and a remainder slice of possibly uninitialized scalar values that couldn't be packed.
161 #[inline]
162 #[must_use]
163 fn pack_maybe_uninit_slice_with_suffix_mut(
164 buf: &mut [MaybeUninit<Self::Value>],
165 ) -> (&mut [MaybeUninit<Self>], &mut [MaybeUninit<Self::Value>]) {
166 let (packed, suffix) = buf.split_at_mut(buf.len() - buf.len() % Self::WIDTH);
167 (Self::pack_maybe_uninit_slice_mut(packed), suffix)
168 }
169
170 /// Reinterprets a slice of packed values as a flat slice of scalar values.
171 ///
172 /// Each packed value contains `Self::WIDTH` scalar values, which are laid out
173 /// contiguously in memory. This function allows direct access to those scalars.
174 #[inline]
175 #[must_use]
176 fn unpack_slice(buf: &[Self]) -> &[Self::Value] {
177 const {
178 assert!(align_of::<Self>() >= align_of::<Self::Value>());
179 }
180 let buf_ptr = buf.as_ptr().cast::<Self::Value>();
181 let n = buf.len() * Self::WIDTH;
182 unsafe { slice::from_raw_parts(buf_ptr, n) }
183 }
184
185 /// Pack columns from `WIDTH` rows of scalar values into `N` packed values.
186 ///
187 /// Given `WIDTH` rows of `N` scalar values, extract each column and pack it
188 /// into a single packed value. This is the inverse of `unpack_into`.
189 ///
190 /// ## Panics
191 /// Panics if `rows.len() != WIDTH`.
192 #[inline]
193 #[must_use]
194 fn pack_columns<const N: usize>(rows: &[[Self::Value; N]]) -> [Self; N] {
195 assert_eq!(rows.len(), Self::WIDTH);
196 array::from_fn(|col| Self::from_fn(|lane| rows[lane][col]))
197 }
198
199 /// Pack columns using a closure that provides each row's data.
200 ///
201 /// Calls `row_fn(lane)` for each lane `0..WIDTH` to get `[Self::Value; N]`,
202 /// then transposes columns into packed values. Useful when rows aren't
203 /// contiguous in memory (e.g., strided access).
204 #[inline]
205 #[must_use]
206 fn pack_columns_fn<const N: usize>(row_fn: impl Fn(usize) -> [Self::Value; N]) -> [Self; N] {
207 array::from_fn(|col| Self::from_fn(|lane| row_fn(lane)[col]))
208 }
209
210 /// Unpack `N` packed values into `WIDTH` rows of `N` scalars.
211 ///
212 /// ## Inputs
213 /// - `packed`: An array of `N` packed values.
214 /// - `rows`: A mutable slice of exactly `WIDTH` arrays to write the unpacked values.
215 ///
216 /// ## Panics
217 /// Panics if `rows.len() != WIDTH`.
218 #[inline]
219 fn unpack_into<const N: usize>(packed: &[Self; N], rows: &mut [[Self::Value; N]]) {
220 assert_eq!(rows.len(), Self::WIDTH);
221 for (lane, row) in rows.iter_mut().enumerate() {
222 *row = array::from_fn(|col| packed[col].extract(lane));
223 }
224 }
225
226 /// Unpack `N` packed values into an iterator of `WIDTH` rows.
227 ///
228 /// This is the iterator equivalent of `unpack_into`, yielding each row
229 /// without requiring a pre-allocated buffer.
230 #[inline]
231 fn unpack_iter<const N: usize>(packed: [Self; N]) -> impl Iterator<Item = [Self::Value; N]> {
232 (0..Self::WIDTH).map(move |lane| array::from_fn(|col| packed[col].extract(lane)))
233 }
234}
235
236unsafe impl<T: Packable, const WIDTH: usize> PackedValue for [T; WIDTH] {
237 type Value = T;
238 const WIDTH: usize = WIDTH;
239
240 #[inline]
241 fn from_slice(slice: &[Self::Value]) -> &Self {
242 assert_eq!(slice.len(), Self::WIDTH);
243 unsafe { &*slice.as_ptr().cast() }
244 }
245
246 #[inline]
247 fn from_slice_mut(slice: &mut [Self::Value]) -> &mut Self {
248 assert_eq!(slice.len(), Self::WIDTH);
249 unsafe { &mut *slice.as_mut_ptr().cast() }
250 }
251
252 #[inline]
253 fn from_fn<Fn>(f: Fn) -> Self
254 where
255 Fn: FnMut(usize) -> Self::Value,
256 {
257 core::array::from_fn(f)
258 }
259
260 #[inline]
261 fn as_slice(&self) -> &[Self::Value] {
262 self
263 }
264
265 #[inline]
266 fn as_slice_mut(&mut self) -> &mut [Self::Value] {
267 self
268 }
269}
270
271/// An array of field elements which can be packed into a vector for SIMD operations.
272///
273/// # Safety
274/// - See `PackedValue` above.
275pub unsafe trait PackedField:
276 Algebra<Self::Scalar>
277 + PackedValue<Value = Self::Scalar>
278 + Div<Self, Output = Self>
279 + Div<Self::Scalar, Output = Self>
280 + DivAssign<Self>
281 + DivAssign<Self::Scalar>
282 + Sum<Self::Scalar>
283 + Product<Self::Scalar>
284{
285 type Scalar: Field;
286
287 /// Construct an iterator which returns powers of `base` packed into packed field elements.
288 ///
289 /// E.g. if `Self::WIDTH = 4`, returns: `[base^0, base^1, base^2, base^3], [base^4, base^5, base^6, base^7], ...`.
290 #[must_use]
291 fn packed_powers(base: Self::Scalar) -> Powers<Self> {
292 Self::packed_shifted_powers(base, Self::Scalar::ONE)
293 }
294
295 /// Construct an iterator which returns powers of `base` multiplied by `start` and packed into packed field elements.
296 ///
297 /// E.g. if `Self::WIDTH = 4`, returns: `[start, start*base, start*base^2, start*base^3], [start*base^4, start*base^5, start*base^6, start*base^7], ...`.
298 #[must_use]
299 fn packed_shifted_powers(base: Self::Scalar, start: Self::Scalar) -> Powers<Self> {
300 let mut current: Self = start.into();
301 let slice = current.as_slice_mut();
302 for i in 1..Self::WIDTH {
303 slice[i] = slice[i - 1] * base;
304 }
305
306 Powers {
307 base: base.exp_u64(Self::WIDTH as u64).into(),
308 current,
309 }
310 }
311}
312
313/// # Safety
314/// - `WIDTH` is assumed to be a power of 2.
315pub unsafe trait PackedFieldPow2: PackedField {
316 /// Take interpret two vectors as chunks of `block_len` elements. Unpack and interleave those
317 /// chunks. This is best seen with an example. If we have:
318 /// ```text
319 /// A = [x0, y0, x1, y1]
320 /// B = [x2, y2, x3, y3]
321 /// ```
322 ///
323 /// then
324 ///
325 /// ```text
326 /// interleave(A, B, 1) = ([x0, x2, x1, x3], [y0, y2, y1, y3])
327 /// ```
328 ///
329 /// Pairs that were adjacent in the input are at corresponding positions in the output.
330 ///
331 /// `r` lets us set the size of chunks we're interleaving. If we set `block_len = 2`, then for
332 ///
333 /// ```text
334 /// A = [x0, x1, y0, y1]
335 /// B = [x2, x3, y2, y3]
336 /// ```
337 ///
338 /// we obtain
339 ///
340 /// ```text
341 /// interleave(A, B, block_len) = ([x0, x1, x2, x3], [y0, y1, y2, y3])
342 /// ```
343 ///
344 /// We can also think about this as stacking the vectors, dividing them into 2x2 matrices, and
345 /// transposing those matrices.
346 ///
347 /// When `block_len = WIDTH`, this operation is a no-op.
348 ///
349 /// # Panics
350 /// This may panic if `block_len` does not divide `WIDTH`. Since `WIDTH` is specified to be a power of 2,
351 /// `block_len` must also be a power of 2. It cannot be 0 and it cannot exceed `WIDTH`.
352 #[must_use]
353 fn interleave(&self, other: Self, block_len: usize) -> (Self, Self);
354}
355
356/// Fix a field `F` a packing width `W` and an extension field `EF` of `F`.
357///
358/// By choosing a basis `B`, `EF` can be transformed into an array `[F; D]`.
359///
360/// A type should implement PackedFieldExtension if it can be transformed into `[F::Packing; D] ~ [[F; W]; D]`
361///
362/// This is interpreted by taking a transpose to get `[[F; D]; W]` which can then be reinterpreted
363/// as `[EF; W]` by making use of the chosen basis `B` again.
364pub trait PackedFieldExtension<
365 BaseField: Field,
366 ExtField: ExtensionField<BaseField, ExtensionPacking = Self>,
367>: Algebra<ExtField> + Algebra<BaseField::Packing> + BasedVectorSpace<BaseField::Packing>
368{
369 /// Construct a packed extension by applying `f` to each lane.
370 ///
371 /// This is the extension-field analog of [`PackedValue::from_fn`] and the canonical
372 /// primitive constructor for packed extensions: every other constructor in this
373 /// trait (`from_ext_slice`, `pack_ext_columns`, etc.) routes through it.
374 ///
375 /// `f` is called once per `(basis_coefficient, lane)` pair (`D * W` calls total),
376 /// hence the [`Fn`] bound — closures with side effects are unsuitable.
377 ///
378 /// The default impl uses only the [`BasedVectorSpace`] machinery the trait already
379 /// requires. Concrete impls should override when the extension struct exposes its
380 /// base packings directly, e.g. `Self::new(F::Packing::pack_columns_fn(|l| f(l).value))`.
381 #[inline]
382 #[must_use]
383 fn from_ext_fn(f: impl Fn(usize) -> ExtField) -> Self {
384 Self::from_basis_coefficients_fn(|d| {
385 BaseField::Packing::from_fn(|lane| f(lane).as_basis_coefficients_slice()[d])
386 })
387 }
388
389 /// Pack a length-`WIDTH` slice of extension field elements into one packed extension.
390 ///
391 /// ## Panics
392 /// Panics if `slice.len() != BaseField::Packing::WIDTH`.
393 #[inline]
394 #[must_use]
395 fn from_ext_slice(slice: &[ExtField]) -> Self {
396 assert_eq!(slice.len(), BaseField::Packing::WIDTH);
397 Self::from_ext_fn(|lane| slice[lane])
398 }
399
400 /// Pack `N` columns from `W` rows of extension field elements into `N` packed extensions.
401 ///
402 /// This is the extension-field analog of [`PackedValue::pack_columns`]: given `W` rows
403 /// of `N` extension elements, lane `lane` of output column `col` is `rows[lane][col]`.
404 ///
405 /// ## Panics
406 /// Panics if `rows.len() != BaseField::Packing::WIDTH`.
407 #[inline]
408 #[must_use]
409 fn pack_ext_columns<const N: usize>(rows: &[[ExtField; N]]) -> [Self; N] {
410 assert_eq!(rows.len(), BaseField::Packing::WIDTH);
411 array::from_fn(|col| Self::from_ext_fn(|lane| rows[lane][col]))
412 }
413
414 /// Pack `N` columns using a closure that produces each row.
415 ///
416 /// Analog of [`PackedValue::pack_columns_fn`].
417 #[inline]
418 #[must_use]
419 fn pack_ext_columns_fn<const N: usize>(row_fn: impl Fn(usize) -> [ExtField; N]) -> [Self; N] {
420 array::from_fn(|col| Self::from_ext_fn(|lane| row_fn(lane)[col]))
421 }
422
423 /// Extract the extension field element at the given SIMD lane.
424 #[inline]
425 #[must_use]
426 fn extract(&self, lane: usize) -> ExtField {
427 ExtField::from_basis_coefficients_fn(|d| {
428 self.as_basis_coefficients_slice()[d].as_slice()[lane]
429 })
430 }
431
432 /// Write all `W` lanes into the given slice.
433 ///
434 /// This is the extension-field analog of [`PackedValue::as_slice`], but the lanes of
435 /// a packed extension are not contiguous in memory (the layout is `[[F; W]; D]`,
436 /// indexed first by basis coefficient), so the lanes must be copied rather than
437 /// borrowed.
438 ///
439 /// ## Panics
440 /// Panics if `out.len() != BaseField::Packing::WIDTH`.
441 #[inline]
442 fn to_ext_slice(&self, out: &mut [ExtField]) {
443 assert_eq!(out.len(), BaseField::Packing::WIDTH);
444 for (lane, slot) in out.iter_mut().enumerate() {
445 *slot = self.extract(lane);
446 }
447 }
448
449 /// Unpack `N` packed extensions into `W` rows of `N` extension elements.
450 ///
451 /// Inverse of [`PackedFieldExtension::pack_ext_columns`]. Lane `lane` of input
452 /// column `col` is written to `rows[lane][col]`.
453 ///
454 /// ## Panics
455 /// Panics if `rows.len() != BaseField::Packing::WIDTH`.
456 #[inline]
457 fn unpack_ext_into<const N: usize>(packed: &[Self; N], rows: &mut [[ExtField; N]]) {
458 assert_eq!(rows.len(), BaseField::Packing::WIDTH);
459 for (lane, row) in rows.iter_mut().enumerate() {
460 *row = array::from_fn(|col| {
461 ExtField::from_basis_coefficients_fn(|d| {
462 packed[col].as_basis_coefficients_slice()[d].as_slice()[lane]
463 })
464 });
465 }
466 }
467
468 /// Iterator equivalent of [`PackedFieldExtension::unpack_ext_into`].
469 ///
470 /// Yields `WIDTH` rows of `N` extension elements without requiring a pre-allocated
471 /// buffer. Analog of [`PackedValue::unpack_iter`].
472 #[inline]
473 fn unpack_ext_iter<const N: usize>(packed: [Self; N]) -> impl Iterator<Item = [ExtField; N]> {
474 (0..BaseField::Packing::WIDTH).map(move |lane| {
475 array::from_fn(|col| {
476 ExtField::from_basis_coefficients_fn(|d| {
477 packed[col].as_basis_coefficients_slice()[d].as_slice()[lane]
478 })
479 })
480 })
481 }
482
483 /// Convert an iterator of packed extension field elements to an iterator of
484 /// extension field elements (flat — one `ExtField` per lane per packed value).
485 #[inline]
486 #[must_use]
487 fn to_ext_iter(iter: impl IntoIterator<Item = Self>) -> impl Iterator<Item = ExtField> {
488 iter.into_iter()
489 .flat_map(|x| (0..BaseField::Packing::WIDTH).map(move |lane| x.extract(lane)))
490 }
491
492 /// Similar to `packed_powers`, construct an iterator which returns
493 /// powers of `base` packed into `PackedFieldExtension` elements.
494 #[must_use]
495 fn packed_ext_powers(base: ExtField) -> Powers<Self>;
496
497 /// Similar to `packed_ext_powers` but only returns `unpacked_len` powers of `base`.
498 ///
499 /// Note that the length of the returned iterator will be `unpacked_len / WIDTH` and
500 /// not `len` as the iterator is over packed extension field elements. If `unpacked_len`
501 /// is not divisible by `WIDTH`, `unpacked_len` will be rounded up to the next multiple of `WIDTH`.
502 #[must_use]
503 fn packed_ext_powers_capped(base: ExtField, unpacked_len: usize) -> impl Iterator<Item = Self> {
504 Self::packed_ext_powers(base).take(unpacked_len.div_ceil(BaseField::Packing::WIDTH))
505 }
506}
507
508unsafe impl<T: Packable> PackedValue for T {
509 type Value = Self;
510
511 const WIDTH: usize = 1;
512
513 #[inline]
514 fn from_slice(slice: &[Self::Value]) -> &Self {
515 assert_eq!(slice.len(), Self::WIDTH);
516 &slice[0]
517 }
518
519 #[inline]
520 fn from_slice_mut(slice: &mut [Self::Value]) -> &mut Self {
521 assert_eq!(slice.len(), Self::WIDTH);
522 &mut slice[0]
523 }
524
525 #[inline]
526 fn from_fn<Fn>(mut f: Fn) -> Self
527 where
528 Fn: FnMut(usize) -> Self::Value,
529 {
530 f(0)
531 }
532
533 #[inline]
534 fn as_slice(&self) -> &[Self::Value] {
535 slice::from_ref(self)
536 }
537
538 #[inline]
539 fn as_slice_mut(&mut self) -> &mut [Self::Value] {
540 slice::from_mut(self)
541 }
542}
543
544unsafe impl<F: Field> PackedField for F {
545 type Scalar = Self;
546}
547
548unsafe impl<F: Field> PackedFieldPow2 for F {
549 #[inline]
550 fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) {
551 match block_len {
552 1 => (*self, other),
553 _ => panic!("unsupported block length"),
554 }
555 }
556}
557
558impl<F: Field> PackedFieldExtension<F, F> for F::Packing {
559 #[inline]
560 fn from_ext_fn(f: impl Fn(usize) -> F) -> Self {
561 F::Packing::from_fn(f)
562 }
563
564 #[inline]
565 fn from_ext_slice(slice: &[F]) -> Self {
566 *F::Packing::from_slice(slice)
567 }
568
569 #[inline]
570 fn packed_ext_powers(base: F) -> Powers<Self> {
571 F::Packing::packed_powers(base)
572 }
573}
574
575impl Packable for u8 {}
576
577impl Packable for u16 {}
578
579impl Packable for u32 {}
580
581impl Packable for u64 {}
582
583impl Packable for u128 {}