oxifft_codegen_impl/gen_simd/
runtime_dispatch.rs1use proc_macro2::TokenStream;
38use quote::{format_ident, quote};
39use syn::{
40 parse::{Parse, ParseStream},
41 LitInt, Token,
42};
43
44pub use super::multi_transform::Precision;
45
46#[derive(Debug, Clone, Copy)]
52pub struct DispatcherConfig {
53 pub size: usize,
55 pub precision: Precision,
57}
58
59pub const ISA_SCALAR: u8 = 0;
65pub const ISA_SSE2: u8 = 1;
67pub const ISA_AVX: u8 = 2;
69pub const ISA_AVX2_FMA: u8 = 3;
71pub const ISA_AVX512: u8 = 4;
73pub const ISA_NEON: u8 = 5;
75pub const ISA_UNDETECTED: u8 = 255;
77
78#[must_use]
89pub fn detect_host_isa() -> u8 {
90 #[cfg(target_arch = "x86_64")]
91 {
92 if is_x86_feature_detected!("avx512f") {
93 return ISA_AVX512;
94 }
95 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
96 return ISA_AVX2_FMA;
97 }
98 if is_x86_feature_detected!("avx") {
99 return ISA_AVX;
100 }
101 if is_x86_feature_detected!("sse2") {
102 return ISA_SSE2;
103 }
104 return ISA_SCALAR;
105 }
106
107 #[cfg(target_arch = "aarch64")]
108 {
109 if std::arch::is_aarch64_feature_detected!("neon") {
110 return ISA_NEON;
111 }
112 return ISA_SCALAR;
113 }
114
115 #[allow(unreachable_code)]
117 ISA_SCALAR
118}
119
120fn build_detect_x86_body() -> TokenStream {
126 quote! {
127 #[cfg(target_arch = "x86_64")]
128 {
129 if is_x86_feature_detected!("avx512f") {
130 return ISA_AVX512_LEVEL;
131 }
132 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
133 return ISA_AVX2_FMA_LEVEL;
134 }
135 if is_x86_feature_detected!("avx") {
136 return ISA_AVX_LEVEL;
137 }
138 if is_x86_feature_detected!("sse2") {
139 return ISA_SSE2_LEVEL;
140 }
141 return ISA_SCALAR_LEVEL;
142 }
143 }
144}
145
146fn build_detect_aarch64_body() -> TokenStream {
148 quote! {
149 #[cfg(target_arch = "aarch64")]
150 {
151 if std::arch::is_aarch64_feature_detected!("neon") {
152 return ISA_NEON_LEVEL;
153 }
154 return ISA_SCALAR_LEVEL;
155 }
156 }
157}
158
159fn build_x86_64_branches(config: DispatcherConfig) -> TokenStream {
167 let size = config.size;
168 let ty_str = config.precision.type_str();
169 let ty_tokens: TokenStream = ty_str
170 .parse()
171 .unwrap_or_else(|_| unreachable!("ty_str is always f32 or f64"));
172 let avx512_fn = format_ident!("codelet_simd_{}_avx512_{}", size, ty_str);
173 let avx2_fn = format_ident!("codelet_simd_{}_avx2_{}", size, ty_str);
174 let sse2_fn = format_ident!("codelet_simd_{}_sse2_{}", size, ty_str);
175
176 if size == 16 {
177 if config.precision == Precision::F32 {
178 return quote! {
179 #[cfg(target_arch = "x86_64")]
180 {
181 if cached_level == ISA_AVX512_LEVEL {
182 let data_len = data.len() * 2;
185 let data_ptr = data.as_mut_ptr().cast::<#ty_tokens>();
186 let data_inner = unsafe { core::slice::from_raw_parts_mut(data_ptr, data_len) };
187 unsafe { super::#avx512_fn(data_inner, sign); }
188 return;
189 }
190 }
191 };
192 }
193 return quote! {};
195 }
196
197 let avx_branch = if config.precision == Precision::F64 {
199 let avx_f64_fn = format_ident!("codelet_simd_{}_avx_f64", size);
200 quote! {
201 if cached_level == ISA_AVX_LEVEL {
202 let data_len = data.len() * 2;
204 let data_ptr = data.as_mut_ptr().cast::<#ty_tokens>();
205 let data_inner = unsafe { core::slice::from_raw_parts_mut(data_ptr, data_len) };
206 unsafe { super::#avx_f64_fn(data_inner, sign); }
207 return;
208 }
209 }
210 } else {
211 quote! {}
212 };
213
214 quote! {
215 #[cfg(target_arch = "x86_64")]
216 {
217 if cached_level == ISA_AVX512_LEVEL {
218 let data_len = data.len() * 2;
220 let data_ptr = data.as_mut_ptr().cast::<#ty_tokens>();
221 let data_inner = unsafe { core::slice::from_raw_parts_mut(data_ptr, data_len) };
222 unsafe { super::#avx512_fn(data_inner, sign); }
223 return;
224 }
225 if cached_level == ISA_AVX2_FMA_LEVEL {
226 let data_len = data.len() * 2;
228 let data_ptr = data.as_mut_ptr().cast::<#ty_tokens>();
229 let data_inner = unsafe { core::slice::from_raw_parts_mut(data_ptr, data_len) };
230 unsafe { super::#avx2_fn(data_inner, sign); }
231 return;
232 }
233 #avx_branch
234 if cached_level == ISA_SSE2_LEVEL {
235 let data_len = data.len() * 2;
237 let data_ptr = data.as_mut_ptr().cast::<#ty_tokens>();
238 let data_inner = unsafe { core::slice::from_raw_parts_mut(data_ptr, data_len) };
239 unsafe { super::#sse2_fn(data_inner, sign); }
240 return;
241 }
242 }
243 }
244}
245
246fn build_aarch64_branch(config: DispatcherConfig) -> TokenStream {
250 if config.size == 16 {
251 return quote! {};
252 }
253 let ty_str = config.precision.type_str();
254 let ty_tokens: TokenStream = ty_str
255 .parse()
256 .unwrap_or_else(|_| unreachable!("ty_str is always f32 or f64"));
257 let neon_fn = format_ident!("codelet_simd_{}_neon_{}", config.size, ty_str);
258 quote! {
259 #[cfg(target_arch = "aarch64")]
260 {
261 if cached_level == ISA_NEON_LEVEL {
262 let data_len = data.len() * 2;
264 let data_ptr = data.as_mut_ptr().cast::<#ty_tokens>();
265 let data_inner = unsafe { core::slice::from_raw_parts_mut(data_ptr, data_len) };
266 unsafe { super::#neon_fn(data_inner, sign); }
267 return;
268 }
269 }
270 }
271}
272
273#[allow(clippy::too_many_lines)] pub fn generate_dispatcher(config: DispatcherConfig) -> Result<TokenStream, syn::Error> {
295 let size = config.size;
296 if !matches!(size, 2 | 4 | 8 | 16) {
297 return Err(syn::Error::new(
298 proc_macro2::Span::call_site(),
299 format!(
300 "gen_dispatcher_codelet: unsupported size {size} (expected one of 2, 4, 8, 16)"
301 ),
302 ));
303 }
304
305 let ty_str = config.precision.type_str();
306 let ty_upper = ty_str.to_uppercase();
307 let size_str = size.to_string();
308
309 let static_name = format_ident!("DETECTED_ISA_{}_{}", size_str, ty_upper);
311 let detect_fn = format_ident!("detect_isa_{}_{}", size_str, ty_str);
313 let cached_fn = format_ident!("codelet_simd_{}_cached_{}", size_str, ty_str);
315 let scalar_fn = format_ident!("codelet_simd_{}_scalar", size);
317
318 let detect_x86_body = build_detect_x86_body();
319 let detect_aarch64_body = build_detect_aarch64_body();
320 let x86_64_branches = build_x86_64_branches(config);
321 let aarch64_branch = build_aarch64_branch(config);
322
323 let ty_tokens: TokenStream = ty_str
324 .parse()
325 .unwrap_or_else(|_| unreachable!("ty_str is always f32 or f64"));
326
327 let fn_doc = format!(
328 "Cached runtime ISA dispatcher for size-{size} DFT ({ty_str}).\n\n\
329 On first call, probes CPU features and stores the ISA level in a\n\
330 thread-safe `AtomicU8` static. Subsequent calls read the cache with\n\
331 `Relaxed` ordering (benign-racy: all threads converge on the same answer).\n\n\
332 Dispatch priority on `x86_64`: AVX-512F > AVX2+FMA > AVX > SSE2 > scalar.\n\
333 Dispatch priority on `aarch64`: NEON > scalar.\n\
334 Other architectures fall through to the scalar codelet."
335 );
336
337 let size_lit = size;
338
339 Ok(quote! {
340 const ISA_SCALAR_LEVEL: u8 = 0;
342 const ISA_SSE2_LEVEL: u8 = 1;
343 const ISA_AVX_LEVEL: u8 = 2;
344 const ISA_AVX2_FMA_LEVEL: u8 = 3;
345 const ISA_AVX512_LEVEL: u8 = 4;
346 const ISA_NEON_LEVEL: u8 = 5;
347 const ISA_UNDETECTED_LEVEL: u8 = 255;
348
349 static #static_name: core::sync::atomic::AtomicU8 =
353 core::sync::atomic::AtomicU8::new(ISA_UNDETECTED_LEVEL);
354
355 fn #detect_fn() -> u8 {
357 #detect_x86_body
358 #detect_aarch64_body
359 #[allow(unreachable_code)]
360 ISA_SCALAR_LEVEL
361 }
362
363 #[doc = #fn_doc]
364 #[inline]
365 pub fn #cached_fn(
366 data: &mut [crate::kernel::Complex<#ty_tokens>],
367 sign: i32,
368 ) {
369 debug_assert!(
370 data.len() >= #size_lit,
371 "codelet_simd_{}_cached_{}: need >= {} elements, got {}",
372 #size_lit,
373 stringify!(#ty_tokens),
374 #size_lit,
375 data.len(),
376 );
377
378 let cached_level = {
380 let level = #static_name.load(core::sync::atomic::Ordering::Relaxed);
381 if level == ISA_UNDETECTED_LEVEL {
382 let detected = #detect_fn();
383 #static_name.store(detected, core::sync::atomic::Ordering::Relaxed);
385 detected
386 } else {
387 level
388 }
389 };
390
391 #x86_64_branches
397 #aarch64_branch
398
399 super::#scalar_fn(data, sign);
402 }
403 })
404}
405
406struct MacroArgs {
412 size: usize,
413 precision: Precision,
414}
415
416impl Parse for MacroArgs {
417 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
418 let mut size: Option<usize> = None;
419 let mut precision: Option<Precision> = None;
420
421 while !input.is_empty() {
422 let key: syn::Ident = input.parse()?;
423 let _eq: Token![=] = input.parse()?;
424 match key.to_string().as_str() {
425 "size" => {
426 let lit: LitInt = input.parse()?;
427 size = Some(lit.base10_parse::<usize>().map_err(|_| {
428 syn::Error::new(lit.span(), "expected an integer literal for `size`")
429 })?);
430 }
431 "ty" => {
432 let ident: syn::Ident = input.parse()?;
433 precision = Some(match ident.to_string().as_str() {
434 "f32" => Precision::F32,
435 "f64" => Precision::F64,
436 other => {
437 return Err(syn::Error::new(
438 ident.span(),
439 format!("unknown ty `{other}`, expected f32 or f64"),
440 ));
441 }
442 });
443 }
444 other => {
445 return Err(syn::Error::new(
446 key.span(),
447 format!("unknown key `{other}`, expected one of: size, ty"),
448 ));
449 }
450 }
451 if input.peek(Token![,]) {
452 let _: Token![,] = input.parse()?;
453 }
454 }
455
456 let size = size.ok_or_else(|| {
457 syn::Error::new(proc_macro2::Span::call_site(), "missing `size` argument")
458 })?;
459 let precision = precision.ok_or_else(|| {
460 syn::Error::new(proc_macro2::Span::call_site(), "missing `ty` argument")
461 })?;
462
463 Ok(Self { size, precision })
464 }
465}
466
467pub fn generate_from_macro(input: TokenStream) -> Result<TokenStream, syn::Error> {
476 let args: MacroArgs = syn::parse2(input)?;
477 generate_dispatcher(DispatcherConfig {
478 size: args.size,
479 precision: args.precision,
480 })
481}
482
483#[cfg(test)]
488mod tests {
489 use super::*;
490
491 #[test]
494 fn test_dispatcher_config_valid_f32() {
495 let config = DispatcherConfig {
496 size: 4,
497 precision: Precision::F32,
498 };
499 assert_eq!(config.size, 4);
500 assert_eq!(config.precision, Precision::F32);
501 }
502
503 #[test]
504 fn test_dispatcher_config_valid_f64() {
505 let config = DispatcherConfig {
506 size: 8,
507 precision: Precision::F64,
508 };
509 assert_eq!(config.size, 8);
510 assert_eq!(config.precision, Precision::F64);
511 }
512
513 #[test]
516 fn test_isa_constants_are_ordered() {
517 const _: () = {
519 assert!(ISA_SCALAR < ISA_SSE2);
520 assert!(ISA_SSE2 < ISA_AVX);
521 assert!(ISA_AVX < ISA_AVX2_FMA);
522 assert!(ISA_AVX2_FMA < ISA_AVX512);
523 assert!(ISA_NEON != ISA_SCALAR);
524 assert!(ISA_UNDETECTED == 255);
525 };
526 }
527
528 #[test]
531 fn test_generate_dispatcher_nonempty() {
532 let ts = generate_dispatcher(DispatcherConfig {
533 size: 4,
534 precision: Precision::F32,
535 })
536 .expect("should generate for size 4 f32");
537 assert!(!ts.is_empty(), "TokenStream must not be empty");
538 }
539
540 #[test]
541 fn test_generate_dispatcher_nonempty_f64() {
542 let ts = generate_dispatcher(DispatcherConfig {
543 size: 8,
544 precision: Precision::F64,
545 })
546 .expect("should generate for size 8 f64");
547 assert!(!ts.is_empty(), "TokenStream must not be empty");
548 }
549
550 #[test]
551 fn test_generate_dispatcher_contains_is_x86_feature_detected() {
552 let ts = generate_dispatcher(DispatcherConfig {
553 size: 4,
554 precision: Precision::F32,
555 })
556 .expect("should generate");
557 let s = ts.to_string();
558 assert!(
559 s.contains("is_x86_feature_detected"),
560 "generated code must contain is_x86_feature_detected! macro; got snippet: {}",
561 &s[..s.len().min(500)]
562 );
563 }
564
565 #[test]
566 fn test_generate_dispatcher_contains_atomic_u8() {
567 let ts = generate_dispatcher(DispatcherConfig {
568 size: 4,
569 precision: Precision::F32,
570 })
571 .expect("should generate");
572 let s = ts.to_string();
573 assert!(
574 s.contains("AtomicU8"),
575 "generated code must contain AtomicU8 static; got snippet: {}",
576 &s[..s.len().min(500)]
577 );
578 }
579
580 #[test]
581 fn test_generate_dispatcher_contains_isa_undetected() {
582 let ts = generate_dispatcher(DispatcherConfig {
583 size: 4,
584 precision: Precision::F32,
585 })
586 .expect("should generate");
587 let s = ts.to_string();
588 assert!(
589 s.contains("ISA_UNDETECTED_LEVEL") || s.contains("255"),
590 "generated code must reference ISA_UNDETECTED_LEVEL sentinel"
591 );
592 }
593
594 #[test]
595 fn test_generate_dispatcher_function_name_size4_f32() {
596 let ts = generate_dispatcher(DispatcherConfig {
597 size: 4,
598 precision: Precision::F32,
599 })
600 .expect("should generate");
601 let s = ts.to_string();
602 assert!(
603 s.contains("codelet_simd_4_cached_f32"),
604 "expected cached dispatcher name in output; snippet: {}",
605 &s[..s.len().min(400)]
606 );
607 }
608
609 #[test]
610 fn test_generate_dispatcher_function_name_size8_f64() {
611 let ts = generate_dispatcher(DispatcherConfig {
612 size: 8,
613 precision: Precision::F64,
614 })
615 .expect("should generate");
616 let s = ts.to_string();
617 assert!(
618 s.contains("codelet_simd_8_cached_f64"),
619 "expected cached dispatcher name in output"
620 );
621 }
622
623 #[test]
624 fn test_generate_dispatcher_all_valid_sizes() {
625 for &size in &[2_usize, 4, 8, 16] {
626 for &prec in &[Precision::F32, Precision::F64] {
627 let result = generate_dispatcher(DispatcherConfig {
628 size,
629 precision: prec,
630 });
631 assert!(
632 result.is_ok(),
633 "size={size} prec={prec:?} should succeed, got: {:?}",
634 result.err()
635 );
636 }
637 }
638 }
639
640 #[test]
641 fn test_generate_dispatcher_unsupported_size_returns_error() {
642 let result = generate_dispatcher(DispatcherConfig {
643 size: 3,
644 precision: Precision::F32,
645 });
646 assert!(result.is_err(), "size 3 must return Err");
647 }
648
649 #[test]
650 fn test_generate_dispatcher_unsupported_size_6_returns_error() {
651 let result = generate_dispatcher(DispatcherConfig {
652 size: 6,
653 precision: Precision::F64,
654 });
655 assert!(result.is_err(), "size 6 must return Err");
656 }
657
658 #[test]
661 fn test_dispatcher_isa_detection() {
662 let isa = detect_host_isa();
666 assert_ne!(
667 isa, ISA_UNDETECTED,
668 "detect_host_isa must never return ISA_UNDETECTED (255)"
669 );
670 assert!(
672 matches!(
673 isa,
674 ISA_SCALAR | ISA_SSE2 | ISA_AVX | ISA_AVX2_FMA | ISA_AVX512 | ISA_NEON
675 ),
676 "detect_host_isa returned unknown level {isa}"
677 );
678 }
679
680 #[test]
681 fn test_detect_host_isa_is_deterministic() {
682 let first = detect_host_isa();
683 let second = detect_host_isa();
684 assert_eq!(first, second, "detect_host_isa must be deterministic");
685 }
686
687 #[test]
690 fn test_generate_from_macro_size4_f32() {
691 let input: TokenStream = "size = 4, ty = f32".parse().expect("valid token stream");
692 let result = generate_from_macro(input);
693 assert!(
694 result.is_ok(),
695 "size=4 ty=f32 must succeed: {:?}",
696 result.err()
697 );
698 let s = result.expect("TokenStream").to_string();
699 assert!(
700 s.contains("codelet_simd_4_cached_f32"),
701 "must contain cached dispatcher name"
702 );
703 }
704
705 #[test]
706 fn test_generate_from_macro_size8_f64() {
707 let input: TokenStream = "size = 8, ty = f64".parse().expect("valid token stream");
708 let result = generate_from_macro(input);
709 assert!(
710 result.is_ok(),
711 "size=8 ty=f64 must succeed: {:?}",
712 result.err()
713 );
714 let s = result.expect("TokenStream").to_string();
715 assert!(
716 s.contains("codelet_simd_8_cached_f64"),
717 "must contain cached dispatcher name"
718 );
719 }
720
721 #[test]
722 fn test_generate_from_macro_size2_f64() {
723 let input: TokenStream = "size = 2, ty = f64".parse().expect("valid token stream");
724 let result = generate_from_macro(input);
725 assert!(result.is_ok(), "size=2 ty=f64 must succeed");
726 }
727
728 #[test]
729 fn test_generate_from_macro_size16_f32() {
730 let input: TokenStream = "size = 16, ty = f32".parse().expect("valid token stream");
731 let result = generate_from_macro(input);
732 assert!(result.is_ok(), "size=16 ty=f32 must succeed");
733 }
734
735 #[test]
736 fn test_generate_from_macro_missing_size_returns_error() {
737 let input: TokenStream = "ty = f32".parse().expect("valid token stream");
738 let result = generate_from_macro(input);
739 assert!(result.is_err(), "missing size must return error");
740 }
741
742 #[test]
743 fn test_generate_from_macro_missing_ty_returns_error() {
744 let input: TokenStream = "size = 4".parse().expect("valid token stream");
745 let result = generate_from_macro(input);
746 assert!(result.is_err(), "missing ty must return error");
747 }
748
749 #[test]
750 fn test_generate_from_macro_unknown_ty_returns_error() {
751 let input: TokenStream = "size = 4, ty = f16".parse().expect("valid token stream");
752 let result = generate_from_macro(input);
753 assert!(result.is_err(), "unknown ty must return error");
754 }
755
756 #[test]
757 fn test_generate_from_macro_unknown_key_returns_error() {
758 let input: TokenStream = "size = 4, ty = f32, isa = avx2"
759 .parse()
760 .expect("valid token stream");
761 let result = generate_from_macro(input);
762 assert!(result.is_err(), "unknown key must return error");
763 }
764
765 #[test]
766 fn test_generate_from_macro_unsupported_size_returns_error() {
767 let input: TokenStream = "size = 5, ty = f32".parse().expect("valid token stream");
768 let result = generate_from_macro(input);
769 assert!(result.is_err(), "size=5 must return error");
770 }
771}