lib_q_stark_util/lib.rs
1//! Various simple utilities.
2
3#![no_std]
4
5#[cfg(all(feature = "parallel", target_arch = "wasm32"))]
6compile_error!(
7 "parallel feature is not supported on wasm32; do not enable 'parallel' for WASM builds"
8);
9
10extern crate alloc;
11
12use alloc::slice;
13use alloc::string::String;
14use alloc::vec::Vec;
15use core::any::type_name;
16use core::hint::unreachable_unchecked;
17use core::mem::{
18 ManuallyDrop,
19 MaybeUninit,
20};
21use core::{
22 iter,
23 mem,
24};
25
26use crate::transpose::transpose_in_place_square;
27
28pub mod array_serialization;
29pub mod linear_map;
30pub mod transpose;
31pub mod zip_eq;
32
33/// Computes `ceil(log_2(n))`.
34#[must_use]
35pub const fn log2_ceil_usize(n: usize) -> usize {
36 (usize::BITS - n.saturating_sub(1).leading_zeros()) as usize
37}
38
39#[must_use]
40pub fn log2_ceil_u64(n: u64) -> u64 {
41 (u64::BITS - n.saturating_sub(1).leading_zeros()).into()
42}
43
44/// Computes `log_2(n)`
45///
46/// # Panics
47/// Panics if `n` is not a power of two.
48#[must_use]
49#[inline]
50pub fn log2_strict_usize(n: usize) -> usize {
51 let res = n.trailing_zeros();
52 assert_eq!(n.wrapping_shr(res), 1, "Not a power of two: {n}");
53 // Tell the optimizer about the semantics of `log2_strict`. i.e. it can replace `n` with
54 // `1 << res` and vice versa.
55 unsafe {
56 assume(n == 1 << res);
57 }
58 res as usize
59}
60
61/// Returns `[0, ..., N - 1]`.
62#[must_use]
63pub const fn indices_arr<const N: usize>() -> [usize; N] {
64 let mut indices_arr = [0; N];
65 let mut i = 0;
66 while i < N {
67 indices_arr[i] = i;
68 i += 1;
69 }
70 indices_arr
71}
72
73#[inline]
74pub const fn reverse_bits(x: usize, n: usize) -> usize {
75 // Assert that n is a power of 2
76 debug_assert!(n.is_power_of_two());
77 reverse_bits_len(x, n.trailing_zeros() as usize)
78}
79
80#[inline]
81pub const fn reverse_bits_len(x: usize, bit_len: usize) -> usize {
82 // NB: The only reason we need overflowing_shr() here as opposed
83 // to plain '>>' is to accommodate the case n == num_bits == 0,
84 // which would become `0 >> 64`. Rust thinks that any shift of 64
85 // bits causes overflow, even when the argument is zero.
86 x.reverse_bits()
87 .overflowing_shr(usize::BITS - bit_len as u32)
88 .0
89}
90
91// Lookup table of 6-bit reverses.
92// NB: 2^6=64 bytes is a cache line. A smaller table wastes cache space.
93#[cfg(not(target_arch = "aarch64"))]
94#[rustfmt::skip]
95const BIT_REVERSE_6BIT: &[u8] = &[
96 0o00, 0o40, 0o20, 0o60, 0o10, 0o50, 0o30, 0o70,
97 0o04, 0o44, 0o24, 0o64, 0o14, 0o54, 0o34, 0o74,
98 0o02, 0o42, 0o22, 0o62, 0o12, 0o52, 0o32, 0o72,
99 0o06, 0o46, 0o26, 0o66, 0o16, 0o56, 0o36, 0o76,
100 0o01, 0o41, 0o21, 0o61, 0o11, 0o51, 0o31, 0o71,
101 0o05, 0o45, 0o25, 0o65, 0o15, 0o55, 0o35, 0o75,
102 0o03, 0o43, 0o23, 0o63, 0o13, 0o53, 0o33, 0o73,
103 0o07, 0o47, 0o27, 0o67, 0o17, 0o57, 0o37, 0o77,
104];
105
106// Ensure that SMALL_ARR_SIZE >= 4 * BIG_T_SIZE.
107const BIG_T_SIZE: usize = 1 << 14;
108const SMALL_ARR_SIZE: usize = 1 << 16;
109
110/// Permutes `arr` such that each index is mapped to its reverse in binary.
111///
112/// If the whole array fits in fast cache, then the trivial algorithm is cache friendly. Also, if
113/// `T` is really big, then the trivial algorithm is cache-friendly, no matter the size of the array.
114pub fn reverse_slice_index_bits<F>(vals: &mut [F])
115where
116 F: Copy + Send + Sync,
117{
118 let n = vals.len();
119 if n == 0 {
120 return;
121 }
122 let log_n = log2_strict_usize(n);
123
124 // If the whole array fits in fast cache, then the trivial algorithm is cache friendly. Also, if
125 // `T` is really big, then the trivial algorithm is cache-friendly, no matter the size of the array.
126 if core::mem::size_of::<F>() << log_n <= SMALL_ARR_SIZE ||
127 core::mem::size_of::<F>() >= BIG_T_SIZE
128 {
129 reverse_slice_index_bits_small(vals, log_n);
130 } else {
131 debug_assert!(n >= 4); // By our choice of `BIG_T_SIZE` and `SMALL_ARR_SIZE`.
132
133 // Algorithm:
134 //
135 // Treat `arr` as a `sqrt(n)` by `sqrt(n)` row-major matrix. (Assume for now that `lb_n` is
136 // even, i.e., `n` is a square number.) To perform bit-order reversal we:
137 // 1. Bit-reverse the order of the rows. (They are contiguous in memory, so this is
138 // basically a series of large `memcpy`s.)
139 // 2. Transpose the matrix.
140 // 3. Bit-reverse the order of the rows.
141 //
142 // This is equivalent to, for every index `0 <= i < n`:
143 // 1. bit-reversing `i[lb_n / 2..lb_n]`,
144 // 2. swapping `i[0..lb_n / 2]` and `i[lb_n / 2..lb_n]`,
145 // 3. bit-reversing `i[lb_n / 2..lb_n]`.
146 //
147 // If `lb_n` is odd, i.e., `n` is not a square number, then the above procedure requires
148 // slight modification. At steps 1 and 3 we bit-reverse bits `ceil(lb_n / 2)..lb_n`, of the
149 // index (shuffling `floor(lb_n / 2)` chunks of length `ceil(lb_n / 2)`). At step 2, we
150 // perform _two_ transposes. We treat `arr` as two matrices, one where the middle bit of the
151 // index is `0` and another, where the middle bit is `1`; we transpose each individually.
152
153 let lb_num_chunks = log_n >> 1;
154 let lb_chunk_size = log_n - lb_num_chunks;
155 unsafe {
156 reverse_slice_index_bits_chunks(vals, lb_num_chunks, lb_chunk_size);
157 transpose_in_place_square(vals, lb_chunk_size, lb_num_chunks, 0);
158 if lb_num_chunks != lb_chunk_size {
159 // `arr` cannot be interpreted as a square matrix. We instead interpret it as a
160 // `1 << lb_num_chunks` by `2` by `1 << lb_num_chunks` tensor, in row-major order.
161 // The above transpose acted on `tensor[..., 0, ...]` (all indices with middle bit
162 // `0`). We still need to transpose `tensor[..., 1, ...]`. To do so, we advance
163 // arr by `1 << lb_num_chunks` effectively, adding that to every index.
164 let vals_with_offset = &mut vals[1 << lb_num_chunks..];
165 transpose_in_place_square(vals_with_offset, lb_chunk_size, lb_num_chunks, 0);
166 }
167 reverse_slice_index_bits_chunks(vals, lb_num_chunks, lb_chunk_size);
168 }
169 }
170}
171
172// Both functions below are semantically equivalent to:
173// for i in 0..n {
174// result.push(arr[reverse_bits(i, n_power)]);
175// }
176// where reverse_bits(i, n_power) computes the n_power-bit reverse. The complications are there
177// to guide the compiler to generate optimal assembly.
178
179#[cfg(not(target_arch = "aarch64"))]
180fn reverse_slice_index_bits_small<F>(vals: &mut [F], lb_n: usize) {
181 if lb_n <= 6 {
182 // BIT_REVERSE_6BIT holds 6-bit reverses. This shift makes them lb_n-bit reverses.
183 let dst_shr_amt = 6 - lb_n as u32;
184 #[allow(clippy::needless_range_loop)]
185 for src in 0..vals.len() {
186 let dst = (BIT_REVERSE_6BIT[src] as usize).wrapping_shr(dst_shr_amt);
187 if src < dst {
188 vals.swap(src, dst);
189 }
190 }
191 } else {
192 // LLVM does not know that it does not need to reverse src at each iteration (which is
193 // expensive on x86). We take advantage of the fact that the low bits of dst change rarely and the high
194 // bits of dst are dependent only on the low bits of src.
195 let dst_lo_shr_amt = usize::BITS - (lb_n - 6) as u32;
196 let dst_hi_shl_amt = lb_n - 6;
197 for src_chunk in 0..(vals.len() >> 6) {
198 let src_hi = src_chunk << 6;
199 let dst_lo = src_chunk.reverse_bits().wrapping_shr(dst_lo_shr_amt);
200 #[allow(clippy::needless_range_loop)]
201 for src_lo in 0..(1 << 6) {
202 let dst_hi = (BIT_REVERSE_6BIT[src_lo] as usize) << dst_hi_shl_amt;
203 let src = src_hi + src_lo;
204 let dst = dst_hi + dst_lo;
205 if src < dst {
206 vals.swap(src, dst);
207 }
208 }
209 }
210 }
211}
212
213#[cfg(target_arch = "aarch64")]
214fn reverse_slice_index_bits_small<F>(vals: &mut [F], lb_n: usize) {
215 // Aarch64 can reverse bits in one instruction, so the trivial version works best.
216 for src in 0..vals.len() {
217 let dst = src.reverse_bits().wrapping_shr(usize::BITS - lb_n as u32);
218 if src < dst {
219 vals.swap(src, dst);
220 }
221 }
222}
223
224/// Split `arr` chunks and bit-reverse the order of the chunks. There are `1 << lb_num_chunks`
225/// chunks, each of length `1 << lb_chunk_size`.
226/// SAFETY: ensure that `arr.len() == 1 << lb_num_chunks + lb_chunk_size`.
227unsafe fn reverse_slice_index_bits_chunks<F>(
228 vals: &mut [F],
229 lb_num_chunks: usize,
230 lb_chunk_size: usize,
231) {
232 for i in 0..1usize << lb_num_chunks {
233 // `wrapping_shr` handles the silly case when `lb_num_chunks == 0`.
234 let j = i
235 .reverse_bits()
236 .wrapping_shr(usize::BITS - lb_num_chunks as u32);
237 if i < j {
238 unsafe {
239 core::ptr::swap_nonoverlapping(
240 vals.get_unchecked_mut(i << lb_chunk_size),
241 vals.get_unchecked_mut(j << lb_chunk_size),
242 1 << lb_chunk_size,
243 );
244 }
245 }
246 }
247}
248
249/// Allow the compiler to assume that the given predicate `p` is always `true`.
250///
251/// # Safety
252///
253/// Callers must ensure that `p` is true. If this is not the case, the behavior is undefined.
254#[inline(always)]
255pub unsafe fn assume(p: bool) {
256 debug_assert!(p);
257 if !p {
258 unsafe {
259 unreachable_unchecked();
260 }
261 }
262}
263
264/// Try to force Rust to emit a branch. Example:
265///
266/// ```no_run
267/// let x = 100;
268/// if x > 20 {
269/// println!("x is big!");
270/// lib_q_stark_util::branch_hint();
271/// } else {
272/// println!("x is small!");
273/// }
274/// ```
275///
276/// This function has no semantics. It is a hint only.
277#[inline(always)]
278pub fn branch_hint() {
279 // NOTE: These are the currently supported assembly architectures. See the
280 // [nightly reference](https://doc.rust-lang.org/nightly/reference/inline-assembly.html) for
281 // the most up-to-date list.
282 #[cfg(any(
283 target_arch = "aarch64",
284 target_arch = "arm",
285 target_arch = "riscv32",
286 target_arch = "riscv64",
287 target_arch = "x86",
288 target_arch = "x86_64",
289 ))]
290 unsafe {
291 core::arch::asm!("", options(nomem, nostack, preserves_flags));
292 }
293}
294
295/// Return a String containing the name of T but with all the crate
296/// and module prefixes removed.
297pub fn pretty_name<T>() -> String {
298 let name = type_name::<T>();
299 let mut result = String::new();
300 for qual in name.split_inclusive(&['<', '>', ',']) {
301 result.push_str(qual.split("::").last().unwrap());
302 }
303 result
304}
305
306/// A C-style buffered input reader, similar to
307/// `core::iter::Iterator::next_chunk()` from nightly.
308///
309/// Returns an array of `MaybeUninit<T>` and the number of items in the
310/// array which have been correctly initialized.
311#[inline]
312fn iter_next_chunk_erased<const BUFLEN: usize, I: Iterator>(
313 iter: &mut I,
314) -> ([MaybeUninit<I::Item>; BUFLEN], usize)
315where
316 I::Item: Copy,
317{
318 let mut buf = [const { MaybeUninit::<I::Item>::uninit() }; BUFLEN];
319 let mut i = 0;
320
321 while i < BUFLEN {
322 if let Some(c) = iter.next() {
323 // Copy the next Item into `buf`.
324 unsafe {
325 buf.get_unchecked_mut(i).write(c);
326 i = i.unchecked_add(1);
327 }
328 } else {
329 // No more items in the iterator.
330 break;
331 }
332 }
333 (buf, i)
334}
335
336/// Gets a shared reference to the contained value.
337///
338/// # Safety
339///
340/// Calling this when the content is not yet fully initialized causes undefined
341/// behavior: it is up to the caller to guarantee that every `MaybeUninit<T>` in
342/// the slice really is in an initialized state.
343///
344/// Copied from:
345/// <https://doc.rust-lang.org/std/primitive.slice.html#method.assume_init_ref>
346/// Once that is stabilized, this should be removed.
347#[inline(always)]
348pub const unsafe fn assume_init_ref<T>(slice: &[MaybeUninit<T>]) -> &[T] {
349 // SAFETY: casting `slice` to a `*const [T]` is safe since the caller guarantees that
350 // `slice` is initialized, and `MaybeUninit` is guaranteed to have the same layout as `T`.
351 // The pointer obtained is valid since it refers to memory owned by `slice` which is a
352 // reference and thus guaranteed to be valid for reads.
353 unsafe { &*(slice as *const [MaybeUninit<T>] as *const [T]) }
354}
355
356/// Split an iterator into small arrays and apply `func` to each.
357///
358/// Repeatedly read `BUFLEN` elements from `input` into an array and
359/// pass the array to `func` as a slice. If less than `BUFLEN`
360/// elements are remaining, that smaller slice is passed to `func` (if
361/// it is non-empty) and the function returns.
362#[inline]
363pub fn apply_to_chunks<const BUFLEN: usize, I, H>(input: I, mut func: H)
364where
365 I: IntoIterator<Item = u8>,
366 H: FnMut(&[I::Item]),
367{
368 let mut iter = input.into_iter();
369 loop {
370 let (buf, n) = iter_next_chunk_erased::<BUFLEN, _>(&mut iter);
371 if n == 0 {
372 break;
373 }
374 func(unsafe { assume_init_ref(buf.get_unchecked(..n)) });
375 }
376}
377
378/// Pulls `N` items from `iter` and returns them as an array. If the iterator
379/// yields fewer than `N` items (but more than `0`), pads by the given default value.
380///
381/// Since the iterator is passed as a mutable reference and this function calls
382/// `next` at most `N` times, the iterator can still be used afterwards to
383/// retrieve the remaining items.
384///
385/// If `iter.next()` panics, all items already yielded by the iterator are
386/// dropped.
387#[inline]
388fn iter_next_chunk_padded<T: Copy, const N: usize>(
389 iter: &mut impl Iterator<Item = T>,
390 default: T, // Needed due to [T; M] not always implementing Default. Can probably be dropped if const generics stabilize.
391) -> Option<[T; N]> {
392 let (mut arr, n) = iter_next_chunk_erased::<N, _>(iter);
393 (n != 0).then(|| {
394 // Fill the rest of the array with default values.
395 arr[n..].fill(MaybeUninit::new(default));
396 unsafe { mem::transmute_copy::<_, [T; N]>(&arr) }
397 })
398}
399
400/// Returns an iterator over `N` elements of the iterator at a time.
401///
402/// The chunks do not overlap. If `N` does not divide the length of the
403/// iterator, then the last `N-1` elements will be padded with the given default value.
404///
405/// This is essentially a copy pasted version of the nightly `array_chunks` function.
406/// <https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.array_chunks>
407/// Once that is stabilized this and the functions above it should be removed.
408#[inline]
409pub fn iter_array_chunks_padded<T: Copy, const N: usize>(
410 iter: impl IntoIterator<Item = T>,
411 default: T, // Needed due to [T; M] not always implementing Default. Can probably be dropped if const generics stabilize.
412) -> impl Iterator<Item = [T; N]> {
413 let mut iter = iter.into_iter();
414 iter::from_fn(move || iter_next_chunk_padded(&mut iter, default))
415}
416
417/// Reinterpret a slice of `BaseArray` elements as a slice of `Base` elements
418///
419/// This is useful to convert `&[F; N]` to `&[F]` or `&[A]` to `&[F]` where
420/// `A` has the same size, alignment and memory layout as `[F; N]` for some `N`.
421///
422/// # Safety
423///
424/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
425/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
426/// the array is the same as the alignment of its elements, this means that `BaseArray`
427/// must have the same alignment as `Base`.
428///
429/// # Panics
430///
431/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
432#[inline]
433pub const unsafe fn as_base_slice<Base, BaseArray>(buf: &[BaseArray]) -> &[Base] {
434 const {
435 assert!(align_of::<Base>() == align_of::<BaseArray>());
436 assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
437 }
438
439 let d = size_of::<BaseArray>() / size_of::<Base>();
440
441 let buf_ptr = buf.as_ptr().cast::<Base>();
442 let n = buf.len() * d;
443 unsafe { slice::from_raw_parts(buf_ptr, n) }
444}
445
446/// Reinterpret a mutable slice of `BaseArray` elements as a slice of `Base` elements
447///
448/// This is useful to convert `&[F; N]` to `&[F]` or `&[A]` to `&[F]` where
449/// `A` has the same size, alignment and memory layout as `[F; N]` for some `N`.
450///
451/// # Safety
452///
453/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
454/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
455/// the array is the same as the alignment of its elements, this means that `BaseArray`
456/// must have the same alignment as `Base`.
457///
458/// # Panics
459///
460/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
461#[inline]
462pub const unsafe fn as_base_slice_mut<Base, BaseArray>(buf: &mut [BaseArray]) -> &mut [Base] {
463 const {
464 assert!(align_of::<Base>() == align_of::<BaseArray>());
465 assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
466 }
467
468 let d = size_of::<BaseArray>() / size_of::<Base>();
469
470 let buf_ptr = buf.as_mut_ptr().cast::<Base>();
471 let n = buf.len() * d;
472 unsafe { slice::from_raw_parts_mut(buf_ptr, n) }
473}
474
475/// Convert a vector of `BaseArray` elements to a vector of `Base` elements without any
476/// reallocations.
477///
478/// This is useful to convert `Vec<[F; N]>` to `Vec<F>` or `Vec<A>` to `Vec<F>` where
479/// `A` has the same size, alignment and memory layout as `[F; N]` for some `N`. It can also,
480/// be used to safely convert `Vec<u32>` to `Vec<F>` if `F` is a `32` bit field
481/// or `Vec<u64>` to `Vec<F>` if `F` is a `64` bit field.
482///
483/// # Safety
484///
485/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
486/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
487/// the array is the same as the alignment of its elements, this means that `BaseArray`
488/// must have the same alignment as `Base`.
489///
490/// # Panics
491///
492/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
493#[inline]
494pub unsafe fn flatten_to_base<Base, BaseArray>(vec: Vec<BaseArray>) -> Vec<Base> {
495 const {
496 assert!(align_of::<Base>() == align_of::<BaseArray>());
497 assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
498 }
499
500 let d = size_of::<BaseArray>() / size_of::<Base>();
501 // Prevent running `vec`'s destructor so we are in complete control
502 // of the allocation.
503 let mut values = ManuallyDrop::new(vec);
504
505 // Each `Self` is an array of `d` elements, so the length and capacity of
506 // the new vector will be multiplied by `d`.
507 let new_len = values.len() * d;
508 let new_cap = values.capacity() * d;
509
510 // Safe as BaseArray and Base have the same alignment.
511 let ptr = values.as_mut_ptr() as *mut Base;
512
513 unsafe {
514 // Safety:
515 // - BaseArray and Base have the same alignment.
516 // - As size_of::<BaseArray>() == size_of::<Base>() * d:
517 // -- The capacity of the new vector is equal to the capacity of the old vector.
518 // -- The first new_len elements of the new vector correspond to the first
519 // len elements of the old vector and so are properly initialized.
520 Vec::from_raw_parts(ptr, new_len, new_cap)
521 }
522}
523
524/// Convert a vector of `Base` elements to a vector of `BaseArray` elements ideally without any
525/// reallocations.
526///
527/// This is an inverse of `flatten_to_base`. Unfortunately, unlike `flatten_to_base`, it may not be
528/// possible to avoid allocations. This issue is that there is not way to guarantee that the capacity
529/// of the vector is a multiple of `d`.
530///
531/// # Safety
532///
533/// This is assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`.
534/// As Rust guarantees that arrays elements are contiguous in memory and the alignment of
535/// the array is the same as the alignment of its elements, this means that `BaseArray`
536/// must have the same alignment as `Base`.
537///
538/// # Panics
539///
540/// This panics if the size of `BaseArray` is not a multiple of the size of `Base`.
541/// This panics if the length of the vector is not a multiple of the ratio of the sizes.
542#[inline]
543pub unsafe fn reconstitute_from_base<Base, BaseArray: Clone>(mut vec: Vec<Base>) -> Vec<BaseArray> {
544 const {
545 assert!(align_of::<Base>() == align_of::<BaseArray>());
546 assert!(size_of::<BaseArray>().is_multiple_of(size_of::<Base>()));
547 }
548
549 let d = size_of::<BaseArray>() / size_of::<Base>();
550
551 assert!(
552 vec.len().is_multiple_of(d),
553 "Vector length (got {}) must be a multiple of the extension field dimension ({}).",
554 vec.len(),
555 d
556 );
557
558 let new_len = vec.len() / d;
559
560 // We could call vec.shrink_to_fit() here to try and increase the probability that
561 // the capacity is a multiple of d. That might cause a reallocation though which
562 // would defeat the whole purpose.
563 let cap = vec.capacity();
564
565 // The assumption is that basically all callers of `reconstitute_from_base_vec` will be calling it
566 // with a vector constructed from `flatten_to_base` and so the capacity should be a multiple of `d`.
567 // But capacities can do strange things so we need to support both possibilities.
568 // Note that the `else` branch would also work if the capacity is a multiple of `d` but it is slower.
569 if cap.is_multiple_of(d) {
570 // Prevent running `vec`'s destructor so we are in complete control
571 // of the allocation.
572 let mut values = ManuallyDrop::new(vec);
573
574 // If we are on this branch then the capacity is a multiple of `d`.
575 let new_cap = cap / d;
576
577 // Safe as BaseArray and Base have the same alignment.
578 let ptr = values.as_mut_ptr() as *mut BaseArray;
579
580 unsafe {
581 // Safety:
582 // - BaseArray and Base have the same alignment.
583 // - As size_of::<Base>() == size_of::<BaseArray>() / d:
584 // -- If we have reached this point, the length and capacity are both divisible by `d`.
585 // -- The capacity of the new vector is equal to the capacity of the old vector.
586 // -- The first new_len elements of the new vector correspond to the first
587 // len elements of the old vector and so are properly initialized.
588 Vec::from_raw_parts(ptr, new_len, new_cap)
589 }
590 } else {
591 // If the capacity is not a multiple of `D`, we go via slices.
592
593 let buf_ptr = vec.as_mut_ptr().cast::<BaseArray>();
594 let slice = unsafe {
595 // Safety:
596 // - BaseArray and Base have the same alignment.
597 // - As size_of::<Base>() == size_of::<BaseArray>() / D:
598 // -- If we have reached this point, the length is divisible by `D`.
599 // -- The first new_len elements of the slice correspond to the first
600 // len elements of the old slice and so are properly initialized.
601 slice::from_raw_parts(buf_ptr, new_len)
602 };
603
604 // Ideally the compiler could optimize this away to avoid the copy but it appears not to.
605 slice.to_vec()
606 }
607}
608
609#[inline(always)]
610pub const fn relatively_prime_u64(mut u: u64, mut v: u64) -> bool {
611 // Check that neither input is 0.
612 if u == 0 || v == 0 {
613 return false;
614 }
615
616 // Check divisibility by 2.
617 if (u | v) & 1 == 0 {
618 return false;
619 }
620
621 // Remove factors of 2 from `u` and `v`
622 u >>= u.trailing_zeros();
623 if u == 1 {
624 return true;
625 }
626
627 while v != 0 {
628 v >>= v.trailing_zeros();
629 if v == 1 {
630 return true;
631 }
632
633 // Ensure u <= v
634 if u > v {
635 core::mem::swap(&mut u, &mut v);
636 }
637
638 // This looks inefficient for v >> u but thanks to the fact that we remove
639 // trailing_zeros of v in every iteration, it ends up much more performative
640 // than first glance implies.
641 v -= u;
642 }
643 // If we made it through the loop, at no point is u or v equal to 1 and so the gcd
644 // must be greater than 1.
645 false
646}
647
648/// Inner loop of the deferred GCD algorithm.
649///
650/// See: <https://eprint.iacr.org/2020/972.pdf> for more information.
651///
652/// This is basically a mini GCD algorithm which builds up a transformation to apply to the larger
653/// numbers in the main loop. The key point is that this small loop only uses u64s, subtractions and
654/// bit shifts, which are very fast operations.
655///
656/// The bottom `NUM_ROUNDS` bits of `a` and `b` should match the bottom `NUM_ROUNDS` bits of
657/// the corresponding big-ints and the top `NUM_ROUNDS + 2` should match the top bits including
658/// zeroes if the original numbers have different sizes.
659#[inline]
660pub fn gcd_inner<const NUM_ROUNDS: usize>(a: &mut u64, b: &mut u64) -> (i64, i64, i64, i64) {
661 // Initialise update factors.
662 // At the start of round 0: -1 < f0, g0, f1, g1 <= 1
663 let (mut f0, mut g0, mut f1, mut g1) = (1, 0, 0, 1);
664
665 // If at the start of a round: -2^i < f0, g0, f1, g1 <= 2^i
666 // Then, at the end of the round: -2^{i + 1} < f0, g0, f1, g1 <= 2^{i + 1}
667 for _ in 0..NUM_ROUNDS {
668 if *a & 1 == 0 {
669 *a >>= 1;
670 } else {
671 if a < b {
672 core::mem::swap(a, b);
673 (f0, f1) = (f1, f0);
674 (g0, g1) = (g1, g0);
675 }
676 *a -= *b;
677 *a >>= 1;
678 f0 -= f1;
679 g0 -= g1;
680 }
681 f1 <<= 1;
682 g1 <<= 1;
683 }
684
685 // -2^NUM_ROUNDS < f0, g0, f1, g1 <= 2^NUM_ROUNDS
686 // Hence provided NUM_ROUNDS <= 62, we will not get any overflow.
687 // Additionally, if NUM_ROUNDS <= 63, then the only source of overflow will be
688 // if a variable is meant to equal 2^{63} in which case it will overflow to -2^{63}.
689 (f0, g0, f1, g1)
690}
691
692/// Inverts elements inside the prime field `F_P` with `P < 2^FIELD_BITS`.
693///
694/// Arguments:
695/// - a: The value we want to invert. It must be < P.
696/// - b: The value of the prime `P > 2`.
697///
698/// Output:
699/// - A `64-bit` signed integer `v` equal to `2^{2 * FIELD_BITS - 2} a^{-1} mod P` with
700/// size `|v| < 2^{2 * FIELD_BITS - 2}`.
701///
702/// It is up to the user to ensure that `b` is an odd prime with at most `FIELD_BITS` bits and
703/// `a < b`. If either of these assumptions break, the output is undefined.
704#[inline]
705pub fn gcd_inversion_prime_field_32<const FIELD_BITS: u32>(mut a: u32, mut b: u32) -> i64 {
706 const {
707 assert!(FIELD_BITS <= 32);
708 }
709 debug_assert!(((1_u64 << FIELD_BITS) - 1) >= b as u64);
710
711 // Initialise u, v. Note that |u|, |v| <= 2^0
712 let (mut u, mut v) = (1_i64, 0_i64);
713
714 // Let a0 and P denote the initial values of a and b. Observe:
715 // `a = u * a0 mod P`
716 // `b = v * a0 mod P`
717 // `len(a) + len(b) <= 2 * len(P) <= 2 * FIELD_BITS`
718
719 for _ in 0..(2 * FIELD_BITS - 2) {
720 // Assume at the start of the loop i:
721 // (1) `|u|, |v| <= 2^{i}`
722 // (2) `2^i * a = u * a0 mod P`
723 // (3) `2^i * b = v * a0 mod P`
724 // (4) `gcd(a, b) = 1`
725 // (5) `b` is odd.
726 // (6) `len(a) + len(b) <= max(n - i, 1)`
727
728 if a & 1 != 0 {
729 if a < b {
730 (a, b) = (b, a);
731 (u, v) = (v, u);
732 }
733 // As b < a, this subtraction cannot increase `len(a) + len(b)`
734 a -= b;
735 // Observe |u'| = |u - v| <= |u| + |v| <= 2^{i + 1}
736 u -= v;
737
738 // As (1) and (2) hold, we have
739 // `2^i a' = 2^i * (a - b) = (u - v) * a0 mod P = u' * a0 mod P`
740 }
741 // As b is odd, a must now be even.
742 // This reduces `len(a) + len(b)` by 1 (unless `a = 0` in which case `b = 1` and the sum of the lengths is always 1)
743 a >>= 1;
744
745 // Observe |v'| = 2|v| <= 2^{i + 1}
746 v <<= 1;
747
748 // Thus as the end of loop i:
749 // (1) `|u|, |v| <= 2^{i + 1}`
750 // (2) `2^{i + 1} * a = u * a0 mod P` (As we have halved a)
751 // (3) `2^{i + 1} * b = v * a0 mod P` (As we have doubled v)
752 // (4) `gcd(a, b) = 1`
753 // (5) `b` is odd.
754 // (6) `len(a) + len(b) <= max(n - i - 1, 1)`
755 }
756
757 // After the loops, we see that:
758 // |u|, |v| <= 2^{2 * FIELD_BITS - 2}: Hence for FIELD_BITS <= 32 we will not overflow an i64.
759 // `2^{2 * FIELD_BITS - 2} * b = v * a0 mod P`
760 // `len(a) + len(b) <= 2` with `gcd(a, b) = 1` and `b` odd.
761 // This implies that `b` must be `1` and so `v = 2^{2 * FIELD_BITS - 2} a0^{-1} mod P` as desired.
762 v
763}
764
765#[cfg(test)]
766mod tests {
767 use alloc::vec;
768 use alloc::vec::Vec;
769
770 use rand::rngs::SmallRng;
771 use rand::{
772 RngExt,
773 SeedableRng,
774 };
775
776 use super::*;
777
778 #[test]
779 fn test_reverse_bits_len() {
780 assert_eq!(reverse_bits_len(0b0000000000, 10), 0b0000000000);
781 assert_eq!(reverse_bits_len(0b0000000001, 10), 0b1000000000);
782 assert_eq!(reverse_bits_len(0b1000000000, 10), 0b0000000001);
783 assert_eq!(reverse_bits_len(0b00000, 5), 0b00000);
784 assert_eq!(reverse_bits_len(0b01011, 5), 0b11010);
785 }
786
787 #[test]
788 fn test_reverse_index_bits() {
789 let mut arg = vec![10, 20, 30, 40];
790 reverse_slice_index_bits(&mut arg);
791 assert_eq!(arg, vec![10, 30, 20, 40]);
792
793 let mut input256: Vec<u64> = (0..256).collect();
794 #[rustfmt::skip]
795 let output256: Vec<u64> = vec![
796 0x00, 0x80, 0x40, 0xc0, 0x20, 0xa0, 0x60, 0xe0, 0x10, 0x90, 0x50, 0xd0, 0x30, 0xb0, 0x70, 0xf0,
797 0x08, 0x88, 0x48, 0xc8, 0x28, 0xa8, 0x68, 0xe8, 0x18, 0x98, 0x58, 0xd8, 0x38, 0xb8, 0x78, 0xf8,
798 0x04, 0x84, 0x44, 0xc4, 0x24, 0xa4, 0x64, 0xe4, 0x14, 0x94, 0x54, 0xd4, 0x34, 0xb4, 0x74, 0xf4,
799 0x0c, 0x8c, 0x4c, 0xcc, 0x2c, 0xac, 0x6c, 0xec, 0x1c, 0x9c, 0x5c, 0xdc, 0x3c, 0xbc, 0x7c, 0xfc,
800 0x02, 0x82, 0x42, 0xc2, 0x22, 0xa2, 0x62, 0xe2, 0x12, 0x92, 0x52, 0xd2, 0x32, 0xb2, 0x72, 0xf2,
801 0x0a, 0x8a, 0x4a, 0xca, 0x2a, 0xaa, 0x6a, 0xea, 0x1a, 0x9a, 0x5a, 0xda, 0x3a, 0xba, 0x7a, 0xfa,
802 0x06, 0x86, 0x46, 0xc6, 0x26, 0xa6, 0x66, 0xe6, 0x16, 0x96, 0x56, 0xd6, 0x36, 0xb6, 0x76, 0xf6,
803 0x0e, 0x8e, 0x4e, 0xce, 0x2e, 0xae, 0x6e, 0xee, 0x1e, 0x9e, 0x5e, 0xde, 0x3e, 0xbe, 0x7e, 0xfe,
804 0x01, 0x81, 0x41, 0xc1, 0x21, 0xa1, 0x61, 0xe1, 0x11, 0x91, 0x51, 0xd1, 0x31, 0xb1, 0x71, 0xf1,
805 0x09, 0x89, 0x49, 0xc9, 0x29, 0xa9, 0x69, 0xe9, 0x19, 0x99, 0x59, 0xd9, 0x39, 0xb9, 0x79, 0xf9,
806 0x05, 0x85, 0x45, 0xc5, 0x25, 0xa5, 0x65, 0xe5, 0x15, 0x95, 0x55, 0xd5, 0x35, 0xb5, 0x75, 0xf5,
807 0x0d, 0x8d, 0x4d, 0xcd, 0x2d, 0xad, 0x6d, 0xed, 0x1d, 0x9d, 0x5d, 0xdd, 0x3d, 0xbd, 0x7d, 0xfd,
808 0x03, 0x83, 0x43, 0xc3, 0x23, 0xa3, 0x63, 0xe3, 0x13, 0x93, 0x53, 0xd3, 0x33, 0xb3, 0x73, 0xf3,
809 0x0b, 0x8b, 0x4b, 0xcb, 0x2b, 0xab, 0x6b, 0xeb, 0x1b, 0x9b, 0x5b, 0xdb, 0x3b, 0xbb, 0x7b, 0xfb,
810 0x07, 0x87, 0x47, 0xc7, 0x27, 0xa7, 0x67, 0xe7, 0x17, 0x97, 0x57, 0xd7, 0x37, 0xb7, 0x77, 0xf7,
811 0x0f, 0x8f, 0x4f, 0xcf, 0x2f, 0xaf, 0x6f, 0xef, 0x1f, 0x9f, 0x5f, 0xdf, 0x3f, 0xbf, 0x7f, 0xff,
812 ];
813 reverse_slice_index_bits(&mut input256[..]);
814 assert_eq!(input256, output256);
815 }
816
817 #[test]
818 fn test_apply_to_chunks_exact_fit() {
819 const CHUNK_SIZE: usize = 4;
820 let input: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8];
821 let mut results: Vec<Vec<u8>> = Vec::new();
822
823 apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
824 results.push(chunk.to_vec());
825 });
826
827 assert_eq!(results, vec![vec![1, 2, 3, 4], vec![5, 6, 7, 8]]);
828 }
829
830 #[test]
831 fn test_apply_to_chunks_with_remainder() {
832 const CHUNK_SIZE: usize = 3;
833 let input: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7];
834 let mut results: Vec<Vec<u8>> = Vec::new();
835
836 apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
837 results.push(chunk.to_vec());
838 });
839
840 assert_eq!(results, vec![vec![1, 2, 3], vec![4, 5, 6], vec![7]]);
841 }
842
843 #[test]
844 fn test_apply_to_chunks_empty_input() {
845 const CHUNK_SIZE: usize = 4;
846 let input: Vec<u8> = vec![];
847 let mut results: Vec<Vec<u8>> = Vec::new();
848
849 apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
850 results.push(chunk.to_vec());
851 });
852
853 assert!(results.is_empty());
854 }
855
856 #[test]
857 fn test_apply_to_chunks_single_chunk() {
858 const CHUNK_SIZE: usize = 10;
859 let input: Vec<u8> = vec![1, 2, 3, 4, 5];
860 let mut results: Vec<Vec<u8>> = Vec::new();
861
862 apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
863 results.push(chunk.to_vec());
864 });
865
866 assert_eq!(results, vec![vec![1, 2, 3, 4, 5]]);
867 }
868
869 #[test]
870 fn test_apply_to_chunks_large_chunk_size() {
871 const CHUNK_SIZE: usize = 100;
872 let input: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8];
873 let mut results: Vec<Vec<u8>> = Vec::new();
874
875 apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
876 results.push(chunk.to_vec());
877 });
878
879 assert_eq!(results, vec![vec![1, 2, 3, 4, 5, 6, 7, 8]]);
880 }
881
882 #[test]
883 fn test_apply_to_chunks_large_input() {
884 const CHUNK_SIZE: usize = 5;
885 let input: Vec<u8> = (1..=20).collect();
886 let mut results: Vec<Vec<u8>> = Vec::new();
887
888 apply_to_chunks::<CHUNK_SIZE, _, _>(input, |chunk| {
889 results.push(chunk.to_vec());
890 });
891
892 assert_eq!(
893 results,
894 vec![
895 vec![1, 2, 3, 4, 5],
896 vec![6, 7, 8, 9, 10],
897 vec![11, 12, 13, 14, 15],
898 vec![16, 17, 18, 19, 20]
899 ]
900 );
901 }
902
903 #[test]
904 fn test_reverse_slice_index_bits_random() {
905 let lengths = [32, 128, 1 << 16];
906 let mut rng = SmallRng::seed_from_u64(1);
907 for _ in 0..32 {
908 for &length in &lengths {
909 let mut rand_list: Vec<u32> = Vec::with_capacity(length);
910 rand_list.resize_with(length, || rng.random());
911 let expect = reverse_index_bits_naive(&rand_list);
912
913 let mut actual = rand_list.clone();
914 reverse_slice_index_bits(&mut actual);
915
916 assert_eq!(actual, expect);
917 }
918 }
919 }
920
921 #[test]
922 fn test_log2_strict_usize_edge_cases() {
923 assert_eq!(log2_strict_usize(1), 0);
924 assert_eq!(log2_strict_usize(2), 1);
925 assert_eq!(log2_strict_usize(1 << 18), 18);
926 assert_eq!(log2_strict_usize(1 << 31), 31);
927 assert_eq!(
928 log2_strict_usize(1 << (usize::BITS - 1)),
929 usize::BITS as usize - 1
930 );
931 }
932
933 #[test]
934 #[should_panic]
935 fn test_log2_strict_usize_zero() {
936 let _ = log2_strict_usize(0);
937 }
938
939 #[test]
940 #[should_panic]
941 fn test_log2_strict_usize_nonpower_2() {
942 let _ = log2_strict_usize(0x78C341C65AE6D262);
943 }
944
945 #[test]
946 #[should_panic]
947 fn test_log2_strict_usize_max() {
948 let _ = log2_strict_usize(usize::MAX);
949 }
950
951 #[test]
952 fn test_log2_ceil_usize_comprehensive() {
953 // Powers of 2
954 assert_eq!(log2_ceil_usize(0), 0);
955 assert_eq!(log2_ceil_usize(1), 0);
956 assert_eq!(log2_ceil_usize(2), 1);
957 assert_eq!(log2_ceil_usize(1 << 18), 18);
958 assert_eq!(log2_ceil_usize(1 << 31), 31);
959 assert_eq!(
960 log2_ceil_usize(1 << (usize::BITS - 1)),
961 usize::BITS as usize - 1
962 );
963
964 // Nonpowers; want to round up
965 assert_eq!(log2_ceil_usize(3), 2);
966 assert_eq!(log2_ceil_usize(0x14FE901B), 29);
967 assert_eq!(
968 log2_ceil_usize((1 << (usize::BITS - 1)) + 1),
969 usize::BITS as usize
970 );
971 assert_eq!(log2_ceil_usize(usize::MAX - 1), usize::BITS as usize);
972 assert_eq!(log2_ceil_usize(usize::MAX), usize::BITS as usize);
973 }
974
975 fn reverse_index_bits_naive<T: Copy>(arr: &[T]) -> Vec<T> {
976 let n = arr.len();
977 let n_power = log2_strict_usize(n);
978
979 let mut out = vec![None; n];
980 for (i, v) in arr.iter().enumerate() {
981 let dst = i.reverse_bits() >> (usize::BITS - n_power as u32);
982 out[dst] = Some(*v);
983 }
984
985 out.into_iter().map(|x| x.unwrap()).collect()
986 }
987
988 #[test]
989 fn test_relatively_prime_u64() {
990 // Zero cases (should always return false)
991 assert!(!relatively_prime_u64(0, 0));
992 assert!(!relatively_prime_u64(10, 0));
993 assert!(!relatively_prime_u64(0, 10));
994 assert!(!relatively_prime_u64(0, 123456789));
995
996 // Number with itself (if greater than 1, not relatively prime)
997 assert!(relatively_prime_u64(1, 1));
998 assert!(!relatively_prime_u64(10, 10));
999 assert!(!relatively_prime_u64(99999, 99999));
1000
1001 // Powers of 2 (always false since they share factor 2)
1002 assert!(!relatively_prime_u64(2, 4));
1003 assert!(!relatively_prime_u64(16, 32));
1004 assert!(!relatively_prime_u64(64, 128));
1005 assert!(!relatively_prime_u64(1024, 4096));
1006 assert!(!relatively_prime_u64(u64::MAX, u64::MAX));
1007
1008 // One number is a multiple of the other (always false)
1009 assert!(!relatively_prime_u64(5, 10));
1010 assert!(!relatively_prime_u64(12, 36));
1011 assert!(!relatively_prime_u64(15, 45));
1012 assert!(!relatively_prime_u64(100, 500));
1013
1014 // Co-prime numbers (should be true)
1015 assert!(relatively_prime_u64(17, 31));
1016 assert!(relatively_prime_u64(97, 43));
1017 assert!(relatively_prime_u64(7919, 65537));
1018 assert!(relatively_prime_u64(15485863, 32452843));
1019
1020 // Small prime numbers (should be true)
1021 assert!(relatively_prime_u64(13, 17));
1022 assert!(relatively_prime_u64(101, 103));
1023 assert!(relatively_prime_u64(1009, 1013));
1024
1025 // Large numbers (some cases where they are relatively prime or not)
1026 assert!(!relatively_prime_u64(
1027 190266297176832000,
1028 10430732356495263744
1029 ));
1030 assert!(!relatively_prime_u64(
1031 2040134905096275968,
1032 5701159354248194048
1033 ));
1034 assert!(!relatively_prime_u64(
1035 16611311494648745984,
1036 7514969329383038976
1037 ));
1038 assert!(!relatively_prime_u64(
1039 14863931409971066880,
1040 7911906750992527360
1041 ));
1042
1043 // Max values
1044 assert!(relatively_prime_u64(u64::MAX, 1));
1045 assert!(relatively_prime_u64(u64::MAX, u64::MAX - 1));
1046 assert!(!relatively_prime_u64(u64::MAX, u64::MAX));
1047 }
1048}