use std::collections::{BTreeSet, HashMap};
use std::sync::Arc;
use crate::resolver::resolve_candidates;
pub use gaze_types::{Candidate, DetectContext, Recognizer};
use gaze_types::{LocaleChain, PiiClass};
pub trait Validator: Send + Sync {
fn id(&self) -> &str;
fn validate(&self, raw: &str) -> ValidationResult;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum ValidationResult {
Valid,
Invalid,
Indeterminate,
}
pub trait Canonicalizer: Send + Sync {
fn canonicalize(&self, raw: &str) -> Option<String>;
}
pub struct RecognizerRegistry {
entries: Vec<Arc<dyn Recognizer>>,
validators: HashMap<String, Arc<dyn Validator>>,
canonicalizers: HashMap<String, Arc<dyn Canonicalizer>>,
}
impl std::fmt::Debug for RecognizerRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RecognizerRegistry").finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{ConflictTier, DictionaryBundle, LocaleTag, PiiClass};
struct StubRecognizer {
class: PiiClass,
}
impl Recognizer for StubRecognizer {
fn id(&self) -> &str {
"stub"
}
fn supported_class(&self) -> &PiiClass {
&self.class
}
fn detect(&self, _input: &str, _ctx: &DetectContext<'_>) -> Vec<Candidate> {
vec![Candidate::new(
0..5,
self.class.clone(),
self.id(),
1.0,
0,
Some("canonical".to_string()),
self.token_family(),
"test",
ConflictTier::None,
Vec::new(),
)]
}
fn token_family(&self) -> &str {
"counter"
}
}
#[test]
fn registry_detect_all_uses_registered_recognizers() {
let registry = RecognizerRegistry::builder()
.register(StubRecognizer {
class: PiiClass::Email,
})
.build();
let dictionaries = DictionaryBundle::default();
let ctx = DetectContext::new(&[LocaleTag::Global], &dictionaries);
let candidates = registry.detect_all("input", &ctx);
assert_eq!(candidates.len(), 1);
assert_eq!(candidates[0].class, PiiClass::Email);
assert_eq!(candidates[0].token_family, "counter");
let candidates = registry.detect_all_resolved("input", &ctx);
assert_eq!(candidates.len(), 1);
assert_eq!(candidates[0].class, PiiClass::Email);
}
#[test]
fn default_locale_is_global() {
let recognizer = StubRecognizer {
class: PiiClass::Email,
};
assert_eq!(recognizer.locales(), &[LocaleTag::Global]);
}
#[test]
fn registry_filters_recognizers_by_locale_before_detection() {
struct LocaleRecognizer {
locale: LocaleTag,
}
impl Recognizer for LocaleRecognizer {
fn id(&self) -> &str {
"locale"
}
fn supported_class(&self) -> &PiiClass {
&PiiClass::Email
}
fn detect(&self, _input: &str, _ctx: &DetectContext<'_>) -> Vec<Candidate> {
vec![Candidate::new(
0..5,
PiiClass::Email,
self.id(),
1.0,
0,
None,
"counter",
self.id(),
ConflictTier::None,
Vec::new(),
)]
}
fn token_family(&self) -> &str {
"counter"
}
fn locales(&self) -> &[LocaleTag] {
std::slice::from_ref(&self.locale)
}
}
let registry = RecognizerRegistry::builder()
.register(LocaleRecognizer {
locale: LocaleTag::DeDe,
})
.build();
let dictionaries = DictionaryBundle::default();
let ctx = DetectContext::new(&[LocaleTag::EnUs, LocaleTag::Global], &dictionaries);
assert!(registry.detect_all("input", &ctx).is_empty());
}
}
impl RecognizerRegistry {
pub fn builder() -> RecognizerRegistryBuilder {
RecognizerRegistryBuilder::default()
}
pub fn detect_all(&self, input: &str, ctx: &DetectContext<'_>) -> Vec<Candidate> {
self.entries
.iter()
.filter(|recognizer| {
LocaleChain::from(ctx.locale_chain).intersects(recognizer.locales())
})
.flat_map(|recognizer| recognizer.detect(input, ctx))
.collect()
}
pub fn detect_all_resolved(&self, input: &str, ctx: &DetectContext<'_>) -> Vec<Candidate> {
let classes = self
.entries
.iter()
.map(|recognizer| recognizer.supported_class().clone())
.collect::<BTreeSet<_>>();
let mut candidates = Vec::new();
for class in classes {
for locale in ctx.locale_chain {
let locale_ctx = DetectContext::new(std::slice::from_ref(locale), ctx.dictionaries);
locale_ctx.degraded.set(ctx.degraded.get());
let class_candidates = self
.entries
.iter()
.filter(|recognizer| recognizer.supported_class() == &class)
.filter(|recognizer| {
LocaleChain::from(locale_ctx.locale_chain).intersects(recognizer.locales())
})
.flat_map(|recognizer| recognizer.detect(input, &locale_ctx))
.filter(|candidate| candidate.score >= min_score(&class))
.collect::<Vec<_>>();
if !class_candidates.is_empty() {
candidates.extend(class_candidates);
break;
}
}
}
resolve_candidates(candidates)
}
pub fn validators(&self) -> &HashMap<String, Arc<dyn Validator>> {
&self.validators
}
pub fn canonicalizers(&self) -> &HashMap<String, Arc<dyn Canonicalizer>> {
&self.canonicalizers
}
}
fn min_score(_class: &PiiClass) -> f32 {
0.0
}
#[derive(Default)]
pub struct RecognizerRegistryBuilder {
entries: Vec<Arc<dyn Recognizer>>,
validators: HashMap<String, Arc<dyn Validator>>,
canonicalizers: HashMap<String, Arc<dyn Canonicalizer>>,
}
impl RecognizerRegistryBuilder {
pub fn register<R: Recognizer + 'static>(mut self, r: R) -> Self {
self.entries.push(Arc::new(r));
self
}
pub fn register_arc(mut self, r: Arc<dyn Recognizer>) -> Self {
self.entries.push(r);
self
}
pub fn build(self) -> RecognizerRegistry {
RecognizerRegistry {
entries: self.entries,
validators: self.validators,
canonicalizers: self.canonicalizers,
}
}
}