1pub fn batch_decode_varints(data: &[u8], count: usize) -> Vec<(u64, usize)> {
25 let mut results = Vec::with_capacity(count);
26 let mut offset = 0;
27 for _ in 0..count {
28 if offset >= data.len() {
29 break;
30 }
31 match crous_core::varint::decode_varint(data, offset) {
32 Ok((val, consumed)) => {
33 results.push((val, consumed));
34 offset += consumed;
35 }
36 Err(_) => break,
37 }
38 }
39 results
40}
41
42pub fn batch_decode_total_consumed(data: &[u8], count: usize) -> usize {
44 let mut offset = 0;
45 for _ in 0..count {
46 if offset >= data.len() {
47 break;
48 }
49 match crous_core::varint::decode_varint(data, offset) {
50 Ok((_val, consumed)) => offset += consumed,
51 Err(_) => break,
52 }
53 }
54 offset
55}
56
57#[cfg(all(feature = "simd-varint", target_arch = "aarch64"))]
65mod simd_varint_neon {
66 use std::arch::aarch64::*;
67
68 #[inline]
77 pub(crate) unsafe fn varint_len_neon(data: &[u8], offset: usize) -> Option<usize> {
78 let remaining = data.len() - offset;
79 if remaining == 0 {
80 return None;
81 }
82
83 if remaining >= 16 {
84 let ptr = data.as_ptr().add(offset);
85 let chunk = unsafe { vld1q_u8(ptr) };
86 let high_bits = unsafe { vshrq_n_u8::<7>(chunk) }; let zero_vec = unsafe { vdupq_n_u8(0) };
89 let is_terminator = unsafe { vceqq_u8(high_bits, zero_vec) };
90 let max_val = unsafe { vmaxvq_u8(is_terminator) };
91 if max_val != 0 {
92 let mut mask = [0u8; 16];
93 unsafe { vst1q_u8(mask.as_mut_ptr(), is_terminator) };
94 for (j, &m) in mask.iter().enumerate() {
95 if m != 0 {
96 let len = j + 1;
97 if len <= 10 {
98 return Some(len);
99 } else {
100 return None; }
102 }
103 }
104 }
105 None
106 } else {
107 scalar_varint_len(data, offset)
109 }
110 }
111
112 fn scalar_varint_len(data: &[u8], offset: usize) -> Option<usize> {
113 for i in 0..10.min(data.len() - offset) {
114 if data[offset + i] & 0x80 == 0 {
115 return Some(i + 1);
116 }
117 }
118 None
119 }
120}
121
122pub fn batch_decode_varints_simd(data: &[u8], count: usize) -> Vec<(u64, usize)> {
130 #[cfg(all(feature = "simd-varint", target_arch = "aarch64"))]
131 {
132 let mut results = Vec::with_capacity(count);
133 let mut offset = 0;
134 for _ in 0..count {
135 if offset >= data.len() {
136 break;
137 }
138 let vlen = unsafe { simd_varint_neon::varint_len_neon(data, offset) };
140 match vlen {
141 Some(len) => {
142 match crous_core::varint::decode_varint(data, offset) {
144 Ok((val, consumed)) => {
145 debug_assert_eq!(consumed, len);
146 results.push((val, consumed));
147 offset += consumed;
148 }
149 Err(_) => break,
150 }
151 }
152 None => {
153 match crous_core::varint::decode_varint(data, offset) {
155 Ok((val, consumed)) => {
156 results.push((val, consumed));
157 offset += consumed;
158 }
159 Err(_) => break,
160 }
161 }
162 }
163 }
164 results
165 }
166 #[cfg(not(all(feature = "simd-varint", target_arch = "aarch64")))]
167 {
168 batch_decode_varints(data, count)
169 }
170}
171
172#[cfg(target_arch = "aarch64")]
175mod neon {
176 use std::arch::aarch64::*;
177
178 #[inline]
183 pub(crate) unsafe fn find_byte_neon(data: &[u8], needle: u8) -> Option<usize> {
184 let len = data.len();
185 let ptr = data.as_ptr();
186 let needle_vec = unsafe { vdupq_n_u8(needle) };
187 let mut i = 0;
188
189 while i + 16 <= len {
191 let chunk = unsafe { vld1q_u8(ptr.add(i)) };
192 let cmp = unsafe { vceqq_u8(chunk, needle_vec) };
193 let max = unsafe { vmaxvq_u8(cmp) };
195 if max != 0 {
196 let mut mask_bytes = [0u8; 16];
198 unsafe { vst1q_u8(mask_bytes.as_mut_ptr(), cmp) };
199 for (j, &m) in mask_bytes.iter().enumerate() {
200 if m != 0 {
201 return Some(i + j);
202 }
203 }
204 }
205 i += 16;
206 }
207
208 while i < len {
210 if unsafe { *ptr.add(i) } == needle {
211 return Some(i);
212 }
213 i += 1;
214 }
215 None
216 }
217
218 #[inline]
223 pub(crate) unsafe fn count_byte_neon(data: &[u8], needle: u8) -> usize {
224 let len = data.len();
225 let ptr = data.as_ptr();
226 let needle_vec = unsafe { vdupq_n_u8(needle) };
227 let mut total: usize = 0;
228 let mut i = 0;
229
230 while i + 16 <= len {
234 let chunk = unsafe { vld1q_u8(ptr.add(i)) };
235 let cmp = unsafe { vceqq_u8(chunk, needle_vec) };
236 let sum = unsafe { vaddlvq_u8(cmp) } as usize;
239 total += sum / 255;
240 i += 16;
241 }
242
243 while i < len {
245 if unsafe { *ptr.add(i) } == needle {
246 total += 1;
247 }
248 i += 1;
249 }
250 total
251 }
252
253 #[inline]
260 pub(crate) unsafe fn find_non_ascii_neon(data: &[u8]) -> Option<usize> {
261 let len = data.len();
262 let ptr = data.as_ptr();
263 let threshold = unsafe { vdupq_n_u8(0x80) };
264 let mut i = 0;
265
266 while i + 16 <= len {
267 let chunk = unsafe { vld1q_u8(ptr.add(i)) };
268 let high_bits = unsafe { vcgeq_u8(chunk, threshold) };
270 let max = unsafe { vmaxvq_u8(high_bits) };
271 if max != 0 {
272 let mut mask_bytes = [0u8; 16];
273 unsafe { vst1q_u8(mask_bytes.as_mut_ptr(), high_bits) };
274 for (j, &m) in mask_bytes.iter().enumerate() {
275 if m != 0 {
276 return Some(i + j);
277 }
278 }
279 }
280 i += 16;
281 }
282
283 while i < len {
284 if unsafe { *ptr.add(i) } >= 0x80 {
285 return Some(i);
286 }
287 i += 1;
288 }
289 None
290 }
291}
292
293#[inline]
300pub fn find_byte(data: &[u8], needle: u8) -> Option<usize> {
301 #[cfg(target_arch = "aarch64")]
302 {
303 unsafe { neon::find_byte_neon(data, needle) }
305 }
306 #[cfg(not(target_arch = "aarch64"))]
307 {
308 data.iter().position(|&b| b == needle)
309 }
310}
311
312#[inline]
316pub fn count_byte(data: &[u8], needle: u8) -> usize {
317 #[cfg(target_arch = "aarch64")]
318 {
319 unsafe { neon::count_byte_neon(data, needle) }
320 }
321 #[cfg(not(target_arch = "aarch64"))]
322 {
323 data.iter().filter(|&&b| b == needle).count()
324 }
325}
326
327#[inline]
332pub fn find_non_ascii(data: &[u8]) -> Option<usize> {
333 #[cfg(target_arch = "aarch64")]
334 {
335 unsafe { neon::find_non_ascii_neon(data) }
336 }
337 #[cfg(not(target_arch = "aarch64"))]
338 {
339 data.iter().position(|&b| b >= 0x80)
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346
347 #[test]
348 fn batch_decode_basic() {
349 let mut data = Vec::new();
350 for v in [0u64, 1, 127, 128, 300] {
351 crous_core::varint::encode_varint_vec(v, &mut data);
352 }
353 let results = batch_decode_varints(&data, 5);
354 assert_eq!(results.len(), 5);
355 assert_eq!(results[0].0, 0);
356 assert_eq!(results[1].0, 1);
357 assert_eq!(results[2].0, 127);
358 assert_eq!(results[3].0, 128);
359 assert_eq!(results[4].0, 300);
360 }
361
362 #[test]
363 fn batch_decode_simd_matches_scalar() {
364 let mut data = Vec::new();
365 let values = [0u64, 1, 42, 127, 128, 255, 300, 16384, u64::MAX];
366 for v in &values {
367 crous_core::varint::encode_varint_vec(*v, &mut data);
368 }
369 let scalar = batch_decode_varints(&data, values.len());
370 let simd = batch_decode_varints_simd(&data, values.len());
371 assert_eq!(scalar.len(), simd.len());
372 for (s, d) in scalar.iter().zip(simd.iter()) {
373 assert_eq!(s.0, d.0, "value mismatch");
374 assert_eq!(s.1, d.1, "consumed mismatch");
375 }
376 }
377
378 #[test]
379 fn find_byte_basic() {
380 assert_eq!(find_byte(b"hello", b'l'), Some(2));
381 assert_eq!(find_byte(b"hello", b'z'), None);
382 }
383
384 #[test]
385 fn find_byte_long() {
386 let data: Vec<u8> = (0..256).map(|i| i as u8).collect();
388 assert_eq!(find_byte(&data, 0), Some(0));
389 assert_eq!(find_byte(&data, 42), Some(42));
390 assert_eq!(find_byte(&data, 255), Some(255));
391
392 let zeros = vec![0u8; 100];
393 assert_eq!(find_byte(&zeros, 1), None);
394 }
395
396 #[test]
397 fn count_byte_basic() {
398 assert_eq!(count_byte(b"hello", b'l'), 2);
399 assert_eq!(count_byte(b"hello", b'z'), 0);
400 assert_eq!(count_byte(b"hello", b'o'), 1);
401 }
402
403 #[test]
404 fn count_byte_long() {
405 let data = vec![0xABu8; 200];
406 assert_eq!(count_byte(&data, 0xAB), 200);
407 assert_eq!(count_byte(&data, 0x00), 0);
408 }
409
410 #[test]
411 fn find_non_ascii_basic() {
412 assert_eq!(find_non_ascii(b"hello"), None);
413 assert_eq!(find_non_ascii(b"hello\x80"), Some(5));
414 assert_eq!(find_non_ascii(b"\xff"), Some(0));
415 }
416
417 #[test]
418 fn find_non_ascii_long() {
419 let mut data = vec![b'a'; 100];
420 assert_eq!(find_non_ascii(&data), None);
421 data[50] = 0x80;
422 assert_eq!(find_non_ascii(&data), Some(50));
423 }
424}