memspan 0.1.0

SIMD-accelerated byte-class scanning for lexers and parsers. Backends: AVX-512, AVX2, SSE4.1, NEON, WASM SIMD128. no_std compatible.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
//! AVX-512BW (512-bit) implementations of `skip_until`, `skip_while`, and the
#![allow(unsafe_op_in_unsafe_fn)]
//! specialized ASCII-class scanners.
//!
//! AVX-512BW is special: comparison intrinsics like `_mm512_cmpeq_epi8_mask`
//! and `_mm512_cmple_epu8_mask` return a `__mmask64` (u64) **directly** — no
//! `movemask` conversion. Position extraction is therefore simply
//! `bits.trailing_zeros()`.
//!
//! Chunk size is 64 bytes; the 2× unrolled loop covers 128 bytes per iteration.

#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;

use crate::Needles;

const CHUNK: usize = 64;

/// Range check [lo, hi] for AVX-512BW (2 ops, same cost as NEON).
///
/// `_mm512_sub_epi8` does wrapping u8 subtraction (identical bit pattern to
/// unsigned). `_mm512_cmple_epu8_mask` then does an unsigned ≤ comparison and
/// returns a `u64` mask directly — no further conversion needed.
#[cfg_attr(not(tarpaulin), inline)]
#[target_feature(enable = "avx512bw")]
pub(crate) unsafe fn range_mask(chunk: __m512i, lo: u8, hi: u8) -> u64 {
  let x = _mm512_sub_epi8(chunk, _mm512_set1_epi8(lo as i8));
  _mm512_cmple_epu8_mask(x, _mm512_set1_epi8(hi.wrapping_sub(lo) as i8))
}

// ── per-class mask functions (return u64) ────────────────────────────────────

#[cfg_attr(not(tarpaulin), inline)]
#[target_feature(enable = "avx512bw")]
unsafe fn binary_mask(c: __m512i) -> u64 {
  range_mask(c, b'0', b'1')
}

#[cfg_attr(not(tarpaulin), inline)]
#[target_feature(enable = "avx512bw")]
unsafe fn octal_digit_mask(c: __m512i) -> u64 {
  range_mask(c, b'0', b'7')
}

#[cfg_attr(not(tarpaulin), inline)]
#[target_feature(enable = "avx512bw")]
unsafe fn digit_mask(c: __m512i) -> u64 {
  range_mask(c, b'0', b'9')
}

#[cfg_attr(not(tarpaulin), inline)]
#[target_feature(enable = "avx512bw")]
unsafe fn hex_digit_mask(c: __m512i) -> u64 {
  let digit = digit_mask(c);
  let lower = _mm512_or_si512(c, _mm512_set1_epi8(0x20u8 as i8));
  let alpha = range_mask(lower, b'a', b'f');
  digit | alpha
}

#[cfg_attr(not(tarpaulin), inline)]
#[target_feature(enable = "avx512bw")]
unsafe fn whitespace_mask(c: __m512i) -> u64 {
  let sp = _mm512_cmpeq_epi8_mask(c, _mm512_set1_epi8(b' ' as i8));
  let tab = _mm512_cmpeq_epi8_mask(c, _mm512_set1_epi8(b'\t' as i8));
  let nl = _mm512_cmpeq_epi8_mask(c, _mm512_set1_epi8(b'\n' as i8));
  let cr = _mm512_cmpeq_epi8_mask(c, _mm512_set1_epi8(b'\r' as i8));
  sp | tab | nl | cr
}

#[cfg_attr(not(tarpaulin), inline)]
#[target_feature(enable = "avx512bw")]
unsafe fn alpha_mask(c: __m512i) -> u64 {
  let lower = _mm512_or_si512(c, _mm512_set1_epi8(0x20u8 as i8));
  range_mask(lower, b'a', b'z')
}

#[cfg_attr(not(tarpaulin), inline)]
#[target_feature(enable = "avx512bw")]
unsafe fn alphanumeric_mask(c: __m512i) -> u64 {
  alpha_mask(c) | digit_mask(c)
}

#[cfg_attr(not(tarpaulin), inline)]
#[target_feature(enable = "avx512bw")]
unsafe fn ident_start_mask(c: __m512i) -> u64 {
  alpha_mask(c) | _mm512_cmpeq_epi8_mask(c, _mm512_set1_epi8(b'_' as i8))
}

