Skip to main content

khive_fold/ordering/
scored_entry.rs

1//! ScoredEntry wrapper for heap operations
2
3use std::cmp::Ordering;
4use std::hash::{Hash, Hasher};
5use uuid::Uuid;
6
7use khive_score::DeterministicScore;
8
9use super::has_id::HasId;
10
11/// A wrapper for scored candidates that implements deterministic `Ord`.
12///
13/// This struct caches the score as a `DeterministicScore` (i64 fixed-point) for
14/// cross-platform deterministic ordering, along with the UUID for tie-breaking.
15///
16/// # Ordering
17///
18/// `ScoredEntry` orders by:
19/// 1. Score descending (higher scores first)
20/// 2. UUID ascending (lower UUIDs first for tie-breaking)
21#[derive(Debug, Clone, Copy)]
22pub struct ScoredEntry<T> {
23    /// The candidate being scored
24    candidate: T,
25    /// Cached raw score value (human-readable)
26    score: f64,
27    /// Cached UUID for tie-breaking
28    id: Uuid,
29    /// Original index in the candidate list
30    index: usize,
31    /// Deterministic fixed-point score for ordering
32    det_score: DeterministicScore,
33}
34
35impl<T: HasId> ScoredEntry<T> {
36    /// Create a new scored entry.
37    #[inline]
38    pub fn new(candidate: T, score: f64, index: usize) -> Self {
39        let id = candidate.id();
40        let det_score = DeterministicScore::from_f64(score);
41        Self {
42            candidate,
43            score,
44            id,
45            index,
46            det_score,
47        }
48    }
49
50    /// Get the candidate reference.
51    #[inline]
52    pub fn candidate(&self) -> &T {
53        &self.candidate
54    }
55
56    /// Consume and return the candidate.
57    #[inline]
58    pub fn into_candidate(self) -> T {
59        self.candidate
60    }
61
62    /// Get the cached score.
63    #[inline]
64    pub fn score(&self) -> f64 {
65        self.score
66    }
67
68    /// Get the cached UUID.
69    #[inline]
70    pub fn id(&self) -> Uuid {
71        self.id
72    }
73
74    /// Get the original index.
75    #[inline]
76    pub fn index(&self) -> usize {
77        self.index
78    }
79
80    /// Get the deterministic fixed-point score.
81    #[inline]
82    pub fn det_score(&self) -> DeterministicScore {
83        self.det_score
84    }
85}
86
87impl<T> Eq for ScoredEntry<T> {}
88
89impl<T> PartialEq for ScoredEntry<T> {
90    #[inline]
91    fn eq(&self, other: &Self) -> bool {
92        self.det_score == other.det_score && self.id == other.id
93    }
94}
95
96impl<T> Ord for ScoredEntry<T> {
97    /// Compare entries for use with BinaryHeap and sorted collections.
98    ///
99    /// Higher scores are "greater" (popped first from max-heap).
100    /// On equal scores, lower UUIDs are "greater" (popped first from max-heap).
101    /// NaN scores map to DeterministicScore::ZERO and sort after all normal values.
102    #[inline]
103    fn cmp(&self, other: &Self) -> Ordering {
104        self.det_score
105            .cmp(&other.det_score)
106            .then_with(|| other.id.cmp(&self.id))
107    }
108}
109
110impl<T> PartialOrd for ScoredEntry<T> {
111    #[inline]
112    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
113        Some(self.cmp(other))
114    }
115}
116
117impl<T> Hash for ScoredEntry<T> {
118    fn hash<H: Hasher>(&self, state: &mut H) {
119        self.det_score.hash(state);
120        self.id.hash(state);
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use std::collections::BinaryHeap;
128
129    #[derive(Debug)]
130    struct Item {
131        id: Uuid,
132    }
133
134    impl HasId for Item {
135        fn id(&self) -> Uuid {
136            self.id
137        }
138    }
139
140    #[test]
141    fn test_scored_entry_higher_score_pops_first() {
142        let id_a = Uuid::from_u128(1);
143        let id_b = Uuid::from_u128(2);
144
145        let mut heap = BinaryHeap::new();
146        heap.push(ScoredEntry::new(Item { id: id_a }, 0.3, 0));
147        heap.push(ScoredEntry::new(Item { id: id_b }, 0.9, 1));
148
149        let first = heap.pop().unwrap();
150        assert_eq!(first.id(), id_b, "Higher score pops first");
151        assert!((first.score() - 0.9).abs() < 1e-10);
152    }
153
154    #[test]
155    fn test_scored_entry_uuid_tie_breaking() {
156        let id_a = Uuid::from_u128(1);
157        let id_b = Uuid::from_u128(2);
158
159        let mut heap = BinaryHeap::new();
160        heap.push(ScoredEntry::new(Item { id: id_b }, 0.5, 1));
161        heap.push(ScoredEntry::new(Item { id: id_a }, 0.5, 0));
162
163        // Lower UUID (id_a=1) should pop first
164        let first = heap.pop().unwrap();
165        assert_eq!(first.id(), id_a, "Lower UUID pops first on tie");
166    }
167
168    #[test]
169    fn test_scored_entry_nan_pops_last() {
170        let id_nan = Uuid::from_u128(1);
171        let id_normal = Uuid::from_u128(2);
172
173        let mut heap = BinaryHeap::new();
174        heap.push(ScoredEntry::new(Item { id: id_nan }, f64::NAN, 0));
175        heap.push(ScoredEntry::new(Item { id: id_normal }, 0.5, 1));
176
177        let first = heap.pop().unwrap();
178        assert_eq!(first.id(), id_normal, "Normal score pops before NaN");
179
180        let second = heap.pop().unwrap();
181        assert_eq!(second.id(), id_nan, "NaN pops last");
182    }
183
184    #[test]
185    fn test_scored_entry_index_preserved() {
186        let id = Uuid::from_u128(42);
187        let entry = ScoredEntry::new(Item { id }, 0.7, 99);
188        assert_eq!(entry.index(), 99);
189    }
190
191    #[test]
192    fn test_scored_entry_equality_by_det_score_and_id() {
193        let id_a = Uuid::from_u128(1);
194        let id_b = Uuid::from_u128(1); // same UUID
195        let a = ScoredEntry::new(Item { id: id_a }, 0.5, 0);
196        let b = ScoredEntry::new(Item { id: id_b }, 0.5, 99); // different index
197        assert_eq!(a, b, "Equality ignores index");
198    }
199
200    #[test]
201    fn test_scored_entry_det_score_accessor() {
202        let id = Uuid::from_u128(1);
203        let entry = ScoredEntry::new(Item { id }, 0.75, 0);
204        let det = entry.det_score();
205        assert!((det.to_f64() - 0.75).abs() < 1e-9);
206    }
207}