1use std::cmp::Ordering;
3use std::collections::HashMap;
4
5use async_trait::async_trait;
6use chrono::Utc;
7
8use crate::{
9 error::VectorResult,
10 types::{RerankerConfig, SearchResult},
11};
12
13#[async_trait]
15pub trait Reranker {
16 async fn rerank(
18 &self,
19 query: &[f32],
20 results: Vec<SearchResult>,
21 ) -> VectorResult<Vec<SearchResult>>;
22}
23
24pub struct CrossEncoderReranker {
26 pub service_available: bool,
28}
29
30#[async_trait]
31impl Reranker for CrossEncoderReranker {
32 async fn rerank(
33 &self,
34 _query: &[f32],
35 results: Vec<SearchResult>,
36 ) -> VectorResult<Vec<SearchResult>> {
37 let _ = self.service_available;
38 Ok(results)
39 }
40}
41
42pub struct DiversityReranker {
44 pub lambda: f32,
46}
47
48#[async_trait]
49impl Reranker for DiversityReranker {
50 async fn rerank(
51 &self,
52 query: &[f32],
53 results: Vec<SearchResult>,
54 ) -> VectorResult<Vec<SearchResult>> {
55 Ok(mmr_select(query, &results, self.lambda, results.len()))
56 }
57}
58
59pub struct RecencyReranker {
61 pub recency_weight: f32,
63 pub half_life_days: f32,
65}
66
67#[async_trait]
68impl Reranker for RecencyReranker {
69 async fn rerank(
70 &self,
71 _query: &[f32],
72 mut results: Vec<SearchResult>,
73 ) -> VectorResult<Vec<SearchResult>> {
74 let now = Utc::now();
75 let half_life_days = self.half_life_days.max(0.001);
76
77 for result in &mut results {
78 let age_seconds = now
79 .signed_duration_since(result.created_at)
80 .num_seconds()
81 .max(0) as f32;
82 let age_days = age_seconds / 86_400.0;
83 let decay_factor = (-age_days / half_life_days).exp();
84 result.score *= 1.0 + self.recency_weight * decay_factor;
85 }
86
87 sort_results_desc(&mut results);
88 Ok(results)
89 }
90}
91
92pub struct CompositeReranker(pub Vec<Box<dyn Reranker + Send + Sync>>);
94
95#[async_trait]
96impl Reranker for CompositeReranker {
97 async fn rerank(
98 &self,
99 query: &[f32],
100 mut results: Vec<SearchResult>,
101 ) -> VectorResult<Vec<SearchResult>> {
102 for reranker in &self.0 {
103 results = reranker.rerank(query, results).await?;
104 }
105 Ok(results)
106 }
107}
108
109pub fn mmr_select(
111 query: &[f32],
112 candidates: &[SearchResult],
113 lambda: f32,
114 top_k: usize,
115) -> Vec<SearchResult> {
116 if candidates.is_empty() || top_k == 0 {
117 return Vec::new();
118 }
119
120 let lambda = lambda.clamp(0.0, 1.0);
121 let mut remaining = candidates.to_vec();
122 let mut selected = Vec::with_capacity(top_k.min(candidates.len()));
123
124 while !remaining.is_empty() && selected.len() < top_k {
125 let best_index = remaining
126 .iter()
127 .enumerate()
128 .map(|(index, candidate)| {
129 let relevance = query_relevance(query, candidate);
130 let max_similarity = selected
131 .iter()
132 .map(|selected| candidate_similarity(candidate, selected))
133 .fold(0.0, f32::max);
134 let score = lambda * relevance - (1.0 - lambda) * max_similarity;
135 (index, score)
136 })
137 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal))
138 .map(|(index, _)| index)
139 .unwrap_or(0);
140
141 selected.push(remaining.remove(best_index));
142 }
143
144 selected
145}
146
147pub async fn apply_reranker_config(
149 query: &[f32],
150 results: Vec<SearchResult>,
151 config: Option<&RerankerConfig>,
152) -> VectorResult<Vec<SearchResult>> {
153 match config {
154 None | Some(RerankerConfig::None) => Ok(results),
155 Some(RerankerConfig::Composite(configs)) => {
156 apply_composite_reranker(query, results, configs).await
157 }
158 Some(config) => build_reranker(config).rerank(query, results).await,
159 }
160}
161
162pub fn reranker_needs_vectors(config: Option<&RerankerConfig>) -> bool {
164 match config {
165 None | Some(RerankerConfig::None) => false,
166 Some(RerankerConfig::Diversity { .. }) => true,
167 Some(RerankerConfig::Recency { .. }) => false,
168 Some(RerankerConfig::Composite(configs)) => configs
169 .iter()
170 .any(|config| reranker_needs_vectors(Some(config))),
171 }
172}
173
174fn build_reranker(config: &RerankerConfig) -> Box<dyn Reranker + Send + Sync> {
175 match config {
176 RerankerConfig::None => Box::new(CompositeReranker(Vec::new())),
177 RerankerConfig::Diversity { lambda, .. } => Box::new(DiversityReranker { lambda: *lambda }),
178 RerankerConfig::Recency {
179 boost,
180 half_life_days,
181 ..
182 } => Box::new(RecencyReranker {
183 recency_weight: *boost,
184 half_life_days: *half_life_days,
185 }),
186 RerankerConfig::Composite(configs) => Box::new(CompositeReranker(
187 configs.iter().map(build_reranker).collect(),
188 )),
189 }
190}
191
192async fn apply_composite_reranker(
193 query: &[f32],
194 results: Vec<SearchResult>,
195 configs: &[RerankerConfig],
196) -> VectorResult<Vec<SearchResult>> {
197 if configs.is_empty() {
198 return Ok(results);
199 }
200
201 let mut current = results;
202 let mut aggregate_scores: HashMap<uuid::Uuid, f32> = HashMap::new();
203 let mut by_id: HashMap<uuid::Uuid, SearchResult> = HashMap::new();
204
205 for config in configs {
206 let reranked = build_reranker(config).rerank(query, current).await?;
207 let normalized = normalize_scores(reranked);
208 let stage_weight = stage_weight(config);
209 current = normalized.clone();
210
211 for result in normalized {
212 *aggregate_scores.entry(result.id).or_insert(0.0) += stage_weight * result.score;
213 by_id.insert(result.id, result);
214 }
215 }
216
217 let mut final_results = by_id
218 .into_iter()
219 .filter_map(|(id, mut result)| {
220 let final_score = aggregate_scores.get(&id).copied()?;
221 result.score = final_score;
222 Some(result)
223 })
224 .collect::<Vec<_>>();
225 sort_results_desc(&mut final_results);
226 Ok(final_results)
227}
228
229fn normalize_scores(mut results: Vec<SearchResult>) -> Vec<SearchResult> {
230 if results.is_empty() {
231 return results;
232 }
233
234 let min_score = results
235 .iter()
236 .map(|result| result.score)
237 .fold(f32::INFINITY, f32::min);
238 let max_score = results
239 .iter()
240 .map(|result| result.score)
241 .fold(f32::NEG_INFINITY, f32::max);
242 let range = max_score - min_score;
243
244 if range.abs() < f32::EPSILON {
245 for result in &mut results {
246 result.score = 1.0;
247 }
248 return results;
249 }
250
251 for result in &mut results {
252 result.score = ((result.score - min_score) / range).clamp(0.0, 1.0);
253 }
254 results
255}
256
257fn stage_weight(config: &RerankerConfig) -> f32 {
258 match config {
259 RerankerConfig::None => 0.0,
260 RerankerConfig::Diversity { weight, .. } => *weight,
261 RerankerConfig::Recency { weight, .. } => *weight,
262 RerankerConfig::Composite(_) => 1.0,
263 }
264}
265
266fn query_relevance(query: &[f32], candidate: &SearchResult) -> f32 {
267 if let Some(vector) = candidate.vector.as_deref() {
268 cosine_similarity(query, vector)
269 .map(|similarity| ((similarity + 1.0) / 2.0).clamp(0.0, 1.0))
270 .unwrap_or(candidate.score)
271 } else {
272 candidate.score
273 }
274}
275
276fn candidate_similarity(left: &SearchResult, right: &SearchResult) -> f32 {
277 match (left.vector.as_deref(), right.vector.as_deref()) {
278 (Some(left), Some(right)) => cosine_similarity(left, right)
279 .map(|similarity| ((similarity + 1.0) / 2.0).clamp(0.0, 1.0))
280 .unwrap_or(0.0),
281 _ => 0.0,
282 }
283}
284
285fn cosine_similarity(left: &[f32], right: &[f32]) -> Option<f32> {
286 if left.len() != right.len() || left.is_empty() {
287 return None;
288 }
289
290 let dot: f32 = left.iter().zip(right.iter()).map(|(a, b)| a * b).sum();
291 let left_norm = left.iter().map(|value| value * value).sum::<f32>().sqrt();
292 let right_norm = right.iter().map(|value| value * value).sum::<f32>().sqrt();
293 if left_norm == 0.0 || right_norm == 0.0 {
294 None
295 } else {
296 Some(dot / (left_norm * right_norm))
297 }
298}
299
300fn sort_results_desc(results: &mut [SearchResult]) {
301 results.sort_by(|left, right| {
302 right
303 .score
304 .partial_cmp(&left.score)
305 .unwrap_or(Ordering::Equal)
306 });
307}
308
309#[cfg(test)]
310mod tests {
311 use chrono::{Duration, Utc};
312
313 use super::apply_reranker_config;
314 use crate::types::{RerankerConfig, SearchResult};
315
316 fn fixture_result(
317 id: uuid::Uuid,
318 vector: Vec<f32>,
319 created_at: chrono::DateTime<Utc>,
320 ) -> SearchResult {
321 SearchResult {
322 id,
323 score: 0.5,
324 vector: Some(vector),
325 metadata: serde_json::json!({}),
326 text: None,
327 created_at,
328 }
329 }
330
331 #[tokio::test]
332 async fn composite_diversity_then_recency_is_deterministic() {
333 let now = Utc::now();
334 let first_id = uuid::Uuid::parse_str("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa").unwrap();
335 let second_id = uuid::Uuid::parse_str("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb").unwrap();
336 let third_id = uuid::Uuid::parse_str("cccccccc-cccc-cccc-cccc-cccccccccccc").unwrap();
337 let query = vec![1.0, 0.0, 0.0];
338 let results = vec![
339 fixture_result(first_id, vec![0.95, 0.05, 0.0], now),
340 fixture_result(second_id, vec![0.5, 0.5, 0.0], now - Duration::days(1)),
341 fixture_result(third_id, vec![0.2, 0.8, 0.0], now - Duration::days(30)),
342 ];
343
344 let config = RerankerConfig::Composite(vec![
345 RerankerConfig::Diversity {
346 lambda: 0.8,
347 weight: 0.6,
348 },
349 RerankerConfig::Recency {
350 boost: 0.5,
351 half_life_days: 7.0,
352 weight: 0.4,
353 },
354 ]);
355
356 let reranked = apply_reranker_config(&query, results, Some(&config))
357 .await
358 .unwrap();
359 let ids = reranked.into_iter().map(|item| item.id).collect::<Vec<_>>();
360
361 assert_eq!(ids, vec![first_id, second_id, third_id]);
362 }
363}