1use crate::{Objective, ObjectiveContext, Selection};
4
5pub struct MaxScoreObjective<T, F>
7where
8 F: Fn(&T) -> f64 + Send + Sync,
9{
10 scorer: F,
11 _phantom: std::marker::PhantomData<T>,
12}
13
14impl<T, F> MaxScoreObjective<T, F>
15where
16 F: Fn(&T) -> f64 + Send + Sync,
17{
18 pub fn new(scorer: F) -> Self {
20 Self {
21 scorer,
22 _phantom: std::marker::PhantomData,
23 }
24 }
25}
26
27impl<T: Send + Sync, F> Objective<T> for MaxScoreObjective<T, F>
28where
29 F: Fn(&T) -> f64 + Send + Sync,
30{
31 fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
32 (self.scorer)(candidate)
33 }
34
35 fn name(&self) -> &str {
36 "MaxScoreObjective"
37 }
38}
39
40pub struct ThresholdObjective<T, F>
42where
43 F: Fn(&T) -> f64 + Send + Sync,
44{
45 scorer: F,
46 threshold: f64,
47 _phantom: std::marker::PhantomData<T>,
48}
49
50impl<T, F> ThresholdObjective<T, F>
51where
52 F: Fn(&T) -> f64 + Send + Sync,
53{
54 pub fn new(scorer: F, threshold: f64) -> Self {
56 Self {
57 scorer,
58 threshold,
59 _phantom: std::marker::PhantomData,
60 }
61 }
62}
63
64impl<T: Send + Sync, F> Objective<T> for ThresholdObjective<T, F>
65where
66 F: Fn(&T) -> f64 + Send + Sync,
67{
68 fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
69 (self.scorer)(candidate)
70 }
71
72 fn passes_score(&self, score: f64, context: &ObjectiveContext) -> bool {
73 if !score.is_finite() {
74 return false;
75 }
76 let passes_obj = score >= self.threshold;
77 let passes_ctx = context.min_score.map(|min| score >= min).unwrap_or(true);
78 passes_obj && passes_ctx
79 }
80
81 fn passes(&self, candidate: &T, context: &ObjectiveContext) -> bool {
82 let score = (self.scorer)(candidate);
83 self.passes_score(score, context)
84 }
85
86 fn name(&self) -> &str {
87 "ThresholdObjective"
88 }
89}
90
91pub struct FirstMatchObjective<T, F>
93where
94 F: Fn(&T) -> bool + Send + Sync,
95{
96 predicate: F,
97 _phantom: std::marker::PhantomData<T>,
98}
99
100impl<T, F> FirstMatchObjective<T, F>
101where
102 F: Fn(&T) -> bool + Send + Sync,
103{
104 pub fn new(predicate: F) -> Self {
106 Self {
107 predicate,
108 _phantom: std::marker::PhantomData,
109 }
110 }
111}
112
113impl<T: Send + Sync, F> Objective<T> for FirstMatchObjective<T, F>
114where
115 F: Fn(&T) -> bool + Send + Sync,
116{
117 fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
118 if (self.predicate)(candidate) {
119 1.0
120 } else {
121 0.0
122 }
123 }
124
125 fn select<'a>(&self, candidates: &'a [T], context: &ObjectiveContext) -> Vec<Selection<&'a T>> {
126 if candidates.is_empty() {
127 return Vec::new();
128 }
129
130 let limit = context
131 .max_candidates
132 .unwrap_or(candidates.len())
133 .min(candidates.len());
134
135 for (i, candidate) in candidates.iter().take(limit).enumerate() {
136 if (self.predicate)(candidate) {
137 return vec![Selection::new(candidate, 1.0, i)
138 .with_considered(i + 1)
139 .with_passed(1)];
140 }
141 }
142
143 Vec::new()
144 }
145
146 fn name(&self) -> &str {
147 "FirstMatchObjective"
148 }
149}
150
151pub trait HasTimestamp {
153 fn timestamp(&self) -> chrono::DateTime<chrono::Utc>;
155}
156
157pub trait HasSalience {
159 fn salience(&self) -> f64;
161}
162
163pub struct RecencyObjective {
165 half_life_seconds: f64,
166}
167
168impl RecencyObjective {
169 const MIN_HALF_LIFE: f64 = 1.0;
170
171 pub fn new(half_life_seconds: f64) -> Self {
173 assert!(
174 half_life_seconds.is_finite() && half_life_seconds > 0.0,
175 "half_life_seconds must be positive and finite, got {half_life_seconds}"
176 );
177 Self {
178 half_life_seconds: half_life_seconds.max(Self::MIN_HALF_LIFE),
179 }
180 }
181
182 pub fn hours(hours: f64) -> Self {
184 Self::new(hours * 3600.0)
185 }
186
187 pub fn days(days: f64) -> Self {
189 Self::new(days * 86400.0)
190 }
191}
192
193impl<T: HasTimestamp + Send + Sync> Objective<T> for RecencyObjective {
194 fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
195 let age_seconds = (context.as_of - candidate.timestamp()).num_seconds().max(0) as f64;
196 0.5f64.powf(age_seconds / self.half_life_seconds)
197 }
198
199 fn name(&self) -> &str {
200 "RecencyObjective"
201 }
202}
203
204pub struct SalienceObjective {
206 min_salience: f64,
207}
208
209impl SalienceObjective {
210 pub fn new() -> Self {
212 Self { min_salience: 0.0 }
213 }
214
215 pub fn with_min(mut self, min: f64) -> Self {
217 self.min_salience = min;
218 self
219 }
220}
221
222impl Default for SalienceObjective {
223 fn default() -> Self {
224 Self::new()
225 }
226}
227
228impl<T: HasSalience + Send + Sync> Objective<T> for SalienceObjective {
229 fn score(&self, candidate: &T, _context: &ObjectiveContext) -> f64 {
230 let salience = candidate.salience();
231 if salience >= self.min_salience {
232 salience
233 } else {
234 0.0
235 }
236 }
237
238 fn name(&self) -> &str {
239 "SalienceObjective"
240 }
241}
242
243pub struct RelevanceObjective {
245 recency_weight: f64,
246 salience_weight: f64,
247 recency: RecencyObjective,
248}
249
250impl RelevanceObjective {
251 pub fn new(recency_half_life: f64, recency_weight: f64, salience_weight: f64) -> Self {
253 assert!(
254 recency_weight.is_finite() && recency_weight >= 0.0,
255 "recency_weight must be finite and non-negative, got {recency_weight}"
256 );
257 assert!(
258 salience_weight.is_finite() && salience_weight >= 0.0,
259 "salience_weight must be finite and non-negative, got {salience_weight}"
260 );
261 Self {
262 recency_weight,
263 salience_weight,
264 recency: RecencyObjective::new(recency_half_life),
265 }
266 }
267
268 pub fn balanced(recency_half_life: f64) -> Self {
270 Self::new(recency_half_life, 0.5, 0.5)
271 }
272}
273
274impl<T: HasTimestamp + HasSalience + Send + Sync> Objective<T> for RelevanceObjective {
275 fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
276 if let Some(v) = context
278 .extra
279 .get("relevance_score")
280 .and_then(|v| v.as_f64())
281 {
282 return v;
283 }
284
285 let recency_score = self.recency.score(candidate, context);
286 let salience_score = candidate.salience();
287
288 let total_weight = self.recency_weight + self.salience_weight;
289 if total_weight > 0.0 {
290 (self.recency_weight * recency_score + self.salience_weight * salience_score)
291 / total_weight
292 } else {
293 0.0
294 }
295 }
296
297 fn name(&self) -> &str {
298 "RelevanceObjective"
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[test]
307 fn test_max_score_objective() {
308 let objective = MaxScoreObjective::new(|n: &i32| *n as f64);
309
310 let candidates = vec![1, 5, 3, 8, 2];
311 let selection = objective
312 .select(&candidates, &ObjectiveContext::new())
313 .into_iter()
314 .next()
315 .unwrap();
316
317 assert_eq!(*selection.item, 8);
318 }
319
320 #[test]
321 fn test_threshold_objective() {
322 let objective = ThresholdObjective::new(|n: &i32| *n as f64, 5.0);
323
324 assert!(objective.passes(&10, &ObjectiveContext::new()));
325 assert!(!objective.passes(&3, &ObjectiveContext::new()));
326 }
327
328 #[test]
329 fn test_threshold_objective_rejects_infinite_scores() {
330 let objective = ThresholdObjective::new(|_n: &i32| f64::INFINITY, 5.0);
331
332 assert!(!objective.passes(&10, &ObjectiveContext::new()));
333 }
334
335 #[test]
336 fn test_first_match_objective() {
337 let objective = FirstMatchObjective::new(|n: &i32| *n > 5);
338
339 let candidates = vec![1, 3, 7, 9, 2];
340 let selection = objective
341 .select(&candidates, &ObjectiveContext::new())
342 .into_iter()
343 .next()
344 .unwrap();
345
346 assert_eq!(*selection.item, 7);
347 assert_eq!(selection.index, 2);
348 }
349
350 #[test]
351 fn test_first_match_respects_max_candidates() {
352 let objective = FirstMatchObjective::new(|n: &i32| *n > 5);
353
354 let candidates = vec![1, 3, 7, 9, 2];
356 let context = ObjectiveContext::new().with_max_candidates(2);
357 let result = objective.select(&candidates, &context);
358
359 assert!(result.is_empty());
360 }
361
362 #[derive(Clone)]
363 struct TestItem {
364 _value: i32,
365 timestamp: chrono::DateTime<chrono::Utc>,
366 salience: f64,
367 }
368
369 impl HasTimestamp for TestItem {
370 fn timestamp(&self) -> chrono::DateTime<chrono::Utc> {
371 self.timestamp
372 }
373 }
374
375 impl HasSalience for TestItem {
376 fn salience(&self) -> f64 {
377 self.salience
378 }
379 }
380
381 #[test]
382 fn test_recency_objective() {
383 let objective = RecencyObjective::hours(1.0);
384 let now = chrono::Utc::now();
385 let context = ObjectiveContext::at(now);
387
388 let old = now - chrono::Duration::hours(2);
389
390 let new_item = TestItem {
391 _value: 1,
392 timestamp: now,
393 salience: 0.5,
394 };
395 let old_item = TestItem {
396 _value: 2,
397 timestamp: old,
398 salience: 0.5,
399 };
400
401 let new_score = objective.score(&new_item, &context);
402 let old_score = objective.score(&old_item, &context);
403
404 assert!(new_score > old_score);
405 assert!((new_score - 1.0).abs() < 0.1);
406 }
407
408 #[test]
409 fn test_relevance_objective() {
410 let objective = RelevanceObjective::balanced(3600.0);
411 let now = chrono::Utc::now();
412 let context = ObjectiveContext::at(now);
414
415 let item = TestItem {
416 _value: 1,
417 timestamp: now,
418 salience: 0.8,
419 };
420
421 let score = objective.score(&item, &context);
422
423 assert!(score > 0.8 && score < 1.0);
424 }
425
426 #[test]
427 fn test_relevance_uses_context_relevance_score() {
428 let objective = RelevanceObjective::balanced(3600.0);
429 let now = chrono::Utc::now();
430 let context =
432 ObjectiveContext::at(now).with_extra(serde_json::json!({"relevance_score": 0.42}));
433
434 let item = TestItem {
435 _value: 1,
436 timestamp: now,
437 salience: 0.9,
438 };
439
440 let score = objective.score(&item, &context);
442 assert!((score - 0.42).abs() < 1e-9);
443 }
444
445 #[test]
446 #[should_panic(expected = "recency_weight must be finite and non-negative")]
447 fn test_relevance_negative_recency_weight_panics() {
448 RelevanceObjective::new(3600.0, -0.1, 0.5);
449 }
450
451 #[test]
452 #[should_panic(expected = "salience_weight must be finite and non-negative")]
453 fn test_relevance_nan_salience_weight_panics() {
454 RelevanceObjective::new(3600.0, 0.5, f64::NAN);
455 }
456
457 #[test]
458 #[should_panic(expected = "half_life_seconds must be positive and finite")]
459 fn test_recency_zero_half_life_panics() {
460 RecencyObjective::new(0.0);
461 }
462
463 #[test]
464 #[should_panic(expected = "half_life_seconds must be positive and finite")]
465 fn test_recency_negative_half_life_panics() {
466 RecencyObjective::new(-1.0);
467 }
468
469 #[test]
470 #[should_panic(expected = "half_life_seconds must be positive and finite")]
471 fn test_recency_nan_half_life_panics() {
472 RecencyObjective::new(f64::NAN);
473 }
474
475 #[test]
476 fn test_threshold_no_match_below_threshold() {
477 let objective = ThresholdObjective::new(|n: &i32| *n as f64, 10.0);
478
479 let candidates = vec![1, 5, 3];
480 let result = objective.select(&candidates, &ObjectiveContext::new());
481
482 assert!(result.is_empty());
483 }
484
485 #[test]
486 fn test_threshold_selects_best_above() {
487 let objective = ThresholdObjective::new(|n: &i32| *n as f64, 5.0);
488
489 let candidates = vec![1, 10, 3, 15];
490 let selection = objective
491 .select(&candidates, &ObjectiveContext::new())
492 .into_iter()
493 .next()
494 .unwrap();
495
496 assert_eq!(*selection.item, 15);
497 assert_eq!(selection.score, 15.0);
498 assert_eq!(selection.passed, 2);
499 }
500}