hermes_core/query/
boost.rs1use std::sync::Arc;
4
5use crate::segment::SegmentReader;
6use crate::{DocId, Score};
7
8use super::{CountFuture, Query, Scorer, ScorerFuture};
9
10#[derive(Clone)]
12pub struct BoostQuery {
13 pub inner: Arc<dyn Query>,
14 pub boost: f32,
15}
16
17impl std::fmt::Debug for BoostQuery {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 f.debug_struct("BoostQuery")
20 .field("boost", &self.boost)
21 .finish()
22 }
23}
24
25impl BoostQuery {
26 pub fn new(query: impl Query + 'static, boost: f32) -> Self {
27 Self {
28 inner: Arc::new(query),
29 boost,
30 }
31 }
32}
33
34impl Query for BoostQuery {
35 fn scorer<'a>(
36 &self,
37 reader: &'a SegmentReader,
38 limit: usize,
39 predicate: Option<super::DocPredicate<'a>>,
40 ) -> ScorerFuture<'a> {
41 let inner = self.inner.clone();
42 let boost = self.boost;
43 Box::pin(async move {
44 let inner_scorer = inner.scorer(reader, limit, predicate).await?;
45 Ok(Box::new(BoostScorer {
46 inner: inner_scorer,
47 boost,
48 }) as Box<dyn Scorer + 'a>)
49 })
50 }
51
52 #[cfg(feature = "sync")]
53 fn scorer_sync<'a>(
54 &self,
55 reader: &'a SegmentReader,
56 limit: usize,
57 predicate: Option<super::DocPredicate<'a>>,
58 ) -> crate::Result<Box<dyn Scorer + 'a>> {
59 let inner_scorer = self.inner.scorer_sync(reader, limit, predicate)?;
60 Ok(Box::new(BoostScorer {
61 inner: inner_scorer,
62 boost: self.boost,
63 }) as Box<dyn Scorer + 'a>)
64 }
65
66 fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
67 let inner = self.inner.clone();
68 Box::pin(async move { inner.count_estimate(reader).await })
69 }
70}
71
72struct BoostScorer<'a> {
73 inner: Box<dyn Scorer + 'a>,
74 boost: f32,
75}
76
77impl Scorer for BoostScorer<'_> {
78 fn doc(&self) -> DocId {
79 self.inner.doc()
80 }
81
82 fn score(&self) -> Score {
83 self.inner.score() * self.boost
84 }
85
86 fn advance(&mut self) -> DocId {
87 self.inner.advance()
88 }
89
90 fn seek(&mut self, target: DocId) -> DocId {
91 self.inner.seek(target)
92 }
93
94 fn size_hint(&self) -> u32 {
95 self.inner.size_hint()
96 }
97}