1#![allow(clippy::needless_range_loop)]
2#![no_std]
3
4extern crate alloc;
5
6use alloc::vec::Vec;
7use core::hint::unreachable_unchecked;
8use core::mem::size_of;
9use core::ptr::{swap, swap_nonoverlapping};
10
11use crate::transpose_util::transpose_in_place_square;
12
13mod transpose_util;
14
15pub const fn bits_u64(n: u64) -> usize {
16 (64 - n.leading_zeros()) as usize
17}
18
19#[must_use]
21pub const fn log2_ceil(n: usize) -> usize {
22 (usize::BITS - n.saturating_sub(1).leading_zeros()) as usize
23}
24
25pub fn log2_strict(n: usize) -> usize {
27 let res = n.trailing_zeros();
28 assert!(n.wrapping_shr(res) == 1, "Not a power of two: {n}");
29 assume(n == 1 << res);
32 res as usize
33}
34
35pub const fn log_floor(n: u64, base: u64) -> usize {
37 assert!(n > 0);
38 assert!(base > 1);
39 let mut i = 0;
40 let mut cur: u64 = 1;
41 loop {
42 let (mul, overflow) = cur.overflowing_mul(base);
43 if overflow || mul > n {
44 return i;
45 } else {
46 i += 1;
47 cur = mul;
48 }
49 }
50}
51
52pub fn reverse_index_bits<T: Copy>(arr: &[T]) -> Vec<T> {
54 let n = arr.len();
55 let n_power = log2_strict(n);
56
57 if n_power <= 6 {
58 reverse_index_bits_small(arr, n_power)
59 } else {
60 reverse_index_bits_large(arr, n_power)
61 }
62}
63
64fn reverse_index_bits_small<T: Copy>(arr: &[T], n_power: usize) -> Vec<T> {
73 let n = arr.len();
74 let mut result = Vec::with_capacity(n);
75 let dst_shr_amt = 6 - n_power;
77 for i in 0..n {
78 let src = (BIT_REVERSE_6BIT[i] as usize) >> dst_shr_amt;
79 result.push(arr[src]);
80 }
81 result
82}
83
84fn reverse_index_bits_large<T: Copy>(arr: &[T], n_power: usize) -> Vec<T> {
85 let n = arr.len();
86 let src_lo_shr_amt = 64 - (n_power - 6);
90 let src_hi_shl_amt = n_power - 6;
91 let mut result = Vec::with_capacity(n);
92 for i_chunk in 0..(n >> 6) {
93 let src_lo = i_chunk.reverse_bits() >> src_lo_shr_amt;
94 for i_lo in 0..(1 << 6) {
95 let src_hi = (BIT_REVERSE_6BIT[i_lo] as usize) << src_hi_shl_amt;
96 let src = src_hi + src_lo;
97 result.push(arr[src]);
98 }
99 }
100 result
101}
102
103#[cfg(not(target_arch = "aarch64"))]
106unsafe fn reverse_index_bits_in_place_small<T>(arr: &mut [T], lb_n: usize) {
107 if lb_n <= 6 {
108 let dst_shr_amt = 6 - lb_n as u32;
110 for src in 0..arr.len() {
111 let dst = (BIT_REVERSE_6BIT[src] as usize).wrapping_shr(dst_shr_amt);
115 if src < dst {
116 swap(arr.get_unchecked_mut(src), arr.get_unchecked_mut(dst));
117 }
118 }
119 } else {
120 let dst_lo_shr_amt = usize::BITS - (lb_n - 6) as u32;
124 let dst_hi_shl_amt = lb_n - 6;
125 for src_chunk in 0..(arr.len() >> 6) {
126 let src_hi = src_chunk << 6;
127 let dst_lo = src_chunk.reverse_bits().wrapping_shr(dst_lo_shr_amt);
131 for src_lo in 0..(1 << 6) {
132 let dst_hi = (BIT_REVERSE_6BIT[src_lo] as usize) << dst_hi_shl_amt;
133 let src = src_hi + src_lo;
134 let dst = dst_hi + dst_lo;
135 if src < dst {
136 swap(arr.get_unchecked_mut(src), arr.get_unchecked_mut(dst));
137 }
138 }
139 }
140 }
141}
142
143#[cfg(target_arch = "aarch64")]
146unsafe fn reverse_index_bits_in_place_small<T>(arr: &mut [T], lb_n: usize) {
147 for src in 0..arr.len() {
149 let dst = src.reverse_bits().wrapping_shr(usize::BITS - lb_n as u32);
153 if src < dst {
154 swap(arr.get_unchecked_mut(src), arr.get_unchecked_mut(dst));
155 }
156 }
157}
158
159unsafe fn reverse_index_bits_in_place_chunks<T>(
163 arr: &mut [T],
164 lb_num_chunks: usize,
165 lb_chunk_size: usize,
166) {
167 for i in 0..1usize << lb_num_chunks {
168 let j = i
170 .reverse_bits()
171 .wrapping_shr(usize::BITS - lb_num_chunks as u32);
172 if i < j {
173 swap_nonoverlapping(
174 arr.get_unchecked_mut(i << lb_chunk_size),
175 arr.get_unchecked_mut(j << lb_chunk_size),
176 1 << lb_chunk_size,
177 );
178 }
179 }
180}
181
182const BIG_T_SIZE: usize = 1 << 14;
184const SMALL_ARR_SIZE: usize = 1 << 16;
185pub fn reverse_index_bits_in_place<T>(arr: &mut [T]) {
186 let n = arr.len();
187 let lb_n = log2_strict(n);
188 if size_of::<T>() << lb_n <= SMALL_ARR_SIZE || size_of::<T>() >= BIG_T_SIZE {
192 unsafe {
193 reverse_index_bits_in_place_small(arr, lb_n);
194 }
195 } else {
196 debug_assert!(n >= 4); let lb_num_chunks = lb_n >> 1;
218 let lb_chunk_size = lb_n - lb_num_chunks;
219 unsafe {
220 reverse_index_bits_in_place_chunks(arr, lb_num_chunks, lb_chunk_size);
221 transpose_in_place_square(arr, lb_chunk_size, lb_num_chunks, 0);
222 if lb_num_chunks != lb_chunk_size {
223 let arr_with_offset = &mut arr[1 << lb_num_chunks..];
229 transpose_in_place_square(arr_with_offset, lb_chunk_size, lb_num_chunks, 0);
230 }
231 reverse_index_bits_in_place_chunks(arr, lb_num_chunks, lb_chunk_size);
232 }
233 }
234}
235
236#[rustfmt::skip]
239const BIT_REVERSE_6BIT: &[u8] = &[
240 0o00, 0o40, 0o20, 0o60, 0o10, 0o50, 0o30, 0o70,
241 0o04, 0o44, 0o24, 0o64, 0o14, 0o54, 0o34, 0o74,
242 0o02, 0o42, 0o22, 0o62, 0o12, 0o52, 0o32, 0o72,
243 0o06, 0o46, 0o26, 0o66, 0o16, 0o56, 0o36, 0o76,
244 0o01, 0o41, 0o21, 0o61, 0o11, 0o51, 0o31, 0o71,
245 0o05, 0o45, 0o25, 0o65, 0o15, 0o55, 0o35, 0o75,
246 0o03, 0o43, 0o23, 0o63, 0o13, 0o53, 0o33, 0o73,
247 0o07, 0o47, 0o27, 0o67, 0o17, 0o57, 0o37, 0o77,
248];
249
250#[inline(always)]
251pub fn assume(p: bool) {
252 debug_assert!(p);
253 if !p {
254 unsafe {
255 unreachable_unchecked();
256 }
257 }
258}
259
260#[inline(always)]
269pub fn branch_hint() {
270 #[cfg(any(
274 target_arch = "aarch64",
275 target_arch = "arm",
276 target_arch = "riscv32",
277 target_arch = "riscv64",
278 target_arch = "x86",
279 target_arch = "x86_64",
280 ))]
281 unsafe {
282 core::arch::asm!("", options(nomem, nostack, preserves_flags));
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use alloc::vec;
289 use alloc::vec::Vec;
290
291 use rand::rngs::OsRng;
292 use rand::Rng;
293
294 use crate::{log2_ceil, log2_strict};
295
296 #[test]
297 fn test_reverse_index_bits() {
298 let lengths = [32, 128, 1 << 16];
299 let mut rng = OsRng;
300 for _ in 0..32 {
301 for length in lengths {
302 let mut rand_list: Vec<u32> = Vec::with_capacity(length);
303 rand_list.resize_with(length, || rng.gen());
304
305 let out = super::reverse_index_bits(&rand_list);
306 let expect = reverse_index_bits_naive(&rand_list);
307
308 for (out, expect) in out.iter().zip(&expect) {
309 assert_eq!(out, expect);
310 }
311 }
312 }
313 }
314
315 #[test]
316 fn test_reverse_index_bits_in_place() {
317 let lengths = [32, 128, 1 << 16];
318 let mut rng = OsRng;
319 for _ in 0..32 {
320 for length in lengths {
321 let mut rand_list: Vec<u32> = Vec::with_capacity(length);
322 rand_list.resize_with(length, || rng.gen());
323
324 let expect = reverse_index_bits_naive(&rand_list);
325
326 super::reverse_index_bits_in_place(&mut rand_list);
327
328 for (got, expect) in rand_list.iter().zip(&expect) {
329 assert_eq!(got, expect);
330 }
331 }
332 }
333 }
334
335 #[test]
336 fn test_log2_strict() {
337 assert_eq!(log2_strict(1), 0);
338 assert_eq!(log2_strict(2), 1);
339 assert_eq!(log2_strict(1 << 18), 18);
340 assert_eq!(log2_strict(1 << 31), 31);
341 assert_eq!(
342 log2_strict(1 << (usize::BITS - 1)),
343 usize::BITS as usize - 1
344 );
345 }
346
347 #[test]
348 #[should_panic]
349 fn test_log2_strict_zero() {
350 log2_strict(0);
351 }
352
353 #[test]
354 #[should_panic]
355 fn test_log2_strict_nonpower_2() {
356 log2_strict(0x78c341c65ae6d262);
357 }
358
359 #[test]
360 #[should_panic]
361 fn test_log2_strict_usize_max() {
362 log2_strict(usize::MAX);
363 }
364
365 #[test]
366 fn test_log2_ceil() {
367 assert_eq!(log2_ceil(0), 0);
369 assert_eq!(log2_ceil(1), 0);
370 assert_eq!(log2_ceil(2), 1);
371 assert_eq!(log2_ceil(1 << 18), 18);
372 assert_eq!(log2_ceil(1 << 31), 31);
373 assert_eq!(log2_ceil(1 << (usize::BITS - 1)), usize::BITS as usize - 1);
374
375 assert_eq!(log2_ceil(3), 2);
377 assert_eq!(log2_ceil(0x14fe901b), 29);
378 assert_eq!(
379 log2_ceil((1 << (usize::BITS - 1)) + 1),
380 usize::BITS as usize
381 );
382 assert_eq!(log2_ceil(usize::MAX - 1), usize::BITS as usize);
383 assert_eq!(log2_ceil(usize::MAX), usize::BITS as usize);
384 }
385
386 fn reverse_index_bits_naive<T: Copy>(arr: &[T]) -> Vec<T> {
387 let n = arr.len();
388 let n_power = log2_strict(n);
389
390 let mut out = vec![None; n];
391 for (i, v) in arr.iter().enumerate() {
392 let dst = i.reverse_bits() >> (64 - n_power);
393 out[dst] = Some(*v);
394 }
395
396 out.into_iter().map(|x| x.unwrap()).collect()
397 }
398}