hermes_core/query/
term.rs1use std::sync::Arc;
4
5use crate::dsl::Field;
6use crate::segment::SegmentReader;
7use crate::structures::BlockPostingList;
8use crate::{DocId, Score};
9
10use super::{CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture, TermQueryInfo};
11
12#[derive(Clone)]
14pub struct TermQuery {
15 pub field: Field,
16 pub term: Vec<u8>,
17 global_stats: Option<Arc<GlobalStats>>,
19}
20
21impl std::fmt::Debug for TermQuery {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 f.debug_struct("TermQuery")
24 .field("field", &self.field)
25 .field("term", &String::from_utf8_lossy(&self.term))
26 .field("has_global_stats", &self.global_stats.is_some())
27 .finish()
28 }
29}
30
31impl TermQuery {
32 pub fn new(field: Field, term: impl Into<Vec<u8>>) -> Self {
33 Self {
34 field,
35 term: term.into(),
36 global_stats: None,
37 }
38 }
39
40 pub fn text(field: Field, text: &str) -> Self {
41 Self {
42 field,
43 term: text.to_lowercase().into_bytes(),
44 global_stats: None,
45 }
46 }
47
48 pub fn with_global_stats(field: Field, text: &str, stats: Arc<GlobalStats>) -> Self {
50 Self {
51 field,
52 term: text.to_lowercase().into_bytes(),
53 global_stats: Some(stats),
54 }
55 }
56
57 pub fn set_global_stats(&mut self, stats: Arc<GlobalStats>) {
59 self.global_stats = Some(stats);
60 }
61}
62
63impl Query for TermQuery {
64 fn scorer<'a>(&self, reader: &'a SegmentReader, _limit: usize) -> ScorerFuture<'a> {
65 let field = self.field;
66 let term = self.term.clone();
67 let global_stats = self.global_stats.clone();
68 Box::pin(async move {
69 let postings = reader.get_postings(field, &term).await?;
70
71 match postings {
72 Some(posting_list) => {
73 let (idf, avg_field_len) = if let Some(ref stats) = global_stats {
75 let term_str = String::from_utf8_lossy(&term);
76 let global_idf = stats.text_idf(field, &term_str);
77
78 if global_idf > 0.0 {
81 (global_idf, stats.avg_field_len(field))
82 } else {
83 let num_docs = reader.num_docs() as f32;
85 let doc_freq = posting_list.doc_count() as f32;
86 let idf = super::bm25_idf(doc_freq, num_docs);
87 (idf, reader.avg_field_len(field))
88 }
89 } else {
90 let num_docs = reader.num_docs() as f32;
92 let doc_freq = posting_list.doc_count() as f32;
93 let idf = super::bm25_idf(doc_freq, num_docs);
94 (idf, reader.avg_field_len(field))
95 };
96
97 let positions = reader.get_positions(field, &term).await.ok().flatten();
99
100 let mut scorer = TermScorer::new(
101 posting_list,
102 idf,
103 avg_field_len,
104 1.0, );
106
107 if let Some(pos) = positions {
108 scorer = scorer.with_positions(field.0, pos);
109 }
110
111 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
112 }
113 None => Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>),
114 }
115 })
116 }
117
118 fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
119 let field = self.field;
120 let term = self.term.clone();
121 Box::pin(async move {
122 match reader.get_postings(field, &term).await? {
123 Some(list) => Ok(list.doc_count()),
124 None => Ok(0),
125 }
126 })
127 }
128
129 fn as_term_query_info(&self) -> Option<TermQueryInfo> {
130 Some(TermQueryInfo {
131 field: self.field,
132 term: self.term.clone(),
133 })
134 }
135}
136
137struct TermScorer {
138 iterator: crate::structures::BlockPostingIterator<'static>,
139 idf: f32,
140 avg_field_len: f32,
142 field_boost: f32,
144 field_id: u32,
146 positions: Option<crate::structures::PositionPostingList>,
148}
149
150impl TermScorer {
151 pub fn new(
152 posting_list: BlockPostingList,
153 idf: f32,
154 avg_field_len: f32,
155 field_boost: f32,
156 ) -> Self {
157 Self {
158 iterator: posting_list.into_iterator(),
159 idf,
160 avg_field_len,
161 field_boost,
162 field_id: 0,
163 positions: None,
164 }
165 }
166
167 pub fn with_positions(
168 mut self,
169 field_id: u32,
170 positions: crate::structures::PositionPostingList,
171 ) -> Self {
172 self.field_id = field_id;
173 self.positions = Some(positions);
174 self
175 }
176}
177
178impl Scorer for TermScorer {
179 fn doc(&self) -> DocId {
180 self.iterator.doc()
181 }
182
183 fn score(&self) -> Score {
184 let tf = self.iterator.term_freq() as f32;
185 super::bm25f_score(tf, self.idf, tf, self.avg_field_len, self.field_boost)
188 }
189
190 fn advance(&mut self) -> DocId {
191 self.iterator.advance()
192 }
193
194 fn seek(&mut self, target: DocId) -> DocId {
195 self.iterator.seek(target)
196 }
197
198 fn size_hint(&self) -> u32 {
199 0
200 }
201
202 fn matched_positions(&self) -> Option<super::MatchedPositions> {
203 let positions = self.positions.as_ref()?;
204 let doc_id = self.iterator.doc();
205 let pos = positions.get_positions(doc_id)?;
206 Some(vec![(self.field_id, pos.to_vec())])
207 }
208}