oxifft_codegen_impl/
gen_notw.rs1use crate::symbolic::emit_body_from_symbolic;
7use proc_macro2::TokenStream;
8use quote::quote;
9use syn::LitInt;
10
11pub fn generate(input: TokenStream) -> Result<TokenStream, syn::Error> {
17 let size: LitInt = syn::parse2(input)?;
18 let n: usize = size.base10_parse().map_err(|_| {
19 syn::Error::new(
20 size.span(),
21 "gen_notw_codelet: expected an integer size literal",
22 )
23 })?;
24
25 match n {
26 2 => Ok(gen_size_2()),
27 4 => Ok(gen_size_4()),
28 8 => Ok(gen_size_8()),
29 16 => Ok(gen_size_16()),
30 32 => Ok(gen_size_32()),
31 64 => Ok(gen_size_64()),
32 _ => Err(syn::Error::new(
33 size.span(),
34 format!("gen_notw_codelet: unsupported size {n} (expected one of 2, 4, 8, 16, 32, 64)"),
35 )),
36 }
37}
38
39fn gen_size_2() -> TokenStream {
40 quote! {
41 #[inline(always)]
43 pub fn codelet_notw_2<T: crate::kernel::Float>(
44 x: &mut [crate::kernel::Complex<T>],
45 _sign: i32,
46 ) {
47 debug_assert!(x.len() >= 2);
48 let a = x[0];
49 let b = x[1];
50 x[0] = a + b;
51 x[1] = a - b;
52 }
53 }
54}
55
56fn gen_size_4() -> TokenStream {
57 quote! {
58 #[inline(always)]
60 pub fn codelet_notw_4<T: crate::kernel::Float>(
61 x: &mut [crate::kernel::Complex<T>],
62 sign: i32,
63 ) {
64 debug_assert!(x.len() >= 4);
65
66 let x0 = x[0];
67 let x1 = x[1];
68 let x2 = x[2];
69 let x3 = x[3];
70
71 let t0 = x0 + x2;
73 let t1 = x0 - x2;
74 let t2 = x1 + x3;
75 let t3 = x1 - x3;
76
77 let t3_rot = if sign < 0 {
79 crate::kernel::Complex::new(t3.im, -t3.re)
80 } else {
81 crate::kernel::Complex::new(-t3.im, t3.re)
82 };
83
84 x[0] = t0 + t2;
86 x[1] = t1 + t3_rot;
87 x[2] = t0 - t2;
88 x[3] = t1 - t3_rot;
89 }
90 }
91}
92
93fn gen_size_8() -> TokenStream {
94 quote! {
97 #[inline(always)]
101 pub fn codelet_notw_8<T: crate::kernel::Float>(
102 x: &mut [crate::kernel::Complex<T>],
103 sign: i32,
104 ) {
105 debug_assert!(x.len() >= 8);
106
107 let c2 = T::from_f64(0.707_106_781_186_547_6_f64);
109
110 let mut a = [crate::kernel::Complex::<T>::zero(); 8];
114 a[0] = x[0]; a[1] = x[4];
115 a[2] = x[2]; a[3] = x[6];
116 a[4] = x[1]; a[5] = x[5];
117 a[6] = x[3]; a[7] = x[7];
118
119 for i in (0..8usize).step_by(2) {
121 let t = a[i + 1];
122 a[i + 1] = a[i] - t;
123 a[i] = a[i] + t;
124 }
125
126 for group in (0..8usize).step_by(4) {
129 let t = a[group + 2];
131 a[group + 2] = a[group] - t;
132 a[group] = a[group] + t;
133
134 let t = a[group + 3];
136 let t_tw = if sign < 0 {
137 crate::kernel::Complex::new(t.im, -t.re)
138 } else {
139 crate::kernel::Complex::new(-t.im, t.re)
140 };
141 a[group + 3] = a[group + 1] - t_tw;
142 a[group + 1] = a[group + 1] + t_tw;
143 }
144
145 let t = a[4];
149 a[4] = a[0] - t;
150 a[0] = a[0] + t;
151
152 let t = a[5];
154 let t_tw = if sign < 0 {
155 crate::kernel::Complex::new((t.re + t.im) * c2, (t.im - t.re) * c2)
156 } else {
157 crate::kernel::Complex::new((t.re - t.im) * c2, (t.im + t.re) * c2)
158 };
159 a[5] = a[1] - t_tw;
160 a[1] = a[1] + t_tw;
161
162 let t = a[6];
164 let t_tw = if sign < 0 {
165 crate::kernel::Complex::new(t.im, -t.re)
166 } else {
167 crate::kernel::Complex::new(-t.im, t.re)
168 };
169 a[6] = a[2] - t_tw;
170 a[2] = a[2] + t_tw;
171
172 let t = a[7];
174 let t_tw = if sign < 0 {
175 crate::kernel::Complex::new((-t.re + t.im) * c2, (-t.im - t.re) * c2)
176 } else {
177 crate::kernel::Complex::new((-t.re - t.im) * c2, (-t.im + t.re) * c2)
178 };
179 a[7] = a[3] - t_tw;
180 a[3] = a[3] + t_tw;
181
182 for i in 0..8usize {
184 x[i] = a[i];
185 }
186 }
187 }
188}
189
190fn gen_size_16() -> TokenStream {
191 let fwd = emit_body_from_symbolic(16, true);
195 let inv = emit_body_from_symbolic(16, false);
196 quote! {
197 #[inline(always)]
201 #[allow(clippy::too_many_lines, clippy::approx_constant, clippy::suboptimal_flops)]
202 pub fn codelet_notw_16<T: crate::kernel::Float>(
203 x: &mut [crate::kernel::Complex<T>],
204 sign: i32,
205 ) {
206 debug_assert!(x.len() >= 16);
207 if sign < 0 {
208 #fwd
209 } else {
210 #inv
211 }
212 }
213 }
214}
215
216fn gen_size_32() -> TokenStream {
217 let fwd = emit_body_from_symbolic(32, true);
219 let inv = emit_body_from_symbolic(32, false);
220 quote! {
221 #[inline(always)]
225 #[allow(clippy::too_many_lines, clippy::approx_constant, clippy::suboptimal_flops)]
226 pub fn codelet_notw_32<T: crate::kernel::Float>(
227 x: &mut [crate::kernel::Complex<T>],
228 sign: i32,
229 ) {
230 debug_assert!(x.len() >= 32);
231 if sign < 0 {
232 #fwd
233 } else {
234 #inv
235 }
236 }
237 }
238}
239
240fn gen_size_64() -> TokenStream {
241 let fwd = emit_body_from_symbolic(64, true);
243 let inv = emit_body_from_symbolic(64, false);
244 quote! {
245 #[inline(always)]
249 #[allow(clippy::too_many_lines, clippy::approx_constant, clippy::suboptimal_flops)]
250 pub fn codelet_notw_64<T: crate::kernel::Float>(
251 x: &mut [crate::kernel::Complex<T>],
252 sign: i32,
253 ) {
254 debug_assert!(x.len() >= 64);
255 if sign < 0 {
256 #fwd
257 } else {
258 #inv
259 }
260 }
261 }
262}