pub mod ascii {
pub fn to_lowercase_optimized(input: &str) -> String {
let bytes = input.as_bytes();
let mut result = Vec::with_capacity(bytes.len());
let chunks = bytes.chunks_exact(8);
let remainder = chunks.remainder();
for chunk in chunks {
let mut processed = [0u8; 8];
for (i, &byte) in chunk.iter().enumerate() {
if byte.is_ascii_uppercase() {
processed[i] = byte + 32; } else {
processed[i] = byte;
}
}
result.extend_from_slice(&processed);
}
for &byte in remainder {
if byte.is_ascii_uppercase() {
result.push(byte + 32);
} else {
result.push(byte);
}
}
unsafe { String::from_utf8_unchecked(result) }
}
pub fn to_lowercase_fallback(input: &str) -> String {
input.to_lowercase()
}
pub fn to_lowercase(input: &str) -> String {
if input.is_ascii() && input.len() >= 16 {
to_lowercase_optimized(input)
} else {
to_lowercase_fallback(input)
}
}
pub fn find_whitespace_optimized(input: &[u8]) -> Option<usize> {
if input.len() < 8 {
return input.iter().position(|&b| b.is_ascii_whitespace());
}
let mut chunks = input.chunks_exact(8);
let remainder = chunks.remainder();
let mut chunk_idx = 0;
for chunk in &mut chunks {
for (byte_idx, &byte) in chunk.iter().enumerate() {
if byte == b' ' || byte == b'\t' || byte == b'\n' || byte == b'\r' {
return Some(chunk_idx * 8 + byte_idx);
}
}
chunk_idx += 1;
}
let base_offset = chunk_idx * 8;
remainder
.iter()
.position(|&b| b.is_ascii_whitespace())
.map(|pos| base_offset + pos)
}
pub fn find_whitespace_simd(input: &[u8]) -> Option<usize> {
find_whitespace_optimized(input)
}
}
pub mod numeric {
use wide::f32x8;
pub fn batch_bm25_tf(term_freqs: &[f32], k1: f32, norm_factors: &[f32]) -> Vec<f32> {
assert_eq!(term_freqs.len(), norm_factors.len());
let len = term_freqs.len();
let mut results = Vec::with_capacity(len);
let k1_plus_1 = f32x8::splat(k1 + 1.0);
let k1_vec = f32x8::splat(k1);
let tf_chunks = term_freqs.chunks_exact(8);
let norm_chunks = norm_factors.chunks_exact(8);
let tf_remainder = tf_chunks.remainder();
for (tf_chunk, norm_chunk) in tf_chunks.zip(norm_chunks) {
let tf = f32x8::from(tf_chunk);
let norm = f32x8::from(norm_chunk);
let numerator = tf * k1_plus_1;
let denominator = tf + k1_vec * norm;
let result = numerator / denominator;
results.extend_from_slice(&result.to_array());
}
let norm_remainder_start = len - tf_remainder.len();
let norm_remainder = &norm_factors[norm_remainder_start..];
for (tf, norm) in tf_remainder.iter().zip(norm_remainder.iter()) {
results.push((tf * (k1 + 1.0)) / (tf + k1 * norm));
}
results
}
pub fn batch_bm25_final_score(
tf_scores: &[f32],
idf_scores: &[f32],
boosts: &[f32],
) -> Vec<f32> {
assert_eq!(tf_scores.len(), idf_scores.len());
assert_eq!(tf_scores.len(), boosts.len());
let len = tf_scores.len();
let mut results = Vec::with_capacity(len);
let tf_chunks = tf_scores.chunks_exact(8);
let idf_chunks = idf_scores.chunks_exact(8);
let boost_chunks = boosts.chunks_exact(8);
let tf_remainder = tf_chunks.remainder();
for ((tf_chunk, idf_chunk), boost_chunk) in tf_chunks.zip(idf_chunks).zip(boost_chunks) {
let tf = f32x8::from(tf_chunk);
let idf = f32x8::from(idf_chunk);
let boost = f32x8::from(boost_chunk);
let result = idf * tf * boost;
results.extend_from_slice(&result.to_array());
}
let remainder_start = len - tf_remainder.len();
let idf_remainder = &idf_scores[remainder_start..];
let boost_remainder = &boosts[remainder_start..];
for ((tf, idf), boost) in tf_remainder
.iter()
.zip(idf_remainder.iter())
.zip(boost_remainder.iter())
{
results.push(idf * tf * boost);
}
results
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_optimized_lowercase_ascii() {
let input = "HELLO WORLD THIS IS A TEST STRING FOR OPTIMIZATION";
let expected = "hello world this is a test string for optimization";
let result = ascii::to_lowercase(input);
assert_eq!(result, expected);
}
#[test]
fn test_optimized_lowercase_mixed() {
let input = "Hello World 123 ABC def";
let expected = "hello world 123 abc def";
let result = ascii::to_lowercase(input);
assert_eq!(result, expected);
}
#[test]
fn test_optimized_lowercase_empty() {
let result = ascii::to_lowercase("");
assert_eq!(result, "");
}
#[test]
fn test_optimized_lowercase_short() {
let input = "ABC";
let expected = "abc";
let result = ascii::to_lowercase(input);
assert_eq!(result, expected);
}
#[test]
fn test_fallback_for_unicode() {
let input = "Héllo Wörld"; let result = ascii::to_lowercase(input);
let expected = input.to_lowercase();
assert_eq!(result, expected);
}
#[test]
fn test_simd_batch_bm25_tf() {
let tfs = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let norms = vec![1.0; 10];
let k1 = 1.2;
let result = numeric::batch_bm25_tf(&tfs, k1, &norms);
assert_eq!(result.len(), 10);
for (i, &r) in result.iter().enumerate() {
let tf = tfs[i];
let expected = (tf * (k1 + 1.0)) / (tf + k1 * 1.0);
assert!((r - expected).abs() < 1e-5, "mismatch at index {i}");
}
}
#[test]
fn test_simd_batch_bm25_tf_exact_multiple() {
let tfs = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let norms = vec![0.5; 8];
let k1 = 1.5;
let result = numeric::batch_bm25_tf(&tfs, k1, &norms);
assert_eq!(result.len(), 8);
for (i, &r) in result.iter().enumerate() {
let tf = tfs[i];
let expected = (tf * (k1 + 1.0)) / (tf + k1 * 0.5);
assert!((r - expected).abs() < 1e-5);
}
}
#[test]
fn test_simd_batch_bm25_final_score() {
let tfs = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let idfs = vec![0.5, 1.0, 1.5, 2.0, 0.5, 1.0, 1.5, 2.0, 0.5];
let boosts = vec![1.0; 9];
let result = numeric::batch_bm25_final_score(&tfs, &idfs, &boosts);
assert_eq!(result.len(), 9);
for (i, &r) in result.iter().enumerate() {
let expected = idfs[i] * tfs[i] * boosts[i];
assert!((r - expected).abs() < 1e-5, "mismatch at index {i}");
}
}
#[test]
fn test_simd_batch_bm25_tf_empty() {
let result = numeric::batch_bm25_tf(&[], 1.2, &[]);
assert!(result.is_empty());
}
#[test]
fn test_simd_batch_bm25_final_score_empty() {
let result = numeric::batch_bm25_final_score(&[], &[], &[]);
assert!(result.is_empty());
}
}