use std::{
any::Any,
borrow::Cow,
str::{from_utf8, from_utf8_unchecked},
};
use wide::u8x16;
const NON_ASCII_BYTE_MIN: u8 = 0x80;
const LANE_BITMASK: u32 = 0xFFFF;
const TOKEN_SCRATCH_INITIAL_CAP: usize = 32;
pub trait Tokenizer: Send + Sync + 'static {
fn tokenize<'a>(&'a self, text: &'a str) -> Box<dyn Iterator<Item = String> + 'a>;
fn tokenize_each(&self, text: &str, f: &mut dyn FnMut(&str)) {
for s in self.tokenize(text) {
f(&s);
}
}
fn as_any(&self) -> &dyn Any;
fn tokenize_each_query<'q>(&self, text: &'q str, f: &mut dyn FnMut(Cow<'q, str>)) {
self.tokenize_each(text, &mut |t| f(Cow::Owned(t.to_owned())));
}
fn parse<'q>(&self, query: &'q str) -> ParsedQuery<'q> {
let mut parsed = ParsedQuery::default();
for run in query.split_whitespace() {
match run.strip_prefix('-') {
Some(rest) if !rest.is_empty() => {
self.tokenize_each_query(rest, &mut |t| parsed.negatives.push(t));
}
_ => self.tokenize_each_query(run, &mut |t| parsed.positives.push(t)),
}
}
parsed
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct AsciiLowerTokenizer;
impl AsciiLowerTokenizer {
pub fn new() -> Self {
Self
}
#[inline]
pub fn tokenize_each_inline<F: FnMut(&str)>(&self, text: &str, mut f: F) {
let bytes = text.as_bytes();
let mut buf: Vec<u8> = Vec::new();
let mut pos = 0;
while pos < bytes.len() {
pos = simd_skip_non_token(bytes, pos);
if pos >= bytes.len() {
return;
}
let start = pos;
let (end, had_upper, had_non_ascii) = simd_scan_token_run(bytes, pos);
pos = end;
if had_non_ascii || start == pos {
continue;
}
if !had_upper {
let s = unsafe { from_utf8_unchecked(&bytes[start..end]) };
f(s);
} else {
buf.clear();
buf.reserve(end - start);
for &b in &bytes[start..end] {
buf.push(b.to_ascii_lowercase());
}
let s = unsafe { from_utf8_unchecked(&buf) };
f(s);
}
}
}
}
#[inline(always)]
fn simd_skip_non_token(bytes: &[u8], mut pos: usize) -> usize {
const LANES: usize = 16;
while pos + LANES <= bytes.len() {
let arr: [u8; LANES] = unsafe { *(bytes.as_ptr().add(pos) as *const [u8; LANES]) };
let chunk = u8x16::from(arr);
let is_digit = chunk.simd_ge(u8x16::splat(b'0')) & chunk.simd_le(u8x16::splat(b'9'));
let is_upper = chunk.simd_ge(u8x16::splat(b'A')) & chunk.simd_le(u8x16::splat(b'Z'));
let is_lower = chunk.simd_ge(u8x16::splat(b'a')) & chunk.simd_le(u8x16::splat(b'z'));
let is_token = is_digit | is_upper | is_lower;
let mask = is_token.to_bitmask() & LANE_BITMASK;
if mask == 0 {
pos += LANES;
} else {
return pos + mask.trailing_zeros() as usize;
}
}
while pos < bytes.len() && !bytes[pos].is_ascii_alphanumeric() {
pos += 1;
}
pos
}
#[inline(always)]
fn simd_scan_token_run(bytes: &[u8], mut pos: usize) -> (usize, bool, bool) {
const LANES: usize = 16;
let mut had_upper = false;
let mut had_non_ascii = false;
while pos + LANES <= bytes.len() {
let arr: [u8; LANES] = unsafe { *(bytes.as_ptr().add(pos) as *const [u8; LANES]) };
let chunk = u8x16::from(arr);
let is_digit = chunk.simd_ge(u8x16::splat(b'0')) & chunk.simd_le(u8x16::splat(b'9'));
let is_upper = chunk.simd_ge(u8x16::splat(b'A')) & chunk.simd_le(u8x16::splat(b'Z'));
let is_lower = chunk.simd_ge(u8x16::splat(b'a')) & chunk.simd_le(u8x16::splat(b'z'));
let is_high =
(chunk & u8x16::splat(NON_ASCII_BYTE_MIN)).simd_eq(u8x16::splat(NON_ASCII_BYTE_MIN));
let is_token = is_digit | is_upper | is_lower;
let is_extend = is_token | is_high;
let extend_mask = is_extend.to_bitmask() & LANE_BITMASK;
let upper_mask = is_upper.to_bitmask() & LANE_BITMASK;
let high_mask = is_high.to_bitmask() & LANE_BITMASK;
let non_extend = !extend_mask & LANE_BITMASK;
if non_extend == 0 {
had_upper |= upper_mask != 0;
had_non_ascii |= high_mask != 0;
pos += LANES;
} else {
let sep_idx = non_extend.trailing_zeros() as usize;
let prefix_mask: u32 = (1u32 << sep_idx).wrapping_sub(1);
had_upper |= (upper_mask & prefix_mask) != 0;
had_non_ascii |= (high_mask & prefix_mask) != 0;
pos += sep_idx;
return (pos, had_upper, had_non_ascii);
}
}
while pos < bytes.len() {
let b = bytes[pos];
if is_token_byte(b) {
had_upper |= b.is_ascii_uppercase();
pos += 1;
} else if b >= NON_ASCII_BYTE_MIN {
had_non_ascii = true;
pos += 1;
} else {
break;
}
}
(pos, had_upper, had_non_ascii)
}
#[derive(Debug, Default)]
pub struct ParsedQuery<'q> {
pub positives: Vec<Cow<'q, str>>,
pub negatives: Vec<Cow<'q, str>>,
}
impl Tokenizer for AsciiLowerTokenizer {
fn tokenize<'a>(&'a self, text: &'a str) -> Box<dyn Iterator<Item = String> + 'a> {
Box::new(AsciiLowerIter::new(text.as_bytes()))
}
fn tokenize_each(&self, text: &str, f: &mut dyn FnMut(&str)) {
self.tokenize_each_inline(text, |s| f(s));
}
fn as_any(&self) -> &dyn Any {
self
}
fn tokenize_each_query<'q>(&self, text: &'q str, f: &mut dyn FnMut(Cow<'q, str>)) {
let bytes = text.as_bytes();
let mut pos = 0;
while pos < bytes.len() {
pos = simd_skip_non_token(bytes, pos);
if pos >= bytes.len() {
return;
}
let start = pos;
let (end, had_upper, had_non_ascii) = simd_scan_token_run(bytes, pos);
pos = end;
if had_non_ascii || start == pos {
continue;
}
let s = from_utf8(&bytes[start..end]).expect("ASCII-only by construction");
if had_upper {
f(Cow::Owned(s.to_ascii_lowercase()));
} else {
f(Cow::Borrowed(s));
}
}
}
}
struct AsciiLowerIter<'a> {
src: &'a [u8],
pos: usize,
buf: Vec<u8>,
}
impl<'a> AsciiLowerIter<'a> {
fn new(src: &'a [u8]) -> Self {
Self {
src,
pos: 0,
buf: Vec::with_capacity(TOKEN_SCRATCH_INITIAL_CAP),
}
}
}
impl Iterator for AsciiLowerIter<'_> {
type Item = String;
fn next(&mut self) -> Option<String> {
loop {
while self.pos < self.src.len() && !is_token_byte(self.src[self.pos]) {
self.pos += 1;
}
if self.pos >= self.src.len() {
return None;
}
self.buf.clear();
let mut had_non_ascii = false;
while self.pos < self.src.len() {
let b = self.src[self.pos];
if is_token_byte(b) {
self.buf.push(b.to_ascii_lowercase());
self.pos += 1;
} else if b >= NON_ASCII_BYTE_MIN {
had_non_ascii = true;
self.pos += 1;
} else {
break;
}
}
if had_non_ascii || self.buf.is_empty() {
continue;
}
let s = from_utf8(&self.buf)
.expect("ASCII-only by construction")
.to_owned();
return Some(s);
}
}
}
#[inline]
fn is_token_byte(b: u8) -> bool {
b.is_ascii_alphanumeric()
}
#[cfg(test)]
mod tests {
use super::*;
fn tokens(text: &str) -> Vec<String> {
AsciiLowerTokenizer.tokenize(text).collect()
}
#[test]
fn empty_input_yields_nothing() {
assert_eq!(tokens(""), Vec::<String>::new());
}
#[test]
fn whitespace_only_yields_nothing() {
assert_eq!(tokens(" \t\n\r"), Vec::<String>::new());
}
#[test]
fn single_token_lowercased() {
assert_eq!(tokens("Hello"), vec!["hello"]);
}
#[test]
fn multiple_tokens_split_on_whitespace() {
assert_eq!(
tokens("Rust async runtime"),
vec!["rust", "async", "runtime"]
);
}
#[test]
fn punctuation_splits_tokens() {
assert_eq!(
tokens("hello,world!foo;bar.baz?"),
vec!["hello", "world", "foo", "bar", "baz"]
);
}
#[test]
fn case_folding_applies_to_uppercase_only() {
assert_eq!(tokens("ABC abc XyZ"), vec!["abc", "abc", "xyz"]);
}
#[test]
fn alphanumerics_kept_together() {
assert_eq!(tokens("foo123 bar456"), vec!["foo123", "bar456"]);
}
#[test]
fn pure_numeric_tokens_kept() {
assert_eq!(tokens("404 200 500"), vec!["404", "200", "500"]);
}
#[test]
fn underscore_is_a_separator_in_v1() {
assert_eq!(tokens("foo_bar"), vec!["foo", "bar"]);
}
#[test]
fn dash_is_a_separator() {
assert_eq!(tokens("rust-async"), vec!["rust", "async"]);
}
#[test]
fn non_ascii_token_is_dropped() {
assert_eq!(tokens("café"), Vec::<String>::new());
}
#[test]
fn non_ascii_token_drops_only_that_token() {
assert_eq!(tokens("hello café world"), vec!["hello", "world"]);
}
#[test]
fn cjk_input_yields_nothing() {
assert_eq!(tokens("日本語"), Vec::<String>::new());
}
#[test]
fn emoji_input_yields_nothing() {
assert_eq!(tokens("hello 🚀 world"), vec!["hello", "world"]);
}
#[test]
fn multiple_consecutive_separators_are_collapsed() {
assert_eq!(tokens("foo,,,bar"), vec!["foo", "bar"]);
assert_eq!(tokens("foo bar"), vec!["foo", "bar"]);
}
#[test]
fn leading_and_trailing_separators_are_skipped() {
assert_eq!(tokens(" foo bar "), vec!["foo", "bar"]);
assert_eq!(tokens("...foo..."), vec!["foo"]);
}
#[test]
fn tokenizer_is_send_and_sync() {
fn is_send_sync<T: Send + Sync>() {}
is_send_sync::<AsciiLowerTokenizer>();
}
#[test]
fn tokenizer_used_via_dyn_trait() {
let tok: Box<dyn Tokenizer> = Box::new(AsciiLowerTokenizer);
let v: Vec<String> = tok.tokenize("Hello WORLD").collect();
assert_eq!(v, vec!["hello", "world"]);
}
#[test]
fn stress_long_input_does_not_panic() {
let chunk = "lorem ipsum dolor sit amet, consectetur adipiscing elit. ";
let big = chunk.repeat(20_000);
let count = AsciiLowerTokenizer.tokenize(&big).count();
assert_eq!(count, 8 * 20_000);
}
fn parse(query: &str) -> ParsedQuery<'_> {
AsciiLowerTokenizer.parse(query)
}
#[test]
fn parse_default_trait_impl_matches_override() {
struct PlainTok;
impl Tokenizer for PlainTok {
fn tokenize<'a>(&'a self, text: &'a str) -> Box<dyn Iterator<Item = String> + 'a> {
AsciiLowerTokenizer.tokenize(text)
}
fn as_any(&self) -> &dyn Any {
self
}
}
let p = PlainTok.parse("Rust -PYTHON");
assert_eq!(p.positives, vec!["rust"]);
assert_eq!(p.negatives, vec!["python"]);
assert!(matches!(p.positives[0], Cow::Owned(_)));
}
#[test]
fn parse_positives_only() {
let p = parse("rust async");
assert_eq!(p.positives, vec!["rust", "async"]);
assert!(p.negatives.is_empty());
}
#[test]
fn parse_single_negative() {
let p = parse("rust -python");
assert_eq!(p.positives, vec!["rust"]);
assert_eq!(p.negatives, vec!["python"]);
}
#[test]
fn parse_multiple_negatives() {
let p = parse("rust async -python -php");
assert_eq!(p.positives, vec!["rust", "async"]);
assert_eq!(p.negatives, vec!["python", "php"]);
}
#[test]
fn parse_negation_only() {
let p = parse("-python");
assert!(p.positives.is_empty());
assert_eq!(p.negatives, vec!["python"]);
}
#[test]
fn parse_interior_hyphen_is_not_negation() {
let p = parse("a-b");
assert_eq!(p.positives, vec!["a", "b"]);
assert!(p.negatives.is_empty());
}
#[test]
fn parse_bare_dash_contributes_nothing() {
let p = parse("rust - python");
assert_eq!(p.positives, vec!["rust", "python"]);
assert!(p.negatives.is_empty());
}
#[test]
fn parse_double_dash_strips_one_then_tokenizes() {
let p = parse("--py");
assert!(p.positives.is_empty());
assert_eq!(p.negatives, vec!["py"]);
}
#[test]
fn parse_negated_term_is_normalized() {
let p = parse("rust -PYTHON");
assert_eq!(p.negatives, vec!["python"]);
}
#[test]
fn parse_empty_query() {
let p = parse("");
assert!(p.positives.is_empty());
assert!(p.negatives.is_empty());
}
#[test]
fn parse_lowercase_tokens_borrow_the_query() {
let p = parse("rust -python");
assert!(matches!(p.positives[0], Cow::Borrowed(_)));
assert!(matches!(p.negatives[0], Cow::Borrowed(_)));
}
#[test]
fn parse_uppercase_token_is_the_only_copy() {
let p = parse("rust -PYTHON");
assert!(matches!(p.positives[0], Cow::Borrowed(_)));
assert!(matches!(p.negatives[0], Cow::Owned(_)));
}
#[test]
fn dyn_tokenize_each_lowercases_and_splits() {
let tok = AsciiLowerTokenizer::new();
let dynt: &dyn Tokenizer = &tok;
let mut out = Vec::new();
dynt.tokenize_each("Hello, World rust", &mut |s| out.push(s.to_string()));
assert_eq!(out, vec!["hello", "world", "rust"]);
}
}