use std::cmp::Ordering;
use std::collections::BinaryHeap;
use crate::segment::SegmentReader;
use crate::structures::TERMINATED;
use crate::{DocId, Result, Score};
use super::Query;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct DocAddress {
segment_id_raw: u128,
pub doc_id: DocId,
}
impl DocAddress {
pub fn new(segment_id: u128, doc_id: DocId) -> Self {
Self {
segment_id_raw: segment_id,
doc_id,
}
}
pub fn segment_id(&self) -> String {
format!("{:032x}", self.segment_id_raw)
}
pub fn segment_id_u128(&self) -> Option<u128> {
Some(self.segment_id_raw)
}
}
impl serde::Serialize for DocAddress {
fn serialize<S: serde::Serializer>(
&self,
serializer: S,
) -> std::result::Result<S::Ok, S::Error> {
use serde::ser::SerializeStruct;
let mut s = serializer.serialize_struct("DocAddress", 2)?;
s.serialize_field("segment_id", &format!("{:032x}", self.segment_id_raw))?;
s.serialize_field("doc_id", &self.doc_id)?;
s.end()
}
}
impl<'de> serde::Deserialize<'de> for DocAddress {
fn deserialize<D: serde::Deserializer<'de>>(
deserializer: D,
) -> std::result::Result<Self, D::Error> {
#[derive(serde::Deserialize)]
struct Helper {
segment_id: String,
doc_id: DocId,
}
let h = Helper::deserialize(deserializer)?;
let raw = u128::from_str_radix(&h.segment_id, 16).map_err(serde::de::Error::custom)?;
Ok(DocAddress {
segment_id_raw: raw,
doc_id: h.doc_id,
})
}
}
#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
pub struct ScoredPosition {
pub position: u32,
pub score: f32,
}
impl ScoredPosition {
pub fn new(position: u32, score: f32) -> Self {
Self { position, score }
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SearchResult {
pub doc_id: DocId,
pub score: Score,
#[serde(default, skip_serializing_if = "is_zero_u128")]
pub segment_id: u128,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub positions: Vec<(u32, Vec<ScoredPosition>)>,
}
fn is_zero_u128(v: &u128) -> bool {
*v == 0
}
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct MatchedField {
pub field_id: u32,
pub ordinals: Vec<u32>,
}
impl SearchResult {
pub fn extract_ordinals(&self) -> Vec<MatchedField> {
use rustc_hash::FxHashSet;
self.positions
.iter()
.map(|(field_id, scored_positions)| {
let mut ordinals: FxHashSet<u32> = FxHashSet::default();
for sp in scored_positions {
let ordinal = if sp.position > 0xFFFFF {
sp.position >> 20
} else {
sp.position
};
ordinals.insert(ordinal);
}
let mut ordinals: Vec<u32> = ordinals.into_iter().collect();
ordinals.sort_unstable();
MatchedField {
field_id: *field_id,
ordinals,
}
})
.collect()
}
pub fn field_positions(&self, field_id: u32) -> Option<&[ScoredPosition]> {
self.positions
.iter()
.find(|(fid, _)| *fid == field_id)
.map(|(_, positions)| positions.as_slice())
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SearchHit {
pub address: DocAddress,
pub score: Score,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub matched_fields: Vec<MatchedField>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SearchResponse {
pub hits: Vec<SearchHit>,
pub total_hits: u32,
}
impl PartialEq for SearchResult {
fn eq(&self, other: &Self) -> bool {
self.segment_id == other.segment_id && self.doc_id == other.doc_id
}
}
impl Eq for SearchResult {}
impl PartialOrd for SearchResult {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for SearchResult {
fn cmp(&self, other: &Self) -> Ordering {
other
.score
.partial_cmp(&self.score)
.unwrap_or(Ordering::Equal)
.then_with(|| self.segment_id.cmp(&other.segment_id))
.then_with(|| self.doc_id.cmp(&other.doc_id))
}
}
pub trait Collector {
fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]);
fn needs_positions(&self) -> bool {
false
}
}
pub struct TopKCollector {
heap: BinaryHeap<SearchResult>,
k: usize,
collect_positions: bool,
total_seen: u32,
}
impl TopKCollector {
pub fn new(k: usize) -> Self {
Self {
heap: BinaryHeap::with_capacity(k + 1),
k,
collect_positions: false,
total_seen: 0,
}
}
pub fn with_positions(k: usize) -> Self {
Self {
heap: BinaryHeap::with_capacity(k + 1),
k,
collect_positions: true,
total_seen: 0,
}
}
pub fn total_seen(&self) -> u32 {
self.total_seen
}
pub fn into_sorted_results(self) -> Vec<SearchResult> {
let mut results: Vec<_> = self.heap.into_vec();
results.sort_unstable_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(Ordering::Equal)
.then_with(|| a.doc_id.cmp(&b.doc_id))
});
results
}
pub fn into_results_with_count(self) -> (Vec<SearchResult>, u32) {
let total = self.total_seen;
(self.into_sorted_results(), total)
}
}
impl Collector for TopKCollector {
fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
self.total_seen = self.total_seen.saturating_add(1);
let dominated =
self.heap.len() >= self.k && self.heap.peek().is_some_and(|min| score <= min.score);
if dominated {
return;
}
let positions = if self.collect_positions {
positions.to_vec()
} else {
Vec::new()
};
if self.heap.len() >= self.k {
self.heap.pop();
}
self.heap.push(SearchResult {
doc_id,
score,
segment_id: 0,
positions,
});
}
fn needs_positions(&self) -> bool {
self.collect_positions
}
}
#[derive(Default)]
pub struct CountCollector {
count: u64,
}
impl CountCollector {
pub fn new() -> Self {
Self { count: 0 }
}
pub fn count(&self) -> u64 {
self.count
}
}
impl Collector for CountCollector {
#[inline]
fn collect(
&mut self,
_doc_id: DocId,
_score: Score,
_positions: &[(u32, Vec<ScoredPosition>)],
) {
self.count += 1;
}
}
pub async fn search_segment_with_count(
reader: &SegmentReader,
query: &dyn Query,
limit: usize,
) -> Result<(Vec<SearchResult>, u32)> {
let mut collector = TopKCollector::new(limit);
collect_segment_with_limit(reader, query, &mut collector, limit).await?;
Ok(collector.into_results_with_count())
}
pub async fn search_segment_with_positions_and_count(
reader: &SegmentReader,
query: &dyn Query,
limit: usize,
) -> Result<(Vec<SearchResult>, u32)> {
let mut collector = TopKCollector::with_positions(limit);
collect_segment_with_limit(reader, query, &mut collector, limit).await?;
Ok(collector.into_results_with_count())
}
impl<A: Collector, B: Collector> Collector for (&mut A, &mut B) {
fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
self.0.collect(doc_id, score, positions);
self.1.collect(doc_id, score, positions);
}
fn needs_positions(&self) -> bool {
self.0.needs_positions() || self.1.needs_positions()
}
}
impl<A: Collector, B: Collector, C: Collector> Collector for (&mut A, &mut B, &mut C) {
fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
self.0.collect(doc_id, score, positions);
self.1.collect(doc_id, score, positions);
self.2.collect(doc_id, score, positions);
}
fn needs_positions(&self) -> bool {
self.0.needs_positions() || self.1.needs_positions() || self.2.needs_positions()
}
}
pub async fn collect_segment<C: Collector>(
reader: &SegmentReader,
query: &dyn Query,
collector: &mut C,
) -> Result<()> {
collect_segment_with_limit(reader, query, collector, usize::MAX / 2).await
}
pub async fn collect_segment_with_limit<C: Collector>(
reader: &SegmentReader,
query: &dyn Query,
collector: &mut C,
limit: usize,
) -> Result<()> {
let mut scorer = query.scorer(reader, limit).await?;
drive_scorer(scorer.as_mut(), collector);
Ok(())
}
fn drive_scorer<C: Collector>(scorer: &mut dyn super::Scorer, collector: &mut C) {
let needs_positions = collector.needs_positions();
let mut doc = scorer.doc();
while doc != TERMINATED {
if needs_positions {
let positions = scorer.matched_positions().unwrap_or_default();
collector.collect(doc, scorer.score(), &positions);
} else {
collector.collect(doc, scorer.score(), &[]);
}
doc = scorer.advance();
}
}
#[cfg(feature = "sync")]
pub fn search_segment_with_count_sync(
reader: &SegmentReader,
query: &dyn Query,
limit: usize,
) -> Result<(Vec<SearchResult>, u32)> {
let mut collector = TopKCollector::new(limit);
collect_segment_with_limit_sync(reader, query, &mut collector, limit)?;
Ok(collector.into_results_with_count())
}
#[cfg(feature = "sync")]
pub fn search_segment_with_positions_and_count_sync(
reader: &SegmentReader,
query: &dyn Query,
limit: usize,
) -> Result<(Vec<SearchResult>, u32)> {
let mut collector = TopKCollector::with_positions(limit);
collect_segment_with_limit_sync(reader, query, &mut collector, limit)?;
Ok(collector.into_results_with_count())
}
#[cfg(feature = "sync")]
pub fn collect_segment_with_limit_sync<C: Collector>(
reader: &SegmentReader,
query: &dyn Query,
collector: &mut C,
limit: usize,
) -> Result<()> {
let mut scorer = query.scorer_sync(reader, limit)?;
drive_scorer(scorer.as_mut(), collector);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_top_k_collector() {
let mut collector = TopKCollector::new(3);
collector.collect(0, 1.0, &[]);
collector.collect(1, 3.0, &[]);
collector.collect(2, 2.0, &[]);
collector.collect(3, 4.0, &[]);
collector.collect(4, 0.5, &[]);
let results = collector.into_sorted_results();
assert_eq!(results.len(), 3);
assert_eq!(results[0].doc_id, 3); assert_eq!(results[1].doc_id, 1); assert_eq!(results[2].doc_id, 2); }
#[test]
fn test_count_collector() {
let mut collector = CountCollector::new();
collector.collect(0, 1.0, &[]);
collector.collect(1, 2.0, &[]);
collector.collect(2, 3.0, &[]);
assert_eq!(collector.count(), 3);
}
#[test]
fn test_multi_collector() {
let mut top_k = TopKCollector::new(2);
let mut count = CountCollector::new();
for (doc_id, score) in [(0, 1.0), (1, 3.0), (2, 2.0), (3, 4.0), (4, 0.5)] {
top_k.collect(doc_id, score, &[]);
count.collect(doc_id, score, &[]);
}
assert_eq!(count.count(), 5);
let results = top_k.into_sorted_results();
assert_eq!(results.len(), 2);
assert_eq!(results[0].doc_id, 3); assert_eq!(results[1].doc_id, 1); }
}