use super::scoring::{ScoreCollector, ScoredDoc};
use crate::segment::{
BMP_SUPERBLOCK_SIZE, BmpIndex, accumulate_u4_weighted, block_term_postings,
compute_block_masks_4bit, find_dim_in_block_data,
};
#[inline(always)]
fn prefetch_read<T>(ptr: *const T) {
#[cfg(target_arch = "aarch64")]
unsafe {
std::arch::asm!(
"prfm pldl1keep, [{0}]",
in(reg) ptr,
options(nostack, preserves_flags)
);
}
#[cfg(target_arch = "x86_64")]
unsafe {
std::arch::x86_64::_mm_prefetch(ptr as *const i8, std::arch::x86_64::_MM_HINT_T0);
}
#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
{
let _ = ptr;
}
}
#[derive(Default)]
struct BmpScratch {
sb_ubs: Vec<f32>,
sb_order: Vec<u32>,
sb_priorities: Vec<f32>,
sb_suffix_max: Vec<f32>,
local_block_ubs: Vec<f32>,
local_block_masks: Vec<u64>,
local_block_order: Vec<u32>,
phase1_local_block_ubs: Vec<f32>,
acc: Vec<u32>,
compact_sb_grid: Vec<u8>,
compact_grid: Vec<u8>,
}
impl BmpScratch {
fn ensure_capacity_sb(&mut self, num_superblocks: usize, sb_size: usize, block_size: usize) {
if self.sb_ubs.len() < num_superblocks {
self.sb_ubs.resize(num_superblocks, 0.0);
}
if self.sb_order.capacity() < num_superblocks {
self.sb_order.reserve(num_superblocks - self.sb_order.len());
}
if self.sb_priorities.len() < num_superblocks {
self.sb_priorities.resize(num_superblocks, 0.0);
}
if self.sb_suffix_max.len() < num_superblocks + 1 {
self.sb_suffix_max.resize(num_superblocks + 1, 0.0);
}
if self.local_block_ubs.len() < sb_size {
self.local_block_ubs.resize(sb_size, 0.0);
}
if self.local_block_masks.len() < sb_size {
self.local_block_masks.resize(sb_size, 0u64);
}
if self.local_block_order.len() < sb_size {
self.local_block_order.resize(sb_size, 0);
}
if self.phase1_local_block_ubs.len() < sb_size {
self.phase1_local_block_ubs.resize(sb_size, 0.0);
}
if self.acc.len() < block_size {
self.acc.resize(block_size, 0);
}
}
}
thread_local! {
static BMP_SCRATCH: std::cell::RefCell<BmpScratch> =
std::cell::RefCell::new(BmpScratch::default());
}
pub fn execute_bmp(
index: &BmpIndex,
query_terms: &[(u32, f32)],
k: usize,
heap_factor: f32,
max_superblocks: usize,
) -> crate::Result<Vec<ScoredDoc>> {
execute_bmp_inner(index, query_terms, k, heap_factor, max_superblocks, None)
}
pub fn execute_bmp_filtered(
index: &BmpIndex,
query_terms: &[(u32, f32)],
k: usize,
heap_factor: f32,
max_superblocks: usize,
predicate: &dyn Fn(crate::DocId) -> bool,
) -> crate::Result<Vec<ScoredDoc>> {
execute_bmp_inner(
index,
query_terms,
k,
heap_factor,
max_superblocks,
Some(predicate),
)
}
fn execute_bmp_inner(
index: &BmpIndex,
query_terms: &[(u32, f32)],
k: usize,
heap_factor: f32,
max_superblocks: usize,
predicate: Option<&dyn Fn(crate::DocId) -> bool>,
) -> crate::Result<Vec<ScoredDoc>> {
if query_terms.is_empty() || index.num_blocks == 0 {
return Ok(Vec::new());
}
let alpha = heap_factor.clamp(0.01, 1.0);
let scale = index.max_weight_scale / 255.0;
let num_blocks = index.num_blocks as usize;
let block_size = index.bmp_block_size as usize;
let num_superblocks_total = index.num_superblocks as usize;
let dims = index.dims();
let mut query_info: Vec<(u32, usize, f32)> = Vec::with_capacity(query_terms.len());
for &(dim_id, weight) in query_terms {
if dim_id < dims {
let scaled = weight * scale;
query_info.push((dim_id, dim_id as usize, scaled));
}
}
if query_info.is_empty() {
return Ok(Vec::new());
}
query_info.sort_unstable_by_key(|&(dim_id, _, _)| dim_id);
let resolved: Vec<(usize, f32)> = query_info.iter().map(|&(_, idx, w)| (idx, w)).collect();
let max_scaled = query_info.iter().map(|q| q.2.abs()).fold(0.0f32, f32::max);
let (quant_scale, dequant) = if max_scaled > 0.0 {
(16383.0 / max_scaled, max_scaled / 16383.0)
} else {
(0.0, 0.0)
};
let query_by_dim_u16: Vec<(u32, u16)> = query_info
.iter()
.map(|&(dim_id, _, w)| {
(
dim_id,
(w.abs() * quant_scale).round().clamp(0.0, 16383.0) as u16,
)
})
.collect();
const PHASE1_DIMS: usize = 3;
const MIN_DIMS_FOR_TWO_PHASE: usize = 6;
let two_phase_active = query_by_dim_u16.len() >= MIN_DIMS_FOR_TWO_PHASE;
let phase1_mask: u64 = if two_phase_active {
let mut weight_indices: Vec<(u16, usize)> = query_by_dim_u16
.iter()
.enumerate()
.map(|(i, &(_, w))| (w, i))
.collect();
weight_indices.sort_unstable_by_key(|b| std::cmp::Reverse(b.0));
weight_indices[..PHASE1_DIMS]
.iter()
.fold(0u64, |m, &(_, i)| m | (1u64 << i))
} else {
u64::MAX
};
let phase1_grid_indices: Vec<usize> = if two_phase_active {
(0..query_by_dim_u16.len())
.filter(|&i| phase1_mask & (1u64 << i) != 0)
.collect()
} else {
Vec::new()
};
let ordinals_per_doc = if index.num_real_docs() > 0 {
(index.num_virtual_docs as f32 / index.num_real_docs() as f32).ceil() as usize
} else {
1
};
let collector_k = (k * ordinals_per_doc).min(k * 10);
let t_start = std::time::Instant::now();
let result = BMP_SCRATCH.with(|cell| {
let scratch = &mut *cell.borrow_mut();
scratch.ensure_capacity_sb(
num_superblocks_total,
BMP_SUPERBLOCK_SIZE as usize,
block_size,
);
let prs = index.packed_row_size();
const COMPACT_GRID_MAX: usize = 128 * 1024; let compact_sb_size = resolved.len() * num_superblocks_total;
let compact_grid_size = resolved.len() * prs;
let use_compact = compact_sb_size + compact_grid_size <= COMPACT_GRID_MAX;
let grid_dims: Vec<(usize, f32)>;
let sb_int_weights: Vec<(usize, u16)>;
let sb_grid_slice: &[u8];
let grid_slice: &[u8];
if use_compact {
let dim_indices: Vec<usize> = resolved.iter().map(|&(idx, _)| idx).collect();
index.extract_compact_grids(
&dim_indices,
&mut scratch.compact_sb_grid,
&mut scratch.compact_grid,
);
grid_dims = (0..resolved.len()).map(|i| (i, resolved[i].1)).collect();
sb_int_weights = (0..resolved.len())
.map(|i| (i, query_by_dim_u16[i].1))
.collect();
sb_grid_slice = &scratch.compact_sb_grid;
grid_slice = &scratch.compact_grid;
} else {
grid_dims = resolved.clone();
sb_int_weights = resolved
.iter()
.enumerate()
.map(|(i, &(idx, _))| (idx, query_by_dim_u16[i].1))
.collect();
sb_grid_slice = index.sb_grid_slice();
grid_slice = index.grid_slice();
if scratch.compact_grid.capacity() > COMPACT_GRID_MAX {
scratch.compact_sb_grid = Vec::new();
scratch.compact_grid = Vec::new();
}
}
compute_sb_ubs_int(
sb_grid_slice,
num_superblocks_total,
&sb_int_weights,
dequant,
&mut scratch.sb_ubs,
);
let nqd = sb_int_weights.len();
for sb in 0..num_superblocks_total {
let base_ub = scratch.sb_ubs[sb];
if base_ub == 0.0 {
scratch.sb_priorities[sb] = 0.0;
continue;
}
let mut coverage = 0u32;
for &(local_idx, _) in &sb_int_weights {
if sb_grid_slice[local_idx * num_superblocks_total + sb] > 0 {
coverage += 1;
}
}
let cf = coverage as f32 / nqd as f32;
scratch.sb_priorities[sb] = base_ub * (1.0 + cf * 0.05);
}
sort_sb_desc_into(
&scratch.sb_priorities[..num_superblocks_total],
&mut scratch.sb_order,
);
if scratch.sb_order.is_empty() {
return Vec::new();
}
compute_suffix_max_ubs(
&scratch.sb_ubs,
&scratch.sb_order,
&mut scratch.sb_suffix_max,
);
let mut blocks_scored = 0u32;
let mut sbs_scored = 0u32;
let mut collector = ScoreCollector::new(collector_k);
for (idx, &sb_id) in scratch.sb_order.iter().enumerate() {
if max_superblocks > 0 && idx >= max_superblocks {
break;
}
if collector.len() >= collector_k
&& scratch.sb_suffix_max[idx] * alpha <= collector.threshold()
{
break;
}
let sb_ub = scratch.sb_ubs[sb_id as usize];
if collector.len() >= collector_k && sb_ub * alpha <= collector.threshold() {
continue;
}
let block_start = sb_id as usize * BMP_SUPERBLOCK_SIZE as usize;
let block_end = (block_start + BMP_SUPERBLOCK_SIZE as usize).min(num_blocks);
let count = block_end - block_start;
{
let bds_base = index.block_data_starts_ptr(0);
for b in (block_start..block_end + 1).step_by(8) {
prefetch_read(unsafe { bds_base.add(b * 8) });
}
}
compute_block_ubs_compact(
grid_slice,
prs,
&grid_dims,
block_start,
block_end,
&mut scratch.local_block_ubs,
);
compute_block_masks_4bit(
grid_slice,
prs,
&grid_dims,
block_start,
block_end - block_start,
&mut scratch.local_block_masks,
);
if two_phase_active {
let phase1_dims: Vec<(usize, f32)> = phase1_grid_indices
.iter()
.map(|&i| grid_dims[i])
.collect();
compute_block_ubs_compact(
grid_slice,
prs,
&phase1_dims,
block_start,
block_end,
&mut scratch.phase1_local_block_ubs,
);
}
sort_local_blocks_desc(
&scratch.local_block_ubs[..count],
&mut scratch.local_block_order,
);
score_superblock_blocks(
index,
block_start,
count,
&scratch.local_block_order,
&scratch.local_block_ubs,
&scratch.local_block_masks,
&query_by_dim_u16,
dequant,
alpha,
collector_k,
&predicate,
&mut collector,
&mut blocks_scored,
&mut scratch.acc,
phase1_mask,
if two_phase_active {
Some(&scratch.phase1_local_block_ubs)
} else {
None
},
);
if let Some(&next_sb) = scratch.sb_order.get(idx + 1) {
let next_start = next_sb as usize * BMP_SUPERBLOCK_SIZE as usize;
let next_end = (next_start + BMP_SUPERBLOCK_SIZE as usize).min(num_blocks);
let bds_base = index.block_data_starts_ptr(0);
for b in (next_start..next_end + 1).step_by(8) {
prefetch_read(unsafe { bds_base.add(b * 8) });
}
}
sbs_scored += 1;
}
let elapsed_ms = t_start.elapsed().as_secs_f64() * 1000.0;
let threshold = collector.threshold();
if elapsed_ms > 500.0 {
log::warn!(
"slow BMP: {:.1}ms, sbs={}/{}, blocks={}/{}, returned={}, threshold={:.4}, alpha={:.2}",
elapsed_ms,
sbs_scored,
num_superblocks_total,
blocks_scored,
num_blocks,
collector.len(),
threshold,
alpha,
);
} else {
log::debug!(
"BMP execute: {:.1}ms, sbs={}/{}, blocks={}/{}, returned={}, threshold={:.4}, alpha={:.2}",
elapsed_ms,
sbs_scored,
num_superblocks_total,
blocks_scored,
num_blocks,
collector.len(),
threshold,
alpha,
);
}
collector_to_results(collector)
});
Ok(result)
}
#[inline(always)]
fn max_touched_acc(acc: &[u32], touched: &[u64; 4]) -> u32 {
let mut max_val = 0u32;
for word in 0..4 {
let mut bits = touched[word];
while bits != 0 {
let bit = bits.trailing_zeros() as usize;
bits &= bits - 1;
max_val = max_val.max(acc[word * 64 + bit]);
}
}
max_val
}
#[inline(always)]
fn zero_touched_acc(acc: &mut [u32], touched: &[u64; 4]) {
for word in 0..4 {
let mut bits = touched[word];
while bits != 0 {
let bit = bits.trailing_zeros() as usize;
bits &= bits - 1;
acc[word * 64 + bit] = 0;
}
}
}
#[inline(always)]
#[allow(clippy::too_many_arguments)]
fn score_block_bsearch_int(
num_terms: u16,
dim_ptr: *const u8,
ps_ptr: *const u8,
post_ptr: *const u8,
query_by_dim_u16: &[(u32, u16)],
block_mask: u64,
acc: &mut [u32],
touched: &mut [u64; 4],
) {
for (q, &(dim_id, w)) in query_by_dim_u16.iter().enumerate() {
if block_mask & (1u64 << q) == 0 {
continue;
}
if let Some(local_term) = find_dim_in_block_data(dim_ptr, num_terms, dim_id) {
let postings = unsafe { block_term_postings(ps_ptr, post_ptr, local_term) };
for p in postings {
let slot = p.local_slot as usize;
unsafe {
*acc.get_unchecked_mut(slot) += w as u32 * p.impact as u32;
}
touched[slot / 64] |= 1u64 << (slot % 64);
}
}
}
}
#[allow(clippy::too_many_arguments)]
fn score_superblock_blocks(
index: &BmpIndex,
block_start: usize,
count: usize,
local_order: &[u32],
local_ubs: &[f32],
local_masks: &[u64],
query_by_dim_u16: &[(u32, u16)],
dequant: f32,
alpha: f32,
k: usize,
predicate: &Option<&dyn Fn(crate::DocId) -> bool>,
collector: &mut ScoreCollector,
blocks_scored: &mut u32,
acc: &mut [u32],
phase1_mask: u64,
phase1_local_ubs: Option<&[f32]>,
) {
let block_size = index.bmp_block_size as usize;
let num_vdocs_total = index.num_virtual_docs as usize;
let two_phase = phase1_mask != u64::MAX && phase1_local_ubs.is_some();
for &li in local_order.iter().take(4) {
let li = li as usize;
if li >= count {
break;
}
prefetch_read(index.block_data_ptr((block_start + li) as u32));
}
for (order_idx, &local_idx) in local_order.iter().enumerate() {
if local_idx as usize >= count {
break;
}
let ub = local_ubs[local_idx as usize];
if collector.len() >= k && ub * alpha <= collector.threshold() {
break;
}
let block_id = (block_start + local_idx as usize) as u32;
let pred_mask: [u64; 4] = if let Some(pred) = predicate {
let base = block_id as usize * block_size;
let end = (base + block_size).min(num_vdocs_total);
let mut mask = [0u64; 4];
for slot in 0..(end - base) {
let doc_id = index.doc_id_for_virtual((base + slot) as u32);
if doc_id != u32::MAX && pred(doc_id) {
mask[slot / 64] |= 1u64 << (slot % 64);
}
}
if mask == [0u64; 4] {
continue; }
mask
} else {
[u64::MAX; 4]
};
if order_idx + 1 < local_order.len() {
let next_local = local_order[order_idx + 1] as usize;
if next_local < count {
prefetch_read(index.block_data_ptr((block_start + next_local) as u32));
if order_idx + 2 < local_order.len() {
let next2_local = local_order[order_idx + 2] as usize;
if next2_local < count {
prefetch_read(index.block_data_ptr((block_start + next2_local) as u32));
}
}
}
}
let (num_terms, dim_ptr, ps_ptr, post_ptr) = index.parse_block(block_id);
let mask = local_masks[local_idx as usize];
let mut touched = [0u64; 4];
if num_terms > 0 {
if two_phase && collector.len() >= k {
score_block_bsearch_int(
num_terms,
dim_ptr,
ps_ptr,
post_ptr,
query_by_dim_u16,
mask & phase1_mask,
acc,
&mut touched,
);
let max_partial = max_touched_acc(acc, &touched) as f32 * dequant;
let phase1_ub = phase1_local_ubs.unwrap()[local_idx as usize];
let remaining_ub = (ub - phase1_ub).max(0.0);
if (max_partial + remaining_ub) * alpha <= collector.threshold() {
zero_touched_acc(acc, &touched);
*blocks_scored += 1;
continue;
}
score_block_bsearch_int(
num_terms,
dim_ptr,
ps_ptr,
post_ptr,
query_by_dim_u16,
mask & !phase1_mask,
acc,
&mut touched,
);
} else {
score_block_bsearch_int(
num_terms,
dim_ptr,
ps_ptr,
post_ptr,
query_by_dim_u16,
mask,
acc,
&mut touched,
);
}
}
let base = block_id as usize * block_size;
let num_vdocs = index.num_virtual_docs as usize;
for word in 0..4 {
let mut reject = touched[word] & !pred_mask[word];
while reject != 0 {
let bit = reject.trailing_zeros() as usize;
reject &= reject - 1;
acc[word * 64 + bit] = 0;
}
let mut scan = touched[word] & pred_mask[word];
while scan != 0 {
let bit = scan.trailing_zeros() as usize;
scan &= scan - 1;
let i = word * 64 + bit;
let score_u32 = acc[i];
acc[i] = 0;
if score_u32 == 0 {
continue;
}
let virtual_id = base + i;
if virtual_id >= num_vdocs {
continue;
}
let (doc_id, ordinal) = index.virtual_to_doc(virtual_id as u32);
if doc_id == u32::MAX {
continue;
}
let score = score_u32 as f32 * dequant;
if collector.would_enter(score) {
collector.insert_with_ordinal(doc_id, score, ordinal);
}
}
}
*blocks_scored += 1;
}
}
fn sort_local_blocks_desc(local_ubs: &[f32], out: &mut Vec<u32>) {
out.clear();
for (i, &ub) in local_ubs.iter().enumerate() {
if ub > 0.0 {
out.push(i as u32);
}
}
out.sort_unstable_by(|&a, &b| {
local_ubs[b as usize]
.partial_cmp(&local_ubs[a as usize])
.unwrap_or(std::cmp::Ordering::Equal)
});
}
fn sort_sb_desc_into(values: &[f32], out: &mut Vec<u32>) {
out.clear();
for (i, &v) in values.iter().enumerate() {
if v > 0.0 {
out.push(i as u32);
}
}
out.sort_unstable_by(|&a, &b| {
values[b as usize]
.partial_cmp(&values[a as usize])
.unwrap_or(std::cmp::Ordering::Equal)
});
}
fn compute_suffix_max_ubs(sb_ubs: &[f32], order: &[u32], out: &mut [f32]) {
let n = order.len();
out[n] = 0.0;
for i in (0..n).rev() {
let ub = sb_ubs[order[i] as usize];
out[i] = ub.max(out[i + 1]);
}
}
fn collector_to_results(collector: ScoreCollector) -> Vec<ScoredDoc> {
collector
.into_sorted_results()
.into_iter()
.map(|(doc_id, score, ordinal)| ScoredDoc {
doc_id,
score,
ordinal,
})
.collect()
}
#[inline]
fn compute_sb_ubs_int(
compact_sb_grid: &[u8],
nsb: usize,
int_weights: &[(usize, u16)],
dequant: f32,
out: &mut [f32],
) {
debug_assert!(out.len() >= nsb);
for sb in 0..nsb {
let mut acc: u32 = 0;
for &(local_idx, w) in int_weights {
let val = compact_sb_grid[local_idx * nsb + sb];
acc += w as u32 * val as u32;
}
out[sb] = acc as f32 * dequant;
}
}
#[inline]
fn compute_block_ubs_compact(
compact_grid: &[u8],
prs: usize,
compact_dims: &[(usize, f32)],
block_start: usize,
block_end: usize,
out: &mut [f32],
) {
let count = block_end - block_start;
debug_assert!(out.len() >= count);
out[..count].fill(0.0);
for &(local_idx, weight) in compact_dims {
let row = &compact_grid[local_idx * prs..(local_idx + 1) * prs];
accumulate_u4_weighted(row, block_start, count, weight, &mut out[..count]);
}
}