1use std::collections::HashMap;
8
9use uuid::Uuid;
10
11use khive_fold::objective::{Objective, ObjectiveContext};
12use khive_fold::ordering::HasId;
13
14#[derive(Debug, Clone)]
20pub struct RetrievalCandidate {
21 pub id: Uuid,
23 pub vector_score: Option<f64>,
25 pub text_score: Option<f64>,
27 pub graph_distance: Option<u32>,
29 pub rrf_score: Option<f64>,
31}
32
33impl HasId for RetrievalCandidate {
34 #[inline]
35 fn id(&self) -> Uuid {
36 self.id
37 }
38}
39
40pub struct VectorSimilarityObjective;
46
47impl Objective<RetrievalCandidate> for VectorSimilarityObjective {
48 #[inline]
49 fn score(&self, candidate: &RetrievalCandidate, _context: &ObjectiveContext) -> f64 {
50 candidate.vector_score.unwrap_or(0.0)
51 }
52
53 fn name(&self) -> &str {
54 "VectorSimilarityObjective"
55 }
56}
57
58pub struct TextRelevanceObjective;
64
65impl Objective<RetrievalCandidate> for TextRelevanceObjective {
66 #[inline]
67 fn score(&self, candidate: &RetrievalCandidate, _context: &ObjectiveContext) -> f64 {
68 candidate.text_score.unwrap_or(0.0)
69 }
70
71 fn name(&self) -> &str {
72 "TextRelevanceObjective"
73 }
74}
75
76pub struct GraphProximityObjective {
91 pub max_distance: u32,
93}
94
95impl Objective<RetrievalCandidate> for GraphProximityObjective {
96 fn score(&self, candidate: &RetrievalCandidate, _context: &ObjectiveContext) -> f64 {
97 let d = match candidate.graph_distance {
98 Some(d) => d,
99 None => return 0.0,
100 };
101 if self.max_distance == 0 || d >= self.max_distance {
102 return 0.0;
103 }
104 1.0 - (d as f64 / self.max_distance as f64)
105 }
106
107 fn name(&self) -> &str {
108 "GraphProximityObjective"
109 }
110}
111
112pub struct RrfFusionObjective;
121
122impl Objective<RetrievalCandidate> for RrfFusionObjective {
123 #[inline]
124 fn score(&self, candidate: &RetrievalCandidate, _context: &ObjectiveContext) -> f64 {
125 candidate.rrf_score.unwrap_or(0.0)
126 }
127
128 fn name(&self) -> &str {
129 "RrfFusionObjective"
130 }
131}
132
133impl Objective<NoteCandidate> for RrfFusionObjective {
134 #[inline]
135 fn score(&self, candidate: &NoteCandidate, _context: &ObjectiveContext) -> f64 {
136 candidate.rrf_score.unwrap_or(0.0)
137 }
138
139 fn name(&self) -> &str {
140 "RrfFusionObjective"
141 }
142}
143
144#[derive(Debug, Clone)]
153pub struct NoteCandidate {
154 pub id: Uuid,
156 pub rrf_score: Option<f64>,
158 pub salience: f64,
160 pub decay_factor: f64,
162 pub age_days: f64,
164 pub effective_salience: f64,
170 pub rerank_scores: HashMap<String, f64>,
173}
174
175impl HasId for NoteCandidate {
176 #[inline]
177 fn id(&self) -> Uuid {
178 self.id
179 }
180}
181
182pub struct DecayAwareSalienceObjective {
194 pub decay_rate: f64,
197}
198
199impl DecayAwareSalienceObjective {
200 pub fn new(decay_rate: f64) -> Self {
204 Self { decay_rate }
205 }
206
207 pub fn default_memory() -> Self {
209 Self::new(0.01)
210 }
211}
212
213impl Objective<NoteCandidate> for DecayAwareSalienceObjective {
214 #[inline]
215 fn score(&self, candidate: &NoteCandidate, _context: &ObjectiveContext) -> f64 {
216 candidate.salience * (-candidate.decay_factor * candidate.age_days).exp()
218 }
219
220 fn name(&self) -> &str {
221 "DecayAwareSalienceObjective"
222 }
223}
224
225pub struct AmplifiedDecayAwareSalienceObjective {
240 pub alpha: f64,
242}
243
244impl AmplifiedDecayAwareSalienceObjective {
245 pub fn new(alpha: f64) -> Self {
247 Self { alpha }
248 }
249
250 pub fn default_memory() -> Self {
252 Self::new(1.5)
253 }
254}
255
256impl Objective<NoteCandidate> for AmplifiedDecayAwareSalienceObjective {
257 #[inline]
258 fn score(&self, candidate: &NoteCandidate, _context: &ObjectiveContext) -> f64 {
259 candidate.effective_salience.powf(self.alpha)
263 }
264
265 fn name(&self) -> &str {
266 "AmplifiedDecayAwareSalienceObjective"
267 }
268}
269
270pub struct TemporalRecencyObjective {
282 pub half_life_days: f64,
284}
285
286impl TemporalRecencyObjective {
287 pub fn default_memory() -> Self {
289 Self {
290 half_life_days: 30.0,
291 }
292 }
293}
294
295impl Objective<NoteCandidate> for TemporalRecencyObjective {
296 #[inline]
297 fn score(&self, candidate: &NoteCandidate, _context: &ObjectiveContext) -> f64 {
298 let k = std::f64::consts::LN_2 / self.half_life_days.max(f64::EPSILON);
299 (-k * candidate.age_days).exp()
300 }
301
302 fn name(&self) -> &str {
303 "TemporalRecencyObjective"
304 }
305}
306
307pub struct RerankerObjective {
317 pub reranker_name: String,
319}
320
321impl RerankerObjective {
322 pub fn new(name: impl Into<String>) -> Self {
324 Self {
325 reranker_name: name.into(),
326 }
327 }
328}
329
330impl Objective<NoteCandidate> for RerankerObjective {
331 #[inline]
332 fn score(&self, candidate: &NoteCandidate, _context: &ObjectiveContext) -> f64 {
333 candidate
334 .rerank_scores
335 .get(&self.reranker_name)
336 .copied()
337 .unwrap_or(0.0)
338 }
339
340 fn name(&self) -> &str {
341 "RerankerObjective"
342 }
343}
344
345pub struct MemoryRecallPipeline {
354 pipeline: khive_fold::WeightedObjective<NoteCandidate>,
355}
356
357impl MemoryRecallPipeline {
358 pub fn new(
365 relevance_weight: f64,
366 salience_weight: f64,
367 temporal_weight: f64,
368 half_life_days: f64,
369 salience_alpha: f64,
370 ) -> Self {
371 use khive_fold::WeightedObjective;
372 let pipeline = WeightedObjective::<NoteCandidate>::new()
373 .add(Box::new(RrfFusionObjective), relevance_weight)
374 .add(
375 Box::new(AmplifiedDecayAwareSalienceObjective::new(salience_alpha)),
376 salience_weight,
377 )
378 .add(
379 Box::new(TemporalRecencyObjective { half_life_days }),
380 temporal_weight,
381 );
382 Self { pipeline }
383 }
384
385 pub fn default_memory() -> Self {
389 Self::new(0.70, 0.20, 0.10, 30.0, 1.5)
390 }
391
392 pub fn score(&self, candidate: &NoteCandidate) -> f64 {
397 let ctx = ObjectiveContext::new();
398 use khive_fold::objective::Objective;
399 self.pipeline.score(candidate, &ctx).clamp(0.0, 1.0)
400 }
401}
402
403#[cfg(test)]
409mod tests {
410 use super::*;
411 use khive_fold::objective::{Objective, ObjectiveContext};
412 use khive_fold::WeightedObjective;
413 use uuid::Uuid;
414
415 fn ctx() -> ObjectiveContext {
416 ObjectiveContext::new()
417 }
418
419 fn candidate(
420 vector: Option<f64>,
421 text: Option<f64>,
422 dist: Option<u32>,
423 rrf: Option<f64>,
424 ) -> RetrievalCandidate {
425 RetrievalCandidate {
426 id: Uuid::new_v4(),
427 vector_score: vector,
428 text_score: text,
429 graph_distance: dist,
430 rrf_score: rrf,
431 }
432 }
433
434 fn note_candidate(
435 rrf: Option<f64>,
436 salience: f64,
437 decay_factor: f64,
438 age_days: f64,
439 ) -> NoteCandidate {
440 let effective_salience = salience * (-decay_factor * age_days).exp();
442 NoteCandidate {
443 id: Uuid::new_v4(),
444 rrf_score: rrf,
445 salience,
446 decay_factor,
447 age_days,
448 effective_salience,
449 rerank_scores: HashMap::new(),
450 }
451 }
452
453 #[test]
456 fn vector_present_returns_signal() {
457 let c = candidate(Some(0.85), None, None, None);
458 let score = VectorSimilarityObjective.score(&c, &ctx());
459 assert!((score - 0.85).abs() < 1e-12);
460 }
461
462 #[test]
463 fn vector_absent_returns_zero() {
464 let c = candidate(None, None, None, None);
465 assert_eq!(VectorSimilarityObjective.score(&c, &ctx()), 0.0);
466 }
467
468 #[test]
469 fn vector_zero_score_returns_zero() {
470 let c = candidate(Some(0.0), None, None, None);
471 assert_eq!(VectorSimilarityObjective.score(&c, &ctx()), 0.0);
472 }
473
474 #[test]
477 fn text_present_returns_signal() {
478 let c = candidate(None, Some(0.6), None, None);
479 let score = TextRelevanceObjective.score(&c, &ctx());
480 assert!((score - 0.6).abs() < 1e-12);
481 }
482
483 #[test]
484 fn text_absent_returns_zero() {
485 let c = candidate(None, None, None, None);
486 assert_eq!(TextRelevanceObjective.score(&c, &ctx()), 0.0);
487 }
488
489 #[test]
492 fn graph_anchor_hit_scores_one() {
493 let c = candidate(None, None, Some(0), None);
495 let obj = GraphProximityObjective { max_distance: 3 };
496 assert!((obj.score(&c, &ctx()) - 1.0).abs() < 1e-12);
497 }
498
499 #[test]
500 fn graph_midpoint_scores_half() {
501 let c = candidate(None, None, Some(1), None);
503 let obj = GraphProximityObjective { max_distance: 2 };
504 assert!((obj.score(&c, &ctx()) - 0.5).abs() < 1e-12);
505 }
506
507 #[test]
508 fn graph_at_boundary_scores_zero() {
509 let c = candidate(None, None, Some(3), None);
511 let obj = GraphProximityObjective { max_distance: 3 };
512 assert_eq!(obj.score(&c, &ctx()), 0.0);
513 }
514
515 #[test]
516 fn graph_beyond_boundary_scores_zero() {
517 let c = candidate(None, None, Some(10), None);
518 let obj = GraphProximityObjective { max_distance: 3 };
519 assert_eq!(obj.score(&c, &ctx()), 0.0);
520 }
521
522 #[test]
523 fn graph_absent_scores_zero() {
524 let c = candidate(None, None, None, None);
525 let obj = GraphProximityObjective { max_distance: 3 };
526 assert_eq!(obj.score(&c, &ctx()), 0.0);
527 }
528
529 #[test]
530 fn graph_max_distance_zero_always_scores_zero() {
531 let c = candidate(None, None, Some(0), None);
533 let obj = GraphProximityObjective { max_distance: 0 };
534 assert_eq!(obj.score(&c, &ctx()), 0.0);
535 }
536
537 #[test]
540 fn rrf_present_returns_signal() {
541 let c = candidate(None, None, None, Some(0.0327));
542 let score = RrfFusionObjective.score(&c, &ctx());
543 assert!((score - 0.0327).abs() < 1e-12);
544 }
545
546 #[test]
547 fn rrf_absent_returns_zero() {
548 let c = candidate(None, None, None, None);
549 assert_eq!(RrfFusionObjective.score(&c, &ctx()), 0.0);
550 }
551
552 #[test]
555 fn weighted_composition_vector_and_text() {
556 let c = candidate(Some(0.8), Some(0.6), None, None);
559
560 let obj = WeightedObjective::<RetrievalCandidate>::new()
561 .add(Box::new(VectorSimilarityObjective), 0.5)
562 .add(Box::new(TextRelevanceObjective), 0.5);
563
564 let score = obj.score(&c, &ctx());
565 assert!((score - 0.7).abs() < 1e-12);
567 }
568
569 #[test]
570 fn weighted_composition_with_graph() {
571 let c = candidate(Some(1.0), Some(0.0), Some(1), None);
575
576 let obj = WeightedObjective::<RetrievalCandidate>::new()
577 .add(Box::new(VectorSimilarityObjective), 0.4)
578 .add(Box::new(TextRelevanceObjective), 0.3)
579 .add(Box::new(GraphProximityObjective { max_distance: 4 }), 0.3);
580
581 let score = obj.score(&c, &ctx());
582 assert!((score - 0.625).abs() < 1e-12);
583 }
584
585 #[test]
586 fn weighted_all_absent_returns_zero() {
587 let c = candidate(None, None, None, None);
588
589 let obj = WeightedObjective::<RetrievalCandidate>::new()
590 .add(Box::new(VectorSimilarityObjective), 0.5)
591 .add(Box::new(TextRelevanceObjective), 0.5);
592
593 assert_eq!(obj.score(&c, &ctx()), 0.0);
595 }
596
597 #[test]
600 fn has_id_returns_candidate_uuid() {
601 let id = Uuid::new_v4();
602 let c = RetrievalCandidate {
603 id,
604 vector_score: None,
605 text_score: None,
606 graph_distance: None,
607 rrf_score: None,
608 };
609 assert_eq!(c.id(), id);
610 }
611
612 #[test]
615 fn select_top_orders_by_vector_score() {
616 use khive_fold::DeterministicObjective;
617
618 let candidates = vec![
619 candidate(Some(0.3), None, None, None),
620 candidate(Some(0.9), None, None, None),
621 candidate(Some(0.6), None, None, None),
622 ];
623
624 let top = VectorSimilarityObjective.select_top_deterministic(&candidates, 2, &ctx());
625
626 assert_eq!(top.len(), 2);
627 assert!((top[0].score - 0.9).abs() < 1e-12);
628 assert!((top[1].score - 0.6).abs() < 1e-12);
629 }
630
631 #[test]
634 fn note_candidate_has_id_returns_uuid() {
635 let id = Uuid::new_v4();
636 let c = NoteCandidate {
637 id,
638 rrf_score: None,
639 salience: 0.5,
640 decay_factor: 0.01,
641 age_days: 0.0,
642 effective_salience: 0.5,
643 rerank_scores: HashMap::new(),
644 };
645 assert_eq!(c.id(), id);
646 }
647
648 #[test]
651 fn decay_aware_zero_age_returns_full_salience() {
652 let obj = DecayAwareSalienceObjective::new(0.01);
653 let c = note_candidate(None, 0.8, 0.01, 0.0);
654 let score = obj.score(&c, &ctx());
655 assert!((score - 0.8).abs() < 1e-12, "got {score}");
656 }
657
658 #[test]
659 fn decay_aware_uses_note_decay_factor_not_field() {
660 let obj = DecayAwareSalienceObjective::new(0.99); let c = note_candidate(None, 1.0, 0.01, 100.0);
664 let score = obj.score(&c, &ctx());
665 let expected = (-0.01_f64 * 100.0).exp();
666 assert!(
667 (score - expected).abs() < 1e-12,
668 "got {score}, expected {expected}"
669 );
670 }
671
672 #[test]
673 fn decay_aware_high_decay_reduces_score_faster() {
674 let obj = DecayAwareSalienceObjective::new(0.0);
676 let slow = note_candidate(None, 1.0, 0.001, 100.0);
677 let fast = note_candidate(None, 1.0, 0.1, 100.0);
678 let score_slow = obj.score(&slow, &ctx());
679 let score_fast = obj.score(&fast, &ctx());
680 assert!(
681 score_slow > score_fast,
682 "slow decay should score higher: {score_slow} vs {score_fast}"
683 );
684 }
685
686 #[test]
689 fn temporal_score_one_at_zero_age() {
690 let obj = TemporalRecencyObjective {
691 half_life_days: 30.0,
692 };
693 let c = note_candidate(None, 0.5, 0.01, 0.0);
694 let score = obj.score(&c, &ctx());
695 assert!((score - 1.0).abs() < 1e-12, "got {score}");
696 }
697
698 #[test]
699 fn temporal_score_half_at_half_life() {
700 let half_life = 30.0;
701 let obj = TemporalRecencyObjective {
702 half_life_days: half_life,
703 };
704 let c = note_candidate(None, 0.5, 0.01, half_life);
705 let score = obj.score(&c, &ctx());
706 assert!(
707 (score - 0.5).abs() < 1e-10,
708 "expected 0.5 at half_life, got {score}"
709 );
710 }
711
712 #[test]
713 fn temporal_score_decreases_with_age() {
714 let obj = TemporalRecencyObjective {
715 half_life_days: 30.0,
716 };
717 let young = note_candidate(None, 1.0, 0.01, 10.0);
718 let old = note_candidate(None, 1.0, 0.01, 100.0);
719 let score_young = obj.score(&young, &ctx());
720 let score_old = obj.score(&old, &ctx());
721 assert!(
722 score_young > score_old,
723 "younger note should score higher: {score_young} vs {score_old}"
724 );
725 }
726
727 #[test]
730 fn reranker_returns_named_score() {
731 let mut c = note_candidate(None, 0.5, 0.01, 0.0);
732 c.rerank_scores.insert("cross_encoder".to_string(), 0.9);
733 let obj = RerankerObjective::new("cross_encoder");
734 let score = obj.score(&c, &ctx());
735 assert!((score - 0.9).abs() < 1e-12, "got {score}");
736 }
737
738 #[test]
739 fn reranker_absent_key_returns_zero() {
740 let c = note_candidate(None, 0.5, 0.01, 0.0);
741 let obj = RerankerObjective::new("cross_encoder");
742 let score = obj.score(&c, &ctx());
743 assert_eq!(score, 0.0);
744 }
745
746 #[test]
747 fn reranker_different_keys_independent() {
748 let mut c = note_candidate(None, 0.5, 0.01, 0.0);
749 c.rerank_scores.insert("salience".to_string(), 0.7);
750 let obj_ce = RerankerObjective::new("cross_encoder");
751 let obj_sal = RerankerObjective::new("salience");
752 assert_eq!(obj_ce.score(&c, &ctx()), 0.0);
753 assert!((obj_sal.score(&c, &ctx()) - 0.7).abs() < 1e-12);
754 }
755
756 #[test]
759 fn memory_pipeline_weighted_composition() {
760 let c = NoteCandidate {
764 id: Uuid::new_v4(),
765 rrf_score: Some(0.5),
766 salience: 0.8,
767 decay_factor: 0.01,
768 age_days: 0.0,
769 effective_salience: 0.8, rerank_scores: HashMap::new(),
771 };
772 let pipeline = WeightedObjective::<NoteCandidate>::new()
773 .add(Box::new(RrfFusionObjective), 0.70)
774 .add(Box::new(DecayAwareSalienceObjective::new(0.0)), 0.20)
775 .add(
776 Box::new(TemporalRecencyObjective {
777 half_life_days: 30.0,
778 }),
779 0.10,
780 );
781 let score = pipeline.score(&c, &ctx());
782 assert!((score - 0.61).abs() < 1e-10, "got {score}");
784 }
785}