use uuid::Uuid;
use khive_fold::objective::{Objective, ObjectiveContext};
use khive_fold::ordering::HasId;
#[derive(Debug, Clone)]
pub struct RetrievalCandidate {
pub id: Uuid,
pub vector_score: Option<f64>,
pub text_score: Option<f64>,
pub graph_distance: Option<u32>,
pub rrf_score: Option<f64>,
}
impl HasId for RetrievalCandidate {
#[inline]
fn id(&self) -> Uuid {
self.id
}
}
pub struct VectorSimilarityObjective;
impl Objective<RetrievalCandidate> for VectorSimilarityObjective {
#[inline]
fn score(&self, candidate: &RetrievalCandidate, _context: &ObjectiveContext) -> f64 {
candidate.vector_score.unwrap_or(0.0)
}
fn name(&self) -> &str {
"VectorSimilarityObjective"
}
}
pub struct TextRelevanceObjective;
impl Objective<RetrievalCandidate> for TextRelevanceObjective {
#[inline]
fn score(&self, candidate: &RetrievalCandidate, _context: &ObjectiveContext) -> f64 {
candidate.text_score.unwrap_or(0.0)
}
fn name(&self) -> &str {
"TextRelevanceObjective"
}
}
pub struct GraphProximityObjective {
pub max_distance: u32,
}
impl Objective<RetrievalCandidate> for GraphProximityObjective {
fn score(&self, candidate: &RetrievalCandidate, _context: &ObjectiveContext) -> f64 {
let d = match candidate.graph_distance {
Some(d) => d,
None => return 0.0,
};
if self.max_distance == 0 || d >= self.max_distance {
return 0.0;
}
1.0 - (d as f64 / self.max_distance as f64)
}
fn name(&self) -> &str {
"GraphProximityObjective"
}
}
pub struct RrfFusionObjective;
impl Objective<RetrievalCandidate> for RrfFusionObjective {
#[inline]
fn score(&self, candidate: &RetrievalCandidate, _context: &ObjectiveContext) -> f64 {
candidate.rrf_score.unwrap_or(0.0)
}
fn name(&self) -> &str {
"RrfFusionObjective"
}
}
#[cfg(test)]
mod tests {
use super::*;
use khive_fold::objective::{Objective, ObjectiveContext};
use khive_fold::WeightedObjective;
use uuid::Uuid;
fn ctx() -> ObjectiveContext {
ObjectiveContext::new()
}
fn candidate(
vector: Option<f64>,
text: Option<f64>,
dist: Option<u32>,
rrf: Option<f64>,
) -> RetrievalCandidate {
RetrievalCandidate {
id: Uuid::new_v4(),
vector_score: vector,
text_score: text,
graph_distance: dist,
rrf_score: rrf,
}
}
#[test]
fn vector_present_returns_signal() {
let c = candidate(Some(0.85), None, None, None);
let score = VectorSimilarityObjective.score(&c, &ctx());
assert!((score - 0.85).abs() < 1e-12);
}
#[test]
fn vector_absent_returns_zero() {
let c = candidate(None, None, None, None);
assert_eq!(VectorSimilarityObjective.score(&c, &ctx()), 0.0);
}
#[test]
fn vector_zero_score_returns_zero() {
let c = candidate(Some(0.0), None, None, None);
assert_eq!(VectorSimilarityObjective.score(&c, &ctx()), 0.0);
}
#[test]
fn text_present_returns_signal() {
let c = candidate(None, Some(0.6), None, None);
let score = TextRelevanceObjective.score(&c, &ctx());
assert!((score - 0.6).abs() < 1e-12);
}
#[test]
fn text_absent_returns_zero() {
let c = candidate(None, None, None, None);
assert_eq!(TextRelevanceObjective.score(&c, &ctx()), 0.0);
}
#[test]
fn graph_anchor_hit_scores_one() {
let c = candidate(None, None, Some(0), None);
let obj = GraphProximityObjective { max_distance: 3 };
assert!((obj.score(&c, &ctx()) - 1.0).abs() < 1e-12);
}
#[test]
fn graph_midpoint_scores_half() {
let c = candidate(None, None, Some(1), None);
let obj = GraphProximityObjective { max_distance: 2 };
assert!((obj.score(&c, &ctx()) - 0.5).abs() < 1e-12);
}
#[test]
fn graph_at_boundary_scores_zero() {
let c = candidate(None, None, Some(3), None);
let obj = GraphProximityObjective { max_distance: 3 };
assert_eq!(obj.score(&c, &ctx()), 0.0);
}
#[test]
fn graph_beyond_boundary_scores_zero() {
let c = candidate(None, None, Some(10), None);
let obj = GraphProximityObjective { max_distance: 3 };
assert_eq!(obj.score(&c, &ctx()), 0.0);
}
#[test]
fn graph_absent_scores_zero() {
let c = candidate(None, None, None, None);
let obj = GraphProximityObjective { max_distance: 3 };
assert_eq!(obj.score(&c, &ctx()), 0.0);
}
#[test]
fn graph_max_distance_zero_always_scores_zero() {
let c = candidate(None, None, Some(0), None);
let obj = GraphProximityObjective { max_distance: 0 };
assert_eq!(obj.score(&c, &ctx()), 0.0);
}
#[test]
fn rrf_present_returns_signal() {
let c = candidate(None, None, None, Some(0.0327));
let score = RrfFusionObjective.score(&c, &ctx());
assert!((score - 0.0327).abs() < 1e-12);
}
#[test]
fn rrf_absent_returns_zero() {
let c = candidate(None, None, None, None);
assert_eq!(RrfFusionObjective.score(&c, &ctx()), 0.0);
}
#[test]
fn weighted_composition_vector_and_text() {
let c = candidate(Some(0.8), Some(0.6), None, None);
let obj = WeightedObjective::<RetrievalCandidate>::new()
.add(Box::new(VectorSimilarityObjective), 0.5)
.add(Box::new(TextRelevanceObjective), 0.5);
let score = obj.score(&c, &ctx());
assert!((score - 0.7).abs() < 1e-12);
}
#[test]
fn weighted_composition_with_graph() {
let c = candidate(Some(1.0), Some(0.0), Some(1), None);
let obj = WeightedObjective::<RetrievalCandidate>::new()
.add(Box::new(VectorSimilarityObjective), 0.4)
.add(Box::new(TextRelevanceObjective), 0.3)
.add(Box::new(GraphProximityObjective { max_distance: 4 }), 0.3);
let score = obj.score(&c, &ctx());
assert!((score - 0.625).abs() < 1e-12);
}
#[test]
fn weighted_all_absent_returns_zero() {
let c = candidate(None, None, None, None);
let obj = WeightedObjective::<RetrievalCandidate>::new()
.add(Box::new(VectorSimilarityObjective), 0.5)
.add(Box::new(TextRelevanceObjective), 0.5);
assert_eq!(obj.score(&c, &ctx()), 0.0);
}
#[test]
fn has_id_returns_candidate_uuid() {
let id = Uuid::new_v4();
let c = RetrievalCandidate {
id,
vector_score: None,
text_score: None,
graph_distance: None,
rrf_score: None,
};
assert_eq!(c.id(), id);
}
#[test]
fn select_top_orders_by_vector_score() {
use khive_fold::DeterministicObjective;
let candidates = vec![
candidate(Some(0.3), None, None, None),
candidate(Some(0.9), None, None, None),
candidate(Some(0.6), None, None, None),
];
let top = VectorSimilarityObjective.select_top_deterministic(&candidates, 2, &ctx());
assert_eq!(top.len(), 2);
assert!((top[0].score - 0.9).abs() < 1e-12);
assert!((top[1].score - 0.6).abs() < 1e-12);
}
}