1use serde::{Deserialize, Serialize};
18use std::collections::{HashMap, HashSet};
19use regex::Regex;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct EntityType {
24 pub name: String,
26 pub description: String,
28 pub examples: Vec<String>,
30 pub patterns: Vec<String>,
32 pub dictionary: HashSet<String>,
34}
35
36impl EntityType {
37 pub fn new(name: String, description: String) -> Self {
39 Self {
40 name,
41 description,
42 examples: Vec::new(),
43 patterns: Vec::new(),
44 dictionary: HashSet::new(),
45 }
46 }
47
48 pub fn add_example(&mut self, example: String) {
50 self.examples.push(example.clone());
51 self.dictionary.insert(example.to_lowercase());
52 }
53
54 pub fn add_pattern(&mut self, pattern: String) {
56 self.patterns.push(pattern);
57 }
58
59 pub fn add_dictionary_entries(&mut self, entries: Vec<String>) {
61 for entry in entries {
62 self.dictionary.insert(entry.to_lowercase());
63 }
64 }
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct ExtractionRule {
70 pub name: String,
72 pub entity_type: String,
74 pub rule_type: RuleType,
76 pub pattern: String,
78 pub min_confidence: f32,
80 pub priority: i32,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
86pub enum RuleType {
87 ExactMatch,
89 Regex,
91 Prefix,
93 Suffix,
95 Contains,
97 Dictionary,
99 Contextual,
101}
102
103pub struct CustomNER {
105 entity_types: HashMap<String, EntityType>,
107 rules: Vec<ExtractionRule>,
109 compiled_patterns: HashMap<String, Regex>,
111}
112
113impl CustomNER {
114 pub fn new() -> Self {
116 Self {
117 entity_types: HashMap::new(),
118 rules: Vec::new(),
119 compiled_patterns: HashMap::new(),
120 }
121 }
122
123 pub fn register_entity_type(&mut self, entity_type: EntityType) {
125 self.entity_types.insert(entity_type.name.clone(), entity_type);
126 }
127
128 pub fn add_rule(&mut self, rule: ExtractionRule) {
130 if rule.rule_type == RuleType::Regex {
132 if let Ok(regex) = Regex::new(&rule.pattern) {
133 self.compiled_patterns.insert(rule.name.clone(), regex);
134 }
135 }
136
137 self.rules.push(rule);
138 self.rules.sort_by(|a, b| b.priority.cmp(&a.priority));
139 }
140
141 pub fn extract(&self, text: &str) -> Vec<ExtractedEntity> {
143 let mut entities = Vec::new();
144
145 for rule in &self.rules {
147 let rule_entities = self.apply_rule(text, rule);
148 entities.extend(rule_entities);
149 }
150
151 self.resolve_overlaps(entities)
153 }
154
155 fn apply_rule(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
157 match rule.rule_type {
158 RuleType::ExactMatch => self.extract_exact_match(text, rule),
159 RuleType::Regex => self.extract_regex(text, rule),
160 RuleType::Prefix => self.extract_prefix(text, rule),
161 RuleType::Suffix => self.extract_suffix(text, rule),
162 RuleType::Contains => self.extract_contains(text, rule),
163 RuleType::Dictionary => self.extract_dictionary(text, rule),
164 RuleType::Contextual => self.extract_contextual(text, rule),
165 }
166 }
167
168 fn extract_exact_match(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
170 let mut entities = Vec::new();
171 let pattern = &rule.pattern;
172 let text_lower = text.to_lowercase();
173 let pattern_lower = pattern.to_lowercase();
174
175 let mut start = 0;
176 while let Some(pos) = text_lower[start..].find(&pattern_lower) {
177 let absolute_pos = start + pos;
178 entities.push(ExtractedEntity {
179 text: text[absolute_pos..absolute_pos + pattern.len()].to_string(),
180 entity_type: rule.entity_type.clone(),
181 start: absolute_pos,
182 end: absolute_pos + pattern.len(),
183 confidence: 1.0,
184 rule_name: rule.name.clone(),
185 });
186
187 start = absolute_pos + pattern.len();
188 }
189
190 entities
191 }
192
193 fn extract_regex(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
195 let mut entities = Vec::new();
196
197 if let Some(regex) = self.compiled_patterns.get(&rule.name) {
198 for capture in regex.captures_iter(text) {
199 if let Some(matched) = capture.get(0) {
200 entities.push(ExtractedEntity {
201 text: matched.as_str().to_string(),
202 entity_type: rule.entity_type.clone(),
203 start: matched.start(),
204 end: matched.end(),
205 confidence: 0.9,
206 rule_name: rule.name.clone(),
207 });
208 }
209 }
210 }
211
212 entities
213 }
214
215 fn extract_prefix(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
217 let mut entities = Vec::new();
218 let words: Vec<&str> = text.split_whitespace().collect();
219 let mut pos = 0;
220
221 for word in words {
222 if word.to_lowercase().starts_with(&rule.pattern.to_lowercase()) {
223 entities.push(ExtractedEntity {
224 text: word.to_string(),
225 entity_type: rule.entity_type.clone(),
226 start: pos,
227 end: pos + word.len(),
228 confidence: 0.7,
229 rule_name: rule.name.clone(),
230 });
231 }
232 pos += word.len() + 1; }
234
235 entities
236 }
237
238 fn extract_suffix(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
240 let mut entities = Vec::new();
241 let words: Vec<&str> = text.split_whitespace().collect();
242 let mut pos = 0;
243
244 for word in words {
245 if word.to_lowercase().ends_with(&rule.pattern.to_lowercase()) {
246 entities.push(ExtractedEntity {
247 text: word.to_string(),
248 entity_type: rule.entity_type.clone(),
249 start: pos,
250 end: pos + word.len(),
251 confidence: 0.7,
252 rule_name: rule.name.clone(),
253 });
254 }
255 pos += word.len() + 1;
256 }
257
258 entities
259 }
260
261 fn extract_contains(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
263 let mut entities = Vec::new();
264 let words: Vec<&str> = text.split_whitespace().collect();
265 let mut pos = 0;
266
267 for word in words {
268 if word.to_lowercase().contains(&rule.pattern.to_lowercase()) {
269 entities.push(ExtractedEntity {
270 text: word.to_string(),
271 entity_type: rule.entity_type.clone(),
272 start: pos,
273 end: pos + word.len(),
274 confidence: 0.6,
275 rule_name: rule.name.clone(),
276 });
277 }
278 pos += word.len() + 1;
279 }
280
281 entities
282 }
283
284 fn extract_dictionary(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
286 let mut entities = Vec::new();
287
288 if let Some(entity_type) = self.entity_types.get(&rule.entity_type) {
289 let text_lower = text.to_lowercase();
290
291 for entry in &entity_type.dictionary {
292 let mut start = 0;
293 while let Some(pos) = text_lower[start..].find(entry) {
294 let absolute_pos = start + pos;
295 entities.push(ExtractedEntity {
296 text: text[absolute_pos..absolute_pos + entry.len()].to_string(),
297 entity_type: rule.entity_type.clone(),
298 start: absolute_pos,
299 end: absolute_pos + entry.len(),
300 confidence: 0.95,
301 rule_name: rule.name.clone(),
302 });
303
304 start = absolute_pos + entry.len();
305 }
306 }
307 }
308
309 entities
310 }
311
312 fn extract_contextual(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
314 let parts: Vec<&str> = rule.pattern.split('|').collect();
317 if parts.len() != 3 {
318 return Vec::new();
319 }
320
321 let before = parts[0];
322 let target = parts[1];
323 let after = parts[2];
324
325 let mut entities = Vec::new();
326 let words: Vec<&str> = text.split_whitespace().collect();
327
328 for window in words.windows(3) {
329 if window[0].to_lowercase().contains(&before.to_lowercase())
330 && window[1].to_lowercase().contains(&target.to_lowercase())
331 && window[2].to_lowercase().contains(&after.to_lowercase())
332 {
333 if let Some(pos) = text.find(window[1]) {
335 entities.push(ExtractedEntity {
336 text: window[1].to_string(),
337 entity_type: rule.entity_type.clone(),
338 start: pos,
339 end: pos + window[1].len(),
340 confidence: 0.85,
341 rule_name: rule.name.clone(),
342 });
343 }
344 }
345 }
346
347 entities
348 }
349
350 fn resolve_overlaps(&self, mut entities: Vec<ExtractedEntity>) -> Vec<ExtractedEntity> {
352 if entities.is_empty() {
353 return entities;
354 }
355
356 entities.sort_by(|a, b| {
358 a.start.cmp(&b.start)
359 .then(b.confidence.partial_cmp(&a.confidence).unwrap())
360 });
361
362 let mut result = Vec::new();
363 let mut last_end = 0;
364
365 for entity in entities {
366 if entity.start < last_end {
368 continue;
369 }
370
371 last_end = entity.end;
372 result.push(entity);
373 }
374
375 result
376 }
377
378 pub fn entity_types(&self) -> &HashMap<String, EntityType> {
380 &self.entity_types
381 }
382
383 pub fn rules(&self) -> &[ExtractionRule] {
385 &self.rules
386 }
387}
388
389impl Default for CustomNER {
390 fn default() -> Self {
391 Self::new()
392 }
393}
394
395#[derive(Debug, Clone, Serialize, Deserialize)]
397pub struct ExtractedEntity {
398 pub text: String,
400 pub entity_type: String,
402 pub start: usize,
404 pub end: usize,
406 pub confidence: f32,
408 pub rule_name: String,
410}
411
412#[derive(Debug, Clone, Serialize, Deserialize)]
414pub struct TrainingDataset {
415 pub examples: Vec<AnnotatedExample>,
417}
418
419impl TrainingDataset {
420 pub fn new() -> Self {
422 Self {
423 examples: Vec::new(),
424 }
425 }
426
427 pub fn add_example(&mut self, example: AnnotatedExample) {
429 self.examples.push(example);
430 }
431
432 pub fn statistics(&self) -> DatasetStatistics {
434 let total_examples = self.examples.len();
435 let mut entity_counts: HashMap<String, usize> = HashMap::new();
436
437 for example in &self.examples {
438 for entity in &example.entities {
439 *entity_counts.entry(entity.entity_type.clone()).or_insert(0) += 1;
440 }
441 }
442
443 DatasetStatistics {
444 total_examples,
445 entity_counts,
446 }
447 }
448}
449
450impl Default for TrainingDataset {
451 fn default() -> Self {
452 Self::new()
453 }
454}
455
456#[derive(Debug, Clone, Serialize, Deserialize)]
458pub struct AnnotatedExample {
459 pub text: String,
461 pub entities: Vec<ExtractedEntity>,
463}
464
465#[derive(Debug, Clone, Serialize, Deserialize)]
467pub struct DatasetStatistics {
468 pub total_examples: usize,
470 pub entity_counts: HashMap<String, usize>,
472}
473
474#[cfg(test)]
475mod tests {
476 use super::*;
477
478 #[test]
479 fn test_entity_type_creation() {
480 let mut entity_type = EntityType::new(
481 "PROTEIN".to_string(),
482 "Protein names".to_string(),
483 );
484
485 entity_type.add_example("hemoglobin".to_string());
486 entity_type.add_example("insulin".to_string());
487
488 assert_eq!(entity_type.examples.len(), 2);
489 assert_eq!(entity_type.dictionary.len(), 2);
490 }
491
492 #[test]
493 fn test_exact_match_extraction() {
494 let mut ner = CustomNER::new();
495
496 let rule = ExtractionRule {
497 name: "protein_exact".to_string(),
498 entity_type: "PROTEIN".to_string(),
499 rule_type: RuleType::ExactMatch,
500 pattern: "hemoglobin".to_string(),
501 min_confidence: 0.9,
502 priority: 1,
503 };
504
505 ner.add_rule(rule);
506
507 let text = "The protein hemoglobin is important. Hemoglobin carries oxygen.";
508 let entities = ner.extract(text);
509
510 assert_eq!(entities.len(), 2);
511 assert_eq!(entities[0].entity_type, "PROTEIN");
512 assert_eq!(entities[0].text.to_lowercase(), "hemoglobin");
513 }
514
515 #[test]
516 fn test_regex_extraction() {
517 let mut ner = CustomNER::new();
518
519 let rule = ExtractionRule {
520 name: "gene_pattern".to_string(),
521 entity_type: "GENE".to_string(),
522 rule_type: RuleType::Regex,
523 pattern: r"[A-Z]{2,4}\d+".to_string(),
524 min_confidence: 0.8,
525 priority: 1,
526 };
527
528 ner.add_rule(rule);
529
530 let text = "The genes TP53 and BRCA1 are tumor suppressors.";
531 let entities = ner.extract(text);
532
533 assert!(entities.len() >= 2);
534 assert!(entities.iter().any(|e| e.text == "TP53"));
535 assert!(entities.iter().any(|e| e.text == "BRCA1"));
536 }
537
538 #[test]
539 fn test_dictionary_extraction() {
540 let mut ner = CustomNER::new();
541
542 let mut protein_type = EntityType::new(
543 "PROTEIN".to_string(),
544 "Protein names".to_string(),
545 );
546 protein_type.add_dictionary_entries(vec![
547 "insulin".to_string(),
548 "hemoglobin".to_string(),
549 "collagen".to_string(),
550 ]);
551
552 ner.register_entity_type(protein_type);
553
554 let rule = ExtractionRule {
555 name: "protein_dict".to_string(),
556 entity_type: "PROTEIN".to_string(),
557 rule_type: RuleType::Dictionary,
558 pattern: "".to_string(),
559 min_confidence: 0.9,
560 priority: 2,
561 };
562
563 ner.add_rule(rule);
564
565 let text = "Insulin regulates blood sugar. Hemoglobin transports oxygen.";
566 let entities = ner.extract(text);
567
568 assert_eq!(entities.len(), 2);
569 }
570
571 #[test]
572 fn test_prefix_extraction() {
573 let mut ner = CustomNER::new();
574
575 let rule = ExtractionRule {
576 name: "bio_prefix".to_string(),
577 entity_type: "BIO_TERM".to_string(),
578 rule_type: RuleType::Prefix,
579 pattern: "bio".to_string(),
580 min_confidence: 0.7,
581 priority: 1,
582 };
583
584 ner.add_rule(rule);
585
586 let text = "Biology and biochemistry are fascinating subjects.";
587 let entities = ner.extract(text);
588
589 assert!(entities.len() >= 2);
590 }
591
592 #[test]
593 fn test_overlap_resolution() {
594 let mut ner = CustomNER::new();
595
596 let rule1 = ExtractionRule {
597 name: "rule1".to_string(),
598 entity_type: "TYPE1".to_string(),
599 rule_type: RuleType::ExactMatch,
600 pattern: "test".to_string(),
601 min_confidence: 0.9,
602 priority: 1,
603 };
604
605 let rule2 = ExtractionRule {
606 name: "rule2".to_string(),
607 entity_type: "TYPE2".to_string(),
608 rule_type: RuleType::ExactMatch,
609 pattern: "testing".to_string(),
610 min_confidence: 0.95,
611 priority: 2,
612 };
613
614 ner.add_rule(rule1);
615 ner.add_rule(rule2);
616
617 let text = "We are testing this code.";
618 let entities = ner.extract(text);
619
620 assert_eq!(entities.len(), 1);
622 }
623
624 #[test]
625 fn test_training_dataset() {
626 let mut dataset = TrainingDataset::new();
627
628 let example = AnnotatedExample {
629 text: "Insulin regulates glucose.".to_string(),
630 entities: vec![
631 ExtractedEntity {
632 text: "Insulin".to_string(),
633 entity_type: "PROTEIN".to_string(),
634 start: 0,
635 end: 7,
636 confidence: 1.0,
637 rule_name: "manual".to_string(),
638 },
639 ],
640 };
641
642 dataset.add_example(example);
643
644 let stats = dataset.statistics();
645 assert_eq!(stats.total_examples, 1);
646 assert_eq!(stats.entity_counts.get("PROTEIN"), Some(&1));
647 }
648}