1use crate::{Dodecet, Result};
6
7#[cfg(target_arch = "x86_64")]
8use std::arch::x86_64::*;
9
10#[cfg(target_arch = "aarch64")]
11use std::arch::aarch64::*;
12
13pub struct SimdOps;
15
16impl SimdOps {
17 #[cfg(target_arch = "x86_64")]
19 #[target_feature(enable = "avx2")]
20 #[inline]
21 pub unsafe fn normalize_avx2(dodecets: &[Dodecet], mut output: &mut [f32]) -> Result<()> {
22 if dodecets.len() != output.len() {
23 return Err(crate::DodecetError::InvalidLength);
24 }
25
26 const MAX_DODECET_F32: f32 = 4095.0;
27
28 let chunks = dodecets.chunks_exact(8);
29 let remainder = chunks.remainder();
30
31 for chunk in chunks {
33 let values = [
35 chunk[0].value() as f32,
36 chunk[1].value() as f32,
37 chunk[2].value() as f32,
38 chunk[3].value() as f32,
39 chunk[4].value() as f32,
40 chunk[5].value() as f32,
41 chunk[6].value() as f32,
42 chunk[7].value() as f32,
43 ];
44
45 let vec = _mm256_loadu_ps(values.as_ptr());
47
48 let max_vec = _mm256_set1_ps(MAX_DODECET_F32);
50 let normalized = _mm256_div_ps(vec, max_vec);
51
52 _mm256_storeu_ps(output.as_mut_ptr(), normalized);
54
55 output = &mut output[8..];
56 }
57
58 for (i, &d) in remainder.iter().enumerate() {
60 output[i] = d.value() as f32 / MAX_DODECET_F32;
61 }
62
63 Ok(())
64 }
65
66 #[cfg(target_arch = "aarch64")]
68 #[target_feature(enable = "neon")]
69 #[inline]
70 pub unsafe fn normalize_neon(dodecets: &[Dodecet], mut output: &mut [f32]) -> Result<()> {
71 if dodecets.len() != output.len() {
72 return Err(crate::DodecetError::InvalidLength);
73 }
74
75 const MAX_DODECET_F32: f32 = 4095.0;
76
77 let chunks = dodecets.chunks_exact(4);
78 let remainder = chunks.remainder();
79
80 for chunk in chunks {
82 let values = [
84 chunk[0].value() as f32,
85 chunk[1].value() as f32,
86 chunk[2].value() as f32,
87 chunk[3].value() as f32,
88 ];
89
90 let vec = vld1q_f32(values.as_ptr());
92
93 let max_vec = vdupq_n_f32(MAX_DODECET_F32);
95 let normalized = vdivq_f32(vec, max_vec);
96
97 vst1q_f32(output.as_mut_ptr(), normalized);
99
100 output = &mut output[4..];
101 }
102
103 for (i, &d) in remainder.iter().enumerate() {
105 output[i] = d.value() as f32 / MAX_DODECET_F32;
106 }
107
108 Ok(())
109 }
110
111 #[inline]
113 pub fn normalize_scalar(dodecets: &[Dodecet], output: &mut [f32]) -> Result<()> {
114 if dodecets.len() != output.len() {
115 return Err(crate::DodecetError::InvalidLength);
116 }
117
118 const MAX_DODECET_F32: f32 = 4095.0;
119
120 for (i, &d) in dodecets.iter().enumerate() {
121 output[i] = d.value() as f32 / MAX_DODECET_F32;
122 }
123
124 Ok(())
125 }
126
127 pub fn normalize_auto(dodecets: &[Dodecet], output: &mut [f32]) -> Result<()> {
129 #[cfg(target_arch = "x86_64")]
130 {
131 if is_x86_feature_detected!("avx2") {
132 return unsafe { Self::normalize_avx2(dodecets, output) };
133 }
134 }
135
136 #[cfg(target_arch = "aarch64")]
137 {
138 if std::arch::is_aarch64_feature_detected!("neon") {
139 return unsafe { Self::normalize_neon(dodecets, output) };
140 }
141 }
142
143 Self::normalize_scalar(dodecets, output)
145 }
146}
147
148pub struct SimdHex;
150
151impl SimdHex {
152 pub fn encode(dodecets: &[Dodecet]) -> String {
154 let mut output = String::with_capacity(dodecets.len() * 3);
156
157 for d in dodecets {
158 let value = d.value();
159 let hex = [
160 b"0123456789ABCDEF"[(value >> 8) as usize],
161 b"0123456789ABCDEF"[((value >> 4) & 0xF) as usize],
162 b"0123456789ABCDEF"[(value & 0xF) as usize],
163 ];
164
165 output.push_str(unsafe { std::str::from_utf8_unchecked(&hex) });
166 }
167
168 output
169 }
170
171 pub fn decode(hex: &str) -> Result<Vec<Dodecet>> {
173 if hex.len() % 3 != 0 {
174 return Err(crate::DodecetError::InvalidHex);
175 }
176
177 let mut dodecets = Vec::with_capacity(hex.len() / 3);
178
179 for chunk in hex.as_bytes().chunks_exact(3) {
180 let mut value: u16 = 0;
181
182 for &byte in chunk {
183 let digit = if byte.is_ascii_digit() {
184 byte - b'0'
185 } else if byte.is_ascii_uppercase() {
186 byte - b'A' + 10
187 } else if byte.is_ascii_lowercase() {
188 byte - b'a' + 10
189 } else {
190 return Err(crate::DodecetError::InvalidHex);
191 };
192
193 value = (value << 4) | (digit as u16);
194 }
195
196 dodecets.push(Dodecet::from_hex(value));
197 }
198
199 Ok(dodecets)
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206
207 #[test]
208 fn test_normalize_auto() {
209 let dodecets: Vec<Dodecet> = (0..16).map(Dodecet::from_hex).collect();
210 let mut output = vec![0.0f32; 16];
211
212 SimdOps::normalize_auto(&dodecets, &mut output).unwrap();
213
214 for (i, &d) in dodecets.iter().enumerate() {
215 let expected = d.value() as f32 / 4095.0;
216 assert!((output[i] - expected).abs() < 0.0001);
217 }
218 }
219
220 #[test]
221 fn test_simd_hex_encode() {
222 let dodecets: Vec<Dodecet> = vec![
223 Dodecet::from_hex(0xABC),
224 Dodecet::from_hex(0x123),
225 Dodecet::from_hex(0x456),
226 ];
227
228 let hex = SimdHex::encode(&dodecets);
229 assert_eq!(hex, "ABC123456");
230 }
231
232 #[test]
233 fn test_simd_hex_decode() {
234 let hex = "ABC123456";
235 let dodecets = SimdHex::decode(hex).unwrap();
236
237 assert_eq!(dodecets[0].value(), 0xABC);
238 assert_eq!(dodecets[1].value(), 0x123);
239 assert_eq!(dodecets[2].value(), 0x456);
240 }
241
242 #[test]
243 fn test_simd_hex_roundtrip() {
244 let original: Vec<Dodecet> = (0..100).map(|i| Dodecet::from_hex(i % 4096)).collect();
245 let hex = SimdHex::encode(&original);
246 let decoded = SimdHex::decode(&hex).unwrap();
247
248 assert_eq!(original.len(), decoded.len());
249 for (o, d) in original.iter().zip(decoded.iter()) {
250 assert_eq!(o.value(), d.value());
251 }
252 }
253}