hermes_core/query/
term.rs1use std::sync::Arc;
4
5use crate::dsl::Field;
6use crate::segment::SegmentReader;
7use crate::structures::BlockPostingList;
8use crate::{DocId, Score};
9
10use super::{
11 Bm25Params, CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture, TermQueryInfo,
12};
13
14#[derive(Clone)]
16pub struct TermQuery {
17 pub field: Field,
18 pub term: Vec<u8>,
19 global_stats: Option<Arc<GlobalStats>>,
21}
22
23impl std::fmt::Debug for TermQuery {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 f.debug_struct("TermQuery")
26 .field("field", &self.field)
27 .field("term", &String::from_utf8_lossy(&self.term))
28 .field("has_global_stats", &self.global_stats.is_some())
29 .finish()
30 }
31}
32
33impl TermQuery {
34 pub fn new(field: Field, term: impl Into<Vec<u8>>) -> Self {
35 Self {
36 field,
37 term: term.into(),
38 global_stats: None,
39 }
40 }
41
42 pub fn text(field: Field, text: &str) -> Self {
43 Self {
44 field,
45 term: text.to_lowercase().into_bytes(),
46 global_stats: None,
47 }
48 }
49
50 pub fn with_global_stats(field: Field, text: &str, stats: Arc<GlobalStats>) -> Self {
52 Self {
53 field,
54 term: text.to_lowercase().into_bytes(),
55 global_stats: Some(stats),
56 }
57 }
58
59 pub fn set_global_stats(&mut self, stats: Arc<GlobalStats>) {
61 self.global_stats = Some(stats);
62 }
63}
64
65impl Query for TermQuery {
66 fn scorer<'a>(&'a self, reader: &'a SegmentReader, _limit: usize) -> ScorerFuture<'a> {
67 Box::pin(async move {
68 let postings = reader.get_postings(self.field, &self.term).await?;
69
70 match postings {
71 Some(posting_list) => {
72 let (idf, avg_field_len) = if let Some(ref stats) = self.global_stats {
74 let term_str = String::from_utf8_lossy(&self.term);
75 let global_idf = stats.text_idf(self.field, &term_str);
76
77 if global_idf > 0.0 {
80 (global_idf, stats.avg_field_len(self.field))
81 } else {
82 let num_docs = reader.num_docs() as f32;
84 let doc_freq = posting_list.doc_count() as f32;
85 let idf = super::bm25_idf(doc_freq, num_docs);
86 (idf, reader.avg_field_len(self.field))
87 }
88 } else {
89 let num_docs = reader.num_docs() as f32;
91 let doc_freq = posting_list.doc_count() as f32;
92 let idf = super::bm25_idf(doc_freq, num_docs);
93 (idf, reader.avg_field_len(self.field))
94 };
95
96 Ok(Box::new(TermScorer::new(
97 posting_list,
98 idf,
99 avg_field_len,
100 Bm25Params::default(),
101 1.0, )) as Box<dyn Scorer + 'a>)
103 }
104 None => Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>),
105 }
106 })
107 }
108
109 fn count_estimate<'a>(&'a self, reader: &'a SegmentReader) -> CountFuture<'a> {
110 Box::pin(async move {
111 match reader.get_postings(self.field, &self.term).await? {
112 Some(list) => Ok(list.doc_count()),
113 None => Ok(0),
114 }
115 })
116 }
117
118 fn as_term_query_info(&self) -> Option<TermQueryInfo> {
119 Some(TermQueryInfo {
120 field: self.field,
121 term: self.term.clone(),
122 })
123 }
124}
125
126struct TermScorer {
127 iterator: crate::structures::BlockPostingIterator<'static>,
128 idf: f32,
129 params: Bm25Params,
131 avg_field_len: f32,
133 field_boost: f32,
135}
136
137impl TermScorer {
138 pub fn new(
139 posting_list: BlockPostingList,
140 idf: f32,
141 avg_field_len: f32,
142 params: Bm25Params,
143 field_boost: f32,
144 ) -> Self {
145 Self {
146 iterator: posting_list.into_iterator(),
147 idf,
148 params,
149 avg_field_len,
150 field_boost,
151 }
152 }
153}
154
155impl Scorer for TermScorer {
156 fn doc(&self) -> DocId {
157 self.iterator.doc()
158 }
159
160 fn score(&self) -> Score {
161 let tf = self.iterator.term_freq() as f32;
162 let k1 = self.params.k1;
163 let b = self.params.b;
164
165 let length_norm = 1.0 - b + b * (tf / self.avg_field_len.max(1.0));
167 let tf_norm =
168 (tf * self.field_boost * (k1 + 1.0)) / (tf * self.field_boost + k1 * length_norm);
169
170 self.idf * tf_norm
171 }
172
173 fn advance(&mut self) -> DocId {
174 self.iterator.advance()
175 }
176
177 fn seek(&mut self, target: DocId) -> DocId {
178 self.iterator.seek(target)
179 }
180
181 fn size_hint(&self) -> u32 {
182 0
183 }
184}