1use std::cmp::Ordering;
2use tracing::{debug, info};
3
4use crate::embedding::{Reranker, RerankerConfig};
5use crate::storage::CacheEntry;
6
7use super::error::ScoringError;
8use super::types::{VerificationResult, VerifiedCandidate};
9
10pub struct CrossEncoderScorer {
12 reranker: Reranker,
13}
14
15impl std::fmt::Debug for CrossEncoderScorer {
16 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17 f.debug_struct("CrossEncoderScorer")
18 .field("reranker", &self.reranker)
19 .finish()
20 }
21}
22
23impl CrossEncoderScorer {
24 pub fn new(config: RerankerConfig) -> Result<Self, ScoringError> {
26 let reranker = Reranker::load(config)?;
27 Ok(Self { reranker })
28 }
29
30 pub fn stub() -> Result<Self, ScoringError> {
32 Ok(Self {
33 reranker: Reranker::stub()?,
34 })
35 }
36
37 pub fn is_model_loaded(&self) -> bool {
39 self.reranker.is_model_loaded()
40 }
41
42 pub fn threshold(&self) -> f32 {
44 self.reranker.threshold()
45 }
46
47 pub fn reranker(&self) -> &Reranker {
49 &self.reranker
50 }
51
52 pub fn score(&self, query: &str, candidate_text: &str) -> Result<f32, ScoringError> {
54 Ok(self.reranker.score(query, candidate_text)?)
55 }
56
57 pub fn verify_candidates(
59 &self,
60 query: &str,
61 candidates: Vec<(CacheEntry, f32)>,
62 ) -> Result<(Option<CacheEntry>, VerificationResult), ScoringError> {
63 if candidates.is_empty() {
64 debug!("No candidates provided for verification");
65 return Ok((None, VerificationResult::NoCandidates));
66 }
67
68 debug!(
69 query_len = query.len(),
70 num_candidates = candidates.len(),
71 "Starting L3 verification"
72 );
73
74 let mut verified_candidates = self.score_candidates(query, candidates)?;
75
76 verified_candidates.sort_by(|a, b| {
77 b.cross_encoder_score
78 .partial_cmp(&a.cross_encoder_score)
79 .unwrap_or(Ordering::Equal)
80 });
81
82 let top = &verified_candidates[0];
85
86 debug!(
87 top_score = top.cross_encoder_score,
88 original_score = top.original_score,
89 threshold = self.threshold(),
90 "Top candidate after reranking"
91 );
92
93 let score = top.cross_encoder_score;
94
95 if score > self.threshold() {
96 let entry = top.entry.clone();
97
98 info!(
99 score = score,
100 threshold = self.threshold(),
101 "L3 verification passed - cache hit"
102 );
103
104 Ok((Some(entry), VerificationResult::Verified { score }))
105 } else {
106 debug!(
107 score = score,
108 threshold = self.threshold(),
109 "Top candidate below threshold - cache miss"
110 );
111
112 Ok((None, VerificationResult::Rejected { top_score: score }))
113 }
114 }
115
116 pub fn score_candidates(
118 &self,
119 query: &str,
120 candidates: Vec<(CacheEntry, f32)>,
121 ) -> Result<Vec<VerifiedCandidate>, ScoringError> {
122 candidates
123 .into_iter()
124 .map(|(entry, original_score)| {
125 let candidate_text = String::from_utf8_lossy(&entry.payload_blob);
126 let cross_encoder_score = self.reranker.score(query, &candidate_text)?;
127
128 Ok(VerifiedCandidate::new(
129 entry,
130 cross_encoder_score,
131 original_score,
132 ))
133 })
134 .collect()
135 }
136
137 pub fn verify_candidates_with_details(
139 &self,
140 query: &str,
141 candidates: Vec<(CacheEntry, f32)>,
142 ) -> Result<(Vec<VerifiedCandidate>, VerificationResult), ScoringError> {
143 if candidates.is_empty() {
144 return Ok((vec![], VerificationResult::NoCandidates));
145 }
146
147 let mut scored = self.score_candidates(query, candidates)?;
148
149 scored.sort_by(|a, b| {
150 b.cross_encoder_score
151 .partial_cmp(&a.cross_encoder_score)
152 .unwrap_or(Ordering::Equal)
153 });
154
155 let score = scored[0].cross_encoder_score;
158 let result = if score > self.threshold() {
159 VerificationResult::Verified { score }
160 } else {
161 VerificationResult::Rejected { top_score: score }
162 };
163
164 Ok((scored, result))
165 }
166
167 pub fn rerank_top_n(
169 &self,
170 query: &str,
171 candidates: Vec<(CacheEntry, f32)>,
172 top_n: usize,
173 ) -> Result<Vec<VerifiedCandidate>, ScoringError> {
174 let mut scored = self.score_candidates(query, candidates)?;
175
176 scored.sort_by(|a, b| {
177 b.cross_encoder_score
178 .partial_cmp(&a.cross_encoder_score)
179 .unwrap_or(Ordering::Equal)
180 });
181
182 scored.truncate(top_n);
183 Ok(scored)
184 }
185}