Skip to main content

graphrag_core/nlp/
custom_ner.rs

1//! Custom NER Training Pipeline
2//!
3//! This module provides a framework for training custom Named Entity Recognition models:
4//! - Pattern-based entity extraction
5//! - Dictionary/gazetteer matching
6//! - Rule-based extraction
7//! - Active learning support
8//! - Model fine-tuning preparation
9//!
10//! ## Use Cases
11//!
12//! - Domain-specific entities (medical terms, legal concepts, etc.)
13//! - Company-specific terminology
14//! - Custom product names
15//! - Technical jargon extraction
16
17use regex::Regex;
18use serde::{Deserialize, Serialize};
19use std::collections::{HashMap, HashSet};
20
21/// Entity type definition
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct EntityType {
24    /// Type name (e.g., "PROTEIN", "DRUG", "DISEASE")
25    pub name: String,
26    /// Type description
27    pub description: String,
28    /// Example entities of this type
29    pub examples: Vec<String>,
30    /// Patterns for recognition
31    pub patterns: Vec<String>,
32    /// Dictionary/gazetteer entries
33    pub dictionary: HashSet<String>,
34}
35
36impl EntityType {
37    /// Create new entity type
38    pub fn new(name: String, description: String) -> Self {
39        Self {
40            name,
41            description,
42            examples: Vec::new(),
43            patterns: Vec::new(),
44            dictionary: HashSet::new(),
45        }
46    }
47
48    /// Add example entity
49    pub fn add_example(&mut self, example: String) {
50        self.examples.push(example.clone());
51        self.dictionary.insert(example.to_lowercase());
52    }
53
54    /// Add pattern (regex)
55    pub fn add_pattern(&mut self, pattern: String) {
56        self.patterns.push(pattern);
57    }
58
59    /// Add dictionary entries (bulk)
60    pub fn add_dictionary_entries(&mut self, entries: Vec<String>) {
61        for entry in entries {
62            self.dictionary.insert(entry.to_lowercase());
63        }
64    }
65}
66
67/// Extraction rule
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct ExtractionRule {
70    /// Rule name
71    pub name: String,
72    /// Entity type this rule extracts
73    pub entity_type: String,
74    /// Rule type
75    pub rule_type: RuleType,
76    /// Rule pattern or configuration
77    pub pattern: String,
78    /// Minimum confidence for matches
79    pub min_confidence: f32,
80    /// Priority (higher = checked first)
81    pub priority: i32,
82}
83
84/// Rule types
85#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
86pub enum RuleType {
87    /// Exact string match
88    ExactMatch,
89    /// Regex pattern
90    Regex,
91    /// Prefix match
92    Prefix,
93    /// Suffix match
94    Suffix,
95    /// Contains substring
96    Contains,
97    /// Dictionary lookup
98    Dictionary,
99    /// Context-based (requires surrounding words)
100    Contextual,
101}
102
103/// Custom NER model
104pub struct CustomNER {
105    /// Entity types
106    entity_types: HashMap<String, EntityType>,
107    /// Extraction rules
108    rules: Vec<ExtractionRule>,
109    /// Compiled regex patterns
110    compiled_patterns: HashMap<String, Regex>,
111}
112
113impl CustomNER {
114    /// Create new custom NER model
115    pub fn new() -> Self {
116        Self {
117            entity_types: HashMap::new(),
118            rules: Vec::new(),
119            compiled_patterns: HashMap::new(),
120        }
121    }
122
123    /// Register entity type
124    pub fn register_entity_type(&mut self, entity_type: EntityType) {
125        self.entity_types
126            .insert(entity_type.name.clone(), entity_type);
127    }
128
129    /// Add extraction rule
130    pub fn add_rule(&mut self, rule: ExtractionRule) {
131        // Compile regex if needed
132        if rule.rule_type == RuleType::Regex {
133            if let Ok(regex) = Regex::new(&rule.pattern) {
134                self.compiled_patterns.insert(rule.name.clone(), regex);
135            }
136        }
137
138        self.rules.push(rule);
139        self.rules.sort_by(|a, b| b.priority.cmp(&a.priority));
140    }
141
142    /// Extract entities from text
143    pub fn extract(&self, text: &str) -> Vec<ExtractedEntity> {
144        let mut entities = Vec::new();
145
146        // Apply rules in priority order
147        for rule in &self.rules {
148            let rule_entities = self.apply_rule(text, rule);
149            entities.extend(rule_entities);
150        }
151
152        // Deduplicate and resolve conflicts
153        self.resolve_overlaps(entities)
154    }
155
156    /// Apply a single extraction rule
157    fn apply_rule(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
158        match rule.rule_type {
159            RuleType::ExactMatch => self.extract_exact_match(text, rule),
160            RuleType::Regex => self.extract_regex(text, rule),
161            RuleType::Prefix => self.extract_prefix(text, rule),
162            RuleType::Suffix => self.extract_suffix(text, rule),
163            RuleType::Contains => self.extract_contains(text, rule),
164            RuleType::Dictionary => self.extract_dictionary(text, rule),
165            RuleType::Contextual => self.extract_contextual(text, rule),
166        }
167    }
168
169    /// Exact match extraction
170    fn extract_exact_match(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
171        let mut entities = Vec::new();
172        let pattern = &rule.pattern;
173        let text_lower = text.to_lowercase();
174        let pattern_lower = pattern.to_lowercase();
175
176        let mut start = 0;
177        while let Some(pos) = text_lower[start..].find(&pattern_lower) {
178            let absolute_pos = start + pos;
179            entities.push(ExtractedEntity {
180                text: text[absolute_pos..absolute_pos + pattern.len()].to_string(),
181                entity_type: rule.entity_type.clone(),
182                start: absolute_pos,
183                end: absolute_pos + pattern.len(),
184                confidence: 1.0,
185                rule_name: rule.name.clone(),
186            });
187
188            start = absolute_pos + pattern.len();
189        }
190
191        entities
192    }
193
194    /// Regex extraction
195    fn extract_regex(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
196        let mut entities = Vec::new();
197
198        if let Some(regex) = self.compiled_patterns.get(&rule.name) {
199            for capture in regex.captures_iter(text) {
200                if let Some(matched) = capture.get(0) {
201                    entities.push(ExtractedEntity {
202                        text: matched.as_str().to_string(),
203                        entity_type: rule.entity_type.clone(),
204                        start: matched.start(),
205                        end: matched.end(),
206                        confidence: 0.9,
207                        rule_name: rule.name.clone(),
208                    });
209                }
210            }
211        }
212
213        entities
214    }
215
216    /// Prefix match extraction
217    fn extract_prefix(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
218        let mut entities = Vec::new();
219        let words: Vec<&str> = text.split_whitespace().collect();
220        let mut pos = 0;
221
222        for word in words {
223            if word
224                .to_lowercase()
225                .starts_with(&rule.pattern.to_lowercase())
226            {
227                entities.push(ExtractedEntity {
228                    text: word.to_string(),
229                    entity_type: rule.entity_type.clone(),
230                    start: pos,
231                    end: pos + word.len(),
232                    confidence: 0.7,
233                    rule_name: rule.name.clone(),
234                });
235            }
236            pos += word.len() + 1; // +1 for space
237        }
238
239        entities
240    }
241
242    /// Suffix match extraction
243    fn extract_suffix(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
244        let mut entities = Vec::new();
245        let words: Vec<&str> = text.split_whitespace().collect();
246        let mut pos = 0;
247
248        for word in words {
249            if word.to_lowercase().ends_with(&rule.pattern.to_lowercase()) {
250                entities.push(ExtractedEntity {
251                    text: word.to_string(),
252                    entity_type: rule.entity_type.clone(),
253                    start: pos,
254                    end: pos + word.len(),
255                    confidence: 0.7,
256                    rule_name: rule.name.clone(),
257                });
258            }
259            pos += word.len() + 1;
260        }
261
262        entities
263    }
264
265    /// Contains substring extraction
266    fn extract_contains(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
267        let mut entities = Vec::new();
268        let words: Vec<&str> = text.split_whitespace().collect();
269        let mut pos = 0;
270
271        for word in words {
272            if word.to_lowercase().contains(&rule.pattern.to_lowercase()) {
273                entities.push(ExtractedEntity {
274                    text: word.to_string(),
275                    entity_type: rule.entity_type.clone(),
276                    start: pos,
277                    end: pos + word.len(),
278                    confidence: 0.6,
279                    rule_name: rule.name.clone(),
280                });
281            }
282            pos += word.len() + 1;
283        }
284
285        entities
286    }
287
288    /// Dictionary-based extraction
289    fn extract_dictionary(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
290        let mut entities = Vec::new();
291
292        if let Some(entity_type) = self.entity_types.get(&rule.entity_type) {
293            let text_lower = text.to_lowercase();
294
295            for entry in &entity_type.dictionary {
296                let mut start = 0;
297                while let Some(pos) = text_lower[start..].find(entry) {
298                    let absolute_pos = start + pos;
299                    entities.push(ExtractedEntity {
300                        text: text[absolute_pos..absolute_pos + entry.len()].to_string(),
301                        entity_type: rule.entity_type.clone(),
302                        start: absolute_pos,
303                        end: absolute_pos + entry.len(),
304                        confidence: 0.95,
305                        rule_name: rule.name.clone(),
306                    });
307
308                    start = absolute_pos + entry.len();
309                }
310            }
311        }
312
313        entities
314    }
315
316    /// Contextual extraction (requires specific surrounding words)
317    fn extract_contextual(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
318        // Simplified contextual extraction
319        // Pattern format: "before_word|target|after_word"
320        let parts: Vec<&str> = rule.pattern.split('|').collect();
321        if parts.len() != 3 {
322            return Vec::new();
323        }
324
325        let before = parts[0];
326        let target = parts[1];
327        let after = parts[2];
328
329        let mut entities = Vec::new();
330        let words: Vec<&str> = text.split_whitespace().collect();
331
332        for window in words.windows(3) {
333            if window[0].to_lowercase().contains(&before.to_lowercase())
334                && window[1].to_lowercase().contains(&target.to_lowercase())
335                && window[2].to_lowercase().contains(&after.to_lowercase())
336            {
337                // Find position in original text
338                if let Some(pos) = text.find(window[1]) {
339                    entities.push(ExtractedEntity {
340                        text: window[1].to_string(),
341                        entity_type: rule.entity_type.clone(),
342                        start: pos,
343                        end: pos + window[1].len(),
344                        confidence: 0.85,
345                        rule_name: rule.name.clone(),
346                    });
347                }
348            }
349        }
350
351        entities
352    }
353
354    /// Resolve overlapping entities (keep higher confidence)
355    fn resolve_overlaps(&self, mut entities: Vec<ExtractedEntity>) -> Vec<ExtractedEntity> {
356        if entities.is_empty() {
357            return entities;
358        }
359
360        // Sort by position, then by confidence (descending)
361        entities.sort_by(|a, b| {
362            a.start
363                .cmp(&b.start)
364                .then(b.confidence.partial_cmp(&a.confidence).unwrap())
365        });
366
367        let mut result = Vec::new();
368        let mut last_end = 0;
369
370        for entity in entities {
371            // Skip if overlaps with previous entity
372            if entity.start < last_end {
373                continue;
374            }
375
376            last_end = entity.end;
377            result.push(entity);
378        }
379
380        result
381    }
382
383    /// Get entity types
384    pub fn entity_types(&self) -> &HashMap<String, EntityType> {
385        &self.entity_types
386    }
387
388    /// Get rules
389    pub fn rules(&self) -> &[ExtractionRule] {
390        &self.rules
391    }
392}
393
394impl Default for CustomNER {
395    fn default() -> Self {
396        Self::new()
397    }
398}
399
400/// Extracted entity
401#[derive(Debug, Clone, Serialize, Deserialize)]
402pub struct ExtractedEntity {
403    /// Entity text
404    pub text: String,
405    /// Entity type
406    pub entity_type: String,
407    /// Start position in text
408    pub start: usize,
409    /// End position in text
410    pub end: usize,
411    /// Confidence score (0.0 to 1.0)
412    pub confidence: f32,
413    /// Rule that extracted this entity
414    pub rule_name: String,
415}
416
417/// Training dataset for custom NER
418#[derive(Debug, Clone, Serialize, Deserialize)]
419pub struct TrainingDataset {
420    /// Annotated examples
421    pub examples: Vec<AnnotatedExample>,
422}
423
424impl TrainingDataset {
425    /// Create new training dataset
426    pub fn new() -> Self {
427        Self {
428            examples: Vec::new(),
429        }
430    }
431
432    /// Add annotated example
433    pub fn add_example(&mut self, example: AnnotatedExample) {
434        self.examples.push(example);
435    }
436
437    /// Get statistics
438    pub fn statistics(&self) -> DatasetStatistics {
439        let total_examples = self.examples.len();
440        let mut entity_counts: HashMap<String, usize> = HashMap::new();
441
442        for example in &self.examples {
443            for entity in &example.entities {
444                *entity_counts.entry(entity.entity_type.clone()).or_insert(0) += 1;
445            }
446        }
447
448        DatasetStatistics {
449            total_examples,
450            entity_counts,
451        }
452    }
453}
454
455impl Default for TrainingDataset {
456    fn default() -> Self {
457        Self::new()
458    }
459}
460
461/// Annotated text example
462#[derive(Debug, Clone, Serialize, Deserialize)]
463pub struct AnnotatedExample {
464    /// Original text
465    pub text: String,
466    /// Annotated entities
467    pub entities: Vec<ExtractedEntity>,
468}
469
470/// Dataset statistics
471#[derive(Debug, Clone, Serialize, Deserialize)]
472pub struct DatasetStatistics {
473    /// Total examples
474    pub total_examples: usize,
475    /// Entity type counts
476    pub entity_counts: HashMap<String, usize>,
477}
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482
483    #[test]
484    fn test_entity_type_creation() {
485        let mut entity_type = EntityType::new("PROTEIN".to_string(), "Protein names".to_string());
486
487        entity_type.add_example("hemoglobin".to_string());
488        entity_type.add_example("insulin".to_string());
489
490        assert_eq!(entity_type.examples.len(), 2);
491        assert_eq!(entity_type.dictionary.len(), 2);
492    }
493
494    #[test]
495    fn test_exact_match_extraction() {
496        let mut ner = CustomNER::new();
497
498        let rule = ExtractionRule {
499            name: "protein_exact".to_string(),
500            entity_type: "PROTEIN".to_string(),
501            rule_type: RuleType::ExactMatch,
502            pattern: "hemoglobin".to_string(),
503            min_confidence: 0.9,
504            priority: 1,
505        };
506
507        ner.add_rule(rule);
508
509        let text = "The protein hemoglobin is important. Hemoglobin carries oxygen.";
510        let entities = ner.extract(text);
511
512        assert_eq!(entities.len(), 2);
513        assert_eq!(entities[0].entity_type, "PROTEIN");
514        assert_eq!(entities[0].text.to_lowercase(), "hemoglobin");
515    }
516
517    #[test]
518    fn test_regex_extraction() {
519        let mut ner = CustomNER::new();
520
521        let rule = ExtractionRule {
522            name: "gene_pattern".to_string(),
523            entity_type: "GENE".to_string(),
524            rule_type: RuleType::Regex,
525            pattern: r"[A-Z]{2,4}\d+".to_string(),
526            min_confidence: 0.8,
527            priority: 1,
528        };
529
530        ner.add_rule(rule);
531
532        let text = "The genes TP53 and BRCA1 are tumor suppressors.";
533        let entities = ner.extract(text);
534
535        assert!(entities.len() >= 2);
536        assert!(entities.iter().any(|e| e.text == "TP53"));
537        assert!(entities.iter().any(|e| e.text == "BRCA1"));
538    }
539
540    #[test]
541    fn test_dictionary_extraction() {
542        let mut ner = CustomNER::new();
543
544        let mut protein_type = EntityType::new("PROTEIN".to_string(), "Protein names".to_string());
545        protein_type.add_dictionary_entries(vec![
546            "insulin".to_string(),
547            "hemoglobin".to_string(),
548            "collagen".to_string(),
549        ]);
550
551        ner.register_entity_type(protein_type);
552
553        let rule = ExtractionRule {
554            name: "protein_dict".to_string(),
555            entity_type: "PROTEIN".to_string(),
556            rule_type: RuleType::Dictionary,
557            pattern: "".to_string(),
558            min_confidence: 0.9,
559            priority: 2,
560        };
561
562        ner.add_rule(rule);
563
564        let text = "Insulin regulates blood sugar. Hemoglobin transports oxygen.";
565        let entities = ner.extract(text);
566
567        assert_eq!(entities.len(), 2);
568    }
569
570    #[test]
571    fn test_prefix_extraction() {
572        let mut ner = CustomNER::new();
573
574        let rule = ExtractionRule {
575            name: "bio_prefix".to_string(),
576            entity_type: "BIO_TERM".to_string(),
577            rule_type: RuleType::Prefix,
578            pattern: "bio".to_string(),
579            min_confidence: 0.7,
580            priority: 1,
581        };
582
583        ner.add_rule(rule);
584
585        let text = "Biology and biochemistry are fascinating subjects.";
586        let entities = ner.extract(text);
587
588        assert!(entities.len() >= 2);
589    }
590
591    #[test]
592    fn test_overlap_resolution() {
593        let mut ner = CustomNER::new();
594
595        let rule1 = ExtractionRule {
596            name: "rule1".to_string(),
597            entity_type: "TYPE1".to_string(),
598            rule_type: RuleType::ExactMatch,
599            pattern: "test".to_string(),
600            min_confidence: 0.9,
601            priority: 1,
602        };
603
604        let rule2 = ExtractionRule {
605            name: "rule2".to_string(),
606            entity_type: "TYPE2".to_string(),
607            rule_type: RuleType::ExactMatch,
608            pattern: "testing".to_string(),
609            min_confidence: 0.95,
610            priority: 2,
611        };
612
613        ner.add_rule(rule1);
614        ner.add_rule(rule2);
615
616        let text = "We are testing this code.";
617        let entities = ner.extract(text);
618
619        // Should only extract one entity (higher confidence/priority wins)
620        assert_eq!(entities.len(), 1);
621    }
622
623    #[test]
624    fn test_training_dataset() {
625        let mut dataset = TrainingDataset::new();
626
627        let example = AnnotatedExample {
628            text: "Insulin regulates glucose.".to_string(),
629            entities: vec![ExtractedEntity {
630                text: "Insulin".to_string(),
631                entity_type: "PROTEIN".to_string(),
632                start: 0,
633                end: 7,
634                confidence: 1.0,
635                rule_name: "manual".to_string(),
636            }],
637        };
638
639        dataset.add_example(example);
640
641        let stats = dataset.statistics();
642        assert_eq!(stats.total_examples, 1);
643        assert_eq!(stats.entity_counts.get("PROTEIN"), Some(&1));
644    }
645}