hermes_core/query/
collector.rs1use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5
6use crate::segment::SegmentReader;
7use crate::structures::TERMINATED;
8use crate::{DocId, Result, Score};
9
10use super::Query;
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
14pub struct DocAddress {
15 pub segment_id: String,
17 pub doc_id: DocId,
19}
20
21impl DocAddress {
22 pub fn new(segment_id: u128, doc_id: DocId) -> Self {
23 Self {
24 segment_id: format!("{:032x}", segment_id),
25 doc_id,
26 }
27 }
28
29 pub fn segment_id_u128(&self) -> Option<u128> {
31 u128::from_str_radix(&self.segment_id, 16).ok()
32 }
33}
34
35#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
37pub struct SearchResult {
38 pub doc_id: DocId,
39 pub score: Score,
40}
41
42#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
44pub struct SearchHit {
45 pub address: DocAddress,
47 pub score: Score,
48}
49
50#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
52pub struct SearchResponse {
53 pub hits: Vec<SearchHit>,
54 pub total_hits: u32,
55}
56
57impl PartialEq for SearchResult {
58 fn eq(&self, other: &Self) -> bool {
59 self.doc_id == other.doc_id
60 }
61}
62
63impl Eq for SearchResult {}
64
65impl PartialOrd for SearchResult {
66 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
67 Some(self.cmp(other))
68 }
69}
70
71impl Ord for SearchResult {
72 fn cmp(&self, other: &Self) -> Ordering {
73 other
74 .score
75 .partial_cmp(&self.score)
76 .unwrap_or(Ordering::Equal)
77 .then_with(|| self.doc_id.cmp(&other.doc_id))
78 }
79}
80
81pub struct TopKCollector {
83 heap: BinaryHeap<SearchResult>,
84 k: usize,
85}
86
87impl TopKCollector {
88 pub fn new(k: usize) -> Self {
89 Self {
90 heap: BinaryHeap::with_capacity(k + 1),
91 k,
92 }
93 }
94
95 pub fn collect(&mut self, doc_id: DocId, score: Score) {
96 if self.heap.len() < self.k {
97 self.heap.push(SearchResult { doc_id, score });
98 } else if let Some(min) = self.heap.peek()
99 && score > min.score
100 {
101 self.heap.pop();
102 self.heap.push(SearchResult { doc_id, score });
103 }
104 }
105
106 pub fn into_sorted_results(self) -> Vec<SearchResult> {
107 let mut results: Vec<_> = self.heap.into_vec();
108 results.sort_by(|a, b| {
109 b.score
110 .partial_cmp(&a.score)
111 .unwrap_or(Ordering::Equal)
112 .then_with(|| a.doc_id.cmp(&b.doc_id))
113 });
114 results
115 }
116}
117
118pub async fn search_segment(
120 reader: &SegmentReader,
121 query: &dyn Query,
122 limit: usize,
123) -> Result<Vec<SearchResult>> {
124 let mut scorer = query.scorer(reader).await?;
125 let mut collector = TopKCollector::new(limit);
126
127 let mut doc = scorer.doc();
128
129 while doc != TERMINATED {
130 collector.collect(doc, scorer.score());
131 doc = scorer.advance();
132 }
133
134 Ok(collector.into_sorted_results())
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140
141 #[test]
142 fn test_top_k_collector() {
143 let mut collector = TopKCollector::new(3);
144
145 collector.collect(0, 1.0);
146 collector.collect(1, 3.0);
147 collector.collect(2, 2.0);
148 collector.collect(3, 4.0);
149 collector.collect(4, 0.5);
150
151 let results = collector.into_sorted_results();
152
153 assert_eq!(results.len(), 3);
154 assert_eq!(results[0].doc_id, 3); assert_eq!(results[1].doc_id, 1); assert_eq!(results[2].doc_id, 2); }
158}