#[cfg_attr(not(tarpaulin), inline)]
#[target_feature(enable = "avx512bw")]
unsafe fn ident_mask(c: __m512i) -> u64 {
  alphanumeric_mask(c) | _mm512_cmpeq_epi8_mask(c, _mm512_set1_epi8(b'_' as i8))
}

#[cfg_attr(not(tarpaulin), inline)]
#[target_feature(enable = "avx512bw")]
unsafe fn lower_mask(c: __m512i) -> u64 {
  range_mask(c, b'a', b'z')
}

#[cfg_attr(not(tarpaulin), inline)]
#[target_feature(enable = "avx512bw")]
unsafe fn upper_mask(c: __m512i) -> u64 {
  range_mask(c, b'A', b'Z')
}

#[cfg_attr(not(tarpaulin), inline)]
#[target_feature(enable = "avx512bw")]
unsafe fn ascii_mask(c: __m512i) -> u64 {
  range_mask(c, 0x00, 0x7F)
}

#[cfg_attr(not(tarpaulin), inline)]
#[target_feature(enable = "avx512bw")]
unsafe fn non_ascii_mask(c: __m512i) -> u64 {
  range_mask(c, 0x80, 0xFF)
}

#[cfg_attr(not(tarpaulin), inline)]
#[target_feature(enable = "avx512bw")]
unsafe fn ascii_graphic_mask(c: __m512i) -> u64 {
  range_mask(c, 0x21, 0x7E)
}

#[cfg_attr(not(tarpaulin), inline)]
#[target_feature(enable = "avx512bw")]
unsafe fn ascii_control_mask(c: __m512i) -> u64 {
  let ctrl = range_mask(c, 0x00, 0x1F);
  let del = _mm512_cmpeq_epi8_mask(c, _mm512_set1_epi8(0x7F_u8 as i8));
  ctrl | del
}

// ── skip_ascii_class macro ───────────────────────────────────────────────────

macro_rules! skip_ascii_class {
  ($name:ident, $prefix_len:ident, $mask:ident) => {
    #[cfg_attr(not(tarpaulin), inline)]
    #[target_feature(enable = "avx512bw")]
    pub(super) unsafe fn $name(input: &[u8]) -> usize {
      let len = input.len();
      if len < CHUNK {
        return super::$prefix_len(input);
      }

      let ptr = input.as_ptr();

      let first = super::$prefix_len(&input[..CHUNK]);
      if first != CHUNK {
        return first;
      }

      let mut cur = CHUNK;

      while cur + 2 * CHUNK <= len {
        let c0 = _mm512_loadu_si512(ptr.add(cur).cast::<__m512i>());
        let c1 = _mm512_loadu_si512(ptr.add(cur + CHUNK).cast::<__m512i>());
        let m0 = $mask(c0);
        let m1 = $mask(c1);
        // All-ones means all match; any zero means a non-match in m0|m1 position.
        // For skip_while: miss iff NOT all-ones. Use AND: zero bit = non-match.
        let combined = m0 & m1;
        if combined != !0u64 {
          if m0 != !0u64 {
            return cur + (!m0).trailing_zeros() as usize;
          }
          return cur + CHUNK + (!m1).trailing_zeros() as usize;
        }
        cur += 2 * CHUNK;
      }

      while cur + CHUNK <= len {
        let chunk = _mm512_loadu_si512(ptr.add(cur).cast::<__m512i>());
        let bits = $mask(chunk);
        if bits != !0u64 {
          return cur + (!bits).trailing_zeros() as usize;
        }
        cur += CHUNK;
      }

      if cur == len {
        return len;
      }

      let overlap_start = len - CHUNK;
      let chunk = _mm512_loadu_si512(ptr.add(overlap_start).cast::<__m512i>());
      let bits = $mask(chunk);
      let already = cur - overlap_start;
      let scan_mask = (!0u64) << already;
      let non_match = (!bits) & scan_mask;
      if non_match != 0 {
        overlap_start + non_match.trailing_zeros() as usize
      } else {
        len
      }
    }
  };
}

