1use anno_core::Entity;
4use std::collections::HashMap;
5
6pub trait EntitySliceExt {
31 fn above_confidence(&self, min: f64) -> impl Iterator<Item = &Entity>;
33
34 fn of_type(&self, ty: &anno_core::EntityType) -> impl Iterator<Item = &Entity>;
36
37 fn has_overlaps(&self) -> bool;
39
40 fn overlapping_pairs(&self) -> Vec<(&Entity, &Entity)>;
42
43 fn sorted_by_confidence(&self) -> Vec<&Entity>;
45
46 fn sorted_by_position(&self) -> Vec<&Entity>;
48
49 fn highest_confidence(&self) -> Option<&Entity>;
51
52 fn mean_confidence(&self) -> Option<f64>;
54
55 fn group_by_type(&self) -> HashMap<String, Vec<&Entity>>;
57
58 fn contains_position(&self, pos: usize) -> bool;
60
61 fn at_position(&self, pos: usize) -> Option<&Entity>;
63
64 fn named_only(&self) -> impl Iterator<Item = &Entity>;
66
67 fn structured_only(&self) -> impl Iterator<Item = &Entity>;
69}
70
71impl EntitySliceExt for [Entity] {
72 fn above_confidence(&self, min: f64) -> impl Iterator<Item = &Entity> {
73 self.iter().filter(move |e| e.confidence >= min)
74 }
75
76 fn of_type(&self, ty: &anno_core::EntityType) -> impl Iterator<Item = &Entity> {
77 let ty = ty.clone();
78 self.iter().filter(move |e| e.entity_type == ty)
79 }
80
81 fn has_overlaps(&self) -> bool {
82 for i in 0..self.len() {
83 for j in (i + 1)..self.len() {
84 if self[i].overlaps(&self[j]) {
85 return true;
86 }
87 }
88 }
89 false
90 }
91
92 fn overlapping_pairs(&self) -> Vec<(&Entity, &Entity)> {
93 let mut pairs = Vec::new();
94 for i in 0..self.len() {
95 for j in (i + 1)..self.len() {
96 if self[i].overlaps(&self[j]) {
97 pairs.push((&self[i], &self[j]));
98 }
99 }
100 }
101 pairs
102 }
103
104 fn sorted_by_confidence(&self) -> Vec<&Entity> {
105 let mut sorted: Vec<_> = self.iter().collect();
106 sorted.sort_by(|a, b| {
107 b.confidence
108 .partial_cmp(&a.confidence)
109 .unwrap_or(std::cmp::Ordering::Equal)
110 });
111 sorted
112 }
113
114 fn sorted_by_position(&self) -> Vec<&Entity> {
115 let mut sorted: Vec<_> = self.iter().collect();
116 sorted.sort_by_key(|e| (e.start, e.end));
117 sorted
118 }
119
120 fn highest_confidence(&self) -> Option<&Entity> {
121 self.iter().max_by(|a, b| {
122 a.confidence
123 .partial_cmp(&b.confidence)
124 .unwrap_or(std::cmp::Ordering::Equal)
125 })
126 }
127
128 fn mean_confidence(&self) -> Option<f64> {
129 if self.is_empty() {
130 return None;
131 }
132 let sum: f64 = self.iter().map(|e| e.confidence).sum();
133 Some(sum / self.len() as f64)
134 }
135
136 fn group_by_type(&self) -> HashMap<String, Vec<&Entity>> {
137 let mut groups: HashMap<String, Vec<&Entity>> = HashMap::new();
138 for entity in self {
139 groups
140 .entry(entity.entity_type.as_label().to_string())
141 .or_default()
142 .push(entity);
143 }
144 groups
145 }
146
147 fn contains_position(&self, pos: usize) -> bool {
148 self.iter().any(|e| pos >= e.start && pos < e.end)
149 }
150
151 fn at_position(&self, pos: usize) -> Option<&Entity> {
152 self.iter().find(|e| pos >= e.start && pos < e.end)
153 }
154
155 fn named_only(&self) -> impl Iterator<Item = &Entity> {
156 self.iter().filter(|e| e.is_named())
157 }
158
159 fn structured_only(&self) -> impl Iterator<Item = &Entity> {
160 self.iter().filter(|e| e.is_structured())
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167 use anno_core::EntityType;
168
169 fn sample_entities() -> Vec<Entity> {
170 vec![
171 Entity::new("John", EntityType::Person, 0, 4, 0.9),
172 Entity::new("$100", EntityType::Money, 10, 14, 0.95),
173 Entity::new("Paris", EntityType::Location, 20, 25, 0.7),
174 Entity::new("2024", EntityType::Date, 30, 34, 0.85),
175 ]
176 }
177
178 #[test]
179 fn above_confidence_filters() {
180 let entities = sample_entities();
181 let high: Vec<_> = entities.above_confidence(0.85).collect();
182 assert_eq!(high.len(), 3);
183 }
184
185 #[test]
186 fn of_type_filters() {
187 let entities = sample_entities();
188 let people: Vec<_> = entities.of_type(&EntityType::Person).collect();
189 assert_eq!(people.len(), 1);
190 assert_eq!(people[0].text, "John");
191 }
192
193 #[test]
194 fn has_overlaps_detects() {
195 let entities = sample_entities();
196 assert!(!entities.has_overlaps());
197
198 let overlapping = [
199 Entity::new("New York", EntityType::Location, 0, 8, 0.9),
200 Entity::new("York", EntityType::Location, 4, 8, 0.8),
201 ];
202 assert!(overlapping.has_overlaps());
203 }
204
205 #[test]
206 fn sorted_by_confidence_descending() {
207 let entities = sample_entities();
208 let sorted = entities.sorted_by_confidence();
209 assert_eq!(sorted[0].text, "$100");
210 assert_eq!(sorted[1].text, "John");
211 }
212
213 #[test]
214 fn sorted_by_position_ascending() {
215 let mut entities = sample_entities();
216 entities.reverse();
217 let sorted = entities.sorted_by_position();
218 assert_eq!(sorted[0].text, "John");
219 assert_eq!(sorted[1].text, "$100");
220 }
221
222 #[test]
223 fn highest_confidence_finds_max() {
224 let entities = sample_entities();
225 let highest = entities.highest_confidence().unwrap();
226 assert_eq!(highest.text, "$100");
227 }
228
229 #[test]
230 fn mean_confidence_calculates() {
231 let entities = sample_entities();
232 let mean = entities.mean_confidence().unwrap();
233 assert!((mean - 0.85).abs() < 1e-10);
234 }
235
236 #[test]
237 fn group_by_type_groups() {
238 let entities = sample_entities();
239 let groups = entities.group_by_type();
240 assert_eq!(groups.get("PER").map(|v| v.len()), Some(1));
241 assert_eq!(groups.get("MONEY").map(|v| v.len()), Some(1));
242 }
243
244 #[test]
245 fn position_queries() {
246 let entities = sample_entities();
247 assert!(entities.contains_position(2));
248 assert!(!entities.contains_position(5));
249 assert_eq!(entities.at_position(12).unwrap().text, "$100");
250 }
251
252 #[test]
253 fn named_and_structured_filters() {
254 let entities = sample_entities();
255 let named: Vec<_> = entities.named_only().collect();
256 assert_eq!(named.len(), 2);
257 let structured: Vec<_> = entities.structured_only().collect();
258 assert_eq!(structured.len(), 2);
259 }
260
261 #[test]
262 fn empty_slice_handles_gracefully() {
263 let entities: Vec<Entity> = vec![];
264 assert!(!entities.has_overlaps());
265 assert!(entities.highest_confidence().is_none());
266 assert!(entities.mean_confidence().is_none());
267 }
268}