1use crate::{Objective, ObjectiveContext, Selection};
4
5pub struct MaxScoreObjective<T, F>
7where
8 F: Fn(&T) -> f64 + Send + Sync,
9{
10 scorer: F,
11 _phantom: std::marker::PhantomData<T>,
12}
13
14impl<T, F> MaxScoreObjective<T, F>
15where
16 F: Fn(&T) -> f64 + Send + Sync,
17{
18 pub fn new(scorer: F) -> Self {
20 Self {
21 scorer,
22 _phantom: std::marker::PhantomData,
23 }
24 }
25}
26
27impl<T: Send + Sync, F> Objective<T> for MaxScoreObjective<T, F>
28where
29 F: Fn(&T) -> f64 + Send + Sync,
30{
31 fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
32 (self.scorer)(candidate)
33 }
34
35 fn name(&self) -> &str {
36 "MaxScoreObjective"
37 }
38}
39
40pub struct ThresholdObjective<T, F>
42where
43 F: Fn(&T) -> f64 + Send + Sync,
44{
45 scorer: F,
46 threshold: f64,
47 _phantom: std::marker::PhantomData<T>,
48}
49
50impl<T, F> ThresholdObjective<T, F>
51where
52 F: Fn(&T) -> f64 + Send + Sync,
53{
54 pub fn new(scorer: F, threshold: f64) -> Self {
56 Self {
57 scorer,
58 threshold,
59 _phantom: std::marker::PhantomData,
60 }
61 }
62}
63
64impl<T: Send + Sync, F> Objective<T> for ThresholdObjective<T, F>
65where
66 F: Fn(&T) -> f64 + Send + Sync,
67{
68 fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
69 (self.scorer)(candidate)
70 }
71
72 fn passes_score(&self, score: f64, context: &ObjectiveContext) -> bool {
73 if !score.is_finite() {
74 return false;
75 }
76 let passes_obj = score >= self.threshold;
77 let passes_ctx = context.min_score.map(|min| score >= min).unwrap_or(true);
78 passes_obj && passes_ctx
79 }
80
81 fn passes(&self, candidate: &T, context: &ObjectiveContext) -> bool {
82 let score = (self.scorer)(candidate);
83 self.passes_score(score, context)
84 }
85
86 fn name(&self) -> &str {
87 "ThresholdObjective"
88 }
89}
90
91pub struct FirstMatchObjective<T, F>
93where
94 F: Fn(&T) -> bool + Send + Sync,
95{
96 predicate: F,
97 _phantom: std::marker::PhantomData<T>,
98}
99
100impl<T, F> FirstMatchObjective<T, F>
101where
102 F: Fn(&T) -> bool + Send + Sync,
103{
104 pub fn new(predicate: F) -> Self {
106 Self {
107 predicate,
108 _phantom: std::marker::PhantomData,
109 }
110 }
111}
112
113impl<T: Send + Sync, F> Objective<T> for FirstMatchObjective<T, F>
114where
115 F: Fn(&T) -> bool + Send + Sync,
116{
117 fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
118 if (self.predicate)(candidate) {
119 1.0
120 } else {
121 0.0
122 }
123 }
124
125 fn select<'a>(&self, candidates: &'a [T], context: &ObjectiveContext) -> Vec<Selection<&'a T>> {
126 if candidates.is_empty() {
127 return Vec::new();
128 }
129
130 let limit = context
131 .max_candidates
132 .unwrap_or(candidates.len())
133 .min(candidates.len());
134
135 for (i, candidate) in candidates.iter().take(limit).enumerate() {
136 if (self.predicate)(candidate) {
137 return vec![Selection::new(candidate, 1.0, i)
138 .with_considered(i + 1)
139 .with_passed(1)];
140 }
141 }
142
143 Vec::new()
144 }
145
146 fn name(&self) -> &str {
147 "FirstMatchObjective"
148 }
149}
150
151pub trait HasTimestamp {
153 fn timestamp(&self) -> chrono::DateTime<chrono::Utc>;
155}
156
157pub trait HasImportance {
159 fn importance(&self) -> f64;
161}
162
163pub struct RecencyObjective {
165 half_life_seconds: f64,
166}
167
168impl RecencyObjective {
169 const MIN_HALF_LIFE: f64 = 1.0;
170
171 pub fn new(half_life_seconds: f64) -> Self {
176 assert!(
177 half_life_seconds.is_finite() && half_life_seconds > 0.0,
178 "half_life_seconds must be positive and finite, got {half_life_seconds}"
179 );
180 Self {
181 half_life_seconds: half_life_seconds.max(Self::MIN_HALF_LIFE),
182 }
183 }
184
185 pub fn hours(hours: f64) -> Self {
187 Self::new(hours * 3600.0)
188 }
189
190 pub fn days(days: f64) -> Self {
192 Self::new(days * 86400.0)
193 }
194}
195
196impl<T: HasTimestamp + Send + Sync> Objective<T> for RecencyObjective {
197 fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
198 let age_seconds = (context.as_of - candidate.timestamp()).num_seconds().max(0) as f64;
199 0.5f64.powf(age_seconds / self.half_life_seconds)
200 }
201
202 fn name(&self) -> &str {
203 "RecencyObjective"
204 }
205}
206
207pub struct ImportanceObjective {
209 min_importance: f64,
210}
211
212impl ImportanceObjective {
213 pub fn new() -> Self {
215 Self {
216 min_importance: 0.0,
217 }
218 }
219
220 pub fn with_min(mut self, min: f64) -> Self {
222 self.min_importance = min;
223 self
224 }
225}
226
227impl Default for ImportanceObjective {
228 fn default() -> Self {
229 Self::new()
230 }
231}
232
233impl<T: HasImportance + Send + Sync> Objective<T> for ImportanceObjective {
234 fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
235 let importance = candidate.importance();
236 if importance >= self.min_importance {
237 importance
238 } else {
239 0.0
240 }
241 }
242
243 fn name(&self) -> &str {
244 "ImportanceObjective"
245 }
246}
247
248pub struct RelevanceObjective {
250 recency_weight: f64,
251 importance_weight: f64,
252 recency: RecencyObjective,
253}
254
255impl RelevanceObjective {
256 pub fn new(recency_half_life: f64, recency_weight: f64, importance_weight: f64) -> Self {
261 assert!(
262 recency_weight.is_finite() && recency_weight >= 0.0,
263 "recency_weight must be finite and non-negative, got {recency_weight}"
264 );
265 assert!(
266 importance_weight.is_finite() && importance_weight >= 0.0,
267 "importance_weight must be finite and non-negative, got {importance_weight}"
268 );
269 Self {
270 recency_weight,
271 importance_weight,
272 recency: RecencyObjective::new(recency_half_life),
273 }
274 }
275
276 pub fn balanced(recency_half_life: f64) -> Self {
278 Self::new(recency_half_life, 0.5, 0.5)
279 }
280}
281
282impl<T: HasTimestamp + HasImportance + Send + Sync> Objective<T> for RelevanceObjective {
283 fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
284 if let Some(v) = context
286 .extra
287 .get("relevance_score")
288 .and_then(|v| v.as_f64())
289 {
290 return v;
291 }
292
293 let recency_score = self.recency.score(candidate, context);
294 let importance_score = candidate.importance();
295
296 let total_weight = self.recency_weight + self.importance_weight;
297 if total_weight > 0.0 {
298 (self.recency_weight * recency_score + self.importance_weight * importance_score)
299 / total_weight
300 } else {
301 0.0
302 }
303 }
304
305 fn name(&self) -> &str {
306 "RelevanceObjective"
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314 #[test]
315 fn test_max_score_objective() {
316 let objective = MaxScoreObjective::new(|n: &i32| *n as f64);
317
318 let candidates = vec![1, 5, 3, 8, 2];
319 let selection = objective
320 .select(&candidates, &ObjectiveContext::new())
321 .into_iter()
322 .next()
323 .unwrap();
324
325 assert_eq!(*selection.item, 8);
326 }
327
328 #[test]
329 fn test_threshold_objective() {
330 let objective = ThresholdObjective::new(|n: &i32| *n as f64, 5.0);
331
332 assert!(objective.passes(&10, &ObjectiveContext::new()));
333 assert!(!objective.passes(&3, &ObjectiveContext::new()));
334 }
335
336 #[test]
337 fn test_threshold_objective_rejects_infinite_scores() {
338 let objective = ThresholdObjective::new(|_n: &i32| f64::INFINITY, 5.0);
339
340 assert!(!objective.passes(&10, &ObjectiveContext::new()));
341 }
342
343 #[test]
344 fn test_first_match_objective() {
345 let objective = FirstMatchObjective::new(|n: &i32| *n > 5);
346
347 let candidates = vec![1, 3, 7, 9, 2];
348 let selection = objective
349 .select(&candidates, &ObjectiveContext::new())
350 .into_iter()
351 .next()
352 .unwrap();
353
354 assert_eq!(*selection.item, 7);
355 assert_eq!(selection.index, 2);
356 }
357
358 #[test]
359 fn test_first_match_respects_max_candidates() {
360 let objective = FirstMatchObjective::new(|n: &i32| *n > 5);
361
362 let candidates = vec![1, 3, 7, 9, 2];
364 let context = ObjectiveContext::new().with_max_candidates(2);
365 let result = objective.select(&candidates, &context);
366
367 assert!(result.is_empty());
368 }
369
370 #[derive(Clone)]
371 struct TestItem {
372 _value: i32,
373 timestamp: chrono::DateTime<chrono::Utc>,
374 importance: f64,
375 }
376
377 impl HasTimestamp for TestItem {
378 fn timestamp(&self) -> chrono::DateTime<chrono::Utc> {
379 self.timestamp
380 }
381 }
382
383 impl HasImportance for TestItem {
384 fn importance(&self) -> f64 {
385 self.importance
386 }
387 }
388
389 #[test]
390 fn test_recency_objective() {
391 let objective = RecencyObjective::hours(1.0);
392 let context = ObjectiveContext::new();
393
394 let now = chrono::Utc::now();
395 let old = now - chrono::Duration::hours(2);
396
397 let new_item = TestItem {
398 _value: 1,
399 timestamp: now,
400 importance: 0.5,
401 };
402 let old_item = TestItem {
403 _value: 2,
404 timestamp: old,
405 importance: 0.5,
406 };
407
408 let new_score = objective.score(&new_item, &context);
409 let old_score = objective.score(&old_item, &context);
410
411 assert!(new_score > old_score);
412 assert!((new_score - 1.0).abs() < 0.1);
413 }
414
415 #[test]
416 fn test_relevance_objective() {
417 let objective = RelevanceObjective::balanced(3600.0);
418 let context = ObjectiveContext::new();
419
420 let now = chrono::Utc::now();
421
422 let item = TestItem {
423 _value: 1,
424 timestamp: now,
425 importance: 0.8,
426 };
427
428 let score = objective.score(&item, &context);
429
430 assert!(score > 0.8 && score < 1.0);
431 }
432
433 #[test]
434 fn test_relevance_uses_context_relevance_score() {
435 let objective = RelevanceObjective::balanced(3600.0);
436 let context =
437 ObjectiveContext::new().with_extra(serde_json::json!({"relevance_score": 0.42}));
438
439 let now = chrono::Utc::now();
440 let item = TestItem {
441 _value: 1,
442 timestamp: now,
443 importance: 0.9,
444 };
445
446 let score = objective.score(&item, &context);
448 assert!((score - 0.42).abs() < 1e-9);
449 }
450
451 #[test]
452 #[should_panic(expected = "recency_weight must be finite and non-negative")]
453 fn test_relevance_negative_recency_weight_panics() {
454 RelevanceObjective::new(3600.0, -0.1, 0.5);
455 }
456
457 #[test]
458 #[should_panic(expected = "importance_weight must be finite and non-negative")]
459 fn test_relevance_nan_importance_weight_panics() {
460 RelevanceObjective::new(3600.0, 0.5, f64::NAN);
461 }
462
463 #[test]
464 #[should_panic(expected = "half_life_seconds must be positive and finite")]
465 fn test_recency_zero_half_life_panics() {
466 RecencyObjective::new(0.0);
467 }
468
469 #[test]
470 #[should_panic(expected = "half_life_seconds must be positive and finite")]
471 fn test_recency_negative_half_life_panics() {
472 RecencyObjective::new(-1.0);
473 }
474
475 #[test]
476 #[should_panic(expected = "half_life_seconds must be positive and finite")]
477 fn test_recency_nan_half_life_panics() {
478 RecencyObjective::new(f64::NAN);
479 }
480
481 #[test]
482 fn test_threshold_no_match_below_threshold() {
483 let objective = ThresholdObjective::new(|n: &i32| *n as f64, 10.0);
484
485 let candidates = vec![1, 5, 3];
486 let result = objective.select(&candidates, &ObjectiveContext::new());
487
488 assert!(result.is_empty());
489 }
490
491 #[test]
492 fn test_threshold_selects_best_above() {
493 let objective = ThresholdObjective::new(|n: &i32| *n as f64, 5.0);
494
495 let candidates = vec![1, 10, 3, 15];
496 let selection = objective
497 .select(&candidates, &ObjectiveContext::new())
498 .into_iter()
499 .next()
500 .unwrap();
501
502 assert_eq!(*selection.item, 15);
503 assert_eq!(selection.score, 15.0);
504 assert_eq!(selection.passed, 2);
505 }
506}