Skip to main content

khive_fold/objective/
builtin.rs

1//! Built-in objective functions
2
3use crate::{Objective, ObjectiveContext, ObjectiveError, ObjectiveResult, Selection};
4
5/// Selects candidate with highest score.
6pub struct MaxScoreObjective<T, F>
7where
8    F: Fn(&T) -> f64 + Send + Sync,
9{
10    scorer: F,
11    _phantom: std::marker::PhantomData<T>,
12}
13
14impl<T, F> MaxScoreObjective<T, F>
15where
16    F: Fn(&T) -> f64 + Send + Sync,
17{
18    /// Create a new max score objective
19    pub fn new(scorer: F) -> Self {
20        Self {
21            scorer,
22            _phantom: std::marker::PhantomData,
23        }
24    }
25}
26
27impl<T: Send + Sync, F> Objective<T> for MaxScoreObjective<T, F>
28where
29    F: Fn(&T) -> f64 + Send + Sync,
30{
31    fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
32        (self.scorer)(candidate)
33    }
34
35    fn name(&self) -> &str {
36        "MaxScoreObjective"
37    }
38}
39
40/// Passes candidates above a threshold.
41pub struct ThresholdObjective<T, F>
42where
43    F: Fn(&T) -> f64 + Send + Sync,
44{
45    scorer: F,
46    threshold: f64,
47    _phantom: std::marker::PhantomData<T>,
48}
49
50impl<T, F> ThresholdObjective<T, F>
51where
52    F: Fn(&T) -> f64 + Send + Sync,
53{
54    /// Create a new threshold objective
55    pub fn new(scorer: F, threshold: f64) -> Self {
56        Self {
57            scorer,
58            threshold,
59            _phantom: std::marker::PhantomData,
60        }
61    }
62}
63
64impl<T: Send + Sync, F> Objective<T> for ThresholdObjective<T, F>
65where
66    F: Fn(&T) -> f64 + Send + Sync,
67{
68    fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
69        (self.scorer)(candidate)
70    }
71
72    fn passes_score(&self, score: f64, context: &ObjectiveContext) -> bool {
73        if !score.is_finite() {
74            return false;
75        }
76        let passes_obj = score >= self.threshold;
77        let passes_ctx = context.min_score.map(|min| score >= min).unwrap_or(true);
78        passes_obj && passes_ctx
79    }
80
81    fn passes(&self, candidate: &T, context: &ObjectiveContext) -> bool {
82        let score = (self.scorer)(candidate);
83        self.passes_score(score, context)
84    }
85
86    fn name(&self) -> &str {
87        "ThresholdObjective"
88    }
89}
90
91/// Returns first candidate that passes predicate.
92pub struct FirstMatchObjective<T, F>
93where
94    F: Fn(&T) -> bool + Send + Sync,
95{
96    predicate: F,
97    _phantom: std::marker::PhantomData<T>,
98}
99
100impl<T, F> FirstMatchObjective<T, F>
101where
102    F: Fn(&T) -> bool + Send + Sync,
103{
104    /// Create a new first match objective
105    pub fn new(predicate: F) -> Self {
106        Self {
107            predicate,
108            _phantom: std::marker::PhantomData,
109        }
110    }
111}
112
113impl<T: Send + Sync, F> Objective<T> for FirstMatchObjective<T, F>
114where
115    F: Fn(&T) -> bool + Send + Sync,
116{
117    fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
118        if (self.predicate)(candidate) {
119            1.0
120        } else {
121            0.0
122        }
123    }
124
125    fn select<'a>(
126        &self,
127        candidates: &'a [T],
128        context: &ObjectiveContext,
129    ) -> ObjectiveResult<Selection<&'a T>> {
130        if candidates.is_empty() {
131            return Err(ObjectiveError::NoCandidates);
132        }
133
134        let limit = context
135            .max_candidates
136            .unwrap_or(candidates.len())
137            .min(candidates.len());
138
139        for (i, candidate) in candidates.iter().take(limit).enumerate() {
140            if (self.predicate)(candidate) {
141                return Ok(Selection::new(candidate, 1.0, i)
142                    .with_considered(i + 1)
143                    .with_passed(1));
144            }
145        }
146
147        Err(ObjectiveError::NoMatch(
148            "No candidate matched predicate".into(),
149        ))
150    }
151
152    fn name(&self) -> &str {
153        "FirstMatchObjective"
154    }
155}
156
157/// Trait for items with timestamps
158pub trait HasTimestamp {
159    /// Returns the timestamp of this item
160    fn timestamp(&self) -> chrono::DateTime<chrono::Utc>;
161}
162
163/// Trait for items with importance
164pub trait HasImportance {
165    /// Returns the importance value of this item (0.0 to 1.0)
166    fn importance(&self) -> f64;
167}
168
169/// Scores by recency (newer = higher score).
170pub struct RecencyObjective {
171    half_life_seconds: f64,
172}
173
174impl RecencyObjective {
175    const MIN_HALF_LIFE: f64 = 1.0;
176
177    /// Create a new recency objective.
178    ///
179    /// # Panics
180    /// Panics if `half_life_seconds` is not positive and finite.
181    pub fn new(half_life_seconds: f64) -> Self {
182        assert!(
183            half_life_seconds.is_finite() && half_life_seconds > 0.0,
184            "half_life_seconds must be positive and finite, got {half_life_seconds}"
185        );
186        Self {
187            half_life_seconds: half_life_seconds.max(Self::MIN_HALF_LIFE),
188        }
189    }
190
191    /// Create with hour half-life
192    pub fn hours(hours: f64) -> Self {
193        Self::new(hours * 3600.0)
194    }
195
196    /// Create with day half-life
197    pub fn days(days: f64) -> Self {
198        Self::new(days * 86400.0)
199    }
200}
201
202impl<T: HasTimestamp + Send + Sync> Objective<T> for RecencyObjective {
203    fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
204        let age_seconds = (context.as_of - candidate.timestamp()).num_seconds().max(0) as f64;
205        0.5f64.powf(age_seconds / self.half_life_seconds)
206    }
207
208    fn name(&self) -> &str {
209        "RecencyObjective"
210    }
211}
212
213/// Scores by importance field.
214pub struct ImportanceObjective {
215    min_importance: f64,
216}
217
218impl ImportanceObjective {
219    /// Create a new importance objective
220    pub fn new() -> Self {
221        Self {
222            min_importance: 0.0,
223        }
224    }
225
226    /// Set minimum importance
227    pub fn with_min(mut self, min: f64) -> Self {
228        self.min_importance = min;
229        self
230    }
231}
232
233impl Default for ImportanceObjective {
234    fn default() -> Self {
235        Self::new()
236    }
237}
238
239impl<T: HasImportance + Send + Sync> Objective<T> for ImportanceObjective {
240    fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
241        let importance = candidate.importance();
242        if importance >= self.min_importance {
243            importance
244        } else {
245            0.0
246        }
247    }
248
249    fn name(&self) -> &str {
250        "ImportanceObjective"
251    }
252}
253
254/// Combines recency and importance.
255pub struct RelevanceObjective {
256    recency_weight: f64,
257    importance_weight: f64,
258    recency: RecencyObjective,
259}
260
261impl RelevanceObjective {
262    /// Create a new relevance objective.
263    ///
264    /// # Panics
265    /// Panics if either weight is negative or non-finite.
266    pub fn new(recency_half_life: f64, recency_weight: f64, importance_weight: f64) -> Self {
267        assert!(
268            recency_weight.is_finite() && recency_weight >= 0.0,
269            "recency_weight must be finite and non-negative, got {recency_weight}"
270        );
271        assert!(
272            importance_weight.is_finite() && importance_weight >= 0.0,
273            "importance_weight must be finite and non-negative, got {importance_weight}"
274        );
275        Self {
276            recency_weight,
277            importance_weight,
278            recency: RecencyObjective::new(recency_half_life),
279        }
280    }
281
282    /// Create with default weights (0.5 each)
283    pub fn balanced(recency_half_life: f64) -> Self {
284        Self::new(recency_half_life, 0.5, 0.5)
285    }
286}
287
288impl<T: HasTimestamp + HasImportance + Send + Sync> Objective<T> for RelevanceObjective {
289    fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
290        // If context carries a named relevance score, use it directly.
291        if let Some(v) = context
292            .extra
293            .get("relevance_score")
294            .and_then(|v| v.as_f64())
295        {
296            return v;
297        }
298
299        let recency_score = self.recency.score(candidate, context);
300        let importance_score = candidate.importance();
301
302        let total_weight = self.recency_weight + self.importance_weight;
303        if total_weight > 0.0 {
304            (self.recency_weight * recency_score + self.importance_weight * importance_score)
305                / total_weight
306        } else {
307            0.0
308        }
309    }
310
311    fn name(&self) -> &str {
312        "RelevanceObjective"
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319
320    #[test]
321    fn test_max_score_objective() {
322        let objective = MaxScoreObjective::new(|n: &i32| *n as f64);
323
324        let candidates = vec![1, 5, 3, 8, 2];
325        let selection = objective
326            .select(&candidates, &ObjectiveContext::new())
327            .unwrap();
328
329        assert_eq!(*selection.item, 8);
330    }
331
332    #[test]
333    fn test_threshold_objective() {
334        let objective = ThresholdObjective::new(|n: &i32| *n as f64, 5.0);
335
336        assert!(objective.passes(&10, &ObjectiveContext::new()));
337        assert!(!objective.passes(&3, &ObjectiveContext::new()));
338    }
339
340    #[test]
341    fn test_threshold_objective_rejects_infinite_scores() {
342        let objective = ThresholdObjective::new(|_n: &i32| f64::INFINITY, 5.0);
343
344        assert!(!objective.passes(&10, &ObjectiveContext::new()));
345    }
346
347    #[test]
348    fn test_first_match_objective() {
349        let objective = FirstMatchObjective::new(|n: &i32| *n > 5);
350
351        let candidates = vec![1, 3, 7, 9, 2];
352        let selection = objective
353            .select(&candidates, &ObjectiveContext::new())
354            .unwrap();
355
356        assert_eq!(*selection.item, 7);
357        assert_eq!(selection.index, 2);
358    }
359
360    #[test]
361    fn test_first_match_respects_max_candidates() {
362        let objective = FirstMatchObjective::new(|n: &i32| *n > 5);
363
364        // Match is at index 2 (value 7), but max_candidates=2 limits scan to indices 0..1.
365        let candidates = vec![1, 3, 7, 9, 2];
366        let context = ObjectiveContext::new().with_max_candidates(2);
367        let result = objective.select(&candidates, &context);
368
369        assert!(matches!(result, Err(ObjectiveError::NoMatch(_))));
370    }
371
372    #[derive(Clone)]
373    struct TestItem {
374        _value: i32,
375        timestamp: chrono::DateTime<chrono::Utc>,
376        importance: f64,
377    }
378
379    impl HasTimestamp for TestItem {
380        fn timestamp(&self) -> chrono::DateTime<chrono::Utc> {
381            self.timestamp
382        }
383    }
384
385    impl HasImportance for TestItem {
386        fn importance(&self) -> f64 {
387            self.importance
388        }
389    }
390
391    #[test]
392    fn test_recency_objective() {
393        let objective = RecencyObjective::hours(1.0);
394        let context = ObjectiveContext::new();
395
396        let now = chrono::Utc::now();
397        let old = now - chrono::Duration::hours(2);
398
399        let new_item = TestItem {
400            _value: 1,
401            timestamp: now,
402            importance: 0.5,
403        };
404        let old_item = TestItem {
405            _value: 2,
406            timestamp: old,
407            importance: 0.5,
408        };
409
410        let new_score = objective.score(&new_item, &context);
411        let old_score = objective.score(&old_item, &context);
412
413        assert!(new_score > old_score);
414        assert!((new_score - 1.0).abs() < 0.1);
415    }
416
417    #[test]
418    fn test_relevance_objective() {
419        let objective = RelevanceObjective::balanced(3600.0);
420        let context = ObjectiveContext::new();
421
422        let now = chrono::Utc::now();
423
424        let item = TestItem {
425            _value: 1,
426            timestamp: now,
427            importance: 0.8,
428        };
429
430        let score = objective.score(&item, &context);
431
432        assert!(score > 0.8 && score < 1.0);
433    }
434
435    #[test]
436    fn test_relevance_uses_context_relevance_score() {
437        let objective = RelevanceObjective::balanced(3600.0);
438        let context =
439            ObjectiveContext::new().with_extra(serde_json::json!({"relevance_score": 0.42}));
440
441        let now = chrono::Utc::now();
442        let item = TestItem {
443            _value: 1,
444            timestamp: now,
445            importance: 0.9,
446        };
447
448        // The context relevance_score should override the recency+importance fusion.
449        let score = objective.score(&item, &context);
450        assert!((score - 0.42).abs() < 1e-9);
451    }
452
453    #[test]
454    #[should_panic(expected = "recency_weight must be finite and non-negative")]
455    fn test_relevance_negative_recency_weight_panics() {
456        RelevanceObjective::new(3600.0, -0.1, 0.5);
457    }
458
459    #[test]
460    #[should_panic(expected = "importance_weight must be finite and non-negative")]
461    fn test_relevance_nan_importance_weight_panics() {
462        RelevanceObjective::new(3600.0, 0.5, f64::NAN);
463    }
464
465    #[test]
466    #[should_panic(expected = "half_life_seconds must be positive and finite")]
467    fn test_recency_zero_half_life_panics() {
468        RecencyObjective::new(0.0);
469    }
470
471    #[test]
472    #[should_panic(expected = "half_life_seconds must be positive and finite")]
473    fn test_recency_negative_half_life_panics() {
474        RecencyObjective::new(-1.0);
475    }
476
477    #[test]
478    #[should_panic(expected = "half_life_seconds must be positive and finite")]
479    fn test_recency_nan_half_life_panics() {
480        RecencyObjective::new(f64::NAN);
481    }
482
483    #[test]
484    fn test_threshold_no_match_below_threshold() {
485        let objective = ThresholdObjective::new(|n: &i32| *n as f64, 10.0);
486
487        let candidates = vec![1, 5, 3];
488        let result = objective.select(&candidates, &ObjectiveContext::new());
489
490        assert!(matches!(result, Err(ObjectiveError::NoMatch(_))));
491    }
492
493    #[test]
494    fn test_threshold_selects_best_above() {
495        let objective = ThresholdObjective::new(|n: &i32| *n as f64, 5.0);
496
497        let candidates = vec![1, 10, 3, 15];
498        let selection = objective
499            .select(&candidates, &ObjectiveContext::new())
500            .unwrap();
501
502        assert_eq!(*selection.item, 15);
503        assert_eq!(selection.score, 15.0);
504        assert_eq!(selection.passed, 2);
505    }
506}