khive_fold/ordering/
scored_entry.rs1use std::cmp::Ordering;
4use std::hash::{Hash, Hasher};
5use uuid::Uuid;
6
7use khive_score::DeterministicScore;
8
9use super::has_id::HasId;
10
11#[derive(Debug, Clone, Copy)]
22pub struct ScoredEntry<T> {
23 candidate: T,
25 score: f64,
27 id: Uuid,
29 index: usize,
31 det_score: DeterministicScore,
33}
34
35impl<T: HasId> ScoredEntry<T> {
36 #[inline]
38 pub fn new(candidate: T, score: f64, index: usize) -> Self {
39 let id = candidate.id();
40 let det_score = DeterministicScore::from_f64(score);
41 Self {
42 candidate,
43 score,
44 id,
45 index,
46 det_score,
47 }
48 }
49
50 #[inline]
52 pub fn candidate(&self) -> &T {
53 &self.candidate
54 }
55
56 #[inline]
58 pub fn into_candidate(self) -> T {
59 self.candidate
60 }
61
62 #[inline]
64 pub fn score(&self) -> f64 {
65 self.score
66 }
67
68 #[inline]
70 pub fn id(&self) -> Uuid {
71 self.id
72 }
73
74 #[inline]
76 pub fn index(&self) -> usize {
77 self.index
78 }
79
80 #[inline]
82 pub fn det_score(&self) -> DeterministicScore {
83 self.det_score
84 }
85}
86
87impl<T> Eq for ScoredEntry<T> {}
88
89impl<T> PartialEq for ScoredEntry<T> {
90 #[inline]
91 fn eq(&self, other: &Self) -> bool {
92 self.det_score == other.det_score && self.id == other.id
93 }
94}
95
96impl<T> Ord for ScoredEntry<T> {
97 #[inline]
103 fn cmp(&self, other: &Self) -> Ordering {
104 self.det_score
105 .cmp(&other.det_score)
106 .then_with(|| other.id.cmp(&self.id))
107 }
108}
109
110impl<T> PartialOrd for ScoredEntry<T> {
111 #[inline]
112 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
113 Some(self.cmp(other))
114 }
115}
116
117impl<T> Hash for ScoredEntry<T> {
118 fn hash<H: Hasher>(&self, state: &mut H) {
119 self.det_score.hash(state);
120 self.id.hash(state);
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use std::collections::BinaryHeap;
128
129 #[derive(Debug)]
130 struct Item {
131 id: Uuid,
132 }
133
134 impl HasId for Item {
135 fn id(&self) -> Uuid {
136 self.id
137 }
138 }
139
140 #[test]
141 fn test_scored_entry_higher_score_pops_first() {
142 let id_a = Uuid::from_u128(1);
143 let id_b = Uuid::from_u128(2);
144
145 let mut heap = BinaryHeap::new();
146 heap.push(ScoredEntry::new(Item { id: id_a }, 0.3, 0));
147 heap.push(ScoredEntry::new(Item { id: id_b }, 0.9, 1));
148
149 let first = heap.pop().unwrap();
150 assert_eq!(first.id(), id_b, "Higher score pops first");
151 assert!((first.score() - 0.9).abs() < 1e-10);
152 }
153
154 #[test]
155 fn test_scored_entry_uuid_tie_breaking() {
156 let id_a = Uuid::from_u128(1);
157 let id_b = Uuid::from_u128(2);
158
159 let mut heap = BinaryHeap::new();
160 heap.push(ScoredEntry::new(Item { id: id_b }, 0.5, 1));
161 heap.push(ScoredEntry::new(Item { id: id_a }, 0.5, 0));
162
163 let first = heap.pop().unwrap();
165 assert_eq!(first.id(), id_a, "Lower UUID pops first on tie");
166 }
167
168 #[test]
169 fn test_scored_entry_nan_pops_last() {
170 let id_nan = Uuid::from_u128(1);
171 let id_normal = Uuid::from_u128(2);
172
173 let mut heap = BinaryHeap::new();
174 heap.push(ScoredEntry::new(Item { id: id_nan }, f64::NAN, 0));
175 heap.push(ScoredEntry::new(Item { id: id_normal }, 0.5, 1));
176
177 let first = heap.pop().unwrap();
178 assert_eq!(first.id(), id_normal, "Normal score pops before NaN");
179
180 let second = heap.pop().unwrap();
181 assert_eq!(second.id(), id_nan, "NaN pops last");
182 }
183
184 #[test]
185 fn test_scored_entry_index_preserved() {
186 let id = Uuid::from_u128(42);
187 let entry = ScoredEntry::new(Item { id }, 0.7, 99);
188 assert_eq!(entry.index(), 99);
189 }
190
191 #[test]
192 fn test_scored_entry_equality_by_det_score_and_id() {
193 let id_a = Uuid::from_u128(1);
194 let id_b = Uuid::from_u128(1); let a = ScoredEntry::new(Item { id: id_a }, 0.5, 0);
196 let b = ScoredEntry::new(Item { id: id_b }, 0.5, 99); assert_eq!(a, b, "Equality ignores index");
198 }
199
200 #[test]
201 fn test_scored_entry_det_score_accessor() {
202 let id = Uuid::from_u128(1);
203 let entry = ScoredEntry::new(Item { id }, 0.75, 0);
204 let det = entry.det_score();
205 assert!((det.to_f64() - 0.75).abs() < 1e-9);
206 }
207}