use std::cmp::Ordering;
use std::collections::BinaryHeap;
use crate::segment::SegmentReader;
use crate::structures::TERMINATED;
use crate::{DocId, Result, Score};
use super::Query;
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub struct DocAddress {
pub segment_id: String,
pub doc_id: DocId,
}
impl DocAddress {
pub fn new(segment_id: u128, doc_id: DocId) -> Self {
Self {
segment_id: format!("{:032x}", segment_id),
doc_id,
}
}
pub fn segment_id_u128(&self) -> Option<u128> {
u128::from_str_radix(&self.segment_id, 16).ok()
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SearchResult {
pub doc_id: DocId,
pub score: Score,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub positions: Vec<(u32, Vec<u32>)>,
}
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct MatchedField {
pub field_id: u32,
pub ordinals: Vec<u32>,
}
impl SearchResult {
pub fn extract_ordinals(&self) -> Vec<MatchedField> {
use rustc_hash::FxHashSet;
self.positions
.iter()
.map(|(field_id, positions)| {
let mut ordinals: FxHashSet<u32> = FxHashSet::default();
for &pos in positions {
let ordinal = pos >> 20; ordinals.insert(ordinal);
}
let mut ordinals: Vec<u32> = ordinals.into_iter().collect();
ordinals.sort_unstable();
MatchedField {
field_id: *field_id,
ordinals,
}
})
.collect()
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SearchHit {
pub address: DocAddress,
pub score: Score,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub matched_fields: Vec<MatchedField>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SearchResponse {
pub hits: Vec<SearchHit>,
pub total_hits: u32,
}
impl PartialEq for SearchResult {
fn eq(&self, other: &Self) -> bool {
self.doc_id == other.doc_id
}
}
impl Eq for SearchResult {}
impl PartialOrd for SearchResult {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for SearchResult {
fn cmp(&self, other: &Self) -> Ordering {
other
.score
.partial_cmp(&self.score)
.unwrap_or(Ordering::Equal)
.then_with(|| self.doc_id.cmp(&other.doc_id))
}
}
pub trait Collector {
fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<u32>)]);
fn needs_positions(&self) -> bool {
false
}
}
pub struct TopKCollector {
heap: BinaryHeap<SearchResult>,
k: usize,
collect_positions: bool,
}
impl TopKCollector {
pub fn new(k: usize) -> Self {
Self {
heap: BinaryHeap::with_capacity(k + 1),
k,
collect_positions: false,
}
}
pub fn with_positions(k: usize) -> Self {
Self {
heap: BinaryHeap::with_capacity(k + 1),
k,
collect_positions: true,
}
}
pub fn into_sorted_results(self) -> Vec<SearchResult> {
let mut results: Vec<_> = self.heap.into_vec();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(Ordering::Equal)
.then_with(|| a.doc_id.cmp(&b.doc_id))
});
results
}
}
impl Collector for TopKCollector {
fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<u32>)]) {
let positions = if self.collect_positions {
positions.to_vec()
} else {
Vec::new()
};
if self.heap.len() < self.k {
self.heap.push(SearchResult {
doc_id,
score,
positions,
});
} else if let Some(min) = self.heap.peek()
&& score > min.score
{
self.heap.pop();
self.heap.push(SearchResult {
doc_id,
score,
positions,
});
}
}
fn needs_positions(&self) -> bool {
self.collect_positions
}
}
#[derive(Default)]
pub struct CountCollector {
count: u64,
}
impl CountCollector {
pub fn new() -> Self {
Self { count: 0 }
}
pub fn count(&self) -> u64 {
self.count
}
}
impl Collector for CountCollector {
#[inline]
fn collect(&mut self, _doc_id: DocId, _score: Score, _positions: &[(u32, Vec<u32>)]) {
self.count += 1;
}
}
pub async fn search_segment(
reader: &SegmentReader,
query: &dyn Query,
limit: usize,
) -> Result<Vec<SearchResult>> {
let mut collector = TopKCollector::new(limit);
collect_segment(reader, query, &mut collector).await?;
Ok(collector.into_sorted_results())
}
pub async fn search_segment_with_positions(
reader: &SegmentReader,
query: &dyn Query,
limit: usize,
) -> Result<Vec<SearchResult>> {
let mut collector = TopKCollector::with_positions(limit);
collect_segment(reader, query, &mut collector).await?;
Ok(collector.into_sorted_results())
}
pub async fn count_segment(reader: &SegmentReader, query: &dyn Query) -> Result<u64> {
let mut collector = CountCollector::new();
collect_segment(reader, query, &mut collector).await?;
Ok(collector.count())
}
impl<A: Collector, B: Collector> Collector for (&mut A, &mut B) {
fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<u32>)]) {
self.0.collect(doc_id, score, positions);
self.1.collect(doc_id, score, positions);
}
fn needs_positions(&self) -> bool {
self.0.needs_positions() || self.1.needs_positions()
}
}
impl<A: Collector, B: Collector, C: Collector> Collector for (&mut A, &mut B, &mut C) {
fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<u32>)]) {
self.0.collect(doc_id, score, positions);
self.1.collect(doc_id, score, positions);
self.2.collect(doc_id, score, positions);
}
fn needs_positions(&self) -> bool {
self.0.needs_positions() || self.1.needs_positions() || self.2.needs_positions()
}
}
pub async fn collect_segment<C: Collector>(
reader: &SegmentReader,
query: &dyn Query,
collector: &mut C,
) -> Result<()> {
let needs_positions = collector.needs_positions();
let mut scorer = query.scorer(reader, usize::MAX / 2).await?;
let mut doc = scorer.doc();
while doc != TERMINATED {
let positions = if needs_positions {
scorer.matched_positions().unwrap_or_default()
} else {
Vec::new()
};
collector.collect(doc, scorer.score(), &positions);
doc = scorer.advance();
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_top_k_collector() {
let mut collector = TopKCollector::new(3);
collector.collect(0, 1.0, &[]);
collector.collect(1, 3.0, &[]);
collector.collect(2, 2.0, &[]);
collector.collect(3, 4.0, &[]);
collector.collect(4, 0.5, &[]);
let results = collector.into_sorted_results();
assert_eq!(results.len(), 3);
assert_eq!(results[0].doc_id, 3); assert_eq!(results[1].doc_id, 1); assert_eq!(results[2].doc_id, 2); }
#[test]
fn test_count_collector() {
let mut collector = CountCollector::new();
collector.collect(0, 1.0, &[]);
collector.collect(1, 2.0, &[]);
collector.collect(2, 3.0, &[]);
assert_eq!(collector.count(), 3);
}
#[test]
fn test_multi_collector() {
let mut top_k = TopKCollector::new(2);
let mut count = CountCollector::new();
for (doc_id, score) in [(0, 1.0), (1, 3.0), (2, 2.0), (3, 4.0), (4, 0.5)] {
top_k.collect(doc_id, score, &[]);
count.collect(doc_id, score, &[]);
}
assert_eq!(count.count(), 5);
let results = top_k.into_sorted_results();
assert_eq!(results.len(), 2);
assert_eq!(results[0].doc_id, 3); assert_eq!(results[1].doc_id, 1); }
}