1use crate::retrieval::engine::RetrievedRecord;
13use ordered_float::OrderedFloat;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
17pub enum RerankerType {
18 #[default]
20 None,
21 OutcomeWeighted,
23 Recency,
25 MMR,
27 Composite,
29}
30
31#[derive(Debug, Clone)]
33pub struct RerankerConfig {
34 pub strategy: RerankerType,
36 pub original_weight: f32,
38 pub outcome_weight: f32,
40 pub recency_weight: f32,
42 pub recency_half_life: f64,
44 pub mmr_lambda: f32,
46 pub min_samples: u64,
48}
49
50impl Default for RerankerConfig {
51 fn default() -> Self {
52 Self {
53 strategy: RerankerType::OutcomeWeighted,
54 original_weight: 0.5,
55 outcome_weight: 0.3,
56 recency_weight: 0.2,
57 recency_half_life: 86400.0 * 7.0, mmr_lambda: 0.7,
59 min_samples: 3,
60 }
61 }
62}
63
64impl RerankerConfig {
65 #[must_use]
67 pub fn new() -> Self {
68 Self::default()
69 }
70
71 #[must_use]
73 pub const fn with_strategy(mut self, strategy: RerankerType) -> Self {
74 self.strategy = strategy;
75 self
76 }
77
78 #[must_use]
80 pub const fn with_outcome_weight(mut self, weight: f32) -> Self {
81 self.outcome_weight = weight;
82 self
83 }
84
85 #[must_use]
87 pub const fn with_mmr_lambda(mut self, lambda: f32) -> Self {
88 self.mmr_lambda = lambda;
89 self
90 }
91}
92
93pub struct Reranker {
95 config: RerankerConfig,
96}
97
98impl Reranker {
99 #[must_use]
101 pub fn new(config: RerankerConfig) -> Self {
102 Self { config }
103 }
104
105 #[must_use]
107 pub fn rerank(&self, results: Vec<RetrievedRecord>) -> Vec<RetrievedRecord> {
108 match self.config.strategy {
109 RerankerType::None => results,
110 RerankerType::OutcomeWeighted => self.rerank_by_outcome(results),
111 RerankerType::Recency => self.rerank_by_recency(results),
112 RerankerType::MMR => self.rerank_mmr(results),
113 RerankerType::Composite => self.rerank_composite(results),
114 }
115 }
116
117 fn rerank_by_outcome(&self, mut results: Vec<RetrievedRecord>) -> Vec<RetrievedRecord> {
119 for result in &mut results {
120 let outcome_score = self.compute_outcome_score(&result.record);
121 result.score = self.config.original_weight * result.score
122 + self.config.outcome_weight * outcome_score;
123 }
124
125 results.sort_by(|a, b| {
126 OrderedFloat(b.score).cmp(&OrderedFloat(a.score))
127 });
128
129 results
130 }
131
132 fn compute_outcome_score(&self, record: &crate::types::MemoryRecord) -> f32 {
134 let stats = &record.stats;
135
136 if stats.count() < self.config.min_samples {
137 return record.outcome as f32;
139 }
140
141 if let Some((lower, _upper)) = stats.confidence_interval(0.90) {
145 lower.first().copied().unwrap_or(record.outcome as f32)
147 } else {
148 record.outcome as f32
149 }
150 }
151
152 fn rerank_by_recency(&self, mut results: Vec<RetrievedRecord>) -> Vec<RetrievedRecord> {
154 let now = current_time_secs();
155
156 for result in &mut results {
157 let age_secs = (now - result.record.created_at) as f64;
158 let recency_score = self.compute_recency_score(age_secs);
159
160 result.score = self.config.original_weight * result.score
161 + self.config.recency_weight * recency_score;
162 }
163
164 results.sort_by(|a, b| {
165 OrderedFloat(b.score).cmp(&OrderedFloat(a.score))
166 });
167
168 results
169 }
170
171 fn compute_recency_score(&self, age_secs: f64) -> f32 {
173 let decay = (-age_secs / self.config.recency_half_life * std::f64::consts::LN_2).exp();
175 decay as f32
176 }
177
178 fn rerank_mmr(&self, results: Vec<RetrievedRecord>) -> Vec<RetrievedRecord> {
180 if results.len() <= 1 {
181 return results;
182 }
183
184 let lambda = self.config.mmr_lambda;
185 let mut reranked = Vec::with_capacity(results.len());
186 let mut remaining: Vec<_> = results.into_iter().collect();
187
188 remaining.sort_by(|a, b| OrderedFloat(b.score).cmp(&OrderedFloat(a.score)));
190 reranked.push(remaining.remove(0));
191
192 while !remaining.is_empty() {
194 let mut best_idx = 0;
195 let mut best_mmr = f32::NEG_INFINITY;
196
197 for (i, candidate) in remaining.iter().enumerate() {
198 let relevance = candidate.score;
200
201 let max_sim = reranked
203 .iter()
204 .map(|r| self.embedding_similarity(&candidate.record.embedding, &r.record.embedding))
205 .fold(0.0f32, f32::max);
206
207 let mmr = lambda * relevance - (1.0 - lambda) * max_sim;
209
210 if mmr > best_mmr {
211 best_mmr = mmr;
212 best_idx = i;
213 }
214 }
215
216 reranked.push(remaining.remove(best_idx));
217 }
218
219 reranked
220 }
221
222 fn embedding_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
224 if a.len() != b.len() {
225 return 0.0;
226 }
227
228 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
229 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
230 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
231
232 if norm_a > 0.0 && norm_b > 0.0 {
233 dot / (norm_a * norm_b)
234 } else {
235 0.0
236 }
237 }
238
239 fn rerank_composite(&self, mut results: Vec<RetrievedRecord>) -> Vec<RetrievedRecord> {
241 let now = current_time_secs();
242
243 for result in &mut results {
244 let outcome_score = self.compute_outcome_score(&result.record);
245 let age_secs = (now - result.record.created_at) as f64;
246 let recency_score = self.compute_recency_score(age_secs);
247
248 result.score = self.config.original_weight * result.score
249 + self.config.outcome_weight * outcome_score
250 + self.config.recency_weight * recency_score;
251 }
252
253 results.sort_by(|a, b| {
254 OrderedFloat(b.score).cmp(&OrderedFloat(a.score))
255 });
256
257 results
258 }
259}
260
261fn current_time_secs() -> u64 {
263 use std::time::{SystemTime, UNIX_EPOCH};
264
265 SystemTime::now()
266 .duration_since(UNIX_EPOCH)
267 .map(|d| d.as_secs())
268 .unwrap_or(0)
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274 use crate::stats::OutcomeStats;
275 use crate::types::{MemoryRecord, RecordStatus};
276
277 fn create_test_result(id: &str, score: f32, outcome: f64, age_secs: u64) -> RetrievedRecord {
278 let now = current_time_secs();
279 let created_at = now.saturating_sub(age_secs);
280
281 RetrievedRecord {
282 record: MemoryRecord {
283 id: id.into(),
284 embedding: vec![1.0, 0.0, 0.0],
285 context: format!("Context {id}"),
286 outcome,
287 metadata: Default::default(),
288 created_at,
289 status: RecordStatus::Active,
290 stats: OutcomeStats::new(1),
291 },
292 score,
293 rank: 0,
294 source_index: "test".into(),
295 }
296 }
297
298 fn create_result_with_stats(id: &str, score: f32, outcomes: &[f64]) -> RetrievedRecord {
299 let mut stats = OutcomeStats::new(1);
300 for &o in outcomes {
301 stats.update_scalar(o);
302 }
303
304 RetrievedRecord {
305 record: MemoryRecord {
306 id: id.into(),
307 embedding: vec![1.0, 0.0, 0.0],
308 context: format!("Context {id}"),
309 outcome: outcomes.first().copied().unwrap_or(0.5),
310 metadata: Default::default(),
311 created_at: current_time_secs(),
312 status: RecordStatus::Active,
313 stats,
314 },
315 score,
316 rank: 0,
317 source_index: "test".into(),
318 }
319 }
320
321 #[test]
322 fn test_no_reranking() {
323 let reranker = Reranker::new(RerankerConfig::new().with_strategy(RerankerType::None));
324
325 let results = vec![
326 create_test_result("a", 0.9, 0.5, 0),
327 create_test_result("b", 0.8, 0.9, 0),
328 ];
329
330 let reranked = reranker.rerank(results);
331
332 assert_eq!(reranked[0].record.id.as_str(), "a");
333 assert_eq!(reranked[1].record.id.as_str(), "b");
334 }
335
336 #[test]
337 fn test_outcome_reranking() {
338 let reranker = Reranker::new(
339 RerankerConfig::new()
340 .with_strategy(RerankerType::OutcomeWeighted)
341 .with_outcome_weight(0.8),
342 );
343
344 let results = vec![
346 create_result_with_stats("a", 0.9, &[0.3, 0.4, 0.3, 0.4]),
347 create_result_with_stats("b", 0.8, &[0.9, 0.8, 0.9, 0.85]),
348 ];
349
350 let reranked = reranker.rerank(results);
351
352 assert_eq!(reranked[0].record.id.as_str(), "b");
354 }
355
356 #[test]
357 fn test_recency_reranking() {
358 let reranker = Reranker::new(
359 RerankerConfig::new()
360 .with_strategy(RerankerType::Recency),
361 );
362
363 let results = vec![
364 create_test_result("old", 0.9, 0.5, 86400 * 30), create_test_result("new", 0.8, 0.5, 3600), ];
367
368 let reranked = reranker.rerank(results);
369
370 assert_eq!(reranked[0].record.id.as_str(), "new");
372 }
373
374 #[test]
375 fn test_mmr_diversity() {
376 let reranker = Reranker::new(
377 RerankerConfig::new()
378 .with_strategy(RerankerType::MMR)
379 .with_mmr_lambda(0.5),
380 );
381
382 let mut results = vec![
384 RetrievedRecord {
385 record: MemoryRecord {
386 id: "a".into(),
387 embedding: vec![1.0, 0.0, 0.0],
388 context: "a".into(),
389 outcome: 0.5,
390 metadata: Default::default(),
391 created_at: 0,
392 status: RecordStatus::Active,
393 stats: OutcomeStats::new(1),
394 },
395 score: 0.95,
396 rank: 0,
397 source_index: "test".into(),
398 },
399 RetrievedRecord {
400 record: MemoryRecord {
401 id: "b".into(),
402 embedding: vec![0.99, 0.01, 0.0], context: "b".into(),
404 outcome: 0.5,
405 metadata: Default::default(),
406 created_at: 0,
407 status: RecordStatus::Active,
408 stats: OutcomeStats::new(1),
409 },
410 score: 0.9,
411 rank: 0,
412 source_index: "test".into(),
413 },
414 RetrievedRecord {
415 record: MemoryRecord {
416 id: "c".into(),
417 embedding: vec![0.0, 1.0, 0.0], context: "c".into(),
419 outcome: 0.5,
420 metadata: Default::default(),
421 created_at: 0,
422 status: RecordStatus::Active,
423 stats: OutcomeStats::new(1),
424 },
425 score: 0.85,
426 rank: 0,
427 source_index: "test".into(),
428 },
429 ];
430
431 let reranked = reranker.rerank(results);
432
433 assert_eq!(reranked[0].record.id.as_str(), "a");
435
436 assert_eq!(reranked[1].record.id.as_str(), "c");
439 }
440
441 #[test]
442 fn test_composite_reranking() {
443 let reranker = Reranker::new(
444 RerankerConfig::new().with_strategy(RerankerType::Composite),
445 );
446
447 let results = vec![
448 create_test_result("a", 0.9, 0.5, 86400 * 30),
449 create_test_result("b", 0.7, 0.9, 3600),
450 ];
451
452 let reranked = reranker.rerank(results);
453
454 assert_eq!(reranked.len(), 2);
456 }
457
458 #[test]
459 fn test_empty_results() {
460 let reranker = Reranker::new(RerankerConfig::new());
461 let results = Vec::new();
462 let reranked = reranker.rerank(results);
463 assert!(reranked.is_empty());
464 }
465
466 #[test]
467 fn test_single_result() {
468 let reranker = Reranker::new(RerankerConfig::new());
469 let results = vec![create_test_result("a", 0.9, 0.5, 0)];
470 let reranked = reranker.rerank(results);
471 assert_eq!(reranked.len(), 1);
472 }
473}