Skip to main content

oxifft/dft/codelets/simd/
mod.rs

1//! SIMD-optimized codelets.
2//!
3//! This module provides SIMD-accelerated versions of the small DFT kernels.
4//! The codelets use the architecture-specific SIMD backends when available.
5
6// Items after statements are intentional for precomputed twiddle tables
7#![allow(clippy::items_after_statements)]
8// Large stack arrays are intentional for performance in fixed-size transforms
9#![allow(clippy::large_stack_arrays)]
10
11use core::any::TypeId;
12
13use crate::kernel::{Complex, Float};
14use crate::simd::{detect_simd_level, SimdLevel};
15
16pub(crate) mod backends;
17mod large_sizes;
18mod small_sizes;
19#[cfg(test)]
20mod tests;
21
22pub use large_sizes::*;
23pub use small_sizes::*;
24
25/// Detect if SIMD acceleration is available and beneficial.
26#[inline]
27pub fn simd_available() -> bool {
28    let level = detect_simd_level();
29    matches!(
30        level,
31        SimdLevel::Sse2 | SimdLevel::Avx | SimdLevel::Avx2 | SimdLevel::Avx512 | SimdLevel::Neon
32    )
33}
34
35/// Size-2 DFT with automatic SIMD dispatch.
36///
37/// This function selects the best implementation based on available CPU features
38/// and the float type. For f64, uses SIMD acceleration when available.
39#[inline]
40pub fn notw_2_dispatch<T: Float>(x: &mut [Complex<T>]) {
41    // Check if T is f64 at runtime
42    if TypeId::of::<T>() == TypeId::of::<f64>() {
43        // Safety: We verified T is f64, so the memory layout is identical
44        let x_f64 = unsafe {
45            core::slice::from_raw_parts_mut(x.as_mut_ptr().cast::<Complex<f64>>(), x.len())
46        };
47        notw_2_simd_f64(x_f64);
48        return;
49    }
50    // Fallback to scalar for other types
51    super::notw_2(x);
52}
53
54/// Size-4 DFT with automatic SIMD dispatch.
55///
56/// This function selects the best implementation based on available CPU features
57/// and the float type. For f64, uses SIMD acceleration when available.
58#[inline]
59pub fn notw_4_dispatch<T: Float>(x: &mut [Complex<T>], sign: i32) {
60    // Check if T is f64 at runtime
61    if TypeId::of::<T>() == TypeId::of::<f64>() {
62        // Safety: We verified T is f64, so the memory layout is identical
63        let x_f64 = unsafe {
64            core::slice::from_raw_parts_mut(x.as_mut_ptr().cast::<Complex<f64>>(), x.len())
65        };
66        notw_4_simd_f64(x_f64, sign);
67        return;
68    }
69    // Fallback to scalar for other types
70    super::notw_4(x, sign);
71}
72
73/// Size-8 DFT with automatic SIMD dispatch.
74///
75/// This function selects the best implementation based on available CPU features
76/// and the float type. For f64, uses SIMD acceleration when available.
77#[inline]
78pub fn notw_8_dispatch<T: Float>(x: &mut [Complex<T>], sign: i32) {
79    // Check if T is f64 at runtime
80    if TypeId::of::<T>() == TypeId::of::<f64>() {
81        // Safety: We verified T is f64, so the memory layout is identical
82        let x_f64 = unsafe {
83            core::slice::from_raw_parts_mut(x.as_mut_ptr().cast::<Complex<f64>>(), x.len())
84        };
85        notw_8_simd_f64(x_f64, sign);
86        return;
87    }
88    // Fallback to scalar for other types
89    super::notw_8(x, sign);
90}
91
92/// Size-16 DFT with automatic SIMD dispatch.
93///
94/// Uses scalar implementation with twiddle recurrence optimization.
95#[inline]
96pub fn notw_16_dispatch<T: Float>(x: &mut [Complex<T>], sign: i32) {
97    // Check if T is f64 at runtime
98    if TypeId::of::<T>() == TypeId::of::<f64>() {
99        // Safety: We verified T is f64, so the memory layout is identical
100        let x_f64 = unsafe {
101            core::slice::from_raw_parts_mut(x.as_mut_ptr().cast::<Complex<f64>>(), x.len())
102        };
103        notw_16_simd_f64(x_f64, sign);
104        return;
105    }
106    // Fallback to scalar for other types
107    super::notw_16(x, sign);
108}
109
110/// Size-32 DFT with automatic SIMD dispatch.
111///
112/// Uses scalar implementation with twiddle recurrence optimization.
113#[inline]
114pub fn notw_32_dispatch<T: Float>(x: &mut [Complex<T>], sign: i32) {
115    // Check if T is f64 at runtime
116    if TypeId::of::<T>() == TypeId::of::<f64>() {
117        // Safety: We verified T is f64, so the memory layout is identical
118        let x_f64 = unsafe {
119            core::slice::from_raw_parts_mut(x.as_mut_ptr().cast::<Complex<f64>>(), x.len())
120        };
121        notw_32_simd_f64(x_f64, sign);
122        return;
123    }
124    // Fallback to scalar for other types
125    super::notw_32(x, sign);
126}
127
128/// Size-64 DFT with automatic SIMD dispatch.
129///
130/// Uses scalar implementation with twiddle recurrence optimization.
131#[inline]
132pub fn notw_64_dispatch<T: Float>(x: &mut [Complex<T>], sign: i32) {
133    // Check if T is f64 at runtime
134    if TypeId::of::<T>() == TypeId::of::<f64>() {
135        // Safety: We verified T is f64, so the memory layout is identical
136        let x_f64 = unsafe {
137            core::slice::from_raw_parts_mut(x.as_mut_ptr().cast::<Complex<f64>>(), x.len())
138        };
139        notw_64_simd_f64(x_f64, sign);
140        return;
141    }
142    // Fallback to scalar for other types
143    super::notw_64(x, sign);
144}
145
146/// Size-128 DFT with automatic SIMD dispatch.
147///
148/// Uses scalar implementation with twiddle recurrence optimization.
149#[inline]
150pub fn notw_128_dispatch<T: Float>(x: &mut [Complex<T>], sign: i32) {
151    // Check if T is f64 at runtime
152    if TypeId::of::<T>() == TypeId::of::<f64>() {
153        // Safety: We verified T is f64, so the memory layout is identical
154        let x_f64 = unsafe {
155            core::slice::from_raw_parts_mut(x.as_mut_ptr().cast::<Complex<f64>>(), x.len())
156        };
157        notw_128_simd_f64(x_f64, sign);
158        return;
159    }
160    // Fallback to scalar for other types
161    super::notw_128(x, sign);
162}
163
164/// Size-256 DFT with automatic SIMD dispatch.
165///
166/// Uses scalar implementation with twiddle recurrence optimization.
167#[inline]
168pub fn notw_256_dispatch<T: Float>(x: &mut [Complex<T>], sign: i32) {
169    // Check if T is f64 at runtime
170    if TypeId::of::<T>() == TypeId::of::<f64>() {
171        // Safety: We verified T is f64, so the memory layout is identical
172        let x_f64 = unsafe {
173            core::slice::from_raw_parts_mut(x.as_mut_ptr().cast::<Complex<f64>>(), x.len())
174        };
175        notw_256_simd_f64(x_f64, sign);
176        return;
177    }
178    // Fallback to scalar for other types
179    super::notw_256(x, sign);
180}
181
182/// Size-512 DFT with automatic SIMD dispatch.
183///
184/// Uses iterative DIT with SIMD butterflies for optimal performance.
185#[inline]
186pub fn notw_512_dispatch<T: Float>(x: &mut [Complex<T>], sign: i32) {
187    // Check if T is f64 at runtime
188    if TypeId::of::<T>() == TypeId::of::<f64>() {
189        // Safety: We verified T is f64, so the memory layout is identical
190        let x_f64 = unsafe {
191            core::slice::from_raw_parts_mut(x.as_mut_ptr().cast::<Complex<f64>>(), x.len())
192        };
193        notw_512_simd_f64(x_f64, sign);
194        return;
195    }
196    // Fallback to iterative DIT for other types.
197    // Note: Must use execute_inplace directly to avoid infinite recursion,
198    // since CooleyTukeySolver::execute dispatches back to this codelet.
199    use crate::dft::problem::Sign;
200    use crate::dft::solvers::CooleyTukeySolver;
201    let sign_enum = if sign < 0 {
202        Sign::Forward
203    } else {
204        Sign::Backward
205    };
206    CooleyTukeySolver::default().execute_dit_inplace(x, sign_enum);
207}
208
209/// Size-1024 DFT with automatic SIMD dispatch.
210///
211/// Uses iterative DIT with SIMD butterflies for optimal performance.
212#[inline]
213pub fn notw_1024_dispatch<T: Float>(x: &mut [Complex<T>], sign: i32) {
214    // Check if T is f64 at runtime
215    if TypeId::of::<T>() == TypeId::of::<f64>() {
216        // Safety: We verified T is f64, so the memory layout is identical
217        let x_f64 = unsafe {
218            core::slice::from_raw_parts_mut(x.as_mut_ptr().cast::<Complex<f64>>(), x.len())
219        };
220        notw_1024_simd_f64(x_f64, sign);
221        return;
222    }
223    // Fallback to iterative DIT for other types.
224    // Note: Must use execute_inplace directly to avoid infinite recursion,
225    // since CooleyTukeySolver::execute dispatches back to this codelet.
226    use crate::dft::problem::Sign;
227    use crate::dft::solvers::CooleyTukeySolver;
228    let sign_enum = if sign < 0 {
229        Sign::Forward
230    } else {
231        Sign::Backward
232    };
233    CooleyTukeySolver::default().execute_dit_inplace(x, sign_enum);
234}
235
236/// Size-4096 DFT with automatic SIMD dispatch.
237///
238/// Uses iterative DIT with SIMD butterflies for optimal performance.
239#[inline]
240pub fn notw_4096_dispatch<T: Float>(x: &mut [Complex<T>], sign: i32) {
241    // Check if T is f64 at runtime
242    if TypeId::of::<T>() == TypeId::of::<f64>() {
243        // Safety: We verified T is f64, so the memory layout is identical
244        let x_f64 = unsafe {
245            core::slice::from_raw_parts_mut(x.as_mut_ptr().cast::<Complex<f64>>(), x.len())
246        };
247        notw_4096_simd_f64(x_f64, sign);
248        return;
249    }
250    // Fallback to iterative DIT for other types.
251    // Note: Must use execute_inplace directly to avoid infinite recursion,
252    // since CooleyTukeySolver::execute dispatches back to this codelet.
253    use crate::dft::problem::Sign;
254    use crate::dft::solvers::CooleyTukeySolver;
255    let sign_enum = if sign < 0 {
256        Sign::Forward
257    } else {
258        Sign::Backward
259    };
260    CooleyTukeySolver::default().execute_dit_inplace(x, sign_enum);
261}