Skip to main content

khive_fold/objective/
traits.rs

1//! Core objective function traits
2
3use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5
6use uuid::Uuid;
7
8use khive_score::DeterministicScore;
9
10use super::context::ObjectiveContext;
11use super::selection::Selection;
12use crate::ordering::{HasId, ScoredEntry};
13use crate::{ObjectiveError, ObjectiveResult};
14
15const SMALL_TOP_N: usize = 96;
16
17#[derive(Debug, Clone, Copy)]
18struct RankedIndex {
19    score: f64,
20    det_score: DeterministicScore,
21    index: usize,
22}
23
24impl RankedIndex {
25    #[inline]
26    fn new(score: f64, index: usize) -> Self {
27        Self {
28            score,
29            det_score: DeterministicScore::from_f64(score),
30            index,
31        }
32    }
33}
34
35impl Eq for RankedIndex {}
36
37impl PartialEq for RankedIndex {
38    #[inline]
39    fn eq(&self, other: &Self) -> bool {
40        self.det_score == other.det_score && self.index == other.index
41    }
42}
43
44impl Ord for RankedIndex {
45    #[inline]
46    fn cmp(&self, other: &Self) -> Ordering {
47        self.det_score
48            .cmp(&other.det_score)
49            .then_with(|| other.index.cmp(&self.index))
50    }
51}
52
53impl PartialOrd for RankedIndex {
54    #[inline]
55    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
56        Some(self.cmp(other))
57    }
58}
59
60#[derive(Debug, Clone, Copy)]
61struct WorstRankedIndex(RankedIndex);
62
63impl Eq for WorstRankedIndex {}
64
65impl PartialEq for WorstRankedIndex {
66    #[inline]
67    fn eq(&self, other: &Self) -> bool {
68        self.0 == other.0
69    }
70}
71
72impl Ord for WorstRankedIndex {
73    #[inline]
74    fn cmp(&self, other: &Self) -> Ordering {
75        other.0.cmp(&self.0)
76    }
77}
78
79impl PartialOrd for WorstRankedIndex {
80    #[inline]
81    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
82        Some(self.cmp(other))
83    }
84}
85
86#[derive(Debug, Clone, Copy)]
87struct WorstScoredEntry<T>(ScoredEntry<T>);
88
89impl<T> Eq for WorstScoredEntry<T> {}
90
91impl<T> PartialEq for WorstScoredEntry<T> {
92    #[inline]
93    fn eq(&self, other: &Self) -> bool {
94        self.0 == other.0
95    }
96}
97
98impl<T> Ord for WorstScoredEntry<T> {
99    #[inline]
100    fn cmp(&self, other: &Self) -> Ordering {
101        other.0.cmp(&self.0)
102    }
103}
104
105impl<T> PartialOrd for WorstScoredEntry<T> {
106    #[inline]
107    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
108        Some(self.cmp(other))
109    }
110}
111
112#[inline]
113fn considered_limit(len: usize, context: &ObjectiveContext) -> usize {
114    context.max_candidates.unwrap_or(len).min(len)
115}
116
117/// Trait for objective functions that select from candidates.
118///
119/// An objective function is a measurement operator that collapses a space of
120/// possibilities into a single selection. Objectives are deterministic,
121/// composable, and introspectable.
122pub trait Objective<T>: Send + Sync {
123    /// Evaluate a single candidate.
124    fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64;
125
126    /// Precision (inverse variance) estimate for the score of a candidate.
127    ///
128    /// Default is 1.0 (fully trusted). Override when score reliability varies
129    /// across candidates — e.g., an embedding model that returns a confidence
130    /// alongside the similarity score. The effective ranking value used by the
131    /// default `select` / `select_top` implementations is `score * precision`
132    /// (ADR-059, Predictive Coding).
133    ///
134    /// When overriding, return values in (0, 1]. Non-finite values are treated
135    /// as 1.0 by the default implementations.
136    #[inline]
137    fn precision(&self, _candidate: &T, _context: &ObjectiveContext) -> f64 {
138        1.0
139    }
140
141    /// Check if a score passes the threshold.
142    ///
143    /// Non-finite scores never pass.
144    #[inline]
145    fn passes_score(&self, score: f64, context: &ObjectiveContext) -> bool {
146        score.is_finite() && context.min_score.map(|min| score >= min).unwrap_or(true)
147    }
148
149    /// Check if a candidate passes the threshold.
150    #[inline]
151    fn passes(&self, candidate: &T, context: &ObjectiveContext) -> bool {
152        let score = self.score(candidate, context);
153        self.passes_score(score, context)
154    }
155
156    /// Score a batch of candidates and return the passing `(index, score)` pairs.
157    ///
158    /// The default implementation is a scalar fallback. Objectives with SIMD-friendly
159    /// layouts can override this hook for higher throughput.
160    fn batch_score(&self, candidates: &[T], context: &ObjectiveContext) -> Vec<(usize, f64)> {
161        let mut scored = Vec::with_capacity(candidates.len().min(256));
162        for (index, candidate) in candidates.iter().enumerate() {
163            let score = self.score(candidate, context);
164            if self.passes_score(score, context) {
165                scored.push((index, score));
166            }
167        }
168        scored
169    }
170
171    /// Select candidates from a list, returning all that pass in score-descending order.
172    ///
173    /// Returns an empty vector when no candidates pass the threshold or the input is empty.
174    /// Delegates to `select_top` using the full considered limit so callers get a ranked
175    /// list rather than a single item. Use `.into_iter().next()` for single-best access.
176    fn select<'a>(&self, candidates: &'a [T], context: &ObjectiveContext) -> Vec<Selection<&'a T>> {
177        if candidates.is_empty() {
178            return Vec::new();
179        }
180        let n = considered_limit(candidates.len(), context);
181        self.select_top(candidates, n, context)
182    }
183
184    /// Select the top N candidates.
185    ///
186    /// Ranking uses `score * precision` (ADR-059). Small `n` (≤96) uses a sorted
187    /// small-vector path with binary-search insertion. Large `n` uses a worst-first heap.
188    fn select_top<'a>(
189        &self,
190        candidates: &'a [T],
191        n: usize,
192        context: &ObjectiveContext,
193    ) -> Vec<Selection<&'a T>> {
194        if n == 0 || candidates.is_empty() {
195            return Vec::new();
196        }
197
198        let considered_limit = considered_limit(candidates.len(), context);
199
200        let mut considered = 0usize;
201        let mut passed = 0usize;
202
203        if n <= SMALL_TOP_N {
204            let mut top: Vec<RankedIndex> = Vec::with_capacity(n.min(considered_limit));
205
206            for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
207                considered += 1;
208
209                let score = self.score(candidate, context);
210                if !self.passes_score(score, context) {
211                    continue;
212                }
213
214                passed += 1;
215                let precision = self.precision(candidate, context);
216                let effective = score
217                    * if precision.is_finite() {
218                        precision
219                    } else {
220                        1.0
221                    };
222                let entry = RankedIndex::new(effective, index);
223
224                if top.len() == n {
225                    let worst = *top.last().expect("non-empty top when len == n");
226                    if entry <= worst {
227                        continue;
228                    }
229                }
230
231                let pos = top.partition_point(|existing| *existing >= entry);
232                if pos < n {
233                    top.insert(pos, entry);
234                    if top.len() > n {
235                        top.pop();
236                    }
237                }
238            }
239
240            return top
241                .into_iter()
242                .map(|entry| {
243                    Selection::new(&candidates[entry.index], entry.score, entry.index)
244                        .with_considered(considered)
245                        .with_passed(passed)
246                })
247                .collect();
248        }
249
250        let mut heap: BinaryHeap<WorstRankedIndex> = BinaryHeap::with_capacity(n);
251
252        for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
253            considered += 1;
254
255            let score = self.score(candidate, context);
256            if !self.passes_score(score, context) {
257                continue;
258            }
259
260            passed += 1;
261            let precision = self.precision(candidate, context);
262            let effective = score
263                * if precision.is_finite() {
264                    precision
265                } else {
266                    1.0
267                };
268            let entry = RankedIndex::new(effective, index);
269
270            if heap.len() < n {
271                heap.push(WorstRankedIndex(entry));
272                continue;
273            }
274
275            if let Some(mut worst) = heap.peek_mut() {
276                if entry > worst.0 {
277                    *worst = WorstRankedIndex(entry);
278                }
279            }
280        }
281
282        let mut scored: Vec<RankedIndex> = heap.into_iter().map(|entry| entry.0).collect();
283        scored.sort_unstable_by(|a, b| b.cmp(a));
284
285        scored
286            .into_iter()
287            .map(|entry| {
288                Selection::new(&candidates[entry.index], entry.score, entry.index)
289                    .with_considered(considered)
290                    .with_passed(passed)
291            })
292            .collect()
293    }
294
295    /// Get the name of this objective.
296    fn name(&self) -> &str {
297        std::any::type_name::<Self>()
298    }
299}
300
301/// Implement Objective for closures.
302impl<T, F> Objective<T> for F
303where
304    F: Fn(&T, &ObjectiveContext) -> f64 + Send + Sync,
305{
306    #[inline]
307    fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
308        self(candidate, context)
309    }
310}
311
312/// Create an objective from a scoring function.
313pub fn objective_fn<T, F>(f: F) -> impl Objective<T>
314where
315    F: Fn(&T, &ObjectiveContext) -> f64 + Send + Sync,
316{
317    f
318}
319
320// ============================================================================
321// Deterministic Objective Extension
322// ============================================================================
323
324/// Extension trait for deterministic selection with UUID tie-breaking.
325///
326/// Provides reproducible ordering when multiple candidates have equal scores.
327/// Requires candidates to implement `HasId` for UUID-based tie-breaking.
328///
329/// Ordering is determined by:
330/// 1. Score (descending) using canonical IEEE-754 total-order ranks
331/// 2. UUID (ascending) for tie-breaking
332pub trait DeterministicObjective<T>: Objective<T>
333where
334    T: HasId,
335{
336    /// Select the best candidate with deterministic tie-breaking.
337    fn select_deterministic<'a>(
338        &self,
339        candidates: &'a [T],
340        context: &ObjectiveContext,
341    ) -> ObjectiveResult<Selection<&'a T>>;
342
343    /// Select the top N candidates with deterministic ordering.
344    fn select_top_deterministic<'a>(
345        &self,
346        candidates: &'a [T],
347        n: usize,
348        context: &ObjectiveContext,
349    ) -> Vec<Selection<&'a T>>;
350}
351
352/// Blanket implementation of DeterministicObjective for any Objective<T> where T: HasId.
353impl<O, T> DeterministicObjective<T> for O
354where
355    O: Objective<T>,
356    T: HasId,
357{
358    fn select_deterministic<'a>(
359        &self,
360        candidates: &'a [T],
361        context: &ObjectiveContext,
362    ) -> ObjectiveResult<Selection<&'a T>> {
363        if candidates.is_empty() {
364            return Err(ObjectiveError::NoCandidates);
365        }
366
367        let considered_limit = considered_limit(candidates.len(), context);
368
369        let mut considered = 0usize;
370        let mut passed = 0usize;
371        let mut has_best = false;
372        let mut best_index = 0usize;
373        let mut best_score = 0.0f64;
374        let mut best_precision = 1.0f64;
375        let mut best_det = DeterministicScore::ZERO;
376        let mut best_id = Uuid::nil();
377
378        for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
379            considered += 1;
380
381            let score = self.score(candidate, context);
382            if !self.passes_score(score, context) {
383                continue;
384            }
385
386            passed += 1;
387            let id = candidate.id();
388            let precision = self.precision(candidate, context);
389            let effective = score
390                * if precision.is_finite() {
391                    precision
392                } else {
393                    1.0
394                };
395            let det = DeterministicScore::from_f64(effective);
396
397            if !has_best || det > best_det || (det == best_det && id < best_id) {
398                has_best = true;
399                best_index = index;
400                best_score = score;
401                best_precision = precision;
402                best_det = det;
403                best_id = id;
404            }
405        }
406
407        if has_best {
408            Ok(
409                Selection::new(&candidates[best_index], best_score, best_index)
410                    .with_precision(best_precision)
411                    .with_considered(considered)
412                    .with_passed(passed),
413            )
414        } else {
415            Err(ObjectiveError::NoMatch("No candidate passed".into()))
416        }
417    }
418
419    fn select_top_deterministic<'a>(
420        &self,
421        candidates: &'a [T],
422        n: usize,
423        context: &ObjectiveContext,
424    ) -> Vec<Selection<&'a T>> {
425        if n == 0 || candidates.is_empty() {
426            return Vec::new();
427        }
428
429        if n == 1 {
430            return self
431                .select_deterministic(candidates, context)
432                .ok()
433                .into_iter()
434                .collect();
435        }
436
437        let considered_limit = considered_limit(candidates.len(), context);
438
439        let mut considered = 0usize;
440        let mut passed = 0usize;
441
442        if n <= SMALL_TOP_N {
443            let mut top: Vec<ScoredEntry<&T>> = Vec::with_capacity(n.min(considered_limit));
444
445            for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
446                considered += 1;
447
448                let score = self.score(candidate, context);
449                if !self.passes_score(score, context) {
450                    continue;
451                }
452
453                passed += 1;
454                let precision = self.precision(candidate, context);
455                let effective = score
456                    * if precision.is_finite() {
457                        precision
458                    } else {
459                        1.0
460                    };
461                let entry = ScoredEntry::new(candidate, effective, index);
462
463                if top.len() == n {
464                    let worst = *top.last().expect("non-empty top when len == n");
465                    if entry <= worst {
466                        continue;
467                    }
468                }
469
470                let pos = top.partition_point(|existing| *existing >= entry);
471                if pos < n {
472                    top.insert(pos, entry);
473                    if top.len() > n {
474                        top.pop();
475                    }
476                }
477            }
478
479            return top
480                .into_iter()
481                .map(|entry| {
482                    Selection::new(entry.into_candidate(), entry.score(), entry.index())
483                        .with_considered(considered)
484                        .with_passed(passed)
485                })
486                .collect();
487        }
488
489        let mut heap: BinaryHeap<WorstScoredEntry<&T>> = BinaryHeap::with_capacity(n);
490
491        for (index, candidate) in candidates.iter().take(considered_limit).enumerate() {
492            considered += 1;
493
494            let score = self.score(candidate, context);
495            if !self.passes_score(score, context) {
496                continue;
497            }
498
499            passed += 1;
500            let precision = self.precision(candidate, context);
501            let effective = score
502                * if precision.is_finite() {
503                    precision
504                } else {
505                    1.0
506                };
507            let entry = ScoredEntry::new(candidate, effective, index);
508
509            if heap.len() < n {
510                heap.push(WorstScoredEntry(entry));
511                continue;
512            }
513
514            if let Some(mut worst) = heap.peek_mut() {
515                if entry > worst.0 {
516                    *worst = WorstScoredEntry(entry);
517                }
518            }
519        }
520
521        let mut scored: Vec<ScoredEntry<&T>> = heap.into_iter().map(|entry| entry.0).collect();
522        scored.sort_unstable_by(|a, b| b.cmp(a));
523
524        scored
525            .into_iter()
526            .map(|entry| {
527                Selection::new(entry.into_candidate(), entry.score(), entry.index())
528                    .with_considered(considered)
529                    .with_passed(passed)
530            })
531            .collect()
532    }
533}