Skip to main content

cuda_rust_wasm/simd/
detection.rs

1//! Runtime SIMD feature detection
2//!
3//! Detects available SIMD instruction sets at runtime and returns a capabilities
4//! struct that can be queried to select optimal code paths.
5
6use std::fmt;
7
8/// Available SIMD instruction set levels
9#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
10pub enum SimdLevel {
11    /// No SIMD support, scalar fallback only
12    Scalar,
13    /// SSE2 (128-bit, x86_64 baseline)
14    Sse2,
15    /// SSE4.1 (128-bit, enhanced integer ops)
16    Sse41,
17    /// AVX2 (256-bit, integer + float)
18    Avx2,
19    /// AVX-512 Foundation (512-bit)
20    Avx512,
21    /// ARM NEON (128-bit)
22    Neon,
23    /// ARM SVE (scalable vector extension)
24    Sve,
25    /// WebAssembly SIMD (128-bit)
26    WasmSimd128,
27}
28
29impl fmt::Display for SimdLevel {
30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        match self {
32            SimdLevel::Scalar => write!(f, "Scalar"),
33            SimdLevel::Sse2 => write!(f, "SSE2"),
34            SimdLevel::Sse41 => write!(f, "SSE4.1"),
35            SimdLevel::Avx2 => write!(f, "AVX2"),
36            SimdLevel::Avx512 => write!(f, "AVX-512"),
37            SimdLevel::Neon => write!(f, "NEON"),
38            SimdLevel::Sve => write!(f, "SVE"),
39            SimdLevel::WasmSimd128 => write!(f, "WASM SIMD128"),
40        }
41    }
42}
43
44/// Runtime SIMD capabilities of the current platform
45#[derive(Debug, Clone)]
46pub struct SimdCapabilities {
47    /// Whether SSE2 is available (x86_64 baseline, always true on x86_64)
48    pub has_sse2: bool,
49    /// Whether SSE4.1 is available
50    pub has_sse41: bool,
51    /// Whether AVX2 is available (256-bit integer + float)
52    pub has_avx2: bool,
53    /// Whether AVX-512 Foundation is available
54    pub has_avx512f: bool,
55    /// Whether FMA (fused multiply-add) is available
56    pub has_fma: bool,
57    /// Whether ARM NEON is available
58    pub has_neon: bool,
59    /// Whether ARM SVE is available
60    pub has_sve: bool,
61    /// Whether WASM SIMD128 is available
62    pub has_wasm_simd128: bool,
63    /// The maximum vector width in bytes supported by the platform
64    pub max_vector_width_bytes: usize,
65    /// The best available SIMD level
66    pub best_level: SimdLevel,
67}
68
69impl SimdCapabilities {
70    /// Detect SIMD capabilities at runtime for the current platform.
71    pub fn detect() -> Self {
72        let mut caps = SimdCapabilities {
73            has_sse2: false,
74            has_sse41: false,
75            has_avx2: false,
76            has_avx512f: false,
77            has_fma: false,
78            has_neon: false,
79            has_sve: false,
80            has_wasm_simd128: false,
81            max_vector_width_bytes: 0,
82            best_level: SimdLevel::Scalar,
83        };
84
85        #[cfg(target_arch = "x86_64")]
86        {
87            caps.detect_x86_64();
88        }
89
90        #[cfg(target_arch = "aarch64")]
91        {
92            caps.detect_aarch64();
93        }
94
95        #[cfg(target_arch = "wasm32")]
96        {
97            caps.detect_wasm();
98        }
99
100        // If nothing was detected, we still have scalar
101        if caps.max_vector_width_bytes == 0 {
102            // Scalar: process one f32 at a time
103            caps.max_vector_width_bytes = 4;
104        }
105
106        caps
107    }
108
109    /// Detect x86_64 SIMD features using `is_x86_feature_detected!`
110    #[cfg(target_arch = "x86_64")]
111    fn detect_x86_64(&mut self) {
112        // SSE2 is always available on x86_64
113        self.has_sse2 = true;
114        self.best_level = SimdLevel::Sse2;
115        self.max_vector_width_bytes = 16; // 128-bit
116
117        if is_x86_feature_detected!("sse4.1") {
118            self.has_sse41 = true;
119            self.best_level = SimdLevel::Sse41;
120        }
121
122        if is_x86_feature_detected!("fma") {
123            self.has_fma = true;
124        }
125
126        if is_x86_feature_detected!("avx2") {
127            self.has_avx2 = true;
128            self.best_level = SimdLevel::Avx2;
129            self.max_vector_width_bytes = 32; // 256-bit
130        }
131
132        if is_x86_feature_detected!("avx512f") {
133            self.has_avx512f = true;
134            self.best_level = SimdLevel::Avx512;
135            self.max_vector_width_bytes = 64; // 512-bit
136        }
137    }
138
139    /// Detect aarch64 SIMD features
140    #[cfg(target_arch = "aarch64")]
141    fn detect_aarch64(&mut self) {
142        // NEON is mandatory on aarch64
143        self.has_neon = true;
144        self.best_level = SimdLevel::Neon;
145        self.max_vector_width_bytes = 16; // 128-bit
146
147        // SVE detection via std::arch feature detection
148        #[cfg(target_feature = "sve")]
149        {
150            self.has_sve = true;
151            self.best_level = SimdLevel::Sve;
152            // SVE vector length is implementation-defined (128-2048 bits)
153            // Use a conservative estimate; actual length queried at runtime would
154            // require inline assembly (cntb instruction).
155            self.max_vector_width_bytes = 32; // conservative 256-bit estimate
156        }
157    }
158
159    /// Detect WebAssembly SIMD features
160    #[cfg(target_arch = "wasm32")]
161    fn detect_wasm(&mut self) {
162        #[cfg(target_feature = "simd128")]
163        {
164            self.has_wasm_simd128 = true;
165            self.best_level = SimdLevel::WasmSimd128;
166            self.max_vector_width_bytes = 16; // 128-bit
167        }
168    }
169
170    /// Returns the number of f32 elements that can be processed in a single
171    /// SIMD operation.
172    pub fn f32_lane_count(&self) -> usize {
173        self.max_vector_width_bytes / std::mem::size_of::<f32>()
174    }
175
176    /// Returns true if any SIMD acceleration is available beyond scalar.
177    pub fn has_simd(&self) -> bool {
178        self.best_level != SimdLevel::Scalar
179    }
180
181    /// Returns a human-readable summary of detected capabilities.
182    pub fn summary(&self) -> String {
183        let mut features = Vec::new();
184
185        if self.has_sse2 {
186            features.push("SSE2");
187        }
188        if self.has_sse41 {
189            features.push("SSE4.1");
190        }
191        if self.has_avx2 {
192            features.push("AVX2");
193        }
194        if self.has_avx512f {
195            features.push("AVX-512F");
196        }
197        if self.has_fma {
198            features.push("FMA");
199        }
200        if self.has_neon {
201            features.push("NEON");
202        }
203        if self.has_sve {
204            features.push("SVE");
205        }
206        if self.has_wasm_simd128 {
207            features.push("WASM SIMD128");
208        }
209
210        if features.is_empty() {
211            "Scalar only (no SIMD)".to_string()
212        } else {
213            format!(
214                "Best: {} | Features: {} | Vector width: {} bytes ({} f32 lanes)",
215                self.best_level,
216                features.join(", "),
217                self.max_vector_width_bytes,
218                self.f32_lane_count()
219            )
220        }
221    }
222}
223
224impl fmt::Display for SimdCapabilities {
225    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226        write!(f, "{}", self.summary())
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233
234    #[test]
235    fn test_detect_returns_valid_capabilities() {
236        let caps = SimdCapabilities::detect();
237        // Should always have at least scalar width
238        assert!(caps.max_vector_width_bytes >= 4);
239        assert!(caps.f32_lane_count() >= 1);
240    }
241
242    #[test]
243    fn test_summary_not_empty() {
244        let caps = SimdCapabilities::detect();
245        let summary = caps.summary();
246        assert!(!summary.is_empty());
247    }
248
249    #[test]
250    fn test_display_impl() {
251        let caps = SimdCapabilities::detect();
252        let display = format!("{caps}");
253        assert!(!display.is_empty());
254    }
255
256    #[cfg(target_arch = "x86_64")]
257    #[test]
258    fn test_x86_64_has_sse2() {
259        let caps = SimdCapabilities::detect();
260        // SSE2 is mandatory on x86_64
261        assert!(caps.has_sse2);
262        assert!(caps.has_simd());
263    }
264
265    #[cfg(target_arch = "aarch64")]
266    #[test]
267    fn test_aarch64_has_neon() {
268        let caps = SimdCapabilities::detect();
269        // NEON is mandatory on aarch64
270        assert!(caps.has_neon);
271        assert!(caps.has_simd());
272    }
273}