Skip to main content

anno/backends/
label_prompt.rs

1//! Label prompt normalization for zero-shot NER systems.
2//!
3//! GLiNER and similar label-text-conditioned models are sensitive to how
4//! entity type labels are phrased. "Person", "person", "PERSON", "human",
5//! "individual" may all give different results.
6//!
7//! This module provides tools to:
8//! - Normalize label strings to canonical forms
9//! - Expand labels to synonyms/aliases for better coverage
10//! - Map between different ontologies (e.g., OntoNotes → CoNLL)
11//!
12//! # Research Background
13//!
14//! GLiNER critique (2025): "Performance is sensitive to how labels are phrased;
15//! semantically similar but poorly written label prompts can cause large drops,
16//! especially for fine-grained or rare types."
17//!
18//! # Usage
19//!
20//! ```rust
21//! use anno::backends::label_prompt::{LabelNormalizer, StandardNormalizer};
22//!
23//! let normalizer = StandardNormalizer::default();
24//!
25//! // Canonical form
26//! assert_eq!(normalizer.normalize("PERSON"), "person");
27//! assert_eq!(normalizer.normalize("ORG"), "organization");
28//!
29//! // Expansions for better zero-shot coverage
30//! let expansions = normalizer.expand("person");
31//! // Returns: ["person", "human", "individual", "people", ...]
32//! ```
33
34use serde::{Deserialize, Serialize};
35use std::collections::HashMap;
36
37/// Trait for label prompt normalization.
38pub trait LabelNormalizer: Send + Sync {
39    /// Normalize a label to its canonical form.
40    fn normalize(&self, label: &str) -> String;
41
42    /// Expand a label to synonyms/aliases for better coverage.
43    fn expand(&self, label: &str) -> Vec<String>;
44
45    /// Check if two labels are equivalent.
46    fn equivalent(&self, a: &str, b: &str) -> bool {
47        self.normalize(a) == self.normalize(b)
48    }
49
50    /// Get the canonical name for a label (for display).
51    fn canonical_name(&self, label: &str) -> String {
52        self.normalize(label)
53    }
54}
55
56/// Standard label normalizer with common NER ontology mappings.
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct StandardNormalizer {
59    /// Canonical form mappings (alias → canonical)
60    canonical_map: HashMap<String, String>,
61    /// Expansion mappings (canonical → synonyms)
62    expansion_map: HashMap<String, Vec<String>>,
63    /// Whether to lowercase during normalization
64    pub lowercase: bool,
65}
66
67impl Default for StandardNormalizer {
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73impl StandardNormalizer {
74    /// Create a new standard normalizer with default mappings.
75    #[must_use]
76    pub fn new() -> Self {
77        let mut canonical_map = HashMap::new();
78        let mut expansion_map = HashMap::new();
79
80        // Person variations
81        for alias in &[
82            "PER",
83            "PERSON",
84            "Person",
85            "per",
86            "people",
87            "PEOPLE",
88            "human",
89            "HUMAN",
90            "individual",
91            "INDIVIDUAL",
92            "B-PER",
93            "I-PER",
94            "S-PER",
95            "E-PER",
96        ] {
97            canonical_map.insert(alias.to_lowercase(), "person".to_string());
98        }
99        expansion_map.insert(
100            "person".to_string(),
101            vec![
102                "person".to_string(),
103                "human".to_string(),
104                "individual".to_string(),
105                "human being".to_string(),
106                "people".to_string(),
107            ],
108        );
109
110        // Organization variations
111        for alias in &[
112            "ORG",
113            "ORGANIZATION",
114            "Organization",
115            "org",
116            "organisation",
117            "ORGANISATION",
118            "company",
119            "COMPANY",
120            "institution",
121            "INSTITUTION",
122            "B-ORG",
123            "I-ORG",
124            "S-ORG",
125            "E-ORG",
126            "CORP",
127            "corporation",
128        ] {
129            canonical_map.insert(alias.to_lowercase(), "organization".to_string());
130        }
131        expansion_map.insert(
132            "organization".to_string(),
133            vec![
134                "organization".to_string(),
135                "organisation".to_string(),
136                "company".to_string(),
137                "institution".to_string(),
138                "corporation".to_string(),
139                "agency".to_string(),
140                "group".to_string(),
141            ],
142        );
143
144        // Location variations
145        for alias in &[
146            "LOC",
147            "LOCATION",
148            "Location",
149            "loc",
150            "GPE",
151            "gpe",
152            "place",
153            "PLACE",
154            "GEO",
155            "geo",
156            "geographic location",
157            "B-LOC",
158            "I-LOC",
159            "S-LOC",
160            "E-LOC",
161            "B-GPE",
162            "I-GPE",
163            "FAC",
164            "facility",
165        ] {
166            canonical_map.insert(alias.to_lowercase(), "location".to_string());
167        }
168        expansion_map.insert(
169            "location".to_string(),
170            vec![
171                "location".to_string(),
172                "place".to_string(),
173                "geographic location".to_string(),
174                "geopolitical entity".to_string(),
175                "country".to_string(),
176                "city".to_string(),
177                "region".to_string(),
178            ],
179        );
180
181        // Miscellaneous variations
182        for alias in &[
183            "MISC",
184            "Misc",
185            "misc",
186            "miscellaneous",
187            "MISCELLANEOUS",
188            "OTHER",
189            "other",
190            "B-MISC",
191            "I-MISC",
192            "S-MISC",
193            "E-MISC",
194        ] {
195            canonical_map.insert(alias.to_lowercase(), "miscellaneous".to_string());
196        }
197        expansion_map.insert(
198            "miscellaneous".to_string(),
199            vec![
200                "miscellaneous".to_string(),
201                "other entity".to_string(),
202                "named entity".to_string(),
203            ],
204        );
205
206        // Date/Time variations
207        for alias in &[
208            "DATE", "Date", "date", "TIME", "Time", "time", "DATETIME", "datetime", "temporal",
209            "TEMPORAL", "B-DATE", "I-DATE", "B-TIME", "I-TIME",
210        ] {
211            canonical_map.insert(alias.to_lowercase(), "date".to_string());
212        }
213        expansion_map.insert(
214            "date".to_string(),
215            vec![
216                "date".to_string(),
217                "time".to_string(),
218                "temporal expression".to_string(),
219                "datetime".to_string(),
220            ],
221        );
222
223        // Money variations
224        for alias in &[
225            "MONEY", "Money", "money", "CURRENCY", "currency", "monetary", "B-MONEY", "I-MONEY",
226        ] {
227            canonical_map.insert(alias.to_lowercase(), "money".to_string());
228        }
229        expansion_map.insert(
230            "money".to_string(),
231            vec![
232                "money".to_string(),
233                "monetary value".to_string(),
234                "currency amount".to_string(),
235                "price".to_string(),
236            ],
237        );
238
239        // Event variations
240        for alias in &[
241            "EVENT",
242            "Event",
243            "event",
244            "HAPPENING",
245            "occurrence",
246            "B-EVENT",
247            "I-EVENT",
248        ] {
249            canonical_map.insert(alias.to_lowercase(), "event".to_string());
250        }
251        expansion_map.insert(
252            "event".to_string(),
253            vec![
254                "event".to_string(),
255                "occurrence".to_string(),
256                "happening".to_string(),
257                "incident".to_string(),
258            ],
259        );
260
261        // Product variations
262        for alias in &[
263            "PRODUCT",
264            "Product",
265            "product",
266            "PROD",
267            "B-PRODUCT",
268            "I-PRODUCT",
269        ] {
270            canonical_map.insert(alias.to_lowercase(), "product".to_string());
271        }
272        expansion_map.insert(
273            "product".to_string(),
274            vec![
275                "product".to_string(),
276                "commercial product".to_string(),
277                "item".to_string(),
278                "goods".to_string(),
279            ],
280        );
281
282        // Work of art variations
283        for alias in &[
284            "WORK_OF_ART",
285            "WorkOfArt",
286            "work_of_art",
287            "WORK",
288            "artwork",
289            "B-WORK_OF_ART",
290            "I-WORK_OF_ART",
291            "creative work",
292        ] {
293            canonical_map.insert(alias.to_lowercase(), "work_of_art".to_string());
294        }
295        expansion_map.insert(
296            "work_of_art".to_string(),
297            vec![
298                "work of art".to_string(),
299                "creative work".to_string(),
300                "artwork".to_string(),
301                "artistic creation".to_string(),
302            ],
303        );
304
305        Self {
306            canonical_map,
307            expansion_map,
308            lowercase: true,
309        }
310    }
311
312    /// Add a custom mapping.
313    pub fn add_mapping(&mut self, alias: &str, canonical: &str) {
314        self.canonical_map
315            .insert(alias.to_lowercase(), canonical.to_string());
316    }
317
318    /// Add custom expansions for a canonical label.
319    pub fn add_expansions(&mut self, canonical: &str, expansions: Vec<String>) {
320        self.expansion_map.insert(canonical.to_string(), expansions);
321    }
322}
323
324impl LabelNormalizer for StandardNormalizer {
325    fn normalize(&self, label: &str) -> String {
326        let key = if self.lowercase {
327            label.to_lowercase()
328        } else {
329            label.to_string()
330        };
331
332        // Strip BIO prefix if present
333        let stripped = key
334            .strip_prefix("b-")
335            .or_else(|| key.strip_prefix("i-"))
336            .or_else(|| key.strip_prefix("s-"))
337            .or_else(|| key.strip_prefix("e-"))
338            .unwrap_or(&key);
339
340        self.canonical_map
341            .get(stripped)
342            .cloned()
343            .unwrap_or_else(|| stripped.to_string())
344    }
345
346    fn expand(&self, label: &str) -> Vec<String> {
347        let canonical = self.normalize(label);
348        self.expansion_map
349            .get(&canonical)
350            .cloned()
351            .unwrap_or_else(|| vec![canonical])
352    }
353}
354
355/// Hierarchical entity type system.
356///
357/// Supports type hierarchies like: Person → Athlete → Tennis Player
358#[derive(Debug, Clone, Serialize, Deserialize)]
359pub struct HierarchicalTypeSystem {
360    /// Parent → children mapping
361    children: HashMap<String, Vec<String>>,
362    /// Child → parent mapping
363    parent: HashMap<String, String>,
364    /// All types in the system
365    all_types: Vec<String>,
366}
367
368impl Default for HierarchicalTypeSystem {
369    fn default() -> Self {
370        Self::new()
371    }
372}
373
374impl HierarchicalTypeSystem {
375    /// Create a new empty type system.
376    #[must_use]
377    pub fn new() -> Self {
378        Self {
379            children: HashMap::new(),
380            parent: HashMap::new(),
381            all_types: Vec::new(),
382        }
383    }
384
385    /// Create a type system with standard NER hierarchy.
386    #[must_use]
387    pub fn standard_ner() -> Self {
388        let mut sys = Self::new();
389
390        // Person hierarchy
391        sys.add_type("person", None);
392        sys.add_type("politician", Some("person"));
393        sys.add_type("athlete", Some("person"));
394        sys.add_type("artist", Some("person"));
395        sys.add_type("scientist", Some("person"));
396        sys.add_type("businessperson", Some("person"));
397
398        // Organization hierarchy
399        sys.add_type("organization", None);
400        sys.add_type("company", Some("organization"));
401        sys.add_type("government", Some("organization"));
402        sys.add_type("educational", Some("organization"));
403        sys.add_type("sports_team", Some("organization"));
404        sys.add_type("political_party", Some("organization"));
405
406        // Location hierarchy
407        sys.add_type("location", None);
408        sys.add_type("country", Some("location"));
409        sys.add_type("city", Some("location"));
410        sys.add_type("state", Some("location"));
411        sys.add_type("facility", Some("location"));
412        sys.add_type("natural_feature", Some("location"));
413
414        sys
415    }
416
417    /// Add a type to the hierarchy.
418    pub fn add_type(&mut self, type_name: &str, parent_type: Option<&str>) {
419        let type_lower = type_name.to_lowercase();
420
421        if !self.all_types.contains(&type_lower) {
422            self.all_types.push(type_lower.clone());
423        }
424
425        if let Some(parent) = parent_type {
426            let parent_lower = parent.to_lowercase();
427            self.parent.insert(type_lower.clone(), parent_lower.clone());
428            self.children
429                .entry(parent_lower)
430                .or_default()
431                .push(type_lower);
432        }
433    }
434
435    /// Get all ancestors of a type (from specific to general).
436    #[must_use]
437    pub fn ancestors(&self, type_name: &str) -> Vec<String> {
438        let mut result = Vec::new();
439        let mut current = type_name.to_lowercase();
440
441        while let Some(parent) = self.parent.get(&current) {
442            result.push(parent.clone());
443            current = parent.clone();
444        }
445
446        result
447    }
448
449    /// Get all descendants of a type.
450    #[must_use]
451    pub fn descendants(&self, type_name: &str) -> Vec<String> {
452        let mut result = Vec::new();
453        let mut queue = vec![type_name.to_lowercase()];
454
455        while let Some(current) = queue.pop() {
456            if let Some(children) = self.children.get(&current) {
457                for child in children {
458                    result.push(child.clone());
459                    queue.push(child.clone());
460                }
461            }
462        }
463
464        result
465    }
466
467    /// Check if type_a is a subtype of type_b.
468    #[must_use]
469    pub fn is_subtype(&self, type_a: &str, type_b: &str) -> bool {
470        let a_lower = type_a.to_lowercase();
471        let b_lower = type_b.to_lowercase();
472
473        if a_lower == b_lower {
474            return true;
475        }
476
477        self.ancestors(&a_lower).contains(&b_lower)
478    }
479
480    /// Get the most specific common ancestor of two types.
481    #[must_use]
482    pub fn common_ancestor(&self, type_a: &str, type_b: &str) -> Option<String> {
483        let ancestors_a: std::collections::HashSet<_> = std::iter::once(type_a.to_lowercase())
484            .chain(self.ancestors(type_a))
485            .collect();
486
487        let current = type_b.to_lowercase();
488        if ancestors_a.contains(&current) {
489            return Some(current);
490        }
491
492        self.ancestors(type_b)
493            .into_iter()
494            .find(|ancestor| ancestors_a.contains(ancestor))
495            .map(|s| s.to_string())
496    }
497
498    /// Get all root types (types with no parent).
499    #[must_use]
500    pub fn roots(&self) -> Vec<String> {
501        self.all_types
502            .iter()
503            .filter(|t| !self.parent.contains_key(*t))
504            .cloned()
505            .collect()
506    }
507}
508
509/// Ontology mapper for cross-dataset type normalization.
510#[derive(Debug, Clone, Serialize, Deserialize)]
511pub struct OntologyMapper {
512    /// Source ontology name
513    pub source: String,
514    /// Target ontology name
515    pub target: String,
516    /// Mappings from source types to target types
517    mappings: HashMap<String, String>,
518}
519
520impl OntologyMapper {
521    /// Create a new ontology mapper.
522    #[must_use]
523    pub fn new(source: &str, target: &str) -> Self {
524        Self {
525            source: source.to_string(),
526            target: target.to_string(),
527            mappings: HashMap::new(),
528        }
529    }
530
531    /// Create a CoNLL-2003 → OntoNotes mapper.
532    #[must_use]
533    pub fn conll_to_ontonotes() -> Self {
534        let mut mapper = Self::new("conll2003", "ontonotes");
535        mapper.add("PER", "PERSON");
536        mapper.add("ORG", "ORG");
537        mapper.add("LOC", "GPE"); // CoNLL LOC ≈ OntoNotes GPE for most cases
538        mapper.add("MISC", "MISC"); // No direct equivalent
539        mapper
540    }
541
542    /// Create an OntoNotes → CoNLL-2003 mapper.
543    #[must_use]
544    pub fn ontonotes_to_conll() -> Self {
545        let mut mapper = Self::new("ontonotes", "conll2003");
546        mapper.add("PERSON", "PER");
547        mapper.add("ORG", "ORG");
548        mapper.add("GPE", "LOC");
549        mapper.add("LOC", "LOC");
550        mapper.add("FAC", "LOC");
551        mapper.add("NORP", "MISC");
552        mapper.add("WORK_OF_ART", "MISC");
553        mapper.add("EVENT", "MISC");
554        mapper.add("PRODUCT", "MISC");
555        mapper.add("LAW", "MISC");
556        mapper.add("LANGUAGE", "MISC");
557        // Numeric types typically not in CoNLL
558        mapper.add("DATE", "MISC");
559        mapper.add("TIME", "MISC");
560        mapper.add("MONEY", "MISC");
561        mapper.add("QUANTITY", "MISC");
562        mapper.add("PERCENT", "MISC");
563        mapper.add("CARDINAL", "MISC");
564        mapper.add("ORDINAL", "MISC");
565        mapper
566    }
567
568    /// Add a mapping.
569    pub fn add(&mut self, source_type: &str, target_type: &str) {
570        self.mappings
571            .insert(source_type.to_string(), target_type.to_string());
572    }
573
574    /// Map a type from source to target ontology.
575    #[must_use]
576    pub fn map(&self, source_type: &str) -> Option<String> {
577        self.mappings.get(source_type).cloned()
578    }
579
580    /// Map a type, falling back to original if no mapping exists.
581    #[must_use]
582    pub fn map_or_keep(&self, source_type: &str) -> String {
583        self.map(source_type)
584            .unwrap_or_else(|| source_type.to_string())
585    }
586}
587
588#[cfg(test)]
589mod tests {
590    use super::*;
591
592    #[test]
593    fn test_standard_normalizer() {
594        let norm = StandardNormalizer::default();
595
596        assert_eq!(norm.normalize("PER"), "person");
597        assert_eq!(norm.normalize("PERSON"), "person");
598        assert_eq!(norm.normalize("B-PER"), "person");
599
600        assert_eq!(norm.normalize("ORG"), "organization");
601        assert_eq!(norm.normalize("organisation"), "organization");
602
603        assert_eq!(norm.normalize("LOC"), "location");
604        assert_eq!(norm.normalize("GPE"), "location");
605    }
606
607    #[test]
608    fn test_expansion() {
609        let norm = StandardNormalizer::default();
610
611        let expansions = norm.expand("PER");
612        assert!(expansions.contains(&"person".to_string()));
613        assert!(expansions.contains(&"human".to_string()));
614    }
615
616    #[test]
617    fn test_hierarchical_types() {
618        let sys = HierarchicalTypeSystem::standard_ner();
619
620        assert!(sys.is_subtype("athlete", "person"));
621        assert!(sys.is_subtype("person", "person"));
622        assert!(!sys.is_subtype("person", "athlete"));
623
624        let ancestors = sys.ancestors("athlete");
625        assert_eq!(ancestors, vec!["person"]);
626
627        let descendants = sys.descendants("person");
628        assert!(descendants.contains(&"athlete".to_string()));
629    }
630
631    #[test]
632    fn test_ontology_mapper() {
633        let mapper = OntologyMapper::conll_to_ontonotes();
634
635        assert_eq!(mapper.map("PER"), Some("PERSON".to_string()));
636        assert_eq!(mapper.map("LOC"), Some("GPE".to_string()));
637    }
638
639    #[test]
640    fn test_normalizer_bio_prefix_stripping() {
641        let norm = StandardNormalizer::default();
642
643        // All BIO prefixes should be stripped
644        assert_eq!(norm.normalize("B-PER"), "person");
645        assert_eq!(norm.normalize("I-PER"), "person");
646        assert_eq!(norm.normalize("E-PER"), "person");
647        assert_eq!(norm.normalize("S-PER"), "person");
648    }
649
650    #[test]
651    fn test_normalizer_case_insensitive() {
652        let norm = StandardNormalizer::default();
653
654        assert_eq!(norm.normalize("per"), "person");
655        assert_eq!(norm.normalize("Per"), "person");
656        assert_eq!(norm.normalize("PER"), "person");
657        assert_eq!(norm.normalize("PERSON"), "person");
658    }
659
660    #[test]
661    fn test_expansion_all_types() {
662        let norm = StandardNormalizer::default();
663
664        // PER expansions
665        let per = norm.expand("PER");
666        assert!(per.len() >= 2);
667        assert!(per.contains(&"person".to_string()));
668
669        // ORG expansions
670        let org = norm.expand("ORG");
671        assert!(org.contains(&"organization".to_string()));
672
673        // LOC expansions
674        let loc = norm.expand("LOC");
675        assert!(loc.contains(&"location".to_string()));
676    }
677
678    #[test]
679    fn test_hierarchical_athletes() {
680        let sys = HierarchicalTypeSystem::standard_ner();
681
682        // athlete -> person
683        assert!(sys.is_subtype("athlete", "person"));
684
685        // politician -> person
686        assert!(sys.is_subtype("politician", "person"));
687
688        // transitivity: shouldn't match unrelated
689        assert!(!sys.is_subtype("athlete", "organization"));
690    }
691
692    #[test]
693    fn test_mapper_bidirectional() {
694        let mapper = OntologyMapper::conll_to_ontonotes();
695
696        // Known mappings
697        assert_eq!(mapper.map("PER"), Some("PERSON".to_string()));
698        assert_eq!(mapper.map("ORG"), Some("ORG".to_string()));
699        assert_eq!(mapper.map("LOC"), Some("GPE".to_string()));
700        assert_eq!(mapper.map("MISC"), Some("MISC".to_string()));
701
702        // Unknown type returns None
703        assert_eq!(mapper.map("UNKNOWN_TYPE"), None);
704    }
705
706    #[test]
707    fn test_mapper_or_keep() {
708        let mapper = OntologyMapper::conll_to_ontonotes();
709
710        // Known type gets mapped
711        assert_eq!(mapper.map_or_keep("PER"), "PERSON");
712
713        // Unknown type gets kept as-is
714        assert_eq!(mapper.map_or_keep("CUSTOM_TYPE"), "CUSTOM_TYPE");
715    }
716}