1use crate::{Objective, ObjectiveContext, Selection};
4
5pub struct MaxScoreObjective<T, F>
7where
8 F: Fn(&T) -> f64 + Send + Sync,
9{
10 scorer: F,
11 _phantom: std::marker::PhantomData<T>,
12}
13
14impl<T, F> MaxScoreObjective<T, F>
15where
16 F: Fn(&T) -> f64 + Send + Sync,
17{
18 pub fn new(scorer: F) -> Self {
20 Self {
21 scorer,
22 _phantom: std::marker::PhantomData,
23 }
24 }
25}
26
27impl<T: Send + Sync, F> Objective<T> for MaxScoreObjective<T, F>
28where
29 F: Fn(&T) -> f64 + Send + Sync,
30{
31 fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
32 (self.scorer)(candidate)
33 }
34
35 fn name(&self) -> &str {
36 "MaxScoreObjective"
37 }
38}
39
40pub struct ThresholdObjective<T, F>
42where
43 F: Fn(&T) -> f64 + Send + Sync,
44{
45 scorer: F,
46 threshold: f64,
47 _phantom: std::marker::PhantomData<T>,
48}
49
50impl<T, F> ThresholdObjective<T, F>
51where
52 F: Fn(&T) -> f64 + Send + Sync,
53{
54 pub fn new(scorer: F, threshold: f64) -> Self {
56 Self {
57 scorer,
58 threshold,
59 _phantom: std::marker::PhantomData,
60 }
61 }
62}
63
64impl<T: Send + Sync, F> Objective<T> for ThresholdObjective<T, F>
65where
66 F: Fn(&T) -> f64 + Send + Sync,
67{
68 fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
69 (self.scorer)(candidate)
70 }
71
72 fn passes_score(&self, score: f64, context: &ObjectiveContext) -> bool {
73 if !score.is_finite() {
74 return false;
75 }
76 let passes_obj = score >= self.threshold;
77 let passes_ctx = context.min_score.map(|min| score >= min).unwrap_or(true);
78 passes_obj && passes_ctx
79 }
80
81 fn passes(&self, candidate: &T, context: &ObjectiveContext) -> bool {
82 let score = (self.scorer)(candidate);
83 self.passes_score(score, context)
84 }
85
86 fn name(&self) -> &str {
87 "ThresholdObjective"
88 }
89}
90
91pub struct FirstMatchObjective<T, F>
93where
94 F: Fn(&T) -> bool + Send + Sync,
95{
96 predicate: F,
97 _phantom: std::marker::PhantomData<T>,
98}
99
100impl<T, F> FirstMatchObjective<T, F>
101where
102 F: Fn(&T) -> bool + Send + Sync,
103{
104 pub fn new(predicate: F) -> Self {
106 Self {
107 predicate,
108 _phantom: std::marker::PhantomData,
109 }
110 }
111}
112
113impl<T: Send + Sync, F> Objective<T> for FirstMatchObjective<T, F>
114where
115 F: Fn(&T) -> bool + Send + Sync,
116{
117 fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
118 if (self.predicate)(candidate) {
119 1.0
120 } else {
121 0.0
122 }
123 }
124
125 fn select<'a>(&self, candidates: &'a [T], context: &ObjectiveContext) -> Vec<Selection<&'a T>> {
126 if candidates.is_empty() {
127 return Vec::new();
128 }
129
130 let limit = context
131 .max_candidates
132 .unwrap_or(candidates.len())
133 .min(candidates.len());
134
135 for (i, candidate) in candidates.iter().take(limit).enumerate() {
136 if (self.predicate)(candidate) {
137 return vec![Selection::new(candidate, 1.0, i)
138 .with_considered(i + 1)
139 .with_passed(1)];
140 }
141 }
142
143 Vec::new()
144 }
145
146 fn name(&self) -> &str {
147 "FirstMatchObjective"
148 }
149}
150
151pub trait HasTimestamp {
153 fn timestamp(&self) -> chrono::DateTime<chrono::Utc>;
155}
156
157pub trait HasSalience {
159 fn salience(&self) -> f64;
161}
162
163pub struct RecencyObjective {
165 half_life_seconds: f64,
166}
167
168impl RecencyObjective {
169 const MIN_HALF_LIFE: f64 = 1.0;
170
171 pub fn new(half_life_seconds: f64) -> Self {
176 assert!(
177 half_life_seconds.is_finite() && half_life_seconds > 0.0,
178 "half_life_seconds must be positive and finite, got {half_life_seconds}"
179 );
180 Self {
181 half_life_seconds: half_life_seconds.max(Self::MIN_HALF_LIFE),
182 }
183 }
184
185 pub fn hours(hours: f64) -> Self {
187 Self::new(hours * 3600.0)
188 }
189
190 pub fn days(days: f64) -> Self {
192 Self::new(days * 86400.0)
193 }
194}
195
196impl<T: HasTimestamp + Send + Sync> Objective<T> for RecencyObjective {
197 fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
198 let age_seconds = (context.as_of - candidate.timestamp()).num_seconds().max(0) as f64;
199 0.5f64.powf(age_seconds / self.half_life_seconds)
200 }
201
202 fn name(&self) -> &str {
203 "RecencyObjective"
204 }
205}
206
207pub struct SalienceObjective {
209 min_salience: f64,
210}
211
212impl SalienceObjective {
213 pub fn new() -> Self {
215 Self { min_salience: 0.0 }
216 }
217
218 pub fn with_min(mut self, min: f64) -> Self {
220 self.min_salience = min;
221 self
222 }
223}
224
225impl Default for SalienceObjective {
226 fn default() -> Self {
227 Self::new()
228 }
229}
230
231impl<T: HasSalience + Send + Sync> Objective<T> for SalienceObjective {
232 fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
233 let salience = candidate.salience();
234 if salience >= self.min_salience {
235 salience
236 } else {
237 0.0
238 }
239 }
240
241 fn name(&self) -> &str {
242 "SalienceObjective"
243 }
244}
245
246pub struct RelevanceObjective {
248 recency_weight: f64,
249 salience_weight: f64,
250 recency: RecencyObjective,
251}
252
253impl RelevanceObjective {
254 pub fn new(recency_half_life: f64, recency_weight: f64, salience_weight: f64) -> Self {
259 assert!(
260 recency_weight.is_finite() && recency_weight >= 0.0,
261 "recency_weight must be finite and non-negative, got {recency_weight}"
262 );
263 assert!(
264 salience_weight.is_finite() && salience_weight >= 0.0,
265 "salience_weight must be finite and non-negative, got {salience_weight}"
266 );
267 Self {
268 recency_weight,
269 salience_weight,
270 recency: RecencyObjective::new(recency_half_life),
271 }
272 }
273
274 pub fn balanced(recency_half_life: f64) -> Self {
276 Self::new(recency_half_life, 0.5, 0.5)
277 }
278}
279
280impl<T: HasTimestamp + HasSalience + Send + Sync> Objective<T> for RelevanceObjective {
281 fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
282 if let Some(v) = context
284 .extra
285 .get("relevance_score")
286 .and_then(|v| v.as_f64())
287 {
288 return v;
289 }
290
291 let recency_score = self.recency.score(candidate, context);
292 let salience_score = candidate.salience();
293
294 let total_weight = self.recency_weight + self.salience_weight;
295 if total_weight > 0.0 {
296 (self.recency_weight * recency_score + self.salience_weight * salience_score)
297 / total_weight
298 } else {
299 0.0
300 }
301 }
302
303 fn name(&self) -> &str {
304 "RelevanceObjective"
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 #[test]
313 fn test_max_score_objective() {
314 let objective = MaxScoreObjective::new(|n: &i32| *n as f64);
315
316 let candidates = vec![1, 5, 3, 8, 2];
317 let selection = objective
318 .select(&candidates, &ObjectiveContext::new())
319 .into_iter()
320 .next()
321 .unwrap();
322
323 assert_eq!(*selection.item, 8);
324 }
325
326 #[test]
327 fn test_threshold_objective() {
328 let objective = ThresholdObjective::new(|n: &i32| *n as f64, 5.0);
329
330 assert!(objective.passes(&10, &ObjectiveContext::new()));
331 assert!(!objective.passes(&3, &ObjectiveContext::new()));
332 }
333
334 #[test]
335 fn test_threshold_objective_rejects_infinite_scores() {
336 let objective = ThresholdObjective::new(|_n: &i32| f64::INFINITY, 5.0);
337
338 assert!(!objective.passes(&10, &ObjectiveContext::new()));
339 }
340
341 #[test]
342 fn test_first_match_objective() {
343 let objective = FirstMatchObjective::new(|n: &i32| *n > 5);
344
345 let candidates = vec![1, 3, 7, 9, 2];
346 let selection = objective
347 .select(&candidates, &ObjectiveContext::new())
348 .into_iter()
349 .next()
350 .unwrap();
351
352 assert_eq!(*selection.item, 7);
353 assert_eq!(selection.index, 2);
354 }
355
356 #[test]
357 fn test_first_match_respects_max_candidates() {
358 let objective = FirstMatchObjective::new(|n: &i32| *n > 5);
359
360 let candidates = vec![1, 3, 7, 9, 2];
362 let context = ObjectiveContext::new().with_max_candidates(2);
363 let result = objective.select(&candidates, &context);
364
365 assert!(result.is_empty());
366 }
367
368 #[derive(Clone)]
369 struct TestItem {
370 _value: i32,
371 timestamp: chrono::DateTime<chrono::Utc>,
372 salience: f64,
373 }
374
375 impl HasTimestamp for TestItem {
376 fn timestamp(&self) -> chrono::DateTime<chrono::Utc> {
377 self.timestamp
378 }
379 }
380
381 impl HasSalience for TestItem {
382 fn salience(&self) -> f64 {
383 self.salience
384 }
385 }
386
387 #[test]
388 fn test_recency_objective() {
389 let objective = RecencyObjective::hours(1.0);
390 let now = chrono::Utc::now();
391 let context = ObjectiveContext::at(now);
393
394 let old = now - chrono::Duration::hours(2);
395
396 let new_item = TestItem {
397 _value: 1,
398 timestamp: now,
399 salience: 0.5,
400 };
401 let old_item = TestItem {
402 _value: 2,
403 timestamp: old,
404 salience: 0.5,
405 };
406
407 let new_score = objective.score(&new_item, &context);
408 let old_score = objective.score(&old_item, &context);
409
410 assert!(new_score > old_score);
411 assert!((new_score - 1.0).abs() < 0.1);
412 }
413
414 #[test]
415 fn test_relevance_objective() {
416 let objective = RelevanceObjective::balanced(3600.0);
417 let now = chrono::Utc::now();
418 let context = ObjectiveContext::at(now);
420
421 let item = TestItem {
422 _value: 1,
423 timestamp: now,
424 salience: 0.8,
425 };
426
427 let score = objective.score(&item, &context);
428
429 assert!(score > 0.8 && score < 1.0);
430 }
431
432 #[test]
433 fn test_relevance_uses_context_relevance_score() {
434 let objective = RelevanceObjective::balanced(3600.0);
435 let now = chrono::Utc::now();
436 let context =
438 ObjectiveContext::at(now).with_extra(serde_json::json!({"relevance_score": 0.42}));
439
440 let item = TestItem {
441 _value: 1,
442 timestamp: now,
443 salience: 0.9,
444 };
445
446 let score = objective.score(&item, &context);
448 assert!((score - 0.42).abs() < 1e-9);
449 }
450
451 #[test]
452 #[should_panic(expected = "recency_weight must be finite and non-negative")]
453 fn test_relevance_negative_recency_weight_panics() {
454 RelevanceObjective::new(3600.0, -0.1, 0.5);
455 }
456
457 #[test]
458 #[should_panic(expected = "salience_weight must be finite and non-negative")]
459 fn test_relevance_nan_salience_weight_panics() {
460 RelevanceObjective::new(3600.0, 0.5, f64::NAN);
461 }
462
463 #[test]
464 #[should_panic(expected = "half_life_seconds must be positive and finite")]
465 fn test_recency_zero_half_life_panics() {
466 RecencyObjective::new(0.0);
467 }
468
469 #[test]
470 #[should_panic(expected = "half_life_seconds must be positive and finite")]
471 fn test_recency_negative_half_life_panics() {
472 RecencyObjective::new(-1.0);
473 }
474
475 #[test]
476 #[should_panic(expected = "half_life_seconds must be positive and finite")]
477 fn test_recency_nan_half_life_panics() {
478 RecencyObjective::new(f64::NAN);
479 }
480
481 #[test]
482 fn test_threshold_no_match_below_threshold() {
483 let objective = ThresholdObjective::new(|n: &i32| *n as f64, 10.0);
484
485 let candidates = vec![1, 5, 3];
486 let result = objective.select(&candidates, &ObjectiveContext::new());
487
488 assert!(result.is_empty());
489 }
490
491 #[test]
492 fn test_threshold_selects_best_above() {
493 let objective = ThresholdObjective::new(|n: &i32| *n as f64, 5.0);
494
495 let candidates = vec![1, 10, 3, 15];
496 let selection = objective
497 .select(&candidates, &ObjectiveContext::new())
498 .into_iter()
499 .next()
500 .unwrap();
501
502 assert_eq!(*selection.item, 15);
503 assert_eq!(selection.score, 15.0);
504 assert_eq!(selection.passed, 2);
505 }
506}