1use regex::Regex;
18use serde::{Deserialize, Serialize};
19use std::collections::{HashMap, HashSet};
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct EntityType {
24 pub name: String,
26 pub description: String,
28 pub examples: Vec<String>,
30 pub patterns: Vec<String>,
32 pub dictionary: HashSet<String>,
34}
35
36impl EntityType {
37 pub fn new(name: String, description: String) -> Self {
39 Self {
40 name,
41 description,
42 examples: Vec::new(),
43 patterns: Vec::new(),
44 dictionary: HashSet::new(),
45 }
46 }
47
48 pub fn add_example(&mut self, example: String) {
50 self.examples.push(example.clone());
51 self.dictionary.insert(example.to_lowercase());
52 }
53
54 pub fn add_pattern(&mut self, pattern: String) {
56 self.patterns.push(pattern);
57 }
58
59 pub fn add_dictionary_entries(&mut self, entries: Vec<String>) {
61 for entry in entries {
62 self.dictionary.insert(entry.to_lowercase());
63 }
64 }
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct ExtractionRule {
70 pub name: String,
72 pub entity_type: String,
74 pub rule_type: RuleType,
76 pub pattern: String,
78 pub min_confidence: f32,
80 pub priority: i32,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
86pub enum RuleType {
87 ExactMatch,
89 Regex,
91 Prefix,
93 Suffix,
95 Contains,
97 Dictionary,
99 Contextual,
101}
102
103pub struct CustomNER {
105 entity_types: HashMap<String, EntityType>,
107 rules: Vec<ExtractionRule>,
109 compiled_patterns: HashMap<String, Regex>,
111}
112
113impl CustomNER {
114 pub fn new() -> Self {
116 Self {
117 entity_types: HashMap::new(),
118 rules: Vec::new(),
119 compiled_patterns: HashMap::new(),
120 }
121 }
122
123 pub fn register_entity_type(&mut self, entity_type: EntityType) {
125 self.entity_types
126 .insert(entity_type.name.clone(), entity_type);
127 }
128
129 pub fn add_rule(&mut self, rule: ExtractionRule) {
131 if rule.rule_type == RuleType::Regex {
133 if let Ok(regex) = Regex::new(&rule.pattern) {
134 self.compiled_patterns.insert(rule.name.clone(), regex);
135 }
136 }
137
138 self.rules.push(rule);
139 self.rules
140 .sort_by_key(|rule| std::cmp::Reverse(rule.priority));
141 }
142
143 pub fn extract(&self, text: &str) -> Vec<ExtractedEntity> {
145 let mut entities = Vec::new();
146
147 for rule in &self.rules {
149 let rule_entities = self.apply_rule(text, rule);
150 entities.extend(rule_entities);
151 }
152
153 self.resolve_overlaps(entities)
155 }
156
157 fn apply_rule(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
159 match rule.rule_type {
160 RuleType::ExactMatch => self.extract_exact_match(text, rule),
161 RuleType::Regex => self.extract_regex(text, rule),
162 RuleType::Prefix => self.extract_prefix(text, rule),
163 RuleType::Suffix => self.extract_suffix(text, rule),
164 RuleType::Contains => self.extract_contains(text, rule),
165 RuleType::Dictionary => self.extract_dictionary(text, rule),
166 RuleType::Contextual => self.extract_contextual(text, rule),
167 }
168 }
169
170 fn extract_exact_match(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
172 let mut entities = Vec::new();
173 let pattern = &rule.pattern;
174 let text_lower = text.to_lowercase();
175 let pattern_lower = pattern.to_lowercase();
176
177 let mut start = 0;
178 while let Some(pos) = text_lower[start..].find(&pattern_lower) {
179 let absolute_pos = start + pos;
180 entities.push(ExtractedEntity {
181 text: text[absolute_pos..absolute_pos + pattern.len()].to_string(),
182 entity_type: rule.entity_type.clone(),
183 start: absolute_pos,
184 end: absolute_pos + pattern.len(),
185 confidence: 1.0,
186 rule_name: rule.name.clone(),
187 });
188
189 start = absolute_pos + pattern.len();
190 }
191
192 entities
193 }
194
195 fn extract_regex(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
197 let mut entities = Vec::new();
198
199 if let Some(regex) = self.compiled_patterns.get(&rule.name) {
200 for capture in regex.captures_iter(text) {
201 if let Some(matched) = capture.get(0) {
202 entities.push(ExtractedEntity {
203 text: matched.as_str().to_string(),
204 entity_type: rule.entity_type.clone(),
205 start: matched.start(),
206 end: matched.end(),
207 confidence: 0.9,
208 rule_name: rule.name.clone(),
209 });
210 }
211 }
212 }
213
214 entities
215 }
216
217 fn extract_prefix(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
219 let mut entities = Vec::new();
220 let words: Vec<&str> = text.split_whitespace().collect();
221 let mut pos = 0;
222
223 for word in words {
224 if word
225 .to_lowercase()
226 .starts_with(&rule.pattern.to_lowercase())
227 {
228 entities.push(ExtractedEntity {
229 text: word.to_string(),
230 entity_type: rule.entity_type.clone(),
231 start: pos,
232 end: pos + word.len(),
233 confidence: 0.7,
234 rule_name: rule.name.clone(),
235 });
236 }
237 pos += word.len() + 1; }
239
240 entities
241 }
242
243 fn extract_suffix(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
245 let mut entities = Vec::new();
246 let words: Vec<&str> = text.split_whitespace().collect();
247 let mut pos = 0;
248
249 for word in words {
250 if word.to_lowercase().ends_with(&rule.pattern.to_lowercase()) {
251 entities.push(ExtractedEntity {
252 text: word.to_string(),
253 entity_type: rule.entity_type.clone(),
254 start: pos,
255 end: pos + word.len(),
256 confidence: 0.7,
257 rule_name: rule.name.clone(),
258 });
259 }
260 pos += word.len() + 1;
261 }
262
263 entities
264 }
265
266 fn extract_contains(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
268 let mut entities = Vec::new();
269 let words: Vec<&str> = text.split_whitespace().collect();
270 let mut pos = 0;
271
272 for word in words {
273 if word.to_lowercase().contains(&rule.pattern.to_lowercase()) {
274 entities.push(ExtractedEntity {
275 text: word.to_string(),
276 entity_type: rule.entity_type.clone(),
277 start: pos,
278 end: pos + word.len(),
279 confidence: 0.6,
280 rule_name: rule.name.clone(),
281 });
282 }
283 pos += word.len() + 1;
284 }
285
286 entities
287 }
288
289 fn extract_dictionary(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
291 let mut entities = Vec::new();
292
293 if let Some(entity_type) = self.entity_types.get(&rule.entity_type) {
294 let text_lower = text.to_lowercase();
295
296 for entry in &entity_type.dictionary {
297 let mut start = 0;
298 while let Some(pos) = text_lower[start..].find(entry) {
299 let absolute_pos = start + pos;
300 entities.push(ExtractedEntity {
301 text: text[absolute_pos..absolute_pos + entry.len()].to_string(),
302 entity_type: rule.entity_type.clone(),
303 start: absolute_pos,
304 end: absolute_pos + entry.len(),
305 confidence: 0.95,
306 rule_name: rule.name.clone(),
307 });
308
309 start = absolute_pos + entry.len();
310 }
311 }
312 }
313
314 entities
315 }
316
317 fn extract_contextual(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
319 let parts: Vec<&str> = rule.pattern.split('|').collect();
322 if parts.len() != 3 {
323 return Vec::new();
324 }
325
326 let before = parts[0];
327 let target = parts[1];
328 let after = parts[2];
329
330 let mut entities = Vec::new();
331 let words: Vec<&str> = text.split_whitespace().collect();
332
333 for window in words.windows(3) {
334 if window[0].to_lowercase().contains(&before.to_lowercase())
335 && window[1].to_lowercase().contains(&target.to_lowercase())
336 && window[2].to_lowercase().contains(&after.to_lowercase())
337 {
338 if let Some(pos) = text.find(window[1]) {
340 entities.push(ExtractedEntity {
341 text: window[1].to_string(),
342 entity_type: rule.entity_type.clone(),
343 start: pos,
344 end: pos + window[1].len(),
345 confidence: 0.85,
346 rule_name: rule.name.clone(),
347 });
348 }
349 }
350 }
351
352 entities
353 }
354
355 fn resolve_overlaps(&self, mut entities: Vec<ExtractedEntity>) -> Vec<ExtractedEntity> {
357 if entities.is_empty() {
358 return entities;
359 }
360
361 entities.sort_by(|a, b| {
363 a.start.cmp(&b.start).then(
364 b.confidence
365 .partial_cmp(&a.confidence)
366 .unwrap_or(std::cmp::Ordering::Equal),
367 )
368 });
369
370 let mut result = Vec::new();
371 let mut last_end = 0;
372
373 for entity in entities {
374 if entity.start < last_end {
376 continue;
377 }
378
379 last_end = entity.end;
380 result.push(entity);
381 }
382
383 result
384 }
385
386 pub fn entity_types(&self) -> &HashMap<String, EntityType> {
388 &self.entity_types
389 }
390
391 pub fn rules(&self) -> &[ExtractionRule] {
393 &self.rules
394 }
395}
396
397impl Default for CustomNER {
398 fn default() -> Self {
399 Self::new()
400 }
401}
402
403#[derive(Debug, Clone, Serialize, Deserialize)]
405pub struct ExtractedEntity {
406 pub text: String,
408 pub entity_type: String,
410 pub start: usize,
412 pub end: usize,
414 pub confidence: f32,
416 pub rule_name: String,
418}
419
420#[derive(Debug, Clone, Serialize, Deserialize)]
422pub struct TrainingDataset {
423 pub examples: Vec<AnnotatedExample>,
425}
426
427impl TrainingDataset {
428 pub fn new() -> Self {
430 Self {
431 examples: Vec::new(),
432 }
433 }
434
435 pub fn add_example(&mut self, example: AnnotatedExample) {
437 self.examples.push(example);
438 }
439
440 pub fn statistics(&self) -> DatasetStatistics {
442 let total_examples = self.examples.len();
443 let mut entity_counts: HashMap<String, usize> = HashMap::new();
444
445 for example in &self.examples {
446 for entity in &example.entities {
447 *entity_counts.entry(entity.entity_type.clone()).or_insert(0) += 1;
448 }
449 }
450
451 DatasetStatistics {
452 total_examples,
453 entity_counts,
454 }
455 }
456}
457
458impl Default for TrainingDataset {
459 fn default() -> Self {
460 Self::new()
461 }
462}
463
464#[derive(Debug, Clone, Serialize, Deserialize)]
466pub struct AnnotatedExample {
467 pub text: String,
469 pub entities: Vec<ExtractedEntity>,
471}
472
473#[derive(Debug, Clone, Serialize, Deserialize)]
475pub struct DatasetStatistics {
476 pub total_examples: usize,
478 pub entity_counts: HashMap<String, usize>,
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485
486 #[test]
487 fn test_entity_type_creation() {
488 let mut entity_type = EntityType::new("PROTEIN".to_string(), "Protein names".to_string());
489
490 entity_type.add_example("hemoglobin".to_string());
491 entity_type.add_example("insulin".to_string());
492
493 assert_eq!(entity_type.examples.len(), 2);
494 assert_eq!(entity_type.dictionary.len(), 2);
495 }
496
497 #[test]
498 fn test_exact_match_extraction() {
499 let mut ner = CustomNER::new();
500
501 let rule = ExtractionRule {
502 name: "protein_exact".to_string(),
503 entity_type: "PROTEIN".to_string(),
504 rule_type: RuleType::ExactMatch,
505 pattern: "hemoglobin".to_string(),
506 min_confidence: 0.9,
507 priority: 1,
508 };
509
510 ner.add_rule(rule);
511
512 let text = "The protein hemoglobin is important. Hemoglobin carries oxygen.";
513 let entities = ner.extract(text);
514
515 assert_eq!(entities.len(), 2);
516 assert_eq!(entities[0].entity_type, "PROTEIN");
517 assert_eq!(entities[0].text.to_lowercase(), "hemoglobin");
518 }
519
520 #[test]
521 fn test_regex_extraction() {
522 let mut ner = CustomNER::new();
523
524 let rule = ExtractionRule {
525 name: "gene_pattern".to_string(),
526 entity_type: "GENE".to_string(),
527 rule_type: RuleType::Regex,
528 pattern: r"[A-Z]{2,4}\d+".to_string(),
529 min_confidence: 0.8,
530 priority: 1,
531 };
532
533 ner.add_rule(rule);
534
535 let text = "The genes TP53 and BRCA1 are tumor suppressors.";
536 let entities = ner.extract(text);
537
538 assert!(entities.len() >= 2);
539 assert!(entities.iter().any(|e| e.text == "TP53"));
540 assert!(entities.iter().any(|e| e.text == "BRCA1"));
541 }
542
543 #[test]
544 fn test_dictionary_extraction() {
545 let mut ner = CustomNER::new();
546
547 let mut protein_type = EntityType::new("PROTEIN".to_string(), "Protein names".to_string());
548 protein_type.add_dictionary_entries(vec![
549 "insulin".to_string(),
550 "hemoglobin".to_string(),
551 "collagen".to_string(),
552 ]);
553
554 ner.register_entity_type(protein_type);
555
556 let rule = ExtractionRule {
557 name: "protein_dict".to_string(),
558 entity_type: "PROTEIN".to_string(),
559 rule_type: RuleType::Dictionary,
560 pattern: "".to_string(),
561 min_confidence: 0.9,
562 priority: 2,
563 };
564
565 ner.add_rule(rule);
566
567 let text = "Insulin regulates blood sugar. Hemoglobin transports oxygen.";
568 let entities = ner.extract(text);
569
570 assert_eq!(entities.len(), 2);
571 }
572
573 #[test]
574 fn test_prefix_extraction() {
575 let mut ner = CustomNER::new();
576
577 let rule = ExtractionRule {
578 name: "bio_prefix".to_string(),
579 entity_type: "BIO_TERM".to_string(),
580 rule_type: RuleType::Prefix,
581 pattern: "bio".to_string(),
582 min_confidence: 0.7,
583 priority: 1,
584 };
585
586 ner.add_rule(rule);
587
588 let text = "Biology and biochemistry are fascinating subjects.";
589 let entities = ner.extract(text);
590
591 assert!(entities.len() >= 2);
592 }
593
594 #[test]
595 fn test_overlap_resolution() {
596 let mut ner = CustomNER::new();
597
598 let rule1 = ExtractionRule {
599 name: "rule1".to_string(),
600 entity_type: "TYPE1".to_string(),
601 rule_type: RuleType::ExactMatch,
602 pattern: "test".to_string(),
603 min_confidence: 0.9,
604 priority: 1,
605 };
606
607 let rule2 = ExtractionRule {
608 name: "rule2".to_string(),
609 entity_type: "TYPE2".to_string(),
610 rule_type: RuleType::ExactMatch,
611 pattern: "testing".to_string(),
612 min_confidence: 0.95,
613 priority: 2,
614 };
615
616 ner.add_rule(rule1);
617 ner.add_rule(rule2);
618
619 let text = "We are testing this code.";
620 let entities = ner.extract(text);
621
622 assert_eq!(entities.len(), 1);
624 }
625
626 #[test]
627 fn test_training_dataset() {
628 let mut dataset = TrainingDataset::new();
629
630 let example = AnnotatedExample {
631 text: "Insulin regulates glucose.".to_string(),
632 entities: vec![ExtractedEntity {
633 text: "Insulin".to_string(),
634 entity_type: "PROTEIN".to_string(),
635 start: 0,
636 end: 7,
637 confidence: 1.0,
638 rule_name: "manual".to_string(),
639 }],
640 };
641
642 dataset.add_example(example);
643
644 let stats = dataset.statistics();
645 assert_eq!(stats.total_examples, 1);
646 assert_eq!(stats.entity_counts.get("PROTEIN"), Some(&1));
647 }
648}