blst/
pippenger.rs

1// Copyright Supranational LLC
2// Licensed under the Apache License, Version 2.0, see LICENSE for details.
3// SPDX-License-Identifier: Apache-2.0
4
5use core::num::Wrapping;
6use core::ops::{Index, IndexMut};
7use core::slice::SliceIndex;
8use std::sync::Barrier;
9
10struct tile {
11    x: usize,
12    dx: usize,
13    y: usize,
14    dy: usize,
15}
16
17// Minimalist core::cell::Cell stand-in, but with Sync marker, which
18// makes it possible to pass it to multiple threads. It works, because
19// *here* each Cell is written only once and by just one thread.
20#[repr(transparent)]
21struct Cell<T: ?Sized> {
22    value: T,
23}
24unsafe impl<T: ?Sized + Sync> Sync for Cell<T> {}
25impl<T> Cell<T> {
26    pub fn as_ptr(&self) -> *mut T {
27        &self.value as *const T as *mut T
28    }
29}
30
31macro_rules! pippenger_mult_impl {
32    (
33        $points:ident,
34        $point:ty,
35        $point_affine:ty,
36        $to_affines:ident,
37        $scratch_sizeof:ident,
38        $multi_scalar_mult:ident,
39        $tile_mult:ident,
40        $add_or_double:ident,
41        $double:ident,
42        $test_mod:ident,
43        $generator:ident,
44        $mult:ident,
45        $add:ident,
46        $is_inf:ident,
47        $in_group:ident,
48        $from_affine:ident,
49    ) => {
50        pub struct $points {
51            points: Vec<$point_affine>,
52        }
53
54        impl<I: SliceIndex<[$point_affine]>> Index<I> for $points {
55            type Output = I::Output;
56
57            #[inline]
58            fn index(&self, i: I) -> &Self::Output {
59                &self.points[i]
60            }
61        }
62        impl<I: SliceIndex<[$point_affine]>> IndexMut<I> for $points {
63            #[inline]
64            fn index_mut(&mut self, i: I) -> &mut Self::Output {
65                &mut self.points[i]
66            }
67        }
68
69        impl $points {
70            #[inline]
71            pub fn as_slice(&self) -> &[$point_affine] {
72                self.points.as_slice()
73            }
74
75            pub fn from(points: &[$point]) -> Self {
76                let npoints = points.len();
77                let mut ret = Self {
78                    points: Vec::with_capacity(npoints),
79                };
80                unsafe { ret.points.set_len(npoints) };
81
82                let pool = mt::da_pool();
83                let ncpus = pool.max_count();
84                if ncpus < 2 || npoints < 768 {
85                    let p: [*const $point; 2] = [&points[0], ptr::null()];
86                    unsafe { $to_affines(&mut ret.points[0], &p[0], npoints) };
87                    return ret;
88                }
89
90                let mut nslices = (npoints + 511) / 512;
91                nslices = core::cmp::min(nslices, ncpus);
92                let wg = Arc::new((Barrier::new(2), AtomicUsize::new(nslices)));
93
94                let (mut delta, mut rem) =
95                    (npoints / nslices + 1, Wrapping(npoints % nslices));
96                let mut x = 0usize;
97                while x < npoints {
98                    let out = &mut ret.points[x];
99                    let inp = &points[x];
100
101                    delta -= (rem == Wrapping(0)) as usize;
102                    rem -= Wrapping(1);
103                    x += delta;
104
105                    let wg = wg.clone();
106                    pool.joined_execute(move || {
107                        let p: [*const $point; 2] = [inp, ptr::null()];
108                        unsafe { $to_affines(out, &p[0], delta) };
109                        if wg.1.fetch_sub(1, Ordering::AcqRel) == 1 {
110                            wg.0.wait();
111                        }
112                    });
113                }
114                wg.0.wait();
115
116                ret
117            }
118
119            #[inline]
120            pub fn mult(&self, scalars: &[u8], nbits: usize) -> $point {
121                self.as_slice().mult(scalars, nbits)
122            }
123
124            #[inline]
125            pub fn add(&self) -> $point {
126                self.as_slice().add()
127            }
128        }
129
130        impl MultiPoint for [$point_affine] {
131            type Output = $point;
132
133            fn mult(&self, scalars: &[u8], nbits: usize) -> $point {
134                let npoints = self.len();
135                let nbytes = (nbits + 7) / 8;
136
137                if scalars.len() < nbytes * npoints {
138                    panic!("scalars length mismatch");
139                }
140
141                let pool = mt::da_pool();
142                let ncpus = pool.max_count();
143                if ncpus < 2 {
144                    let p: [*const $point_affine; 2] = [&self[0], ptr::null()];
145                    let s: [*const u8; 2] = [&scalars[0], ptr::null()];
146
147                    unsafe {
148                        let mut scratch: Vec<u64> =
149                            Vec::with_capacity($scratch_sizeof(npoints) / 8);
150                        #[allow(clippy::uninit_vec)]
151                        scratch.set_len(scratch.capacity());
152                        let mut ret = <$point>::default();
153                        $multi_scalar_mult(
154                            &mut ret,
155                            &p[0],
156                            npoints,
157                            &s[0],
158                            nbits,
159                            &mut scratch[0],
160                        );
161                        return ret;
162                    }
163                }
164
165                if npoints < 32 {
166                    let counter = Arc::new(AtomicUsize::new(0));
167                    let n_workers = core::cmp::min(ncpus, npoints);
168                    let (tx, rx) = sync_channel(n_workers);
169                    for _ in 0..n_workers {
170                        let tx = tx.clone();
171                        let counter = counter.clone();
172
173                        pool.joined_execute(move || {
174                            let mut acc = <$point>::default();
175                            let mut tmp = <$point>::default();
176                            let mut first = true;
177
178                            loop {
179                                let work =
180                                    counter.fetch_add(1, Ordering::Relaxed);
181                                if work >= npoints {
182                                    break;
183                                }
184
185                                unsafe {
186                                    $from_affine(&mut tmp, &self[work]);
187                                    let scalar = &scalars[nbytes * work];
188                                    if first {
189                                        $mult(&mut acc, &tmp, scalar, nbits);
190                                        first = false;
191                                    } else {
192                                        $mult(&mut tmp, &tmp, scalar, nbits);
193                                        $add_or_double(&mut acc, &acc, &tmp);
194                                    }
195                                }
196                            }
197
198                            tx.send(acc).expect("disaster");
199                        });
200                    }
201
202                    let mut ret = rx.recv().expect("disaster");
203                    for _ in 1..n_workers {
204                        let p = rx.recv().expect("disaster");
205                        unsafe { $add_or_double(&mut ret, &ret, &p) };
206                    }
207
208                    return ret;
209                }
210
211                let (nx, ny, window) =
212                    breakdown(nbits, pippenger_window_size(npoints), ncpus);
213
214                // |grid[]| holds "coordinates" and place for result
215                let mut grid: Vec<(tile, Cell<$point>)> =
216                    Vec::with_capacity(nx * ny);
217                #[allow(clippy::uninit_vec)]
218                unsafe { grid.set_len(grid.capacity()) };
219                let dx = npoints / nx;
220                let mut y = window * (ny - 1);
221                let mut total = 0usize;
222
223                while total < nx {
224                    grid[total].0.x = total * dx;
225                    grid[total].0.dx = dx;
226                    grid[total].0.y = y;
227                    grid[total].0.dy = nbits - y;
228                    total += 1;
229                }
230                grid[total - 1].0.dx = npoints - grid[total - 1].0.x;
231                while y != 0 {
232                    y -= window;
233                    for i in 0..nx {
234                        grid[total].0.x = grid[i].0.x;
235                        grid[total].0.dx = grid[i].0.dx;
236                        grid[total].0.y = y;
237                        grid[total].0.dy = window;
238                        total += 1;
239                    }
240                }
241                let grid = &grid[..];
242
243                let points = &self[..];
244                let sz = unsafe { $scratch_sizeof(0) / 8 };
245
246                let mut row_sync: Vec<AtomicUsize> = Vec::with_capacity(ny);
247                row_sync.resize_with(ny, Default::default);
248                let row_sync = Arc::new(row_sync);
249                let counter = Arc::new(AtomicUsize::new(0));
250                let n_workers = core::cmp::min(ncpus, total);
251                let (tx, rx) = sync_channel(n_workers);
252                for _ in 0..n_workers {
253                    let tx = tx.clone();
254                    let counter = counter.clone();
255                    let row_sync = row_sync.clone();
256
257                    pool.joined_execute(move || {
258                        let mut scratch = vec![0u64; sz << (window - 1)];
259                        let mut p: [*const $point_affine; 2] =
260                            [ptr::null(), ptr::null()];
261                        let mut s: [*const u8; 2] = [ptr::null(), ptr::null()];
262
263                        loop {
264                            let work = counter.fetch_add(1, Ordering::Relaxed);
265                            if work >= total {
266                                break;
267                            }
268                            let x = grid[work].0.x;
269                            let y = grid[work].0.y;
270
271                            p[0] = &points[x];
272                            s[0] = &scalars[x * nbytes];
273                            unsafe {
274                                $tile_mult(
275                                    grid[work].1.as_ptr(),
276                                    &p[0],
277                                    grid[work].0.dx,
278                                    &s[0],
279                                    nbits,
280                                    &mut scratch[0],
281                                    y,
282                                    window,
283                                );
284                            }
285                            if row_sync[y / window]
286                                .fetch_add(1, Ordering::AcqRel)
287                                == nx - 1
288                            {
289                                tx.send(y).expect("disaster");
290                            }
291                        }
292                    });
293                }
294
295                let mut ret = <$point>::default();
296                let mut rows = vec![false; ny];
297                let mut row = 0usize;
298                for _ in 0..ny {
299                    let mut y = rx.recv().unwrap();
300                    rows[y / window] = true;
301                    while grid[row].0.y == y {
302                        while row < total && grid[row].0.y == y {
303                            unsafe {
304                                $add_or_double(
305                                    &mut ret,
306                                    &ret,
307                                    grid[row].1.as_ptr(),
308                                );
309                            }
310                            row += 1;
311                        }
312                        if y == 0 {
313                            break;
314                        }
315                        for _ in 0..window {
316                            unsafe { $double(&mut ret, &ret) };
317                        }
318                        y -= window;
319                        if !rows[y / window] {
320                            break;
321                        }
322                    }
323                }
324                ret
325            }
326
327            fn add(&self) -> $point {
328                let npoints = self.len();
329
330                let pool = mt::da_pool();
331                let ncpus = pool.max_count();
332                if ncpus < 2 || npoints < 384 {
333                    let p: [*const _; 2] = [&self[0], ptr::null()];
334                    let mut ret = <$point>::default();
335                    unsafe { $add(&mut ret, &p[0], npoints) };
336                    return ret;
337                }
338
339                let counter = Arc::new(AtomicUsize::new(0));
340                let nchunks = (npoints + 255) / 256;
341                let chunk = npoints / nchunks + 1;
342                let n_workers = core::cmp::min(ncpus, nchunks);
343                let (tx, rx) = sync_channel(n_workers);
344                for _ in 0..n_workers {
345                    let tx = tx.clone();
346                    let counter = counter.clone();
347
348                    pool.joined_execute(move || {
349                        let mut acc = <$point>::default();
350                        let mut chunk = chunk;
351                        let mut p: [*const _; 2] = [ptr::null(), ptr::null()];
352
353                        loop {
354                            let work =
355                                counter.fetch_add(chunk, Ordering::Relaxed);
356                            if work >= npoints {
357                                break;
358                            }
359                            p[0] = &self[work];
360                            if work + chunk > npoints {
361                                chunk = npoints - work;
362                            }
363                            unsafe {
364                                let mut t = MaybeUninit::<$point>::uninit();
365                                $add(t.as_mut_ptr(), &p[0], chunk);
366                                $add_or_double(&mut acc, &acc, t.as_ptr());
367                            };
368                        }
369                        tx.send(acc).expect("disaster");
370                    });
371                }
372
373                let mut ret = rx.recv().unwrap();
374                for _ in 1..n_workers {
375                    unsafe {
376                        $add_or_double(&mut ret, &ret, &rx.recv().unwrap())
377                    };
378                }
379
380                ret
381            }
382
383            fn validate(&self) -> Result<(), BLST_ERROR> {
384                fn check(point: &$point_affine) -> Result<(), BLST_ERROR> {
385                    if unsafe { $is_inf(point) } {
386                        return Err(BLST_ERROR::BLST_PK_IS_INFINITY);
387                    }
388                    if !unsafe { $in_group(point) } {
389                        return Err(BLST_ERROR::BLST_POINT_NOT_IN_GROUP);
390                    }
391                    Ok(())
392                }
393
394                let npoints = self.len();
395
396                let pool = mt::da_pool();
397                let n_workers = core::cmp::min(npoints, pool.max_count());
398                if n_workers < 2 {
399                    for i in 0..npoints {
400                        check(&self[i])?
401                    }
402                    return Ok(())
403                }
404
405                let counter = Arc::new(AtomicUsize::new(0));
406                let valid = Arc::new(AtomicBool::new(true));
407                let wg =
408                    Arc::new((Barrier::new(2), AtomicUsize::new(n_workers)));
409
410                for _ in 0..n_workers {
411                    let counter = counter.clone();
412                    let valid = valid.clone();
413                    let wg = wg.clone();
414
415                    pool.joined_execute(move || {
416                        while valid.load(Ordering::Relaxed) {
417                            let work = counter.fetch_add(1, Ordering::Relaxed);
418                            if work >= npoints {
419                                break;
420                            }
421
422                            if check(&self[work]).is_err() {
423                                valid.store(false, Ordering::Relaxed);
424                                break;
425                            }
426                        }
427
428                        if wg.1.fetch_sub(1, Ordering::AcqRel) == 1 {
429                            wg.0.wait();
430                        }
431                    });
432                }
433
434                wg.0.wait();
435
436                if valid.load(Ordering::Relaxed) {
437                    return Ok(());
438                } else {
439                    return Err(BLST_ERROR::BLST_POINT_NOT_IN_GROUP);
440                }
441            }
442        }
443
444        #[cfg(test)]
445        pippenger_test_mod!(
446            $test_mod,
447            $points,
448            $point,
449            $add_or_double,
450            $generator,
451            $mult,
452        );
453    };
454}
455
456#[cfg(test)]
457include!("pippenger-test_mod.rs");
458
459pippenger_mult_impl!(
460    p1_affines,
461    blst_p1,
462    blst_p1_affine,
463    blst_p1s_to_affine,
464    blst_p1s_mult_pippenger_scratch_sizeof,
465    blst_p1s_mult_pippenger,
466    blst_p1s_tile_pippenger,
467    blst_p1_add_or_double,
468    blst_p1_double,
469    p1_multi_point,
470    blst_p1_generator,
471    blst_p1_mult,
472    blst_p1s_add,
473    blst_p1_affine_is_inf,
474    blst_p1_affine_in_g1,
475    blst_p1_from_affine,
476);
477
478pippenger_mult_impl!(
479    p2_affines,
480    blst_p2,
481    blst_p2_affine,
482    blst_p2s_to_affine,
483    blst_p2s_mult_pippenger_scratch_sizeof,
484    blst_p2s_mult_pippenger,
485    blst_p2s_tile_pippenger,
486    blst_p2_add_or_double,
487    blst_p2_double,
488    p2_multi_point,
489    blst_p2_generator,
490    blst_p2_mult,
491    blst_p2s_add,
492    blst_p2_affine_is_inf,
493    blst_p2_affine_in_g2,
494    blst_p2_from_affine,
495);
496
497fn num_bits(l: usize) -> usize {
498    8 * core::mem::size_of_val(&l) - l.leading_zeros() as usize
499}
500
501fn breakdown(
502    nbits: usize,
503    window: usize,
504    ncpus: usize,
505) -> (usize, usize, usize) {
506    let mut nx: usize;
507    let mut wnd: usize;
508
509    if nbits > window * ncpus {
510        nx = 1;
511        wnd = num_bits(ncpus / 4);
512        if (window + wnd) > 18 {
513            wnd = window - wnd;
514        } else {
515            wnd = (nbits / window + ncpus - 1) / ncpus;
516            if (nbits / (window + 1) + ncpus - 1) / ncpus < wnd {
517                wnd = window + 1;
518            } else {
519                wnd = window;
520            }
521        }
522    } else {
523        nx = 2;
524        wnd = window - 2;
525        while (nbits / wnd + 1) * nx < ncpus {
526            nx += 1;
527            wnd = window - num_bits(3 * nx / 2);
528        }
529        nx -= 1;
530        wnd = window - num_bits(3 * nx / 2);
531    }
532    let ny = nbits / wnd + 1;
533    wnd = nbits / ny + 1;
534
535    (nx, ny, wnd)
536}
537
538fn pippenger_window_size(npoints: usize) -> usize {
539    let wbits = num_bits(npoints);
540
541    if wbits > 13 {
542        return wbits - 4;
543    }
544    if wbits > 5 {
545        return wbits - 3;
546    }
547    2
548}