Skip to main content

graphrag_core/nlp/
custom_ner.rs

1//! Custom NER Training Pipeline
2//!
3//! This module provides a framework for training custom Named Entity Recognition models:
4//! - Pattern-based entity extraction
5//! - Dictionary/gazetteer matching
6//! - Rule-based extraction
7//! - Active learning support
8//! - Model fine-tuning preparation
9//!
10//! ## Use Cases
11//!
12//! - Domain-specific entities (medical terms, legal concepts, etc.)
13//! - Company-specific terminology
14//! - Custom product names
15//! - Technical jargon extraction
16
17use regex::Regex;
18use serde::{Deserialize, Serialize};
19use std::collections::{HashMap, HashSet};
20
21/// Entity type definition
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct EntityType {
24    /// Type name (e.g., "PROTEIN", "DRUG", "DISEASE")
25    pub name: String,
26    /// Type description
27    pub description: String,
28    /// Example entities of this type
29    pub examples: Vec<String>,
30    /// Patterns for recognition
31    pub patterns: Vec<String>,
32    /// Dictionary/gazetteer entries
33    pub dictionary: HashSet<String>,
34}
35
36impl EntityType {
37    /// Create new entity type
38    pub fn new(name: String, description: String) -> Self {
39        Self {
40            name,
41            description,
42            examples: Vec::new(),
43            patterns: Vec::new(),
44            dictionary: HashSet::new(),
45        }
46    }
47
48    /// Add example entity
49    pub fn add_example(&mut self, example: String) {
50        self.examples.push(example.clone());
51        self.dictionary.insert(example.to_lowercase());
52    }
53
54    /// Add pattern (regex)
55    pub fn add_pattern(&mut self, pattern: String) {
56        self.patterns.push(pattern);
57    }
58
59    /// Add dictionary entries (bulk)
60    pub fn add_dictionary_entries(&mut self, entries: Vec<String>) {
61        for entry in entries {
62            self.dictionary.insert(entry.to_lowercase());
63        }
64    }
65}
66
67/// Extraction rule
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct ExtractionRule {
70    /// Rule name
71    pub name: String,
72    /// Entity type this rule extracts
73    pub entity_type: String,
74    /// Rule type
75    pub rule_type: RuleType,
76    /// Rule pattern or configuration
77    pub pattern: String,
78    /// Minimum confidence for matches
79    pub min_confidence: f32,
80    /// Priority (higher = checked first)
81    pub priority: i32,
82}
83
84/// Rule types
85#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
86pub enum RuleType {
87    /// Exact string match
88    ExactMatch,
89    /// Regex pattern
90    Regex,
91    /// Prefix match
92    Prefix,
93    /// Suffix match
94    Suffix,
95    /// Contains substring
96    Contains,
97    /// Dictionary lookup
98    Dictionary,
99    /// Context-based (requires surrounding words)
100    Contextual,
101}
102
103/// Custom NER model
104pub struct CustomNER {
105    /// Entity types
106    entity_types: HashMap<String, EntityType>,
107    /// Extraction rules
108    rules: Vec<ExtractionRule>,
109    /// Compiled regex patterns
110    compiled_patterns: HashMap<String, Regex>,
111}
112
113impl CustomNER {
114    /// Create new custom NER model
115    pub fn new() -> Self {
116        Self {
117            entity_types: HashMap::new(),
118            rules: Vec::new(),
119            compiled_patterns: HashMap::new(),
120        }
121    }
122
123    /// Register entity type
124    pub fn register_entity_type(&mut self, entity_type: EntityType) {
125        self.entity_types
126            .insert(entity_type.name.clone(), entity_type);
127    }
128
129    /// Add extraction rule
130    pub fn add_rule(&mut self, rule: ExtractionRule) {
131        // Compile regex if needed
132        if rule.rule_type == RuleType::Regex {
133            if let Ok(regex) = Regex::new(&rule.pattern) {
134                self.compiled_patterns.insert(rule.name.clone(), regex);
135            }
136        }
137
138        self.rules.push(rule);
139        self.rules
140            .sort_by_key(|rule| std::cmp::Reverse(rule.priority));
141    }
142
143    /// Extract entities from text
144    pub fn extract(&self, text: &str) -> Vec<ExtractedEntity> {
145        let mut entities = Vec::new();
146
147        // Apply rules in priority order
148        for rule in &self.rules {
149            let rule_entities = self.apply_rule(text, rule);
150            entities.extend(rule_entities);
151        }
152
153        // Deduplicate and resolve conflicts
154        self.resolve_overlaps(entities)
155    }
156
157    /// Apply a single extraction rule
158    fn apply_rule(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
159        match rule.rule_type {
160            RuleType::ExactMatch => self.extract_exact_match(text, rule),
161            RuleType::Regex => self.extract_regex(text, rule),
162            RuleType::Prefix => self.extract_prefix(text, rule),
163            RuleType::Suffix => self.extract_suffix(text, rule),
164            RuleType::Contains => self.extract_contains(text, rule),
165            RuleType::Dictionary => self.extract_dictionary(text, rule),
166            RuleType::Contextual => self.extract_contextual(text, rule),
167        }
168    }
169
170    /// Exact match extraction
171    fn extract_exact_match(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
172        let mut entities = Vec::new();
173        let pattern = &rule.pattern;
174        let text_lower = text.to_lowercase();
175        let pattern_lower = pattern.to_lowercase();
176
177        let mut start = 0;
178        while let Some(pos) = text_lower[start..].find(&pattern_lower) {
179            let absolute_pos = start + pos;
180            entities.push(ExtractedEntity {
181                text: text[absolute_pos..absolute_pos + pattern.len()].to_string(),
182                entity_type: rule.entity_type.clone(),
183                start: absolute_pos,
184                end: absolute_pos + pattern.len(),
185                confidence: 1.0,
186                rule_name: rule.name.clone(),
187            });
188
189            start = absolute_pos + pattern.len();
190        }
191
192        entities
193    }
194
195    /// Regex extraction
196    fn extract_regex(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
197        let mut entities = Vec::new();
198
199        if let Some(regex) = self.compiled_patterns.get(&rule.name) {
200            for capture in regex.captures_iter(text) {
201                if let Some(matched) = capture.get(0) {
202                    entities.push(ExtractedEntity {
203                        text: matched.as_str().to_string(),
204                        entity_type: rule.entity_type.clone(),
205                        start: matched.start(),
206                        end: matched.end(),
207                        confidence: 0.9,
208                        rule_name: rule.name.clone(),
209                    });
210                }
211            }
212        }
213
214        entities
215    }
216
217    /// Prefix match extraction
218    fn extract_prefix(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
219        let mut entities = Vec::new();
220        let words: Vec<&str> = text.split_whitespace().collect();
221        let mut pos = 0;
222
223        for word in words {
224            if word
225                .to_lowercase()
226                .starts_with(&rule.pattern.to_lowercase())
227            {
228                entities.push(ExtractedEntity {
229                    text: word.to_string(),
230                    entity_type: rule.entity_type.clone(),
231                    start: pos,
232                    end: pos + word.len(),
233                    confidence: 0.7,
234                    rule_name: rule.name.clone(),
235                });
236            }
237            pos += word.len() + 1; // +1 for space
238        }
239
240        entities
241    }
242
243    /// Suffix match extraction
244    fn extract_suffix(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
245        let mut entities = Vec::new();
246        let words: Vec<&str> = text.split_whitespace().collect();
247        let mut pos = 0;
248
249        for word in words {
250            if word.to_lowercase().ends_with(&rule.pattern.to_lowercase()) {
251                entities.push(ExtractedEntity {
252                    text: word.to_string(),
253                    entity_type: rule.entity_type.clone(),
254                    start: pos,
255                    end: pos + word.len(),
256                    confidence: 0.7,
257                    rule_name: rule.name.clone(),
258                });
259            }
260            pos += word.len() + 1;
261        }
262
263        entities
264    }
265
266    /// Contains substring extraction
267    fn extract_contains(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
268        let mut entities = Vec::new();
269        let words: Vec<&str> = text.split_whitespace().collect();
270        let mut pos = 0;
271
272        for word in words {
273            if word.to_lowercase().contains(&rule.pattern.to_lowercase()) {
274                entities.push(ExtractedEntity {
275                    text: word.to_string(),
276                    entity_type: rule.entity_type.clone(),
277                    start: pos,
278                    end: pos + word.len(),
279                    confidence: 0.6,
280                    rule_name: rule.name.clone(),
281                });
282            }
283            pos += word.len() + 1;
284        }
285
286        entities
287    }
288
289    /// Dictionary-based extraction
290    fn extract_dictionary(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
291        let mut entities = Vec::new();
292
293        if let Some(entity_type) = self.entity_types.get(&rule.entity_type) {
294            let text_lower = text.to_lowercase();
295
296            for entry in &entity_type.dictionary {
297                let mut start = 0;
298                while let Some(pos) = text_lower[start..].find(entry) {
299                    let absolute_pos = start + pos;
300                    entities.push(ExtractedEntity {
301                        text: text[absolute_pos..absolute_pos + entry.len()].to_string(),
302                        entity_type: rule.entity_type.clone(),
303                        start: absolute_pos,
304                        end: absolute_pos + entry.len(),
305                        confidence: 0.95,
306                        rule_name: rule.name.clone(),
307                    });
308
309                    start = absolute_pos + entry.len();
310                }
311            }
312        }
313
314        entities
315    }
316
317    /// Contextual extraction (requires specific surrounding words)
318    fn extract_contextual(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
319        // Simplified contextual extraction
320        // Pattern format: "before_word|target|after_word"
321        let parts: Vec<&str> = rule.pattern.split('|').collect();
322        if parts.len() != 3 {
323            return Vec::new();
324        }
325
326        let before = parts[0];
327        let target = parts[1];
328        let after = parts[2];
329
330        let mut entities = Vec::new();
331        let words: Vec<&str> = text.split_whitespace().collect();
332
333        for window in words.windows(3) {
334            if window[0].to_lowercase().contains(&before.to_lowercase())
335                && window[1].to_lowercase().contains(&target.to_lowercase())
336                && window[2].to_lowercase().contains(&after.to_lowercase())
337            {
338                // Find position in original text
339                if let Some(pos) = text.find(window[1]) {
340                    entities.push(ExtractedEntity {
341                        text: window[1].to_string(),
342                        entity_type: rule.entity_type.clone(),
343                        start: pos,
344                        end: pos + window[1].len(),
345                        confidence: 0.85,
346                        rule_name: rule.name.clone(),
347                    });
348                }
349            }
350        }
351
352        entities
353    }
354
355    /// Resolve overlapping entities (keep higher confidence)
356    fn resolve_overlaps(&self, mut entities: Vec<ExtractedEntity>) -> Vec<ExtractedEntity> {
357        if entities.is_empty() {
358            return entities;
359        }
360
361        // Sort by position, then by confidence (descending)
362        entities.sort_by(|a, b| {
363            a.start.cmp(&b.start).then(
364                b.confidence
365                    .partial_cmp(&a.confidence)
366                    .unwrap_or(std::cmp::Ordering::Equal),
367            )
368        });
369
370        let mut result = Vec::new();
371        let mut last_end = 0;
372
373        for entity in entities {
374            // Skip if overlaps with previous entity
375            if entity.start < last_end {
376                continue;
377            }
378
379            last_end = entity.end;
380            result.push(entity);
381        }
382
383        result
384    }
385
386    /// Get entity types
387    pub fn entity_types(&self) -> &HashMap<String, EntityType> {
388        &self.entity_types
389    }
390
391    /// Get rules
392    pub fn rules(&self) -> &[ExtractionRule] {
393        &self.rules
394    }
395}
396
397impl Default for CustomNER {
398    fn default() -> Self {
399        Self::new()
400    }
401}
402
403/// Extracted entity
404#[derive(Debug, Clone, Serialize, Deserialize)]
405pub struct ExtractedEntity {
406    /// Entity text
407    pub text: String,
408    /// Entity type
409    pub entity_type: String,
410    /// Start position in text
411    pub start: usize,
412    /// End position in text
413    pub end: usize,
414    /// Confidence score (0.0 to 1.0)
415    pub confidence: f32,
416    /// Rule that extracted this entity
417    pub rule_name: String,
418}
419
420/// Training dataset for custom NER
421#[derive(Debug, Clone, Serialize, Deserialize)]
422pub struct TrainingDataset {
423    /// Annotated examples
424    pub examples: Vec<AnnotatedExample>,
425}
426
427impl TrainingDataset {
428    /// Create new training dataset
429    pub fn new() -> Self {
430        Self {
431            examples: Vec::new(),
432        }
433    }
434
435    /// Add annotated example
436    pub fn add_example(&mut self, example: AnnotatedExample) {
437        self.examples.push(example);
438    }
439
440    /// Get statistics
441    pub fn statistics(&self) -> DatasetStatistics {
442        let total_examples = self.examples.len();
443        let mut entity_counts: HashMap<String, usize> = HashMap::new();
444
445        for example in &self.examples {
446            for entity in &example.entities {
447                *entity_counts.entry(entity.entity_type.clone()).or_insert(0) += 1;
448            }
449        }
450
451        DatasetStatistics {
452            total_examples,
453            entity_counts,
454        }
455    }
456}
457
458impl Default for TrainingDataset {
459    fn default() -> Self {
460        Self::new()
461    }
462}
463
464/// Annotated text example
465#[derive(Debug, Clone, Serialize, Deserialize)]
466pub struct AnnotatedExample {
467    /// Original text
468    pub text: String,
469    /// Annotated entities
470    pub entities: Vec<ExtractedEntity>,
471}
472
473/// Dataset statistics
474#[derive(Debug, Clone, Serialize, Deserialize)]
475pub struct DatasetStatistics {
476    /// Total examples
477    pub total_examples: usize,
478    /// Entity type counts
479    pub entity_counts: HashMap<String, usize>,
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485
486    #[test]
487    fn test_entity_type_creation() {
488        let mut entity_type = EntityType::new("PROTEIN".to_string(), "Protein names".to_string());
489
490        entity_type.add_example("hemoglobin".to_string());
491        entity_type.add_example("insulin".to_string());
492
493        assert_eq!(entity_type.examples.len(), 2);
494        assert_eq!(entity_type.dictionary.len(), 2);
495    }
496
497    #[test]
498    fn test_exact_match_extraction() {
499        let mut ner = CustomNER::new();
500
501        let rule = ExtractionRule {
502            name: "protein_exact".to_string(),
503            entity_type: "PROTEIN".to_string(),
504            rule_type: RuleType::ExactMatch,
505            pattern: "hemoglobin".to_string(),
506            min_confidence: 0.9,
507            priority: 1,
508        };
509
510        ner.add_rule(rule);
511
512        let text = "The protein hemoglobin is important. Hemoglobin carries oxygen.";
513        let entities = ner.extract(text);
514
515        assert_eq!(entities.len(), 2);
516        assert_eq!(entities[0].entity_type, "PROTEIN");
517        assert_eq!(entities[0].text.to_lowercase(), "hemoglobin");
518    }
519
520    #[test]
521    fn test_regex_extraction() {
522        let mut ner = CustomNER::new();
523
524        let rule = ExtractionRule {
525            name: "gene_pattern".to_string(),
526            entity_type: "GENE".to_string(),
527            rule_type: RuleType::Regex,
528            pattern: r"[A-Z]{2,4}\d+".to_string(),
529            min_confidence: 0.8,
530            priority: 1,
531        };
532
533        ner.add_rule(rule);
534
535        let text = "The genes TP53 and BRCA1 are tumor suppressors.";
536        let entities = ner.extract(text);
537
538        assert!(entities.len() >= 2);
539        assert!(entities.iter().any(|e| e.text == "TP53"));
540        assert!(entities.iter().any(|e| e.text == "BRCA1"));
541    }
542
543    #[test]
544    fn test_dictionary_extraction() {
545        let mut ner = CustomNER::new();
546
547        let mut protein_type = EntityType::new("PROTEIN".to_string(), "Protein names".to_string());
548        protein_type.add_dictionary_entries(vec![
549            "insulin".to_string(),
550            "hemoglobin".to_string(),
551            "collagen".to_string(),
552        ]);
553
554        ner.register_entity_type(protein_type);
555
556        let rule = ExtractionRule {
557            name: "protein_dict".to_string(),
558            entity_type: "PROTEIN".to_string(),
559            rule_type: RuleType::Dictionary,
560            pattern: "".to_string(),
561            min_confidence: 0.9,
562            priority: 2,
563        };
564
565        ner.add_rule(rule);
566
567        let text = "Insulin regulates blood sugar. Hemoglobin transports oxygen.";
568        let entities = ner.extract(text);
569
570        assert_eq!(entities.len(), 2);
571    }
572
573    #[test]
574    fn test_prefix_extraction() {
575        let mut ner = CustomNER::new();
576
577        let rule = ExtractionRule {
578            name: "bio_prefix".to_string(),
579            entity_type: "BIO_TERM".to_string(),
580            rule_type: RuleType::Prefix,
581            pattern: "bio".to_string(),
582            min_confidence: 0.7,
583            priority: 1,
584        };
585
586        ner.add_rule(rule);
587
588        let text = "Biology and biochemistry are fascinating subjects.";
589        let entities = ner.extract(text);
590
591        assert!(entities.len() >= 2);
592    }
593
594    #[test]
595    fn test_overlap_resolution() {
596        let mut ner = CustomNER::new();
597
598        let rule1 = ExtractionRule {
599            name: "rule1".to_string(),
600            entity_type: "TYPE1".to_string(),
601            rule_type: RuleType::ExactMatch,
602            pattern: "test".to_string(),
603            min_confidence: 0.9,
604            priority: 1,
605        };
606
607        let rule2 = ExtractionRule {
608            name: "rule2".to_string(),
609            entity_type: "TYPE2".to_string(),
610            rule_type: RuleType::ExactMatch,
611            pattern: "testing".to_string(),
612            min_confidence: 0.95,
613            priority: 2,
614        };
615
616        ner.add_rule(rule1);
617        ner.add_rule(rule2);
618
619        let text = "We are testing this code.";
620        let entities = ner.extract(text);
621
622        // Should only extract one entity (higher confidence/priority wins)
623        assert_eq!(entities.len(), 1);
624    }
625
626    #[test]
627    fn test_training_dataset() {
628        let mut dataset = TrainingDataset::new();
629
630        let example = AnnotatedExample {
631            text: "Insulin regulates glucose.".to_string(),
632            entities: vec![ExtractedEntity {
633                text: "Insulin".to_string(),
634                entity_type: "PROTEIN".to_string(),
635                start: 0,
636                end: 7,
637                confidence: 1.0,
638                rule_name: "manual".to_string(),
639            }],
640        };
641
642        dataset.add_example(example);
643
644        let stats = dataset.statistics();
645        assert_eq!(stats.total_examples, 1);
646        assert_eq!(stats.entity_counts.get("PROTEIN"), Some(&1));
647    }
648}