redact_core/recognizers/
registry.rs1use super::{Recognizer, RecognizerResult};
6use crate::types::EntityType;
7use anyhow::{Context, Result};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11#[derive(Debug, Clone)]
13pub struct RecognizerRegistry {
14 recognizers: Vec<Arc<dyn Recognizer>>,
15 entity_map: HashMap<EntityType, Vec<usize>>,
16}
17
18impl RecognizerRegistry {
19 pub fn new() -> Self {
21 Self {
22 recognizers: Vec::new(),
23 entity_map: HashMap::new(),
24 }
25 }
26
27 pub fn add_recognizer(&mut self, recognizer: Arc<dyn Recognizer>) {
29 let index = self.recognizers.len();
30
31 for entity_type in recognizer.supported_entities() {
33 self.entity_map
34 .entry(entity_type.clone())
35 .or_default()
36 .push(index);
37 }
38
39 self.recognizers.push(recognizer);
40 }
41
42 pub fn recognizers(&self) -> &[Arc<dyn Recognizer>] {
44 &self.recognizers
45 }
46
47 pub fn recognizers_for_entity(&self, entity_type: &EntityType) -> Vec<Arc<dyn Recognizer>> {
49 if let Some(indices) = self.entity_map.get(entity_type) {
50 indices
51 .iter()
52 .map(|&idx| self.recognizers[idx].clone())
53 .collect()
54 } else {
55 Vec::new()
56 }
57 }
58
59 pub fn analyze(&self, text: &str, language: &str) -> Result<Vec<RecognizerResult>> {
61 let mut all_results = Vec::new();
62
63 for recognizer in &self.recognizers {
64 if !recognizer.supports_language(language) {
66 continue;
67 }
68
69 let results = recognizer.analyze(text, language).with_context(|| {
70 format!("Failed to analyze with recognizer: {}", recognizer.name())
71 })?;
72
73 all_results.extend(results);
74 }
75
76 all_results.sort();
78 let resolved = self.resolve_overlaps(all_results);
79
80 Ok(resolved)
81 }
82
83 pub fn analyze_with_entities(
85 &self,
86 text: &str,
87 language: &str,
88 entity_types: &[EntityType],
89 ) -> Result<Vec<RecognizerResult>> {
90 let mut all_results = Vec::new();
91
92 let mut used_recognizers = std::collections::HashSet::new();
94
95 for entity_type in entity_types {
96 if let Some(indices) = self.entity_map.get(entity_type) {
97 used_recognizers.extend(indices.iter().copied());
98 }
99 }
100
101 for idx in used_recognizers {
102 let recognizer = &self.recognizers[idx];
103
104 if !recognizer.supports_language(language) {
105 continue;
106 }
107
108 let results = recognizer.analyze(text, language).with_context(|| {
109 format!("Failed to analyze with recognizer: {}", recognizer.name())
110 })?;
111
112 let filtered: Vec<_> = results
114 .into_iter()
115 .filter(|r| entity_types.contains(&r.entity_type))
116 .collect();
117
118 all_results.extend(filtered);
119 }
120
121 all_results.sort();
122 let resolved = self.resolve_overlaps(all_results);
123
124 Ok(resolved)
125 }
126
127 fn resolve_overlaps(&self, results: Vec<RecognizerResult>) -> Vec<RecognizerResult> {
142 if results.is_empty() {
143 return results;
144 }
145
146 let mut resolved = Vec::new();
147 let mut consumed = vec![false; results.len()];
148
149 for i in 0..results.len() {
150 if consumed[i] {
151 continue;
152 }
153
154 let mut group: Vec<usize> = vec![i];
156 for j in (i + 1)..results.len() {
157 if consumed[j] {
158 continue;
159 }
160 let overlaps_group = group.iter().any(|&g| results[g].overlaps_with(&results[j]));
162 if overlaps_group {
163 group.push(j);
164 }
165 }
166
167 let mut best_idx = i;
169 let mut best = &results[i];
170
171 for &idx in &group[1..] {
172 let candidate = &results[idx];
173
174 if best.entity_type.is_suppressed_by(&candidate.entity_type) {
176 best = candidate;
177 best_idx = idx;
178 continue;
179 }
180
181 if candidate.entity_type.is_suppressed_by(&best.entity_type) {
182 continue;
183 }
184
185 let best_combined = Self::combined_score(best);
187 let candidate_combined = Self::combined_score(candidate);
188
189 if candidate_combined > best_combined {
190 best = candidate;
191 best_idx = idx;
192 } else if (candidate_combined - best_combined).abs() < 0.05 {
193 if candidate.len() > best.len() {
195 best = candidate;
196 best_idx = idx;
197 }
198 }
199 }
200
201 for &idx in &group {
203 consumed[idx] = true;
204 }
205
206 resolved.push(results[best_idx].clone());
207 }
208
209 resolved
210 }
211
212 fn combined_score(result: &RecognizerResult) -> f32 {
219 let specificity = result.entity_type.specificity_score() as f32 / 100.0;
220 0.6 * result.score + 0.4 * specificity
221 }
222
223 pub fn stats(&self) -> RegistryStats {
225 let mut entity_coverage = HashMap::new();
226 for (entity_type, indices) in &self.entity_map {
227 entity_coverage.insert(entity_type.clone(), indices.len());
228 }
229
230 RegistryStats {
231 recognizer_count: self.recognizers.len(),
232 entity_coverage,
233 }
234 }
235}
236
237impl Default for RecognizerRegistry {
238 fn default() -> Self {
239 Self::new()
240 }
241}
242
243#[derive(Debug, Clone)]
245pub struct RegistryStats {
246 pub recognizer_count: usize,
247 pub entity_coverage: HashMap<EntityType, usize>,
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253 use crate::recognizers::pattern::PatternRecognizer;
254
255 #[test]
256 fn test_registry_add_recognizer() {
257 let mut registry = RecognizerRegistry::new();
258 let recognizer = Arc::new(PatternRecognizer::new());
259
260 registry.add_recognizer(recognizer);
261
262 assert_eq!(registry.recognizers().len(), 1);
263 }
264
265 #[test]
266 fn test_registry_analyze() {
267 let mut registry = RecognizerRegistry::new();
268 let recognizer = Arc::new(PatternRecognizer::new());
269 registry.add_recognizer(recognizer);
270
271 let text = "Email: john@example.com, Phone: (555) 123-4567";
272 let results = registry.analyze(text, "en").unwrap();
273
274 assert!(results.len() >= 2);
275 }
276
277 #[test]
278 fn test_registry_analyze_with_entities() {
279 let mut registry = RecognizerRegistry::new();
280 let recognizer = Arc::new(PatternRecognizer::new());
281 registry.add_recognizer(recognizer);
282
283 let text = "Email: john@example.com, Phone: (555) 123-4567";
284 let results = registry
285 .analyze_with_entities(text, "en", &[EntityType::EmailAddress])
286 .unwrap();
287
288 assert!(results
290 .iter()
291 .all(|r| r.entity_type == EntityType::EmailAddress));
292 }
293
294 #[test]
295 fn test_overlap_resolution() {
296 let registry = RecognizerRegistry::new();
297
298 let mut results = vec![
300 RecognizerResult::new(EntityType::Person, 0, 10, 0.8, "test1"),
301 RecognizerResult::new(EntityType::Person, 5, 15, 0.9, "test2"),
302 ];
303
304 results.sort();
305 let resolved = registry.resolve_overlaps(results);
306
307 assert_eq!(resolved.len(), 1);
310 assert_eq!(resolved[0].score, 0.9);
311 }
312
313 #[test]
314 fn test_overlap_resolution_specificity() {
315 let registry = RecognizerRegistry::new();
316
317 let mut results = vec![
320 RecognizerResult::new(EntityType::PhoneNumber, 0, 13, 0.75, "pattern"),
321 RecognizerResult::new(EntityType::UkMobileNumber, 0, 13, 0.80, "pattern"),
322 ];
323
324 results.sort();
325 let resolved = registry.resolve_overlaps(results);
326
327 assert_eq!(resolved.len(), 1);
328 assert_eq!(resolved[0].entity_type, EntityType::UkMobileNumber);
329 }
330
331 #[test]
332 fn test_overlap_resolution_suppression() {
333 let registry = RecognizerRegistry::new();
334
335 let mut results = vec![
337 RecognizerResult::new(EntityType::PhoneNumber, 0, 13, 0.90, "pattern"),
338 RecognizerResult::new(EntityType::UkMobileNumber, 0, 13, 0.80, "pattern"),
339 ];
340
341 results.sort();
342 let resolved = registry.resolve_overlaps(results);
343
344 assert_eq!(resolved.len(), 1);
346 assert_eq!(resolved[0].entity_type, EntityType::UkMobileNumber);
347 }
348
349 #[test]
350 fn test_recognizers_for_entity() {
351 let mut registry = RecognizerRegistry::new();
352 let recognizer = Arc::new(PatternRecognizer::new());
353 registry.add_recognizer(recognizer);
354
355 let recognizers = registry.recognizers_for_entity(&EntityType::EmailAddress);
356 assert_eq!(recognizers.len(), 1);
357
358 let recognizers = registry.recognizers_for_entity(&EntityType::Person);
359 assert_eq!(recognizers.len(), 0); }
361
362 #[test]
363 fn test_registry_stats() {
364 let mut registry = RecognizerRegistry::new();
365 let recognizer = Arc::new(PatternRecognizer::new());
366 registry.add_recognizer(recognizer);
367
368 let stats = registry.stats();
369 assert_eq!(stats.recognizer_count, 1);
370 assert!(!stats.entity_coverage.is_empty());
371 }
372}