Skip to main content

graphrag_core/nlp/
custom_ner.rs

1//! Custom NER Training Pipeline
2//!
3//! This module provides a framework for training custom Named Entity Recognition models:
4//! - Pattern-based entity extraction
5//! - Dictionary/gazetteer matching
6//! - Rule-based extraction
7//! - Active learning support
8//! - Model fine-tuning preparation
9//!
10//! ## Use Cases
11//!
12//! - Domain-specific entities (medical terms, legal concepts, etc.)
13//! - Company-specific terminology
14//! - Custom product names
15//! - Technical jargon extraction
16
17use serde::{Deserialize, Serialize};
18use std::collections::{HashMap, HashSet};
19use regex::Regex;
20
21/// Entity type definition
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct EntityType {
24    /// Type name (e.g., "PROTEIN", "DRUG", "DISEASE")
25    pub name: String,
26    /// Type description
27    pub description: String,
28    /// Example entities of this type
29    pub examples: Vec<String>,
30    /// Patterns for recognition
31    pub patterns: Vec<String>,
32    /// Dictionary/gazetteer entries
33    pub dictionary: HashSet<String>,
34}
35
36impl EntityType {
37    /// Create new entity type
38    pub fn new(name: String, description: String) -> Self {
39        Self {
40            name,
41            description,
42            examples: Vec::new(),
43            patterns: Vec::new(),
44            dictionary: HashSet::new(),
45        }
46    }
47
48    /// Add example entity
49    pub fn add_example(&mut self, example: String) {
50        self.examples.push(example.clone());
51        self.dictionary.insert(example.to_lowercase());
52    }
53
54    /// Add pattern (regex)
55    pub fn add_pattern(&mut self, pattern: String) {
56        self.patterns.push(pattern);
57    }
58
59    /// Add dictionary entries (bulk)
60    pub fn add_dictionary_entries(&mut self, entries: Vec<String>) {
61        for entry in entries {
62            self.dictionary.insert(entry.to_lowercase());
63        }
64    }
65}
66
67/// Extraction rule
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct ExtractionRule {
70    /// Rule name
71    pub name: String,
72    /// Entity type this rule extracts
73    pub entity_type: String,
74    /// Rule type
75    pub rule_type: RuleType,
76    /// Rule pattern or configuration
77    pub pattern: String,
78    /// Minimum confidence for matches
79    pub min_confidence: f32,
80    /// Priority (higher = checked first)
81    pub priority: i32,
82}
83
84/// Rule types
85#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
86pub enum RuleType {
87    /// Exact string match
88    ExactMatch,
89    /// Regex pattern
90    Regex,
91    /// Prefix match
92    Prefix,
93    /// Suffix match
94    Suffix,
95    /// Contains substring
96    Contains,
97    /// Dictionary lookup
98    Dictionary,
99    /// Context-based (requires surrounding words)
100    Contextual,
101}
102
103/// Custom NER model
104pub struct CustomNER {
105    /// Entity types
106    entity_types: HashMap<String, EntityType>,
107    /// Extraction rules
108    rules: Vec<ExtractionRule>,
109    /// Compiled regex patterns
110    compiled_patterns: HashMap<String, Regex>,
111}
112
113impl CustomNER {
114    /// Create new custom NER model
115    pub fn new() -> Self {
116        Self {
117            entity_types: HashMap::new(),
118            rules: Vec::new(),
119            compiled_patterns: HashMap::new(),
120        }
121    }
122
123    /// Register entity type
124    pub fn register_entity_type(&mut self, entity_type: EntityType) {
125        self.entity_types.insert(entity_type.name.clone(), entity_type);
126    }
127
128    /// Add extraction rule
129    pub fn add_rule(&mut self, rule: ExtractionRule) {
130        // Compile regex if needed
131        if rule.rule_type == RuleType::Regex {
132            if let Ok(regex) = Regex::new(&rule.pattern) {
133                self.compiled_patterns.insert(rule.name.clone(), regex);
134            }
135        }
136
137        self.rules.push(rule);
138        self.rules.sort_by(|a, b| b.priority.cmp(&a.priority));
139    }
140
141    /// Extract entities from text
142    pub fn extract(&self, text: &str) -> Vec<ExtractedEntity> {
143        let mut entities = Vec::new();
144
145        // Apply rules in priority order
146        for rule in &self.rules {
147            let rule_entities = self.apply_rule(text, rule);
148            entities.extend(rule_entities);
149        }
150
151        // Deduplicate and resolve conflicts
152        self.resolve_overlaps(entities)
153    }
154
155    /// Apply a single extraction rule
156    fn apply_rule(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
157        match rule.rule_type {
158            RuleType::ExactMatch => self.extract_exact_match(text, rule),
159            RuleType::Regex => self.extract_regex(text, rule),
160            RuleType::Prefix => self.extract_prefix(text, rule),
161            RuleType::Suffix => self.extract_suffix(text, rule),
162            RuleType::Contains => self.extract_contains(text, rule),
163            RuleType::Dictionary => self.extract_dictionary(text, rule),
164            RuleType::Contextual => self.extract_contextual(text, rule),
165        }
166    }
167
168    /// Exact match extraction
169    fn extract_exact_match(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
170        let mut entities = Vec::new();
171        let pattern = &rule.pattern;
172        let text_lower = text.to_lowercase();
173        let pattern_lower = pattern.to_lowercase();
174
175        let mut start = 0;
176        while let Some(pos) = text_lower[start..].find(&pattern_lower) {
177            let absolute_pos = start + pos;
178            entities.push(ExtractedEntity {
179                text: text[absolute_pos..absolute_pos + pattern.len()].to_string(),
180                entity_type: rule.entity_type.clone(),
181                start: absolute_pos,
182                end: absolute_pos + pattern.len(),
183                confidence: 1.0,
184                rule_name: rule.name.clone(),
185            });
186
187            start = absolute_pos + pattern.len();
188        }
189
190        entities
191    }
192
193    /// Regex extraction
194    fn extract_regex(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
195        let mut entities = Vec::new();
196
197        if let Some(regex) = self.compiled_patterns.get(&rule.name) {
198            for capture in regex.captures_iter(text) {
199                if let Some(matched) = capture.get(0) {
200                    entities.push(ExtractedEntity {
201                        text: matched.as_str().to_string(),
202                        entity_type: rule.entity_type.clone(),
203                        start: matched.start(),
204                        end: matched.end(),
205                        confidence: 0.9,
206                        rule_name: rule.name.clone(),
207                    });
208                }
209            }
210        }
211
212        entities
213    }
214
215    /// Prefix match extraction
216    fn extract_prefix(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
217        let mut entities = Vec::new();
218        let words: Vec<&str> = text.split_whitespace().collect();
219        let mut pos = 0;
220
221        for word in words {
222            if word.to_lowercase().starts_with(&rule.pattern.to_lowercase()) {
223                entities.push(ExtractedEntity {
224                    text: word.to_string(),
225                    entity_type: rule.entity_type.clone(),
226                    start: pos,
227                    end: pos + word.len(),
228                    confidence: 0.7,
229                    rule_name: rule.name.clone(),
230                });
231            }
232            pos += word.len() + 1; // +1 for space
233        }
234
235        entities
236    }
237
238    /// Suffix match extraction
239    fn extract_suffix(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
240        let mut entities = Vec::new();
241        let words: Vec<&str> = text.split_whitespace().collect();
242        let mut pos = 0;
243
244        for word in words {
245            if word.to_lowercase().ends_with(&rule.pattern.to_lowercase()) {
246                entities.push(ExtractedEntity {
247                    text: word.to_string(),
248                    entity_type: rule.entity_type.clone(),
249                    start: pos,
250                    end: pos + word.len(),
251                    confidence: 0.7,
252                    rule_name: rule.name.clone(),
253                });
254            }
255            pos += word.len() + 1;
256        }
257
258        entities
259    }
260
261    /// Contains substring extraction
262    fn extract_contains(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
263        let mut entities = Vec::new();
264        let words: Vec<&str> = text.split_whitespace().collect();
265        let mut pos = 0;
266
267        for word in words {
268            if word.to_lowercase().contains(&rule.pattern.to_lowercase()) {
269                entities.push(ExtractedEntity {
270                    text: word.to_string(),
271                    entity_type: rule.entity_type.clone(),
272                    start: pos,
273                    end: pos + word.len(),
274                    confidence: 0.6,
275                    rule_name: rule.name.clone(),
276                });
277            }
278            pos += word.len() + 1;
279        }
280
281        entities
282    }
283
284    /// Dictionary-based extraction
285    fn extract_dictionary(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
286        let mut entities = Vec::new();
287
288        if let Some(entity_type) = self.entity_types.get(&rule.entity_type) {
289            let text_lower = text.to_lowercase();
290
291            for entry in &entity_type.dictionary {
292                let mut start = 0;
293                while let Some(pos) = text_lower[start..].find(entry) {
294                    let absolute_pos = start + pos;
295                    entities.push(ExtractedEntity {
296                        text: text[absolute_pos..absolute_pos + entry.len()].to_string(),
297                        entity_type: rule.entity_type.clone(),
298                        start: absolute_pos,
299                        end: absolute_pos + entry.len(),
300                        confidence: 0.95,
301                        rule_name: rule.name.clone(),
302                    });
303
304                    start = absolute_pos + entry.len();
305                }
306            }
307        }
308
309        entities
310    }
311
312    /// Contextual extraction (requires specific surrounding words)
313    fn extract_contextual(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
314        // Simplified contextual extraction
315        // Pattern format: "before_word|target|after_word"
316        let parts: Vec<&str> = rule.pattern.split('|').collect();
317        if parts.len() != 3 {
318            return Vec::new();
319        }
320
321        let before = parts[0];
322        let target = parts[1];
323        let after = parts[2];
324
325        let mut entities = Vec::new();
326        let words: Vec<&str> = text.split_whitespace().collect();
327
328        for window in words.windows(3) {
329            if window[0].to_lowercase().contains(&before.to_lowercase())
330                && window[1].to_lowercase().contains(&target.to_lowercase())
331                && window[2].to_lowercase().contains(&after.to_lowercase())
332            {
333                // Find position in original text
334                if let Some(pos) = text.find(window[1]) {
335                    entities.push(ExtractedEntity {
336                        text: window[1].to_string(),
337                        entity_type: rule.entity_type.clone(),
338                        start: pos,
339                        end: pos + window[1].len(),
340                        confidence: 0.85,
341                        rule_name: rule.name.clone(),
342                    });
343                }
344            }
345        }
346
347        entities
348    }
349
350    /// Resolve overlapping entities (keep higher confidence)
351    fn resolve_overlaps(&self, mut entities: Vec<ExtractedEntity>) -> Vec<ExtractedEntity> {
352        if entities.is_empty() {
353            return entities;
354        }
355
356        // Sort by position, then by confidence (descending)
357        entities.sort_by(|a, b| {
358            a.start.cmp(&b.start)
359                .then(b.confidence.partial_cmp(&a.confidence).unwrap())
360        });
361
362        let mut result = Vec::new();
363        let mut last_end = 0;
364
365        for entity in entities {
366            // Skip if overlaps with previous entity
367            if entity.start < last_end {
368                continue;
369            }
370
371            last_end = entity.end;
372            result.push(entity);
373        }
374
375        result
376    }
377
378    /// Get entity types
379    pub fn entity_types(&self) -> &HashMap<String, EntityType> {
380        &self.entity_types
381    }
382
383    /// Get rules
384    pub fn rules(&self) -> &[ExtractionRule] {
385        &self.rules
386    }
387}
388
389impl Default for CustomNER {
390    fn default() -> Self {
391        Self::new()
392    }
393}
394
395/// Extracted entity
396#[derive(Debug, Clone, Serialize, Deserialize)]
397pub struct ExtractedEntity {
398    /// Entity text
399    pub text: String,
400    /// Entity type
401    pub entity_type: String,
402    /// Start position in text
403    pub start: usize,
404    /// End position in text
405    pub end: usize,
406    /// Confidence score (0.0 to 1.0)
407    pub confidence: f32,
408    /// Rule that extracted this entity
409    pub rule_name: String,
410}
411
412/// Training dataset for custom NER
413#[derive(Debug, Clone, Serialize, Deserialize)]
414pub struct TrainingDataset {
415    /// Annotated examples
416    pub examples: Vec<AnnotatedExample>,
417}
418
419impl TrainingDataset {
420    /// Create new training dataset
421    pub fn new() -> Self {
422        Self {
423            examples: Vec::new(),
424        }
425    }
426
427    /// Add annotated example
428    pub fn add_example(&mut self, example: AnnotatedExample) {
429        self.examples.push(example);
430    }
431
432    /// Get statistics
433    pub fn statistics(&self) -> DatasetStatistics {
434        let total_examples = self.examples.len();
435        let mut entity_counts: HashMap<String, usize> = HashMap::new();
436
437        for example in &self.examples {
438            for entity in &example.entities {
439                *entity_counts.entry(entity.entity_type.clone()).or_insert(0) += 1;
440            }
441        }
442
443        DatasetStatistics {
444            total_examples,
445            entity_counts,
446        }
447    }
448}
449
450impl Default for TrainingDataset {
451    fn default() -> Self {
452        Self::new()
453    }
454}
455
456/// Annotated text example
457#[derive(Debug, Clone, Serialize, Deserialize)]
458pub struct AnnotatedExample {
459    /// Original text
460    pub text: String,
461    /// Annotated entities
462    pub entities: Vec<ExtractedEntity>,
463}
464
465/// Dataset statistics
466#[derive(Debug, Clone, Serialize, Deserialize)]
467pub struct DatasetStatistics {
468    /// Total examples
469    pub total_examples: usize,
470    /// Entity type counts
471    pub entity_counts: HashMap<String, usize>,
472}
473
474#[cfg(test)]
475mod tests {
476    use super::*;
477
478    #[test]
479    fn test_entity_type_creation() {
480        let mut entity_type = EntityType::new(
481            "PROTEIN".to_string(),
482            "Protein names".to_string(),
483        );
484
485        entity_type.add_example("hemoglobin".to_string());
486        entity_type.add_example("insulin".to_string());
487
488        assert_eq!(entity_type.examples.len(), 2);
489        assert_eq!(entity_type.dictionary.len(), 2);
490    }
491
492    #[test]
493    fn test_exact_match_extraction() {
494        let mut ner = CustomNER::new();
495
496        let rule = ExtractionRule {
497            name: "protein_exact".to_string(),
498            entity_type: "PROTEIN".to_string(),
499            rule_type: RuleType::ExactMatch,
500            pattern: "hemoglobin".to_string(),
501            min_confidence: 0.9,
502            priority: 1,
503        };
504
505        ner.add_rule(rule);
506
507        let text = "The protein hemoglobin is important. Hemoglobin carries oxygen.";
508        let entities = ner.extract(text);
509
510        assert_eq!(entities.len(), 2);
511        assert_eq!(entities[0].entity_type, "PROTEIN");
512        assert_eq!(entities[0].text.to_lowercase(), "hemoglobin");
513    }
514
515    #[test]
516    fn test_regex_extraction() {
517        let mut ner = CustomNER::new();
518
519        let rule = ExtractionRule {
520            name: "gene_pattern".to_string(),
521            entity_type: "GENE".to_string(),
522            rule_type: RuleType::Regex,
523            pattern: r"[A-Z]{2,4}\d+".to_string(),
524            min_confidence: 0.8,
525            priority: 1,
526        };
527
528        ner.add_rule(rule);
529
530        let text = "The genes TP53 and BRCA1 are tumor suppressors.";
531        let entities = ner.extract(text);
532
533        assert!(entities.len() >= 2);
534        assert!(entities.iter().any(|e| e.text == "TP53"));
535        assert!(entities.iter().any(|e| e.text == "BRCA1"));
536    }
537
538    #[test]
539    fn test_dictionary_extraction() {
540        let mut ner = CustomNER::new();
541
542        let mut protein_type = EntityType::new(
543            "PROTEIN".to_string(),
544            "Protein names".to_string(),
545        );
546        protein_type.add_dictionary_entries(vec![
547            "insulin".to_string(),
548            "hemoglobin".to_string(),
549            "collagen".to_string(),
550        ]);
551
552        ner.register_entity_type(protein_type);
553
554        let rule = ExtractionRule {
555            name: "protein_dict".to_string(),
556            entity_type: "PROTEIN".to_string(),
557            rule_type: RuleType::Dictionary,
558            pattern: "".to_string(),
559            min_confidence: 0.9,
560            priority: 2,
561        };
562
563        ner.add_rule(rule);
564
565        let text = "Insulin regulates blood sugar. Hemoglobin transports oxygen.";
566        let entities = ner.extract(text);
567
568        assert_eq!(entities.len(), 2);
569    }
570
571    #[test]
572    fn test_prefix_extraction() {
573        let mut ner = CustomNER::new();
574
575        let rule = ExtractionRule {
576            name: "bio_prefix".to_string(),
577            entity_type: "BIO_TERM".to_string(),
578            rule_type: RuleType::Prefix,
579            pattern: "bio".to_string(),
580            min_confidence: 0.7,
581            priority: 1,
582        };
583
584        ner.add_rule(rule);
585
586        let text = "Biology and biochemistry are fascinating subjects.";
587        let entities = ner.extract(text);
588
589        assert!(entities.len() >= 2);
590    }
591
592    #[test]
593    fn test_overlap_resolution() {
594        let mut ner = CustomNER::new();
595
596        let rule1 = ExtractionRule {
597            name: "rule1".to_string(),
598            entity_type: "TYPE1".to_string(),
599            rule_type: RuleType::ExactMatch,
600            pattern: "test".to_string(),
601            min_confidence: 0.9,
602            priority: 1,
603        };
604
605        let rule2 = ExtractionRule {
606            name: "rule2".to_string(),
607            entity_type: "TYPE2".to_string(),
608            rule_type: RuleType::ExactMatch,
609            pattern: "testing".to_string(),
610            min_confidence: 0.95,
611            priority: 2,
612        };
613
614        ner.add_rule(rule1);
615        ner.add_rule(rule2);
616
617        let text = "We are testing this code.";
618        let entities = ner.extract(text);
619
620        // Should only extract one entity (higher confidence/priority wins)
621        assert_eq!(entities.len(), 1);
622    }
623
624    #[test]
625    fn test_training_dataset() {
626        let mut dataset = TrainingDataset::new();
627
628        let example = AnnotatedExample {
629            text: "Insulin regulates glucose.".to_string(),
630            entities: vec![
631                ExtractedEntity {
632                    text: "Insulin".to_string(),
633                    entity_type: "PROTEIN".to_string(),
634                    start: 0,
635                    end: 7,
636                    confidence: 1.0,
637                    rule_name: "manual".to_string(),
638                },
639            ],
640        };
641
642        dataset.add_example(example);
643
644        let stats = dataset.statistics();
645        assert_eq!(stats.total_examples, 1);
646        assert_eq!(stats.entity_counts.get("PROTEIN"), Some(&1));
647    }
648}