Skip to main content

hermes_core/query/
boost.rs

1//! Boost query - multiplies the score of the inner query
2
3use std::sync::Arc;
4
5use crate::segment::SegmentReader;
6use crate::{DocId, Score};
7
8use super::{CountFuture, Query, Scorer, ScorerFuture};
9
10/// Boost query - multiplies the score of the inner query
11#[derive(Clone)]
12pub struct BoostQuery {
13    pub inner: Arc<dyn Query>,
14    pub boost: f32,
15}
16
17impl std::fmt::Debug for BoostQuery {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        f.debug_struct("BoostQuery")
20            .field("boost", &self.boost)
21            .finish()
22    }
23}
24
25impl std::fmt::Display for BoostQuery {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        write!(f, "{}^{}", self.inner, self.boost)
28    }
29}
30
31impl BoostQuery {
32    pub fn new(query: impl Query + 'static, boost: f32) -> Self {
33        Self {
34            inner: Arc::new(query),
35            boost,
36        }
37    }
38}
39
40impl Query for BoostQuery {
41    fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
42        let inner = self.inner.clone();
43        let boost = self.boost;
44        Box::pin(async move {
45            let inner_scorer = inner.scorer(reader, limit).await?;
46            Ok(Box::new(BoostScorer {
47                inner: inner_scorer,
48                boost,
49            }) as Box<dyn Scorer + 'a>)
50        })
51    }
52
53    #[cfg(feature = "sync")]
54    fn scorer_sync<'a>(
55        &self,
56        reader: &'a SegmentReader,
57        limit: usize,
58    ) -> crate::Result<Box<dyn Scorer + 'a>> {
59        let inner_scorer = self.inner.scorer_sync(reader, limit)?;
60        Ok(Box::new(BoostScorer {
61            inner: inner_scorer,
62            boost: self.boost,
63        }) as Box<dyn Scorer + 'a>)
64    }
65
66    fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
67        let inner = self.inner.clone();
68        Box::pin(async move { inner.count_estimate(reader).await })
69    }
70
71    fn is_filter(&self) -> bool {
72        self.inner.is_filter()
73    }
74
75    fn as_doc_predicate<'a>(&self, reader: &'a SegmentReader) -> Option<super::DocPredicate<'a>> {
76        self.inner.as_doc_predicate(reader)
77    }
78}
79
80struct BoostScorer<'a> {
81    inner: Box<dyn Scorer + 'a>,
82    boost: f32,
83}
84
85impl super::docset::DocSet for BoostScorer<'_> {
86    fn doc(&self) -> DocId {
87        self.inner.doc()
88    }
89
90    fn advance(&mut self) -> DocId {
91        self.inner.advance()
92    }
93
94    fn seek(&mut self, target: DocId) -> DocId {
95        self.inner.seek(target)
96    }
97
98    fn size_hint(&self) -> u32 {
99        self.inner.size_hint()
100    }
101}
102
103impl Scorer for BoostScorer<'_> {
104    fn score(&self) -> Score {
105        self.inner.score() * self.boost
106    }
107}