use std::collections::HashMap;
use uuid::Uuid;
use khive_fold::objective::{Objective, ObjectiveContext};
use khive_fold::ordering::HasId;
#[derive(Debug, Clone)]
pub struct RetrievalCandidate {
pub id: Uuid,
pub vector_score: Option<f64>,
pub text_score: Option<f64>,
pub graph_distance: Option<u32>,
pub rrf_score: Option<f64>,
}
impl HasId for RetrievalCandidate {
#[inline]
fn id(&self) -> Uuid {
self.id
}
}
pub struct VectorSimilarityObjective;
impl Objective<RetrievalCandidate> for VectorSimilarityObjective {
#[inline]
fn score(&self, candidate: &RetrievalCandidate, _context: &ObjectiveContext) -> f64 {
candidate.vector_score.unwrap_or(0.0)
}
fn name(&self) -> &str {
"VectorSimilarityObjective"
}
}
pub struct TextRelevanceObjective;
impl Objective<RetrievalCandidate> for TextRelevanceObjective {
#[inline]
fn score(&self, candidate: &RetrievalCandidate, _context: &ObjectiveContext) -> f64 {
candidate.text_score.unwrap_or(0.0)
}
fn name(&self) -> &str {
"TextRelevanceObjective"
}
}
pub struct GraphProximityObjective {
pub max_distance: u32,
}
impl Objective<RetrievalCandidate> for GraphProximityObjective {
fn score(&self, candidate: &RetrievalCandidate, _context: &ObjectiveContext) -> f64 {
let d = match candidate.graph_distance {
Some(d) => d,
None => return 0.0,
};
if self.max_distance == 0 || d >= self.max_distance {
return 0.0;
}
1.0 - (d as f64 / self.max_distance as f64)
}
fn name(&self) -> &str {
"GraphProximityObjective"
}
}
pub struct RrfFusionObjective;
impl Objective<RetrievalCandidate> for RrfFusionObjective {
#[inline]
fn score(&self, candidate: &RetrievalCandidate, _context: &ObjectiveContext) -> f64 {
candidate.rrf_score.unwrap_or(0.0)
}
fn name(&self) -> &str {
"RrfFusionObjective"
}
}
impl Objective<NoteCandidate> for RrfFusionObjective {
#[inline]
fn score(&self, candidate: &NoteCandidate, _context: &ObjectiveContext) -> f64 {
candidate.rrf_score.unwrap_or(0.0)
}
fn name(&self) -> &str {
"RrfFusionObjective"
}
}
#[derive(Debug, Clone)]
pub struct NoteCandidate {
pub id: Uuid,
pub rrf_score: Option<f64>,
pub salience: f64,
pub decay_factor: f64,
pub age_days: f64,
pub effective_salience: f64,
pub rerank_scores: HashMap<String, f64>,
}
impl HasId for NoteCandidate {
#[inline]
fn id(&self) -> Uuid {
self.id
}
}
pub struct DecayAwareSalienceObjective {
pub decay_rate: f64,
}
impl DecayAwareSalienceObjective {
pub fn new(decay_rate: f64) -> Self {
Self { decay_rate }
}
pub fn default_memory() -> Self {
Self::new(0.01)
}
}
impl Objective<NoteCandidate> for DecayAwareSalienceObjective {
#[inline]
fn score(&self, candidate: &NoteCandidate, _context: &ObjectiveContext) -> f64 {
candidate.salience * (-candidate.decay_factor * candidate.age_days).exp()
}
fn name(&self) -> &str {
"DecayAwareSalienceObjective"
}
}
pub struct AmplifiedDecayAwareSalienceObjective {
pub alpha: f64,
}
impl AmplifiedDecayAwareSalienceObjective {
pub fn new(alpha: f64) -> Self {
Self { alpha }
}
pub fn default_memory() -> Self {
Self::new(1.5)
}
}
impl Objective<NoteCandidate> for AmplifiedDecayAwareSalienceObjective {
#[inline]
fn score(&self, candidate: &NoteCandidate, _context: &ObjectiveContext) -> f64 {
candidate.effective_salience.powf(self.alpha)
}
fn name(&self) -> &str {
"AmplifiedDecayAwareSalienceObjective"
}
}
pub struct TemporalRecencyObjective {
pub half_life_days: f64,
}
impl TemporalRecencyObjective {
pub fn default_memory() -> Self {
Self {
half_life_days: 30.0,
}
}
}
impl Objective<NoteCandidate> for TemporalRecencyObjective {
#[inline]
fn score(&self, candidate: &NoteCandidate, _context: &ObjectiveContext) -> f64 {
let k = std::f64::consts::LN_2 / self.half_life_days.max(f64::EPSILON);
(-k * candidate.age_days).exp()
}
fn name(&self) -> &str {
"TemporalRecencyObjective"
}
}
pub struct RerankerObjective {
pub reranker_name: String,
}
impl RerankerObjective {
pub fn new(name: impl Into<String>) -> Self {
Self {
reranker_name: name.into(),
}
}
}
impl Objective<NoteCandidate> for RerankerObjective {
#[inline]
fn score(&self, candidate: &NoteCandidate, _context: &ObjectiveContext) -> f64 {
candidate
.rerank_scores
.get(&self.reranker_name)
.copied()
.unwrap_or(0.0)
}
fn name(&self) -> &str {
"RerankerObjective"
}
}
pub struct MemoryRecallPipeline {
pipeline: khive_fold::WeightedObjective<NoteCandidate>,
}
impl MemoryRecallPipeline {
pub fn new(
relevance_weight: f64,
salience_weight: f64,
temporal_weight: f64,
half_life_days: f64,
salience_alpha: f64,
) -> Self {
use khive_fold::WeightedObjective;
let pipeline = WeightedObjective::<NoteCandidate>::new()
.add(Box::new(RrfFusionObjective), relevance_weight)
.add(
Box::new(AmplifiedDecayAwareSalienceObjective::new(salience_alpha)),
salience_weight,
)
.add(
Box::new(TemporalRecencyObjective { half_life_days }),
temporal_weight,
);
Self { pipeline }
}
pub fn default_memory() -> Self {
Self::new(0.70, 0.20, 0.10, 30.0, 1.5)
}
pub fn score(&self, candidate: &NoteCandidate) -> f64 {
let ctx = ObjectiveContext::new();
use khive_fold::objective::Objective;
self.pipeline.score(candidate, &ctx).clamp(0.0, 1.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
use khive_fold::objective::{Objective, ObjectiveContext};
use khive_fold::WeightedObjective;
use uuid::Uuid;
fn ctx() -> ObjectiveContext {
ObjectiveContext::new()
}
fn candidate(
vector: Option<f64>,
text: Option<f64>,
dist: Option<u32>,
rrf: Option<f64>,
) -> RetrievalCandidate {
RetrievalCandidate {
id: Uuid::new_v4(),
vector_score: vector,
text_score: text,
graph_distance: dist,
rrf_score: rrf,
}
}
fn note_candidate(
rrf: Option<f64>,
salience: f64,
decay_factor: f64,
age_days: f64,
) -> NoteCandidate {
let effective_salience = salience * (-decay_factor * age_days).exp();
NoteCandidate {
id: Uuid::new_v4(),
rrf_score: rrf,
salience,
decay_factor,
age_days,
effective_salience,
rerank_scores: HashMap::new(),
}
}
#[test]
fn vector_present_returns_signal() {
let c = candidate(Some(0.85), None, None, None);
let score = VectorSimilarityObjective.score(&c, &ctx());
assert!((score - 0.85).abs() < 1e-12);
}
#[test]
fn vector_absent_returns_zero() {
let c = candidate(None, None, None, None);
assert_eq!(VectorSimilarityObjective.score(&c, &ctx()), 0.0);
}
#[test]
fn vector_zero_score_returns_zero() {
let c = candidate(Some(0.0), None, None, None);
assert_eq!(VectorSimilarityObjective.score(&c, &ctx()), 0.0);
}
#[test]
fn text_present_returns_signal() {
let c = candidate(None, Some(0.6), None, None);
let score = TextRelevanceObjective.score(&c, &ctx());
assert!((score - 0.6).abs() < 1e-12);
}
#[test]
fn text_absent_returns_zero() {
let c = candidate(None, None, None, None);
assert_eq!(TextRelevanceObjective.score(&c, &ctx()), 0.0);
}
#[test]
fn graph_anchor_hit_scores_one() {
let c = candidate(None, None, Some(0), None);
let obj = GraphProximityObjective { max_distance: 3 };
assert!((obj.score(&c, &ctx()) - 1.0).abs() < 1e-12);
}
#[test]
fn graph_midpoint_scores_half() {
let c = candidate(None, None, Some(1), None);
let obj = GraphProximityObjective { max_distance: 2 };
assert!((obj.score(&c, &ctx()) - 0.5).abs() < 1e-12);
}
#[test]
fn graph_at_boundary_scores_zero() {
let c = candidate(None, None, Some(3), None);
let obj = GraphProximityObjective { max_distance: 3 };
assert_eq!(obj.score(&c, &ctx()), 0.0);
}
#[test]
fn graph_beyond_boundary_scores_zero() {
let c = candidate(None, None, Some(10), None);
let obj = GraphProximityObjective { max_distance: 3 };
assert_eq!(obj.score(&c, &ctx()), 0.0);
}
#[test]
fn graph_absent_scores_zero() {
let c = candidate(None, None, None, None);
let obj = GraphProximityObjective { max_distance: 3 };
assert_eq!(obj.score(&c, &ctx()), 0.0);
}
#[test]
fn graph_max_distance_zero_always_scores_zero() {
let c = candidate(None, None, Some(0), None);
let obj = GraphProximityObjective { max_distance: 0 };
assert_eq!(obj.score(&c, &ctx()), 0.0);
}
#[test]
fn rrf_present_returns_signal() {
let c = candidate(None, None, None, Some(0.0327));
let score = RrfFusionObjective.score(&c, &ctx());
assert!((score - 0.0327).abs() < 1e-12);
}
#[test]
fn rrf_absent_returns_zero() {
let c = candidate(None, None, None, None);
assert_eq!(RrfFusionObjective.score(&c, &ctx()), 0.0);
}
#[test]
fn weighted_composition_vector_and_text() {
let c = candidate(Some(0.8), Some(0.6), None, None);
let obj = WeightedObjective::<RetrievalCandidate>::new()
.add(Box::new(VectorSimilarityObjective), 0.5)
.add(Box::new(TextRelevanceObjective), 0.5);
let score = obj.score(&c, &ctx());
assert!((score - 0.7).abs() < 1e-12);
}
#[test]
fn weighted_composition_with_graph() {
let c = candidate(Some(1.0), Some(0.0), Some(1), None);
let obj = WeightedObjective::<RetrievalCandidate>::new()
.add(Box::new(VectorSimilarityObjective), 0.4)
.add(Box::new(TextRelevanceObjective), 0.3)
.add(Box::new(GraphProximityObjective { max_distance: 4 }), 0.3);
let score = obj.score(&c, &ctx());
assert!((score - 0.625).abs() < 1e-12);
}
#[test]
fn weighted_all_absent_returns_zero() {
let c = candidate(None, None, None, None);
let obj = WeightedObjective::<RetrievalCandidate>::new()
.add(Box::new(VectorSimilarityObjective), 0.5)
.add(Box::new(TextRelevanceObjective), 0.5);
assert_eq!(obj.score(&c, &ctx()), 0.0);
}
#[test]
fn has_id_returns_candidate_uuid() {
let id = Uuid::new_v4();
let c = RetrievalCandidate {
id,
vector_score: None,
text_score: None,
graph_distance: None,
rrf_score: None,
};
assert_eq!(c.id(), id);
}
#[test]
fn select_top_orders_by_vector_score() {
use khive_fold::DeterministicObjective;
let candidates = vec![
candidate(Some(0.3), None, None, None),
candidate(Some(0.9), None, None, None),
candidate(Some(0.6), None, None, None),
];
let top = VectorSimilarityObjective.select_top_deterministic(&candidates, 2, &ctx());
assert_eq!(top.len(), 2);
assert!((top[0].score - 0.9).abs() < 1e-12);
assert!((top[1].score - 0.6).abs() < 1e-12);
}
#[test]
fn note_candidate_has_id_returns_uuid() {
let id = Uuid::new_v4();
let c = NoteCandidate {
id,
rrf_score: None,
salience: 0.5,
decay_factor: 0.01,
age_days: 0.0,
effective_salience: 0.5,
rerank_scores: HashMap::new(),
};
assert_eq!(c.id(), id);
}
#[test]
fn decay_aware_zero_age_returns_full_salience() {
let obj = DecayAwareSalienceObjective::new(0.01);
let c = note_candidate(None, 0.8, 0.01, 0.0);
let score = obj.score(&c, &ctx());
assert!((score - 0.8).abs() < 1e-12, "got {score}");
}
#[test]
fn decay_aware_uses_note_decay_factor_not_field() {
let obj = DecayAwareSalienceObjective::new(0.99); let c = note_candidate(None, 1.0, 0.01, 100.0);
let score = obj.score(&c, &ctx());
let expected = (-0.01_f64 * 100.0).exp();
assert!(
(score - expected).abs() < 1e-12,
"got {score}, expected {expected}"
);
}
#[test]
fn decay_aware_high_decay_reduces_score_faster() {
let obj = DecayAwareSalienceObjective::new(0.0);
let slow = note_candidate(None, 1.0, 0.001, 100.0);
let fast = note_candidate(None, 1.0, 0.1, 100.0);
let score_slow = obj.score(&slow, &ctx());
let score_fast = obj.score(&fast, &ctx());
assert!(
score_slow > score_fast,
"slow decay should score higher: {score_slow} vs {score_fast}"
);
}
#[test]
fn temporal_score_one_at_zero_age() {
let obj = TemporalRecencyObjective {
half_life_days: 30.0,
};
let c = note_candidate(None, 0.5, 0.01, 0.0);
let score = obj.score(&c, &ctx());
assert!((score - 1.0).abs() < 1e-12, "got {score}");
}
#[test]
fn temporal_score_half_at_half_life() {
let half_life = 30.0;
let obj = TemporalRecencyObjective {
half_life_days: half_life,
};
let c = note_candidate(None, 0.5, 0.01, half_life);
let score = obj.score(&c, &ctx());
assert!(
(score - 0.5).abs() < 1e-10,
"expected 0.5 at half_life, got {score}"
);
}
#[test]
fn temporal_score_decreases_with_age() {
let obj = TemporalRecencyObjective {
half_life_days: 30.0,
};
let young = note_candidate(None, 1.0, 0.01, 10.0);
let old = note_candidate(None, 1.0, 0.01, 100.0);
let score_young = obj.score(&young, &ctx());
let score_old = obj.score(&old, &ctx());
assert!(
score_young > score_old,
"younger note should score higher: {score_young} vs {score_old}"
);
}
#[test]
fn reranker_returns_named_score() {
let mut c = note_candidate(None, 0.5, 0.01, 0.0);
c.rerank_scores.insert("cross_encoder".to_string(), 0.9);
let obj = RerankerObjective::new("cross_encoder");
let score = obj.score(&c, &ctx());
assert!((score - 0.9).abs() < 1e-12, "got {score}");
}
#[test]
fn reranker_absent_key_returns_zero() {
let c = note_candidate(None, 0.5, 0.01, 0.0);
let obj = RerankerObjective::new("cross_encoder");
let score = obj.score(&c, &ctx());
assert_eq!(score, 0.0);
}
#[test]
fn reranker_different_keys_independent() {
let mut c = note_candidate(None, 0.5, 0.01, 0.0);
c.rerank_scores.insert("salience".to_string(), 0.7);
let obj_ce = RerankerObjective::new("cross_encoder");
let obj_sal = RerankerObjective::new("salience");
assert_eq!(obj_ce.score(&c, &ctx()), 0.0);
assert!((obj_sal.score(&c, &ctx()) - 0.7).abs() < 1e-12);
}
#[test]
fn memory_pipeline_weighted_composition() {
let c = NoteCandidate {
id: Uuid::new_v4(),
rrf_score: Some(0.5),
salience: 0.8,
decay_factor: 0.01,
age_days: 0.0,
effective_salience: 0.8, rerank_scores: HashMap::new(),
};
let pipeline = WeightedObjective::<NoteCandidate>::new()
.add(Box::new(RrfFusionObjective), 0.70)
.add(Box::new(DecayAwareSalienceObjective::new(0.0)), 0.20)
.add(
Box::new(TemporalRecencyObjective {
half_life_days: 30.0,
}),
0.10,
);
let score = pipeline.score(&c, &ctx());
assert!((score - 0.61).abs() < 1e-10, "got {score}");
}
}