use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::fmt::Debug;
use crate::error::Result;
use crate::lexical::query::SearchHit;
pub trait Collector: Send + Debug {
fn collect(&mut self, doc_id: u64, score: f32) -> Result<()>;
fn results(&self) -> Vec<SearchHit>;
fn total_hits(&self) -> u64;
fn needs_more(&self) -> bool;
fn min_score(&self) -> f32;
fn reset(&mut self);
}
#[derive(Debug)]
pub struct TopDocsCollector {
max_docs: usize,
min_score: f32,
hits: BinaryHeap<ScoredDoc>,
total_hits: u64,
}
#[derive(Debug, Clone)]
struct ScoredDoc {
doc_id: u64,
score: f32,
}
#[derive(Debug, Clone)]
struct FieldScoredDoc {
doc_id: u64,
score: f32,
field_value: crate::lexical::core::field::FieldValue,
ascending: bool,
}
#[derive(Debug)]
pub struct TopFieldCollector<'a> {
max_docs: usize,
min_score: f32,
field_name: String,
ascending: bool,
hits: BinaryHeap<FieldScoredDoc>,
total_hits: u64,
reader: &'a dyn crate::lexical::reader::LexicalIndexReader,
}
impl<'a> TopFieldCollector<'a> {
pub fn new(
max_docs: usize,
field_name: String,
ascending: bool,
reader: &'a dyn crate::lexical::reader::LexicalIndexReader,
) -> Self {
TopFieldCollector {
max_docs,
min_score: 0.0,
field_name,
ascending,
hits: BinaryHeap::new(),
total_hits: 0,
reader,
}
}
pub fn with_min_score(
max_docs: usize,
min_score: f32,
field_name: String,
ascending: bool,
reader: &'a dyn crate::lexical::reader::LexicalIndexReader,
) -> Self {
TopFieldCollector {
max_docs,
min_score,
field_name,
ascending,
hits: BinaryHeap::new(),
total_hits: 0,
reader,
}
}
fn get_field_value(&self, doc_id: u64) -> crate::lexical::core::field::FieldValue {
if let Ok(Some(value)) = self.reader.get_doc_value(&self.field_name, doc_id) {
value
} else {
crate::lexical::core::field::FieldValue::Null
}
}
fn compare_for_heap(&self, a: &FieldScoredDoc, b: &FieldScoredDoc) -> Ordering {
if self.ascending {
a.cmp(b)
} else {
b.cmp(a)
}
}
fn should_collect(&self, new_doc: &FieldScoredDoc) -> bool {
if self.hits.len() < self.max_docs {
return true;
}
if let Some(worst) = self.hits.peek() {
self.compare_for_heap(new_doc, worst) == Ordering::Less
} else {
true
}
}
}
impl<'a> Collector for TopFieldCollector<'a> {
fn collect(&mut self, doc_id: u64, score: f32) -> Result<()> {
self.total_hits += 1;
if score < self.min_score {
return Ok(());
}
let field_value = self.get_field_value(doc_id);
let scored_doc = FieldScoredDoc {
doc_id,
score,
field_value,
ascending: self.ascending,
};
if self.hits.len() < self.max_docs {
self.hits.push(scored_doc);
} else {
if self.should_collect(&scored_doc) {
self.hits.pop();
self.hits.push(scored_doc);
}
}
Ok(())
}
fn results(&self) -> Vec<SearchHit> {
let mut sorted_docs: Vec<_> = self.hits.iter().cloned().collect();
use crate::lexical::core::field::FieldValue;
if self.ascending {
sorted_docs.sort_by(|a, b| match (&a.field_value, &b.field_value) {
(FieldValue::Text(av), FieldValue::Text(bv)) => av.cmp(bv),
(FieldValue::Int64(av), FieldValue::Int64(bv)) => av.cmp(bv),
(FieldValue::Float64(av), FieldValue::Float64(bv)) => {
av.partial_cmp(bv).unwrap_or(Ordering::Equal)
}
(FieldValue::Bool(av), FieldValue::Bool(bv)) => av.cmp(bv),
(FieldValue::DateTime(av), FieldValue::DateTime(bv)) => av.cmp(bv),
(FieldValue::Geo(alat, alon), FieldValue::Geo(blat, blon)) => {
let lat_cmp = alat.partial_cmp(blat).unwrap_or(Ordering::Equal);
if lat_cmp != Ordering::Equal {
lat_cmp
} else {
alon.partial_cmp(blon).unwrap_or(Ordering::Equal)
}
}
(FieldValue::Bytes(av, _), FieldValue::Bytes(bv, _)) => av.cmp(bv),
(FieldValue::Null, FieldValue::Null) => Ordering::Equal,
(FieldValue::Null, _) => Ordering::Greater,
(_, FieldValue::Null) => Ordering::Less,
_ => Ordering::Equal,
});
} else {
sorted_docs.sort_by(|a, b| match (&a.field_value, &b.field_value) {
(FieldValue::Text(av), FieldValue::Text(bv)) => bv.cmp(av),
(FieldValue::Int64(av), FieldValue::Int64(bv)) => bv.cmp(av),
(FieldValue::Float64(av), FieldValue::Float64(bv)) => {
bv.partial_cmp(av).unwrap_or(Ordering::Equal)
}
(FieldValue::Bool(av), FieldValue::Bool(bv)) => bv.cmp(av),
(FieldValue::DateTime(av), FieldValue::DateTime(bv)) => bv.cmp(av),
(FieldValue::Geo(alat, alon), FieldValue::Geo(blat, blon)) => {
let lat_cmp = blat.partial_cmp(alat).unwrap_or(Ordering::Equal);
if lat_cmp != Ordering::Equal {
lat_cmp
} else {
blon.partial_cmp(alon).unwrap_or(Ordering::Equal)
}
}
(FieldValue::Bytes(av, _), FieldValue::Bytes(bv, _)) => bv.cmp(av),
(FieldValue::Null, FieldValue::Null) => Ordering::Equal,
(FieldValue::Null, _) => Ordering::Less,
(_, FieldValue::Null) => Ordering::Greater,
_ => Ordering::Equal,
});
}
sorted_docs
.into_iter()
.map(|doc| SearchHit {
doc_id: doc.doc_id,
score: doc.score,
document: None,
})
.collect()
}
fn total_hits(&self) -> u64 {
self.total_hits
}
fn needs_more(&self) -> bool {
self.hits.len() < self.max_docs
}
fn min_score(&self) -> f32 {
self.min_score
}
fn reset(&mut self) {
self.hits.clear();
self.total_hits = 0;
}
}
impl PartialEq for FieldScoredDoc {
fn eq(&self, other: &Self) -> bool {
self.field_value == other.field_value && self.doc_id == other.doc_id
}
}
impl Eq for FieldScoredDoc {}
impl PartialOrd for FieldScoredDoc {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for FieldScoredDoc {
fn cmp(&self, other: &Self) -> Ordering {
use crate::lexical::core::field::FieldValue;
let field_cmp = if self.ascending {
match (&self.field_value, &other.field_value) {
(FieldValue::Text(a), FieldValue::Text(b)) => b.cmp(a),
(FieldValue::Int64(a), FieldValue::Int64(b)) => b.cmp(a),
(FieldValue::Float64(a), FieldValue::Float64(b)) => {
b.partial_cmp(a).unwrap_or(Ordering::Equal)
}
(FieldValue::Bool(a), FieldValue::Bool(b)) => b.cmp(a),
(FieldValue::DateTime(a), FieldValue::DateTime(b)) => b.cmp(a),
(FieldValue::Geo(alat, alon), FieldValue::Geo(blat, blon)) => {
let lat_cmp = blat.partial_cmp(alat).unwrap_or(Ordering::Equal);
if lat_cmp != Ordering::Equal {
lat_cmp
} else {
blon.partial_cmp(alon).unwrap_or(Ordering::Equal)
}
}
(FieldValue::Bytes(a, _), FieldValue::Bytes(b, _)) => b.cmp(a),
(FieldValue::Null, FieldValue::Null) => Ordering::Equal,
(FieldValue::Null, _) => Ordering::Greater,
(_, FieldValue::Null) => Ordering::Less,
_ => Ordering::Equal,
}
} else {
match (&self.field_value, &other.field_value) {
(FieldValue::Text(a), FieldValue::Text(b)) => a.cmp(b),
(FieldValue::Int64(a), FieldValue::Int64(b)) => a.cmp(b),
(FieldValue::Float64(a), FieldValue::Float64(b)) => {
a.partial_cmp(b).unwrap_or(Ordering::Equal)
}
(FieldValue::Bool(a), FieldValue::Bool(b)) => a.cmp(b),
(FieldValue::DateTime(a), FieldValue::DateTime(b)) => a.cmp(b),
(FieldValue::Geo(alat, alon), FieldValue::Geo(blat, blon)) => {
let lat_cmp = alat.partial_cmp(blat).unwrap_or(Ordering::Equal);
if lat_cmp != Ordering::Equal {
lat_cmp
} else {
alon.partial_cmp(blon).unwrap_or(Ordering::Equal)
}
}
(FieldValue::Bytes(a, _), FieldValue::Bytes(b, _)) => a.cmp(b),
(FieldValue::Null, FieldValue::Null) => Ordering::Equal,
(FieldValue::Null, _) => Ordering::Less,
(_, FieldValue::Null) => Ordering::Greater,
_ => Ordering::Equal,
}
};
field_cmp.then_with(|| other.doc_id.cmp(&self.doc_id))
}
}
impl PartialEq for ScoredDoc {
fn eq(&self, other: &Self) -> bool {
self.score == other.score && self.doc_id == other.doc_id
}
}
impl Eq for ScoredDoc {}
impl PartialOrd for ScoredDoc {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for ScoredDoc {
fn cmp(&self, other: &Self) -> Ordering {
other
.score
.partial_cmp(&self.score)
.unwrap_or(Ordering::Equal)
.then_with(|| other.doc_id.cmp(&self.doc_id))
}
}
impl TopDocsCollector {
pub fn new(max_docs: usize) -> Self {
TopDocsCollector {
max_docs,
min_score: 0.0,
hits: BinaryHeap::new(),
total_hits: 0,
}
}
pub fn with_min_score(max_docs: usize, min_score: f32) -> Self {
TopDocsCollector {
max_docs,
min_score,
hits: BinaryHeap::new(),
total_hits: 0,
}
}
pub fn max_docs(&self) -> usize {
self.max_docs
}
pub fn current_min_score(&self) -> f32 {
if self.hits.len() < self.max_docs {
self.min_score
} else {
self.hits
.peek()
.map(|doc| doc.score)
.unwrap_or(self.min_score)
}
}
}
impl Collector for TopDocsCollector {
fn collect(&mut self, doc_id: u64, score: f32) -> Result<()> {
self.total_hits += 1;
if score < self.min_score {
return Ok(());
}
let scored_doc = ScoredDoc { doc_id, score };
if self.hits.len() < self.max_docs {
self.hits.push(scored_doc);
} else {
if let Some(worst) = self.hits.peek()
&& score > worst.score
{
self.hits.pop();
self.hits.push(scored_doc);
}
}
Ok(())
}
fn results(&self) -> Vec<SearchHit> {
let mut results: Vec<_> = self
.hits
.iter()
.map(|doc| SearchHit {
doc_id: doc.doc_id,
score: doc.score,
document: None,
})
.collect();
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
results
}
fn total_hits(&self) -> u64 {
self.total_hits
}
fn needs_more(&self) -> bool {
self.hits.len() < self.max_docs
}
fn min_score(&self) -> f32 {
self.current_min_score()
}
fn reset(&mut self) {
self.hits.clear();
self.total_hits = 0;
}
}
#[derive(Debug)]
pub struct CountCollector {
count: u64,
min_score: f32,
}
impl CountCollector {
pub fn new() -> Self {
CountCollector {
count: 0,
min_score: 0.0,
}
}
pub fn with_min_score(min_score: f32) -> Self {
CountCollector {
count: 0,
min_score,
}
}
pub fn count(&self) -> u64 {
self.count
}
}
impl Default for CountCollector {
fn default() -> Self {
Self::new()
}
}
impl Collector for CountCollector {
fn collect(&mut self, _doc_id: u64, score: f32) -> Result<()> {
if score >= self.min_score {
self.count += 1;
}
Ok(())
}
fn results(&self) -> Vec<SearchHit> {
Vec::new()
}
fn total_hits(&self) -> u64 {
self.count
}
fn needs_more(&self) -> bool {
true
}
fn min_score(&self) -> f32 {
self.min_score
}
fn reset(&mut self) {
self.count = 0;
}
}
#[derive(Debug)]
pub struct AllDocsCollector {
hits: Vec<SearchHit>,
min_score: f32,
}
impl AllDocsCollector {
pub fn new() -> Self {
AllDocsCollector {
hits: Vec::new(),
min_score: 0.0,
}
}
pub fn with_min_score(min_score: f32) -> Self {
AllDocsCollector {
hits: Vec::new(),
min_score,
}
}
}
impl Default for AllDocsCollector {
fn default() -> Self {
Self::new()
}
}
impl Collector for AllDocsCollector {
fn collect(&mut self, doc_id: u64, score: f32) -> Result<()> {
if score >= self.min_score {
self.hits.push(SearchHit {
doc_id,
score,
document: None,
});
}
Ok(())
}
fn results(&self) -> Vec<SearchHit> {
let mut results = self.hits.clone();
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
results
}
fn total_hits(&self) -> u64 {
self.hits.len() as u64
}
fn needs_more(&self) -> bool {
true
}
fn min_score(&self) -> f32 {
self.min_score
}
fn reset(&mut self) {
self.hits.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_top_docs_collector() {
let mut collector = TopDocsCollector::new(3);
assert_eq!(collector.max_docs(), 3);
assert_eq!(collector.total_hits(), 0);
assert!(collector.needs_more());
collector.collect(1, 0.5).unwrap();
collector.collect(2, 0.8).unwrap();
collector.collect(3, 0.3).unwrap();
assert_eq!(collector.total_hits(), 3);
assert!(!collector.needs_more());
collector.collect(4, 0.9).unwrap();
assert_eq!(collector.total_hits(), 4);
let results = collector.results();
assert_eq!(results.len(), 3);
assert!(results[0].score >= results[1].score);
assert!(results[1].score >= results[2].score);
assert_eq!(results[0].doc_id, 4);
assert_eq!(results[0].score, 0.9);
}
#[test]
fn test_top_docs_collector_with_min_score() {
let mut collector = TopDocsCollector::with_min_score(3, 0.5);
assert_eq!(collector.min_score(), 0.5);
collector.collect(1, 0.3).unwrap(); collector.collect(2, 0.8).unwrap(); collector.collect(3, 0.6).unwrap();
assert_eq!(collector.total_hits(), 3);
let results = collector.results();
assert_eq!(results.len(), 2);
assert!(!results.iter().any(|hit| hit.score == 0.3));
}
#[test]
fn test_count_collector() {
let mut collector = CountCollector::new();
assert_eq!(collector.count(), 0);
assert_eq!(collector.total_hits(), 0);
assert!(collector.needs_more());
collector.collect(1, 0.5).unwrap();
collector.collect(2, 0.8).unwrap();
collector.collect(3, 0.3).unwrap();
assert_eq!(collector.count(), 3);
assert_eq!(collector.total_hits(), 3);
let results = collector.results();
assert!(results.is_empty());
}
#[test]
fn test_count_collector_with_min_score() {
let mut collector = CountCollector::with_min_score(0.5);
collector.collect(1, 0.3).unwrap(); collector.collect(2, 0.8).unwrap(); collector.collect(3, 0.6).unwrap();
assert_eq!(collector.count(), 2); assert_eq!(collector.total_hits(), 2);
}
#[test]
fn test_all_docs_collector() {
let mut collector = AllDocsCollector::new();
assert_eq!(collector.total_hits(), 0);
assert!(collector.needs_more());
collector.collect(1, 0.5).unwrap();
collector.collect(2, 0.8).unwrap();
collector.collect(3, 0.3).unwrap();
assert_eq!(collector.total_hits(), 3);
let results = collector.results();
assert_eq!(results.len(), 3);
assert!(results[0].score >= results[1].score);
assert!(results[1].score >= results[2].score);
assert_eq!(results[0].doc_id, 2); assert_eq!(results[1].doc_id, 1); assert_eq!(results[2].doc_id, 3); }
#[test]
fn test_all_docs_collector_with_min_score() {
let mut collector = AllDocsCollector::with_min_score(0.5);
collector.collect(1, 0.3).unwrap(); collector.collect(2, 0.8).unwrap(); collector.collect(3, 0.6).unwrap();
assert_eq!(collector.total_hits(), 2);
let results = collector.results();
assert_eq!(results.len(), 2);
assert!(!results.iter().any(|hit| hit.score == 0.3));
}
#[test]
fn test_collector_reset() {
let mut collector = TopDocsCollector::new(3);
collector.collect(1, 0.5).unwrap();
collector.collect(2, 0.8).unwrap();
assert_eq!(collector.total_hits(), 2);
assert_eq!(collector.results().len(), 2);
collector.reset();
assert_eq!(collector.total_hits(), 0);
assert_eq!(collector.results().len(), 0);
assert!(collector.needs_more());
}
#[test]
fn test_scored_doc_ordering() {
let doc1 = ScoredDoc {
doc_id: 1,
score: 0.5,
};
let doc2 = ScoredDoc {
doc_id: 2,
score: 0.8,
};
let doc3 = ScoredDoc {
doc_id: 3,
score: 0.8,
};
assert!(doc2 < doc1);
assert!(doc3 < doc2);
let mut heap = BinaryHeap::new();
heap.push(doc1);
heap.push(doc2);
heap.push(doc3);
assert_eq!(heap.peek().unwrap().score, 0.5);
}
}