p3_util/
lib.rs

1//! Various simple utilities.
2
3#![no_std]
4
5extern crate alloc;
6
7use alloc::slice;
8use alloc::string::String;
9use alloc::vec::Vec;
10use core::any::type_name;
11use core::hint::unreachable_unchecked;
12use core::mem::{ManuallyDrop, MaybeUninit};
13use core::{iter, mem};
14
15use crate::transpose::transpose_in_place_square;
16
17pub mod array_serialization;
18pub mod linear_map;
19pub mod transpose;
20pub mod zip_eq;
21
22/// Computes `ceil(log_2(n))`.
23#[must_use]
24pub const fn log2_ceil_usize(n: usize) -> usize {
25    (usize::BITS - n.saturating_sub(1).leading_zeros()) as usize
26}
27
28#[must_use]
29pub const fn log2_ceil_u64(n: u64) -> u64 {
30    (u64::BITS - n.saturating_sub(1).leading_zeros()) as u64
31}
32
33/// Computes `log_2(n)`
34///
35/// # Panics
36/// Panics if `n` is not a power of two.
37#[must_use]
38#[inline]
39pub const fn log2_strict_usize(n: usize) -> usize {
40    let res = n.trailing_zeros();
41    assert!(n.wrapping_shr(res) == 1, "Not a power of two");
42    // Tell the optimizer about the semantics of `log2_strict`. i.e. it can replace `n` with
43    // `1 << res` and vice versa.
44    unsafe {
45        assume(n == 1 << res);
46    }
47    res as usize
48}
49
50/// Returns `[0, ..., N - 1]`.
51#[must_use]
52pub const fn indices_arr<const N: usize>() -> [usize; N] {
53    let mut indices_arr = [0; N];
54    let mut i = 0;
55    while i < N {
56        indices_arr[i] = i;
57        i += 1;
58    }
59    indices_arr
60}
61
62#[inline]
63pub const fn reverse_bits(x: usize, n: usize) -> usize {
64    // Assert that n is a power of 2
65    debug_assert!(n.is_power_of_two());
66    reverse_bits_len(x, n.trailing_zeros() as usize)
67}
68
69#[inline]
70pub const fn reverse_bits_len(x: usize, bit_len: usize) -> usize {
71    // NB: The only reason we need overflowing_shr() here as opposed
72    // to plain '>>' is to accommodate the case n == num_bits == 0,
73    // which would become `0 >> 64`. Rust thinks that any shift of 64
74    // bits causes overflow, even when the argument is zero.
75    x.reverse_bits()
76        .overflowing_shr(usize::BITS - bit_len as u32)
77        .0
78}
79
80// Lookup table of 6-bit reverses.
81// NB: 2^6=64 bytes is a cache line. A smaller table wastes cache space.
82#[cfg(not(target_arch = "aarch64"))]
83#[rustfmt::skip]
84const BIT_REVERSE_6BIT: &[u8] = &[
85    0o00, 0o40, 0o20, 0o60, 0o10, 0o50, 0o30, 0o70,
86    0o04, 0o44, 0o24, 0o64, 0o14, 0o54, 0o34, 0o74,
87    0o02, 0o42, 0o22, 0o62, 0o12, 0o52, 0o32, 0o72,
88    0o06, 0o46, 0o26, 0o66, 0o16, 0o56, 0o36, 0o76,
89    0o01, 0o41, 0o21, 0o61, 0o11, 0o51, 0o31, 0o71,
90    0o05, 0o45, 0o25, 0o65, 0o15, 0o55, 0o35, 0o75,
91    0o03, 0o43, 0o23, 0o63, 0o13, 0o53, 0o33, 0o73,
92    0o07, 0o47, 0o27, 0o67, 0o17, 0o57, 0o37, 0o77,
93];
94
95// Ensure that SMALL_ARR_SIZE >= 4 * BIG_T_SIZE.
96const BIG_T_SIZE: usize = 1 << 14;
97const SMALL_ARR_SIZE: usize = 1 << 16;
98
99/// Permutes `arr` such that each index is mapped to its reverse in binary.
100///
101/// If the whole array fits in fast cache, then the trivial algorithm is cache friendly. Also, if
102/// `T` is really big, then the trivial algorithm is cache-friendly, no matter the size of the array.
103pub fn reverse_slice_index_bits<F>(vals: &mut [F])
104where
105    F: Copy + Send + Sync,
106{
107    let n = vals.len();
108    if n == 0 {
109        return;
110    }
111    let log_n = log2_strict_usize(n);
112
113    // If the whole array fits in fast cache, then the trivial algorithm is cache friendly. Also, if
114    // `T` is really big, then the trivial algorithm is cache-friendly, no matter the size of the array.
115    if core::mem::size_of::<F>() << log_n <= SMALL_ARR_SIZE
116        || core::mem::size_of::<F>() >= BIG_T_SIZE
117    {
118        reverse_slice_index_bits_small(vals, log_n);
119    } else {
120        debug_assert!(n >= 4); // By our choice of `BIG_T_SIZE` and `SMALL_ARR_SIZE`.
121
122        // Algorithm:
123        //
124        // Treat `arr` as a `sqrt(n)` by `sqrt(n)` row-major matrix. (Assume for now that `lb_n` is
125        // even, i.e., `n` is a square number.) To perform bit-order reversal we:
126        //  1. Bit-reverse the order of the rows. (They are contiguous in memory, so this is
127        //     basically a series of large `memcpy`s.)
128        //  2. Transpose the matrix.
129        //  3. Bit-reverse the order of the rows.
130        //
131        // This is equivalent to, for every index `0 <= i < n`:
132        //  1. bit-reversing `i[lb_n / 2..lb_n]`,
133        //  2. swapping `i[0..lb_n / 2]` and `i[lb_n / 2..lb_n]`,
134        //  3. bit-reversing `i[lb_n / 2..lb_n]`.
135        //
136        // If `lb_n` is odd, i.e., `n` is not a square number, then the above procedure requires
137        // slight modification. At steps 1 and 3 we bit-reverse bits `ceil(lb_n / 2)..lb_n`, of the
138        // index (shuffling `floor(lb_n / 2)` chunks of length `ceil(lb_n / 2)`). At step 2, we
139        // perform _two_ transposes. We treat `arr` as two matrices, one where the middle bit of the
140        // index is `0` and another, where the middle bit is `1`; we transpose each individually.
141
142        let lb_num_chunks = log_n >> 1;
143        let lb_chunk_size = log_n - lb_num_chunks;
144        unsafe {
145            reverse_slice_index_bits_chunks(vals, lb_num_chunks, lb_chunk_size);
146            transpose_in_place_square(vals, lb_chunk_size, lb_num_chunks, 0);
147            if lb_num_chunks != lb_chunk_size {
148                // `arr` cannot be interpreted as a square matrix. We instead interpret it as a
149                // `1 << lb_num_chunks` by `2` by `1 << lb_num_chunks` tensor, in row-major order.
150                // The above transpose acted on `tensor[..., 0, ...]` (all indices with middle bit
151                // `0`). We still need to transpose `tensor[..., 1, ...]`. To do so, we advance
152                // arr by `1 << lb_num_chunks` effectively, adding that to every index.
153                let vals_with_offset = &mut vals[1 << lb_num_chunks..];
154                transpose_in_place_square(vals_with_offset, lb_chunk_size, lb_num_chunks, 0);
155            }
156            reverse_slice_index_bits_chunks(vals, lb_num_chunks, lb_chunk_size);
157        }
158    }
159}
160
161// Both functions below are semantically equivalent to:
162//     for i in 0..n {
163//         result.push(arr[reverse_bits(i, n_power)]);
164//     }
165// where reverse_bits(i, n_power) computes the n_power-bit reverse. The complications are there
166// to guide the compiler to generate optimal assembly.
167
168#[cfg(not(target_arch = "aarch64"))]
169fn reverse_slice_index_bits_small<F>(vals: &mut [F], lb_n: usize) {
170    if lb_n <= 6 {
171        // BIT_REVERSE_6BIT holds 6-bit reverses. This shift makes them lb_n-bit reverses.
172        let dst_shr_amt = 6 - lb_n as u32;
173        #[allow(clippy::needless_range_loop)]
174        for src in 0..vals.len() {
175            let dst = (BIT_REVERSE_6BIT[src] as usize).wrapping_shr(dst_shr_amt);
176            if src < dst {
177                vals.swap(src, dst);
178            }
179        }
180    } else {
181        // LLVM does not know that it does not need to reverse src at each iteration (which is
182        // expensive on x86). We take advantage of the fact that the low bits of dst change rarely and the high
183        // bits of dst are dependent only on the low bits of src.
184        let dst_lo_shr_amt = usize::BITS - (lb_n - 6) as u32;
185        let dst_hi_shl_amt = lb_n - 6;
186        for src_chunk in 0..(vals.len() >> 6) {
187            let src_hi = src_chunk << 6;
188            let dst_lo = src_chunk.reverse_bits().wrapping_shr(dst_lo_shr_amt);
189            #[allow(clippy::needless_range_loop)]
190            for src_lo in 0..(1 << 6) {
191                let dst_hi = (BIT_REVERSE_6BIT[src_lo] as usize) << dst_hi_shl_amt;
192                let src = src_hi + src_lo;
193                let dst = dst_hi + dst_lo;
194                if src < dst {
195                    vals.swap(src, dst);
196                }
197            }
198        }
199    }
200}
201
202#[cfg(target_arch = "aarch64")]
203const fn reverse_slice_index_bits_small<F>(vals: &mut [F], lb_n: usize) {
204    // Aarch64 can reverse bits in one instruction, so the trivial version works best.
205    // use manual `while` loop to enable `const`
206    let mut src = 0;
207    while src < vals.len() {
208        let dst = src.reverse_bits().wrapping_shr(usize::BITS - lb_n as u32);
209        if src < dst {
210            vals.swap(src, dst);
211        }
212
213        src += 1;
214    }
215}
216
217/// Split `arr` chunks and bit-reverse the order of the chunks. There are `1 << lb_num_chunks`
218/// chunks, each of length `1 << lb_chunk_size`.
219/// SAFETY: ensure that `arr.len() == 1 << lb_num_chunks + lb_chunk_size`.
220unsafe fn reverse_slice_index_bits_chunks<F>(
221    vals: &mut [F],
222    lb_num_chunks: usize,
223    lb_chunk_size: usize,
224) {
225    for i in 0..1usize << lb_num_chunks {
226        // `wrapping_shr` handles the silly case when `lb_num_chunks == 0`.
227        let j = i
228            .reverse_bits()
229            .wrapping_shr(usize::BITS - lb_num_chunks as u32);
230        if i < j {
231            unsafe {
232                core::ptr::swap_nonoverlapping(
233                    vals.get_unchecked_mut(i << lb_chunk_size),
234                    vals.get_unchecked_mut(j << lb_chunk_size),
235                    1 << lb_chunk_size,
236                );
237            }
238        }
239    }
240}
241
242/// Allow the compiler to assume that the given predicate `p` is always `true`.
243///
244/// # Safety
245///
246/// Callers must ensure that `p` is true. If this is not the case, the behavior is undefined.
247#[inline(always)]
248pub const unsafe fn assume(p: bool) {
249    debug_assert!(p);
250    if !p {
251        unsafe {
252            unreachable_unchecked();
253        }
254    }
255}
256
257/// Try to force Rust to emit a branch. Example:
258///
259/// ```no_run
260/// let x = 100;
261/// if x > 20 {
262///     println!("x is big!");
263///     p3_util::branch_hint();
264/// } else {
265///     println!("x is small!");
266/// }
267/// ```
268///
269/// This function has no semantics. It is a hint only.
270#[inline(always)]
271pub fn branch_hint() {
272    // NOTE: These are the currently supported assembly architectures. See the
273    // [nightly reference](https://doc.rust-lang.org/nightly/reference/inline-assembly.html) for
274    // the most up-to-date list.
275    #[cfg(any(
276        target_arch = "aarch64",
277        target_arch = "arm",
278        target_arch = "riscv32",
279        target_arch = "riscv64",
280        target_arch = "x86",
281        target_arch = "x86_64",
282    ))]
283    unsafe {
284        core::arch::asm!("", options(nomem, nostack, preserves_flags));
285    }
286}
287
288/// Return a String containing the name of T but with all the crate
289/// and module prefixes removed.
290pub fn pretty_name<T>() -> String {
291    let name = type_name::<T>();
292    let mut result = String::new();
293    for qual in name.split_inclusive(&['<', '>', ',']) {
294        result.push_str(qual.split("::").last().unwrap());
295    }
296    result
297}
298
299/// A C-style buffered input reader, similar to
300/// `core::iter::Iterator::next_chunk()` from nightly.
301///
302/// Returns an array of `MaybeUninit<T>` and the number of items in the
303/// array which have been correctly initialized.
304#[inline]
305fn iter_next_chunk_erased<const BUFLEN: usize, I: Iterator>(
306    iter: &mut I,
307) -> ([MaybeUninit<I::Item>; BUFLEN], usize)
308where
309    I::Item: Copy,
310{
311    let mut buf = [const { MaybeUninit::<I::Item>::uninit() }; BUFLEN];
312    let mut i = 0;
313
314    while i < BUFLEN {
315        if let Some(c) = iter.next() {
316            // Copy the next Item into `buf`.
317            unsafe {
318                buf.get_unchecked_mut(i).write(c);
319                i = i.unchecked_add(1);
320            }
321        } else {
322            // No more items in the iterator.
323            break;
324        }
325    }
326    (buf, i)
327}
328
329/// Gets a shared reference to the contained value.
330///
331/// # Safety
332///
333/// Calling this when the content is not yet fully initialized causes undefined
334/// behavior: it is up to the caller to guarantee that every `MaybeUninit<T>` in
335/// the slice really is in an initialized state.
336///
337/// Copied from:
338/// https://doc.rust-lang.org/std/primitive.slice.html#method.assume_init_ref
339/// Once that is stabilized, this should be removed.
340#[inline(always)]
341pub const unsafe fn assume_init_ref<T>(slice: &[MaybeUninit<T>]) -> &[T] {
342    // SAFETY: casting `slice` to a `*const [T]` is safe since the caller guarantees that
343    // `slice` is initialized, and `MaybeUninit` is guaranteed to have the same layout as `T`.
344    // The pointer obtained is valid since it refers to memory owned by `slice` which is a
345    // reference and thus guaranteed to be valid for reads.
346    unsafe { &*(slice as *const [MaybeUninit<T>] as *const [T]) }
347}
348
349/// Split an iterator into small arrays and apply `func` to each.
350///
351/// Repeatedly read `BUFLEN` elements from `input` into an array and
352/// pass the array to `func` as a slice. If less than `BUFLEN`
353/// elements are remaining, that smaller slice is passed to `func` (if
354/// it is non-empty) and the function returns.
355#[inline]
356pub fn apply_to_chunks<const BUFLEN: usize, I, H>(input: I, mut func: H)
357where
358    I: IntoIterator<Item = u8>,
359    H: FnMut(&[I::Item]),
360{
361    let mut iter = input.into_iter();
362    loop {
363        let (buf, n) = iter_next_chunk_erased::<BUFLEN, _>(&mut iter);
364        if n == 0 {
365            break;
366        }
367        func(unsafe { assume_init_ref(buf.get_unchecked(..n)) });
368    }
369}
370
371/// Pulls `N` items from `iter` and returns them as an array. If the iterator
372/// yields fewer than `N` items (but more than `0`), pads by the given default value.
373///
374/// Since the iterator is passed as a mutable reference and this function calls
375/// `next` at most `N` times, the iterator can still be used afterwards to
376/// retrieve the remaining items.
377///
378/// If `iter.next()` panics, all items already yielded by the iterator are
379/// dropped.
380#[inline]
381fn iter_next_chunk_padded<T: Copy, const N: usize>(
382    iter: &mut impl Iterator<Item = T>,
383    default: T, // Needed due to [T; M] not always implementing Default. Can probably be dropped if const generics stabilize.
384) -> Option<[T; N]> {
385    let (mut arr, n) = iter_next_chunk_erased::<N, _>(iter);
386    (n != 0).then(|| {
387        // Fill the rest of the array with default values.
388        arr[n..].fill(MaybeUninit::new(default));
389        unsafe { mem::transmute_copy::<_, [T; N]>(&arr) }
390    })
391}
392
393/// Returns an iterator over `N` elements of the iterator at a time.
394///
395/// The chunks do not overlap. If `N` does not divide the length of the
396/// iterator, then the last `N-1` elements will be padded with the given default value.
397///
398/// This is essentially a copy pasted version of the nightly `array_chunks` function.
399/// https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.array_chunks
400/// Once that is stabilized this and the functions above it should be removed.
401#[inline]
402pub fn iter_array_chunks_padded<T: Copy, const N: usize>(
403    iter: impl IntoIterator<Item = T>,
404    default: T, // Needed due to [T; M] not always implementing Default. Can probably be dropped if const generics stabilize.
405) -> impl Iterator<Item = [T; N]> {
406    let mut iter = iter.into_iter();
407    iter::from_fn(move || iter_next_chunk_padded(&mut iter, default))
408}
409
410/// Reinterpret a slice of `BaseArray` elements as a slice of `Base` elements
411///
412/// This is useful to convert `&[F; N]` to `&[F]` or `&[A]` to `&[F]` where
413/// `A` has the same size, alignment and memory layout as `[F; N]` for some `N`.
414///
415/// # Safety
416///
417/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
418/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
419/// the array is the same as the alignment of its elements, this means that `BaseArray`
420/// must have the same alignment as `Base`.
421///
422/// # Panics
423///
424/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
425#[inline]
426pub const unsafe fn as_base_slice<Base, BaseArray>(buf: &[BaseArray]) -> &[Base] {
427    const {
428        assert!(align_of::<Base>() == align_of::<BaseArray>());
429        assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
430    }
431
432    let d = size_of::<BaseArray>() / size_of::<Base>();
433
434    let buf_ptr = buf.as_ptr().cast::<Base>();
435    let n = buf.len() * d;
436    unsafe { slice::from_raw_parts(buf_ptr, n) }
437}
438
439/// Reinterpret a mutable slice of `BaseArray` elements as a slice of `Base` elements
440///
441/// This is useful to convert `&[F; N]` to `&[F]` or `&[A]` to `&[F]` where
442/// `A` has the same size, alignment and memory layout as `[F; N]` for some `N`.
443///
444/// # Safety
445///
446/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
447/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
448/// the array is the same as the alignment of its elements, this means that `BaseArray`
449/// must have the same alignment as `Base`.
450///
451/// # Panics
452///
453/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
454#[inline]
455pub const unsafe fn as_base_slice_mut<Base, BaseArray>(buf: &mut [BaseArray]) -> &mut [Base] {
456    const {
457        assert!(align_of::<Base>() == align_of::<BaseArray>());
458        assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
459    }
460
461    let d = size_of::<BaseArray>() / size_of::<Base>();
462
463    let buf_ptr = buf.as_mut_ptr().cast::<Base>();
464    let n = buf.len() * d;
465    unsafe { slice::from_raw_parts_mut(buf_ptr, n) }
466}
467
468/// Convert a vector of `BaseArray` elements to a vector of `Base` elements without any
469/// reallocations.
470///
471/// This is useful to convert `Vec<[F; N]>` to `Vec<F>` or `Vec<A>` to `Vec<F>` where
472/// `A` has the same size, alignment and memory layout as `[F; N]` for some `N`. It can also,
473/// be used to safely convert `Vec<u32>` to `Vec<F>` if `F` is a `32` bit field
474/// or `Vec<u64>` to `Vec<F>` if `F` is a `64` bit field.
475///
476/// # Safety
477///
478/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
479/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
480/// the array is the same as the alignment of its elements, this means that `BaseArray`
481/// must have the same alignment as `Base`.
482///
483/// # Panics
484///
485/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
486#[inline]
487pub unsafe fn flatten_to_base<Base, BaseArray>(vec: Vec<BaseArray>) -> Vec<Base> {
488    const {
489        assert!(align_of::<Base>() == align_of::<BaseArray>());
490        assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
491    }
492
493    let d = size_of::<BaseArray>() / size_of::<Base>();
494    // Prevent running `vec`'s destructor so we are in complete control
495    // of the allocation.
496    let mut values = ManuallyDrop::new(vec);
497
498    // Each `Self` is an array of `d` elements, so the length and capacity of
499    // the new vector will be multiplied by `d`.
500    let new_len = values.len() * d;
501    let new_cap = values.capacity() * d;
502
503    // Safe as BaseArray and Base have the same alignment.
504    let ptr = values.as_mut_ptr() as *mut Base;
505
506    unsafe {
507        // Safety:
508        // - BaseArray and Base have the same alignment.
509        // - As size_of::<BaseArray>() == size_of::<Base>() * d:
510        //      -- The capacity of the new vector is equal to the capacity of the old vector.
511        //      -- The first new_len elements of the new vector correspond to the first
512        //         len elements of the old vector and so are properly initialized.
513        Vec::from_raw_parts(ptr, new_len, new_cap)
514    }
515}
516
517/// Convert a vector of `Base` elements to a vector of `BaseArray` elements ideally without any
518/// reallocations.
519///
520/// This is an inverse of `flatten_to_base`. Unfortunately, unlike `flatten_to_base`, it may not be
521/// possible to avoid allocations. This issue is that there is not way to guarantee that the capacity
522/// of the vector is a multiple of `d`.
523///
524/// # Safety
525///
526/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
527/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
528/// the array is the same as the alignment of its elements, this means that `BaseArray`
529/// must have the same alignment as `Base`.
530///
531/// # Panics
532///
533/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
534/// This panics if the length of the vector is not a multiple of the ratio of the sizes.
535#[inline]
536pub unsafe fn reconstitute_from_base<Base, BaseArray: Clone>(mut vec: Vec<Base>) -> Vec<BaseArray> {
537    const {
538        assert!(align_of::<Base>() == align_of::<BaseArray>());
539        assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
540    }
541
542    let d = size_of::<BaseArray>() / size_of::<Base>();
543
544    assert!(
545        vec.len().is_multiple_of(d),
546        "Vector length (got {}) must be a multiple of the extension field dimension ({}).",
547        vec.len(),
548        d
549    );
550
551    let new_len = vec.len() / d;
552
553    // We could call vec.shrink_to_fit() here to try and increase the probability that
554    // the capacity is a multiple of d. That might cause a reallocation though which
555    // would defeat the whole purpose.
556    let cap = vec.capacity();
557
558    // The assumption is that basically all callers of `reconstitute_from_base_vec` will be calling it
559    // with a vector constructed from `flatten_to_base` and so the capacity should be a multiple of `d`.
560    // But capacities can do strange things so we need to support both possibilities.
561    // Note that the `else` branch would also work if the capacity is a multiple of `d` but it is slower.
562    if cap.is_multiple_of(d) {
563        // Prevent running `vec`'s destructor so we are in complete control
564        // of the allocation.
565        let mut values = ManuallyDrop::new(vec);
566
567        // If we are on this branch then the capacity is a multiple of `d`.
568        let new_cap = cap / d;
569
570        // Safe as BaseArray and Base have the same alignment.
571        let ptr = values.as_mut_ptr() as *mut BaseArray;
572
573        unsafe {
574            // Safety:
575            // - BaseArray and Base have the same alignment.
576            // - As size_of::<Base>() == size_of::<BaseArray>() / d:
577            //      -- If we have reached this point, the length and capacity are both divisible by `d`.
578            //      -- The capacity of the new vector is equal to the capacity of the old vector.
579            //      -- The first new_len elements of the new vector correspond to the first
580            //         len elements of the old vector and so are properly initialized.
581            Vec::from_raw_parts(ptr, new_len, new_cap)
582        }
583    } else {
584        // If the capacity is not a multiple of `D`, we go via slices.
585
586        let buf_ptr = vec.as_mut_ptr().cast::<BaseArray>();
587        let slice = unsafe {
588            // Safety:
589            // - BaseArray and Base have the same alignment.
590            // - As size_of::<Base>() == size_of::<BaseArray>() / D:
591            //      -- If we have reached this point, the length is divisible by `D`.
592            //      -- The first new_len elements of the slice correspond to the first
593            //         len elements of the old slice and so are properly initialized.
594            slice::from_raw_parts(buf_ptr, new_len)
595        };
596
597        // Ideally the compiler could optimize this away to avoid the copy but it appears not to.
598        slice.to_vec()
599    }
600}
601
602#[inline(always)]
603pub const fn relatively_prime_u64(mut u: u64, mut v: u64) -> bool {
604    // Check that neither input is 0.
605    if u == 0 || v == 0 {
606        return false;
607    }
608
609    // Check divisibility by 2.
610    if (u | v) & 1 == 0 {
611        return false;
612    }
613
614    // Remove factors of 2 from `u` and `v`
615    u >>= u.trailing_zeros();
616    if u == 1 {
617        return true;
618    }
619
620    while v != 0 {
621        v >>= v.trailing_zeros();
622        if v == 1 {
623            return true;
624        }
625
626        // Ensure u <= v
627        if u > v {
628            core::mem::swap(&mut u, &mut v);
629        }
630
631        // This looks inefficient for v >> u but thanks to the fact that we remove
632        // trailing_zeros of v in every iteration, it ends up much more performative
633        // than first glance implies.
634        v -= u;
635    }
636    // If we made it through the loop, at no point is u or v equal to 1 and so the gcd
637    // must be greater than 1.
638    false
639}
640
641/// Inner loop of the deferred GCD algorithm.
642///
643/// See: https://eprint.iacr.org/2020/972.pdf for more information.
644///
645/// This is basically a mini GCD algorithm which builds up a transformation to apply to the larger
646/// numbers in the main loop. The key point is that this small loop only uses u64s, subtractions and
647/// bit shifts, which are very fast operations.
648///
649/// The bottom `NUM_ROUNDS` bits of `a` and `b` should match the bottom `NUM_ROUNDS` bits of
650/// the corresponding big-ints and the top `NUM_ROUNDS + 2` should match the top bits including
651/// zeroes if the original numbers have different sizes.
652#[inline]
653pub const fn gcd_inner<const NUM_ROUNDS: usize>(a: &mut u64, b: &mut u64) -> (i64, i64, i64, i64) {
654    // Initialise update factors.
655    // At the start of round 0: -1 < f0, g0, f1, g1 <= 1
656    let (mut f0, mut g0, mut f1, mut g1) = (1, 0, 0, 1);
657
658    // If at the start of a round: -2^i < f0, g0, f1, g1 <= 2^i
659    // Then, at the end of the round: -2^{i + 1} < f0, g0, f1, g1 <= 2^{i + 1}
660    // use manual `while` loop to enable `const`
661    let mut round = 0;
662    while round < NUM_ROUNDS {
663        if *a & 1 == 0 {
664            *a >>= 1;
665        } else {
666            if *a < *b {
667                core::mem::swap(a, b);
668                (f0, f1) = (f1, f0);
669                (g0, g1) = (g1, g0);
670            }
671            *a -= *b;
672            *a >>= 1;
673            f0 -= f1;
674            g0 -= g1;
675        }
676        f1 <<= 1;
677        g1 <<= 1;
678
679        round += 1;
680    }
681
682    // -2^NUM_ROUNDS < f0, g0, f1, g1 <= 2^NUM_ROUNDS
683    // Hence provided NUM_ROUNDS <= 62, we will not get any overflow.
684    // Additionally, if NUM_ROUNDS <= 63, then the only source of overflow will be
685    // if a variable is meant to equal 2^{63} in which case it will overflow to -2^{63}.
686    (f0, g0, f1, g1)
687}
688
689/// Inverts elements inside the prime field `F_P` with `P < 2^FIELD_BITS`.
690///
691/// Arguments:
692///  - a: The value we want to invert. It must be < P.
693///  - b: The value of the prime `P > 2`.
694///
695/// Output:
696/// - A `64-bit` signed integer `v` equal to `2^{2 * FIELD_BITS - 2} a^{-1} mod P` with
697///   size `|v| < 2^{2 * FIELD_BITS - 2}`.
698///
699/// It is up to the user to ensure that `b` is an odd prime with at most `FIELD_BITS` bits and
700/// `a < b`. If either of these assumptions break, the output is undefined.
701#[inline]
702pub const fn gcd_inversion_prime_field_32<const FIELD_BITS: u32>(mut a: u32, mut b: u32) -> i64 {
703    const {
704        assert!(FIELD_BITS <= 32);
705    }
706    debug_assert!(((1_u64 << FIELD_BITS) - 1) >= b as u64);
707
708    // Initialise u, v. Note that |u|, |v| <= 2^0
709    let (mut u, mut v) = (1_i64, 0_i64);
710
711    // Let a0 and P denote the initial values of a and b. Observe:
712    // `a = u * a0 mod P`
713    // `b = v * a0 mod P`
714    // `len(a) + len(b) <= 2 * len(P) <= 2 * FIELD_BITS`
715
716    // use manual `while` loop to enable `const`
717    let mut i = 0;
718    while i < 2 * FIELD_BITS - 2 {
719        // Assume at the start of the loop i:
720        // (1) `|u|, |v| <= 2^{i}`
721        // (2) `2^i * a = u * a0 mod P`
722        // (3) `2^i * b = v * a0 mod P`
723        // (4) `gcd(a, b) = 1`
724        // (5) `b` is odd.
725        // (6) `len(a) + len(b) <= max(n - i, 1)`
726
727        if a & 1 != 0 {
728            if a < b {
729                (a, b) = (b, a);
730                (u, v) = (v, u);
731            }
732            // As b < a, this subtraction cannot increase `len(a) + len(b)`
733            a -= b;
734            // Observe |u'| = |u - v| <= |u| + |v| <= 2^{i + 1}
735            u -= v;
736
737            // As (1) and (2) hold, we have
738            // `2^i a' = 2^i * (a - b) = (u - v) * a0 mod P = u' * a0 mod P`
739        }
740        // As b is odd, a must now be even.
741        // This reduces `len(a) + len(b)` by 1 (unless `a = 0` in which case `b = 1` and the sum of the lengths is always 1)
742        a >>= 1;
743
744        // Observe |v'| = 2|v| <= 2^{i + 1}
745        v <<= 1;
746
747        // Thus as the end of loop i:
748        // (1) `|u|, |v| <= 2^{i + 1}`
749        // (2) `2^{i + 1} * a = u * a0 mod P`  (As we have halved a)
750        // (3) `2^{i + 1} * b = v * a0 mod P`  (As we have doubled v)
751        // (4) `gcd(a, b) = 1`
752        // (5) `b` is odd.
753        // (6) `len(a) + len(b) <= max(n - i - 1, 1)`
754
755        i += 1;
756    }
757
758    // After the loops, we see that:
759    // |u|, |v| <= 2^{2 * FIELD_BITS - 2}: Hence for FIELD_BITS <= 32 we will not overflow an i64.
760    // `2^{2 * FIELD_BITS - 2} * b = v * a0 mod P`
761    // `len(a) + len(b) <= 2` with `gcd(a, b) = 1` and `b` odd.
762    // This implies that `b` must be `1` and so `v = 2^{2 * FIELD_BITS - 2} a0^{-1} mod P` as desired.
763    v
764}
765
766#[cfg(test)]
767mod tests {
768    use alloc::vec;
769    use alloc::vec::Vec;
770
771    use rand::rngs::SmallRng;
772    use rand::{Rng, SeedableRng};
773
774    use super::*;
775
776    #[test]
777    fn test_reverse_bits_len() {
778        assert_eq!(reverse_bits_len(0b0000000000, 10), 0b0000000000);
779        assert_eq!(reverse_bits_len(0b0000000001, 10), 0b1000000000);
780        assert_eq!(reverse_bits_len(0b1000000000, 10), 0b0000000001);
781        assert_eq!(reverse_bits_len(0b00000, 5), 0b00000);
782        assert_eq!(reverse_bits_len(0b01011, 5), 0b11010);
783    }
784
785    #[test]
786    fn test_reverse_index_bits() {
787        let mut arg = vec![10, 20, 30, 40];
788        reverse_slice_index_bits(&mut arg);
789        assert_eq!(arg, vec![10, 30, 20, 40]);
790
791        let mut input256: Vec<u64> = (0..256).collect();
792        #[rustfmt::skip]
793        let output256: Vec<u64> = vec![
794            0x00, 0x80, 0x40, 0xc0, 0x20, 0xa0, 0x60, 0xe0, 0x10, 0x90, 0x50, 0xd0, 0x30, 0xb0, 0x70, 0xf0,
795            0x08, 0x88, 0x48, 0xc8, 0x28, 0xa8, 0x68, 0xe8, 0x18, 0x98, 0x58, 0xd8, 0x38, 0xb8, 0x78, 0xf8,
796            0x04, 0x84, 0x44, 0xc4, 0x24, 0xa4, 0x64, 0xe4, 0x14, 0x94, 0x54, 0xd4, 0x34, 0xb4, 0x74, 0xf4,
797            0x0c, 0x8c, 0x4c, 0xcc, 0x2c, 0xac, 0x6c, 0xec, 0x1c, 0x9c, 0x5c, 0xdc, 0x3c, 0xbc, 0x7c, 0xfc,
798            0x02, 0x82, 0x42, 0xc2, 0x22, 0xa2, 0x62, 0xe2, 0x12, 0x92, 0x52, 0xd2, 0x32, 0xb2, 0x72, 0xf2,
799            0x0a, 0x8a, 0x4a, 0xca, 0x2a, 0xaa, 0x6a, 0xea, 0x1a, 0x9a, 0x5a, 0xda, 0x3a, 0xba, 0x7a, 0xfa,
800            0x06, 0x86, 0x46, 0xc6, 0x26, 0xa6, 0x66, 0xe6, 0x16, 0x96, 0x56, 0xd6, 0x36, 0xb6, 0x76, 0xf6,
801            0x0e, 0x8e, 0x4e, 0xce, 0x2e, 0xae, 0x6e, 0xee, 0x1e, 0x9e, 0x5e, 0xde, 0x3e, 0xbe, 0x7e, 0xfe,
802            0x01, 0x81, 0x41, 0xc1, 0x21, 0xa1, 0x61, 0xe1, 0x11, 0x91, 0x51, 0xd1, 0x31, 0xb1, 0x71, 0xf1,
803            0x09, 0x89, 0x49, 0xc9, 0x29, 0xa9, 0x69, 0xe9, 0x19, 0x99, 0x59, 0xd9, 0x39, 0xb9, 0x79, 0xf9,
804            0x05, 0x85, 0x45, 0xc5, 0x25, 0xa5, 0x65, 0xe5, 0x15, 0x95, 0x55, 0xd5, 0x35, 0xb5, 0x75, 0xf5,
805            0x0d, 0x8d, 0x4d, 0xcd, 0x2d, 0xad, 0x6d, 0xed, 0x1d, 0x9d, 0x5d, 0xdd, 0x3d, 0xbd, 0x7d, 0xfd,
806            0x03, 0x83, 0x43, 0xc3, 0x23, 0xa3, 0x63, 0xe3, 0x13, 0x93, 0x53, 0xd3, 0x33, 0xb3, 0x73, 0xf3,
807            0x0b, 0x8b, 0x4b, 0xcb, 0x2b, 0xab, 0x6b, 0xeb, 0x1b, 0x9b, 0x5b, 0xdb, 0x3b, 0xbb, 0x7b, 0xfb,
808            0x07, 0x87, 0x47, 0xc7, 0x27, 0xa7, 0x67, 0xe7, 0x17, 0x97, 0x57, 0xd7, 0x37, 0xb7, 0x77, 0xf7,
809            0x0f, 0x8f, 0x4f, 0xcf, 0x2f, 0xaf, 0x6f, 0xef, 0x1f, 0x9f, 0x5f, 0xdf, 0x3f, 0xbf, 0x7f, 0xff,
810        ];
811        reverse_slice_index_bits(&mut input256[..]);
812        assert_eq!(input256, output256);
813    }
814
815    #[test]
816    fn test_apply_to_chunks_exact_fit() {
817        const CHUNK_SIZE: usize = 4;
818        let input: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8];
819        let mut results: Vec<Vec<u8>> = Vec::new();
820
821        apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
822            results.push(chunk.to_vec());
823        });
824
825        assert_eq!(results, vec![vec![1, 2, 3, 4], vec![5, 6, 7, 8]]);
826    }
827
828    #[test]
829    fn test_apply_to_chunks_with_remainder() {
830        const CHUNK_SIZE: usize = 3;
831        let input: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7];
832        let mut results: Vec<Vec<u8>> = Vec::new();
833
834        apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
835            results.push(chunk.to_vec());
836        });
837
838        assert_eq!(results, vec![vec![1, 2, 3], vec![4, 5, 6], vec![7]]);
839    }
840
841    #[test]
842    fn test_apply_to_chunks_empty_input() {
843        const CHUNK_SIZE: usize = 4;
844        let input: Vec<u8> = vec![];
845        let mut results: Vec<Vec<u8>> = Vec::new();
846
847        apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
848            results.push(chunk.to_vec());
849        });
850
851        assert!(results.is_empty());
852    }
853
854    #[test]
855    fn test_apply_to_chunks_single_chunk() {
856        const CHUNK_SIZE: usize = 10;
857        let input: Vec<u8> = vec![1, 2, 3, 4, 5];
858        let mut results: Vec<Vec<u8>> = Vec::new();
859
860        apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
861            results.push(chunk.to_vec());
862        });
863
864        assert_eq!(results, vec![vec![1, 2, 3, 4, 5]]);
865    }
866
867    #[test]
868    fn test_apply_to_chunks_large_chunk_size() {
869        const CHUNK_SIZE: usize = 100;
870        let input: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8];
871        let mut results: Vec<Vec<u8>> = Vec::new();
872
873        apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
874            results.push(chunk.to_vec());
875        });
876
877        assert_eq!(results, vec![vec![1, 2, 3, 4, 5, 6, 7, 8]]);
878    }
879
880    #[test]
881    fn test_apply_to_chunks_large_input() {
882        const CHUNK_SIZE: usize = 5;
883        let input: Vec<u8> = (1..=20).collect();
884        let mut results: Vec<Vec<u8>> = Vec::new();
885
886        apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
887            results.push(chunk.to_vec());
888        });
889
890        assert_eq!(
891            results,
892            vec![
893                vec![1, 2, 3, 4, 5],
894                vec![6, 7, 8, 9, 10],
895                vec![11, 12, 13, 14, 15],
896                vec![16, 17, 18, 19, 20]
897            ]
898        );
899    }
900
901    #[test]
902    fn test_reverse_slice_index_bits_random() {
903        let lengths = [32, 128, 1 << 16];
904        let mut rng = SmallRng::seed_from_u64(1);
905        for _ in 0..32 {
906            for &length in &lengths {
907                let mut rand_list: Vec<u32> = Vec::with_capacity(length);
908                rand_list.resize_with(length, || rng.random());
909                let expect = reverse_index_bits_naive(&rand_list);
910
911                let mut actual = rand_list.clone();
912                reverse_slice_index_bits(&mut actual);
913
914                assert_eq!(actual, expect);
915            }
916        }
917    }
918
919    #[test]
920    fn test_log2_strict_usize_edge_cases() {
921        assert_eq!(log2_strict_usize(1), 0);
922        assert_eq!(log2_strict_usize(2), 1);
923        assert_eq!(log2_strict_usize(1 << 18), 18);
924        assert_eq!(log2_strict_usize(1 << 31), 31);
925        assert_eq!(
926            log2_strict_usize(1 << (usize::BITS - 1)),
927            usize::BITS as usize - 1
928        );
929    }
930
931    #[test]
932    #[should_panic]
933    fn test_log2_strict_usize_zero() {
934        let _ = log2_strict_usize(0);
935    }
936
937    #[test]
938    #[should_panic]
939    fn test_log2_strict_usize_nonpower_2() {
940        let _ = log2_strict_usize(0x78c341c65ae6d262);
941    }
942
943    #[test]
944    #[should_panic]
945    fn test_log2_strict_usize_max() {
946        let _ = log2_strict_usize(usize::MAX);
947    }
948
949    #[test]
950    fn test_log2_ceil_usize_comprehensive() {
951        // Powers of 2
952        assert_eq!(log2_ceil_usize(0), 0);
953        assert_eq!(log2_ceil_usize(1), 0);
954        assert_eq!(log2_ceil_usize(2), 1);
955        assert_eq!(log2_ceil_usize(1 << 18), 18);
956        assert_eq!(log2_ceil_usize(1 << 31), 31);
957        assert_eq!(
958            log2_ceil_usize(1 << (usize::BITS - 1)),
959            usize::BITS as usize - 1
960        );
961
962        // Nonpowers; want to round up
963        assert_eq!(log2_ceil_usize(3), 2);
964        assert_eq!(log2_ceil_usize(0x14fe901b), 29);
965        assert_eq!(
966            log2_ceil_usize((1 << (usize::BITS - 1)) + 1),
967            usize::BITS as usize
968        );
969        assert_eq!(log2_ceil_usize(usize::MAX - 1), usize::BITS as usize);
970        assert_eq!(log2_ceil_usize(usize::MAX), usize::BITS as usize);
971    }
972
973    fn reverse_index_bits_naive<T: Copy>(arr: &[T]) -> Vec<T> {
974        let n = arr.len();
975        let n_power = log2_strict_usize(n);
976
977        let mut out = vec![None; n];
978        for (i, v) in arr.iter().enumerate() {
979            let dst = i.reverse_bits() >> (usize::BITS - n_power as u32);
980            out[dst] = Some(*v);
981        }
982
983        out.into_iter().map(|x| x.unwrap()).collect()
984    }
985
986    #[test]
987    fn test_relatively_prime_u64() {
988        // Zero cases (should always return false)
989        assert!(!relatively_prime_u64(0, 0));
990        assert!(!relatively_prime_u64(10, 0));
991        assert!(!relatively_prime_u64(0, 10));
992        assert!(!relatively_prime_u64(0, 123456789));
993
994        // Number with itself (if greater than 1, not relatively prime)
995        assert!(relatively_prime_u64(1, 1));
996        assert!(!relatively_prime_u64(10, 10));
997        assert!(!relatively_prime_u64(99999, 99999));
998
999        // Powers of 2 (always false since they share factor 2)
1000        assert!(!relatively_prime_u64(2, 4));
1001        assert!(!relatively_prime_u64(16, 32));
1002        assert!(!relatively_prime_u64(64, 128));
1003        assert!(!relatively_prime_u64(1024, 4096));
1004        assert!(!relatively_prime_u64(u64::MAX, u64::MAX));
1005
1006        // One number is a multiple of the other (always false)
1007        assert!(!relatively_prime_u64(5, 10));
1008        assert!(!relatively_prime_u64(12, 36));
1009        assert!(!relatively_prime_u64(15, 45));
1010        assert!(!relatively_prime_u64(100, 500));
1011
1012        // Co-prime numbers (should be true)
1013        assert!(relatively_prime_u64(17, 31));
1014        assert!(relatively_prime_u64(97, 43));
1015        assert!(relatively_prime_u64(7919, 65537));
1016        assert!(relatively_prime_u64(15485863, 32452843));
1017
1018        // Small prime numbers (should be true)
1019        assert!(relatively_prime_u64(13, 17));
1020        assert!(relatively_prime_u64(101, 103));
1021        assert!(relatively_prime_u64(1009, 1013));
1022
1023        // Large numbers (some cases where they are relatively prime or not)
1024        assert!(!relatively_prime_u64(
1025            190266297176832000,
1026            10430732356495263744
1027        ));
1028        assert!(!relatively_prime_u64(
1029            2040134905096275968,
1030            5701159354248194048
1031        ));
1032        assert!(!relatively_prime_u64(
1033            16611311494648745984,
1034            7514969329383038976
1035        ));
1036        assert!(!relatively_prime_u64(
1037            14863931409971066880,
1038            7911906750992527360
1039        ));
1040
1041        // Max values
1042        assert!(relatively_prime_u64(u64::MAX, 1));
1043        assert!(relatively_prime_u64(u64::MAX, u64::MAX - 1));
1044        assert!(!relatively_prime_u64(u64::MAX, u64::MAX));
1045    }
1046}