skip_ascii_class!(skip_binary, prefix_len_binary, binary_mask);
skip_ascii_class!(skip_octal_digits, prefix_len_octal_digits, octal_digit_mask);
skip_ascii_class!(skip_digits, prefix_len_digits, digit_mask);
skip_ascii_class!(skip_hex_digits, prefix_len_hex_digits, hex_digit_mask);
skip_ascii_class!(skip_whitespace, prefix_len_whitespace, whitespace_mask);
skip_ascii_class!(skip_alpha, prefix_len_alpha, alpha_mask);
skip_ascii_class!(
  skip_alphanumeric,
  prefix_len_alphanumeric,
  alphanumeric_mask
);
skip_ascii_class!(skip_ident_start, prefix_len_ident_start, ident_start_mask);
skip_ascii_class!(skip_ident, prefix_len_ident, ident_mask);
skip_ascii_class!(skip_lower, prefix_len_lower, lower_mask);
skip_ascii_class!(skip_upper, prefix_len_upper, upper_mask);
skip_ascii_class!(skip_ascii, prefix_len_ascii, ascii_mask);
skip_ascii_class!(skip_non_ascii, prefix_len_non_ascii, non_ascii_mask);
skip_ascii_class!(
  skip_ascii_graphic,
  prefix_len_ascii_graphic,
  ascii_graphic_mask
);
skip_ascii_class!(
  skip_ascii_control,
  prefix_len_ascii_control,
  ascii_control_mask
);

// ── count_matches / find_last ────────────────────────────────────────────────

#[cfg_attr(not(tarpaulin), inline)]
#[target_feature(enable = "avx512bw")]
pub(super) unsafe fn count_matches<Nd>(input: &[u8], needles: Nd) -> usize
where
  Nd: Needles,
{
  let len = input.len();
  if len < CHUNK {
    return input
      .iter()
      .filter(|&&b| needles.tail_find(core::slice::from_ref(&b)).is_some())
      .count();
  }

  let ptr = input.as_ptr();
  let mut count = 0usize;
  let mut cur = 0;

  while cur + 2 * CHUNK <= len {
    let c0 = _mm512_loadu_si512(ptr.add(cur).cast::<__m512i>());
    let c1 = _mm512_loadu_si512(ptr.add(cur + CHUNK).cast::<__m512i>());
    let m0 = needles.eq_any_mask_avx512(c0);
    let m1 = needles.eq_any_mask_avx512(c1);
    count += m0.count_ones() as usize;
    count += m1.count_ones() as usize;
    cur += 2 * CHUNK;
  }

  while cur + CHUNK <= len {
    let chunk = _mm512_loadu_si512(ptr.add(cur).cast::<__m512i>());
    let bits = needles.eq_any_mask_avx512(chunk);
    count += bits.count_ones() as usize;
    cur += CHUNK;
  }

  if cur < len {
    let overlap_start = len - CHUNK;
    let chunk = _mm512_loadu_si512(ptr.add(overlap_start).cast::<__m512i>());
    let bits = needles.eq_any_mask_avx512(chunk);
    let already = cur - overlap_start;
    let scan_mask = (!0u64) << already;
    count += (bits & scan_mask).count_ones() as usize;
  }

  count
}

#[cfg_attr(not(tarpaulin), inline)]
#[target_feature(enable = "avx512bw")]
pub(super) unsafe fn find_last<Nd>(input: &[u8], needles: Nd) -> Option<usize>
where
  Nd: Needles,
{
  let len = input.len();
  if len < CHUNK {
    let mut last = None;
    for (i, &b) in input.iter().enumerate() {
      if needles.tail_find(core::slice::from_ref(&b)).is_some() {
        last = Some(i);
      }
    }
    return last;
  }

  let ptr = input.as_ptr();
  let mut last: Option<usize> = None;
  let mut cur = 0;

  while cur + 2 * CHUNK <= len {
    let c0 = _mm512_loadu_si512(ptr.add(cur).cast::<__m512i>());
    let c1 = _mm512_loadu_si512(ptr.add(cur + CHUNK).cast::<__m512i>());
    let b0 = needles.eq_any_mask_avx512(c0);
    let b1 = needles.eq_any_mask_avx512(c1);
    if b0 != 0 {
      last = Some(cur + (63 - b0.leading_zeros()) as usize);
    }
    if b1 != 0 {
      last = Some(cur + CHUNK + (63 - b1.leading_zeros()) as usize);
    }
    cur += 2 * CHUNK;
  }

  while cur + CHUNK <= len {
    let chunk = _mm512_loadu_si512(ptr.add(cur).cast::<__m512i>());
    let bits = needles.eq_any_mask_avx512(chunk);
    if bits != 0 {
      last = Some(cur + (63 - bits.leading_zeros()) as usize);
    }
    cur += CHUNK;
  }

  if cur < len {
    let overlap_start = len - CHUNK;
    let chunk = _mm512_loadu_si512(ptr.add(overlap_start).cast::<__m512i>());
    let bits = needles.eq_any_mask_avx512(chunk);
    let already = cur - overlap_start;
    let scan_mask = (!0u64) << already;
    let hit_bits = bits & scan_mask;
    if hit_bits != 0 {
      last = Some(overlap_start + (63 - hit_bits.leading_zeros()) as usize);
    }
  }

  last
}

