Skip to main content

entrenar/efficiency/device/
simd.rs

1//! SIMD capability detection and abstraction.
2
3use serde::{Deserialize, Serialize};
4
5/// SIMD capability levels
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
7pub enum SimdCapability {
8    /// No SIMD support
9    #[default]
10    None,
11    /// SSE4.1/4.2 (128-bit)
12    Sse4,
13    /// AVX2 (256-bit)
14    Avx2,
15    /// AVX-512 (512-bit)
16    Avx512,
17    /// ARM NEON (128-bit)
18    Neon,
19}
20
21impl SimdCapability {
22    /// Returns the vector width in bits
23    pub fn vector_width_bits(&self) -> u32 {
24        match self {
25            Self::None => 0,
26            Self::Sse4 => 128,
27            Self::Avx2 => 256,
28            Self::Avx512 => 512,
29            Self::Neon => 128,
30        }
31    }
32
33    /// Detect SIMD capability of current CPU
34    #[cfg(target_arch = "x86_64")]
35    pub fn detect() -> Self {
36        if is_x86_feature_detected!("avx512f") {
37            Self::Avx512
38        } else if is_x86_feature_detected!("avx2") {
39            Self::Avx2
40        } else if is_x86_feature_detected!("sse4.1") {
41            Self::Sse4
42        } else {
43            Self::None
44        }
45    }
46
47    /// Detect SIMD capability of current CPU (ARM)
48    #[cfg(target_arch = "aarch64")]
49    pub fn detect() -> Self {
50        // NEON is mandatory on aarch64
51        Self::Neon
52    }
53
54    /// Detect SIMD capability (fallback for other architectures)
55    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
56    pub fn detect() -> Self {
57        Self::None
58    }
59}
60
61impl std::fmt::Display for SimdCapability {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        match self {
64            Self::None => write!(f, "none"),
65            Self::Sse4 => write!(f, "SSE4"),
66            Self::Avx2 => write!(f, "AVX2"),
67            Self::Avx512 => write!(f, "AVX-512"),
68            Self::Neon => write!(f, "NEON"),
69        }
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76
77    #[test]
78    fn test_simd_capability_default() {
79        assert_eq!(SimdCapability::default(), SimdCapability::None);
80    }
81
82    #[test]
83    fn test_vector_width_bits_none() {
84        assert_eq!(SimdCapability::None.vector_width_bits(), 0);
85    }
86
87    #[test]
88    fn test_vector_width_bits_sse4() {
89        assert_eq!(SimdCapability::Sse4.vector_width_bits(), 128);
90    }
91
92    #[test]
93    fn test_vector_width_bits_avx2() {
94        assert_eq!(SimdCapability::Avx2.vector_width_bits(), 256);
95    }
96
97    #[test]
98    fn test_vector_width_bits_avx512() {
99        assert_eq!(SimdCapability::Avx512.vector_width_bits(), 512);
100    }
101
102    #[test]
103    fn test_vector_width_bits_neon() {
104        assert_eq!(SimdCapability::Neon.vector_width_bits(), 128);
105    }
106
107    #[test]
108    fn test_simd_capability_display_none() {
109        assert_eq!(SimdCapability::None.to_string(), "none");
110    }
111
112    #[test]
113    fn test_simd_capability_display_sse4() {
114        assert_eq!(SimdCapability::Sse4.to_string(), "SSE4");
115    }
116
117    #[test]
118    fn test_simd_capability_display_avx2() {
119        assert_eq!(SimdCapability::Avx2.to_string(), "AVX2");
120    }
121
122    #[test]
123    fn test_simd_capability_display_avx512() {
124        assert_eq!(SimdCapability::Avx512.to_string(), "AVX-512");
125    }
126
127    #[test]
128    fn test_simd_capability_display_neon() {
129        assert_eq!(SimdCapability::Neon.to_string(), "NEON");
130    }
131
132    #[test]
133    fn test_simd_capability_detect() {
134        let detected = SimdCapability::detect();
135        // Just verify it returns one of the valid variants
136        let _ = detected.vector_width_bits(); // Should not panic
137    }
138
139    #[test]
140    fn test_simd_capability_clone() {
141        let cap = SimdCapability::Avx2;
142        let cloned = cap;
143        assert_eq!(cap, cloned);
144    }
145
146    #[test]
147    fn test_simd_capability_eq() {
148        assert_eq!(SimdCapability::Avx2, SimdCapability::Avx2);
149        assert_ne!(SimdCapability::Avx2, SimdCapability::Avx512);
150    }
151
152    #[test]
153    fn test_simd_capability_hash() {
154        use std::collections::HashSet;
155        let mut set = HashSet::new();
156        set.insert(SimdCapability::Avx2);
157        set.insert(SimdCapability::Avx2);
158        assert_eq!(set.len(), 1);
159        set.insert(SimdCapability::Avx512);
160        assert_eq!(set.len(), 2);
161    }
162
163    #[test]
164    fn test_simd_capability_serde() {
165        let cap = SimdCapability::Avx512;
166        let json = serde_json::to_string(&cap).expect("JSON serialization should succeed");
167        let deserialized: SimdCapability =
168            serde_json::from_str(&json).expect("JSON deserialization should succeed");
169        assert_eq!(cap, deserialized);
170    }
171
172    #[test]
173    fn test_simd_capability_debug() {
174        assert_eq!(format!("{:?}", SimdCapability::None), "None");
175        assert_eq!(format!("{:?}", SimdCapability::Sse4), "Sse4");
176        assert_eq!(format!("{:?}", SimdCapability::Avx2), "Avx2");
177        assert_eq!(format!("{:?}", SimdCapability::Avx512), "Avx512");
178        assert_eq!(format!("{:?}", SimdCapability::Neon), "Neon");
179    }
180}