Skip to main content

khive_fold/objective/
traits.rs

1//! Core objective function traits
2
3use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5
6use uuid::Uuid;
7
8use khive_score::DeterministicScore;
9
10use super::context::ObjectiveContext;
11use super::selection::Selection;
12use crate::ordering::{HasId, ScoredEntry};
13use crate::{ObjectiveError, ObjectiveResult};
14
15const SMALL_TOP_N: usize = 96;
16
17#[derive(Debug, Clone, Copy)]
18struct RankedIndex {
19    score: f64,
20    det_score: DeterministicScore,
21    index: usize,
22}
23
24impl RankedIndex {
25    #[inline]
26    fn new(score: f64, index: usize) -> Self {
27        Self {
28            score,
29            det_score: DeterministicScore::from_f64(score),
30            index,
31        }
32    }
33}
34
35impl Eq for RankedIndex {}
36
37impl PartialEq for RankedIndex {
38    #[inline]
39    fn eq(&self, other: &Self) -> bool {
40        self.det_score == other.det_score && self.index == other.index
41    }
42}
43
44impl Ord for RankedIndex {
45    #[inline]
46    fn cmp(&self, other: &Self) -> Ordering {
47        self.det_score
48            .cmp(&other.det_score)
49            .then_with(|| other.index.cmp(&self.index))
50    }
51}
52
53impl PartialOrd for RankedIndex {
54    #[inline]
55    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
56        Some(self.cmp(other))
57    }
58}
59
60#[derive(Debug, Clone, Copy)]
61struct WorstRankedIndex(RankedIndex);
62
63impl Eq for WorstRankedIndex {}
64
65impl PartialEq for WorstRankedIndex {
66    #[inline]
67    fn eq(&self, other: &Self) -> bool {
68        self.0 == other.0
69    }
70}
71
72impl Ord for WorstRankedIndex {
73    #[inline]
74    fn cmp(&self, other: &Self) -> Ordering {
75        other.0.cmp(&self.0)
76    }
77}
78
79impl PartialOrd for WorstRankedIndex {
80    #[inline]
81    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
82        Some(self.cmp(other))
83    }
84}
85
86#[derive(Debug, Clone, Copy)]
87struct WorstScoredEntry<T>(ScoredEntry<T>);
88
89impl<T> Eq for WorstScoredEntry<T> {}
90
91impl<T> PartialEq for WorstScoredEntry<T> {
92    #[inline]
93    fn eq(&self, other: &Self) -> bool {
94        self.0 == other.0
95    }
96}
97
98impl<T> Ord for WorstScoredEntry<T> {
99    #[inline]
100    fn cmp(&self, other: &Self) -> Ordering {
101        other.0.cmp(&self.0)
102    }
103}
104
105impl<T> PartialOrd for WorstScoredEntry<T> {
106    #[inline]
107    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
108        Some(self.cmp(other))
109    }
110}
111
112#[inline]
113fn considered_limit(len: usize, context: &ObjectiveContext) -> usize {
114    context.max_candidates.unwrap_or(len).min(len)
115}
116
117/// Deterministic, composable objective function over a candidate set.
118pub trait Objective<T>: Send + Sync {
119    /// Evaluate a single candidate.
120    fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64;
121
122    /// Precision (inverse variance) of the score estimate; default 1.0 (fully trusted).
123    #[inline]
124    fn precision(&self, _candidate: &T, _context: &ObjectiveContext) -> f64 {
125        1.0
126    }
127
128    /// Check if a score passes the threshold; non-finite scores never pass.
129    #[inline]
130    fn passes_score(&self, score: f64, context: &ObjectiveContext) -> bool {
131        score.is_finite() && context.min_score.map(|min| score >= min).unwrap_or(true)
132    }
133
134    /// Check if a candidate passes the threshold.
135    #[inline]
136    fn passes(&self, candidate: &T, context: &ObjectiveContext) -> bool {
137        let score = self.score(candidate, context);
138        self.passes_score(score, context)
139    }
140
141    /// Score a batch of candidates and return passing `(index, score)` pairs.
142    fn batch_score(&self, candidates: &[T], context: &ObjectiveContext) -> Vec<(usize, f64)> {
143        let mut scored = Vec::with_capacity(candidates.len().min(256));
144        for (index, candidate) in candidates.iter().enumerate() {
145            let score = self.score(candidate, context);
146            if self.passes_score(score, context) {
147                scored.push((index, score));
148            }
149        }
150        scored
151    }
152
153    /// Select all passing candidates in score-descending order.
154    fn select<'a>(&self, candidates: &'a [T], context: &ObjectiveContext) -> Vec<Selection<&'a T>> {
155        if candidates.is_empty() {
156            return Vec::new();
157        }
158        let n = considered_limit(candidates.len(), context);
159        self.select_top(candidates, n, context)
160    }
161
162    /// Select the top N candidates by precision-weighted score.
163    fn select_top<'a>(
164        &self,
165        candidates: &'a [T],
166        n: usize,
167        context: &ObjectiveContext,
168    ) -> Vec<Selection<&'a T>> {
169        if n == 0 || candidates.is_empty() {
170            return Vec::new();
171        }
172
173        let considered_limit = considered_limit(candidates.len(), context);
174
175        let mut considered = 0usize;
176        let mut passed = 0usize;
177
178        if n <= SMALL_TOP_N {
179            let mut top: Vec<RankedIndex> = Vec::with_capacity(n.min(considered_limit));
180
181            for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
182                considered += 1;
183
184                let score = self.score(candidate, context);
185                if !self.passes_score(score, context) {
186                    continue;
187                }
188
189                passed += 1;
190                let precision = self.precision(candidate, context);
191                let effective = score
192                    * if precision.is_finite() {
193                        precision
194                    } else {
195                        1.0
196                    };
197                let entry = RankedIndex::new(effective, index);
198
199                if top.len() == n {
200                    let worst = *top.last().expect("non-empty top when len == n");
201                    if entry <= worst {
202                        continue;
203                    }
204                }
205
206                let pos = top.partition_point(|existing| *existing >= entry);
207                if pos < n {
208                    top.insert(pos, entry);
209                    if top.len() > n {
210                        top.pop();
211                    }
212                }
213            }
214
215            return top
216                .into_iter()
217                .map(|entry| {
218                    Selection::new(&candidates[entry.index], entry.score, entry.index)
219                        .with_considered(considered)
220                        .with_passed(passed)
221                })
222                .collect();
223        }
224
225        let mut heap: BinaryHeap<WorstRankedIndex> = BinaryHeap::with_capacity(n);
226
227        for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
228            considered += 1;
229
230            let score = self.score(candidate, context);
231            if !self.passes_score(score, context) {
232                continue;
233            }
234
235            passed += 1;
236            let precision = self.precision(candidate, context);
237            let effective = score
238                * if precision.is_finite() {
239                    precision
240                } else {
241                    1.0
242                };
243            let entry = RankedIndex::new(effective, index);
244
245            if heap.len() < n {
246                heap.push(WorstRankedIndex(entry));
247                continue;
248            }
249
250            if let Some(mut worst) = heap.peek_mut() {
251                if entry > worst.0 {
252                    *worst = WorstRankedIndex(entry);
253                }
254            }
255        }
256
257        let mut scored: Vec<RankedIndex> = heap.into_iter().map(|entry| entry.0).collect();
258        scored.sort_unstable_by(|a, b| b.cmp(a));
259
260        scored
261            .into_iter()
262            .map(|entry| {
263                Selection::new(&candidates[entry.index], entry.score, entry.index)
264                    .with_considered(considered)
265                    .with_passed(passed)
266            })
267            .collect()
268    }
269
270    /// Get the name of this objective.
271    fn name(&self) -> &str {
272        std::any::type_name::<Self>()
273    }
274}
275
276/// Implement Objective for closures.
277impl<T, F> Objective<T> for F
278where
279    F: Fn(&T, &ObjectiveContext) -> f64 + Send + Sync,
280{
281    #[inline]
282    fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
283        self(candidate, context)
284    }
285}
286
287/// Create an objective from a scoring function.
288pub fn objective_fn<T, F>(f: F) -> impl Objective<T>
289where
290    F: Fn(&T, &ObjectiveContext) -> f64 + Send + Sync,
291{
292    f
293}
294
295// ============================================================================
296// Deterministic Objective Extension
297// ============================================================================
298
299/// Extension trait for deterministic selection with UUID tie-breaking on equal scores.
300pub trait DeterministicObjective<T>: Objective<T>
301where
302    T: HasId,
303{
304    /// Select the best candidate with deterministic tie-breaking.
305    fn select_deterministic<'a>(
306        &self,
307        candidates: &'a [T],
308        context: &ObjectiveContext,
309    ) -> ObjectiveResult<Selection<&'a T>>;
310
311    /// Select the top N candidates with deterministic ordering.
312    fn select_top_deterministic<'a>(
313        &self,
314        candidates: &'a [T],
315        n: usize,
316        context: &ObjectiveContext,
317    ) -> Vec<Selection<&'a T>>;
318}
319
320/// Blanket implementation of `DeterministicObjective` for any `Objective<T>` where `T: HasId`.
321impl<O, T> DeterministicObjective<T> for O
322where
323    O: Objective<T>,
324    T: HasId,
325{
326    fn select_deterministic<'a>(
327        &self,
328        candidates: &'a [T],
329        context: &ObjectiveContext,
330    ) -> ObjectiveResult<Selection<&'a T>> {
331        if candidates.is_empty() {
332            return Err(ObjectiveError::NoCandidates);
333        }
334
335        let considered_limit = considered_limit(candidates.len(), context);
336
337        let mut considered = 0usize;
338        let mut passed = 0usize;
339        let mut has_best = false;
340        let mut best_index = 0usize;
341        let mut best_score = 0.0f64;
342        let mut best_precision = 1.0f64;
343        let mut best_det = DeterministicScore::ZERO;
344        let mut best_id = Uuid::nil();
345
346        for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
347            considered += 1;
348
349            let score = self.score(candidate, context);
350            if !self.passes_score(score, context) {
351                continue;
352            }
353
354            passed += 1;
355            let id = candidate.id();
356            let precision = self.precision(candidate, context);
357            let effective = score
358                * if precision.is_finite() {
359                    precision
360                } else {
361                    1.0
362                };
363            let det = DeterministicScore::from_f64(effective);
364
365            if !has_best || det > best_det || (det == best_det && id < best_id) {
366                has_best = true;
367                best_index = index;
368                best_score = score;
369                best_precision = precision;
370                best_det = det;
371                best_id = id;
372            }
373        }
374
375        if has_best {
376            Ok(
377                Selection::new(&candidates[best_index], best_score, best_index)
378                    .with_precision(best_precision)
379                    .with_considered(considered)
380                    .with_passed(passed),
381            )
382        } else {
383            Err(ObjectiveError::NoMatch("No candidate passed".into()))
384        }
385    }
386
387    fn select_top_deterministic<'a>(
388        &self,
389        candidates: &'a [T],
390        n: usize,
391        context: &ObjectiveContext,
392    ) -> Vec<Selection<&'a T>> {
393        if n == 0 || candidates.is_empty() {
394            return Vec::new();
395        }
396
397        if n == 1 {
398            return self
399                .select_deterministic(candidates, context)
400                .ok()
401                .into_iter()
402                .collect();
403        }
404
405        let considered_limit = considered_limit(candidates.len(), context);
406
407        let mut considered = 0usize;
408        let mut passed = 0usize;
409
410        if n <= SMALL_TOP_N {
411            let mut top: Vec<ScoredEntry<&T>> = Vec::with_capacity(n.min(considered_limit));
412
413            for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
414                considered += 1;
415
416                let score = self.score(candidate, context);
417                if !self.passes_score(score, context) {
418                    continue;
419                }
420
421                passed += 1;
422                let precision = self.precision(candidate, context);
423                let effective = score
424                    * if precision.is_finite() {
425                        precision
426                    } else {
427                        1.0
428                    };
429                let entry = ScoredEntry::new(candidate, effective, index);
430
431                if top.len() == n {
432                    let worst = *top.last().expect("non-empty top when len == n");
433                    if entry <= worst {
434                        continue;
435                    }
436                }
437
438                let pos = top.partition_point(|existing| *existing >= entry);
439                if pos < n {
440                    top.insert(pos, entry);
441                    if top.len() > n {
442                        top.pop();
443                    }
444                }
445            }
446
447            return top
448                .into_iter()
449                .map(|entry| {
450                    Selection::new(entry.into_candidate(), entry.score(), entry.index())
451                        .with_considered(considered)
452                        .with_passed(passed)
453                })
454                .collect();
455        }
456
457        let mut heap: BinaryHeap<WorstScoredEntry<&T>> = BinaryHeap::with_capacity(n);
458
459        for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
460            considered += 1;
461
462            let score = self.score(candidate, context);
463            if !self.passes_score(score, context) {
464                continue;
465            }
466
467            passed += 1;
468            let precision = self.precision(candidate, context);
469            let effective = score
470                * if precision.is_finite() {
471                    precision
472                } else {
473                    1.0
474                };
475            let entry = ScoredEntry::new(candidate, effective, index);
476
477            if heap.len() < n {
478                heap.push(WorstScoredEntry(entry));
479                continue;
480            }
481
482            if let Some(mut worst) = heap.peek_mut() {
483                if entry > worst.0 {
484                    *worst = WorstScoredEntry(entry);
485                }
486            }
487        }
488
489        let mut scored: Vec<ScoredEntry<&T>> = heap.into_iter().map(|entry| entry.0).collect();
490        scored.sort_unstable_by(|a, b| b.cmp(a));
491
492        scored
493            .into_iter()
494            .map(|entry| {
495                Selection::new(entry.into_candidate(), entry.score(), entry.index())
496                    .with_considered(considered)
497                    .with_passed(passed)
498            })
499            .collect()
500    }
501}