use std::cmp::Ordering;
use std::collections::BinaryHeap;
use crate::core::{DocId, SegmentId};
#[derive(Clone, Debug)]
pub(crate) struct ScoreDoc {
pub(crate) doc_id: DocId,
pub(crate) segment_id: SegmentId,
pub(crate) score: f32,
}
struct MinScoreDoc(ScoreDoc);
impl PartialEq for MinScoreDoc {
fn eq(&self, other: &Self) -> bool {
self.0.score == other.0.score
}
}
impl Eq for MinScoreDoc {}
impl PartialOrd for MinScoreDoc {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for MinScoreDoc {
fn cmp(&self, other: &Self) -> Ordering {
other
.0
.score
.partial_cmp(&self.0.score)
.unwrap_or(Ordering::Equal)
.then_with(|| other.0.doc_id.cmp(&self.0.doc_id))
}
}
pub struct TopDocsCollector {
heap: BinaryHeap<MinScoreDoc>,
limit: usize,
}
impl TopDocsCollector {
pub fn new(limit: usize) -> Self {
Self {
heap: BinaryHeap::with_capacity(limit + 1),
limit,
}
}
pub fn collect(&mut self, doc_id: DocId, segment_id: SegmentId, score: f32) {
if self.heap.len() < self.limit {
self.heap.push(MinScoreDoc(ScoreDoc {
doc_id,
segment_id,
score,
}));
return;
}
if let Some(min) = self.heap.peek() {
if score <= min.0.score {
return;
}
} else {
return;
}
let mut top = self.heap.peek_mut().unwrap();
*top = MinScoreDoc(ScoreDoc {
doc_id,
segment_id,
score,
});
}
pub fn min_score(&self) -> f32 {
if self.heap.len() < self.limit {
0.0
} else {
self.heap.peek().map(|m| m.0.score).unwrap_or(0.0)
}
}
pub fn merge(&mut self, other: TopDocsCollector) {
for item in other.heap {
self.collect(item.0.doc_id, item.0.segment_id, item.0.score);
}
}
pub(crate) fn into_sorted_results(self) -> Vec<ScoreDoc> {
let mut results: Vec<ScoreDoc> = self.heap.into_iter().map(|m| m.0).collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(Ordering::Equal)
.then_with(|| a.doc_id.cmp(&b.doc_id))
});
results
}
}
use crate::search::{SortField, SortValue, compare_sort_values_cascade};
#[derive(Clone, Debug)]
pub(crate) struct FieldDoc {
pub(crate) doc_id: DocId,
pub(crate) segment_id: SegmentId,
pub(crate) score: f32,
pub(crate) sort_values: Vec<SortValue>,
}
struct MinFieldDoc {
doc: FieldDoc,
sort_fields: *const Vec<SortField>,
}
unsafe impl Send for MinFieldDoc {}
impl PartialEq for MinFieldDoc {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == Ordering::Equal
}
}
impl Eq for MinFieldDoc {}
impl PartialOrd for MinFieldDoc {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for MinFieldDoc {
fn cmp(&self, other: &Self) -> Ordering {
let sort_fields = unsafe { &*self.sort_fields };
compare_sort_values_cascade(&self.doc.sort_values, &other.doc.sort_values, sort_fields)
.then_with(|| self.doc.doc_id.cmp(&other.doc.doc_id))
}
}
pub struct TopFieldCollector {
heap: BinaryHeap<MinFieldDoc>,
limit: usize,
sort_fields: Vec<SortField>,
search_after: Option<Vec<SortValue>>,
}
impl TopFieldCollector {
pub fn new(limit: usize, sort_fields: Vec<SortField>) -> Self {
Self {
heap: BinaryHeap::with_capacity(limit + 1),
limit,
sort_fields,
search_after: None,
}
}
pub fn set_search_after(&mut self, cursor: Vec<SortValue>) {
self.search_after = Some(cursor);
}
#[inline]
pub fn is_competitive_primary(&self, primary: &SortValue) -> bool {
if self.heap.len() < self.limit {
return true;
}
if let Some(top) = self.heap.peek() {
if let Some(worst_primary) = top.doc.sort_values.first() {
let cmp = primary.compare(worst_primary, &self.sort_fields[0]);
return cmp == Ordering::Less;
}
}
false
}
#[inline]
pub fn is_competitive_keyword(&self, value: Option<&str>) -> bool {
if self.heap.len() < self.limit {
return true;
}
if let Some(top) = self.heap.peek() {
if let Some(worst) = top.doc.sort_values.first() {
let _sv = match value {
Some(_s) => SortValue::Str(String::new()), None => SortValue::Null,
};
let natural = match (value, worst) {
(None, SortValue::Null) => std::cmp::Ordering::Equal,
(None, _) => match self.sort_fields[0].missing {
crate::search::MissingValue::Last => match self.sort_fields[0].order {
crate::search::SortOrder::Asc => std::cmp::Ordering::Greater,
crate::search::SortOrder::Desc => std::cmp::Ordering::Less,
},
crate::search::MissingValue::First => match self.sort_fields[0].order {
crate::search::SortOrder::Asc => std::cmp::Ordering::Less,
crate::search::SortOrder::Desc => std::cmp::Ordering::Greater,
},
},
(Some(_), SortValue::Null) => match self.sort_fields[0].missing {
crate::search::MissingValue::Last => match self.sort_fields[0].order {
crate::search::SortOrder::Asc => std::cmp::Ordering::Less,
crate::search::SortOrder::Desc => std::cmp::Ordering::Greater,
},
crate::search::MissingValue::First => match self.sort_fields[0].order {
crate::search::SortOrder::Asc => std::cmp::Ordering::Greater,
crate::search::SortOrder::Desc => std::cmp::Ordering::Less,
},
},
(Some(s), SortValue::Str(w)) => s.cmp(w.as_str()),
_ => std::cmp::Ordering::Equal,
};
let cmp = match self.sort_fields[0].order {
crate::search::SortOrder::Asc => natural,
crate::search::SortOrder::Desc => natural.reverse(),
};
return cmp == Ordering::Less;
}
}
false
}
pub fn collect(
&mut self,
doc_id: DocId,
segment_id: SegmentId,
score: f32,
sort_values: Vec<SortValue>,
) {
if let Some(ref after) = self.search_after {
let cmp = compare_sort_values_cascade(&sort_values, after, &self.sort_fields);
if cmp != Ordering::Greater {
return;
}
}
let doc = FieldDoc {
doc_id,
segment_id,
score,
sort_values,
};
if self.heap.len() < self.limit {
self.heap.push(MinFieldDoc {
doc,
sort_fields: &self.sort_fields as *const Vec<SortField>,
});
return;
}
if let Some(top) = self.heap.peek() {
let cmp = compare_sort_values_cascade(
&doc.sort_values,
&top.doc.sort_values,
&self.sort_fields,
)
.then_with(|| doc.doc_id.cmp(&top.doc.doc_id));
if cmp != Ordering::Less {
return;
}
} else {
return;
}
let mut top = self.heap.peek_mut().unwrap();
*top = MinFieldDoc {
doc,
sort_fields: &self.sort_fields as *const Vec<SortField>,
};
}
pub fn merge(&mut self, other: TopFieldCollector) {
for item in other.heap {
self.collect(
item.doc.doc_id,
item.doc.segment_id,
item.doc.score,
item.doc.sort_values,
);
}
}
pub(crate) fn into_sorted_results(self) -> Vec<FieldDoc> {
let sort_fields = self.sort_fields;
let mut results: Vec<FieldDoc> = self.heap.into_iter().map(|m| m.doc).collect();
results.sort_by(|a, b| {
compare_sort_values_cascade(&a.sort_values, &b.sort_values, &sort_fields)
.then_with(|| a.doc_id.cmp(&b.doc_id))
});
results
}
}
use std::collections::HashMap;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub(crate) enum CollapseKey {
Null,
Keyword(String),
Numeric(i64),
}
impl CollapseKey {
pub fn to_json(&self) -> serde_json::Value {
match self {
CollapseKey::Null => serde_json::Value::Null,
CollapseKey::Keyword(s) => serde_json::json!(s),
CollapseKey::Numeric(n) => serde_json::json!(n),
}
}
}
#[derive(Clone, Debug)]
pub(crate) struct CollapsedDoc {
pub(crate) doc_id: DocId,
pub(crate) segment_id: SegmentId,
pub(crate) score: f32,
pub(crate) collapse_key: CollapseKey,
}
pub struct CollapsingCollector {
groups: HashMap<CollapseKey, CollapsedDoc>,
limit: usize,
}
impl CollapsingCollector {
pub fn new(limit: usize) -> Self {
Self {
groups: HashMap::new(),
limit,
}
}
pub(crate) fn collect(
&mut self,
doc_id: DocId,
segment_id: SegmentId,
score: f32,
key: CollapseKey,
) {
let entry = self.groups.entry(key.clone()).or_insert(CollapsedDoc {
doc_id,
segment_id,
score,
collapse_key: key,
});
if score > entry.score {
entry.doc_id = doc_id;
entry.segment_id = segment_id;
entry.score = score;
}
}
pub fn merge(&mut self, other: CollapsingCollector) {
for (key, doc) in other.groups {
self.collect(doc.doc_id, doc.segment_id, doc.score, key);
}
}
pub fn group_count(&self) -> usize {
self.groups.len()
}
pub(crate) fn into_sorted_results(self) -> Vec<CollapsedDoc> {
let mut results: Vec<CollapsedDoc> = self.groups.into_values().collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(Ordering::Equal)
.then_with(|| a.doc_id.cmp(&b.doc_id))
});
results.truncate(self.limit);
results
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn collect_under_limit() {
let mut collector = TopDocsCollector::new(10);
collector.collect(DocId::new(0), SegmentId::new(1), 1.0);
collector.collect(DocId::new(1), SegmentId::new(1), 2.0);
collector.collect(DocId::new(2), SegmentId::new(1), 0.5);
let results = collector.into_sorted_results();
assert_eq!(results.len(), 3);
assert_eq!(results[0].doc_id, DocId::new(1)); assert_eq!(results[1].doc_id, DocId::new(0));
assert_eq!(results[2].doc_id, DocId::new(2)); }
#[test]
fn top_k_limiting() {
let mut collector = TopDocsCollector::new(2);
collector.collect(DocId::new(0), SegmentId::new(1), 1.0);
collector.collect(DocId::new(1), SegmentId::new(1), 3.0);
collector.collect(DocId::new(2), SegmentId::new(1), 2.0);
collector.collect(DocId::new(3), SegmentId::new(1), 0.5);
let results = collector.into_sorted_results();
assert_eq!(results.len(), 2);
assert_eq!(results[0].doc_id, DocId::new(1)); assert_eq!(results[1].doc_id, DocId::new(2)); }
#[test]
fn score_descending_order() {
let mut collector = TopDocsCollector::new(5);
for i in 0..5 {
collector.collect(DocId::new(i), SegmentId::new(1), i as f32);
}
let results = collector.into_sorted_results();
for i in 0..results.len() - 1 {
assert!(results[i].score >= results[i + 1].score);
}
}
#[test]
fn empty_collector() {
let collector = TopDocsCollector::new(10);
let results = collector.into_sorted_results();
assert!(results.is_empty());
}
#[test]
fn multi_segment() {
let mut collector = TopDocsCollector::new(3);
collector.collect(DocId::new(0), SegmentId::new(1), 1.0);
collector.collect(DocId::new(0), SegmentId::new(2), 2.0);
collector.collect(DocId::new(1), SegmentId::new(1), 3.0);
collector.collect(DocId::new(1), SegmentId::new(2), 0.5);
let results = collector.into_sorted_results();
assert_eq!(results.len(), 3);
assert_eq!(results[0].score, 3.0);
assert_eq!(results[0].segment_id, SegmentId::new(1));
}
use crate::search::{MissingValue, SortField, SortFieldType, SortOrder, SortValue};
fn sort_field_asc(name: &str) -> SortField {
SortField {
field: SortFieldType::Field(name.to_string()),
order: SortOrder::Asc,
missing: MissingValue::Last,
}
}
fn sort_field_desc(name: &str) -> SortField {
SortField {
field: SortFieldType::Field(name.to_string()),
order: SortOrder::Desc,
missing: MissingValue::Last,
}
}
fn prices(results: &[FieldDoc]) -> Vec<f64> {
results
.iter()
.map(|r| match r.sort_values[0] {
SortValue::F64(f) => f,
_ => panic!("expected F64 sort value"),
})
.collect()
}
#[test]
fn top_field_collector_asc_retains_smallest() {
let mut c = TopFieldCollector::new(2, vec![sort_field_asc("price")]);
for (i, price) in [199.0, 79.0, 9.99, 59.0, 25.0].into_iter().enumerate() {
c.collect(
DocId::new(i as u32),
SegmentId::new(0),
1.0,
vec![SortValue::F64(price)],
);
}
assert_eq!(prices(&c.into_sorted_results()), vec![9.99, 25.0]);
}
#[test]
fn top_field_collector_desc_retains_largest() {
let mut c = TopFieldCollector::new(2, vec![sort_field_desc("price")]);
for (i, price) in [199.0, 79.0, 9.99, 59.0, 25.0].into_iter().enumerate() {
c.collect(
DocId::new(i as u32),
SegmentId::new(0),
1.0,
vec![SortValue::F64(price)],
);
}
assert_eq!(prices(&c.into_sorted_results()), vec![199.0, 79.0]);
}
#[test]
fn top_field_collector_tiebreak_prefers_smaller_doc_id() {
let mut c = TopFieldCollector::new(1, vec![sort_field_asc("price")]);
c.collect(
DocId::new(7),
SegmentId::new(0),
1.0,
vec![SortValue::F64(42.0)],
);
c.collect(
DocId::new(3),
SegmentId::new(0),
1.0,
vec![SortValue::F64(42.0)],
);
let results = c.into_sorted_results();
assert_eq!(results.len(), 1);
assert_eq!(results[0].doc_id, DocId::new(3));
}
}