1use core::num::Wrapping;
6use core::ops::{Index, IndexMut};
7use core::slice::SliceIndex;
8use std::sync::Barrier;
9
10struct tile {
11 x: usize,
12 dx: usize,
13 y: usize,
14 dy: usize,
15}
16
17#[repr(transparent)]
21struct Cell<T: ?Sized> {
22 value: T,
23}
24unsafe impl<T: ?Sized + Sync> Sync for Cell<T> {}
25impl<T> Cell<T> {
26 pub fn as_ptr(&self) -> *mut T {
27 &self.value as *const T as *mut T
28 }
29}
30
31macro_rules! pippenger_mult_impl {
32 (
33 $points:ident,
34 $point:ty,
35 $point_affine:ty,
36 $to_affines:ident,
37 $scratch_sizeof:ident,
38 $multi_scalar_mult:ident,
39 $tile_mult:ident,
40 $add_or_double:ident,
41 $double:ident,
42 $test_mod:ident,
43 $generator:ident,
44 $mult:ident,
45 $add:ident,
46 $is_inf:ident,
47 $in_group:ident,
48 $from_affine:ident,
49 ) => {
50 pub struct $points {
51 points: Vec<$point_affine>,
52 }
53
54 impl<I: SliceIndex<[$point_affine]>> Index<I> for $points {
55 type Output = I::Output;
56
57 #[inline]
58 fn index(&self, i: I) -> &Self::Output {
59 &self.points[i]
60 }
61 }
62 impl<I: SliceIndex<[$point_affine]>> IndexMut<I> for $points {
63 #[inline]
64 fn index_mut(&mut self, i: I) -> &mut Self::Output {
65 &mut self.points[i]
66 }
67 }
68
69 impl $points {
70 #[inline]
71 pub fn as_slice(&self) -> &[$point_affine] {
72 self.points.as_slice()
73 }
74
75 pub fn from(points: &[$point]) -> Self {
76 let npoints = points.len();
77 let mut ret = Self {
78 points: Vec::with_capacity(npoints),
79 };
80 unsafe { ret.points.set_len(npoints) };
81
82 let pool = mt::da_pool();
83 let ncpus = pool.max_count();
84 if ncpus < 2 || npoints < 768 {
85 let p: [*const $point; 2] = [&points[0], ptr::null()];
86 unsafe { $to_affines(&mut ret.points[0], &p[0], npoints) };
87 return ret;
88 }
89
90 let mut nslices = (npoints + 511) / 512;
91 nslices = core::cmp::min(nslices, ncpus);
92 let wg = Arc::new((Barrier::new(2), AtomicUsize::new(nslices)));
93
94 let (mut delta, mut rem) =
95 (npoints / nslices + 1, Wrapping(npoints % nslices));
96 let mut x = 0usize;
97 while x < npoints {
98 let out = &mut ret.points[x];
99 let inp = &points[x];
100
101 delta -= (rem == Wrapping(0)) as usize;
102 rem -= Wrapping(1);
103 x += delta;
104
105 let wg = wg.clone();
106 pool.joined_execute(move || {
107 let p: [*const $point; 2] = [inp, ptr::null()];
108 unsafe { $to_affines(out, &p[0], delta) };
109 if wg.1.fetch_sub(1, Ordering::AcqRel) == 1 {
110 wg.0.wait();
111 }
112 });
113 }
114 wg.0.wait();
115
116 ret
117 }
118
119 #[inline]
120 pub fn mult(&self, scalars: &[u8], nbits: usize) -> $point {
121 self.as_slice().mult(scalars, nbits)
122 }
123
124 #[inline]
125 pub fn add(&self) -> $point {
126 self.as_slice().add()
127 }
128 }
129
130 impl MultiPoint for [$point_affine] {
131 type Output = $point;
132
133 fn mult(&self, scalars: &[u8], nbits: usize) -> $point {
134 let npoints = self.len();
135 let nbytes = (nbits + 7) / 8;
136
137 if scalars.len() < nbytes * npoints {
138 panic!("scalars length mismatch");
139 }
140
141 let pool = mt::da_pool();
142 let ncpus = pool.max_count();
143 if ncpus < 2 {
144 let p: [*const $point_affine; 2] = [&self[0], ptr::null()];
145 let s: [*const u8; 2] = [&scalars[0], ptr::null()];
146
147 unsafe {
148 let mut scratch: Vec<u64> =
149 Vec::with_capacity($scratch_sizeof(npoints) / 8);
150 #[allow(clippy::uninit_vec)]
151 scratch.set_len(scratch.capacity());
152 let mut ret = <$point>::default();
153 $multi_scalar_mult(
154 &mut ret,
155 &p[0],
156 npoints,
157 &s[0],
158 nbits,
159 &mut scratch[0],
160 );
161 return ret;
162 }
163 }
164
165 if npoints < 32 {
166 let counter = Arc::new(AtomicUsize::new(0));
167 let n_workers = core::cmp::min(ncpus, npoints);
168 let (tx, rx) = sync_channel(n_workers);
169 for _ in 0..n_workers {
170 let tx = tx.clone();
171 let counter = counter.clone();
172
173 pool.joined_execute(move || {
174 let mut acc = <$point>::default();
175 let mut tmp = <$point>::default();
176 let mut first = true;
177
178 loop {
179 let work =
180 counter.fetch_add(1, Ordering::Relaxed);
181 if work >= npoints {
182 break;
183 }
184
185 unsafe {
186 $from_affine(&mut tmp, &self[work]);
187 let scalar = &scalars[nbytes * work];
188 if first {
189 $mult(&mut acc, &tmp, scalar, nbits);
190 first = false;
191 } else {
192 $mult(&mut tmp, &tmp, scalar, nbits);
193 $add_or_double(&mut acc, &acc, &tmp);
194 }
195 }
196 }
197
198 tx.send(acc).expect("disaster");
199 });
200 }
201
202 let mut ret = rx.recv().expect("disaster");
203 for _ in 1..n_workers {
204 let p = rx.recv().expect("disaster");
205 unsafe { $add_or_double(&mut ret, &ret, &p) };
206 }
207
208 return ret;
209 }
210
211 let (nx, ny, window) =
212 breakdown(nbits, pippenger_window_size(npoints), ncpus);
213
214 let mut grid: Vec<(tile, Cell<$point>)> =
216 Vec::with_capacity(nx * ny);
217 #[allow(clippy::uninit_vec)]
218 unsafe { grid.set_len(grid.capacity()) };
219 let dx = npoints / nx;
220 let mut y = window * (ny - 1);
221 let mut total = 0usize;
222
223 while total < nx {
224 grid[total].0.x = total * dx;
225 grid[total].0.dx = dx;
226 grid[total].0.y = y;
227 grid[total].0.dy = nbits - y;
228 total += 1;
229 }
230 grid[total - 1].0.dx = npoints - grid[total - 1].0.x;
231 while y != 0 {
232 y -= window;
233 for i in 0..nx {
234 grid[total].0.x = grid[i].0.x;
235 grid[total].0.dx = grid[i].0.dx;
236 grid[total].0.y = y;
237 grid[total].0.dy = window;
238 total += 1;
239 }
240 }
241 let grid = &grid[..];
242
243 let points = &self[..];
244 let sz = unsafe { $scratch_sizeof(0) / 8 };
245
246 let mut row_sync: Vec<AtomicUsize> = Vec::with_capacity(ny);
247 row_sync.resize_with(ny, Default::default);
248 let row_sync = Arc::new(row_sync);
249 let counter = Arc::new(AtomicUsize::new(0));
250 let n_workers = core::cmp::min(ncpus, total);
251 let (tx, rx) = sync_channel(n_workers);
252 for _ in 0..n_workers {
253 let tx = tx.clone();
254 let counter = counter.clone();
255 let row_sync = row_sync.clone();
256
257 pool.joined_execute(move || {
258 let mut scratch = vec![0u64; sz << (window - 1)];
259 let mut p: [*const $point_affine; 2] =
260 [ptr::null(), ptr::null()];
261 let mut s: [*const u8; 2] = [ptr::null(), ptr::null()];
262
263 loop {
264 let work = counter.fetch_add(1, Ordering::Relaxed);
265 if work >= total {
266 break;
267 }
268 let x = grid[work].0.x;
269 let y = grid[work].0.y;
270
271 p[0] = &points[x];
272 s[0] = &scalars[x * nbytes];
273 unsafe {
274 $tile_mult(
275 grid[work].1.as_ptr(),
276 &p[0],
277 grid[work].0.dx,
278 &s[0],
279 nbits,
280 &mut scratch[0],
281 y,
282 window,
283 );
284 }
285 if row_sync[y / window]
286 .fetch_add(1, Ordering::AcqRel)
287 == nx - 1
288 {
289 tx.send(y).expect("disaster");
290 }
291 }
292 });
293 }
294
295 let mut ret = <$point>::default();
296 let mut rows = vec![false; ny];
297 let mut row = 0usize;
298 for _ in 0..ny {
299 let mut y = rx.recv().unwrap();
300 rows[y / window] = true;
301 while grid[row].0.y == y {
302 while row < total && grid[row].0.y == y {
303 unsafe {
304 $add_or_double(
305 &mut ret,
306 &ret,
307 grid[row].1.as_ptr(),
308 );
309 }
310 row += 1;
311 }
312 if y == 0 {
313 break;
314 }
315 for _ in 0..window {
316 unsafe { $double(&mut ret, &ret) };
317 }
318 y -= window;
319 if !rows[y / window] {
320 break;
321 }
322 }
323 }
324 ret
325 }
326
327 fn add(&self) -> $point {
328 let npoints = self.len();
329
330 let pool = mt::da_pool();
331 let ncpus = pool.max_count();
332 if ncpus < 2 || npoints < 384 {
333 let p: [*const _; 2] = [&self[0], ptr::null()];
334 let mut ret = <$point>::default();
335 unsafe { $add(&mut ret, &p[0], npoints) };
336 return ret;
337 }
338
339 let counter = Arc::new(AtomicUsize::new(0));
340 let nchunks = (npoints + 255) / 256;
341 let chunk = npoints / nchunks + 1;
342 let n_workers = core::cmp::min(ncpus, nchunks);
343 let (tx, rx) = sync_channel(n_workers);
344 for _ in 0..n_workers {
345 let tx = tx.clone();
346 let counter = counter.clone();
347
348 pool.joined_execute(move || {
349 let mut acc = <$point>::default();
350 let mut chunk = chunk;
351 let mut p: [*const _; 2] = [ptr::null(), ptr::null()];
352
353 loop {
354 let work =
355 counter.fetch_add(chunk, Ordering::Relaxed);
356 if work >= npoints {
357 break;
358 }
359 p[0] = &self[work];
360 if work + chunk > npoints {
361 chunk = npoints - work;
362 }
363 unsafe {
364 let mut t = MaybeUninit::<$point>::uninit();
365 $add(t.as_mut_ptr(), &p[0], chunk);
366 $add_or_double(&mut acc, &acc, t.as_ptr());
367 };
368 }
369 tx.send(acc).expect("disaster");
370 });
371 }
372
373 let mut ret = rx.recv().unwrap();
374 for _ in 1..n_workers {
375 unsafe {
376 $add_or_double(&mut ret, &ret, &rx.recv().unwrap())
377 };
378 }
379
380 ret
381 }
382
383 fn validate(&self) -> Result<(), BLST_ERROR> {
384 fn check(point: &$point_affine) -> Result<(), BLST_ERROR> {
385 if unsafe { $is_inf(point) } {
386 return Err(BLST_ERROR::BLST_PK_IS_INFINITY);
387 }
388 if !unsafe { $in_group(point) } {
389 return Err(BLST_ERROR::BLST_POINT_NOT_IN_GROUP);
390 }
391 Ok(())
392 }
393
394 let npoints = self.len();
395
396 let pool = mt::da_pool();
397 let n_workers = core::cmp::min(npoints, pool.max_count());
398 if n_workers < 2 {
399 for i in 0..npoints {
400 check(&self[i])?
401 }
402 return Ok(())
403 }
404
405 let counter = Arc::new(AtomicUsize::new(0));
406 let valid = Arc::new(AtomicBool::new(true));
407 let wg =
408 Arc::new((Barrier::new(2), AtomicUsize::new(n_workers)));
409
410 for _ in 0..n_workers {
411 let counter = counter.clone();
412 let valid = valid.clone();
413 let wg = wg.clone();
414
415 pool.joined_execute(move || {
416 while valid.load(Ordering::Relaxed) {
417 let work = counter.fetch_add(1, Ordering::Relaxed);
418 if work >= npoints {
419 break;
420 }
421
422 if check(&self[work]).is_err() {
423 valid.store(false, Ordering::Relaxed);
424 break;
425 }
426 }
427
428 if wg.1.fetch_sub(1, Ordering::AcqRel) == 1 {
429 wg.0.wait();
430 }
431 });
432 }
433
434 wg.0.wait();
435
436 if valid.load(Ordering::Relaxed) {
437 return Ok(());
438 } else {
439 return Err(BLST_ERROR::BLST_POINT_NOT_IN_GROUP);
440 }
441 }
442 }
443
444 #[cfg(test)]
445 pippenger_test_mod!(
446 $test_mod,
447 $points,
448 $point,
449 $add_or_double,
450 $generator,
451 $mult,
452 );
453 };
454}
455
456#[cfg(test)]
457include!("pippenger-test_mod.rs");
458
459pippenger_mult_impl!(
460 p1_affines,
461 blst_p1,
462 blst_p1_affine,
463 blst_p1s_to_affine,
464 blst_p1s_mult_pippenger_scratch_sizeof,
465 blst_p1s_mult_pippenger,
466 blst_p1s_tile_pippenger,
467 blst_p1_add_or_double,
468 blst_p1_double,
469 p1_multi_point,
470 blst_p1_generator,
471 blst_p1_mult,
472 blst_p1s_add,
473 blst_p1_affine_is_inf,
474 blst_p1_affine_in_g1,
475 blst_p1_from_affine,
476);
477
478pippenger_mult_impl!(
479 p2_affines,
480 blst_p2,
481 blst_p2_affine,
482 blst_p2s_to_affine,
483 blst_p2s_mult_pippenger_scratch_sizeof,
484 blst_p2s_mult_pippenger,
485 blst_p2s_tile_pippenger,
486 blst_p2_add_or_double,
487 blst_p2_double,
488 p2_multi_point,
489 blst_p2_generator,
490 blst_p2_mult,
491 blst_p2s_add,
492 blst_p2_affine_is_inf,
493 blst_p2_affine_in_g2,
494 blst_p2_from_affine,
495);
496
497fn num_bits(l: usize) -> usize {
498 8 * core::mem::size_of_val(&l) - l.leading_zeros() as usize
499}
500
501fn breakdown(
502 nbits: usize,
503 window: usize,
504 ncpus: usize,
505) -> (usize, usize, usize) {
506 let mut nx: usize;
507 let mut wnd: usize;
508
509 if nbits > window * ncpus {
510 nx = 1;
511 wnd = num_bits(ncpus / 4);
512 if (window + wnd) > 18 {
513 wnd = window - wnd;
514 } else {
515 wnd = (nbits / window + ncpus - 1) / ncpus;
516 if (nbits / (window + 1) + ncpus - 1) / ncpus < wnd {
517 wnd = window + 1;
518 } else {
519 wnd = window;
520 }
521 }
522 } else {
523 nx = 2;
524 wnd = window - 2;
525 while (nbits / wnd + 1) * nx < ncpus {
526 nx += 1;
527 wnd = window - num_bits(3 * nx / 2);
528 }
529 nx -= 1;
530 wnd = window - num_bits(3 * nx / 2);
531 }
532 let ny = nbits / wnd + 1;
533 wnd = nbits / ny + 1;
534
535 (nx, ny, wnd)
536}
537
538fn pippenger_window_size(npoints: usize) -> usize {
539 let wbits = num_bits(npoints);
540
541 if wbits > 13 {
542 return wbits - 4;
543 }
544 if wbits > 5 {
545 return wbits - 3;
546 }
547 2
548}