1use alloc::vec::Vec;
2
3use p3_matrix::dense::RowMajorMatrix;
4use p3_matrix::row_index_mapped::{RowIndexMap, RowIndexMappedView};
5use p3_matrix::Matrix;
6use p3_util::{log2_strict_usize, reverse_bits_len};
7
8#[inline]
9pub(crate) fn cfft_permute_index(index: usize, log_n: usize) -> usize {
10 let (index, lsb) = (index >> 1, index & 1);
11 reverse_bits_len(
12 if lsb == 0 {
13 index
14 } else {
15 (1 << log_n) - index - 1
16 },
17 log_n,
18 )
19}
20
21pub(crate) fn cfft_permute_slice<T: Clone>(xs: &[T]) -> Vec<T> {
22 let log_n = log2_strict_usize(xs.len());
23 (0..xs.len())
24 .map(|i| xs[cfft_permute_index(i, log_n)].clone())
25 .collect()
26}
27
28pub(crate) fn cfft_permute_slice_chunked_in_place<T>(xs: &mut [T], chunk_size: usize) {
29 assert_eq!(xs.len() % chunk_size, 0);
30 let n_chunks = xs.len() / chunk_size;
31 let log_n = log2_strict_usize(n_chunks);
32 for i in 0..n_chunks {
33 let j = cfft_permute_index(i, log_n);
34 if i < j {
35 for k in 0..chunk_size {
37 xs.swap(i * chunk_size + k, j * chunk_size + k);
38 }
39 }
49 }
50}
51
52pub type CfftView<M> = RowIndexMappedView<CfftPerm, M>;
53
54#[derive(Copy, Clone)]
55pub struct CfftPerm {
56 log_height: usize,
57}
58
59impl RowIndexMap for CfftPerm {
60 fn height(&self) -> usize {
61 1 << self.log_height
62 }
63 fn map_row_index(&self, r: usize) -> usize {
64 cfft_permute_index(r, self.log_height)
65 }
66 fn to_row_major_matrix<T: Clone + Send + Sync, Inner: Matrix<T>>(
67 &self,
68 inner: Inner,
69 ) -> RowMajorMatrix<T> {
70 let mut inner = inner.to_row_major_matrix();
71 cfft_permute_slice_chunked_in_place(&mut inner.values, inner.width);
72 inner
73 }
74}
75
76pub(crate) trait CfftPermutable<T: Send + Sync>: Matrix<T> + Sized {
77 fn cfft_perm_rows(self) -> CfftView<Self>;
78}
79
80impl<T: Send + Sync, M: Matrix<T>> CfftPermutable<T> for M {
81 fn cfft_perm_rows(self) -> CfftView<M> {
82 RowIndexMappedView {
83 index_map: CfftPerm {
84 log_height: log2_strict_usize(self.height()),
85 },
86 inner: self,
87 }
88 }
89}
90
91#[cfg(test)]
92mod tests {
93 use itertools::Itertools;
94
95 use super::*;
96
97 #[test]
98 fn ordering() {
99 assert_eq!(
101 (0..8).map(|i| cfft_permute_index(i, 3)).collect_vec(),
102 &[0, 7, 4, 3, 2, 5, 6, 1],
103 );
104 for log_n in 1..5 {
105 let n = 1 << log_n;
106 let sigma = |i| cfft_permute_index(i, log_n);
107 for i in 0..n {
108 assert_eq!(sigma(sigma(i)), i);
110 }
111 assert_eq!(
113 cfft_permute_slice(&(0..n).collect_vec()),
114 (0..n).map(sigma).collect_vec()
115 );
116 }
117 }
118}