Skip to main content

khive_fold/objective/
builtin.rs

1//! Built-in objective functions
2
3use crate::{Objective, ObjectiveContext, Selection};
4
5/// Selects candidate with highest score.
6pub struct MaxScoreObjective<T, F>
7where
8    F: Fn(&T) -> f64 + Send + Sync,
9{
10    scorer: F,
11    _phantom: std::marker::PhantomData<T>,
12}
13
14impl<T, F> MaxScoreObjective<T, F>
15where
16    F: Fn(&T) -> f64 + Send + Sync,
17{
18    /// Create a new max score objective
19    pub fn new(scorer: F) -> Self {
20        Self {
21            scorer,
22            _phantom: std::marker::PhantomData,
23        }
24    }
25}
26
27impl<T: Send + Sync, F> Objective<T> for MaxScoreObjective<T, F>
28where
29    F: Fn(&T) -> f64 + Send + Sync,
30{
31    fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
32        (self.scorer)(candidate)
33    }
34
35    fn name(&self) -> &str {
36        "MaxScoreObjective"
37    }
38}
39
40/// Passes candidates above a threshold.
41pub struct ThresholdObjective<T, F>
42where
43    F: Fn(&T) -> f64 + Send + Sync,
44{
45    scorer: F,
46    threshold: f64,
47    _phantom: std::marker::PhantomData<T>,
48}
49
50impl<T, F> ThresholdObjective<T, F>
51where
52    F: Fn(&T) -> f64 + Send + Sync,
53{
54    /// Create a new threshold objective
55    pub fn new(scorer: F, threshold: f64) -> Self {
56        Self {
57            scorer,
58            threshold,
59            _phantom: std::marker::PhantomData,
60        }
61    }
62}
63
64impl<T: Send + Sync, F> Objective<T> for ThresholdObjective<T, F>
65where
66    F: Fn(&T) -> f64 + Send + Sync,
67{
68    fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
69        (self.scorer)(candidate)
70    }
71
72    fn passes_score(&self, score: f64, context: &ObjectiveContext) -> bool {
73        if !score.is_finite() {
74            return false;
75        }
76        let passes_obj = score >= self.threshold;
77        let passes_ctx = context.min_score.map(|min| score >= min).unwrap_or(true);
78        passes_obj && passes_ctx
79    }
80
81    fn passes(&self, candidate: &T, context: &ObjectiveContext) -> bool {
82        let score = (self.scorer)(candidate);
83        self.passes_score(score, context)
84    }
85
86    fn name(&self) -> &str {
87        "ThresholdObjective"
88    }
89}
90
91/// Returns first candidate that passes predicate.
92pub struct FirstMatchObjective<T, F>
93where
94    F: Fn(&T) -> bool + Send + Sync,
95{
96    predicate: F,
97    _phantom: std::marker::PhantomData<T>,
98}
99
100impl<T, F> FirstMatchObjective<T, F>
101where
102    F: Fn(&T) -> bool + Send + Sync,
103{
104    /// Create a new first match objective
105    pub fn new(predicate: F) -> Self {
106        Self {
107            predicate,
108            _phantom: std::marker::PhantomData,
109        }
110    }
111}
112
113impl<T: Send + Sync, F> Objective<T> for FirstMatchObjective<T, F>
114where
115    F: Fn(&T) -> bool + Send + Sync,
116{
117    fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
118        if (self.predicate)(candidate) {
119            1.0
120        } else {
121            0.0
122        }
123    }
124
125    fn select<'a>(&self, candidates: &'a [T], context: &ObjectiveContext) -> Vec<Selection<&'a T>> {
126        if candidates.is_empty() {
127            return Vec::new();
128        }
129
130        let limit = context
131            .max_candidates
132            .unwrap_or(candidates.len())
133            .min(candidates.len());
134
135        for (i, candidate) in candidates.iter().take(limit).enumerate() {
136            if (self.predicate)(candidate) {
137                return vec![Selection::new(candidate, 1.0, i)
138                    .with_considered(i + 1)
139                    .with_passed(1)];
140            }
141        }
142
143        Vec::new()
144    }
145
146    fn name(&self) -> &str {
147        "FirstMatchObjective"
148    }
149}
150
151/// Trait for items with timestamps
152pub trait HasTimestamp {
153    /// Returns the timestamp of this item
154    fn timestamp(&self) -> chrono::DateTime<chrono::Utc>;
155}
156
157/// Trait for items with importance
158pub trait HasImportance {
159    /// Returns the importance value of this item (0.0 to 1.0)
160    fn importance(&self) -> f64;
161}
162
163/// Scores by recency (newer = higher score).
164pub struct RecencyObjective {
165    half_life_seconds: f64,
166}
167
168impl RecencyObjective {
169    const MIN_HALF_LIFE: f64 = 1.0;
170
171    /// Create a new recency objective.
172    ///
173    /// # Panics
174    /// Panics if `half_life_seconds` is not positive and finite.
175    pub fn new(half_life_seconds: f64) -> Self {
176        assert!(
177            half_life_seconds.is_finite() && half_life_seconds > 0.0,
178            "half_life_seconds must be positive and finite, got {half_life_seconds}"
179        );
180        Self {
181            half_life_seconds: half_life_seconds.max(Self::MIN_HALF_LIFE),
182        }
183    }
184
185    /// Create with hour half-life
186    pub fn hours(hours: f64) -> Self {
187        Self::new(hours * 3600.0)
188    }
189
190    /// Create with day half-life
191    pub fn days(days: f64) -> Self {
192        Self::new(days * 86400.0)
193    }
194}
195
196impl<T: HasTimestamp + Send + Sync> Objective<T> for RecencyObjective {
197    fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
198        let age_seconds = (context.as_of - candidate.timestamp()).num_seconds().max(0) as f64;
199        0.5f64.powf(age_seconds / self.half_life_seconds)
200    }
201
202    fn name(&self) -> &str {
203        "RecencyObjective"
204    }
205}
206
207/// Scores by importance field.
208pub struct ImportanceObjective {
209    min_importance: f64,
210}
211
212impl ImportanceObjective {
213    /// Create a new importance objective
214    pub fn new() -> Self {
215        Self {
216            min_importance: 0.0,
217        }
218    }
219
220    /// Set minimum importance
221    pub fn with_min(mut self, min: f64) -> Self {
222        self.min_importance = min;
223        self
224    }
225}
226
227impl Default for ImportanceObjective {
228    fn default() -> Self {
229        Self::new()
230    }
231}
232
233impl<T: HasImportance + Send + Sync> Objective<T> for ImportanceObjective {
234    fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
235        let importance = candidate.importance();
236        if importance >= self.min_importance {
237            importance
238        } else {
239            0.0
240        }
241    }
242
243    fn name(&self) -> &str {
244        "ImportanceObjective"
245    }
246}
247
248/// Combines recency and importance.
249pub struct RelevanceObjective {
250    recency_weight: f64,
251    importance_weight: f64,
252    recency: RecencyObjective,
253}
254
255impl RelevanceObjective {
256    /// Create a new relevance objective.
257    ///
258    /// # Panics
259    /// Panics if either weight is negative or non-finite.
260    pub fn new(recency_half_life: f64, recency_weight: f64, importance_weight: f64) -> Self {
261        assert!(
262            recency_weight.is_finite() && recency_weight >= 0.0,
263            "recency_weight must be finite and non-negative, got {recency_weight}"
264        );
265        assert!(
266            importance_weight.is_finite() && importance_weight >= 0.0,
267            "importance_weight must be finite and non-negative, got {importance_weight}"
268        );
269        Self {
270            recency_weight,
271            importance_weight,
272            recency: RecencyObjective::new(recency_half_life),
273        }
274    }
275
276    /// Create with default weights (0.5 each)
277    pub fn balanced(recency_half_life: f64) -> Self {
278        Self::new(recency_half_life, 0.5, 0.5)
279    }
280}
281
282impl<T: HasTimestamp + HasImportance + Send + Sync> Objective<T> for RelevanceObjective {
283    fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
284        // If context carries a named relevance score, use it directly.
285        if let Some(v) = context
286            .extra
287            .get("relevance_score")
288            .and_then(|v| v.as_f64())
289        {
290            return v;
291        }
292
293        let recency_score = self.recency.score(candidate, context);
294        let importance_score = candidate.importance();
295
296        let total_weight = self.recency_weight + self.importance_weight;
297        if total_weight > 0.0 {
298            (self.recency_weight * recency_score + self.importance_weight * importance_score)
299                / total_weight
300        } else {
301            0.0
302        }
303    }
304
305    fn name(&self) -> &str {
306        "RelevanceObjective"
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    #[test]
315    fn test_max_score_objective() {
316        let objective = MaxScoreObjective::new(|n: &i32| *n as f64);
317
318        let candidates = vec![1, 5, 3, 8, 2];
319        let selection = objective
320            .select(&candidates, &ObjectiveContext::new())
321            .into_iter()
322            .next()
323            .unwrap();
324
325        assert_eq!(*selection.item, 8);
326    }
327
328    #[test]
329    fn test_threshold_objective() {
330        let objective = ThresholdObjective::new(|n: &i32| *n as f64, 5.0);
331
332        assert!(objective.passes(&10, &ObjectiveContext::new()));
333        assert!(!objective.passes(&3, &ObjectiveContext::new()));
334    }
335
336    #[test]
337    fn test_threshold_objective_rejects_infinite_scores() {
338        let objective = ThresholdObjective::new(|_n: &i32| f64::INFINITY, 5.0);
339
340        assert!(!objective.passes(&10, &ObjectiveContext::new()));
341    }
342
343    #[test]
344    fn test_first_match_objective() {
345        let objective = FirstMatchObjective::new(|n: &i32| *n > 5);
346
347        let candidates = vec![1, 3, 7, 9, 2];
348        let selection = objective
349            .select(&candidates, &ObjectiveContext::new())
350            .into_iter()
351            .next()
352            .unwrap();
353
354        assert_eq!(*selection.item, 7);
355        assert_eq!(selection.index, 2);
356    }
357
358    #[test]
359    fn test_first_match_respects_max_candidates() {
360        let objective = FirstMatchObjective::new(|n: &i32| *n > 5);
361
362        // Match is at index 2 (value 7), but max_candidates=2 limits scan to indices 0..1.
363        let candidates = vec![1, 3, 7, 9, 2];
364        let context = ObjectiveContext::new().with_max_candidates(2);
365        let result = objective.select(&candidates, &context);
366
367        assert!(result.is_empty());
368    }
369
370    #[derive(Clone)]
371    struct TestItem {
372        _value: i32,
373        timestamp: chrono::DateTime<chrono::Utc>,
374        importance: f64,
375    }
376
377    impl HasTimestamp for TestItem {
378        fn timestamp(&self) -> chrono::DateTime<chrono::Utc> {
379            self.timestamp
380        }
381    }
382
383    impl HasImportance for TestItem {
384        fn importance(&self) -> f64 {
385            self.importance
386        }
387    }
388
389    #[test]
390    fn test_recency_objective() {
391        let objective = RecencyObjective::hours(1.0);
392        let context = ObjectiveContext::new();
393
394        let now = chrono::Utc::now();
395        let old = now - chrono::Duration::hours(2);
396
397        let new_item = TestItem {
398            _value: 1,
399            timestamp: now,
400            importance: 0.5,
401        };
402        let old_item = TestItem {
403            _value: 2,
404            timestamp: old,
405            importance: 0.5,
406        };
407
408        let new_score = objective.score(&new_item, &context);
409        let old_score = objective.score(&old_item, &context);
410
411        assert!(new_score > old_score);
412        assert!((new_score - 1.0).abs() < 0.1);
413    }
414
415    #[test]
416    fn test_relevance_objective() {
417        let objective = RelevanceObjective::balanced(3600.0);
418        let context = ObjectiveContext::new();
419
420        let now = chrono::Utc::now();
421
422        let item = TestItem {
423            _value: 1,
424            timestamp: now,
425            importance: 0.8,
426        };
427
428        let score = objective.score(&item, &context);
429
430        assert!(score > 0.8 && score < 1.0);
431    }
432
433    #[test]
434    fn test_relevance_uses_context_relevance_score() {
435        let objective = RelevanceObjective::balanced(3600.0);
436        let context =
437            ObjectiveContext::new().with_extra(serde_json::json!({"relevance_score": 0.42}));
438
439        let now = chrono::Utc::now();
440        let item = TestItem {
441            _value: 1,
442            timestamp: now,
443            importance: 0.9,
444        };
445
446        // The context relevance_score should override the recency+importance fusion.
447        let score = objective.score(&item, &context);
448        assert!((score - 0.42).abs() < 1e-9);
449    }
450
451    #[test]
452    #[should_panic(expected = "recency_weight must be finite and non-negative")]
453    fn test_relevance_negative_recency_weight_panics() {
454        RelevanceObjective::new(3600.0, -0.1, 0.5);
455    }
456
457    #[test]
458    #[should_panic(expected = "importance_weight must be finite and non-negative")]
459    fn test_relevance_nan_importance_weight_panics() {
460        RelevanceObjective::new(3600.0, 0.5, f64::NAN);
461    }
462
463    #[test]
464    #[should_panic(expected = "half_life_seconds must be positive and finite")]
465    fn test_recency_zero_half_life_panics() {
466        RecencyObjective::new(0.0);
467    }
468
469    #[test]
470    #[should_panic(expected = "half_life_seconds must be positive and finite")]
471    fn test_recency_negative_half_life_panics() {
472        RecencyObjective::new(-1.0);
473    }
474
475    #[test]
476    #[should_panic(expected = "half_life_seconds must be positive and finite")]
477    fn test_recency_nan_half_life_panics() {
478        RecencyObjective::new(f64::NAN);
479    }
480
481    #[test]
482    fn test_threshold_no_match_below_threshold() {
483        let objective = ThresholdObjective::new(|n: &i32| *n as f64, 10.0);
484
485        let candidates = vec![1, 5, 3];
486        let result = objective.select(&candidates, &ObjectiveContext::new());
487
488        assert!(result.is_empty());
489    }
490
491    #[test]
492    fn test_threshold_selects_best_above() {
493        let objective = ThresholdObjective::new(|n: &i32| *n as f64, 5.0);
494
495        let candidates = vec![1, 10, 3, 15];
496        let selection = objective
497            .select(&candidates, &ObjectiveContext::new())
498            .into_iter()
499            .next()
500            .unwrap();
501
502        assert_eq!(*selection.item, 15);
503        assert_eq!(selection.score, 15.0);
504        assert_eq!(selection.passed, 2);
505    }
506}