1use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5
6use uuid::Uuid;
7
8use khive_score::DeterministicScore;
9
10use super::context::ObjectiveContext;
11use super::selection::Selection;
12use crate::ordering::{HasId, ScoredEntry};
13use crate::{ObjectiveError, ObjectiveResult};
14
15const SMALL_TOP_N: usize = 96;
16
17#[derive(Debug, Clone, Copy)]
18struct RankedIndex {
19 score: f64,
20 det_score: DeterministicScore,
21 index: usize,
22}
23
24impl RankedIndex {
25 #[inline]
26 fn new(score: f64, index: usize) -> Self {
27 Self {
28 score,
29 det_score: DeterministicScore::from_f64(score),
30 index,
31 }
32 }
33}
34
35impl Eq for RankedIndex {}
36
37impl PartialEq for RankedIndex {
38 #[inline]
39 fn eq(&self, other: &Self) -> bool {
40 self.det_score == other.det_score && self.index == other.index
41 }
42}
43
44impl Ord for RankedIndex {
45 #[inline]
46 fn cmp(&self, other: &Self) -> Ordering {
47 self.det_score
48 .cmp(&other.det_score)
49 .then_with(|| other.index.cmp(&self.index))
50 }
51}
52
53impl PartialOrd for RankedIndex {
54 #[inline]
55 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
56 Some(self.cmp(other))
57 }
58}
59
60#[derive(Debug, Clone, Copy)]
61struct WorstRankedIndex(RankedIndex);
62
63impl Eq for WorstRankedIndex {}
64
65impl PartialEq for WorstRankedIndex {
66 #[inline]
67 fn eq(&self, other: &Self) -> bool {
68 self.0 == other.0
69 }
70}
71
72impl Ord for WorstRankedIndex {
73 #[inline]
74 fn cmp(&self, other: &Self) -> Ordering {
75 other.0.cmp(&self.0)
76 }
77}
78
79impl PartialOrd for WorstRankedIndex {
80 #[inline]
81 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
82 Some(self.cmp(other))
83 }
84}
85
86#[derive(Debug, Clone, Copy)]
87struct WorstScoredEntry<T>(ScoredEntry<T>);
88
89impl<T> Eq for WorstScoredEntry<T> {}
90
91impl<T> PartialEq for WorstScoredEntry<T> {
92 #[inline]
93 fn eq(&self, other: &Self) -> bool {
94 self.0 == other.0
95 }
96}
97
98impl<T> Ord for WorstScoredEntry<T> {
99 #[inline]
100 fn cmp(&self, other: &Self) -> Ordering {
101 other.0.cmp(&self.0)
102 }
103}
104
105impl<T> PartialOrd for WorstScoredEntry<T> {
106 #[inline]
107 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
108 Some(self.cmp(other))
109 }
110}
111
112#[inline]
113fn considered_limit(len: usize, context: &ObjectiveContext) -> usize {
114 context.max_candidates.unwrap_or(len).min(len)
115}
116
117pub trait Objective<T>: Send + Sync {
123 fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64;
125
126 #[inline]
137 fn precision(&self, _candidate: &T, _context: &ObjectiveContext) -> f64 {
138 1.0
139 }
140
141 #[inline]
145 fn passes_score(&self, score: f64, context: &ObjectiveContext) -> bool {
146 score.is_finite() && context.min_score.map(|min| score >= min).unwrap_or(true)
147 }
148
149 #[inline]
151 fn passes(&self, candidate: &T, context: &ObjectiveContext) -> bool {
152 let score = self.score(candidate, context);
153 self.passes_score(score, context)
154 }
155
156 fn batch_score(&self, candidates: &[T], context: &ObjectiveContext) -> Vec<(usize, f64)> {
161 let mut scored = Vec::with_capacity(candidates.len().min(256));
162 for (index, candidate) in candidates.iter().enumerate() {
163 let score = self.score(candidate, context);
164 if self.passes_score(score, context) {
165 scored.push((index, score));
166 }
167 }
168 scored
169 }
170
171 fn select<'a>(
177 &self,
178 candidates: &'a [T],
179 context: &ObjectiveContext,
180 ) -> ObjectiveResult<Selection<&'a T>> {
181 if candidates.is_empty() {
182 return Err(ObjectiveError::NoCandidates);
183 }
184
185 let considered_limit = considered_limit(candidates.len(), context);
186
187 let mut considered = 0usize;
188 let mut passed = 0usize;
189 let mut has_best = false;
190 let mut best_index = 0usize;
191 let mut best_score = 0.0f64;
192 let mut best_precision = 1.0f64;
193 let mut best_det = DeterministicScore::ZERO;
194
195 for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
196 considered += 1;
197
198 let score = self.score(candidate, context);
199 if !self.passes_score(score, context) {
200 continue;
201 }
202
203 passed += 1;
204
205 let precision = self.precision(candidate, context);
206 let effective = score
207 * if precision.is_finite() {
208 precision
209 } else {
210 1.0
211 };
212 let det = DeterministicScore::from_f64(effective);
213 if !has_best || det > best_det {
214 has_best = true;
215 best_index = index;
216 best_score = score;
217 best_precision = precision;
218 best_det = det;
219 }
220 }
221
222 if has_best {
223 Ok(
224 Selection::new(&candidates[best_index], best_score, best_index)
225 .with_precision(best_precision)
226 .with_considered(considered)
227 .with_passed(passed),
228 )
229 } else {
230 Err(ObjectiveError::NoMatch("No candidate passed".into()))
231 }
232 }
233
234 fn select_top<'a>(
239 &self,
240 candidates: &'a [T],
241 n: usize,
242 context: &ObjectiveContext,
243 ) -> Vec<Selection<&'a T>> {
244 if n == 0 || candidates.is_empty() {
245 return Vec::new();
246 }
247
248 if n == 1 {
249 return self.select(candidates, context).ok().into_iter().collect();
250 }
251
252 let considered_limit = considered_limit(candidates.len(), context);
253
254 let mut considered = 0usize;
255 let mut passed = 0usize;
256
257 if n <= SMALL_TOP_N {
258 let mut top: Vec<RankedIndex> = Vec::with_capacity(n.min(considered_limit));
259
260 for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
261 considered += 1;
262
263 let score = self.score(candidate, context);
264 if !self.passes_score(score, context) {
265 continue;
266 }
267
268 passed += 1;
269 let precision = self.precision(candidate, context);
270 let effective = score
271 * if precision.is_finite() {
272 precision
273 } else {
274 1.0
275 };
276 let entry = RankedIndex::new(effective, index);
277
278 if top.len() == n {
279 let worst = *top.last().expect("non-empty top when len == n");
280 if entry <= worst {
281 continue;
282 }
283 }
284
285 let pos = top.partition_point(|existing| *existing >= entry);
286 if pos < n {
287 top.insert(pos, entry);
288 if top.len() > n {
289 top.pop();
290 }
291 }
292 }
293
294 return top
295 .into_iter()
296 .map(|entry| {
297 Selection::new(&candidates[entry.index], entry.score, entry.index)
298 .with_considered(considered)
299 .with_passed(passed)
300 })
301 .collect();
302 }
303
304 let mut heap: BinaryHeap<WorstRankedIndex> = BinaryHeap::with_capacity(n);
305
306 for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
307 considered += 1;
308
309 let score = self.score(candidate, context);
310 if !self.passes_score(score, context) {
311 continue;
312 }
313
314 passed += 1;
315 let precision = self.precision(candidate, context);
316 let effective = score
317 * if precision.is_finite() {
318 precision
319 } else {
320 1.0
321 };
322 let entry = RankedIndex::new(effective, index);
323
324 if heap.len() < n {
325 heap.push(WorstRankedIndex(entry));
326 continue;
327 }
328
329 if let Some(mut worst) = heap.peek_mut() {
330 if entry > worst.0 {
331 *worst = WorstRankedIndex(entry);
332 }
333 }
334 }
335
336 let mut scored: Vec<RankedIndex> = heap.into_iter().map(|entry| entry.0).collect();
337 scored.sort_unstable_by(|a, b| b.cmp(a));
338
339 scored
340 .into_iter()
341 .map(|entry| {
342 Selection::new(&candidates[entry.index], entry.score, entry.index)
343 .with_considered(considered)
344 .with_passed(passed)
345 })
346 .collect()
347 }
348
349 fn name(&self) -> &str {
351 std::any::type_name::<Self>()
352 }
353}
354
355impl<T, F> Objective<T> for F
357where
358 F: Fn(&T, &ObjectiveContext) -> f64 + Send + Sync,
359{
360 #[inline]
361 fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
362 self(candidate, context)
363 }
364}
365
366pub fn objective_fn<T, F>(f: F) -> impl Objective<T>
368where
369 F: Fn(&T, &ObjectiveContext) -> f64 + Send + Sync,
370{
371 f
372}
373
374pub trait DeterministicObjective<T>: Objective<T>
387where
388 T: HasId,
389{
390 fn select_deterministic<'a>(
392 &self,
393 candidates: &'a [T],
394 context: &ObjectiveContext,
395 ) -> ObjectiveResult<Selection<&'a T>>;
396
397 fn select_top_deterministic<'a>(
399 &self,
400 candidates: &'a [T],
401 n: usize,
402 context: &ObjectiveContext,
403 ) -> Vec<Selection<&'a T>>;
404}
405
406impl<O, T> DeterministicObjective<T> for O
408where
409 O: Objective<T>,
410 T: HasId,
411{
412 fn select_deterministic<'a>(
413 &self,
414 candidates: &'a [T],
415 context: &ObjectiveContext,
416 ) -> ObjectiveResult<Selection<&'a T>> {
417 if candidates.is_empty() {
418 return Err(ObjectiveError::NoCandidates);
419 }
420
421 let considered_limit = considered_limit(candidates.len(), context);
422
423 let mut considered = 0usize;
424 let mut passed = 0usize;
425 let mut has_best = false;
426 let mut best_index = 0usize;
427 let mut best_score = 0.0f64;
428 let mut best_precision = 1.0f64;
429 let mut best_det = DeterministicScore::ZERO;
430 let mut best_id = Uuid::nil();
431
432 for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
433 considered += 1;
434
435 let score = self.score(candidate, context);
436 if !self.passes_score(score, context) {
437 continue;
438 }
439
440 passed += 1;
441 let id = candidate.id();
442 let precision = self.precision(candidate, context);
443 let effective = score
444 * if precision.is_finite() {
445 precision
446 } else {
447 1.0
448 };
449 let det = DeterministicScore::from_f64(effective);
450
451 if !has_best || det > best_det || (det == best_det && id < best_id) {
452 has_best = true;
453 best_index = index;
454 best_score = score;
455 best_precision = precision;
456 best_det = det;
457 best_id = id;
458 }
459 }
460
461 if has_best {
462 Ok(
463 Selection::new(&candidates[best_index], best_score, best_index)
464 .with_precision(best_precision)
465 .with_considered(considered)
466 .with_passed(passed),
467 )
468 } else {
469 Err(ObjectiveError::NoMatch("No candidate passed".into()))
470 }
471 }
472
473 fn select_top_deterministic<'a>(
474 &self,
475 candidates: &'a [T],
476 n: usize,
477 context: &ObjectiveContext,
478 ) -> Vec<Selection<&'a T>> {
479 if n == 0 || candidates.is_empty() {
480 return Vec::new();
481 }
482
483 if n == 1 {
484 return self
485 .select_deterministic(candidates, context)
486 .ok()
487 .into_iter()
488 .collect();
489 }
490
491 let considered_limit = considered_limit(candidates.len(), context);
492
493 let mut considered = 0usize;
494 let mut passed = 0usize;
495
496 if n <= SMALL_TOP_N {
497 let mut top: Vec<ScoredEntry<&T>> = Vec::with_capacity(n.min(considered_limit));
498
499 for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
500 considered += 1;
501
502 let score = self.score(candidate, context);
503 if !self.passes_score(score, context) {
504 continue;
505 }
506
507 passed += 1;
508 let precision = self.precision(candidate, context);
509 let effective = score
510 * if precision.is_finite() {
511 precision
512 } else {
513 1.0
514 };
515 let entry = ScoredEntry::new(candidate, effective, index);
516
517 if top.len() == n {
518 let worst = *top.last().expect("non-empty top when len == n");
519 if entry <= worst {
520 continue;
521 }
522 }
523
524 let pos = top.partition_point(|existing| *existing >= entry);
525 if pos < n {
526 top.insert(pos, entry);
527 if top.len() > n {
528 top.pop();
529 }
530 }
531 }
532
533 return top
534 .into_iter()
535 .map(|entry| {
536 Selection::new(entry.into_candidate(), entry.score(), entry.index())
537 .with_considered(considered)
538 .with_passed(passed)
539 })
540 .collect();
541 }
542
543 let mut heap: BinaryHeap<WorstScoredEntry<&T>> = BinaryHeap::with_capacity(n);
544
545 for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
546 considered += 1;
547
548 let score = self.score(candidate, context);
549 if !self.passes_score(score, context) {
550 continue;
551 }
552
553 passed += 1;
554 let precision = self.precision(candidate, context);
555 let effective = score
556 * if precision.is_finite() {
557 precision
558 } else {
559 1.0
560 };
561 let entry = ScoredEntry::new(candidate, effective, index);
562
563 if heap.len() < n {
564 heap.push(WorstScoredEntry(entry));
565 continue;
566 }
567
568 if let Some(mut worst) = heap.peek_mut() {
569 if entry > worst.0 {
570 *worst = WorstScoredEntry(entry);
571 }
572 }
573 }
574
575 let mut scored: Vec<ScoredEntry<&T>> = heap.into_iter().map(|entry| entry.0).collect();
576 scored.sort_unstable_by(|a, b| b.cmp(a));
577
578 scored
579 .into_iter()
580 .map(|entry| {
581 Selection::new(entry.into_candidate(), entry.score(), entry.index())
582 .with_considered(considered)
583 .with_passed(passed)
584 })
585 .collect()
586 }
587}