Skip to main content

anno/types/
ext.rs

1//! Extension traits for entity collections.
2
3use anno_core::Entity;
4use std::collections::HashMap;
5
6/// Extension methods for slices of entities.
7///
8/// This trait adds useful operations to `[Entity]` and `Vec<Entity>`
9/// without requiring you to wrap them in a newtype.
10///
11/// # Example
12///
13/// ```rust
14/// use anno::{Entity, EntityType};
15/// use anno::types::EntitySliceExt;
16///
17/// let entities = vec![
18///     Entity::new("John", EntityType::Person, 0, 4, 0.9),
19///     Entity::new("$100", EntityType::Money, 10, 14, 0.95),
20///     Entity::new("Paris", EntityType::Location, 20, 25, 0.7),
21/// ];
22///
23/// // Filter by confidence
24/// let high_conf: Vec<_> = entities.above_confidence(0.8).collect();
25/// assert_eq!(high_conf.len(), 2);
26///
27/// // Check for overlaps
28/// assert!(!entities.has_overlaps());
29/// ```
30pub trait EntitySliceExt {
31    /// Filter entities by minimum confidence threshold.
32    fn above_confidence(&self, min: f64) -> impl Iterator<Item = &Entity>;
33
34    /// Filter entities by type.
35    fn of_type(&self, ty: &anno_core::EntityType) -> impl Iterator<Item = &Entity>;
36
37    /// Check if any entities overlap with each other.
38    fn has_overlaps(&self) -> bool;
39
40    /// Find all overlapping pairs of entities.
41    fn overlapping_pairs(&self) -> Vec<(&Entity, &Entity)>;
42
43    /// Get entities sorted by confidence (descending).
44    fn sorted_by_confidence(&self) -> Vec<&Entity>;
45
46    /// Get entities sorted by position (ascending).
47    fn sorted_by_position(&self) -> Vec<&Entity>;
48
49    /// Get the entity with highest confidence.
50    fn highest_confidence(&self) -> Option<&Entity>;
51
52    /// Calculate average confidence across all entities.
53    fn mean_confidence(&self) -> Option<f64>;
54
55    /// Group entities by type.
56    fn group_by_type(&self) -> HashMap<String, Vec<&Entity>>;
57
58    /// Check if a position falls within any entity span.
59    fn contains_position(&self, pos: usize) -> bool;
60
61    /// Get entity at a specific position (if any).
62    fn at_position(&self, pos: usize) -> Option<&Entity>;
63
64    /// Filter to only named entities (Person, Org, Location).
65    fn named_only(&self) -> impl Iterator<Item = &Entity>;
66
67    /// Filter to only structured entities (Date, Money, Email, etc.).
68    fn structured_only(&self) -> impl Iterator<Item = &Entity>;
69}
70
71impl EntitySliceExt for [Entity] {
72    fn above_confidence(&self, min: f64) -> impl Iterator<Item = &Entity> {
73        self.iter().filter(move |e| e.confidence >= min)
74    }
75
76    fn of_type(&self, ty: &anno_core::EntityType) -> impl Iterator<Item = &Entity> {
77        let ty = ty.clone();
78        self.iter().filter(move |e| e.entity_type == ty)
79    }
80
81    fn has_overlaps(&self) -> bool {
82        for i in 0..self.len() {
83            for j in (i + 1)..self.len() {
84                if self[i].overlaps(&self[j]) {
85                    return true;
86                }
87            }
88        }
89        false
90    }
91
92    fn overlapping_pairs(&self) -> Vec<(&Entity, &Entity)> {
93        let mut pairs = Vec::new();
94        for i in 0..self.len() {
95            for j in (i + 1)..self.len() {
96                if self[i].overlaps(&self[j]) {
97                    pairs.push((&self[i], &self[j]));
98                }
99            }
100        }
101        pairs
102    }
103
104    fn sorted_by_confidence(&self) -> Vec<&Entity> {
105        let mut sorted: Vec<_> = self.iter().collect();
106        sorted.sort_by(|a, b| {
107            b.confidence
108                .partial_cmp(&a.confidence)
109                .unwrap_or(std::cmp::Ordering::Equal)
110        });
111        sorted
112    }
113
114    fn sorted_by_position(&self) -> Vec<&Entity> {
115        let mut sorted: Vec<_> = self.iter().collect();
116        sorted.sort_by_key(|e| (e.start, e.end));
117        sorted
118    }
119
120    fn highest_confidence(&self) -> Option<&Entity> {
121        self.iter().max_by(|a, b| {
122            a.confidence
123                .partial_cmp(&b.confidence)
124                .unwrap_or(std::cmp::Ordering::Equal)
125        })
126    }
127
128    fn mean_confidence(&self) -> Option<f64> {
129        if self.is_empty() {
130            return None;
131        }
132        let sum: f64 = self.iter().map(|e| e.confidence).sum();
133        Some(sum / self.len() as f64)
134    }
135
136    fn group_by_type(&self) -> HashMap<String, Vec<&Entity>> {
137        let mut groups: HashMap<String, Vec<&Entity>> = HashMap::new();
138        for entity in self {
139            groups
140                .entry(entity.entity_type.as_label().to_string())
141                .or_default()
142                .push(entity);
143        }
144        groups
145    }
146
147    fn contains_position(&self, pos: usize) -> bool {
148        self.iter().any(|e| pos >= e.start && pos < e.end)
149    }
150
151    fn at_position(&self, pos: usize) -> Option<&Entity> {
152        self.iter().find(|e| pos >= e.start && pos < e.end)
153    }
154
155    fn named_only(&self) -> impl Iterator<Item = &Entity> {
156        self.iter().filter(|e| e.is_named())
157    }
158
159    fn structured_only(&self) -> impl Iterator<Item = &Entity> {
160        self.iter().filter(|e| e.is_structured())
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use anno_core::EntityType;
168
169    fn sample_entities() -> Vec<Entity> {
170        vec![
171            Entity::new("John", EntityType::Person, 0, 4, 0.9),
172            Entity::new("$100", EntityType::Money, 10, 14, 0.95),
173            Entity::new("Paris", EntityType::Location, 20, 25, 0.7),
174            Entity::new("2024", EntityType::Date, 30, 34, 0.85),
175        ]
176    }
177
178    #[test]
179    fn above_confidence_filters() {
180        let entities = sample_entities();
181        let high: Vec<_> = entities.above_confidence(0.85).collect();
182        assert_eq!(high.len(), 3);
183    }
184
185    #[test]
186    fn of_type_filters() {
187        let entities = sample_entities();
188        let people: Vec<_> = entities.of_type(&EntityType::Person).collect();
189        assert_eq!(people.len(), 1);
190        assert_eq!(people[0].text, "John");
191    }
192
193    #[test]
194    fn has_overlaps_detects() {
195        let entities = sample_entities();
196        assert!(!entities.has_overlaps());
197
198        let overlapping = [
199            Entity::new("New York", EntityType::Location, 0, 8, 0.9),
200            Entity::new("York", EntityType::Location, 4, 8, 0.8),
201        ];
202        assert!(overlapping.has_overlaps());
203    }
204
205    #[test]
206    fn sorted_by_confidence_descending() {
207        let entities = sample_entities();
208        let sorted = entities.sorted_by_confidence();
209        assert_eq!(sorted[0].text, "$100");
210        assert_eq!(sorted[1].text, "John");
211    }
212
213    #[test]
214    fn sorted_by_position_ascending() {
215        let mut entities = sample_entities();
216        entities.reverse();
217        let sorted = entities.sorted_by_position();
218        assert_eq!(sorted[0].text, "John");
219        assert_eq!(sorted[1].text, "$100");
220    }
221
222    #[test]
223    fn highest_confidence_finds_max() {
224        let entities = sample_entities();
225        let highest = entities.highest_confidence().unwrap();
226        assert_eq!(highest.text, "$100");
227    }
228
229    #[test]
230    fn mean_confidence_calculates() {
231        let entities = sample_entities();
232        let mean = entities.mean_confidence().unwrap();
233        assert!((mean - 0.85).abs() < 1e-10);
234    }
235
236    #[test]
237    fn group_by_type_groups() {
238        let entities = sample_entities();
239        let groups = entities.group_by_type();
240        assert_eq!(groups.get("PER").map(|v| v.len()), Some(1));
241        assert_eq!(groups.get("MONEY").map(|v| v.len()), Some(1));
242    }
243
244    #[test]
245    fn position_queries() {
246        let entities = sample_entities();
247        assert!(entities.contains_position(2));
248        assert!(!entities.contains_position(5));
249        assert_eq!(entities.at_position(12).unwrap().text, "$100");
250    }
251
252    #[test]
253    fn named_and_structured_filters() {
254        let entities = sample_entities();
255        let named: Vec<_> = entities.named_only().collect();
256        assert_eq!(named.len(), 2);
257        let structured: Vec<_> = entities.structured_only().collect();
258        assert_eq!(structured.len(), 2);
259    }
260
261    #[test]
262    fn empty_slice_handles_gracefully() {
263        let entities: Vec<Entity> = vec![];
264        assert!(!entities.has_overlaps());
265        assert!(entities.highest_confidence().is_none());
266        assert!(entities.mean_confidence().is_none());
267    }
268}