1use crate::{Entity, EntityType, Model, Result};
105use std::collections::HashMap;
106
107#[derive(Debug, Clone)]
109pub struct BiLstmCrfConfig {
110 pub hidden_size: usize,
112 pub num_layers: usize,
114 pub dropout: f32,
116 pub use_char_embeddings: bool,
118 pub max_seq_len: usize,
120}
121
122impl Default for BiLstmCrfConfig {
123 fn default() -> Self {
124 Self {
125 hidden_size: 256,
126 num_layers: 2,
127 dropout: 0.5,
128 use_char_embeddings: true,
129 max_seq_len: 512,
130 }
131 }
132}
133
134#[derive(Debug)]
146pub struct BiLstmCrfNER {
147 #[allow(dead_code)] config: BiLstmCrfConfig,
150 labels: Vec<String>,
152 label_to_idx: HashMap<String, usize>,
154 transitions: Vec<Vec<f64>>,
156 #[allow(dead_code)] vocab: HashMap<String, usize>,
159 #[cfg(feature = "onnx")]
161 session: Option<ort::session::Session>,
162}
163
164impl BiLstmCrfNER {
165 #[must_use]
171 pub fn new() -> Self {
172 Self::with_config(BiLstmCrfConfig::default())
173 }
174
175 #[must_use]
177 pub fn with_config(config: BiLstmCrfConfig) -> Self {
178 let labels = vec![
179 "O".to_string(),
180 "B-PER".to_string(),
181 "I-PER".to_string(),
182 "B-ORG".to_string(),
183 "I-ORG".to_string(),
184 "B-LOC".to_string(),
185 "I-LOC".to_string(),
186 "B-MISC".to_string(),
187 "I-MISC".to_string(),
188 ];
189
190 let label_to_idx: HashMap<String, usize> = labels
191 .iter()
192 .enumerate()
193 .map(|(i, l)| (l.clone(), i))
194 .collect();
195
196 let n = labels.len();
199 let mut transitions = vec![vec![0.0; n]; n];
200
201 for i in 0..n {
203 for j in 0..n {
204 let from_label = &labels[i];
205 let to_label = &labels[j];
206
207 if let Some(entity_type) = to_label.strip_prefix("I-") {
208 let valid_prev = format!("B-{}", entity_type);
209 let valid_cont = format!("I-{}", entity_type);
210
211 if from_label == &valid_prev || from_label == &valid_cont {
212 transitions[i][j] = 1.0; } else {
214 transitions[i][j] = -10.0; }
216 } else {
217 transitions[i][j] = 0.0;
219 }
220 }
221 }
222
223 Self {
224 config,
225 labels,
226 label_to_idx,
227 transitions,
228 vocab: HashMap::new(),
229 #[cfg(feature = "onnx")]
230 session: None,
231 }
232 }
233
234 #[cfg(feature = "onnx")]
236 pub fn from_onnx(model_path: &str) -> Result<Self> {
237 use crate::Error;
238 use ort::session::{builder::GraphOptimizationLevel, Session};
239
240 let session = Session::builder()
241 .map_err(|e| Error::model_init(format!("Failed to create session builder: {}", e)))?
242 .with_optimization_level(GraphOptimizationLevel::Level3)
243 .map_err(|e| Error::model_init(format!("Failed to set optimization level: {}", e)))?
244 .commit_from_file(model_path)
245 .map_err(|e| Error::model_init(format!("Failed to load ONNX model: {}", e)))?;
246
247 let mut model = Self::new();
248 model.session = Some(session);
249 Ok(model)
250 }
251
252 fn tokenize(text: &str) -> Vec<&str> {
254 text.split_whitespace().collect()
255 }
256
257 fn get_emissions(&self, tokens: &[&str]) -> Vec<Vec<f64>> {
263 let n_labels = self.labels.len();
264 let mut emissions = vec![vec![0.0; n_labels]; tokens.len()];
265
266 const PERSON_NAMES: &[&str] = &[
268 "john",
269 "mary",
270 "james",
271 "david",
272 "michael",
273 "robert",
274 "william",
275 "richard",
276 "sarah",
277 "jennifer",
278 "elizabeth",
279 "lisa",
280 "marie",
281 "jane",
282 "emily",
283 "anna",
284 "barack",
285 "donald",
286 "joe",
287 "george",
288 "bill",
289 "hillary",
290 "elon",
291 "jeff",
292 "mr",
293 "mrs",
294 "ms",
295 "dr",
296 "prof",
297 "sir",
298 "lord",
299 "president",
300 "ceo",
301 ];
302 const ORG_NAMES: &[&str] = &[
303 "google",
304 "apple",
305 "microsoft",
306 "amazon",
307 "facebook",
308 "meta",
309 "tesla",
310 "ibm",
311 "intel",
312 "nvidia",
313 "oracle",
314 "cisco",
315 "adobe",
316 "netflix",
317 "uber",
318 "university",
319 "institute",
320 "corporation",
321 "company",
322 "inc",
323 "corp",
324 "ltd",
325 "llc",
326 "foundation",
327 "association",
328 "organization",
329 "department",
330 "agency",
331 "fbi",
332 "cia",
333 "nsa",
334 "nasa",
335 "un",
336 "nato",
337 "who",
338 "imf",
339 "eu",
340 "usa",
341 ];
342 const LOC_NAMES: &[&str] = &[
343 "new",
344 "york",
345 "california",
346 "texas",
347 "florida",
348 "london",
349 "paris",
350 "berlin",
351 "tokyo",
352 "beijing",
353 "moscow",
354 "washington",
355 "chicago",
356 "boston",
357 "seattle",
358 "san",
359 "francisco",
360 "los",
361 "angeles",
362 "las",
363 "vegas",
364 "united",
365 "states",
366 "america",
367 "china",
368 "russia",
369 "germany",
370 "france",
371 "japan",
372 "india",
373 "brazil",
374 "city",
375 "county",
376 "state",
377 "country",
378 "river",
379 "mountain",
380 "lake",
381 "ocean",
382 ];
383
384 for (i, token) in tokens.iter().enumerate() {
385 let lower = token.to_lowercase();
386 let is_capitalized = token.chars().next().is_some_and(|c| c.is_uppercase());
387 let is_all_caps = token
388 .chars()
389 .all(|c| c.is_uppercase() || !c.is_alphabetic())
390 && token.len() > 1;
391 let has_digit = token.chars().any(|c| c.is_ascii_digit());
392 let is_first = i == 0;
393
394 emissions[i][0] = 1.5;
396
397 if PERSON_NAMES.contains(&lower.as_str()) {
399 emissions[i][self.label_to_idx["B-PER"]] += 2.0;
400 emissions[i][self.label_to_idx["I-PER"]] += 1.0;
401 }
402 if ORG_NAMES.contains(&lower.as_str()) {
403 emissions[i][self.label_to_idx["B-ORG"]] += 2.0;
404 emissions[i][self.label_to_idx["I-ORG"]] += 1.0;
405 }
406 if LOC_NAMES.contains(&lower.as_str()) {
407 emissions[i][self.label_to_idx["B-LOC"]] += 2.0;
408 emissions[i][self.label_to_idx["I-LOC"]] += 1.0;
409 }
410
411 if is_capitalized && !has_digit && !is_first {
413 emissions[i][self.label_to_idx["B-PER"]] += 0.8;
414 emissions[i][self.label_to_idx["B-ORG"]] += 0.6;
415 emissions[i][self.label_to_idx["B-LOC"]] += 0.5;
416 }
417
418 if lower.ends_with("inc.")
420 || lower.ends_with("corp.")
421 || lower.ends_with("ltd.")
422 || lower.ends_with("llc")
423 || lower.ends_with("co.")
424 {
425 emissions[i][self.label_to_idx["B-ORG"]] += 1.5;
426 emissions[i][self.label_to_idx["I-ORG"]] += 1.0;
427 }
428
429 if is_all_caps && token.len() >= 2 && token.len() <= 5 && !has_digit {
431 emissions[i][self.label_to_idx["B-ORG"]] += 1.2;
432 }
433
434 if ["mr.", "mrs.", "ms.", "dr.", "prof."].contains(&lower.as_str()) {
436 emissions[i][self.label_to_idx["B-PER"]] += 1.5;
437 }
438
439 if i > 0 && tokens[i - 1].to_lowercase() == "the" && is_capitalized {
441 emissions[i][self.label_to_idx["B-ORG"]] += 0.5;
442 emissions[i][self.label_to_idx["B-LOC"]] += 0.3;
443 }
444
445 if i > 0 {
447 let prev_cap = tokens[i - 1]
448 .chars()
449 .next()
450 .is_some_and(|c| c.is_uppercase());
451 if prev_cap && is_capitalized && !is_first {
452 emissions[i][self.label_to_idx["I-PER"]] += 0.6;
454 emissions[i][self.label_to_idx["I-ORG"]] += 0.6;
455 emissions[i][self.label_to_idx["I-LOC"]] += 0.4;
456 }
457 }
458 }
459
460 emissions
461 }
462
463 fn viterbi_decode(&self, emissions: &[Vec<f64>]) -> Vec<usize> {
465 if emissions.is_empty() {
466 return vec![];
467 }
468
469 let n = emissions.len();
470 let m = self.labels.len();
471
472 let mut scores = vec![vec![f64::NEG_INFINITY; m]; n];
474 let mut backpointers = vec![vec![0usize; m]; n];
475
476 for j in 0..m {
478 scores[0][j] = emissions[0][j];
479 }
480
481 for i in 1..n {
483 for j in 0..m {
484 let mut best_score = f64::NEG_INFINITY;
485 let mut best_prev = 0;
486
487 #[allow(clippy::needless_range_loop)]
488 for k in 0..m {
489 let score = scores[i - 1][k] + self.transitions[k][j] + emissions[i][j];
490 if score > best_score {
491 best_score = score;
492 best_prev = k;
493 }
494 }
495
496 scores[i][j] = best_score;
497 backpointers[i][j] = best_prev;
498 }
499 }
500
501 let mut path = vec![0usize; n];
503 let mut best_final = 0;
504 let mut best_score = f64::NEG_INFINITY;
505
506 for (j, &score) in scores[n - 1].iter().enumerate() {
507 if score > best_score {
508 best_score = score;
509 best_final = j;
510 }
511 }
512
513 path[n - 1] = best_final;
514 for i in (0..n - 1).rev() {
515 path[i] = backpointers[i + 1][path[i + 1]];
516 }
517
518 path
519 }
520
521 fn labels_to_entities(
527 &self,
528 text: &str,
529 tokens: &[&str],
530 label_indices: &[usize],
531 ) -> Vec<Entity> {
532 use crate::offset::SpanConverter;
533
534 let converter = SpanConverter::new(text);
535 let mut entities = Vec::new();
536
537 let token_positions: Vec<(usize, usize)> = Self::calculate_token_positions(text, tokens);
539
540 let mut current_entity: Option<(usize, usize, EntityType, Vec<&str>)> = None;
541
542 for (i, (&label_idx, &token)) in label_indices.iter().zip(tokens.iter()).enumerate() {
543 let label = &self.labels[label_idx];
544
545 if let Some(entity_suffix) = label.strip_prefix("B-") {
546 if let Some((start_token_idx, end_token_idx, entity_type, words)) =
548 current_entity.take()
549 {
550 Self::push_entity_from_positions(
551 &converter,
552 &token_positions,
553 start_token_idx,
554 end_token_idx,
555 &words,
556 entity_type,
557 &mut entities,
558 );
559 }
560
561 let entity_type = match entity_suffix {
563 "PER" => EntityType::Person,
564 "ORG" => EntityType::Organization,
565 "LOC" => EntityType::Location,
566 other => EntityType::Other(other.to_string()),
567 };
568 current_entity = Some((i, i, entity_type, vec![token]));
569 } else if label.starts_with("I-") && current_entity.is_some() {
570 if let Some((_, ref mut end_idx, _, ref mut words)) = current_entity {
572 words.push(token);
573 *end_idx = i;
574 }
575 } else {
576 if let Some((start_token_idx, end_token_idx, entity_type, words)) =
578 current_entity.take()
579 {
580 Self::push_entity_from_positions(
581 &converter,
582 &token_positions,
583 start_token_idx,
584 end_token_idx,
585 &words,
586 entity_type,
587 &mut entities,
588 );
589 }
590 }
591 }
592
593 if let Some((start_token_idx, end_token_idx, entity_type, words)) = current_entity.take() {
595 Self::push_entity_from_positions(
596 &converter,
597 &token_positions,
598 start_token_idx,
599 end_token_idx,
600 &words,
601 entity_type,
602 &mut entities,
603 );
604 }
605
606 entities
607 }
608
609 fn calculate_token_positions(text: &str, tokens: &[&str]) -> Vec<(usize, usize)> {
611 let mut positions = Vec::with_capacity(tokens.len());
612 let mut byte_pos = 0;
613
614 for token in tokens {
615 if let Some(rel_pos) = text[byte_pos..].find(token) {
617 let start = byte_pos + rel_pos;
618 let end = start + token.len();
619 positions.push((start, end));
620 byte_pos = end; } else {
622 positions.push((byte_pos, byte_pos));
624 }
625 }
626
627 positions
628 }
629
630 fn push_entity_from_positions(
632 converter: &crate::offset::SpanConverter,
633 positions: &[(usize, usize)],
634 start_token_idx: usize,
635 end_token_idx: usize,
636 words: &[&str],
637 entity_type: EntityType,
638 entities: &mut Vec<Entity>,
639 ) {
640 if start_token_idx >= positions.len() || end_token_idx >= positions.len() {
641 return;
642 }
643
644 let byte_start = positions[start_token_idx].0;
645 let byte_end = positions[end_token_idx].1;
646 let char_start = converter.byte_to_char(byte_start);
647 let char_end = converter.byte_to_char(byte_end);
648 let entity_text = words.join(" ");
649
650 entities.push(Entity::new(
651 entity_text,
652 entity_type,
653 char_start,
654 char_end,
655 0.75, ));
657 }
658}
659
660impl Default for BiLstmCrfNER {
661 fn default() -> Self {
662 Self::new()
663 }
664}
665
666impl Model for BiLstmCrfNER {
667 fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
668 if text.trim().is_empty() {
669 return Ok(vec![]);
670 }
671
672 let tokens = Self::tokenize(text);
673 if tokens.is_empty() {
674 return Ok(vec![]);
675 }
676
677 let emissions = self.get_emissions(&tokens);
679
680 let label_indices = self.viterbi_decode(&emissions);
682
683 let entities = self.labels_to_entities(text, &tokens, &label_indices);
685
686 Ok(entities)
687 }
688
689 fn supported_types(&self) -> Vec<EntityType> {
690 vec![
691 EntityType::Person,
692 EntityType::Organization,
693 EntityType::Location,
694 EntityType::Other("MISC".to_string()),
695 ]
696 }
697
698 fn is_available(&self) -> bool {
699 true }
701}
702
703impl crate::sealed::Sealed for BiLstmCrfNER {}
704impl crate::NamedEntityCapable for BiLstmCrfNER {}
705impl crate::BatchCapable for BiLstmCrfNER {
706 fn optimal_batch_size(&self) -> Option<usize> {
707 Some(32) }
709}
710
711#[cfg(test)]
712mod tests {
713 use super::*;
714
715 #[test]
716 fn test_basic_extraction() {
717 let ner = BiLstmCrfNER::new();
718 let entities = ner
719 .extract_entities("John Smith works at Google Inc.", None)
720 .unwrap();
721
722 assert!(entities
725 .iter()
726 .all(|e| e.confidence > 0.0 && e.confidence <= 1.0));
727 }
728
729 #[test]
730 fn test_empty_input() {
731 let ner = BiLstmCrfNER::new();
732 let entities = ner.extract_entities("", None).unwrap();
733 assert!(entities.is_empty());
734 }
735
736 #[test]
737 fn test_whitespace_only() {
738 let ner = BiLstmCrfNER::new();
739 let entities = ner.extract_entities(" \n\t ", None).unwrap();
740 assert!(entities.is_empty());
741 }
742
743 #[test]
744 fn test_viterbi_respects_bio_constraints() {
745 let ner = BiLstmCrfNER::new();
746
747 let emissions = vec![
750 vec![0.5, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], vec![0.1, 0.1, 0.8, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], ];
753
754 let path = ner.viterbi_decode(&emissions);
755
756 if path[0] == 0 {
759 assert!(
761 path[1] == 0 || ner.labels[path[1]].starts_with("B-"),
762 "Invalid BIO sequence: O followed by {}",
763 ner.labels[path[1]]
764 );
765 }
766 }
767
768 #[test]
769 fn test_unicode_offsets() {
770 let ner = BiLstmCrfNER::new();
771 let text = "北京 Google Inc.";
772 let char_count = text.chars().count();
773
774 let entities = ner.extract_entities(text, None).unwrap();
775
776 for entity in &entities {
777 assert!(entity.start <= entity.end);
778 assert!(entity.end <= char_count);
779 }
780 }
781
782 #[test]
783 fn test_config() {
784 let config = BiLstmCrfConfig {
785 hidden_size: 512,
786 num_layers: 3,
787 dropout: 0.3,
788 use_char_embeddings: false,
789 max_seq_len: 256,
790 };
791
792 let ner = BiLstmCrfNER::with_config(config.clone());
793 assert_eq!(ner.config.hidden_size, 512);
794 assert_eq!(ner.config.num_layers, 3);
795 }
796
797 #[test]
798 fn test_transition_matrix_shape() {
799 let ner = BiLstmCrfNER::new();
800 let n = ner.labels.len();
801
802 assert_eq!(ner.transitions.len(), n);
803 for row in &ner.transitions {
804 assert_eq!(row.len(), n);
805 }
806 }
807
808 #[test]
809 fn test_supported_types() {
810 let ner = BiLstmCrfNER::new();
811 let types = ner.supported_types();
812
813 assert!(types.contains(&EntityType::Person));
814 assert!(types.contains(&EntityType::Organization));
815 assert!(types.contains(&EntityType::Location));
816 }
817
818 #[test]
825 fn test_duplicate_entity_offsets() {
826 let ner = BiLstmCrfNER::new();
827
828 let text = "Google bought Google for $1 billion.";
830
831 let tokens: Vec<&str> = text.split_whitespace().collect();
833 let positions = BiLstmCrfNER::calculate_token_positions(text, &tokens);
834
835 assert_eq!(
838 positions[0],
839 (0, 6),
840 "First 'Google' should be at bytes 0-6"
841 );
842 assert_eq!(
844 positions[2],
845 (14, 20),
846 "Second 'Google' should be at bytes 14-20"
847 );
848
849 let entities = ner.extract_entities(text, None).unwrap();
851
852 let google_entities: Vec<_> = entities
854 .iter()
855 .filter(|e| e.text.contains("Google"))
856 .collect();
857
858 if google_entities.len() >= 2 {
859 assert_ne!(
860 google_entities[0].start, google_entities[1].start,
861 "Duplicate entities should have different start positions"
862 );
863 }
864 }
865
866 #[test]
868 fn test_token_positions_unicode() {
869 let text = "東京 Tokyo 東京 Osaka";
870 let tokens: Vec<&str> = text.split_whitespace().collect();
871 let positions = BiLstmCrfNER::calculate_token_positions(text, &tokens);
872
873 assert_eq!(positions[0], (0, 6), "First '東京' at bytes 0-6");
875 assert_eq!(positions[1], (7, 12), "Tokyo at bytes 7-12");
876 assert_eq!(positions[2], (13, 19), "Second '東京' at bytes 13-19");
877 assert_eq!(positions[3], (20, 25), "Osaka at bytes 20-25");
878 }
879}