oxide_rs/inference/
simd_dispatch.rs1use std::sync::OnceLock;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum CpuFeature {
10 Avx512,
11 Avx2,
12 Neon,
13 None,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum SimdLevel {
18 Auto,
19 Avx512,
20 Avx2,
21 Neon,
22 Scalar,
23}
24
25impl SimdLevel {
26 pub fn from_str(s: &str) -> Self {
27 match s.to_lowercase().as_str() {
28 "avx512" => SimdLevel::Avx512,
29 "avx2" => SimdLevel::Avx2,
30 "neon" => SimdLevel::Neon,
31 "none" | "scalar" => SimdLevel::Scalar,
32 _ => SimdLevel::Auto,
33 }
34 }
35}
36
37#[derive(Debug, Clone)]
38pub struct SimdDispatch {
39 pub level: SimdLevel,
40 pub cpu_features: CpuFeatures,
41}
42
43#[derive(Debug, Clone)]
44pub struct CpuFeatures {
45 pub has_avx512: bool,
46 pub has_avx2: bool,
47 pub has_avx: bool,
48 pub has_neon: bool,
49 pub num_cores: usize,
50 pub num_physical_cores: usize,
51}
52
53impl CpuFeatures {
54 pub fn detect() -> Self {
55 let num_cores = num_cpus::get();
56 let num_physical_cores = num_cpus::get_physical();
57
58 #[cfg(target_arch = "x86_64")]
59 {
60 Self::detect_x86(num_cores, num_physical_cores)
61 }
62 #[cfg(target_arch = "aarch64")]
63 {
64 Self::detect_arm(num_cores, num_physical_cores)
65 }
66 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
67 {
68 Self {
69 has_avx512: false,
70 has_avx2: false,
71 has_avx: false,
72 has_neon: false,
73 num_cores,
74 num_physical_cores,
75 }
76 }
77 }
78
79 #[cfg(target_arch = "x86_64")]
80 fn detect_x86(num_cores: usize, num_physical_cores: usize) -> Self {
81 Self {
84 has_avx512: false,
85 has_avx2: true,
86 has_avx: true,
87 has_neon: false,
88 num_cores,
89 num_physical_cores,
90 }
91 }
92
93 #[cfg(target_arch = "aarch64")]
94 fn detect_arm(num_cores: usize, num_physical_cores: usize) -> Self {
95 let has_neon = std::arch::is_aarch64_feature_detected!("neon");
96
97 Self {
98 has_avx512: false,
99 has_avx2: false,
100 has_avx: false,
101 has_neon,
102 num_cores,
103 num_physical_cores,
104 }
105 }
106
107 pub fn recommended_simd(&self) -> SimdLevel {
108 if self.has_avx512 {
109 SimdLevel::Avx512
110 } else if self.has_avx2 {
111 SimdLevel::Avx2
112 } else if self.has_neon {
113 SimdLevel::Neon
114 } else {
115 SimdLevel::Scalar
116 }
117 }
118}
119
120static SIMD_DISPATCH: OnceLock<SimdDispatch> = OnceLock::new();
121
122impl SimdDispatch {
123 pub fn init(level: SimdLevel) -> &'static Self {
124 SIMD_DISPATCH.get_or_init(|| {
125 let cpu_features = CpuFeatures::detect();
126 let actual_level = match level {
127 SimdLevel::Auto => cpu_features.recommended_simd(),
128 _ => level,
129 };
130
131 tracing::info!(
132 "SIMD dispatch initialized: {:?} (detected: AVX512: {}, AVX2: {}, NEON: {})",
133 actual_level,
134 cpu_features.has_avx512,
135 cpu_features.has_avx2,
136 cpu_features.has_neon
137 );
138
139 Self {
140 level: actual_level,
141 cpu_features,
142 }
143 })
144 }
145
146 pub fn get() -> &'static Self {
147 SIMD_DISPATCH.get_or_init(|| {
148 let cpu_features = CpuFeatures::detect();
149 let level = cpu_features.recommended_simd();
150
151 Self {
152 level,
153 cpu_features,
154 }
155 })
156 }
157}
158
159pub fn init_simd(level: SimdLevel) -> &'static SimdDispatch {
160 SimdDispatch::init(level)
161}
162
163pub fn get_simd() -> &'static SimdDispatch {
164 SimdDispatch::get()
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 #[test]
172 fn test_cpu_features_detection() {
173 let features = CpuFeatures::detect();
174 assert!(features.num_cores > 0);
175 }
176
177 #[test]
178 fn test_simd_level_from_str() {
179 assert_eq!(SimdLevel::from_str("avx512"), SimdLevel::Avx512);
180 assert_eq!(SimdLevel::from_str("AVX2"), SimdLevel::Avx2);
181 assert_eq!(SimdLevel::from_str("auto"), SimdLevel::Auto);
182 assert_eq!(SimdLevel::from_str("none"), SimdLevel::Scalar);
183 }
184}