1use serde::{Deserialize, Serialize};
14use std::collections::HashSet;
15
16use crate::types::{Memory, MemoryType, SearchResult};
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct RerankConfig {
21 pub enabled: bool,
23 pub strategy: RerankStrategy,
25 pub original_score_weight: f32,
27 pub rerank_score_weight: f32,
29 pub recency_boost: f32,
31 pub recency_half_life_days: f32,
33 pub importance_boost: f32,
35 pub entity_match_boost: f32,
37 pub exact_match_boost: f32,
39 pub min_results: usize,
41 pub max_rerank_candidates: usize,
43}
44
45impl Default for RerankConfig {
46 fn default() -> Self {
47 Self {
48 enabled: true,
49 strategy: RerankStrategy::Heuristic,
50 original_score_weight: 0.6,
51 rerank_score_weight: 0.4,
52 recency_boost: 0.05,
53 recency_half_life_days: 30.0,
54 importance_boost: 0.1,
55 entity_match_boost: 0.15,
56 exact_match_boost: 0.2,
57 min_results: 3,
58 max_rerank_candidates: 100,
59 }
60 }
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
65#[serde(rename_all = "snake_case")]
66pub enum RerankStrategy {
67 None,
69 #[default]
71 Heuristic,
72 CrossEncoder,
74 MultiSignal,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct RerankResult {
81 pub result: SearchResult,
83 pub original_rank: usize,
85 pub new_rank: usize,
87 pub rerank_info: RerankInfo,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize, Default)]
93pub struct RerankInfo {
94 pub original_score: f32,
96 pub final_score: f32,
98 pub rerank_score: f32,
100 pub components: RerankComponents,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize, Default)]
106pub struct RerankComponents {
107 pub term_overlap: f32,
109 pub recency: f32,
111 pub importance: f32,
113 pub entity_match: f32,
115 pub exact_match: f32,
117 pub type_relevance: f32,
119 pub tag_match: f32,
121}
122
123pub struct Reranker {
125 config: RerankConfig,
126}
127
128impl Reranker {
129 pub fn new() -> Self {
131 Self {
132 config: RerankConfig::default(),
133 }
134 }
135
136 pub fn with_config(config: RerankConfig) -> Self {
138 Self { config }
139 }
140
141 pub fn rerank(
143 &self,
144 results: Vec<SearchResult>,
145 query: &str,
146 query_entities: Option<&[String]>,
147 ) -> Vec<RerankResult> {
148 if !self.config.enabled || results.len() < self.config.min_results {
149 return results
151 .into_iter()
152 .enumerate()
153 .map(|(i, r)| RerankResult {
154 rerank_info: RerankInfo {
155 original_score: r.score,
156 final_score: r.score,
157 rerank_score: 0.0,
158 components: RerankComponents::default(),
159 },
160 result: r,
161 original_rank: i + 1,
162 new_rank: i + 1,
163 })
164 .collect();
165 }
166
167 match self.config.strategy {
168 RerankStrategy::None => self.no_rerank(results),
169 RerankStrategy::Heuristic => self.heuristic_rerank(results, query, query_entities),
170 RerankStrategy::CrossEncoder => {
171 self.heuristic_rerank(results, query, query_entities)
173 }
174 RerankStrategy::MultiSignal => self.multi_signal_rerank(results, query, query_entities),
175 }
176 }
177
178 fn no_rerank(&self, results: Vec<SearchResult>) -> Vec<RerankResult> {
180 results
181 .into_iter()
182 .enumerate()
183 .map(|(i, r)| RerankResult {
184 rerank_info: RerankInfo {
185 original_score: r.score,
186 final_score: r.score,
187 rerank_score: 0.0,
188 components: RerankComponents::default(),
189 },
190 result: r,
191 original_rank: i + 1,
192 new_rank: i + 1,
193 })
194 .collect()
195 }
196
197 fn heuristic_rerank(
199 &self,
200 results: Vec<SearchResult>,
201 query: &str,
202 query_entities: Option<&[String]>,
203 ) -> Vec<RerankResult> {
204 let query_terms = extract_terms(query);
205 let query_lower = query.to_lowercase();
206
207 let mut rerank_results: Vec<RerankResult> = results
208 .into_iter()
209 .enumerate()
210 .take(self.config.max_rerank_candidates)
211 .map(|(i, r)| {
212 let components = self.compute_rerank_components(
213 &r.memory,
214 &query_terms,
215 &query_lower,
216 query_entities,
217 );
218
219 let rerank_score = self.combine_components(&components);
220 let final_score = self.config.original_score_weight * r.score
221 + self.config.rerank_score_weight * rerank_score;
222
223 RerankResult {
224 rerank_info: RerankInfo {
225 original_score: r.score,
226 final_score,
227 rerank_score,
228 components,
229 },
230 result: r,
231 original_rank: i + 1,
232 new_rank: 0, }
234 })
235 .collect();
236
237 rerank_results.sort_by(|a, b| {
239 b.rerank_info
240 .final_score
241 .partial_cmp(&a.rerank_info.final_score)
242 .unwrap_or(std::cmp::Ordering::Equal)
243 });
244
245 for (i, result) in rerank_results.iter_mut().enumerate() {
247 result.new_rank = i + 1;
248 }
249
250 rerank_results
251 }
252
253 fn multi_signal_rerank(
255 &self,
256 results: Vec<SearchResult>,
257 query: &str,
258 query_entities: Option<&[String]>,
259 ) -> Vec<RerankResult> {
260 let query_terms = extract_terms(query);
261 let query_lower = query.to_lowercase();
262
263 let mut original_ranks: Vec<(usize, f32)> = results
265 .iter()
266 .enumerate()
267 .map(|(i, r)| (i, r.score))
268 .collect();
269 original_ranks.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
270
271 let mut recency_ranks: Vec<(usize, f32)> = results
272 .iter()
273 .enumerate()
274 .map(|(i, r)| (i, self.compute_recency_score(&r.memory)))
275 .collect();
276 recency_ranks.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
277
278 let mut term_ranks: Vec<(usize, f32)> = results
279 .iter()
280 .enumerate()
281 .map(|(i, r)| (i, compute_term_overlap(&r.memory.content, &query_terms)))
282 .collect();
283 term_ranks.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
284
285 let k = 60.0;
287 let mut rrf_scores: Vec<(usize, f32)> = vec![];
288
289 for i in 0..results.len() {
290 let orig_rank = original_ranks
291 .iter()
292 .position(|(idx, _)| *idx == i)
293 .unwrap()
294 + 1;
295 let rec_rank = recency_ranks.iter().position(|(idx, _)| *idx == i).unwrap() + 1;
296 let term_rank = term_ranks.iter().position(|(idx, _)| *idx == i).unwrap() + 1;
297
298 let rrf_score = 1.0 / (k + orig_rank as f32)
299 + 0.5 / (k + rec_rank as f32)
300 + 0.5 / (k + term_rank as f32);
301
302 rrf_scores.push((i, rrf_score));
303 }
304
305 rrf_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
306
307 let mut rerank_results: Vec<RerankResult> = results
309 .into_iter()
310 .enumerate()
311 .map(|(i, r)| {
312 let components = self.compute_rerank_components(
313 &r.memory,
314 &query_terms,
315 &query_lower,
316 query_entities,
317 );
318 let rrf_score = rrf_scores
319 .iter()
320 .find(|(idx, _)| *idx == i)
321 .map(|(_, s)| *s)
322 .unwrap_or(0.0);
323 let new_rank = rrf_scores
324 .iter()
325 .position(|(idx, _)| *idx == i)
326 .unwrap_or(i)
327 + 1;
328
329 RerankResult {
330 rerank_info: RerankInfo {
331 original_score: r.score,
332 final_score: rrf_score,
333 rerank_score: rrf_score,
334 components,
335 },
336 result: r,
337 original_rank: i + 1,
338 new_rank,
339 }
340 })
341 .collect();
342
343 rerank_results.sort_by_key(|r| r.new_rank);
344 rerank_results
345 }
346
347 fn compute_rerank_components(
349 &self,
350 memory: &Memory,
351 query_terms: &HashSet<String>,
352 query_lower: &str,
353 query_entities: Option<&[String]>,
354 ) -> RerankComponents {
355 let content_lower = memory.content.to_lowercase();
356
357 RerankComponents {
358 term_overlap: compute_term_overlap(&memory.content, query_terms),
359 recency: self.compute_recency_score(memory),
360 importance: memory.importance * self.config.importance_boost,
361 entity_match: self.compute_entity_match_score(memory, query_entities),
362 exact_match: if content_lower.contains(query_lower) {
363 self.config.exact_match_boost
364 } else {
365 0.0
366 },
367 type_relevance: self.compute_type_relevance(memory),
368 tag_match: self.compute_tag_match_score(memory, query_terms),
369 }
370 }
371
372 fn combine_components(&self, components: &RerankComponents) -> f32 {
374 components.term_overlap * 0.25
376 + components.recency * 0.15
377 + components.importance * 0.15
378 + components.entity_match * 0.15
379 + components.exact_match * 0.15
380 + components.type_relevance * 0.05
381 + components.tag_match * 0.10
382 }
383
384 fn compute_recency_score(&self, memory: &Memory) -> f32 {
386 let now = chrono::Utc::now();
387 let age_days = (now - memory.created_at).num_days() as f32;
388
389 let decay = 0.5_f32.powf(age_days / self.config.recency_half_life_days);
391 self.config.recency_boost * decay
392 }
393
394 fn compute_entity_match_score(
396 &self,
397 memory: &Memory,
398 query_entities: Option<&[String]>,
399 ) -> f32 {
400 let Some(entities) = query_entities else {
401 return 0.0;
402 };
403
404 if entities.is_empty() {
405 return 0.0;
406 }
407
408 let content_lower = memory.content.to_lowercase();
409 let matches = entities
410 .iter()
411 .filter(|e| content_lower.contains(&e.to_lowercase()))
412 .count();
413
414 if matches > 0 {
415 self.config.entity_match_boost * (matches as f32 / entities.len() as f32)
416 } else {
417 0.0
418 }
419 }
420
421 fn compute_type_relevance(&self, memory: &Memory) -> f32 {
423 match memory.memory_type {
424 MemoryType::Decision => 0.1,
425 MemoryType::Preference => 0.08,
426 MemoryType::Learning => 0.06,
427 MemoryType::Context => 0.05,
428 MemoryType::Note => 0.04,
429 MemoryType::Todo => 0.03,
430 MemoryType::Issue => 0.03,
431 MemoryType::Credential => 0.02,
432 MemoryType::Custom => 0.04,
433 MemoryType::TranscriptChunk => 0.02, MemoryType::Episodic => 0.07,
435 MemoryType::Procedural => 0.06,
436 MemoryType::Summary => 0.05,
437 MemoryType::Checkpoint => 0.04,
438 }
439 }
440
441 fn compute_tag_match_score(&self, memory: &Memory, query_terms: &HashSet<String>) -> f32 {
443 if memory.tags.is_empty() || query_terms.is_empty() {
444 return 0.0;
445 }
446
447 let tag_set: HashSet<String> = memory.tags.iter().map(|t| t.to_lowercase()).collect();
448 let matches = query_terms.intersection(&tag_set).count();
449
450 if matches > 0 {
451 0.1 * (matches as f32 / query_terms.len().min(memory.tags.len()) as f32)
452 } else {
453 0.0
454 }
455 }
456}
457
458impl Default for Reranker {
459 fn default() -> Self {
460 Self::new()
461 }
462}
463
464fn extract_terms(text: &str) -> HashSet<String> {
466 text.to_lowercase()
467 .split(|c: char| !c.is_alphanumeric())
468 .filter(|s| s.len() > 2)
469 .map(|s| s.to_string())
470 .collect()
471}
472
473fn compute_term_overlap(content: &str, query_terms: &HashSet<String>) -> f32 {
475 if query_terms.is_empty() {
476 return 0.0;
477 }
478
479 let content_terms = extract_terms(content);
480 let matches = query_terms.intersection(&content_terms).count();
481
482 matches as f32 / query_terms.len() as f32
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488 use crate::types::{MatchInfo, MemoryScope, SearchStrategy, Visibility};
489 use chrono::Utc;
490 use std::collections::HashMap;
491
492 fn create_test_memory(content: &str, importance: f32) -> Memory {
493 Memory {
494 id: 1,
495 content: content.to_string(),
496 memory_type: MemoryType::Note,
497 importance,
498 tags: vec![],
499 access_count: 0,
500 created_at: Utc::now(),
501 updated_at: Utc::now(),
502 last_accessed_at: None,
503 owner_id: None,
504 visibility: Visibility::Private,
505 version: 1,
506 has_embedding: false,
507 metadata: HashMap::new(),
508 scope: MemoryScope::Global,
509 workspace: "default".to_string(),
510 tier: crate::types::MemoryTier::Permanent,
511 expires_at: None,
512 content_hash: None,
513 event_time: None,
514 event_duration_seconds: None,
515 trigger_pattern: None,
516 procedure_success_count: 0,
517 procedure_failure_count: 0,
518 summary_of_id: None,
519 lifecycle_state: crate::types::LifecycleState::Active,
520 }
521 }
522
523 fn create_test_result(memory: Memory, score: f32) -> SearchResult {
524 SearchResult {
525 memory,
526 score,
527 match_info: MatchInfo {
528 strategy: SearchStrategy::Hybrid,
529 matched_terms: vec![],
530 highlights: vec![],
531 semantic_score: None,
532 keyword_score: Some(score),
533 },
534 }
535 }
536
537 #[test]
538 fn test_reranker_preserves_order_when_disabled() {
539 let config = RerankConfig {
540 enabled: false,
541 ..Default::default()
542 };
543 let reranker = Reranker::with_config(config);
544
545 let results = vec![
546 create_test_result(create_test_memory("First result", 0.5), 0.9),
547 create_test_result(create_test_memory("Second result", 0.5), 0.8),
548 create_test_result(create_test_memory("Third result", 0.5), 0.7),
549 ];
550
551 let reranked = reranker.rerank(results, "test query", None);
552
553 assert_eq!(reranked[0].new_rank, 1);
554 assert_eq!(reranked[1].new_rank, 2);
555 assert_eq!(reranked[2].new_rank, 3);
556 }
557
558 #[test]
559 fn test_exact_match_boost() {
560 let reranker = Reranker::new();
561
562 let results = vec![
563 create_test_result(create_test_memory("Some unrelated content", 0.5), 0.9),
564 create_test_result(
565 create_test_memory("This contains test query exactly", 0.5),
566 0.7,
567 ),
568 create_test_result(create_test_memory("Another unrelated text", 0.5), 0.8),
569 ];
570
571 let reranked = reranker.rerank(results, "test query", None);
572
573 let exact_match_result = reranked
575 .iter()
576 .find(|r| r.result.memory.content.contains("test query"))
577 .unwrap();
578 assert!(exact_match_result.rerank_info.components.exact_match > 0.0);
579 }
580
581 #[test]
582 fn test_importance_boost() {
583 let config = RerankConfig {
584 min_results: 2, ..Default::default()
586 };
587 let reranker = Reranker::with_config(config);
588
589 let mut low_importance = create_test_memory("Test content low", 0.2);
590 let mut high_importance = create_test_memory("Test content high", 0.9);
591
592 low_importance.id = 1;
593 high_importance.id = 2;
594
595 let results = vec![
596 create_test_result(low_importance, 0.8),
597 create_test_result(high_importance, 0.75),
598 ];
599
600 let reranked = reranker.rerank(results, "test", None);
601
602 let high_result = reranked.iter().find(|r| r.result.memory.id == 2).unwrap();
604 let low_result = reranked.iter().find(|r| r.result.memory.id == 1).unwrap();
605
606 assert!(
607 high_result.rerank_info.components.importance
608 > low_result.rerank_info.components.importance
609 );
610 }
611
612 #[test]
613 fn test_entity_match_boost() {
614 let config = RerankConfig {
615 min_results: 2, ..Default::default()
617 };
618 let reranker = Reranker::with_config(config);
619
620 let results = vec![
621 create_test_result(
622 create_test_memory("Content about Python programming", 0.5),
623 0.8,
624 ),
625 create_test_result(
626 create_test_memory("Content about Rust and systems", 0.5),
627 0.75,
628 ),
629 ];
630
631 let entities = vec!["Rust".to_string(), "systems".to_string()];
632 let reranked = reranker.rerank(results, "programming language", Some(&entities));
633
634 let rust_result = reranked
636 .iter()
637 .find(|r| r.result.memory.content.contains("Rust"))
638 .unwrap();
639 assert!(rust_result.rerank_info.components.entity_match > 0.0);
640 }
641
642 #[test]
643 fn test_term_overlap() {
644 let terms: HashSet<String> = ["rust", "programming", "memory"]
645 .iter()
646 .map(|s| s.to_string())
647 .collect();
648
649 let high_overlap = compute_term_overlap("Rust programming with memory management", &terms);
650 let low_overlap = compute_term_overlap("Python web development", &terms);
651
652 assert!(high_overlap > low_overlap);
653 assert!(high_overlap > 0.5); }
655
656 #[test]
657 fn test_multi_signal_rerank() {
658 let config = RerankConfig {
659 strategy: RerankStrategy::MultiSignal,
660 ..Default::default()
661 };
662 let reranker = Reranker::with_config(config);
663
664 let results = vec![
665 create_test_result(create_test_memory("First memory", 0.5), 0.9),
666 create_test_result(create_test_memory("Second memory", 0.5), 0.8),
667 create_test_result(
668 create_test_memory("Third memory with exact query", 0.5),
669 0.7,
670 ),
671 ];
672
673 let reranked = reranker.rerank(results, "exact query", None);
674
675 assert_eq!(reranked.len(), 3);
677 for r in &reranked {
679 assert!(r.rerank_info.final_score > 0.0);
680 }
681 }
682}