Skip to main content

crous_simd/
lib.rs

1//! # crous-simd
2//!
3//! Optional SIMD-accelerated routines for Crous encoding/decoding.
4//!
5//! This crate provides optimized implementations of performance-critical
6//! operations. On aarch64 (Apple Silicon, etc.) it uses NEON intrinsics.
7//! On x86_64 it uses SSE2/AVX2 when available.
8//! All functions have scalar fallbacks for unsupported platforms.
9//!
10//! # Feature flags
11//! - `simd-varint` — enable SIMD-accelerated varint boundary pre-scan.
12//!   Idea: https://github.com/as-com/varint-simd
13//!
14//! # Provided routines
15//! - `batch_decode_varints` — decode multiple LEB128 varints sequentially
16//! - `batch_decode_varints_simd` — SIMD pre-scan variant (feature `simd-varint`)
17//! - `find_byte` — locate first occurrence of a byte (SIMD-accelerated)
18//! - `count_byte` — count occurrences of a byte (SIMD-accelerated)
19//! - `find_non_ascii` — locate first non-ASCII byte (for fast UTF-8 pre-scan)
20
21/// Batch-decode multiple varints from a contiguous buffer (scalar path).
22///
23/// Returns a vector of `(value, bytes_consumed)` pairs.
24pub fn batch_decode_varints(data: &[u8], count: usize) -> Vec<(u64, usize)> {
25    let mut results = Vec::with_capacity(count);
26    let mut offset = 0;
27    for _ in 0..count {
28        if offset >= data.len() {
29            break;
30        }
31        match crous_core::varint::decode_varint(data, offset) {
32            Ok((val, consumed)) => {
33                results.push((val, consumed));
34                offset += consumed;
35            }
36            Err(_) => break,
37        }
38    }
39    results
40}
41
42/// Total bytes consumed by a batch decode.
43pub fn batch_decode_total_consumed(data: &[u8], count: usize) -> usize {
44    let mut offset = 0;
45    for _ in 0..count {
46        if offset >= data.len() {
47            break;
48        }
49        match crous_core::varint::decode_varint(data, offset) {
50            Ok((_val, consumed)) => offset += consumed,
51            Err(_) => break,
52        }
53    }
54    offset
55}
56
57// ── SIMD varint boundary pre-scan (feature = "simd-varint") ──────────
58//
59// The key insight from varint-simd: we can use SIMD to scan for the
60// continuation bit (0x80) across 16 bytes at once to quickly find
61// varint termination bytes, then extract values with scalar code.
62// Citation: https://github.com/as-com/varint-simd
63
64#[cfg(all(feature = "simd-varint", target_arch = "aarch64"))]
65mod simd_varint_neon {
66    use std::arch::aarch64::*;
67
68    /// SIMD pre-scan: find the offset of the first byte without the high bit
69    /// set (i.e., a varint terminator) starting from `offset`.
70    ///
71    /// Returns the length of the varint starting at `offset` (1..=10), or
72    /// None if no terminator found in the next 16 bytes (malformed).
73    ///
74    /// # Safety
75    /// NEON always available on aarch64.
76    #[inline]
77    pub(crate) unsafe fn varint_len_neon(data: &[u8], offset: usize) -> Option<usize> {
78        let remaining = data.len() - offset;
79        if remaining == 0 {
80            return None;
81        }
82
83        if remaining >= 16 {
84            let ptr = data.as_ptr().add(offset);
85            let chunk = unsafe { vld1q_u8(ptr) };
86            let high_bits = unsafe { vshrq_n_u8::<7>(chunk) }; // isolate bit 7
87            // We want the first lane where bit 7 is 0 (terminator)
88            let zero_vec = unsafe { vdupq_n_u8(0) };
89            let is_terminator = unsafe { vceqq_u8(high_bits, zero_vec) };
90            let max_val = unsafe { vmaxvq_u8(is_terminator) };
91            if max_val != 0 {
92                let mut mask = [0u8; 16];
93                unsafe { vst1q_u8(mask.as_mut_ptr(), is_terminator) };
94                for (j, &m) in mask.iter().enumerate() {
95                    if m != 0 {
96                        let len = j + 1;
97                        if len <= 10 {
98                            return Some(len);
99                        } else {
100                            return None; // overflow
101                        }
102                    }
103                }
104            }
105            None
106        } else {
107            // Scalar fallback for tail
108            scalar_varint_len(data, offset)
109        }
110    }
111
112    fn scalar_varint_len(data: &[u8], offset: usize) -> Option<usize> {
113        for i in 0..10.min(data.len() - offset) {
114            if data[offset + i] & 0x80 == 0 {
115                return Some(i + 1);
116            }
117        }
118        None
119    }
120}
121
122/// Batch-decode varints using SIMD pre-scan to determine boundaries first.
123///
124/// This amortizes branch misprediction by scanning continuation bits in bulk.
125/// Falls back to `batch_decode_varints` when the `simd-varint` feature is disabled
126/// or the platform is unsupported.
127///
128/// Citation: SIMD varint idea — https://github.com/as-com/varint-simd
129pub fn batch_decode_varints_simd(data: &[u8], count: usize) -> Vec<(u64, usize)> {
130    #[cfg(all(feature = "simd-varint", target_arch = "aarch64"))]
131    {
132        let mut results = Vec::with_capacity(count);
133        let mut offset = 0;
134        for _ in 0..count {
135            if offset >= data.len() {
136                break;
137            }
138            // Use SIMD to find varint length, then decode scalar
139            let vlen = unsafe { simd_varint_neon::varint_len_neon(data, offset) };
140            match vlen {
141                Some(len) => {
142                    // Fast scalar decode knowing the exact length
143                    match crous_core::varint::decode_varint(data, offset) {
144                        Ok((val, consumed)) => {
145                            debug_assert_eq!(consumed, len);
146                            results.push((val, consumed));
147                            offset += consumed;
148                        }
149                        Err(_) => break,
150                    }
151                }
152                None => {
153                    // Fallback to scalar
154                    match crous_core::varint::decode_varint(data, offset) {
155                        Ok((val, consumed)) => {
156                            results.push((val, consumed));
157                            offset += consumed;
158                        }
159                        Err(_) => break,
160                    }
161                }
162            }
163        }
164        results
165    }
166    #[cfg(not(all(feature = "simd-varint", target_arch = "aarch64")))]
167    {
168        batch_decode_varints(data, count)
169    }
170}
171
172// ── SIMD byte scanning (aarch64 NEON) ────────────────────────────────
173
174#[cfg(target_arch = "aarch64")]
175mod neon {
176    use std::arch::aarch64::*;
177
178    /// Find the first occurrence of `needle` in `data` using NEON.
179    ///
180    /// # Safety
181    /// Caller must ensure NEON is available (always true on aarch64).
182    #[inline]
183    pub(crate) unsafe fn find_byte_neon(data: &[u8], needle: u8) -> Option<usize> {
184        let len = data.len();
185        let ptr = data.as_ptr();
186        let needle_vec = unsafe { vdupq_n_u8(needle) };
187        let mut i = 0;
188
189        // Process 16-byte chunks
190        while i + 16 <= len {
191            let chunk = unsafe { vld1q_u8(ptr.add(i)) };
192            let cmp = unsafe { vceqq_u8(chunk, needle_vec) };
193            // Check if any byte matched
194            let max = unsafe { vmaxvq_u8(cmp) };
195            if max != 0 {
196                // Find the exact position
197                let mut mask_bytes = [0u8; 16];
198                unsafe { vst1q_u8(mask_bytes.as_mut_ptr(), cmp) };
199                for (j, &m) in mask_bytes.iter().enumerate() {
200                    if m != 0 {
201                        return Some(i + j);
202                    }
203                }
204            }
205            i += 16;
206        }
207
208        // Scalar tail
209        while i < len {
210            if unsafe { *ptr.add(i) } == needle {
211                return Some(i);
212            }
213            i += 1;
214        }
215        None
216    }
217
218    /// Count occurrences of `needle` in `data` using NEON.
219    ///
220    /// # Safety
221    /// Caller must ensure NEON is available.
222    #[inline]
223    pub(crate) unsafe fn count_byte_neon(data: &[u8], needle: u8) -> usize {
224        let len = data.len();
225        let ptr = data.as_ptr();
226        let needle_vec = unsafe { vdupq_n_u8(needle) };
227        let mut total: usize = 0;
228        let mut i = 0;
229
230        // Process 16-byte chunks; accumulate per-lane counts
231        // Use vaddlvq_u8 on the mask (0xFF = match, 0 = no match).
232        // Each match contributes 0xFF = 255, and we need count, so divide by 255.
233        while i + 16 <= len {
234            let chunk = unsafe { vld1q_u8(ptr.add(i)) };
235            let cmp = unsafe { vceqq_u8(chunk, needle_vec) };
236            // Each matching lane has value 0xFF. Sum all lanes.
237            // We want count of matches = sum / 255.
238            let sum = unsafe { vaddlvq_u8(cmp) } as usize;
239            total += sum / 255;
240            i += 16;
241        }
242
243        // Scalar tail
244        while i < len {
245            if unsafe { *ptr.add(i) } == needle {
246                total += 1;
247            }
248            i += 1;
249        }
250        total
251    }
252
253    /// Find the first non-ASCII byte (byte >= 0x80) using NEON.
254    ///
255    /// Returns `None` if all bytes are ASCII.
256    ///
257    /// # Safety
258    /// Caller must ensure NEON is available.
259    #[inline]
260    pub(crate) unsafe fn find_non_ascii_neon(data: &[u8]) -> Option<usize> {
261        let len = data.len();
262        let ptr = data.as_ptr();
263        let threshold = unsafe { vdupq_n_u8(0x80) };
264        let mut i = 0;
265
266        while i + 16 <= len {
267            let chunk = unsafe { vld1q_u8(ptr.add(i)) };
268            // Compare >= 0x80 means high bit set
269            let high_bits = unsafe { vcgeq_u8(chunk, threshold) };
270            let max = unsafe { vmaxvq_u8(high_bits) };
271            if max != 0 {
272                let mut mask_bytes = [0u8; 16];
273                unsafe { vst1q_u8(mask_bytes.as_mut_ptr(), high_bits) };
274                for (j, &m) in mask_bytes.iter().enumerate() {
275                    if m != 0 {
276                        return Some(i + j);
277                    }
278                }
279            }
280            i += 16;
281        }
282
283        while i < len {
284            if unsafe { *ptr.add(i) } >= 0x80 {
285                return Some(i);
286            }
287            i += 1;
288        }
289        None
290    }
291}
292
293// ── Public API ───────────────────────────────────────────────────────
294
295/// Scan a byte slice for a specific byte using SIMD where available.
296///
297/// On aarch64, uses NEON intrinsics for 16-byte-at-a-time scanning.
298/// Falls back to a scalar scan on other architectures.
299#[inline]
300pub fn find_byte(data: &[u8], needle: u8) -> Option<usize> {
301    #[cfg(target_arch = "aarch64")]
302    {
303        // NEON is always available on aarch64
304        unsafe { neon::find_byte_neon(data, needle) }
305    }
306    #[cfg(not(target_arch = "aarch64"))]
307    {
308        data.iter().position(|&b| b == needle)
309    }
310}
311
312/// Count the number of occurrences of `needle` in `data`.
313///
314/// On aarch64, uses NEON intrinsics for fast counting.
315#[inline]
316pub fn count_byte(data: &[u8], needle: u8) -> usize {
317    #[cfg(target_arch = "aarch64")]
318    {
319        unsafe { neon::count_byte_neon(data, needle) }
320    }
321    #[cfg(not(target_arch = "aarch64"))]
322    {
323        data.iter().filter(|&&b| b == needle).count()
324    }
325}
326
327/// Find the first byte with the high bit set (non-ASCII).
328///
329/// This is useful for fast UTF-8 pre-scanning: if this returns `None`,
330/// the entire slice is pure ASCII and valid UTF-8.
331#[inline]
332pub fn find_non_ascii(data: &[u8]) -> Option<usize> {
333    #[cfg(target_arch = "aarch64")]
334    {
335        unsafe { neon::find_non_ascii_neon(data) }
336    }
337    #[cfg(not(target_arch = "aarch64"))]
338    {
339        data.iter().position(|&b| b >= 0x80)
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346
347    #[test]
348    fn batch_decode_basic() {
349        let mut data = Vec::new();
350        for v in [0u64, 1, 127, 128, 300] {
351            crous_core::varint::encode_varint_vec(v, &mut data);
352        }
353        let results = batch_decode_varints(&data, 5);
354        assert_eq!(results.len(), 5);
355        assert_eq!(results[0].0, 0);
356        assert_eq!(results[1].0, 1);
357        assert_eq!(results[2].0, 127);
358        assert_eq!(results[3].0, 128);
359        assert_eq!(results[4].0, 300);
360    }
361
362    #[test]
363    fn batch_decode_simd_matches_scalar() {
364        let mut data = Vec::new();
365        let values = [0u64, 1, 42, 127, 128, 255, 300, 16384, u64::MAX];
366        for v in &values {
367            crous_core::varint::encode_varint_vec(*v, &mut data);
368        }
369        let scalar = batch_decode_varints(&data, values.len());
370        let simd = batch_decode_varints_simd(&data, values.len());
371        assert_eq!(scalar.len(), simd.len());
372        for (s, d) in scalar.iter().zip(simd.iter()) {
373            assert_eq!(s.0, d.0, "value mismatch");
374            assert_eq!(s.1, d.1, "consumed mismatch");
375        }
376    }
377
378    #[test]
379    fn find_byte_basic() {
380        assert_eq!(find_byte(b"hello", b'l'), Some(2));
381        assert_eq!(find_byte(b"hello", b'z'), None);
382    }
383
384    #[test]
385    fn find_byte_long() {
386        // Test with data longer than 16 bytes to exercise SIMD path
387        let data: Vec<u8> = (0..256).map(|i| i as u8).collect();
388        assert_eq!(find_byte(&data, 0), Some(0));
389        assert_eq!(find_byte(&data, 42), Some(42));
390        assert_eq!(find_byte(&data, 255), Some(255));
391
392        let zeros = vec![0u8; 100];
393        assert_eq!(find_byte(&zeros, 1), None);
394    }
395
396    #[test]
397    fn count_byte_basic() {
398        assert_eq!(count_byte(b"hello", b'l'), 2);
399        assert_eq!(count_byte(b"hello", b'z'), 0);
400        assert_eq!(count_byte(b"hello", b'o'), 1);
401    }
402
403    #[test]
404    fn count_byte_long() {
405        let data = vec![0xABu8; 200];
406        assert_eq!(count_byte(&data, 0xAB), 200);
407        assert_eq!(count_byte(&data, 0x00), 0);
408    }
409
410    #[test]
411    fn find_non_ascii_basic() {
412        assert_eq!(find_non_ascii(b"hello"), None);
413        assert_eq!(find_non_ascii(b"hello\x80"), Some(5));
414        assert_eq!(find_non_ascii(b"\xff"), Some(0));
415    }
416
417    #[test]
418    fn find_non_ascii_long() {
419        let mut data = vec![b'a'; 100];
420        assert_eq!(find_non_ascii(&data), None);
421        data[50] = 0x80;
422        assert_eq!(find_non_ascii(&data), Some(50));
423    }
424}