1use serde::{Deserialize, Serialize};
14use std::collections::HashSet;
15
16use crate::types::{Memory, MemoryType, SearchResult};
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct RerankConfig {
21 pub enabled: bool,
23 pub strategy: RerankStrategy,
25 pub original_score_weight: f32,
27 pub rerank_score_weight: f32,
29 pub recency_boost: f32,
31 pub recency_half_life_days: f32,
33 pub importance_boost: f32,
35 pub entity_match_boost: f32,
37 pub exact_match_boost: f32,
39 pub min_results: usize,
41 pub max_rerank_candidates: usize,
43}
44
45impl Default for RerankConfig {
46 fn default() -> Self {
47 Self {
48 enabled: true,
49 strategy: RerankStrategy::Heuristic,
50 original_score_weight: 0.6,
51 rerank_score_weight: 0.4,
52 recency_boost: 0.05,
53 recency_half_life_days: 30.0,
54 importance_boost: 0.1,
55 entity_match_boost: 0.15,
56 exact_match_boost: 0.2,
57 min_results: 3,
58 max_rerank_candidates: 100,
59 }
60 }
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
65#[serde(rename_all = "snake_case")]
66pub enum RerankStrategy {
67 None,
69 #[default]
71 Heuristic,
72 CrossEncoder,
74 MultiSignal,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct RerankResult {
81 pub result: SearchResult,
83 pub original_rank: usize,
85 pub new_rank: usize,
87 pub rerank_info: RerankInfo,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize, Default)]
93pub struct RerankInfo {
94 pub original_score: f32,
96 pub final_score: f32,
98 pub rerank_score: f32,
100 pub components: RerankComponents,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize, Default)]
106pub struct RerankComponents {
107 pub term_overlap: f32,
109 pub recency: f32,
111 pub importance: f32,
113 pub entity_match: f32,
115 pub exact_match: f32,
117 pub type_relevance: f32,
119 pub tag_match: f32,
121}
122
123pub struct Reranker {
125 config: RerankConfig,
126}
127
128impl Reranker {
129 pub fn new() -> Self {
131 Self {
132 config: RerankConfig::default(),
133 }
134 }
135
136 pub fn with_config(config: RerankConfig) -> Self {
138 Self { config }
139 }
140
141 pub fn rerank(
143 &self,
144 results: Vec<SearchResult>,
145 query: &str,
146 query_entities: Option<&[String]>,
147 ) -> Vec<RerankResult> {
148 if !self.config.enabled || results.len() < self.config.min_results {
149 return results
151 .into_iter()
152 .enumerate()
153 .map(|(i, r)| RerankResult {
154 rerank_info: RerankInfo {
155 original_score: r.score,
156 final_score: r.score,
157 rerank_score: 0.0,
158 components: RerankComponents::default(),
159 },
160 result: r,
161 original_rank: i + 1,
162 new_rank: i + 1,
163 })
164 .collect();
165 }
166
167 match self.config.strategy {
168 RerankStrategy::None => self.no_rerank(results),
169 RerankStrategy::Heuristic => self.heuristic_rerank(results, query, query_entities),
170 RerankStrategy::CrossEncoder => {
171 self.heuristic_rerank(results, query, query_entities)
173 }
174 RerankStrategy::MultiSignal => self.multi_signal_rerank(results, query, query_entities),
175 }
176 }
177
178 fn no_rerank(&self, results: Vec<SearchResult>) -> Vec<RerankResult> {
180 results
181 .into_iter()
182 .enumerate()
183 .map(|(i, r)| RerankResult {
184 rerank_info: RerankInfo {
185 original_score: r.score,
186 final_score: r.score,
187 rerank_score: 0.0,
188 components: RerankComponents::default(),
189 },
190 result: r,
191 original_rank: i + 1,
192 new_rank: i + 1,
193 })
194 .collect()
195 }
196
197 fn heuristic_rerank(
199 &self,
200 results: Vec<SearchResult>,
201 query: &str,
202 query_entities: Option<&[String]>,
203 ) -> Vec<RerankResult> {
204 let query_terms = extract_terms(query);
205 let query_lower = query.to_lowercase();
206
207 let mut rerank_results: Vec<RerankResult> = results
208 .into_iter()
209 .enumerate()
210 .take(self.config.max_rerank_candidates)
211 .map(|(i, r)| {
212 let components = self.compute_rerank_components(
213 &r.memory,
214 &query_terms,
215 &query_lower,
216 query_entities,
217 );
218
219 let rerank_score = self.combine_components(&components);
220 let final_score = self.config.original_score_weight * r.score
221 + self.config.rerank_score_weight * rerank_score;
222
223 RerankResult {
224 rerank_info: RerankInfo {
225 original_score: r.score,
226 final_score,
227 rerank_score,
228 components,
229 },
230 result: r,
231 original_rank: i + 1,
232 new_rank: 0, }
234 })
235 .collect();
236
237 rerank_results.sort_by(|a, b| {
239 b.rerank_info
240 .final_score
241 .partial_cmp(&a.rerank_info.final_score)
242 .unwrap_or(std::cmp::Ordering::Equal)
243 });
244
245 for (i, result) in rerank_results.iter_mut().enumerate() {
247 result.new_rank = i + 1;
248 }
249
250 rerank_results
251 }
252
253 fn multi_signal_rerank(
255 &self,
256 results: Vec<SearchResult>,
257 query: &str,
258 query_entities: Option<&[String]>,
259 ) -> Vec<RerankResult> {
260 let query_terms = extract_terms(query);
261 let query_lower = query.to_lowercase();
262
263 let mut original_ranks: Vec<(usize, f32)> = results
265 .iter()
266 .enumerate()
267 .map(|(i, r)| (i, r.score))
268 .collect();
269 original_ranks.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
270
271 let mut recency_ranks: Vec<(usize, f32)> = results
272 .iter()
273 .enumerate()
274 .map(|(i, r)| (i, self.compute_recency_score(&r.memory)))
275 .collect();
276 recency_ranks.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
277
278 let mut term_ranks: Vec<(usize, f32)> = results
279 .iter()
280 .enumerate()
281 .map(|(i, r)| (i, compute_term_overlap(&r.memory.content, &query_terms)))
282 .collect();
283 term_ranks.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
284
285 let k = 60.0;
287 let mut rrf_scores: Vec<(usize, f32)> = vec![];
288
289 for i in 0..results.len() {
290 let orig_rank = original_ranks
291 .iter()
292 .position(|(idx, _)| *idx == i)
293 .unwrap()
294 + 1;
295 let rec_rank = recency_ranks.iter().position(|(idx, _)| *idx == i).unwrap() + 1;
296 let term_rank = term_ranks.iter().position(|(idx, _)| *idx == i).unwrap() + 1;
297
298 let rrf_score = 1.0 / (k + orig_rank as f32)
299 + 0.5 / (k + rec_rank as f32)
300 + 0.5 / (k + term_rank as f32);
301
302 rrf_scores.push((i, rrf_score));
303 }
304
305 rrf_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
306
307 let mut rerank_results: Vec<RerankResult> = results
309 .into_iter()
310 .enumerate()
311 .map(|(i, r)| {
312 let components = self.compute_rerank_components(
313 &r.memory,
314 &query_terms,
315 &query_lower,
316 query_entities,
317 );
318 let rrf_score = rrf_scores
319 .iter()
320 .find(|(idx, _)| *idx == i)
321 .map(|(_, s)| *s)
322 .unwrap_or(0.0);
323 let new_rank = rrf_scores
324 .iter()
325 .position(|(idx, _)| *idx == i)
326 .unwrap_or(i)
327 + 1;
328
329 RerankResult {
330 rerank_info: RerankInfo {
331 original_score: r.score,
332 final_score: rrf_score,
333 rerank_score: rrf_score,
334 components,
335 },
336 result: r,
337 original_rank: i + 1,
338 new_rank,
339 }
340 })
341 .collect();
342
343 rerank_results.sort_by_key(|r| r.new_rank);
344 rerank_results
345 }
346
347 fn compute_rerank_components(
349 &self,
350 memory: &Memory,
351 query_terms: &HashSet<String>,
352 query_lower: &str,
353 query_entities: Option<&[String]>,
354 ) -> RerankComponents {
355 let content_lower = memory.content.to_lowercase();
356
357 RerankComponents {
358 term_overlap: compute_term_overlap(&memory.content, query_terms),
359 recency: self.compute_recency_score(memory),
360 importance: memory.importance * self.config.importance_boost,
361 entity_match: self.compute_entity_match_score(memory, query_entities),
362 exact_match: if content_lower.contains(query_lower) {
363 self.config.exact_match_boost
364 } else {
365 0.0
366 },
367 type_relevance: self.compute_type_relevance(memory),
368 tag_match: self.compute_tag_match_score(memory, query_terms),
369 }
370 }
371
372 fn combine_components(&self, components: &RerankComponents) -> f32 {
374 components.term_overlap * 0.25
376 + components.recency * 0.15
377 + components.importance * 0.15
378 + components.entity_match * 0.15
379 + components.exact_match * 0.15
380 + components.type_relevance * 0.05
381 + components.tag_match * 0.10
382 }
383
384 fn compute_recency_score(&self, memory: &Memory) -> f32 {
386 let now = chrono::Utc::now();
387 let age_days = (now - memory.created_at).num_days() as f32;
388
389 let decay = 0.5_f32.powf(age_days / self.config.recency_half_life_days);
391 self.config.recency_boost * decay
392 }
393
394 fn compute_entity_match_score(
396 &self,
397 memory: &Memory,
398 query_entities: Option<&[String]>,
399 ) -> f32 {
400 let Some(entities) = query_entities else {
401 return 0.0;
402 };
403
404 if entities.is_empty() {
405 return 0.0;
406 }
407
408 let content_lower = memory.content.to_lowercase();
409 let matches = entities
410 .iter()
411 .filter(|e| content_lower.contains(&e.to_lowercase()))
412 .count();
413
414 if matches > 0 {
415 self.config.entity_match_boost * (matches as f32 / entities.len() as f32)
416 } else {
417 0.0
418 }
419 }
420
421 fn compute_type_relevance(&self, memory: &Memory) -> f32 {
423 match memory.memory_type {
424 MemoryType::Decision => 0.1,
425 MemoryType::Preference => 0.08,
426 MemoryType::Learning => 0.06,
427 MemoryType::Context => 0.05,
428 MemoryType::Note => 0.04,
429 MemoryType::Todo => 0.03,
430 MemoryType::Issue => 0.03,
431 MemoryType::Credential => 0.02,
432 MemoryType::Custom => 0.04,
433 MemoryType::TranscriptChunk => 0.02, MemoryType::Episodic => 0.07,
435 MemoryType::Procedural => 0.06,
436 MemoryType::Summary => 0.05,
437 MemoryType::Checkpoint => 0.04,
438 MemoryType::Image | MemoryType::Audio | MemoryType::Video => 0.05,
439 }
440 }
441
442 fn compute_tag_match_score(&self, memory: &Memory, query_terms: &HashSet<String>) -> f32 {
444 if memory.tags.is_empty() || query_terms.is_empty() {
445 return 0.0;
446 }
447
448 let tag_set: HashSet<String> = memory.tags.iter().map(|t| t.to_lowercase()).collect();
449 let matches = query_terms.intersection(&tag_set).count();
450
451 if matches > 0 {
452 0.1 * (matches as f32 / query_terms.len().min(memory.tags.len()) as f32)
453 } else {
454 0.0
455 }
456 }
457}
458
459impl Default for Reranker {
460 fn default() -> Self {
461 Self::new()
462 }
463}
464
465fn extract_terms(text: &str) -> HashSet<String> {
467 text.to_lowercase()
468 .split(|c: char| !c.is_alphanumeric())
469 .filter(|s| s.len() > 2)
470 .map(|s| s.to_string())
471 .collect()
472}
473
474fn compute_term_overlap(content: &str, query_terms: &HashSet<String>) -> f32 {
476 if query_terms.is_empty() {
477 return 0.0;
478 }
479
480 let content_terms = extract_terms(content);
481 let matches = query_terms.intersection(&content_terms).count();
482
483 matches as f32 / query_terms.len() as f32
484}
485
486#[cfg(test)]
487mod tests {
488 use super::*;
489 use crate::types::{MatchInfo, MemoryScope, SearchStrategy, Visibility};
490 use chrono::Utc;
491 use std::collections::HashMap;
492
493 fn create_test_memory(content: &str, importance: f32) -> Memory {
494 Memory {
495 id: 1,
496 content: content.to_string(),
497 memory_type: MemoryType::Note,
498 importance,
499 tags: vec![],
500 access_count: 0,
501 created_at: Utc::now(),
502 updated_at: Utc::now(),
503 last_accessed_at: None,
504 owner_id: None,
505 visibility: Visibility::Private,
506 version: 1,
507 has_embedding: false,
508 metadata: HashMap::new(),
509 scope: MemoryScope::Global,
510 workspace: "default".to_string(),
511 tier: crate::types::MemoryTier::Permanent,
512 expires_at: None,
513 content_hash: None,
514 event_time: None,
515 event_duration_seconds: None,
516 trigger_pattern: None,
517 procedure_success_count: 0,
518 procedure_failure_count: 0,
519 summary_of_id: None,
520 lifecycle_state: crate::types::LifecycleState::Active,
521 media_url: None,
522 }
523 }
524
525 fn create_test_result(memory: Memory, score: f32) -> SearchResult {
526 SearchResult {
527 memory,
528 score,
529 match_info: MatchInfo {
530 strategy: SearchStrategy::Hybrid,
531 matched_terms: vec![],
532 highlights: vec![],
533 semantic_score: None,
534 keyword_score: Some(score),
535 },
536 }
537 }
538
539 #[test]
540 fn test_reranker_preserves_order_when_disabled() {
541 let config = RerankConfig {
542 enabled: false,
543 ..Default::default()
544 };
545 let reranker = Reranker::with_config(config);
546
547 let results = vec![
548 create_test_result(create_test_memory("First result", 0.5), 0.9),
549 create_test_result(create_test_memory("Second result", 0.5), 0.8),
550 create_test_result(create_test_memory("Third result", 0.5), 0.7),
551 ];
552
553 let reranked = reranker.rerank(results, "test query", None);
554
555 assert_eq!(reranked[0].new_rank, 1);
556 assert_eq!(reranked[1].new_rank, 2);
557 assert_eq!(reranked[2].new_rank, 3);
558 }
559
560 #[test]
561 fn test_exact_match_boost() {
562 let reranker = Reranker::new();
563
564 let results = vec![
565 create_test_result(create_test_memory("Some unrelated content", 0.5), 0.9),
566 create_test_result(
567 create_test_memory("This contains test query exactly", 0.5),
568 0.7,
569 ),
570 create_test_result(create_test_memory("Another unrelated text", 0.5), 0.8),
571 ];
572
573 let reranked = reranker.rerank(results, "test query", None);
574
575 let exact_match_result = reranked
577 .iter()
578 .find(|r| r.result.memory.content.contains("test query"))
579 .unwrap();
580 assert!(exact_match_result.rerank_info.components.exact_match > 0.0);
581 }
582
583 #[test]
584 fn test_importance_boost() {
585 let config = RerankConfig {
586 min_results: 2, ..Default::default()
588 };
589 let reranker = Reranker::with_config(config);
590
591 let mut low_importance = create_test_memory("Test content low", 0.2);
592 let mut high_importance = create_test_memory("Test content high", 0.9);
593
594 low_importance.id = 1;
595 high_importance.id = 2;
596
597 let results = vec![
598 create_test_result(low_importance, 0.8),
599 create_test_result(high_importance, 0.75),
600 ];
601
602 let reranked = reranker.rerank(results, "test", None);
603
604 let high_result = reranked.iter().find(|r| r.result.memory.id == 2).unwrap();
606 let low_result = reranked.iter().find(|r| r.result.memory.id == 1).unwrap();
607
608 assert!(
609 high_result.rerank_info.components.importance
610 > low_result.rerank_info.components.importance
611 );
612 }
613
614 #[test]
615 fn test_entity_match_boost() {
616 let config = RerankConfig {
617 min_results: 2, ..Default::default()
619 };
620 let reranker = Reranker::with_config(config);
621
622 let results = vec![
623 create_test_result(
624 create_test_memory("Content about Python programming", 0.5),
625 0.8,
626 ),
627 create_test_result(
628 create_test_memory("Content about Rust and systems", 0.5),
629 0.75,
630 ),
631 ];
632
633 let entities = vec!["Rust".to_string(), "systems".to_string()];
634 let reranked = reranker.rerank(results, "programming language", Some(&entities));
635
636 let rust_result = reranked
638 .iter()
639 .find(|r| r.result.memory.content.contains("Rust"))
640 .unwrap();
641 assert!(rust_result.rerank_info.components.entity_match > 0.0);
642 }
643
644 #[test]
645 fn test_term_overlap() {
646 let terms: HashSet<String> = ["rust", "programming", "memory"]
647 .iter()
648 .map(|s| s.to_string())
649 .collect();
650
651 let high_overlap = compute_term_overlap("Rust programming with memory management", &terms);
652 let low_overlap = compute_term_overlap("Python web development", &terms);
653
654 assert!(high_overlap > low_overlap);
655 assert!(high_overlap > 0.5); }
657
658 #[test]
659 fn test_multi_signal_rerank() {
660 let config = RerankConfig {
661 strategy: RerankStrategy::MultiSignal,
662 ..Default::default()
663 };
664 let reranker = Reranker::with_config(config);
665
666 let results = vec![
667 create_test_result(create_test_memory("First memory", 0.5), 0.9),
668 create_test_result(create_test_memory("Second memory", 0.5), 0.8),
669 create_test_result(
670 create_test_memory("Third memory with exact query", 0.5),
671 0.7,
672 ),
673 ];
674
675 let reranked = reranker.rerank(results, "exact query", None);
676
677 assert_eq!(reranked.len(), 3);
679 for r in &reranked {
681 assert!(r.rerank_info.final_score > 0.0);
682 }
683 }
684}