khive_fold/ordering/
scored_entry.rs1use std::cmp::Ordering;
4use std::hash::{Hash, Hasher};
5use uuid::Uuid;
6
7use khive_score::DeterministicScore;
8
9use super::has_id::HasId;
10
11#[derive(Debug, Clone, Copy)]
13pub struct ScoredEntry<T> {
14 candidate: T,
16 score: f64,
18 id: Uuid,
20 index: usize,
22 det_score: DeterministicScore,
24}
25
26impl<T: HasId> ScoredEntry<T> {
27 #[inline]
29 pub fn new(candidate: T, score: f64, index: usize) -> Self {
30 let id = candidate.id();
31 let det_score = DeterministicScore::from_f64(score);
32 Self {
33 candidate,
34 score,
35 id,
36 index,
37 det_score,
38 }
39 }
40
41 #[inline]
43 pub fn candidate(&self) -> &T {
44 &self.candidate
45 }
46
47 #[inline]
49 pub fn into_candidate(self) -> T {
50 self.candidate
51 }
52
53 #[inline]
55 pub fn score(&self) -> f64 {
56 self.score
57 }
58
59 #[inline]
61 pub fn id(&self) -> Uuid {
62 self.id
63 }
64
65 #[inline]
67 pub fn index(&self) -> usize {
68 self.index
69 }
70
71 #[inline]
73 pub fn det_score(&self) -> DeterministicScore {
74 self.det_score
75 }
76}
77
78impl<T> Eq for ScoredEntry<T> {}
79
80impl<T> PartialEq for ScoredEntry<T> {
81 #[inline]
82 fn eq(&self, other: &Self) -> bool {
83 self.det_score == other.det_score && self.id == other.id
84 }
85}
86
87impl<T> Ord for ScoredEntry<T> {
88 #[inline]
91 fn cmp(&self, other: &Self) -> Ordering {
92 self.det_score
93 .cmp(&other.det_score)
94 .then_with(|| other.id.cmp(&self.id))
95 }
96}
97
98impl<T> PartialOrd for ScoredEntry<T> {
99 #[inline]
100 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
101 Some(self.cmp(other))
102 }
103}
104
105impl<T> Hash for ScoredEntry<T> {
106 fn hash<H: Hasher>(&self, state: &mut H) {
107 self.det_score.hash(state);
108 self.id.hash(state);
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115 use std::collections::BinaryHeap;
116
117 #[derive(Debug)]
118 struct Item {
119 id: Uuid,
120 }
121
122 impl HasId for Item {
123 fn id(&self) -> Uuid {
124 self.id
125 }
126 }
127
128 #[test]
129 fn test_scored_entry_higher_score_pops_first() {
130 let id_a = Uuid::from_u128(1);
131 let id_b = Uuid::from_u128(2);
132
133 let mut heap = BinaryHeap::new();
134 heap.push(ScoredEntry::new(Item { id: id_a }, 0.3, 0));
135 heap.push(ScoredEntry::new(Item { id: id_b }, 0.9, 1));
136
137 let first = heap.pop().unwrap();
138 assert_eq!(first.id(), id_b, "Higher score pops first");
139 assert!((first.score() - 0.9).abs() < 1e-10);
140 }
141
142 #[test]
143 fn test_scored_entry_uuid_tie_breaking() {
144 let id_a = Uuid::from_u128(1);
145 let id_b = Uuid::from_u128(2);
146
147 let mut heap = BinaryHeap::new();
148 heap.push(ScoredEntry::new(Item { id: id_b }, 0.5, 1));
149 heap.push(ScoredEntry::new(Item { id: id_a }, 0.5, 0));
150
151 let first = heap.pop().unwrap();
153 assert_eq!(first.id(), id_a, "Lower UUID pops first on tie");
154 }
155
156 #[test]
157 fn test_scored_entry_nan_pops_last() {
158 let id_nan = Uuid::from_u128(1);
159 let id_normal = Uuid::from_u128(2);
160
161 let mut heap = BinaryHeap::new();
162 heap.push(ScoredEntry::new(Item { id: id_nan }, f64::NAN, 0));
163 heap.push(ScoredEntry::new(Item { id: id_normal }, 0.5, 1));
164
165 let first = heap.pop().unwrap();
166 assert_eq!(first.id(), id_normal, "Normal score pops before NaN");
167
168 let second = heap.pop().unwrap();
169 assert_eq!(second.id(), id_nan, "NaN pops last");
170 }
171
172 #[test]
173 fn test_scored_entry_index_preserved() {
174 let id = Uuid::from_u128(42);
175 let entry = ScoredEntry::new(Item { id }, 0.7, 99);
176 assert_eq!(entry.index(), 99);
177 }
178
179 #[test]
180 fn test_scored_entry_equality_by_det_score_and_id() {
181 let id_a = Uuid::from_u128(1);
182 let id_b = Uuid::from_u128(1); let a = ScoredEntry::new(Item { id: id_a }, 0.5, 0);
184 let b = ScoredEntry::new(Item { id: id_b }, 0.5, 99); assert_eq!(a, b, "Equality ignores index");
186 }
187
188 #[test]
189 fn test_scored_entry_det_score_accessor() {
190 let id = Uuid::from_u128(1);
191 let entry = ScoredEntry::new(Item { id }, 0.75, 0);
192 let det = entry.det_score();
193 assert!((det.to_f64() - 0.75).abs() < 1e-9);
194 }
195}