hermes_core/query/
collector.rs

1//! Search result collection and response types
2
3use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5
6use crate::segment::SegmentReader;
7use crate::structures::TERMINATED;
8use crate::{DocId, Result, Score};
9
10use super::Query;
11
12/// Unique document address: segment_id (hex) + local doc_id within segment
13#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
14pub struct DocAddress {
15    /// Segment ID as hex string (32 chars)
16    pub segment_id: String,
17    /// Document ID within the segment
18    pub doc_id: DocId,
19}
20
21impl DocAddress {
22    pub fn new(segment_id: u128, doc_id: DocId) -> Self {
23        Self {
24            segment_id: format!("{:032x}", segment_id),
25            doc_id,
26        }
27    }
28
29    /// Parse segment_id from hex string
30    pub fn segment_id_u128(&self) -> Option<u128> {
31        u128::from_str_radix(&self.segment_id, 16).ok()
32    }
33}
34
35/// Search result with doc_id and score (internal use)
36#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
37pub struct SearchResult {
38    pub doc_id: DocId,
39    pub score: Score,
40}
41
42/// Search hit with unique document address and score
43#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
44pub struct SearchHit {
45    /// Unique document address (segment_id + local doc_id)
46    pub address: DocAddress,
47    pub score: Score,
48}
49
50/// Search response with hits (IDs only, no documents)
51#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
52pub struct SearchResponse {
53    pub hits: Vec<SearchHit>,
54    pub total_hits: u32,
55}
56
57impl PartialEq for SearchResult {
58    fn eq(&self, other: &Self) -> bool {
59        self.doc_id == other.doc_id
60    }
61}
62
63impl Eq for SearchResult {}
64
65impl PartialOrd for SearchResult {
66    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
67        Some(self.cmp(other))
68    }
69}
70
71impl Ord for SearchResult {
72    fn cmp(&self, other: &Self) -> Ordering {
73        other
74            .score
75            .partial_cmp(&self.score)
76            .unwrap_or(Ordering::Equal)
77            .then_with(|| self.doc_id.cmp(&other.doc_id))
78    }
79}
80
81/// Collector for top-k results
82pub struct TopKCollector {
83    heap: BinaryHeap<SearchResult>,
84    k: usize,
85}
86
87impl TopKCollector {
88    pub fn new(k: usize) -> Self {
89        Self {
90            heap: BinaryHeap::with_capacity(k + 1),
91            k,
92        }
93    }
94
95    pub fn collect(&mut self, doc_id: DocId, score: Score) {
96        if self.heap.len() < self.k {
97            self.heap.push(SearchResult { doc_id, score });
98        } else if let Some(min) = self.heap.peek()
99            && score > min.score
100        {
101            self.heap.pop();
102            self.heap.push(SearchResult { doc_id, score });
103        }
104    }
105
106    pub fn into_sorted_results(self) -> Vec<SearchResult> {
107        let mut results: Vec<_> = self.heap.into_vec();
108        results.sort_by(|a, b| {
109            b.score
110                .partial_cmp(&a.score)
111                .unwrap_or(Ordering::Equal)
112                .then_with(|| a.doc_id.cmp(&b.doc_id))
113        });
114        results
115    }
116}
117
118/// Execute a search query on a single segment (async)
119pub async fn search_segment(
120    reader: &SegmentReader,
121    query: &dyn Query,
122    limit: usize,
123) -> Result<Vec<SearchResult>> {
124    let mut scorer = query.scorer(reader).await?;
125    let mut collector = TopKCollector::new(limit);
126
127    let mut doc = scorer.doc();
128
129    while doc != TERMINATED {
130        collector.collect(doc, scorer.score());
131        doc = scorer.advance();
132    }
133
134    Ok(collector.into_sorted_results())
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    #[test]
142    fn test_top_k_collector() {
143        let mut collector = TopKCollector::new(3);
144
145        collector.collect(0, 1.0);
146        collector.collect(1, 3.0);
147        collector.collect(2, 2.0);
148        collector.collect(3, 4.0);
149        collector.collect(4, 0.5);
150
151        let results = collector.into_sorted_results();
152
153        assert_eq!(results.len(), 3);
154        assert_eq!(results[0].doc_id, 3); // score 4.0
155        assert_eq!(results[1].doc_id, 1); // score 3.0
156        assert_eq!(results[2].doc_id, 2); // score 2.0
157    }
158}