../../.cargo/katex-header.html

plonky2_util/
lib.rs

1#![allow(clippy::needless_range_loop)]
2#![no_std]
3
4extern crate alloc;
5
6use alloc::vec::Vec;
7use core::hint::unreachable_unchecked;
8use core::mem::size_of;
9use core::ptr::{swap, swap_nonoverlapping};
10
11use crate::transpose_util::transpose_in_place_square;
12
13mod transpose_util;
14
15pub const fn bits_u64(n: u64) -> usize {
16    (64 - n.leading_zeros()) as usize
17}
18
19/// Computes `ceil(log_2(n))`.
20#[must_use]
21pub const fn log2_ceil(n: usize) -> usize {
22    (usize::BITS - n.saturating_sub(1).leading_zeros()) as usize
23}
24
25/// Computes `log_2(n)`, panicking if `n` is not a power of two.
26pub fn log2_strict(n: usize) -> usize {
27    let res = n.trailing_zeros();
28    assert!(n.wrapping_shr(res) == 1, "Not a power of two: {n}");
29    // Tell the optimizer about the semantics of `log2_strict`. i.e. it can replace `n` with
30    // `1 << res` and vice versa.
31    assume(n == 1 << res);
32    res as usize
33}
34
35/// Returns the largest integer `i` such that `base**i <= n`.
36pub const fn log_floor(n: u64, base: u64) -> usize {
37    assert!(n > 0);
38    assert!(base > 1);
39    let mut i = 0;
40    let mut cur: u64 = 1;
41    loop {
42        let (mul, overflow) = cur.overflowing_mul(base);
43        if overflow || mul > n {
44            return i;
45        } else {
46            i += 1;
47            cur = mul;
48        }
49    }
50}
51
52/// Permutes `arr` such that each index is mapped to its reverse in binary.
53pub fn reverse_index_bits<T: Copy>(arr: &[T]) -> Vec<T> {
54    let n = arr.len();
55    let n_power = log2_strict(n);
56
57    if n_power <= 6 {
58        reverse_index_bits_small(arr, n_power)
59    } else {
60        reverse_index_bits_large(arr, n_power)
61    }
62}
63
64/* Both functions below are semantically equivalent to:
65        for i in 0..n {
66            result.push(arr[reverse_bits(i, n_power)]);
67        }
68   where reverse_bits(i, n_power) computes the n_power-bit reverse. The complications are there
69   to guide the compiler to generate optimal assembly.
70*/
71
72fn reverse_index_bits_small<T: Copy>(arr: &[T], n_power: usize) -> Vec<T> {
73    let n = arr.len();
74    let mut result = Vec::with_capacity(n);
75    // BIT_REVERSE_6BIT holds 6-bit reverses. This shift makes them n_power-bit reverses.
76    let dst_shr_amt = 6 - n_power;
77    for i in 0..n {
78        let src = (BIT_REVERSE_6BIT[i] as usize) >> dst_shr_amt;
79        result.push(arr[src]);
80    }
81    result
82}
83
84fn reverse_index_bits_large<T: Copy>(arr: &[T], n_power: usize) -> Vec<T> {
85    let n = arr.len();
86    // LLVM does not know that it does not need to reverse src at each iteration (which is expensive
87    // on x86). We take advantage of the fact that the low bits of dst change rarely and the high
88    // bits of dst are dependent only on the low bits of src.
89    let src_lo_shr_amt = 64 - (n_power - 6);
90    let src_hi_shl_amt = n_power - 6;
91    let mut result = Vec::with_capacity(n);
92    for i_chunk in 0..(n >> 6) {
93        let src_lo = i_chunk.reverse_bits() >> src_lo_shr_amt;
94        for i_lo in 0..(1 << 6) {
95            let src_hi = (BIT_REVERSE_6BIT[i_lo] as usize) << src_hi_shl_amt;
96            let src = src_hi + src_lo;
97            result.push(arr[src]);
98        }
99    }
100    result
101}
102
103/// Bit-reverse the order of elements in `arr`.
104/// SAFETY: ensure that `arr.len() == 1 << lb_n`.
105#[cfg(not(target_arch = "aarch64"))]
106unsafe fn reverse_index_bits_in_place_small<T>(arr: &mut [T], lb_n: usize) {
107    if lb_n <= 6 {
108        // BIT_REVERSE_6BIT holds 6-bit reverses. This shift makes them lb_n-bit reverses.
109        let dst_shr_amt = 6 - lb_n as u32;
110        for src in 0..arr.len() {
111            // `wrapping_shr` handles the case when `arr.len() == 1`. In that case `src == 0`, so
112            // `src.reverse_bits() == 0`. `usize::wrapping_shr` by 64 is a no-op, but it gives the
113            // correct result.
114            let dst = (BIT_REVERSE_6BIT[src] as usize).wrapping_shr(dst_shr_amt);
115            if src < dst {
116                swap(arr.get_unchecked_mut(src), arr.get_unchecked_mut(dst));
117            }
118        }
119    } else {
120        // LLVM does not know that it does not need to reverse src at each iteration (which is
121        // expensive on x86). We take advantage of the fact that the low bits of dst change rarely and the high
122        // bits of dst are dependent only on the low bits of src.
123        let dst_lo_shr_amt = usize::BITS - (lb_n - 6) as u32;
124        let dst_hi_shl_amt = lb_n - 6;
125        for src_chunk in 0..(arr.len() >> 6) {
126            let src_hi = src_chunk << 6;
127            // `wrapping_shr` handles the case when `arr.len() == 1`. In that case `src == 0`, so
128            // `src.reverse_bits() == 0`. `usize::wrapping_shr` by 64 is a no-op, but it gives the
129            // correct result.
130            let dst_lo = src_chunk.reverse_bits().wrapping_shr(dst_lo_shr_amt);
131            for src_lo in 0..(1 << 6) {
132                let dst_hi = (BIT_REVERSE_6BIT[src_lo] as usize) << dst_hi_shl_amt;
133                let src = src_hi + src_lo;
134                let dst = dst_hi + dst_lo;
135                if src < dst {
136                    swap(arr.get_unchecked_mut(src), arr.get_unchecked_mut(dst));
137                }
138            }
139        }
140    }
141}
142
143/// Bit-reverse the order of elements in `arr`.
144/// SAFETY: ensure that `arr.len() == 1 << lb_n`.
145#[cfg(target_arch = "aarch64")]
146unsafe fn reverse_index_bits_in_place_small<T>(arr: &mut [T], lb_n: usize) {
147    // Aarch64 can reverse bits in one instruction, so the trivial version works best.
148    for src in 0..arr.len() {
149        // `wrapping_shr` handles the case when `arr.len() == 1`. In that case `src == 0`, so
150        // `src.reverse_bits() == 0`. `usize::wrapping_shr` by 64 is a no-op, but it gives the
151        // correct result.
152        let dst = src.reverse_bits().wrapping_shr(usize::BITS - lb_n as u32);
153        if src < dst {
154            swap(arr.get_unchecked_mut(src), arr.get_unchecked_mut(dst));
155        }
156    }
157}
158
159/// Split `arr` chunks and bit-reverse the order of the chunks. There are `1 << lb_num_chunks`
160/// chunks, each of length `1 << lb_chunk_size`.
161/// SAFETY: ensure that `arr.len() == 1 << lb_num_chunks + lb_chunk_size`.
162unsafe fn reverse_index_bits_in_place_chunks<T>(
163    arr: &mut [T],
164    lb_num_chunks: usize,
165    lb_chunk_size: usize,
166) {
167    for i in 0..1usize << lb_num_chunks {
168        // `wrapping_shr` handles the silly case when `lb_num_chunks == 0`.
169        let j = i
170            .reverse_bits()
171            .wrapping_shr(usize::BITS - lb_num_chunks as u32);
172        if i < j {
173            swap_nonoverlapping(
174                arr.get_unchecked_mut(i << lb_chunk_size),
175                arr.get_unchecked_mut(j << lb_chunk_size),
176                1 << lb_chunk_size,
177            );
178        }
179    }
180}
181
182// Ensure that SMALL_ARR_SIZE >= 4 * BIG_T_SIZE.
183const BIG_T_SIZE: usize = 1 << 14;
184const SMALL_ARR_SIZE: usize = 1 << 16;
185pub fn reverse_index_bits_in_place<T>(arr: &mut [T]) {
186    let n = arr.len();
187    let lb_n = log2_strict(n);
188    // If the whole array fits in fast cache, then the trivial algorithm is cache friendly. Also, if
189    // `T` is really big, then the trivial algorithm is cache-friendly, no matter the size of the
190    // array.
191    if size_of::<T>() << lb_n <= SMALL_ARR_SIZE || size_of::<T>() >= BIG_T_SIZE {
192        unsafe {
193            reverse_index_bits_in_place_small(arr, lb_n);
194        }
195    } else {
196        debug_assert!(n >= 4); // By our choice of `BIG_T_SIZE` and `SMALL_ARR_SIZE`.
197
198        // Algorithm:
199        //
200        // Treat `arr` as a `sqrt(n)` by `sqrt(n)` row-major matrix. (Assume for now that `lb_n` is
201        // even, i.e., `n` is a square number.) To perform bit-order reversal we:
202        //  1. Bit-reverse the order of the rows. (They are contiguous in memory, so this is
203        //     basically a series of large `memcpy`s.)
204        //  2. Transpose the matrix.
205        //  3. Bit-reverse the order of the rows.
206        // This is equivalent to, for every index `0 <= i < n`:
207        //  1. bit-reversing `i[lb_n / 2..lb_n]`,
208        //  2. swapping `i[0..lb_n / 2]` and `i[lb_n / 2..lb_n]`,
209        //  3. bit-reversing `i[lb_n / 2..lb_n]`.
210        //
211        // If `lb_n` is odd, i.e., `n` is not a square number, then the above procedure requires
212        // slight modification. At steps 1 and 3 we bit-reverse bits `ceil(lb_n / 2)..lb_n`, of the
213        // index (shuffling `floor(lb_n / 2)` chunks of length `ceil(lb_n / 2)`). At step 2, we
214        // perform _two_ transposes. We treat `arr` as two matrices, one where the middle bit of the
215        // index is `0` and another, where the middle bit is `1`; we transpose each individually.
216
217        let lb_num_chunks = lb_n >> 1;
218        let lb_chunk_size = lb_n - lb_num_chunks;
219        unsafe {
220            reverse_index_bits_in_place_chunks(arr, lb_num_chunks, lb_chunk_size);
221            transpose_in_place_square(arr, lb_chunk_size, lb_num_chunks, 0);
222            if lb_num_chunks != lb_chunk_size {
223                // `arr` cannot be interpreted as a square matrix. We instead interpret it as a
224                // `1 << lb_num_chunks` by `2` by `1 << lb_num_chunks` tensor, in row-major order.
225                // The above transpose acted on `tensor[..., 0, ...]` (all indices with middle bit
226                // `0`). We still need to transpose `tensor[..., 1, ...]`. To do so, we advance
227                // arr by `1 << lb_num_chunks` effectively, adding that to every index.
228                let arr_with_offset = &mut arr[1 << lb_num_chunks..];
229                transpose_in_place_square(arr_with_offset, lb_chunk_size, lb_num_chunks, 0);
230            }
231            reverse_index_bits_in_place_chunks(arr, lb_num_chunks, lb_chunk_size);
232        }
233    }
234}
235
236// Lookup table of 6-bit reverses.
237// NB: 2^6=64 bytes is a cacheline. A smaller table wastes cache space.
238#[rustfmt::skip]
239const BIT_REVERSE_6BIT: &[u8] = &[
240    0o00, 0o40, 0o20, 0o60, 0o10, 0o50, 0o30, 0o70,
241    0o04, 0o44, 0o24, 0o64, 0o14, 0o54, 0o34, 0o74,
242    0o02, 0o42, 0o22, 0o62, 0o12, 0o52, 0o32, 0o72,
243    0o06, 0o46, 0o26, 0o66, 0o16, 0o56, 0o36, 0o76,
244    0o01, 0o41, 0o21, 0o61, 0o11, 0o51, 0o31, 0o71,
245    0o05, 0o45, 0o25, 0o65, 0o15, 0o55, 0o35, 0o75,
246    0o03, 0o43, 0o23, 0o63, 0o13, 0o53, 0o33, 0o73,
247    0o07, 0o47, 0o27, 0o67, 0o17, 0o57, 0o37, 0o77,
248];
249
250#[inline(always)]
251pub fn assume(p: bool) {
252    debug_assert!(p);
253    if !p {
254        unsafe {
255            unreachable_unchecked();
256        }
257    }
258}
259
260/// Try to force Rust to emit a branch. Example:
261///     if x > 2 {
262///         y = foo();
263///         branch_hint();
264///     } else {
265///         y = bar();
266///     }
267/// This function has no semantics. It is a hint only.
268#[inline(always)]
269pub fn branch_hint() {
270    // NOTE: These are the currently supported assembly architectures. See the
271    // [nightly reference](https://doc.rust-lang.org/nightly/reference/inline-assembly.html) for
272    // the most up-to-date list.
273    #[cfg(any(
274        target_arch = "aarch64",
275        target_arch = "arm",
276        target_arch = "riscv32",
277        target_arch = "riscv64",
278        target_arch = "x86",
279        target_arch = "x86_64",
280    ))]
281    unsafe {
282        core::arch::asm!("", options(nomem, nostack, preserves_flags));
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use alloc::vec;
289    use alloc::vec::Vec;
290
291    use rand::rngs::OsRng;
292    use rand::Rng;
293
294    use crate::{log2_ceil, log2_strict};
295
296    #[test]
297    fn test_reverse_index_bits() {
298        let lengths = [32, 128, 1 << 16];
299        let mut rng = OsRng;
300        for _ in 0..32 {
301            for length in lengths {
302                let mut rand_list: Vec<u32> = Vec::with_capacity(length);
303                rand_list.resize_with(length, || rng.gen());
304
305                let out = super::reverse_index_bits(&rand_list);
306                let expect = reverse_index_bits_naive(&rand_list);
307
308                for (out, expect) in out.iter().zip(&expect) {
309                    assert_eq!(out, expect);
310                }
311            }
312        }
313    }
314
315    #[test]
316    fn test_reverse_index_bits_in_place() {
317        let lengths = [32, 128, 1 << 16];
318        let mut rng = OsRng;
319        for _ in 0..32 {
320            for length in lengths {
321                let mut rand_list: Vec<u32> = Vec::with_capacity(length);
322                rand_list.resize_with(length, || rng.gen());
323
324                let expect = reverse_index_bits_naive(&rand_list);
325
326                super::reverse_index_bits_in_place(&mut rand_list);
327
328                for (got, expect) in rand_list.iter().zip(&expect) {
329                    assert_eq!(got, expect);
330                }
331            }
332        }
333    }
334
335    #[test]
336    fn test_log2_strict() {
337        assert_eq!(log2_strict(1), 0);
338        assert_eq!(log2_strict(2), 1);
339        assert_eq!(log2_strict(1 << 18), 18);
340        assert_eq!(log2_strict(1 << 31), 31);
341        assert_eq!(
342            log2_strict(1 << (usize::BITS - 1)),
343            usize::BITS as usize - 1
344        );
345    }
346
347    #[test]
348    #[should_panic]
349    fn test_log2_strict_zero() {
350        log2_strict(0);
351    }
352
353    #[test]
354    #[should_panic]
355    fn test_log2_strict_nonpower_2() {
356        log2_strict(0x78c341c65ae6d262);
357    }
358
359    #[test]
360    #[should_panic]
361    fn test_log2_strict_usize_max() {
362        log2_strict(usize::MAX);
363    }
364
365    #[test]
366    fn test_log2_ceil() {
367        // Powers of 2
368        assert_eq!(log2_ceil(0), 0);
369        assert_eq!(log2_ceil(1), 0);
370        assert_eq!(log2_ceil(2), 1);
371        assert_eq!(log2_ceil(1 << 18), 18);
372        assert_eq!(log2_ceil(1 << 31), 31);
373        assert_eq!(log2_ceil(1 << (usize::BITS - 1)), usize::BITS as usize - 1);
374
375        // Nonpowers; want to round up
376        assert_eq!(log2_ceil(3), 2);
377        assert_eq!(log2_ceil(0x14fe901b), 29);
378        assert_eq!(
379            log2_ceil((1 << (usize::BITS - 1)) + 1),
380            usize::BITS as usize
381        );
382        assert_eq!(log2_ceil(usize::MAX - 1), usize::BITS as usize);
383        assert_eq!(log2_ceil(usize::MAX), usize::BITS as usize);
384    }
385
386    fn reverse_index_bits_naive<T: Copy>(arr: &[T]) -> Vec<T> {
387        let n = arr.len();
388        let n_power = log2_strict(n);
389
390        let mut out = vec![None; n];
391        for (i, v) in arr.iter().enumerate() {
392            let dst = i.reverse_bits() >> (64 - n_power);
393            out[dst] = Some(*v);
394        }
395
396        out.into_iter().map(|x| x.unwrap()).collect()
397    }
398}