1use crate::{Objective, ObjectiveContext, ObjectiveError, ObjectiveResult, Selection};
4
5pub struct MaxScoreObjective<T, F>
7where
8 F: Fn(&T) -> f64 + Send + Sync,
9{
10 scorer: F,
11 _phantom: std::marker::PhantomData<T>,
12}
13
14impl<T, F> MaxScoreObjective<T, F>
15where
16 F: Fn(&T) -> f64 + Send + Sync,
17{
18 pub fn new(scorer: F) -> Self {
20 Self {
21 scorer,
22 _phantom: std::marker::PhantomData,
23 }
24 }
25}
26
27impl<T: Send + Sync, F> Objective<T> for MaxScoreObjective<T, F>
28where
29 F: Fn(&T) -> f64 + Send + Sync,
30{
31 fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
32 (self.scorer)(candidate)
33 }
34
35 fn name(&self) -> &str {
36 "MaxScoreObjective"
37 }
38}
39
40pub struct ThresholdObjective<T, F>
42where
43 F: Fn(&T) -> f64 + Send + Sync,
44{
45 scorer: F,
46 threshold: f64,
47 _phantom: std::marker::PhantomData<T>,
48}
49
50impl<T, F> ThresholdObjective<T, F>
51where
52 F: Fn(&T) -> f64 + Send + Sync,
53{
54 pub fn new(scorer: F, threshold: f64) -> Self {
56 Self {
57 scorer,
58 threshold,
59 _phantom: std::marker::PhantomData,
60 }
61 }
62}
63
64impl<T: Send + Sync, F> Objective<T> for ThresholdObjective<T, F>
65where
66 F: Fn(&T) -> f64 + Send + Sync,
67{
68 fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
69 (self.scorer)(candidate)
70 }
71
72 fn passes_score(&self, score: f64, context: &ObjectiveContext) -> bool {
73 if !score.is_finite() {
74 return false;
75 }
76 let passes_obj = score >= self.threshold;
77 let passes_ctx = context.min_score.map(|min| score >= min).unwrap_or(true);
78 passes_obj && passes_ctx
79 }
80
81 fn passes(&self, candidate: &T, context: &ObjectiveContext) -> bool {
82 let score = (self.scorer)(candidate);
83 self.passes_score(score, context)
84 }
85
86 fn name(&self) -> &str {
87 "ThresholdObjective"
88 }
89}
90
91pub struct FirstMatchObjective<T, F>
93where
94 F: Fn(&T) -> bool + Send + Sync,
95{
96 predicate: F,
97 _phantom: std::marker::PhantomData<T>,
98}
99
100impl<T, F> FirstMatchObjective<T, F>
101where
102 F: Fn(&T) -> bool + Send + Sync,
103{
104 pub fn new(predicate: F) -> Self {
106 Self {
107 predicate,
108 _phantom: std::marker::PhantomData,
109 }
110 }
111}
112
113impl<T: Send + Sync, F> Objective<T> for FirstMatchObjective<T, F>
114where
115 F: Fn(&T) -> bool + Send + Sync,
116{
117 fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
118 if (self.predicate)(candidate) {
119 1.0
120 } else {
121 0.0
122 }
123 }
124
125 fn select<'a>(
126 &self,
127 candidates: &'a [T],
128 context: &ObjectiveContext,
129 ) -> ObjectiveResult<Selection<&'a T>> {
130 if candidates.is_empty() {
131 return Err(ObjectiveError::NoCandidates);
132 }
133
134 let limit = context
135 .max_candidates
136 .unwrap_or(candidates.len())
137 .min(candidates.len());
138
139 for (i, candidate) in candidates.iter().take(limit).enumerate() {
140 if (self.predicate)(candidate) {
141 return Ok(Selection::new(candidate, 1.0, i)
142 .with_considered(i + 1)
143 .with_passed(1));
144 }
145 }
146
147 Err(ObjectiveError::NoMatch(
148 "No candidate matched predicate".into(),
149 ))
150 }
151
152 fn name(&self) -> &str {
153 "FirstMatchObjective"
154 }
155}
156
157pub trait HasTimestamp {
159 fn timestamp(&self) -> chrono::DateTime<chrono::Utc>;
161}
162
163pub trait HasImportance {
165 fn importance(&self) -> f64;
167}
168
169pub struct RecencyObjective {
171 half_life_seconds: f64,
172}
173
174impl RecencyObjective {
175 const MIN_HALF_LIFE: f64 = 1.0;
176
177 pub fn new(half_life_seconds: f64) -> Self {
182 assert!(
183 half_life_seconds.is_finite() && half_life_seconds > 0.0,
184 "half_life_seconds must be positive and finite, got {half_life_seconds}"
185 );
186 Self {
187 half_life_seconds: half_life_seconds.max(Self::MIN_HALF_LIFE),
188 }
189 }
190
191 pub fn hours(hours: f64) -> Self {
193 Self::new(hours * 3600.0)
194 }
195
196 pub fn days(days: f64) -> Self {
198 Self::new(days * 86400.0)
199 }
200}
201
202impl<T: HasTimestamp + Send + Sync> Objective<T> for RecencyObjective {
203 fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
204 let age_seconds = (context.as_of - candidate.timestamp()).num_seconds().max(0) as f64;
205 0.5f64.powf(age_seconds / self.half_life_seconds)
206 }
207
208 fn name(&self) -> &str {
209 "RecencyObjective"
210 }
211}
212
213pub struct ImportanceObjective {
215 min_importance: f64,
216}
217
218impl ImportanceObjective {
219 pub fn new() -> Self {
221 Self {
222 min_importance: 0.0,
223 }
224 }
225
226 pub fn with_min(mut self, min: f64) -> Self {
228 self.min_importance = min;
229 self
230 }
231}
232
233impl Default for ImportanceObjective {
234 fn default() -> Self {
235 Self::new()
236 }
237}
238
239impl<T: HasImportance + Send + Sync> Objective<T> for ImportanceObjective {
240 fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
241 let importance = candidate.importance();
242 if importance >= self.min_importance {
243 importance
244 } else {
245 0.0
246 }
247 }
248
249 fn name(&self) -> &str {
250 "ImportanceObjective"
251 }
252}
253
254pub struct RelevanceObjective {
256 recency_weight: f64,
257 importance_weight: f64,
258 recency: RecencyObjective,
259}
260
261impl RelevanceObjective {
262 pub fn new(recency_half_life: f64, recency_weight: f64, importance_weight: f64) -> Self {
267 assert!(
268 recency_weight.is_finite() && recency_weight >= 0.0,
269 "recency_weight must be finite and non-negative, got {recency_weight}"
270 );
271 assert!(
272 importance_weight.is_finite() && importance_weight >= 0.0,
273 "importance_weight must be finite and non-negative, got {importance_weight}"
274 );
275 Self {
276 recency_weight,
277 importance_weight,
278 recency: RecencyObjective::new(recency_half_life),
279 }
280 }
281
282 pub fn balanced(recency_half_life: f64) -> Self {
284 Self::new(recency_half_life, 0.5, 0.5)
285 }
286}
287
288impl<T: HasTimestamp + HasImportance + Send + Sync> Objective<T> for RelevanceObjective {
289 fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
290 if let Some(v) = context
292 .extra
293 .get("relevance_score")
294 .and_then(|v| v.as_f64())
295 {
296 return v;
297 }
298
299 let recency_score = self.recency.score(candidate, context);
300 let importance_score = candidate.importance();
301
302 let total_weight = self.recency_weight + self.importance_weight;
303 if total_weight > 0.0 {
304 (self.recency_weight * recency_score + self.importance_weight * importance_score)
305 / total_weight
306 } else {
307 0.0
308 }
309 }
310
311 fn name(&self) -> &str {
312 "RelevanceObjective"
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319
320 #[test]
321 fn test_max_score_objective() {
322 let objective = MaxScoreObjective::new(|n: &i32| *n as f64);
323
324 let candidates = vec![1, 5, 3, 8, 2];
325 let selection = objective
326 .select(&candidates, &ObjectiveContext::new())
327 .unwrap();
328
329 assert_eq!(*selection.item, 8);
330 }
331
332 #[test]
333 fn test_threshold_objective() {
334 let objective = ThresholdObjective::new(|n: &i32| *n as f64, 5.0);
335
336 assert!(objective.passes(&10, &ObjectiveContext::new()));
337 assert!(!objective.passes(&3, &ObjectiveContext::new()));
338 }
339
340 #[test]
341 fn test_threshold_objective_rejects_infinite_scores() {
342 let objective = ThresholdObjective::new(|_n: &i32| f64::INFINITY, 5.0);
343
344 assert!(!objective.passes(&10, &ObjectiveContext::new()));
345 }
346
347 #[test]
348 fn test_first_match_objective() {
349 let objective = FirstMatchObjective::new(|n: &i32| *n > 5);
350
351 let candidates = vec![1, 3, 7, 9, 2];
352 let selection = objective
353 .select(&candidates, &ObjectiveContext::new())
354 .unwrap();
355
356 assert_eq!(*selection.item, 7);
357 assert_eq!(selection.index, 2);
358 }
359
360 #[test]
361 fn test_first_match_respects_max_candidates() {
362 let objective = FirstMatchObjective::new(|n: &i32| *n > 5);
363
364 let candidates = vec![1, 3, 7, 9, 2];
366 let context = ObjectiveContext::new().with_max_candidates(2);
367 let result = objective.select(&candidates, &context);
368
369 assert!(matches!(result, Err(ObjectiveError::NoMatch(_))));
370 }
371
372 #[derive(Clone)]
373 struct TestItem {
374 _value: i32,
375 timestamp: chrono::DateTime<chrono::Utc>,
376 importance: f64,
377 }
378
379 impl HasTimestamp for TestItem {
380 fn timestamp(&self) -> chrono::DateTime<chrono::Utc> {
381 self.timestamp
382 }
383 }
384
385 impl HasImportance for TestItem {
386 fn importance(&self) -> f64 {
387 self.importance
388 }
389 }
390
391 #[test]
392 fn test_recency_objective() {
393 let objective = RecencyObjective::hours(1.0);
394 let context = ObjectiveContext::new();
395
396 let now = chrono::Utc::now();
397 let old = now - chrono::Duration::hours(2);
398
399 let new_item = TestItem {
400 _value: 1,
401 timestamp: now,
402 importance: 0.5,
403 };
404 let old_item = TestItem {
405 _value: 2,
406 timestamp: old,
407 importance: 0.5,
408 };
409
410 let new_score = objective.score(&new_item, &context);
411 let old_score = objective.score(&old_item, &context);
412
413 assert!(new_score > old_score);
414 assert!((new_score - 1.0).abs() < 0.1);
415 }
416
417 #[test]
418 fn test_relevance_objective() {
419 let objective = RelevanceObjective::balanced(3600.0);
420 let context = ObjectiveContext::new();
421
422 let now = chrono::Utc::now();
423
424 let item = TestItem {
425 _value: 1,
426 timestamp: now,
427 importance: 0.8,
428 };
429
430 let score = objective.score(&item, &context);
431
432 assert!(score > 0.8 && score < 1.0);
433 }
434
435 #[test]
436 fn test_relevance_uses_context_relevance_score() {
437 let objective = RelevanceObjective::balanced(3600.0);
438 let context =
439 ObjectiveContext::new().with_extra(serde_json::json!({"relevance_score": 0.42}));
440
441 let now = chrono::Utc::now();
442 let item = TestItem {
443 _value: 1,
444 timestamp: now,
445 importance: 0.9,
446 };
447
448 let score = objective.score(&item, &context);
450 assert!((score - 0.42).abs() < 1e-9);
451 }
452
453 #[test]
454 #[should_panic(expected = "recency_weight must be finite and non-negative")]
455 fn test_relevance_negative_recency_weight_panics() {
456 RelevanceObjective::new(3600.0, -0.1, 0.5);
457 }
458
459 #[test]
460 #[should_panic(expected = "importance_weight must be finite and non-negative")]
461 fn test_relevance_nan_importance_weight_panics() {
462 RelevanceObjective::new(3600.0, 0.5, f64::NAN);
463 }
464
465 #[test]
466 #[should_panic(expected = "half_life_seconds must be positive and finite")]
467 fn test_recency_zero_half_life_panics() {
468 RecencyObjective::new(0.0);
469 }
470
471 #[test]
472 #[should_panic(expected = "half_life_seconds must be positive and finite")]
473 fn test_recency_negative_half_life_panics() {
474 RecencyObjective::new(-1.0);
475 }
476
477 #[test]
478 #[should_panic(expected = "half_life_seconds must be positive and finite")]
479 fn test_recency_nan_half_life_panics() {
480 RecencyObjective::new(f64::NAN);
481 }
482
483 #[test]
484 fn test_threshold_no_match_below_threshold() {
485 let objective = ThresholdObjective::new(|n: &i32| *n as f64, 10.0);
486
487 let candidates = vec![1, 5, 3];
488 let result = objective.select(&candidates, &ObjectiveContext::new());
489
490 assert!(matches!(result, Err(ObjectiveError::NoMatch(_))));
491 }
492
493 #[test]
494 fn test_threshold_selects_best_above() {
495 let objective = ThresholdObjective::new(|n: &i32| *n as f64, 5.0);
496
497 let candidates = vec![1, 10, 3, 15];
498 let selection = objective
499 .select(&candidates, &ObjectiveContext::new())
500 .unwrap();
501
502 assert_eq!(*selection.item, 15);
503 assert_eq!(selection.score, 15.0);
504 assert_eq!(selection.passed, 2);
505 }
506}