1use regex::Regex;
18use serde::{Deserialize, Serialize};
19use std::collections::{HashMap, HashSet};
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct EntityType {
24 pub name: String,
26 pub description: String,
28 pub examples: Vec<String>,
30 pub patterns: Vec<String>,
32 pub dictionary: HashSet<String>,
34}
35
36impl EntityType {
37 pub fn new(name: String, description: String) -> Self {
39 Self {
40 name,
41 description,
42 examples: Vec::new(),
43 patterns: Vec::new(),
44 dictionary: HashSet::new(),
45 }
46 }
47
48 pub fn add_example(&mut self, example: String) {
50 self.examples.push(example.clone());
51 self.dictionary.insert(example.to_lowercase());
52 }
53
54 pub fn add_pattern(&mut self, pattern: String) {
56 self.patterns.push(pattern);
57 }
58
59 pub fn add_dictionary_entries(&mut self, entries: Vec<String>) {
61 for entry in entries {
62 self.dictionary.insert(entry.to_lowercase());
63 }
64 }
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct ExtractionRule {
70 pub name: String,
72 pub entity_type: String,
74 pub rule_type: RuleType,
76 pub pattern: String,
78 pub min_confidence: f32,
80 pub priority: i32,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
86pub enum RuleType {
87 ExactMatch,
89 Regex,
91 Prefix,
93 Suffix,
95 Contains,
97 Dictionary,
99 Contextual,
101}
102
103pub struct CustomNER {
105 entity_types: HashMap<String, EntityType>,
107 rules: Vec<ExtractionRule>,
109 compiled_patterns: HashMap<String, Regex>,
111}
112
113impl CustomNER {
114 pub fn new() -> Self {
116 Self {
117 entity_types: HashMap::new(),
118 rules: Vec::new(),
119 compiled_patterns: HashMap::new(),
120 }
121 }
122
123 pub fn register_entity_type(&mut self, entity_type: EntityType) {
125 self.entity_types
126 .insert(entity_type.name.clone(), entity_type);
127 }
128
129 pub fn add_rule(&mut self, rule: ExtractionRule) {
131 if rule.rule_type == RuleType::Regex {
133 if let Ok(regex) = Regex::new(&rule.pattern) {
134 self.compiled_patterns.insert(rule.name.clone(), regex);
135 }
136 }
137
138 self.rules.push(rule);
139 self.rules.sort_by(|a, b| b.priority.cmp(&a.priority));
140 }
141
142 pub fn extract(&self, text: &str) -> Vec<ExtractedEntity> {
144 let mut entities = Vec::new();
145
146 for rule in &self.rules {
148 let rule_entities = self.apply_rule(text, rule);
149 entities.extend(rule_entities);
150 }
151
152 self.resolve_overlaps(entities)
154 }
155
156 fn apply_rule(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
158 match rule.rule_type {
159 RuleType::ExactMatch => self.extract_exact_match(text, rule),
160 RuleType::Regex => self.extract_regex(text, rule),
161 RuleType::Prefix => self.extract_prefix(text, rule),
162 RuleType::Suffix => self.extract_suffix(text, rule),
163 RuleType::Contains => self.extract_contains(text, rule),
164 RuleType::Dictionary => self.extract_dictionary(text, rule),
165 RuleType::Contextual => self.extract_contextual(text, rule),
166 }
167 }
168
169 fn extract_exact_match(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
171 let mut entities = Vec::new();
172 let pattern = &rule.pattern;
173 let text_lower = text.to_lowercase();
174 let pattern_lower = pattern.to_lowercase();
175
176 let mut start = 0;
177 while let Some(pos) = text_lower[start..].find(&pattern_lower) {
178 let absolute_pos = start + pos;
179 entities.push(ExtractedEntity {
180 text: text[absolute_pos..absolute_pos + pattern.len()].to_string(),
181 entity_type: rule.entity_type.clone(),
182 start: absolute_pos,
183 end: absolute_pos + pattern.len(),
184 confidence: 1.0,
185 rule_name: rule.name.clone(),
186 });
187
188 start = absolute_pos + pattern.len();
189 }
190
191 entities
192 }
193
194 fn extract_regex(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
196 let mut entities = Vec::new();
197
198 if let Some(regex) = self.compiled_patterns.get(&rule.name) {
199 for capture in regex.captures_iter(text) {
200 if let Some(matched) = capture.get(0) {
201 entities.push(ExtractedEntity {
202 text: matched.as_str().to_string(),
203 entity_type: rule.entity_type.clone(),
204 start: matched.start(),
205 end: matched.end(),
206 confidence: 0.9,
207 rule_name: rule.name.clone(),
208 });
209 }
210 }
211 }
212
213 entities
214 }
215
216 fn extract_prefix(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
218 let mut entities = Vec::new();
219 let words: Vec<&str> = text.split_whitespace().collect();
220 let mut pos = 0;
221
222 for word in words {
223 if word
224 .to_lowercase()
225 .starts_with(&rule.pattern.to_lowercase())
226 {
227 entities.push(ExtractedEntity {
228 text: word.to_string(),
229 entity_type: rule.entity_type.clone(),
230 start: pos,
231 end: pos + word.len(),
232 confidence: 0.7,
233 rule_name: rule.name.clone(),
234 });
235 }
236 pos += word.len() + 1; }
238
239 entities
240 }
241
242 fn extract_suffix(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
244 let mut entities = Vec::new();
245 let words: Vec<&str> = text.split_whitespace().collect();
246 let mut pos = 0;
247
248 for word in words {
249 if word.to_lowercase().ends_with(&rule.pattern.to_lowercase()) {
250 entities.push(ExtractedEntity {
251 text: word.to_string(),
252 entity_type: rule.entity_type.clone(),
253 start: pos,
254 end: pos + word.len(),
255 confidence: 0.7,
256 rule_name: rule.name.clone(),
257 });
258 }
259 pos += word.len() + 1;
260 }
261
262 entities
263 }
264
265 fn extract_contains(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
267 let mut entities = Vec::new();
268 let words: Vec<&str> = text.split_whitespace().collect();
269 let mut pos = 0;
270
271 for word in words {
272 if word.to_lowercase().contains(&rule.pattern.to_lowercase()) {
273 entities.push(ExtractedEntity {
274 text: word.to_string(),
275 entity_type: rule.entity_type.clone(),
276 start: pos,
277 end: pos + word.len(),
278 confidence: 0.6,
279 rule_name: rule.name.clone(),
280 });
281 }
282 pos += word.len() + 1;
283 }
284
285 entities
286 }
287
288 fn extract_dictionary(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
290 let mut entities = Vec::new();
291
292 if let Some(entity_type) = self.entity_types.get(&rule.entity_type) {
293 let text_lower = text.to_lowercase();
294
295 for entry in &entity_type.dictionary {
296 let mut start = 0;
297 while let Some(pos) = text_lower[start..].find(entry) {
298 let absolute_pos = start + pos;
299 entities.push(ExtractedEntity {
300 text: text[absolute_pos..absolute_pos + entry.len()].to_string(),
301 entity_type: rule.entity_type.clone(),
302 start: absolute_pos,
303 end: absolute_pos + entry.len(),
304 confidence: 0.95,
305 rule_name: rule.name.clone(),
306 });
307
308 start = absolute_pos + entry.len();
309 }
310 }
311 }
312
313 entities
314 }
315
316 fn extract_contextual(&self, text: &str, rule: &ExtractionRule) -> Vec<ExtractedEntity> {
318 let parts: Vec<&str> = rule.pattern.split('|').collect();
321 if parts.len() != 3 {
322 return Vec::new();
323 }
324
325 let before = parts[0];
326 let target = parts[1];
327 let after = parts[2];
328
329 let mut entities = Vec::new();
330 let words: Vec<&str> = text.split_whitespace().collect();
331
332 for window in words.windows(3) {
333 if window[0].to_lowercase().contains(&before.to_lowercase())
334 && window[1].to_lowercase().contains(&target.to_lowercase())
335 && window[2].to_lowercase().contains(&after.to_lowercase())
336 {
337 if let Some(pos) = text.find(window[1]) {
339 entities.push(ExtractedEntity {
340 text: window[1].to_string(),
341 entity_type: rule.entity_type.clone(),
342 start: pos,
343 end: pos + window[1].len(),
344 confidence: 0.85,
345 rule_name: rule.name.clone(),
346 });
347 }
348 }
349 }
350
351 entities
352 }
353
354 fn resolve_overlaps(&self, mut entities: Vec<ExtractedEntity>) -> Vec<ExtractedEntity> {
356 if entities.is_empty() {
357 return entities;
358 }
359
360 entities.sort_by(|a, b| {
362 a.start
363 .cmp(&b.start)
364 .then(b.confidence.partial_cmp(&a.confidence).unwrap())
365 });
366
367 let mut result = Vec::new();
368 let mut last_end = 0;
369
370 for entity in entities {
371 if entity.start < last_end {
373 continue;
374 }
375
376 last_end = entity.end;
377 result.push(entity);
378 }
379
380 result
381 }
382
383 pub fn entity_types(&self) -> &HashMap<String, EntityType> {
385 &self.entity_types
386 }
387
388 pub fn rules(&self) -> &[ExtractionRule] {
390 &self.rules
391 }
392}
393
394impl Default for CustomNER {
395 fn default() -> Self {
396 Self::new()
397 }
398}
399
400#[derive(Debug, Clone, Serialize, Deserialize)]
402pub struct ExtractedEntity {
403 pub text: String,
405 pub entity_type: String,
407 pub start: usize,
409 pub end: usize,
411 pub confidence: f32,
413 pub rule_name: String,
415}
416
417#[derive(Debug, Clone, Serialize, Deserialize)]
419pub struct TrainingDataset {
420 pub examples: Vec<AnnotatedExample>,
422}
423
424impl TrainingDataset {
425 pub fn new() -> Self {
427 Self {
428 examples: Vec::new(),
429 }
430 }
431
432 pub fn add_example(&mut self, example: AnnotatedExample) {
434 self.examples.push(example);
435 }
436
437 pub fn statistics(&self) -> DatasetStatistics {
439 let total_examples = self.examples.len();
440 let mut entity_counts: HashMap<String, usize> = HashMap::new();
441
442 for example in &self.examples {
443 for entity in &example.entities {
444 *entity_counts.entry(entity.entity_type.clone()).or_insert(0) += 1;
445 }
446 }
447
448 DatasetStatistics {
449 total_examples,
450 entity_counts,
451 }
452 }
453}
454
455impl Default for TrainingDataset {
456 fn default() -> Self {
457 Self::new()
458 }
459}
460
461#[derive(Debug, Clone, Serialize, Deserialize)]
463pub struct AnnotatedExample {
464 pub text: String,
466 pub entities: Vec<ExtractedEntity>,
468}
469
470#[derive(Debug, Clone, Serialize, Deserialize)]
472pub struct DatasetStatistics {
473 pub total_examples: usize,
475 pub entity_counts: HashMap<String, usize>,
477}
478
479#[cfg(test)]
480mod tests {
481 use super::*;
482
483 #[test]
484 fn test_entity_type_creation() {
485 let mut entity_type = EntityType::new("PROTEIN".to_string(), "Protein names".to_string());
486
487 entity_type.add_example("hemoglobin".to_string());
488 entity_type.add_example("insulin".to_string());
489
490 assert_eq!(entity_type.examples.len(), 2);
491 assert_eq!(entity_type.dictionary.len(), 2);
492 }
493
494 #[test]
495 fn test_exact_match_extraction() {
496 let mut ner = CustomNER::new();
497
498 let rule = ExtractionRule {
499 name: "protein_exact".to_string(),
500 entity_type: "PROTEIN".to_string(),
501 rule_type: RuleType::ExactMatch,
502 pattern: "hemoglobin".to_string(),
503 min_confidence: 0.9,
504 priority: 1,
505 };
506
507 ner.add_rule(rule);
508
509 let text = "The protein hemoglobin is important. Hemoglobin carries oxygen.";
510 let entities = ner.extract(text);
511
512 assert_eq!(entities.len(), 2);
513 assert_eq!(entities[0].entity_type, "PROTEIN");
514 assert_eq!(entities[0].text.to_lowercase(), "hemoglobin");
515 }
516
517 #[test]
518 fn test_regex_extraction() {
519 let mut ner = CustomNER::new();
520
521 let rule = ExtractionRule {
522 name: "gene_pattern".to_string(),
523 entity_type: "GENE".to_string(),
524 rule_type: RuleType::Regex,
525 pattern: r"[A-Z]{2,4}\d+".to_string(),
526 min_confidence: 0.8,
527 priority: 1,
528 };
529
530 ner.add_rule(rule);
531
532 let text = "The genes TP53 and BRCA1 are tumor suppressors.";
533 let entities = ner.extract(text);
534
535 assert!(entities.len() >= 2);
536 assert!(entities.iter().any(|e| e.text == "TP53"));
537 assert!(entities.iter().any(|e| e.text == "BRCA1"));
538 }
539
540 #[test]
541 fn test_dictionary_extraction() {
542 let mut ner = CustomNER::new();
543
544 let mut protein_type = EntityType::new("PROTEIN".to_string(), "Protein names".to_string());
545 protein_type.add_dictionary_entries(vec![
546 "insulin".to_string(),
547 "hemoglobin".to_string(),
548 "collagen".to_string(),
549 ]);
550
551 ner.register_entity_type(protein_type);
552
553 let rule = ExtractionRule {
554 name: "protein_dict".to_string(),
555 entity_type: "PROTEIN".to_string(),
556 rule_type: RuleType::Dictionary,
557 pattern: "".to_string(),
558 min_confidence: 0.9,
559 priority: 2,
560 };
561
562 ner.add_rule(rule);
563
564 let text = "Insulin regulates blood sugar. Hemoglobin transports oxygen.";
565 let entities = ner.extract(text);
566
567 assert_eq!(entities.len(), 2);
568 }
569
570 #[test]
571 fn test_prefix_extraction() {
572 let mut ner = CustomNER::new();
573
574 let rule = ExtractionRule {
575 name: "bio_prefix".to_string(),
576 entity_type: "BIO_TERM".to_string(),
577 rule_type: RuleType::Prefix,
578 pattern: "bio".to_string(),
579 min_confidence: 0.7,
580 priority: 1,
581 };
582
583 ner.add_rule(rule);
584
585 let text = "Biology and biochemistry are fascinating subjects.";
586 let entities = ner.extract(text);
587
588 assert!(entities.len() >= 2);
589 }
590
591 #[test]
592 fn test_overlap_resolution() {
593 let mut ner = CustomNER::new();
594
595 let rule1 = ExtractionRule {
596 name: "rule1".to_string(),
597 entity_type: "TYPE1".to_string(),
598 rule_type: RuleType::ExactMatch,
599 pattern: "test".to_string(),
600 min_confidence: 0.9,
601 priority: 1,
602 };
603
604 let rule2 = ExtractionRule {
605 name: "rule2".to_string(),
606 entity_type: "TYPE2".to_string(),
607 rule_type: RuleType::ExactMatch,
608 pattern: "testing".to_string(),
609 min_confidence: 0.95,
610 priority: 2,
611 };
612
613 ner.add_rule(rule1);
614 ner.add_rule(rule2);
615
616 let text = "We are testing this code.";
617 let entities = ner.extract(text);
618
619 assert_eq!(entities.len(), 1);
621 }
622
623 #[test]
624 fn test_training_dataset() {
625 let mut dataset = TrainingDataset::new();
626
627 let example = AnnotatedExample {
628 text: "Insulin regulates glucose.".to_string(),
629 entities: vec![ExtractedEntity {
630 text: "Insulin".to_string(),
631 entity_type: "PROTEIN".to_string(),
632 start: 0,
633 end: 7,
634 confidence: 1.0,
635 rule_name: "manual".to_string(),
636 }],
637 };
638
639 dataset.add_example(example);
640
641 let stats = dataset.statistics();
642 assert_eq!(stats.total_examples, 1);
643 assert_eq!(stats.entity_counts.get("PROTEIN"), Some(&1));
644 }
645}