p3_util/lib.rs
1//! Various simple utilities.
2
3#![no_std]
4
5extern crate alloc;
6
7use alloc::slice;
8use alloc::string::String;
9use alloc::vec::Vec;
10use core::any::type_name;
11use core::hint::unreachable_unchecked;
12use core::mem::{ManuallyDrop, MaybeUninit};
13use core::{iter, mem};
14
15use crate::transpose::transpose_in_place_square;
16
17pub mod array_serialization;
18pub mod linear_map;
19pub mod transpose;
20pub mod zip_eq;
21
22/// Computes `ceil(log_2(n))`.
23#[must_use]
24pub const fn log2_ceil_usize(n: usize) -> usize {
25 (usize::BITS - n.saturating_sub(1).leading_zeros()) as usize
26}
27
28#[must_use]
29pub const fn log2_ceil_u64(n: u64) -> u64 {
30 (u64::BITS - n.saturating_sub(1).leading_zeros()) as u64
31}
32
33/// Computes `log_2(n)`
34///
35/// # Panics
36/// Panics if `n` is not a power of two.
37#[must_use]
38#[inline]
39pub const fn log2_strict_usize(n: usize) -> usize {
40 let res = n.trailing_zeros();
41 assert!(n.wrapping_shr(res) == 1, "Not a power of two");
42 // Tell the optimizer about the semantics of `log2_strict`. i.e. it can replace `n` with
43 // `1 << res` and vice versa.
44 unsafe {
45 assume(n == 1 << res);
46 }
47 res as usize
48}
49
50/// Returns `[0, ..., N - 1]`.
51#[must_use]
52pub const fn indices_arr<const N: usize>() -> [usize; N] {
53 let mut indices_arr = [0; N];
54 let mut i = 0;
55 while i < N {
56 indices_arr[i] = i;
57 i += 1;
58 }
59 indices_arr
60}
61
62#[inline]
63pub const fn reverse_bits(x: usize, n: usize) -> usize {
64 // Assert that n is a power of 2
65 debug_assert!(n.is_power_of_two());
66 reverse_bits_len(x, n.trailing_zeros() as usize)
67}
68
69#[inline]
70pub const fn reverse_bits_len(x: usize, bit_len: usize) -> usize {
71 // NB: The only reason we need overflowing_shr() here as opposed
72 // to plain '>>' is to accommodate the case n == num_bits == 0,
73 // which would become `0 >> 64`. Rust thinks that any shift of 64
74 // bits causes overflow, even when the argument is zero.
75 x.reverse_bits()
76 .overflowing_shr(usize::BITS - bit_len as u32)
77 .0
78}
79
80// Lookup table of 6-bit reverses.
81// NB: 2^6=64 bytes is a cache line. A smaller table wastes cache space.
82#[cfg(not(target_arch = "aarch64"))]
83#[rustfmt::skip]
84const BIT_REVERSE_6BIT: &[u8] = &[
85 0o00, 0o40, 0o20, 0o60, 0o10, 0o50, 0o30, 0o70,
86 0o04, 0o44, 0o24, 0o64, 0o14, 0o54, 0o34, 0o74,
87 0o02, 0o42, 0o22, 0o62, 0o12, 0o52, 0o32, 0o72,
88 0o06, 0o46, 0o26, 0o66, 0o16, 0o56, 0o36, 0o76,
89 0o01, 0o41, 0o21, 0o61, 0o11, 0o51, 0o31, 0o71,
90 0o05, 0o45, 0o25, 0o65, 0o15, 0o55, 0o35, 0o75,
91 0o03, 0o43, 0o23, 0o63, 0o13, 0o53, 0o33, 0o73,
92 0o07, 0o47, 0o27, 0o67, 0o17, 0o57, 0o37, 0o77,
93];
94
95// Ensure that SMALL_ARR_SIZE >= 4 * BIG_T_SIZE.
96const BIG_T_SIZE: usize = 1 << 14;
97const SMALL_ARR_SIZE: usize = 1 << 16;
98
99/// Permutes `arr` such that each index is mapped to its reverse in binary.
100///
101/// If the whole array fits in fast cache, then the trivial algorithm is cache friendly. Also, if
102/// `T` is really big, then the trivial algorithm is cache-friendly, no matter the size of the array.
103pub fn reverse_slice_index_bits<F>(vals: &mut [F])
104where
105 F: Copy + Send + Sync,
106{
107 let n = vals.len();
108 if n == 0 {
109 return;
110 }
111 let log_n = log2_strict_usize(n);
112
113 // If the whole array fits in fast cache, then the trivial algorithm is cache friendly. Also, if
114 // `T` is really big, then the trivial algorithm is cache-friendly, no matter the size of the array.
115 if core::mem::size_of::<F>() << log_n <= SMALL_ARR_SIZE
116 || core::mem::size_of::<F>() >= BIG_T_SIZE
117 {
118 reverse_slice_index_bits_small(vals, log_n);
119 } else {
120 debug_assert!(n >= 4); // By our choice of `BIG_T_SIZE` and `SMALL_ARR_SIZE`.
121
122 // Algorithm:
123 //
124 // Treat `arr` as a `sqrt(n)` by `sqrt(n)` row-major matrix. (Assume for now that `lb_n` is
125 // even, i.e., `n` is a square number.) To perform bit-order reversal we:
126 // 1. Bit-reverse the order of the rows. (They are contiguous in memory, so this is
127 // basically a series of large `memcpy`s.)
128 // 2. Transpose the matrix.
129 // 3. Bit-reverse the order of the rows.
130 //
131 // This is equivalent to, for every index `0 <= i < n`:
132 // 1. bit-reversing `i[lb_n / 2..lb_n]`,
133 // 2. swapping `i[0..lb_n / 2]` and `i[lb_n / 2..lb_n]`,
134 // 3. bit-reversing `i[lb_n / 2..lb_n]`.
135 //
136 // If `lb_n` is odd, i.e., `n` is not a square number, then the above procedure requires
137 // slight modification. At steps 1 and 3 we bit-reverse bits `ceil(lb_n / 2)..lb_n`, of the
138 // index (shuffling `floor(lb_n / 2)` chunks of length `ceil(lb_n / 2)`). At step 2, we
139 // perform _two_ transposes. We treat `arr` as two matrices, one where the middle bit of the
140 // index is `0` and another, where the middle bit is `1`; we transpose each individually.
141
142 let lb_num_chunks = log_n >> 1;
143 let lb_chunk_size = log_n - lb_num_chunks;
144 unsafe {
145 reverse_slice_index_bits_chunks(vals, lb_num_chunks, lb_chunk_size);
146 transpose_in_place_square(vals, lb_chunk_size, lb_num_chunks, 0);
147 if lb_num_chunks != lb_chunk_size {
148 // `arr` cannot be interpreted as a square matrix. We instead interpret it as a
149 // `1 << lb_num_chunks` by `2` by `1 << lb_num_chunks` tensor, in row-major order.
150 // The above transpose acted on `tensor[..., 0, ...]` (all indices with middle bit
151 // `0`). We still need to transpose `tensor[..., 1, ...]`. To do so, we advance
152 // arr by `1 << lb_num_chunks` effectively, adding that to every index.
153 let vals_with_offset = &mut vals[1 << lb_num_chunks..];
154 transpose_in_place_square(vals_with_offset, lb_chunk_size, lb_num_chunks, 0);
155 }
156 reverse_slice_index_bits_chunks(vals, lb_num_chunks, lb_chunk_size);
157 }
158 }
159}
160
161// Both functions below are semantically equivalent to:
162// for i in 0..n {
163// result.push(arr[reverse_bits(i, n_power)]);
164// }
165// where reverse_bits(i, n_power) computes the n_power-bit reverse. The complications are there
166// to guide the compiler to generate optimal assembly.
167
168#[cfg(not(target_arch = "aarch64"))]
169fn reverse_slice_index_bits_small<F>(vals: &mut [F], lb_n: usize) {
170 if lb_n <= 6 {
171 // BIT_REVERSE_6BIT holds 6-bit reverses. This shift makes them lb_n-bit reverses.
172 let dst_shr_amt = 6 - lb_n as u32;
173 #[allow(clippy::needless_range_loop)]
174 for src in 0..vals.len() {
175 let dst = (BIT_REVERSE_6BIT[src] as usize).wrapping_shr(dst_shr_amt);
176 if src < dst {
177 vals.swap(src, dst);
178 }
179 }
180 } else {
181 // LLVM does not know that it does not need to reverse src at each iteration (which is
182 // expensive on x86). We take advantage of the fact that the low bits of dst change rarely and the high
183 // bits of dst are dependent only on the low bits of src.
184 let dst_lo_shr_amt = usize::BITS - (lb_n - 6) as u32;
185 let dst_hi_shl_amt = lb_n - 6;
186 for src_chunk in 0..(vals.len() >> 6) {
187 let src_hi = src_chunk << 6;
188 let dst_lo = src_chunk.reverse_bits().wrapping_shr(dst_lo_shr_amt);
189 #[allow(clippy::needless_range_loop)]
190 for src_lo in 0..(1 << 6) {
191 let dst_hi = (BIT_REVERSE_6BIT[src_lo] as usize) << dst_hi_shl_amt;
192 let src = src_hi + src_lo;
193 let dst = dst_hi + dst_lo;
194 if src < dst {
195 vals.swap(src, dst);
196 }
197 }
198 }
199 }
200}
201
202#[cfg(target_arch = "aarch64")]
203const fn reverse_slice_index_bits_small<F>(vals: &mut [F], lb_n: usize) {
204 // Aarch64 can reverse bits in one instruction, so the trivial version works best.
205 // use manual `while` loop to enable `const`
206 let mut src = 0;
207 while src < vals.len() {
208 let dst = src.reverse_bits().wrapping_shr(usize::BITS - lb_n as u32);
209 if src < dst {
210 vals.swap(src, dst);
211 }
212
213 src += 1;
214 }
215}
216
217/// Split `arr` chunks and bit-reverse the order of the chunks. There are `1 << lb_num_chunks`
218/// chunks, each of length `1 << lb_chunk_size`.
219/// SAFETY: ensure that `arr.len() == 1 << lb_num_chunks + lb_chunk_size`.
220unsafe fn reverse_slice_index_bits_chunks<F>(
221 vals: &mut [F],
222 lb_num_chunks: usize,
223 lb_chunk_size: usize,
224) {
225 for i in 0..1usize << lb_num_chunks {
226 // `wrapping_shr` handles the silly case when `lb_num_chunks == 0`.
227 let j = i
228 .reverse_bits()
229 .wrapping_shr(usize::BITS - lb_num_chunks as u32);
230 if i < j {
231 unsafe {
232 core::ptr::swap_nonoverlapping(
233 vals.get_unchecked_mut(i << lb_chunk_size),
234 vals.get_unchecked_mut(j << lb_chunk_size),
235 1 << lb_chunk_size,
236 );
237 }
238 }
239 }
240}
241
242/// Allow the compiler to assume that the given predicate `p` is always `true`.
243///
244/// # Safety
245///
246/// Callers must ensure that `p` is true. If this is not the case, the behavior is undefined.
247#[inline(always)]
248pub const unsafe fn assume(p: bool) {
249 debug_assert!(p);
250 if !p {
251 unsafe {
252 unreachable_unchecked();
253 }
254 }
255}
256
257/// Try to force Rust to emit a branch. Example:
258///
259/// ```no_run
260/// let x = 100;
261/// if x > 20 {
262/// println!("x is big!");
263/// p3_util::branch_hint();
264/// } else {
265/// println!("x is small!");
266/// }
267/// ```
268///
269/// This function has no semantics. It is a hint only.
270#[inline(always)]
271pub fn branch_hint() {
272 // NOTE: These are the currently supported assembly architectures. See the
273 // [nightly reference](https://doc.rust-lang.org/nightly/reference/inline-assembly.html) for
274 // the most up-to-date list.
275 #[cfg(any(
276 target_arch = "aarch64",
277 target_arch = "arm",
278 target_arch = "riscv32",
279 target_arch = "riscv64",
280 target_arch = "x86",
281 target_arch = "x86_64",
282 ))]
283 unsafe {
284 core::arch::asm!("", options(nomem, nostack, preserves_flags));
285 }
286}
287
288/// Return a String containing the name of T but with all the crate
289/// and module prefixes removed.
290pub fn pretty_name<T>() -> String {
291 let name = type_name::<T>();
292 let mut result = String::new();
293 for qual in name.split_inclusive(&['<', '>', ',']) {
294 result.push_str(qual.split("::").last().unwrap());
295 }
296 result
297}
298
299/// A C-style buffered input reader, similar to
300/// `core::iter::Iterator::next_chunk()` from nightly.
301///
302/// Returns an array of `MaybeUninit<T>` and the number of items in the
303/// array which have been correctly initialized.
304#[inline]
305fn iter_next_chunk_erased<const BUFLEN: usize, I: Iterator>(
306 iter: &mut I,
307) -> ([MaybeUninit<I::Item>; BUFLEN], usize)
308where
309 I::Item: Copy,
310{
311 let mut buf = [const { MaybeUninit::<I::Item>::uninit() }; BUFLEN];
312 let mut i = 0;
313
314 while i < BUFLEN {
315 if let Some(c) = iter.next() {
316 // Copy the next Item into `buf`.
317 unsafe {
318 buf.get_unchecked_mut(i).write(c);
319 i = i.unchecked_add(1);
320 }
321 } else {
322 // No more items in the iterator.
323 break;
324 }
325 }
326 (buf, i)
327}
328
329/// Gets a shared reference to the contained value.
330///
331/// # Safety
332///
333/// Calling this when the content is not yet fully initialized causes undefined
334/// behavior: it is up to the caller to guarantee that every `MaybeUninit<T>` in
335/// the slice really is in an initialized state.
336///
337/// Copied from:
338/// https://doc.rust-lang.org/std/primitive.slice.html#method.assume_init_ref
339/// Once that is stabilized, this should be removed.
340#[inline(always)]
341pub const unsafe fn assume_init_ref<T>(slice: &[MaybeUninit<T>]) -> &[T] {
342 // SAFETY: casting `slice` to a `*const [T]` is safe since the caller guarantees that
343 // `slice` is initialized, and `MaybeUninit` is guaranteed to have the same layout as `T`.
344 // The pointer obtained is valid since it refers to memory owned by `slice` which is a
345 // reference and thus guaranteed to be valid for reads.
346 unsafe { &*(slice as *const [MaybeUninit<T>] as *const [T]) }
347}
348
349/// Split an iterator into small arrays and apply `func` to each.
350///
351/// Repeatedly read `BUFLEN` elements from `input` into an array and
352/// pass the array to `func` as a slice. If less than `BUFLEN`
353/// elements are remaining, that smaller slice is passed to `func` (if
354/// it is non-empty) and the function returns.
355#[inline]
356pub fn apply_to_chunks<const BUFLEN: usize, I, H>(input: I, mut func: H)
357where
358 I: IntoIterator<Item = u8>,
359 H: FnMut(&[I::Item]),
360{
361 let mut iter = input.into_iter();
362 loop {
363 let (buf, n) = iter_next_chunk_erased::<BUFLEN, _>(&mut iter);
364 if n == 0 {
365 break;
366 }
367 func(unsafe { assume_init_ref(buf.get_unchecked(..n)) });
368 }
369}
370
371/// Pulls `N` items from `iter` and returns them as an array. If the iterator
372/// yields fewer than `N` items (but more than `0`), pads by the given default value.
373///
374/// Since the iterator is passed as a mutable reference and this function calls
375/// `next` at most `N` times, the iterator can still be used afterwards to
376/// retrieve the remaining items.
377///
378/// If `iter.next()` panics, all items already yielded by the iterator are
379/// dropped.
380#[inline]
381fn iter_next_chunk_padded<T: Copy, const N: usize>(
382 iter: &mut impl Iterator<Item = T>,
383 default: T, // Needed due to [T; M] not always implementing Default. Can probably be dropped if const generics stabilize.
384) -> Option<[T; N]> {
385 let (mut arr, n) = iter_next_chunk_erased::<N, _>(iter);
386 (n != 0).then(|| {
387 // Fill the rest of the array with default values.
388 arr[n..].fill(MaybeUninit::new(default));
389 unsafe { mem::transmute_copy::<_, [T; N]>(&arr) }
390 })
391}
392
393/// Returns an iterator over `N` elements of the iterator at a time.
394///
395/// The chunks do not overlap. If `N` does not divide the length of the
396/// iterator, then the last `N-1` elements will be padded with the given default value.
397///
398/// This is essentially a copy pasted version of the nightly `array_chunks` function.
399/// https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.array_chunks
400/// Once that is stabilized this and the functions above it should be removed.
401#[inline]
402pub fn iter_array_chunks_padded<T: Copy, const N: usize>(
403 iter: impl IntoIterator<Item = T>,
404 default: T, // Needed due to [T; M] not always implementing Default. Can probably be dropped if const generics stabilize.
405) -> impl Iterator<Item = [T; N]> {
406 let mut iter = iter.into_iter();
407 iter::from_fn(move || iter_next_chunk_padded(&mut iter, default))
408}
409
410/// Reinterpret a slice of `BaseArray` elements as a slice of `Base` elements
411///
412/// This is useful to convert `&[F; N]` to `&[F]` or `&[A]` to `&[F]` where
413/// `A` has the same size, alignment and memory layout as `[F; N]` for some `N`.
414///
415/// # Safety
416///
417/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
418/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
419/// the array is the same as the alignment of its elements, this means that `BaseArray`
420/// must have the same alignment as `Base`.
421///
422/// # Panics
423///
424/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
425#[inline]
426pub const unsafe fn as_base_slice<Base, BaseArray>(buf: &[BaseArray]) -> &[Base] {
427 const {
428 assert!(align_of::<Base>() == align_of::<BaseArray>());
429 assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
430 }
431
432 let d = size_of::<BaseArray>() / size_of::<Base>();
433
434 let buf_ptr = buf.as_ptr().cast::<Base>();
435 let n = buf.len() * d;
436 unsafe { slice::from_raw_parts(buf_ptr, n) }
437}
438
439/// Reinterpret a mutable slice of `BaseArray` elements as a slice of `Base` elements
440///
441/// This is useful to convert `&[F; N]` to `&[F]` or `&[A]` to `&[F]` where
442/// `A` has the same size, alignment and memory layout as `[F; N]` for some `N`.
443///
444/// # Safety
445///
446/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
447/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
448/// the array is the same as the alignment of its elements, this means that `BaseArray`
449/// must have the same alignment as `Base`.
450///
451/// # Panics
452///
453/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
454#[inline]
455pub const unsafe fn as_base_slice_mut<Base, BaseArray>(buf: &mut [BaseArray]) -> &mut [Base] {
456 const {
457 assert!(align_of::<Base>() == align_of::<BaseArray>());
458 assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
459 }
460
461 let d = size_of::<BaseArray>() / size_of::<Base>();
462
463 let buf_ptr = buf.as_mut_ptr().cast::<Base>();
464 let n = buf.len() * d;
465 unsafe { slice::from_raw_parts_mut(buf_ptr, n) }
466}
467
468/// Convert a vector of `BaseArray` elements to a vector of `Base` elements without any
469/// reallocations.
470///
471/// This is useful to convert `Vec<[F; N]>` to `Vec<F>` or `Vec<A>` to `Vec<F>` where
472/// `A` has the same size, alignment and memory layout as `[F; N]` for some `N`. It can also,
473/// be used to safely convert `Vec<u32>` to `Vec<F>` if `F` is a `32` bit field
474/// or `Vec<u64>` to `Vec<F>` if `F` is a `64` bit field.
475///
476/// # Safety
477///
478/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
479/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
480/// the array is the same as the alignment of its elements, this means that `BaseArray`
481/// must have the same alignment as `Base`.
482///
483/// # Panics
484///
485/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
486#[inline]
487pub unsafe fn flatten_to_base<Base, BaseArray>(vec: Vec<BaseArray>) -> Vec<Base> {
488 const {
489 assert!(align_of::<Base>() == align_of::<BaseArray>());
490 assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
491 }
492
493 let d = size_of::<BaseArray>() / size_of::<Base>();
494 // Prevent running `vec`'s destructor so we are in complete control
495 // of the allocation.
496 let mut values = ManuallyDrop::new(vec);
497
498 // Each `Self` is an array of `d` elements, so the length and capacity of
499 // the new vector will be multiplied by `d`.
500 let new_len = values.len() * d;
501 let new_cap = values.capacity() * d;
502
503 // Safe as BaseArray and Base have the same alignment.
504 let ptr = values.as_mut_ptr() as *mut Base;
505
506 unsafe {
507 // Safety:
508 // - BaseArray and Base have the same alignment.
509 // - As size_of::<BaseArray>() == size_of::<Base>() * d:
510 // -- The capacity of the new vector is equal to the capacity of the old vector.
511 // -- The first new_len elements of the new vector correspond to the first
512 // len elements of the old vector and so are properly initialized.
513 Vec::from_raw_parts(ptr, new_len, new_cap)
514 }
515}
516
517/// Convert a vector of `Base` elements to a vector of `BaseArray` elements ideally without any
518/// reallocations.
519///
520/// This is an inverse of `flatten_to_base`. Unfortunately, unlike `flatten_to_base`, it may not be
521/// possible to avoid allocations. This issue is that there is not way to guarantee that the capacity
522/// of the vector is a multiple of `d`.
523///
524/// # Safety
525///
526/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
527/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
528/// the array is the same as the alignment of its elements, this means that `BaseArray`
529/// must have the same alignment as `Base`.
530///
531/// # Panics
532///
533/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
534/// This panics if the length of the vector is not a multiple of the ratio of the sizes.
535#[inline]
536pub unsafe fn reconstitute_from_base<Base, BaseArray: Clone>(mut vec: Vec<Base>) -> Vec<BaseArray> {
537 const {
538 assert!(align_of::<Base>() == align_of::<BaseArray>());
539 assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
540 }
541
542 let d = size_of::<BaseArray>() / size_of::<Base>();
543
544 assert!(
545 vec.len().is_multiple_of(d),
546 "Vector length (got {}) must be a multiple of the extension field dimension ({}).",
547 vec.len(),
548 d
549 );
550
551 let new_len = vec.len() / d;
552
553 // We could call vec.shrink_to_fit() here to try and increase the probability that
554 // the capacity is a multiple of d. That might cause a reallocation though which
555 // would defeat the whole purpose.
556 let cap = vec.capacity();
557
558 // The assumption is that basically all callers of `reconstitute_from_base_vec` will be calling it
559 // with a vector constructed from `flatten_to_base` and so the capacity should be a multiple of `d`.
560 // But capacities can do strange things so we need to support both possibilities.
561 // Note that the `else` branch would also work if the capacity is a multiple of `d` but it is slower.
562 if cap.is_multiple_of(d) {
563 // Prevent running `vec`'s destructor so we are in complete control
564 // of the allocation.
565 let mut values = ManuallyDrop::new(vec);
566
567 // If we are on this branch then the capacity is a multiple of `d`.
568 let new_cap = cap / d;
569
570 // Safe as BaseArray and Base have the same alignment.
571 let ptr = values.as_mut_ptr() as *mut BaseArray;
572
573 unsafe {
574 // Safety:
575 // - BaseArray and Base have the same alignment.
576 // - As size_of::<Base>() == size_of::<BaseArray>() / d:
577 // -- If we have reached this point, the length and capacity are both divisible by `d`.
578 // -- The capacity of the new vector is equal to the capacity of the old vector.
579 // -- The first new_len elements of the new vector correspond to the first
580 // len elements of the old vector and so are properly initialized.
581 Vec::from_raw_parts(ptr, new_len, new_cap)
582 }
583 } else {
584 // If the capacity is not a multiple of `D`, we go via slices.
585
586 let buf_ptr = vec.as_mut_ptr().cast::<BaseArray>();
587 let slice = unsafe {
588 // Safety:
589 // - BaseArray and Base have the same alignment.
590 // - As size_of::<Base>() == size_of::<BaseArray>() / D:
591 // -- If we have reached this point, the length is divisible by `D`.
592 // -- The first new_len elements of the slice correspond to the first
593 // len elements of the old slice and so are properly initialized.
594 slice::from_raw_parts(buf_ptr, new_len)
595 };
596
597 // Ideally the compiler could optimize this away to avoid the copy but it appears not to.
598 slice.to_vec()
599 }
600}
601
602#[inline(always)]
603pub const fn relatively_prime_u64(mut u: u64, mut v: u64) -> bool {
604 // Check that neither input is 0.
605 if u == 0 || v == 0 {
606 return false;
607 }
608
609 // Check divisibility by 2.
610 if (u | v) & 1 == 0 {
611 return false;
612 }
613
614 // Remove factors of 2 from `u` and `v`
615 u >>= u.trailing_zeros();
616 if u == 1 {
617 return true;
618 }
619
620 while v != 0 {
621 v >>= v.trailing_zeros();
622 if v == 1 {
623 return true;
624 }
625
626 // Ensure u <= v
627 if u > v {
628 core::mem::swap(&mut u, &mut v);
629 }
630
631 // This looks inefficient for v >> u but thanks to the fact that we remove
632 // trailing_zeros of v in every iteration, it ends up much more performative
633 // than first glance implies.
634 v -= u;
635 }
636 // If we made it through the loop, at no point is u or v equal to 1 and so the gcd
637 // must be greater than 1.
638 false
639}
640
641/// Inner loop of the deferred GCD algorithm.
642///
643/// See: https://eprint.iacr.org/2020/972.pdf for more information.
644///
645/// This is basically a mini GCD algorithm which builds up a transformation to apply to the larger
646/// numbers in the main loop. The key point is that this small loop only uses u64s, subtractions and
647/// bit shifts, which are very fast operations.
648///
649/// The bottom `NUM_ROUNDS` bits of `a` and `b` should match the bottom `NUM_ROUNDS` bits of
650/// the corresponding big-ints and the top `NUM_ROUNDS + 2` should match the top bits including
651/// zeroes if the original numbers have different sizes.
652#[inline]
653pub const fn gcd_inner<const NUM_ROUNDS: usize>(a: &mut u64, b: &mut u64) -> (i64, i64, i64, i64) {
654 // Initialise update factors.
655 // At the start of round 0: -1 < f0, g0, f1, g1 <= 1
656 let (mut f0, mut g0, mut f1, mut g1) = (1, 0, 0, 1);
657
658 // If at the start of a round: -2^i < f0, g0, f1, g1 <= 2^i
659 // Then, at the end of the round: -2^{i + 1} < f0, g0, f1, g1 <= 2^{i + 1}
660 // use manual `while` loop to enable `const`
661 let mut round = 0;
662 while round < NUM_ROUNDS {
663 if *a & 1 == 0 {
664 *a >>= 1;
665 } else {
666 if *a < *b {
667 core::mem::swap(a, b);
668 (f0, f1) = (f1, f0);
669 (g0, g1) = (g1, g0);
670 }
671 *a -= *b;
672 *a >>= 1;
673 f0 -= f1;
674 g0 -= g1;
675 }
676 f1 <<= 1;
677 g1 <<= 1;
678
679 round += 1;
680 }
681
682 // -2^NUM_ROUNDS < f0, g0, f1, g1 <= 2^NUM_ROUNDS
683 // Hence provided NUM_ROUNDS <= 62, we will not get any overflow.
684 // Additionally, if NUM_ROUNDS <= 63, then the only source of overflow will be
685 // if a variable is meant to equal 2^{63} in which case it will overflow to -2^{63}.
686 (f0, g0, f1, g1)
687}
688
689/// Inverts elements inside the prime field `F_P` with `P < 2^FIELD_BITS`.
690///
691/// Arguments:
692/// - a: The value we want to invert. It must be < P.
693/// - b: The value of the prime `P > 2`.
694///
695/// Output:
696/// - A `64-bit` signed integer `v` equal to `2^{2 * FIELD_BITS - 2} a^{-1} mod P` with
697/// size `|v| < 2^{2 * FIELD_BITS - 2}`.
698///
699/// It is up to the user to ensure that `b` is an odd prime with at most `FIELD_BITS` bits and
700/// `a < b`. If either of these assumptions break, the output is undefined.
701#[inline]
702pub const fn gcd_inversion_prime_field_32<const FIELD_BITS: u32>(mut a: u32, mut b: u32) -> i64 {
703 const {
704 assert!(FIELD_BITS <= 32);
705 }
706 debug_assert!(((1_u64 << FIELD_BITS) - 1) >= b as u64);
707
708 // Initialise u, v. Note that |u|, |v| <= 2^0
709 let (mut u, mut v) = (1_i64, 0_i64);
710
711 // Let a0 and P denote the initial values of a and b. Observe:
712 // `a = u * a0 mod P`
713 // `b = v * a0 mod P`
714 // `len(a) + len(b) <= 2 * len(P) <= 2 * FIELD_BITS`
715
716 // use manual `while` loop to enable `const`
717 let mut i = 0;
718 while i < 2 * FIELD_BITS - 2 {
719 // Assume at the start of the loop i:
720 // (1) `|u|, |v| <= 2^{i}`
721 // (2) `2^i * a = u * a0 mod P`
722 // (3) `2^i * b = v * a0 mod P`
723 // (4) `gcd(a, b) = 1`
724 // (5) `b` is odd.
725 // (6) `len(a) + len(b) <= max(n - i, 1)`
726
727 if a & 1 != 0 {
728 if a < b {
729 (a, b) = (b, a);
730 (u, v) = (v, u);
731 }
732 // As b < a, this subtraction cannot increase `len(a) + len(b)`
733 a -= b;
734 // Observe |u'| = |u - v| <= |u| + |v| <= 2^{i + 1}
735 u -= v;
736
737 // As (1) and (2) hold, we have
738 // `2^i a' = 2^i * (a - b) = (u - v) * a0 mod P = u' * a0 mod P`
739 }
740 // As b is odd, a must now be even.
741 // This reduces `len(a) + len(b)` by 1 (unless `a = 0` in which case `b = 1` and the sum of the lengths is always 1)
742 a >>= 1;
743
744 // Observe |v'| = 2|v| <= 2^{i + 1}
745 v <<= 1;
746
747 // Thus as the end of loop i:
748 // (1) `|u|, |v| <= 2^{i + 1}`
749 // (2) `2^{i + 1} * a = u * a0 mod P` (As we have halved a)
750 // (3) `2^{i + 1} * b = v * a0 mod P` (As we have doubled v)
751 // (4) `gcd(a, b) = 1`
752 // (5) `b` is odd.
753 // (6) `len(a) + len(b) <= max(n - i - 1, 1)`
754
755 i += 1;
756 }
757
758 // After the loops, we see that:
759 // |u|, |v| <= 2^{2 * FIELD_BITS - 2}: Hence for FIELD_BITS <= 32 we will not overflow an i64.
760 // `2^{2 * FIELD_BITS - 2} * b = v * a0 mod P`
761 // `len(a) + len(b) <= 2` with `gcd(a, b) = 1` and `b` odd.
762 // This implies that `b` must be `1` and so `v = 2^{2 * FIELD_BITS - 2} a0^{-1} mod P` as desired.
763 v
764}
765
766#[cfg(test)]
767mod tests {
768 use alloc::vec;
769 use alloc::vec::Vec;
770
771 use rand::rngs::SmallRng;
772 use rand::{Rng, SeedableRng};
773
774 use super::*;
775
776 #[test]
777 fn test_reverse_bits_len() {
778 assert_eq!(reverse_bits_len(0b0000000000, 10), 0b0000000000);
779 assert_eq!(reverse_bits_len(0b0000000001, 10), 0b1000000000);
780 assert_eq!(reverse_bits_len(0b1000000000, 10), 0b0000000001);
781 assert_eq!(reverse_bits_len(0b00000, 5), 0b00000);
782 assert_eq!(reverse_bits_len(0b01011, 5), 0b11010);
783 }
784
785 #[test]
786 fn test_reverse_index_bits() {
787 let mut arg = vec![10, 20, 30, 40];
788 reverse_slice_index_bits(&mut arg);
789 assert_eq!(arg, vec![10, 30, 20, 40]);
790
791 let mut input256: Vec<u64> = (0..256).collect();
792 #[rustfmt::skip]
793 let output256: Vec<u64> = vec![
794 0x00, 0x80, 0x40, 0xc0, 0x20, 0xa0, 0x60, 0xe0, 0x10, 0x90, 0x50, 0xd0, 0x30, 0xb0, 0x70, 0xf0,
795 0x08, 0x88, 0x48, 0xc8, 0x28, 0xa8, 0x68, 0xe8, 0x18, 0x98, 0x58, 0xd8, 0x38, 0xb8, 0x78, 0xf8,
796 0x04, 0x84, 0x44, 0xc4, 0x24, 0xa4, 0x64, 0xe4, 0x14, 0x94, 0x54, 0xd4, 0x34, 0xb4, 0x74, 0xf4,
797 0x0c, 0x8c, 0x4c, 0xcc, 0x2c, 0xac, 0x6c, 0xec, 0x1c, 0x9c, 0x5c, 0xdc, 0x3c, 0xbc, 0x7c, 0xfc,
798 0x02, 0x82, 0x42, 0xc2, 0x22, 0xa2, 0x62, 0xe2, 0x12, 0x92, 0x52, 0xd2, 0x32, 0xb2, 0x72, 0xf2,
799 0x0a, 0x8a, 0x4a, 0xca, 0x2a, 0xaa, 0x6a, 0xea, 0x1a, 0x9a, 0x5a, 0xda, 0x3a, 0xba, 0x7a, 0xfa,
800 0x06, 0x86, 0x46, 0xc6, 0x26, 0xa6, 0x66, 0xe6, 0x16, 0x96, 0x56, 0xd6, 0x36, 0xb6, 0x76, 0xf6,
801 0x0e, 0x8e, 0x4e, 0xce, 0x2e, 0xae, 0x6e, 0xee, 0x1e, 0x9e, 0x5e, 0xde, 0x3e, 0xbe, 0x7e, 0xfe,
802 0x01, 0x81, 0x41, 0xc1, 0x21, 0xa1, 0x61, 0xe1, 0x11, 0x91, 0x51, 0xd1, 0x31, 0xb1, 0x71, 0xf1,
803 0x09, 0x89, 0x49, 0xc9, 0x29, 0xa9, 0x69, 0xe9, 0x19, 0x99, 0x59, 0xd9, 0x39, 0xb9, 0x79, 0xf9,
804 0x05, 0x85, 0x45, 0xc5, 0x25, 0xa5, 0x65, 0xe5, 0x15, 0x95, 0x55, 0xd5, 0x35, 0xb5, 0x75, 0xf5,
805 0x0d, 0x8d, 0x4d, 0xcd, 0x2d, 0xad, 0x6d, 0xed, 0x1d, 0x9d, 0x5d, 0xdd, 0x3d, 0xbd, 0x7d, 0xfd,
806 0x03, 0x83, 0x43, 0xc3, 0x23, 0xa3, 0x63, 0xe3, 0x13, 0x93, 0x53, 0xd3, 0x33, 0xb3, 0x73, 0xf3,
807 0x0b, 0x8b, 0x4b, 0xcb, 0x2b, 0xab, 0x6b, 0xeb, 0x1b, 0x9b, 0x5b, 0xdb, 0x3b, 0xbb, 0x7b, 0xfb,
808 0x07, 0x87, 0x47, 0xc7, 0x27, 0xa7, 0x67, 0xe7, 0x17, 0x97, 0x57, 0xd7, 0x37, 0xb7, 0x77, 0xf7,
809 0x0f, 0x8f, 0x4f, 0xcf, 0x2f, 0xaf, 0x6f, 0xef, 0x1f, 0x9f, 0x5f, 0xdf, 0x3f, 0xbf, 0x7f, 0xff,
810 ];
811 reverse_slice_index_bits(&mut input256[..]);
812 assert_eq!(input256, output256);
813 }
814
815 #[test]
816 fn test_apply_to_chunks_exact_fit() {
817 const CHUNK_SIZE: usize = 4;
818 let input: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8];
819 let mut results: Vec<Vec<u8>> = Vec::new();
820
821 apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
822 results.push(chunk.to_vec());
823 });
824
825 assert_eq!(results, vec![vec![1, 2, 3, 4], vec![5, 6, 7, 8]]);
826 }
827
828 #[test]
829 fn test_apply_to_chunks_with_remainder() {
830 const CHUNK_SIZE: usize = 3;
831 let input: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7];
832 let mut results: Vec<Vec<u8>> = Vec::new();
833
834 apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
835 results.push(chunk.to_vec());
836 });
837
838 assert_eq!(results, vec![vec![1, 2, 3], vec![4, 5, 6], vec![7]]);
839 }
840
841 #[test]
842 fn test_apply_to_chunks_empty_input() {
843 const CHUNK_SIZE: usize = 4;
844 let input: Vec<u8> = vec![];
845 let mut results: Vec<Vec<u8>> = Vec::new();
846
847 apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
848 results.push(chunk.to_vec());
849 });
850
851 assert!(results.is_empty());
852 }
853
854 #[test]
855 fn test_apply_to_chunks_single_chunk() {
856 const CHUNK_SIZE: usize = 10;
857 let input: Vec<u8> = vec![1, 2, 3, 4, 5];
858 let mut results: Vec<Vec<u8>> = Vec::new();
859
860 apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
861 results.push(chunk.to_vec());
862 });
863
864 assert_eq!(results, vec![vec![1, 2, 3, 4, 5]]);
865 }
866
867 #[test]
868 fn test_apply_to_chunks_large_chunk_size() {
869 const CHUNK_SIZE: usize = 100;
870 let input: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8];
871 let mut results: Vec<Vec<u8>> = Vec::new();
872
873 apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
874 results.push(chunk.to_vec());
875 });
876
877 assert_eq!(results, vec![vec![1, 2, 3, 4, 5, 6, 7, 8]]);
878 }
879
880 #[test]
881 fn test_apply_to_chunks_large_input() {
882 const CHUNK_SIZE: usize = 5;
883 let input: Vec<u8> = (1..=20).collect();
884 let mut results: Vec<Vec<u8>> = Vec::new();
885
886 apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
887 results.push(chunk.to_vec());
888 });
889
890 assert_eq!(
891 results,
892 vec![
893 vec![1, 2, 3, 4, 5],
894 vec![6, 7, 8, 9, 10],
895 vec![11, 12, 13, 14, 15],
896 vec![16, 17, 18, 19, 20]
897 ]
898 );
899 }
900
901 #[test]
902 fn test_reverse_slice_index_bits_random() {
903 let lengths = [32, 128, 1 << 16];
904 let mut rng = SmallRng::seed_from_u64(1);
905 for _ in 0..32 {
906 for &length in &lengths {
907 let mut rand_list: Vec<u32> = Vec::with_capacity(length);
908 rand_list.resize_with(length, || rng.random());
909 let expect = reverse_index_bits_naive(&rand_list);
910
911 let mut actual = rand_list.clone();
912 reverse_slice_index_bits(&mut actual);
913
914 assert_eq!(actual, expect);
915 }
916 }
917 }
918
919 #[test]
920 fn test_log2_strict_usize_edge_cases() {
921 assert_eq!(log2_strict_usize(1), 0);
922 assert_eq!(log2_strict_usize(2), 1);
923 assert_eq!(log2_strict_usize(1 << 18), 18);
924 assert_eq!(log2_strict_usize(1 << 31), 31);
925 assert_eq!(
926 log2_strict_usize(1 << (usize::BITS - 1)),
927 usize::BITS as usize - 1
928 );
929 }
930
931 #[test]
932 #[should_panic]
933 fn test_log2_strict_usize_zero() {
934 let _ = log2_strict_usize(0);
935 }
936
937 #[test]
938 #[should_panic]
939 fn test_log2_strict_usize_nonpower_2() {
940 let _ = log2_strict_usize(0x78c341c65ae6d262);
941 }
942
943 #[test]
944 #[should_panic]
945 fn test_log2_strict_usize_max() {
946 let _ = log2_strict_usize(usize::MAX);
947 }
948
949 #[test]
950 fn test_log2_ceil_usize_comprehensive() {
951 // Powers of 2
952 assert_eq!(log2_ceil_usize(0), 0);
953 assert_eq!(log2_ceil_usize(1), 0);
954 assert_eq!(log2_ceil_usize(2), 1);
955 assert_eq!(log2_ceil_usize(1 << 18), 18);
956 assert_eq!(log2_ceil_usize(1 << 31), 31);
957 assert_eq!(
958 log2_ceil_usize(1 << (usize::BITS - 1)),
959 usize::BITS as usize - 1
960 );
961
962 // Nonpowers; want to round up
963 assert_eq!(log2_ceil_usize(3), 2);
964 assert_eq!(log2_ceil_usize(0x14fe901b), 29);
965 assert_eq!(
966 log2_ceil_usize((1 << (usize::BITS - 1)) + 1),
967 usize::BITS as usize
968 );
969 assert_eq!(log2_ceil_usize(usize::MAX - 1), usize::BITS as usize);
970 assert_eq!(log2_ceil_usize(usize::MAX), usize::BITS as usize);
971 }
972
973 fn reverse_index_bits_naive<T: Copy>(arr: &[T]) -> Vec<T> {
974 let n = arr.len();
975 let n_power = log2_strict_usize(n);
976
977 let mut out = vec![None; n];
978 for (i, v) in arr.iter().enumerate() {
979 let dst = i.reverse_bits() >> (usize::BITS - n_power as u32);
980 out[dst] = Some(*v);
981 }
982
983 out.into_iter().map(|x| x.unwrap()).collect()
984 }
985
986 #[test]
987 fn test_relatively_prime_u64() {
988 // Zero cases (should always return false)
989 assert!(!relatively_prime_u64(0, 0));
990 assert!(!relatively_prime_u64(10, 0));
991 assert!(!relatively_prime_u64(0, 10));
992 assert!(!relatively_prime_u64(0, 123456789));
993
994 // Number with itself (if greater than 1, not relatively prime)
995 assert!(relatively_prime_u64(1, 1));
996 assert!(!relatively_prime_u64(10, 10));
997 assert!(!relatively_prime_u64(99999, 99999));
998
999 // Powers of 2 (always false since they share factor 2)
1000 assert!(!relatively_prime_u64(2, 4));
1001 assert!(!relatively_prime_u64(16, 32));
1002 assert!(!relatively_prime_u64(64, 128));
1003 assert!(!relatively_prime_u64(1024, 4096));
1004 assert!(!relatively_prime_u64(u64::MAX, u64::MAX));
1005
1006 // One number is a multiple of the other (always false)
1007 assert!(!relatively_prime_u64(5, 10));
1008 assert!(!relatively_prime_u64(12, 36));
1009 assert!(!relatively_prime_u64(15, 45));
1010 assert!(!relatively_prime_u64(100, 500));
1011
1012 // Co-prime numbers (should be true)
1013 assert!(relatively_prime_u64(17, 31));
1014 assert!(relatively_prime_u64(97, 43));
1015 assert!(relatively_prime_u64(7919, 65537));
1016 assert!(relatively_prime_u64(15485863, 32452843));
1017
1018 // Small prime numbers (should be true)
1019 assert!(relatively_prime_u64(13, 17));
1020 assert!(relatively_prime_u64(101, 103));
1021 assert!(relatively_prime_u64(1009, 1013));
1022
1023 // Large numbers (some cases where they are relatively prime or not)
1024 assert!(!relatively_prime_u64(
1025 190266297176832000,
1026 10430732356495263744
1027 ));
1028 assert!(!relatively_prime_u64(
1029 2040134905096275968,
1030 5701159354248194048
1031 ));
1032 assert!(!relatively_prime_u64(
1033 16611311494648745984,
1034 7514969329383038976
1035 ));
1036 assert!(!relatively_prime_u64(
1037 14863931409971066880,
1038 7911906750992527360
1039 ));
1040
1041 // Max values
1042 assert!(relatively_prime_u64(u64::MAX, 1));
1043 assert!(relatively_prime_u64(u64::MAX, u64::MAX - 1));
1044 assert!(!relatively_prime_u64(u64::MAX, u64::MAX));
1045 }
1046}