entrenar/efficiency/device/
simd.rs1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
7pub enum SimdCapability {
8 #[default]
10 None,
11 Sse4,
13 Avx2,
15 Avx512,
17 Neon,
19}
20
21impl SimdCapability {
22 pub fn vector_width_bits(&self) -> u32 {
24 match self {
25 Self::None => 0,
26 Self::Sse4 => 128,
27 Self::Avx2 => 256,
28 Self::Avx512 => 512,
29 Self::Neon => 128,
30 }
31 }
32
33 #[cfg(target_arch = "x86_64")]
35 pub fn detect() -> Self {
36 if is_x86_feature_detected!("avx512f") {
37 Self::Avx512
38 } else if is_x86_feature_detected!("avx2") {
39 Self::Avx2
40 } else if is_x86_feature_detected!("sse4.1") {
41 Self::Sse4
42 } else {
43 Self::None
44 }
45 }
46
47 #[cfg(target_arch = "aarch64")]
49 pub fn detect() -> Self {
50 Self::Neon
52 }
53
54 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
56 pub fn detect() -> Self {
57 Self::None
58 }
59}
60
61impl std::fmt::Display for SimdCapability {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 match self {
64 Self::None => write!(f, "none"),
65 Self::Sse4 => write!(f, "SSE4"),
66 Self::Avx2 => write!(f, "AVX2"),
67 Self::Avx512 => write!(f, "AVX-512"),
68 Self::Neon => write!(f, "NEON"),
69 }
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76
77 #[test]
78 fn test_simd_capability_default() {
79 assert_eq!(SimdCapability::default(), SimdCapability::None);
80 }
81
82 #[test]
83 fn test_vector_width_bits_none() {
84 assert_eq!(SimdCapability::None.vector_width_bits(), 0);
85 }
86
87 #[test]
88 fn test_vector_width_bits_sse4() {
89 assert_eq!(SimdCapability::Sse4.vector_width_bits(), 128);
90 }
91
92 #[test]
93 fn test_vector_width_bits_avx2() {
94 assert_eq!(SimdCapability::Avx2.vector_width_bits(), 256);
95 }
96
97 #[test]
98 fn test_vector_width_bits_avx512() {
99 assert_eq!(SimdCapability::Avx512.vector_width_bits(), 512);
100 }
101
102 #[test]
103 fn test_vector_width_bits_neon() {
104 assert_eq!(SimdCapability::Neon.vector_width_bits(), 128);
105 }
106
107 #[test]
108 fn test_simd_capability_display_none() {
109 assert_eq!(SimdCapability::None.to_string(), "none");
110 }
111
112 #[test]
113 fn test_simd_capability_display_sse4() {
114 assert_eq!(SimdCapability::Sse4.to_string(), "SSE4");
115 }
116
117 #[test]
118 fn test_simd_capability_display_avx2() {
119 assert_eq!(SimdCapability::Avx2.to_string(), "AVX2");
120 }
121
122 #[test]
123 fn test_simd_capability_display_avx512() {
124 assert_eq!(SimdCapability::Avx512.to_string(), "AVX-512");
125 }
126
127 #[test]
128 fn test_simd_capability_display_neon() {
129 assert_eq!(SimdCapability::Neon.to_string(), "NEON");
130 }
131
132 #[test]
133 fn test_simd_capability_detect() {
134 let detected = SimdCapability::detect();
135 let _ = detected.vector_width_bits(); }
138
139 #[test]
140 fn test_simd_capability_clone() {
141 let cap = SimdCapability::Avx2;
142 let cloned = cap;
143 assert_eq!(cap, cloned);
144 }
145
146 #[test]
147 fn test_simd_capability_eq() {
148 assert_eq!(SimdCapability::Avx2, SimdCapability::Avx2);
149 assert_ne!(SimdCapability::Avx2, SimdCapability::Avx512);
150 }
151
152 #[test]
153 fn test_simd_capability_hash() {
154 use std::collections::HashSet;
155 let mut set = HashSet::new();
156 set.insert(SimdCapability::Avx2);
157 set.insert(SimdCapability::Avx2);
158 assert_eq!(set.len(), 1);
159 set.insert(SimdCapability::Avx512);
160 assert_eq!(set.len(), 2);
161 }
162
163 #[test]
164 fn test_simd_capability_serde() {
165 let cap = SimdCapability::Avx512;
166 let json = serde_json::to_string(&cap).expect("JSON serialization should succeed");
167 let deserialized: SimdCapability =
168 serde_json::from_str(&json).expect("JSON deserialization should succeed");
169 assert_eq!(cap, deserialized);
170 }
171
172 #[test]
173 fn test_simd_capability_debug() {
174 assert_eq!(format!("{:?}", SimdCapability::None), "None");
175 assert_eq!(format!("{:?}", SimdCapability::Sse4), "Sse4");
176 assert_eq!(format!("{:?}", SimdCapability::Avx2), "Avx2");
177 assert_eq!(format!("{:?}", SimdCapability::Avx512), "Avx512");
178 assert_eq!(format!("{:?}", SimdCapability::Neon), "Neon");
179 }
180}