Skip to main content

claw_vector/search/
rerank.rs

1// search/rerank.rs — post-retrieval reranking strategies.
2use std::cmp::Ordering;
3use std::collections::HashMap;
4
5use async_trait::async_trait;
6use chrono::Utc;
7
8use crate::{
9    error::VectorResult,
10    types::{RerankerConfig, SearchResult},
11};
12
13/// Post-retrieval reranker interface.
14#[async_trait]
15pub trait Reranker {
16    /// Rerank a list of search results for the provided query vector.
17    async fn rerank(
18        &self,
19        query: &[f32],
20        results: Vec<SearchResult>,
21    ) -> VectorResult<Vec<SearchResult>>;
22}
23
24/// Placeholder cross-encoder reranker with a future gRPC hook.
25pub struct CrossEncoderReranker {
26    /// Whether a scoring backend is currently available.
27    pub service_available: bool,
28}
29
30#[async_trait]
31impl Reranker for CrossEncoderReranker {
32    async fn rerank(
33        &self,
34        _query: &[f32],
35        results: Vec<SearchResult>,
36    ) -> VectorResult<Vec<SearchResult>> {
37        let _ = self.service_available;
38        Ok(results)
39    }
40}
41
42/// Diversity-promoting reranker based on maximal marginal relevance.
43pub struct DiversityReranker {
44    /// Relevance-vs-diversity balance in the range `[0.0, 1.0]`.
45    pub lambda: f32,
46}
47
48#[async_trait]
49impl Reranker for DiversityReranker {
50    async fn rerank(
51        &self,
52        query: &[f32],
53        results: Vec<SearchResult>,
54    ) -> VectorResult<Vec<SearchResult>> {
55        Ok(mmr_select(query, &results, self.lambda, results.len()))
56    }
57}
58
59/// Recency-based reranker.
60pub struct RecencyReranker {
61    /// Weight applied to the recency boost.
62    pub recency_weight: f32,
63    /// Exponential decay half-life in days.
64    pub half_life_days: f32,
65}
66
67#[async_trait]
68impl Reranker for RecencyReranker {
69    async fn rerank(
70        &self,
71        _query: &[f32],
72        mut results: Vec<SearchResult>,
73    ) -> VectorResult<Vec<SearchResult>> {
74        let now = Utc::now();
75        let half_life_days = self.half_life_days.max(0.001);
76
77        for result in &mut results {
78            let age_seconds = now
79                .signed_duration_since(result.created_at)
80                .num_seconds()
81                .max(0) as f32;
82            let age_days = age_seconds / 86_400.0;
83            let decay_factor = (-age_days / half_life_days).exp();
84            result.score *= 1.0 + self.recency_weight * decay_factor;
85        }
86
87        sort_results_desc(&mut results);
88        Ok(results)
89    }
90}
91
92/// Apply multiple rerankers in sequence.
93pub struct CompositeReranker(pub Vec<Box<dyn Reranker + Send + Sync>>);
94
95#[async_trait]
96impl Reranker for CompositeReranker {
97    async fn rerank(
98        &self,
99        query: &[f32],
100        mut results: Vec<SearchResult>,
101    ) -> VectorResult<Vec<SearchResult>> {
102        for reranker in &self.0 {
103            results = reranker.rerank(query, results).await?;
104        }
105        Ok(results)
106    }
107}
108
109/// Select a diverse subset of candidates using maximal marginal relevance.
110pub fn mmr_select(
111    query: &[f32],
112    candidates: &[SearchResult],
113    lambda: f32,
114    top_k: usize,
115) -> Vec<SearchResult> {
116    if candidates.is_empty() || top_k == 0 {
117        return Vec::new();
118    }
119
120    let lambda = lambda.clamp(0.0, 1.0);
121    let mut remaining = candidates.to_vec();
122    let mut selected = Vec::with_capacity(top_k.min(candidates.len()));
123
124    while !remaining.is_empty() && selected.len() < top_k {
125        let best_index = remaining
126            .iter()
127            .enumerate()
128            .map(|(index, candidate)| {
129                let relevance = query_relevance(query, candidate);
130                let max_similarity = selected
131                    .iter()
132                    .map(|selected| candidate_similarity(candidate, selected))
133                    .fold(0.0, f32::max);
134                let score = lambda * relevance - (1.0 - lambda) * max_similarity;
135                (index, score)
136            })
137            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal))
138            .map(|(index, _)| index)
139            .unwrap_or(0);
140
141        selected.push(remaining.remove(best_index));
142    }
143
144    selected
145}
146
147/// Apply a configured reranker chain to search results.
148pub async fn apply_reranker_config(
149    query: &[f32],
150    results: Vec<SearchResult>,
151    config: Option<&RerankerConfig>,
152) -> VectorResult<Vec<SearchResult>> {
153    match config {
154        None | Some(RerankerConfig::None) => Ok(results),
155        Some(RerankerConfig::Composite(configs)) => {
156            apply_composite_reranker(query, results, configs).await
157        }
158        Some(config) => build_reranker(config).rerank(query, results).await,
159    }
160}
161
162/// Return `true` when a configured reranker requires access to raw vectors.
163pub fn reranker_needs_vectors(config: Option<&RerankerConfig>) -> bool {
164    match config {
165        None | Some(RerankerConfig::None) => false,
166        Some(RerankerConfig::Diversity { .. }) => true,
167        Some(RerankerConfig::Recency { .. }) => false,
168        Some(RerankerConfig::Composite(configs)) => configs
169            .iter()
170            .any(|config| reranker_needs_vectors(Some(config))),
171    }
172}
173
174fn build_reranker(config: &RerankerConfig) -> Box<dyn Reranker + Send + Sync> {
175    match config {
176        RerankerConfig::None => Box::new(CompositeReranker(Vec::new())),
177        RerankerConfig::Diversity { lambda, .. } => Box::new(DiversityReranker { lambda: *lambda }),
178        RerankerConfig::Recency {
179            boost,
180            half_life_days,
181            ..
182        } => Box::new(RecencyReranker {
183            recency_weight: *boost,
184            half_life_days: *half_life_days,
185        }),
186        RerankerConfig::Composite(configs) => Box::new(CompositeReranker(
187            configs.iter().map(build_reranker).collect(),
188        )),
189    }
190}
191
192async fn apply_composite_reranker(
193    query: &[f32],
194    results: Vec<SearchResult>,
195    configs: &[RerankerConfig],
196) -> VectorResult<Vec<SearchResult>> {
197    if configs.is_empty() {
198        return Ok(results);
199    }
200
201    let mut current = results;
202    let mut aggregate_scores: HashMap<uuid::Uuid, f32> = HashMap::new();
203    let mut by_id: HashMap<uuid::Uuid, SearchResult> = HashMap::new();
204    let mut total_weight = 0.0f32;
205
206    for config in configs {
207        let reranked = build_reranker(config).rerank(query, current).await?;
208        let normalized = normalize_scores(reranked);
209        let stage_weight = stage_weight(config);
210        total_weight += stage_weight;
211        current = normalized.clone();
212
213        for result in normalized {
214            *aggregate_scores.entry(result.id).or_insert(0.0) += stage_weight * result.score;
215            by_id.insert(result.id, result);
216        }
217    }
218
219    let mut final_results = by_id
220        .into_iter()
221        .filter_map(|(id, mut result)| {
222            let denominator = if total_weight <= 0.0 {
223                1.0
224            } else {
225                total_weight
226            };
227            let final_score = aggregate_scores.get(&id).copied()? / denominator;
228            result.score = final_score;
229            Some(result)
230        })
231        .collect::<Vec<_>>();
232    sort_results_desc(&mut final_results);
233    Ok(final_results)
234}
235
236fn normalize_scores(mut results: Vec<SearchResult>) -> Vec<SearchResult> {
237    if results.is_empty() {
238        return results;
239    }
240
241    let min_score = results
242        .iter()
243        .map(|result| result.score)
244        .fold(f32::INFINITY, f32::min);
245    let max_score = results
246        .iter()
247        .map(|result| result.score)
248        .fold(f32::NEG_INFINITY, f32::max);
249    let epsilon = 1e-9_f32;
250    let range = max_score - min_score;
251
252    if range.abs() < epsilon {
253        for result in &mut results {
254            result.score = 1.0;
255        }
256        return results;
257    }
258
259    for result in &mut results {
260        result.score = ((result.score - min_score) / (range + epsilon)).clamp(0.0, 1.0);
261    }
262    results
263}
264
265fn stage_weight(config: &RerankerConfig) -> f32 {
266    match config {
267        RerankerConfig::None => 0.0,
268        RerankerConfig::Diversity { weight, .. } => *weight,
269        RerankerConfig::Recency { weight, .. } => *weight,
270        RerankerConfig::Composite(_) => 1.0,
271    }
272}
273
274fn query_relevance(query: &[f32], candidate: &SearchResult) -> f32 {
275    if let Some(vector) = candidate.vector.as_deref() {
276        cosine_similarity(query, vector)
277            .map(|similarity| ((similarity + 1.0) / 2.0).clamp(0.0, 1.0))
278            .unwrap_or(candidate.score)
279    } else {
280        candidate.score
281    }
282}
283
284fn candidate_similarity(left: &SearchResult, right: &SearchResult) -> f32 {
285    match (left.vector.as_deref(), right.vector.as_deref()) {
286        (Some(left), Some(right)) => cosine_similarity(left, right)
287            .map(|similarity| ((similarity + 1.0) / 2.0).clamp(0.0, 1.0))
288            .unwrap_or(0.0),
289        _ => 0.0,
290    }
291}
292
293fn cosine_similarity(left: &[f32], right: &[f32]) -> Option<f32> {
294    if left.len() != right.len() || left.is_empty() {
295        return None;
296    }
297
298    let dot: f32 = left.iter().zip(right.iter()).map(|(a, b)| a * b).sum();
299    let left_norm = left.iter().map(|value| value * value).sum::<f32>().sqrt();
300    let right_norm = right.iter().map(|value| value * value).sum::<f32>().sqrt();
301    if left_norm == 0.0 || right_norm == 0.0 {
302        None
303    } else {
304        Some(dot / (left_norm * right_norm))
305    }
306}
307
308fn sort_results_desc(results: &mut [SearchResult]) {
309    results.sort_by(|left, right| {
310        right
311            .score
312            .partial_cmp(&left.score)
313            .unwrap_or(Ordering::Equal)
314    });
315}
316
317#[cfg(test)]
318mod tests {
319    use chrono::{Duration, Utc};
320
321    use super::apply_reranker_config;
322    use crate::types::{RerankerConfig, SearchResult};
323
324    fn fixture_result(
325        id: uuid::Uuid,
326        vector: Vec<f32>,
327        created_at: chrono::DateTime<Utc>,
328    ) -> SearchResult {
329        SearchResult {
330            id,
331            score: 0.5,
332            vector: Some(vector),
333            metadata: serde_json::json!({}),
334            text: None,
335            created_at,
336        }
337    }
338
339    #[tokio::test]
340    async fn composite_diversity_then_recency_is_deterministic() {
341        let now = Utc::now();
342        let first_id = uuid::Uuid::parse_str("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa").unwrap();
343        let second_id = uuid::Uuid::parse_str("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb").unwrap();
344        let third_id = uuid::Uuid::parse_str("cccccccc-cccc-cccc-cccc-cccccccccccc").unwrap();
345        let query = vec![1.0, 0.0, 0.0];
346        let results = vec![
347            fixture_result(first_id, vec![0.95, 0.05, 0.0], now),
348            fixture_result(second_id, vec![0.5, 0.5, 0.0], now - Duration::days(1)),
349            fixture_result(third_id, vec![0.2, 0.8, 0.0], now - Duration::days(30)),
350        ];
351
352        let config = RerankerConfig::Composite(vec![
353            RerankerConfig::Diversity {
354                lambda: 0.8,
355                weight: 0.6,
356            },
357            RerankerConfig::Recency {
358                boost: 0.5,
359                half_life_days: 7.0,
360                weight: 0.4,
361            },
362        ]);
363
364        let reranked = apply_reranker_config(&query, results, Some(&config))
365            .await
366            .unwrap();
367        let ids = reranked.into_iter().map(|item| item.id).collect::<Vec<_>>();
368
369        assert_eq!(ids, vec![first_id, second_id, third_id]);
370    }
371}