use std::cmp::Ordering;
use std::collections::BinaryHeap;
use log::{debug, warn};
use crate::DocId;
#[derive(Clone, Copy)]
pub struct HeapEntry {
pub doc_id: DocId,
pub score: f32,
pub ordinal: u16,
}
impl PartialEq for HeapEntry {
fn eq(&self, other: &Self) -> bool {
self.score == other.score && self.doc_id == other.doc_id
}
}
impl Eq for HeapEntry {}
impl Ord for HeapEntry {
fn cmp(&self, other: &Self) -> Ordering {
other
.score
.total_cmp(&self.score)
.then(self.doc_id.cmp(&other.doc_id))
}
}
impl PartialOrd for HeapEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
pub struct ScoreCollector {
heap: BinaryHeap<HeapEntry>,
pub k: usize,
cached_threshold: f32,
}
impl ScoreCollector {
pub fn new(k: usize) -> Self {
let capacity = k.saturating_add(1).min(1_000_000);
Self {
heap: BinaryHeap::with_capacity(capacity),
k,
cached_threshold: 0.0,
}
}
#[inline]
pub fn threshold(&self) -> f32 {
self.cached_threshold
}
#[inline]
fn update_threshold(&mut self) {
self.cached_threshold = if self.heap.len() >= self.k {
self.heap.peek().map(|e| e.score).unwrap_or(0.0)
} else {
0.0
};
}
#[inline]
pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
self.insert_with_ordinal(doc_id, score, 0)
}
#[inline]
pub fn insert_with_ordinal(&mut self, doc_id: DocId, score: f32, ordinal: u16) -> bool {
if self.heap.len() < self.k {
self.heap.push(HeapEntry {
doc_id,
score,
ordinal,
});
if self.heap.len() == self.k {
self.update_threshold();
}
true
} else if score > self.cached_threshold {
self.heap.push(HeapEntry {
doc_id,
score,
ordinal,
});
self.heap.pop(); self.update_threshold();
true
} else {
false
}
}
#[inline]
pub fn would_enter(&self, score: f32) -> bool {
self.heap.len() < self.k || score > self.cached_threshold
}
#[inline]
pub fn len(&self) -> usize {
self.heap.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.heap.is_empty()
}
pub fn seed_threshold(&mut self, initial_threshold: f32) {
if initial_threshold > 0.0 && self.heap.is_empty() {
for _ in 0..self.k {
self.heap.push(HeapEntry {
doc_id: u32::MAX,
score: initial_threshold,
ordinal: 0,
});
}
self.update_threshold();
}
}
pub fn into_sorted_results(self) -> Vec<(DocId, f32, u16)> {
let mut results: Vec<(DocId, f32, u16)> = self
.heap
.into_vec()
.into_iter()
.filter(|e| e.doc_id != u32::MAX)
.map(|e| (e.doc_id, e.score, e.ordinal))
.collect();
results.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then(a.0.cmp(&b.0)));
results
}
}
#[derive(Debug, Clone, Copy)]
pub struct ScoredDoc {
pub doc_id: DocId,
pub score: f32,
pub ordinal: u16,
}
pub struct MaxScoreExecutor<'a> {
cursors: Vec<TermCursor<'a>>,
prefix_sums: Vec<f32>,
collector: ScoreCollector,
inv_heap_factor: f32,
predicate: Option<super::DocPredicate<'a>>,
}
pub(crate) struct TermCursor<'a> {
pub max_score: f32,
num_blocks: usize,
block_idx: usize,
doc_ids: Vec<u32>,
scores: Vec<f32>,
ordinals: Vec<u16>,
pos: usize,
block_loaded: bool,
exhausted: bool,
lazy_ordinals: bool,
ordinals_loaded: bool,
current_sparse_block: Option<crate::structures::SparseBlock>,
variant: CursorVariant<'a>,
}
enum CursorVariant<'a> {
Text {
list: crate::structures::BlockPostingList,
idf: f32,
idf_times_k1_plus_1: f32,
denom_tf_coeff: f32,
denom_const: f32,
tfs: Vec<u32>,
deferred_tf: Option<(usize, usize, usize)>,
},
Sparse {
si: &'a crate::segment::SparseIndex,
query_weight: f32,
skip_start: usize,
block_data_offset: u64,
},
}
macro_rules! cursor_ensure_block {
($self:ident, $load_block_fn:ident, $($aw:tt)*) => {{
if $self.exhausted || $self.block_loaded {
return Ok(!$self.exhausted);
}
match &mut $self.variant {
CursorVariant::Text {
list,
deferred_tf,
..
} => {
if let Some(state) = list.decode_block_doc_ids_only($self.block_idx, &mut $self.doc_ids) {
*deferred_tf = Some(state);
$self.scores.clear();
$self.pos = 0;
$self.block_loaded = true;
Ok(true)
} else {
$self.exhausted = true;
Ok(false)
}
}
CursorVariant::Sparse {
si,
query_weight,
skip_start,
block_data_offset,
..
} => {
let block = si
.$load_block_fn(*skip_start, *block_data_offset, $self.block_idx)
$($aw)* ?;
match block {
Some(b) => {
b.decode_doc_ids_into(&mut $self.doc_ids);
b.decode_scored_weights_into(*query_weight, &mut $self.scores);
if $self.lazy_ordinals {
$self.current_sparse_block = Some(b);
$self.ordinals_loaded = false;
} else {
b.decode_ordinals_into(&mut $self.ordinals);
$self.ordinals_loaded = true;
$self.current_sparse_block = None;
}
$self.pos = 0;
$self.block_loaded = true;
Ok(true)
}
None => {
$self.exhausted = true;
Ok(false)
}
}
}
}
}};
}
macro_rules! cursor_advance {
($self:ident, $ensure_fn:ident, $($aw:tt)*) => {{
if $self.exhausted {
return Ok(u32::MAX);
}
$self.$ensure_fn() $($aw)* ?;
if $self.exhausted {
return Ok(u32::MAX);
}
Ok($self.advance_pos())
}};
}
macro_rules! cursor_seek {
($self:ident, $ensure_fn:ident, $target:expr, $($aw:tt)*) => {{
if let Some(doc) = $self.seek_prepare($target) {
return Ok(doc);
}
$self.$ensure_fn() $($aw)* ?;
if $self.seek_finish($target) {
$self.$ensure_fn() $($aw)* ?;
}
Ok($self.doc())
}};
}
impl<'a> TermCursor<'a> {
pub fn text(
posting_list: crate::structures::BlockPostingList,
idf: f32,
avg_field_len: f32,
) -> Self {
let max_tf = posting_list.max_tf() as f32;
let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
let num_blocks = posting_list.num_blocks();
let safe_avg = avg_field_len.max(1.0);
Self {
max_score,
num_blocks,
block_idx: 0,
doc_ids: Vec::with_capacity(128),
scores: Vec::with_capacity(128),
ordinals: Vec::new(),
pos: 0,
block_loaded: false,
exhausted: num_blocks == 0,
lazy_ordinals: false,
ordinals_loaded: true, current_sparse_block: None,
variant: CursorVariant::Text {
list: posting_list,
idf,
idf_times_k1_plus_1: idf * (super::BM25_K1 + 1.0),
denom_tf_coeff: 1.0 + super::BM25_K1 * (super::BM25_B / safe_avg),
denom_const: super::BM25_K1 * (1.0 - super::BM25_B),
tfs: Vec::with_capacity(128),
deferred_tf: None,
},
}
}
pub fn sparse(
si: &'a crate::segment::SparseIndex,
query_weight: f32,
skip_start: usize,
skip_count: usize,
global_max_weight: f32,
block_data_offset: u64,
) -> Self {
Self {
max_score: query_weight.abs() * global_max_weight,
num_blocks: skip_count,
block_idx: 0,
doc_ids: Vec::with_capacity(256),
scores: Vec::with_capacity(256),
ordinals: Vec::with_capacity(256),
pos: 0,
block_loaded: false,
exhausted: skip_count == 0,
lazy_ordinals: false,
ordinals_loaded: true,
current_sparse_block: None,
variant: CursorVariant::Sparse {
si,
query_weight,
skip_start,
block_data_offset,
},
}
}
#[inline]
fn block_first_doc(&self, idx: usize) -> DocId {
match &self.variant {
CursorVariant::Text { list, .. } => list.block_first_doc(idx).unwrap_or(u32::MAX),
CursorVariant::Sparse { si, skip_start, .. } => {
si.read_skip_entry(*skip_start + idx).first_doc
}
}
}
#[inline]
fn block_last_doc(&self, idx: usize) -> DocId {
match &self.variant {
CursorVariant::Text { list, .. } => list.block_last_doc(idx).unwrap_or(0),
CursorVariant::Sparse { si, skip_start, .. } => {
si.read_skip_entry(*skip_start + idx).last_doc
}
}
}
#[inline]
pub fn doc(&self) -> DocId {
if self.exhausted {
return u32::MAX;
}
if self.block_loaded {
debug_assert!(self.pos < self.doc_ids.len());
unsafe { *self.doc_ids.get_unchecked(self.pos) }
} else {
self.block_first_doc(self.block_idx)
}
}
#[inline]
pub fn ordinal(&self) -> u16 {
if !self.block_loaded || self.ordinals.is_empty() {
return 0;
}
debug_assert!(self.pos < self.ordinals.len());
unsafe { *self.ordinals.get_unchecked(self.pos) }
}
#[inline]
pub fn ordinal_mut(&mut self) -> u16 {
if !self.block_loaded {
return 0;
}
if !self.ordinals_loaded {
if let Some(ref block) = self.current_sparse_block {
block.decode_ordinals_into(&mut self.ordinals);
}
self.ordinals_loaded = true;
}
if self.ordinals.is_empty() {
return 0;
}
debug_assert!(self.pos < self.ordinals.len());
unsafe { *self.ordinals.get_unchecked(self.pos) }
}
#[inline]
pub fn score(&self) -> f32 {
if !self.block_loaded {
return 0.0;
}
debug_assert!(self.pos < self.scores.len());
unsafe { *self.scores.get_unchecked(self.pos) }
}
#[inline]
pub fn ensure_scores(&mut self) {
if self.block_loaded && self.scores.is_empty() {
self.compute_deferred_scores();
}
}
#[inline]
pub fn current_block_max_score(&self) -> f32 {
if self.exhausted {
return 0.0;
}
match &self.variant {
CursorVariant::Text { list, idf, .. } => {
let block_max_tf = list.block_max_tf(self.block_idx).unwrap_or(0) as f32;
super::bm25_upper_bound(block_max_tf.max(1.0), *idf)
}
CursorVariant::Sparse {
si,
query_weight,
skip_start,
..
} => query_weight.abs() * si.read_skip_entry(*skip_start + self.block_idx).max_weight,
}
}
pub fn skip_to_next_block(&mut self) -> DocId {
if self.exhausted {
return u32::MAX;
}
self.block_idx += 1;
self.block_loaded = false;
if self.block_idx >= self.num_blocks {
self.exhausted = true;
return u32::MAX;
}
self.block_first_doc(self.block_idx)
}
#[inline]
fn advance_pos(&mut self) -> DocId {
self.pos += 1;
if self.pos >= self.doc_ids.len() {
self.block_idx += 1;
self.block_loaded = false;
if self.block_idx >= self.num_blocks {
self.exhausted = true;
return u32::MAX;
}
}
self.doc()
}
#[inline(never)]
fn compute_deferred_scores(&mut self) {
if let CursorVariant::Text {
list,
idf_times_k1_plus_1,
denom_tf_coeff,
denom_const,
tfs,
deferred_tf,
..
} = &mut self.variant
&& let Some((block_offset, tf_start, count)) = deferred_tf.take()
{
list.decode_block_tfs_deferred(block_offset, tf_start, count, tfs);
let num_scale = *idf_times_k1_plus_1;
let d_tf = *denom_tf_coeff;
let d_const = *denom_const;
self.scores.clear();
self.scores.resize(count, 0.0);
for i in 0..count {
let tf = unsafe { *tfs.get_unchecked(i) } as f32;
let score = (num_scale * tf) / (d_tf * tf + d_const);
unsafe {
*self.scores.get_unchecked_mut(i) = score;
}
}
}
}
pub async fn ensure_block_loaded(&mut self) -> crate::Result<bool> {
cursor_ensure_block!(self, load_block_direct, .await)
}
pub fn ensure_block_loaded_sync(&mut self) -> crate::Result<bool> {
cursor_ensure_block!(self, load_block_direct_sync,)
}
pub async fn advance(&mut self) -> crate::Result<DocId> {
cursor_advance!(self, ensure_block_loaded, .await)
}
pub fn advance_sync(&mut self) -> crate::Result<DocId> {
cursor_advance!(self, ensure_block_loaded_sync,)
}
pub async fn seek(&mut self, target: DocId) -> crate::Result<DocId> {
cursor_seek!(self, ensure_block_loaded, target, .await)
}
pub fn seek_sync(&mut self, target: DocId) -> crate::Result<DocId> {
cursor_seek!(self, ensure_block_loaded_sync, target,)
}
fn seek_prepare(&mut self, target: DocId) -> Option<DocId> {
if self.exhausted {
return Some(u32::MAX);
}
if self.block_loaded
&& let Some(&last) = self.doc_ids.last()
{
if last >= target && self.doc_ids[self.pos] < target {
let remaining = &self.doc_ids[self.pos..];
self.pos += crate::structures::simd::find_first_ge_u32(remaining, target);
if self.pos >= self.doc_ids.len() {
self.block_idx += 1;
self.block_loaded = false;
if self.block_idx >= self.num_blocks {
self.exhausted = true;
return Some(u32::MAX);
}
}
return Some(self.doc());
}
if self.doc_ids[self.pos] >= target {
return Some(self.doc());
}
}
let lo = match &self.variant {
CursorVariant::Text { list, .. } => match list.seek_block(target, self.block_idx) {
Some(idx) => idx,
None => {
self.exhausted = true;
return Some(u32::MAX);
}
},
CursorVariant::Sparse { .. } => {
let mut lo = self.block_idx;
let mut hi = self.num_blocks;
while lo < hi {
let mid = lo + (hi - lo) / 2;
if self.block_last_doc(mid) < target {
lo = mid + 1;
} else {
hi = mid;
}
}
lo
}
};
if lo >= self.num_blocks {
self.exhausted = true;
return Some(u32::MAX);
}
if lo != self.block_idx || !self.block_loaded {
self.block_idx = lo;
self.block_loaded = false;
}
None
}
#[inline]
fn seek_finish(&mut self, target: DocId) -> bool {
if self.exhausted {
return false;
}
self.pos = crate::structures::simd::find_first_ge_u32(&self.doc_ids, target);
if self.pos >= self.doc_ids.len() {
self.block_idx += 1;
self.block_loaded = false;
if self.block_idx >= self.num_blocks {
self.exhausted = true;
return false;
}
return true;
}
false
}
}
macro_rules! bms_execute_loop {
($self:ident, $ensure:ident, $advance:ident, $seek:ident, $($aw:tt)*) => {{
let n = $self.cursors.len();
for cursor in &mut $self.cursors {
cursor.$ensure() $($aw)* ?;
}
let mut docs_scored = 0u64;
let mut docs_skipped = 0u64;
let mut blocks_skipped = 0u64;
let mut conjunction_skipped = 0u64;
let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
let _bms_start = std::time::Instant::now();
let inv_heap_factor = $self.inv_heap_factor;
let mut adjusted_threshold = $self.collector.threshold() * inv_heap_factor - 1e-6;
loop {
let partition = $self.find_partition();
if partition >= n {
break;
}
let mut min_doc = u32::MAX;
let mut at_min_mask = 0u64; for i in partition..n {
let doc = $self.cursors[i].doc();
match doc.cmp(&min_doc) {
std::cmp::Ordering::Less => {
min_doc = doc;
at_min_mask = 1u64 << (i as u32);
}
std::cmp::Ordering::Equal => {
at_min_mask |= 1u64 << (i as u32);
}
_ => {}
}
}
if min_doc == u32::MAX {
break;
}
let non_essential_upper = if partition > 0 {
$self.prefix_sums[partition - 1]
} else {
0.0
};
if $self.collector.len() >= $self.collector.k {
let mut present_upper: f32 = 0.0;
let mut mask = at_min_mask;
while mask != 0 {
let i = mask.trailing_zeros() as usize;
present_upper += $self.cursors[i].max_score;
mask &= mask - 1;
}
if present_upper + non_essential_upper <= adjusted_threshold {
let mut mask = at_min_mask;
while mask != 0 {
let i = mask.trailing_zeros() as usize;
$self.cursors[i].$ensure() $($aw)* ?;
$self.cursors[i].$advance() $($aw)* ?;
mask &= mask - 1;
}
conjunction_skipped += 1;
continue;
}
}
if $self.collector.len() >= $self.collector.k {
let mut block_max_sum: f32 = 0.0;
let mut mask = at_min_mask;
while mask != 0 {
let i = mask.trailing_zeros() as usize;
block_max_sum += $self.cursors[i].current_block_max_score();
mask &= mask - 1;
}
if block_max_sum + non_essential_upper <= adjusted_threshold {
let mut mask = at_min_mask;
while mask != 0 {
let i = mask.trailing_zeros() as usize;
$self.cursors[i].skip_to_next_block();
$self.cursors[i].$ensure() $($aw)* ?;
mask &= mask - 1;
}
blocks_skipped += 1;
continue;
}
}
if let Some(ref pred) = $self.predicate {
if !pred(min_doc) {
let mut mask = at_min_mask;
while mask != 0 {
let i = mask.trailing_zeros() as usize;
$self.cursors[i].$ensure() $($aw)* ?;
$self.cursors[i].$advance() $($aw)* ?;
mask &= mask - 1;
}
continue;
}
}
ordinal_scores.clear();
{
let mut mask = at_min_mask;
while mask != 0 {
let i = mask.trailing_zeros() as usize;
$self.cursors[i].$ensure() $($aw)* ?;
$self.cursors[i].ensure_scores();
while $self.cursors[i].doc() == min_doc {
let ord = $self.cursors[i].ordinal_mut();
let sc = $self.cursors[i].score();
ordinal_scores.push((ord, sc));
$self.cursors[i].$advance() $($aw)* ?;
}
mask &= mask - 1;
}
}
let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
if $self.collector.len() >= $self.collector.k
&& essential_total + non_essential_upper <= adjusted_threshold
{
docs_skipped += 1;
continue;
}
let mut running_total = essential_total;
for i in (0..partition).rev() {
if $self.collector.len() >= $self.collector.k
&& running_total + $self.prefix_sums[i] <= adjusted_threshold
{
break;
}
let doc = $self.cursors[i].$seek(min_doc) $($aw)* ?;
if doc == min_doc {
$self.cursors[i].ensure_scores();
while $self.cursors[i].doc() == min_doc {
let s = $self.cursors[i].score();
running_total += s;
let ord = $self.cursors[i].ordinal_mut();
ordinal_scores.push((ord, s));
$self.cursors[i].$advance() $($aw)* ?;
}
}
}
if ordinal_scores.len() == 1 {
let (ord, score) = ordinal_scores[0];
if $self.collector.insert_with_ordinal(min_doc, score, ord) {
docs_scored += 1;
adjusted_threshold = $self.collector.threshold() * inv_heap_factor - 1e-6;
} else {
docs_skipped += 1;
}
} else if !ordinal_scores.is_empty() {
if ordinal_scores.len() > 2 {
ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
} else if ordinal_scores.len() == 2 && ordinal_scores[0].0 > ordinal_scores[1].0 {
ordinal_scores.swap(0, 1);
}
let mut j = 0;
while j < ordinal_scores.len() {
let current_ord = ordinal_scores[j].0;
let mut score = 0.0f32;
while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
score += ordinal_scores[j].1;
j += 1;
}
if $self
.collector
.insert_with_ordinal(min_doc, score, current_ord)
{
docs_scored += 1;
adjusted_threshold = $self.collector.threshold() * inv_heap_factor - 1e-6;
} else {
docs_skipped += 1;
}
}
}
}
let results: Vec<ScoredDoc> = $self
.collector
.into_sorted_results()
.into_iter()
.map(|(doc_id, score, ordinal)| ScoredDoc {
doc_id,
score,
ordinal,
})
.collect();
let _bms_elapsed_ms = _bms_start.elapsed().as_millis() as u64;
if _bms_elapsed_ms > 500 {
warn!(
"slow MaxScore: {}ms, cursors={}, scored={}, skipped={}, blocks_skipped={}, conjunction_skipped={}, returned={}, top_score={:.4}",
_bms_elapsed_ms,
n,
docs_scored,
docs_skipped,
blocks_skipped,
conjunction_skipped,
results.len(),
results.first().map(|r| r.score).unwrap_or(0.0)
);
} else {
debug!(
"MaxScoreExecutor: {}ms, scored={}, skipped={}, blocks_skipped={}, conjunction_skipped={}, returned={}, top_score={:.4}",
_bms_elapsed_ms,
docs_scored,
docs_skipped,
blocks_skipped,
conjunction_skipped,
results.len(),
results.first().map(|r| r.score).unwrap_or(0.0)
);
}
Ok(results)
}};
}
impl<'a> MaxScoreExecutor<'a> {
pub(crate) fn new(mut cursors: Vec<TermCursor<'a>>, k: usize, heap_factor: f32) -> Self {
for c in &mut cursors {
c.lazy_ordinals = true;
}
cursors.sort_by(|a, b| {
a.max_score
.partial_cmp(&b.max_score)
.unwrap_or(Ordering::Equal)
});
let mut prefix_sums = Vec::with_capacity(cursors.len());
let mut cumsum = 0.0f32;
for c in &cursors {
cumsum += c.max_score;
prefix_sums.push(cumsum);
}
let clamped_heap_factor = heap_factor.clamp(0.01, 1.0);
debug!(
"Creating MaxScoreExecutor: num_cursors={}, k={}, total_upper={:.4}, heap_factor={:.2}",
cursors.len(),
k,
cumsum,
clamped_heap_factor
);
Self {
cursors,
prefix_sums,
collector: ScoreCollector::new(k),
inv_heap_factor: 1.0 / clamped_heap_factor,
predicate: None,
}
}
pub fn sparse(
sparse_index: &'a crate::segment::SparseIndex,
query_terms: Vec<(u32, f32)>,
k: usize,
heap_factor: f32,
) -> Self {
let cursors: Vec<TermCursor<'a>> = query_terms
.iter()
.filter_map(|&(dim_id, qw)| {
let (skip_start, skip_count, global_max, block_data_offset) =
sparse_index.get_skip_range_full(dim_id)?;
Some(TermCursor::sparse(
sparse_index,
qw,
skip_start,
skip_count,
global_max,
block_data_offset,
))
})
.collect();
Self::new(cursors, k, heap_factor)
}
pub fn text(
posting_lists: Vec<(crate::structures::BlockPostingList, f32)>,
avg_field_len: f32,
k: usize,
) -> Self {
let cursors: Vec<TermCursor<'a>> = posting_lists
.into_iter()
.map(|(pl, idf)| TermCursor::text(pl, idf, avg_field_len))
.collect();
Self::new(cursors, k, 1.0)
}
#[inline]
fn find_partition(&self) -> usize {
let threshold = self.collector.threshold() * self.inv_heap_factor;
self.prefix_sums.partition_point(|&sum| sum <= threshold)
}
pub fn with_predicate(mut self, predicate: super::DocPredicate<'a>) -> Self {
self.predicate = Some(predicate);
self
}
pub fn seed_threshold(&mut self, initial_threshold: f32) {
self.collector.seed_threshold(initial_threshold);
}
pub async fn execute(mut self) -> crate::Result<Vec<ScoredDoc>> {
if self.cursors.is_empty() {
return Ok(Vec::new());
}
bms_execute_loop!(self, ensure_block_loaded, advance, seek, .await)
}
pub fn execute_sync(mut self) -> crate::Result<Vec<ScoredDoc>> {
if self.cursors.is_empty() {
return Ok(Vec::new());
}
bms_execute_loop!(self, ensure_block_loaded_sync, advance_sync, seek_sync,)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_score_collector_basic() {
let mut collector = ScoreCollector::new(3);
collector.insert(1, 1.0);
collector.insert(2, 2.0);
collector.insert(3, 3.0);
assert_eq!(collector.threshold(), 1.0);
collector.insert(4, 4.0);
assert_eq!(collector.threshold(), 2.0);
let results = collector.into_sorted_results();
assert_eq!(results.len(), 3);
assert_eq!(results[0].0, 4); assert_eq!(results[1].0, 3);
assert_eq!(results[2].0, 2);
}
#[test]
fn test_score_collector_threshold() {
let mut collector = ScoreCollector::new(2);
collector.insert(1, 5.0);
collector.insert(2, 3.0);
assert_eq!(collector.threshold(), 3.0);
assert!(!collector.would_enter(2.0));
assert!(!collector.insert(3, 2.0));
assert!(collector.would_enter(4.0));
assert!(collector.insert(4, 4.0));
assert_eq!(collector.threshold(), 4.0);
}
#[test]
fn test_heap_entry_ordering() {
let mut heap = BinaryHeap::new();
heap.push(HeapEntry {
doc_id: 1,
score: 3.0,
ordinal: 0,
});
heap.push(HeapEntry {
doc_id: 2,
score: 1.0,
ordinal: 0,
});
heap.push(HeapEntry {
doc_id: 3,
score: 2.0,
ordinal: 0,
});
assert_eq!(heap.pop().unwrap().score, 1.0);
assert_eq!(heap.pop().unwrap().score, 2.0);
assert_eq!(heap.pop().unwrap().score, 3.0);
}
}