p3_util/lib.rs
1//! Various simple utilities.
2
3#![no_std]
4
5extern crate alloc;
6
7use alloc::slice;
8use alloc::string::String;
9use alloc::vec::Vec;
10use core::any::type_name;
11use core::hint::unreachable_unchecked;
12use core::mem::{ManuallyDrop, MaybeUninit};
13use core::{iter, mem};
14
15use crate::transpose::transpose_in_place_square;
16
17pub mod array_serialization;
18pub mod linear_map;
19pub mod transpose;
20pub mod zip_eq;
21
22/// Computes `ceil(log_2(n))`.
23#[must_use]
24pub const fn log2_ceil_usize(n: usize) -> usize {
25 (usize::BITS - n.saturating_sub(1).leading_zeros()) as usize
26}
27
28/// Computes `floor(log_2(n))`.
29///
30/// Returns `0` for `n == 0` (matching `log2_ceil_usize(0) == 0`); `floor(log2(0))`
31/// is undefined mathematically and the saturating behaviour is the convention used
32/// elsewhere in the workspace.
33#[must_use]
34pub const fn log2_floor_usize(n: usize) -> usize {
35 if n == 0 {
36 return 0;
37 }
38 (usize::BITS - 1 - n.leading_zeros()) as usize
39}
40
41#[must_use]
42pub const fn log2_ceil_u64(n: u64) -> u64 {
43 (u64::BITS - n.saturating_sub(1).leading_zeros()) as u64
44}
45
46/// Returns `2^log_degree` if it can be represented by `usize`.
47#[must_use]
48pub const fn checked_pow2(log_degree: usize) -> Option<usize> {
49 if log_degree < usize::BITS as usize {
50 Some(1usize << log_degree)
51 } else {
52 None
53 }
54}
55
56/// Adds two log-sizes and computes the resulting power of two.
57///
58/// Returns:
59/// - `(a + b, 2^(a + b))` when the sum fits in a `usize` shift,
60/// - `None` if the addition overflows or the resulting power exceeds the representable range.
61#[must_use]
62pub const fn checked_log_size_sum(a: usize, b: usize) -> Option<(usize, usize)> {
63 match a.checked_add(b) {
64 Some(sum) => match checked_pow2(sum) {
65 Some(size) => Some((sum, size)),
66 None => None,
67 },
68 None => None,
69 }
70}
71
72/// Computes `log_2(n)`
73///
74/// # Panics
75/// Panics if `n` is not a power of two.
76#[must_use]
77#[inline]
78pub const fn log2_strict_usize(n: usize) -> usize {
79 let res = n.trailing_zeros();
80 assert!(n.wrapping_shr(res) == 1, "Not a power of two");
81 // Tell the optimizer about the semantics of `log2_strict`. i.e. it can replace `n` with
82 // `1 << res` and vice versa.
83 unsafe {
84 assume(n == 1 << res);
85 }
86 res as usize
87}
88
89/// Precomputed table of all powers of 3 that fit in a `u64`.
90///
91/// The maximum power is `3^40 = 12_157_665_459_056_928_801`.
92///
93/// We use `u64` instead of `usize` so the table compiles safely on 32-bit targets,
94/// where `3^40` would overflow a 32-bit `usize`.
95const POWERS_OF_3: [u64; 41] = {
96 // Start with 3^0 = 1.
97 let mut table = [0u64; 41];
98 table[0] = 1;
99
100 // Fill iteratively: each entry is 3 times the previous one.
101 let mut i = 1;
102 while i < 41 {
103 table[i] = table[i - 1] * 3;
104 i += 1;
105 }
106 table
107};
108
109/// Maps a bit-position (i.e. `floor(log2(n))`) to the corresponding base-3 exponent.
110///
111/// Because `3^k` grows faster than `2^k`, every power of 3 has a unique highest set
112/// bit position. This lets us use `leading_zeros()` to jump straight to the answer
113/// in O(1) without any loop or binary search.
114///
115/// Entries that don't correspond to any power of 3 are unused (left as 0).
116const LOG2_TO_EXP: [u8; 64] = {
117 // Initialize every slot to 0.
118 let mut table = [0u8; 64];
119
120 // For each power of 3, record which log2 bucket it falls into.
121 let mut i = 0;
122 while i < 41 {
123 // Compute floor(log2(3^i)) via the highest set bit.
124 let log2 = (u64::BITS - 1 - POWERS_OF_3[i].leading_zeros()) as usize;
125
126 // Store the exponent i at the corresponding bit-position.
127 table[log2] = i as u8;
128 i += 1;
129 }
130 table
131};
132
133/// Computes the strict base-3 logarithm of `n`.
134///
135/// Returns `k` such that `3^k == n`. Panics if `n` is not a power of 3.
136///
137/// This is the base-3 analogue of [`log2_strict_usize`].
138///
139/// # Arguments
140///
141/// * `n` - A positive integer that must be a power of 3 (i.e., 1, 3, 9, 27, 81, ...).
142///
143/// # Returns
144///
145/// The exponent `k` where `3^k == n`.
146///
147/// # Panics
148///
149/// Panics if:
150/// - `n` is zero
151/// - `n` is not a power of 3
152#[must_use]
153#[inline]
154pub const fn log3_strict_usize(n: usize) -> usize {
155 // Zero has no logarithm - check explicitly for a clear error message.
156 assert!(n != 0, "log3_strict_usize: input must be non-zero");
157
158 // Instantly find the candidate exponent via the highest set bit.
159 //
160 // Because every power of 3 occupies a unique log2 bucket, this single
161 // lookup gives us the answer in O(1) with zero branches.
162 let log2 = (usize::BITS - 1 - n.leading_zeros()) as usize;
163 let res = LOG2_TO_EXP[log2] as usize;
164
165 // Verify the result: catches non-powers of 3 in a single O(1) check.
166 assert!(
167 POWERS_OF_3[res] as usize == n,
168 "log3_strict_usize: input is not a power of 3"
169 );
170
171 res
172}
173
174/// Returns `[0, ..., N - 1]`.
175#[must_use]
176pub const fn indices_arr<const N: usize>() -> [usize; N] {
177 let mut indices_arr = [0; N];
178 let mut i = 0;
179 while i < N {
180 indices_arr[i] = i;
181 i += 1;
182 }
183 indices_arr
184}
185
186/// Statically asserts that `T` implements [`Clone`].
187pub const fn assert_clone<T: Clone>() {}
188
189/// Statically asserts that `T` implements [`Send`].
190pub const fn assert_send<T: Send>() {}
191
192/// Statically asserts that `T` implements [`Sync`].
193pub const fn assert_sync<T: Sync>() {}
194
195#[inline]
196pub const fn reverse_bits(x: usize, n: usize) -> usize {
197 // Assert that n is a power of 2
198 debug_assert!(n.is_power_of_two());
199 reverse_bits_len(x, n.trailing_zeros() as usize)
200}
201
202#[inline]
203pub const fn reverse_bits_len(x: usize, bit_len: usize) -> usize {
204 // A `bit_len` wider than the word would underflow the shift below.
205 // That yields a wrong, non-panicking permutation in release, so reject it up front.
206 debug_assert!(bit_len <= usize::BITS as usize);
207 // NB: The only reason we need overflowing_shr() here as opposed
208 // to plain '>>' is to accommodate the case n == num_bits == 0,
209 // which would become `0 >> 64`. Rust thinks that any shift of 64
210 // bits causes overflow, even when the argument is zero.
211 x.reverse_bits()
212 .overflowing_shr(usize::BITS - bit_len as u32)
213 .0
214}
215
216// Lookup table of 6-bit reverses.
217// NB: 2^6=64 bytes is a cache line. A smaller table wastes cache space.
218#[cfg(not(target_arch = "aarch64"))]
219#[rustfmt::skip]
220const BIT_REVERSE_6BIT: &[u8] = &[
221 0o00, 0o40, 0o20, 0o60, 0o10, 0o50, 0o30, 0o70,
222 0o04, 0o44, 0o24, 0o64, 0o14, 0o54, 0o34, 0o74,
223 0o02, 0o42, 0o22, 0o62, 0o12, 0o52, 0o32, 0o72,
224 0o06, 0o46, 0o26, 0o66, 0o16, 0o56, 0o36, 0o76,
225 0o01, 0o41, 0o21, 0o61, 0o11, 0o51, 0o31, 0o71,
226 0o05, 0o45, 0o25, 0o65, 0o15, 0o55, 0o35, 0o75,
227 0o03, 0o43, 0o23, 0o63, 0o13, 0o53, 0o33, 0o73,
228 0o07, 0o47, 0o27, 0o67, 0o17, 0o57, 0o37, 0o77,
229];
230
231// Ensure that SMALL_ARR_SIZE >= 4 * BIG_T_SIZE.
232const BIG_T_SIZE: usize = 1 << 14;
233const SMALL_ARR_SIZE: usize = 1 << 16;
234
235/// Permutes `arr` such that each index is mapped to its reverse in binary.
236///
237/// If the whole array fits in fast cache, then the trivial algorithm is cache friendly. Also, if
238/// `T` is really big, then the trivial algorithm is cache-friendly, no matter the size of the array.
239pub fn reverse_slice_index_bits<F>(vals: &mut [F])
240where
241 F: Copy + Send + Sync,
242{
243 let n = vals.len();
244 if n == 0 {
245 return;
246 }
247 let log_n = log2_strict_usize(n);
248
249 // If the whole array fits in fast cache, then the trivial algorithm is cache friendly. Also, if
250 // `T` is really big, then the trivial algorithm is cache-friendly, no matter the size of the array.
251 if core::mem::size_of::<F>() << log_n <= SMALL_ARR_SIZE
252 || core::mem::size_of::<F>() >= BIG_T_SIZE
253 {
254 reverse_slice_index_bits_small(vals, log_n);
255 } else {
256 debug_assert!(n >= 4); // By our choice of `BIG_T_SIZE` and `SMALL_ARR_SIZE`.
257
258 // Algorithm:
259 //
260 // Treat `arr` as a `sqrt(n)` by `sqrt(n)` row-major matrix. (Assume for now that `lb_n` is
261 // even, i.e., `n` is a square number.) To perform bit-order reversal we:
262 // 1. Bit-reverse the order of the rows. (They are contiguous in memory, so this is
263 // basically a series of large `memcpy`s.)
264 // 2. Transpose the matrix.
265 // 3. Bit-reverse the order of the rows.
266 //
267 // This is equivalent to, for every index `0 <= i < n`:
268 // 1. bit-reversing `i[lb_n / 2..lb_n]`,
269 // 2. swapping `i[0..lb_n / 2]` and `i[lb_n / 2..lb_n]`,
270 // 3. bit-reversing `i[lb_n / 2..lb_n]`.
271 //
272 // If `lb_n` is odd, i.e., `n` is not a square number, then the above procedure requires
273 // slight modification. At steps 1 and 3 we bit-reverse bits `ceil(lb_n / 2)..lb_n`, of the
274 // index (shuffling `floor(lb_n / 2)` chunks of length `ceil(lb_n / 2)`). At step 2, we
275 // perform _two_ transposes. We treat `arr` as two matrices, one where the middle bit of the
276 // index is `0` and another, where the middle bit is `1`; we transpose each individually.
277
278 let lb_num_chunks = log_n >> 1;
279 let lb_chunk_size = log_n - lb_num_chunks;
280 unsafe {
281 reverse_slice_index_bits_chunks(vals, lb_num_chunks, lb_chunk_size);
282 transpose_in_place_square(vals, lb_chunk_size, lb_num_chunks, 0);
283 if lb_num_chunks != lb_chunk_size {
284 // `arr` cannot be interpreted as a square matrix. We instead interpret it as a
285 // `1 << lb_num_chunks` by `2` by `1 << lb_num_chunks` tensor, in row-major order.
286 // The above transpose acted on `tensor[..., 0, ...]` (all indices with middle bit
287 // `0`). We still need to transpose `tensor[..., 1, ...]`. To do so, we advance
288 // arr by `1 << lb_num_chunks` effectively, adding that to every index.
289 let vals_with_offset = &mut vals[1 << lb_num_chunks..];
290 transpose_in_place_square(vals_with_offset, lb_chunk_size, lb_num_chunks, 0);
291 }
292 reverse_slice_index_bits_chunks(vals, lb_num_chunks, lb_chunk_size);
293 }
294 }
295}
296
297// Both functions below are semantically equivalent to:
298// for i in 0..n {
299// result.push(arr[reverse_bits(i, n_power)]);
300// }
301// where reverse_bits(i, n_power) computes the n_power-bit reverse. The complications are there
302// to guide the compiler to generate optimal assembly.
303
304#[cfg(not(target_arch = "aarch64"))]
305fn reverse_slice_index_bits_small<F>(vals: &mut [F], lb_n: usize) {
306 if lb_n <= 6 {
307 // BIT_REVERSE_6BIT holds 6-bit reverses. This shift makes them lb_n-bit reverses.
308 let dst_shr_amt = 6 - lb_n as u32;
309 for (src, &br) in BIT_REVERSE_6BIT.iter().enumerate().take(vals.len()) {
310 let dst = (br as usize).wrapping_shr(dst_shr_amt);
311 if src < dst {
312 vals.swap(src, dst);
313 }
314 }
315 } else {
316 // LLVM does not know that it does not need to reverse src at each iteration (which is
317 // expensive on x86). We take advantage of the fact that the low bits of dst change rarely and the high
318 // bits of dst are dependent only on the low bits of src.
319 let dst_lo_shr_amt = usize::BITS - (lb_n - 6) as u32;
320 let dst_hi_shl_amt = lb_n - 6;
321 for src_chunk in 0..(vals.len() >> 6) {
322 let src_hi = src_chunk << 6;
323 let dst_lo = src_chunk.reverse_bits().wrapping_shr(dst_lo_shr_amt);
324 for (src_lo, &br) in BIT_REVERSE_6BIT.iter().enumerate() {
325 let dst_hi = (br as usize) << dst_hi_shl_amt;
326 let src = src_hi + src_lo;
327 let dst = dst_hi + dst_lo;
328 if src < dst {
329 vals.swap(src, dst);
330 }
331 }
332 }
333 }
334}
335
336#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
337const fn reverse_slice_index_bits_small<F>(vals: &mut [F], lb_n: usize) {
338 // Aarch64 can reverse bits in one instruction, so the trivial version works best.
339 // use manual `while` loop to enable `const`
340 let mut src = 0;
341 while src < vals.len() {
342 let dst = src.reverse_bits().wrapping_shr(usize::BITS - lb_n as u32);
343 if src < dst {
344 vals.swap(src, dst);
345 }
346
347 src += 1;
348 }
349}
350
351/// Split `arr` chunks and bit-reverse the order of the chunks. There are `1 << lb_num_chunks`
352/// chunks, each of length `1 << lb_chunk_size`.
353/// SAFETY: ensure that `arr.len() == 1 << lb_num_chunks + lb_chunk_size`.
354unsafe fn reverse_slice_index_bits_chunks<F>(
355 vals: &mut [F],
356 lb_num_chunks: usize,
357 lb_chunk_size: usize,
358) {
359 for i in 0..1usize << lb_num_chunks {
360 // `wrapping_shr` handles the silly case when `lb_num_chunks == 0`.
361 let j = i
362 .reverse_bits()
363 .wrapping_shr(usize::BITS - lb_num_chunks as u32);
364 if i < j {
365 unsafe {
366 core::ptr::swap_nonoverlapping(
367 vals.get_unchecked_mut(i << lb_chunk_size),
368 vals.get_unchecked_mut(j << lb_chunk_size),
369 1 << lb_chunk_size,
370 );
371 }
372 }
373 }
374}
375
376/// Allow the compiler to assume that the given predicate `p` is always `true`.
377///
378/// # Safety
379///
380/// Callers must ensure that `p` is true. If this is not the case, the behavior is undefined.
381#[inline(always)]
382pub const unsafe fn assume(p: bool) {
383 debug_assert!(p);
384 if !p {
385 unsafe {
386 unreachable_unchecked();
387 }
388 }
389}
390
391/// Try to force Rust to emit a branch. Example:
392///
393/// ```no_run
394/// let x = 100;
395/// if x > 20 {
396/// println!("x is big!");
397/// p3_util::branch_hint();
398/// } else {
399/// println!("x is small!");
400/// }
401/// ```
402///
403/// This function has no semantics. It is a hint only.
404#[inline(always)]
405pub fn branch_hint() {
406 // NOTE: These are the currently supported assembly architectures. See the
407 // [nightly reference](https://doc.rust-lang.org/nightly/reference/inline-assembly.html) for
408 // the most up-to-date list.
409 #[cfg(any(
410 target_arch = "aarch64",
411 target_arch = "arm",
412 target_arch = "riscv32",
413 target_arch = "riscv64",
414 target_arch = "x86",
415 target_arch = "x86_64",
416 ))]
417 unsafe {
418 core::arch::asm!("", options(nomem, nostack, preserves_flags));
419 }
420}
421
422/// Return a String containing the name of T but with all the crate
423/// and module prefixes removed.
424pub fn pretty_name<T>() -> String {
425 let name = type_name::<T>();
426 let mut result = String::new();
427 for qual in name.split_inclusive(&['<', '>', ',']) {
428 result.push_str(qual.split("::").last().unwrap());
429 }
430 result
431}
432
433/// A C-style buffered input reader, similar to
434/// `core::iter::Iterator::next_chunk()` from nightly.
435///
436/// Returns an array of `MaybeUninit<T>` and the number of items in the
437/// array which have been correctly initialized.
438#[inline]
439fn iter_next_chunk_erased<const BUFLEN: usize, I: Iterator>(
440 iter: &mut I,
441) -> ([MaybeUninit<I::Item>; BUFLEN], usize)
442where
443 I::Item: Copy,
444{
445 let mut buf = [const { MaybeUninit::<I::Item>::uninit() }; BUFLEN];
446 let mut i = 0;
447
448 while i < BUFLEN {
449 if let Some(c) = iter.next() {
450 // Copy the next Item into `buf`.
451 unsafe {
452 buf.get_unchecked_mut(i).write(c);
453 i = i.unchecked_add(1);
454 }
455 } else {
456 // No more items in the iterator.
457 break;
458 }
459 }
460 (buf, i)
461}
462
463/// Split an iterator into small arrays and apply `func` to each.
464///
465/// Repeatedly read `BUFLEN` elements from `input` into an array and
466/// pass the array to `func` as a slice. If less than `BUFLEN`
467/// elements are remaining, that smaller slice is passed to `func` (if
468/// it is non-empty) and the function returns.
469#[inline]
470pub fn apply_to_chunks<const BUFLEN: usize, I, H>(input: I, mut func: H)
471where
472 I: IntoIterator<Item = u8>,
473 H: FnMut(&[I::Item]),
474{
475 let mut iter = input.into_iter();
476 loop {
477 let (buf, n) = iter_next_chunk_erased::<BUFLEN, _>(&mut iter);
478 if n == 0 {
479 break;
480 }
481 func(unsafe { buf.get_unchecked(..n).assume_init_ref() });
482 }
483}
484
485/// Pulls `N` items from `iter` and returns them as an array. If the iterator
486/// yields fewer than `N` items (but more than `0`), pads by the given default value.
487///
488/// Since the iterator is passed as a mutable reference and this function calls
489/// `next` at most `N` times, the iterator can still be used afterwards to
490/// retrieve the remaining items.
491///
492/// If `iter.next()` panics, all items already yielded by the iterator are
493/// dropped.
494#[inline]
495fn iter_next_chunk_padded<T: Copy, const N: usize>(
496 iter: &mut impl Iterator<Item = T>,
497 default: T, // Needed due to [T; M] not always implementing Default. Can probably be dropped if const generics stabilize.
498) -> Option<[T; N]> {
499 let (mut arr, n) = iter_next_chunk_erased::<N, _>(iter);
500 (n != 0).then(|| {
501 // Fill the rest of the array with default values.
502 arr[n..].fill(MaybeUninit::new(default));
503 unsafe { mem::transmute_copy::<_, [T; N]>(&arr) }
504 })
505}
506
507/// Returns an iterator over `N` elements of the iterator at a time.
508///
509/// The chunks do not overlap. If `N` does not divide the length of the
510/// iterator, then the last `N-1` elements will be padded with the given default value.
511///
512/// This is essentially a copy pasted version of the nightly `array_chunks` function.
513/// <https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.array_chunks>
514/// Once that is stabilized this and the functions above it should be removed.
515#[inline]
516pub fn iter_array_chunks_padded<T: Copy, const N: usize>(
517 iter: impl IntoIterator<Item = T>,
518 default: T, // Needed due to [T; M] not always implementing Default. Can probably be dropped if const generics stabilize.
519) -> impl Iterator<Item = [T; N]> {
520 let mut iter = iter.into_iter();
521 iter::from_fn(move || iter_next_chunk_padded(&mut iter, default))
522}
523
524/// Reinterpret a slice of `BaseArray` elements as a slice of `Base` elements
525///
526/// This is useful to convert `&[F; N]` to `&[F]` or `&[A]` to `&[F]` where
527/// `A` has the same size, alignment and memory layout as `[F; N]` for some `N`.
528///
529/// # Safety
530///
531/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
532/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
533/// the array is the same as the alignment of its elements, this means that `BaseArray`
534/// must have the same alignment as `Base`.
535///
536/// # Panics
537///
538/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
539#[inline]
540pub const unsafe fn as_base_slice<Base, BaseArray>(buf: &[BaseArray]) -> &[Base] {
541 const {
542 assert!(align_of::<Base>() == align_of::<BaseArray>());
543 assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
544 }
545
546 let d = size_of::<BaseArray>() / size_of::<Base>();
547
548 let buf_ptr = buf.as_ptr().cast::<Base>();
549 let n = buf.len() * d;
550 unsafe { slice::from_raw_parts(buf_ptr, n) }
551}
552
553/// Reinterpret a mutable slice of `BaseArray` elements as a slice of `Base` elements
554///
555/// This is useful to convert `&[F; N]` to `&[F]` or `&[A]` to `&[F]` where
556/// `A` has the same size, alignment and memory layout as `[F; N]` for some `N`.
557///
558/// # Safety
559///
560/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
561/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
562/// the array is the same as the alignment of its elements, this means that `BaseArray`
563/// must have the same alignment as `Base`.
564///
565/// # Panics
566///
567/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
568#[inline]
569pub const unsafe fn as_base_slice_mut<Base, BaseArray>(buf: &mut [BaseArray]) -> &mut [Base] {
570 const {
571 assert!(align_of::<Base>() == align_of::<BaseArray>());
572 assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
573 }
574
575 let d = size_of::<BaseArray>() / size_of::<Base>();
576
577 let buf_ptr = buf.as_mut_ptr().cast::<Base>();
578 let n = buf.len() * d;
579 unsafe { slice::from_raw_parts_mut(buf_ptr, n) }
580}
581
582/// Convert a vector of `BaseArray` elements to a vector of `Base` elements without any
583/// reallocations.
584///
585/// This is useful to convert `Vec<[F; N]>` to `Vec<F>` or `Vec<A>` to `Vec<F>` where
586/// `A` has the same size, alignment and memory layout as `[F; N]` for some `N`. It can also,
587/// be used to safely convert `Vec<u32>` to `Vec<F>` if `F` is a `32` bit field
588/// or `Vec<u64>` to `Vec<F>` if `F` is a `64` bit field.
589///
590/// # Safety
591///
592/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
593/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
594/// the array is the same as the alignment of its elements, this means that `BaseArray`
595/// must have the same alignment as `Base`.
596///
597/// # Panics
598///
599/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
600#[inline]
601pub unsafe fn flatten_to_base<Base, BaseArray>(vec: Vec<BaseArray>) -> Vec<Base> {
602 const {
603 assert!(align_of::<Base>() == align_of::<BaseArray>());
604 assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
605 }
606
607 let d = size_of::<BaseArray>() / size_of::<Base>();
608 // Prevent running `vec`'s destructor so we are in complete control
609 // of the allocation.
610 let mut values = ManuallyDrop::new(vec);
611
612 // Each `Self` is an array of `d` elements, so the length and capacity of
613 // the new vector will be multiplied by `d`.
614 let new_len = values.len() * d;
615 let new_cap = values.capacity() * d;
616
617 // Safe as BaseArray and Base have the same alignment.
618 let ptr = values.as_mut_ptr() as *mut Base;
619
620 unsafe {
621 // Safety:
622 // - BaseArray and Base have the same alignment.
623 // - As size_of::<BaseArray>() == size_of::<Base>() * d:
624 // -- The capacity of the new vector is equal to the capacity of the old vector.
625 // -- The first new_len elements of the new vector correspond to the first
626 // len elements of the old vector and so are properly initialized.
627 Vec::from_raw_parts(ptr, new_len, new_cap)
628 }
629}
630
631/// Convert a vector of `Base` elements to a vector of `BaseArray` elements ideally without any
632/// reallocations.
633///
634/// This is an inverse of `flatten_to_base`. Unfortunately, unlike `flatten_to_base`, it may not be
635/// possible to avoid allocations. This issue is that there is not way to guarantee that the capacity
636/// of the vector is a multiple of `d`.
637///
638/// # Safety
639///
640/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
641/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
642/// the array is the same as the alignment of its elements, this means that `BaseArray`
643/// must have the same alignment as `Base`.
644///
645/// # Panics
646///
647/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
648/// This panics if the length of the vector is not a multiple of the ratio of the sizes.
649#[inline]
650pub unsafe fn reconstitute_from_base<Base, BaseArray: Clone>(mut vec: Vec<Base>) -> Vec<BaseArray> {
651 const {
652 assert!(align_of::<Base>() == align_of::<BaseArray>());
653 assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
654 }
655
656 let d = size_of::<BaseArray>() / size_of::<Base>();
657
658 assert!(
659 vec.len().is_multiple_of(d),
660 "Vector length (got {}) must be a multiple of the extension field dimension ({}).",
661 vec.len(),
662 d
663 );
664
665 let new_len = vec.len() / d;
666
667 // We could call vec.shrink_to_fit() here to try and increase the probability that
668 // the capacity is a multiple of d. That might cause a reallocation though which
669 // would defeat the whole purpose.
670 let cap = vec.capacity();
671
672 // The assumption is that basically all callers of `reconstitute_from_base_vec` will be calling it
673 // with a vector constructed from `flatten_to_base` and so the capacity should be a multiple of `d`.
674 // But capacities can do strange things so we need to support both possibilities.
675 // Note that the `else` branch would also work if the capacity is a multiple of `d` but it is slower.
676 if cap.is_multiple_of(d) {
677 // Prevent running `vec`'s destructor so we are in complete control
678 // of the allocation.
679 let mut values = ManuallyDrop::new(vec);
680
681 // If we are on this branch then the capacity is a multiple of `d`.
682 let new_cap = cap / d;
683
684 // Safe as BaseArray and Base have the same alignment.
685 let ptr = values.as_mut_ptr() as *mut BaseArray;
686
687 unsafe {
688 // Safety:
689 // - BaseArray and Base have the same alignment.
690 // - As size_of::<Base>() == size_of::<BaseArray>() / d:
691 // -- If we have reached this point, the length and capacity are both divisible by `d`.
692 // -- The capacity of the new vector is equal to the capacity of the old vector.
693 // -- The first new_len elements of the new vector correspond to the first
694 // len elements of the old vector and so are properly initialized.
695 Vec::from_raw_parts(ptr, new_len, new_cap)
696 }
697 } else {
698 // If the capacity is not a multiple of `D`, we go via slices.
699
700 let buf_ptr = vec.as_mut_ptr().cast::<BaseArray>();
701 let slice = unsafe {
702 // Safety:
703 // - BaseArray and Base have the same alignment.
704 // - As size_of::<Base>() == size_of::<BaseArray>() / D:
705 // -- If we have reached this point, the length is divisible by `D`.
706 // -- The first new_len elements of the slice correspond to the first
707 // len elements of the old slice and so are properly initialized.
708 slice::from_raw_parts(buf_ptr, new_len)
709 };
710
711 // Ideally the compiler could optimize this away to avoid the copy but it appears not to.
712 slice.to_vec()
713 }
714}
715
716#[inline(always)]
717pub const fn relatively_prime_u64(mut u: u64, mut v: u64) -> bool {
718 // Check that neither input is 0.
719 if u == 0 || v == 0 {
720 return false;
721 }
722
723 // Check divisibility by 2.
724 if (u | v) & 1 == 0 {
725 return false;
726 }
727
728 // Remove factors of 2 from `u` and `v`
729 u >>= u.trailing_zeros();
730 if u == 1 {
731 return true;
732 }
733
734 while v != 0 {
735 v >>= v.trailing_zeros();
736 if v == 1 {
737 return true;
738 }
739
740 // Ensure u <= v
741 if u > v {
742 core::mem::swap(&mut u, &mut v);
743 }
744
745 // This looks inefficient for v >> u but thanks to the fact that we remove
746 // trailing_zeros of v in every iteration, it ends up much more performative
747 // than first glance implies.
748 v -= u;
749 }
750 // If we made it through the loop, at no point is u or v equal to 1 and so the gcd
751 // must be greater than 1.
752 false
753}
754
755/// Inner loop of the deferred GCD algorithm.
756///
757/// See: <https://eprint.iacr.org/2020/972.pdf> for more information.
758///
759/// This is basically a mini GCD algorithm which builds up a transformation to apply to the larger
760/// numbers in the main loop. The key point is that this small loop only uses u64s, subtractions and
761/// bit shifts, which are very fast operations.
762///
763/// The bottom `NUM_ROUNDS` bits of `a` and `b` should match the bottom `NUM_ROUNDS` bits of
764/// the corresponding big-ints and the top `NUM_ROUNDS + 2` should match the top bits including
765/// zeroes if the original numbers have different sizes.
766#[inline]
767pub const fn gcd_inner<const NUM_ROUNDS: usize>(a: &mut u64, b: &mut u64) -> (i64, i64, i64, i64) {
768 // Initialise update factors.
769 // At the start of round 0: -1 < f0, g0, f1, g1 <= 1
770 let (mut f0, mut g0, mut f1, mut g1) = (1, 0, 0, 1);
771
772 // If at the start of a round: -2^i < f0, g0, f1, g1 <= 2^i
773 // Then, at the end of the round: -2^{i + 1} < f0, g0, f1, g1 <= 2^{i + 1}
774 // use manual `while` loop to enable `const`
775 let mut round = 0;
776 while round < NUM_ROUNDS {
777 if *a & 1 == 0 {
778 *a >>= 1;
779 } else {
780 if *a < *b {
781 core::mem::swap(a, b);
782 (f0, f1) = (f1, f0);
783 (g0, g1) = (g1, g0);
784 }
785 *a -= *b;
786 *a >>= 1;
787 f0 -= f1;
788 g0 -= g1;
789 }
790 f1 <<= 1;
791 g1 <<= 1;
792
793 round += 1;
794 }
795
796 // -2^NUM_ROUNDS < f0, g0, f1, g1 <= 2^NUM_ROUNDS
797 // Hence provided NUM_ROUNDS <= 62, we will not get any overflow.
798 // Additionally, if NUM_ROUNDS <= 63, then the only source of overflow will be
799 // if a variable is meant to equal 2^{63} in which case it will overflow to -2^{63}.
800 (f0, g0, f1, g1)
801}
802
803/// Inverts elements inside the prime field `F_P` with `P < 2^FIELD_BITS`.
804///
805/// Arguments:
806/// - a: The value we want to invert. It must be < P.
807/// - b: The value of the prime `P > 2`.
808///
809/// Output:
810/// - A `64-bit` signed integer `v` equal to `2^{2 * FIELD_BITS - 2} a^{-1} mod P` with
811/// size `|v| < 2^{2 * FIELD_BITS - 2}`.
812///
813/// It is up to the user to ensure that `b` is an odd prime with at most `FIELD_BITS` bits and
814/// `a < b`. If either of these assumptions break, the output is undefined.
815#[inline]
816pub const fn gcd_inversion_prime_field_32<const FIELD_BITS: u32>(mut a: u32, mut b: u32) -> i64 {
817 const {
818 assert!(FIELD_BITS <= 32);
819 }
820 debug_assert!(((1_u64 << FIELD_BITS) - 1) >= b as u64);
821
822 // Initialise u, v. Note that |u|, |v| <= 2^0
823 let (mut u, mut v) = (1_i64, 0_i64);
824
825 // Let a0 and P denote the initial values of a and b. Observe:
826 // `a = u * a0 mod P`
827 // `b = v * a0 mod P`
828 // `len(a) + len(b) <= 2 * len(P) <= 2 * FIELD_BITS`
829
830 // use manual `while` loop to enable `const`
831 let mut i = 0;
832 while i < 2 * FIELD_BITS - 2 {
833 // Assume at the start of the loop i:
834 // (1) `|u|, |v| <= 2^{i}`
835 // (2) `2^i * a = u * a0 mod P`
836 // (3) `2^i * b = v * a0 mod P`
837 // (4) `gcd(a, b) = 1`
838 // (5) `b` is odd.
839 // (6) `len(a) + len(b) <= max(n - i, 1)`
840
841 if a & 1 != 0 {
842 if a < b {
843 (a, b) = (b, a);
844 (u, v) = (v, u);
845 }
846 // As b < a, this subtraction cannot increase `len(a) + len(b)`
847 a -= b;
848 // Observe |u'| = |u - v| <= |u| + |v| <= 2^{i + 1}
849 u -= v;
850
851 // As (1) and (2) hold, we have
852 // `2^i a' = 2^i * (a - b) = (u - v) * a0 mod P = u' * a0 mod P`
853 }
854 // As b is odd, a must now be even.
855 // This reduces `len(a) + len(b)` by 1 (unless `a = 0` in which case `b = 1` and the sum of the lengths is always 1)
856 a >>= 1;
857
858 // Observe |v'| = 2|v| <= 2^{i + 1}
859 v <<= 1;
860
861 // Thus as the end of loop i:
862 // (1) `|u|, |v| <= 2^{i + 1}`
863 // (2) `2^{i + 1} * a = u * a0 mod P` (As we have halved a)
864 // (3) `2^{i + 1} * b = v * a0 mod P` (As we have doubled v)
865 // (4) `gcd(a, b) = 1`
866 // (5) `b` is odd.
867 // (6) `len(a) + len(b) <= max(n - i - 1, 1)`
868
869 i += 1;
870 }
871
872 // After the loops, we see that:
873 // |u|, |v| <= 2^{2 * FIELD_BITS - 2}: Hence for FIELD_BITS <= 32 we will not overflow an i64.
874 // `2^{2 * FIELD_BITS - 2} * b = v * a0 mod P`
875 // `len(a) + len(b) <= 2` with `gcd(a, b) = 1` and `b` odd.
876 // This implies that `b` must be `1` and so `v = 2^{2 * FIELD_BITS - 2} a0^{-1} mod P` as desired.
877 v
878}
879
880/// A raw mutable pointer wrapper that implements [`Send`] and [`Sync`].
881///
882/// Used to enable parallel writes to disjoint slices of a pre-allocated buffer
883/// from within closures that require `Send + Sync` (e.g. `rayon::ParallelIterator::for_each_init`).
884///
885/// # Safety
886///
887/// The caller must ensure that concurrent accesses through this pointer always
888/// target **non-overlapping** memory regions.
889#[derive(Clone, Copy)]
890pub struct DisjointMutPtr<T>(*mut T);
891
892// SAFETY: The contract of DisjointMutPtr guarantees that each thread writes to
893// a disjoint region, so sharing the pointer across threads is safe.
894unsafe impl<T> Send for DisjointMutPtr<T> {}
895unsafe impl<T> Sync for DisjointMutPtr<T> {}
896
897impl<T> DisjointMutPtr<T> {
898 /// Create a new `DisjointMutPtr` from a mutable slice.
899 #[inline]
900 pub const fn new(slice: &mut [T]) -> Self {
901 Self(slice.as_mut_ptr())
902 }
903
904 /// Get a mutable slice starting at `offset` with `len` elements.
905 ///
906 /// # Safety
907 ///
908 /// The caller must ensure the range `[offset, offset+len)` is within bounds
909 /// and does not overlap with any other concurrent access.
910 #[inline]
911 pub const unsafe fn slice_mut(self, offset: usize, len: usize) -> &'static mut [T] {
912 unsafe { core::slice::from_raw_parts_mut(self.0.add(offset), len) }
913 }
914}
915
916#[cfg(test)]
917mod tests {
918 use alloc::vec;
919 use alloc::vec::Vec;
920
921 use proptest::prelude::*;
922 use rand::rngs::SmallRng;
923 use rand::{RngExt, SeedableRng};
924
925 use super::*;
926
927 #[test]
928 fn test_reverse_bits_len() {
929 assert_eq!(reverse_bits_len(0b0000000000, 10), 0b0000000000);
930 assert_eq!(reverse_bits_len(0b0000000001, 10), 0b1000000000);
931 assert_eq!(reverse_bits_len(0b1000000000, 10), 0b0000000001);
932 assert_eq!(reverse_bits_len(0b00000, 5), 0b00000);
933 assert_eq!(reverse_bits_len(0b01011, 5), 0b11010);
934 }
935
936 #[test]
937 fn test_reverse_bits_len_full_width() {
938 // A full-width reversal is the largest valid bit length and must reverse every bit.
939 let bits = usize::BITS as usize;
940 assert_eq!(reverse_bits_len(1, bits), 1 << (bits - 1));
941 assert_eq!(reverse_bits_len(1 << (bits - 1), bits), 1);
942 }
943
944 #[test]
945 #[cfg(debug_assertions)]
946 #[should_panic(expected = "bit_len <= usize::BITS")]
947 fn test_reverse_bits_len_rejects_oversized_bit_len() {
948 // One bit past the word width: the shift would underflow into a wrong permutation.
949 // The expected message pins the guard, not the incidental subtraction-overflow panic.
950 let _ = reverse_bits_len(0, usize::BITS as usize + 1);
951 }
952
953 #[test]
954 fn test_reverse_index_bits() {
955 let mut arg = vec![10, 20, 30, 40];
956 reverse_slice_index_bits(&mut arg);
957 assert_eq!(arg, vec![10, 30, 20, 40]);
958
959 let mut input256: Vec<u64> = (0..256).collect();
960 #[rustfmt::skip]
961 let output256: Vec<u64> = vec![
962 0x00, 0x80, 0x40, 0xc0, 0x20, 0xa0, 0x60, 0xe0, 0x10, 0x90, 0x50, 0xd0, 0x30, 0xb0, 0x70, 0xf0,
963 0x08, 0x88, 0x48, 0xc8, 0x28, 0xa8, 0x68, 0xe8, 0x18, 0x98, 0x58, 0xd8, 0x38, 0xb8, 0x78, 0xf8,
964 0x04, 0x84, 0x44, 0xc4, 0x24, 0xa4, 0x64, 0xe4, 0x14, 0x94, 0x54, 0xd4, 0x34, 0xb4, 0x74, 0xf4,
965 0x0c, 0x8c, 0x4c, 0xcc, 0x2c, 0xac, 0x6c, 0xec, 0x1c, 0x9c, 0x5c, 0xdc, 0x3c, 0xbc, 0x7c, 0xfc,
966 0x02, 0x82, 0x42, 0xc2, 0x22, 0xa2, 0x62, 0xe2, 0x12, 0x92, 0x52, 0xd2, 0x32, 0xb2, 0x72, 0xf2,
967 0x0a, 0x8a, 0x4a, 0xca, 0x2a, 0xaa, 0x6a, 0xea, 0x1a, 0x9a, 0x5a, 0xda, 0x3a, 0xba, 0x7a, 0xfa,
968 0x06, 0x86, 0x46, 0xc6, 0x26, 0xa6, 0x66, 0xe6, 0x16, 0x96, 0x56, 0xd6, 0x36, 0xb6, 0x76, 0xf6,
969 0x0e, 0x8e, 0x4e, 0xce, 0x2e, 0xae, 0x6e, 0xee, 0x1e, 0x9e, 0x5e, 0xde, 0x3e, 0xbe, 0x7e, 0xfe,
970 0x01, 0x81, 0x41, 0xc1, 0x21, 0xa1, 0x61, 0xe1, 0x11, 0x91, 0x51, 0xd1, 0x31, 0xb1, 0x71, 0xf1,
971 0x09, 0x89, 0x49, 0xc9, 0x29, 0xa9, 0x69, 0xe9, 0x19, 0x99, 0x59, 0xd9, 0x39, 0xb9, 0x79, 0xf9,
972 0x05, 0x85, 0x45, 0xc5, 0x25, 0xa5, 0x65, 0xe5, 0x15, 0x95, 0x55, 0xd5, 0x35, 0xb5, 0x75, 0xf5,
973 0x0d, 0x8d, 0x4d, 0xcd, 0x2d, 0xad, 0x6d, 0xed, 0x1d, 0x9d, 0x5d, 0xdd, 0x3d, 0xbd, 0x7d, 0xfd,
974 0x03, 0x83, 0x43, 0xc3, 0x23, 0xa3, 0x63, 0xe3, 0x13, 0x93, 0x53, 0xd3, 0x33, 0xb3, 0x73, 0xf3,
975 0x0b, 0x8b, 0x4b, 0xcb, 0x2b, 0xab, 0x6b, 0xeb, 0x1b, 0x9b, 0x5b, 0xdb, 0x3b, 0xbb, 0x7b, 0xfb,
976 0x07, 0x87, 0x47, 0xc7, 0x27, 0xa7, 0x67, 0xe7, 0x17, 0x97, 0x57, 0xd7, 0x37, 0xb7, 0x77, 0xf7,
977 0x0f, 0x8f, 0x4f, 0xcf, 0x2f, 0xaf, 0x6f, 0xef, 0x1f, 0x9f, 0x5f, 0xdf, 0x3f, 0xbf, 0x7f, 0xff,
978 ];
979 reverse_slice_index_bits(&mut input256[..]);
980 assert_eq!(input256, output256);
981 }
982
983 #[test]
984 fn test_apply_to_chunks_exact_fit() {
985 const CHUNK_SIZE: usize = 4;
986 let input: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8];
987 let mut results: Vec<Vec<u8>> = Vec::new();
988
989 apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
990 results.push(chunk.to_vec());
991 });
992
993 assert_eq!(results, vec![vec![1, 2, 3, 4], vec![5, 6, 7, 8]]);
994 }
995
996 #[test]
997 fn test_apply_to_chunks_with_remainder() {
998 const CHUNK_SIZE: usize = 3;
999 let input: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7];
1000 let mut results: Vec<Vec<u8>> = Vec::new();
1001
1002 apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
1003 results.push(chunk.to_vec());
1004 });
1005
1006 assert_eq!(results, vec![vec![1, 2, 3], vec![4, 5, 6], vec![7]]);
1007 }
1008
1009 #[test]
1010 fn test_apply_to_chunks_empty_input() {
1011 const CHUNK_SIZE: usize = 4;
1012 let input: Vec<u8> = vec![];
1013 let mut results: Vec<Vec<u8>> = Vec::new();
1014
1015 apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
1016 results.push(chunk.to_vec());
1017 });
1018
1019 assert!(results.is_empty());
1020 }
1021
1022 #[test]
1023 fn test_apply_to_chunks_single_chunk() {
1024 const CHUNK_SIZE: usize = 10;
1025 let input: Vec<u8> = vec![1, 2, 3, 4, 5];
1026 let mut results: Vec<Vec<u8>> = Vec::new();
1027
1028 apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
1029 results.push(chunk.to_vec());
1030 });
1031
1032 assert_eq!(results, vec![vec![1, 2, 3, 4, 5]]);
1033 }
1034
1035 #[test]
1036 fn test_apply_to_chunks_large_chunk_size() {
1037 const CHUNK_SIZE: usize = 100;
1038 let input: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8];
1039 let mut results: Vec<Vec<u8>> = Vec::new();
1040
1041 apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
1042 results.push(chunk.to_vec());
1043 });
1044
1045 assert_eq!(results, vec![vec![1, 2, 3, 4, 5, 6, 7, 8]]);
1046 }
1047
1048 #[test]
1049 fn test_apply_to_chunks_large_input() {
1050 const CHUNK_SIZE: usize = 5;
1051 let input: Vec<u8> = (1..=20).collect();
1052 let mut results: Vec<Vec<u8>> = Vec::new();
1053
1054 apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
1055 results.push(chunk.to_vec());
1056 });
1057
1058 assert_eq!(
1059 results,
1060 vec![
1061 vec![1, 2, 3, 4, 5],
1062 vec![6, 7, 8, 9, 10],
1063 vec![11, 12, 13, 14, 15],
1064 vec![16, 17, 18, 19, 20]
1065 ]
1066 );
1067 }
1068
1069 #[test]
1070 fn test_reverse_slice_index_bits_random() {
1071 let lengths = [32, 128, 1 << 16];
1072 let mut rng = SmallRng::seed_from_u64(1);
1073 for _ in 0..32 {
1074 for &length in &lengths {
1075 let mut rand_list: Vec<u32> = Vec::with_capacity(length);
1076 rand_list.resize_with(length, || rng.random());
1077 let expect = reverse_index_bits_naive(&rand_list);
1078
1079 let mut actual = rand_list.clone();
1080 reverse_slice_index_bits(&mut actual);
1081
1082 assert_eq!(actual, expect);
1083 }
1084 }
1085 }
1086
1087 #[test]
1088 fn test_log2_strict_usize_edge_cases() {
1089 assert_eq!(log2_strict_usize(1), 0);
1090 assert_eq!(log2_strict_usize(2), 1);
1091 assert_eq!(log2_strict_usize(1 << 18), 18);
1092 assert_eq!(log2_strict_usize(1 << 31), 31);
1093 assert_eq!(
1094 log2_strict_usize(1 << (usize::BITS - 1)),
1095 usize::BITS as usize - 1
1096 );
1097 }
1098
1099 #[test]
1100 fn test_checked_pow2() {
1101 // 2^0 = 1, the smallest valid exponent.
1102 assert_eq!(checked_pow2(0), Some(1));
1103
1104 // 2^1 = 2.
1105 assert_eq!(checked_pow2(1), Some(2));
1106
1107 // 2^5 = 32, a typical small power.
1108 assert_eq!(checked_pow2(5), Some(32));
1109
1110 // 2^10 = 1024, commonly used as a domain size in FRI.
1111 assert_eq!(checked_pow2(10), Some(1024));
1112
1113 // 2^20 = 1_048_576, a realistic large trace length.
1114 assert_eq!(checked_pow2(20), Some(1_048_576));
1115
1116 // Largest representable power: 2^(BITS - 1).
1117 // On a 64-bit platform this is 2^63 = 0x8000_0000_0000_0000.
1118 let max_exp = usize::BITS as usize - 1;
1119 assert_eq!(checked_pow2(max_exp), Some(1usize << max_exp));
1120
1121 // Exponent equal to the bit width would shift 1 out of range.
1122 //
1123 // 1_usize << 64 (on 64-bit) → overflow
1124 //
1125 // Must return `None`.
1126 assert_eq!(checked_pow2(usize::BITS as usize), None);
1127
1128 // One past the maximum: also out of range.
1129 assert_eq!(checked_pow2(usize::BITS as usize + 1), None);
1130
1131 // Extreme exponent: usize::MAX is astronomically beyond
1132 // representable range — must return `None`.
1133 assert_eq!(checked_pow2(usize::MAX), None);
1134 }
1135
1136 #[test]
1137 fn test_checked_log_size_sum() {
1138 // Both zero: 0 + 0 = 0, 2^0 = 1.
1139 assert_eq!(checked_log_size_sum(0, 0), Some((0, 1)));
1140
1141 // Identity cases: adding zero to either side is a no-op.
1142 assert_eq!(checked_log_size_sum(5, 0), Some((5, 32)));
1143 assert_eq!(checked_log_size_sum(0, 10), Some((10, 1024)));
1144
1145 // Typical FRI scenario: degree_bits=10, log_quotient_chunks=2.
1146 //
1147 // 10 + 2 = 12, 2^12 = 4096
1148 assert_eq!(checked_log_size_sum(10, 2), Some((12, 4096)));
1149
1150 // Commutativity: order of operands must not matter.
1151 assert_eq!(checked_log_size_sum(2, 10), Some((12, 4096)));
1152
1153 // Large realistic case: degree_bits=20, log_chunks=3.
1154 //
1155 // 20 + 3 = 23, 2^23 = 8_388_608
1156 assert_eq!(checked_log_size_sum(20, 3), Some((23, 8_388_608)));
1157
1158 // Largest representable sum: (BITS - 2) + 1 = BITS - 1.
1159 let almost_max = usize::BITS as usize - 2;
1160 let max_exp = usize::BITS as usize - 1;
1161 assert_eq!(
1162 checked_log_size_sum(almost_max, 1),
1163 Some((max_exp, 1usize << max_exp))
1164 );
1165
1166 // Sum exactly at the bit width: overflows the shift.
1167 //
1168 // (BITS - 1) + 1 = BITS → 2^BITS is unrepresentable → None
1169 assert_eq!(checked_log_size_sum(max_exp, 1), None);
1170
1171 // Both operands large but sum still within range.
1172 //
1173 // 32 + 31 = 63 (on 64-bit) → 2^63 is representable
1174 let half = usize::BITS as usize / 2;
1175 let other_half = max_exp - half;
1176 assert_eq!(
1177 checked_log_size_sum(half, other_half),
1178 Some((max_exp, 1usize << max_exp))
1179 );
1180
1181 // Addition itself overflows usize, not just the shift.
1182 //
1183 // usize::MAX + 1 → checked_add returns None → None
1184 assert_eq!(checked_log_size_sum(usize::MAX, 1), None);
1185
1186 // Both operands at usize::MAX: addition doubly overflows.
1187 assert_eq!(checked_log_size_sum(usize::MAX, usize::MAX), None);
1188 }
1189
1190 #[test]
1191 #[should_panic]
1192 fn test_log2_strict_usize_zero() {
1193 let _ = log2_strict_usize(0);
1194 }
1195
1196 #[test]
1197 #[should_panic]
1198 fn test_log2_strict_usize_nonpower_2() {
1199 let _ = log2_strict_usize(0x78c341c65ae6d262);
1200 }
1201
1202 #[test]
1203 #[should_panic]
1204 fn test_log2_strict_usize_max() {
1205 let _ = log2_strict_usize(usize::MAX);
1206 }
1207
1208 #[test]
1209 fn test_log3_strict_powers_of_3() {
1210 // Test all powers of 3 up to 3^12 = 531441.
1211 assert_eq!(log3_strict_usize(1), 0);
1212 assert_eq!(log3_strict_usize(3), 1);
1213 assert_eq!(log3_strict_usize(9), 2);
1214 assert_eq!(log3_strict_usize(27), 3);
1215 assert_eq!(log3_strict_usize(81), 4);
1216 assert_eq!(log3_strict_usize(243), 5);
1217 assert_eq!(log3_strict_usize(729), 6);
1218 assert_eq!(log3_strict_usize(2187), 7);
1219 assert_eq!(log3_strict_usize(6561), 8);
1220 assert_eq!(log3_strict_usize(19683), 9);
1221 assert_eq!(log3_strict_usize(59049), 10);
1222 assert_eq!(log3_strict_usize(177_147), 11);
1223 assert_eq!(log3_strict_usize(531_441), 12);
1224 }
1225
1226 #[test]
1227 #[should_panic(expected = "input must be non-zero")]
1228 fn test_log3_strict_panics_on_zero() {
1229 let _ = log3_strict_usize(0);
1230 }
1231
1232 #[test]
1233 #[should_panic(expected = "is not a power of 3")]
1234 fn test_log3_strict_panics_on_non_power_of_3() {
1235 // 2 is not a power of 3.
1236 let _ = log3_strict_usize(2);
1237 }
1238
1239 #[test]
1240 #[should_panic(expected = "is not a power of 3")]
1241 fn test_log3_strict_panics_on_power_of_2() {
1242 // 8 = 2^3 is not a power of 3.
1243 let _ = log3_strict_usize(8);
1244 }
1245
1246 #[test]
1247 #[should_panic(expected = "is not a power of 3")]
1248 fn test_log3_strict_panics_on_product_with_other_primes() {
1249 // 6 = 2 * 3 is not a power of 3.
1250 let _ = log3_strict_usize(6);
1251 }
1252
1253 proptest! {
1254 #[test]
1255 fn test_log3_strict_roundtrip(k in 0u32..25u32) {
1256 // Roundtrip: 3^k -> log3_strict_usize -> k
1257 let n = 3usize.pow(k);
1258 assert_eq!(log3_strict_usize(n), k as usize);
1259 }
1260 }
1261
1262 #[test]
1263 fn test_log2_ceil_usize_comprehensive() {
1264 // Powers of 2
1265 assert_eq!(log2_ceil_usize(0), 0);
1266 assert_eq!(log2_ceil_usize(1), 0);
1267 assert_eq!(log2_ceil_usize(2), 1);
1268 assert_eq!(log2_ceil_usize(1 << 18), 18);
1269 assert_eq!(log2_ceil_usize(1 << 31), 31);
1270 assert_eq!(
1271 log2_ceil_usize(1 << (usize::BITS - 1)),
1272 usize::BITS as usize - 1
1273 );
1274
1275 // Nonpowers; want to round up
1276 assert_eq!(log2_ceil_usize(3), 2);
1277 assert_eq!(log2_ceil_usize(0x14fe901b), 29);
1278 assert_eq!(
1279 log2_ceil_usize((1 << (usize::BITS - 1)) + 1),
1280 usize::BITS as usize
1281 );
1282 assert_eq!(log2_ceil_usize(usize::MAX - 1), usize::BITS as usize);
1283 assert_eq!(log2_ceil_usize(usize::MAX), usize::BITS as usize);
1284 }
1285
1286 fn reverse_index_bits_naive<T: Copy>(arr: &[T]) -> Vec<T> {
1287 let n = arr.len();
1288 let n_power = log2_strict_usize(n);
1289
1290 let mut out = vec![None; n];
1291 for (i, v) in arr.iter().enumerate() {
1292 let dst = i.reverse_bits() >> (usize::BITS - n_power as u32);
1293 out[dst] = Some(*v);
1294 }
1295
1296 out.into_iter().map(|x| x.unwrap()).collect()
1297 }
1298
1299 #[test]
1300 fn test_relatively_prime_u64() {
1301 // Zero cases (should always return false)
1302 assert!(!relatively_prime_u64(0, 0));
1303 assert!(!relatively_prime_u64(10, 0));
1304 assert!(!relatively_prime_u64(0, 10));
1305 assert!(!relatively_prime_u64(0, 123456789));
1306
1307 // Number with itself (if greater than 1, not relatively prime)
1308 assert!(relatively_prime_u64(1, 1));
1309 assert!(!relatively_prime_u64(10, 10));
1310 assert!(!relatively_prime_u64(99999, 99999));
1311
1312 // Powers of 2 (always false since they share factor 2)
1313 assert!(!relatively_prime_u64(2, 4));
1314 assert!(!relatively_prime_u64(16, 32));
1315 assert!(!relatively_prime_u64(64, 128));
1316 assert!(!relatively_prime_u64(1024, 4096));
1317 assert!(!relatively_prime_u64(u64::MAX, u64::MAX));
1318
1319 // One number is a multiple of the other (always false)
1320 assert!(!relatively_prime_u64(5, 10));
1321 assert!(!relatively_prime_u64(12, 36));
1322 assert!(!relatively_prime_u64(15, 45));
1323 assert!(!relatively_prime_u64(100, 500));
1324
1325 // Co-prime numbers (should be true)
1326 assert!(relatively_prime_u64(17, 31));
1327 assert!(relatively_prime_u64(97, 43));
1328 assert!(relatively_prime_u64(7919, 65537));
1329 assert!(relatively_prime_u64(15485863, 32452843));
1330
1331 // Small prime numbers (should be true)
1332 assert!(relatively_prime_u64(13, 17));
1333 assert!(relatively_prime_u64(101, 103));
1334 assert!(relatively_prime_u64(1009, 1013));
1335
1336 // Large numbers (some cases where they are relatively prime or not)
1337 assert!(!relatively_prime_u64(
1338 190266297176832000,
1339 10430732356495263744
1340 ));
1341 assert!(!relatively_prime_u64(
1342 2040134905096275968,
1343 5701159354248194048
1344 ));
1345 assert!(!relatively_prime_u64(
1346 16611311494648745984,
1347 7514969329383038976
1348 ));
1349 assert!(!relatively_prime_u64(
1350 14863931409971066880,
1351 7911906750992527360
1352 ));
1353
1354 // Max values
1355 assert!(relatively_prime_u64(u64::MAX, 1));
1356 assert!(relatively_prime_u64(u64::MAX, u64::MAX - 1));
1357 assert!(!relatively_prime_u64(u64::MAX, u64::MAX));
1358 }
1359}