1use std::cmp::Ordering;
3use std::collections::HashMap;
4
5use async_trait::async_trait;
6use chrono::Utc;
7
8use crate::{
9 error::VectorResult,
10 types::{RerankerConfig, SearchResult},
11};
12
13#[async_trait]
15pub trait Reranker {
16 async fn rerank(
18 &self,
19 query: &[f32],
20 results: Vec<SearchResult>,
21 ) -> VectorResult<Vec<SearchResult>>;
22}
23
24pub struct CrossEncoderReranker {
26 pub service_available: bool,
28}
29
30#[async_trait]
31impl Reranker for CrossEncoderReranker {
32 async fn rerank(
33 &self,
34 _query: &[f32],
35 results: Vec<SearchResult>,
36 ) -> VectorResult<Vec<SearchResult>> {
37 let _ = self.service_available;
38 Ok(results)
39 }
40}
41
42pub struct DiversityReranker {
44 pub lambda: f32,
46}
47
48#[async_trait]
49impl Reranker for DiversityReranker {
50 async fn rerank(
51 &self,
52 query: &[f32],
53 results: Vec<SearchResult>,
54 ) -> VectorResult<Vec<SearchResult>> {
55 Ok(mmr_select(query, &results, self.lambda, results.len()))
56 }
57}
58
59pub struct RecencyReranker {
61 pub recency_weight: f32,
63 pub half_life_days: f32,
65}
66
67#[async_trait]
68impl Reranker for RecencyReranker {
69 async fn rerank(
70 &self,
71 _query: &[f32],
72 mut results: Vec<SearchResult>,
73 ) -> VectorResult<Vec<SearchResult>> {
74 let now = Utc::now();
75 let half_life_days = self.half_life_days.max(0.001);
76
77 for result in &mut results {
78 let age_seconds = now
79 .signed_duration_since(result.created_at)
80 .num_seconds()
81 .max(0) as f32;
82 let age_days = age_seconds / 86_400.0;
83 let decay_factor = (-age_days / half_life_days).exp();
84 result.score *= 1.0 + self.recency_weight * decay_factor;
85 }
86
87 sort_results_desc(&mut results);
88 Ok(results)
89 }
90}
91
92pub struct CompositeReranker(pub Vec<Box<dyn Reranker + Send + Sync>>);
94
95#[async_trait]
96impl Reranker for CompositeReranker {
97 async fn rerank(
98 &self,
99 query: &[f32],
100 mut results: Vec<SearchResult>,
101 ) -> VectorResult<Vec<SearchResult>> {
102 for reranker in &self.0 {
103 results = reranker.rerank(query, results).await?;
104 }
105 Ok(results)
106 }
107}
108
109pub fn mmr_select(
111 query: &[f32],
112 candidates: &[SearchResult],
113 lambda: f32,
114 top_k: usize,
115) -> Vec<SearchResult> {
116 if candidates.is_empty() || top_k == 0 {
117 return Vec::new();
118 }
119
120 let lambda = lambda.clamp(0.0, 1.0);
121 let mut remaining = candidates.to_vec();
122 let mut selected = Vec::with_capacity(top_k.min(candidates.len()));
123
124 while !remaining.is_empty() && selected.len() < top_k {
125 let best_index = remaining
126 .iter()
127 .enumerate()
128 .map(|(index, candidate)| {
129 let relevance = query_relevance(query, candidate);
130 let max_similarity = selected
131 .iter()
132 .map(|selected| candidate_similarity(candidate, selected))
133 .fold(0.0, f32::max);
134 let score = lambda * relevance - (1.0 - lambda) * max_similarity;
135 (index, score)
136 })
137 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal))
138 .map(|(index, _)| index)
139 .unwrap_or(0);
140
141 selected.push(remaining.remove(best_index));
142 }
143
144 selected
145}
146
147pub async fn apply_reranker_config(
149 query: &[f32],
150 results: Vec<SearchResult>,
151 config: Option<&RerankerConfig>,
152) -> VectorResult<Vec<SearchResult>> {
153 match config {
154 None | Some(RerankerConfig::None) => Ok(results),
155 Some(RerankerConfig::Composite(configs)) => {
156 apply_composite_reranker(query, results, configs).await
157 }
158 Some(config) => build_reranker(config).rerank(query, results).await,
159 }
160}
161
162pub fn reranker_needs_vectors(config: Option<&RerankerConfig>) -> bool {
164 match config {
165 None | Some(RerankerConfig::None) => false,
166 Some(RerankerConfig::Diversity { .. }) => true,
167 Some(RerankerConfig::Recency { .. }) => false,
168 Some(RerankerConfig::Composite(configs)) => configs
169 .iter()
170 .any(|config| reranker_needs_vectors(Some(config))),
171 }
172}
173
174fn build_reranker(config: &RerankerConfig) -> Box<dyn Reranker + Send + Sync> {
175 match config {
176 RerankerConfig::None => Box::new(CompositeReranker(Vec::new())),
177 RerankerConfig::Diversity { lambda, .. } => Box::new(DiversityReranker { lambda: *lambda }),
178 RerankerConfig::Recency {
179 boost,
180 half_life_days,
181 ..
182 } => Box::new(RecencyReranker {
183 recency_weight: *boost,
184 half_life_days: *half_life_days,
185 }),
186 RerankerConfig::Composite(configs) => Box::new(CompositeReranker(
187 configs.iter().map(build_reranker).collect(),
188 )),
189 }
190}
191
192async fn apply_composite_reranker(
193 query: &[f32],
194 results: Vec<SearchResult>,
195 configs: &[RerankerConfig],
196) -> VectorResult<Vec<SearchResult>> {
197 if configs.is_empty() {
198 return Ok(results);
199 }
200
201 let mut current = results;
202 let mut aggregate_scores: HashMap<uuid::Uuid, f32> = HashMap::new();
203 let mut by_id: HashMap<uuid::Uuid, SearchResult> = HashMap::new();
204 let mut total_weight = 0.0f32;
205
206 for config in configs {
207 let reranked = build_reranker(config).rerank(query, current).await?;
208 let normalized = normalize_scores(reranked);
209 let stage_weight = stage_weight(config);
210 total_weight += stage_weight;
211 current = normalized.clone();
212
213 for result in normalized {
214 *aggregate_scores.entry(result.id).or_insert(0.0) += stage_weight * result.score;
215 by_id.insert(result.id, result);
216 }
217 }
218
219 let mut final_results = by_id
220 .into_iter()
221 .filter_map(|(id, mut result)| {
222 let denominator = if total_weight <= 0.0 {
223 1.0
224 } else {
225 total_weight
226 };
227 let final_score = aggregate_scores.get(&id).copied()? / denominator;
228 result.score = final_score;
229 Some(result)
230 })
231 .collect::<Vec<_>>();
232 sort_results_desc(&mut final_results);
233 Ok(final_results)
234}
235
236fn normalize_scores(mut results: Vec<SearchResult>) -> Vec<SearchResult> {
237 if results.is_empty() {
238 return results;
239 }
240
241 let min_score = results
242 .iter()
243 .map(|result| result.score)
244 .fold(f32::INFINITY, f32::min);
245 let max_score = results
246 .iter()
247 .map(|result| result.score)
248 .fold(f32::NEG_INFINITY, f32::max);
249 let epsilon = 1e-9_f32;
250 let range = max_score - min_score;
251
252 if range.abs() < epsilon {
253 for result in &mut results {
254 result.score = 1.0;
255 }
256 return results;
257 }
258
259 for result in &mut results {
260 result.score = ((result.score - min_score) / (range + epsilon)).clamp(0.0, 1.0);
261 }
262 results
263}
264
265fn stage_weight(config: &RerankerConfig) -> f32 {
266 match config {
267 RerankerConfig::None => 0.0,
268 RerankerConfig::Diversity { weight, .. } => *weight,
269 RerankerConfig::Recency { weight, .. } => *weight,
270 RerankerConfig::Composite(_) => 1.0,
271 }
272}
273
274fn query_relevance(query: &[f32], candidate: &SearchResult) -> f32 {
275 if let Some(vector) = candidate.vector.as_deref() {
276 cosine_similarity(query, vector)
277 .map(|similarity| ((similarity + 1.0) / 2.0).clamp(0.0, 1.0))
278 .unwrap_or(candidate.score)
279 } else {
280 candidate.score
281 }
282}
283
284fn candidate_similarity(left: &SearchResult, right: &SearchResult) -> f32 {
285 match (left.vector.as_deref(), right.vector.as_deref()) {
286 (Some(left), Some(right)) => cosine_similarity(left, right)
287 .map(|similarity| ((similarity + 1.0) / 2.0).clamp(0.0, 1.0))
288 .unwrap_or(0.0),
289 _ => 0.0,
290 }
291}
292
293fn cosine_similarity(left: &[f32], right: &[f32]) -> Option<f32> {
294 if left.len() != right.len() || left.is_empty() {
295 return None;
296 }
297
298 let dot: f32 = left.iter().zip(right.iter()).map(|(a, b)| a * b).sum();
299 let left_norm = left.iter().map(|value| value * value).sum::<f32>().sqrt();
300 let right_norm = right.iter().map(|value| value * value).sum::<f32>().sqrt();
301 if left_norm == 0.0 || right_norm == 0.0 {
302 None
303 } else {
304 Some(dot / (left_norm * right_norm))
305 }
306}
307
308fn sort_results_desc(results: &mut [SearchResult]) {
309 results.sort_by(|left, right| {
310 right
311 .score
312 .partial_cmp(&left.score)
313 .unwrap_or(Ordering::Equal)
314 });
315}
316
317#[cfg(test)]
318mod tests {
319 use chrono::{Duration, Utc};
320
321 use super::apply_reranker_config;
322 use crate::types::{RerankerConfig, SearchResult};
323
324 fn fixture_result(
325 id: uuid::Uuid,
326 vector: Vec<f32>,
327 created_at: chrono::DateTime<Utc>,
328 ) -> SearchResult {
329 SearchResult {
330 id,
331 score: 0.5,
332 vector: Some(vector),
333 metadata: serde_json::json!({}),
334 text: None,
335 created_at,
336 }
337 }
338
339 #[tokio::test]
340 async fn composite_diversity_then_recency_is_deterministic() {
341 let now = Utc::now();
342 let first_id = uuid::Uuid::parse_str("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa").unwrap();
343 let second_id = uuid::Uuid::parse_str("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb").unwrap();
344 let third_id = uuid::Uuid::parse_str("cccccccc-cccc-cccc-cccc-cccccccccccc").unwrap();
345 let query = vec![1.0, 0.0, 0.0];
346 let results = vec![
347 fixture_result(first_id, vec![0.95, 0.05, 0.0], now),
348 fixture_result(second_id, vec![0.5, 0.5, 0.0], now - Duration::days(1)),
349 fixture_result(third_id, vec![0.2, 0.8, 0.0], now - Duration::days(30)),
350 ];
351
352 let config = RerankerConfig::Composite(vec![
353 RerankerConfig::Diversity {
354 lambda: 0.8,
355 weight: 0.6,
356 },
357 RerankerConfig::Recency {
358 boost: 0.5,
359 half_life_days: 7.0,
360 weight: 0.4,
361 },
362 ]);
363
364 let reranked = apply_reranker_config(&query, results, Some(&config))
365 .await
366 .unwrap();
367 let ids = reranked.into_iter().map(|item| item.id).collect::<Vec<_>>();
368
369 assert_eq!(ids, vec![first_id, second_id, third_id]);
370 }
371}