1use crate::hnsw::SearchResult;
8use crate::metadata::{Metadata, MetadataValue};
9use ipfrs_core::Cid;
10use std::collections::HashMap;
11
12#[derive(Debug, Clone)]
14pub enum ReRankingStrategy {
15 WeightedCombination(Vec<(ScoreComponent, f32)>),
17 ReciprocalRankFusion { k: f32 },
19 LearnToRank { model_name: String },
21 Custom,
23}
24
25#[derive(Debug, Clone)]
27pub enum ScoreComponent {
28 VectorSimilarity,
30 MetadataScore { field: String },
32 Recency { decay_factor: f32 },
34 Popularity,
36 Diversity { threshold: f32 },
38 Custom { name: String },
40}
41
42#[derive(Debug, Clone)]
44pub struct ReRankingConfig {
45 pub strategy: ReRankingStrategy,
47 pub normalize_scores: bool,
49 pub top_k: Option<usize>,
51}
52
53impl Default for ReRankingConfig {
54 fn default() -> Self {
55 Self {
56 strategy: ReRankingStrategy::WeightedCombination(vec![(
57 ScoreComponent::VectorSimilarity,
58 1.0,
59 )]),
60 normalize_scores: true,
61 top_k: Some(100), }
63 }
64}
65
66#[derive(Debug, Clone)]
68pub struct ScoredResult {
69 pub result: SearchResult,
71 pub score_components: HashMap<String, f32>,
73 pub final_score: f32,
75}
76
77pub struct ReRanker {
79 config: ReRankingConfig,
80 metadata_cache: HashMap<Cid, Metadata>,
81}
82
83impl ReRanker {
84 pub fn new(config: ReRankingConfig) -> Self {
86 Self {
87 config,
88 metadata_cache: HashMap::new(),
89 }
90 }
91
92 pub fn with_defaults() -> Self {
94 Self::new(ReRankingConfig::default())
95 }
96
97 pub fn add_metadata(&mut self, cid: Cid, metadata: Metadata) {
99 self.metadata_cache.insert(cid, metadata);
100 }
101
102 pub fn rerank(&self, results: Vec<SearchResult>) -> Vec<ScoredResult> {
104 let limit = self
105 .config
106 .top_k
107 .unwrap_or(results.len())
108 .min(results.len());
109 let mut to_rerank: Vec<SearchResult> = results.into_iter().take(limit).collect();
110
111 match &self.config.strategy {
112 ReRankingStrategy::WeightedCombination(weights) => {
113 self.rerank_weighted(&mut to_rerank, weights)
114 }
115 ReRankingStrategy::ReciprocalRankFusion { k } => self.rerank_rrf(&mut to_rerank, *k),
116 ReRankingStrategy::LearnToRank { model_name: _ } => {
117 self.rerank_placeholder(&mut to_rerank)
119 }
120 ReRankingStrategy::Custom => self.rerank_placeholder(&mut to_rerank),
121 }
122 }
123
124 fn rerank_weighted(
126 &self,
127 results: &mut [SearchResult],
128 weights: &[(ScoreComponent, f32)],
129 ) -> Vec<ScoredResult> {
130 let mut scored_results: Vec<ScoredResult> = results
131 .iter()
132 .map(|r| {
133 let mut score_components = HashMap::new();
134 let mut final_score = 0.0;
135
136 for (component, weight) in weights {
137 let component_score = self.compute_component_score(r, component);
138 let component_name = self.component_name(component);
139 score_components.insert(component_name, component_score);
140 final_score += component_score * weight;
141 }
142
143 ScoredResult {
144 result: r.clone(),
145 score_components,
146 final_score,
147 }
148 })
149 .collect();
150
151 if self.config.normalize_scores {
153 self.normalize_scores(&mut scored_results);
154 }
155
156 scored_results.sort_by(|a, b| {
158 b.final_score
159 .partial_cmp(&a.final_score)
160 .unwrap_or(std::cmp::Ordering::Equal)
161 });
162
163 scored_results
164 }
165
166 fn rerank_rrf(&self, results: &mut [SearchResult], k: f32) -> Vec<ScoredResult> {
168 let scored_results: Vec<ScoredResult> = results
169 .iter()
170 .enumerate()
171 .map(|(rank, r)| {
172 let rrf_score = 1.0 / (k + rank as f32 + 1.0);
173 let mut score_components = HashMap::new();
174 score_components.insert("vector_similarity".to_string(), r.score);
175 score_components.insert("rrf_score".to_string(), rrf_score);
176
177 ScoredResult {
178 result: r.clone(),
179 score_components,
180 final_score: rrf_score,
181 }
182 })
183 .collect();
184
185 scored_results
186 }
187
188 fn rerank_placeholder(&self, results: &mut [SearchResult]) -> Vec<ScoredResult> {
190 results
191 .iter()
192 .map(|r| {
193 let mut score_components = HashMap::new();
194 score_components.insert("vector_similarity".to_string(), r.score);
195
196 ScoredResult {
197 result: r.clone(),
198 score_components,
199 final_score: r.score,
200 }
201 })
202 .collect()
203 }
204
205 fn compute_component_score(&self, result: &SearchResult, component: &ScoreComponent) -> f32 {
207 match component {
208 ScoreComponent::VectorSimilarity => result.score,
209 ScoreComponent::MetadataScore { field } => {
210 if let Some(metadata) = self.metadata_cache.get(&result.cid) {
212 if let Some(value) = metadata.get(field) {
213 return self.metadata_value_to_score(value);
214 }
215 }
216 0.0
217 }
218 ScoreComponent::Recency { decay_factor } => {
219 if let Some(metadata) = self.metadata_cache.get(&result.cid) {
221 if let Some(MetadataValue::Integer(timestamp)) = metadata.get("timestamp") {
222 let age = Self::current_timestamp() - timestamp;
224 return (-(age as f32) * decay_factor).exp();
225 }
226 }
227 0.0
228 }
229 ScoreComponent::Popularity => {
230 if let Some(metadata) = self.metadata_cache.get(&result.cid) {
232 if let Some(value) = metadata.get("popularity") {
233 return self.metadata_value_to_score(value);
234 }
235 }
236 0.0
237 }
238 ScoreComponent::Diversity { threshold: _ } => {
239 0.0
242 }
243 ScoreComponent::Custom { name: _ } => {
244 0.0
246 }
247 }
248 }
249
250 fn metadata_value_to_score(&self, value: &MetadataValue) -> f32 {
252 match value {
253 MetadataValue::Integer(i) => *i as f32,
254 MetadataValue::Float(f) => *f as f32,
255 MetadataValue::Boolean(b) => {
256 if *b {
257 1.0
258 } else {
259 0.0
260 }
261 }
262 MetadataValue::Timestamp(t) => *t as f32,
263 MetadataValue::String(_) | MetadataValue::StringArray(_) | MetadataValue::Null => 0.0,
264 }
265 }
266
267 fn component_name(&self, component: &ScoreComponent) -> String {
269 match component {
270 ScoreComponent::VectorSimilarity => "vector_similarity".to_string(),
271 ScoreComponent::MetadataScore { field } => format!("metadata_{}", field),
272 ScoreComponent::Recency { .. } => "recency".to_string(),
273 ScoreComponent::Popularity => "popularity".to_string(),
274 ScoreComponent::Diversity { .. } => "diversity".to_string(),
275 ScoreComponent::Custom { name } => format!("custom_{}", name),
276 }
277 }
278
279 fn normalize_scores(&self, results: &mut [ScoredResult]) {
281 if results.is_empty() {
282 return;
283 }
284
285 let min_score = results
287 .iter()
288 .map(|r| r.final_score)
289 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
290 .unwrap_or(0.0);
291
292 let max_score = results
293 .iter()
294 .map(|r| r.final_score)
295 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
296 .unwrap_or(1.0);
297
298 let range = max_score - min_score;
299
300 if range > 0.0 {
301 for result in results.iter_mut() {
302 result.final_score = (result.final_score - min_score) / range;
303 }
304 }
305 }
306
307 fn current_timestamp() -> i64 {
309 use std::time::{SystemTime, UNIX_EPOCH};
310 SystemTime::now()
311 .duration_since(UNIX_EPOCH)
312 .unwrap()
313 .as_secs() as i64
314 }
315
316 pub fn weighted(components: Vec<(ScoreComponent, f32)>) -> ReRankingConfig {
318 ReRankingConfig {
319 strategy: ReRankingStrategy::WeightedCombination(components),
320 normalize_scores: true,
321 top_k: Some(100),
322 }
323 }
324
325 pub fn reciprocal_rank_fusion(k: f32) -> ReRankingConfig {
327 ReRankingConfig {
328 strategy: ReRankingStrategy::ReciprocalRankFusion { k },
329 normalize_scores: false,
330 top_k: Some(100),
331 }
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338
339 #[test]
340 fn test_reranker_creation() {
341 let reranker = ReRanker::with_defaults();
342 assert!(matches!(
343 reranker.config.strategy,
344 ReRankingStrategy::WeightedCombination(_)
345 ));
346 }
347
348 #[test]
349 fn test_weighted_reranking() {
350 let config = ReRanker::weighted(vec![
351 (ScoreComponent::VectorSimilarity, 0.7),
352 (ScoreComponent::Popularity, 0.3),
353 ]);
354
355 let mut reranker = ReRanker::new(config);
356
357 let cid1 = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
359 .parse::<Cid>()
360 .unwrap();
361 let cid2 = "bafybeihpjhkeuiq3k6nqa3fkgeigeri7iebtrsuyuey5y6vy36n345xmbi"
362 .parse::<Cid>()
363 .unwrap();
364
365 let mut metadata1 = Metadata::new();
367 metadata1.set("popularity", MetadataValue::Float(0.5));
368 reranker.add_metadata(cid1, metadata1);
369
370 let mut metadata2 = Metadata::new();
371 metadata2.set("popularity", MetadataValue::Float(0.9));
372 reranker.add_metadata(cid2, metadata2);
373
374 let results = vec![
375 SearchResult {
376 cid: cid1,
377 score: 0.9,
378 },
379 SearchResult {
380 cid: cid2,
381 score: 0.7,
382 },
383 ];
384
385 let reranked = reranker.rerank(results);
386 assert_eq!(reranked.len(), 2);
387
388 assert_eq!(reranked[0].result.cid, cid1);
391 }
392
393 #[test]
394 fn test_rrf_reranking() {
395 let config = ReRanker::reciprocal_rank_fusion(60.0);
396 let reranker = ReRanker::new(config);
397
398 let cid1 = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
399 .parse::<Cid>()
400 .unwrap();
401 let cid2 = "bafybeihpjhkeuiq3k6nqa3fkgeigeri7iebtrsuyuey5y6vy36n345xmbi"
402 .parse::<Cid>()
403 .unwrap();
404
405 let results = vec![
406 SearchResult {
407 cid: cid1,
408 score: 0.9,
409 },
410 SearchResult {
411 cid: cid2,
412 score: 0.7,
413 },
414 ];
415
416 let reranked = reranker.rerank(results);
417 assert_eq!(reranked.len(), 2);
418
419 assert!(reranked[0].final_score > reranked[1].final_score);
421 }
422
423 #[test]
424 fn test_recency_scoring() {
425 let config = ReRanker::weighted(vec![
426 (ScoreComponent::VectorSimilarity, 0.5),
427 (
428 ScoreComponent::Recency {
429 decay_factor: 0.0001,
430 },
431 0.5,
432 ),
433 ]);
434
435 let mut reranker = ReRanker::new(config);
436
437 let cid1 = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
438 .parse::<Cid>()
439 .unwrap();
440
441 let current_time = ReRanker::current_timestamp();
442
443 let mut metadata = Metadata::new();
444 metadata.set("timestamp", MetadataValue::Integer(current_time - 100));
445 reranker.add_metadata(cid1, metadata);
446
447 let results = vec![SearchResult {
448 cid: cid1,
449 score: 0.8,
450 }];
451
452 let reranked = reranker.rerank(results);
453 assert_eq!(reranked.len(), 1);
454 assert!(reranked[0].score_components.contains_key("recency"));
455 }
456
457 #[test]
458 fn test_normalize_scores() {
459 let config = ReRankingConfig {
460 strategy: ReRankingStrategy::WeightedCombination(vec![(
461 ScoreComponent::VectorSimilarity,
462 1.0,
463 )]),
464 normalize_scores: true,
465 top_k: None,
466 };
467
468 let reranker = ReRanker::new(config);
469
470 let cid1 = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
471 .parse::<Cid>()
472 .unwrap();
473 let cid2 = "bafybeihpjhkeuiq3k6nqa3fkgeigeri7iebtrsuyuey5y6vy36n345xmbi"
474 .parse::<Cid>()
475 .unwrap();
476
477 let results = vec![
478 SearchResult {
479 cid: cid1,
480 score: 0.9,
481 },
482 SearchResult {
483 cid: cid2,
484 score: 0.5,
485 },
486 ];
487
488 let reranked = reranker.rerank(results);
489
490 assert!(reranked[0].final_score >= 0.0 && reranked[0].final_score <= 1.0);
492 assert!(reranked[1].final_score >= 0.0 && reranked[1].final_score <= 1.0);
493 }
494
495 #[test]
496 fn test_top_k_reranking() {
497 let config = ReRankingConfig {
498 strategy: ReRankingStrategy::WeightedCombination(vec![(
499 ScoreComponent::VectorSimilarity,
500 1.0,
501 )]),
502 normalize_scores: false,
503 top_k: Some(2), };
505
506 let reranker = ReRanker::new(config);
507
508 let cid1 = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
509 .parse::<Cid>()
510 .unwrap();
511 let cid2 = "bafybeihpjhkeuiq3k6nqa3fkgeigeri7iebtrsuyuey5y6vy36n345xmbi"
512 .parse::<Cid>()
513 .unwrap();
514 let cid3 = "bafybeif2pall7dybz7vecqka3zo24irdwabwdi4wc55jznaq75q7eaavvu"
515 .parse::<Cid>()
516 .unwrap();
517
518 let results = vec![
519 SearchResult {
520 cid: cid1,
521 score: 0.9,
522 },
523 SearchResult {
524 cid: cid2,
525 score: 0.7,
526 },
527 SearchResult {
528 cid: cid3,
529 score: 0.5,
530 },
531 ];
532
533 let reranked = reranker.rerank(results);
534
535 assert_eq!(reranked.len(), 2);
537 }
538}