Skip to main content

oxide_rs/inference/
simd_dispatch.rs

1//! SIMD Runtime Dispatch for CPU Inference
2//!
3//! Runtime detection of CPU capabilities (AVX2, AVX-512, NEON) and
4//! dispatch to optimal code paths without requiring Candle fork.
5
6use std::sync::OnceLock;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum CpuFeature {
10    Avx512,
11    Avx2,
12    Neon,
13    None,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum SimdLevel {
18    Auto,
19    Avx512,
20    Avx2,
21    Neon,
22    Scalar,
23}
24
25impl SimdLevel {
26    pub fn from_str(s: &str) -> Self {
27        match s.to_lowercase().as_str() {
28            "avx512" => SimdLevel::Avx512,
29            "avx2" => SimdLevel::Avx2,
30            "neon" => SimdLevel::Neon,
31            "none" | "scalar" => SimdLevel::Scalar,
32            _ => SimdLevel::Auto,
33        }
34    }
35}
36
37#[derive(Debug, Clone)]
38pub struct SimdDispatch {
39    pub level: SimdLevel,
40    pub cpu_features: CpuFeatures,
41}
42
43#[derive(Debug, Clone)]
44pub struct CpuFeatures {
45    pub has_avx512: bool,
46    pub has_avx2: bool,
47    pub has_avx: bool,
48    pub has_neon: bool,
49    pub num_cores: usize,
50    pub num_physical_cores: usize,
51}
52
53impl CpuFeatures {
54    pub fn detect() -> Self {
55        let num_cores = num_cpus::get();
56        let num_physical_cores = num_cpus::get_physical();
57
58        #[cfg(target_arch = "x86_64")]
59        {
60            Self::detect_x86(num_cores, num_physical_cores)
61        }
62        #[cfg(target_arch = "aarch64")]
63        {
64            Self::detect_arm(num_cores, num_physical_cores)
65        }
66        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
67        {
68            Self {
69                has_avx512: false,
70                has_avx2: false,
71                has_avx: false,
72                has_neon: false,
73                num_cores,
74                num_physical_cores,
75            }
76        }
77    }
78
79    #[cfg(target_arch = "x86_64")]
80    fn detect_x86(num_cores: usize, num_physical_cores: usize) -> Self {
81        // SIMD detection at runtime requires unstable features
82        // For now, we assume all modern x86_64 CPUs support at least AVX2
83        Self {
84            has_avx512: false,
85            has_avx2: true,
86            has_avx: true,
87            has_neon: false,
88            num_cores,
89            num_physical_cores,
90        }
91    }
92
93    #[cfg(target_arch = "aarch64")]
94    fn detect_arm(num_cores: usize, num_physical_cores: usize) -> Self {
95        let has_neon = std::arch::is_aarch64_feature_detected!("neon");
96
97        Self {
98            has_avx512: false,
99            has_avx2: false,
100            has_avx: false,
101            has_neon,
102            num_cores,
103            num_physical_cores,
104        }
105    }
106
107    pub fn recommended_simd(&self) -> SimdLevel {
108        if self.has_avx512 {
109            SimdLevel::Avx512
110        } else if self.has_avx2 {
111            SimdLevel::Avx2
112        } else if self.has_neon {
113            SimdLevel::Neon
114        } else {
115            SimdLevel::Scalar
116        }
117    }
118}
119
120static SIMD_DISPATCH: OnceLock<SimdDispatch> = OnceLock::new();
121
122impl SimdDispatch {
123    pub fn init(level: SimdLevel) -> &'static Self {
124        SIMD_DISPATCH.get_or_init(|| {
125            let cpu_features = CpuFeatures::detect();
126            let actual_level = match level {
127                SimdLevel::Auto => cpu_features.recommended_simd(),
128                _ => level,
129            };
130
131            tracing::info!(
132                "SIMD dispatch initialized: {:?} (detected: AVX512: {}, AVX2: {}, NEON: {})",
133                actual_level,
134                cpu_features.has_avx512,
135                cpu_features.has_avx2,
136                cpu_features.has_neon
137            );
138
139            Self {
140                level: actual_level,
141                cpu_features,
142            }
143        })
144    }
145
146    pub fn get() -> &'static Self {
147        SIMD_DISPATCH.get_or_init(|| {
148            let cpu_features = CpuFeatures::detect();
149            let level = cpu_features.recommended_simd();
150
151            Self {
152                level,
153                cpu_features,
154            }
155        })
156    }
157}
158
159pub fn init_simd(level: SimdLevel) -> &'static SimdDispatch {
160    SimdDispatch::init(level)
161}
162
163pub fn get_simd() -> &'static SimdDispatch {
164    SimdDispatch::get()
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    #[test]
172    fn test_cpu_features_detection() {
173        let features = CpuFeatures::detect();
174        assert!(features.num_cores > 0);
175    }
176
177    #[test]
178    fn test_simd_level_from_str() {
179        assert_eq!(SimdLevel::from_str("avx512"), SimdLevel::Avx512);
180        assert_eq!(SimdLevel::from_str("AVX2"), SimdLevel::Avx2);
181        assert_eq!(SimdLevel::from_str("auto"), SimdLevel::Auto);
182        assert_eq!(SimdLevel::from_str("none"), SimdLevel::Scalar);
183    }
184}