1#![allow(clippy::cast_precision_loss)] use std::collections::HashMap;
16
17use proc_macro2::{Span, TokenStream};
18use quote::{format_ident, quote};
19use syn::{parse::ParseStream, Ident, LitInt, Token};
20
21use crate::symbolic::{ConstantFolder, Expr, StrengthReducer};
22
23pub struct RdftInput {
29 pub size: usize,
30 pub kind: RdftKind,
31}
32
33#[derive(Copy, Clone, PartialEq, Eq, Debug)]
35pub enum RdftKind {
36 R2hc,
37 Hc2r,
38}
39
40impl syn::parse::Parse for RdftInput {
41 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
42 let kw_size: Ident = input.parse()?;
44 if kw_size != "size" {
45 return Err(syn::Error::new(
46 kw_size.span(),
47 "expected `size = N, kind = R2hc | Hc2r`",
48 ));
49 }
50 let _eq: Token![=] = input.parse()?;
51 let size_lit: LitInt = input.parse()?;
52 let size: usize = size_lit
53 .base10_parse()
54 .map_err(|_| syn::Error::new(size_lit.span(), "expected integer size literal"))?;
55
56 let _comma: Token![,] = input.parse()?;
57
58 let kw_kind: Ident = input.parse()?;
59 if kw_kind != "kind" {
60 return Err(syn::Error::new(
61 kw_kind.span(),
62 "expected `kind = R2hc | Hc2r`",
63 ));
64 }
65 let _eq2: Token![=] = input.parse()?;
66 let kind_ident: Ident = input.parse()?;
67
68 let kind = match kind_ident.to_string().as_str() {
69 "R2hc" => RdftKind::R2hc,
70 "Hc2r" => RdftKind::Hc2r,
71 other => {
72 return Err(syn::Error::new(
73 kind_ident.span(),
74 format!("unknown RDFT kind `{other}`, expected `R2hc` or `Hc2r`"),
75 ))
76 }
77 };
78
79 Ok(Self { size, kind })
80 }
81}
82
83pub fn generate(input: TokenStream) -> Result<TokenStream, syn::Error> {
92 let parsed: RdftInput = syn::parse2(input)?;
93 match parsed.kind {
94 RdftKind::R2hc => gen_r2hc(parsed.size),
95 RdftKind::Hc2r => gen_hc2r(parsed.size),
96 }
97}
98
99fn gen_r2hc(n: usize) -> Result<TokenStream, syn::Error> {
104 match n {
105 2 | 4 | 8 => Ok(emit_r2hc_codelet(n)),
106 _ => Err(syn::Error::new(
107 Span::call_site(),
108 format!("gen_rdft_codelet: unsupported size {n} for R2hc (expected 2, 4, or 8)"),
109 )),
110 }
111}
112
113fn symbolic_r2hc(n: usize) -> Vec<(Expr, Expr)> {
121 let half = n / 2;
122 let mut outputs = Vec::with_capacity(half + 1);
123
124 for k in 0..=half {
125 let mut re_acc = Expr::Const(0.0);
126 let mut im_acc = Expr::Const(0.0);
127 for j in 0..n {
128 let angle = -2.0 * std::f64::consts::PI * (j * k) as f64 / n as f64;
130 let cos_val = angle.cos();
131 let sin_val = angle.sin(); let xj = Expr::input_re(j);
133 re_acc = re_acc.add(xj.clone().mul(Expr::Const(cos_val)));
134 im_acc = im_acc.add(xj.mul(Expr::Const(sin_val)));
135 }
136 let re_red = ConstantFolder::fold(&StrengthReducer::reduce(&re_acc));
137 let im_red = ConstantFolder::fold(&StrengthReducer::reduce(&im_acc));
138 outputs.push((re_red, im_red));
139 }
140 outputs
141}
142
143fn emit_r2hc_codelet(n: usize) -> TokenStream {
145 let outputs = symbolic_r2hc(n); let half = n / 2;
147 let min_out = half + 1;
148 let fn_name = format_ident!("r2hc_{n}_gen");
149 let body = emit_r2hc_body(n, &outputs);
150
151 quote! {
152 #[inline(always)]
157 #[allow(clippy::too_many_lines, clippy::approx_constant, clippy::suboptimal_flops)]
158 pub fn #fn_name<T: crate::kernel::Float>(x: &[T], y: &mut [crate::kernel::Complex<T>]) {
159 debug_assert_eq!(x.len(), #n);
160 debug_assert!(y.len() >= #min_out);
161 #body
162 }
163 }
164}
165
166fn emit_r2hc_body(n: usize, outputs: &[(Expr, Expr)]) -> TokenStream {
168 let all_exprs: Vec<&Expr> = outputs.iter().flat_map(|(re, im)| [re, im]).collect();
170
171 let mut cse = LocalCse::new();
172 for expr in &all_exprs {
173 cse.count_recursive(expr);
174 }
175
176 let mut body = TokenStream::new();
177
178 for i in 0..n {
180 let var = format_ident!("x{i}");
181 body.extend(quote! { let #var = x[#i]; });
182 }
183
184 let assignments = cse.get_assignments();
190 for (name, expr) in &assignments {
191 let id = format_ident!("{name}");
192 let tok = emit_real_scalar(expr);
195 body.extend(quote! { let #id = #tok; });
196 }
197
198 for (k, (re_expr, im_expr)) in outputs.iter().enumerate() {
200 let re_tok = emit_real_scalar(&cse.rewrite(re_expr));
201 let im_tok = emit_real_scalar(&cse.rewrite(im_expr));
202 body.extend(quote! {
203 y[#k] = crate::kernel::Complex::new(#re_tok, #im_tok);
204 });
205 }
206
207 body
208}
209
210fn gen_hc2r(n: usize) -> Result<TokenStream, syn::Error> {
215 match n {
216 2 | 4 | 8 => Ok(emit_hc2r_codelet(n)),
217 _ => Err(syn::Error::new(
218 Span::call_site(),
219 format!("gen_rdft_codelet: unsupported size {n} for Hc2r (expected 2, 4, or 8)"),
220 )),
221 }
222}
223
224fn symbolic_hc2r(n: usize) -> Vec<Expr> {
234 let half = n / 2;
235 let mut outputs = Vec::with_capacity(n);
236
237 for j in 0..n {
238 let mut acc = Expr::input_re(0);
240
241 for k in 1..half {
243 let angle = 2.0 * std::f64::consts::PI * (j * k) as f64 / n as f64;
244 let cos_val = angle.cos();
245 let sin_val = angle.sin();
246
247 let yk_re = Expr::input_re(k);
248 let yk_im = Expr::input_im(k);
249
250 let term_re = yk_re.mul(Expr::Const(cos_val));
251 let term_im = yk_im.mul(Expr::Const(sin_val));
252 let term = term_re.sub(term_im);
253 let term2 = term.mul(Expr::Const(2.0));
255 acc = acc.add(term2);
256 }
257
258 let nyquist_angle = std::f64::consts::PI * j as f64;
260 let nyquist_cos = nyquist_angle.cos(); let nyquist_term = Expr::input_re(half).mul(Expr::Const(nyquist_cos));
262 acc = acc.add(nyquist_term);
263
264 let reduced = ConstantFolder::fold(&StrengthReducer::reduce(&acc));
265 outputs.push(reduced);
266 }
267 outputs
268}
269
270fn emit_hc2r_codelet(n: usize) -> TokenStream {
272 let outputs = symbolic_hc2r(n);
273 let half = n / 2;
274 let min_in = half + 1;
275 let fn_name = format_ident!("hc2r_{n}_gen");
276 let body = emit_hc2r_body(n, &outputs, half);
277
278 quote! {
279 #[inline(always)]
284 #[allow(clippy::too_many_lines, clippy::approx_constant, clippy::suboptimal_flops)]
285 pub fn #fn_name<T: crate::kernel::Float>(y: &[crate::kernel::Complex<T>], x: &mut [T]) {
286 debug_assert!(y.len() >= #min_in);
287 debug_assert_eq!(x.len(), #n);
288 #body
289 }
290 }
291}
292
293fn emit_hc2r_body(_n: usize, outputs: &[Expr], half: usize) -> TokenStream {
295 let mut cse = LocalCse::new();
296 for expr in outputs {
297 cse.count_recursive(expr);
298 }
299
300 let mut body = TokenStream::new();
301
302 for k in 0..=half {
304 let re_var = format_ident!("y{k}_re");
305 let im_var = format_ident!("y{k}_im");
306 body.extend(quote! {
307 let #re_var = y[#k].re;
308 let #im_var = y[#k].im;
309 });
310 }
311
312 let assignments = cse.get_assignments();
314 for (name, expr) in &assignments {
315 let id = format_ident!("{name}");
316 let tok = emit_hc2r_scalar(expr);
317 body.extend(quote! { let #id = #tok; });
318 }
319
320 for (j, expr) in outputs.iter().enumerate() {
322 let val_tok = emit_hc2r_scalar(&cse.rewrite(expr));
323 body.extend(quote! { x[#j] = #val_tok; });
324 }
325
326 body
327}
328
329struct LocalCse {
338 cache: HashMap<u64, (Expr, String, usize)>,
340 counter: usize,
341}
342
343impl LocalCse {
344 fn new() -> Self {
345 Self {
346 cache: HashMap::new(),
347 counter: 0,
348 }
349 }
350
351 fn count_recursive(&mut self, expr: &Expr) {
353 match expr {
354 Expr::Input { .. } | Expr::Const(_) | Expr::Temp(_) => {}
355 Expr::Add(a, b) | Expr::Sub(a, b) | Expr::Mul(a, b) => {
356 self.count_recursive(a);
357 self.count_recursive(b);
358 let hash = expr.structural_hash();
359 let entry = self.cache.entry(hash).or_insert_with(|| {
360 let name = format!("t{}", self.counter);
361 self.counter += 1;
362 (expr.clone(), name, 0)
363 });
364 entry.2 += 1;
365 }
366 Expr::Neg(a) => {
367 self.count_recursive(a);
368 let hash = expr.structural_hash();
369 let entry = self.cache.entry(hash).or_insert_with(|| {
370 let name = format!("t{}", self.counter);
371 self.counter += 1;
372 (expr.clone(), name, 0)
373 });
374 entry.2 += 1;
375 }
376 }
377 }
378
379 fn rewrite(&self, expr: &Expr) -> Expr {
381 self.rewrite_inner(expr, None)
382 }
383
384 fn rewrite_inner(&self, expr: &Expr, exclude_hash: Option<u64>) -> Expr {
385 match expr {
386 Expr::Input { .. } | Expr::Const(_) | Expr::Temp(_) => expr.clone(),
387 Expr::Add(a, b) => {
388 let hash = expr.structural_hash();
389 if exclude_hash != Some(hash) {
390 if let Some((_, name, count)) = self.cache.get(&hash) {
391 if *count >= 2 {
392 return Expr::Temp(name.clone());
393 }
394 }
395 }
396 Expr::Add(
397 Box::new(self.rewrite_inner(a, None)),
398 Box::new(self.rewrite_inner(b, None)),
399 )
400 }
401 Expr::Sub(a, b) => {
402 let hash = expr.structural_hash();
403 if exclude_hash != Some(hash) {
404 if let Some((_, name, count)) = self.cache.get(&hash) {
405 if *count >= 2 {
406 return Expr::Temp(name.clone());
407 }
408 }
409 }
410 Expr::Sub(
411 Box::new(self.rewrite_inner(a, None)),
412 Box::new(self.rewrite_inner(b, None)),
413 )
414 }
415 Expr::Mul(a, b) => {
416 let hash = expr.structural_hash();
417 if exclude_hash != Some(hash) {
418 if let Some((_, name, count)) = self.cache.get(&hash) {
419 if *count >= 2 {
420 return Expr::Temp(name.clone());
421 }
422 }
423 }
424 Expr::Mul(
425 Box::new(self.rewrite_inner(a, None)),
426 Box::new(self.rewrite_inner(b, None)),
427 )
428 }
429 Expr::Neg(a) => {
430 let hash = expr.structural_hash();
431 if exclude_hash != Some(hash) {
432 if let Some((_, name, count)) = self.cache.get(&hash) {
433 if *count >= 2 {
434 return Expr::Temp(name.clone());
435 }
436 }
437 }
438 Expr::Neg(Box::new(self.rewrite_inner(a, None)))
439 }
440 }
441 }
442
443 fn get_assignments(&self) -> Vec<(String, Expr)> {
445 let mut result: Vec<(String, Expr)> = self
446 .cache
447 .values()
448 .filter(|(_, _, count)| *count >= 2)
449 .map(|(expr, name, _)| (name.clone(), expr.clone()))
450 .collect();
451 result.sort_by(|a, b| {
452 let na: usize = a.0[1..].parse().unwrap_or(0);
453 let nb: usize = b.0[1..].parse().unwrap_or(0);
454 na.cmp(&nb)
455 });
456 result
457 }
458}
459
460fn emit_real_scalar(expr: &Expr) -> TokenStream {
470 match expr {
471 Expr::Input { index, is_real } => {
472 if *is_real {
473 let name = format_ident!("x{index}");
474 quote! { #name }
475 } else {
476 let name = format_ident!("y{index}_im");
478 quote! { #name }
479 }
480 }
481 Expr::Const(v) => emit_const(*v),
482 Expr::Add(a, b) => {
483 let a = emit_real_scalar(a);
484 let b = emit_real_scalar(b);
485 quote! { (#a + #b) }
486 }
487 Expr::Sub(a, b) => {
488 let a = emit_real_scalar(a);
489 let b = emit_real_scalar(b);
490 quote! { (#a - #b) }
491 }
492 Expr::Mul(a, b) => {
493 let a = emit_real_scalar(a);
494 let b = emit_real_scalar(b);
495 quote! { (#a * #b) }
496 }
497 Expr::Neg(a) => {
498 let a = emit_real_scalar(a);
499 quote! { (-#a) }
500 }
501 Expr::Temp(name) => {
502 let id = format_ident!("{name}");
503 quote! { #id }
504 }
505 }
506}
507
508fn emit_hc2r_scalar(expr: &Expr) -> TokenStream {
510 match expr {
511 Expr::Input { index, is_real } => {
512 let name = if *is_real {
513 format_ident!("y{index}_re")
514 } else {
515 format_ident!("y{index}_im")
516 };
517 quote! { #name }
518 }
519 Expr::Const(v) => emit_const(*v),
520 Expr::Add(a, b) => {
521 let a = emit_hc2r_scalar(a);
522 let b = emit_hc2r_scalar(b);
523 quote! { (#a + #b) }
524 }
525 Expr::Sub(a, b) => {
526 let a = emit_hc2r_scalar(a);
527 let b = emit_hc2r_scalar(b);
528 quote! { (#a - #b) }
529 }
530 Expr::Mul(a, b) => {
531 let a = emit_hc2r_scalar(a);
532 let b = emit_hc2r_scalar(b);
533 quote! { (#a * #b) }
534 }
535 Expr::Neg(a) => {
536 let a = emit_hc2r_scalar(a);
537 quote! { (-#a) }
538 }
539 Expr::Temp(name) => {
540 let id = format_ident!("{name}");
541 quote! { #id }
542 }
543 }
544}
545
546fn emit_const(v: f64) -> TokenStream {
548 if (v - 0.0_f64).abs() < f64::EPSILON {
549 quote! { T::ZERO }
550 } else if (v - 1.0_f64).abs() < f64::EPSILON {
551 quote! { T::ONE }
552 } else if (v - (-1.0_f64)).abs() < f64::EPSILON {
553 quote! { (-T::ONE) }
554 } else {
555 quote! { T::from_f64(#v) }
556 }
557}