Skip to main content

redact_core/recognizers/
registry.rs

1// Copyright 2026 Censgate LLC.
2// Licensed under the Apache License, Version 2.0. See the LICENSE file
3// in the project root for license information.
4
5use super::{Recognizer, RecognizerResult};
6use crate::types::EntityType;
7use anyhow::{Context, Result};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11/// Registry for managing multiple recognizers
12#[derive(Debug, Clone)]
13pub struct RecognizerRegistry {
14    recognizers: Vec<Arc<dyn Recognizer>>,
15    entity_map: HashMap<EntityType, Vec<usize>>,
16}
17
18impl RecognizerRegistry {
19    /// Create a new empty registry
20    pub fn new() -> Self {
21        Self {
22            recognizers: Vec::new(),
23            entity_map: HashMap::new(),
24        }
25    }
26
27    /// Add a recognizer to the registry
28    pub fn add_recognizer(&mut self, recognizer: Arc<dyn Recognizer>) {
29        let index = self.recognizers.len();
30
31        // Map entity types to recognizer index
32        for entity_type in recognizer.supported_entities() {
33            self.entity_map
34                .entry(entity_type.clone())
35                .or_default()
36                .push(index);
37        }
38
39        self.recognizers.push(recognizer);
40    }
41
42    /// Get all recognizers
43    pub fn recognizers(&self) -> &[Arc<dyn Recognizer>] {
44        &self.recognizers
45    }
46
47    /// Get recognizers that support a specific entity type
48    pub fn recognizers_for_entity(&self, entity_type: &EntityType) -> Vec<Arc<dyn Recognizer>> {
49        if let Some(indices) = self.entity_map.get(entity_type) {
50            indices
51                .iter()
52                .map(|&idx| self.recognizers[idx].clone())
53                .collect()
54        } else {
55            Vec::new()
56        }
57    }
58
59    /// Analyze text using all recognizers
60    pub fn analyze(&self, text: &str, language: &str) -> Result<Vec<RecognizerResult>> {
61        let mut all_results = Vec::new();
62
63        for recognizer in &self.recognizers {
64            // Skip recognizers that don't support the language
65            if !recognizer.supports_language(language) {
66                continue;
67            }
68
69            let results = recognizer.analyze(text, language).with_context(|| {
70                format!("Failed to analyze with recognizer: {}", recognizer.name())
71            })?;
72
73            all_results.extend(results);
74        }
75
76        // Sort and resolve overlaps
77        all_results.sort();
78        let resolved = self.resolve_overlaps(all_results);
79
80        Ok(resolved)
81    }
82
83    /// Analyze text using only specific entity types
84    pub fn analyze_with_entities(
85        &self,
86        text: &str,
87        language: &str,
88        entity_types: &[EntityType],
89    ) -> Result<Vec<RecognizerResult>> {
90        let mut all_results = Vec::new();
91
92        // Get unique recognizers that support the requested entities
93        let mut used_recognizers = std::collections::HashSet::new();
94
95        for entity_type in entity_types {
96            if let Some(indices) = self.entity_map.get(entity_type) {
97                used_recognizers.extend(indices.iter().copied());
98            }
99        }
100
101        for idx in used_recognizers {
102            let recognizer = &self.recognizers[idx];
103
104            if !recognizer.supports_language(language) {
105                continue;
106            }
107
108            let results = recognizer.analyze(text, language).with_context(|| {
109                format!("Failed to analyze with recognizer: {}", recognizer.name())
110            })?;
111
112            // Filter to only requested entity types
113            let filtered: Vec<_> = results
114                .into_iter()
115                .filter(|r| entity_types.contains(&r.entity_type))
116                .collect();
117
118            all_results.extend(filtered);
119        }
120
121        all_results.sort();
122        let resolved = self.resolve_overlaps(all_results);
123
124        Ok(resolved)
125    }
126
127    /// Resolve overlapping detections using a multi-factor scoring approach.
128    ///
129    /// When multiple patterns match the same text span, we use the following
130    /// priority order to determine which detection to keep:
131    ///
132    /// 1. **Suppression rules**: Specific entity types suppress generic ones
133    ///    (e.g., UK_MOBILE_NUMBER suppresses PHONE_NUMBER)
134    /// 2. **Combined score**: Weighted combination of confidence and specificity
135    /// 3. **Span length**: Longer matches preferred (more context = more reliable)
136    ///
137    /// This approach reduces false positives by preferring:
138    /// - Country/domain-specific patterns over generic ones
139    /// - Validated patterns (checksums) over unvalidated
140    /// - Higher confidence matches
141    fn resolve_overlaps(&self, results: Vec<RecognizerResult>) -> Vec<RecognizerResult> {
142        if results.is_empty() {
143            return results;
144        }
145
146        let mut resolved = Vec::new();
147        let mut consumed = vec![false; results.len()];
148
149        for i in 0..results.len() {
150            if consumed[i] {
151                continue;
152            }
153
154            // Collect all overlapping results (including current)
155            let mut group: Vec<usize> = vec![i];
156            for j in (i + 1)..results.len() {
157                if consumed[j] {
158                    continue;
159                }
160                // Check if j overlaps with any result already in the group
161                let overlaps_group = group.iter().any(|&g| results[g].overlaps_with(&results[j]));
162                if overlaps_group {
163                    group.push(j);
164                }
165            }
166
167            // Find the best result in this overlapping group
168            let mut best_idx = i;
169            let mut best = &results[i];
170
171            for &idx in &group[1..] {
172                let candidate = &results[idx];
173
174                // Check suppression rules first
175                if best.entity_type.is_suppressed_by(&candidate.entity_type) {
176                    best = candidate;
177                    best_idx = idx;
178                    continue;
179                }
180
181                if candidate.entity_type.is_suppressed_by(&best.entity_type) {
182                    continue;
183                }
184
185                // Calculate combined scores
186                let best_combined = Self::combined_score(best);
187                let candidate_combined = Self::combined_score(candidate);
188
189                if candidate_combined > best_combined {
190                    best = candidate;
191                    best_idx = idx;
192                } else if (candidate_combined - best_combined).abs() < 0.05 {
193                    // Scores are close - prefer longer match (more context)
194                    if candidate.len() > best.len() {
195                        best = candidate;
196                        best_idx = idx;
197                    }
198                }
199            }
200
201            // Mark all in group as consumed
202            for &idx in &group {
203                consumed[idx] = true;
204            }
205
206            resolved.push(results[best_idx].clone());
207        }
208
209        resolved
210    }
211
212    /// Calculate a combined score from confidence and entity specificity.
213    ///
214    /// Formula: 0.6 * confidence + 0.4 * (specificity / 100)
215    ///
216    /// This weights confidence higher but gives meaningful boost to
217    /// more specific entity types.
218    fn combined_score(result: &RecognizerResult) -> f32 {
219        let specificity = result.entity_type.specificity_score() as f32 / 100.0;
220        0.6 * result.score + 0.4 * specificity
221    }
222
223    /// Get statistics about the registry
224    pub fn stats(&self) -> RegistryStats {
225        let mut entity_coverage = HashMap::new();
226        for (entity_type, indices) in &self.entity_map {
227            entity_coverage.insert(entity_type.clone(), indices.len());
228        }
229
230        RegistryStats {
231            recognizer_count: self.recognizers.len(),
232            entity_coverage,
233        }
234    }
235}
236
237impl Default for RecognizerRegistry {
238    fn default() -> Self {
239        Self::new()
240    }
241}
242
243/// Statistics about a recognizer registry
244#[derive(Debug, Clone)]
245pub struct RegistryStats {
246    pub recognizer_count: usize,
247    pub entity_coverage: HashMap<EntityType, usize>,
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253    use crate::recognizers::pattern::PatternRecognizer;
254
255    #[test]
256    fn test_registry_add_recognizer() {
257        let mut registry = RecognizerRegistry::new();
258        let recognizer = Arc::new(PatternRecognizer::new());
259
260        registry.add_recognizer(recognizer);
261
262        assert_eq!(registry.recognizers().len(), 1);
263    }
264
265    #[test]
266    fn test_registry_analyze() {
267        let mut registry = RecognizerRegistry::new();
268        let recognizer = Arc::new(PatternRecognizer::new());
269        registry.add_recognizer(recognizer);
270
271        let text = "Email: john@example.com, Phone: (555) 123-4567";
272        let results = registry.analyze(text, "en").unwrap();
273
274        assert!(results.len() >= 2);
275    }
276
277    #[test]
278    fn test_registry_analyze_with_entities() {
279        let mut registry = RecognizerRegistry::new();
280        let recognizer = Arc::new(PatternRecognizer::new());
281        registry.add_recognizer(recognizer);
282
283        let text = "Email: john@example.com, Phone: (555) 123-4567";
284        let results = registry
285            .analyze_with_entities(text, "en", &[EntityType::EmailAddress])
286            .unwrap();
287
288        // Should only get email results
289        assert!(results
290            .iter()
291            .all(|r| r.entity_type == EntityType::EmailAddress));
292    }
293
294    #[test]
295    fn test_overlap_resolution() {
296        let registry = RecognizerRegistry::new();
297
298        // Create overlapping results with same entity type
299        let mut results = vec![
300            RecognizerResult::new(EntityType::Person, 0, 10, 0.8, "test1"),
301            RecognizerResult::new(EntityType::Person, 5, 15, 0.9, "test2"),
302        ];
303
304        results.sort();
305        let resolved = registry.resolve_overlaps(results);
306
307        // Should keep only the higher combined score result
308        // Both have same specificity (Person = 85), so higher confidence wins
309        assert_eq!(resolved.len(), 1);
310        assert_eq!(resolved[0].score, 0.9);
311    }
312
313    #[test]
314    fn test_overlap_resolution_specificity() {
315        let registry = RecognizerRegistry::new();
316
317        // UK mobile (specificity 70) should win over generic phone (specificity 50)
318        // even with slightly lower confidence
319        let mut results = vec![
320            RecognizerResult::new(EntityType::PhoneNumber, 0, 13, 0.75, "pattern"),
321            RecognizerResult::new(EntityType::UkMobileNumber, 0, 13, 0.80, "pattern"),
322        ];
323
324        results.sort();
325        let resolved = registry.resolve_overlaps(results);
326
327        assert_eq!(resolved.len(), 1);
328        assert_eq!(resolved[0].entity_type, EntityType::UkMobileNumber);
329    }
330
331    #[test]
332    fn test_overlap_resolution_suppression() {
333        let registry = RecognizerRegistry::new();
334
335        // Generic phone should be suppressed by UK mobile at same location
336        let mut results = vec![
337            RecognizerResult::new(EntityType::PhoneNumber, 0, 13, 0.90, "pattern"),
338            RecognizerResult::new(EntityType::UkMobileNumber, 0, 13, 0.80, "pattern"),
339        ];
340
341        results.sort();
342        let resolved = registry.resolve_overlaps(results);
343
344        // UK mobile wins due to suppression rule, even with lower confidence
345        assert_eq!(resolved.len(), 1);
346        assert_eq!(resolved[0].entity_type, EntityType::UkMobileNumber);
347    }
348
349    #[test]
350    fn test_recognizers_for_entity() {
351        let mut registry = RecognizerRegistry::new();
352        let recognizer = Arc::new(PatternRecognizer::new());
353        registry.add_recognizer(recognizer);
354
355        let recognizers = registry.recognizers_for_entity(&EntityType::EmailAddress);
356        assert_eq!(recognizers.len(), 1);
357
358        let recognizers = registry.recognizers_for_entity(&EntityType::Person);
359        assert_eq!(recognizers.len(), 0); // Pattern recognizer doesn't support Person
360    }
361
362    #[test]
363    fn test_registry_stats() {
364        let mut registry = RecognizerRegistry::new();
365        let recognizer = Arc::new(PatternRecognizer::new());
366        registry.add_recognizer(recognizer);
367
368        let stats = registry.stats();
369        assert_eq!(stats.recognizer_count, 1);
370        assert!(!stats.entity_coverage.is_empty());
371    }
372}