p3_circle/
ordering.rs

1use alloc::vec::Vec;
2
3use p3_matrix::dense::RowMajorMatrix;
4use p3_matrix::row_index_mapped::{RowIndexMap, RowIndexMappedView};
5use p3_matrix::Matrix;
6use p3_util::{log2_strict_usize, reverse_bits_len};
7
8#[inline]
9pub(crate) fn cfft_permute_index(index: usize, log_n: usize) -> usize {
10    let (index, lsb) = (index >> 1, index & 1);
11    reverse_bits_len(
12        if lsb == 0 {
13            index
14        } else {
15            (1 << log_n) - index - 1
16        },
17        log_n,
18    )
19}
20
21pub(crate) fn cfft_permute_slice<T: Clone>(xs: &[T]) -> Vec<T> {
22    let log_n = log2_strict_usize(xs.len());
23    (0..xs.len())
24        .map(|i| xs[cfft_permute_index(i, log_n)].clone())
25        .collect()
26}
27
28pub(crate) fn cfft_permute_slice_chunked_in_place<T>(xs: &mut [T], chunk_size: usize) {
29    assert_eq!(xs.len() % chunk_size, 0);
30    let n_chunks = xs.len() / chunk_size;
31    let log_n = log2_strict_usize(n_chunks);
32    for i in 0..n_chunks {
33        let j = cfft_permute_index(i, log_n);
34        if i < j {
35            // somehow this is slightly faster than the unsafe block below
36            for k in 0..chunk_size {
37                xs.swap(i * chunk_size + k, j * chunk_size + k);
38            }
39            /*
40            unsafe {
41                core::ptr::swap_nonoverlapping(
42                    xs.as_mut_ptr().add(i * chunk_size),
43                    xs.as_mut_ptr().add(j * chunk_size),
44                    chunk_size,
45                );
46            }
47            */
48        }
49    }
50}
51
52pub type CfftView<M> = RowIndexMappedView<CfftPerm, M>;
53
54#[derive(Copy, Clone)]
55pub struct CfftPerm {
56    log_height: usize,
57}
58
59impl RowIndexMap for CfftPerm {
60    fn height(&self) -> usize {
61        1 << self.log_height
62    }
63    fn map_row_index(&self, r: usize) -> usize {
64        cfft_permute_index(r, self.log_height)
65    }
66    fn to_row_major_matrix<T: Clone + Send + Sync, Inner: Matrix<T>>(
67        &self,
68        inner: Inner,
69    ) -> RowMajorMatrix<T> {
70        let mut inner = inner.to_row_major_matrix();
71        cfft_permute_slice_chunked_in_place(&mut inner.values, inner.width);
72        inner
73    }
74}
75
76pub(crate) trait CfftPermutable<T: Send + Sync>: Matrix<T> + Sized {
77    fn cfft_perm_rows(self) -> CfftView<Self>;
78}
79
80impl<T: Send + Sync, M: Matrix<T>> CfftPermutable<T> for M {
81    fn cfft_perm_rows(self) -> CfftView<M> {
82        RowIndexMappedView {
83            index_map: CfftPerm {
84                log_height: log2_strict_usize(self.height()),
85            },
86            inner: self,
87        }
88    }
89}
90
91#[cfg(test)]
92mod tests {
93    use itertools::Itertools;
94
95    use super::*;
96
97    #[test]
98    fn ordering() {
99        // reference ordering derived by hand
100        assert_eq!(
101            (0..8).map(|i| cfft_permute_index(i, 3)).collect_vec(),
102            &[0, 7, 4, 3, 2, 5, 6, 1],
103        );
104        for log_n in 1..5 {
105            let n = 1 << log_n;
106            let sigma = |i| cfft_permute_index(i, log_n);
107            for i in 0..n {
108                // involution: σ(σ(i)) = i
109                assert_eq!(sigma(sigma(i)), i);
110            }
111            // perm_slice same as perm_idx
112            assert_eq!(
113                cfft_permute_slice(&(0..n).collect_vec()),
114                (0..n).map(sigma).collect_vec()
115            );
116        }
117    }
118}