Skip to main content

lib_q_stark_util/
lib.rs

1//! Various simple utilities.
2
3#![no_std]
4
5#[cfg(all(feature = "parallel", target_arch = "wasm32"))]
6compile_error!(
7    "parallel feature is not supported on wasm32; do not enable 'parallel' for WASM builds"
8);
9
10extern crate alloc;
11
12use alloc::slice;
13use alloc::string::String;
14use alloc::vec::Vec;
15use core::any::type_name;
16use core::hint::unreachable_unchecked;
17use core::mem::{
18    ManuallyDrop,
19    MaybeUninit,
20};
21use core::{
22    iter,
23    mem,
24};
25
26use crate::transpose::transpose_in_place_square;
27
28pub mod array_serialization;
29pub mod linear_map;
30pub mod transpose;
31pub mod zip_eq;
32
33/// Computes `ceil(log_2(n))`.
34#[must_use]
35pub const fn log2_ceil_usize(n: usize) -> usize {
36    (usize::BITS - n.saturating_sub(1).leading_zeros()) as usize
37}
38
39#[must_use]
40pub fn log2_ceil_u64(n: u64) -> u64 {
41    (u64::BITS - n.saturating_sub(1).leading_zeros()).into()
42}
43
44/// Computes `log_2(n)`
45///
46/// # Panics
47/// Panics if `n` is not a power of two.
48#[must_use]
49#[inline]
50pub fn log2_strict_usize(n: usize) -> usize {
51    let res = n.trailing_zeros();
52    assert_eq!(n.wrapping_shr(res), 1, "Not a power of two: {n}");
53    // Tell the optimizer about the semantics of `log2_strict`. i.e. it can replace `n` with
54    // `1 << res` and vice versa.
55    unsafe {
56        assume(n == 1 << res);
57    }
58    res as usize
59}
60
61/// Returns `[0, ..., N - 1]`.
62#[must_use]
63pub const fn indices_arr<const N: usize>() -> [usize; N] {
64    let mut indices_arr = [0; N];
65    let mut i = 0;
66    while i < N {
67        indices_arr[i] = i;
68        i += 1;
69    }
70    indices_arr
71}
72
73#[inline]
74pub const fn reverse_bits(x: usize, n: usize) -> usize {
75    // Assert that n is a power of 2
76    debug_assert!(n.is_power_of_two());
77    reverse_bits_len(x, n.trailing_zeros() as usize)
78}
79
80#[inline]
81pub const fn reverse_bits_len(x: usize, bit_len: usize) -> usize {
82    // NB: The only reason we need overflowing_shr() here as opposed
83    // to plain '>>' is to accommodate the case n == num_bits == 0,
84    // which would become `0 >> 64`. Rust thinks that any shift of 64
85    // bits causes overflow, even when the argument is zero.
86    x.reverse_bits()
87        .overflowing_shr(usize::BITS - bit_len as u32)
88        .0
89}
90
91// Lookup table of 6-bit reverses.
92// NB: 2^6=64 bytes is a cache line. A smaller table wastes cache space.
93#[cfg(not(target_arch = "aarch64"))]
94#[rustfmt::skip]
95const BIT_REVERSE_6BIT: &[u8] = &[
96    0o00, 0o40, 0o20, 0o60, 0o10, 0o50, 0o30, 0o70,
97    0o04, 0o44, 0o24, 0o64, 0o14, 0o54, 0o34, 0o74,
98    0o02, 0o42, 0o22, 0o62, 0o12, 0o52, 0o32, 0o72,
99    0o06, 0o46, 0o26, 0o66, 0o16, 0o56, 0o36, 0o76,
100    0o01, 0o41, 0o21, 0o61, 0o11, 0o51, 0o31, 0o71,
101    0o05, 0o45, 0o25, 0o65, 0o15, 0o55, 0o35, 0o75,
102    0o03, 0o43, 0o23, 0o63, 0o13, 0o53, 0o33, 0o73,
103    0o07, 0o47, 0o27, 0o67, 0o17, 0o57, 0o37, 0o77,
104];
105
106// Ensure that SMALL_ARR_SIZE >= 4 * BIG_T_SIZE.
107const BIG_T_SIZE: usize = 1 << 14;
108const SMALL_ARR_SIZE: usize = 1 << 16;
109
110/// Permutes `arr` such that each index is mapped to its reverse in binary.
111///
112/// If the whole array fits in fast cache, then the trivial algorithm is cache friendly. Also, if
113/// `T` is really big, then the trivial algorithm is cache-friendly, no matter the size of the array.
114pub fn reverse_slice_index_bits<F>(vals: &mut [F])
115where
116    F: Copy + Send + Sync,
117{
118    let n = vals.len();
119    if n == 0 {
120        return;
121    }
122    let log_n = log2_strict_usize(n);
123
124    // If the whole array fits in fast cache, then the trivial algorithm is cache friendly. Also, if
125    // `T` is really big, then the trivial algorithm is cache-friendly, no matter the size of the array.
126    if core::mem::size_of::<F>() << log_n <= SMALL_ARR_SIZE ||
127        core::mem::size_of::<F>() >= BIG_T_SIZE
128    {
129        reverse_slice_index_bits_small(vals, log_n);
130    } else {
131        debug_assert!(n >= 4); // By our choice of `BIG_T_SIZE` and `SMALL_ARR_SIZE`.
132
133        // Algorithm:
134        //
135        // Treat `arr` as a `sqrt(n)` by `sqrt(n)` row-major matrix. (Assume for now that `lb_n` is
136        // even, i.e., `n` is a square number.) To perform bit-order reversal we:
137        //  1. Bit-reverse the order of the rows. (They are contiguous in memory, so this is
138        //     basically a series of large `memcpy`s.)
139        //  2. Transpose the matrix.
140        //  3. Bit-reverse the order of the rows.
141        //
142        // This is equivalent to, for every index `0 <= i < n`:
143        //  1. bit-reversing `i[lb_n / 2..lb_n]`,
144        //  2. swapping `i[0..lb_n / 2]` and `i[lb_n / 2..lb_n]`,
145        //  3. bit-reversing `i[lb_n / 2..lb_n]`.
146        //
147        // If `lb_n` is odd, i.e., `n` is not a square number, then the above procedure requires
148        // slight modification. At steps 1 and 3 we bit-reverse bits `ceil(lb_n / 2)..lb_n`, of the
149        // index (shuffling `floor(lb_n / 2)` chunks of length `ceil(lb_n / 2)`). At step 2, we
150        // perform _two_ transposes. We treat `arr` as two matrices, one where the middle bit of the
151        // index is `0` and another, where the middle bit is `1`; we transpose each individually.
152
153        let lb_num_chunks = log_n >> 1;
154        let lb_chunk_size = log_n - lb_num_chunks;
155        unsafe {
156            reverse_slice_index_bits_chunks(vals, lb_num_chunks, lb_chunk_size);
157            transpose_in_place_square(vals, lb_chunk_size, lb_num_chunks, 0);
158            if lb_num_chunks != lb_chunk_size {
159                // `arr` cannot be interpreted as a square matrix. We instead interpret it as a
160                // `1 << lb_num_chunks` by `2` by `1 << lb_num_chunks` tensor, in row-major order.
161                // The above transpose acted on `tensor[..., 0, ...]` (all indices with middle bit
162                // `0`). We still need to transpose `tensor[..., 1, ...]`. To do so, we advance
163                // arr by `1 << lb_num_chunks` effectively, adding that to every index.
164                let vals_with_offset = &mut vals[1 << lb_num_chunks..];
165                transpose_in_place_square(vals_with_offset, lb_chunk_size, lb_num_chunks, 0);
166            }
167            reverse_slice_index_bits_chunks(vals, lb_num_chunks, lb_chunk_size);
168        }
169    }
170}
171
172// Both functions below are semantically equivalent to:
173//     for i in 0..n {
174//         result.push(arr[reverse_bits(i, n_power)]);
175//     }
176// where reverse_bits(i, n_power) computes the n_power-bit reverse. The complications are there
177// to guide the compiler to generate optimal assembly.
178
179#[cfg(not(target_arch = "aarch64"))]
180fn reverse_slice_index_bits_small<F>(vals: &mut [F], lb_n: usize) {
181    if lb_n <= 6 {
182        // BIT_REVERSE_6BIT holds 6-bit reverses. This shift makes them lb_n-bit reverses.
183        let dst_shr_amt = 6 - lb_n as u32;
184        #[allow(clippy::needless_range_loop)]
185        for src in 0..vals.len() {
186            let dst = (BIT_REVERSE_6BIT[src] as usize).wrapping_shr(dst_shr_amt);
187            if src < dst {
188                vals.swap(src, dst);
189            }
190        }
191    } else {
192        // LLVM does not know that it does not need to reverse src at each iteration (which is
193        // expensive on x86). We take advantage of the fact that the low bits of dst change rarely and the high
194        // bits of dst are dependent only on the low bits of src.
195        let dst_lo_shr_amt = usize::BITS - (lb_n - 6) as u32;
196        let dst_hi_shl_amt = lb_n - 6;
197        for src_chunk in 0..(vals.len() >> 6) {
198            let src_hi = src_chunk << 6;
199            let dst_lo = src_chunk.reverse_bits().wrapping_shr(dst_lo_shr_amt);
200            #[allow(clippy::needless_range_loop)]
201            for src_lo in 0..(1 << 6) {
202                let dst_hi = (BIT_REVERSE_6BIT[src_lo] as usize) << dst_hi_shl_amt;
203                let src = src_hi + src_lo;
204                let dst = dst_hi + dst_lo;
205                if src < dst {
206                    vals.swap(src, dst);
207                }
208            }
209        }
210    }
211}
212
213#[cfg(target_arch = "aarch64")]
214fn reverse_slice_index_bits_small<F>(vals: &mut [F], lb_n: usize) {
215    // Aarch64 can reverse bits in one instruction, so the trivial version works best.
216    for src in 0..vals.len() {
217        let dst = src.reverse_bits().wrapping_shr(usize::BITS - lb_n as u32);
218        if src < dst {
219            vals.swap(src, dst);
220        }
221    }
222}
223
224/// Split `arr` chunks and bit-reverse the order of the chunks. There are `1 << lb_num_chunks`
225/// chunks, each of length `1 << lb_chunk_size`.
226/// SAFETY: ensure that `arr.len() == 1 << lb_num_chunks + lb_chunk_size`.
227unsafe fn reverse_slice_index_bits_chunks<F>(
228    vals: &mut [F],
229    lb_num_chunks: usize,
230    lb_chunk_size: usize,
231) {
232    for i in 0..1usize << lb_num_chunks {
233        // `wrapping_shr` handles the silly case when `lb_num_chunks == 0`.
234        let j = i
235            .reverse_bits()
236            .wrapping_shr(usize::BITS - lb_num_chunks as u32);
237        if i < j {
238            unsafe {
239                core::ptr::swap_nonoverlapping(
240                    vals.get_unchecked_mut(i << lb_chunk_size),
241                    vals.get_unchecked_mut(j << lb_chunk_size),
242                    1 << lb_chunk_size,
243                );
244            }
245        }
246    }
247}
248
249/// Allow the compiler to assume that the given predicate `p` is always `true`.
250///
251/// # Safety
252///
253/// Callers must ensure that `p` is true. If this is not the case, the behavior is undefined.
254#[inline(always)]
255pub unsafe fn assume(p: bool) {
256    debug_assert!(p);
257    if !p {
258        unsafe {
259            unreachable_unchecked();
260        }
261    }
262}
263
264/// Try to force Rust to emit a branch. Example:
265///
266/// ```no_run
267/// let x = 100;
268/// if x > 20 {
269///     println!("x is big!");
270///     lib_q_stark_util::branch_hint();
271/// } else {
272///     println!("x is small!");
273/// }
274/// ```
275///
276/// This function has no semantics. It is a hint only.
277#[inline(always)]
278pub fn branch_hint() {
279    // NOTE: These are the currently supported assembly architectures. See the
280    // [nightly reference](https://doc.rust-lang.org/nightly/reference/inline-assembly.html) for
281    // the most up-to-date list.
282    #[cfg(any(
283        target_arch = "aarch64",
284        target_arch = "arm",
285        target_arch = "riscv32",
286        target_arch = "riscv64",
287        target_arch = "x86",
288        target_arch = "x86_64",
289    ))]
290    unsafe {
291        core::arch::asm!("", options(nomem, nostack, preserves_flags));
292    }
293}
294
295/// Return a String containing the name of T but with all the crate
296/// and module prefixes removed.
297pub fn pretty_name<T>() -> String {
298    let name = type_name::<T>();
299    let mut result = String::new();
300    for qual in name.split_inclusive(&['<', '>', ',']) {
301        result.push_str(qual.split("::").last().unwrap());
302    }
303    result
304}
305
306/// A C-style buffered input reader, similar to
307/// `core::iter::Iterator::next_chunk()` from nightly.
308///
309/// Returns an array of `MaybeUninit<T>` and the number of items in the
310/// array which have been correctly initialized.
311#[inline]
312fn iter_next_chunk_erased<const BUFLEN: usize, I: Iterator>(
313    iter: &mut I,
314) -> ([MaybeUninit<I::Item>; BUFLEN], usize)
315where
316    I::Item: Copy,
317{
318    let mut buf = [const { MaybeUninit::<I::Item>::uninit() }; BUFLEN];
319    let mut i = 0;
320
321    while i < BUFLEN {
322        if let Some(c) = iter.next() {
323            // Copy the next Item into `buf`.
324            unsafe {
325                buf.get_unchecked_mut(i).write(c);
326                i = i.unchecked_add(1);
327            }
328        } else {
329            // No more items in the iterator.
330            break;
331        }
332    }
333    (buf, i)
334}
335
336/// Gets a shared reference to the contained value.
337///
338/// # Safety
339///
340/// Calling this when the content is not yet fully initialized causes undefined
341/// behavior: it is up to the caller to guarantee that every `MaybeUninit<T>` in
342/// the slice really is in an initialized state.
343///
344/// Copied from:
345/// <https://doc.rust-lang.org/std/primitive.slice.html#method.assume_init_ref>
346/// Once that is stabilized, this should be removed.
347#[inline(always)]
348pub const unsafe fn assume_init_ref<T>(slice: &[MaybeUninit<T>]) -> &[T] {
349    // SAFETY: casting `slice` to a `*const [T]` is safe since the caller guarantees that
350    // `slice` is initialized, and `MaybeUninit` is guaranteed to have the same layout as `T`.
351    // The pointer obtained is valid since it refers to memory owned by `slice` which is a
352    // reference and thus guaranteed to be valid for reads.
353    unsafe { &*(slice as *const [MaybeUninit<T>] as *const [T]) }
354}
355
356/// Split an iterator into small arrays and apply `func` to each.
357///
358/// Repeatedly read `BUFLEN` elements from `input` into an array and
359/// pass the array to `func` as a slice. If less than `BUFLEN`
360/// elements are remaining, that smaller slice is passed to `func` (if
361/// it is non-empty) and the function returns.
362#[inline]
363pub fn apply_to_chunks<const BUFLEN: usize, I, H>(input: I, mut func: H)
364where
365    I: IntoIterator<Item = u8>,
366    H: FnMut(&[I::Item]),
367{
368    let mut iter = input.into_iter();
369    loop {
370        let (buf, n) = iter_next_chunk_erased::<BUFLEN, _>(&mut iter);
371        if n == 0 {
372            break;
373        }
374        func(unsafe { assume_init_ref(buf.get_unchecked(..n)) });
375    }
376}
377
378/// Pulls `N` items from `iter` and returns them as an array. If the iterator
379/// yields fewer than `N` items (but more than `0`), pads by the given default value.
380///
381/// Since the iterator is passed as a mutable reference and this function calls
382/// `next` at most `N` times, the iterator can still be used afterwards to
383/// retrieve the remaining items.
384///
385/// If `iter.next()` panics, all items already yielded by the iterator are
386/// dropped.
387#[inline]
388fn iter_next_chunk_padded<T: Copy, const N: usize>(
389    iter: &mut impl Iterator<Item = T>,
390    default: T, // Needed due to [T; M] not always implementing Default. Can probably be dropped if const generics stabilize.
391) -> Option<[T; N]> {
392    let (mut arr, n) = iter_next_chunk_erased::<N, _>(iter);
393    (n != 0).then(|| {
394        // Fill the rest of the array with default values.
395        arr[n..].fill(MaybeUninit::new(default));
396        unsafe { mem::transmute_copy::<_, [T; N]>(&arr) }
397    })
398}
399
400/// Returns an iterator over `N` elements of the iterator at a time.
401///
402/// The chunks do not overlap. If `N` does not divide the length of the
403/// iterator, then the last `N-1` elements will be padded with the given default value.
404///
405/// This is essentially a copy pasted version of the nightly `array_chunks` function.
406/// <https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.array_chunks>
407/// Once that is stabilized this and the functions above it should be removed.
408#[inline]
409pub fn iter_array_chunks_padded<T: Copy, const N: usize>(
410    iter: impl IntoIterator<Item = T>,
411    default: T, // Needed due to [T; M] not always implementing Default. Can probably be dropped if const generics stabilize.
412) -> impl Iterator<Item = [T; N]> {
413    let mut iter = iter.into_iter();
414    iter::from_fn(move || iter_next_chunk_padded(&mut iter, default))
415}
416
417/// Reinterpret a slice of `BaseArray` elements as a slice of `Base` elements
418///
419/// This is useful to convert `&[F; N]` to `&[F]` or `&[A]` to `&[F]` where
420/// `A` has the same size, alignment and memory layout as `[F; N]` for some `N`.
421///
422/// # Safety
423///
424/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
425/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
426/// the array is the same as the alignment of its elements, this means that `BaseArray`
427/// must have the same alignment as `Base`.
428///
429/// # Panics
430///
431/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
432#[inline]
433pub const unsafe fn as_base_slice<Base, BaseArray>(buf: &[BaseArray]) -> &[Base] {
434    const {
435        assert!(align_of::<Base>() == align_of::<BaseArray>());
436        assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
437    }
438
439    let d = size_of::<BaseArray>() / size_of::<Base>();
440
441    let buf_ptr = buf.as_ptr().cast::<Base>();
442    let n = buf.len() * d;
443    unsafe { slice::from_raw_parts(buf_ptr, n) }
444}
445
446/// Reinterpret a mutable slice of `BaseArray` elements as a slice of `Base` elements
447///
448/// This is useful to convert `&[F; N]` to `&[F]` or `&[A]` to `&[F]` where
449/// `A` has the same size, alignment and memory layout as `[F; N]` for some `N`.
450///
451/// # Safety
452///
453/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
454/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
455/// the array is the same as the alignment of its elements, this means that `BaseArray`
456/// must have the same alignment as `Base`.
457///
458/// # Panics
459///
460/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
461#[inline]
462pub const unsafe fn as_base_slice_mut<Base, BaseArray>(buf: &mut [BaseArray]) -> &mut [Base] {
463    const {
464        assert!(align_of::<Base>() == align_of::<BaseArray>());
465        assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
466    }
467
468    let d = size_of::<BaseArray>() / size_of::<Base>();
469
470    let buf_ptr = buf.as_mut_ptr().cast::<Base>();
471    let n = buf.len() * d;
472    unsafe { slice::from_raw_parts_mut(buf_ptr, n) }
473}
474
475/// Convert a vector of `BaseArray` elements to a vector of `Base` elements without any
476/// reallocations.
477///
478/// This is useful to convert `Vec<[F; N]>` to `Vec<F>` or `Vec<A>` to `Vec<F>` where
479/// `A` has the same size, alignment and memory layout as `[F; N]` for some `N`. It can also,
480/// be used to safely convert `Vec<u32>` to `Vec<F>` if `F` is a `32` bit field
481/// or `Vec<u64>` to `Vec<F>` if `F` is a `64` bit field.
482///
483/// # Safety
484///
485/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
486/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
487/// the array is the same as the alignment of its elements, this means that `BaseArray`
488/// must have the same alignment as `Base`.
489///
490/// # Panics
491///
492/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
493#[inline]
494pub unsafe fn flatten_to_base<Base, BaseArray>(vec: Vec<BaseArray>) -> Vec<Base> {
495    const {
496        assert!(align_of::<Base>() == align_of::<BaseArray>());
497        assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
498    }
499
500    let d = size_of::<BaseArray>() / size_of::<Base>();
501    // Prevent running `vec`'s destructor so we are in complete control
502    // of the allocation.
503    let mut values = ManuallyDrop::new(vec);
504
505    // Each `Self` is an array of `d` elements, so the length and capacity of
506    // the new vector will be multiplied by `d`.
507    let new_len = values.len() * d;
508    let new_cap = values.capacity() * d;
509
510    // Safe as BaseArray and Base have the same alignment.
511    let ptr = values.as_mut_ptr() as *mut Base;
512
513    unsafe {
514        // Safety:
515        // - BaseArray and Base have the same alignment.
516        // - As size_of::<BaseArray>() == size_of::<Base>() * d:
517        //      -- The capacity of the new vector is equal to the capacity of the old vector.
518        //      -- The first new_len elements of the new vector correspond to the first
519        //         len elements of the old vector and so are properly initialized.
520        Vec::from_raw_parts(ptr, new_len, new_cap)
521    }
522}
523
524/// Convert a vector of `Base` elements to a vector of `BaseArray` elements ideally without any
525/// reallocations.
526///
527/// This is an inverse of `flatten_to_base`. Unfortunately, unlike `flatten_to_base`, it may not be
528/// possible to avoid allocations. This issue is that there is not way to guarantee that the capacity
529/// of the vector is a multiple of `d`.
530///
531/// # Safety
532///
533/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
534/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
535/// the array is the same as the alignment of its elements, this means that `BaseArray`
536/// must have the same alignment as `Base`.
537///
538/// # Panics
539///
540/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
541/// This panics if the length of the vector is not a multiple of the ratio of the sizes.
542#[inline]
543pub unsafe fn reconstitute_from_base<Base, BaseArray: Clone>(mut vec: Vec<Base>) -> Vec<BaseArray> {
544    const {
545        assert!(align_of::<Base>() == align_of::<BaseArray>());
546        assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
547    }
548
549    let d = size_of::<BaseArray>() / size_of::<Base>();
550
551    assert!(
552        vec.len().is_multiple_of(d),
553        "Vector length (got {}) must be a multiple of the extension field dimension ({}).",
554        vec.len(),
555        d
556    );
557
558    let new_len = vec.len() / d;
559
560    // We could call vec.shrink_to_fit() here to try and increase the probability that
561    // the capacity is a multiple of d. That might cause a reallocation though which
562    // would defeat the whole purpose.
563    let cap = vec.capacity();
564
565    // The assumption is that basically all callers of `reconstitute_from_base_vec` will be calling it
566    // with a vector constructed from `flatten_to_base` and so the capacity should be a multiple of `d`.
567    // But capacities can do strange things so we need to support both possibilities.
568    // Note that the `else` branch would also work if the capacity is a multiple of `d` but it is slower.
569    if cap.is_multiple_of(d) {
570        // Prevent running `vec`'s destructor so we are in complete control
571        // of the allocation.
572        let mut values = ManuallyDrop::new(vec);
573
574        // If we are on this branch then the capacity is a multiple of `d`.
575        let new_cap = cap / d;
576
577        // Safe as BaseArray and Base have the same alignment.
578        let ptr = values.as_mut_ptr() as *mut BaseArray;
579
580        unsafe {
581            // Safety:
582            // - BaseArray and Base have the same alignment.
583            // - As size_of::<Base>() == size_of::<BaseArray>() / d:
584            //      -- If we have reached this point, the length and capacity are both divisible by `d`.
585            //      -- The capacity of the new vector is equal to the capacity of the old vector.
586            //      -- The first new_len elements of the new vector correspond to the first
587            //         len elements of the old vector and so are properly initialized.
588            Vec::from_raw_parts(ptr, new_len, new_cap)
589        }
590    } else {
591        // If the capacity is not a multiple of `D`, we go via slices.
592
593        let buf_ptr = vec.as_mut_ptr().cast::<BaseArray>();
594        let slice = unsafe {
595            // Safety:
596            // - BaseArray and Base have the same alignment.
597            // - As size_of::<Base>() == size_of::<BaseArray>() / D:
598            //      -- If we have reached this point, the length is divisible by `D`.
599            //      -- The first new_len elements of the slice correspond to the first
600            //         len elements of the old slice and so are properly initialized.
601            slice::from_raw_parts(buf_ptr, new_len)
602        };
603
604        // Ideally the compiler could optimize this away to avoid the copy but it appears not to.
605        slice.to_vec()
606    }
607}
608
609#[inline(always)]
610pub const fn relatively_prime_u64(mut u: u64, mut v: u64) -> bool {
611    // Check that neither input is 0.
612    if u == 0 || v == 0 {
613        return false;
614    }
615
616    // Check divisibility by 2.
617    if (u | v) & 1 == 0 {
618        return false;
619    }
620
621    // Remove factors of 2 from `u` and `v`
622    u >>= u.trailing_zeros();
623    if u == 1 {
624        return true;
625    }
626
627    while v != 0 {
628        v >>= v.trailing_zeros();
629        if v == 1 {
630            return true;
631        }
632
633        // Ensure u <= v
634        if u > v {
635            core::mem::swap(&mut u, &mut v);
636        }
637
638        // This looks inefficient for v >> u but thanks to the fact that we remove
639        // trailing_zeros of v in every iteration, it ends up much more performative
640        // than first glance implies.
641        v -= u;
642    }
643    // If we made it through the loop, at no point is u or v equal to 1 and so the gcd
644    // must be greater than 1.
645    false
646}
647
648/// Inner loop of the deferred GCD algorithm.
649///
650/// See: <https://eprint.iacr.org/2020/972.pdf> for more information.
651///
652/// This is basically a mini GCD algorithm which builds up a transformation to apply to the larger
653/// numbers in the main loop. The key point is that this small loop only uses u64s, subtractions and
654/// bit shifts, which are very fast operations.
655///
656/// The bottom `NUM_ROUNDS` bits of `a` and `b` should match the bottom `NUM_ROUNDS` bits of
657/// the corresponding big-ints and the top `NUM_ROUNDS + 2` should match the top bits including
658/// zeroes if the original numbers have different sizes.
659#[inline]
660pub fn gcd_inner<const NUM_ROUNDS: usize>(a: &mut u64, b: &mut u64) -> (i64, i64, i64, i64) {
661    // Initialise update factors.
662    // At the start of round 0: -1 < f0, g0, f1, g1 <= 1
663    let (mut f0, mut g0, mut f1, mut g1) = (1, 0, 0, 1);
664
665    // If at the start of a round: -2^i < f0, g0, f1, g1 <= 2^i
666    // Then, at the end of the round: -2^{i + 1} < f0, g0, f1, g1 <= 2^{i + 1}
667    for _ in 0..NUM_ROUNDS {
668        if *a & 1 == 0 {
669            *a >>= 1;
670        } else {
671            if a < b {
672                core::mem::swap(a, b);
673                (f0, f1) = (f1, f0);
674                (g0, g1) = (g1, g0);
675            }
676            *a -= *b;
677            *a >>= 1;
678            f0 -= f1;
679            g0 -= g1;
680        }
681        f1 <<= 1;
682        g1 <<= 1;
683    }
684
685    // -2^NUM_ROUNDS < f0, g0, f1, g1 <= 2^NUM_ROUNDS
686    // Hence provided NUM_ROUNDS <= 62, we will not get any overflow.
687    // Additionally, if NUM_ROUNDS <= 63, then the only source of overflow will be
688    // if a variable is meant to equal 2^{63} in which case it will overflow to -2^{63}.
689    (f0, g0, f1, g1)
690}
691
692/// Inverts elements inside the prime field `F_P` with `P < 2^FIELD_BITS`.
693///
694/// Arguments:
695///  - a: The value we want to invert. It must be < P.
696///  - b: The value of the prime `P > 2`.
697///
698/// Output:
699/// - A `64-bit` signed integer `v` equal to `2^{2 * FIELD_BITS - 2} a^{-1} mod P` with
700///   size `|v| < 2^{2 * FIELD_BITS - 2}`.
701///
702/// It is up to the user to ensure that `b` is an odd prime with at most `FIELD_BITS` bits and
703/// `a < b`. If either of these assumptions break, the output is undefined.
704#[inline]
705pub fn gcd_inversion_prime_field_32<const FIELD_BITS: u32>(mut a: u32, mut b: u32) -> i64 {
706    const {
707        assert!(FIELD_BITS <= 32);
708    }
709    debug_assert!(((1_u64 << FIELD_BITS) - 1) >= b as u64);
710
711    // Initialise u, v. Note that |u|, |v| <= 2^0
712    let (mut u, mut v) = (1_i64, 0_i64);
713
714    // Let a0 and P denote the initial values of a and b. Observe:
715    // `a = u * a0 mod P`
716    // `b = v * a0 mod P`
717    // `len(a) + len(b) <= 2 * len(P) <= 2 * FIELD_BITS`
718
719    for _ in 0..(2 * FIELD_BITS - 2) {
720        // Assume at the start of the loop i:
721        // (1) `|u|, |v| <= 2^{i}`
722        // (2) `2^i * a = u * a0 mod P`
723        // (3) `2^i * b = v * a0 mod P`
724        // (4) `gcd(a, b) = 1`
725        // (5) `b` is odd.
726        // (6) `len(a) + len(b) <= max(n - i, 1)`
727
728        if a & 1 != 0 {
729            if a < b {
730                (a, b) = (b, a);
731                (u, v) = (v, u);
732            }
733            // As b < a, this subtraction cannot increase `len(a) + len(b)`
734            a -= b;
735            // Observe |u'| = |u - v| <= |u| + |v| <= 2^{i + 1}
736            u -= v;
737
738            // As (1) and (2) hold, we have
739            // `2^i a' = 2^i * (a - b) = (u - v) * a0 mod P = u' * a0 mod P`
740        }
741        // As b is odd, a must now be even.
742        // This reduces `len(a) + len(b)` by 1 (unless `a = 0` in which case `b = 1` and the sum of the lengths is always 1)
743        a >>= 1;
744
745        // Observe |v'| = 2|v| <= 2^{i + 1}
746        v <<= 1;
747
748        // Thus as the end of loop i:
749        // (1) `|u|, |v| <= 2^{i + 1}`
750        // (2) `2^{i + 1} * a = u * a0 mod P`  (As we have halved a)
751        // (3) `2^{i + 1} * b = v * a0 mod P`  (As we have doubled v)
752        // (4) `gcd(a, b) = 1`
753        // (5) `b` is odd.
754        // (6) `len(a) + len(b) <= max(n - i - 1, 1)`
755    }
756
757    // After the loops, we see that:
758    // |u|, |v| <= 2^{2 * FIELD_BITS - 2}: Hence for FIELD_BITS <= 32 we will not overflow an i64.
759    // `2^{2 * FIELD_BITS - 2} * b = v * a0 mod P`
760    // `len(a) + len(b) <= 2` with `gcd(a, b) = 1` and `b` odd.
761    // This implies that `b` must be `1` and so `v = 2^{2 * FIELD_BITS - 2} a0^{-1} mod P` as desired.
762    v
763}
764
765#[cfg(test)]
766mod tests {
767    use alloc::vec;
768    use alloc::vec::Vec;
769
770    use rand::rngs::SmallRng;
771    use rand::{
772        RngExt,
773        SeedableRng,
774    };
775
776    use super::*;
777
778    #[test]
779    fn test_reverse_bits_len() {
780        assert_eq!(reverse_bits_len(0b0000000000, 10), 0b0000000000);
781        assert_eq!(reverse_bits_len(0b0000000001, 10), 0b1000000000);
782        assert_eq!(reverse_bits_len(0b1000000000, 10), 0b0000000001);
783        assert_eq!(reverse_bits_len(0b00000, 5), 0b00000);
784        assert_eq!(reverse_bits_len(0b01011, 5), 0b11010);
785    }
786
787    #[test]
788    fn test_reverse_index_bits() {
789        let mut arg = vec![10, 20, 30, 40];
790        reverse_slice_index_bits(&mut arg);
791        assert_eq!(arg, vec![10, 30, 20, 40]);
792
793        let mut input256: Vec<u64> = (0..256).collect();
794        #[rustfmt::skip]
795        let output256: Vec<u64> = vec![
796            0x00, 0x80, 0x40, 0xc0, 0x20, 0xa0, 0x60, 0xe0, 0x10, 0x90, 0x50, 0xd0, 0x30, 0xb0, 0x70, 0xf0,
797            0x08, 0x88, 0x48, 0xc8, 0x28, 0xa8, 0x68, 0xe8, 0x18, 0x98, 0x58, 0xd8, 0x38, 0xb8, 0x78, 0xf8,
798            0x04, 0x84, 0x44, 0xc4, 0x24, 0xa4, 0x64, 0xe4, 0x14, 0x94, 0x54, 0xd4, 0x34, 0xb4, 0x74, 0xf4,
799            0x0c, 0x8c, 0x4c, 0xcc, 0x2c, 0xac, 0x6c, 0xec, 0x1c, 0x9c, 0x5c, 0xdc, 0x3c, 0xbc, 0x7c, 0xfc,
800            0x02, 0x82, 0x42, 0xc2, 0x22, 0xa2, 0x62, 0xe2, 0x12, 0x92, 0x52, 0xd2, 0x32, 0xb2, 0x72, 0xf2,
801            0x0a, 0x8a, 0x4a, 0xca, 0x2a, 0xaa, 0x6a, 0xea, 0x1a, 0x9a, 0x5a, 0xda, 0x3a, 0xba, 0x7a, 0xfa,
802            0x06, 0x86, 0x46, 0xc6, 0x26, 0xa6, 0x66, 0xe6, 0x16, 0x96, 0x56, 0xd6, 0x36, 0xb6, 0x76, 0xf6,
803            0x0e, 0x8e, 0x4e, 0xce, 0x2e, 0xae, 0x6e, 0xee, 0x1e, 0x9e, 0x5e, 0xde, 0x3e, 0xbe, 0x7e, 0xfe,
804            0x01, 0x81, 0x41, 0xc1, 0x21, 0xa1, 0x61, 0xe1, 0x11, 0x91, 0x51, 0xd1, 0x31, 0xb1, 0x71, 0xf1,
805            0x09, 0x89, 0x49, 0xc9, 0x29, 0xa9, 0x69, 0xe9, 0x19, 0x99, 0x59, 0xd9, 0x39, 0xb9, 0x79, 0xf9,
806            0x05, 0x85, 0x45, 0xc5, 0x25, 0xa5, 0x65, 0xe5, 0x15, 0x95, 0x55, 0xd5, 0x35, 0xb5, 0x75, 0xf5,
807            0x0d, 0x8d, 0x4d, 0xcd, 0x2d, 0xad, 0x6d, 0xed, 0x1d, 0x9d, 0x5d, 0xdd, 0x3d, 0xbd, 0x7d, 0xfd,
808            0x03, 0x83, 0x43, 0xc3, 0x23, 0xa3, 0x63, 0xe3, 0x13, 0x93, 0x53, 0xd3, 0x33, 0xb3, 0x73, 0xf3,
809            0x0b, 0x8b, 0x4b, 0xcb, 0x2b, 0xab, 0x6b, 0xeb, 0x1b, 0x9b, 0x5b, 0xdb, 0x3b, 0xbb, 0x7b, 0xfb,
810            0x07, 0x87, 0x47, 0xc7, 0x27, 0xa7, 0x67, 0xe7, 0x17, 0x97, 0x57, 0xd7, 0x37, 0xb7, 0x77, 0xf7,
811            0x0f, 0x8f, 0x4f, 0xcf, 0x2f, 0xaf, 0x6f, 0xef, 0x1f, 0x9f, 0x5f, 0xdf, 0x3f, 0xbf, 0x7f, 0xff,
812        ];
813        reverse_slice_index_bits(&mut input256[..]);
814        assert_eq!(input256, output256);
815    }
816
817    #[test]
818    fn test_apply_to_chunks_exact_fit() {
819        const CHUNK_SIZE: usize = 4;
820        let input: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8];
821        let mut results: Vec<Vec<u8>> = Vec::new();
822
823        apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
824            results.push(chunk.to_vec());
825        });
826
827        assert_eq!(results, vec![vec![1, 2, 3, 4], vec![5, 6, 7, 8]]);
828    }
829
830    #[test]
831    fn test_apply_to_chunks_with_remainder() {
832        const CHUNK_SIZE: usize = 3;
833        let input: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7];
834        let mut results: Vec<Vec<u8>> = Vec::new();
835
836        apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
837            results.push(chunk.to_vec());
838        });
839
840        assert_eq!(results, vec![vec![1, 2, 3], vec![4, 5, 6], vec![7]]);
841    }
842
843    #[test]
844    fn test_apply_to_chunks_empty_input() {
845        const CHUNK_SIZE: usize = 4;
846        let input: Vec<u8> = vec![];
847        let mut results: Vec<Vec<u8>> = Vec::new();
848
849        apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
850            results.push(chunk.to_vec());
851        });
852
853        assert!(results.is_empty());
854    }
855
856    #[test]
857    fn test_apply_to_chunks_single_chunk() {
858        const CHUNK_SIZE: usize = 10;
859        let input: Vec<u8> = vec![1, 2, 3, 4, 5];
860        let mut results: Vec<Vec<u8>> = Vec::new();
861
862        apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
863            results.push(chunk.to_vec());
864        });
865
866        assert_eq!(results, vec![vec![1, 2, 3, 4, 5]]);
867    }
868
869    #[test]
870    fn test_apply_to_chunks_large_chunk_size() {
871        const CHUNK_SIZE: usize = 100;
872        let input: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8];
873        let mut results: Vec<Vec<u8>> = Vec::new();
874
875        apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
876            results.push(chunk.to_vec());
877        });
878
879        assert_eq!(results, vec![vec![1, 2, 3, 4, 5, 6, 7, 8]]);
880    }
881
882    #[test]
883    fn test_apply_to_chunks_large_input() {
884        const CHUNK_SIZE: usize = 5;
885        let input: Vec<u8> = (1..=20).collect();
886        let mut results: Vec<Vec<u8>> = Vec::new();
887
888        apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
889            results.push(chunk.to_vec());
890        });
891
892        assert_eq!(
893            results,
894            vec![
895                vec![1, 2, 3, 4, 5],
896                vec![6, 7, 8, 9, 10],
897                vec![11, 12, 13, 14, 15],
898                vec![16, 17, 18, 19, 20]
899            ]
900        );
901    }
902
903    #[test]
904    fn test_reverse_slice_index_bits_random() {
905        let lengths = [32, 128, 1 << 16];
906        let mut rng = SmallRng::seed_from_u64(1);
907        for _ in 0..32 {
908            for &length in &lengths {
909                let mut rand_list: Vec<u32> = Vec::with_capacity(length);
910                rand_list.resize_with(length, || rng.random());
911                let expect = reverse_index_bits_naive(&rand_list);
912
913                let mut actual = rand_list.clone();
914                reverse_slice_index_bits(&mut actual);
915
916                assert_eq!(actual, expect);
917            }
918        }
919    }
920
921    #[test]
922    fn test_log2_strict_usize_edge_cases() {
923        assert_eq!(log2_strict_usize(1), 0);
924        assert_eq!(log2_strict_usize(2), 1);
925        assert_eq!(log2_strict_usize(1 << 18), 18);
926        assert_eq!(log2_strict_usize(1 << 31), 31);
927        assert_eq!(
928            log2_strict_usize(1 << (usize::BITS - 1)),
929            usize::BITS as usize - 1
930        );
931    }
932
933    #[test]
934    #[should_panic]
935    fn test_log2_strict_usize_zero() {
936        let _ = log2_strict_usize(0);
937    }
938
939    #[test]
940    #[should_panic]
941    fn test_log2_strict_usize_nonpower_2() {
942        let _ = log2_strict_usize(0x78C341C65AE6D262);
943    }
944
945    #[test]
946    #[should_panic]
947    fn test_log2_strict_usize_max() {
948        let _ = log2_strict_usize(usize::MAX);
949    }
950
951    #[test]
952    fn test_log2_ceil_usize_comprehensive() {
953        // Powers of 2
954        assert_eq!(log2_ceil_usize(0), 0);
955        assert_eq!(log2_ceil_usize(1), 0);
956        assert_eq!(log2_ceil_usize(2), 1);
957        assert_eq!(log2_ceil_usize(1 << 18), 18);
958        assert_eq!(log2_ceil_usize(1 << 31), 31);
959        assert_eq!(
960            log2_ceil_usize(1 << (usize::BITS - 1)),
961            usize::BITS as usize - 1
962        );
963
964        // Nonpowers; want to round up
965        assert_eq!(log2_ceil_usize(3), 2);
966        assert_eq!(log2_ceil_usize(0x14FE901B), 29);
967        assert_eq!(
968            log2_ceil_usize((1 << (usize::BITS - 1)) + 1),
969            usize::BITS as usize
970        );
971        assert_eq!(log2_ceil_usize(usize::MAX - 1), usize::BITS as usize);
972        assert_eq!(log2_ceil_usize(usize::MAX), usize::BITS as usize);
973    }
974
975    fn reverse_index_bits_naive<T: Copy>(arr: &[T]) -> Vec<T> {
976        let n = arr.len();
977        let n_power = log2_strict_usize(n);
978
979        let mut out = vec![None; n];
980        for (i, v) in arr.iter().enumerate() {
981            let dst = i.reverse_bits() >> (usize::BITS - n_power as u32);
982            out[dst] = Some(*v);
983        }
984
985        out.into_iter().map(|x| x.unwrap()).collect()
986    }
987
988    #[test]
989    fn test_relatively_prime_u64() {
990        // Zero cases (should always return false)
991        assert!(!relatively_prime_u64(0, 0));
992        assert!(!relatively_prime_u64(10, 0));
993        assert!(!relatively_prime_u64(0, 10));
994        assert!(!relatively_prime_u64(0, 123456789));
995
996        // Number with itself (if greater than 1, not relatively prime)
997        assert!(relatively_prime_u64(1, 1));
998        assert!(!relatively_prime_u64(10, 10));
999        assert!(!relatively_prime_u64(99999, 99999));
1000
1001        // Powers of 2 (always false since they share factor 2)
1002        assert!(!relatively_prime_u64(2, 4));
1003        assert!(!relatively_prime_u64(16, 32));
1004        assert!(!relatively_prime_u64(64, 128));
1005        assert!(!relatively_prime_u64(1024, 4096));
1006        assert!(!relatively_prime_u64(u64::MAX, u64::MAX));
1007
1008        // One number is a multiple of the other (always false)
1009        assert!(!relatively_prime_u64(5, 10));
1010        assert!(!relatively_prime_u64(12, 36));
1011        assert!(!relatively_prime_u64(15, 45));
1012        assert!(!relatively_prime_u64(100, 500));
1013
1014        // Co-prime numbers (should be true)
1015        assert!(relatively_prime_u64(17, 31));
1016        assert!(relatively_prime_u64(97, 43));
1017        assert!(relatively_prime_u64(7919, 65537));
1018        assert!(relatively_prime_u64(15485863, 32452843));
1019
1020        // Small prime numbers (should be true)
1021        assert!(relatively_prime_u64(13, 17));
1022        assert!(relatively_prime_u64(101, 103));
1023        assert!(relatively_prime_u64(1009, 1013));
1024
1025        // Large numbers (some cases where they are relatively prime or not)
1026        assert!(!relatively_prime_u64(
1027            190266297176832000,
1028            10430732356495263744
1029        ));
1030        assert!(!relatively_prime_u64(
1031            2040134905096275968,
1032            5701159354248194048
1033        ));
1034        assert!(!relatively_prime_u64(
1035            16611311494648745984,
1036            7514969329383038976
1037        ));
1038        assert!(!relatively_prime_u64(
1039            14863931409971066880,
1040            7911906750992527360
1041        ));
1042
1043        // Max values
1044        assert!(relatively_prime_u64(u64::MAX, 1));
1045        assert!(relatively_prime_u64(u64::MAX, u64::MAX - 1));
1046        assert!(!relatively_prime_u64(u64::MAX, u64::MAX));
1047    }
1048}