Skip to main content

p3_util/
lib.rs

1//! Various simple utilities.
2
3#![no_std]
4
5extern crate alloc;
6
7use alloc::slice;
8use alloc::string::String;
9use alloc::vec::Vec;
10use core::any::type_name;
11use core::hint::unreachable_unchecked;
12use core::mem::{ManuallyDrop, MaybeUninit};
13use core::{iter, mem};
14
15use crate::transpose::transpose_in_place_square;
16
17pub mod array_serialization;
18pub mod linear_map;
19pub mod transpose;
20pub mod zip_eq;
21
22/// Computes `ceil(log_2(n))`.
23#[must_use]
24pub const fn log2_ceil_usize(n: usize) -> usize {
25    (usize::BITS - n.saturating_sub(1).leading_zeros()) as usize
26}
27
28/// Computes `floor(log_2(n))`.
29///
30/// Returns `0` for `n == 0` (matching `log2_ceil_usize(0) == 0`); `floor(log2(0))`
31/// is undefined mathematically and the saturating behaviour is the convention used
32/// elsewhere in the workspace.
33#[must_use]
34pub const fn log2_floor_usize(n: usize) -> usize {
35    if n == 0 {
36        return 0;
37    }
38    (usize::BITS - 1 - n.leading_zeros()) as usize
39}
40
41#[must_use]
42pub const fn log2_ceil_u64(n: u64) -> u64 {
43    (u64::BITS - n.saturating_sub(1).leading_zeros()) as u64
44}
45
46/// Returns `2^log_degree` if it can be represented by `usize`.
47#[must_use]
48pub const fn checked_pow2(log_degree: usize) -> Option<usize> {
49    if log_degree < usize::BITS as usize {
50        Some(1usize << log_degree)
51    } else {
52        None
53    }
54}
55
56/// Adds two log-sizes and computes the resulting power of two.
57///
58/// Returns:
59/// - `(a + b, 2^(a + b))` when the sum fits in a `usize` shift,
60/// - `None` if the addition overflows or the resulting power exceeds the representable range.
61#[must_use]
62pub const fn checked_log_size_sum(a: usize, b: usize) -> Option<(usize, usize)> {
63    match a.checked_add(b) {
64        Some(sum) => match checked_pow2(sum) {
65            Some(size) => Some((sum, size)),
66            None => None,
67        },
68        None => None,
69    }
70}
71
72/// Computes `log_2(n)`
73///
74/// # Panics
75/// Panics if `n` is not a power of two.
76#[must_use]
77#[inline]
78pub const fn log2_strict_usize(n: usize) -> usize {
79    let res = n.trailing_zeros();
80    assert!(n.wrapping_shr(res) == 1, "Not a power of two");
81    // Tell the optimizer about the semantics of `log2_strict`. i.e. it can replace `n` with
82    // `1 << res` and vice versa.
83    unsafe {
84        assume(n == 1 << res);
85    }
86    res as usize
87}
88
89/// Precomputed table of all powers of 3 that fit in a `u64`.
90///
91/// The maximum power is `3^40 = 12_157_665_459_056_928_801`.
92///
93/// We use `u64` instead of `usize` so the table compiles safely on 32-bit targets,
94/// where `3^40` would overflow a 32-bit `usize`.
95const POWERS_OF_3: [u64; 41] = {
96    // Start with 3^0 = 1.
97    let mut table = [0u64; 41];
98    table[0] = 1;
99
100    // Fill iteratively: each entry is 3 times the previous one.
101    let mut i = 1;
102    while i < 41 {
103        table[i] = table[i - 1] * 3;
104        i += 1;
105    }
106    table
107};
108
109/// Maps a bit-position (i.e. `floor(log2(n))`) to the corresponding base-3 exponent.
110///
111/// Because `3^k` grows faster than `2^k`, every power of 3 has a unique highest set
112/// bit position. This lets us use `leading_zeros()` to jump straight to the answer
113/// in O(1) without any loop or binary search.
114///
115/// Entries that don't correspond to any power of 3 are unused (left as 0).
116const LOG2_TO_EXP: [u8; 64] = {
117    // Initialize every slot to 0.
118    let mut table = [0u8; 64];
119
120    // For each power of 3, record which log2 bucket it falls into.
121    let mut i = 0;
122    while i < 41 {
123        // Compute floor(log2(3^i)) via the highest set bit.
124        let log2 = (u64::BITS - 1 - POWERS_OF_3[i].leading_zeros()) as usize;
125
126        // Store the exponent i at the corresponding bit-position.
127        table[log2] = i as u8;
128        i += 1;
129    }
130    table
131};
132
133/// Computes the strict base-3 logarithm of `n`.
134///
135/// Returns `k` such that `3^k == n`. Panics if `n` is not a power of 3.
136///
137/// This is the base-3 analogue of [`log2_strict_usize`].
138///
139/// # Arguments
140///
141/// * `n` - A positive integer that must be a power of 3 (i.e., 1, 3, 9, 27, 81, ...).
142///
143/// # Returns
144///
145/// The exponent `k` where `3^k == n`.
146///
147/// # Panics
148///
149/// Panics if:
150/// - `n` is zero
151/// - `n` is not a power of 3
152#[must_use]
153#[inline]
154pub const fn log3_strict_usize(n: usize) -> usize {
155    // Zero has no logarithm - check explicitly for a clear error message.
156    assert!(n != 0, "log3_strict_usize: input must be non-zero");
157
158    // Instantly find the candidate exponent via the highest set bit.
159    //
160    // Because every power of 3 occupies a unique log2 bucket, this single
161    // lookup gives us the answer in O(1) with zero branches.
162    let log2 = (usize::BITS - 1 - n.leading_zeros()) as usize;
163    let res = LOG2_TO_EXP[log2] as usize;
164
165    // Verify the result: catches non-powers of 3 in a single O(1) check.
166    assert!(
167        POWERS_OF_3[res] as usize == n,
168        "log3_strict_usize: input is not a power of 3"
169    );
170
171    res
172}
173
174/// Returns `[0, ..., N - 1]`.
175#[must_use]
176pub const fn indices_arr<const N: usize>() -> [usize; N] {
177    let mut indices_arr = [0; N];
178    let mut i = 0;
179    while i < N {
180        indices_arr[i] = i;
181        i += 1;
182    }
183    indices_arr
184}
185
186/// Statically asserts that `T` implements [`Clone`].
187pub const fn assert_clone<T: Clone>() {}
188
189/// Statically asserts that `T` implements [`Send`].
190pub const fn assert_send<T: Send>() {}
191
192/// Statically asserts that `T` implements [`Sync`].
193pub const fn assert_sync<T: Sync>() {}
194
195#[inline]
196pub const fn reverse_bits(x: usize, n: usize) -> usize {
197    // Assert that n is a power of 2
198    debug_assert!(n.is_power_of_two());
199    reverse_bits_len(x, n.trailing_zeros() as usize)
200}
201
202#[inline]
203pub const fn reverse_bits_len(x: usize, bit_len: usize) -> usize {
204    // A `bit_len` wider than the word would underflow the shift below.
205    // That yields a wrong, non-panicking permutation in release, so reject it up front.
206    debug_assert!(bit_len <= usize::BITS as usize);
207    // NB: The only reason we need overflowing_shr() here as opposed
208    // to plain '>>' is to accommodate the case n == num_bits == 0,
209    // which would become `0 >> 64`. Rust thinks that any shift of 64
210    // bits causes overflow, even when the argument is zero.
211    x.reverse_bits()
212        .overflowing_shr(usize::BITS - bit_len as u32)
213        .0
214}
215
216// Lookup table of 6-bit reverses.
217// NB: 2^6=64 bytes is a cache line. A smaller table wastes cache space.
218#[cfg(not(target_arch = "aarch64"))]
219#[rustfmt::skip]
220const BIT_REVERSE_6BIT: &[u8] = &[
221    0o00, 0o40, 0o20, 0o60, 0o10, 0o50, 0o30, 0o70,
222    0o04, 0o44, 0o24, 0o64, 0o14, 0o54, 0o34, 0o74,
223    0o02, 0o42, 0o22, 0o62, 0o12, 0o52, 0o32, 0o72,
224    0o06, 0o46, 0o26, 0o66, 0o16, 0o56, 0o36, 0o76,
225    0o01, 0o41, 0o21, 0o61, 0o11, 0o51, 0o31, 0o71,
226    0o05, 0o45, 0o25, 0o65, 0o15, 0o55, 0o35, 0o75,
227    0o03, 0o43, 0o23, 0o63, 0o13, 0o53, 0o33, 0o73,
228    0o07, 0o47, 0o27, 0o67, 0o17, 0o57, 0o37, 0o77,
229];
230
231// Ensure that SMALL_ARR_SIZE >= 4 * BIG_T_SIZE.
232const BIG_T_SIZE: usize = 1 << 14;
233const SMALL_ARR_SIZE: usize = 1 << 16;
234
235/// Permutes `arr` such that each index is mapped to its reverse in binary.
236///
237/// If the whole array fits in fast cache, then the trivial algorithm is cache friendly. Also, if
238/// `T` is really big, then the trivial algorithm is cache-friendly, no matter the size of the array.
239pub fn reverse_slice_index_bits<F>(vals: &mut [F])
240where
241    F: Copy + Send + Sync,
242{
243    let n = vals.len();
244    if n == 0 {
245        return;
246    }
247    let log_n = log2_strict_usize(n);
248
249    // If the whole array fits in fast cache, then the trivial algorithm is cache friendly. Also, if
250    // `T` is really big, then the trivial algorithm is cache-friendly, no matter the size of the array.
251    if core::mem::size_of::<F>() << log_n <= SMALL_ARR_SIZE
252        || core::mem::size_of::<F>() >= BIG_T_SIZE
253    {
254        reverse_slice_index_bits_small(vals, log_n);
255    } else {
256        debug_assert!(n >= 4); // By our choice of `BIG_T_SIZE` and `SMALL_ARR_SIZE`.
257
258        // Algorithm:
259        //
260        // Treat `arr` as a `sqrt(n)` by `sqrt(n)` row-major matrix. (Assume for now that `lb_n` is
261        // even, i.e., `n` is a square number.) To perform bit-order reversal we:
262        //  1. Bit-reverse the order of the rows. (They are contiguous in memory, so this is
263        //     basically a series of large `memcpy`s.)
264        //  2. Transpose the matrix.
265        //  3. Bit-reverse the order of the rows.
266        //
267        // This is equivalent to, for every index `0 <= i < n`:
268        //  1. bit-reversing `i[lb_n / 2..lb_n]`,
269        //  2. swapping `i[0..lb_n / 2]` and `i[lb_n / 2..lb_n]`,
270        //  3. bit-reversing `i[lb_n / 2..lb_n]`.
271        //
272        // If `lb_n` is odd, i.e., `n` is not a square number, then the above procedure requires
273        // slight modification. At steps 1 and 3 we bit-reverse bits `ceil(lb_n / 2)..lb_n`, of the
274        // index (shuffling `floor(lb_n / 2)` chunks of length `ceil(lb_n / 2)`). At step 2, we
275        // perform _two_ transposes. We treat `arr` as two matrices, one where the middle bit of the
276        // index is `0` and another, where the middle bit is `1`; we transpose each individually.
277
278        let lb_num_chunks = log_n >> 1;
279        let lb_chunk_size = log_n - lb_num_chunks;
280        unsafe {
281            reverse_slice_index_bits_chunks(vals, lb_num_chunks, lb_chunk_size);
282            transpose_in_place_square(vals, lb_chunk_size, lb_num_chunks, 0);
283            if lb_num_chunks != lb_chunk_size {
284                // `arr` cannot be interpreted as a square matrix. We instead interpret it as a
285                // `1 << lb_num_chunks` by `2` by `1 << lb_num_chunks` tensor, in row-major order.
286                // The above transpose acted on `tensor[..., 0, ...]` (all indices with middle bit
287                // `0`). We still need to transpose `tensor[..., 1, ...]`. To do so, we advance
288                // arr by `1 << lb_num_chunks` effectively, adding that to every index.
289                let vals_with_offset = &mut vals[1 << lb_num_chunks..];
290                transpose_in_place_square(vals_with_offset, lb_chunk_size, lb_num_chunks, 0);
291            }
292            reverse_slice_index_bits_chunks(vals, lb_num_chunks, lb_chunk_size);
293        }
294    }
295}
296
297// Both functions below are semantically equivalent to:
298//     for i in 0..n {
299//         result.push(arr[reverse_bits(i, n_power)]);
300//     }
301// where reverse_bits(i, n_power) computes the n_power-bit reverse. The complications are there
302// to guide the compiler to generate optimal assembly.
303
304#[cfg(not(target_arch = "aarch64"))]
305fn reverse_slice_index_bits_small<F>(vals: &mut [F], lb_n: usize) {
306    if lb_n <= 6 {
307        // BIT_REVERSE_6BIT holds 6-bit reverses. This shift makes them lb_n-bit reverses.
308        let dst_shr_amt = 6 - lb_n as u32;
309        for (src, &br) in BIT_REVERSE_6BIT.iter().enumerate().take(vals.len()) {
310            let dst = (br as usize).wrapping_shr(dst_shr_amt);
311            if src < dst {
312                vals.swap(src, dst);
313            }
314        }
315    } else {
316        // LLVM does not know that it does not need to reverse src at each iteration (which is
317        // expensive on x86). We take advantage of the fact that the low bits of dst change rarely and the high
318        // bits of dst are dependent only on the low bits of src.
319        let dst_lo_shr_amt = usize::BITS - (lb_n - 6) as u32;
320        let dst_hi_shl_amt = lb_n - 6;
321        for src_chunk in 0..(vals.len() >> 6) {
322            let src_hi = src_chunk << 6;
323            let dst_lo = src_chunk.reverse_bits().wrapping_shr(dst_lo_shr_amt);
324            for (src_lo, &br) in BIT_REVERSE_6BIT.iter().enumerate() {
325                let dst_hi = (br as usize) << dst_hi_shl_amt;
326                let src = src_hi + src_lo;
327                let dst = dst_hi + dst_lo;
328                if src < dst {
329                    vals.swap(src, dst);
330                }
331            }
332        }
333    }
334}
335
336#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
337const fn reverse_slice_index_bits_small<F>(vals: &mut [F], lb_n: usize) {
338    // Aarch64 can reverse bits in one instruction, so the trivial version works best.
339    // use manual `while` loop to enable `const`
340    let mut src = 0;
341    while src < vals.len() {
342        let dst = src.reverse_bits().wrapping_shr(usize::BITS - lb_n as u32);
343        if src < dst {
344            vals.swap(src, dst);
345        }
346
347        src += 1;
348    }
349}
350
351/// Split `arr` chunks and bit-reverse the order of the chunks. There are `1 << lb_num_chunks`
352/// chunks, each of length `1 << lb_chunk_size`.
353/// SAFETY: ensure that `arr.len() == 1 << lb_num_chunks + lb_chunk_size`.
354unsafe fn reverse_slice_index_bits_chunks<F>(
355    vals: &mut [F],
356    lb_num_chunks: usize,
357    lb_chunk_size: usize,
358) {
359    for i in 0..1usize << lb_num_chunks {
360        // `wrapping_shr` handles the silly case when `lb_num_chunks == 0`.
361        let j = i
362            .reverse_bits()
363            .wrapping_shr(usize::BITS - lb_num_chunks as u32);
364        if i < j {
365            unsafe {
366                core::ptr::swap_nonoverlapping(
367                    vals.get_unchecked_mut(i << lb_chunk_size),
368                    vals.get_unchecked_mut(j << lb_chunk_size),
369                    1 << lb_chunk_size,
370                );
371            }
372        }
373    }
374}
375
376/// Allow the compiler to assume that the given predicate `p` is always `true`.
377///
378/// # Safety
379///
380/// Callers must ensure that `p` is true. If this is not the case, the behavior is undefined.
381#[inline(always)]
382pub const unsafe fn assume(p: bool) {
383    debug_assert!(p);
384    if !p {
385        unsafe {
386            unreachable_unchecked();
387        }
388    }
389}
390
391/// Try to force Rust to emit a branch. Example:
392///
393/// ```no_run
394/// let x = 100;
395/// if x > 20 {
396///     println!("x is big!");
397///     p3_util::branch_hint();
398/// } else {
399///     println!("x is small!");
400/// }
401/// ```
402///
403/// This function has no semantics. It is a hint only.
404#[inline(always)]
405pub fn branch_hint() {
406    // NOTE: These are the currently supported assembly architectures. See the
407    // [nightly reference](https://doc.rust-lang.org/nightly/reference/inline-assembly.html) for
408    // the most up-to-date list.
409    #[cfg(any(
410        target_arch = "aarch64",
411        target_arch = "arm",
412        target_arch = "riscv32",
413        target_arch = "riscv64",
414        target_arch = "x86",
415        target_arch = "x86_64",
416    ))]
417    unsafe {
418        core::arch::asm!("", options(nomem, nostack, preserves_flags));
419    }
420}
421
422/// Return a String containing the name of T but with all the crate
423/// and module prefixes removed.
424pub fn pretty_name<T>() -> String {
425    let name = type_name::<T>();
426    let mut result = String::new();
427    for qual in name.split_inclusive(&['<', '>', ',']) {
428        result.push_str(qual.split("::").last().unwrap());
429    }
430    result
431}
432
433/// A C-style buffered input reader, similar to
434/// `core::iter::Iterator::next_chunk()` from nightly.
435///
436/// Returns an array of `MaybeUninit<T>` and the number of items in the
437/// array which have been correctly initialized.
438#[inline]
439fn iter_next_chunk_erased<const BUFLEN: usize, I: Iterator>(
440    iter: &mut I,
441) -> ([MaybeUninit<I::Item>; BUFLEN], usize)
442where
443    I::Item: Copy,
444{
445    let mut buf = [const { MaybeUninit::<I::Item>::uninit() }; BUFLEN];
446    let mut i = 0;
447
448    while i < BUFLEN {
449        if let Some(c) = iter.next() {
450            // Copy the next Item into `buf`.
451            unsafe {
452                buf.get_unchecked_mut(i).write(c);
453                i = i.unchecked_add(1);
454            }
455        } else {
456            // No more items in the iterator.
457            break;
458        }
459    }
460    (buf, i)
461}
462
463/// Split an iterator into small arrays and apply `func` to each.
464///
465/// Repeatedly read `BUFLEN` elements from `input` into an array and
466/// pass the array to `func` as a slice. If less than `BUFLEN`
467/// elements are remaining, that smaller slice is passed to `func` (if
468/// it is non-empty) and the function returns.
469#[inline]
470pub fn apply_to_chunks<const BUFLEN: usize, I, H>(input: I, mut func: H)
471where
472    I: IntoIterator<Item = u8>,
473    H: FnMut(&[I::Item]),
474{
475    let mut iter = input.into_iter();
476    loop {
477        let (buf, n) = iter_next_chunk_erased::<BUFLEN, _>(&mut iter);
478        if n == 0 {
479            break;
480        }
481        func(unsafe { buf.get_unchecked(..n).assume_init_ref() });
482    }
483}
484
485/// Pulls `N` items from `iter` and returns them as an array. If the iterator
486/// yields fewer than `N` items (but more than `0`), pads by the given default value.
487///
488/// Since the iterator is passed as a mutable reference and this function calls
489/// `next` at most `N` times, the iterator can still be used afterwards to
490/// retrieve the remaining items.
491///
492/// If `iter.next()` panics, all items already yielded by the iterator are
493/// dropped.
494#[inline]
495fn iter_next_chunk_padded<T: Copy, const N: usize>(
496    iter: &mut impl Iterator<Item = T>,
497    default: T, // Needed due to [T; M] not always implementing Default. Can probably be dropped if const generics stabilize.
498) -> Option<[T; N]> {
499    let (mut arr, n) = iter_next_chunk_erased::<N, _>(iter);
500    (n != 0).then(|| {
501        // Fill the rest of the array with default values.
502        arr[n..].fill(MaybeUninit::new(default));
503        unsafe { mem::transmute_copy::<_, [T; N]>(&arr) }
504    })
505}
506
507/// Returns an iterator over `N` elements of the iterator at a time.
508///
509/// The chunks do not overlap. If `N` does not divide the length of the
510/// iterator, then the last `N-1` elements will be padded with the given default value.
511///
512/// This is essentially a copy pasted version of the nightly `array_chunks` function.
513/// <https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.array_chunks>
514/// Once that is stabilized this and the functions above it should be removed.
515#[inline]
516pub fn iter_array_chunks_padded<T: Copy, const N: usize>(
517    iter: impl IntoIterator<Item = T>,
518    default: T, // Needed due to [T; M] not always implementing Default. Can probably be dropped if const generics stabilize.
519) -> impl Iterator<Item = [T; N]> {
520    let mut iter = iter.into_iter();
521    iter::from_fn(move || iter_next_chunk_padded(&mut iter, default))
522}
523
524/// Reinterpret a slice of `BaseArray` elements as a slice of `Base` elements
525///
526/// This is useful to convert `&[F; N]` to `&[F]` or `&[A]` to `&[F]` where
527/// `A` has the same size, alignment and memory layout as `[F; N]` for some `N`.
528///
529/// # Safety
530///
531/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
532/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
533/// the array is the same as the alignment of its elements, this means that `BaseArray`
534/// must have the same alignment as `Base`.
535///
536/// # Panics
537///
538/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
539#[inline]
540pub const unsafe fn as_base_slice<Base, BaseArray>(buf: &[BaseArray]) -> &[Base] {
541    const {
542        assert!(align_of::<Base>() == align_of::<BaseArray>());
543        assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
544    }
545
546    let d = size_of::<BaseArray>() / size_of::<Base>();
547
548    let buf_ptr = buf.as_ptr().cast::<Base>();
549    let n = buf.len() * d;
550    unsafe { slice::from_raw_parts(buf_ptr, n) }
551}
552
553/// Reinterpret a mutable slice of `BaseArray` elements as a slice of `Base` elements
554///
555/// This is useful to convert `&[F; N]` to `&[F]` or `&[A]` to `&[F]` where
556/// `A` has the same size, alignment and memory layout as `[F; N]` for some `N`.
557///
558/// # Safety
559///
560/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
561/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
562/// the array is the same as the alignment of its elements, this means that `BaseArray`
563/// must have the same alignment as `Base`.
564///
565/// # Panics
566///
567/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
568#[inline]
569pub const unsafe fn as_base_slice_mut<Base, BaseArray>(buf: &mut [BaseArray]) -> &mut [Base] {
570    const {
571        assert!(align_of::<Base>() == align_of::<BaseArray>());
572        assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
573    }
574
575    let d = size_of::<BaseArray>() / size_of::<Base>();
576
577    let buf_ptr = buf.as_mut_ptr().cast::<Base>();
578    let n = buf.len() * d;
579    unsafe { slice::from_raw_parts_mut(buf_ptr, n) }
580}
581
582/// Convert a vector of `BaseArray` elements to a vector of `Base` elements without any
583/// reallocations.
584///
585/// This is useful to convert `Vec<[F; N]>` to `Vec<F>` or `Vec<A>` to `Vec<F>` where
586/// `A` has the same size, alignment and memory layout as `[F; N]` for some `N`. It can also,
587/// be used to safely convert `Vec<u32>` to `Vec<F>` if `F` is a `32` bit field
588/// or `Vec<u64>` to `Vec<F>` if `F` is a `64` bit field.
589///
590/// # Safety
591///
592/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
593/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
594/// the array is the same as the alignment of its elements, this means that `BaseArray`
595/// must have the same alignment as `Base`.
596///
597/// # Panics
598///
599/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
600#[inline]
601pub unsafe fn flatten_to_base<Base, BaseArray>(vec: Vec<BaseArray>) -> Vec<Base> {
602    const {
603        assert!(align_of::<Base>() == align_of::<BaseArray>());
604        assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
605    }
606
607    let d = size_of::<BaseArray>() / size_of::<Base>();
608    // Prevent running `vec`'s destructor so we are in complete control
609    // of the allocation.
610    let mut values = ManuallyDrop::new(vec);
611
612    // Each `Self` is an array of `d` elements, so the length and capacity of
613    // the new vector will be multiplied by `d`.
614    let new_len = values.len() * d;
615    let new_cap = values.capacity() * d;
616
617    // Safe as BaseArray and Base have the same alignment.
618    let ptr = values.as_mut_ptr() as *mut Base;
619
620    unsafe {
621        // Safety:
622        // - BaseArray and Base have the same alignment.
623        // - As size_of::<BaseArray>() == size_of::<Base>() * d:
624        //      -- The capacity of the new vector is equal to the capacity of the old vector.
625        //      -- The first new_len elements of the new vector correspond to the first
626        //         len elements of the old vector and so are properly initialized.
627        Vec::from_raw_parts(ptr, new_len, new_cap)
628    }
629}
630
631/// Convert a vector of `Base` elements to a vector of `BaseArray` elements ideally without any
632/// reallocations.
633///
634/// This is an inverse of `flatten_to_base`. Unfortunately, unlike `flatten_to_base`, it may not be
635/// possible to avoid allocations. This issue is that there is not way to guarantee that the capacity
636/// of the vector is a multiple of `d`.
637///
638/// # Safety
639///
640/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
641/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
642/// the array is the same as the alignment of its elements, this means that `BaseArray`
643/// must have the same alignment as `Base`.
644///
645/// # Panics
646///
647/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
648/// This panics if the length of the vector is not a multiple of the ratio of the sizes.
649#[inline]
650pub unsafe fn reconstitute_from_base<Base, BaseArray: Clone>(mut vec: Vec<Base>) -> Vec<BaseArray> {
651    const {
652        assert!(align_of::<Base>() == align_of::<BaseArray>());
653        assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
654    }
655
656    let d = size_of::<BaseArray>() / size_of::<Base>();
657
658    assert!(
659        vec.len().is_multiple_of(d),
660        "Vector length (got {}) must be a multiple of the extension field dimension ({}).",
661        vec.len(),
662        d
663    );
664
665    let new_len = vec.len() / d;
666
667    // We could call vec.shrink_to_fit() here to try and increase the probability that
668    // the capacity is a multiple of d. That might cause a reallocation though which
669    // would defeat the whole purpose.
670    let cap = vec.capacity();
671
672    // The assumption is that basically all callers of `reconstitute_from_base_vec` will be calling it
673    // with a vector constructed from `flatten_to_base` and so the capacity should be a multiple of `d`.
674    // But capacities can do strange things so we need to support both possibilities.
675    // Note that the `else` branch would also work if the capacity is a multiple of `d` but it is slower.
676    if cap.is_multiple_of(d) {
677        // Prevent running `vec`'s destructor so we are in complete control
678        // of the allocation.
679        let mut values = ManuallyDrop::new(vec);
680
681        // If we are on this branch then the capacity is a multiple of `d`.
682        let new_cap = cap / d;
683
684        // Safe as BaseArray and Base have the same alignment.
685        let ptr = values.as_mut_ptr() as *mut BaseArray;
686
687        unsafe {
688            // Safety:
689            // - BaseArray and Base have the same alignment.
690            // - As size_of::<Base>() == size_of::<BaseArray>() / d:
691            //      -- If we have reached this point, the length and capacity are both divisible by `d`.
692            //      -- The capacity of the new vector is equal to the capacity of the old vector.
693            //      -- The first new_len elements of the new vector correspond to the first
694            //         len elements of the old vector and so are properly initialized.
695            Vec::from_raw_parts(ptr, new_len, new_cap)
696        }
697    } else {
698        // If the capacity is not a multiple of `D`, we go via slices.
699
700        let buf_ptr = vec.as_mut_ptr().cast::<BaseArray>();
701        let slice = unsafe {
702            // Safety:
703            // - BaseArray and Base have the same alignment.
704            // - As size_of::<Base>() == size_of::<BaseArray>() / D:
705            //      -- If we have reached this point, the length is divisible by `D`.
706            //      -- The first new_len elements of the slice correspond to the first
707            //         len elements of the old slice and so are properly initialized.
708            slice::from_raw_parts(buf_ptr, new_len)
709        };
710
711        // Ideally the compiler could optimize this away to avoid the copy but it appears not to.
712        slice.to_vec()
713    }
714}
715
716#[inline(always)]
717pub const fn relatively_prime_u64(mut u: u64, mut v: u64) -> bool {
718    // Check that neither input is 0.
719    if u == 0 || v == 0 {
720        return false;
721    }
722
723    // Check divisibility by 2.
724    if (u | v) & 1 == 0 {
725        return false;
726    }
727
728    // Remove factors of 2 from `u` and `v`
729    u >>= u.trailing_zeros();
730    if u == 1 {
731        return true;
732    }
733
734    while v != 0 {
735        v >>= v.trailing_zeros();
736        if v == 1 {
737            return true;
738        }
739
740        // Ensure u <= v
741        if u > v {
742            core::mem::swap(&mut u, &mut v);
743        }
744
745        // This looks inefficient for v >> u but thanks to the fact that we remove
746        // trailing_zeros of v in every iteration, it ends up much more performative
747        // than first glance implies.
748        v -= u;
749    }
750    // If we made it through the loop, at no point is u or v equal to 1 and so the gcd
751    // must be greater than 1.
752    false
753}
754
755/// Inner loop of the deferred GCD algorithm.
756///
757/// See: <https://eprint.iacr.org/2020/972.pdf> for more information.
758///
759/// This is basically a mini GCD algorithm which builds up a transformation to apply to the larger
760/// numbers in the main loop. The key point is that this small loop only uses u64s, subtractions and
761/// bit shifts, which are very fast operations.
762///
763/// The bottom `NUM_ROUNDS` bits of `a` and `b` should match the bottom `NUM_ROUNDS` bits of
764/// the corresponding big-ints and the top `NUM_ROUNDS + 2` should match the top bits including
765/// zeroes if the original numbers have different sizes.
766#[inline]
767pub const fn gcd_inner<const NUM_ROUNDS: usize>(a: &mut u64, b: &mut u64) -> (i64, i64, i64, i64) {
768    // Initialise update factors.
769    // At the start of round 0: -1 < f0, g0, f1, g1 <= 1
770    let (mut f0, mut g0, mut f1, mut g1) = (1, 0, 0, 1);
771
772    // If at the start of a round: -2^i < f0, g0, f1, g1 <= 2^i
773    // Then, at the end of the round: -2^{i + 1} < f0, g0, f1, g1 <= 2^{i + 1}
774    // use manual `while` loop to enable `const`
775    let mut round = 0;
776    while round < NUM_ROUNDS {
777        if *a & 1 == 0 {
778            *a >>= 1;
779        } else {
780            if *a < *b {
781                core::mem::swap(a, b);
782                (f0, f1) = (f1, f0);
783                (g0, g1) = (g1, g0);
784            }
785            *a -= *b;
786            *a >>= 1;
787            f0 -= f1;
788            g0 -= g1;
789        }
790        f1 <<= 1;
791        g1 <<= 1;
792
793        round += 1;
794    }
795
796    // -2^NUM_ROUNDS < f0, g0, f1, g1 <= 2^NUM_ROUNDS
797    // Hence provided NUM_ROUNDS <= 62, we will not get any overflow.
798    // Additionally, if NUM_ROUNDS <= 63, then the only source of overflow will be
799    // if a variable is meant to equal 2^{63} in which case it will overflow to -2^{63}.
800    (f0, g0, f1, g1)
801}
802
803/// Inverts elements inside the prime field `F_P` with `P < 2^FIELD_BITS`.
804///
805/// Arguments:
806///  - a: The value we want to invert. It must be < P.
807///  - b: The value of the prime `P > 2`.
808///
809/// Output:
810/// - A `64-bit` signed integer `v` equal to `2^{2 * FIELD_BITS - 2} a^{-1} mod P` with
811///   size `|v| < 2^{2 * FIELD_BITS - 2}`.
812///
813/// It is up to the user to ensure that `b` is an odd prime with at most `FIELD_BITS` bits and
814/// `a < b`. If either of these assumptions break, the output is undefined.
815#[inline]
816pub const fn gcd_inversion_prime_field_32<const FIELD_BITS: u32>(mut a: u32, mut b: u32) -> i64 {
817    const {
818        assert!(FIELD_BITS <= 32);
819    }
820    debug_assert!(((1_u64 << FIELD_BITS) - 1) >= b as u64);
821
822    // Initialise u, v. Note that |u|, |v| <= 2^0
823    let (mut u, mut v) = (1_i64, 0_i64);
824
825    // Let a0 and P denote the initial values of a and b. Observe:
826    // `a = u * a0 mod P`
827    // `b = v * a0 mod P`
828    // `len(a) + len(b) <= 2 * len(P) <= 2 * FIELD_BITS`
829
830    // use manual `while` loop to enable `const`
831    let mut i = 0;
832    while i < 2 * FIELD_BITS - 2 {
833        // Assume at the start of the loop i:
834        // (1) `|u|, |v| <= 2^{i}`
835        // (2) `2^i * a = u * a0 mod P`
836        // (3) `2^i * b = v * a0 mod P`
837        // (4) `gcd(a, b) = 1`
838        // (5) `b` is odd.
839        // (6) `len(a) + len(b) <= max(n - i, 1)`
840
841        if a & 1 != 0 {
842            if a < b {
843                (a, b) = (b, a);
844                (u, v) = (v, u);
845            }
846            // As b < a, this subtraction cannot increase `len(a) + len(b)`
847            a -= b;
848            // Observe |u'| = |u - v| <= |u| + |v| <= 2^{i + 1}
849            u -= v;
850
851            // As (1) and (2) hold, we have
852            // `2^i a' = 2^i * (a - b) = (u - v) * a0 mod P = u' * a0 mod P`
853        }
854        // As b is odd, a must now be even.
855        // This reduces `len(a) + len(b)` by 1 (unless `a = 0` in which case `b = 1` and the sum of the lengths is always 1)
856        a >>= 1;
857
858        // Observe |v'| = 2|v| <= 2^{i + 1}
859        v <<= 1;
860
861        // Thus as the end of loop i:
862        // (1) `|u|, |v| <= 2^{i + 1}`
863        // (2) `2^{i + 1} * a = u * a0 mod P`  (As we have halved a)
864        // (3) `2^{i + 1} * b = v * a0 mod P`  (As we have doubled v)
865        // (4) `gcd(a, b) = 1`
866        // (5) `b` is odd.
867        // (6) `len(a) + len(b) <= max(n - i - 1, 1)`
868
869        i += 1;
870    }
871
872    // After the loops, we see that:
873    // |u|, |v| <= 2^{2 * FIELD_BITS - 2}: Hence for FIELD_BITS <= 32 we will not overflow an i64.
874    // `2^{2 * FIELD_BITS - 2} * b = v * a0 mod P`
875    // `len(a) + len(b) <= 2` with `gcd(a, b) = 1` and `b` odd.
876    // This implies that `b` must be `1` and so `v = 2^{2 * FIELD_BITS - 2} a0^{-1} mod P` as desired.
877    v
878}
879
880/// A raw mutable pointer wrapper that implements [`Send`] and [`Sync`].
881///
882/// Used to enable parallel writes to disjoint slices of a pre-allocated buffer
883/// from within closures that require `Send + Sync` (e.g. `rayon::ParallelIterator::for_each_init`).
884///
885/// # Safety
886///
887/// The caller must ensure that concurrent accesses through this pointer always
888/// target **non-overlapping** memory regions.
889#[derive(Clone, Copy)]
890pub struct DisjointMutPtr<T>(*mut T);
891
892// SAFETY: The contract of DisjointMutPtr guarantees that each thread writes to
893// a disjoint region, so sharing the pointer across threads is safe.
894unsafe impl<T> Send for DisjointMutPtr<T> {}
895unsafe impl<T> Sync for DisjointMutPtr<T> {}
896
897impl<T> DisjointMutPtr<T> {
898    /// Create a new `DisjointMutPtr` from a mutable slice.
899    #[inline]
900    pub const fn new(slice: &mut [T]) -> Self {
901        Self(slice.as_mut_ptr())
902    }
903
904    /// Get a mutable slice starting at `offset` with `len` elements.
905    ///
906    /// # Safety
907    ///
908    /// The caller must ensure the range `[offset, offset+len)` is within bounds
909    /// and does not overlap with any other concurrent access.
910    #[inline]
911    pub const unsafe fn slice_mut(self, offset: usize, len: usize) -> &'static mut [T] {
912        unsafe { core::slice::from_raw_parts_mut(self.0.add(offset), len) }
913    }
914}
915
916#[cfg(test)]
917mod tests {
918    use alloc::vec;
919    use alloc::vec::Vec;
920
921    use proptest::prelude::*;
922    use rand::rngs::SmallRng;
923    use rand::{RngExt, SeedableRng};
924
925    use super::*;
926
927    #[test]
928    fn test_reverse_bits_len() {
929        assert_eq!(reverse_bits_len(0b0000000000, 10), 0b0000000000);
930        assert_eq!(reverse_bits_len(0b0000000001, 10), 0b1000000000);
931        assert_eq!(reverse_bits_len(0b1000000000, 10), 0b0000000001);
932        assert_eq!(reverse_bits_len(0b00000, 5), 0b00000);
933        assert_eq!(reverse_bits_len(0b01011, 5), 0b11010);
934    }
935
936    #[test]
937    fn test_reverse_bits_len_full_width() {
938        // A full-width reversal is the largest valid bit length and must reverse every bit.
939        let bits = usize::BITS as usize;
940        assert_eq!(reverse_bits_len(1, bits), 1 << (bits - 1));
941        assert_eq!(reverse_bits_len(1 << (bits - 1), bits), 1);
942    }
943
944    #[test]
945    #[cfg(debug_assertions)]
946    #[should_panic(expected = "bit_len <= usize::BITS")]
947    fn test_reverse_bits_len_rejects_oversized_bit_len() {
948        // One bit past the word width: the shift would underflow into a wrong permutation.
949        // The expected message pins the guard, not the incidental subtraction-overflow panic.
950        let _ = reverse_bits_len(0, usize::BITS as usize + 1);
951    }
952
953    #[test]
954    fn test_reverse_index_bits() {
955        let mut arg = vec![10, 20, 30, 40];
956        reverse_slice_index_bits(&mut arg);
957        assert_eq!(arg, vec![10, 30, 20, 40]);
958
959        let mut input256: Vec<u64> = (0..256).collect();
960        #[rustfmt::skip]
961        let output256: Vec<u64> = vec![
962            0x00, 0x80, 0x40, 0xc0, 0x20, 0xa0, 0x60, 0xe0, 0x10, 0x90, 0x50, 0xd0, 0x30, 0xb0, 0x70, 0xf0,
963            0x08, 0x88, 0x48, 0xc8, 0x28, 0xa8, 0x68, 0xe8, 0x18, 0x98, 0x58, 0xd8, 0x38, 0xb8, 0x78, 0xf8,
964            0x04, 0x84, 0x44, 0xc4, 0x24, 0xa4, 0x64, 0xe4, 0x14, 0x94, 0x54, 0xd4, 0x34, 0xb4, 0x74, 0xf4,
965            0x0c, 0x8c, 0x4c, 0xcc, 0x2c, 0xac, 0x6c, 0xec, 0x1c, 0x9c, 0x5c, 0xdc, 0x3c, 0xbc, 0x7c, 0xfc,
966            0x02, 0x82, 0x42, 0xc2, 0x22, 0xa2, 0x62, 0xe2, 0x12, 0x92, 0x52, 0xd2, 0x32, 0xb2, 0x72, 0xf2,
967            0x0a, 0x8a, 0x4a, 0xca, 0x2a, 0xaa, 0x6a, 0xea, 0x1a, 0x9a, 0x5a, 0xda, 0x3a, 0xba, 0x7a, 0xfa,
968            0x06, 0x86, 0x46, 0xc6, 0x26, 0xa6, 0x66, 0xe6, 0x16, 0x96, 0x56, 0xd6, 0x36, 0xb6, 0x76, 0xf6,
969            0x0e, 0x8e, 0x4e, 0xce, 0x2e, 0xae, 0x6e, 0xee, 0x1e, 0x9e, 0x5e, 0xde, 0x3e, 0xbe, 0x7e, 0xfe,
970            0x01, 0x81, 0x41, 0xc1, 0x21, 0xa1, 0x61, 0xe1, 0x11, 0x91, 0x51, 0xd1, 0x31, 0xb1, 0x71, 0xf1,
971            0x09, 0x89, 0x49, 0xc9, 0x29, 0xa9, 0x69, 0xe9, 0x19, 0x99, 0x59, 0xd9, 0x39, 0xb9, 0x79, 0xf9,
972            0x05, 0x85, 0x45, 0xc5, 0x25, 0xa5, 0x65, 0xe5, 0x15, 0x95, 0x55, 0xd5, 0x35, 0xb5, 0x75, 0xf5,
973            0x0d, 0x8d, 0x4d, 0xcd, 0x2d, 0xad, 0x6d, 0xed, 0x1d, 0x9d, 0x5d, 0xdd, 0x3d, 0xbd, 0x7d, 0xfd,
974            0x03, 0x83, 0x43, 0xc3, 0x23, 0xa3, 0x63, 0xe3, 0x13, 0x93, 0x53, 0xd3, 0x33, 0xb3, 0x73, 0xf3,
975            0x0b, 0x8b, 0x4b, 0xcb, 0x2b, 0xab, 0x6b, 0xeb, 0x1b, 0x9b, 0x5b, 0xdb, 0x3b, 0xbb, 0x7b, 0xfb,
976            0x07, 0x87, 0x47, 0xc7, 0x27, 0xa7, 0x67, 0xe7, 0x17, 0x97, 0x57, 0xd7, 0x37, 0xb7, 0x77, 0xf7,
977            0x0f, 0x8f, 0x4f, 0xcf, 0x2f, 0xaf, 0x6f, 0xef, 0x1f, 0x9f, 0x5f, 0xdf, 0x3f, 0xbf, 0x7f, 0xff,
978        ];
979        reverse_slice_index_bits(&mut input256[..]);
980        assert_eq!(input256, output256);
981    }
982
983    #[test]
984    fn test_apply_to_chunks_exact_fit() {
985        const CHUNK_SIZE: usize = 4;
986        let input: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8];
987        let mut results: Vec<Vec<u8>> = Vec::new();
988
989        apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
990            results.push(chunk.to_vec());
991        });
992
993        assert_eq!(results, vec![vec![1, 2, 3, 4], vec![5, 6, 7, 8]]);
994    }
995
996    #[test]
997    fn test_apply_to_chunks_with_remainder() {
998        const CHUNK_SIZE: usize = 3;
999        let input: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7];
1000        let mut results: Vec<Vec<u8>> = Vec::new();
1001
1002        apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
1003            results.push(chunk.to_vec());
1004        });
1005
1006        assert_eq!(results, vec![vec![1, 2, 3], vec![4, 5, 6], vec![7]]);
1007    }
1008
1009    #[test]
1010    fn test_apply_to_chunks_empty_input() {
1011        const CHUNK_SIZE: usize = 4;
1012        let input: Vec<u8> = vec![];
1013        let mut results: Vec<Vec<u8>> = Vec::new();
1014
1015        apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
1016            results.push(chunk.to_vec());
1017        });
1018
1019        assert!(results.is_empty());
1020    }
1021
1022    #[test]
1023    fn test_apply_to_chunks_single_chunk() {
1024        const CHUNK_SIZE: usize = 10;
1025        let input: Vec<u8> = vec![1, 2, 3, 4, 5];
1026        let mut results: Vec<Vec<u8>> = Vec::new();
1027
1028        apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
1029            results.push(chunk.to_vec());
1030        });
1031
1032        assert_eq!(results, vec![vec![1, 2, 3, 4, 5]]);
1033    }
1034
1035    #[test]
1036    fn test_apply_to_chunks_large_chunk_size() {
1037        const CHUNK_SIZE: usize = 100;
1038        let input: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8];
1039        let mut results: Vec<Vec<u8>> = Vec::new();
1040
1041        apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
1042            results.push(chunk.to_vec());
1043        });
1044
1045        assert_eq!(results, vec![vec![1, 2, 3, 4, 5, 6, 7, 8]]);
1046    }
1047
1048    #[test]
1049    fn test_apply_to_chunks_large_input() {
1050        const CHUNK_SIZE: usize = 5;
1051        let input: Vec<u8> = (1..=20).collect();
1052        let mut results: Vec<Vec<u8>> = Vec::new();
1053
1054        apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
1055            results.push(chunk.to_vec());
1056        });
1057
1058        assert_eq!(
1059            results,
1060            vec![
1061                vec![1, 2, 3, 4, 5],
1062                vec![6, 7, 8, 9, 10],
1063                vec![11, 12, 13, 14, 15],
1064                vec![16, 17, 18, 19, 20]
1065            ]
1066        );
1067    }
1068
1069    #[test]
1070    fn test_reverse_slice_index_bits_random() {
1071        let lengths = [32, 128, 1 << 16];
1072        let mut rng = SmallRng::seed_from_u64(1);
1073        for _ in 0..32 {
1074            for &length in &lengths {
1075                let mut rand_list: Vec<u32> = Vec::with_capacity(length);
1076                rand_list.resize_with(length, || rng.random());
1077                let expect = reverse_index_bits_naive(&rand_list);
1078
1079                let mut actual = rand_list.clone();
1080                reverse_slice_index_bits(&mut actual);
1081
1082                assert_eq!(actual, expect);
1083            }
1084        }
1085    }
1086
1087    #[test]
1088    fn test_log2_strict_usize_edge_cases() {
1089        assert_eq!(log2_strict_usize(1), 0);
1090        assert_eq!(log2_strict_usize(2), 1);
1091        assert_eq!(log2_strict_usize(1 << 18), 18);
1092        assert_eq!(log2_strict_usize(1 << 31), 31);
1093        assert_eq!(
1094            log2_strict_usize(1 << (usize::BITS - 1)),
1095            usize::BITS as usize - 1
1096        );
1097    }
1098
1099    #[test]
1100    fn test_checked_pow2() {
1101        // 2^0 = 1, the smallest valid exponent.
1102        assert_eq!(checked_pow2(0), Some(1));
1103
1104        // 2^1 = 2.
1105        assert_eq!(checked_pow2(1), Some(2));
1106
1107        // 2^5 = 32, a typical small power.
1108        assert_eq!(checked_pow2(5), Some(32));
1109
1110        // 2^10 = 1024, commonly used as a domain size in FRI.
1111        assert_eq!(checked_pow2(10), Some(1024));
1112
1113        // 2^20 = 1_048_576, a realistic large trace length.
1114        assert_eq!(checked_pow2(20), Some(1_048_576));
1115
1116        // Largest representable power: 2^(BITS - 1).
1117        // On a 64-bit platform this is 2^63 = 0x8000_0000_0000_0000.
1118        let max_exp = usize::BITS as usize - 1;
1119        assert_eq!(checked_pow2(max_exp), Some(1usize << max_exp));
1120
1121        // Exponent equal to the bit width would shift 1 out of range.
1122        //
1123        //     1_usize << 64  (on 64-bit)  →  overflow
1124        //
1125        // Must return `None`.
1126        assert_eq!(checked_pow2(usize::BITS as usize), None);
1127
1128        // One past the maximum: also out of range.
1129        assert_eq!(checked_pow2(usize::BITS as usize + 1), None);
1130
1131        // Extreme exponent: usize::MAX is astronomically beyond
1132        // representable range — must return `None`.
1133        assert_eq!(checked_pow2(usize::MAX), None);
1134    }
1135
1136    #[test]
1137    fn test_checked_log_size_sum() {
1138        // Both zero: 0 + 0 = 0, 2^0 = 1.
1139        assert_eq!(checked_log_size_sum(0, 0), Some((0, 1)));
1140
1141        // Identity cases: adding zero to either side is a no-op.
1142        assert_eq!(checked_log_size_sum(5, 0), Some((5, 32)));
1143        assert_eq!(checked_log_size_sum(0, 10), Some((10, 1024)));
1144
1145        // Typical FRI scenario: degree_bits=10, log_quotient_chunks=2.
1146        //
1147        //     10 + 2 = 12,  2^12 = 4096
1148        assert_eq!(checked_log_size_sum(10, 2), Some((12, 4096)));
1149
1150        // Commutativity: order of operands must not matter.
1151        assert_eq!(checked_log_size_sum(2, 10), Some((12, 4096)));
1152
1153        // Large realistic case: degree_bits=20, log_chunks=3.
1154        //
1155        //     20 + 3 = 23,  2^23 = 8_388_608
1156        assert_eq!(checked_log_size_sum(20, 3), Some((23, 8_388_608)));
1157
1158        // Largest representable sum: (BITS - 2) + 1 = BITS - 1.
1159        let almost_max = usize::BITS as usize - 2;
1160        let max_exp = usize::BITS as usize - 1;
1161        assert_eq!(
1162            checked_log_size_sum(almost_max, 1),
1163            Some((max_exp, 1usize << max_exp))
1164        );
1165
1166        // Sum exactly at the bit width: overflows the shift.
1167        //
1168        //     (BITS - 1) + 1 = BITS  →  2^BITS is unrepresentable  →  None
1169        assert_eq!(checked_log_size_sum(max_exp, 1), None);
1170
1171        // Both operands large but sum still within range.
1172        //
1173        //     32 + 31 = 63  (on 64-bit)  →  2^63 is representable
1174        let half = usize::BITS as usize / 2;
1175        let other_half = max_exp - half;
1176        assert_eq!(
1177            checked_log_size_sum(half, other_half),
1178            Some((max_exp, 1usize << max_exp))
1179        );
1180
1181        // Addition itself overflows usize, not just the shift.
1182        //
1183        //     usize::MAX + 1  →  checked_add returns None  →  None
1184        assert_eq!(checked_log_size_sum(usize::MAX, 1), None);
1185
1186        // Both operands at usize::MAX: addition doubly overflows.
1187        assert_eq!(checked_log_size_sum(usize::MAX, usize::MAX), None);
1188    }
1189
1190    #[test]
1191    #[should_panic]
1192    fn test_log2_strict_usize_zero() {
1193        let _ = log2_strict_usize(0);
1194    }
1195
1196    #[test]
1197    #[should_panic]
1198    fn test_log2_strict_usize_nonpower_2() {
1199        let _ = log2_strict_usize(0x78c341c65ae6d262);
1200    }
1201
1202    #[test]
1203    #[should_panic]
1204    fn test_log2_strict_usize_max() {
1205        let _ = log2_strict_usize(usize::MAX);
1206    }
1207
1208    #[test]
1209    fn test_log3_strict_powers_of_3() {
1210        // Test all powers of 3 up to 3^12 = 531441.
1211        assert_eq!(log3_strict_usize(1), 0);
1212        assert_eq!(log3_strict_usize(3), 1);
1213        assert_eq!(log3_strict_usize(9), 2);
1214        assert_eq!(log3_strict_usize(27), 3);
1215        assert_eq!(log3_strict_usize(81), 4);
1216        assert_eq!(log3_strict_usize(243), 5);
1217        assert_eq!(log3_strict_usize(729), 6);
1218        assert_eq!(log3_strict_usize(2187), 7);
1219        assert_eq!(log3_strict_usize(6561), 8);
1220        assert_eq!(log3_strict_usize(19683), 9);
1221        assert_eq!(log3_strict_usize(59049), 10);
1222        assert_eq!(log3_strict_usize(177_147), 11);
1223        assert_eq!(log3_strict_usize(531_441), 12);
1224    }
1225
1226    #[test]
1227    #[should_panic(expected = "input must be non-zero")]
1228    fn test_log3_strict_panics_on_zero() {
1229        let _ = log3_strict_usize(0);
1230    }
1231
1232    #[test]
1233    #[should_panic(expected = "is not a power of 3")]
1234    fn test_log3_strict_panics_on_non_power_of_3() {
1235        // 2 is not a power of 3.
1236        let _ = log3_strict_usize(2);
1237    }
1238
1239    #[test]
1240    #[should_panic(expected = "is not a power of 3")]
1241    fn test_log3_strict_panics_on_power_of_2() {
1242        // 8 = 2^3 is not a power of 3.
1243        let _ = log3_strict_usize(8);
1244    }
1245
1246    #[test]
1247    #[should_panic(expected = "is not a power of 3")]
1248    fn test_log3_strict_panics_on_product_with_other_primes() {
1249        // 6 = 2 * 3 is not a power of 3.
1250        let _ = log3_strict_usize(6);
1251    }
1252
1253    proptest! {
1254        #[test]
1255        fn test_log3_strict_roundtrip(k in 0u32..25u32) {
1256            // Roundtrip: 3^k -> log3_strict_usize -> k
1257            let n = 3usize.pow(k);
1258            assert_eq!(log3_strict_usize(n), k as usize);
1259        }
1260    }
1261
1262    #[test]
1263    fn test_log2_ceil_usize_comprehensive() {
1264        // Powers of 2
1265        assert_eq!(log2_ceil_usize(0), 0);
1266        assert_eq!(log2_ceil_usize(1), 0);
1267        assert_eq!(log2_ceil_usize(2), 1);
1268        assert_eq!(log2_ceil_usize(1 << 18), 18);
1269        assert_eq!(log2_ceil_usize(1 << 31), 31);
1270        assert_eq!(
1271            log2_ceil_usize(1 << (usize::BITS - 1)),
1272            usize::BITS as usize - 1
1273        );
1274
1275        // Nonpowers; want to round up
1276        assert_eq!(log2_ceil_usize(3), 2);
1277        assert_eq!(log2_ceil_usize(0x14fe901b), 29);
1278        assert_eq!(
1279            log2_ceil_usize((1 << (usize::BITS - 1)) + 1),
1280            usize::BITS as usize
1281        );
1282        assert_eq!(log2_ceil_usize(usize::MAX - 1), usize::BITS as usize);
1283        assert_eq!(log2_ceil_usize(usize::MAX), usize::BITS as usize);
1284    }
1285
1286    fn reverse_index_bits_naive<T: Copy>(arr: &[T]) -> Vec<T> {
1287        let n = arr.len();
1288        let n_power = log2_strict_usize(n);
1289
1290        let mut out = vec![None; n];
1291        for (i, v) in arr.iter().enumerate() {
1292            let dst = i.reverse_bits() >> (usize::BITS - n_power as u32);
1293            out[dst] = Some(*v);
1294        }
1295
1296        out.into_iter().map(|x| x.unwrap()).collect()
1297    }
1298
1299    #[test]
1300    fn test_relatively_prime_u64() {
1301        // Zero cases (should always return false)
1302        assert!(!relatively_prime_u64(0, 0));
1303        assert!(!relatively_prime_u64(10, 0));
1304        assert!(!relatively_prime_u64(0, 10));
1305        assert!(!relatively_prime_u64(0, 123456789));
1306
1307        // Number with itself (if greater than 1, not relatively prime)
1308        assert!(relatively_prime_u64(1, 1));
1309        assert!(!relatively_prime_u64(10, 10));
1310        assert!(!relatively_prime_u64(99999, 99999));
1311
1312        // Powers of 2 (always false since they share factor 2)
1313        assert!(!relatively_prime_u64(2, 4));
1314        assert!(!relatively_prime_u64(16, 32));
1315        assert!(!relatively_prime_u64(64, 128));
1316        assert!(!relatively_prime_u64(1024, 4096));
1317        assert!(!relatively_prime_u64(u64::MAX, u64::MAX));
1318
1319        // One number is a multiple of the other (always false)
1320        assert!(!relatively_prime_u64(5, 10));
1321        assert!(!relatively_prime_u64(12, 36));
1322        assert!(!relatively_prime_u64(15, 45));
1323        assert!(!relatively_prime_u64(100, 500));
1324
1325        // Co-prime numbers (should be true)
1326        assert!(relatively_prime_u64(17, 31));
1327        assert!(relatively_prime_u64(97, 43));
1328        assert!(relatively_prime_u64(7919, 65537));
1329        assert!(relatively_prime_u64(15485863, 32452843));
1330
1331        // Small prime numbers (should be true)
1332        assert!(relatively_prime_u64(13, 17));
1333        assert!(relatively_prime_u64(101, 103));
1334        assert!(relatively_prime_u64(1009, 1013));
1335
1336        // Large numbers (some cases where they are relatively prime or not)
1337        assert!(!relatively_prime_u64(
1338            190266297176832000,
1339            10430732356495263744
1340        ));
1341        assert!(!relatively_prime_u64(
1342            2040134905096275968,
1343            5701159354248194048
1344        ));
1345        assert!(!relatively_prime_u64(
1346            16611311494648745984,
1347            7514969329383038976
1348        ));
1349        assert!(!relatively_prime_u64(
1350            14863931409971066880,
1351            7911906750992527360
1352        ));
1353
1354        // Max values
1355        assert!(relatively_prime_u64(u64::MAX, 1));
1356        assert!(relatively_prime_u64(u64::MAX, u64::MAX - 1));
1357        assert!(!relatively_prime_u64(u64::MAX, u64::MAX));
1358    }
1359}