Skip to main content

khive_fold/objective/
builtin.rs

1//! Built-in objective functions
2
3use crate::{Objective, ObjectiveContext, Selection};
4
5/// Selects candidate with highest score.
6pub struct MaxScoreObjective<T, F>
7where
8    F: Fn(&T) -> f64 + Send + Sync,
9{
10    scorer: F,
11    _phantom: std::marker::PhantomData<T>,
12}
13
14impl<T, F> MaxScoreObjective<T, F>
15where
16    F: Fn(&T) -> f64 + Send + Sync,
17{
18    /// Create a new max score objective
19    pub fn new(scorer: F) -> Self {
20        Self {
21            scorer,
22            _phantom: std::marker::PhantomData,
23        }
24    }
25}
26
27impl<T: Send + Sync, F> Objective<T> for MaxScoreObjective<T, F>
28where
29    F: Fn(&T) -> f64 + Send + Sync,
30{
31    fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
32        (self.scorer)(candidate)
33    }
34
35    fn name(&self) -> &str {
36        "MaxScoreObjective"
37    }
38}
39
40/// Passes candidates above a threshold.
41pub struct ThresholdObjective<T, F>
42where
43    F: Fn(&T) -> f64 + Send + Sync,
44{
45    scorer: F,
46    threshold: f64,
47    _phantom: std::marker::PhantomData<T>,
48}
49
50impl<T, F> ThresholdObjective<T, F>
51where
52    F: Fn(&T) -> f64 + Send + Sync,
53{
54    /// Create a new threshold objective
55    pub fn new(scorer: F, threshold: f64) -> Self {
56        Self {
57            scorer,
58            threshold,
59            _phantom: std::marker::PhantomData,
60        }
61    }
62}
63
64impl<T: Send + Sync, F> Objective<T> for ThresholdObjective<T, F>
65where
66    F: Fn(&T) -> f64 + Send + Sync,
67{
68    fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
69        (self.scorer)(candidate)
70    }
71
72    fn passes_score(&self, score: f64, context: &ObjectiveContext) -> bool {
73        if !score.is_finite() {
74            return false;
75        }
76        let passes_obj = score >= self.threshold;
77        let passes_ctx = context.min_score.map(|min| score >= min).unwrap_or(true);
78        passes_obj && passes_ctx
79    }
80
81    fn passes(&self, candidate: &T, context: &ObjectiveContext) -> bool {
82        let score = (self.scorer)(candidate);
83        self.passes_score(score, context)
84    }
85
86    fn name(&self) -> &str {
87        "ThresholdObjective"
88    }
89}
90
91/// Returns first candidate that passes predicate.
92pub struct FirstMatchObjective<T, F>
93where
94    F: Fn(&T) -> bool + Send + Sync,
95{
96    predicate: F,
97    _phantom: std::marker::PhantomData<T>,
98}
99
100impl<T, F> FirstMatchObjective<T, F>
101where
102    F: Fn(&T) -> bool + Send + Sync,
103{
104    /// Create a new first match objective
105    pub fn new(predicate: F) -> Self {
106        Self {
107            predicate,
108            _phantom: std::marker::PhantomData,
109        }
110    }
111}
112
113impl<T: Send + Sync, F> Objective<T> for FirstMatchObjective<T, F>
114where
115    F: Fn(&T) -> bool + Send + Sync,
116{
117    fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
118        if (self.predicate)(candidate) {
119            1.0
120        } else {
121            0.0
122        }
123    }
124
125    fn select<'a>(&self, candidates: &'a [T], context: &ObjectiveContext) -> Vec<Selection<&'a T>> {
126        if candidates.is_empty() {
127            return Vec::new();
128        }
129
130        let limit = context
131            .max_candidates
132            .unwrap_or(candidates.len())
133            .min(candidates.len());
134
135        for (i, candidate) in candidates.iter().take(limit).enumerate() {
136            if (self.predicate)(candidate) {
137                return vec![Selection::new(candidate, 1.0, i)
138                    .with_considered(i + 1)
139                    .with_passed(1)];
140            }
141        }
142
143        Vec::new()
144    }
145
146    fn name(&self) -> &str {
147        "FirstMatchObjective"
148    }
149}
150
151/// Trait for items with timestamps
152pub trait HasTimestamp {
153    /// Returns the timestamp of this item
154    fn timestamp(&self) -> chrono::DateTime<chrono::Utc>;
155}
156
157/// Trait for items with salience
158pub trait HasSalience {
159    /// Returns the salience value of this item (0.0 to 1.0)
160    fn salience(&self) -> f64;
161}
162
163/// Scores by recency (newer = higher score).
164pub struct RecencyObjective {
165    half_life_seconds: f64,
166}
167
168impl RecencyObjective {
169    const MIN_HALF_LIFE: f64 = 1.0;
170
171    /// Create a new recency objective.
172    ///
173    /// # Panics
174    /// Panics if `half_life_seconds` is not positive and finite.
175    pub fn new(half_life_seconds: f64) -> Self {
176        assert!(
177            half_life_seconds.is_finite() && half_life_seconds > 0.0,
178            "half_life_seconds must be positive and finite, got {half_life_seconds}"
179        );
180        Self {
181            half_life_seconds: half_life_seconds.max(Self::MIN_HALF_LIFE),
182        }
183    }
184
185    /// Create with hour half-life
186    pub fn hours(hours: f64) -> Self {
187        Self::new(hours * 3600.0)
188    }
189
190    /// Create with day half-life
191    pub fn days(days: f64) -> Self {
192        Self::new(days * 86400.0)
193    }
194}
195
196impl<T: HasTimestamp + Send + Sync> Objective<T> for RecencyObjective {
197    fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
198        let age_seconds = (context.as_of - candidate.timestamp()).num_seconds().max(0) as f64;
199        0.5f64.powf(age_seconds / self.half_life_seconds)
200    }
201
202    fn name(&self) -> &str {
203        "RecencyObjective"
204    }
205}
206
207/// Scores by salience field.
208pub struct SalienceObjective {
209    min_salience: f64,
210}
211
212impl SalienceObjective {
213    /// Create a new salience objective
214    pub fn new() -> Self {
215        Self { min_salience: 0.0 }
216    }
217
218    /// Set minimum salience
219    pub fn with_min(mut self, min: f64) -> Self {
220        self.min_salience = min;
221        self
222    }
223}
224
225impl Default for SalienceObjective {
226    fn default() -> Self {
227        Self::new()
228    }
229}
230
231impl<T: HasSalience + Send + Sync> Objective<T> for SalienceObjective {
232    fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
233        let salience = candidate.salience();
234        if salience >= self.min_salience {
235            salience
236        } else {
237            0.0
238        }
239    }
240
241    fn name(&self) -> &str {
242        "SalienceObjective"
243    }
244}
245
246/// Combines recency and salience.
247pub struct RelevanceObjective {
248    recency_weight: f64,
249    salience_weight: f64,
250    recency: RecencyObjective,
251}
252
253impl RelevanceObjective {
254    /// Create a new relevance objective.
255    ///
256    /// # Panics
257    /// Panics if either weight is negative or non-finite.
258    pub fn new(recency_half_life: f64, recency_weight: f64, salience_weight: f64) -> Self {
259        assert!(
260            recency_weight.is_finite() && recency_weight >= 0.0,
261            "recency_weight must be finite and non-negative, got {recency_weight}"
262        );
263        assert!(
264            salience_weight.is_finite() && salience_weight >= 0.0,
265            "salience_weight must be finite and non-negative, got {salience_weight}"
266        );
267        Self {
268            recency_weight,
269            salience_weight,
270            recency: RecencyObjective::new(recency_half_life),
271        }
272    }
273
274    /// Create with default weights (0.5 each)
275    pub fn balanced(recency_half_life: f64) -> Self {
276        Self::new(recency_half_life, 0.5, 0.5)
277    }
278}
279
280impl<T: HasTimestamp + HasSalience + Send + Sync> Objective<T> for RelevanceObjective {
281    fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
282        // If context carries a named relevance score, use it directly.
283        if let Some(v) = context
284            .extra
285            .get("relevance_score")
286            .and_then(|v| v.as_f64())
287        {
288            return v;
289        }
290
291        let recency_score = self.recency.score(candidate, context);
292        let salience_score = candidate.salience();
293
294        let total_weight = self.recency_weight + self.salience_weight;
295        if total_weight > 0.0 {
296            (self.recency_weight * recency_score + self.salience_weight * salience_score)
297                / total_weight
298        } else {
299            0.0
300        }
301    }
302
303    fn name(&self) -> &str {
304        "RelevanceObjective"
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[test]
313    fn test_max_score_objective() {
314        let objective = MaxScoreObjective::new(|n: &i32| *n as f64);
315
316        let candidates = vec![1, 5, 3, 8, 2];
317        let selection = objective
318            .select(&candidates, &ObjectiveContext::new())
319            .into_iter()
320            .next()
321            .unwrap();
322
323        assert_eq!(*selection.item, 8);
324    }
325
326    #[test]
327    fn test_threshold_objective() {
328        let objective = ThresholdObjective::new(|n: &i32| *n as f64, 5.0);
329
330        assert!(objective.passes(&10, &ObjectiveContext::new()));
331        assert!(!objective.passes(&3, &ObjectiveContext::new()));
332    }
333
334    #[test]
335    fn test_threshold_objective_rejects_infinite_scores() {
336        let objective = ThresholdObjective::new(|_n: &i32| f64::INFINITY, 5.0);
337
338        assert!(!objective.passes(&10, &ObjectiveContext::new()));
339    }
340
341    #[test]
342    fn test_first_match_objective() {
343        let objective = FirstMatchObjective::new(|n: &i32| *n > 5);
344
345        let candidates = vec![1, 3, 7, 9, 2];
346        let selection = objective
347            .select(&candidates, &ObjectiveContext::new())
348            .into_iter()
349            .next()
350            .unwrap();
351
352        assert_eq!(*selection.item, 7);
353        assert_eq!(selection.index, 2);
354    }
355
356    #[test]
357    fn test_first_match_respects_max_candidates() {
358        let objective = FirstMatchObjective::new(|n: &i32| *n > 5);
359
360        // Match is at index 2 (value 7), but max_candidates=2 limits scan to indices 0..1.
361        let candidates = vec![1, 3, 7, 9, 2];
362        let context = ObjectiveContext::new().with_max_candidates(2);
363        let result = objective.select(&candidates, &context);
364
365        assert!(result.is_empty());
366    }
367
368    #[derive(Clone)]
369    struct TestItem {
370        _value: i32,
371        timestamp: chrono::DateTime<chrono::Utc>,
372        salience: f64,
373    }
374
375    impl HasTimestamp for TestItem {
376        fn timestamp(&self) -> chrono::DateTime<chrono::Utc> {
377            self.timestamp
378        }
379    }
380
381    impl HasSalience for TestItem {
382        fn salience(&self) -> f64 {
383            self.salience
384        }
385    }
386
387    #[test]
388    fn test_recency_objective() {
389        let objective = RecencyObjective::hours(1.0);
390        let now = chrono::Utc::now();
391        // Pass current time explicitly — ObjectiveContext::new() gives epoch per ADR-024.
392        let context = ObjectiveContext::at(now);
393
394        let old = now - chrono::Duration::hours(2);
395
396        let new_item = TestItem {
397            _value: 1,
398            timestamp: now,
399            salience: 0.5,
400        };
401        let old_item = TestItem {
402            _value: 2,
403            timestamp: old,
404            salience: 0.5,
405        };
406
407        let new_score = objective.score(&new_item, &context);
408        let old_score = objective.score(&old_item, &context);
409
410        assert!(new_score > old_score);
411        assert!((new_score - 1.0).abs() < 0.1);
412    }
413
414    #[test]
415    fn test_relevance_objective() {
416        let objective = RelevanceObjective::balanced(3600.0);
417        let now = chrono::Utc::now();
418        // Pass current time explicitly — ObjectiveContext::new() gives epoch per ADR-024.
419        let context = ObjectiveContext::at(now);
420
421        let item = TestItem {
422            _value: 1,
423            timestamp: now,
424            salience: 0.8,
425        };
426
427        let score = objective.score(&item, &context);
428
429        assert!(score > 0.8 && score < 1.0);
430    }
431
432    #[test]
433    fn test_relevance_uses_context_relevance_score() {
434        let objective = RelevanceObjective::balanced(3600.0);
435        let now = chrono::Utc::now();
436        // Pass current time explicitly — ObjectiveContext::new() gives epoch per ADR-024.
437        let context =
438            ObjectiveContext::at(now).with_extra(serde_json::json!({"relevance_score": 0.42}));
439
440        let item = TestItem {
441            _value: 1,
442            timestamp: now,
443            salience: 0.9,
444        };
445
446        // The context relevance_score should override the recency+salience fusion.
447        let score = objective.score(&item, &context);
448        assert!((score - 0.42).abs() < 1e-9);
449    }
450
451    #[test]
452    #[should_panic(expected = "recency_weight must be finite and non-negative")]
453    fn test_relevance_negative_recency_weight_panics() {
454        RelevanceObjective::new(3600.0, -0.1, 0.5);
455    }
456
457    #[test]
458    #[should_panic(expected = "salience_weight must be finite and non-negative")]
459    fn test_relevance_nan_salience_weight_panics() {
460        RelevanceObjective::new(3600.0, 0.5, f64::NAN);
461    }
462
463    #[test]
464    #[should_panic(expected = "half_life_seconds must be positive and finite")]
465    fn test_recency_zero_half_life_panics() {
466        RecencyObjective::new(0.0);
467    }
468
469    #[test]
470    #[should_panic(expected = "half_life_seconds must be positive and finite")]
471    fn test_recency_negative_half_life_panics() {
472        RecencyObjective::new(-1.0);
473    }
474
475    #[test]
476    #[should_panic(expected = "half_life_seconds must be positive and finite")]
477    fn test_recency_nan_half_life_panics() {
478        RecencyObjective::new(f64::NAN);
479    }
480
481    #[test]
482    fn test_threshold_no_match_below_threshold() {
483        let objective = ThresholdObjective::new(|n: &i32| *n as f64, 10.0);
484
485        let candidates = vec![1, 5, 3];
486        let result = objective.select(&candidates, &ObjectiveContext::new());
487
488        assert!(result.is_empty());
489    }
490
491    #[test]
492    fn test_threshold_selects_best_above() {
493        let objective = ThresholdObjective::new(|n: &i32| *n as f64, 5.0);
494
495        let candidates = vec![1, 10, 3, 15];
496        let selection = objective
497            .select(&candidates, &ObjectiveContext::new())
498            .into_iter()
499            .next()
500            .unwrap();
501
502        assert_eq!(*selection.item, 15);
503        assert_eq!(selection.score, 15.0);
504        assert_eq!(selection.passed, 2);
505    }
506}