1use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5
6use uuid::Uuid;
7
8use khive_score::DeterministicScore;
9
10use super::context::ObjectiveContext;
11use super::selection::Selection;
12use crate::ordering::{HasId, ScoredEntry};
13use crate::{ObjectiveError, ObjectiveResult};
14
15const SMALL_TOP_N: usize = 96;
16
17#[derive(Debug, Clone, Copy)]
18struct RankedIndex {
19 score: f64,
20 det_score: DeterministicScore,
21 index: usize,
22}
23
24impl RankedIndex {
25 #[inline]
26 fn new(score: f64, index: usize) -> Self {
27 Self {
28 score,
29 det_score: DeterministicScore::from_f64(score),
30 index,
31 }
32 }
33}
34
35impl Eq for RankedIndex {}
36
37impl PartialEq for RankedIndex {
38 #[inline]
39 fn eq(&self, other: &Self) -> bool {
40 self.det_score == other.det_score && self.index == other.index
41 }
42}
43
44impl Ord for RankedIndex {
45 #[inline]
46 fn cmp(&self, other: &Self) -> Ordering {
47 self.det_score
48 .cmp(&other.det_score)
49 .then_with(|| other.index.cmp(&self.index))
50 }
51}
52
53impl PartialOrd for RankedIndex {
54 #[inline]
55 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
56 Some(self.cmp(other))
57 }
58}
59
60#[derive(Debug, Clone, Copy)]
61struct WorstRankedIndex(RankedIndex);
62
63impl Eq for WorstRankedIndex {}
64
65impl PartialEq for WorstRankedIndex {
66 #[inline]
67 fn eq(&self, other: &Self) -> bool {
68 self.0 == other.0
69 }
70}
71
72impl Ord for WorstRankedIndex {
73 #[inline]
74 fn cmp(&self, other: &Self) -> Ordering {
75 other.0.cmp(&self.0)
76 }
77}
78
79impl PartialOrd for WorstRankedIndex {
80 #[inline]
81 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
82 Some(self.cmp(other))
83 }
84}
85
86#[derive(Debug, Clone, Copy)]
87struct WorstScoredEntry<T>(ScoredEntry<T>);
88
89impl<T> Eq for WorstScoredEntry<T> {}
90
91impl<T> PartialEq for WorstScoredEntry<T> {
92 #[inline]
93 fn eq(&self, other: &Self) -> bool {
94 self.0 == other.0
95 }
96}
97
98impl<T> Ord for WorstScoredEntry<T> {
99 #[inline]
100 fn cmp(&self, other: &Self) -> Ordering {
101 other.0.cmp(&self.0)
102 }
103}
104
105impl<T> PartialOrd for WorstScoredEntry<T> {
106 #[inline]
107 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
108 Some(self.cmp(other))
109 }
110}
111
112#[inline]
113fn considered_limit(len: usize, context: &ObjectiveContext) -> usize {
114 context.max_candidates.unwrap_or(len).min(len)
115}
116
117pub trait Objective<T>: Send + Sync {
123 fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64;
125
126 #[inline]
137 fn precision(&self, _candidate: &T, _context: &ObjectiveContext) -> f64 {
138 1.0
139 }
140
141 #[inline]
145 fn passes_score(&self, score: f64, context: &ObjectiveContext) -> bool {
146 score.is_finite() && context.min_score.map(|min| score >= min).unwrap_or(true)
147 }
148
149 #[inline]
151 fn passes(&self, candidate: &T, context: &ObjectiveContext) -> bool {
152 let score = self.score(candidate, context);
153 self.passes_score(score, context)
154 }
155
156 fn batch_score(&self, candidates: &[T], context: &ObjectiveContext) -> Vec<(usize, f64)> {
161 let mut scored = Vec::with_capacity(candidates.len().min(256));
162 for (index, candidate) in candidates.iter().enumerate() {
163 let score = self.score(candidate, context);
164 if self.passes_score(score, context) {
165 scored.push((index, score));
166 }
167 }
168 scored
169 }
170
171 fn select<'a>(&self, candidates: &'a [T], context: &ObjectiveContext) -> Vec<Selection<&'a T>> {
177 if candidates.is_empty() {
178 return Vec::new();
179 }
180 let n = considered_limit(candidates.len(), context);
181 self.select_top(candidates, n, context)
182 }
183
184 fn select_top<'a>(
189 &self,
190 candidates: &'a [T],
191 n: usize,
192 context: &ObjectiveContext,
193 ) -> Vec<Selection<&'a T>> {
194 if n == 0 || candidates.is_empty() {
195 return Vec::new();
196 }
197
198 let considered_limit = considered_limit(candidates.len(), context);
199
200 let mut considered = 0usize;
201 let mut passed = 0usize;
202
203 if n <= SMALL_TOP_N {
204 let mut top: Vec<RankedIndex> = Vec::with_capacity(n.min(considered_limit));
205
206 for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
207 considered += 1;
208
209 let score = self.score(candidate, context);
210 if !self.passes_score(score, context) {
211 continue;
212 }
213
214 passed += 1;
215 let precision = self.precision(candidate, context);
216 let effective = score
217 * if precision.is_finite() {
218 precision
219 } else {
220 1.0
221 };
222 let entry = RankedIndex::new(effective, index);
223
224 if top.len() == n {
225 let worst = *top.last().expect("non-empty top when len == n");
226 if entry <= worst {
227 continue;
228 }
229 }
230
231 let pos = top.partition_point(|existing| *existing >= entry);
232 if pos < n {
233 top.insert(pos, entry);
234 if top.len() > n {
235 top.pop();
236 }
237 }
238 }
239
240 return top
241 .into_iter()
242 .map(|entry| {
243 Selection::new(&candidates[entry.index], entry.score, entry.index)
244 .with_considered(considered)
245 .with_passed(passed)
246 })
247 .collect();
248 }
249
250 let mut heap: BinaryHeap<WorstRankedIndex> = BinaryHeap::with_capacity(n);
251
252 for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
253 considered += 1;
254
255 let score = self.score(candidate, context);
256 if !self.passes_score(score, context) {
257 continue;
258 }
259
260 passed += 1;
261 let precision = self.precision(candidate, context);
262 let effective = score
263 * if precision.is_finite() {
264 precision
265 } else {
266 1.0
267 };
268 let entry = RankedIndex::new(effective, index);
269
270 if heap.len() < n {
271 heap.push(WorstRankedIndex(entry));
272 continue;
273 }
274
275 if let Some(mut worst) = heap.peek_mut() {
276 if entry > worst.0 {
277 *worst = WorstRankedIndex(entry);
278 }
279 }
280 }
281
282 let mut scored: Vec<RankedIndex> = heap.into_iter().map(|entry| entry.0).collect();
283 scored.sort_unstable_by(|a, b| b.cmp(a));
284
285 scored
286 .into_iter()
287 .map(|entry| {
288 Selection::new(&candidates[entry.index], entry.score, entry.index)
289 .with_considered(considered)
290 .with_passed(passed)
291 })
292 .collect()
293 }
294
295 fn name(&self) -> &str {
297 std::any::type_name::<Self>()
298 }
299}
300
301impl<T, F> Objective<T> for F
303where
304 F: Fn(&T, &ObjectiveContext) -> f64 + Send + Sync,
305{
306 #[inline]
307 fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
308 self(candidate, context)
309 }
310}
311
312pub fn objective_fn<T, F>(f: F) -> impl Objective<T>
314where
315 F: Fn(&T, &ObjectiveContext) -> f64 + Send + Sync,
316{
317 f
318}
319
320pub trait DeterministicObjective<T>: Objective<T>
333where
334 T: HasId,
335{
336 fn select_deterministic<'a>(
338 &self,
339 candidates: &'a [T],
340 context: &ObjectiveContext,
341 ) -> ObjectiveResult<Selection<&'a T>>;
342
343 fn select_top_deterministic<'a>(
345 &self,
346 candidates: &'a [T],
347 n: usize,
348 context: &ObjectiveContext,
349 ) -> Vec<Selection<&'a T>>;
350}
351
352impl<O, T> DeterministicObjective<T> for O
354where
355 O: Objective<T>,
356 T: HasId,
357{
358 fn select_deterministic<'a>(
359 &self,
360 candidates: &'a [T],
361 context: &ObjectiveContext,
362 ) -> ObjectiveResult<Selection<&'a T>> {
363 if candidates.is_empty() {
364 return Err(ObjectiveError::NoCandidates);
365 }
366
367 let considered_limit = considered_limit(candidates.len(), context);
368
369 let mut considered = 0usize;
370 let mut passed = 0usize;
371 let mut has_best = false;
372 let mut best_index = 0usize;
373 let mut best_score = 0.0f64;
374 let mut best_precision = 1.0f64;
375 let mut best_det = DeterministicScore::ZERO;
376 let mut best_id = Uuid::nil();
377
378 for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
379 considered += 1;
380
381 let score = self.score(candidate, context);
382 if !self.passes_score(score, context) {
383 continue;
384 }
385
386 passed += 1;
387 let id = candidate.id();
388 let precision = self.precision(candidate, context);
389 let effective = score
390 * if precision.is_finite() {
391 precision
392 } else {
393 1.0
394 };
395 let det = DeterministicScore::from_f64(effective);
396
397 if !has_best || det > best_det || (det == best_det && id < best_id) {
398 has_best = true;
399 best_index = index;
400 best_score = score;
401 best_precision = precision;
402 best_det = det;
403 best_id = id;
404 }
405 }
406
407 if has_best {
408 Ok(
409 Selection::new(&candidates[best_index], best_score, best_index)
410 .with_precision(best_precision)
411 .with_considered(considered)
412 .with_passed(passed),
413 )
414 } else {
415 Err(ObjectiveError::NoMatch("No candidate passed".into()))
416 }
417 }
418
419 fn select_top_deterministic<'a>(
420 &self,
421 candidates: &'a [T],
422 n: usize,
423 context: &ObjectiveContext,
424 ) -> Vec<Selection<&'a T>> {
425 if n == 0 || candidates.is_empty() {
426 return Vec::new();
427 }
428
429 if n == 1 {
430 return self
431 .select_deterministic(candidates, context)
432 .ok()
433 .into_iter()
434 .collect();
435 }
436
437 let considered_limit = considered_limit(candidates.len(), context);
438
439 let mut considered = 0usize;
440 let mut passed = 0usize;
441
442 if n <= SMALL_TOP_N {
443 let mut top: Vec<ScoredEntry<&T>> = Vec::with_capacity(n.min(considered_limit));
444
445 for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
446 considered += 1;
447
448 let score = self.score(candidate, context);
449 if !self.passes_score(score, context) {
450 continue;
451 }
452
453 passed += 1;
454 let precision = self.precision(candidate, context);
455 let effective = score
456 * if precision.is_finite() {
457 precision
458 } else {
459 1.0
460 };
461 let entry = ScoredEntry::new(candidate, effective, index);
462
463 if top.len() == n {
464 let worst = *top.last().expect("non-empty top when len == n");
465 if entry <= worst {
466 continue;
467 }
468 }
469
470 let pos = top.partition_point(|existing| *existing >= entry);
471 if pos < n {
472 top.insert(pos, entry);
473 if top.len() > n {
474 top.pop();
475 }
476 }
477 }
478
479 return top
480 .into_iter()
481 .map(|entry| {
482 Selection::new(entry.into_candidate(), entry.score(), entry.index())
483 .with_considered(considered)
484 .with_passed(passed)
485 })
486 .collect();
487 }
488
489 let mut heap: BinaryHeap<WorstScoredEntry<&T>> = BinaryHeap::with_capacity(n);
490
491 for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
492 considered += 1;
493
494 let score = self.score(candidate, context);
495 if !self.passes_score(score, context) {
496 continue;
497 }
498
499 passed += 1;
500 let precision = self.precision(candidate, context);
501 let effective = score
502 * if precision.is_finite() {
503 precision
504 } else {
505 1.0
506 };
507 let entry = ScoredEntry::new(candidate, effective, index);
508
509 if heap.len() < n {
510 heap.push(WorstScoredEntry(entry));
511 continue;
512 }
513
514 if let Some(mut worst) = heap.peek_mut() {
515 if entry > worst.0 {
516 *worst = WorstScoredEntry(entry);
517 }
518 }
519 }
520
521 let mut scored: Vec<ScoredEntry<&T>> = heap.into_iter().map(|entry| entry.0).collect();
522 scored.sort_unstable_by(|a, b| b.cmp(a));
523
524 scored
525 .into_iter()
526 .map(|entry| {
527 Selection::new(entry.into_candidate(), entry.score(), entry.index())
528 .with_considered(considered)
529 .with_passed(passed)
530 })
531 .collect()
532 }
533}