Skip to main content

webgpu_groth16/bucket/
mod.rs

1//! Pippenger bucket sorting with signed-digit scalar decomposition.
2//!
3//! Prepares [`BucketData`] for GPU MSM dispatch by decomposing scalars into
4//! signed-digit windows and grouping points by (window, bucket_value).
5//! Large buckets are split into sub-buckets for GPU load balancing (see
6//! [`GpuCurve::MSM_MAX_CHUNK_SIZE`]).
7//!
8//! Two modes (selected by [`GpuCurve::HAS_G1_GLV`]):
9//! - **Standard** ([`compute_bucket_sorting`]): direct scalar decomposition
10//! - **GLV-capable** ([`compute_glv_bucket_sorting`],
11//!   [`compute_glv_bucket_data`]): uses curve-provided endomorphism
12//!   decomposition hooks when available
13
14use crate::gpu::curve::{G1MsmDecomposition, GpuCurve};
15
16/// Bucket sorting result for GPU MSM dispatch.
17///
18/// Uses a Structure-of-Arrays layout: each array is uploaded as a separate
19/// `storage<read_only>` GPU buffer. This avoids struct padding issues in WGSL
20/// and allows independent buffer bindings per kernel.
21///
22/// When sub-bucket chunking is active (`has_chunks == true`), the parallel
23/// arrays (`bucket_pointers`, `bucket_sizes`, `bucket_values`, `window_starts`,
24/// `window_counts`) describe *sub-buckets* (dispatched units), not logical
25/// buckets. The `reduce_starts`/`reduce_counts` arrays map original bucket
26/// indices to their sub-bucket ranges for a post-aggregation reduction pass.
27///
28/// Invariants:
29/// - `bucket_pointers[i]` is the starting index in `base_indices` for
30///   sub-bucket `i`
31/// - `bucket_sizes[i]` is the count of points in sub-bucket `i`
32/// - `bucket_values[i]` is the scalar weight for sub-bucket `i` (in `[1,
33///   2^(c-1)]`)
34/// - `window_starts[w]` is the first sub-bucket index belonging to window `w`
35/// - `window_counts[w]` is the number of sub-buckets in window `w`
36/// - `reduce_starts[j]` is the first sub-bucket index for original bucket `j`
37/// - `reduce_counts[j]` is the number of sub-buckets for original bucket `j`
38pub struct BucketData {
39    pub base_indices: Vec<u32>,
40    /// Sub-bucket pointers into base_indices (length = num_dispatched).
41    pub bucket_pointers: Vec<u32>,
42    /// Sub-bucket sizes (length = num_dispatched).
43    pub bucket_sizes: Vec<u32>,
44    /// Sub-bucket values, same as parent's value (length = num_dispatched).
45    pub bucket_values: Vec<u32>,
46    /// Sub-bucket window starts (length = num_windows).
47    pub window_starts: Vec<u32>,
48    /// Sub-bucket counts per window (length = num_windows).
49    pub window_counts: Vec<u32>,
50    pub num_windows: u32,
51    /// Number of original (logical) buckets.
52    pub num_active_buckets: u32,
53    /// Number of dispatched sub-buckets (>= num_active_buckets when chunking
54    /// occurs).
55    pub num_dispatched: u32,
56    /// Original bucket values for weight/subsum passes (length =
57    /// num_active_buckets).
58    pub orig_bucket_values: Vec<u32>,
59    /// Original window starts for weight/subsum passes (length = num_windows).
60    pub orig_window_starts: Vec<u32>,
61    /// Original window counts for weight/subsum passes (length = num_windows).
62    pub orig_window_counts: Vec<u32>,
63    /// Start offset in the dispatch buffer for each original bucket.
64    pub reduce_starts: Vec<u32>,
65    /// Number of sub-buckets for each original bucket.
66    pub reduce_counts: Vec<u32>,
67    /// Whether any bucket was split into sub-buckets.
68    pub has_chunks: bool,
69    pub bucket_width: usize,
70}
71
72impl BucketData {
73    /// Print bucket size distribution statistics for diagnosing workload
74    /// imbalance. Only active when the `timing` feature is enabled.
75    #[cfg(feature = "timing")]
76    pub fn print_distribution_stats(&self, label: &str) {
77        if self.num_active_buckets == 0 {
78            eprintln!("[bucket-diag] {label}: 0 active buckets");
79            return;
80        }
81        let mut sizes: Vec<u32> = self.bucket_sizes.clone();
82        sizes.sort();
83        let n = sizes.len();
84        let total: u32 = sizes.iter().sum();
85        let max = *sizes.last().unwrap();
86        let min = *sizes.first().unwrap();
87        let mean = total as f64 / n as f64;
88        let median = sizes[n / 2];
89        let p90 = sizes[(n * 90) / 100];
90        let p95 = sizes[(n * 95) / 100];
91        let p99 = sizes[n.saturating_sub(1).min((n * 99) / 100)];
92
93        let over_64 = sizes.iter().filter(|&&s| s > 64).count();
94        let over_256 = sizes.iter().filter(|&&s| s > 256).count();
95        let over_1024 = sizes.iter().filter(|&&s| s > 1024).count();
96
97        eprintln!(
98            "[bucket-diag] {label}: {n} active buckets, {total} total points, \
99             c={}",
100            self.bucket_width
101        );
102        eprintln!(
103            "[bucket-diag]   min={min} max={max} mean={mean:.1} \
104             median={median}"
105        );
106        eprintln!("[bucket-diag]   p90={p90} p95={p95} p99={p99}");
107        eprintln!(
108            "[bucket-diag]   >64: {over_64}  >256: {over_256}  >1024: \
109             {over_1024}"
110        );
111
112        // Per-window summary for windows with large buckets
113        for w in 0..self.num_windows as usize {
114            let start = self.window_starts[w] as usize;
115            let count = self.window_counts[w] as usize;
116            if count == 0 {
117                continue;
118            }
119            let w_sizes: Vec<u32> = (start..start + count)
120                .map(|i| self.bucket_sizes[i])
121                .collect();
122            let w_max = *w_sizes.iter().max().unwrap();
123            let w_total: u32 = w_sizes.iter().sum();
124            // Find the bucket value with max size
125            let max_idx = w_sizes.iter().position(|&s| s == w_max).unwrap();
126            let max_val = self.bucket_values[start + max_idx];
127            if w_max > 32 {
128                eprintln!(
129                    "[bucket-diag]   window {w}: {count} buckets, \
130                     max_size={w_max} (val={max_val}), total={w_total}"
131                );
132            }
133        }
134    }
135}
136
137/// Builds `BucketData` from pre-computed signed-digit window decompositions.
138///
139/// `all_windows[i]` contains the (absolute_value, is_negative) pairs for point
140/// `i`. `c` is the bucket width (window size in bits).
141///
142/// ## Algorithm (two-pass Pippenger bucket sorting with sub-bucket chunking)
143///
144/// **Pass 1 — Group points by (window, bucket_value):**
145/// For each window w, iterate over all points and place each into the bucket
146/// corresponding to its signed-digit value. Produces flat arrays of:
147/// base_indices (point IDs, sign-encoded), pointers, sizes, and values per
148/// bucket.
149///
150/// **Pass 2 — Split oversized buckets for GPU load balancing:**
151/// Buckets with more than `G::MSM_MAX_CHUNK_SIZE` points are split into
152/// sub-buckets. Each sub-bucket becomes an independent GPU thread. A
153/// reduce_starts/reduce_counts table records which sub-buckets belong to the
154/// same logical bucket, so a later GPU reduce pass can sum the sub-bucket
155/// partials back together.
156fn build_bucket_data<G: GpuCurve>(
157    all_windows: &[Vec<(u32, bool)>],
158    c: usize,
159) -> BucketData {
160    let num_windows = all_windows.iter().map(|w| w.len()).max().unwrap_or(0);
161    let num_buckets = (1usize << (c - 1)) + 1;
162
163    // First pass: collect points into logical buckets per window.
164    let mut base_indices = Vec::new();
165    let mut orig_pointers = Vec::new();
166    let mut orig_sizes = Vec::new();
167    let mut orig_values = Vec::new();
168    let mut orig_window_starts = Vec::new();
169    let mut orig_window_counts = Vec::new();
170
171    for w in 0..num_windows {
172        let mut buckets: Vec<Vec<u32>> = vec![Vec::new(); num_buckets];
173
174        for (i, windows) in all_windows.iter().enumerate() {
175            if w < windows.len() {
176                let (abs, neg) = windows[w];
177                if abs != 0 {
178                    let entry = if neg {
179                        i as u32 | G::MSM_INDEX_SIGN_BIT
180                    } else {
181                        i as u32
182                    };
183                    buckets[abs as usize].push(entry);
184                }
185            }
186        }
187
188        orig_window_starts.push(orig_values.len() as u32);
189        let mut count = 0u32;
190
191        for (val, indices) in buckets.into_iter().enumerate() {
192            if !indices.is_empty() {
193                orig_pointers.push(base_indices.len() as u32);
194                orig_sizes.push(indices.len() as u32);
195                orig_values.push(val as u32);
196                base_indices.extend(indices);
197                count += 1;
198            }
199        }
200        orig_window_counts.push(count);
201    }
202
203    let num_active_buckets = orig_sizes.len() as u32;
204
205    // Second pass: split large buckets into sub-buckets.
206    let mut bucket_pointers = Vec::new();
207    let mut bucket_sizes = Vec::new();
208    let mut bucket_values = Vec::new();
209    let mut window_starts = Vec::new();
210    let mut window_counts = Vec::new();
211    let mut reduce_starts = Vec::new();
212    let mut reduce_counts = Vec::new();
213    let mut has_chunks = false;
214
215    for w in 0..num_windows {
216        let w_start = orig_window_starts[w] as usize;
217        let w_count = orig_window_counts[w] as usize;
218        window_starts.push(bucket_pointers.len() as u32);
219        let mut dispatched_in_window = 0u32;
220
221        for b in 0..w_count {
222            let orig_idx = w_start + b;
223            let ptr = orig_pointers[orig_idx];
224            let size = orig_sizes[orig_idx];
225            let val = orig_values[orig_idx];
226
227            let sub_start = bucket_pointers.len() as u32;
228
229            if size <= G::MSM_MAX_CHUNK_SIZE {
230                bucket_pointers.push(ptr);
231                bucket_sizes.push(size);
232                bucket_values.push(val);
233                reduce_starts.push(sub_start);
234                reduce_counts.push(1);
235                dispatched_in_window += 1;
236            } else {
237                has_chunks = true;
238                let num_chunks = size.div_ceil(G::MSM_MAX_CHUNK_SIZE);
239                for chunk in 0..num_chunks {
240                    let chunk_start = ptr + chunk * G::MSM_MAX_CHUNK_SIZE;
241                    let chunk_size = (size - chunk * G::MSM_MAX_CHUNK_SIZE)
242                        .min(G::MSM_MAX_CHUNK_SIZE);
243                    bucket_pointers.push(chunk_start);
244                    bucket_sizes.push(chunk_size);
245                    bucket_values.push(val);
246                    dispatched_in_window += 1;
247                }
248                reduce_starts.push(sub_start);
249                reduce_counts.push(num_chunks);
250            }
251        }
252        window_counts.push(dispatched_in_window);
253    }
254
255    let num_dispatched = bucket_pointers.len() as u32;
256
257    BucketData {
258        base_indices,
259        bucket_pointers,
260        bucket_sizes,
261        bucket_values,
262        window_starts,
263        window_counts,
264        num_windows: num_windows as u32,
265        num_active_buckets,
266        num_dispatched,
267        orig_bucket_values: orig_values,
268        orig_window_starts,
269        orig_window_counts,
270        reduce_starts,
271        reduce_counts,
272        has_chunks,
273        bucket_width: c,
274    }
275}
276
277pub fn optimal_glv_c<G: GpuCurve>(n: usize) -> usize {
278    G::g1_msm_bucket_width(n)
279}
280
281pub fn compute_bucket_sorting<G: GpuCurve>(
282    scalars: &[G::Scalar],
283) -> BucketData {
284    compute_bucket_sorting_with_width::<G>(scalars, G::bucket_width())
285}
286
287pub fn compute_bucket_sorting_with_width<G: GpuCurve>(
288    scalars: &[G::Scalar],
289    c: usize,
290) -> BucketData {
291    let all_windows: Vec<Vec<(u32, bool)>> = scalars
292        .iter()
293        .map(|s| G::scalar_to_signed_windows(s, c))
294        .collect();
295    build_bucket_data::<G>(&all_windows, c)
296}
297
298/// Curve-capability-aware G1 bucket sorting with signed-digit decomposition.
299///
300/// For GLV-capable curves, decomposes each scalar into two components and
301/// builds a 2N-entry bases buffer with conditional point negation. For non-GLV
302/// curves, falls back to standard signed-window sorting and returns the
303/// original base bytes.
304///
305/// Returns `(combined_bases_bytes, bucket_data)` where `combined_bases_bytes`
306/// is a 2N×G1_GPU_BYTES buffer laid out as:
307///   [maybe_neg(P₀), maybe_neg(φ(P₀)), maybe_neg(P₁), maybe_neg(φ(P₁)), ...]
308pub fn compute_glv_bucket_sorting<G: GpuCurve>(
309    scalars: &[G::Scalar],
310    bases_bytes: &[u8],
311    phi_bases_bytes: &[u8],
312    c: usize,
313) -> (Vec<u8>, BucketData) {
314    if !G::HAS_G1_GLV {
315        let bd = compute_bucket_sorting_with_width::<G>(scalars, c);
316        return (bases_bytes.to_vec(), bd);
317    }
318
319    let n = scalars.len();
320    debug_assert_eq!(bases_bytes.len(), n * G::G1_GPU_BYTES);
321    debug_assert_eq!(phi_bases_bytes.len(), n * G::G1_GPU_BYTES);
322
323    // Decompose all scalars and build the combined bases buffer.
324    let mut combined_bases = Vec::with_capacity(n * 2 * G::G1_GPU_BYTES);
325    let mut all_windows: Vec<Vec<(u32, bool)>> = Vec::with_capacity(n * 2);
326
327    for (i, scalar) in scalars.iter().enumerate() {
328        if let Some((k1_windows, k1_neg, k2_windows, k2_neg)) =
329            G::decompose_g1_msm_scalar_glv_windows(scalar, c)
330        {
331            let src_start = i * G::G1_GPU_BYTES;
332            let mut p_bytes =
333                bases_bytes[src_start..src_start + G::G1_GPU_BYTES].to_vec();
334            if k1_neg {
335                G::negate_g1_base_bytes(&mut p_bytes);
336            }
337            combined_bases.extend_from_slice(&p_bytes);
338
339            let mut phi_bytes = phi_bases_bytes
340                [src_start..src_start + G::G1_GPU_BYTES]
341                .to_vec();
342            if k2_neg {
343                G::negate_g1_base_bytes(&mut phi_bytes);
344            }
345            combined_bases.extend_from_slice(&phi_bytes);
346
347            all_windows.push(k1_windows);
348            all_windows.push(k2_windows);
349        } else if let G1MsmDecomposition::Standard { windows } =
350            G::decompose_g1_msm_scalar(scalar, c)
351        {
352            let src_start = i * G::G1_GPU_BYTES;
353            combined_bases.extend_from_slice(
354                &bases_bytes[src_start..src_start + G::G1_GPU_BYTES],
355            );
356            all_windows.push(windows);
357        }
358    }
359
360    (combined_bases, build_bucket_data::<G>(&all_windows, c))
361}
362
363/// Curve-capability-aware bucket sorting that returns only BucketData (no bases
364/// buffer).
365///
366/// For GLV-capable curves with persistent bases, GLV negation is folded into
367/// `base_indices` sign bits (XOR with signed-digit window sign) instead of
368/// mutating base bytes. For non-GLV curves this is equivalent to standard
369/// sorting.
370pub fn compute_glv_bucket_data<G: GpuCurve>(
371    scalars: &[G::Scalar],
372    c: usize,
373) -> BucketData {
374    if !G::HAS_G1_GLV {
375        return compute_bucket_sorting_with_width::<G>(scalars, c);
376    }
377
378    let n = scalars.len();
379    let mut all_windows: Vec<Vec<(u32, bool)>> = Vec::with_capacity(n * 2);
380
381    for scalar in scalars.iter() {
382        if let Some((mut k1_windows, k1_neg, mut k2_windows, k2_neg)) =
383            G::decompose_g1_msm_scalar_glv_windows(scalar, c)
384        {
385            if k1_neg {
386                for w in &mut k1_windows {
387                    if w.0 != 0 {
388                        w.1 = !w.1;
389                    }
390                }
391            }
392            all_windows.push(k1_windows);
393
394            if k2_neg {
395                for w in &mut k2_windows {
396                    if w.0 != 0 {
397                        w.1 = !w.1;
398                    }
399                }
400            }
401            all_windows.push(k2_windows);
402        } else if let G1MsmDecomposition::Standard { windows } =
403            G::decompose_g1_msm_scalar(scalar, c)
404        {
405            all_windows.push(windows);
406        }
407    }
408
409    build_bucket_data::<G>(&all_windows, c)
410}
411
412#[cfg(test)]
413mod tests;