1use super::confidence::Confidence;
39use super::types::{MentionType, PhiFeatures};
40use serde::{Deserialize, Serialize};
41use std::borrow::Cow;
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
53#[non_exhaustive]
54pub enum EntityCategory {
55 Agent,
58 Organization,
61 Place,
64 Creative,
67 Temporal,
70 Numeric,
73 Contact,
76 Relation,
80 Misc,
82}
83
84impl EntityCategory {
85 #[must_use]
87 pub const fn requires_ml(&self) -> bool {
88 matches!(
89 self,
90 EntityCategory::Agent
91 | EntityCategory::Organization
92 | EntityCategory::Place
93 | EntityCategory::Creative
94 | EntityCategory::Relation
95 )
96 }
97
98 #[must_use]
100 pub const fn pattern_detectable(&self) -> bool {
101 matches!(
102 self,
103 EntityCategory::Temporal | EntityCategory::Numeric | EntityCategory::Contact
104 )
105 }
106
107 #[must_use]
109 pub const fn is_relation(&self) -> bool {
110 matches!(self, EntityCategory::Relation)
111 }
112
113 #[must_use]
115 pub const fn as_str(&self) -> &'static str {
116 match self {
117 EntityCategory::Agent => "agent",
118 EntityCategory::Organization => "organization",
119 EntityCategory::Place => "place",
120 EntityCategory::Creative => "creative",
121 EntityCategory::Temporal => "temporal",
122 EntityCategory::Numeric => "numeric",
123 EntityCategory::Contact => "contact",
124 EntityCategory::Relation => "relation",
125 EntityCategory::Misc => "misc",
126 }
127 }
128}
129
130impl std::fmt::Display for EntityCategory {
131 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132 write!(f, "{}", self.as_str())
133 }
134}
135
136#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
173#[non_exhaustive]
174pub enum EntityViewport {
175 Business,
177 Legal,
179 Technical,
181 Academic,
183 Personal,
185 Political,
187 Media,
189 Historical,
191 #[default]
193 General,
194 Custom(String),
196}
197
198impl EntityViewport {
199 #[must_use]
201 pub fn as_str(&self) -> &str {
202 match self {
203 EntityViewport::Business => "business",
204 EntityViewport::Legal => "legal",
205 EntityViewport::Technical => "technical",
206 EntityViewport::Academic => "academic",
207 EntityViewport::Personal => "personal",
208 EntityViewport::Political => "political",
209 EntityViewport::Media => "media",
210 EntityViewport::Historical => "historical",
211 EntityViewport::General => "general",
212 EntityViewport::Custom(s) => s,
213 }
214 }
215
216 #[must_use]
218 pub const fn is_professional(&self) -> bool {
219 matches!(
220 self,
221 EntityViewport::Business
222 | EntityViewport::Legal
223 | EntityViewport::Technical
224 | EntityViewport::Academic
225 | EntityViewport::Political
226 )
227 }
228}
229
230impl std::str::FromStr for EntityViewport {
231 type Err = std::convert::Infallible;
232
233 fn from_str(s: &str) -> Result<Self, Self::Err> {
234 Ok(match s.to_lowercase().as_str() {
235 "business" | "financial" | "corporate" => EntityViewport::Business,
236 "legal" | "law" | "compliance" => EntityViewport::Legal,
237 "technical" | "engineering" | "tech" => EntityViewport::Technical,
238 "academic" | "research" | "scholarly" => EntityViewport::Academic,
239 "personal" | "biographical" | "private" => EntityViewport::Personal,
240 "political" | "policy" | "government" => EntityViewport::Political,
241 "media" | "press" | "pr" | "public_relations" => EntityViewport::Media,
242 "historical" | "history" | "past" => EntityViewport::Historical,
243 "general" | "generic" | "" => EntityViewport::General,
244 other => EntityViewport::Custom(other.to_string()),
245 })
246 }
247}
248
249impl std::fmt::Display for EntityViewport {
250 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
251 write!(f, "{}", self.as_str())
252 }
253}
254
255#[derive(Debug, Clone, PartialEq, Eq, Hash)]
280#[non_exhaustive]
281pub enum EntityType {
282 Person,
285 Organization,
287 Location,
289
290 Date,
293 Time,
295
296 Money,
299 Percent,
301 Quantity,
303 Cardinal,
305 Ordinal,
307
308 Email,
311 Url,
313 Phone,
315
316 Custom {
319 name: String,
321 category: EntityCategory,
323 },
324
325 #[deprecated(note = "use EntityType::custom(name, EntityCategory::Misc) instead")]
331 Other(String),
332}
333
334impl EntityType {
335 #[must_use]
337 pub fn category(&self) -> EntityCategory {
338 match self {
339 EntityType::Person => EntityCategory::Agent,
341 EntityType::Organization => EntityCategory::Organization,
343 EntityType::Location => EntityCategory::Place,
345 EntityType::Date | EntityType::Time => EntityCategory::Temporal,
347 EntityType::Money
349 | EntityType::Percent
350 | EntityType::Quantity
351 | EntityType::Cardinal
352 | EntityType::Ordinal => EntityCategory::Numeric,
353 EntityType::Email | EntityType::Url | EntityType::Phone => EntityCategory::Contact,
355 EntityType::Custom { category, .. } => *category,
357 #[allow(deprecated)]
359 EntityType::Other(_) => EntityCategory::Misc,
360 }
361 }
362
363 #[must_use]
365 pub fn requires_ml(&self) -> bool {
366 self.category().requires_ml()
367 }
368
369 #[must_use]
371 pub fn pattern_detectable(&self) -> bool {
372 self.category().pattern_detectable()
373 }
374
375 #[must_use]
384 pub fn as_label(&self) -> &str {
385 match self {
386 EntityType::Person => "PER",
387 EntityType::Organization => "ORG",
388 EntityType::Location => "LOC",
389 EntityType::Date => "DATE",
390 EntityType::Time => "TIME",
391 EntityType::Money => "MONEY",
392 EntityType::Percent => "PERCENT",
393 EntityType::Quantity => "QUANTITY",
394 EntityType::Cardinal => "CARDINAL",
395 EntityType::Ordinal => "ORDINAL",
396 EntityType::Email => "EMAIL",
397 EntityType::Url => "URL",
398 EntityType::Phone => "PHONE",
399 EntityType::Custom { name, .. } => name.as_str(),
400 #[allow(deprecated)]
401 EntityType::Other(s) => s.as_str(),
402 }
403 }
404
405 #[must_use]
417 pub fn from_label(label: &str) -> Self {
418 let label = label
420 .strip_prefix("B-")
421 .or_else(|| label.strip_prefix("I-"))
422 .or_else(|| label.strip_prefix("E-"))
423 .or_else(|| label.strip_prefix("S-"))
424 .unwrap_or(label);
425
426 match label.to_uppercase().as_str() {
427 "PER" | "PERSON" => EntityType::Person,
429 "ORG" | "ORGANIZATION" | "COMPANY" | "CORPORATION" => EntityType::Organization,
430 "LOC" | "LOCATION" | "GPE" | "GEO-LOC" => EntityType::Location,
431 "FACILITY" | "FAC" | "BUILDING" => {
433 EntityType::custom("BUILDING", EntityCategory::Place)
434 }
435 "PRODUCT" | "PROD" => EntityType::custom("PRODUCT", EntityCategory::Misc),
436 "EVENT" => EntityType::custom("EVENT", EntityCategory::Creative),
437 "CREATIVE-WORK" | "WORK_OF_ART" | "ART" => {
438 EntityType::custom("CREATIVE_WORK", EntityCategory::Creative)
439 }
440 "GROUP" | "NORP" => EntityType::custom("GROUP", EntityCategory::Agent),
441 "DATE" => EntityType::Date,
443 "TIME" => EntityType::Time,
444 "MONEY" | "CURRENCY" => EntityType::Money,
446 "PERCENT" | "PERCENTAGE" => EntityType::Percent,
447 "QUANTITY" => EntityType::Quantity,
448 "CARDINAL" => EntityType::Cardinal,
449 "ORDINAL" => EntityType::Ordinal,
450 "EMAIL" => EntityType::Email,
452 "URL" | "URI" => EntityType::Url,
453 "PHONE" | "TELEPHONE" => EntityType::Phone,
454 "MISC" | "MISCELLANEOUS" | "OTHER" => EntityType::custom("MISC", EntityCategory::Misc),
456 "DISEASE" | "DISORDER" => EntityType::custom("DISEASE", EntityCategory::Misc),
458 "CHEMICAL" | "DRUG" => EntityType::custom("CHEMICAL", EntityCategory::Misc),
459 "GENE" => EntityType::custom("GENE", EntityCategory::Misc),
460 "PROTEIN" => EntityType::custom("PROTEIN", EntityCategory::Misc),
461 other => EntityType::custom(other, EntityCategory::Misc),
463 }
464 }
465
466 #[must_use]
480 pub fn custom(name: impl Into<String>, category: EntityCategory) -> Self {
481 EntityType::Custom {
482 name: name.into(),
483 category,
484 }
485 }
486}
487
488impl std::fmt::Display for EntityType {
489 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
490 write!(f, "{}", self.as_label())
491 }
492}
493
494impl std::str::FromStr for EntityType {
495 type Err = std::convert::Infallible;
496
497 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
499 Ok(Self::from_label(s))
500 }
501}
502
503impl Serialize for EntityType {
508 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
509 serializer.serialize_str(self.as_label())
510 }
511}
512
513impl<'de> Deserialize<'de> for EntityType {
514 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
515 struct EntityTypeVisitor;
516
517 impl<'de> serde::de::Visitor<'de> for EntityTypeVisitor {
518 type Value = EntityType;
519
520 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
521 f.write_str("a string label or a tagged enum object")
522 }
523
524 fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<EntityType, E> {
526 Ok(EntityType::from_label(v))
527 }
528
529 fn visit_map<A: serde::de::MapAccess<'de>>(
532 self,
533 mut map: A,
534 ) -> Result<EntityType, A::Error> {
535 let key: String = map
536 .next_key()?
537 .ok_or_else(|| serde::de::Error::custom("empty object"))?;
538 match key.as_str() {
539 "Custom" => {
540 #[derive(Deserialize)]
541 struct CustomFields {
542 name: String,
543 category: EntityCategory,
544 }
545 let fields: CustomFields = map.next_value()?;
546 Ok(EntityType::Custom {
547 name: fields.name,
548 category: fields.category,
549 })
550 }
551 "Other" => {
552 let val: String = map.next_value()?;
554 Ok(EntityType::custom(val, EntityCategory::Misc))
555 }
556 variant => {
558 let _: serde::de::IgnoredAny = map.next_value()?;
560 Ok(EntityType::from_label(variant))
561 }
562 }
563 }
564 }
565
566 deserializer.deserialize_any(EntityTypeVisitor)
567 }
568}
569
570#[derive(Debug, Clone, Default)]
610pub struct TypeMapper {
611 mappings: std::collections::HashMap<String, EntityType>,
612}
613
614impl TypeMapper {
615 #[must_use]
617 pub fn new() -> Self {
618 Self::default()
619 }
620
621 #[must_use]
623 pub fn mit_movie() -> Self {
624 let mut mapper = Self::new();
625 mapper.add("ACTOR", EntityType::Person);
627 mapper.add("DIRECTOR", EntityType::Person);
628 mapper.add("CHARACTER", EntityType::Person);
629 mapper.add(
630 "TITLE",
631 EntityType::custom("WORK_OF_ART", EntityCategory::Creative),
632 );
633 mapper.add("GENRE", EntityType::custom("GENRE", EntityCategory::Misc));
634 mapper.add("YEAR", EntityType::Date);
635 mapper.add("RATING", EntityType::custom("RATING", EntityCategory::Misc));
636 mapper.add("PLOT", EntityType::custom("PLOT", EntityCategory::Misc));
637 mapper
638 }
639
640 #[must_use]
642 pub fn mit_restaurant() -> Self {
643 let mut mapper = Self::new();
644 mapper.add("RESTAURANT_NAME", EntityType::Organization);
645 mapper.add("LOCATION", EntityType::Location);
646 mapper.add(
647 "CUISINE",
648 EntityType::custom("CUISINE", EntityCategory::Misc),
649 );
650 mapper.add("DISH", EntityType::custom("DISH", EntityCategory::Misc));
651 mapper.add("PRICE", EntityType::Money);
652 mapper.add(
653 "AMENITY",
654 EntityType::custom("AMENITY", EntityCategory::Misc),
655 );
656 mapper.add("HOURS", EntityType::Time);
657 mapper
658 }
659
660 #[must_use]
662 pub fn biomedical() -> Self {
663 let mut mapper = Self::new();
664 mapper.add(
665 "DISEASE",
666 EntityType::custom("DISEASE", EntityCategory::Agent),
667 );
668 mapper.add(
669 "CHEMICAL",
670 EntityType::custom("CHEMICAL", EntityCategory::Misc),
671 );
672 mapper.add("DRUG", EntityType::custom("DRUG", EntityCategory::Misc));
673 mapper.add("GENE", EntityType::custom("GENE", EntityCategory::Misc));
674 mapper.add(
675 "PROTEIN",
676 EntityType::custom("PROTEIN", EntityCategory::Misc),
677 );
678 mapper.add("DNA", EntityType::custom("DNA", EntityCategory::Misc));
680 mapper.add("RNA", EntityType::custom("RNA", EntityCategory::Misc));
681 mapper.add(
682 "cell_line",
683 EntityType::custom("CELL_LINE", EntityCategory::Misc),
684 );
685 mapper.add(
686 "cell_type",
687 EntityType::custom("CELL_TYPE", EntityCategory::Misc),
688 );
689 mapper
690 }
691
692 #[must_use]
694 pub fn social_media() -> Self {
695 let mut mapper = Self::new();
696 mapper.add("person", EntityType::Person);
698 mapper.add("corporation", EntityType::Organization);
699 mapper.add("location", EntityType::Location);
700 mapper.add("group", EntityType::Organization);
701 mapper.add(
702 "product",
703 EntityType::custom("PRODUCT", EntityCategory::Misc),
704 );
705 mapper.add(
706 "creative_work",
707 EntityType::custom("WORK_OF_ART", EntityCategory::Creative),
708 );
709 mapper.add("event", EntityType::custom("EVENT", EntityCategory::Misc));
710 mapper
711 }
712
713 #[must_use]
715 pub fn manufacturing() -> Self {
716 let mut mapper = Self::new();
717 mapper.add("MATE", EntityType::custom("MATERIAL", EntityCategory::Misc));
719 mapper.add("MANP", EntityType::custom("PROCESS", EntityCategory::Misc));
720 mapper.add("MACEQ", EntityType::custom("MACHINE", EntityCategory::Misc));
721 mapper.add(
722 "APPL",
723 EntityType::custom("APPLICATION", EntityCategory::Misc),
724 );
725 mapper.add("FEAT", EntityType::custom("FEATURE", EntityCategory::Misc));
726 mapper.add(
727 "PARA",
728 EntityType::custom("PARAMETER", EntityCategory::Misc),
729 );
730 mapper.add("PRO", EntityType::custom("PROPERTY", EntityCategory::Misc));
731 mapper.add(
732 "CHAR",
733 EntityType::custom("CHARACTERISTIC", EntityCategory::Misc),
734 );
735 mapper.add(
736 "ENAT",
737 EntityType::custom("ENABLING_TECHNOLOGY", EntityCategory::Misc),
738 );
739 mapper.add(
740 "CONPRI",
741 EntityType::custom("CONCEPT_PRINCIPLE", EntityCategory::Misc),
742 );
743 mapper.add(
744 "BIOP",
745 EntityType::custom("BIO_PROCESS", EntityCategory::Misc),
746 );
747 mapper.add(
748 "MANS",
749 EntityType::custom("MAN_STANDARD", EntityCategory::Misc),
750 );
751 mapper
752 }
753
754 pub fn add(&mut self, source: impl Into<String>, target: EntityType) {
756 self.mappings.insert(source.into().to_uppercase(), target);
757 }
758
759 #[must_use]
761 pub fn map(&self, label: &str) -> Option<&EntityType> {
762 self.mappings.get(&label.to_uppercase())
763 }
764
765 #[must_use]
769 pub fn normalize(&self, label: &str) -> EntityType {
770 self.map(label)
771 .cloned()
772 .unwrap_or_else(|| EntityType::from_label(label))
773 }
774
775 #[must_use]
777 pub fn contains(&self, label: &str) -> bool {
778 self.mappings.contains_key(&label.to_uppercase())
779 }
780
781 pub fn labels(&self) -> impl Iterator<Item = &String> {
783 self.mappings.keys()
784 }
785}
786
787#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
803#[non_exhaustive]
804pub enum ExtractionMethod {
805 Pattern,
808
809 #[default]
812 Neural,
813
814 #[deprecated(since = "0.2.0", note = "Use Neural or GatedEnsemble instead")]
818 Lexicon,
819
820 SoftLexicon,
824
825 GatedEnsemble,
829
830 Consensus,
832
833 Heuristic,
836
837 Unknown,
839
840 #[deprecated(since = "0.2.0", note = "Use Heuristic or Pattern instead")]
842 Rule,
843
844 #[deprecated(since = "0.2.0", note = "Use Neural instead")]
846 ML,
847
848 #[deprecated(since = "0.2.0", note = "Use Consensus instead")]
850 Ensemble,
851}
852
853impl ExtractionMethod {
854 #[must_use]
881 pub const fn is_calibrated(&self) -> bool {
882 #[allow(deprecated)]
883 match self {
884 ExtractionMethod::Neural => true,
885 ExtractionMethod::GatedEnsemble => true,
886 ExtractionMethod::SoftLexicon => true,
887 ExtractionMethod::ML => true, ExtractionMethod::Pattern => false,
890 ExtractionMethod::Lexicon => false,
891 ExtractionMethod::Consensus => false,
892 ExtractionMethod::Heuristic => false,
893 ExtractionMethod::Unknown => false,
894 ExtractionMethod::Rule => false,
895 ExtractionMethod::Ensemble => false,
896 }
897 }
898
899 #[must_use]
906 pub const fn confidence_interpretation(&self) -> &'static str {
907 #[allow(deprecated)]
908 match self {
909 ExtractionMethod::Neural | ExtractionMethod::ML => "probability",
910 ExtractionMethod::GatedEnsemble | ExtractionMethod::SoftLexicon => "probability",
911 ExtractionMethod::Pattern | ExtractionMethod::Lexicon => "binary",
912 ExtractionMethod::Heuristic | ExtractionMethod::Rule => "heuristic_score",
913 ExtractionMethod::Consensus | ExtractionMethod::Ensemble => "agreement_ratio",
914 ExtractionMethod::Unknown => "unknown",
915 }
916 }
917}
918
919impl std::fmt::Display for ExtractionMethod {
920 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
921 #[allow(deprecated)]
922 match self {
923 ExtractionMethod::Pattern => write!(f, "pattern"),
924 ExtractionMethod::Neural => write!(f, "neural"),
925 ExtractionMethod::Lexicon => write!(f, "lexicon"),
926 ExtractionMethod::SoftLexicon => write!(f, "soft_lexicon"),
927 ExtractionMethod::GatedEnsemble => write!(f, "gated_ensemble"),
928 ExtractionMethod::Consensus => write!(f, "consensus"),
929 ExtractionMethod::Heuristic => write!(f, "heuristic"),
930 ExtractionMethod::Unknown => write!(f, "unknown"),
931 ExtractionMethod::Rule => write!(f, "heuristic"), ExtractionMethod::ML => write!(f, "neural"), ExtractionMethod::Ensemble => write!(f, "consensus"), }
935 }
936}
937
938pub trait Lexicon: Send + Sync {
992 fn lookup(&self, text: &str) -> Option<(EntityType, Confidence)>;
996
997 fn contains(&self, text: &str) -> bool {
999 self.lookup(text).is_some()
1000 }
1001
1002 fn source(&self) -> &str;
1004
1005 fn len(&self) -> usize;
1007
1008 fn is_empty(&self) -> bool {
1010 self.len() == 0
1011 }
1012}
1013
1014#[derive(Debug, Clone)]
1019pub struct HashMapLexicon {
1020 entries: std::collections::HashMap<String, (EntityType, Confidence)>,
1021 source: String,
1022}
1023
1024impl HashMapLexicon {
1025 #[must_use]
1027 pub fn new(source: impl Into<String>) -> Self {
1028 Self {
1029 entries: std::collections::HashMap::new(),
1030 source: source.into(),
1031 }
1032 }
1033
1034 pub fn insert(
1036 &mut self,
1037 text: impl Into<String>,
1038 entity_type: EntityType,
1039 confidence: impl Into<Confidence>,
1040 ) {
1041 self.entries
1042 .insert(text.into(), (entity_type, confidence.into()));
1043 }
1044
1045 pub fn from_iter<I, S, C>(source: impl Into<String>, entries: I) -> Self
1047 where
1048 I: IntoIterator<Item = (S, EntityType, C)>,
1049 S: Into<String>,
1050 C: Into<Confidence>,
1051 {
1052 let mut lexicon = Self::new(source);
1053 for (text, entity_type, conf) in entries {
1054 lexicon.insert(text, entity_type, conf);
1055 }
1056 lexicon
1057 }
1058
1059 pub fn entries(&self) -> impl Iterator<Item = (&str, &EntityType, Confidence)> {
1061 self.entries.iter().map(|(k, (t, c))| (k.as_str(), t, *c))
1062 }
1063}
1064
1065impl Lexicon for HashMapLexicon {
1066 fn lookup(&self, text: &str) -> Option<(EntityType, Confidence)> {
1067 self.entries.get(text).cloned()
1068 }
1069
1070 fn source(&self) -> &str {
1071 &self.source
1072 }
1073
1074 fn len(&self) -> usize {
1075 self.entries.len()
1076 }
1077}
1078
1079#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
1084pub struct Provenance {
1085 pub source: Cow<'static, str>,
1087 pub method: ExtractionMethod,
1089 pub pattern: Option<Cow<'static, str>>,
1091 pub raw_confidence: Option<Confidence>,
1093 #[serde(default, skip_serializing_if = "Option::is_none")]
1095 pub model_version: Option<Cow<'static, str>>,
1096 #[serde(default, skip_serializing_if = "Option::is_none")]
1098 pub timestamp: Option<String>,
1099}
1100
1101impl Provenance {
1102 #[must_use]
1104 pub fn pattern(pattern_name: &'static str) -> Self {
1105 Self {
1106 source: Cow::Borrowed("pattern"),
1107 method: ExtractionMethod::Pattern,
1108 pattern: Some(Cow::Borrowed(pattern_name)),
1109 raw_confidence: Some(Confidence::ONE), model_version: None,
1111 timestamp: None,
1112 }
1113 }
1114
1115 #[must_use]
1129 pub fn ml(model_name: impl Into<Cow<'static, str>>, confidence: impl Into<Confidence>) -> Self {
1130 Self {
1131 source: model_name.into(),
1132 method: ExtractionMethod::Neural,
1133 pattern: None,
1134 raw_confidence: Some(confidence.into()),
1135 model_version: None,
1136 timestamp: None,
1137 }
1138 }
1139
1140 #[deprecated(
1142 since = "0.2.1",
1143 note = "Use ml() instead, it now accepts owned strings"
1144 )]
1145 #[must_use]
1146 pub fn ml_owned(model_name: impl Into<String>, confidence: impl Into<Confidence>) -> Self {
1147 Self::ml(Cow::Owned(model_name.into()), confidence)
1148 }
1149
1150 #[must_use]
1152 pub fn ensemble(sources: &'static str) -> Self {
1153 Self {
1154 source: Cow::Borrowed(sources),
1155 method: ExtractionMethod::Consensus,
1156 pattern: None,
1157 raw_confidence: None,
1158 model_version: None,
1159 timestamp: None,
1160 }
1161 }
1162
1163 #[must_use]
1165 pub fn with_version(mut self, version: &'static str) -> Self {
1166 self.model_version = Some(Cow::Borrowed(version));
1167 self
1168 }
1169
1170 #[must_use]
1172 pub fn with_timestamp(mut self, timestamp: impl Into<String>) -> Self {
1173 self.timestamp = Some(timestamp.into());
1174 self
1175 }
1176}
1177
1178#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1209pub enum Span {
1210 Text {
1215 start: usize,
1217 end: usize,
1219 },
1220 BoundingBox {
1223 x: f32,
1225 y: f32,
1227 width: f32,
1229 height: f32,
1231 page: Option<u32>,
1233 },
1234 Hybrid {
1236 start: usize,
1238 end: usize,
1240 bbox: Box<Span>,
1242 },
1243}
1244
1245impl Span {
1246 #[must_use]
1248 pub const fn text(start: usize, end: usize) -> Self {
1249 Self::Text { start, end }
1250 }
1251
1252 #[must_use]
1254 pub fn bbox(x: f32, y: f32, width: f32, height: f32) -> Self {
1255 Self::BoundingBox {
1256 x,
1257 y,
1258 width,
1259 height,
1260 page: None,
1261 }
1262 }
1263
1264 #[must_use]
1266 pub fn bbox_on_page(x: f32, y: f32, width: f32, height: f32, page: u32) -> Self {
1267 Self::BoundingBox {
1268 x,
1269 y,
1270 width,
1271 height,
1272 page: Some(page),
1273 }
1274 }
1275
1276 #[must_use]
1278 pub const fn is_text(&self) -> bool {
1279 matches!(self, Self::Text { .. } | Self::Hybrid { .. })
1280 }
1281
1282 #[must_use]
1284 pub const fn is_visual(&self) -> bool {
1285 matches!(self, Self::BoundingBox { .. } | Self::Hybrid { .. })
1286 }
1287
1288 #[must_use]
1290 pub const fn text_offsets(&self) -> Option<(usize, usize)> {
1291 match self {
1292 Self::Text { start, end } => Some((*start, *end)),
1293 Self::Hybrid { start, end, .. } => Some((*start, *end)),
1294 Self::BoundingBox { .. } => None,
1295 }
1296 }
1297
1298 #[must_use]
1300 pub fn len(&self) -> usize {
1301 match self {
1302 Self::Text { start, end } => end.saturating_sub(*start),
1303 Self::Hybrid { start, end, .. } => end.saturating_sub(*start),
1304 Self::BoundingBox { .. } => 0,
1305 }
1306 }
1307
1308 #[must_use]
1310 pub fn is_empty(&self) -> bool {
1311 self.len() == 0
1312 }
1313}
1314
1315#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1356pub struct DiscontinuousSpan {
1357 segments: Vec<std::ops::Range<usize>>,
1360}
1361
1362impl DiscontinuousSpan {
1363 #[must_use]
1367 pub fn new(mut segments: Vec<std::ops::Range<usize>>) -> Self {
1368 segments.sort_by_key(|r| r.start);
1370 Self { segments }
1371 }
1372
1373 #[must_use]
1375 #[allow(clippy::single_range_in_vec_init)] pub fn contiguous(start: usize, end: usize) -> Self {
1377 Self {
1378 segments: vec![start..end],
1379 }
1380 }
1381
1382 #[must_use]
1384 pub fn num_segments(&self) -> usize {
1385 self.segments.len()
1386 }
1387
1388 #[must_use]
1390 pub fn is_discontinuous(&self) -> bool {
1391 self.segments.len() > 1
1392 }
1393
1394 #[must_use]
1396 pub fn is_contiguous(&self) -> bool {
1397 self.segments.len() <= 1
1398 }
1399
1400 #[must_use]
1402 pub fn segments(&self) -> &[std::ops::Range<usize>] {
1403 &self.segments
1404 }
1405
1406 #[must_use]
1408 pub fn bounding_range(&self) -> Option<std::ops::Range<usize>> {
1409 if self.segments.is_empty() {
1410 return None;
1411 }
1412 let start = self.segments.first()?.start;
1413 let end = self.segments.last()?.end;
1414 Some(start..end)
1415 }
1416
1417 #[must_use]
1420 pub fn total_len(&self) -> usize {
1421 self.segments.iter().map(|r| r.end - r.start).sum()
1422 }
1423
1424 #[must_use]
1426 pub fn extract_text(&self, text: &str, separator: &str) -> String {
1427 self.segments
1428 .iter()
1429 .map(|r| {
1430 let start = r.start;
1431 let len = r.end.saturating_sub(r.start);
1432 text.chars().skip(start).take(len).collect::<String>()
1433 })
1434 .collect::<Vec<_>>()
1435 .join(separator)
1436 }
1437
1438 #[must_use]
1448 pub fn contains(&self, pos: usize) -> bool {
1449 self.segments.iter().any(|r| r.contains(&pos))
1450 }
1451
1452 #[must_use]
1454 pub fn to_span(&self) -> Option<Span> {
1455 self.bounding_range().map(|r| Span::Text {
1456 start: r.start,
1457 end: r.end,
1458 })
1459 }
1460}
1461
1462impl From<std::ops::Range<usize>> for DiscontinuousSpan {
1463 fn from(range: std::ops::Range<usize>) -> Self {
1464 Self::contiguous(range.start, range.end)
1465 }
1466}
1467
1468impl Default for Span {
1469 fn default() -> Self {
1470 Self::Text { start: 0, end: 0 }
1471 }
1472}
1473
1474#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
1486pub struct HierarchicalConfidence {
1487 pub linkage: Confidence,
1490 pub type_score: Confidence,
1492 pub boundary: Confidence,
1495}
1496
1497impl HierarchicalConfidence {
1498 #[must_use]
1503 pub fn new(
1504 linkage: impl Into<Confidence>,
1505 type_score: impl Into<Confidence>,
1506 boundary: impl Into<Confidence>,
1507 ) -> Self {
1508 Self {
1509 linkage: linkage.into(),
1510 type_score: type_score.into(),
1511 boundary: boundary.into(),
1512 }
1513 }
1514
1515 #[must_use]
1518 pub fn from_single(confidence: impl Into<Confidence>) -> Self {
1519 let c = confidence.into();
1520 Self {
1521 linkage: c,
1522 type_score: c,
1523 boundary: c,
1524 }
1525 }
1526
1527 #[must_use]
1530 pub fn combined(&self) -> Confidence {
1531 let product = self.linkage.value() * self.type_score.value() * self.boundary.value();
1532 Confidence::new(product.powf(1.0 / 3.0))
1533 }
1534
1535 #[must_use]
1537 pub fn as_f64(&self) -> f64 {
1538 self.combined().value()
1539 }
1540
1541 #[must_use]
1543 pub fn passes_threshold(&self, linkage_min: f64, type_min: f64, boundary_min: f64) -> bool {
1544 self.linkage >= linkage_min && self.type_score >= type_min && self.boundary >= boundary_min
1545 }
1546}
1547
1548impl Default for HierarchicalConfidence {
1549 fn default() -> Self {
1550 Self {
1551 linkage: Confidence::ONE,
1552 type_score: Confidence::ONE,
1553 boundary: Confidence::ONE,
1554 }
1555 }
1556}
1557
1558impl From<f64> for HierarchicalConfidence {
1559 fn from(confidence: f64) -> Self {
1560 Self::from_single(confidence)
1561 }
1562}
1563
1564impl From<f32> for HierarchicalConfidence {
1565 fn from(confidence: f32) -> Self {
1566 Self::from_single(confidence)
1567 }
1568}
1569
1570impl From<Confidence> for HierarchicalConfidence {
1571 fn from(confidence: Confidence) -> Self {
1572 Self::from_single(confidence)
1573 }
1574}
1575
1576#[derive(Debug, Clone)]
1598pub struct RaggedBatch {
1599 pub token_ids: Vec<u32>,
1602 pub cumulative_offsets: Vec<u32>,
1606 pub max_seq_len: usize,
1608}
1609
1610impl RaggedBatch {
1611 pub fn from_sequences(sequences: &[Vec<u32>]) -> Self {
1613 let total_tokens: usize = sequences.iter().map(|s| s.len()).sum();
1614 let mut token_ids = Vec::with_capacity(total_tokens);
1615 let mut cumulative_offsets = Vec::with_capacity(sequences.len() + 1);
1616 let mut max_seq_len = 0;
1617
1618 cumulative_offsets.push(0);
1619 for seq in sequences {
1620 token_ids.extend_from_slice(seq);
1621 let len = token_ids.len();
1625 if len > u32::MAX as usize {
1626 log::warn!(
1629 "Token count {} exceeds u32::MAX, truncating to {}",
1630 len,
1631 u32::MAX
1632 );
1633 cumulative_offsets.push(u32::MAX);
1634 } else {
1635 cumulative_offsets.push(len as u32);
1636 }
1637 max_seq_len = max_seq_len.max(seq.len());
1638 }
1639
1640 Self {
1641 token_ids,
1642 cumulative_offsets,
1643 max_seq_len,
1644 }
1645 }
1646
1647 #[must_use]
1649 pub fn batch_size(&self) -> usize {
1650 self.cumulative_offsets.len().saturating_sub(1)
1651 }
1652
1653 #[must_use]
1655 pub fn total_tokens(&self) -> usize {
1656 self.token_ids.len()
1657 }
1658
1659 #[must_use]
1661 pub fn doc_range(&self, doc_idx: usize) -> Option<std::ops::Range<usize>> {
1662 if doc_idx + 1 < self.cumulative_offsets.len() {
1663 let start = self.cumulative_offsets[doc_idx] as usize;
1664 let end = self.cumulative_offsets[doc_idx + 1] as usize;
1665 Some(start..end)
1666 } else {
1667 None
1668 }
1669 }
1670
1671 #[must_use]
1673 pub fn doc_tokens(&self, doc_idx: usize) -> Option<&[u32]> {
1674 self.doc_range(doc_idx).map(|r| &self.token_ids[r])
1675 }
1676
1677 #[must_use]
1679 pub fn padding_savings(&self) -> f64 {
1680 let padded_size = self.batch_size() * self.max_seq_len;
1681 if padded_size == 0 {
1682 return 0.0;
1683 }
1684 1.0 - (self.total_tokens() as f64 / padded_size as f64)
1685 }
1686}
1687
1688#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1697pub struct SpanCandidate {
1698 pub doc_idx: u32,
1700 pub start: u32,
1702 pub end: u32,
1704}
1705
1706impl SpanCandidate {
1707 #[must_use]
1709 pub const fn new(doc_idx: u32, start: u32, end: u32) -> Self {
1710 Self {
1711 doc_idx,
1712 start,
1713 end,
1714 }
1715 }
1716
1717 #[must_use]
1719 pub const fn width(&self) -> u32 {
1720 self.end.saturating_sub(self.start)
1721 }
1722}
1723
1724pub fn generate_span_candidates(batch: &RaggedBatch, max_width: usize) -> Vec<SpanCandidate> {
1729 let mut candidates = Vec::new();
1730
1731 for doc_idx in 0..batch.batch_size() {
1732 if let Some(range) = batch.doc_range(doc_idx) {
1733 let doc_len = range.len();
1734 for start in 0..doc_len {
1736 let max_end = (start + max_width).min(doc_len);
1737 for end in (start + 1)..=max_end {
1738 candidates.push(SpanCandidate::new(doc_idx as u32, start as u32, end as u32));
1739 }
1740 }
1741 }
1742 }
1743
1744 candidates
1745}
1746
1747pub fn generate_filtered_candidates(
1751 batch: &RaggedBatch,
1752 max_width: usize,
1753 linkage_mask: &[f32],
1754 threshold: f32,
1755) -> Vec<SpanCandidate> {
1756 let mut candidates = Vec::new();
1757 let mut mask_idx = 0;
1758
1759 for doc_idx in 0..batch.batch_size() {
1760 if let Some(range) = batch.doc_range(doc_idx) {
1761 let doc_len = range.len();
1762 for start in 0..doc_len {
1763 let max_end = (start + max_width).min(doc_len);
1764 for end in (start + 1)..=max_end {
1765 if mask_idx < linkage_mask.len() && linkage_mask[mask_idx] >= threshold {
1767 candidates.push(SpanCandidate::new(
1768 doc_idx as u32,
1769 start as u32,
1770 end as u32,
1771 ));
1772 }
1773 mask_idx += 1;
1774 }
1775 }
1776 }
1777 }
1778
1779 candidates
1780}
1781
1782#[derive(Debug, Clone, Serialize, Deserialize)]
1833pub struct Entity {
1834 pub text: String,
1836 pub entity_type: EntityType,
1838 pub start: usize,
1843 pub end: usize,
1848 pub confidence: Confidence,
1853 #[serde(default, skip_serializing_if = "Option::is_none")]
1855 pub normalized: Option<String>,
1856 #[serde(default, skip_serializing_if = "Option::is_none")]
1858 pub provenance: Option<Provenance>,
1859 #[serde(default, skip_serializing_if = "Option::is_none")]
1862 pub kb_id: Option<String>,
1863 #[serde(default, skip_serializing_if = "Option::is_none")]
1867 pub canonical_id: Option<super::types::CanonicalId>,
1868 #[serde(default, skip_serializing_if = "Option::is_none")]
1871 pub hierarchical_confidence: Option<HierarchicalConfidence>,
1872 #[serde(default, skip_serializing_if = "Option::is_none")]
1875 pub visual_span: Option<Span>,
1876 #[serde(default, skip_serializing_if = "Option::is_none")]
1880 pub discontinuous_span: Option<DiscontinuousSpan>,
1881 #[serde(default, skip_serializing_if = "Option::is_none")]
1903 pub valid_from: Option<chrono::DateTime<chrono::Utc>>,
1904 #[serde(default, skip_serializing_if = "Option::is_none")]
1909 pub valid_until: Option<chrono::DateTime<chrono::Utc>>,
1910 #[serde(default, skip_serializing_if = "Option::is_none")]
1931 pub viewport: Option<EntityViewport>,
1932 #[serde(default, skip_serializing_if = "Option::is_none")]
1938 pub phi_features: Option<PhiFeatures>,
1939 #[serde(default, skip_serializing_if = "Option::is_none")]
1945 pub mention_type: Option<MentionType>,
1946}
1947
1948impl Entity {
1949 #[must_use]
1960 pub fn new(
1961 text: impl Into<String>,
1962 entity_type: EntityType,
1963 start: usize,
1964 end: usize,
1965 confidence: impl Into<Confidence>,
1966 ) -> Self {
1967 Self {
1968 text: text.into(),
1969 entity_type,
1970 start,
1971 end,
1972 confidence: confidence.into(),
1973 normalized: None,
1974 provenance: None,
1975 kb_id: None,
1976 canonical_id: None,
1977 hierarchical_confidence: None,
1978 visual_span: None,
1979 discontinuous_span: None,
1980 valid_from: None,
1981 valid_until: None,
1982 viewport: None,
1983 phi_features: None,
1984 mention_type: None,
1985 }
1986 }
1987
1988 #[must_use]
1990 pub fn with_provenance(
1991 text: impl Into<String>,
1992 entity_type: EntityType,
1993 start: usize,
1994 end: usize,
1995 confidence: impl Into<Confidence>,
1996 provenance: Provenance,
1997 ) -> Self {
1998 Self {
1999 text: text.into(),
2000 entity_type,
2001 start,
2002 end,
2003 confidence: confidence.into(),
2004 normalized: None,
2005 provenance: Some(provenance),
2006 kb_id: None,
2007 canonical_id: None,
2008 hierarchical_confidence: None,
2009 visual_span: None,
2010 discontinuous_span: None,
2011 valid_from: None,
2012 valid_until: None,
2013 viewport: None,
2014 phi_features: None,
2015 mention_type: None,
2016 }
2017 }
2018
2019 #[must_use]
2021 pub fn with_hierarchical_confidence(
2022 text: impl Into<String>,
2023 entity_type: EntityType,
2024 start: usize,
2025 end: usize,
2026 confidence: HierarchicalConfidence,
2027 ) -> Self {
2028 Self {
2029 text: text.into(),
2030 entity_type,
2031 start,
2032 end,
2033 confidence: Confidence::new(confidence.as_f64()),
2034 normalized: None,
2035 provenance: None,
2036 kb_id: None,
2037 canonical_id: None,
2038 hierarchical_confidence: Some(confidence),
2039 visual_span: None,
2040 discontinuous_span: None,
2041 valid_from: None,
2042 valid_until: None,
2043 viewport: None,
2044 phi_features: None,
2045 mention_type: None,
2046 }
2047 }
2048
2049 #[must_use]
2051 pub fn from_visual(
2052 text: impl Into<String>,
2053 entity_type: EntityType,
2054 bbox: Span,
2055 confidence: impl Into<Confidence>,
2056 ) -> Self {
2057 Self {
2058 text: text.into(),
2059 entity_type,
2060 start: 0,
2061 end: 0,
2062 confidence: confidence.into(),
2063 normalized: None,
2064 provenance: None,
2065 kb_id: None,
2066 canonical_id: None,
2067 hierarchical_confidence: None,
2068 visual_span: Some(bbox),
2069 discontinuous_span: None,
2070 valid_from: None,
2071 valid_until: None,
2072 viewport: None,
2073 phi_features: None,
2074 mention_type: None,
2075 }
2076 }
2077
2078 #[must_use]
2080 pub fn with_type(
2081 text: impl Into<String>,
2082 entity_type: EntityType,
2083 start: usize,
2084 end: usize,
2085 ) -> Self {
2086 Self::new(text, entity_type, start, end, 1.0)
2087 }
2088
2089 pub fn link_to_kb(&mut self, kb_id: impl Into<String>) {
2099 self.kb_id = Some(kb_id.into());
2100 }
2101
2102 pub fn set_canonical(&mut self, canonical_id: impl Into<super::types::CanonicalId>) {
2106 self.canonical_id = Some(canonical_id.into());
2107 }
2108
2109 #[must_use]
2119 pub fn with_canonical_id(mut self, canonical_id: impl Into<super::types::CanonicalId>) -> Self {
2120 self.canonical_id = Some(canonical_id.into());
2121 self
2122 }
2123
2124 #[must_use]
2126 pub fn is_linked(&self) -> bool {
2127 self.kb_id.is_some()
2128 }
2129
2130 #[must_use]
2132 pub fn has_coreference(&self) -> bool {
2133 self.canonical_id.is_some()
2134 }
2135
2136 #[must_use]
2142 pub fn is_discontinuous(&self) -> bool {
2143 self.discontinuous_span
2144 .as_ref()
2145 .map(|s| s.is_discontinuous())
2146 .unwrap_or(false)
2147 }
2148
2149 #[must_use]
2153 pub fn discontinuous_segments(&self) -> Option<Vec<std::ops::Range<usize>>> {
2154 self.discontinuous_span
2155 .as_ref()
2156 .filter(|s| s.is_discontinuous())
2157 .map(|s| s.segments().to_vec())
2158 }
2159
2160 pub fn set_discontinuous_span(&mut self, span: DiscontinuousSpan) {
2164 if let Some(bounding) = span.bounding_range() {
2166 self.start = bounding.start;
2167 self.end = bounding.end;
2168 }
2169 self.discontinuous_span = Some(span);
2170 }
2171
2172 #[must_use]
2180 pub fn total_len(&self) -> usize {
2181 if let Some(ref span) = self.discontinuous_span {
2182 span.segments().iter().map(|r| r.end - r.start).sum()
2183 } else {
2184 self.end.saturating_sub(self.start)
2185 }
2186 }
2187
2188 pub fn set_normalized(&mut self, normalized: impl Into<String>) {
2200 self.normalized = Some(normalized.into());
2201 }
2202
2203 #[must_use]
2205 pub fn normalized_or_text(&self) -> &str {
2206 self.normalized.as_deref().unwrap_or(&self.text)
2207 }
2208
2209 #[must_use]
2211 pub fn method(&self) -> ExtractionMethod {
2212 self.provenance
2213 .as_ref()
2214 .map_or(ExtractionMethod::Unknown, |p| p.method)
2215 }
2216
2217 #[must_use]
2219 pub fn source(&self) -> Option<&str> {
2220 self.provenance.as_ref().map(|p| p.source.as_ref())
2221 }
2222
2223 #[must_use]
2225 pub fn category(&self) -> EntityCategory {
2226 self.entity_type.category()
2227 }
2228
2229 #[must_use]
2231 pub fn is_structured(&self) -> bool {
2232 self.entity_type.pattern_detectable()
2233 }
2234
2235 #[must_use]
2237 pub fn is_named(&self) -> bool {
2238 self.entity_type.requires_ml()
2239 }
2240
2241 #[must_use]
2243 pub fn overlaps(&self, other: &Entity) -> bool {
2244 !(self.end <= other.start || other.end <= self.start)
2245 }
2246
2247 #[must_use]
2249 pub fn overlap_ratio(&self, other: &Entity) -> f64 {
2250 let intersection_start = self.start.max(other.start);
2251 let intersection_end = self.end.min(other.end);
2252
2253 if intersection_start >= intersection_end {
2254 return 0.0;
2255 }
2256
2257 let intersection = (intersection_end - intersection_start) as f64;
2258 let union = ((self.end - self.start) + (other.end - other.start)
2259 - (intersection_end - intersection_start)) as f64;
2260
2261 if union == 0.0 {
2262 return 1.0;
2263 }
2264
2265 intersection / union
2266 }
2267
2268 pub fn set_hierarchical_confidence(&mut self, confidence: HierarchicalConfidence) {
2270 self.confidence = Confidence::new(confidence.as_f64());
2271 self.hierarchical_confidence = Some(confidence);
2272 }
2273
2274 #[must_use]
2276 pub fn linkage_confidence(&self) -> Confidence {
2277 self.hierarchical_confidence
2278 .map_or(self.confidence, |h| h.linkage)
2279 }
2280
2281 #[must_use]
2283 pub fn type_confidence(&self) -> Confidence {
2284 self.hierarchical_confidence
2285 .map_or(self.confidence, |h| h.type_score)
2286 }
2287
2288 #[must_use]
2290 pub fn boundary_confidence(&self) -> Confidence {
2291 self.hierarchical_confidence
2292 .map_or(self.confidence, |h| h.boundary)
2293 }
2294
2295 #[must_use]
2297 pub fn is_visual(&self) -> bool {
2298 self.visual_span.is_some()
2299 }
2300
2301 #[must_use]
2303 pub const fn text_span(&self) -> (usize, usize) {
2304 (self.start, self.end)
2305 }
2306
2307 #[must_use]
2309 pub const fn span_len(&self) -> usize {
2310 self.end.saturating_sub(self.start)
2311 }
2312
2313 pub fn set_visual_span(&mut self, span: Span) {
2338 self.visual_span = Some(span);
2339 }
2340
2341 #[must_use]
2361 pub fn extract_text(&self, source_text: &str) -> String {
2362 let char_count = source_text.chars().count();
2366 self.extract_text_with_len(source_text, char_count)
2367 }
2368
2369 #[must_use]
2381 pub fn extract_text_with_len(&self, source_text: &str, text_char_count: usize) -> String {
2382 if self.start >= text_char_count || self.end > text_char_count || self.start >= self.end {
2383 return String::new();
2384 }
2385 source_text
2386 .chars()
2387 .skip(self.start)
2388 .take(self.end - self.start)
2389 .collect()
2390 }
2391
2392 pub fn set_valid_from(&mut self, dt: chrono::DateTime<chrono::Utc>) {
2408 self.valid_from = Some(dt);
2409 }
2410
2411 pub fn set_valid_until(&mut self, dt: chrono::DateTime<chrono::Utc>) {
2413 self.valid_until = Some(dt);
2414 }
2415
2416 pub fn set_temporal_range(
2418 &mut self,
2419 from: chrono::DateTime<chrono::Utc>,
2420 until: chrono::DateTime<chrono::Utc>,
2421 ) {
2422 self.valid_from = Some(from);
2423 self.valid_until = Some(until);
2424 }
2425
2426 #[must_use]
2428 pub fn is_temporal(&self) -> bool {
2429 self.valid_from.is_some() || self.valid_until.is_some()
2430 }
2431
2432 #[must_use]
2454 pub fn valid_at(&self, timestamp: &chrono::DateTime<chrono::Utc>) -> bool {
2455 match (&self.valid_from, &self.valid_until) {
2456 (None, None) => true, (Some(from), None) => timestamp >= from, (None, Some(until)) => timestamp <= until, (Some(from), Some(until)) => timestamp >= from && timestamp <= until,
2460 }
2461 }
2462
2463 #[must_use]
2465 pub fn is_currently_valid(&self) -> bool {
2466 self.valid_at(&chrono::Utc::now())
2467 }
2468
2469 pub fn set_viewport(&mut self, viewport: EntityViewport) {
2484 self.viewport = Some(viewport);
2485 }
2486
2487 #[must_use]
2489 pub fn has_viewport(&self) -> bool {
2490 self.viewport.is_some()
2491 }
2492
2493 #[must_use]
2495 pub fn viewport_or_default(&self) -> EntityViewport {
2496 self.viewport.clone().unwrap_or_default()
2497 }
2498
2499 #[must_use]
2505 pub fn matches_viewport(&self, query_viewport: &EntityViewport) -> bool {
2506 match &self.viewport {
2507 None => true, Some(v) => v == query_viewport,
2509 }
2510 }
2511
2512 #[must_use]
2514 pub fn builder(text: impl Into<String>, entity_type: EntityType) -> EntityBuilder {
2515 EntityBuilder::new(text, entity_type)
2516 }
2517
2518 #[must_use]
2551 pub fn validate(&self, source_text: &str) -> Vec<ValidationIssue> {
2552 let char_count = source_text.chars().count();
2554 self.validate_with_len(source_text, char_count)
2555 }
2556
2557 #[must_use]
2569 pub fn validate_with_len(
2570 &self,
2571 source_text: &str,
2572 text_char_count: usize,
2573 ) -> Vec<ValidationIssue> {
2574 let mut issues = Vec::new();
2575
2576 if self.start >= self.end {
2578 issues.push(ValidationIssue::InvalidSpan {
2579 start: self.start,
2580 end: self.end,
2581 reason: "start must be less than end".to_string(),
2582 });
2583 }
2584
2585 if self.end > text_char_count {
2586 issues.push(ValidationIssue::SpanOutOfBounds {
2587 end: self.end,
2588 text_len: text_char_count,
2589 });
2590 }
2591
2592 if self.start < self.end && self.end <= text_char_count {
2594 let actual = self.extract_text_with_len(source_text, text_char_count);
2595 if actual != self.text {
2596 issues.push(ValidationIssue::TextMismatch {
2597 expected: self.text.clone(),
2598 actual,
2599 start: self.start,
2600 end: self.end,
2601 });
2602 }
2603 }
2604
2605 if let EntityType::Custom { ref name, .. } = self.entity_type {
2609 if name.is_empty() {
2610 issues.push(ValidationIssue::InvalidType {
2611 reason: "Custom entity type has empty name".to_string(),
2612 });
2613 }
2614 }
2615
2616 if let Some(ref disc_span) = self.discontinuous_span {
2618 for (i, seg) in disc_span.segments().iter().enumerate() {
2619 if seg.start >= seg.end {
2620 issues.push(ValidationIssue::InvalidSpan {
2621 start: seg.start,
2622 end: seg.end,
2623 reason: format!("discontinuous segment {} is invalid", i),
2624 });
2625 }
2626 if seg.end > text_char_count {
2627 issues.push(ValidationIssue::SpanOutOfBounds {
2628 end: seg.end,
2629 text_len: text_char_count,
2630 });
2631 }
2632 }
2633 }
2634
2635 issues
2636 }
2637
2638 #[must_use]
2642 pub fn is_valid(&self, source_text: &str) -> bool {
2643 self.validate(source_text).is_empty()
2644 }
2645
2646 #[must_use]
2666 pub fn validate_batch(
2667 entities: &[Entity],
2668 source_text: &str,
2669 ) -> std::collections::HashMap<usize, Vec<ValidationIssue>> {
2670 entities
2671 .iter()
2672 .enumerate()
2673 .filter_map(|(idx, entity)| {
2674 let issues = entity.validate(source_text);
2675 if issues.is_empty() {
2676 None
2677 } else {
2678 Some((idx, issues))
2679 }
2680 })
2681 .collect()
2682 }
2683}
2684
2685#[derive(Debug, Clone, PartialEq)]
2687pub enum ValidationIssue {
2688 InvalidSpan {
2690 start: usize,
2692 end: usize,
2694 reason: String,
2696 },
2697 SpanOutOfBounds {
2699 end: usize,
2701 text_len: usize,
2703 },
2704 TextMismatch {
2706 expected: String,
2708 actual: String,
2710 start: usize,
2712 end: usize,
2714 },
2715 InvalidConfidence {
2717 value: f64,
2719 },
2720 InvalidType {
2722 reason: String,
2724 },
2725}
2726
2727impl std::fmt::Display for ValidationIssue {
2728 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2729 match self {
2730 ValidationIssue::InvalidSpan { start, end, reason } => {
2731 write!(f, "Invalid span [{}, {}): {}", start, end, reason)
2732 }
2733 ValidationIssue::SpanOutOfBounds { end, text_len } => {
2734 write!(f, "Span end {} exceeds text length {}", end, text_len)
2735 }
2736 ValidationIssue::TextMismatch {
2737 expected,
2738 actual,
2739 start,
2740 end,
2741 } => {
2742 write!(
2743 f,
2744 "Text mismatch at [{}, {}): expected '{}', got '{}'",
2745 start, end, expected, actual
2746 )
2747 }
2748 ValidationIssue::InvalidConfidence { value } => {
2749 write!(f, "Confidence {} outside [0.0, 1.0]", value)
2750 }
2751 ValidationIssue::InvalidType { reason } => {
2752 write!(f, "Invalid entity type: {}", reason)
2753 }
2754 }
2755 }
2756}
2757
2758#[derive(Debug, Clone)]
2773pub struct EntityBuilder {
2774 text: String,
2775 entity_type: EntityType,
2776 start: usize,
2777 end: usize,
2778 confidence: Confidence,
2779 normalized: Option<String>,
2780 provenance: Option<Provenance>,
2781 kb_id: Option<String>,
2782 canonical_id: Option<super::types::CanonicalId>,
2783 hierarchical_confidence: Option<HierarchicalConfidence>,
2784 visual_span: Option<Span>,
2785 discontinuous_span: Option<DiscontinuousSpan>,
2786 valid_from: Option<chrono::DateTime<chrono::Utc>>,
2787 valid_until: Option<chrono::DateTime<chrono::Utc>>,
2788 viewport: Option<EntityViewport>,
2789 phi_features: Option<PhiFeatures>,
2790 mention_type: Option<MentionType>,
2791}
2792
2793impl EntityBuilder {
2794 #[must_use]
2796 pub fn new(text: impl Into<String>, entity_type: EntityType) -> Self {
2797 let text = text.into();
2798 let end = text.chars().count();
2799 Self {
2800 text,
2801 entity_type,
2802 start: 0,
2803 end,
2804 confidence: Confidence::ONE,
2805 normalized: None,
2806 provenance: None,
2807 kb_id: None,
2808 canonical_id: None,
2809 hierarchical_confidence: None,
2810 visual_span: None,
2811 discontinuous_span: None,
2812 valid_from: None,
2813 valid_until: None,
2814 viewport: None,
2815 phi_features: None,
2816 mention_type: None,
2817 }
2818 }
2819
2820 #[must_use]
2822 pub const fn span(mut self, start: usize, end: usize) -> Self {
2823 self.start = start;
2824 self.end = end;
2825 self
2826 }
2827
2828 #[must_use]
2830 pub fn confidence(mut self, confidence: impl Into<Confidence>) -> Self {
2831 self.confidence = confidence.into();
2832 self
2833 }
2834
2835 #[must_use]
2837 pub fn hierarchical_confidence(mut self, confidence: HierarchicalConfidence) -> Self {
2838 self.confidence = Confidence::new(confidence.as_f64());
2839 self.hierarchical_confidence = Some(confidence);
2840 self
2841 }
2842
2843 #[must_use]
2845 pub fn normalized(mut self, normalized: impl Into<String>) -> Self {
2846 self.normalized = Some(normalized.into());
2847 self
2848 }
2849
2850 #[must_use]
2852 pub fn provenance(mut self, provenance: Provenance) -> Self {
2853 self.provenance = Some(provenance);
2854 self
2855 }
2856
2857 #[must_use]
2859 pub fn kb_id(mut self, kb_id: impl Into<String>) -> Self {
2860 self.kb_id = Some(kb_id.into());
2861 self
2862 }
2863
2864 #[must_use]
2866 pub const fn canonical_id(mut self, canonical_id: u64) -> Self {
2867 self.canonical_id = Some(super::types::CanonicalId::new(canonical_id));
2868 self
2869 }
2870
2871 #[must_use]
2873 pub fn visual_span(mut self, span: Span) -> Self {
2874 self.visual_span = Some(span);
2875 self
2876 }
2877
2878 #[must_use]
2882 pub fn discontinuous_span(mut self, span: DiscontinuousSpan) -> Self {
2883 if let Some(bounding) = span.bounding_range() {
2885 self.start = bounding.start;
2886 self.end = bounding.end;
2887 }
2888 self.discontinuous_span = Some(span);
2889 self
2890 }
2891
2892 #[must_use]
2906 pub fn valid_from(mut self, dt: chrono::DateTime<chrono::Utc>) -> Self {
2907 self.valid_from = Some(dt);
2908 self
2909 }
2910
2911 #[must_use]
2913 pub fn valid_until(mut self, dt: chrono::DateTime<chrono::Utc>) -> Self {
2914 self.valid_until = Some(dt);
2915 self
2916 }
2917
2918 #[must_use]
2920 pub fn temporal_range(
2921 mut self,
2922 from: chrono::DateTime<chrono::Utc>,
2923 until: chrono::DateTime<chrono::Utc>,
2924 ) -> Self {
2925 self.valid_from = Some(from);
2926 self.valid_until = Some(until);
2927 self
2928 }
2929
2930 #[must_use]
2943 pub fn viewport(mut self, viewport: EntityViewport) -> Self {
2944 self.viewport = Some(viewport);
2945 self
2946 }
2947
2948 #[must_use]
2950 pub fn phi_features(mut self, phi_features: PhiFeatures) -> Self {
2951 self.phi_features = Some(phi_features);
2952 self
2953 }
2954
2955 #[must_use]
2957 pub fn mention_type(mut self, mention_type: MentionType) -> Self {
2958 self.mention_type = Some(mention_type);
2959 self
2960 }
2961
2962 #[must_use]
2964 pub fn build(self) -> Entity {
2965 Entity {
2966 text: self.text,
2967 entity_type: self.entity_type,
2968 start: self.start,
2969 end: self.end,
2970 confidence: self.confidence,
2971 normalized: self.normalized,
2972 provenance: self.provenance,
2973 kb_id: self.kb_id,
2974 canonical_id: self.canonical_id,
2975 hierarchical_confidence: self.hierarchical_confidence,
2976 visual_span: self.visual_span,
2977 discontinuous_span: self.discontinuous_span,
2978 valid_from: self.valid_from,
2979 valid_until: self.valid_until,
2980 viewport: self.viewport,
2981 phi_features: self.phi_features,
2982 mention_type: self.mention_type,
2983 }
2984 }
2985}
2986
2987#[derive(Debug, Clone, Serialize, Deserialize)]
3013pub struct Relation {
3014 pub head: Entity,
3016 pub tail: Entity,
3018 pub relation_type: String,
3020 pub trigger_span: Option<(usize, usize)>,
3023 pub confidence: Confidence,
3025}
3026
3027impl Relation {
3028 #[must_use]
3030 pub fn new(
3031 head: Entity,
3032 tail: Entity,
3033 relation_type: impl Into<String>,
3034 confidence: impl Into<Confidence>,
3035 ) -> Self {
3036 Self {
3037 head,
3038 tail,
3039 relation_type: relation_type.into(),
3040 trigger_span: None,
3041 confidence: confidence.into(),
3042 }
3043 }
3044
3045 #[must_use]
3047 pub fn with_trigger(
3048 head: Entity,
3049 tail: Entity,
3050 relation_type: impl Into<String>,
3051 trigger_start: usize,
3052 trigger_end: usize,
3053 confidence: impl Into<Confidence>,
3054 ) -> Self {
3055 Self {
3056 head,
3057 tail,
3058 relation_type: relation_type.into(),
3059 trigger_span: Some((trigger_start, trigger_end)),
3060 confidence: confidence.into(),
3061 }
3062 }
3063
3064 #[must_use]
3066 pub fn as_triple(&self) -> String {
3067 format!(
3068 "({}, {}, {})",
3069 self.head.text, self.relation_type, self.tail.text
3070 )
3071 }
3072
3073 #[must_use]
3076 pub fn span_distance(&self) -> usize {
3077 if self.head.end <= self.tail.start {
3078 self.tail.start.saturating_sub(self.head.end)
3079 } else if self.tail.end <= self.head.start {
3080 self.head.start.saturating_sub(self.tail.end)
3081 } else {
3082 0 }
3084 }
3085}
3086
3087#[cfg(test)]
3088mod tests {
3089 #![allow(clippy::unwrap_used)] use super::*;
3091
3092 #[test]
3093 fn test_entity_type_roundtrip() {
3094 let types = [
3095 EntityType::Person,
3096 EntityType::Organization,
3097 EntityType::Location,
3098 EntityType::Date,
3099 EntityType::Money,
3100 EntityType::Percent,
3101 ];
3102
3103 for t in types {
3104 let label = t.as_label();
3105 let parsed = EntityType::from_label(label);
3106 assert_eq!(t, parsed);
3107 }
3108 }
3109
3110 #[test]
3111 fn test_entity_overlap() {
3112 let e1 = Entity::new("John", EntityType::Person, 0, 4, 0.9);
3113 let e2 = Entity::new("Smith", EntityType::Person, 5, 10, 0.9);
3114 let e3 = Entity::new("John Smith", EntityType::Person, 0, 10, 0.9);
3115
3116 assert!(!e1.overlaps(&e2)); assert!(e1.overlaps(&e3)); assert!(e3.overlaps(&e2)); }
3120
3121 #[test]
3122 fn test_confidence_clamping() {
3123 let e1 = Entity::new("test", EntityType::Person, 0, 4, 1.5);
3124 assert!((e1.confidence - 1.0).abs() < f64::EPSILON);
3125
3126 let e2 = Entity::new("test", EntityType::Person, 0, 4, -0.5);
3127 assert!(e2.confidence.abs() < f64::EPSILON);
3128 }
3129
3130 #[test]
3131 fn test_entity_categories() {
3132 assert_eq!(EntityType::Person.category(), EntityCategory::Agent);
3134 assert_eq!(
3135 EntityType::Organization.category(),
3136 EntityCategory::Organization
3137 );
3138 assert_eq!(EntityType::Location.category(), EntityCategory::Place);
3139 assert!(EntityType::Person.requires_ml());
3140 assert!(!EntityType::Person.pattern_detectable());
3141
3142 assert_eq!(EntityType::Date.category(), EntityCategory::Temporal);
3144 assert_eq!(EntityType::Time.category(), EntityCategory::Temporal);
3145 assert!(EntityType::Date.pattern_detectable());
3146 assert!(!EntityType::Date.requires_ml());
3147
3148 assert_eq!(EntityType::Money.category(), EntityCategory::Numeric);
3150 assert_eq!(EntityType::Percent.category(), EntityCategory::Numeric);
3151 assert!(EntityType::Money.pattern_detectable());
3152
3153 assert_eq!(EntityType::Email.category(), EntityCategory::Contact);
3155 assert_eq!(EntityType::Url.category(), EntityCategory::Contact);
3156 assert_eq!(EntityType::Phone.category(), EntityCategory::Contact);
3157 assert!(EntityType::Email.pattern_detectable());
3158 }
3159
3160 #[test]
3161 fn test_new_types_roundtrip() {
3162 let types = [
3163 EntityType::Time,
3164 EntityType::Email,
3165 EntityType::Url,
3166 EntityType::Phone,
3167 EntityType::Quantity,
3168 EntityType::Cardinal,
3169 EntityType::Ordinal,
3170 ];
3171
3172 for t in types {
3173 let label = t.as_label();
3174 let parsed = EntityType::from_label(label);
3175 assert_eq!(t, parsed, "Roundtrip failed for {}", label);
3176 }
3177 }
3178
3179 #[test]
3180 fn test_custom_entity_type() {
3181 let disease = EntityType::custom("DISEASE", EntityCategory::Agent);
3182 assert_eq!(disease.as_label(), "DISEASE");
3183 assert!(disease.requires_ml());
3184
3185 let product_id = EntityType::custom("PRODUCT_ID", EntityCategory::Misc);
3186 assert_eq!(product_id.as_label(), "PRODUCT_ID");
3187 assert!(!product_id.requires_ml());
3188 assert!(!product_id.pattern_detectable());
3189 }
3190
3191 #[test]
3192 fn test_entity_normalization() {
3193 let mut e = Entity::new("Jan 15", EntityType::Date, 0, 6, 0.95);
3194 assert!(e.normalized.is_none());
3195 assert_eq!(e.normalized_or_text(), "Jan 15");
3196
3197 e.set_normalized("2024-01-15");
3198 assert_eq!(e.normalized.as_deref(), Some("2024-01-15"));
3199 assert_eq!(e.normalized_or_text(), "2024-01-15");
3200 }
3201
3202 #[test]
3203 fn test_entity_helpers() {
3204 let named = Entity::new("John", EntityType::Person, 0, 4, 0.9);
3205 assert!(named.is_named());
3206 assert!(!named.is_structured());
3207 assert_eq!(named.category(), EntityCategory::Agent);
3208
3209 let structured = Entity::new("$100", EntityType::Money, 0, 4, 0.95);
3210 assert!(!structured.is_named());
3211 assert!(structured.is_structured());
3212 assert_eq!(structured.category(), EntityCategory::Numeric);
3213 }
3214
3215 #[test]
3216 fn test_knowledge_linking() {
3217 let mut entity = Entity::new("Marie Curie", EntityType::Person, 0, 11, 0.95);
3218 assert!(!entity.is_linked());
3219 assert!(!entity.has_coreference());
3220
3221 entity.link_to_kb("Q7186"); assert!(entity.is_linked());
3223 assert_eq!(entity.kb_id.as_deref(), Some("Q7186"));
3224
3225 entity.set_canonical(42);
3226 assert!(entity.has_coreference());
3227 assert_eq!(
3228 entity.canonical_id,
3229 Some(crate::core::types::CanonicalId::new(42))
3230 );
3231 }
3232
3233 #[test]
3234 fn test_relation_creation() {
3235 let head = Entity::new("Marie Curie", EntityType::Person, 0, 11, 0.95);
3236 let tail = Entity::new("Sorbonne", EntityType::Organization, 24, 32, 0.90);
3237
3238 let relation = Relation::new(head.clone(), tail.clone(), "WORKED_AT", 0.85);
3239 assert_eq!(relation.relation_type, "WORKED_AT");
3240 assert_eq!(relation.as_triple(), "(Marie Curie, WORKED_AT, Sorbonne)");
3241 assert!(relation.trigger_span.is_none());
3242
3243 let relation2 = Relation::with_trigger(head, tail, "EMPLOYMENT", 13, 19, 0.85);
3245 assert_eq!(relation2.trigger_span, Some((13, 19)));
3246 }
3247
3248 #[test]
3249 fn test_relation_span_distance() {
3250 let head = Entity::new("Marie Curie", EntityType::Person, 0, 11, 0.95);
3252 let tail = Entity::new("Sorbonne", EntityType::Organization, 24, 32, 0.90);
3253 let relation = Relation::new(head, tail, "WORKED_AT", 0.85);
3254 assert_eq!(relation.span_distance(), 13);
3255 }
3256
3257 #[test]
3258 fn test_relation_category() {
3259 let rel_type = EntityType::custom("CEO_OF", EntityCategory::Relation);
3261 assert_eq!(rel_type.category(), EntityCategory::Relation);
3262 assert!(rel_type.category().is_relation());
3263 assert!(rel_type.requires_ml()); }
3265
3266 #[test]
3271 fn test_span_text() {
3272 let span = Span::text(10, 20);
3273 assert!(span.is_text());
3274 assert!(!span.is_visual());
3275 assert_eq!(span.text_offsets(), Some((10, 20)));
3276 assert_eq!(span.len(), 10);
3277 assert!(!span.is_empty());
3278 }
3279
3280 #[test]
3281 fn test_span_bbox() {
3282 let span = Span::bbox(0.1, 0.2, 0.3, 0.4);
3283 assert!(!span.is_text());
3284 assert!(span.is_visual());
3285 assert_eq!(span.text_offsets(), None);
3286 assert_eq!(span.len(), 0); }
3288
3289 #[test]
3290 fn test_span_bbox_with_page() {
3291 let span = Span::bbox_on_page(0.1, 0.2, 0.3, 0.4, 5);
3292 if let Span::BoundingBox { page, .. } = span {
3293 assert_eq!(page, Some(5));
3294 } else {
3295 panic!("Expected BoundingBox");
3296 }
3297 }
3298
3299 #[test]
3300 fn test_span_hybrid() {
3301 let bbox = Span::bbox(0.1, 0.2, 0.3, 0.4);
3302 let hybrid = Span::Hybrid {
3303 start: 10,
3304 end: 20,
3305 bbox: Box::new(bbox),
3306 };
3307 assert!(hybrid.is_text());
3308 assert!(hybrid.is_visual());
3309 assert_eq!(hybrid.text_offsets(), Some((10, 20)));
3310 assert_eq!(hybrid.len(), 10);
3311 }
3312
3313 #[test]
3318 fn test_hierarchical_confidence_new() {
3319 let hc = HierarchicalConfidence::new(0.9, 0.8, 0.7);
3320 assert!((hc.linkage - 0.9).abs() < f64::EPSILON);
3321 assert!((hc.type_score - 0.8).abs() < f64::EPSILON);
3322 assert!((hc.boundary - 0.7).abs() < f64::EPSILON);
3323 }
3324
3325 #[test]
3326 fn test_hierarchical_confidence_clamping() {
3327 let hc = HierarchicalConfidence::new(1.5, -0.5, 0.5);
3328 assert_eq!(hc.linkage, 1.0);
3329 assert_eq!(hc.type_score, 0.0);
3330 assert_eq!(hc.boundary, 0.5);
3331 }
3332
3333 #[test]
3334 fn test_hierarchical_confidence_from_single() {
3335 let hc = HierarchicalConfidence::from_single(0.8);
3336 assert!((hc.linkage - 0.8).abs() < f64::EPSILON);
3337 assert!((hc.type_score - 0.8).abs() < f64::EPSILON);
3338 assert!((hc.boundary - 0.8).abs() < f64::EPSILON);
3339 }
3340
3341 #[test]
3342 fn test_hierarchical_confidence_combined() {
3343 let hc = HierarchicalConfidence::new(1.0, 1.0, 1.0);
3344 assert!((hc.combined() - 1.0).abs() < f64::EPSILON);
3345
3346 let hc2 = HierarchicalConfidence::new(0.8, 0.8, 0.8);
3347 assert!((hc2.combined() - 0.8).abs() < 0.001);
3348
3349 let hc3 = HierarchicalConfidence::new(0.5, 0.5, 0.5);
3351 assert!((hc3.combined() - 0.5).abs() < 0.001);
3352 }
3353
3354 #[test]
3355 fn test_hierarchical_confidence_threshold() {
3356 let hc = HierarchicalConfidence::new(0.9, 0.8, 0.7);
3357 assert!(hc.passes_threshold(0.5, 0.5, 0.5));
3358 assert!(hc.passes_threshold(0.9, 0.8, 0.7));
3359 assert!(!hc.passes_threshold(0.95, 0.8, 0.7)); assert!(!hc.passes_threshold(0.9, 0.85, 0.7)); }
3362
3363 #[test]
3364 fn test_hierarchical_confidence_from_f64() {
3365 let hc: HierarchicalConfidence = 0.85_f64.into();
3366 assert!((hc.linkage - 0.85).abs() < 0.001);
3367 }
3368
3369 #[test]
3374 fn test_ragged_batch_from_sequences() {
3375 let seqs = vec![vec![1, 2, 3], vec![4, 5], vec![6, 7, 8, 9]];
3376 let batch = RaggedBatch::from_sequences(&seqs);
3377
3378 assert_eq!(batch.batch_size(), 3);
3379 assert_eq!(batch.total_tokens(), 9);
3380 assert_eq!(batch.max_seq_len, 4);
3381 assert_eq!(batch.cumulative_offsets, vec![0, 3, 5, 9]);
3382 }
3383
3384 #[test]
3385 fn test_ragged_batch_doc_range() {
3386 let seqs = vec![vec![1, 2, 3], vec![4, 5]];
3387 let batch = RaggedBatch::from_sequences(&seqs);
3388
3389 assert_eq!(batch.doc_range(0), Some(0..3));
3390 assert_eq!(batch.doc_range(1), Some(3..5));
3391 assert_eq!(batch.doc_range(2), None);
3392 }
3393
3394 #[test]
3395 fn test_ragged_batch_doc_tokens() {
3396 let seqs = vec![vec![1, 2, 3], vec![4, 5]];
3397 let batch = RaggedBatch::from_sequences(&seqs);
3398
3399 assert_eq!(batch.doc_tokens(0), Some(&[1, 2, 3][..]));
3400 assert_eq!(batch.doc_tokens(1), Some(&[4, 5][..]));
3401 }
3402
3403 #[test]
3404 fn test_ragged_batch_padding_savings() {
3405 let seqs = vec![vec![1, 2, 3], vec![4, 5], vec![6, 7, 8, 9]];
3409 let batch = RaggedBatch::from_sequences(&seqs);
3410 let savings = batch.padding_savings();
3411 assert!((savings - 0.25).abs() < 0.001);
3412 }
3413
3414 #[test]
3419 fn test_span_candidate() {
3420 let sc = SpanCandidate::new(0, 5, 10);
3421 assert_eq!(sc.doc_idx, 0);
3422 assert_eq!(sc.start, 5);
3423 assert_eq!(sc.end, 10);
3424 assert_eq!(sc.width(), 5);
3425 }
3426
3427 #[test]
3428 fn test_generate_span_candidates() {
3429 let seqs = vec![vec![1, 2, 3]]; let batch = RaggedBatch::from_sequences(&seqs);
3431 let candidates = generate_span_candidates(&batch, 2);
3432
3433 assert_eq!(candidates.len(), 5);
3436
3437 for c in &candidates {
3439 assert_eq!(c.doc_idx, 0);
3440 assert!(c.end as usize <= 3);
3441 assert!(c.width() as usize <= 2);
3442 }
3443 }
3444
3445 #[test]
3446 fn test_generate_filtered_candidates() {
3447 let seqs = vec![vec![1, 2, 3]];
3448 let batch = RaggedBatch::from_sequences(&seqs);
3449
3450 let mask = vec![0.9, 0.9, 0.1, 0.1, 0.1];
3453 let candidates = generate_filtered_candidates(&batch, 2, &mask, 0.5);
3454
3455 assert_eq!(candidates.len(), 2);
3456 }
3457
3458 #[test]
3463 fn test_entity_builder_basic() {
3464 let entity = Entity::builder("John", EntityType::Person)
3465 .span(0, 4)
3466 .confidence(0.95)
3467 .build();
3468
3469 assert_eq!(entity.text, "John");
3470 assert_eq!(entity.entity_type, EntityType::Person);
3471 assert_eq!(entity.start, 0);
3472 assert_eq!(entity.end, 4);
3473 assert!((entity.confidence - 0.95).abs() < f64::EPSILON);
3474 }
3475
3476 #[test]
3477 fn test_entity_builder_full() {
3478 let entity = Entity::builder("Marie Curie", EntityType::Person)
3479 .span(0, 11)
3480 .confidence(0.95)
3481 .kb_id("Q7186")
3482 .canonical_id(42)
3483 .normalized("Marie Salomea Skłodowska Curie")
3484 .provenance(Provenance::ml("bert", 0.95))
3485 .build();
3486
3487 assert_eq!(entity.text, "Marie Curie");
3488 assert_eq!(entity.kb_id.as_deref(), Some("Q7186"));
3489 assert_eq!(
3490 entity.canonical_id,
3491 Some(crate::core::types::CanonicalId::new(42))
3492 );
3493 assert_eq!(
3494 entity.normalized.as_deref(),
3495 Some("Marie Salomea Skłodowska Curie")
3496 );
3497 assert!(entity.provenance.is_some());
3498 }
3499
3500 #[test]
3501 fn test_entity_builder_hierarchical() {
3502 let hc = HierarchicalConfidence::new(0.9, 0.8, 0.7);
3503 let entity = Entity::builder("test", EntityType::Person)
3504 .span(0, 4)
3505 .hierarchical_confidence(hc)
3506 .build();
3507
3508 assert!(entity.hierarchical_confidence.is_some());
3509 assert!((entity.linkage_confidence() - 0.9).abs() < 0.001);
3510 assert!((entity.type_confidence() - 0.8).abs() < 0.001);
3511 assert!((entity.boundary_confidence() - 0.7).abs() < 0.001);
3512 }
3513
3514 #[test]
3515 fn test_entity_builder_visual() {
3516 let bbox = Span::bbox(0.1, 0.2, 0.3, 0.4);
3517 let entity = Entity::builder("receipt item", EntityType::Money)
3518 .visual_span(bbox)
3519 .confidence(0.9)
3520 .build();
3521
3522 assert!(entity.is_visual());
3523 assert!(entity.visual_span.is_some());
3524 }
3525
3526 #[test]
3531 fn test_entity_hierarchical_confidence_helpers() {
3532 let mut entity = Entity::new("test", EntityType::Person, 0, 4, 0.8);
3533
3534 assert!((entity.linkage_confidence() - 0.8).abs() < 0.001);
3536 assert!((entity.type_confidence() - 0.8).abs() < 0.001);
3537 assert!((entity.boundary_confidence() - 0.8).abs() < 0.001);
3538
3539 entity.set_hierarchical_confidence(HierarchicalConfidence::new(0.95, 0.85, 0.75));
3541 assert!((entity.linkage_confidence() - 0.95).abs() < 0.001);
3542 assert!((entity.type_confidence() - 0.85).abs() < 0.001);
3543 assert!((entity.boundary_confidence() - 0.75).abs() < 0.001);
3544 }
3545
3546 #[test]
3547 fn test_entity_from_visual() {
3548 let entity = Entity::from_visual(
3549 "receipt total",
3550 EntityType::Money,
3551 Span::bbox(0.5, 0.8, 0.2, 0.05),
3552 0.92,
3553 );
3554
3555 assert!(entity.is_visual());
3556 assert_eq!(entity.start, 0);
3557 assert_eq!(entity.end, 0);
3558 assert!((entity.confidence - 0.92).abs() < f64::EPSILON);
3559 }
3560
3561 #[test]
3562 fn test_entity_span_helpers() {
3563 let entity = Entity::new("test", EntityType::Person, 10, 20, 0.9);
3564 assert_eq!(entity.text_span(), (10, 20));
3565 assert_eq!(entity.span_len(), 10);
3566 }
3567
3568 #[test]
3573 fn test_provenance_pattern() {
3574 let prov = Provenance::pattern("EMAIL");
3575 assert_eq!(prov.method, ExtractionMethod::Pattern);
3576 assert_eq!(prov.pattern.as_deref(), Some("EMAIL"));
3577 assert_eq!(prov.raw_confidence, Some(Confidence::new(1.0))); }
3579
3580 #[test]
3581 fn test_provenance_ml() {
3582 let prov = Provenance::ml("bert-ner", 0.87);
3583 assert_eq!(prov.method, ExtractionMethod::Neural);
3584 assert_eq!(prov.source.as_ref(), "bert-ner");
3585 assert_eq!(prov.raw_confidence, Some(Confidence::new(0.87)));
3586 }
3587
3588 #[test]
3589 fn test_provenance_with_version() {
3590 let prov = Provenance::ml("gliner", 0.92).with_version("v2.1.0");
3591
3592 assert_eq!(prov.model_version.as_deref(), Some("v2.1.0"));
3593 assert_eq!(prov.source.as_ref(), "gliner");
3594 }
3595
3596 #[test]
3597 fn test_provenance_with_timestamp() {
3598 let prov = Provenance::pattern("DATE").with_timestamp("2024-01-15T10:30:00Z");
3599
3600 assert_eq!(prov.timestamp.as_deref(), Some("2024-01-15T10:30:00Z"));
3601 }
3602
3603 #[test]
3604 fn test_provenance_builder_chain() {
3605 let prov = Provenance::ml("modernbert-ner", 0.95)
3606 .with_version("v1.0.0")
3607 .with_timestamp("2024-11-27T12:00:00Z");
3608
3609 assert_eq!(prov.method, ExtractionMethod::Neural);
3610 assert_eq!(prov.source.as_ref(), "modernbert-ner");
3611 assert_eq!(prov.raw_confidence, Some(Confidence::new(0.95)));
3612 assert_eq!(prov.model_version.as_deref(), Some("v1.0.0"));
3613 assert_eq!(prov.timestamp.as_deref(), Some("2024-11-27T12:00:00Z"));
3614 }
3615
3616 #[test]
3617 fn test_provenance_serialization() {
3618 let prov = Provenance::ml("test", 0.9)
3619 .with_version("v1.0")
3620 .with_timestamp("2024-01-01");
3621
3622 let json = serde_json::to_string(&prov).unwrap();
3623 assert!(json.contains("model_version"));
3624 assert!(json.contains("v1.0"));
3625
3626 let restored: Provenance = serde_json::from_str(&json).unwrap();
3627 assert_eq!(restored.model_version.as_deref(), Some("v1.0"));
3628 assert_eq!(restored.timestamp.as_deref(), Some("2024-01-01"));
3629 }
3630}
3631
3632#[cfg(test)]
3633mod proptests {
3634 #![allow(clippy::unwrap_used)] use super::*;
3636 use proptest::prelude::*;
3637
3638 proptest! {
3639 #[test]
3640 fn confidence_always_clamped(conf in -10.0f64..10.0) {
3641 let e = Entity::new("test", EntityType::Person, 0, 4, conf);
3642 prop_assert!(e.confidence >= 0.0);
3643 prop_assert!(e.confidence <= 1.0);
3644 }
3645
3646 #[test]
3647 fn entity_type_roundtrip(label in "[A-Z]{3,10}") {
3648 let et = EntityType::from_label(&label);
3649 let back = EntityType::from_label(et.as_label());
3650 let is_custom = matches!(back, EntityType::Custom { .. });
3652 prop_assert!(is_custom || back == et);
3653 }
3654
3655 #[test]
3656 fn overlap_is_symmetric(
3657 s1 in 0usize..100,
3658 len1 in 1usize..50,
3659 s2 in 0usize..100,
3660 len2 in 1usize..50,
3661 ) {
3662 let e1 = Entity::new("a", EntityType::Person, s1, s1 + len1, 1.0);
3663 let e2 = Entity::new("b", EntityType::Person, s2, s2 + len2, 1.0);
3664 prop_assert_eq!(e1.overlaps(&e2), e2.overlaps(&e1));
3665 }
3666
3667 #[test]
3668 fn overlap_ratio_bounded(
3669 s1 in 0usize..100,
3670 len1 in 1usize..50,
3671 s2 in 0usize..100,
3672 len2 in 1usize..50,
3673 ) {
3674 let e1 = Entity::new("a", EntityType::Person, s1, s1 + len1, 1.0);
3675 let e2 = Entity::new("b", EntityType::Person, s2, s2 + len2, 1.0);
3676 let ratio = e1.overlap_ratio(&e2);
3677 prop_assert!(ratio >= 0.0);
3678 prop_assert!(ratio <= 1.0);
3679 }
3680
3681 #[test]
3682 fn self_overlap_ratio_is_one(s in 0usize..100, len in 1usize..50) {
3683 let e = Entity::new("test", EntityType::Person, s, s + len, 1.0);
3684 let ratio = e.overlap_ratio(&e);
3685 prop_assert!((ratio - 1.0).abs() < 1e-10);
3686 }
3687
3688 #[test]
3689 fn hierarchical_confidence_always_clamped(
3690 linkage in -2.0f32..2.0,
3691 type_score in -2.0f32..2.0,
3692 boundary in -2.0f32..2.0,
3693 ) {
3694 let hc = HierarchicalConfidence::new(linkage, type_score, boundary);
3695 prop_assert!(hc.linkage >= 0.0 && hc.linkage <= 1.0);
3696 prop_assert!(hc.type_score >= 0.0 && hc.type_score <= 1.0);
3697 prop_assert!(hc.boundary >= 0.0 && hc.boundary <= 1.0);
3698 prop_assert!(hc.combined() >= 0.0 && hc.combined() <= 1.0);
3699 }
3700
3701 #[test]
3702 fn span_candidate_width_consistent(
3703 doc in 0u32..10,
3704 start in 0u32..100,
3705 end in 1u32..100,
3706 ) {
3707 let actual_end = start.max(end);
3708 let sc = SpanCandidate::new(doc, start, actual_end);
3709 prop_assert_eq!(sc.width(), actual_end.saturating_sub(start));
3710 }
3711
3712 #[test]
3713 fn ragged_batch_preserves_tokens(
3714 seq_lens in proptest::collection::vec(1usize..10, 1..5),
3715 ) {
3716 let mut counter = 0u32;
3718 let seqs: Vec<Vec<u32>> = seq_lens.iter().map(|&len| {
3719 let seq: Vec<u32> = (counter..counter + len as u32).collect();
3720 counter += len as u32;
3721 seq
3722 }).collect();
3723
3724 let batch = RaggedBatch::from_sequences(&seqs);
3725
3726 prop_assert_eq!(batch.batch_size(), seqs.len());
3728 prop_assert_eq!(batch.total_tokens(), seq_lens.iter().sum::<usize>());
3729
3730 for (i, seq) in seqs.iter().enumerate() {
3732 let doc_tokens = batch.doc_tokens(i).unwrap();
3733 prop_assert_eq!(doc_tokens, seq.as_slice());
3734 }
3735 }
3736
3737 #[test]
3738 fn span_text_offsets_consistent(start in 0usize..100, len in 0usize..50) {
3739 let end = start + len;
3740 let span = Span::text(start, end);
3741 let (s, e) = span.text_offsets().unwrap();
3742 prop_assert_eq!(s, start);
3743 prop_assert_eq!(e, end);
3744 prop_assert_eq!(span.len(), len);
3745 }
3746
3747 #[test]
3753 fn entity_span_validity(
3754 start in 0usize..10000,
3755 len in 1usize..500,
3756 conf in 0.0f64..=1.0,
3757 ) {
3758 let end = start + len;
3759 let text_content: String = "x".repeat(end);
3761 let entity_text: String = text_content.chars().skip(start).take(len).collect();
3762 let e = Entity::new(&entity_text, EntityType::Person, start, end, conf);
3763 let issues = e.validate(&text_content);
3764 for issue in &issues {
3766 match issue {
3767 ValidationIssue::InvalidSpan { .. } => {
3768 prop_assert!(false, "start < end should never produce InvalidSpan");
3769 }
3770 ValidationIssue::SpanOutOfBounds { .. } => {
3771 prop_assert!(false, "span within text should never produce SpanOutOfBounds");
3772 }
3773 _ => {} }
3775 }
3776 }
3777
3778 #[test]
3780 fn entity_type_label_roundtrip_standard(
3781 idx in 0usize..13,
3782 ) {
3783 let standard_types = [
3784 EntityType::Person,
3785 EntityType::Organization,
3786 EntityType::Location,
3787 EntityType::Date,
3788 EntityType::Time,
3789 EntityType::Money,
3790 EntityType::Percent,
3791 EntityType::Quantity,
3792 EntityType::Cardinal,
3793 EntityType::Ordinal,
3794 EntityType::Email,
3795 EntityType::Url,
3796 EntityType::Phone,
3797 ];
3798 let et = &standard_types[idx];
3799 let label = et.as_label();
3800 let roundtripped = EntityType::from_label(label);
3801 prop_assert_eq!(&roundtripped, et,
3802 "from_label(as_label()) must roundtrip for {:?} (label={:?})", et, label);
3803 }
3804
3805 #[test]
3807 fn span_containment_property(
3808 a_start in 0usize..5000,
3809 a_len in 1usize..5000,
3810 b_offset in 0usize..5000,
3811 b_len in 1usize..5000,
3812 ) {
3813 let a_end = a_start + a_len;
3814 let b_start = a_start + (b_offset % a_len); let b_end_candidate = b_start + b_len;
3816
3817 if b_start >= a_start && b_end_candidate <= a_end {
3819 prop_assert!(a_start <= b_start);
3821 prop_assert!(a_end >= b_end_candidate);
3822
3823 let ea = Entity::new("a", EntityType::Person, a_start, a_end, 1.0);
3825 let eb = Entity::new("b", EntityType::Person, b_start, b_end_candidate, 1.0);
3826 prop_assert!(ea.overlaps(&eb),
3827 "containing span must overlap contained span");
3828 }
3829 }
3830
3831 #[test]
3833 fn entity_serde_roundtrip(
3834 start in 0usize..10000,
3835 len in 1usize..500,
3836 conf in 0.0f64..=1.0,
3837 type_idx in 0usize..5,
3838 ) {
3839 let end = start + len;
3840 let types = [
3841 EntityType::Person,
3842 EntityType::Organization,
3843 EntityType::Location,
3844 EntityType::Date,
3845 EntityType::Email,
3846 ];
3847 let et = types[type_idx].clone();
3848 let text = format!("entity_{}", start);
3849 let e = Entity::new(&text, et, start, end, conf);
3850
3851 let json = serde_json::to_string(&e).unwrap();
3852 let e2: Entity = serde_json::from_str(&json).unwrap();
3853
3854 prop_assert_eq!(&e.text, &e2.text);
3855 prop_assert_eq!(&e.entity_type, &e2.entity_type);
3856 prop_assert_eq!(e.start, e2.start);
3857 prop_assert_eq!(e.end, e2.end);
3858 prop_assert!((e.confidence - e2.confidence).abs() < 1e-10,
3860 "confidence roundtrip: {} vs {}", e.confidence, e2.confidence);
3861 prop_assert_eq!(&e.normalized, &e2.normalized);
3862 prop_assert_eq!(&e.kb_id, &e2.kb_id);
3863 }
3864
3865 #[test]
3867 fn discontinuous_span_total_length(
3868 segments in proptest::collection::vec(
3869 (0usize..5000, 1usize..500),
3870 1..6
3871 ),
3872 ) {
3873 let ranges: Vec<std::ops::Range<usize>> = segments.iter()
3874 .map(|&(start, len)| start..start + len)
3875 .collect();
3876 let expected_sum: usize = ranges.iter().map(|r| r.end - r.start).sum();
3877 let span = DiscontinuousSpan::new(ranges);
3878 prop_assert_eq!(span.total_len(), expected_sum,
3879 "total_len must equal sum of segment lengths");
3880 }
3881 }
3882
3883 #[test]
3888 fn test_entity_viewport_as_str() {
3889 assert_eq!(EntityViewport::Business.as_str(), "business");
3890 assert_eq!(EntityViewport::Legal.as_str(), "legal");
3891 assert_eq!(EntityViewport::Technical.as_str(), "technical");
3892 assert_eq!(EntityViewport::Academic.as_str(), "academic");
3893 assert_eq!(EntityViewport::Personal.as_str(), "personal");
3894 assert_eq!(EntityViewport::Political.as_str(), "political");
3895 assert_eq!(EntityViewport::Media.as_str(), "media");
3896 assert_eq!(EntityViewport::Historical.as_str(), "historical");
3897 assert_eq!(EntityViewport::General.as_str(), "general");
3898 assert_eq!(
3899 EntityViewport::Custom("custom".to_string()).as_str(),
3900 "custom"
3901 );
3902 }
3903
3904 #[test]
3905 fn test_entity_viewport_is_professional() {
3906 assert!(EntityViewport::Business.is_professional());
3907 assert!(EntityViewport::Legal.is_professional());
3908 assert!(EntityViewport::Technical.is_professional());
3909 assert!(EntityViewport::Academic.is_professional());
3910 assert!(EntityViewport::Political.is_professional());
3911
3912 assert!(!EntityViewport::Personal.is_professional());
3913 assert!(!EntityViewport::Media.is_professional());
3914 assert!(!EntityViewport::Historical.is_professional());
3915 assert!(!EntityViewport::General.is_professional());
3916 assert!(!EntityViewport::Custom("test".to_string()).is_professional());
3917 }
3918
3919 #[test]
3920 fn test_entity_viewport_from_str() {
3921 assert_eq!(
3922 "business".parse::<EntityViewport>().unwrap(),
3923 EntityViewport::Business
3924 );
3925 assert_eq!(
3926 "financial".parse::<EntityViewport>().unwrap(),
3927 EntityViewport::Business
3928 );
3929 assert_eq!(
3930 "corporate".parse::<EntityViewport>().unwrap(),
3931 EntityViewport::Business
3932 );
3933
3934 assert_eq!(
3935 "legal".parse::<EntityViewport>().unwrap(),
3936 EntityViewport::Legal
3937 );
3938 assert_eq!(
3939 "law".parse::<EntityViewport>().unwrap(),
3940 EntityViewport::Legal
3941 );
3942
3943 assert_eq!(
3944 "technical".parse::<EntityViewport>().unwrap(),
3945 EntityViewport::Technical
3946 );
3947 assert_eq!(
3948 "engineering".parse::<EntityViewport>().unwrap(),
3949 EntityViewport::Technical
3950 );
3951
3952 assert_eq!(
3953 "academic".parse::<EntityViewport>().unwrap(),
3954 EntityViewport::Academic
3955 );
3956 assert_eq!(
3957 "research".parse::<EntityViewport>().unwrap(),
3958 EntityViewport::Academic
3959 );
3960
3961 assert_eq!(
3962 "personal".parse::<EntityViewport>().unwrap(),
3963 EntityViewport::Personal
3964 );
3965 assert_eq!(
3966 "biographical".parse::<EntityViewport>().unwrap(),
3967 EntityViewport::Personal
3968 );
3969
3970 assert_eq!(
3971 "political".parse::<EntityViewport>().unwrap(),
3972 EntityViewport::Political
3973 );
3974 assert_eq!(
3975 "policy".parse::<EntityViewport>().unwrap(),
3976 EntityViewport::Political
3977 );
3978
3979 assert_eq!(
3980 "media".parse::<EntityViewport>().unwrap(),
3981 EntityViewport::Media
3982 );
3983 assert_eq!(
3984 "press".parse::<EntityViewport>().unwrap(),
3985 EntityViewport::Media
3986 );
3987
3988 assert_eq!(
3989 "historical".parse::<EntityViewport>().unwrap(),
3990 EntityViewport::Historical
3991 );
3992 assert_eq!(
3993 "history".parse::<EntityViewport>().unwrap(),
3994 EntityViewport::Historical
3995 );
3996
3997 assert_eq!(
3998 "general".parse::<EntityViewport>().unwrap(),
3999 EntityViewport::General
4000 );
4001 assert_eq!(
4002 "generic".parse::<EntityViewport>().unwrap(),
4003 EntityViewport::General
4004 );
4005 assert_eq!(
4006 "".parse::<EntityViewport>().unwrap(),
4007 EntityViewport::General
4008 );
4009
4010 assert_eq!(
4012 "custom_viewport".parse::<EntityViewport>().unwrap(),
4013 EntityViewport::Custom("custom_viewport".to_string())
4014 );
4015 }
4016
4017 #[test]
4018 fn test_entity_viewport_from_str_case_insensitive() {
4019 assert_eq!(
4020 "BUSINESS".parse::<EntityViewport>().unwrap(),
4021 EntityViewport::Business
4022 );
4023 assert_eq!(
4024 "Business".parse::<EntityViewport>().unwrap(),
4025 EntityViewport::Business
4026 );
4027 assert_eq!(
4028 "BuSiNeSs".parse::<EntityViewport>().unwrap(),
4029 EntityViewport::Business
4030 );
4031 }
4032
4033 #[test]
4034 fn test_entity_viewport_display() {
4035 assert_eq!(format!("{}", EntityViewport::Business), "business");
4036 assert_eq!(format!("{}", EntityViewport::Academic), "academic");
4037 assert_eq!(
4038 format!("{}", EntityViewport::Custom("test".to_string())),
4039 "test"
4040 );
4041 }
4042
4043 #[test]
4044 fn test_entity_viewport_methods() {
4045 let mut entity = Entity::new("Marie Curie", EntityType::Person, 0, 11, 0.9);
4046
4047 assert!(!entity.has_viewport());
4049 assert_eq!(entity.viewport_or_default(), EntityViewport::General);
4050 assert!(entity.matches_viewport(&EntityViewport::Academic)); entity.set_viewport(EntityViewport::Academic);
4054 assert!(entity.has_viewport());
4055 assert_eq!(entity.viewport_or_default(), EntityViewport::Academic);
4056 assert!(entity.matches_viewport(&EntityViewport::Academic));
4057 assert!(!entity.matches_viewport(&EntityViewport::Business));
4058 }
4059
4060 #[test]
4061 fn test_entity_builder_with_viewport() {
4062 let entity = Entity::builder("Marie Curie", EntityType::Person)
4063 .span(0, 11)
4064 .viewport(EntityViewport::Academic)
4065 .build();
4066
4067 assert_eq!(entity.viewport, Some(EntityViewport::Academic));
4068 assert!(entity.has_viewport());
4069 }
4070
4071 #[test]
4076 fn test_entity_category_requires_ml() {
4077 assert!(EntityCategory::Agent.requires_ml());
4078 assert!(EntityCategory::Organization.requires_ml());
4079 assert!(EntityCategory::Place.requires_ml());
4080 assert!(EntityCategory::Creative.requires_ml());
4081 assert!(EntityCategory::Relation.requires_ml());
4082
4083 assert!(!EntityCategory::Temporal.requires_ml());
4084 assert!(!EntityCategory::Numeric.requires_ml());
4085 assert!(!EntityCategory::Contact.requires_ml());
4086 assert!(!EntityCategory::Misc.requires_ml());
4087 }
4088
4089 #[test]
4090 fn test_entity_category_pattern_detectable() {
4091 assert!(EntityCategory::Temporal.pattern_detectable());
4092 assert!(EntityCategory::Numeric.pattern_detectable());
4093 assert!(EntityCategory::Contact.pattern_detectable());
4094
4095 assert!(!EntityCategory::Agent.pattern_detectable());
4096 assert!(!EntityCategory::Organization.pattern_detectable());
4097 assert!(!EntityCategory::Place.pattern_detectable());
4098 assert!(!EntityCategory::Creative.pattern_detectable());
4099 assert!(!EntityCategory::Relation.pattern_detectable());
4100 assert!(!EntityCategory::Misc.pattern_detectable());
4101 }
4102
4103 #[test]
4104 fn test_entity_category_is_relation() {
4105 assert!(EntityCategory::Relation.is_relation());
4106
4107 assert!(!EntityCategory::Agent.is_relation());
4108 assert!(!EntityCategory::Organization.is_relation());
4109 assert!(!EntityCategory::Place.is_relation());
4110 assert!(!EntityCategory::Temporal.is_relation());
4111 assert!(!EntityCategory::Numeric.is_relation());
4112 assert!(!EntityCategory::Contact.is_relation());
4113 assert!(!EntityCategory::Creative.is_relation());
4114 assert!(!EntityCategory::Misc.is_relation());
4115 }
4116
4117 #[test]
4118 fn test_entity_category_as_str() {
4119 assert_eq!(EntityCategory::Agent.as_str(), "agent");
4120 assert_eq!(EntityCategory::Organization.as_str(), "organization");
4121 assert_eq!(EntityCategory::Place.as_str(), "place");
4122 assert_eq!(EntityCategory::Creative.as_str(), "creative");
4123 assert_eq!(EntityCategory::Temporal.as_str(), "temporal");
4124 assert_eq!(EntityCategory::Numeric.as_str(), "numeric");
4125 assert_eq!(EntityCategory::Contact.as_str(), "contact");
4126 assert_eq!(EntityCategory::Relation.as_str(), "relation");
4127 assert_eq!(EntityCategory::Misc.as_str(), "misc");
4128 }
4129
4130 #[test]
4131 fn test_entity_category_display() {
4132 assert_eq!(format!("{}", EntityCategory::Agent), "agent");
4133 assert_eq!(format!("{}", EntityCategory::Temporal), "temporal");
4134 assert_eq!(format!("{}", EntityCategory::Relation), "relation");
4135 }
4136
4137 #[test]
4142 fn test_entity_type_serializes_to_flat_string() {
4143 assert_eq!(
4144 serde_json::to_string(&EntityType::Person).unwrap(),
4145 r#""PER""#
4146 );
4147 assert_eq!(
4148 serde_json::to_string(&EntityType::Organization).unwrap(),
4149 r#""ORG""#
4150 );
4151 assert_eq!(
4152 serde_json::to_string(&EntityType::Location).unwrap(),
4153 r#""LOC""#
4154 );
4155 assert_eq!(
4156 serde_json::to_string(&EntityType::Date).unwrap(),
4157 r#""DATE""#
4158 );
4159 assert_eq!(
4160 serde_json::to_string(&EntityType::Money).unwrap(),
4161 r#""MONEY""#
4162 );
4163 }
4164
4165 #[test]
4166 fn test_custom_entity_type_serializes_flat() {
4167 let misc = EntityType::custom("MISC", EntityCategory::Misc);
4168 assert_eq!(serde_json::to_string(&misc).unwrap(), r#""MISC""#);
4169
4170 let disease = EntityType::custom("DISEASE", EntityCategory::Agent);
4171 assert_eq!(serde_json::to_string(&disease).unwrap(), r#""DISEASE""#);
4172 }
4173
4174 #[test]
4175 fn test_entity_type_deserializes_from_flat_string() {
4176 let per: EntityType = serde_json::from_str(r#""PER""#).unwrap();
4177 assert_eq!(per, EntityType::Person);
4178
4179 let org: EntityType = serde_json::from_str(r#""ORG""#).unwrap();
4180 assert_eq!(org, EntityType::Organization);
4181
4182 let misc: EntityType = serde_json::from_str(r#""MISC""#).unwrap();
4183 assert_eq!(misc, EntityType::custom("MISC", EntityCategory::Misc));
4184 }
4185
4186 #[test]
4187 fn test_entity_type_deserializes_backward_compat_custom() {
4188 let json = r#"{"Custom":{"name":"MISC","category":"Misc"}}"#;
4190 let et: EntityType = serde_json::from_str(json).unwrap();
4191 assert_eq!(et, EntityType::custom("MISC", EntityCategory::Misc));
4192 }
4193
4194 #[test]
4195 fn test_entity_type_deserializes_backward_compat_other() {
4196 let json = r#"{"Other":"foo"}"#;
4198 let et: EntityType = serde_json::from_str(json).unwrap();
4199 assert_eq!(et, EntityType::custom("foo", EntityCategory::Misc));
4200 }
4201
4202 #[test]
4203 fn test_entity_type_serde_roundtrip() {
4204 let types = vec![
4205 EntityType::Person,
4206 EntityType::Organization,
4207 EntityType::Location,
4208 EntityType::Date,
4209 EntityType::Time,
4210 EntityType::Money,
4211 EntityType::Percent,
4212 EntityType::Quantity,
4213 EntityType::Cardinal,
4214 EntityType::Ordinal,
4215 EntityType::Email,
4216 EntityType::Url,
4217 EntityType::Phone,
4218 EntityType::custom("MISC", EntityCategory::Misc),
4219 EntityType::custom("DISEASE", EntityCategory::Agent),
4220 ];
4221
4222 for t in &types {
4223 let json = serde_json::to_string(t).unwrap();
4224 let back: EntityType = serde_json::from_str(&json).unwrap();
4225 assert_eq!(
4228 t.as_label(),
4229 back.as_label(),
4230 "roundtrip failed for {:?}",
4231 t
4232 );
4233 }
4234 }
4235}