1use crate::common::RadixFactor;
2use crate::Complex;
3use crate::FftNum;
4use std::ops::{Deref, DerefMut};
5
6pub unsafe fn transpose_small<T: Copy>(width: usize, height: usize, input: &[T], output: &mut [T]) {
10 for x in 0..width {
11 for y in 0..height {
12 let input_index = x + y * width;
13 let output_index = y + x * height;
14
15 *output.get_unchecked_mut(output_index) = *input.get_unchecked(input_index);
16 }
17 }
18}
19
20#[allow(unused)]
21pub unsafe fn workaround_transmute<T, U>(slice: &[T]) -> &[U] {
22 let ptr = slice.as_ptr() as *const U;
23 let len = slice.len();
24 std::slice::from_raw_parts(ptr, len)
25}
26#[allow(unused)]
27pub unsafe fn workaround_transmute_mut<T, U>(slice: &mut [T]) -> &mut [U] {
28 let ptr = slice.as_mut_ptr() as *mut U;
29 let len = slice.len();
30 std::slice::from_raw_parts_mut(ptr, len)
31}
32
33pub(crate) trait LoadStore<T: FftNum>: DerefMut {
34 unsafe fn load(&self, idx: usize) -> Complex<T>;
35 unsafe fn store(&mut self, val: Complex<T>, idx: usize);
36}
37
38impl<T: FftNum> LoadStore<T> for &mut [Complex<T>] {
39 #[inline(always)]
40 unsafe fn load(&self, idx: usize) -> Complex<T> {
41 debug_assert!(idx < self.len());
42 *self.get_unchecked(idx)
43 }
44 #[inline(always)]
45 unsafe fn store(&mut self, val: Complex<T>, idx: usize) {
46 debug_assert!(idx < self.len());
47 *self.get_unchecked_mut(idx) = val;
48 }
49}
50impl<T: FftNum, const N: usize> LoadStore<T> for &mut [Complex<T>; N] {
51 #[inline(always)]
52 unsafe fn load(&self, idx: usize) -> Complex<T> {
53 debug_assert!(idx < self.len());
54 *self.get_unchecked(idx)
55 }
56 #[inline(always)]
57 unsafe fn store(&mut self, val: Complex<T>, idx: usize) {
58 debug_assert!(idx < self.len());
59 *self.get_unchecked_mut(idx) = val;
60 }
61}
62
63pub(crate) struct DoubleBuf<'a, T> {
64 pub input: &'a [Complex<T>],
65 pub output: &'a mut [Complex<T>],
66}
67impl<'a, T> Deref for DoubleBuf<'a, T> {
68 type Target = [Complex<T>];
69 fn deref(&self) -> &Self::Target {
70 self.input
71 }
72}
73impl<'a, T> DerefMut for DoubleBuf<'a, T> {
74 fn deref_mut(&mut self) -> &mut Self::Target {
75 self.output
76 }
77}
78impl<'a, T: FftNum> LoadStore<T> for DoubleBuf<'a, T> {
79 #[inline(always)]
80 unsafe fn load(&self, idx: usize) -> Complex<T> {
81 debug_assert!(idx < self.input.len());
82 *self.input.get_unchecked(idx)
83 }
84 #[inline(always)]
85 unsafe fn store(&mut self, val: Complex<T>, idx: usize) {
86 debug_assert!(idx < self.output.len());
87 *self.output.get_unchecked_mut(idx) = val;
88 }
89}
90
91pub(crate) trait Load<T: FftNum>: Deref {
92 unsafe fn load(&self, idx: usize) -> Complex<T>;
93}
94
95impl<T: FftNum> Load<T> for &[Complex<T>] {
96 #[inline(always)]
97 unsafe fn load(&self, idx: usize) -> Complex<T> {
98 debug_assert!(idx < self.len());
99 *self.get_unchecked(idx)
100 }
101}
102impl<T: FftNum, const N: usize> Load<T> for &[Complex<T>; N] {
103 #[inline(always)]
104 unsafe fn load(&self, idx: usize) -> Complex<T> {
105 debug_assert!(idx < self.len());
106 *self.get_unchecked(idx)
107 }
108}
109
110#[cfg(test)]
111mod unit_tests {
112 use super::*;
113 use crate::test_utils::random_signal;
114 use num_complex::Complex;
115 use num_traits::Zero;
116
117 #[test]
118 fn test_transpose() {
119 let sizes: Vec<usize> = (1..16).collect();
120
121 for &width in &sizes {
122 for &height in &sizes {
123 let len = width * height;
124
125 let input: Vec<Complex<f32>> = random_signal(len);
126 let mut output = vec![Zero::zero(); len];
127
128 unsafe { transpose_small(width, height, &input, &mut output) };
129
130 for x in 0..width {
131 for y in 0..height {
132 assert_eq!(
133 input[x + y * width],
134 output[y + x * height],
135 "x = {}, y = {}",
136 x,
137 y
138 );
139 }
140 }
141 }
142 }
143 }
144}
145
146pub fn iter_chunks<T>(
149 mut buffer: &mut [T],
150 chunk_size: usize,
151 mut chunk_fn: impl FnMut(&mut [T]),
152) -> Result<(), ()> {
153 while buffer.len() >= chunk_size {
155 let (head, tail) = buffer.split_at_mut(chunk_size);
156 buffer = tail;
157
158 chunk_fn(head);
159 }
160
161 if buffer.len() == 0 {
163 Ok(())
164 } else {
165 Err(())
166 }
167}
168
169pub fn iter_chunks_zipped<T>(
172 mut buffer1: &mut [T],
173 mut buffer2: &mut [T],
174 chunk_size: usize,
175 mut chunk_fn: impl FnMut(&mut [T], &mut [T]),
176) -> Result<(), ()> {
177 let uneven = if buffer1.len() > buffer2.len() {
179 buffer1 = &mut buffer1[..buffer2.len()];
180 true
181 } else if buffer2.len() < buffer1.len() {
182 buffer2 = &mut buffer2[..buffer1.len()];
183 true
184 } else {
185 false
186 };
187
188 while buffer1.len() >= chunk_size && buffer2.len() >= chunk_size {
190 let (head1, tail1) = buffer1.split_at_mut(chunk_size);
191 buffer1 = tail1;
192
193 let (head2, tail2) = buffer2.split_at_mut(chunk_size);
194 buffer2 = tail2;
195
196 chunk_fn(head1, head2);
197 }
198
199 if !uneven && buffer1.len() == 0 {
201 Ok(())
202 } else {
203 Err(())
204 }
205}
206
207pub fn bitreversed_transpose<T: Copy, const D: usize>(
212 height: usize,
213 input: &[T],
214 output: &mut [T],
215) {
216 let width = input.len() / height;
217
218 assert!(D > 1 && input.len() % height == 0 && input.len() == output.len());
220
221 let strided_width = width / D;
222 let rev_digits = if D.is_power_of_two() {
223 let width_bits = width.trailing_zeros();
224 let d_bits = D.trailing_zeros();
225
226 assert!(width_bits % d_bits == 0);
228 width_bits / d_bits
229 } else {
230 compute_logarithm::<D>(width).unwrap()
231 };
232
233 for x in 0..strided_width {
234 let mut i = 0;
235 let x_fwd = [(); D].map(|_| {
236 let value = D * x + i;
237 i += 1;
238 value
239 }); let x_rev = x_fwd.map(|x| reverse_bits::<D>(x, rev_digits));
241
242 for r in x_rev {
247 assert!(r < width);
248 }
249 for y in 0..height {
250 for (fwd, rev) in x_fwd.iter().zip(x_rev.iter()) {
251 let input_index = *fwd + y * width;
252 let output_index = y + *rev * height;
253
254 unsafe {
255 let temp = *input.get_unchecked(input_index);
256 *output.get_unchecked_mut(output_index) = temp;
257 }
258 }
259 }
260 }
261}
262
263pub fn reverse_bits<const D: usize>(value: usize, rev_digits: u32) -> usize {
267 assert!(D > 1);
268
269 let mut result: usize = 0;
270 let mut value = value;
271 for _ in 0..rev_digits {
272 result = (result * D) + (value % D);
273 value = value / D;
274 }
275 result
276}
277
278pub fn compute_logarithm<const D: usize>(value: usize) -> Option<u32> {
280 if value == 0 || D < 2 {
281 return None;
282 }
283
284 let mut current_exponent = 0;
285 let mut current_value = value;
286
287 while current_value % D == 0 {
288 current_exponent += 1;
289 current_value /= D;
290 }
291
292 if current_value == 1 {
293 Some(current_exponent)
294 } else {
295 None
296 }
297}
298
299pub(crate) struct TransposeFactor {
300 pub factor: RadixFactor,
301 pub count: u8,
302}
303
304pub(crate) fn factor_transpose<T: Copy, const D: usize>(
309 height: usize,
310 input: &[T],
311 output: &mut [T],
312 factors: &[TransposeFactor],
313) {
314 let width = input.len() / height;
315
316 assert!(width % D == 0 && D > 1 && input.len() % width == 0 && input.len() == output.len());
318
319 let strided_width = width / D;
320 for x in 0..strided_width {
321 let mut i = 0;
322 let x_fwd = [(); D].map(|_| {
323 let value = D * x + i;
324 i += 1;
325 value
326 }); let x_rev = x_fwd.map(|x| reverse_remainders(x, factors));
328
329 for r in x_rev {
334 assert!(r < width);
335 }
336 for y in 0..height {
337 for (fwd, rev) in x_fwd.iter().zip(x_rev.iter()) {
338 let input_index = *fwd + y * width;
339 let output_index = y + *rev * height;
340
341 unsafe {
342 let temp = *input.get_unchecked(input_index);
343 *output.get_unchecked_mut(output_index) = temp;
344 }
345 }
346 }
347 }
348}
349
350pub(crate) fn reverse_remainders(value: usize, factors: &[TransposeFactor]) -> usize {
354 let mut result: usize = 0;
355 let mut value = value;
356 for f in factors.iter() {
357 match f.factor {
358 RadixFactor::Factor2 => {
359 for _ in 0..f.count {
360 result = (result * 2) + (value % 2);
361 value = value / 2;
362 }
363 }
364 RadixFactor::Factor3 => {
365 for _ in 0..f.count {
366 result = (result * 3) + (value % 3);
367 value = value / 3;
368 }
369 }
370 RadixFactor::Factor4 => {
371 for _ in 0..f.count {
372 result = (result * 4) + (value % 4);
373 value = value / 4;
374 }
375 }
376 RadixFactor::Factor5 => {
377 for _ in 0..f.count {
378 result = (result * 5) + (value % 5);
379 value = value / 5;
380 }
381 }
382 RadixFactor::Factor6 => {
383 for _ in 0..f.count {
384 result = (result * 6) + (value % 6);
385 value = value / 6;
386 }
387 }
388 RadixFactor::Factor7 => {
389 for _ in 0..f.count {
390 result = (result * 7) + (value % 7);
391 value = value / 7;
392 }
393 }
394 }
395 }
396 result
397}