1use std::ptr;
6
7use rustfft::FftDirection;
8use rustfft::{num_complex::Complex, FftPlanner};
9
10pub fn fft_2d(width: usize, height: usize, img_buffer: &mut [Complex<f64>]) {
26 fft_2d_with_direction(width, height, img_buffer, FftDirection::Forward)
27}
28
29pub fn ifft_2d(width: usize, height: usize, img_buffer: &mut [Complex<f64>]) {
43 fft_2d_with_direction(width, height, img_buffer, FftDirection::Inverse)
44}
45
46fn fft_2d_with_direction(
61 width: usize,
62 height: usize,
63 img_buffer: &mut [Complex<f64>],
64 direction: FftDirection,
65) {
66 let mut planner = FftPlanner::new();
68 let fft_width = planner.plan_fft(width, direction);
69 let mut scratch = vec![Complex::default(); fft_width.get_inplace_scratch_len()];
70 for row_buffer in img_buffer.chunks_exact_mut(width) {
71 fft_width.process_with_scratch(row_buffer, &mut scratch);
72 }
73
74 let mut transposed = transpose(width, height, img_buffer);
76 let fft_height = planner.plan_fft(height, direction);
77 scratch.resize(fft_height.get_outofplace_scratch_len(), Complex::default());
78 for (tr_buf, col_buf) in transposed
79 .chunks_exact_mut(height)
80 .zip(img_buffer.chunks_exact_mut(height))
81 {
82 fft_height.process_outofplace_with_scratch(tr_buf, col_buf, &mut scratch);
83 }
84}
85
86fn transpose<T: Copy + Default>(width: usize, height: usize, matrix: &[T]) -> Vec<T> {
87 let mut ind = 0;
88 let mut ind_tr;
89 let mut transposed = vec![T::default(); matrix.len()];
90 for row in 0..height {
91 ind_tr = row;
92 for _ in 0..width {
93 transposed[ind_tr] = matrix[ind];
94 ind += 1;
95 ind_tr += height;
96 }
97 }
98 transposed
99}
100
101pub fn ifftshift<T: Copy + Default>(width: usize, height: usize, matrix: &[T]) -> Vec<T> {
107 let is_even = |length| length % 2 == 0;
109 assert!(is_even(width), "Need a dedicated implementation");
110 assert!(is_even(height), "Need a dedicated implementation");
111 fftshift(width, height, matrix)
112}
113
114pub fn fftshift<T: Copy + Default>(width: usize, height: usize, matrix: &[T]) -> Vec<T> {
117 let mut shifted = vec![T::default(); matrix.len()];
118 let half_width = width / 2;
119 let half_height = height / 2;
120 let height_off = (height - half_height) * width;
121 for row in 0..half_height {
123 let mrow_start = row * width;
125 let m_row = &matrix[mrow_start..mrow_start + width];
126 let srow_start = mrow_start + height_off;
128 let s_row = &mut shifted[srow_start..srow_start + width];
129 s_row[width - half_width..width].copy_from_slice(&m_row[0..half_width]);
131 s_row[0..width - half_width].copy_from_slice(&m_row[half_width..width]);
132 }
133 for row in half_height..height {
135 let mrow_start = row * width;
137 let m_row = &matrix[mrow_start..mrow_start + width];
138 let srow_start = (row - half_height) * width;
140 let s_row = &mut shifted[srow_start..srow_start + width];
141 s_row[width - half_width..width].copy_from_slice(&m_row[0..half_width]);
143 s_row[0..width - half_width].copy_from_slice(&m_row[half_width..width]);
144 }
145 shifted
146}
147
148pub unsafe fn fftshift_zerocopy<T: Copy>(
158 width: usize,
159 height: usize,
160 matrix: &mut [T],
161) -> &mut [T] {
162 let half_width = width / 2;
163 let half_height = height / 2;
164 let half_width_ceil = width.div_ceil(2);
165 let half_height_ceil = height.div_ceil(2);
166
167 let mid = matrix.len() / 2;
168 let mid_point = matrix.len().div_ceil(2);
169
170 let matrix_p = matrix.as_mut_ptr();
171
172 if height == 1 || width == 1 {
173 ptr::swap_nonoverlapping(matrix_p, matrix_p.add(mid_point), mid);
174 return matrix;
175 }
176
177 for h in 0..half_height {
178 let count = half_width_ceil;
179 let q2_line = matrix_p.add(h * width);
180 let q4_line = matrix_p.add((h + half_height_ceil) * width + half_width);
181 ptr::swap_nonoverlapping(q2_line, q4_line, count);
182 }
183 for h in 0..half_height_ceil {
184 let count = width - half_width_ceil;
185 let q1_start = h * width + half_width_ceil;
186 let q3_start = (h + half_height) * width;
187 ptr::swap_nonoverlapping(matrix_p.add(q1_start), matrix_p.add(q3_start), count);
188 }
189
190 matrix
191}
192
193#[cfg(feature = "rustdct")]
196pub mod dcst {
198
199 use super::transpose;
200 use rustdct::DctPlanner;
201
202 pub fn dct_2d(width: usize, height: usize, img_buffer: &mut [f64]) {
217 let mut planner = DctPlanner::new();
219 let dct_width = planner.plan_dct2(width);
220 let mut scratch = vec![0.0; dct_width.get_scratch_len()];
221 for row_buffer in img_buffer.chunks_exact_mut(width) {
222 dct_width.process_dct2_with_scratch(row_buffer, &mut scratch);
223 }
224
225 let mut transposed = transpose(width, height, img_buffer);
227 let dct_height = planner.plan_dct2(height);
228 scratch.resize(dct_height.get_scratch_len(), 0.0);
229 for column_buffer in transposed.chunks_exact_mut(height) {
230 dct_height.process_dct2_with_scratch(column_buffer, &mut scratch);
231 }
232 img_buffer.copy_from_slice(&transposed);
233 }
234
235 #[cfg(feature = "parallel")]
240 pub fn par_dct_2d(width: usize, height: usize, img_buffer: &mut [f64]) {
241 use rayon::prelude::{ParallelIterator, ParallelSliceMut};
242
243 let mut planner = DctPlanner::new();
244 let dct_width = planner.plan_dct2(width);
245
246 img_buffer
247 .par_chunks_exact_mut(width)
248 .for_each(|row_buffer| {
249 dct_width.process_dct2(row_buffer);
250 });
251
252 let mut transposed = transpose(width, height, img_buffer);
253 let dct_height = planner.plan_dct2(height);
254
255 transposed
256 .par_chunks_exact_mut(height)
257 .for_each(|column_buffer| {
258 dct_height.process_dct2(column_buffer);
259 });
260
261 img_buffer.copy_from_slice(&transposed);
262 }
263
264 pub fn idct_2d(width: usize, height: usize, img_buffer: &mut [f64]) {
279 let mut planner = DctPlanner::new();
281 let dct_width = planner.plan_dct3(width);
282 let mut scratch = vec![0.0; dct_width.get_scratch_len()];
283 for row_buffer in img_buffer.chunks_exact_mut(width) {
284 dct_width.process_dct3_with_scratch(row_buffer, &mut scratch);
285 }
286
287 let mut transposed = transpose(width, height, img_buffer);
289 let dct_height = planner.plan_dct3(height);
290 scratch.resize(dct_height.get_scratch_len(), 0.0);
291 for column_buffer in transposed.chunks_exact_mut(height) {
292 dct_height.process_dct3_with_scratch(column_buffer, &mut scratch);
293 }
294 img_buffer.copy_from_slice(&transposed);
295 }
296
297 #[cfg(feature = "parallel")]
302 pub fn par_idct_2d(width: usize, height: usize, img_buffer: &mut [f64]) {
303 use rayon::prelude::{ParallelIterator, ParallelSliceMut};
304
305 let mut planner = DctPlanner::new();
306 let dct_width = planner.plan_dct3(width);
307 img_buffer
308 .par_chunks_exact_mut(width)
309 .for_each(|row_buffer| {
310 dct_width.process_dct3(row_buffer);
311 });
312
313 let mut transposed = transpose(width, height, img_buffer);
314 let dct_height = planner.plan_dct3(height);
315 transposed
316 .par_chunks_exact_mut(height)
317 .for_each(|column_buffer| {
318 dct_height.process_dct3(column_buffer);
319 });
320 img_buffer.copy_from_slice(&transposed);
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327
328 #[test]
329 #[rustfmt::skip]
330 fn test_zerocopy_fft_shift() {
331 let mut matrix = [
332 1, 2, 3,
333 4, 5, 6,
334 7, 8, 9,
335 ];
336 let mut matrix2 = [
337 1, 2, 3, 4,
338 5, 6, 7, 8,
339 9, 10, 11, 12,
340 13, 14, 15, 16
341 ];
342 unsafe {
343 assert_eq!(
345 fftshift_zerocopy(4, 2, &mut [
346 1, 2, 3, 4,
347 5, 6, 7, 8,
348 ]),
349 [7, 8, 5, 6,
350 3, 4, 1, 2,],
351 );
352 assert_eq!(
354 fftshift_zerocopy(3, 3, fftshift_zerocopy(3, 3, &mut matrix.clone()),),
355 &matrix,
356 );
357 assert_eq!(
358 fftshift_zerocopy(4, 4, fftshift_zerocopy(4, 4, &mut matrix2.clone()),),
359 &matrix2,
360 );
361
362 assert_eq!(
366 fftshift_zerocopy(3, 3, &mut matrix),
367 [8, 9, 4,
368 3, 5, 7,
369 6, 1, 2,],
370 );
371
372 assert_eq!(
374 &fftshift(4, 4, matrix2.clone().as_slice()),
375 fftshift_zerocopy(4, 4, &mut matrix2),
376 );
377 }
378 }
379
380 #[test]
381 #[cfg(all(feature = "parallel", feature = "rustdct"))]
382 fn test_identical_par_dct_result() {
383 let test_vec = vec![
384 54.75, 0.25, 69.39, 121.95, 15.86, 17.24, 77.48, 108.55, 127.40, 93.14, 49.28, 61.86,
385 55.75, 47.64, 28.32, 35.08, 92.85, 66.36, 94.34, 12.58, 50.07, 66.83, 101.12, 67.24,
386 111.74, 12.77, 114.64, 122.66, 86.15, 122.18, 33.94, 120.62, 107.30, 76.17, 99.48,
387 44.19, 86.03, 113.70, 28.54, 110.29, 80.88, 127.94, 14.04, 70.76, 80.95, 79.83, 56.34,
388 11.44, 65.98, 107.16, 54.12, 92.06, 5.32, 47.41, 83.55, 46.60, 17.94, 23.93, 56.11,
389 64.69, 87.37, 47.92, 61.87, 63.50, 40.83, 53.61, 57.16, 18.06, 1.11, 51.35, 53.03,
390 98.74, 43.84, 104.86, 52.87, 103.40, 114.36, 77.39, 45.10, 19.30, 90.93, 4.71, 95.27,
391 26.99, 68.58, 112.49, 114.11, 11.85, 124.35, 28.06, 31.43, 12.53, 57.44, 63.72, 126.73,
392 97.03, 97.45, 90.99, 15.45, 86.07, 27.62, 25.03, 106.54, 79.98, 49.95, 96.92, 124.75,
393 80.09, 127.06, 84.39, 120.42, 124.40, 15.50, 121.84, 105.86, 24.44, 81.38, 111.54,
394 27.66, 1.35, 119.06, 71.15, 108.78, 8.80, 19.83, 27.76, 75.44, 35.15,
395 ];
396 let mut non_para = test_vec.clone();
397 let mut parallel = test_vec.clone();
398 dcst::dct_2d(16, 8, &mut non_para);
399 dcst::par_dct_2d(16, 8, &mut parallel);
400 assert_eq!(non_para, parallel);
401 }
402 #[test]
403 #[cfg(all(feature = "parallel", feature = "rustdct"))]
404 fn test_identical_par_idct_result() {
405 let test_vec = vec![
406 72.16, 47.41, 122.96, 52.90, 36.35, 98.84, 84.12, 34.52, 61.06, 112.66, 39.91, 67.93,
407 84.70, 127.92, 13.63, 107.69, 4.49, 13.85, 124.56, 30.33, 105.12, 90.85, 75.41, 121.80,
408 90.34, 105.42, 49.07, 14.55, 14.52, 33.06, 112.38, 46.69, 125.16, 73.96, 125.25, 7.11,
409 4.42, 38.53, 105.64, 73.45, 43.45, 64.49, 7.68, 85.51, 109.86, 15.45, 122.59, 113.16,
410 64.02, 117.34, 113.04, 56.70, 99.40, 120.27, 51.70, 11.26, 44.75, 1.58, 34.81, 30.54,
411 6.71, 62.75, 72.62, 108.74, 17.51, 54.30, 44.35, 13.97, 26.33, 86.44, 34.88, 105.40,
412 67.61, 22.40, 6.95, 48.64, 7.90, 76.50, 35.04, 29.92, 123.98, 83.62, 3.96, 75.32,
413 42.37, 21.68, 23.58, 98.29, 19.81, 20.84, 110.50, 112.61, 92.65, 30.85, 113.19, 56.70,
414 46.94, 107.89, 92.47, 12.77, 34.66, 0.95, 127.74, 53.54, 56.52, 106.35, 25.50, 52.36,
415 100.60, 13.80, 19.96, 101.19, 58.99, 85.71, 30.79, 41.23, 56.03, 68.65, 46.78, 36.18,
416 30.12, 63.23, 25.27, 93.14, 39.77, 72.21, 7.03, 62.79,
417 ];
418 let mut non_para = test_vec.clone();
419 let mut parallel = test_vec.clone();
420 dcst::idct_2d(16, 8, &mut non_para);
421 dcst::par_idct_2d(16, 8, &mut parallel);
422 assert_eq!(non_para, parallel);
423 }
424}