Skip to main content

oxifft_codegen_impl/gen_simd/
runtime_dispatch.rs

1//! Centralized ISA runtime dispatch codegen for `OxiFFT` SIMD codelets.
2//!
3//! This module generates **cached** runtime ISA dispatchers that extend the
4//! inline dispatchers in [`super`] with an `AtomicU8`-based ISA level cache.
5//!
6//! # Motivation
7//!
8//! The basic dispatchers emitted by `super::gen_dispatcher` perform
9//! `is_x86_feature_detected!` / `is_aarch64_feature_detected!` on every call.
10//! While each call is cheap (typically one CPUID cache read), a hot codelet
11//! invoked millions of times per second may benefit from the cached path, which
12//! replaces repeated feature probes with a single `AtomicU8` load.
13//!
14//! # Priority order (high → low)
15//!
16//! ```text
17//! x86_64: AVX-512F > AVX2+FMA > AVX > SSE2 > scalar
18//! aarch64: NEON > scalar
19//! other: scalar
20//! ```
21//!
22//! # Generated code shape
23//!
24//! For each `(size, precision)` pair, the proc-macro emits:
25//! - ISA level constants (`ISA_SCALAR`, `ISA_SSE2`, … `ISA_UNDETECTED`)
26//! - A `static DETECTED_ISA_{size}_{TY}: AtomicU8` initialized to `ISA_UNDETECTED`
27//! - A private `detect_isa_{size}_{ty}() -> u8` function that probes the CPU once
28//! - A public `{fn_name}_cached(data, sign)` dispatcher that reads the cache first
29//!
30//! # Proc-macro entry
31//!
32//! ```ignore
33//! // Generates a cached dispatcher for size-4 f32.
34//! gen_dispatcher_codelet!(size = 4, ty = f32);
35//! ```
36
37use proc_macro2::TokenStream;
38use quote::{format_ident, quote};
39use syn::{
40    parse::{Parse, ParseStream},
41    LitInt, Token,
42};
43
44pub use super::multi_transform::Precision;
45
46// ============================================================================
47// Public types
48// ============================================================================
49
50/// Configuration for a cached runtime ISA dispatcher codelet.
51#[derive(Debug, Clone, Copy)]
52pub struct DispatcherConfig {
53    /// DFT size — must be one of 2, 4, 8, or 16.
54    pub size: usize,
55    /// Floating-point precision.
56    pub precision: Precision,
57}
58
59// ============================================================================
60// ISA level constants (used in generated code and in host-detection helper)
61// ============================================================================
62
63/// ISA level for scalar fallback.
64pub const ISA_SCALAR: u8 = 0;
65/// ISA level for SSE2.
66pub const ISA_SSE2: u8 = 1;
67/// ISA level for pure AVX (no FMA, no AVX2).
68pub const ISA_AVX: u8 = 2;
69/// ISA level for AVX2 + FMA.
70pub const ISA_AVX2_FMA: u8 = 3;
71/// ISA level for AVX-512F.
72pub const ISA_AVX512: u8 = 4;
73/// ISA level for NEON (aarch64).
74pub const ISA_NEON: u8 = 5;
75/// Sentinel: ISA not yet detected (stored in the `AtomicU8` before first call).
76pub const ISA_UNDETECTED: u8 = 255;
77
78// ============================================================================
79// Host-detection helper (used by tests and by the generated detection code)
80// ============================================================================
81
82/// Detect the best ISA available on the current host at runtime.
83///
84/// Returns one of the `ISA_*` constants.  Never returns `ISA_UNDETECTED`.
85///
86/// This function is also used in the unit tests to validate that we always
87/// detect a valid ISA on the host machine.
88#[must_use]
89pub fn detect_host_isa() -> u8 {
90    #[cfg(target_arch = "x86_64")]
91    {
92        if is_x86_feature_detected!("avx512f") {
93            return ISA_AVX512;
94        }
95        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
96            return ISA_AVX2_FMA;
97        }
98        if is_x86_feature_detected!("avx") {
99            return ISA_AVX;
100        }
101        if is_x86_feature_detected!("sse2") {
102            return ISA_SSE2;
103        }
104        return ISA_SCALAR;
105    }
106
107    #[cfg(target_arch = "aarch64")]
108    {
109        if std::arch::is_aarch64_feature_detected!("neon") {
110            return ISA_NEON;
111        }
112        return ISA_SCALAR;
113    }
114
115    // All other architectures (wasm32, riscv, etc.)
116    #[allow(unreachable_code)]
117    ISA_SCALAR
118}
119
120// ============================================================================
121// Code generation helpers
122// ============================================================================
123
124/// Build the `x86_64` ISA detection body emitted inside the detect function.
125fn build_detect_x86_body() -> TokenStream {
126    quote! {
127        #[cfg(target_arch = "x86_64")]
128        {
129            if is_x86_feature_detected!("avx512f") {
130                return ISA_AVX512_LEVEL;
131            }
132            if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
133                return ISA_AVX2_FMA_LEVEL;
134            }
135            if is_x86_feature_detected!("avx") {
136                return ISA_AVX_LEVEL;
137            }
138            if is_x86_feature_detected!("sse2") {
139                return ISA_SSE2_LEVEL;
140            }
141            return ISA_SCALAR_LEVEL;
142        }
143    }
144}
145
146/// Build the aarch64 ISA detection body emitted inside the detect function.
147fn build_detect_aarch64_body() -> TokenStream {
148    quote! {
149        #[cfg(target_arch = "aarch64")]
150        {
151            if std::arch::is_aarch64_feature_detected!("neon") {
152                return ISA_NEON_LEVEL;
153            }
154            return ISA_SCALAR_LEVEL;
155        }
156    }
157}
158
159/// Build the `x86_64` dispatch branches for the cached dispatcher body.
160///
161/// For size-16 f32 only AVX-512 is available; for size-16 f64 no x86 SIMD
162/// path exists.  For all other sizes (2, 4, 8), all ISA levels are probed.
163///
164/// Each branch creates its own local `data_inner` reinterpretation so that
165/// the raw-pointer slice never aliases the original `data` borrow.
166fn build_x86_64_branches(config: DispatcherConfig) -> TokenStream {
167    let size = config.size;
168    let ty_str = config.precision.type_str();
169    let ty_tokens: TokenStream = ty_str
170        .parse()
171        .unwrap_or_else(|_| unreachable!("ty_str is always f32 or f64"));
172    let avx512_fn = format_ident!("codelet_simd_{}_avx512_{}", size, ty_str);
173    let avx2_fn = format_ident!("codelet_simd_{}_avx2_{}", size, ty_str);
174    let sse2_fn = format_ident!("codelet_simd_{}_sse2_{}", size, ty_str);
175
176    if size == 16 {
177        if config.precision == Precision::F32 {
178            return quote! {
179                #[cfg(target_arch = "x86_64")]
180                {
181                    if cached_level == ISA_AVX512_LEVEL {
182                        // Safety: avx512f detected at runtime.
183                        // Layout: Complex<f32> is #[repr(C)] (re, im) — same as [f32; 2*N].
184                        let data_len = data.len() * 2;
185                        let data_ptr = data.as_mut_ptr().cast::<#ty_tokens>();
186                        let data_inner = unsafe { core::slice::from_raw_parts_mut(data_ptr, data_len) };
187                        unsafe { super::#avx512_fn(data_inner, sign); }
188                        return;
189                    }
190                }
191            };
192        }
193        // size-16 f64: no dedicated SIMD on x86_64
194        return quote! {};
195    }
196
197    // Pure-AVX path only exists for f64 (no pure-AVX f32 emitter)
198    let avx_branch = if config.precision == Precision::F64 {
199        let avx_f64_fn = format_ident!("codelet_simd_{}_avx_f64", size);
200        quote! {
201            if cached_level == ISA_AVX_LEVEL {
202                // Safety: avx detected at runtime; function has #[target_feature(enable = "avx")].
203                let data_len = data.len() * 2;
204                let data_ptr = data.as_mut_ptr().cast::<#ty_tokens>();
205                let data_inner = unsafe { core::slice::from_raw_parts_mut(data_ptr, data_len) };
206                unsafe { super::#avx_f64_fn(data_inner, sign); }
207                return;
208            }
209        }
210    } else {
211        quote! {}
212    };
213
214    quote! {
215        #[cfg(target_arch = "x86_64")]
216        {
217            if cached_level == ISA_AVX512_LEVEL {
218                // Safety: avx512f detected at runtime.
219                let data_len = data.len() * 2;
220                let data_ptr = data.as_mut_ptr().cast::<#ty_tokens>();
221                let data_inner = unsafe { core::slice::from_raw_parts_mut(data_ptr, data_len) };
222                unsafe { super::#avx512_fn(data_inner, sign); }
223                return;
224            }
225            if cached_level == ISA_AVX2_FMA_LEVEL {
226                // Safety: avx2+fma detected at runtime.
227                let data_len = data.len() * 2;
228                let data_ptr = data.as_mut_ptr().cast::<#ty_tokens>();
229                let data_inner = unsafe { core::slice::from_raw_parts_mut(data_ptr, data_len) };
230                unsafe { super::#avx2_fn(data_inner, sign); }
231                return;
232            }
233            #avx_branch
234            if cached_level == ISA_SSE2_LEVEL {
235                // Safety: sse2 detected at runtime.
236                let data_len = data.len() * 2;
237                let data_ptr = data.as_mut_ptr().cast::<#ty_tokens>();
238                let data_inner = unsafe { core::slice::from_raw_parts_mut(data_ptr, data_len) };
239                unsafe { super::#sse2_fn(data_inner, sign); }
240                return;
241            }
242        }
243    }
244}
245
246/// Build the aarch64 dispatch branch for the cached dispatcher body.
247///
248/// Size-16 has no NEON path.
249fn build_aarch64_branch(config: DispatcherConfig) -> TokenStream {
250    if config.size == 16 {
251        return quote! {};
252    }
253    let ty_str = config.precision.type_str();
254    let ty_tokens: TokenStream = ty_str
255        .parse()
256        .unwrap_or_else(|_| unreachable!("ty_str is always f32 or f64"));
257    let neon_fn = format_ident!("codelet_simd_{}_neon_{}", config.size, ty_str);
258    quote! {
259        #[cfg(target_arch = "aarch64")]
260        {
261            if cached_level == ISA_NEON_LEVEL {
262                // Safety: NEON detected at runtime; mandatory on aarch64.
263                let data_len = data.len() * 2;
264                let data_ptr = data.as_mut_ptr().cast::<#ty_tokens>();
265                let data_inner = unsafe { core::slice::from_raw_parts_mut(data_ptr, data_len) };
266                unsafe { super::#neon_fn(data_inner, sign); }
267                return;
268            }
269        }
270    }
271}
272
273// ============================================================================
274// Code generation
275// ============================================================================
276
277/// Generate a cached runtime ISA dispatcher `TokenStream`.
278///
279/// The emitted code:
280/// 1. Declares ISA constants (only once per invocation; the caller is
281///    responsible for deduplication if multiple sizes share a module).
282/// 2. Declares a `static DETECTED_ISA_{size}_{ty}: AtomicU8`.
283/// 3. Emits a private `detect_isa_{size}_{ty}() -> u8` probe function.
284/// 4. Emits a public `codelet_simd_{size}_cached_{ty}(data, sign)` dispatcher.
285///
286/// The dispatcher delegates to the same arch-specific inner functions that the
287/// basic (uncached) dispatcher in [`super`] uses, following the exact same
288/// naming convention: `codelet_simd_{size}_{isa}_{ty}`.
289///
290/// # Errors
291///
292/// Returns `syn::Error` when `config.size` is not one of 2, 4, 8, or 16.
293#[allow(clippy::too_many_lines)] // reason: token-stream assembly requires many local variables
294pub fn generate_dispatcher(config: DispatcherConfig) -> Result<TokenStream, syn::Error> {
295    let size = config.size;
296    if !matches!(size, 2 | 4 | 8 | 16) {
297        return Err(syn::Error::new(
298            proc_macro2::Span::call_site(),
299            format!(
300                "gen_dispatcher_codelet: unsupported size {size} (expected one of 2, 4, 8, 16)"
301            ),
302        ));
303    }
304
305    let ty_str = config.precision.type_str();
306    let ty_upper = ty_str.to_uppercase();
307    let size_str = size.to_string();
308
309    // AtomicU8 static name: DETECTED_ISA_4_F32
310    let static_name = format_ident!("DETECTED_ISA_{}_{}", size_str, ty_upper);
311    // Detect function name: detect_isa_4_f32
312    let detect_fn = format_ident!("detect_isa_{}_{}", size_str, ty_str);
313    // Cached dispatcher name: codelet_simd_4_cached_f32
314    let cached_fn = format_ident!("codelet_simd_{}_cached_{}", size_str, ty_str);
315    // Scalar fallback name: codelet_simd_4_scalar
316    let scalar_fn = format_ident!("codelet_simd_{}_scalar", size);
317
318    let detect_x86_body = build_detect_x86_body();
319    let detect_aarch64_body = build_detect_aarch64_body();
320    let x86_64_branches = build_x86_64_branches(config);
321    let aarch64_branch = build_aarch64_branch(config);
322
323    let ty_tokens: TokenStream = ty_str
324        .parse()
325        .unwrap_or_else(|_| unreachable!("ty_str is always f32 or f64"));
326
327    let fn_doc = format!(
328        "Cached runtime ISA dispatcher for size-{size} DFT ({ty_str}).\n\n\
329         On first call, probes CPU features and stores the ISA level in a\n\
330         thread-safe `AtomicU8` static.  Subsequent calls read the cache with\n\
331         `Relaxed` ordering (benign-racy: all threads converge on the same answer).\n\n\
332         Dispatch priority on `x86_64`: AVX-512F > AVX2+FMA > AVX > SSE2 > scalar.\n\
333         Dispatch priority on `aarch64`: NEON > scalar.\n\
334         Other architectures fall through to the scalar codelet."
335    );
336
337    let size_lit = size;
338
339    Ok(quote! {
340        // ISA level constants (private to the generated scope)
341        const ISA_SCALAR_LEVEL:     u8 = 0;
342        const ISA_SSE2_LEVEL:       u8 = 1;
343        const ISA_AVX_LEVEL:        u8 = 2;
344        const ISA_AVX2_FMA_LEVEL:   u8 = 3;
345        const ISA_AVX512_LEVEL:     u8 = 4;
346        const ISA_NEON_LEVEL:       u8 = 5;
347        const ISA_UNDETECTED_LEVEL: u8 = 255;
348
349        /// Cached ISA level for this (size, precision) pair.
350        ///
351        /// Initialized to `ISA_UNDETECTED_LEVEL`.  Written once on first dispatch call.
352        static #static_name: core::sync::atomic::AtomicU8 =
353            core::sync::atomic::AtomicU8::new(ISA_UNDETECTED_LEVEL);
354
355        /// Probe the CPU once and return the best ISA level for this target.
356        fn #detect_fn() -> u8 {
357            #detect_x86_body
358            #detect_aarch64_body
359            #[allow(unreachable_code)]
360            ISA_SCALAR_LEVEL
361        }
362
363        #[doc = #fn_doc]
364        #[inline]
365        pub fn #cached_fn(
366            data: &mut [crate::kernel::Complex<#ty_tokens>],
367            sign: i32,
368        ) {
369            debug_assert!(
370                data.len() >= #size_lit,
371                "codelet_simd_{}_cached_{}: need >= {} elements, got {}",
372                #size_lit,
373                stringify!(#ty_tokens),
374                #size_lit,
375                data.len(),
376            );
377
378            // Load cached ISA; detect on first call.
379            let cached_level = {
380                let level = #static_name.load(core::sync::atomic::Ordering::Relaxed);
381                if level == ISA_UNDETECTED_LEVEL {
382                    let detected = #detect_fn();
383                    // Relaxed store: benign-racy — all threads converge on the same value.
384                    #static_name.store(detected, core::sync::atomic::Ordering::Relaxed);
385                    detected
386                } else {
387                    level
388                }
389            };
390
391            // Architecture-specific SIMD paths.
392            //
393            // data_inner is created inside each cfg block so that the raw-pointer
394            // reinterpretation and the original `data` borrow never overlap.
395            // The scalar fallback uses `data` directly — no aliasing.
396            #x86_64_branches
397            #aarch64_branch
398
399            // Scalar fallback: use the original Complex slice directly.
400            // No reinterpretation needed — the scalar codelet accepts Complex<T>.
401            super::#scalar_fn(data, sign);
402        }
403    })
404}
405
406// ============================================================================
407// Proc-macro parse input
408// ============================================================================
409
410/// Parsed arguments from `gen_dispatcher_codelet!(size = 4, ty = f32)`.
411struct MacroArgs {
412    size: usize,
413    precision: Precision,
414}
415
416impl Parse for MacroArgs {
417    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
418        let mut size: Option<usize> = None;
419        let mut precision: Option<Precision> = None;
420
421        while !input.is_empty() {
422            let key: syn::Ident = input.parse()?;
423            let _eq: Token![=] = input.parse()?;
424            match key.to_string().as_str() {
425                "size" => {
426                    let lit: LitInt = input.parse()?;
427                    size = Some(lit.base10_parse::<usize>().map_err(|_| {
428                        syn::Error::new(lit.span(), "expected an integer literal for `size`")
429                    })?);
430                }
431                "ty" => {
432                    let ident: syn::Ident = input.parse()?;
433                    precision = Some(match ident.to_string().as_str() {
434                        "f32" => Precision::F32,
435                        "f64" => Precision::F64,
436                        other => {
437                            return Err(syn::Error::new(
438                                ident.span(),
439                                format!("unknown ty `{other}`, expected f32 or f64"),
440                            ));
441                        }
442                    });
443                }
444                other => {
445                    return Err(syn::Error::new(
446                        key.span(),
447                        format!("unknown key `{other}`, expected one of: size, ty"),
448                    ));
449                }
450            }
451            if input.peek(Token![,]) {
452                let _: Token![,] = input.parse()?;
453            }
454        }
455
456        let size = size.ok_or_else(|| {
457            syn::Error::new(proc_macro2::Span::call_site(), "missing `size` argument")
458        })?;
459        let precision = precision.ok_or_else(|| {
460            syn::Error::new(proc_macro2::Span::call_site(), "missing `ty` argument")
461        })?;
462
463        Ok(Self { size, precision })
464    }
465}
466
467/// Entry point for the `gen_dispatcher_codelet!` proc-macro.
468///
469/// Parses `size = N, ty = TY` and calls [`generate_dispatcher`].
470///
471/// # Errors
472///
473/// Returns a `syn::Error` when the input does not parse as valid key-value
474/// pairs, a required key is missing, or `size` / `ty` have unsupported values.
475pub fn generate_from_macro(input: TokenStream) -> Result<TokenStream, syn::Error> {
476    let args: MacroArgs = syn::parse2(input)?;
477    generate_dispatcher(DispatcherConfig {
478        size: args.size,
479        precision: args.precision,
480    })
481}
482
483// ============================================================================
484// Tests
485// ============================================================================
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490
491    // ── DispatcherConfig construction ─────────────────────────────────────
492
493    #[test]
494    fn test_dispatcher_config_valid_f32() {
495        let config = DispatcherConfig {
496            size: 4,
497            precision: Precision::F32,
498        };
499        assert_eq!(config.size, 4);
500        assert_eq!(config.precision, Precision::F32);
501    }
502
503    #[test]
504    fn test_dispatcher_config_valid_f64() {
505        let config = DispatcherConfig {
506            size: 8,
507            precision: Precision::F64,
508        };
509        assert_eq!(config.size, 8);
510        assert_eq!(config.precision, Precision::F64);
511    }
512
513    // ── ISA constants ─────────────────────────────────────────────────────
514
515    #[test]
516    fn test_isa_constants_are_ordered() {
517        // Validate ordering as compile-time assertions embedded in a constant.
518        const _: () = {
519            assert!(ISA_SCALAR < ISA_SSE2);
520            assert!(ISA_SSE2 < ISA_AVX);
521            assert!(ISA_AVX < ISA_AVX2_FMA);
522            assert!(ISA_AVX2_FMA < ISA_AVX512);
523            assert!(ISA_NEON != ISA_SCALAR);
524            assert!(ISA_UNDETECTED == 255);
525        };
526    }
527
528    // ── generate_dispatcher: TokenStream checks ───────────────────────────
529
530    #[test]
531    fn test_generate_dispatcher_nonempty() {
532        let ts = generate_dispatcher(DispatcherConfig {
533            size: 4,
534            precision: Precision::F32,
535        })
536        .expect("should generate for size 4 f32");
537        assert!(!ts.is_empty(), "TokenStream must not be empty");
538    }
539
540    #[test]
541    fn test_generate_dispatcher_nonempty_f64() {
542        let ts = generate_dispatcher(DispatcherConfig {
543            size: 8,
544            precision: Precision::F64,
545        })
546        .expect("should generate for size 8 f64");
547        assert!(!ts.is_empty(), "TokenStream must not be empty");
548    }
549
550    #[test]
551    fn test_generate_dispatcher_contains_is_x86_feature_detected() {
552        let ts = generate_dispatcher(DispatcherConfig {
553            size: 4,
554            precision: Precision::F32,
555        })
556        .expect("should generate");
557        let s = ts.to_string();
558        assert!(
559            s.contains("is_x86_feature_detected"),
560            "generated code must contain is_x86_feature_detected! macro; got snippet: {}",
561            &s[..s.len().min(500)]
562        );
563    }
564
565    #[test]
566    fn test_generate_dispatcher_contains_atomic_u8() {
567        let ts = generate_dispatcher(DispatcherConfig {
568            size: 4,
569            precision: Precision::F32,
570        })
571        .expect("should generate");
572        let s = ts.to_string();
573        assert!(
574            s.contains("AtomicU8"),
575            "generated code must contain AtomicU8 static; got snippet: {}",
576            &s[..s.len().min(500)]
577        );
578    }
579
580    #[test]
581    fn test_generate_dispatcher_contains_isa_undetected() {
582        let ts = generate_dispatcher(DispatcherConfig {
583            size: 4,
584            precision: Precision::F32,
585        })
586        .expect("should generate");
587        let s = ts.to_string();
588        assert!(
589            s.contains("ISA_UNDETECTED_LEVEL") || s.contains("255"),
590            "generated code must reference ISA_UNDETECTED_LEVEL sentinel"
591        );
592    }
593
594    #[test]
595    fn test_generate_dispatcher_function_name_size4_f32() {
596        let ts = generate_dispatcher(DispatcherConfig {
597            size: 4,
598            precision: Precision::F32,
599        })
600        .expect("should generate");
601        let s = ts.to_string();
602        assert!(
603            s.contains("codelet_simd_4_cached_f32"),
604            "expected cached dispatcher name in output; snippet: {}",
605            &s[..s.len().min(400)]
606        );
607    }
608
609    #[test]
610    fn test_generate_dispatcher_function_name_size8_f64() {
611        let ts = generate_dispatcher(DispatcherConfig {
612            size: 8,
613            precision: Precision::F64,
614        })
615        .expect("should generate");
616        let s = ts.to_string();
617        assert!(
618            s.contains("codelet_simd_8_cached_f64"),
619            "expected cached dispatcher name in output"
620        );
621    }
622
623    #[test]
624    fn test_generate_dispatcher_all_valid_sizes() {
625        for &size in &[2_usize, 4, 8, 16] {
626            for &prec in &[Precision::F32, Precision::F64] {
627                let result = generate_dispatcher(DispatcherConfig {
628                    size,
629                    precision: prec,
630                });
631                assert!(
632                    result.is_ok(),
633                    "size={size} prec={prec:?} should succeed, got: {:?}",
634                    result.err()
635                );
636            }
637        }
638    }
639
640    #[test]
641    fn test_generate_dispatcher_unsupported_size_returns_error() {
642        let result = generate_dispatcher(DispatcherConfig {
643            size: 3,
644            precision: Precision::F32,
645        });
646        assert!(result.is_err(), "size 3 must return Err");
647    }
648
649    #[test]
650    fn test_generate_dispatcher_unsupported_size_6_returns_error() {
651        let result = generate_dispatcher(DispatcherConfig {
652            size: 6,
653            precision: Precision::F64,
654        });
655        assert!(result.is_err(), "size 6 must return Err");
656    }
657
658    // ── detect_host_isa ───────────────────────────────────────────────────
659
660    #[test]
661    fn test_dispatcher_isa_detection() {
662        // On the host machine, detect_host_isa() must always return a valid ISA level.
663        // On aarch64 macOS (Apple Silicon) this should be ISA_NEON.
664        // On x86_64 this should be ISA_SSE2 or higher.
665        let isa = detect_host_isa();
666        assert_ne!(
667            isa, ISA_UNDETECTED,
668            "detect_host_isa must never return ISA_UNDETECTED (255)"
669        );
670        // Must be one of the known constants
671        assert!(
672            matches!(
673                isa,
674                ISA_SCALAR | ISA_SSE2 | ISA_AVX | ISA_AVX2_FMA | ISA_AVX512 | ISA_NEON
675            ),
676            "detect_host_isa returned unknown level {isa}"
677        );
678    }
679
680    #[test]
681    fn test_detect_host_isa_is_deterministic() {
682        let first = detect_host_isa();
683        let second = detect_host_isa();
684        assert_eq!(first, second, "detect_host_isa must be deterministic");
685    }
686
687    // ── generate_from_macro ───────────────────────────────────────────────
688
689    #[test]
690    fn test_generate_from_macro_size4_f32() {
691        let input: TokenStream = "size = 4, ty = f32".parse().expect("valid token stream");
692        let result = generate_from_macro(input);
693        assert!(
694            result.is_ok(),
695            "size=4 ty=f32 must succeed: {:?}",
696            result.err()
697        );
698        let s = result.expect("TokenStream").to_string();
699        assert!(
700            s.contains("codelet_simd_4_cached_f32"),
701            "must contain cached dispatcher name"
702        );
703    }
704
705    #[test]
706    fn test_generate_from_macro_size8_f64() {
707        let input: TokenStream = "size = 8, ty = f64".parse().expect("valid token stream");
708        let result = generate_from_macro(input);
709        assert!(
710            result.is_ok(),
711            "size=8 ty=f64 must succeed: {:?}",
712            result.err()
713        );
714        let s = result.expect("TokenStream").to_string();
715        assert!(
716            s.contains("codelet_simd_8_cached_f64"),
717            "must contain cached dispatcher name"
718        );
719    }
720
721    #[test]
722    fn test_generate_from_macro_size2_f64() {
723        let input: TokenStream = "size = 2, ty = f64".parse().expect("valid token stream");
724        let result = generate_from_macro(input);
725        assert!(result.is_ok(), "size=2 ty=f64 must succeed");
726    }
727
728    #[test]
729    fn test_generate_from_macro_size16_f32() {
730        let input: TokenStream = "size = 16, ty = f32".parse().expect("valid token stream");
731        let result = generate_from_macro(input);
732        assert!(result.is_ok(), "size=16 ty=f32 must succeed");
733    }
734
735    #[test]
736    fn test_generate_from_macro_missing_size_returns_error() {
737        let input: TokenStream = "ty = f32".parse().expect("valid token stream");
738        let result = generate_from_macro(input);
739        assert!(result.is_err(), "missing size must return error");
740    }
741
742    #[test]
743    fn test_generate_from_macro_missing_ty_returns_error() {
744        let input: TokenStream = "size = 4".parse().expect("valid token stream");
745        let result = generate_from_macro(input);
746        assert!(result.is_err(), "missing ty must return error");
747    }
748
749    #[test]
750    fn test_generate_from_macro_unknown_ty_returns_error() {
751        let input: TokenStream = "size = 4, ty = f16".parse().expect("valid token stream");
752        let result = generate_from_macro(input);
753        assert!(result.is_err(), "unknown ty must return error");
754    }
755
756    #[test]
757    fn test_generate_from_macro_unknown_key_returns_error() {
758        let input: TokenStream = "size = 4, ty = f32, isa = avx2"
759            .parse()
760            .expect("valid token stream");
761        let result = generate_from_macro(input);
762        assert!(result.is_err(), "unknown key must return error");
763    }
764
765    #[test]
766    fn test_generate_from_macro_unsupported_size_returns_error() {
767        let input: TokenStream = "size = 5, ty = f32".parse().expect("valid token stream");
768        let result = generate_from_macro(input);
769        assert!(result.is_err(), "size=5 must return error");
770    }
771}