// ── generic skip_until / skip_while ─────────────────────────────────────────

#[cfg_attr(not(tarpaulin), inline)]
#[target_feature(enable = "avx512bw")]
pub(super) unsafe fn skip_until<Nd>(input: &[u8], needles: Nd) -> Option<usize>
where
  Nd: Needles,
{
  let len = input.len();
  if len < CHUNK {
    return needles.tail_find(input);
  }

  let ptr = input.as_ptr();

  if let Some(hit) = needles.tail_find(&input[..CHUNK]) {
    return Some(hit);
  }

  let mut cur = CHUNK;

  while cur + 2 * CHUNK <= len {
    let c0 = _mm512_loadu_si512(ptr.add(cur).cast::<__m512i>());
    let c1 = _mm512_loadu_si512(ptr.add(cur + CHUNK).cast::<__m512i>());
    let m0 = needles.eq_any_mask_avx512(c0);
    let m1 = needles.eq_any_mask_avx512(c1);
    let combined = m0 | m1;
    if combined != 0 {
      if m0 != 0 {
        return Some(cur + m0.trailing_zeros() as usize);
      }
      return Some(cur + CHUNK + m1.trailing_zeros() as usize);
    }
    cur += 2 * CHUNK;
  }

  while cur + CHUNK <= len {
    let chunk = _mm512_loadu_si512(ptr.add(cur).cast::<__m512i>());
    let bits = needles.eq_any_mask_avx512(chunk);
    if bits != 0 {
      return Some(cur + bits.trailing_zeros() as usize);
    }
    cur += CHUNK;
  }

  if cur == len {
    return None;
  }

  let overlap_start = len - CHUNK;
  let chunk = _mm512_loadu_si512(ptr.add(overlap_start).cast::<__m512i>());
  let bits = needles.eq_any_mask_avx512(chunk);
  let already = cur - overlap_start;
  let scan_mask = (!0u64) << already;
  let hit_bits = bits & scan_mask;
  if hit_bits != 0 {
    Some(overlap_start + hit_bits.trailing_zeros() as usize)
  } else {
    None
  }
}

#[cfg_attr(not(tarpaulin), inline)]
#[target_feature(enable = "avx512bw")]
pub(super) unsafe fn skip_while<Nd>(input: &[u8], needles: Nd) -> usize
where
  Nd: Needles,
{
  let len = input.len();
  if len < CHUNK {
    return needles.prefix_len(input);
  }

  let ptr = input.as_ptr();

  let first = needles.prefix_len(&input[..CHUNK]);
  if first != CHUNK {
    return first;
  }

  let mut cur = CHUNK;

  while cur + 2 * CHUNK <= len {
    let c0 = _mm512_loadu_si512(ptr.add(cur).cast::<__m512i>());
    let c1 = _mm512_loadu_si512(ptr.add(cur + CHUNK).cast::<__m512i>());
    let m0 = needles.eq_any_mask_avx512(c0);
    let m1 = needles.eq_any_mask_avx512(c1);
    let combined = m0 & m1;
    if combined != !0u64 {
      if m0 != !0u64 {
        return cur + (!m0).trailing_zeros() as usize;
      }
      return cur + CHUNK + (!m1).trailing_zeros() as usize;
    }
    cur += 2 * CHUNK;
  }

  while cur + CHUNK <= len {
    let chunk = _mm512_loadu_si512(ptr.add(cur).cast::<__m512i>());
    let bits = needles.eq_any_mask_avx512(chunk);
    if bits != !0u64 {
      return cur + (!bits).trailing_zeros() as usize;
    }
    cur += CHUNK;
  }

  if cur == len {
    return len;
  }

  let overlap_start = len - CHUNK;
  let chunk = _mm512_loadu_si512(ptr.add(overlap_start).cast::<__m512i>());
  let bits = needles.eq_any_mask_avx512(chunk);
  let already = cur - overlap_start;
  let scan_mask = (!0u64) << already;
  let non_match = (!bits) & scan_mask;
  if non_match != 0 {
    overlap_start + non_match.trailing_zeros() as usize
  } else {
    len
  }
}