#[inline]
pub fn simd_gallop_to(arr: &[u32], cursor: &mut usize, target: u32) -> bool {
if arr.is_empty() || *cursor >= arr.len() {
*cursor = arr.len();
return false;
}
let current = unsafe { *arr.get_unchecked(*cursor) };
if current >= target {
return true;
}
let mut lo = *cursor;
let mut step = 1usize;
let mut hi = lo + step;
while hi < arr.len() {
let val = unsafe { *arr.get_unchecked(hi) };
if val >= target {
break;
}
lo = hi;
step = step.saturating_mul(2);
hi = lo.saturating_add(step);
}
hi = (hi + 1).min(arr.len());
#[cfg(target_arch = "x86_64")]
{
if std::arch::is_x86_feature_detected!("avx2") {
return unsafe { gallop_scan_avx2(arr, lo, hi, target, cursor) };
}
if std::arch::is_x86_feature_detected!("sse2") {
return unsafe { gallop_scan_sse2(arr, lo, hi, target, cursor) };
}
}
gallop_scan_scalar(arr, lo, hi, target, cursor)
}
#[inline]
pub fn simd_block_filter(doc_ids: &[u32], scores: &[f32], theta: f32) -> (u64, usize) {
assert!(scores.len() <= 64, "Block size must be <= 64 elements");
assert_eq!(doc_ids.len(), scores.len(), "doc_ids and scores must have same length");
if scores.is_empty() {
return (0, 0);
}
#[cfg(target_arch = "x86_64")]
{
if std::arch::is_x86_feature_detected!("avx2") {
return unsafe { block_filter_avx2(scores, theta) };
}
}
block_filter_scalar(scores, theta)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn gallop_scan_avx2(
arr: &[u32],
lo: usize,
hi: usize,
target: u32,
cursor: &mut usize,
) -> bool {
use std::arch::x86_64::*;
const BIAS: u32 = 0x80000000;
let mut pos = lo;
unsafe {
let target_biased = _mm256_set1_epi32((target ^ BIAS) as i32);
while pos + 8 <= hi {
let arr_vec = _mm256_loadu_si256(arr.as_ptr().add(pos) as *const __m256i);
let arr_biased = _mm256_xor_si256(arr_vec, _mm256_set1_epi32(BIAS as i32));
let gt_mask = _mm256_cmpgt_epi32(target_biased, arr_biased);
let movemask = _mm256_movemask_epi8(gt_mask) as u32;
if movemask != 0xFFFFFFFF {
let trailing_ones = movemask.trailing_ones() as usize;
let elem_idx = trailing_ones / 4;
*cursor = pos + elem_idx;
return true;
}
pos += 8;
}
}
gallop_scan_scalar(arr, pos, hi, target, cursor)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse2")]
unsafe fn gallop_scan_sse2(
arr: &[u32],
lo: usize,
hi: usize,
target: u32,
cursor: &mut usize,
) -> bool {
use std::arch::x86_64::*;
const BIAS: u32 = 0x80000000;
let mut pos = lo;
unsafe {
let target_biased = _mm_set1_epi32((target ^ BIAS) as i32);
while pos + 4 <= hi {
let arr_vec = _mm_loadu_si128(arr.as_ptr().add(pos) as *const __m128i);
let arr_biased = _mm_xor_si128(arr_vec, _mm_set1_epi32(BIAS as i32));
let gt_mask = _mm_cmpgt_epi32(target_biased, arr_biased);
let movemask = _mm_movemask_epi8(gt_mask) as u32;
if movemask != 0xFFFF {
let trailing_ones = movemask.trailing_ones() as usize;
let elem_idx = trailing_ones / 4;
*cursor = pos + elem_idx;
return true;
}
pos += 4;
}
}
gallop_scan_scalar(arr, pos, hi, target, cursor)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn block_filter_avx2(scores: &[f32], theta: f32) -> (u64, usize) {
use std::arch::x86_64::*;
let mut mask = 0u64;
let mut pos = 0usize;
unsafe {
let theta_vec = _mm256_set1_ps(theta);
while pos + 8 <= scores.len() {
let scores_vec = _mm256_loadu_ps(scores.as_ptr().add(pos));
let cmp = _mm256_cmp_ps::<14>(scores_vec, theta_vec);
let movemask = _mm256_movemask_ps(cmp) as u64;
mask |= movemask << pos;
pos += 8;
}
}
while pos < scores.len() {
if unsafe { *scores.get_unchecked(pos) } > theta {
mask |= 1u64 << pos;
}
pos += 1;
}
let count = mask.count_ones() as usize;
(mask, count)
}
fn gallop_scan_scalar(
arr: &[u32],
lo: usize,
hi: usize,
target: u32,
cursor: &mut usize,
) -> bool {
let mut pos = lo;
while pos < hi {
let val = unsafe { *arr.get_unchecked(pos) };
if val >= target {
*cursor = pos;
return true;
}
pos += 1;
}
*cursor = arr.len();
false
}
fn block_filter_scalar(scores: &[f32], theta: f32) -> (u64, usize) {
let mut mask = 0u64;
for (i, &score) in scores.iter().enumerate() {
if score > theta {
mask |= 1u64 << i;
}
}
let count = mask.count_ones() as usize;
(mask, count)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gallop_empty_array() {
let arr: Vec<u32> = vec![];
let mut cursor = 0;
assert!(!simd_gallop_to(&arr, &mut cursor, 10));
assert_eq!(cursor, 0);
}
#[test]
fn test_gallop_cursor_past_end() {
let arr = vec![1, 2, 3, 4, 5];
let mut cursor = 10;
assert!(!simd_gallop_to(&arr, &mut cursor, 3));
assert_eq!(cursor, arr.len());
}
#[test]
fn test_gallop_already_satisfied() {
let arr = vec![1, 5, 10, 15, 20];
let mut cursor = 2;
assert!(simd_gallop_to(&arr, &mut cursor, 10));
assert_eq!(cursor, 2); }
#[test]
fn test_gallop_target_past_end() {
let arr = vec![1, 5, 10, 15, 20];
let mut cursor = 0;
assert!(!simd_gallop_to(&arr, &mut cursor, 100));
assert_eq!(cursor, arr.len());
}
#[test]
fn test_gallop_exact_match() {
let arr = vec![1, 5, 10, 15, 20, 25, 30];
let mut cursor = 0;
assert!(simd_gallop_to(&arr, &mut cursor, 15));
assert_eq!(cursor, 3); }
#[test]
fn test_gallop_between_elements() {
let arr = vec![1, 5, 10, 15, 20, 25, 30];
let mut cursor = 0;
assert!(simd_gallop_to(&arr, &mut cursor, 12));
assert_eq!(cursor, 3); }
#[test]
fn test_gallop_single_element() {
let arr = vec![42];
let mut cursor = 0;
assert!(simd_gallop_to(&arr, &mut cursor, 40));
assert_eq!(cursor, 0);
cursor = 0;
assert!(simd_gallop_to(&arr, &mut cursor, 42));
assert_eq!(cursor, 0);
cursor = 0;
assert!(!simd_gallop_to(&arr, &mut cursor, 50));
assert_eq!(cursor, 1);
}
#[test]
fn test_gallop_small_array() {
let arr = vec![1, 3, 5];
let mut cursor = 0;
assert!(simd_gallop_to(&arr, &mut cursor, 4));
assert_eq!(cursor, 2); }
#[test]
fn test_gallop_large_array() {
let arr: Vec<u32> = (0..10000).map(|i| i * 2).collect();
let mut cursor = 0;
assert!(simd_gallop_to(&arr, &mut cursor, 100));
assert_eq!(cursor, 50);
cursor = 0;
assert!(simd_gallop_to(&arr, &mut cursor, 10000));
assert_eq!(cursor, 5000);
cursor = 0;
assert!(simd_gallop_to(&arr, &mut cursor, 19500));
assert_eq!(cursor, 9750);
cursor = 0;
assert!(simd_gallop_to(&arr, &mut cursor, 1001));
assert_eq!(cursor, 501); }
#[test]
fn test_gallop_sequential() {
let arr: Vec<u32> = (0..1000).map(|i| i * 10).collect();
let mut cursor = 0;
assert!(simd_gallop_to(&arr, &mut cursor, 500));
assert_eq!(cursor, 50);
assert!(simd_gallop_to(&arr, &mut cursor, 750));
assert_eq!(cursor, 75);
assert!(simd_gallop_to(&arr, &mut cursor, 1000));
assert_eq!(cursor, 100);
assert!(simd_gallop_to(&arr, &mut cursor, 9000));
assert_eq!(cursor, 900);
}
#[test]
fn test_gallop_u32_max_boundary() {
let arr = vec![
u32::MAX - 1000,
u32::MAX - 500,
u32::MAX - 100,
u32::MAX - 10,
u32::MAX,
];
let mut cursor = 0;
assert!(simd_gallop_to(&arr, &mut cursor, u32::MAX - 200));
assert_eq!(cursor, 2);
cursor = 0;
assert!(simd_gallop_to(&arr, &mut cursor, u32::MAX));
assert_eq!(cursor, 4);
cursor = 4;
assert!(simd_gallop_to(&arr, &mut cursor, u32::MAX - 1));
assert_eq!(cursor, 4);
}
#[test]
fn test_filter_empty() {
let doc_ids: Vec<u32> = vec![];
let scores: Vec<f32> = vec![];
let (mask, count) = simd_block_filter(&doc_ids, &scores, 5.0);
assert_eq!(mask, 0);
assert_eq!(count, 0);
}
#[test]
fn test_filter_all_above() {
let doc_ids = vec![1, 2, 3, 4];
let scores = vec![10.0, 20.0, 30.0, 40.0];
let (mask, count) = simd_block_filter(&doc_ids, &scores, 5.0);
assert_eq!(mask, 0b1111);
assert_eq!(count, 4);
}
#[test]
fn test_filter_all_below() {
let doc_ids = vec![1, 2, 3, 4];
let scores = vec![1.0, 2.0, 3.0, 4.0];
let (mask, count) = simd_block_filter(&doc_ids, &scores, 10.0);
assert_eq!(mask, 0);
assert_eq!(count, 0);
}
#[test]
fn test_filter_mixed() {
let doc_ids = vec![1, 2, 3, 4, 5, 6];
let scores = vec![10.0, 5.0, 15.0, 3.0, 20.0, 7.5];
let theta = 7.0;
let (mask, count) = simd_block_filter(&doc_ids, &scores, theta);
assert_eq!(count, 4);
assert_eq!(mask & (1 << 0), 1 << 0); assert_eq!(mask & (1 << 1), 0); assert_eq!(mask & (1 << 2), 1 << 2); assert_eq!(mask & (1 << 3), 0); assert_eq!(mask & (1 << 4), 1 << 4); assert_eq!(mask & (1 << 5), 1 << 5); }
#[test]
fn test_filter_exact_boundary() {
let doc_ids = vec![1, 2, 3, 4];
let scores = vec![10.0, 5.0, 5.0, 15.0];
let (mask, count) = simd_block_filter(&doc_ids, &scores, 5.0);
assert_eq!(count, 2);
assert_eq!(mask & (1 << 0), 1 << 0);
assert_eq!(mask & (1 << 1), 0);
assert_eq!(mask & (1 << 2), 0);
assert_eq!(mask & (1 << 3), 1 << 3);
}
#[test]
fn test_filter_nan_scores() {
let doc_ids = vec![1, 2, 3, 4];
let scores = vec![10.0, f32::NAN, 15.0, 5.0];
let (mask, count) = simd_block_filter(&doc_ids, &scores, 7.0);
assert_eq!(count, 2);
assert_eq!(mask & (1 << 0), 1 << 0);
assert_eq!(mask & (1 << 1), 0); assert_eq!(mask & (1 << 2), 1 << 2);
assert_eq!(mask & (1 << 3), 0);
}
#[test]
fn test_filter_single_element() {
let doc_ids = vec![1];
let scores = vec![10.0];
let (mask, count) = simd_block_filter(&doc_ids, &scores, 5.0);
assert_eq!(mask, 1);
assert_eq!(count, 1);
let (mask2, count2) = simd_block_filter(&doc_ids, &scores, 15.0);
assert_eq!(mask2, 0);
assert_eq!(count2, 0);
}
#[test]
fn test_filter_full_64() {
let doc_ids: Vec<u32> = (0..64).collect();
let scores: Vec<f32> = (0..64).map(|i| i as f32).collect();
let (mask, count) = simd_block_filter(&doc_ids, &scores, 31.5);
assert_eq!(count, 32);
for i in 0..32 {
assert_eq!(mask & (1u64 << i), 0, "bit {} should not be set", i);
}
for i in 32..64 {
assert_eq!(mask & (1u64 << i), 1u64 << i, "bit {} should be set", i);
}
}
#[test]
fn test_filter_8_elements() {
let doc_ids: Vec<u32> = (0..8).collect();
let scores = vec![1.0, 10.0, 2.0, 20.0, 3.0, 30.0, 4.0, 40.0];
let (mask, count) = simd_block_filter(&doc_ids, &scores, 5.0);
assert_eq!(count, 4);
assert_eq!(mask, 0b10101010);
}
#[test]
#[should_panic(expected = "Block size must be <= 64 elements")]
fn test_filter_too_large() {
let doc_ids: Vec<u32> = (0..65).collect();
let scores: Vec<f32> = vec![1.0; 65];
let _ = simd_block_filter(&doc_ids, &scores, 0.5);
}
#[test]
#[should_panic(expected = "doc_ids and scores must have same length")]
fn test_filter_mismatched_lengths() {
let doc_ids = vec![1, 2, 3];
let scores = vec![1.0, 2.0];
let _ = simd_block_filter(&doc_ids, &scores, 0.5);
}
#[test]
#[cfg(not(debug_assertions))]
fn test_gallop_performance() {
use std::time::Instant;
let arr: Vec<u32> = (0..100_000).map(|i| i * 2).collect();
let queries: Vec<u32> = (0..10_000).map(|i| i * 20 + 7).collect();
let start = Instant::now();
let mut cursor = 0;
let mut found_count = 0;
for &target in &queries {
if simd_gallop_to(&arr, &mut cursor, target) {
found_count += 1;
}
cursor = 0; }
let elapsed = start.elapsed();
assert_eq!(found_count, 10_000);
assert!(
elapsed.as_millis() < 100,
"Performance test too slow: {:?}",
elapsed
);
println!(
"Gallop performance: {} queries in {:?} ({:.2} ns/query)",
queries.len(),
elapsed,
elapsed.as_nanos() as f64 / queries.len() as f64
);
}
#[test]
#[cfg(not(debug_assertions))]
fn test_filter_performance() {
use std::time::Instant;
let doc_ids: Vec<u32> = (0..64).collect();
let scores: Vec<f32> = (0..64).map(|i| i as f32 * 1.5).collect();
let theta = 50.0;
let start = Instant::now();
let mut total_count = 0;
for _ in 0..10_000 {
let (_, count) = simd_block_filter(&doc_ids, &scores, theta);
total_count += count;
}
let elapsed = start.elapsed();
assert!(total_count > 0);
assert!(
elapsed.as_millis() < 100,
"Performance test too slow: {:?}",
elapsed
);
println!(
"Filter performance: 10K calls in {:?} ({:.2} ns/call)",
elapsed,
elapsed.as_nanos() as f64 / 10_000.0
);
}
#[test]
fn test_gallop_target_zero() {
let arr = vec![0, 10, 20, 30];
let mut cursor = 0;
assert!(simd_gallop_to(&arr, &mut cursor, 0));
assert_eq!(cursor, 0);
let arr2 = vec![5, 10, 15];
let mut cursor2 = 0;
assert!(simd_gallop_to(&arr2, &mut cursor2, 0));
assert_eq!(cursor2, 0); }
#[test]
fn test_gallop_duplicates() {
let arr = vec![1, 5, 5, 5, 10, 10, 20];
let mut cursor = 0;
assert!(simd_gallop_to(&arr, &mut cursor, 5));
assert_eq!(cursor, 1);
cursor = 0;
assert!(simd_gallop_to(&arr, &mut cursor, 7));
assert_eq!(cursor, 4); }
#[test]
fn test_gallop_simd_boundary_sizes() {
let arr4 = vec![10, 20, 30, 40];
let mut cursor = 0;
assert!(simd_gallop_to(&arr4, &mut cursor, 25));
assert_eq!(cursor, 2);
let arr7 = vec![1, 2, 3, 4, 5, 6, 7];
cursor = 0;
assert!(simd_gallop_to(&arr7, &mut cursor, 5));
assert_eq!(cursor, 4);
let arr9 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
cursor = 0;
assert!(simd_gallop_to(&arr9, &mut cursor, 9));
assert_eq!(cursor, 8);
let arr8 = vec![10, 20, 30, 40, 50, 60, 70, 80];
cursor = 0;
assert!(simd_gallop_to(&arr8, &mut cursor, 45));
assert_eq!(cursor, 4);
let arr16: Vec<u32> = (1..=16).collect();
cursor = 0;
assert!(simd_gallop_to(&arr16, &mut cursor, 15));
assert_eq!(cursor, 14);
}
#[test]
fn test_gallop_mid_array_advance() {
let arr: Vec<u32> = (0..100).map(|i| i * 10).collect();
let mut cursor = 50;
assert!(simd_gallop_to(&arr, &mut cursor, 750));
assert_eq!(cursor, 75);
assert!(simd_gallop_to(&arr, &mut cursor, 900));
assert_eq!(cursor, 90);
assert!(!simd_gallop_to(&arr, &mut cursor, 1000));
assert_eq!(cursor, arr.len());
}
#[test]
fn test_filter_infinity() {
let doc_ids = vec![1, 2, 3, 4];
let scores = vec![f32::INFINITY, 10.0, f32::NEG_INFINITY, 5.0];
let (mask, count) = simd_block_filter(&doc_ids, &scores, 100.0);
assert_eq!(count, 1);
assert_ne!(mask & (1 << 0), 0);
let (mask2, count2) = simd_block_filter(&doc_ids, &scores, f32::INFINITY);
assert_eq!(count2, 0);
assert_eq!(mask2, 0);
let (mask3, count3) = simd_block_filter(&doc_ids, &scores, f32::NEG_INFINITY);
assert_eq!(count3, 3); assert_ne!(mask3 & (1 << 0), 0); assert_ne!(mask3 & (1 << 1), 0); assert_eq!(mask3 & (1 << 2), 0); assert_ne!(mask3 & (1 << 3), 0); }
#[test]
fn test_filter_negative_scores() {
let doc_ids = vec![1, 2, 3, 4];
let scores = vec![-5.0, -2.0, 3.0, -10.0];
let theta = -3.0;
let (mask, count) = simd_block_filter(&doc_ids, &scores, theta);
assert_eq!(count, 2);
assert_ne!(mask & (1 << 1), 0); assert_ne!(mask & (1 << 2), 0); assert_eq!(mask & (1 << 0), 0); assert_eq!(mask & (1 << 3), 0); }
#[test]
fn test_filter_63_elements() {
let doc_ids: Vec<u32> = (0..63).collect();
let scores: Vec<f32> = (0..63).map(|i| i as f32).collect();
let (mask, count) = simd_block_filter(&doc_ids, &scores, 31.0);
assert_eq!(count, 31);
for i in 0..=31 {
assert_eq!(mask & (1u64 << i), 0, "score {} should not qualify", i);
}
for i in 32..63 {
assert_ne!(mask & (1u64 << i), 0, "score {} should qualify", i);
}
}
#[test]
fn test_filter_7_elements() {
let doc_ids: Vec<u32> = (0..7).collect();
let scores = vec![1.0, 5.0, 3.0, 7.0, 2.0, 6.0, 4.0];
let (mask, count) = simd_block_filter(&doc_ids, &scores, 4.0);
assert_eq!(count, 3);
assert_ne!(mask & (1 << 1), 0);
assert_ne!(mask & (1 << 3), 0);
assert_ne!(mask & (1 << 5), 0);
}
#[test]
fn test_filter_9_elements() {
let doc_ids: Vec<u32> = (0..9).collect();
let scores = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let (mask, count) = simd_block_filter(&doc_ids, &scores, 5.0);
assert_eq!(count, 4);
for i in 0..=4 {
assert_eq!(mask & (1u64 << i), 0);
}
for i in 5..9 {
assert_ne!(mask & (1u64 << i), 0);
}
}
#[test]
fn test_gallop_consistency_simd_vs_scalar() {
let arr: Vec<u32> = (0..1000).map(|i| i * 7 + 3).collect();
let targets: Vec<u32> = (0..200).map(|i| i * 35).collect();
for &target in &targets {
let mut cursor_test = 0;
let result = simd_gallop_to(&arr, &mut cursor_test, target);
let expected = arr.iter().position(|&v| v >= target);
match expected {
Some(pos) => {
assert!(result, "should find target {}", target);
assert_eq!(cursor_test, pos, "wrong position for target {}", target);
}
None => {
assert!(!result, "should not find target {}", target);
assert_eq!(cursor_test, arr.len());
}
}
}
}
#[test]
fn test_filter_consistency_simd_vs_scalar() {
let doc_ids: Vec<u32> = (0..64).collect();
let scores: Vec<f32> = (0..64).map(|i| (i as f32) * 0.5 - 10.0).collect();
for theta_i in -20i32..30 {
let theta = theta_i as f32 * 0.5;
let (mask, count) = simd_block_filter(&doc_ids, &scores, theta);
let mut expected_mask = 0u64;
let mut expected_count = 0;
for (i, &s) in scores.iter().enumerate() {
if s > theta {
expected_mask |= 1u64 << i;
expected_count += 1;
}
}
assert_eq!(mask, expected_mask, "mask mismatch for theta={}", theta);
assert_eq!(count, expected_count, "count mismatch for theta={}", theta);
}
}
}