use crate::core::{DocId, Result, ScoreMode, Scorer, TwoPhaseIterator};
use crate::query::{BoundQuery, Query, ScorerSupplier};
use crate::search::searcher::Searcher;
use crate::segment::reader::SegmentReader;
pub struct BoostingQuery {
pub(crate) positive: Box<dyn Query>,
pub(crate) negative: Box<dyn Query>,
pub negative_boost: f32,
}
impl Query for BoostingQuery {
fn bind(&self, searcher: &Searcher, score_mode: ScoreMode) -> Result<Box<dyn BoundQuery>> {
let pos_weight = self.positive.bind(searcher, score_mode)?;
let neg_weight = self.negative.bind(searcher, ScoreMode::CompleteNoScores)?;
Ok(Box::new(BoundBoostingQuery {
positive: pos_weight,
negative: neg_weight,
negative_boost: self.negative_boost,
}))
}
}
struct BoundBoostingQuery {
positive: Box<dyn BoundQuery>,
negative: Box<dyn BoundQuery>,
negative_boost: f32,
}
impl BoundQuery for BoundBoostingQuery {
fn scorer_supplier(&self, reader: &SegmentReader) -> Result<Option<Box<dyn ScorerSupplier>>> {
let pos = match self.positive.scorer_supplier(reader)? {
Some(s) => s,
None => return Ok(None),
};
let neg = self.negative.scorer_supplier(reader)?;
Ok(Some(Box::new(BoostingScorerSupplier {
positive: pos,
negative: neg,
negative_boost: self.negative_boost,
})))
}
}
struct BoostingScorerSupplier {
positive: Box<dyn ScorerSupplier>,
negative: Option<Box<dyn ScorerSupplier>>,
negative_boost: f32,
}
impl ScorerSupplier for BoostingScorerSupplier {
fn cost(&self) -> u64 {
self.positive.cost()
}
fn scorer(self: Box<Self>) -> Result<Box<dyn Scorer>> {
let positive = self.positive.scorer()?;
let negative = match self.negative {
Some(n) => Some(n.scorer()?),
None => None,
};
Ok(Box::new(BoostingScorer {
positive,
negative,
negative_boost: self.negative_boost,
}))
}
}
struct BoostingScorer {
positive: Box<dyn Scorer>,
negative: Option<Box<dyn Scorer>>,
negative_boost: f32,
}
impl BoostingScorer {
fn is_negatively_matched(&mut self) -> bool {
let Some(ref mut neg) = self.negative else {
return false;
};
let doc = self.positive.doc_id();
if neg.doc_id() < doc {
neg.advance(doc);
}
neg.doc_id() == doc
}
}
impl Scorer for BoostingScorer {
fn doc_id(&self) -> DocId {
self.positive.doc_id()
}
fn next(&mut self) -> DocId {
self.positive.next()
}
fn advance(&mut self, target: DocId) -> DocId {
self.positive.advance(target)
}
fn score(&mut self) -> f32 {
let base = self.positive.score();
if self.is_negatively_matched() {
base * self.negative_boost
} else {
base
}
}
fn two_phase(&mut self) -> Option<&mut dyn TwoPhaseIterator> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::analysis::Token;
use crate::core::{FieldId, SegmentId};
use crate::mapping::{FieldType, Mapping};
use crate::query::match_query::MatchQuery;
use crate::query::term::TermQuery;
use crate::segment::builder::SegmentBuilder;
use crate::segment::reader::SegmentReader;
fn make_tokens(terms: &[&str]) -> Vec<Token> {
terms
.iter()
.enumerate()
.map(|(i, t)| Token::new(*t, 0, t.len(), i as u32))
.collect()
}
#[test]
fn boosting_demotes_negative() {
let schema = Mapping::builder()
.field("text", FieldType::Text)
.field("tag", FieldType::Keyword)
.build();
let mut builder = SegmentBuilder::new(SegmentId::new(1), &schema);
builder.add_document(
&[
(FieldId::new(0), make_tokens(&["apple", "pie"])),
(FieldId::new(1), vec![Token::new("food", 0, 4, 0)]),
],
b"{}",
);
builder.add_document(
&[
(FieldId::new(0), make_tokens(&["apple", "computer"])),
(FieldId::new(1), vec![Token::new("tech", 0, 4, 0)]),
],
b"{}",
);
let reader = SegmentReader::open(builder.build()).unwrap();
let store = crate::search::segment_store::SegmentStore::new(
vec![reader],
crate::analysis::AnalyzerRegistry::new(),
None,
None,
);
let searcher = Searcher::new(&store);
let query = BoostingQuery {
positive: Box::new(MatchQuery {
field: "text".into(),
query_text: "apple".into(),
analyzer: None,
}),
negative: Box::new(TermQuery {
field: "tag".into(),
value: "tech".into(),
}),
negative_boost: 0.5,
};
let results = searcher.search_query(&query, 10, 0).unwrap();
assert_eq!(results.total_hits.value, 2); assert!(
results.hits[0].score > results.hits[1].score,
"non-demoted doc should score higher: {} > {}",
results.hits[0].score,
results.hits[1].score
);
}
}