1#![allow(clippy::items_after_statements)]
6#![allow(clippy::manual_contains)]
7
8use crate::api::{Direction, Flags};
9use crate::dft::problem::Sign;
10use crate::dft::solvers::{
11 BluesteinSolver, CooleyTukeySolver, CtVariant, DirectSolver, GenericSolver, NopSolver,
12 StockhamSolver,
13};
14use crate::kernel::{Complex, Float, Tensor};
15use crate::prelude::*;
16use crate::rdft::solvers::R2rKind;
17
18pub struct RealPlanND<T: Float> {
20 dims: Vec<usize>,
21 kind: RealPlanKind,
22 _marker: core::marker::PhantomData<T>,
23}
24impl<T: Float> RealPlanND<T> {
25 #[must_use]
27 pub fn r2c(dims: &[usize], _flags: Flags) -> Option<Self> {
28 if dims.is_empty() || dims.iter().any(|&d| d == 0) {
29 return None;
30 }
31 Some(Self {
32 dims: dims.to_vec(),
33 kind: RealPlanKind::R2C,
34 _marker: core::marker::PhantomData,
35 })
36 }
37 #[must_use]
39 pub fn c2r(dims: &[usize], _flags: Flags) -> Option<Self> {
40 if dims.is_empty() || dims.iter().any(|&d| d == 0) {
41 return None;
42 }
43 Some(Self {
44 dims: dims.to_vec(),
45 kind: RealPlanKind::C2R,
46 _marker: core::marker::PhantomData,
47 })
48 }
49 pub fn execute_r2c(&self, input: &[T], output: &mut [Complex<T>]) {
51 assert_eq!(self.kind, RealPlanKind::R2C);
52 let expected_in: usize = self.dims.iter().product();
53 let last = *self
54 .dims
55 .last()
56 .expect("RealPlanND dimensions cannot be empty");
57 let prefix: usize = self.dims[..self.dims.len() - 1].iter().product();
58 let prefix = prefix.max(1);
59 let expected_out = prefix * (last / 2 + 1);
60 assert_eq!(input.len(), expected_in);
61 assert_eq!(output.len(), expected_out);
62 match self.dims.len() {
63 1 => {
64 use crate::rdft::solvers::R2cSolver;
65 let solver = R2cSolver::new(last);
66 solver.execute(input, output);
67 }
68 2 => {
69 let plan = RealPlan2D::<T>::r2c(self.dims[0], self.dims[1], Flags::ESTIMATE)
70 .expect("Failed to create internal 2D R2C plan");
71 plan.execute_r2c(input, output);
72 }
73 3 => {
74 let plan =
75 RealPlan3D::<T>::r2c(self.dims[0], self.dims[1], self.dims[2], Flags::ESTIMATE)
76 .expect("Failed to create internal 3D R2C plan");
77 plan.execute_r2c(input, output);
78 }
79 _ => {
80 let out_last = last / 2 + 1;
81 let inner_size = last;
82 use crate::rdft::solvers::R2cSolver;
83 let r2c_solver = R2cSolver::new(last);
84 let mut temp = vec![Complex::zero(); prefix * out_last];
85 for row in 0..prefix {
86 let in_start = row * inner_size;
87 let out_start = row * out_last;
88 r2c_solver.execute(
89 &input[in_start..in_start + inner_size],
90 &mut temp[out_start..out_start + out_last],
91 );
92 }
93 use crate::dft::solvers::GenericSolver;
94 let remaining_dims = &self.dims[..self.dims.len() - 1];
95 for (dim_idx, &dim_size) in remaining_dims.iter().enumerate().rev() {
96 let solver = GenericSolver::new(dim_size);
97 let mut col_in = vec![Complex::zero(); dim_size];
98 let mut col_out = vec![Complex::zero(); dim_size];
99 let inner_stride: usize = self.dims[dim_idx + 1..]
100 .iter()
101 .map(|&d| {
102 if dim_idx == self.dims.len() - 2 {
103 out_last
104 } else {
105 d
106 }
107 })
108 .product();
109 let outer_count: usize = self.dims[..dim_idx].iter().product();
110 let outer_count = outer_count.max(1);
111 for outer in 0..outer_count {
112 for inner in 0..inner_stride {
113 for k in 0..dim_size {
114 let idx =
115 outer * (dim_size * inner_stride) + k * inner_stride + inner;
116 col_in[k] = temp[idx];
117 }
118 solver.execute(&col_in, &mut col_out, Sign::Forward);
119 for k in 0..dim_size {
120 let idx =
121 outer * (dim_size * inner_stride) + k * inner_stride + inner;
122 temp[idx] = col_out[k];
123 }
124 }
125 }
126 }
127 output.copy_from_slice(&temp);
128 }
129 }
130 }
131 pub fn execute_c2r(&self, input: &[Complex<T>], output: &mut [T]) {
133 assert_eq!(self.kind, RealPlanKind::C2R);
134 let expected_out: usize = self.dims.iter().product();
135 let last = *self
136 .dims
137 .last()
138 .expect("RealPlanND dimensions cannot be empty");
139 let prefix: usize = self.dims[..self.dims.len() - 1].iter().product();
140 let prefix = prefix.max(1);
141 let expected_in = prefix * (last / 2 + 1);
142 assert_eq!(input.len(), expected_in);
143 assert_eq!(output.len(), expected_out);
144 match self.dims.len() {
145 1 => {
146 use crate::rdft::solvers::C2rSolver;
147 let solver = C2rSolver::new(last);
148 solver.execute(input, output);
149 }
150 2 => {
151 let plan = RealPlan2D::<T>::c2r(self.dims[0], self.dims[1], Flags::ESTIMATE)
152 .expect("Failed to create internal 2D C2R plan");
153 plan.execute_c2r(input, output);
154 }
155 3 => {
156 let plan =
157 RealPlan3D::<T>::c2r(self.dims[0], self.dims[1], self.dims[2], Flags::ESTIMATE)
158 .expect("Failed to create internal 3D C2R plan");
159 plan.execute_c2r(input, output);
160 }
161 _ => {
162 let out_last = last / 2 + 1;
163 let mut temp: Vec<Complex<T>> = input.to_vec();
164 use crate::dft::solvers::GenericSolver;
165 let remaining_dims = &self.dims[..self.dims.len() - 1];
166 for (dim_idx, &dim_size) in remaining_dims.iter().enumerate() {
167 let solver = GenericSolver::new(dim_size);
168 let mut col_in = vec![Complex::zero(); dim_size];
169 let mut col_out = vec![Complex::zero(); dim_size];
170 let inner_stride: usize = self.dims[dim_idx + 1..]
171 .iter()
172 .map(|&d| {
173 if dim_idx == self.dims.len() - 2 {
174 out_last
175 } else {
176 d
177 }
178 })
179 .product();
180 let outer_count: usize = self.dims[..dim_idx].iter().product();
181 let outer_count = outer_count.max(1);
182 for outer in 0..outer_count {
183 for inner in 0..inner_stride {
184 for k in 0..dim_size {
185 let idx =
186 outer * (dim_size * inner_stride) + k * inner_stride + inner;
187 col_in[k] = temp[idx];
188 }
189 solver.execute(&col_in, &mut col_out, Sign::Backward);
190 for k in 0..dim_size {
191 let idx =
192 outer * (dim_size * inner_stride) + k * inner_stride + inner;
193 temp[idx] = col_out[k];
194 }
195 }
196 }
197 }
198 use crate::rdft::solvers::C2rSolver;
199 let c2r_solver = C2rSolver::new(last);
200 for row in 0..prefix {
201 let in_start = row * out_last;
202 let out_start = row * last;
203 c2r_solver.execute(
204 &temp[in_start..in_start + out_last],
205 &mut output[out_start..out_start + last],
206 );
207 }
208 }
209 }
210 }
211}
212pub struct Plan3D<T: Float> {
217 n0: usize,
219 n1: usize,
220 n2: usize,
221 direction: Direction,
223 plane_plan: Plan2D<T>,
225 z_plan: Plan<T>,
227}
228impl<T: Float> Plan3D<T> {
229 #[must_use]
241 pub fn new(
242 n0: usize,
243 n1: usize,
244 n2: usize,
245 direction: Direction,
246 flags: Flags,
247 ) -> Option<Self> {
248 let plane_plan = Plan2D::new(n1, n2, direction, flags)?;
249 let z_plan = Plan::dft_1d(n0, direction, flags)?;
250 Some(Self {
251 n0,
252 n1,
253 n2,
254 direction,
255 plane_plan,
256 z_plan,
257 })
258 }
259 #[must_use]
261 pub fn size(&self) -> usize {
262 self.n0 * self.n1 * self.n2
263 }
264 #[must_use]
266 pub fn direction(&self) -> Direction {
267 self.direction
268 }
269 pub fn execute(&self, input: &[Complex<T>], output: &mut [Complex<T>]) {
276 let total = self.n0 * self.n1 * self.n2;
277 assert_eq!(input.len(), total, "Input size must match n0 × n1 × n2");
278 assert_eq!(output.len(), total, "Output size must match n0 × n1 × n2");
279 if total == 0 {
280 return;
281 }
282 let plane_size = self.n1 * self.n2;
283 let mut temp = vec![Complex::zero(); total];
284 for i in 0..self.n0 {
285 let plane_start = i * plane_size;
286 let plane_end = plane_start + plane_size;
287 self.plane_plan.execute(
288 &input[plane_start..plane_end],
289 &mut temp[plane_start..plane_end],
290 );
291 }
292 let mut z_col = vec![Complex::zero(); self.n0];
293 let mut z_out = vec![Complex::zero(); self.n0];
294 for j in 0..self.n1 {
295 for k in 0..self.n2 {
296 for i in 0..self.n0 {
297 z_col[i] = temp[i * plane_size + j * self.n2 + k];
298 }
299 self.z_plan.execute(&z_col, &mut z_out);
300 for i in 0..self.n0 {
301 output[i * plane_size + j * self.n2 + k] = z_out[i];
302 }
303 }
304 }
305 }
306 pub fn execute_inplace(&self, data: &mut [Complex<T>]) {
311 let total = self.n0 * self.n1 * self.n2;
312 assert_eq!(data.len(), total, "Data size must match n0 × n1 × n2");
313 if total == 0 {
314 return;
315 }
316 let plane_size = self.n1 * self.n2;
317 for i in 0..self.n0 {
318 let plane_start = i * plane_size;
319 let plane_end = plane_start + plane_size;
320 self.plane_plan
321 .execute_inplace(&mut data[plane_start..plane_end]);
322 }
323 let mut z_col = vec![Complex::zero(); self.n0];
324 for j in 0..self.n1 {
325 for k in 0..self.n2 {
326 for i in 0..self.n0 {
327 z_col[i] = data[i * plane_size + j * self.n2 + k];
328 }
329 self.z_plan.execute_inplace(&mut z_col);
330 for i in 0..self.n0 {
331 data[i * plane_size + j * self.n2 + k] = z_col[i];
332 }
333 }
334 }
335 }
336}
337#[derive(Debug, Clone, Copy, PartialEq, Eq)]
339pub enum RealPlanKind {
340 R2C,
342 C2R,
344}
345pub struct SplitPlanND<T: Float> {
347 dims: Vec<usize>,
348 direction: Direction,
349 _marker: core::marker::PhantomData<T>,
350}
351impl<T: Float> SplitPlanND<T> {
352 #[must_use]
354 pub fn new(dims: &[usize], direction: Direction, _flags: Flags) -> Option<Self> {
355 if dims.is_empty() || dims.iter().any(|&d| d == 0) {
356 return None;
357 }
358 Some(Self {
359 dims: dims.to_vec(),
360 direction,
361 _marker: core::marker::PhantomData,
362 })
363 }
364 pub fn execute(&self, in_real: &[T], in_imag: &[T], out_real: &mut [T], out_imag: &mut [T]) {
366 let total: usize = self.dims.iter().product();
367 assert_eq!(in_real.len(), total);
368 assert_eq!(in_imag.len(), total);
369 assert_eq!(out_real.len(), total);
370 assert_eq!(out_imag.len(), total);
371 let mut data: Vec<Complex<T>> = in_real
372 .iter()
373 .zip(in_imag.iter())
374 .map(|(&r, &i)| Complex::new(r, i))
375 .collect();
376 let sign = match self.direction {
377 Direction::Forward => Sign::Forward,
378 Direction::Backward => Sign::Backward,
379 };
380 use crate::dft::solvers::GenericSolver;
381 for (dim_idx, &dim_size) in self.dims.iter().enumerate().rev() {
382 let solver = GenericSolver::new(dim_size);
383 let mut buf = vec![Complex::zero(); dim_size];
384 let mut buf_out = vec![Complex::zero(); dim_size];
385 let inner_stride: usize = self.dims[dim_idx + 1..].iter().product();
386 let inner_stride = inner_stride.max(1);
387 let outer_count: usize = self.dims[..dim_idx].iter().product();
388 let outer_count = outer_count.max(1);
389 for outer in 0..outer_count {
390 for inner in 0..inner_stride {
391 for k in 0..dim_size {
392 let idx = outer * (dim_size * inner_stride) + k * inner_stride + inner;
393 buf[k] = data[idx];
394 }
395 solver.execute(&buf, &mut buf_out, sign);
396 for k in 0..dim_size {
397 let idx = outer * (dim_size * inner_stride) + k * inner_stride + inner;
398 data[idx] = buf_out[k];
399 }
400 }
401 }
402 }
403 if self.direction == Direction::Backward {
404 let scale = T::one() / T::from_usize(total);
405 for c in &mut data {
406 *c = *c * scale;
407 }
408 }
409 for (i, c) in data.iter().enumerate() {
410 out_real[i] = c.re;
411 out_imag[i] = c.im;
412 }
413 }
414 pub fn execute_inplace(&self, real: &mut [T], imag: &mut [T]) {
416 let total: usize = self.dims.iter().product();
417 assert_eq!(real.len(), total);
418 assert_eq!(imag.len(), total);
419 let mut out_real = vec![T::ZERO; total];
420 let mut out_imag = vec![T::ZERO; total];
421 self.execute(real, imag, &mut out_real, &mut out_imag);
422 real.copy_from_slice(&out_real);
423 imag.copy_from_slice(&out_imag);
424 }
425}
426pub struct RealPlan<T: Float> {
431 n: usize,
433 kind: RealPlanKind,
435 _marker: core::marker::PhantomData<T>,
436}
437impl<T: Float> RealPlan<T> {
438 #[must_use]
447 pub fn r2c_1d(n: usize, _flags: Flags) -> Option<Self> {
448 if n == 0 {
449 return None;
450 }
451 Some(Self {
452 n,
453 kind: RealPlanKind::R2C,
454 _marker: core::marker::PhantomData,
455 })
456 }
457 #[must_use]
466 pub fn c2r_1d(n: usize, _flags: Flags) -> Option<Self> {
467 if n == 0 {
468 return None;
469 }
470 Some(Self {
471 n,
472 kind: RealPlanKind::C2R,
473 _marker: core::marker::PhantomData,
474 })
475 }
476 #[must_use]
478 pub fn size(&self) -> usize {
479 self.n
480 }
481 #[must_use]
483 pub fn complex_size(&self) -> usize {
484 self.n / 2 + 1
485 }
486 #[must_use]
488 pub fn kind(&self) -> RealPlanKind {
489 self.kind
490 }
491 pub fn execute_r2c(&self, input: &[T], output: &mut [Complex<T>]) {
496 use crate::rdft::solvers::R2cSolver;
497 assert_eq!(self.kind, RealPlanKind::R2C, "Plan must be R2C");
498 assert_eq!(input.len(), self.n, "Input size must match plan size");
499 assert_eq!(
500 output.len(),
501 self.complex_size(),
502 "Output size must be n/2+1"
503 );
504 R2cSolver::new(self.n).execute(input, output);
505 }
506 pub fn execute_c2r(&self, input: &[Complex<T>], output: &mut [T]) {
513 use crate::rdft::solvers::C2rSolver;
514 assert_eq!(self.kind, RealPlanKind::C2R, "Plan must be C2R");
515 assert_eq!(input.len(), self.complex_size(), "Input size must be n/2+1");
516 assert_eq!(output.len(), self.n, "Output size must match plan size");
517 C2rSolver::new(self.n).execute_normalized(input, output);
518 }
519 pub fn execute_c2r_unnormalized(&self, input: &[Complex<T>], output: &mut [T]) {
524 use crate::rdft::solvers::C2rSolver;
525 assert_eq!(self.kind, RealPlanKind::C2R, "Plan must be C2R");
526 assert_eq!(input.len(), self.complex_size(), "Input size must be n/2+1");
527 assert_eq!(output.len(), self.n, "Output size must match plan size");
528 C2rSolver::new(self.n).execute(input, output);
529 }
530}
531pub struct PlanND<T: Float> {
536 dims: Vec<usize>,
538 total_size: usize,
540 strides: Vec<usize>,
542 direction: Direction,
544 plans: Vec<Plan<T>>,
546}
547impl<T: Float> PlanND<T> {
548 #[must_use]
564 pub fn new(dims: &[usize], direction: Direction, flags: Flags) -> Option<Self> {
565 if dims.is_empty() {
566 return None;
567 }
568 let mut total_size: usize = 1;
569 for &d in dims {
570 total_size = total_size.checked_mul(d)?;
571 }
572 let mut strides = vec![1; dims.len()];
573 for i in (0..dims.len() - 1).rev() {
574 strides[i] = strides[i + 1] * dims[i + 1];
575 }
576 let mut plans = Vec::with_capacity(dims.len());
577 for &dim in dims {
578 plans.push(Plan::dft_1d(dim, direction, flags)?);
579 }
580 Some(Self {
581 dims: dims.to_vec(),
582 total_size,
583 strides,
584 direction,
585 plans,
586 })
587 }
588 #[must_use]
590 pub fn rank(&self) -> usize {
591 self.dims.len()
592 }
593 #[must_use]
595 pub fn dims(&self) -> &[usize] {
596 &self.dims
597 }
598 #[must_use]
600 pub fn size(&self) -> usize {
601 self.total_size
602 }
603 #[must_use]
605 pub fn direction(&self) -> Direction {
606 self.direction
607 }
608 pub fn execute(&self, input: &[Complex<T>], output: &mut [Complex<T>]) {
615 assert_eq!(
616 input.len(),
617 self.total_size,
618 "Input size must match total size"
619 );
620 assert_eq!(
621 output.len(),
622 self.total_size,
623 "Output size must match total size"
624 );
625 if self.total_size == 0 {
626 return;
627 }
628 let mut current = input.to_vec();
629 let mut next = vec![Complex::zero(); self.total_size];
630 for dim_idx in (0..self.dims.len()).rev() {
631 self.transform_along_dimension(¤t, &mut next, dim_idx);
632 core::mem::swap(&mut current, &mut next);
633 }
634 output.copy_from_slice(¤t);
635 }
636 pub fn execute_inplace(&self, data: &mut [Complex<T>]) {
641 assert_eq!(
642 data.len(),
643 self.total_size,
644 "Data size must match total size"
645 );
646 if self.total_size == 0 {
647 return;
648 }
649 let mut temp = vec![Complex::zero(); self.total_size];
650 for dim_idx in (0..self.dims.len()).rev() {
651 self.transform_along_dimension(data, &mut temp, dim_idx);
652 data.copy_from_slice(&temp);
653 }
654 }
655 fn transform_along_dimension(
660 &self,
661 input: &[Complex<T>],
662 output: &mut [Complex<T>],
663 dim_idx: usize,
664 ) {
665 let dim_size = self.dims[dim_idx];
666 let stride = self.strides[dim_idx];
667 let num_fibers = self.total_size / dim_size;
668 let mut fiber_in = vec![Complex::zero(); dim_size];
669 let mut fiber_out = vec![Complex::zero(); dim_size];
670 for fiber_idx in 0..num_fibers {
671 let start_idx = self.fiber_start_index(fiber_idx, dim_idx);
672 for i in 0..dim_size {
673 fiber_in[i] = input[start_idx + i * stride];
674 }
675 self.plans[dim_idx].execute(&fiber_in, &mut fiber_out);
676 for i in 0..dim_size {
677 output[start_idx + i * stride] = fiber_out[i];
678 }
679 }
680 }
681 fn fiber_start_index(&self, fiber_idx: usize, dim_idx: usize) -> usize {
685 let mut idx = 0;
686 let mut remaining = fiber_idx;
687 for d in 0..self.dims.len() {
688 if d == dim_idx {
689 continue;
690 }
691 let below_count = self.fiber_below_count(d, dim_idx);
692 let coord = remaining / below_count;
693 remaining %= below_count;
694 idx += coord * self.strides[d];
695 }
696 idx
697 }
698 fn fiber_below_count(&self, d: usize, dim_idx: usize) -> usize {
700 let mut count = 1;
701 for i in (d + 1)..self.dims.len() {
702 if i != dim_idx {
703 count *= self.dims[i];
704 }
705 }
706 count
707 }
708}
709pub struct RealPlan2D<T: Float> {
714 n0: usize,
715 n1: usize,
716 kind: RealPlanKind,
717 _marker: core::marker::PhantomData<T>,
718}
719impl<T: Float> RealPlan2D<T> {
720 #[must_use]
722 pub fn r2c(n0: usize, n1: usize, _flags: Flags) -> Option<Self> {
723 if n0 == 0 || n1 == 0 {
724 return None;
725 }
726 Some(Self {
727 n0,
728 n1,
729 kind: RealPlanKind::R2C,
730 _marker: core::marker::PhantomData,
731 })
732 }
733 #[must_use]
735 pub fn c2r(n0: usize, n1: usize, _flags: Flags) -> Option<Self> {
736 if n0 == 0 || n1 == 0 {
737 return None;
738 }
739 Some(Self {
740 n0,
741 n1,
742 kind: RealPlanKind::C2R,
743 _marker: core::marker::PhantomData,
744 })
745 }
746 pub fn execute_r2c(&self, input: &[T], output: &mut [Complex<T>]) {
751 assert_eq!(self.kind, RealPlanKind::R2C);
752 let expected_in = self.n0 * self.n1;
753 let expected_out = self.n0 * (self.n1 / 2 + 1);
754 assert_eq!(input.len(), expected_in);
755 assert_eq!(output.len(), expected_out);
756 use crate::rdft::solvers::R2cSolver;
757 let out_cols = self.n1 / 2 + 1;
758 let r2c_solver = R2cSolver::new(self.n1);
759 let mut temp = vec![Complex::zero(); self.n0 * out_cols];
760 for row in 0..self.n0 {
761 let in_start = row * self.n1;
762 let out_start = row * out_cols;
763 r2c_solver.execute(
764 &input[in_start..in_start + self.n1],
765 &mut temp[out_start..out_start + out_cols],
766 );
767 }
768 use crate::dft::solvers::GenericSolver;
769 let col_solver = GenericSolver::new(self.n0);
770 let mut col_in = vec![Complex::zero(); self.n0];
771 let mut col_out = vec![Complex::zero(); self.n0];
772 for col in 0..out_cols {
773 for row in 0..self.n0 {
774 col_in[row] = temp[row * out_cols + col];
775 }
776 col_solver.execute(&col_in, &mut col_out, Sign::Forward);
777 for row in 0..self.n0 {
778 output[row * out_cols + col] = col_out[row];
779 }
780 }
781 }
782 pub fn execute_c2r(&self, input: &[Complex<T>], output: &mut [T]) {
784 assert_eq!(self.kind, RealPlanKind::C2R);
785 let expected_in = self.n0 * (self.n1 / 2 + 1);
786 let expected_out = self.n0 * self.n1;
787 assert_eq!(input.len(), expected_in);
788 assert_eq!(output.len(), expected_out);
789 let out_cols = self.n1 / 2 + 1;
790 use crate::dft::solvers::GenericSolver;
791 let col_solver = GenericSolver::new(self.n0);
792 let mut temp = vec![Complex::zero(); self.n0 * out_cols];
793 let mut col_in = vec![Complex::zero(); self.n0];
794 let mut col_out = vec![Complex::zero(); self.n0];
795 for col in 0..out_cols {
796 for row in 0..self.n0 {
797 col_in[row] = input[row * out_cols + col];
798 }
799 col_solver.execute(&col_in, &mut col_out, Sign::Backward);
800 for row in 0..self.n0 {
801 temp[row * out_cols + col] = col_out[row];
802 }
803 }
804 use crate::rdft::solvers::C2rSolver;
805 let c2r_solver = C2rSolver::new(self.n1);
806 for row in 0..self.n0 {
807 let in_start = row * out_cols;
808 let out_start = row * self.n1;
809 c2r_solver.execute(
810 &temp[in_start..in_start + out_cols],
811 &mut output[out_start..out_start + self.n1],
812 );
813 }
814 }
815}
816pub struct RealPlan3D<T: Float> {
818 n0: usize,
819 n1: usize,
820 n2: usize,
821 kind: RealPlanKind,
822 _marker: core::marker::PhantomData<T>,
823}
824impl<T: Float> RealPlan3D<T> {
825 #[must_use]
827 pub fn r2c(n0: usize, n1: usize, n2: usize, _flags: Flags) -> Option<Self> {
828 if n0 == 0 || n1 == 0 || n2 == 0 {
829 return None;
830 }
831 Some(Self {
832 n0,
833 n1,
834 n2,
835 kind: RealPlanKind::R2C,
836 _marker: core::marker::PhantomData,
837 })
838 }
839 #[must_use]
841 pub fn c2r(n0: usize, n1: usize, n2: usize, _flags: Flags) -> Option<Self> {
842 if n0 == 0 || n1 == 0 || n2 == 0 {
843 return None;
844 }
845 Some(Self {
846 n0,
847 n1,
848 n2,
849 kind: RealPlanKind::C2R,
850 _marker: core::marker::PhantomData,
851 })
852 }
853 pub fn execute_r2c(&self, input: &[T], output: &mut [Complex<T>]) {
855 assert_eq!(self.kind, RealPlanKind::R2C);
856 let expected_in = self.n0 * self.n1 * self.n2;
857 let expected_out = self.n0 * self.n1 * (self.n2 / 2 + 1);
858 assert_eq!(input.len(), expected_in);
859 assert_eq!(output.len(), expected_out);
860 let out_last = self.n2 / 2 + 1;
861 let slice_in_size = self.n1 * self.n2;
862 let slice_out_size = self.n1 * out_last;
863 let plan_2d = RealPlan2D::<T>::r2c(self.n1, self.n2, Flags::ESTIMATE)
864 .expect("Failed to create internal 2D R2C plan");
865 let mut temp = vec![Complex::zero(); self.n0 * slice_out_size];
866 for i in 0..self.n0 {
867 let in_start = i * slice_in_size;
868 let out_start = i * slice_out_size;
869 plan_2d.execute_r2c(
870 &input[in_start..in_start + slice_in_size],
871 &mut temp[out_start..out_start + slice_out_size],
872 );
873 }
874 use crate::dft::solvers::GenericSolver;
875 let n0_solver = GenericSolver::new(self.n0);
876 let mut col_in = vec![Complex::zero(); self.n0];
877 let mut col_out = vec![Complex::zero(); self.n0];
878 for j in 0..self.n1 {
879 for k in 0..out_last {
880 for i in 0..self.n0 {
881 col_in[i] = temp[i * slice_out_size + j * out_last + k];
882 }
883 n0_solver.execute(&col_in, &mut col_out, Sign::Forward);
884 for i in 0..self.n0 {
885 output[i * slice_out_size + j * out_last + k] = col_out[i];
886 }
887 }
888 }
889 }
890 pub fn execute_c2r(&self, input: &[Complex<T>], output: &mut [T]) {
892 assert_eq!(self.kind, RealPlanKind::C2R);
893 let expected_in = self.n0 * self.n1 * (self.n2 / 2 + 1);
894 let expected_out = self.n0 * self.n1 * self.n2;
895 assert_eq!(input.len(), expected_in);
896 assert_eq!(output.len(), expected_out);
897 let out_last = self.n2 / 2 + 1;
898 let slice_in_size = self.n1 * out_last;
899 let slice_out_size = self.n1 * self.n2;
900 use crate::dft::solvers::GenericSolver;
901 let n0_solver = GenericSolver::new(self.n0);
902 let mut temp = vec![Complex::zero(); self.n0 * slice_in_size];
903 let mut col_in = vec![Complex::zero(); self.n0];
904 let mut col_out = vec![Complex::zero(); self.n0];
905 for j in 0..self.n1 {
906 for k in 0..out_last {
907 for i in 0..self.n0 {
908 col_in[i] = input[i * slice_in_size + j * out_last + k];
909 }
910 n0_solver.execute(&col_in, &mut col_out, Sign::Backward);
911 for i in 0..self.n0 {
912 temp[i * slice_in_size + j * out_last + k] = col_out[i];
913 }
914 }
915 }
916 let plan_2d = RealPlan2D::<T>::c2r(self.n1, self.n2, Flags::ESTIMATE)
917 .expect("Failed to create internal 2D C2R plan");
918 for i in 0..self.n0 {
919 let in_start = i * slice_in_size;
920 let out_start = i * slice_out_size;
921 plan_2d.execute_c2r(
922 &temp[in_start..in_start + slice_in_size],
923 &mut output[out_start..out_start + slice_out_size],
924 );
925 }
926 }
927}
928#[allow(dead_code)]
930enum Algorithm<T: Float> {
931 Nop,
933 Direct,
935 CooleyTukey(CtVariant),
937 Stockham,
939 Composite(usize),
941 Generic(Box<GenericSolver<T>>),
943 Bluestein(Box<BluesteinSolver<T>>),
945}
946pub struct GuruPlan<T: Float> {
978 dims: Tensor,
980 howmany: Tensor,
982 direction: Direction,
984 plans: Vec<Plan<T>>,
986}
987impl<T: Float> GuruPlan<T> {
988 pub fn dft(
1005 dims: &Tensor,
1006 howmany: &Tensor,
1007 direction: Direction,
1008 flags: Flags,
1009 ) -> Option<Self> {
1010 if dims.is_empty() {
1011 return None;
1012 }
1013 for dim in &dims.dims {
1014 if dim.n == 0 {
1015 return None;
1016 }
1017 }
1018 for dim in &howmany.dims {
1019 if dim.n == 0 {
1020 return None;
1021 }
1022 }
1023 let mut plans = Vec::with_capacity(dims.rank());
1024 for dim in &dims.dims {
1025 let plan = Plan::dft_1d(dim.n, direction, flags)?;
1026 plans.push(plan);
1027 }
1028 Some(Self {
1029 dims: dims.clone(),
1030 howmany: howmany.clone(),
1031 direction,
1032 plans,
1033 })
1034 }
1035 #[must_use]
1037 pub fn dims(&self) -> &Tensor {
1038 &self.dims
1039 }
1040 #[must_use]
1042 pub fn howmany(&self) -> &Tensor {
1043 &self.howmany
1044 }
1045 #[must_use]
1047 pub fn direction(&self) -> Direction {
1048 self.direction
1049 }
1050 #[must_use]
1052 pub fn transform_size(&self) -> usize {
1053 self.dims.total_size()
1054 }
1055 #[must_use]
1057 pub fn batch_count(&self) -> usize {
1058 if self.howmany.is_empty() {
1059 1
1060 } else {
1061 self.howmany.total_size()
1062 }
1063 }
1064 pub fn execute(&self, input: &[Complex<T>], output: &mut [Complex<T>]) {
1073 if self.dims.rank() == 1 && self.howmany.is_empty() {
1074 let dim = &self.dims.dims[0];
1075 if dim.is == 1 && dim.os == 1 {
1076 self.plans[0].execute(input, output);
1077 return;
1078 }
1079 }
1080 self.execute_batched(input, output);
1081 }
1082 fn execute_batched(&self, input: &[Complex<T>], output: &mut [Complex<T>]) {
1084 let batch_count = self.batch_count();
1085 if batch_count == 1 {
1086 self.execute_single(input, output, 0, 0);
1087 } else {
1088 for batch_idx in 0..batch_count {
1089 let (in_offset, out_offset) = self.compute_batch_offset(batch_idx);
1090 self.execute_single(input, output, in_offset, out_offset);
1091 }
1092 }
1093 }
1094 fn compute_batch_offset(&self, batch_idx: usize) -> (isize, isize) {
1096 if self.howmany.is_empty() {
1097 return (0, 0);
1098 }
1099 let mut in_offset: isize = 0;
1100 let mut out_offset: isize = 0;
1101 let mut remaining = batch_idx;
1102 for dim in self.howmany.dims.iter().rev() {
1103 let idx = remaining % dim.n;
1104 remaining /= dim.n;
1105 in_offset += (idx as isize) * dim.is;
1106 out_offset += (idx as isize) * dim.os;
1107 }
1108 (in_offset, out_offset)
1109 }
1110 fn execute_single(
1112 &self,
1113 input: &[Complex<T>],
1114 output: &mut [Complex<T>],
1115 in_offset: isize,
1116 out_offset: isize,
1117 ) {
1118 if self.dims.rank() == 1 {
1119 self.execute_1d(input, output, in_offset, out_offset);
1120 } else {
1121 self.execute_nd(input, output, in_offset, out_offset);
1122 }
1123 }
1124 fn execute_1d(
1126 &self,
1127 input: &[Complex<T>],
1128 output: &mut [Complex<T>],
1129 in_offset: isize,
1130 out_offset: isize,
1131 ) {
1132 let dim = &self.dims.dims[0];
1133 let n = dim.n;
1134 let in_stride = dim.is;
1135 let out_stride = dim.os;
1136 let input_contiguous: Vec<Complex<T>>;
1137 let input_slice = if in_stride == 1 && in_offset >= 0 {
1138 &input[in_offset as usize..in_offset as usize + n]
1139 } else {
1140 input_contiguous = (0..n)
1141 .map(|i| {
1142 let idx = in_offset + (i as isize) * in_stride;
1143 input[idx as usize]
1144 })
1145 .collect();
1146 &input_contiguous
1147 };
1148 let mut temp = vec![Complex::<T>::zero(); n];
1149 self.plans[0].execute(input_slice, &mut temp);
1150 for i in 0..n {
1151 let idx = out_offset + (i as isize) * out_stride;
1152 output[idx as usize] = temp[i];
1153 }
1154 }
1155 fn execute_nd(
1157 &self,
1158 input: &[Complex<T>],
1159 output: &mut [Complex<T>],
1160 in_offset: isize,
1161 out_offset: isize,
1162 ) {
1163 let total_size = self.transform_size();
1164 let mut work = vec![Complex::<T>::zero(); total_size];
1165 self.gather_nd(input, &mut work, in_offset);
1166 for (dim_idx, plan) in self.plans.iter().enumerate() {
1167 self.apply_1d_along_dimension(&mut work, dim_idx, plan);
1168 }
1169 self.scatter_nd(&work, output, out_offset);
1170 }
1171 fn gather_nd(&self, input: &[Complex<T>], work: &mut [Complex<T>], base_offset: isize) {
1173 let total = self.transform_size();
1174 for flat_idx in 0..total {
1175 let src_offset = self.compute_nd_offset(flat_idx, base_offset, true);
1176 work[flat_idx] = input[src_offset as usize];
1177 }
1178 }
1179 fn scatter_nd(&self, work: &[Complex<T>], output: &mut [Complex<T>], base_offset: isize) {
1181 let total = self.transform_size();
1182 for flat_idx in 0..total {
1183 let dst_offset = self.compute_nd_offset(flat_idx, base_offset, false);
1184 output[dst_offset as usize] = work[flat_idx];
1185 }
1186 }
1187 fn compute_nd_offset(&self, flat_idx: usize, base_offset: isize, is_input: bool) -> isize {
1189 let mut offset = base_offset;
1190 let mut remaining = flat_idx;
1191 for dim in self.dims.dims.iter().rev() {
1192 let idx = remaining % dim.n;
1193 remaining /= dim.n;
1194 let stride = if is_input { dim.is } else { dim.os };
1195 offset += (idx as isize) * stride;
1196 }
1197 offset
1198 }
1199 fn apply_1d_along_dimension(&self, work: &mut [Complex<T>], dim_idx: usize, plan: &Plan<T>) {
1201 let n = self.dims.dims[dim_idx].n;
1202 let inner_size: usize = self.dims.dims[dim_idx + 1..].iter().map(|d| d.n).product();
1203 let inner_size = if inner_size == 0 { 1 } else { inner_size };
1204 let outer_size: usize = self.dims.dims[..dim_idx].iter().map(|d| d.n).product();
1205 let outer_size = if outer_size == 0 { 1 } else { outer_size };
1206 let stride = inner_size;
1207 let mut temp_in = vec![Complex::<T>::zero(); n];
1208 let mut temp_out = vec![Complex::<T>::zero(); n];
1209 for outer in 0..outer_size {
1210 for inner in 0..inner_size {
1211 let base = outer * n * inner_size + inner;
1212 for i in 0..n {
1213 temp_in[i] = work[base + i * stride];
1214 }
1215 plan.execute(&temp_in, &mut temp_out);
1216 for i in 0..n {
1217 work[base + i * stride] = temp_out[i];
1218 }
1219 }
1220 }
1221 }
1222 pub fn execute_inplace(&self, data: &mut [Complex<T>]) {
1226 if !self.dims.is_inplace_compatible() {
1227 panic!("In-place execution requires identical input and output strides");
1228 }
1229 let batch_count = self.batch_count();
1230 if batch_count == 1 {
1231 self.execute_inplace_single(data, 0);
1232 } else {
1233 for batch_idx in 0..batch_count {
1234 let (offset, _) = self.compute_batch_offset(batch_idx);
1235 self.execute_inplace_single(data, offset);
1236 }
1237 }
1238 }
1239 fn execute_inplace_single(&self, data: &mut [Complex<T>], offset: isize) {
1241 let n = self.transform_size();
1242 let dim = &self.dims.dims[0];
1243 if self.dims.rank() == 1 && dim.is == 1 && offset >= 0 {
1244 let start = offset as usize;
1245 let end = start + n;
1246 let mut temp = vec![Complex::<T>::zero(); n];
1247 self.plans[0].execute(&data[start..end], &mut temp);
1248 data[start..end].copy_from_slice(&temp);
1249 } else {
1250 let mut work = vec![Complex::<T>::zero(); n];
1251 for (flat_idx, item) in work.iter_mut().enumerate().take(n) {
1252 let src_offset = self.compute_nd_offset(flat_idx, offset, true);
1253 *item = data[src_offset as usize];
1254 }
1255 let mut result = vec![Complex::<T>::zero(); n];
1256 if self.dims.rank() == 1 {
1257 self.plans[0].execute(&work, &mut result);
1258 } else {
1259 for (dim_idx, plan) in self.plans.iter().enumerate() {
1260 self.apply_1d_along_dimension(&mut work, dim_idx, plan);
1261 }
1262 result = work;
1263 }
1264 for (flat_idx, &item) in result.iter().enumerate().take(n) {
1265 let dst_offset = self.compute_nd_offset(flat_idx, offset, false);
1266 data[dst_offset as usize] = item;
1267 }
1268 }
1269 }
1270}
1271pub struct SplitPlan2D<T: Float> {
1273 plan: Plan2D<T>,
1275}
1276impl<T: Float> SplitPlan2D<T> {
1277 #[must_use]
1285 pub fn new(n0: usize, n1: usize, direction: Direction, flags: Flags) -> Option<Self> {
1286 let plan = Plan2D::new(n0, n1, direction, flags)?;
1287 Some(Self { plan })
1288 }
1289 #[must_use]
1291 pub fn rows(&self) -> usize {
1292 self.plan.n0
1293 }
1294 #[must_use]
1296 pub fn cols(&self) -> usize {
1297 self.plan.n1
1298 }
1299 #[must_use]
1301 pub fn size(&self) -> usize {
1302 self.plan.size()
1303 }
1304 #[must_use]
1306 pub fn direction(&self) -> Direction {
1307 self.plan.direction
1308 }
1309 pub fn execute(&self, in_real: &[T], in_imag: &[T], out_real: &mut [T], out_imag: &mut [T]) {
1316 let n = self.size();
1317 assert_eq!(in_real.len(), n, "Input real size must match n0 × n1");
1318 assert_eq!(in_imag.len(), n, "Input imaginary size must match n0 × n1");
1319 assert_eq!(out_real.len(), n, "Output real size must match n0 × n1");
1320 assert_eq!(
1321 out_imag.len(),
1322 n,
1323 "Output imaginary size must match n0 × n1"
1324 );
1325 let input: Vec<Complex<T>> = in_real
1326 .iter()
1327 .zip(in_imag.iter())
1328 .map(|(&re, &im)| Complex::new(re, im))
1329 .collect();
1330 let mut output = vec![Complex::<T>::zero(); n];
1331 self.plan.execute(&input, &mut output);
1332 for (i, c) in output.iter().enumerate() {
1333 out_real[i] = c.re;
1334 out_imag[i] = c.im;
1335 }
1336 }
1337 pub fn execute_inplace(&self, real: &mut [T], imag: &mut [T]) {
1339 let n = self.size();
1340 assert_eq!(real.len(), n, "Real size must match n0 × n1");
1341 assert_eq!(imag.len(), n, "Imaginary size must match n0 × n1");
1342 let mut data: Vec<Complex<T>> = real
1343 .iter()
1344 .zip(imag.iter())
1345 .map(|(&re, &im)| Complex::new(re, im))
1346 .collect();
1347 self.plan.execute_inplace(&mut data);
1348 for (i, c) in data.iter().enumerate() {
1349 real[i] = c.re;
1350 imag[i] = c.im;
1351 }
1352 }
1353}
1354pub struct SplitPlan3D<T: Float> {
1356 n0: usize,
1357 n1: usize,
1358 n2: usize,
1359 direction: Direction,
1360 _marker: core::marker::PhantomData<T>,
1361}
1362impl<T: Float> SplitPlan3D<T> {
1363 #[must_use]
1365 pub fn new(
1366 n0: usize,
1367 n1: usize,
1368 n2: usize,
1369 direction: Direction,
1370 _flags: Flags,
1371 ) -> Option<Self> {
1372 if n0 == 0 || n1 == 0 || n2 == 0 {
1373 return None;
1374 }
1375 Some(Self {
1376 n0,
1377 n1,
1378 n2,
1379 direction,
1380 _marker: core::marker::PhantomData,
1381 })
1382 }
1383 pub fn execute(&self, in_real: &[T], in_imag: &[T], out_real: &mut [T], out_imag: &mut [T]) {
1385 let total = self.n0 * self.n1 * self.n2;
1386 assert_eq!(in_real.len(), total);
1387 assert_eq!(in_imag.len(), total);
1388 assert_eq!(out_real.len(), total);
1389 assert_eq!(out_imag.len(), total);
1390 let mut data: Vec<Complex<T>> = in_real
1391 .iter()
1392 .zip(in_imag.iter())
1393 .map(|(&r, &i)| Complex::new(r, i))
1394 .collect();
1395 let sign = match self.direction {
1396 Direction::Forward => Sign::Forward,
1397 Direction::Backward => Sign::Backward,
1398 };
1399 use crate::dft::solvers::GenericSolver;
1400 let solver_n2 = GenericSolver::new(self.n2);
1401 let mut row = vec![Complex::zero(); self.n2];
1402 let mut row_out = vec![Complex::zero(); self.n2];
1403 for i in 0..(self.n0 * self.n1) {
1404 let start = i * self.n2;
1405 row.copy_from_slice(&data[start..start + self.n2]);
1406 solver_n2.execute(&row, &mut row_out, sign);
1407 data[start..start + self.n2].copy_from_slice(&row_out);
1408 }
1409 let solver_n1 = GenericSolver::new(self.n1);
1410 let mut col = vec![Complex::zero(); self.n1];
1411 let mut col_out = vec![Complex::zero(); self.n1];
1412 for i in 0..self.n0 {
1413 for k in 0..self.n2 {
1414 for j in 0..self.n1 {
1415 col[j] = data[i * self.n1 * self.n2 + j * self.n2 + k];
1416 }
1417 solver_n1.execute(&col, &mut col_out, sign);
1418 for j in 0..self.n1 {
1419 data[i * self.n1 * self.n2 + j * self.n2 + k] = col_out[j];
1420 }
1421 }
1422 }
1423 let solver_n0 = GenericSolver::new(self.n0);
1424 let mut depth = vec![Complex::zero(); self.n0];
1425 let mut depth_out = vec![Complex::zero(); self.n0];
1426 for j in 0..self.n1 {
1427 for k in 0..self.n2 {
1428 for i in 0..self.n0 {
1429 depth[i] = data[i * self.n1 * self.n2 + j * self.n2 + k];
1430 }
1431 solver_n0.execute(&depth, &mut depth_out, sign);
1432 for i in 0..self.n0 {
1433 data[i * self.n1 * self.n2 + j * self.n2 + k] = depth_out[i];
1434 }
1435 }
1436 }
1437 if self.direction == Direction::Backward {
1438 let scale = T::one() / T::from_usize(total);
1439 for c in &mut data {
1440 *c = *c * scale;
1441 }
1442 }
1443 for (i, c) in data.iter().enumerate() {
1444 out_real[i] = c.re;
1445 out_imag[i] = c.im;
1446 }
1447 }
1448 pub fn execute_inplace(&self, real: &mut [T], imag: &mut [T]) {
1450 let total = self.n0 * self.n1 * self.n2;
1451 assert_eq!(real.len(), total);
1452 assert_eq!(imag.len(), total);
1453 let mut out_real = vec![T::ZERO; total];
1454 let mut out_imag = vec![T::ZERO; total];
1455 self.execute(real, imag, &mut out_real, &mut out_imag);
1456 real.copy_from_slice(&out_real);
1457 imag.copy_from_slice(&out_imag);
1458 }
1459}
1460pub struct Plan<T: Float> {
1465 n: usize,
1467 direction: Direction,
1469 algorithm: Algorithm<T>,
1471}
1472impl<T: Float> Plan<T> {
1473 #[must_use]
1483 pub fn dft_1d(n: usize, direction: Direction, flags: Flags) -> Option<Self> {
1484 let algorithm = Self::select_algorithm(n, flags);
1485 Some(Self {
1486 n,
1487 direction,
1488 algorithm,
1489 })
1490 }
1491 pub fn dft_2d(_n0: usize, _n1: usize, _direction: Direction, _flags: Flags) -> Option<Self> {
1493 todo!("Implement dft_2d planning")
1494 }
1495 pub fn dft_3d(
1497 _n0: usize,
1498 _n1: usize,
1499 _n2: usize,
1500 _direction: Direction,
1501 _flags: Flags,
1502 ) -> Option<Self> {
1503 todo!("Implement dft_3d planning")
1504 }
1505 pub fn r2c_1d(_n: usize, _flags: Flags) -> Option<Self> {
1507 todo!("Implement r2c_1d planning")
1508 }
1509 pub fn c2r_1d(_n: usize, _flags: Flags) -> Option<Self> {
1511 todo!("Implement c2r_1d planning")
1512 }
1513 fn select_algorithm(n: usize, _flags: Flags) -> Algorithm<T> {
1515 use crate::dft::codelets::has_composite_codelet;
1516
1517 if n <= 1 {
1518 Algorithm::Nop
1519 } else if CooleyTukeySolver::<T>::applicable(n) {
1520 Algorithm::CooleyTukey(CtVariant::Dit)
1523 } else if has_composite_codelet(n) {
1524 Algorithm::Composite(n)
1526 } else if n <= 16 {
1527 Algorithm::Direct
1529 } else if GenericSolver::<T>::applicable(n) {
1530 Algorithm::Generic(Box::new(GenericSolver::new(n)))
1531 } else {
1532 Algorithm::Bluestein(Box::new(BluesteinSolver::new(n)))
1533 }
1534 }
1535 #[must_use]
1537 pub fn size(&self) -> usize {
1538 self.n
1539 }
1540 #[must_use]
1542 pub fn direction(&self) -> Direction {
1543 self.direction
1544 }
1545 pub fn execute(&self, input: &[Complex<T>], output: &mut [Complex<T>]) {
1550 use crate::dft::codelets::execute_composite_codelet;
1551
1552 assert_eq!(input.len(), self.n, "Input size must match plan size");
1553 assert_eq!(output.len(), self.n, "Output size must match plan size");
1554 let sign = match self.direction {
1555 Direction::Forward => Sign::Forward,
1556 Direction::Backward => Sign::Backward,
1557 };
1558 match &self.algorithm {
1559 Algorithm::Nop => {
1560 NopSolver::new().execute(input, output);
1561 }
1562 Algorithm::Direct => {
1563 DirectSolver::new().execute(input, output, sign);
1564 }
1565 Algorithm::CooleyTukey(variant) => {
1566 CooleyTukeySolver::new(*variant).execute(input, output, sign);
1567 }
1568 Algorithm::Stockham => {
1569 StockhamSolver::new().execute(input, output, sign);
1570 }
1571 Algorithm::Composite(n) => {
1572 output.copy_from_slice(input);
1573 let sign_int = if sign == Sign::Forward { -1 } else { 1 };
1574 execute_composite_codelet(output, *n, sign_int);
1575 }
1576 Algorithm::Generic(solver) => {
1577 solver.execute(input, output, sign);
1578 }
1579 Algorithm::Bluestein(solver) => {
1580 solver.execute(input, output, sign);
1581 }
1582 }
1583 }
1584 pub fn execute_inplace(&self, data: &mut [Complex<T>]) {
1589 use crate::dft::codelets::execute_composite_codelet;
1590
1591 assert_eq!(data.len(), self.n, "Data size must match plan size");
1592 let sign = match self.direction {
1593 Direction::Forward => Sign::Forward,
1594 Direction::Backward => Sign::Backward,
1595 };
1596 match &self.algorithm {
1597 Algorithm::Nop => {
1598 NopSolver::new().execute_inplace(data);
1599 }
1600 Algorithm::Direct => {
1601 DirectSolver::new().execute_inplace(data, sign);
1602 }
1603 Algorithm::CooleyTukey(variant) => {
1604 CooleyTukeySolver::new(*variant).execute_inplace(data, sign);
1605 }
1606 Algorithm::Stockham => {
1607 let input = data.to_vec();
1609 StockhamSolver::new().execute(&input, data, sign);
1610 }
1611 Algorithm::Composite(n) => {
1612 let sign_int = if sign == Sign::Forward { -1 } else { 1 };
1613 execute_composite_codelet(data, *n, sign_int);
1614 }
1615 Algorithm::Generic(solver) => {
1616 solver.execute_inplace(data, sign);
1617 }
1618 Algorithm::Bluestein(solver) => {
1619 solver.execute_inplace(data, sign);
1620 }
1621 }
1622 }
1623}
1624pub struct SplitPlan<T: Float> {
1654 plan: Plan<T>,
1656}
1657impl<T: Float> SplitPlan<T> {
1658 #[must_use]
1665 pub fn dft_1d(n: usize, direction: Direction, flags: Flags) -> Option<Self> {
1666 let plan = Plan::dft_1d(n, direction, flags)?;
1667 Some(Self { plan })
1668 }
1669 #[must_use]
1671 pub fn size(&self) -> usize {
1672 self.plan.n
1673 }
1674 #[must_use]
1676 pub fn direction(&self) -> Direction {
1677 self.plan.direction
1678 }
1679 pub fn execute(&self, in_real: &[T], in_imag: &[T], out_real: &mut [T], out_imag: &mut [T]) {
1690 let n = self.plan.n;
1691 assert_eq!(in_real.len(), n, "Input real size must match plan size");
1692 assert_eq!(
1693 in_imag.len(),
1694 n,
1695 "Input imaginary size must match plan size"
1696 );
1697 assert_eq!(out_real.len(), n, "Output real size must match plan size");
1698 assert_eq!(
1699 out_imag.len(),
1700 n,
1701 "Output imaginary size must match plan size"
1702 );
1703 let input: Vec<Complex<T>> = in_real
1704 .iter()
1705 .zip(in_imag.iter())
1706 .map(|(&re, &im)| Complex::new(re, im))
1707 .collect();
1708 let mut output = vec![Complex::<T>::zero(); n];
1709 self.plan.execute(&input, &mut output);
1710 for (i, c) in output.iter().enumerate() {
1711 out_real[i] = c.re;
1712 out_imag[i] = c.im;
1713 }
1714 }
1715 pub fn execute_inplace(&self, real: &mut [T], imag: &mut [T]) {
1724 let n = self.plan.n;
1725 assert_eq!(real.len(), n, "Real size must match plan size");
1726 assert_eq!(imag.len(), n, "Imaginary size must match plan size");
1727 let mut data: Vec<Complex<T>> = real
1728 .iter()
1729 .zip(imag.iter())
1730 .map(|(&re, &im)| Complex::new(re, im))
1731 .collect();
1732 self.plan.execute_inplace(&mut data);
1733 for (i, c) in data.iter().enumerate() {
1734 real[i] = c.re;
1735 imag[i] = c.im;
1736 }
1737 }
1738}
1739pub struct Plan2D<T: Float> {
1744 n0: usize,
1746 n1: usize,
1748 direction: Direction,
1750 row_plan: Plan<T>,
1752 col_plan: Plan<T>,
1754}
1755impl<T: Float> Plan2D<T> {
1756 #[must_use]
1767 pub fn new(n0: usize, n1: usize, direction: Direction, flags: Flags) -> Option<Self> {
1768 let row_plan = Plan::dft_1d(n1, direction, flags)?;
1769 let col_plan = Plan::dft_1d(n0, direction, flags)?;
1770 Some(Self {
1771 n0,
1772 n1,
1773 direction,
1774 row_plan,
1775 col_plan,
1776 })
1777 }
1778 #[must_use]
1780 pub fn rows(&self) -> usize {
1781 self.n0
1782 }
1783 #[must_use]
1785 pub fn cols(&self) -> usize {
1786 self.n1
1787 }
1788 #[must_use]
1790 pub fn size(&self) -> usize {
1791 self.n0 * self.n1
1792 }
1793 #[must_use]
1795 pub fn direction(&self) -> Direction {
1796 self.direction
1797 }
1798 pub fn execute(&self, input: &[Complex<T>], output: &mut [Complex<T>]) {
1805 let total = self.n0 * self.n1;
1806 assert_eq!(input.len(), total, "Input size must match n0 × n1");
1807 assert_eq!(output.len(), total, "Output size must match n0 × n1");
1808 if total == 0 {
1809 return;
1810 }
1811 let mut temp = vec![Complex::zero(); total];
1812 for i in 0..self.n0 {
1813 let row_start = i * self.n1;
1814 let row_end = row_start + self.n1;
1815 self.row_plan
1816 .execute(&input[row_start..row_end], &mut temp[row_start..row_end]);
1817 }
1818 let mut col_in = vec![Complex::zero(); self.n0];
1819 let mut col_out = vec![Complex::zero(); self.n0];
1820 for j in 0..self.n1 {
1821 for i in 0..self.n0 {
1822 col_in[i] = temp[i * self.n1 + j];
1823 }
1824 self.col_plan.execute(&col_in, &mut col_out);
1825 for i in 0..self.n0 {
1826 output[i * self.n1 + j] = col_out[i];
1827 }
1828 }
1829 }
1830 pub fn execute_inplace(&self, data: &mut [Complex<T>]) {
1835 let total = self.n0 * self.n1;
1836 assert_eq!(data.len(), total, "Data size must match n0 × n1");
1837 if total == 0 {
1838 return;
1839 }
1840 for i in 0..self.n0 {
1841 let row_start = i * self.n1;
1842 let row_end = row_start + self.n1;
1843 self.row_plan.execute_inplace(&mut data[row_start..row_end]);
1844 }
1845 let mut col = vec![Complex::zero(); self.n0];
1846 for j in 0..self.n1 {
1847 for i in 0..self.n0 {
1848 col[i] = data[i * self.n1 + j];
1849 }
1850 self.col_plan.execute_inplace(&mut col);
1851 for i in 0..self.n0 {
1852 data[i * self.n1 + j] = col[i];
1853 }
1854 }
1855 }
1856}
1857pub struct R2rPlan<T: Float> {
1864 n: usize,
1866 kind: R2rKind,
1868 _marker: core::marker::PhantomData<T>,
1869}
1870impl<T: Float> R2rPlan<T> {
1871 #[must_use]
1881 pub fn r2r_1d(n: usize, kind: R2rKind, _flags: Flags) -> Option<Self> {
1882 if n == 0 {
1883 return None;
1884 }
1885 Some(Self {
1886 n,
1887 kind,
1888 _marker: core::marker::PhantomData,
1889 })
1890 }
1891 #[must_use]
1893 pub fn dct1(n: usize, flags: Flags) -> Option<Self> {
1894 Self::r2r_1d(n, R2rKind::Redft00, flags)
1895 }
1896 #[must_use]
1898 pub fn dct2(n: usize, flags: Flags) -> Option<Self> {
1899 Self::r2r_1d(n, R2rKind::Redft10, flags)
1900 }
1901 #[must_use]
1903 pub fn dct3(n: usize, flags: Flags) -> Option<Self> {
1904 Self::r2r_1d(n, R2rKind::Redft01, flags)
1905 }
1906 #[must_use]
1908 pub fn dct4(n: usize, flags: Flags) -> Option<Self> {
1909 Self::r2r_1d(n, R2rKind::Redft11, flags)
1910 }
1911 #[must_use]
1913 pub fn dst1(n: usize, flags: Flags) -> Option<Self> {
1914 Self::r2r_1d(n, R2rKind::Rodft00, flags)
1915 }
1916 #[must_use]
1918 pub fn dst2(n: usize, flags: Flags) -> Option<Self> {
1919 Self::r2r_1d(n, R2rKind::Rodft10, flags)
1920 }
1921 #[must_use]
1923 pub fn dst3(n: usize, flags: Flags) -> Option<Self> {
1924 Self::r2r_1d(n, R2rKind::Rodft01, flags)
1925 }
1926 #[must_use]
1928 pub fn dst4(n: usize, flags: Flags) -> Option<Self> {
1929 Self::r2r_1d(n, R2rKind::Rodft11, flags)
1930 }
1931 #[must_use]
1933 pub fn dht(n: usize, flags: Flags) -> Option<Self> {
1934 Self::r2r_1d(n, R2rKind::Dht, flags)
1935 }
1936 #[must_use]
1938 pub fn size(&self) -> usize {
1939 self.n
1940 }
1941 #[must_use]
1943 pub fn kind(&self) -> R2rKind {
1944 self.kind
1945 }
1946 pub fn execute(&self, input: &[T], output: &mut [T]) {
1951 use crate::rdft::solvers::R2rSolver;
1952 assert_eq!(input.len(), self.n, "Input size must match plan size");
1953 assert_eq!(output.len(), self.n, "Output size must match plan size");
1954 let solver = R2rSolver::new(self.kind);
1955 solver.execute(input, output);
1956 }
1957 pub fn execute_inplace(&self, data: &mut [T]) {
1962 use crate::rdft::solvers::R2rSolver;
1963 assert_eq!(data.len(), self.n, "Data size must match plan size");
1964 let input = data.to_vec();
1965 let solver = R2rSolver::new(self.kind);
1966 solver.execute(&input, data);
1967 }
1968}