Skip to main content

dodecet_encoder/
simd.rs

1//! SIMD-optimized operations for dodecet processing
2//!
3//! Provides vectorized implementations using platform-specific SIMD
4
5use crate::{Dodecet, Result};
6
7#[cfg(target_arch = "x86_64")]
8use std::arch::x86_64::*;
9
10#[cfg(target_arch = "aarch64")]
11use std::arch::aarch64::*;
12
13/// SIMD-optimized dodecet array operations
14pub struct SimdOps;
15
16impl SimdOps {
17    /// Convert array of dodecets to normalized floats using SIMD
18    #[cfg(target_arch = "x86_64")]
19    #[target_feature(enable = "avx2")]
20    #[inline]
21    pub unsafe fn normalize_avx2(dodecets: &[Dodecet], mut output: &mut [f32]) -> Result<()> {
22        if dodecets.len() != output.len() {
23            return Err(crate::DodecetError::InvalidLength);
24        }
25
26        const MAX_DODECET_F32: f32 = 4095.0;
27
28        let chunks = dodecets.chunks_exact(8);
29        let remainder = chunks.remainder();
30
31        // Process 8 dodecets at a time using AVX2
32        for chunk in chunks {
33            // Load dodecet values
34            let values = [
35                chunk[0].value() as f32,
36                chunk[1].value() as f32,
37                chunk[2].value() as f32,
38                chunk[3].value() as f32,
39                chunk[4].value() as f32,
40                chunk[5].value() as f32,
41                chunk[6].value() as f32,
42                chunk[7].value() as f32,
43            ];
44
45            // Load into AVX2 register
46            let vec = _mm256_loadu_ps(values.as_ptr());
47
48            // Divide by max value
49            let max_vec = _mm256_set1_ps(MAX_DODECET_F32);
50            let normalized = _mm256_div_ps(vec, max_vec);
51
52            // Store result
53            _mm256_storeu_ps(output.as_mut_ptr(), normalized);
54
55            output = &mut output[8..];
56        }
57
58        // Handle remainder
59        for (i, &d) in remainder.iter().enumerate() {
60            output[i] = d.value() as f32 / MAX_DODECET_F32;
61        }
62
63        Ok(())
64    }
65
66    /// Convert array of dodecets to normalized floats using SIMD (ARM NEON)
67    #[cfg(target_arch = "aarch64")]
68    #[target_feature(enable = "neon")]
69    #[inline]
70    pub unsafe fn normalize_neon(dodecets: &[Dodecet], mut output: &mut [f32]) -> Result<()> {
71        if dodecets.len() != output.len() {
72            return Err(crate::DodecetError::InvalidLength);
73        }
74
75        const MAX_DODECET_F32: f32 = 4095.0;
76
77        let chunks = dodecets.chunks_exact(4);
78        let remainder = chunks.remainder();
79
80        // Process 4 dodecets at a time using NEON
81        for chunk in chunks {
82            // Load dodecet values
83            let values = [
84                chunk[0].value() as f32,
85                chunk[1].value() as f32,
86                chunk[2].value() as f32,
87                chunk[3].value() as f32,
88            ];
89
90            // Load into NEON register
91            let vec = vld1q_f32(values.as_ptr());
92
93            // Divide by max value
94            let max_vec = vdupq_n_f32(MAX_DODECET_F32);
95            let normalized = vdivq_f32(vec, max_vec);
96
97            // Store result
98            vst1q_f32(output.as_mut_ptr(), normalized);
99
100            output = &mut output[4..];
101        }
102
103        // Handle remainder
104        for (i, &d) in remainder.iter().enumerate() {
105            output[i] = d.value() as f32 / MAX_DODECET_F32;
106        }
107
108        Ok(())
109    }
110
111    /// Fallback scalar implementation (available on all platforms)
112    #[inline]
113    pub fn normalize_scalar(dodecets: &[Dodecet], output: &mut [f32]) -> Result<()> {
114        if dodecets.len() != output.len() {
115            return Err(crate::DodecetError::InvalidLength);
116        }
117
118        const MAX_DODECET_F32: f32 = 4095.0;
119
120        for (i, &d) in dodecets.iter().enumerate() {
121            output[i] = d.value() as f32 / MAX_DODECET_F32;
122        }
123
124        Ok(())
125    }
126
127    /// Batch normalize with automatic SIMD detection
128    pub fn normalize_auto(dodecets: &[Dodecet], output: &mut [f32]) -> Result<()> {
129        #[cfg(target_arch = "x86_64")]
130        {
131            if is_x86_feature_detected!("avx2") {
132                return unsafe { Self::normalize_avx2(dodecets, output) };
133            }
134        }
135
136        #[cfg(target_arch = "aarch64")]
137        {
138            if std::arch::is_aarch64_feature_detected!("neon") {
139                return unsafe { Self::normalize_neon(dodecets, output) };
140            }
141        }
142
143        // Fallback to scalar implementation on all platforms
144        Self::normalize_scalar(dodecets, output)
145    }
146}
147
148/// SIMD-optimized hex encoding
149pub struct SimdHex;
150
151impl SimdHex {
152    /// Encode dodecets to hex string using SIMD
153    pub fn encode(dodecets: &[Dodecet]) -> String {
154        // Pre-allocate output string
155        let mut output = String::with_capacity(dodecets.len() * 3);
156
157        for d in dodecets {
158            let value = d.value();
159            let hex = [
160                b"0123456789ABCDEF"[(value >> 8) as usize],
161                b"0123456789ABCDEF"[((value >> 4) & 0xF) as usize],
162                b"0123456789ABCDEF"[(value & 0xF) as usize],
163            ];
164
165            output.push_str(unsafe { std::str::from_utf8_unchecked(&hex) });
166        }
167
168        output
169    }
170
171    /// Decode hex string to dodecets using SIMD
172    pub fn decode(hex: &str) -> Result<Vec<Dodecet>> {
173        if hex.len() % 3 != 0 {
174            return Err(crate::DodecetError::InvalidHex);
175        }
176
177        let mut dodecets = Vec::with_capacity(hex.len() / 3);
178
179        for chunk in hex.as_bytes().chunks_exact(3) {
180            let mut value: u16 = 0;
181
182            for &byte in chunk {
183                let digit = if byte.is_ascii_digit() {
184                    byte - b'0'
185                } else if byte.is_ascii_uppercase() {
186                    byte - b'A' + 10
187                } else if byte.is_ascii_lowercase() {
188                    byte - b'a' + 10
189                } else {
190                    return Err(crate::DodecetError::InvalidHex);
191                };
192
193                value = (value << 4) | (digit as u16);
194            }
195
196            dodecets.push(Dodecet::from_hex(value));
197        }
198
199        Ok(dodecets)
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206
207    #[test]
208    fn test_normalize_auto() {
209        let dodecets: Vec<Dodecet> = (0..16).map(Dodecet::from_hex).collect();
210        let mut output = vec![0.0f32; 16];
211
212        SimdOps::normalize_auto(&dodecets, &mut output).unwrap();
213
214        for (i, &d) in dodecets.iter().enumerate() {
215            let expected = d.value() as f32 / 4095.0;
216            assert!((output[i] - expected).abs() < 0.0001);
217        }
218    }
219
220    #[test]
221    fn test_simd_hex_encode() {
222        let dodecets: Vec<Dodecet> = vec![
223            Dodecet::from_hex(0xABC),
224            Dodecet::from_hex(0x123),
225            Dodecet::from_hex(0x456),
226        ];
227
228        let hex = SimdHex::encode(&dodecets);
229        assert_eq!(hex, "ABC123456");
230    }
231
232    #[test]
233    fn test_simd_hex_decode() {
234        let hex = "ABC123456";
235        let dodecets = SimdHex::decode(hex).unwrap();
236
237        assert_eq!(dodecets[0].value(), 0xABC);
238        assert_eq!(dodecets[1].value(), 0x123);
239        assert_eq!(dodecets[2].value(), 0x456);
240    }
241
242    #[test]
243    fn test_simd_hex_roundtrip() {
244        let original: Vec<Dodecet> = (0..100).map(|i| Dodecet::from_hex(i % 4096)).collect();
245        let hex = SimdHex::encode(&original);
246        let decoded = SimdHex::decode(&hex).unwrap();
247
248        assert_eq!(original.len(), decoded.len());
249        for (o, d) in original.iter().zip(decoded.iter()) {
250            assert_eq!(o.value(), d.value());
251        }
252    }
253}