1use std::collections::HashMap;
12
13use uuid::Uuid;
14
15use khive_fold::objective::{Objective, ObjectiveContext};
16use khive_fold::ordering::HasId;
17
18#[derive(Debug, Clone)]
24pub struct RetrievalCandidate {
25 pub id: Uuid,
27 pub vector_score: Option<f64>,
29 pub text_score: Option<f64>,
31 pub graph_distance: Option<u32>,
33 pub rrf_score: Option<f64>,
35}
36
37impl HasId for RetrievalCandidate {
38 #[inline]
39 fn id(&self) -> Uuid {
40 self.id
41 }
42}
43
44pub struct VectorSimilarityObjective;
50
51impl Objective<RetrievalCandidate> for VectorSimilarityObjective {
52 #[inline]
53 fn score(&self, candidate: &RetrievalCandidate, _context: &ObjectiveContext) -> f64 {
54 candidate.vector_score.unwrap_or(0.0)
55 }
56
57 fn name(&self) -> &str {
58 "VectorSimilarityObjective"
59 }
60}
61
62pub struct TextRelevanceObjective;
68
69impl Objective<RetrievalCandidate> for TextRelevanceObjective {
70 #[inline]
71 fn score(&self, candidate: &RetrievalCandidate, _context: &ObjectiveContext) -> f64 {
72 candidate.text_score.unwrap_or(0.0)
73 }
74
75 fn name(&self) -> &str {
76 "TextRelevanceObjective"
77 }
78}
79
80pub struct GraphProximityObjective {
95 pub max_distance: u32,
97}
98
99impl Objective<RetrievalCandidate> for GraphProximityObjective {
100 fn score(&self, candidate: &RetrievalCandidate, _context: &ObjectiveContext) -> f64 {
101 let d = match candidate.graph_distance {
102 Some(d) => d,
103 None => return 0.0,
104 };
105 if self.max_distance == 0 || d >= self.max_distance {
106 return 0.0;
107 }
108 1.0 - (d as f64 / self.max_distance as f64)
109 }
110
111 fn name(&self) -> &str {
112 "GraphProximityObjective"
113 }
114}
115
116pub struct RrfFusionObjective;
125
126impl Objective<RetrievalCandidate> for RrfFusionObjective {
127 #[inline]
128 fn score(&self, candidate: &RetrievalCandidate, _context: &ObjectiveContext) -> f64 {
129 candidate.rrf_score.unwrap_or(0.0)
130 }
131
132 fn name(&self) -> &str {
133 "RrfFusionObjective"
134 }
135}
136
137impl Objective<NoteCandidate> for RrfFusionObjective {
138 #[inline]
139 fn score(&self, candidate: &NoteCandidate, _context: &ObjectiveContext) -> f64 {
140 candidate.rrf_score.unwrap_or(0.0)
141 }
142
143 fn name(&self) -> &str {
144 "RrfFusionObjective"
145 }
146}
147
148#[derive(Debug, Clone)]
159pub struct NoteCandidate {
160 pub id: Uuid,
162 pub rrf_score: Option<f64>,
164 pub salience: f64,
166 pub decay_factor: f64,
168 pub age_days: f64,
170 pub rerank_scores: HashMap<String, f64>,
173}
174
175impl HasId for NoteCandidate {
176 #[inline]
177 fn id(&self) -> Uuid {
178 self.id
179 }
180}
181
182pub struct DecayAwareImportanceObjective {
194 pub decay_rate: f64,
197}
198
199impl DecayAwareImportanceObjective {
200 pub fn new(decay_rate: f64) -> Self {
204 Self { decay_rate }
205 }
206
207 pub fn default_memory() -> Self {
209 Self::new(0.01)
210 }
211}
212
213impl Objective<NoteCandidate> for DecayAwareImportanceObjective {
214 #[inline]
215 fn score(&self, candidate: &NoteCandidate, _context: &ObjectiveContext) -> f64 {
216 candidate.salience * (-candidate.decay_factor * candidate.age_days).exp()
219 }
220
221 fn name(&self) -> &str {
222 "DecayAwareImportanceObjective"
223 }
224}
225
226pub struct TemporalRecencyObjective {
238 pub half_life_days: f64,
240}
241
242impl TemporalRecencyObjective {
243 pub fn default_memory() -> Self {
245 Self {
246 half_life_days: 30.0,
247 }
248 }
249}
250
251impl Objective<NoteCandidate> for TemporalRecencyObjective {
252 #[inline]
253 fn score(&self, candidate: &NoteCandidate, _context: &ObjectiveContext) -> f64 {
254 let k = std::f64::consts::LN_2 / self.half_life_days.max(f64::EPSILON);
255 (-k * candidate.age_days).exp()
256 }
257
258 fn name(&self) -> &str {
259 "TemporalRecencyObjective"
260 }
261}
262
263pub struct RerankerObjective {
274 pub reranker_name: String,
276}
277
278impl RerankerObjective {
279 pub fn new(name: impl Into<String>) -> Self {
281 Self {
282 reranker_name: name.into(),
283 }
284 }
285}
286
287impl Objective<NoteCandidate> for RerankerObjective {
288 #[inline]
289 fn score(&self, candidate: &NoteCandidate, _context: &ObjectiveContext) -> f64 {
290 candidate
291 .rerank_scores
292 .get(&self.reranker_name)
293 .copied()
294 .unwrap_or(0.0)
295 }
296
297 fn name(&self) -> &str {
298 "RerankerObjective"
299 }
300}
301
302#[cfg(test)]
305mod tests {
306 use super::*;
307 use khive_fold::objective::{Objective, ObjectiveContext};
308 use khive_fold::WeightedObjective;
309 use uuid::Uuid;
310
311 fn ctx() -> ObjectiveContext {
312 ObjectiveContext::new()
313 }
314
315 fn candidate(
316 vector: Option<f64>,
317 text: Option<f64>,
318 dist: Option<u32>,
319 rrf: Option<f64>,
320 ) -> RetrievalCandidate {
321 RetrievalCandidate {
322 id: Uuid::new_v4(),
323 vector_score: vector,
324 text_score: text,
325 graph_distance: dist,
326 rrf_score: rrf,
327 }
328 }
329
330 fn note_candidate(
331 rrf: Option<f64>,
332 salience: f64,
333 decay_factor: f64,
334 age_days: f64,
335 ) -> NoteCandidate {
336 NoteCandidate {
337 id: Uuid::new_v4(),
338 rrf_score: rrf,
339 salience,
340 decay_factor,
341 age_days,
342 rerank_scores: HashMap::new(),
343 }
344 }
345
346 #[test]
349 fn vector_present_returns_signal() {
350 let c = candidate(Some(0.85), None, None, None);
351 let score = VectorSimilarityObjective.score(&c, &ctx());
352 assert!((score - 0.85).abs() < 1e-12);
353 }
354
355 #[test]
356 fn vector_absent_returns_zero() {
357 let c = candidate(None, None, None, None);
358 assert_eq!(VectorSimilarityObjective.score(&c, &ctx()), 0.0);
359 }
360
361 #[test]
362 fn vector_zero_score_returns_zero() {
363 let c = candidate(Some(0.0), None, None, None);
364 assert_eq!(VectorSimilarityObjective.score(&c, &ctx()), 0.0);
365 }
366
367 #[test]
370 fn text_present_returns_signal() {
371 let c = candidate(None, Some(0.6), None, None);
372 let score = TextRelevanceObjective.score(&c, &ctx());
373 assert!((score - 0.6).abs() < 1e-12);
374 }
375
376 #[test]
377 fn text_absent_returns_zero() {
378 let c = candidate(None, None, None, None);
379 assert_eq!(TextRelevanceObjective.score(&c, &ctx()), 0.0);
380 }
381
382 #[test]
385 fn graph_anchor_hit_scores_one() {
386 let c = candidate(None, None, Some(0), None);
388 let obj = GraphProximityObjective { max_distance: 3 };
389 assert!((obj.score(&c, &ctx()) - 1.0).abs() < 1e-12);
390 }
391
392 #[test]
393 fn graph_midpoint_scores_half() {
394 let c = candidate(None, None, Some(1), None);
396 let obj = GraphProximityObjective { max_distance: 2 };
397 assert!((obj.score(&c, &ctx()) - 0.5).abs() < 1e-12);
398 }
399
400 #[test]
401 fn graph_at_boundary_scores_zero() {
402 let c = candidate(None, None, Some(3), None);
404 let obj = GraphProximityObjective { max_distance: 3 };
405 assert_eq!(obj.score(&c, &ctx()), 0.0);
406 }
407
408 #[test]
409 fn graph_beyond_boundary_scores_zero() {
410 let c = candidate(None, None, Some(10), None);
411 let obj = GraphProximityObjective { max_distance: 3 };
412 assert_eq!(obj.score(&c, &ctx()), 0.0);
413 }
414
415 #[test]
416 fn graph_absent_scores_zero() {
417 let c = candidate(None, None, None, None);
418 let obj = GraphProximityObjective { max_distance: 3 };
419 assert_eq!(obj.score(&c, &ctx()), 0.0);
420 }
421
422 #[test]
423 fn graph_max_distance_zero_always_scores_zero() {
424 let c = candidate(None, None, Some(0), None);
426 let obj = GraphProximityObjective { max_distance: 0 };
427 assert_eq!(obj.score(&c, &ctx()), 0.0);
428 }
429
430 #[test]
433 fn rrf_present_returns_signal() {
434 let c = candidate(None, None, None, Some(0.0327));
435 let score = RrfFusionObjective.score(&c, &ctx());
436 assert!((score - 0.0327).abs() < 1e-12);
437 }
438
439 #[test]
440 fn rrf_absent_returns_zero() {
441 let c = candidate(None, None, None, None);
442 assert_eq!(RrfFusionObjective.score(&c, &ctx()), 0.0);
443 }
444
445 #[test]
448 fn weighted_composition_vector_and_text() {
449 let c = candidate(Some(0.8), Some(0.6), None, None);
452
453 let obj = WeightedObjective::<RetrievalCandidate>::new()
454 .add(Box::new(VectorSimilarityObjective), 0.5)
455 .add(Box::new(TextRelevanceObjective), 0.5);
456
457 let score = obj.score(&c, &ctx());
458 assert!((score - 0.7).abs() < 1e-12);
460 }
461
462 #[test]
463 fn weighted_composition_with_graph() {
464 let c = candidate(Some(1.0), Some(0.0), Some(1), None);
468
469 let obj = WeightedObjective::<RetrievalCandidate>::new()
470 .add(Box::new(VectorSimilarityObjective), 0.4)
471 .add(Box::new(TextRelevanceObjective), 0.3)
472 .add(Box::new(GraphProximityObjective { max_distance: 4 }), 0.3);
473
474 let score = obj.score(&c, &ctx());
475 assert!((score - 0.625).abs() < 1e-12);
476 }
477
478 #[test]
479 fn weighted_all_absent_returns_zero() {
480 let c = candidate(None, None, None, None);
481
482 let obj = WeightedObjective::<RetrievalCandidate>::new()
483 .add(Box::new(VectorSimilarityObjective), 0.5)
484 .add(Box::new(TextRelevanceObjective), 0.5);
485
486 assert_eq!(obj.score(&c, &ctx()), 0.0);
488 }
489
490 #[test]
493 fn has_id_returns_candidate_uuid() {
494 let id = Uuid::new_v4();
495 let c = RetrievalCandidate {
496 id,
497 vector_score: None,
498 text_score: None,
499 graph_distance: None,
500 rrf_score: None,
501 };
502 assert_eq!(c.id(), id);
503 }
504
505 #[test]
508 fn select_top_orders_by_vector_score() {
509 use khive_fold::DeterministicObjective;
510
511 let candidates = vec![
512 candidate(Some(0.3), None, None, None),
513 candidate(Some(0.9), None, None, None),
514 candidate(Some(0.6), None, None, None),
515 ];
516
517 let top = VectorSimilarityObjective.select_top_deterministic(&candidates, 2, &ctx());
518
519 assert_eq!(top.len(), 2);
520 assert!((top[0].score - 0.9).abs() < 1e-12);
521 assert!((top[1].score - 0.6).abs() < 1e-12);
522 }
523
524 #[test]
527 fn note_candidate_has_id_returns_uuid() {
528 let id = Uuid::new_v4();
529 let c = NoteCandidate {
530 id,
531 rrf_score: None,
532 salience: 0.5,
533 decay_factor: 0.01,
534 age_days: 0.0,
535 rerank_scores: HashMap::new(),
536 };
537 assert_eq!(c.id(), id);
538 }
539
540 #[test]
543 fn decay_aware_zero_age_returns_full_salience() {
544 let obj = DecayAwareImportanceObjective::new(0.01);
545 let c = note_candidate(None, 0.8, 0.01, 0.0);
546 let score = obj.score(&c, &ctx());
547 assert!((score - 0.8).abs() < 1e-12, "got {score}");
548 }
549
550 #[test]
551 fn decay_aware_uses_note_decay_factor_not_field() {
552 let obj = DecayAwareImportanceObjective::new(0.99); let c = note_candidate(None, 1.0, 0.01, 100.0);
556 let score = obj.score(&c, &ctx());
557 let expected = (-0.01_f64 * 100.0).exp();
558 assert!(
559 (score - expected).abs() < 1e-12,
560 "got {score}, expected {expected}"
561 );
562 }
563
564 #[test]
565 fn decay_aware_high_decay_reduces_score_faster() {
566 let obj = DecayAwareImportanceObjective::new(0.0);
568 let slow = note_candidate(None, 1.0, 0.001, 100.0);
569 let fast = note_candidate(None, 1.0, 0.1, 100.0);
570 let score_slow = obj.score(&slow, &ctx());
571 let score_fast = obj.score(&fast, &ctx());
572 assert!(
573 score_slow > score_fast,
574 "slow decay should score higher: {score_slow} vs {score_fast}"
575 );
576 }
577
578 #[test]
581 fn temporal_score_one_at_zero_age() {
582 let obj = TemporalRecencyObjective {
583 half_life_days: 30.0,
584 };
585 let c = note_candidate(None, 0.5, 0.01, 0.0);
586 let score = obj.score(&c, &ctx());
587 assert!((score - 1.0).abs() < 1e-12, "got {score}");
588 }
589
590 #[test]
591 fn temporal_score_half_at_half_life() {
592 let half_life = 30.0;
593 let obj = TemporalRecencyObjective {
594 half_life_days: half_life,
595 };
596 let c = note_candidate(None, 0.5, 0.01, half_life);
597 let score = obj.score(&c, &ctx());
598 assert!(
599 (score - 0.5).abs() < 1e-10,
600 "expected 0.5 at half_life, got {score}"
601 );
602 }
603
604 #[test]
605 fn temporal_score_decreases_with_age() {
606 let obj = TemporalRecencyObjective {
607 half_life_days: 30.0,
608 };
609 let young = note_candidate(None, 1.0, 0.01, 10.0);
610 let old = note_candidate(None, 1.0, 0.01, 100.0);
611 let score_young = obj.score(&young, &ctx());
612 let score_old = obj.score(&old, &ctx());
613 assert!(
614 score_young > score_old,
615 "younger note should score higher: {score_young} vs {score_old}"
616 );
617 }
618
619 #[test]
622 fn reranker_returns_named_score() {
623 let mut c = note_candidate(None, 0.5, 0.01, 0.0);
624 c.rerank_scores.insert("cross_encoder".to_string(), 0.9);
625 let obj = RerankerObjective::new("cross_encoder");
626 let score = obj.score(&c, &ctx());
627 assert!((score - 0.9).abs() < 1e-12, "got {score}");
628 }
629
630 #[test]
631 fn reranker_absent_key_returns_zero() {
632 let c = note_candidate(None, 0.5, 0.01, 0.0);
633 let obj = RerankerObjective::new("cross_encoder");
634 let score = obj.score(&c, &ctx());
635 assert_eq!(score, 0.0);
636 }
637
638 #[test]
639 fn reranker_different_keys_independent() {
640 let mut c = note_candidate(None, 0.5, 0.01, 0.0);
641 c.rerank_scores.insert("salience".to_string(), 0.7);
642 let obj_ce = RerankerObjective::new("cross_encoder");
643 let obj_sal = RerankerObjective::new("salience");
644 assert_eq!(obj_ce.score(&c, &ctx()), 0.0);
645 assert!((obj_sal.score(&c, &ctx()) - 0.7).abs() < 1e-12);
646 }
647
648 #[test]
651 fn memory_pipeline_weighted_composition() {
652 let c = NoteCandidate {
656 id: Uuid::new_v4(),
657 rrf_score: Some(0.5),
658 salience: 0.8,
659 decay_factor: 0.01,
660 age_days: 0.0,
661 rerank_scores: HashMap::new(),
662 };
663 let pipeline = WeightedObjective::<NoteCandidate>::new()
664 .add(Box::new(RrfFusionObjective), 0.70)
665 .add(Box::new(DecayAwareImportanceObjective::new(0.0)), 0.20)
666 .add(
667 Box::new(TemporalRecencyObjective {
668 half_life_days: 30.0,
669 }),
670 0.10,
671 );
672 let score = pipeline.score(&c, &ctx());
673 assert!((score - 0.61).abs() < 1e-10, "got {score}");
675 }
676}