1use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5
6use uuid::Uuid;
7
8use khive_score::DeterministicScore;
9
10use super::context::ObjectiveContext;
11use super::selection::Selection;
12use crate::ordering::{HasId, ScoredEntry};
13use crate::{ObjectiveError, ObjectiveResult};
14
15const SMALL_TOP_N: usize = 96;
16
17#[derive(Debug, Clone, Copy)]
18struct RankedIndex {
19 score: f64,
20 det_score: DeterministicScore,
21 index: usize,
22}
23
24impl RankedIndex {
25 #[inline]
26 fn new(score: f64, index: usize) -> Self {
27 Self {
28 score,
29 det_score: DeterministicScore::from_f64(score),
30 index,
31 }
32 }
33}
34
35impl Eq for RankedIndex {}
36
37impl PartialEq for RankedIndex {
38 #[inline]
39 fn eq(&self, other: &Self) -> bool {
40 self.det_score == other.det_score && self.index == other.index
41 }
42}
43
44impl Ord for RankedIndex {
45 #[inline]
46 fn cmp(&self, other: &Self) -> Ordering {
47 self.det_score
48 .cmp(&other.det_score)
49 .then_with(|| other.index.cmp(&self.index))
50 }
51}
52
53impl PartialOrd for RankedIndex {
54 #[inline]
55 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
56 Some(self.cmp(other))
57 }
58}
59
60#[derive(Debug, Clone, Copy)]
61struct WorstRankedIndex(RankedIndex);
62
63impl Eq for WorstRankedIndex {}
64
65impl PartialEq for WorstRankedIndex {
66 #[inline]
67 fn eq(&self, other: &Self) -> bool {
68 self.0 == other.0
69 }
70}
71
72impl Ord for WorstRankedIndex {
73 #[inline]
74 fn cmp(&self, other: &Self) -> Ordering {
75 other.0.cmp(&self.0)
76 }
77}
78
79impl PartialOrd for WorstRankedIndex {
80 #[inline]
81 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
82 Some(self.cmp(other))
83 }
84}
85
86#[derive(Debug, Clone, Copy)]
87struct WorstScoredEntry<T>(ScoredEntry<T>);
88
89impl<T> Eq for WorstScoredEntry<T> {}
90
91impl<T> PartialEq for WorstScoredEntry<T> {
92 #[inline]
93 fn eq(&self, other: &Self) -> bool {
94 self.0 == other.0
95 }
96}
97
98impl<T> Ord for WorstScoredEntry<T> {
99 #[inline]
100 fn cmp(&self, other: &Self) -> Ordering {
101 other.0.cmp(&self.0)
102 }
103}
104
105impl<T> PartialOrd for WorstScoredEntry<T> {
106 #[inline]
107 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
108 Some(self.cmp(other))
109 }
110}
111
112#[inline]
113fn considered_limit(len: usize, context: &ObjectiveContext) -> usize {
114 context.max_candidates.unwrap_or(len).min(len)
115}
116
117pub trait Objective<T>: Send + Sync {
119 fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64;
121
122 #[inline]
124 fn precision(&self, _candidate: &T, _context: &ObjectiveContext) -> f64 {
125 1.0
126 }
127
128 #[inline]
130 fn passes_score(&self, score: f64, context: &ObjectiveContext) -> bool {
131 score.is_finite() && context.min_score.map(|min| score >= min).unwrap_or(true)
132 }
133
134 #[inline]
136 fn passes(&self, candidate: &T, context: &ObjectiveContext) -> bool {
137 let score = self.score(candidate, context);
138 self.passes_score(score, context)
139 }
140
141 fn batch_score(&self, candidates: &[T], context: &ObjectiveContext) -> Vec<(usize, f64)> {
143 let mut scored = Vec::with_capacity(candidates.len().min(256));
144 for (index, candidate) in candidates.iter().enumerate() {
145 let score = self.score(candidate, context);
146 if self.passes_score(score, context) {
147 scored.push((index, score));
148 }
149 }
150 scored
151 }
152
153 fn select<'a>(&self, candidates: &'a [T], context: &ObjectiveContext) -> Vec<Selection<&'a T>> {
155 if candidates.is_empty() {
156 return Vec::new();
157 }
158 let n = considered_limit(candidates.len(), context);
159 self.select_top(candidates, n, context)
160 }
161
162 fn select_top<'a>(
164 &self,
165 candidates: &'a [T],
166 n: usize,
167 context: &ObjectiveContext,
168 ) -> Vec<Selection<&'a T>> {
169 if n == 0 || candidates.is_empty() {
170 return Vec::new();
171 }
172
173 let considered_limit = considered_limit(candidates.len(), context);
174
175 let mut considered = 0usize;
176 let mut passed = 0usize;
177
178 if n <= SMALL_TOP_N {
179 let mut top: Vec<RankedIndex> = Vec::with_capacity(n.min(considered_limit));
180
181 for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
182 considered += 1;
183
184 let score = self.score(candidate, context);
185 if !self.passes_score(score, context) {
186 continue;
187 }
188
189 passed += 1;
190 let precision = self.precision(candidate, context);
191 let effective = score
192 * if precision.is_finite() {
193 precision
194 } else {
195 1.0
196 };
197 let entry = RankedIndex::new(effective, index);
198
199 if top.len() == n {
200 let worst = *top.last().expect("non-empty top when len == n");
201 if entry <= worst {
202 continue;
203 }
204 }
205
206 let pos = top.partition_point(|existing| *existing >= entry);
207 if pos < n {
208 top.insert(pos, entry);
209 if top.len() > n {
210 top.pop();
211 }
212 }
213 }
214
215 return top
216 .into_iter()
217 .map(|entry| {
218 Selection::new(&candidates[entry.index], entry.score, entry.index)
219 .with_considered(considered)
220 .with_passed(passed)
221 })
222 .collect();
223 }
224
225 let mut heap: BinaryHeap<WorstRankedIndex> = BinaryHeap::with_capacity(n);
226
227 for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
228 considered += 1;
229
230 let score = self.score(candidate, context);
231 if !self.passes_score(score, context) {
232 continue;
233 }
234
235 passed += 1;
236 let precision = self.precision(candidate, context);
237 let effective = score
238 * if precision.is_finite() {
239 precision
240 } else {
241 1.0
242 };
243 let entry = RankedIndex::new(effective, index);
244
245 if heap.len() < n {
246 heap.push(WorstRankedIndex(entry));
247 continue;
248 }
249
250 if let Some(mut worst) = heap.peek_mut() {
251 if entry > worst.0 {
252 *worst = WorstRankedIndex(entry);
253 }
254 }
255 }
256
257 let mut scored: Vec<RankedIndex> = heap.into_iter().map(|entry| entry.0).collect();
258 scored.sort_unstable_by(|a, b| b.cmp(a));
259
260 scored
261 .into_iter()
262 .map(|entry| {
263 Selection::new(&candidates[entry.index], entry.score, entry.index)
264 .with_considered(considered)
265 .with_passed(passed)
266 })
267 .collect()
268 }
269
270 fn name(&self) -> &str {
272 std::any::type_name::<Self>()
273 }
274}
275
276impl<T, F> Objective<T> for F
278where
279 F: Fn(&T, &ObjectiveContext) -> f64 + Send + Sync,
280{
281 #[inline]
282 fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
283 self(candidate, context)
284 }
285}
286
287pub fn objective_fn<T, F>(f: F) -> impl Objective<T>
289where
290 F: Fn(&T, &ObjectiveContext) -> f64 + Send + Sync,
291{
292 f
293}
294
295pub trait DeterministicObjective<T>: Objective<T>
301where
302 T: HasId,
303{
304 fn select_deterministic<'a>(
306 &self,
307 candidates: &'a [T],
308 context: &ObjectiveContext,
309 ) -> ObjectiveResult<Selection<&'a T>>;
310
311 fn select_top_deterministic<'a>(
313 &self,
314 candidates: &'a [T],
315 n: usize,
316 context: &ObjectiveContext,
317 ) -> Vec<Selection<&'a T>>;
318}
319
320impl<O, T> DeterministicObjective<T> for O
322where
323 O: Objective<T>,
324 T: HasId,
325{
326 fn select_deterministic<'a>(
327 &self,
328 candidates: &'a [T],
329 context: &ObjectiveContext,
330 ) -> ObjectiveResult<Selection<&'a T>> {
331 if candidates.is_empty() {
332 return Err(ObjectiveError::NoCandidates);
333 }
334
335 let considered_limit = considered_limit(candidates.len(), context);
336
337 let mut considered = 0usize;
338 let mut passed = 0usize;
339 let mut has_best = false;
340 let mut best_index = 0usize;
341 let mut best_score = 0.0f64;
342 let mut best_precision = 1.0f64;
343 let mut best_det = DeterministicScore::ZERO;
344 let mut best_id = Uuid::nil();
345
346 for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
347 considered += 1;
348
349 let score = self.score(candidate, context);
350 if !self.passes_score(score, context) {
351 continue;
352 }
353
354 passed += 1;
355 let id = candidate.id();
356 let precision = self.precision(candidate, context);
357 let effective = score
358 * if precision.is_finite() {
359 precision
360 } else {
361 1.0
362 };
363 let det = DeterministicScore::from_f64(effective);
364
365 if !has_best || det > best_det || (det == best_det && id < best_id) {
366 has_best = true;
367 best_index = index;
368 best_score = score;
369 best_precision = precision;
370 best_det = det;
371 best_id = id;
372 }
373 }
374
375 if has_best {
376 Ok(
377 Selection::new(&candidates[best_index], best_score, best_index)
378 .with_precision(best_precision)
379 .with_considered(considered)
380 .with_passed(passed),
381 )
382 } else {
383 Err(ObjectiveError::NoMatch("No candidate passed".into()))
384 }
385 }
386
387 fn select_top_deterministic<'a>(
388 &self,
389 candidates: &'a [T],
390 n: usize,
391 context: &ObjectiveContext,
392 ) -> Vec<Selection<&'a T>> {
393 if n == 0 || candidates.is_empty() {
394 return Vec::new();
395 }
396
397 if n == 1 {
398 return self
399 .select_deterministic(candidates, context)
400 .ok()
401 .into_iter()
402 .collect();
403 }
404
405 let considered_limit = considered_limit(candidates.len(), context);
406
407 let mut considered = 0usize;
408 let mut passed = 0usize;
409
410 if n <= SMALL_TOP_N {
411 let mut top: Vec<ScoredEntry<&T>> = Vec::with_capacity(n.min(considered_limit));
412
413 for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
414 considered += 1;
415
416 let score = self.score(candidate, context);
417 if !self.passes_score(score, context) {
418 continue;
419 }
420
421 passed += 1;
422 let precision = self.precision(candidate, context);
423 let effective = score
424 * if precision.is_finite() {
425 precision
426 } else {
427 1.0
428 };
429 let entry = ScoredEntry::new(candidate, effective, index);
430
431 if top.len() == n {
432 let worst = *top.last().expect("non-empty top when len == n");
433 if entry <= worst {
434 continue;
435 }
436 }
437
438 let pos = top.partition_point(|existing| *existing >= entry);
439 if pos < n {
440 top.insert(pos, entry);
441 if top.len() > n {
442 top.pop();
443 }
444 }
445 }
446
447 return top
448 .into_iter()
449 .map(|entry| {
450 Selection::new(entry.into_candidate(), entry.score(), entry.index())
451 .with_considered(considered)
452 .with_passed(passed)
453 })
454 .collect();
455 }
456
457 let mut heap: BinaryHeap<WorstScoredEntry<&T>> = BinaryHeap::with_capacity(n);
458
459 for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
460 considered += 1;
461
462 let score = self.score(candidate, context);
463 if !self.passes_score(score, context) {
464 continue;
465 }
466
467 passed += 1;
468 let precision = self.precision(candidate, context);
469 let effective = score
470 * if precision.is_finite() {
471 precision
472 } else {
473 1.0
474 };
475 let entry = ScoredEntry::new(candidate, effective, index);
476
477 if heap.len() < n {
478 heap.push(WorstScoredEntry(entry));
479 continue;
480 }
481
482 if let Some(mut worst) = heap.peek_mut() {
483 if entry > worst.0 {
484 *worst = WorstScoredEntry(entry);
485 }
486 }
487 }
488
489 let mut scored: Vec<ScoredEntry<&T>> = heap.into_iter().map(|entry| entry.0).collect();
490 scored.sort_unstable_by(|a, b| b.cmp(a));
491
492 scored
493 .into_iter()
494 .map(|entry| {
495 Selection::new(entry.into_candidate(), entry.score(), entry.index())
496 .with_considered(considered)
497 .with_passed(passed)
498 })
499 .collect()
500 }
501}