Skip to main content

hermes_core/query/
boost.rs

1//! Boost query - multiplies the score of the inner query
2
3use std::sync::Arc;
4
5use crate::segment::SegmentReader;
6use crate::{DocId, Score};
7
8use super::{CountFuture, Query, Scorer, ScorerFuture};
9
10/// Boost query - multiplies the score of the inner query
11#[derive(Clone)]
12pub struct BoostQuery {
13    pub inner: Arc<dyn Query>,
14    pub boost: f32,
15}
16
17impl std::fmt::Debug for BoostQuery {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        f.debug_struct("BoostQuery")
20            .field("boost", &self.boost)
21            .finish()
22    }
23}
24
25impl BoostQuery {
26    pub fn new(query: impl Query + 'static, boost: f32) -> Self {
27        Self {
28            inner: Arc::new(query),
29            boost,
30        }
31    }
32}
33
34impl Query for BoostQuery {
35    fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
36        let inner = self.inner.clone();
37        let boost = self.boost;
38        Box::pin(async move {
39            let inner_scorer = inner.scorer(reader, limit).await?;
40            Ok(Box::new(BoostScorer {
41                inner: inner_scorer,
42                boost,
43            }) as Box<dyn Scorer + 'a>)
44        })
45    }
46
47    fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
48        let inner = self.inner.clone();
49        Box::pin(async move { inner.count_estimate(reader).await })
50    }
51}
52
53struct BoostScorer<'a> {
54    inner: Box<dyn Scorer + 'a>,
55    boost: f32,
56}
57
58impl Scorer for BoostScorer<'_> {
59    fn doc(&self) -> DocId {
60        self.inner.doc()
61    }
62
63    fn score(&self) -> Score {
64        self.inner.score() * self.boost
65    }
66
67    fn advance(&mut self) -> DocId {
68        self.inner.advance()
69    }
70
71    fn seek(&mut self, target: DocId) -> DocId {
72        self.inner.seek(target)
73    }
74
75    fn size_hint(&self) -> u32 {
76        self.inner.size_hint()
77    }
78}