1use crate::core::{DocId, Result, ScoreMode, Scorer, TwoPhaseIterator};
9
10use crate::query::{BoundQuery, Query, ScorerSupplier};
11use crate::search::searcher::Searcher;
12use crate::segment::reader::SegmentReader;
13
14pub struct BoostingQuery {
15 pub(crate) positive: Box<dyn Query>,
16 pub(crate) negative: Box<dyn Query>,
17 pub negative_boost: f32,
18}
19
20impl Query for BoostingQuery {
21 fn bind(&self, searcher: &Searcher, score_mode: ScoreMode) -> Result<Box<dyn BoundQuery>> {
22 let pos_weight = self.positive.bind(searcher, score_mode)?;
23 let neg_weight = self.negative.bind(searcher, ScoreMode::CompleteNoScores)?;
29 Ok(Box::new(BoundBoostingQuery {
30 positive: pos_weight,
31 negative: neg_weight,
32 negative_boost: self.negative_boost,
33 }))
34 }
35}
36
37struct BoundBoostingQuery {
38 positive: Box<dyn BoundQuery>,
39 negative: Box<dyn BoundQuery>,
40 negative_boost: f32,
41}
42
43impl BoundQuery for BoundBoostingQuery {
44 fn scorer_supplier(&self, reader: &SegmentReader) -> Result<Option<Box<dyn ScorerSupplier>>> {
45 let pos = match self.positive.scorer_supplier(reader)? {
46 Some(s) => s,
47 None => return Ok(None),
48 };
49 let neg = self.negative.scorer_supplier(reader)?;
50 Ok(Some(Box::new(BoostingScorerSupplier {
51 positive: pos,
52 negative: neg,
53 negative_boost: self.negative_boost,
54 })))
55 }
56}
57
58struct BoostingScorerSupplier {
59 positive: Box<dyn ScorerSupplier>,
60 negative: Option<Box<dyn ScorerSupplier>>,
61 negative_boost: f32,
62}
63
64impl ScorerSupplier for BoostingScorerSupplier {
65 fn cost(&self) -> u64 {
66 self.positive.cost()
67 }
68 fn scorer(self: Box<Self>) -> Result<Box<dyn Scorer>> {
69 let positive = self.positive.scorer()?;
70 let negative = match self.negative {
71 Some(n) => Some(n.scorer()?),
72 None => None,
73 };
74 Ok(Box::new(BoostingScorer {
75 positive,
76 negative,
77 negative_boost: self.negative_boost,
78 }))
79 }
80}
81
82struct BoostingScorer {
83 positive: Box<dyn Scorer>,
84 negative: Option<Box<dyn Scorer>>,
85 negative_boost: f32,
86}
87
88impl BoostingScorer {
89 fn is_negatively_matched(&mut self) -> bool {
90 let Some(ref mut neg) = self.negative else {
91 return false;
92 };
93 let doc = self.positive.doc_id();
94 if neg.doc_id() < doc {
95 neg.advance(doc);
96 }
97 neg.doc_id() == doc
98 }
99}
100
101impl Scorer for BoostingScorer {
102 fn doc_id(&self) -> DocId {
103 self.positive.doc_id()
104 }
105 fn next(&mut self) -> DocId {
106 self.positive.next()
107 }
108 fn advance(&mut self, target: DocId) -> DocId {
109 self.positive.advance(target)
110 }
111
112 fn score(&mut self) -> f32 {
113 let base = self.positive.score();
114 if self.is_negatively_matched() {
115 base * self.negative_boost
116 } else {
117 base
118 }
119 }
120
121 fn two_phase(&mut self) -> Option<&mut dyn TwoPhaseIterator> {
122 None
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129 use crate::analysis::Token;
130 use crate::core::{FieldId, SegmentId};
131 use crate::mapping::{FieldType, Mapping};
132 use crate::query::match_query::MatchQuery;
133 use crate::query::term::TermQuery;
134 use crate::segment::builder::SegmentBuilder;
135 use crate::segment::reader::SegmentReader;
136
137 fn make_tokens(terms: &[&str]) -> Vec<Token> {
138 terms
139 .iter()
140 .enumerate()
141 .map(|(i, t)| Token::new(*t, 0, t.len(), i as u32))
142 .collect()
143 }
144
145 #[test]
146 fn boosting_demotes_negative() {
147 let schema = Mapping::builder()
148 .field("text", FieldType::Text)
149 .field("tag", FieldType::Keyword)
150 .build();
151 let mut builder = SegmentBuilder::new(SegmentId::new(1), &schema);
152
153 builder.add_document(
155 &[
156 (FieldId::new(0), make_tokens(&["apple", "pie"])),
157 (FieldId::new(1), vec![Token::new("food", 0, 4, 0)]),
158 ],
159 b"{}",
160 );
161
162 builder.add_document(
164 &[
165 (FieldId::new(0), make_tokens(&["apple", "computer"])),
166 (FieldId::new(1), vec![Token::new("tech", 0, 4, 0)]),
167 ],
168 b"{}",
169 );
170
171 let reader = SegmentReader::open(builder.build()).unwrap();
172 let store = crate::search::segment_store::SegmentStore::new(
173 vec![reader],
174 crate::analysis::AnalyzerRegistry::new(),
175 None,
176 None,
177 );
178 let searcher = Searcher::new(&store);
179
180 let query = BoostingQuery {
181 positive: Box::new(MatchQuery {
182 field: "text".into(),
183 query_text: "apple".into(),
184 analyzer: None,
185 }),
186 negative: Box::new(TermQuery {
187 field: "tag".into(),
188 value: "tech".into(),
189 }),
190 negative_boost: 0.5,
191 };
192
193 let results = searcher.search_query(&query, 10, 0).unwrap();
194 assert_eq!(results.total_hits.value, 2); assert!(
197 results.hits[0].score > results.hits[1].score,
198 "non-demoted doc should score higher: {} > {}",
199 results.hits[0].score,
200 results.hits[1].score
201 );
202 }
203}