pub struct Bm25BatchScorer<'a> {
norm_table: &'a [f32; 256],
idf_k1p1: f32, }
impl<'a> Bm25BatchScorer<'a> {
#[inline]
pub fn new(norm_table: &'a [f32; 256], idf: f32, k1: f32) -> Self {
Self {
norm_table,
idf_k1p1: idf * (k1 + 1.0),
}
}
#[inline]
pub fn score(&self, fieldnorm_byte: u8, tf: u16) -> f32 {
let tf_f = tf as f32;
let len_norm = self.norm_table[fieldnorm_byte as usize];
self.idf_k1p1 * tf_f / (tf_f + len_norm)
}
pub fn batch_score(&self, fieldnorm_bytes: &[u8], tfs: &[u16], scores: &mut [f32]) {
let n = tfs.len();
assert_eq!(fieldnorm_bytes.len(), n);
assert!(scores.len() >= n);
#[cfg(target_arch = "x86_64")]
{
if std::is_x86_feature_detected!("avx2") {
unsafe {
self.batch_score_avx2(fieldnorm_bytes, tfs, scores, n);
}
return;
}
}
self.batch_score_scalar(fieldnorm_bytes, tfs, scores, n);
}
#[inline]
fn batch_score_scalar(&self, fieldnorm_bytes: &[u8], tfs: &[u16], scores: &mut [f32], n: usize) {
for i in 0..n {
let tf_f = tfs[i] as f32;
let len_norm = self.norm_table[fieldnorm_bytes[i] as usize];
scores[i] = self.idf_k1p1 * tf_f / (tf_f + len_norm);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn batch_score_avx2(
&self,
fieldnorm_bytes: &[u8],
tfs: &[u16],
scores: &mut [f32],
n: usize,
) {
use std::arch::x86_64::*;
let idf_k1p1_vec = unsafe { _mm256_set1_ps(self.idf_k1p1) };
let table = self.norm_table;
let chunks = n / 8;
let remainder = n % 8;
for chunk in 0..chunks {
let base = chunk * 8;
let fn_slice = &fieldnorm_bytes[base..base + 8];
unsafe {
let norms = _mm256_set_ps(
table[fn_slice[7] as usize],
table[fn_slice[6] as usize],
table[fn_slice[5] as usize],
table[fn_slice[4] as usize],
table[fn_slice[3] as usize],
table[fn_slice[2] as usize],
table[fn_slice[1] as usize],
table[fn_slice[0] as usize],
);
let tfs_u16 = _mm_loadu_si128(tfs.as_ptr().add(base) as *const __m128i);
let tfs_i32 = _mm256_cvtepu16_epi32(tfs_u16);
let tfs_f32 = _mm256_cvtepi32_ps(tfs_i32);
let denom = _mm256_add_ps(tfs_f32, norms);
let numer = _mm256_mul_ps(idf_k1p1_vec, tfs_f32);
let result = _mm256_div_ps(numer, denom);
_mm256_storeu_ps(scores.as_mut_ptr().add(base), result);
}
}
let rem_base = chunks * 8;
for i in 0..remainder {
let idx = rem_base + i;
let tf_f = tfs[idx] as f32;
let len_norm = self.norm_table[fieldnorm_bytes[idx] as usize];
scores[idx] = self.idf_k1p1 * tf_f / (tf_f + len_norm);
}
}
#[inline]
pub fn score_with_prefetch(
&self,
fieldnorm_bytes: &[u8],
doc_id: u32,
tf: u16,
next_doc_id: Option<u32>,
) -> f32 {
if let Some(next) = next_doc_id {
prefetch_fieldnorm(fieldnorm_bytes, next);
}
let tf_f = tf as f32;
let len_norm = self.norm_table[fieldnorm_bytes[doc_id as usize] as usize];
self.idf_k1p1 * tf_f / (tf_f + len_norm)
}
pub fn batch_score_with_prefetch(
&self,
fieldnorm_bytes: &[u8],
doc_ids: &[u32],
tfs: &[u16],
scores: &mut [f32],
) {
let n = doc_ids.len();
assert_eq!(tfs.len(), n);
assert!(scores.len() >= n);
const PREFETCH_DISTANCE: usize = 4;
for i in 0..n {
if i + PREFETCH_DISTANCE < n {
prefetch_fieldnorm(fieldnorm_bytes, doc_ids[i + PREFETCH_DISTANCE]);
}
let tf_f = tfs[i] as f32;
let len_norm = self.norm_table[fieldnorm_bytes[doc_ids[i] as usize] as usize];
scores[i] = self.idf_k1p1 * tf_f / (tf_f + len_norm);
}
}
}
#[inline]
pub fn prefetch_fieldnorm(fieldnorm_bytes: &[u8], doc_id: u32) {
let idx = doc_id as usize;
if idx < fieldnorm_bytes.len() {
#[cfg(target_arch = "x86_64")]
{
unsafe {
std::arch::x86_64::_mm_prefetch(
fieldnorm_bytes.as_ptr().add(idx) as *const i8,
std::arch::x86_64::_MM_HINT_T0,
);
}
}
#[cfg(target_arch = "aarch64")]
{
unsafe {
std::arch::aarch64::_prefetch(
fieldnorm_bytes.as_ptr().add(idx) as *const i8,
std::arch::aarch64::_PREFETCH_READ,
std::arch::aarch64::_PREFETCH_LOCALITY3,
);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::scoring::FieldnormEncoder;
fn make_test_scorer() -> ([f32; 256], f32, f32) {
let avg_dl = 150.0f32;
let k1 = 1.2f32;
let b = 0.75f32;
let idf = 3.5f32;
let norm_table = FieldnormEncoder::build_norm_table(avg_dl, k1, b);
(norm_table, idf, k1)
}
#[test]
fn test_single_score() {
let (norm_table, idf, k1) = make_test_scorer();
let scorer = Bm25BatchScorer::new(&norm_table, idf, k1);
assert_eq!(scorer.score(0, 0), 0.0);
let s = scorer.score(FieldnormEncoder::encode(100), 1);
assert!(s > 0.0 && s.is_finite());
let s1 = scorer.score(FieldnormEncoder::encode(100), 1);
let s5 = scorer.score(FieldnormEncoder::encode(100), 5);
assert!(s5 > s1);
let short = scorer.score(FieldnormEncoder::encode(50), 2);
let long = scorer.score(FieldnormEncoder::encode(500), 2);
assert!(short > long);
}
#[test]
fn test_batch_score_correctness() {
let (norm_table, idf, k1) = make_test_scorer();
let scorer = Bm25BatchScorer::new(&norm_table, idf, k1);
let doc_lengths: Vec<u32> = (1..=100).map(|i| i * 10).collect();
let fieldnorm_bytes: Vec<u8> = doc_lengths.iter()
.map(|&l| FieldnormEncoder::encode(l))
.collect();
let tfs: Vec<u16> = (1..=100).map(|i| (i % 20 + 1) as u16).collect();
let n = tfs.len();
let expected: Vec<f32> = (0..n)
.map(|i| scorer.score(fieldnorm_bytes[i], tfs[i]))
.collect();
let mut batch_scores = vec![0.0f32; n];
scorer.batch_score(&fieldnorm_bytes, &tfs, &mut batch_scores);
for i in 0..n {
assert!(
(batch_scores[i] - expected[i]).abs() < 1e-5,
"batch[{i}]={} != expected[{i}]={}",
batch_scores[i], expected[i]
);
}
}
#[test]
fn test_batch_score_small_arrays() {
let (norm_table, idf, k1) = make_test_scorer();
let scorer = Bm25BatchScorer::new(&norm_table, idf, k1);
for size in 0..8 {
let fieldnorm_bytes: Vec<u8> = (0..size).map(|i| FieldnormEncoder::encode(i as u32 * 50 + 50)).collect();
let tfs: Vec<u16> = (0..size).map(|i| (i + 1) as u16).collect();
let mut scores = vec![0.0f32; size];
scorer.batch_score(&fieldnorm_bytes, &tfs, &mut scores);
for i in 0..size {
let expected = scorer.score(fieldnorm_bytes[i], tfs[i]);
assert!(
(scores[i] - expected).abs() < 1e-5,
"size={size}, scores[{i}]={} != expected={}",
scores[i], expected
);
}
}
}
#[test]
fn test_batch_score_exact_multiple_of_8() {
let (norm_table, idf, k1) = make_test_scorer();
let scorer = Bm25BatchScorer::new(&norm_table, idf, k1);
let n = 64; let fieldnorm_bytes: Vec<u8> = (0..n).map(|i| FieldnormEncoder::encode(i as u32 * 10 + 10)).collect();
let tfs: Vec<u16> = (0..n).map(|i| (i % 10 + 1) as u16).collect();
let mut scores = vec![0.0f32; n];
scorer.batch_score(&fieldnorm_bytes, &tfs, &mut scores);
for i in 0..n {
let expected = scorer.score(fieldnorm_bytes[i], tfs[i]);
assert!((scores[i] - expected).abs() < 1e-5);
}
}
#[test]
fn test_score_with_prefetch() {
let (norm_table, idf, k1) = make_test_scorer();
let scorer = Bm25BatchScorer::new(&norm_table, idf, k1);
let fieldnorm_bytes: Vec<u8> = (0..1000)
.map(|i| FieldnormEncoder::encode(i as u32 * 5 + 10))
.collect();
let s1 = scorer.score_with_prefetch(&fieldnorm_bytes, 42, 3, Some(100));
let s2 = scorer.score(fieldnorm_bytes[42], 3);
assert_eq!(s1, s2);
let s3 = scorer.score_with_prefetch(&fieldnorm_bytes, 42, 3, None);
assert_eq!(s3, s2);
}
#[test]
fn test_batch_score_with_prefetch() {
let (norm_table, idf, k1) = make_test_scorer();
let scorer = Bm25BatchScorer::new(&norm_table, idf, k1);
let n = 1000;
let fieldnorm_bytes: Vec<u8> = (0..n)
.map(|i| FieldnormEncoder::encode(i as u32 * 5 + 10))
.collect();
let doc_ids: Vec<u32> = (0..50).map(|i| i * 19 + 3).collect();
let tfs: Vec<u16> = (0..50).map(|i| (i % 8 + 1) as u16).collect();
let mut scores = vec![0.0f32; 50];
scorer.batch_score_with_prefetch(&fieldnorm_bytes, &doc_ids, &tfs, &mut scores);
for i in 0..50 {
let expected = scorer.score(fieldnorm_bytes[doc_ids[i] as usize], tfs[i]);
assert!(
(scores[i] - expected).abs() < 1e-5,
"prefetch_batch[{i}]={} != expected={}",
scores[i], expected
);
}
}
#[test]
fn test_prefetch_fieldnorm_bounds() {
let fieldnorm_bytes = vec![0u8; 100];
prefetch_fieldnorm(&fieldnorm_bytes, 99); prefetch_fieldnorm(&fieldnorm_bytes, 100); prefetch_fieldnorm(&fieldnorm_bytes, u32::MAX); }
#[test]
#[cfg(not(debug_assertions))]
fn test_batch_score_performance() {
use std::time::Instant;
let (norm_table, idf, k1) = make_test_scorer();
let scorer = Bm25BatchScorer::new(&norm_table, idf, k1);
let n = 1_000_000usize;
let fieldnorm_bytes: Vec<u8> = (0..n)
.map(|i| FieldnormEncoder::encode((i % 500 + 1) as u32))
.collect();
let tfs: Vec<u16> = (0..n).map(|i| (i % 20 + 1) as u16).collect();
let mut scores = vec![0.0f32; n];
scorer.batch_score(&fieldnorm_bytes, &tfs, &mut scores);
let start = Instant::now();
for _ in 0..10 {
scorer.batch_score(&fieldnorm_bytes, &tfs, &mut scores);
}
let batch_time = start.elapsed() / 10;
let start = Instant::now();
for _ in 0..10 {
scorer.batch_score_scalar(&fieldnorm_bytes, &tfs, &mut scores, n);
}
let scalar_time = start.elapsed() / 10;
let speedup = scalar_time.as_nanos() as f64 / batch_time.as_nanos().max(1) as f64;
eprintln!(
"BM25 batch scoring ({n} postings): batch(SIMD)={:?}, scalar={:?}, speedup={:.1}x",
batch_time, scalar_time, speedup
);
assert!(speedup >= 0.5, "batch scoring much slower than scalar: {speedup:.2}x");
assert!(scores.iter().all(|&s| s > 0.0));
}
#[test]
#[cfg(not(debug_assertions))]
fn test_prefetch_performance() {
use std::time::Instant;
let (norm_table, idf, k1) = make_test_scorer();
let scorer = Bm25BatchScorer::new(&norm_table, idf, k1);
let n_docs = 10_000_000usize;
let fieldnorm_bytes: Vec<u8> = (0..n_docs)
.map(|i| FieldnormEncoder::encode((i % 500 + 1) as u32))
.collect();
let n_postings = 100_000usize;
let doc_ids: Vec<u32> = (0..n_postings)
.map(|i| (i as u64 * 97 % n_docs as u64) as u32)
.collect();
let mut doc_ids_sorted = doc_ids.clone();
doc_ids_sorted.sort();
let tfs: Vec<u16> = (0..n_postings).map(|i| (i % 8 + 1) as u16).collect();
let mut scores = vec![0.0f32; n_postings];
let start = Instant::now();
for _ in 0..5 {
scorer.batch_score_with_prefetch(
&fieldnorm_bytes,
&doc_ids_sorted,
&tfs,
&mut scores,
);
}
let prefetch_time = start.elapsed() / 5;
let start = Instant::now();
for _ in 0..5 {
for i in 0..n_postings {
let tf_f = tfs[i] as f32;
let len_norm = norm_table[fieldnorm_bytes[doc_ids_sorted[i] as usize] as usize];
scores[i] = scorer.idf_k1p1 * tf_f / (tf_f + len_norm);
}
}
let no_prefetch_time = start.elapsed() / 5;
let speedup = no_prefetch_time.as_nanos() as f64 / prefetch_time.as_nanos().max(1) as f64;
eprintln!(
"Phrase scoring ({n_postings} postings, {n_docs} docs): \
prefetch={:?}, no_prefetch={:?}, speedup={:.2}x",
prefetch_time, no_prefetch_time, speedup
);
assert!(scores.iter().all(|&s| s > 0.0));
}
}