Skip to main content

khive_runtime/
objectives.rs

1//! Retrieval Objective implementations for khive-runtime.
2//!
3//! Domain-specific objectives that operate on pre-computed retrieval signals.
4//! Pure math: no IO, no async. The runtime layer materialises the signal data
5//! and feeds it in via the candidate struct.
6//!
7//! See ADR-061 — Retrieval Infrastructure.
8
9use uuid::Uuid;
10
11use khive_fold::objective::{Objective, ObjectiveContext};
12use khive_fold::ordering::HasId;
13
14/// Pre-computed retrieval signals for a single candidate entity.
15///
16/// All fields are `Option` — a missing signal scores 0.0. The runtime layer
17/// is responsible for populating whichever fields are available before handing
18/// the slice to an objective.
19#[derive(Debug, Clone)]
20pub struct RetrievalCandidate {
21    /// Stable entity UUID.
22    pub id: Uuid,
23    /// Cosine similarity to the query vector (0.0–1.0).
24    pub vector_score: Option<f64>,
25    /// BM25/FTS relevance score (0.0–1.0 normalised, or raw rank score).
26    pub text_score: Option<f64>,
27    /// Hop distance from the nearest anchor node (0 = anchor itself).
28    pub graph_distance: Option<u32>,
29    /// Pre-fused RRF score from `FusionStrategy::Rrf`.
30    pub rrf_score: Option<f64>,
31}
32
33impl HasId for RetrievalCandidate {
34    #[inline]
35    fn id(&self) -> Uuid {
36        self.id
37    }
38}
39
40// ── VectorSimilarityObjective ────────────────────────────────────────────────
41
42/// Scores a candidate by cosine similarity to the query vector.
43///
44/// Returns `vector_score` unchanged, or 0.0 when the field is absent.
45pub struct VectorSimilarityObjective;
46
47impl Objective<RetrievalCandidate> for VectorSimilarityObjective {
48    #[inline]
49    fn score(&self, candidate: &RetrievalCandidate, _context: &ObjectiveContext) -> f64 {
50        candidate.vector_score.unwrap_or(0.0)
51    }
52
53    fn name(&self) -> &str {
54        "VectorSimilarityObjective"
55    }
56}
57
58// ── TextRelevanceObjective ───────────────────────────────────────────────────
59
60/// Scores a candidate by BM25/FTS relevance.
61///
62/// Returns `text_score` unchanged, or 0.0 when the field is absent.
63pub struct TextRelevanceObjective;
64
65impl Objective<RetrievalCandidate> for TextRelevanceObjective {
66    #[inline]
67    fn score(&self, candidate: &RetrievalCandidate, _context: &ObjectiveContext) -> f64 {
68        candidate.text_score.unwrap_or(0.0)
69    }
70
71    fn name(&self) -> &str {
72        "TextRelevanceObjective"
73    }
74}
75
76// ── GraphProximityObjective ──────────────────────────────────────────────────
77
78/// Scores a candidate by graph proximity to anchor nodes.
79///
80/// Score formula (linear decay):
81///
82/// ```text
83/// d ≤ max_distance → score = 1.0 − (d as f64 / max_distance as f64)
84/// d > max_distance → score = 0.0
85/// missing          → score = 0.0
86/// ```
87///
88/// Direct anchor hits (d = 0) score 1.0. The boundary `d == max_distance`
89/// scores 0.0; anything beyond also scores 0.0.
90pub struct GraphProximityObjective {
91    /// Maximum hop distance to consider. Candidates beyond this score 0.0.
92    pub max_distance: u32,
93}
94
95impl Objective<RetrievalCandidate> for GraphProximityObjective {
96    fn score(&self, candidate: &RetrievalCandidate, _context: &ObjectiveContext) -> f64 {
97        let d = match candidate.graph_distance {
98            Some(d) => d,
99            None => return 0.0,
100        };
101        if self.max_distance == 0 || d >= self.max_distance {
102            return 0.0;
103        }
104        1.0 - (d as f64 / self.max_distance as f64)
105    }
106
107    fn name(&self) -> &str {
108        "GraphProximityObjective"
109    }
110}
111
112// ── RrfFusionObjective ───────────────────────────────────────────────────────
113
114/// Scores a candidate by its pre-computed RRF fusion score.
115///
116/// Returns `rrf_score` unchanged, or 0.0 when the field is absent.
117pub struct RrfFusionObjective;
118
119impl Objective<RetrievalCandidate> for RrfFusionObjective {
120    #[inline]
121    fn score(&self, candidate: &RetrievalCandidate, _context: &ObjectiveContext) -> f64 {
122        candidate.rrf_score.unwrap_or(0.0)
123    }
124
125    fn name(&self) -> &str {
126        "RrfFusionObjective"
127    }
128}
129
130// ────────────────────────────────────────────────────────────────────────────
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use khive_fold::objective::{Objective, ObjectiveContext};
136    use khive_fold::WeightedObjective;
137    use uuid::Uuid;
138
139    fn ctx() -> ObjectiveContext {
140        ObjectiveContext::new()
141    }
142
143    fn candidate(
144        vector: Option<f64>,
145        text: Option<f64>,
146        dist: Option<u32>,
147        rrf: Option<f64>,
148    ) -> RetrievalCandidate {
149        RetrievalCandidate {
150            id: Uuid::new_v4(),
151            vector_score: vector,
152            text_score: text,
153            graph_distance: dist,
154            rrf_score: rrf,
155        }
156    }
157
158    // ── VectorSimilarityObjective ────────────────────────────────────────
159
160    #[test]
161    fn vector_present_returns_signal() {
162        let c = candidate(Some(0.85), None, None, None);
163        let score = VectorSimilarityObjective.score(&c, &ctx());
164        assert!((score - 0.85).abs() < 1e-12);
165    }
166
167    #[test]
168    fn vector_absent_returns_zero() {
169        let c = candidate(None, None, None, None);
170        assert_eq!(VectorSimilarityObjective.score(&c, &ctx()), 0.0);
171    }
172
173    #[test]
174    fn vector_zero_score_returns_zero() {
175        let c = candidate(Some(0.0), None, None, None);
176        assert_eq!(VectorSimilarityObjective.score(&c, &ctx()), 0.0);
177    }
178
179    // ── TextRelevanceObjective ───────────────────────────────────────────
180
181    #[test]
182    fn text_present_returns_signal() {
183        let c = candidate(None, Some(0.6), None, None);
184        let score = TextRelevanceObjective.score(&c, &ctx());
185        assert!((score - 0.6).abs() < 1e-12);
186    }
187
188    #[test]
189    fn text_absent_returns_zero() {
190        let c = candidate(None, None, None, None);
191        assert_eq!(TextRelevanceObjective.score(&c, &ctx()), 0.0);
192    }
193
194    // ── GraphProximityObjective ──────────────────────────────────────────
195
196    #[test]
197    fn graph_anchor_hit_scores_one() {
198        // d=0 → score = 1.0 − 0/max = 1.0
199        let c = candidate(None, None, Some(0), None);
200        let obj = GraphProximityObjective { max_distance: 3 };
201        assert!((obj.score(&c, &ctx()) - 1.0).abs() < 1e-12);
202    }
203
204    #[test]
205    fn graph_midpoint_scores_half() {
206        // d=1, max=2 → score = 1.0 − 1/2 = 0.5
207        let c = candidate(None, None, Some(1), None);
208        let obj = GraphProximityObjective { max_distance: 2 };
209        assert!((obj.score(&c, &ctx()) - 0.5).abs() < 1e-12);
210    }
211
212    #[test]
213    fn graph_at_boundary_scores_zero() {
214        // d == max_distance → score = 0.0 (boundary excluded)
215        let c = candidate(None, None, Some(3), None);
216        let obj = GraphProximityObjective { max_distance: 3 };
217        assert_eq!(obj.score(&c, &ctx()), 0.0);
218    }
219
220    #[test]
221    fn graph_beyond_boundary_scores_zero() {
222        let c = candidate(None, None, Some(10), None);
223        let obj = GraphProximityObjective { max_distance: 3 };
224        assert_eq!(obj.score(&c, &ctx()), 0.0);
225    }
226
227    #[test]
228    fn graph_absent_scores_zero() {
229        let c = candidate(None, None, None, None);
230        let obj = GraphProximityObjective { max_distance: 3 };
231        assert_eq!(obj.score(&c, &ctx()), 0.0);
232    }
233
234    #[test]
235    fn graph_max_distance_zero_always_scores_zero() {
236        // max_distance=0 is degenerate; guard against divide-by-zero.
237        let c = candidate(None, None, Some(0), None);
238        let obj = GraphProximityObjective { max_distance: 0 };
239        assert_eq!(obj.score(&c, &ctx()), 0.0);
240    }
241
242    // ── RrfFusionObjective ───────────────────────────────────────────────
243
244    #[test]
245    fn rrf_present_returns_signal() {
246        let c = candidate(None, None, None, Some(0.0327));
247        let score = RrfFusionObjective.score(&c, &ctx());
248        assert!((score - 0.0327).abs() < 1e-12);
249    }
250
251    #[test]
252    fn rrf_absent_returns_zero() {
253        let c = candidate(None, None, None, None);
254        assert_eq!(RrfFusionObjective.score(&c, &ctx()), 0.0);
255    }
256
257    // ── WeightedObjective composition ───────────────────────────────────
258
259    #[test]
260    fn weighted_composition_vector_and_text() {
261        // Candidate with vector=0.8, text=0.6
262        // Weighted(0.5*vector + 0.5*text) = 0.5*0.8 + 0.5*0.6 = 0.7
263        let c = candidate(Some(0.8), Some(0.6), None, None);
264
265        let obj = WeightedObjective::<RetrievalCandidate>::new()
266            .add(Box::new(VectorSimilarityObjective), 0.5)
267            .add(Box::new(TextRelevanceObjective), 0.5);
268
269        let score = obj.score(&c, &ctx());
270        // WeightedObjective divides by total weight (1.0), so result is 0.7
271        assert!((score - 0.7).abs() < 1e-12);
272    }
273
274    #[test]
275    fn weighted_composition_with_graph() {
276        // vector=1.0, text=0.0, graph d=1/max=4 → proximity = 1 - 1/4 = 0.75
277        // weights: vector=0.4, text=0.3, graph=0.3
278        // weighted sum = (0.4*1.0 + 0.3*0.0 + 0.3*0.75) / 1.0 = 0.4 + 0.0 + 0.225 = 0.625
279        let c = candidate(Some(1.0), Some(0.0), Some(1), None);
280
281        let obj = WeightedObjective::<RetrievalCandidate>::new()
282            .add(Box::new(VectorSimilarityObjective), 0.4)
283            .add(Box::new(TextRelevanceObjective), 0.3)
284            .add(Box::new(GraphProximityObjective { max_distance: 4 }), 0.3);
285
286        let score = obj.score(&c, &ctx());
287        assert!((score - 0.625).abs() < 1e-12);
288    }
289
290    #[test]
291    fn weighted_all_absent_returns_zero() {
292        let c = candidate(None, None, None, None);
293
294        let obj = WeightedObjective::<RetrievalCandidate>::new()
295            .add(Box::new(VectorSimilarityObjective), 0.5)
296            .add(Box::new(TextRelevanceObjective), 0.5);
297
298        // 0.0 * 0.5 + 0.0 * 0.5 = 0.0
299        assert_eq!(obj.score(&c, &ctx()), 0.0);
300    }
301
302    // ── HasId ────────────────────────────────────────────────────────────
303
304    #[test]
305    fn has_id_returns_candidate_uuid() {
306        let id = Uuid::new_v4();
307        let c = RetrievalCandidate {
308            id,
309            vector_score: None,
310            text_score: None,
311            graph_distance: None,
312            rrf_score: None,
313        };
314        assert_eq!(c.id(), id);
315    }
316
317    // ── select_top via DeterministicObjective ────────────────────────────
318
319    #[test]
320    fn select_top_orders_by_vector_score() {
321        use khive_fold::DeterministicObjective;
322
323        let candidates = vec![
324            candidate(Some(0.3), None, None, None),
325            candidate(Some(0.9), None, None, None),
326            candidate(Some(0.6), None, None, None),
327        ];
328
329        let top = VectorSimilarityObjective.select_top_deterministic(&candidates, 2, &ctx());
330
331        assert_eq!(top.len(), 2);
332        assert!((top[0].score - 0.9).abs() < 1e-12);
333        assert!((top[1].score - 0.6).abs() < 1e-12);
334    }
335}