p3_field/batch_inverse.rs
1use alloc::vec::Vec;
2
3use p3_maybe_rayon::prelude::*;
4use tracing::instrument;
5
6use crate::field::Field;
7use crate::{
8 ExtensionField, FieldArray, PackedFieldExtension, PackedValue, PrimeCharacteristicRing,
9};
10
11/// Compute the multiplicative inverse of every element in a slice via Montgomery's trick.
12///
13/// Replaces `n` field inversions with one inversion plus `~3n` multiplications:
14/// - forward pass: build prefix products of the inputs,
15/// - one inversion of the full product,
16/// - reverse pass: derive each individual inverse from the prefix products.
17///
18/// The forward pass is a long dependency chain. It is parallelised on two axes:
19/// - 4-lane packed arrays — four independent chains run side by side,
20/// - 1024-element chunks — dispatched across Rayon workers.
21///
22/// Lengths not a multiple of 4 finish with a scalar pass on the trailing 1..=3 elements.
23///
24/// # Panics
25///
26/// Panics if any input is zero.
27#[instrument(level = "debug", skip_all)]
28#[must_use]
29pub fn batch_multiplicative_inverse<F: Field>(x: &[F]) -> Vec<F> {
30 // 1024-element chunks per Rayon task.
31 //
32 // Why 1024:
33 // - amortizes the one field inversion per chunk over many multiplies,
34 // - leaves enough chunks for work-stealing on long slices.
35 const CHUNK_SIZE: usize = 1024;
36
37 // 4-lane packing.
38 //
39 // Why 4:
40 // - smallest packed-field width on every backend,
41 // - wider lanes risk register spills in the per-lane dependency chains.
42 const WIDTH: usize = 4;
43
44 // Pre-allocate the output: each Rayon task writes a disjoint sub-slice.
45 let mut result = F::zero_vec(x.len());
46
47 x.par_chunks(CHUNK_SIZE)
48 .zip(result.par_chunks_mut(CHUNK_SIZE))
49 .for_each(|(x_chunk, result_chunk)| {
50 // Phase 1 — split the chunk:
51 // - packed: 4-aligned prefix viewed as 4-lane arrays,
52 // - tail: 0..=3 trailing scalars (m = n - n%4).
53 //
54 // x_chunk: [ x_0 .. x_{m-1} | x_m .. x_{n-1} ]
55 // └──── packed ────┘└──── tail ────┘
56 let (x_packed, x_tail) = FieldArray::<F, WIDTH>::pack_slice_with_suffix(x_chunk);
57 let (result_packed, result_tail) =
58 FieldArray::<F, WIDTH>::pack_slice_with_suffix_mut(result_chunk);
59
60 // Phase 2 — packed pass: 4 independent Montgomery chains, one per lane.
61 //
62 // Final inversion lands on a 4-lane array → one scalar inversion per chunk.
63 batch_multiplicative_inverse_general(x_packed, result_packed, |y| y.inverse());
64
65 // Phase 3 — tail pass: 0..=3 leftover scalars.
66 //
67 // Empty when n % 4 == 0; this call then returns immediately.
68 batch_multiplicative_inverse_general(x_tail, result_tail, |y| y.inverse());
69 });
70
71 result
72}
73
74/// A simple single-threaded implementation of Montgomery's trick. Since not all `PrimeCharacteristicRing`s
75/// support inversion, this takes a custom inversion function.
76///
77/// Unlike [`batch_multiplicative_inverse`], this writes into a caller-provided buffer,
78/// avoiding heap allocation. This makes it suitable for small, fixed-size inputs
79/// such as packed field lanes.
80#[inline]
81pub fn batch_multiplicative_inverse_general<F, Inv>(x: &[F], result: &mut [F], inv: Inv)
82where
83 F: PrimeCharacteristicRing + Copy,
84 Inv: Fn(F) -> F,
85{
86 let n = x.len();
87 assert_eq!(result.len(), n);
88 if n == 0 {
89 return;
90 }
91
92 result[0] = F::ONE;
93 for i in 1..n {
94 result[i] = result[i - 1] * x[i - 1];
95 }
96
97 let product = result[n - 1] * x[n - 1];
98 let mut inv = inv(product);
99
100 for i in (0..n).rev() {
101 result[i] *= inv;
102 inv *= x[i];
103 }
104}
105
106/// Per-lane inverse of a packed extension via Montgomery's trick. Allocation-free.
107///
108/// Dispatches on `F::Packing::WIDTH` to a const-generic body that materializes the `W`
109/// lanes via [`PackedFieldExtension::extract`], runs [`batch_multiplicative_inverse_general`]
110/// over a stack-sized `[EF; W]` buffer, and rebuilds the packed extension via
111/// [`PackedFieldExtension::from_ext_fn`]. After monomorphization the match folds to
112/// the single live arm.
113///
114/// All `PackedField` backends in this workspace use `WIDTH ∈ {1, 2, 4, 8, 16}`; the
115/// fallback arm panics if a future backend introduces a different width.
116#[inline]
117pub fn invert_packed_extension<F, EF>(packed: EF::ExtensionPacking) -> EF::ExtensionPacking
118where
119 F: Field,
120 EF: ExtensionField<F>,
121{
122 match F::Packing::WIDTH {
123 1 => invert_packed_extension_const::<F, EF, 1>(packed),
124 2 => invert_packed_extension_const::<F, EF, 2>(packed),
125 4 => invert_packed_extension_const::<F, EF, 4>(packed),
126 8 => invert_packed_extension_const::<F, EF, 8>(packed),
127 16 => invert_packed_extension_const::<F, EF, 16>(packed),
128 w => panic!("unsupported PackedField WIDTH = {w}"),
129 }
130}
131
132#[inline]
133fn invert_packed_extension_const<F, EF, const W: usize>(
134 packed: EF::ExtensionPacking,
135) -> EF::ExtensionPacking
136where
137 F: Field,
138 EF: ExtensionField<F>,
139{
140 let lanes: [EF; W] = core::array::from_fn(|i| packed.extract(i));
141 let mut invs = [EF::ZERO; W];
142 batch_multiplicative_inverse_general(&lanes, &mut invs, |x| x.inverse());
143 EF::ExtensionPacking::from_ext_fn(|i| invs[i])
144}