1use std::collections::HashMap;
12
13use uuid::Uuid;
14
15use khive_fold::objective::{Objective, ObjectiveContext};
16use khive_fold::ordering::HasId;
17
18#[derive(Debug, Clone)]
24pub struct RetrievalCandidate {
25 pub id: Uuid,
27 pub vector_score: Option<f64>,
29 pub text_score: Option<f64>,
31 pub graph_distance: Option<u32>,
33 pub rrf_score: Option<f64>,
35}
36
37impl HasId for RetrievalCandidate {
38 #[inline]
39 fn id(&self) -> Uuid {
40 self.id
41 }
42}
43
44pub struct VectorSimilarityObjective;
50
51impl Objective<RetrievalCandidate> for VectorSimilarityObjective {
52 #[inline]
53 fn score(&self, candidate: &RetrievalCandidate, _context: &ObjectiveContext) -> f64 {
54 candidate.vector_score.unwrap_or(0.0)
55 }
56
57 fn name(&self) -> &str {
58 "VectorSimilarityObjective"
59 }
60}
61
62pub struct TextRelevanceObjective;
68
69impl Objective<RetrievalCandidate> for TextRelevanceObjective {
70 #[inline]
71 fn score(&self, candidate: &RetrievalCandidate, _context: &ObjectiveContext) -> f64 {
72 candidate.text_score.unwrap_or(0.0)
73 }
74
75 fn name(&self) -> &str {
76 "TextRelevanceObjective"
77 }
78}
79
80pub struct GraphProximityObjective {
95 pub max_distance: u32,
97}
98
99impl Objective<RetrievalCandidate> for GraphProximityObjective {
100 fn score(&self, candidate: &RetrievalCandidate, _context: &ObjectiveContext) -> f64 {
101 let d = match candidate.graph_distance {
102 Some(d) => d,
103 None => return 0.0,
104 };
105 if self.max_distance == 0 || d >= self.max_distance {
106 return 0.0;
107 }
108 1.0 - (d as f64 / self.max_distance as f64)
109 }
110
111 fn name(&self) -> &str {
112 "GraphProximityObjective"
113 }
114}
115
116pub struct RrfFusionObjective;
125
126impl Objective<RetrievalCandidate> for RrfFusionObjective {
127 #[inline]
128 fn score(&self, candidate: &RetrievalCandidate, _context: &ObjectiveContext) -> f64 {
129 candidate.rrf_score.unwrap_or(0.0)
130 }
131
132 fn name(&self) -> &str {
133 "RrfFusionObjective"
134 }
135}
136
137impl Objective<NoteCandidate> for RrfFusionObjective {
138 #[inline]
139 fn score(&self, candidate: &NoteCandidate, _context: &ObjectiveContext) -> f64 {
140 candidate.rrf_score.unwrap_or(0.0)
141 }
142
143 fn name(&self) -> &str {
144 "RrfFusionObjective"
145 }
146}
147
148#[derive(Debug, Clone)]
159pub struct NoteCandidate {
160 pub id: Uuid,
162 pub rrf_score: Option<f64>,
164 pub salience: f64,
166 pub decay_factor: f64,
168 pub age_days: f64,
170 pub effective_salience: f64,
176 pub rerank_scores: HashMap<String, f64>,
179}
180
181impl HasId for NoteCandidate {
182 #[inline]
183 fn id(&self) -> Uuid {
184 self.id
185 }
186}
187
188pub struct DecayAwareSalienceObjective {
200 pub decay_rate: f64,
203}
204
205impl DecayAwareSalienceObjective {
206 pub fn new(decay_rate: f64) -> Self {
210 Self { decay_rate }
211 }
212
213 pub fn default_memory() -> Self {
215 Self::new(0.01)
216 }
217}
218
219impl Objective<NoteCandidate> for DecayAwareSalienceObjective {
220 #[inline]
221 fn score(&self, candidate: &NoteCandidate, _context: &ObjectiveContext) -> f64 {
222 candidate.salience * (-candidate.decay_factor * candidate.age_days).exp()
225 }
226
227 fn name(&self) -> &str {
228 "DecayAwareSalienceObjective"
229 }
230}
231
232pub struct AmplifiedDecayAwareSalienceObjective {
247 pub alpha: f64,
249}
250
251impl AmplifiedDecayAwareSalienceObjective {
252 pub fn new(alpha: f64) -> Self {
254 Self { alpha }
255 }
256
257 pub fn default_memory() -> Self {
259 Self::new(1.5)
260 }
261}
262
263impl Objective<NoteCandidate> for AmplifiedDecayAwareSalienceObjective {
264 #[inline]
265 fn score(&self, candidate: &NoteCandidate, _context: &ObjectiveContext) -> f64 {
266 candidate.effective_salience.powf(self.alpha)
270 }
271
272 fn name(&self) -> &str {
273 "AmplifiedDecayAwareSalienceObjective"
274 }
275}
276
277pub struct TemporalRecencyObjective {
289 pub half_life_days: f64,
291}
292
293impl TemporalRecencyObjective {
294 pub fn default_memory() -> Self {
296 Self {
297 half_life_days: 30.0,
298 }
299 }
300}
301
302impl Objective<NoteCandidate> for TemporalRecencyObjective {
303 #[inline]
304 fn score(&self, candidate: &NoteCandidate, _context: &ObjectiveContext) -> f64 {
305 let k = std::f64::consts::LN_2 / self.half_life_days.max(f64::EPSILON);
306 (-k * candidate.age_days).exp()
307 }
308
309 fn name(&self) -> &str {
310 "TemporalRecencyObjective"
311 }
312}
313
314pub struct RerankerObjective {
325 pub reranker_name: String,
327}
328
329impl RerankerObjective {
330 pub fn new(name: impl Into<String>) -> Self {
332 Self {
333 reranker_name: name.into(),
334 }
335 }
336}
337
338impl Objective<NoteCandidate> for RerankerObjective {
339 #[inline]
340 fn score(&self, candidate: &NoteCandidate, _context: &ObjectiveContext) -> f64 {
341 candidate
342 .rerank_scores
343 .get(&self.reranker_name)
344 .copied()
345 .unwrap_or(0.0)
346 }
347
348 fn name(&self) -> &str {
349 "RerankerObjective"
350 }
351}
352
353pub struct MemoryRecallPipeline {
364 pipeline: khive_fold::WeightedObjective<NoteCandidate>,
365}
366
367impl MemoryRecallPipeline {
368 pub fn new(
375 relevance_weight: f64,
376 salience_weight: f64,
377 temporal_weight: f64,
378 half_life_days: f64,
379 salience_alpha: f64,
380 ) -> Self {
381 use khive_fold::WeightedObjective;
382 let pipeline = WeightedObjective::<NoteCandidate>::new()
383 .add(Box::new(RrfFusionObjective), relevance_weight)
384 .add(
385 Box::new(AmplifiedDecayAwareSalienceObjective::new(salience_alpha)),
386 salience_weight,
387 )
388 .add(
389 Box::new(TemporalRecencyObjective { half_life_days }),
390 temporal_weight,
391 );
392 Self { pipeline }
393 }
394
395 pub fn default_memory() -> Self {
399 Self::new(0.70, 0.20, 0.10, 30.0, 1.5)
400 }
401
402 pub fn score(&self, candidate: &NoteCandidate) -> f64 {
407 let ctx = ObjectiveContext::new();
408 use khive_fold::objective::Objective;
409 self.pipeline.score(candidate, &ctx).clamp(0.0, 1.0)
410 }
411}
412
413#[cfg(test)]
416mod tests {
417 use super::*;
418 use khive_fold::objective::{Objective, ObjectiveContext};
419 use khive_fold::WeightedObjective;
420 use uuid::Uuid;
421
422 fn ctx() -> ObjectiveContext {
423 ObjectiveContext::new()
424 }
425
426 fn candidate(
427 vector: Option<f64>,
428 text: Option<f64>,
429 dist: Option<u32>,
430 rrf: Option<f64>,
431 ) -> RetrievalCandidate {
432 RetrievalCandidate {
433 id: Uuid::new_v4(),
434 vector_score: vector,
435 text_score: text,
436 graph_distance: dist,
437 rrf_score: rrf,
438 }
439 }
440
441 fn note_candidate(
442 rrf: Option<f64>,
443 salience: f64,
444 decay_factor: f64,
445 age_days: f64,
446 ) -> NoteCandidate {
447 let effective_salience = salience * (-decay_factor * age_days).exp();
449 NoteCandidate {
450 id: Uuid::new_v4(),
451 rrf_score: rrf,
452 salience,
453 decay_factor,
454 age_days,
455 effective_salience,
456 rerank_scores: HashMap::new(),
457 }
458 }
459
460 #[test]
463 fn vector_present_returns_signal() {
464 let c = candidate(Some(0.85), None, None, None);
465 let score = VectorSimilarityObjective.score(&c, &ctx());
466 assert!((score - 0.85).abs() < 1e-12);
467 }
468
469 #[test]
470 fn vector_absent_returns_zero() {
471 let c = candidate(None, None, None, None);
472 assert_eq!(VectorSimilarityObjective.score(&c, &ctx()), 0.0);
473 }
474
475 #[test]
476 fn vector_zero_score_returns_zero() {
477 let c = candidate(Some(0.0), None, None, None);
478 assert_eq!(VectorSimilarityObjective.score(&c, &ctx()), 0.0);
479 }
480
481 #[test]
484 fn text_present_returns_signal() {
485 let c = candidate(None, Some(0.6), None, None);
486 let score = TextRelevanceObjective.score(&c, &ctx());
487 assert!((score - 0.6).abs() < 1e-12);
488 }
489
490 #[test]
491 fn text_absent_returns_zero() {
492 let c = candidate(None, None, None, None);
493 assert_eq!(TextRelevanceObjective.score(&c, &ctx()), 0.0);
494 }
495
496 #[test]
499 fn graph_anchor_hit_scores_one() {
500 let c = candidate(None, None, Some(0), None);
502 let obj = GraphProximityObjective { max_distance: 3 };
503 assert!((obj.score(&c, &ctx()) - 1.0).abs() < 1e-12);
504 }
505
506 #[test]
507 fn graph_midpoint_scores_half() {
508 let c = candidate(None, None, Some(1), None);
510 let obj = GraphProximityObjective { max_distance: 2 };
511 assert!((obj.score(&c, &ctx()) - 0.5).abs() < 1e-12);
512 }
513
514 #[test]
515 fn graph_at_boundary_scores_zero() {
516 let c = candidate(None, None, Some(3), None);
518 let obj = GraphProximityObjective { max_distance: 3 };
519 assert_eq!(obj.score(&c, &ctx()), 0.0);
520 }
521
522 #[test]
523 fn graph_beyond_boundary_scores_zero() {
524 let c = candidate(None, None, Some(10), None);
525 let obj = GraphProximityObjective { max_distance: 3 };
526 assert_eq!(obj.score(&c, &ctx()), 0.0);
527 }
528
529 #[test]
530 fn graph_absent_scores_zero() {
531 let c = candidate(None, None, None, None);
532 let obj = GraphProximityObjective { max_distance: 3 };
533 assert_eq!(obj.score(&c, &ctx()), 0.0);
534 }
535
536 #[test]
537 fn graph_max_distance_zero_always_scores_zero() {
538 let c = candidate(None, None, Some(0), None);
540 let obj = GraphProximityObjective { max_distance: 0 };
541 assert_eq!(obj.score(&c, &ctx()), 0.0);
542 }
543
544 #[test]
547 fn rrf_present_returns_signal() {
548 let c = candidate(None, None, None, Some(0.0327));
549 let score = RrfFusionObjective.score(&c, &ctx());
550 assert!((score - 0.0327).abs() < 1e-12);
551 }
552
553 #[test]
554 fn rrf_absent_returns_zero() {
555 let c = candidate(None, None, None, None);
556 assert_eq!(RrfFusionObjective.score(&c, &ctx()), 0.0);
557 }
558
559 #[test]
562 fn weighted_composition_vector_and_text() {
563 let c = candidate(Some(0.8), Some(0.6), None, None);
566
567 let obj = WeightedObjective::<RetrievalCandidate>::new()
568 .add(Box::new(VectorSimilarityObjective), 0.5)
569 .add(Box::new(TextRelevanceObjective), 0.5);
570
571 let score = obj.score(&c, &ctx());
572 assert!((score - 0.7).abs() < 1e-12);
574 }
575
576 #[test]
577 fn weighted_composition_with_graph() {
578 let c = candidate(Some(1.0), Some(0.0), Some(1), None);
582
583 let obj = WeightedObjective::<RetrievalCandidate>::new()
584 .add(Box::new(VectorSimilarityObjective), 0.4)
585 .add(Box::new(TextRelevanceObjective), 0.3)
586 .add(Box::new(GraphProximityObjective { max_distance: 4 }), 0.3);
587
588 let score = obj.score(&c, &ctx());
589 assert!((score - 0.625).abs() < 1e-12);
590 }
591
592 #[test]
593 fn weighted_all_absent_returns_zero() {
594 let c = candidate(None, None, None, None);
595
596 let obj = WeightedObjective::<RetrievalCandidate>::new()
597 .add(Box::new(VectorSimilarityObjective), 0.5)
598 .add(Box::new(TextRelevanceObjective), 0.5);
599
600 assert_eq!(obj.score(&c, &ctx()), 0.0);
602 }
603
604 #[test]
607 fn has_id_returns_candidate_uuid() {
608 let id = Uuid::new_v4();
609 let c = RetrievalCandidate {
610 id,
611 vector_score: None,
612 text_score: None,
613 graph_distance: None,
614 rrf_score: None,
615 };
616 assert_eq!(c.id(), id);
617 }
618
619 #[test]
622 fn select_top_orders_by_vector_score() {
623 use khive_fold::DeterministicObjective;
624
625 let candidates = vec![
626 candidate(Some(0.3), None, None, None),
627 candidate(Some(0.9), None, None, None),
628 candidate(Some(0.6), None, None, None),
629 ];
630
631 let top = VectorSimilarityObjective.select_top_deterministic(&candidates, 2, &ctx());
632
633 assert_eq!(top.len(), 2);
634 assert!((top[0].score - 0.9).abs() < 1e-12);
635 assert!((top[1].score - 0.6).abs() < 1e-12);
636 }
637
638 #[test]
641 fn note_candidate_has_id_returns_uuid() {
642 let id = Uuid::new_v4();
643 let c = NoteCandidate {
644 id,
645 rrf_score: None,
646 salience: 0.5,
647 decay_factor: 0.01,
648 age_days: 0.0,
649 effective_salience: 0.5,
650 rerank_scores: HashMap::new(),
651 };
652 assert_eq!(c.id(), id);
653 }
654
655 #[test]
658 fn decay_aware_zero_age_returns_full_salience() {
659 let obj = DecayAwareSalienceObjective::new(0.01);
660 let c = note_candidate(None, 0.8, 0.01, 0.0);
661 let score = obj.score(&c, &ctx());
662 assert!((score - 0.8).abs() < 1e-12, "got {score}");
663 }
664
665 #[test]
666 fn decay_aware_uses_note_decay_factor_not_field() {
667 let obj = DecayAwareSalienceObjective::new(0.99); let c = note_candidate(None, 1.0, 0.01, 100.0);
671 let score = obj.score(&c, &ctx());
672 let expected = (-0.01_f64 * 100.0).exp();
673 assert!(
674 (score - expected).abs() < 1e-12,
675 "got {score}, expected {expected}"
676 );
677 }
678
679 #[test]
680 fn decay_aware_high_decay_reduces_score_faster() {
681 let obj = DecayAwareSalienceObjective::new(0.0);
683 let slow = note_candidate(None, 1.0, 0.001, 100.0);
684 let fast = note_candidate(None, 1.0, 0.1, 100.0);
685 let score_slow = obj.score(&slow, &ctx());
686 let score_fast = obj.score(&fast, &ctx());
687 assert!(
688 score_slow > score_fast,
689 "slow decay should score higher: {score_slow} vs {score_fast}"
690 );
691 }
692
693 #[test]
696 fn temporal_score_one_at_zero_age() {
697 let obj = TemporalRecencyObjective {
698 half_life_days: 30.0,
699 };
700 let c = note_candidate(None, 0.5, 0.01, 0.0);
701 let score = obj.score(&c, &ctx());
702 assert!((score - 1.0).abs() < 1e-12, "got {score}");
703 }
704
705 #[test]
706 fn temporal_score_half_at_half_life() {
707 let half_life = 30.0;
708 let obj = TemporalRecencyObjective {
709 half_life_days: half_life,
710 };
711 let c = note_candidate(None, 0.5, 0.01, half_life);
712 let score = obj.score(&c, &ctx());
713 assert!(
714 (score - 0.5).abs() < 1e-10,
715 "expected 0.5 at half_life, got {score}"
716 );
717 }
718
719 #[test]
720 fn temporal_score_decreases_with_age() {
721 let obj = TemporalRecencyObjective {
722 half_life_days: 30.0,
723 };
724 let young = note_candidate(None, 1.0, 0.01, 10.0);
725 let old = note_candidate(None, 1.0, 0.01, 100.0);
726 let score_young = obj.score(&young, &ctx());
727 let score_old = obj.score(&old, &ctx());
728 assert!(
729 score_young > score_old,
730 "younger note should score higher: {score_young} vs {score_old}"
731 );
732 }
733
734 #[test]
737 fn reranker_returns_named_score() {
738 let mut c = note_candidate(None, 0.5, 0.01, 0.0);
739 c.rerank_scores.insert("cross_encoder".to_string(), 0.9);
740 let obj = RerankerObjective::new("cross_encoder");
741 let score = obj.score(&c, &ctx());
742 assert!((score - 0.9).abs() < 1e-12, "got {score}");
743 }
744
745 #[test]
746 fn reranker_absent_key_returns_zero() {
747 let c = note_candidate(None, 0.5, 0.01, 0.0);
748 let obj = RerankerObjective::new("cross_encoder");
749 let score = obj.score(&c, &ctx());
750 assert_eq!(score, 0.0);
751 }
752
753 #[test]
754 fn reranker_different_keys_independent() {
755 let mut c = note_candidate(None, 0.5, 0.01, 0.0);
756 c.rerank_scores.insert("salience".to_string(), 0.7);
757 let obj_ce = RerankerObjective::new("cross_encoder");
758 let obj_sal = RerankerObjective::new("salience");
759 assert_eq!(obj_ce.score(&c, &ctx()), 0.0);
760 assert!((obj_sal.score(&c, &ctx()) - 0.7).abs() < 1e-12);
761 }
762
763 #[test]
766 fn memory_pipeline_weighted_composition() {
767 let c = NoteCandidate {
771 id: Uuid::new_v4(),
772 rrf_score: Some(0.5),
773 salience: 0.8,
774 decay_factor: 0.01,
775 age_days: 0.0,
776 effective_salience: 0.8, rerank_scores: HashMap::new(),
778 };
779 let pipeline = WeightedObjective::<NoteCandidate>::new()
780 .add(Box::new(RrfFusionObjective), 0.70)
781 .add(Box::new(DecayAwareSalienceObjective::new(0.0)), 0.20)
782 .add(
783 Box::new(TemporalRecencyObjective {
784 half_life_days: 30.0,
785 }),
786 0.10,
787 );
788 let score = pipeline.score(&c, &ctx());
789 assert!((score - 0.61).abs() < 1e-10, "got {score}");
791 }
792}