reflex/scoring/
scorer.rs

1use std::cmp::Ordering;
2use tracing::{debug, info};
3
4use crate::embedding::{Reranker, RerankerConfig};
5use crate::storage::CacheEntry;
6
7use super::error::ScoringError;
8use super::types::{VerificationResult, VerifiedCandidate};
9
10/// L3 verifier that reranks candidates with a cross-encoder.
11pub struct CrossEncoderScorer {
12    reranker: Reranker,
13}
14
15impl std::fmt::Debug for CrossEncoderScorer {
16    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17        f.debug_struct("CrossEncoderScorer")
18            .field("reranker", &self.reranker)
19            .finish()
20    }
21}
22
23impl CrossEncoderScorer {
24    /// Creates a new scorer from a reranker config.
25    pub fn new(config: RerankerConfig) -> Result<Self, ScoringError> {
26        let reranker = Reranker::load(config)?;
27        Ok(Self { reranker })
28    }
29
30    /// Creates a scorer in stub mode (no model files required).
31    pub fn stub() -> Result<Self, ScoringError> {
32        Ok(Self {
33            reranker: Reranker::stub()?,
34        })
35    }
36
37    /// Returns `true` if a reranker model is loaded.
38    pub fn is_model_loaded(&self) -> bool {
39        self.reranker.is_model_loaded()
40    }
41
42    /// Returns the verification threshold.
43    pub fn threshold(&self) -> f32 {
44        self.reranker.threshold()
45    }
46
47    /// Returns the underlying reranker.
48    pub fn reranker(&self) -> &Reranker {
49        &self.reranker
50    }
51
52    /// Scores a query/candidate text pair.
53    pub fn score(&self, query: &str, candidate_text: &str) -> Result<f32, ScoringError> {
54        Ok(self.reranker.score(query, candidate_text)?)
55    }
56
57    /// Verifies candidates and returns the winning entry (if verified) plus a result.
58    pub fn verify_candidates(
59        &self,
60        query: &str,
61        candidates: Vec<(CacheEntry, f32)>,
62    ) -> Result<(Option<CacheEntry>, VerificationResult), ScoringError> {
63        if candidates.is_empty() {
64            debug!("No candidates provided for verification");
65            return Ok((None, VerificationResult::NoCandidates));
66        }
67
68        debug!(
69            query_len = query.len(),
70            num_candidates = candidates.len(),
71            "Starting L3 verification"
72        );
73
74        let mut verified_candidates = self.score_candidates(query, candidates)?;
75
76        verified_candidates.sort_by(|a, b| {
77            b.cross_encoder_score
78                .partial_cmp(&a.cross_encoder_score)
79                .unwrap_or(Ordering::Equal)
80        });
81
82        // SAFETY: candidates is non-empty (checked above), and score_candidates
83        // maps 1:1, so verified_candidates is guaranteed non-empty
84        let top = &verified_candidates[0];
85
86        debug!(
87            top_score = top.cross_encoder_score,
88            original_score = top.original_score,
89            threshold = self.threshold(),
90            "Top candidate after reranking"
91        );
92
93        let score = top.cross_encoder_score;
94
95        if score > self.threshold() {
96            let entry = top.entry.clone();
97
98            info!(
99                score = score,
100                threshold = self.threshold(),
101                "L3 verification passed - cache hit"
102            );
103
104            Ok((Some(entry), VerificationResult::Verified { score }))
105        } else {
106            debug!(
107                score = score,
108                threshold = self.threshold(),
109                "Top candidate below threshold - cache miss"
110            );
111
112            Ok((None, VerificationResult::Rejected { top_score: score }))
113        }
114    }
115
116    /// Scores all candidates and returns the full list with scores.
117    pub fn score_candidates(
118        &self,
119        query: &str,
120        candidates: Vec<(CacheEntry, f32)>,
121    ) -> Result<Vec<VerifiedCandidate>, ScoringError> {
122        candidates
123            .into_iter()
124            .map(|(entry, original_score)| {
125                let candidate_text = String::from_utf8_lossy(&entry.payload_blob);
126                let cross_encoder_score = self.reranker.score(query, &candidate_text)?;
127
128                Ok(VerifiedCandidate::new(
129                    entry,
130                    cross_encoder_score,
131                    original_score,
132                ))
133            })
134            .collect()
135    }
136
137    /// Verifies candidates and also returns scored details.
138    pub fn verify_candidates_with_details(
139        &self,
140        query: &str,
141        candidates: Vec<(CacheEntry, f32)>,
142    ) -> Result<(Vec<VerifiedCandidate>, VerificationResult), ScoringError> {
143        if candidates.is_empty() {
144            return Ok((vec![], VerificationResult::NoCandidates));
145        }
146
147        let mut scored = self.score_candidates(query, candidates)?;
148
149        scored.sort_by(|a, b| {
150            b.cross_encoder_score
151                .partial_cmp(&a.cross_encoder_score)
152                .unwrap_or(Ordering::Equal)
153        });
154
155        // SAFETY: candidates is non-empty (checked above), and score_candidates
156        // maps 1:1, so scored is guaranteed non-empty
157        let score = scored[0].cross_encoder_score;
158        let result = if score > self.threshold() {
159            VerificationResult::Verified { score }
160        } else {
161            VerificationResult::Rejected { top_score: score }
162        };
163
164        Ok((scored, result))
165    }
166
167    /// Reranks and returns the top `top_n` scored candidates.
168    pub fn rerank_top_n(
169        &self,
170        query: &str,
171        candidates: Vec<(CacheEntry, f32)>,
172        top_n: usize,
173    ) -> Result<Vec<VerifiedCandidate>, ScoringError> {
174        let mut scored = self.score_candidates(query, candidates)?;
175
176        scored.sort_by(|a, b| {
177            b.cross_encoder_score
178                .partial_cmp(&a.cross_encoder_score)
179                .unwrap_or(Ordering::Equal)
180        });
181
182        scored.truncate(top_n);
183        Ok(scored)
184    }
185}