Skip to main content

khive_fold/objective/
traits.rs

1//! Core objective function traits
2
3use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5
6use uuid::Uuid;
7
8use khive_score::DeterministicScore;
9
10use super::context::ObjectiveContext;
11use super::selection::Selection;
12use crate::ordering::{HasId, ScoredEntry};
13use crate::{ObjectiveError, ObjectiveResult};
14
15const SMALL_TOP_N: usize = 96;
16
17#[derive(Debug, Clone, Copy)]
18struct RankedIndex {
19    score: f64,
20    det_score: DeterministicScore,
21    index: usize,
22}
23
24impl RankedIndex {
25    #[inline]
26    fn new(score: f64, index: usize) -> Self {
27        Self {
28            score,
29            det_score: DeterministicScore::from_f64(score),
30            index,
31        }
32    }
33}
34
35impl Eq for RankedIndex {}
36
37impl PartialEq for RankedIndex {
38    #[inline]
39    fn eq(&self, other: &Self) -> bool {
40        self.det_score == other.det_score && self.index == other.index
41    }
42}
43
44impl Ord for RankedIndex {
45    #[inline]
46    fn cmp(&self, other: &Self) -> Ordering {
47        self.det_score
48            .cmp(&other.det_score)
49            .then_with(|| other.index.cmp(&self.index))
50    }
51}
52
53impl PartialOrd for RankedIndex {
54    #[inline]
55    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
56        Some(self.cmp(other))
57    }
58}
59
60#[derive(Debug, Clone, Copy)]
61struct WorstRankedIndex(RankedIndex);
62
63impl Eq for WorstRankedIndex {}
64
65impl PartialEq for WorstRankedIndex {
66    #[inline]
67    fn eq(&self, other: &Self) -> bool {
68        self.0 == other.0
69    }
70}
71
72impl Ord for WorstRankedIndex {
73    #[inline]
74    fn cmp(&self, other: &Self) -> Ordering {
75        other.0.cmp(&self.0)
76    }
77}
78
79impl PartialOrd for WorstRankedIndex {
80    #[inline]
81    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
82        Some(self.cmp(other))
83    }
84}
85
86#[derive(Debug, Clone, Copy)]
87struct WorstScoredEntry<T>(ScoredEntry<T>);
88
89impl<T> Eq for WorstScoredEntry<T> {}
90
91impl<T> PartialEq for WorstScoredEntry<T> {
92    #[inline]
93    fn eq(&self, other: &Self) -> bool {
94        self.0 == other.0
95    }
96}
97
98impl<T> Ord for WorstScoredEntry<T> {
99    #[inline]
100    fn cmp(&self, other: &Self) -> Ordering {
101        other.0.cmp(&self.0)
102    }
103}
104
105impl<T> PartialOrd for WorstScoredEntry<T> {
106    #[inline]
107    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
108        Some(self.cmp(other))
109    }
110}
111
112#[inline]
113fn considered_limit(len: usize, context: &ObjectiveContext) -> usize {
114    context.max_candidates.unwrap_or(len).min(len)
115}
116
117/// Trait for objective functions that select from candidates.
118///
119/// An objective function is a measurement operator that collapses a space of
120/// possibilities into a single selection. Objectives are deterministic,
121/// composable, and introspectable.
122pub trait Objective<T>: Send + Sync {
123    /// Evaluate a single candidate.
124    fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64;
125
126    /// Precision (inverse variance) estimate for the score of a candidate.
127    ///
128    /// Default is 1.0 (fully trusted). Override when score reliability varies
129    /// across candidates — e.g., an embedding model that returns a confidence
130    /// alongside the similarity score. The effective ranking value used by the
131    /// default `select` / `select_top` implementations is `score * precision`
132    /// (ADR-059, Predictive Coding).
133    ///
134    /// When overriding, return values in (0, 1]. Non-finite values are treated
135    /// as 1.0 by the default implementations.
136    #[inline]
137    fn precision(&self, _candidate: &T, _context: &ObjectiveContext) -> f64 {
138        1.0
139    }
140
141    /// Check if a score passes the threshold.
142    ///
143    /// Non-finite scores never pass.
144    #[inline]
145    fn passes_score(&self, score: f64, context: &ObjectiveContext) -> bool {
146        score.is_finite() && context.min_score.map(|min| score >= min).unwrap_or(true)
147    }
148
149    /// Check if a candidate passes the threshold.
150    #[inline]
151    fn passes(&self, candidate: &T, context: &ObjectiveContext) -> bool {
152        let score = self.score(candidate, context);
153        self.passes_score(score, context)
154    }
155
156    /// Score a batch of candidates and return the passing `(index, score)` pairs.
157    ///
158    /// The default implementation is a scalar fallback. Objectives with SIMD-friendly
159    /// layouts can override this hook for higher throughput.
160    fn batch_score(&self, candidates: &[T], context: &ObjectiveContext) -> Vec<(usize, f64)> {
161        let mut scored = Vec::with_capacity(candidates.len().min(256));
162        for (index, candidate) in candidates.iter().enumerate() {
163            let score = self.score(candidate, context);
164            if self.passes_score(score, context) {
165                scored.push((index, score));
166            }
167        }
168        scored
169    }
170
171    /// Select the best candidate from a list.
172    ///
173    /// Ranking uses `score * precision` so that unreliable high-scores do not
174    /// dominate over lower-scoring but precise candidates (ADR-059). When all
175    /// precisions are 1.0 (the default), ranking is identical to raw score order.
176    fn select<'a>(
177        &self,
178        candidates: &'a [T],
179        context: &ObjectiveContext,
180    ) -> ObjectiveResult<Selection<&'a T>> {
181        if candidates.is_empty() {
182            return Err(ObjectiveError::NoCandidates);
183        }
184
185        let considered_limit = considered_limit(candidates.len(), context);
186
187        let mut considered = 0usize;
188        let mut passed = 0usize;
189        let mut has_best = false;
190        let mut best_index = 0usize;
191        let mut best_score = 0.0f64;
192        let mut best_precision = 1.0f64;
193        let mut best_det = DeterministicScore::ZERO;
194
195        for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
196            considered += 1;
197
198            let score = self.score(candidate, context);
199            if !self.passes_score(score, context) {
200                continue;
201            }
202
203            passed += 1;
204
205            let precision = self.precision(candidate, context);
206            let effective = score
207                * if precision.is_finite() {
208                    precision
209                } else {
210                    1.0
211                };
212            let det = DeterministicScore::from_f64(effective);
213            if !has_best || det > best_det {
214                has_best = true;
215                best_index = index;
216                best_score = score;
217                best_precision = precision;
218                best_det = det;
219            }
220        }
221
222        if has_best {
223            Ok(
224                Selection::new(&candidates[best_index], best_score, best_index)
225                    .with_precision(best_precision)
226                    .with_considered(considered)
227                    .with_passed(passed),
228            )
229        } else {
230            Err(ObjectiveError::NoMatch("No candidate passed".into()))
231        }
232    }
233
234    /// Select the top N candidates.
235    ///
236    /// Ranking uses `score * precision` (ADR-059). Small `n` (≤96) uses a sorted
237    /// small-vector path with binary-search insertion. Large `n` uses a worst-first heap.
238    fn select_top<'a>(
239        &self,
240        candidates: &'a [T],
241        n: usize,
242        context: &ObjectiveContext,
243    ) -> Vec<Selection<&'a T>> {
244        if n == 0 || candidates.is_empty() {
245            return Vec::new();
246        }
247
248        if n == 1 {
249            return self.select(candidates, context).ok().into_iter().collect();
250        }
251
252        let considered_limit = considered_limit(candidates.len(), context);
253
254        let mut considered = 0usize;
255        let mut passed = 0usize;
256
257        if n <= SMALL_TOP_N {
258            let mut top: Vec<RankedIndex> = Vec::with_capacity(n.min(considered_limit));
259
260            for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
261                considered += 1;
262
263                let score = self.score(candidate, context);
264                if !self.passes_score(score, context) {
265                    continue;
266                }
267
268                passed += 1;
269                let precision = self.precision(candidate, context);
270                let effective = score
271                    * if precision.is_finite() {
272                        precision
273                    } else {
274                        1.0
275                    };
276                let entry = RankedIndex::new(effective, index);
277
278                if top.len() == n {
279                    let worst = *top.last().expect("non-empty top when len == n");
280                    if entry <= worst {
281                        continue;
282                    }
283                }
284
285                let pos = top.partition_point(|existing| *existing >= entry);
286                if pos < n {
287                    top.insert(pos, entry);
288                    if top.len() > n {
289                        top.pop();
290                    }
291                }
292            }
293
294            return top
295                .into_iter()
296                .map(|entry| {
297                    Selection::new(&candidates[entry.index], entry.score, entry.index)
298                        .with_considered(considered)
299                        .with_passed(passed)
300                })
301                .collect();
302        }
303
304        let mut heap: BinaryHeap<WorstRankedIndex> = BinaryHeap::with_capacity(n);
305
306        for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
307            considered += 1;
308
309            let score = self.score(candidate, context);
310            if !self.passes_score(score, context) {
311                continue;
312            }
313
314            passed += 1;
315            let precision = self.precision(candidate, context);
316            let effective = score
317                * if precision.is_finite() {
318                    precision
319                } else {
320                    1.0
321                };
322            let entry = RankedIndex::new(effective, index);
323
324            if heap.len() < n {
325                heap.push(WorstRankedIndex(entry));
326                continue;
327            }
328
329            if let Some(mut worst) = heap.peek_mut() {
330                if entry > worst.0 {
331                    *worst = WorstRankedIndex(entry);
332                }
333            }
334        }
335
336        let mut scored: Vec<RankedIndex> = heap.into_iter().map(|entry| entry.0).collect();
337        scored.sort_unstable_by(|a, b| b.cmp(a));
338
339        scored
340            .into_iter()
341            .map(|entry| {
342                Selection::new(&candidates[entry.index], entry.score, entry.index)
343                    .with_considered(considered)
344                    .with_passed(passed)
345            })
346            .collect()
347    }
348
349    /// Get the name of this objective.
350    fn name(&self) -> &str {
351        std::any::type_name::<Self>()
352    }
353}
354
355/// Implement Objective for closures.
356impl<T, F> Objective<T> for F
357where
358    F: Fn(&T, &ObjectiveContext) -> f64 + Send + Sync,
359{
360    #[inline]
361    fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
362        self(candidate, context)
363    }
364}
365
366/// Create an objective from a scoring function.
367pub fn objective_fn<T, F>(f: F) -> impl Objective<T>
368where
369    F: Fn(&T, &ObjectiveContext) -> f64 + Send + Sync,
370{
371    f
372}
373
374// ============================================================================
375// Deterministic Objective Extension
376// ============================================================================
377
378/// Extension trait for deterministic selection with UUID tie-breaking.
379///
380/// Provides reproducible ordering when multiple candidates have equal scores.
381/// Requires candidates to implement `HasId` for UUID-based tie-breaking.
382///
383/// Ordering is determined by:
384/// 1. Score (descending) using canonical IEEE-754 total-order ranks
385/// 2. UUID (ascending) for tie-breaking
386pub trait DeterministicObjective<T>: Objective<T>
387where
388    T: HasId,
389{
390    /// Select the best candidate with deterministic tie-breaking.
391    fn select_deterministic<'a>(
392        &self,
393        candidates: &'a [T],
394        context: &ObjectiveContext,
395    ) -> ObjectiveResult<Selection<&'a T>>;
396
397    /// Select the top N candidates with deterministic ordering.
398    fn select_top_deterministic<'a>(
399        &self,
400        candidates: &'a [T],
401        n: usize,
402        context: &ObjectiveContext,
403    ) -> Vec<Selection<&'a T>>;
404}
405
406/// Blanket implementation of DeterministicObjective for any Objective<T> where T: HasId.
407impl<O, T> DeterministicObjective<T> for O
408where
409    O: Objective<T>,
410    T: HasId,
411{
412    fn select_deterministic<'a>(
413        &self,
414        candidates: &'a [T],
415        context: &ObjectiveContext,
416    ) -> ObjectiveResult<Selection<&'a T>> {
417        if candidates.is_empty() {
418            return Err(ObjectiveError::NoCandidates);
419        }
420
421        let considered_limit = considered_limit(candidates.len(), context);
422
423        let mut considered = 0usize;
424        let mut passed = 0usize;
425        let mut has_best = false;
426        let mut best_index = 0usize;
427        let mut best_score = 0.0f64;
428        let mut best_precision = 1.0f64;
429        let mut best_det = DeterministicScore::ZERO;
430        let mut best_id = Uuid::nil();
431
432        for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
433            considered += 1;
434
435            let score = self.score(candidate, context);
436            if !self.passes_score(score, context) {
437                continue;
438            }
439
440            passed += 1;
441            let id = candidate.id();
442            let precision = self.precision(candidate, context);
443            let effective = score
444                * if precision.is_finite() {
445                    precision
446                } else {
447                    1.0
448                };
449            let det = DeterministicScore::from_f64(effective);
450
451            if !has_best || det > best_det || (det == best_det && id < best_id) {
452                has_best = true;
453                best_index = index;
454                best_score = score;
455                best_precision = precision;
456                best_det = det;
457                best_id = id;
458            }
459        }
460
461        if has_best {
462            Ok(
463                Selection::new(&candidates[best_index], best_score, best_index)
464                    .with_precision(best_precision)
465                    .with_considered(considered)
466                    .with_passed(passed),
467            )
468        } else {
469            Err(ObjectiveError::NoMatch("No candidate passed".into()))
470        }
471    }
472
473    fn select_top_deterministic<'a>(
474        &self,
475        candidates: &'a [T],
476        n: usize,
477        context: &ObjectiveContext,
478    ) -> Vec<Selection<&'a T>> {
479        if n == 0 || candidates.is_empty() {
480            return Vec::new();
481        }
482
483        if n == 1 {
484            return self
485                .select_deterministic(candidates, context)
486                .ok()
487                .into_iter()
488                .collect();
489        }
490
491        let considered_limit = considered_limit(candidates.len(), context);
492
493        let mut considered = 0usize;
494        let mut passed = 0usize;
495
496        if n <= SMALL_TOP_N {
497            let mut top: Vec<ScoredEntry<&T>> = Vec::with_capacity(n.min(considered_limit));
498
499            for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
500                considered += 1;
501
502                let score = self.score(candidate, context);
503                if !self.passes_score(score, context) {
504                    continue;
505                }
506
507                passed += 1;
508                let precision = self.precision(candidate, context);
509                let effective = score
510                    * if precision.is_finite() {
511                        precision
512                    } else {
513                        1.0
514                    };
515                let entry = ScoredEntry::new(candidate, effective, index);
516
517                if top.len() == n {
518                    let worst = *top.last().expect("non-empty top when len == n");
519                    if entry <= worst {
520                        continue;
521                    }
522                }
523
524                let pos = top.partition_point(|existing| *existing >= entry);
525                if pos < n {
526                    top.insert(pos, entry);
527                    if top.len() > n {
528                        top.pop();
529                    }
530                }
531            }
532
533            return top
534                .into_iter()
535                .map(|entry| {
536                    Selection::new(entry.into_candidate(), entry.score(), entry.index())
537                        .with_considered(considered)
538                        .with_passed(passed)
539                })
540                .collect();
541        }
542
543        let mut heap: BinaryHeap<WorstScoredEntry<&T>> = BinaryHeap::with_capacity(n);
544
545        for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
546            considered += 1;
547
548            let score = self.score(candidate, context);
549            if !self.passes_score(score, context) {
550                continue;
551            }
552
553            passed += 1;
554            let precision = self.precision(candidate, context);
555            let effective = score
556                * if precision.is_finite() {
557                    precision
558                } else {
559                    1.0
560                };
561            let entry = ScoredEntry::new(candidate, effective, index);
562
563            if heap.len() < n {
564                heap.push(WorstScoredEntry(entry));
565                continue;
566            }
567
568            if let Some(mut worst) = heap.peek_mut() {
569                if entry > worst.0 {
570                    *worst = WorstScoredEntry(entry);
571                }
572            }
573        }
574
575        let mut scored: Vec<ScoredEntry<&T>> = heap.into_iter().map(|entry| entry.0).collect();
576        scored.sort_unstable_by(|a, b| b.cmp(a));
577
578        scored
579            .into_iter()
580            .map(|entry| {
581                Selection::new(entry.into_candidate(), entry.score(), entry.index())
582                    .with_considered(considered)
583                    .with_passed(passed)
584            })
585            .collect()
586    }
587}