Skip to main content

khive_fold/objective/
builtin.rs

1//! Built-in objective functions
2
3use crate::{Objective, ObjectiveContext, Selection};
4
5/// Selects candidate with highest score.
6pub struct MaxScoreObjective<T, F>
7where
8    F: Fn(&T) -> f64 + Send + Sync,
9{
10    scorer: F,
11    _phantom: std::marker::PhantomData<T>,
12}
13
14impl<T, F> MaxScoreObjective<T, F>
15where
16    F: Fn(&T) -> f64 + Send + Sync,
17{
18    /// Create a new max score objective
19    pub fn new(scorer: F) -> Self {
20        Self {
21            scorer,
22            _phantom: std::marker::PhantomData,
23        }
24    }
25}
26
27impl<T: Send + Sync, F> Objective<T> for MaxScoreObjective<T, F>
28where
29    F: Fn(&T) -> f64 + Send + Sync,
30{
31    fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
32        (self.scorer)(candidate)
33    }
34
35    fn name(&self) -> &str {
36        "MaxScoreObjective"
37    }
38}
39
40/// Passes candidates above a threshold.
41pub struct ThresholdObjective<T, F>
42where
43    F: Fn(&T) -> f64 + Send + Sync,
44{
45    scorer: F,
46    threshold: f64,
47    _phantom: std::marker::PhantomData<T>,
48}
49
50impl<T, F> ThresholdObjective<T, F>
51where
52    F: Fn(&T) -> f64 + Send + Sync,
53{
54    /// Create a new threshold objective
55    pub fn new(scorer: F, threshold: f64) -> Self {
56        Self {
57            scorer,
58            threshold,
59            _phantom: std::marker::PhantomData,
60        }
61    }
62}
63
64impl<T: Send + Sync, F> Objective<T> for ThresholdObjective<T, F>
65where
66    F: Fn(&T) -> f64 + Send + Sync,
67{
68    fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
69        (self.scorer)(candidate)
70    }
71
72    fn passes_score(&self, score: f64, context: &ObjectiveContext) -> bool {
73        if !score.is_finite() {
74            return false;
75        }
76        let passes_obj = score >= self.threshold;
77        let passes_ctx = context.min_score.map(|min| score >= min).unwrap_or(true);
78        passes_obj && passes_ctx
79    }
80
81    fn passes(&self, candidate: &T, context: &ObjectiveContext) -> bool {
82        let score = (self.scorer)(candidate);
83        self.passes_score(score, context)
84    }
85
86    fn name(&self) -> &str {
87        "ThresholdObjective"
88    }
89}
90
91/// Returns first candidate that passes predicate.
92pub struct FirstMatchObjective<T, F>
93where
94    F: Fn(&T) -> bool + Send + Sync,
95{
96    predicate: F,
97    _phantom: std::marker::PhantomData<T>,
98}
99
100impl<T, F> FirstMatchObjective<T, F>
101where
102    F: Fn(&T) -> bool + Send + Sync,
103{
104    /// Create a new first match objective
105    pub fn new(predicate: F) -> Self {
106        Self {
107            predicate,
108            _phantom: std::marker::PhantomData,
109        }
110    }
111}
112
113impl<T: Send + Sync, F> Objective<T> for FirstMatchObjective<T, F>
114where
115    F: Fn(&T) -> bool + Send + Sync,
116{
117    fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
118        if (self.predicate)(candidate) {
119            1.0
120        } else {
121            0.0
122        }
123    }
124
125    fn select<'a>(&self, candidates: &'a [T], context: &ObjectiveContext) -> Vec<Selection<&'a T>> {
126        if candidates.is_empty() {
127            return Vec::new();
128        }
129
130        let limit = context
131            .max_candidates
132            .unwrap_or(candidates.len())
133            .min(candidates.len());
134
135        for (i, candidate) in candidates.iter().take(limit).enumerate() {
136            if (self.predicate)(candidate) {
137                return vec![Selection::new(candidate, 1.0, i)
138                    .with_considered(i + 1)
139                    .with_passed(1)];
140            }
141        }
142
143        Vec::new()
144    }
145
146    fn name(&self) -> &str {
147        "FirstMatchObjective"
148    }
149}
150
151/// Trait for items with timestamps
152pub trait HasTimestamp {
153    /// Returns the timestamp of this item
154    fn timestamp(&self) -> chrono::DateTime<chrono::Utc>;
155}
156
157/// Trait for items with salience
158pub trait HasSalience {
159    /// Returns the salience value of this item (0.0 to 1.0)
160    fn salience(&self) -> f64;
161}
162
163/// Scores by recency (newer = higher score).
164pub struct RecencyObjective {
165    half_life_seconds: f64,
166}
167
168impl RecencyObjective {
169    const MIN_HALF_LIFE: f64 = 1.0;
170
171    /// Create a new recency objective. Panics if `half_life_seconds` is not positive and finite.
172    pub fn new(half_life_seconds: f64) -> Self {
173        assert!(
174            half_life_seconds.is_finite() && half_life_seconds > 0.0,
175            "half_life_seconds must be positive and finite, got {half_life_seconds}"
176        );
177        Self {
178            half_life_seconds: half_life_seconds.max(Self::MIN_HALF_LIFE),
179        }
180    }
181
182    /// Create with hour half-life. Panics if `hours` is not positive and finite.
183    pub fn hours(hours: f64) -> Self {
184        Self::new(hours * 3600.0)
185    }
186
187    /// Create with day half-life. Panics if `days` is not positive and finite.
188    pub fn days(days: f64) -> Self {
189        Self::new(days * 86400.0)
190    }
191}
192
193impl<T: HasTimestamp + Send + Sync> Objective<T> for RecencyObjective {
194    fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
195        let age_seconds = (context.as_of - candidate.timestamp()).num_seconds().max(0) as f64;
196        0.5f64.powf(age_seconds / self.half_life_seconds)
197    }
198
199    fn name(&self) -> &str {
200        "RecencyObjective"
201    }
202}
203
204/// Scores by salience field.
205pub struct SalienceObjective {
206    min_salience: f64,
207}
208
209impl SalienceObjective {
210    /// Create a new salience objective
211    pub fn new() -> Self {
212        Self { min_salience: 0.0 }
213    }
214
215    /// Set minimum salience
216    pub fn with_min(mut self, min: f64) -> Self {
217        self.min_salience = min;
218        self
219    }
220}
221
222impl Default for SalienceObjective {
223    fn default() -> Self {
224        Self::new()
225    }
226}
227
228impl<T: HasSalience + Send + Sync> Objective<T> for SalienceObjective {
229    fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
230        let salience = candidate.salience();
231        if salience >= self.min_salience {
232            salience
233        } else {
234            0.0
235        }
236    }
237
238    fn name(&self) -> &str {
239        "SalienceObjective"
240    }
241}
242
243/// Combines recency and salience.
244pub struct RelevanceObjective {
245    recency_weight: f64,
246    salience_weight: f64,
247    recency: RecencyObjective,
248}
249
250impl RelevanceObjective {
251    /// Create a new relevance objective. Panics if either weight is negative or non-finite.
252    pub fn new(recency_half_life: f64, recency_weight: f64, salience_weight: f64) -> Self {
253        assert!(
254            recency_weight.is_finite() && recency_weight >= 0.0,
255            "recency_weight must be finite and non-negative, got {recency_weight}"
256        );
257        assert!(
258            salience_weight.is_finite() && salience_weight >= 0.0,
259            "salience_weight must be finite and non-negative, got {salience_weight}"
260        );
261        Self {
262            recency_weight,
263            salience_weight,
264            recency: RecencyObjective::new(recency_half_life),
265        }
266    }
267
268    /// Create with equal weights (0.5 each). Panics if `recency_half_life` is not positive and finite.
269    pub fn balanced(recency_half_life: f64) -> Self {
270        Self::new(recency_half_life, 0.5, 0.5)
271    }
272}
273
274impl<T: HasTimestamp + HasSalience + Send + Sync> Objective<T> for RelevanceObjective {
275    fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
276        // If context carries a named relevance score, use it directly.
277        if let Some(v) = context
278            .extra
279            .get("relevance_score")
280            .and_then(|v| v.as_f64())
281        {
282            return v;
283        }
284
285        let recency_score = self.recency.score(candidate, context);
286        let salience_score = candidate.salience();
287
288        let total_weight = self.recency_weight + self.salience_weight;
289        if total_weight > 0.0 {
290            (self.recency_weight * recency_score + self.salience_weight * salience_score)
291                / total_weight
292        } else {
293            0.0
294        }
295    }
296
297    fn name(&self) -> &str {
298        "RelevanceObjective"
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    #[test]
307    fn test_max_score_objective() {
308        let objective = MaxScoreObjective::new(|n: &i32| *n as f64);
309
310        let candidates = vec![1, 5, 3, 8, 2];
311        let selection = objective
312            .select(&candidates, &ObjectiveContext::new())
313            .into_iter()
314            .next()
315            .unwrap();
316
317        assert_eq!(*selection.item, 8);
318    }
319
320    #[test]
321    fn test_threshold_objective() {
322        let objective = ThresholdObjective::new(|n: &i32| *n as f64, 5.0);
323
324        assert!(objective.passes(&10, &ObjectiveContext::new()));
325        assert!(!objective.passes(&3, &ObjectiveContext::new()));
326    }
327
328    #[test]
329    fn test_threshold_objective_rejects_infinite_scores() {
330        let objective = ThresholdObjective::new(|_n: &i32| f64::INFINITY, 5.0);
331
332        assert!(!objective.passes(&10, &ObjectiveContext::new()));
333    }
334
335    #[test]
336    fn test_first_match_objective() {
337        let objective = FirstMatchObjective::new(|n: &i32| *n > 5);
338
339        let candidates = vec![1, 3, 7, 9, 2];
340        let selection = objective
341            .select(&candidates, &ObjectiveContext::new())
342            .into_iter()
343            .next()
344            .unwrap();
345
346        assert_eq!(*selection.item, 7);
347        assert_eq!(selection.index, 2);
348    }
349
350    #[test]
351    fn test_first_match_respects_max_candidates() {
352        let objective = FirstMatchObjective::new(|n: &i32| *n > 5);
353
354        // Match is at index 2 (value 7), but max_candidates=2 limits scan to indices 0..1.
355        let candidates = vec![1, 3, 7, 9, 2];
356        let context = ObjectiveContext::new().with_max_candidates(2);
357        let result = objective.select(&candidates, &context);
358
359        assert!(result.is_empty());
360    }
361
362    #[derive(Clone)]
363    struct TestItem {
364        _value: i32,
365        timestamp: chrono::DateTime<chrono::Utc>,
366        salience: f64,
367    }
368
369    impl HasTimestamp for TestItem {
370        fn timestamp(&self) -> chrono::DateTime<chrono::Utc> {
371            self.timestamp
372        }
373    }
374
375    impl HasSalience for TestItem {
376        fn salience(&self) -> f64 {
377            self.salience
378        }
379    }
380
381    #[test]
382    fn test_recency_objective() {
383        let objective = RecencyObjective::hours(1.0);
384        let now = chrono::Utc::now();
385        // Pass current time explicitly — ObjectiveContext::new() defaults to the Unix epoch.
386        let context = ObjectiveContext::at(now);
387
388        let old = now - chrono::Duration::hours(2);
389
390        let new_item = TestItem {
391            _value: 1,
392            timestamp: now,
393            salience: 0.5,
394        };
395        let old_item = TestItem {
396            _value: 2,
397            timestamp: old,
398            salience: 0.5,
399        };
400
401        let new_score = objective.score(&new_item, &context);
402        let old_score = objective.score(&old_item, &context);
403
404        assert!(new_score > old_score);
405        assert!((new_score - 1.0).abs() < 0.1);
406    }
407
408    #[test]
409    fn test_relevance_objective() {
410        let objective = RelevanceObjective::balanced(3600.0);
411        let now = chrono::Utc::now();
412        // Pass current time explicitly — ObjectiveContext::new() defaults to the Unix epoch.
413        let context = ObjectiveContext::at(now);
414
415        let item = TestItem {
416            _value: 1,
417            timestamp: now,
418            salience: 0.8,
419        };
420
421        let score = objective.score(&item, &context);
422
423        assert!(score > 0.8 && score < 1.0);
424    }
425
426    #[test]
427    fn test_relevance_uses_context_relevance_score() {
428        let objective = RelevanceObjective::balanced(3600.0);
429        let now = chrono::Utc::now();
430        // Pass current time explicitly — ObjectiveContext::new() defaults to the Unix epoch.
431        let context =
432            ObjectiveContext::at(now).with_extra(serde_json::json!({"relevance_score": 0.42}));
433
434        let item = TestItem {
435            _value: 1,
436            timestamp: now,
437            salience: 0.9,
438        };
439
440        // The context relevance_score should override the recency+salience fusion.
441        let score = objective.score(&item, &context);
442        assert!((score - 0.42).abs() < 1e-9);
443    }
444
445    #[test]
446    #[should_panic(expected = "recency_weight must be finite and non-negative")]
447    fn test_relevance_negative_recency_weight_panics() {
448        RelevanceObjective::new(3600.0, -0.1, 0.5);
449    }
450
451    #[test]
452    #[should_panic(expected = "salience_weight must be finite and non-negative")]
453    fn test_relevance_nan_salience_weight_panics() {
454        RelevanceObjective::new(3600.0, 0.5, f64::NAN);
455    }
456
457    #[test]
458    #[should_panic(expected = "half_life_seconds must be positive and finite")]
459    fn test_recency_zero_half_life_panics() {
460        RecencyObjective::new(0.0);
461    }
462
463    #[test]
464    #[should_panic(expected = "half_life_seconds must be positive and finite")]
465    fn test_recency_negative_half_life_panics() {
466        RecencyObjective::new(-1.0);
467    }
468
469    #[test]
470    #[should_panic(expected = "half_life_seconds must be positive and finite")]
471    fn test_recency_nan_half_life_panics() {
472        RecencyObjective::new(f64::NAN);
473    }
474
475    #[test]
476    fn test_threshold_no_match_below_threshold() {
477        let objective = ThresholdObjective::new(|n: &i32| *n as f64, 10.0);
478
479        let candidates = vec![1, 5, 3];
480        let result = objective.select(&candidates, &ObjectiveContext::new());
481
482        assert!(result.is_empty());
483    }
484
485    #[test]
486    fn test_threshold_selects_best_above() {
487        let objective = ThresholdObjective::new(|n: &i32| *n as f64, 5.0);
488
489        let candidates = vec![1, 10, 3, 15];
490        let selection = objective
491            .select(&candidates, &ObjectiveContext::new())
492            .into_iter()
493            .next()
494            .unwrap();
495
496        assert_eq!(*selection.item, 15);
497        assert_eq!(selection.score, 15.0);
498        assert_eq!(selection.passed, 2);
499    }
500}