hermes_core/query/
term.rs1use std::sync::Arc;
4
5use crate::dsl::Field;
6use crate::segment::SegmentReader;
7use crate::structures::BlockPostingList;
8use crate::wand::WandData;
9use crate::{DocId, Score};
10
11use super::{Bm25Params, CountFuture, EmptyScorer, Query, Scorer, ScorerFuture};
12
13#[derive(Debug, Clone)]
15pub struct TermQuery {
16 pub field: Field,
17 pub term: Vec<u8>,
18 wand_data: Option<Arc<WandData>>,
20 field_name: Option<String>,
22}
23
24impl TermQuery {
25 pub fn new(field: Field, term: impl Into<Vec<u8>>) -> Self {
26 Self {
27 field,
28 term: term.into(),
29 wand_data: None,
30 field_name: None,
31 }
32 }
33
34 pub fn text(field: Field, text: &str) -> Self {
35 Self {
36 field,
37 term: text.to_lowercase().into_bytes(),
38 wand_data: None,
39 field_name: None,
40 }
41 }
42
43 pub fn with_wand_data(
49 field: Field,
50 field_name: &str,
51 term: &str,
52 wand_data: Arc<WandData>,
53 ) -> Self {
54 Self {
55 field,
56 term: term.to_lowercase().into_bytes(),
57 wand_data: Some(wand_data),
58 field_name: Some(field_name.to_string()),
59 }
60 }
61
62 pub fn set_wand_data(&mut self, field_name: &str, wand_data: Arc<WandData>) {
64 self.wand_data = Some(wand_data);
65 self.field_name = Some(field_name.to_string());
66 }
67}
68
69impl Query for TermQuery {
70 fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> ScorerFuture<'a> {
71 Box::pin(async move {
72 let postings = reader.get_postings(self.field, &self.term).await?;
73
74 match postings {
75 Some(posting_list) => {
76 let idf = if let (Some(wand_data), Some(field_name)) =
78 (&self.wand_data, &self.field_name)
79 {
80 let term_str = String::from_utf8_lossy(&self.term);
81 wand_data.get_idf(field_name, &term_str).unwrap_or_else(|| {
82 let num_docs = reader.num_docs() as f32;
84 let doc_freq = posting_list.doc_count() as f32;
85 ((num_docs - doc_freq + 0.5) / (doc_freq + 0.5) + 1.0).ln()
86 })
87 } else {
88 let num_docs = reader.num_docs() as f32;
90 let doc_freq = posting_list.doc_count() as f32;
91 ((num_docs - doc_freq + 0.5) / (doc_freq + 0.5) + 1.0).ln()
92 };
93
94 let avg_field_len = self
97 .wand_data
98 .as_ref()
99 .map(|w| w.avg_doc_len)
100 .unwrap_or_else(|| reader.avg_field_len(self.field));
101
102 Ok(Box::new(TermScorer::new(
103 posting_list,
104 idf,
105 avg_field_len,
106 Bm25Params::default(),
107 1.0, )) as Box<dyn Scorer + 'a>)
109 }
110 None => Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>),
111 }
112 })
113 }
114
115 fn count_estimate<'a>(&'a self, reader: &'a SegmentReader) -> CountFuture<'a> {
116 Box::pin(async move {
117 match reader.get_postings(self.field, &self.term).await? {
118 Some(list) => Ok(list.doc_count()),
119 None => Ok(0),
120 }
121 })
122 }
123}
124
125struct TermScorer {
126 iterator: crate::structures::BlockPostingIterator<'static>,
127 idf: f32,
128 params: Bm25Params,
130 avg_field_len: f32,
132 field_boost: f32,
134}
135
136impl TermScorer {
137 pub fn new(
138 posting_list: BlockPostingList,
139 idf: f32,
140 avg_field_len: f32,
141 params: Bm25Params,
142 field_boost: f32,
143 ) -> Self {
144 Self {
145 iterator: posting_list.into_iterator(),
146 idf,
147 params,
148 avg_field_len,
149 field_boost,
150 }
151 }
152}
153
154impl Scorer for TermScorer {
155 fn doc(&self) -> DocId {
156 self.iterator.doc()
157 }
158
159 fn score(&self) -> Score {
160 let tf = self.iterator.term_freq() as f32;
161 let k1 = self.params.k1;
162 let b = self.params.b;
163
164 let length_norm = 1.0 - b + b * (tf / self.avg_field_len.max(1.0));
166 let tf_norm =
167 (tf * self.field_boost * (k1 + 1.0)) / (tf * self.field_boost + k1 * length_norm);
168
169 self.idf * tf_norm
170 }
171
172 fn advance(&mut self) -> DocId {
173 self.iterator.advance()
174 }
175
176 fn seek(&mut self, target: DocId) -> DocId {
177 self.iterator.seek(target)
178 }
179
180 fn size_hint(&self) -> u32 {
181 0
182 }
183}