1use super::confidence::Confidence;
39use super::types::{MentionType, PhiFeatures};
40use serde::{Deserialize, Serialize};
41use std::borrow::Cow;
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
53#[non_exhaustive]
54pub enum EntityCategory {
55 Agent,
58 Organization,
61 Place,
64 Creative,
67 Temporal,
70 Numeric,
73 Contact,
76 Relation,
80 Misc,
82}
83
84impl EntityCategory {
85 #[must_use]
87 pub const fn requires_ml(&self) -> bool {
88 matches!(
89 self,
90 EntityCategory::Agent
91 | EntityCategory::Organization
92 | EntityCategory::Place
93 | EntityCategory::Creative
94 | EntityCategory::Relation
95 )
96 }
97
98 #[must_use]
100 pub const fn pattern_detectable(&self) -> bool {
101 matches!(
102 self,
103 EntityCategory::Temporal | EntityCategory::Numeric | EntityCategory::Contact
104 )
105 }
106
107 #[must_use]
109 pub const fn is_relation(&self) -> bool {
110 matches!(self, EntityCategory::Relation)
111 }
112
113 #[must_use]
115 pub const fn as_str(&self) -> &'static str {
116 match self {
117 EntityCategory::Agent => "agent",
118 EntityCategory::Organization => "organization",
119 EntityCategory::Place => "place",
120 EntityCategory::Creative => "creative",
121 EntityCategory::Temporal => "temporal",
122 EntityCategory::Numeric => "numeric",
123 EntityCategory::Contact => "contact",
124 EntityCategory::Relation => "relation",
125 EntityCategory::Misc => "misc",
126 }
127 }
128}
129
130impl std::fmt::Display for EntityCategory {
131 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132 write!(f, "{}", self.as_str())
133 }
134}
135
136#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
173#[non_exhaustive]
174pub enum EntityViewport {
175 Business,
177 Legal,
179 Technical,
181 Academic,
183 Personal,
185 Political,
187 Media,
189 Historical,
191 #[default]
193 General,
194 Custom(String),
196}
197
198impl EntityViewport {
199 #[must_use]
201 pub fn as_str(&self) -> &str {
202 match self {
203 EntityViewport::Business => "business",
204 EntityViewport::Legal => "legal",
205 EntityViewport::Technical => "technical",
206 EntityViewport::Academic => "academic",
207 EntityViewport::Personal => "personal",
208 EntityViewport::Political => "political",
209 EntityViewport::Media => "media",
210 EntityViewport::Historical => "historical",
211 EntityViewport::General => "general",
212 EntityViewport::Custom(s) => s,
213 }
214 }
215
216 #[must_use]
218 pub const fn is_professional(&self) -> bool {
219 matches!(
220 self,
221 EntityViewport::Business
222 | EntityViewport::Legal
223 | EntityViewport::Technical
224 | EntityViewport::Academic
225 | EntityViewport::Political
226 )
227 }
228}
229
230impl std::str::FromStr for EntityViewport {
231 type Err = std::convert::Infallible;
232
233 fn from_str(s: &str) -> Result<Self, Self::Err> {
234 Ok(match s.to_lowercase().as_str() {
235 "business" | "financial" | "corporate" => EntityViewport::Business,
236 "legal" | "law" | "compliance" => EntityViewport::Legal,
237 "technical" | "engineering" | "tech" => EntityViewport::Technical,
238 "academic" | "research" | "scholarly" => EntityViewport::Academic,
239 "personal" | "biographical" | "private" => EntityViewport::Personal,
240 "political" | "policy" | "government" => EntityViewport::Political,
241 "media" | "press" | "pr" | "public_relations" => EntityViewport::Media,
242 "historical" | "history" | "past" => EntityViewport::Historical,
243 "general" | "generic" | "" => EntityViewport::General,
244 other => EntityViewport::Custom(other.to_string()),
245 })
246 }
247}
248
249impl std::fmt::Display for EntityViewport {
250 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
251 write!(f, "{}", self.as_str())
252 }
253}
254
255#[derive(Debug, Clone, PartialEq, Eq, Hash)]
280#[non_exhaustive]
281pub enum EntityType {
282 Person,
285 Organization,
287 Location,
289
290 Date,
293 Time,
295
296 Money,
299 Percent,
301 Quantity,
303 Cardinal,
305 Ordinal,
307
308 Email,
311 Url,
313 Phone,
315
316 Custom {
319 name: String,
321 category: EntityCategory,
323 },
324
325 #[deprecated(note = "use EntityType::custom(name, EntityCategory::Misc) instead")]
331 Other(String),
332}
333
334impl EntityType {
335 #[must_use]
337 pub fn category(&self) -> EntityCategory {
338 match self {
339 EntityType::Person => EntityCategory::Agent,
341 EntityType::Organization => EntityCategory::Organization,
343 EntityType::Location => EntityCategory::Place,
345 EntityType::Date | EntityType::Time => EntityCategory::Temporal,
347 EntityType::Money
349 | EntityType::Percent
350 | EntityType::Quantity
351 | EntityType::Cardinal
352 | EntityType::Ordinal => EntityCategory::Numeric,
353 EntityType::Email | EntityType::Url | EntityType::Phone => EntityCategory::Contact,
355 EntityType::Custom { category, .. } => *category,
357 #[allow(deprecated)]
359 EntityType::Other(_) => EntityCategory::Misc,
360 }
361 }
362
363 #[must_use]
365 pub fn requires_ml(&self) -> bool {
366 self.category().requires_ml()
367 }
368
369 #[must_use]
371 pub fn pattern_detectable(&self) -> bool {
372 self.category().pattern_detectable()
373 }
374
375 #[must_use]
384 pub fn as_label(&self) -> &str {
385 match self {
386 EntityType::Person => "PER",
387 EntityType::Organization => "ORG",
388 EntityType::Location => "LOC",
389 EntityType::Date => "DATE",
390 EntityType::Time => "TIME",
391 EntityType::Money => "MONEY",
392 EntityType::Percent => "PERCENT",
393 EntityType::Quantity => "QUANTITY",
394 EntityType::Cardinal => "CARDINAL",
395 EntityType::Ordinal => "ORDINAL",
396 EntityType::Email => "EMAIL",
397 EntityType::Url => "URL",
398 EntityType::Phone => "PHONE",
399 EntityType::Custom { name, .. } => name.as_str(),
400 #[allow(deprecated)]
401 EntityType::Other(s) => s.as_str(),
402 }
403 }
404
405 #[must_use]
417 pub fn from_label(label: &str) -> Self {
418 let label = label
420 .strip_prefix("B-")
421 .or_else(|| label.strip_prefix("I-"))
422 .or_else(|| label.strip_prefix("E-"))
423 .or_else(|| label.strip_prefix("S-"))
424 .unwrap_or(label);
425
426 match label.to_uppercase().as_str() {
427 "PER" | "PERSON" => EntityType::Person,
429 "ORG" | "ORGANIZATION" | "COMPANY" | "CORPORATION" => EntityType::Organization,
430 "LOC" | "LOCATION" | "GPE" | "GEO-LOC" => EntityType::Location,
431 "FACILITY" | "FAC" | "BUILDING" => {
433 EntityType::custom("BUILDING", EntityCategory::Place)
434 }
435 "PRODUCT" | "PROD" => EntityType::custom("PRODUCT", EntityCategory::Misc),
436 "EVENT" => EntityType::custom("EVENT", EntityCategory::Creative),
437 "CREATIVE-WORK" | "WORK_OF_ART" | "ART" => {
438 EntityType::custom("CREATIVE_WORK", EntityCategory::Creative)
439 }
440 "GROUP" | "NORP" => EntityType::custom("GROUP", EntityCategory::Agent),
441 "DATE" => EntityType::Date,
443 "TIME" => EntityType::Time,
444 "MONEY" | "CURRENCY" => EntityType::Money,
446 "PERCENT" | "PERCENTAGE" => EntityType::Percent,
447 "QUANTITY" => EntityType::Quantity,
448 "CARDINAL" => EntityType::Cardinal,
449 "ORDINAL" => EntityType::Ordinal,
450 "EMAIL" => EntityType::Email,
452 "URL" | "URI" => EntityType::Url,
453 "PHONE" | "TELEPHONE" => EntityType::Phone,
454 "MISC" | "MISCELLANEOUS" | "OTHER" => EntityType::custom("MISC", EntityCategory::Misc),
456 "DISEASE" | "DISORDER" => EntityType::custom("DISEASE", EntityCategory::Misc),
458 "CHEMICAL" | "DRUG" => EntityType::custom("CHEMICAL", EntityCategory::Misc),
459 "GENE" => EntityType::custom("GENE", EntityCategory::Misc),
460 "PROTEIN" => EntityType::custom("PROTEIN", EntityCategory::Misc),
461 other => EntityType::custom(other, EntityCategory::Misc),
463 }
464 }
465
466 #[must_use]
480 pub fn custom(name: impl Into<String>, category: EntityCategory) -> Self {
481 EntityType::Custom {
482 name: name.into(),
483 category,
484 }
485 }
486}
487
488impl std::fmt::Display for EntityType {
489 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
490 write!(f, "{}", self.as_label())
491 }
492}
493
494impl std::str::FromStr for EntityType {
495 type Err = std::convert::Infallible;
496
497 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
499 Ok(Self::from_label(s))
500 }
501}
502
503impl Serialize for EntityType {
508 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
509 serializer.serialize_str(self.as_label())
510 }
511}
512
513impl<'de> Deserialize<'de> for EntityType {
514 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
515 struct EntityTypeVisitor;
516
517 impl<'de> serde::de::Visitor<'de> for EntityTypeVisitor {
518 type Value = EntityType;
519
520 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
521 f.write_str("a string label or a tagged enum object")
522 }
523
524 fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<EntityType, E> {
526 Ok(EntityType::from_label(v))
527 }
528
529 fn visit_map<A: serde::de::MapAccess<'de>>(
532 self,
533 mut map: A,
534 ) -> Result<EntityType, A::Error> {
535 let key: String = map
536 .next_key()?
537 .ok_or_else(|| serde::de::Error::custom("empty object"))?;
538 match key.as_str() {
539 "Custom" => {
540 #[derive(Deserialize)]
541 struct CustomFields {
542 name: String,
543 category: EntityCategory,
544 }
545 let fields: CustomFields = map.next_value()?;
546 Ok(EntityType::Custom {
547 name: fields.name,
548 category: fields.category,
549 })
550 }
551 "Other" => {
552 let val: String = map.next_value()?;
554 Ok(EntityType::custom(val, EntityCategory::Misc))
555 }
556 variant => {
558 let _: serde::de::IgnoredAny = map.next_value()?;
560 Ok(EntityType::from_label(variant))
561 }
562 }
563 }
564 }
565
566 deserializer.deserialize_any(EntityTypeVisitor)
567 }
568}
569
570#[derive(Debug, Clone, Default)]
610pub struct TypeMapper {
611 mappings: std::collections::HashMap<String, EntityType>,
612}
613
614impl TypeMapper {
615 #[must_use]
617 pub fn new() -> Self {
618 Self::default()
619 }
620
621 #[must_use]
623 pub fn mit_movie() -> Self {
624 let mut mapper = Self::new();
625 mapper.add("ACTOR", EntityType::Person);
627 mapper.add("DIRECTOR", EntityType::Person);
628 mapper.add("CHARACTER", EntityType::Person);
629 mapper.add(
630 "TITLE",
631 EntityType::custom("WORK_OF_ART", EntityCategory::Creative),
632 );
633 mapper.add("GENRE", EntityType::custom("GENRE", EntityCategory::Misc));
634 mapper.add("YEAR", EntityType::Date);
635 mapper.add("RATING", EntityType::custom("RATING", EntityCategory::Misc));
636 mapper.add("PLOT", EntityType::custom("PLOT", EntityCategory::Misc));
637 mapper
638 }
639
640 #[must_use]
642 pub fn mit_restaurant() -> Self {
643 let mut mapper = Self::new();
644 mapper.add("RESTAURANT_NAME", EntityType::Organization);
645 mapper.add("LOCATION", EntityType::Location);
646 mapper.add(
647 "CUISINE",
648 EntityType::custom("CUISINE", EntityCategory::Misc),
649 );
650 mapper.add("DISH", EntityType::custom("DISH", EntityCategory::Misc));
651 mapper.add("PRICE", EntityType::Money);
652 mapper.add(
653 "AMENITY",
654 EntityType::custom("AMENITY", EntityCategory::Misc),
655 );
656 mapper.add("HOURS", EntityType::Time);
657 mapper
658 }
659
660 #[must_use]
662 pub fn biomedical() -> Self {
663 let mut mapper = Self::new();
664 mapper.add(
665 "DISEASE",
666 EntityType::custom("DISEASE", EntityCategory::Agent),
667 );
668 mapper.add(
669 "CHEMICAL",
670 EntityType::custom("CHEMICAL", EntityCategory::Misc),
671 );
672 mapper.add("DRUG", EntityType::custom("DRUG", EntityCategory::Misc));
673 mapper.add("GENE", EntityType::custom("GENE", EntityCategory::Misc));
674 mapper.add(
675 "PROTEIN",
676 EntityType::custom("PROTEIN", EntityCategory::Misc),
677 );
678 mapper.add("DNA", EntityType::custom("DNA", EntityCategory::Misc));
680 mapper.add("RNA", EntityType::custom("RNA", EntityCategory::Misc));
681 mapper.add(
682 "cell_line",
683 EntityType::custom("CELL_LINE", EntityCategory::Misc),
684 );
685 mapper.add(
686 "cell_type",
687 EntityType::custom("CELL_TYPE", EntityCategory::Misc),
688 );
689 mapper
690 }
691
692 #[must_use]
694 pub fn social_media() -> Self {
695 let mut mapper = Self::new();
696 mapper.add("person", EntityType::Person);
698 mapper.add("corporation", EntityType::Organization);
699 mapper.add("location", EntityType::Location);
700 mapper.add("group", EntityType::Organization);
701 mapper.add(
702 "product",
703 EntityType::custom("PRODUCT", EntityCategory::Misc),
704 );
705 mapper.add(
706 "creative_work",
707 EntityType::custom("WORK_OF_ART", EntityCategory::Creative),
708 );
709 mapper.add("event", EntityType::custom("EVENT", EntityCategory::Misc));
710 mapper
711 }
712
713 #[must_use]
715 pub fn manufacturing() -> Self {
716 let mut mapper = Self::new();
717 mapper.add("MATE", EntityType::custom("MATERIAL", EntityCategory::Misc));
719 mapper.add("MANP", EntityType::custom("PROCESS", EntityCategory::Misc));
720 mapper.add("MACEQ", EntityType::custom("MACHINE", EntityCategory::Misc));
721 mapper.add(
722 "APPL",
723 EntityType::custom("APPLICATION", EntityCategory::Misc),
724 );
725 mapper.add("FEAT", EntityType::custom("FEATURE", EntityCategory::Misc));
726 mapper.add(
727 "PARA",
728 EntityType::custom("PARAMETER", EntityCategory::Misc),
729 );
730 mapper.add("PRO", EntityType::custom("PROPERTY", EntityCategory::Misc));
731 mapper.add(
732 "CHAR",
733 EntityType::custom("CHARACTERISTIC", EntityCategory::Misc),
734 );
735 mapper.add(
736 "ENAT",
737 EntityType::custom("ENABLING_TECHNOLOGY", EntityCategory::Misc),
738 );
739 mapper.add(
740 "CONPRI",
741 EntityType::custom("CONCEPT_PRINCIPLE", EntityCategory::Misc),
742 );
743 mapper.add(
744 "BIOP",
745 EntityType::custom("BIO_PROCESS", EntityCategory::Misc),
746 );
747 mapper.add(
748 "MANS",
749 EntityType::custom("MAN_STANDARD", EntityCategory::Misc),
750 );
751 mapper
752 }
753
754 pub fn add(&mut self, source: impl Into<String>, target: EntityType) {
756 self.mappings.insert(source.into().to_uppercase(), target);
757 }
758
759 #[must_use]
761 pub fn map(&self, label: &str) -> Option<&EntityType> {
762 self.mappings.get(&label.to_uppercase())
763 }
764
765 #[must_use]
769 pub fn normalize(&self, label: &str) -> EntityType {
770 self.map(label)
771 .cloned()
772 .unwrap_or_else(|| EntityType::from_label(label))
773 }
774
775 #[must_use]
777 pub fn contains(&self, label: &str) -> bool {
778 self.mappings.contains_key(&label.to_uppercase())
779 }
780
781 pub fn labels(&self) -> impl Iterator<Item = &String> {
783 self.mappings.keys()
784 }
785}
786
787#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
803#[non_exhaustive]
804pub enum ExtractionMethod {
805 Pattern,
808
809 #[default]
812 Neural,
813
814 #[deprecated(since = "0.2.0", note = "Use Neural or GatedEnsemble instead")]
818 Lexicon,
819
820 SoftLexicon,
824
825 GatedEnsemble,
829
830 Consensus,
832
833 Heuristic,
836
837 Unknown,
839
840 #[deprecated(since = "0.2.0", note = "Use Heuristic or Pattern instead")]
842 Rule,
843
844 #[deprecated(since = "0.2.0", note = "Use Neural instead")]
846 ML,
847
848 #[deprecated(since = "0.2.0", note = "Use Consensus instead")]
850 Ensemble,
851}
852
853impl ExtractionMethod {
854 #[must_use]
881 pub const fn is_calibrated(&self) -> bool {
882 #[allow(deprecated)]
883 match self {
884 ExtractionMethod::Neural => true,
885 ExtractionMethod::GatedEnsemble => true,
886 ExtractionMethod::SoftLexicon => true,
887 ExtractionMethod::ML => true, ExtractionMethod::Pattern => false,
890 ExtractionMethod::Lexicon => false,
891 ExtractionMethod::Consensus => false,
892 ExtractionMethod::Heuristic => false,
893 ExtractionMethod::Unknown => false,
894 ExtractionMethod::Rule => false,
895 ExtractionMethod::Ensemble => false,
896 }
897 }
898
899 #[must_use]
906 pub const fn confidence_interpretation(&self) -> &'static str {
907 #[allow(deprecated)]
908 match self {
909 ExtractionMethod::Neural | ExtractionMethod::ML => "probability",
910 ExtractionMethod::GatedEnsemble | ExtractionMethod::SoftLexicon => "probability",
911 ExtractionMethod::Pattern | ExtractionMethod::Lexicon => "binary",
912 ExtractionMethod::Heuristic | ExtractionMethod::Rule => "heuristic_score",
913 ExtractionMethod::Consensus | ExtractionMethod::Ensemble => "agreement_ratio",
914 ExtractionMethod::Unknown => "unknown",
915 }
916 }
917}
918
919impl std::fmt::Display for ExtractionMethod {
920 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
921 #[allow(deprecated)]
922 match self {
923 ExtractionMethod::Pattern => write!(f, "pattern"),
924 ExtractionMethod::Neural => write!(f, "neural"),
925 ExtractionMethod::Lexicon => write!(f, "lexicon"),
926 ExtractionMethod::SoftLexicon => write!(f, "soft_lexicon"),
927 ExtractionMethod::GatedEnsemble => write!(f, "gated_ensemble"),
928 ExtractionMethod::Consensus => write!(f, "consensus"),
929 ExtractionMethod::Heuristic => write!(f, "heuristic"),
930 ExtractionMethod::Unknown => write!(f, "unknown"),
931 ExtractionMethod::Rule => write!(f, "heuristic"), ExtractionMethod::ML => write!(f, "neural"), ExtractionMethod::Ensemble => write!(f, "consensus"), }
935 }
936}
937
938pub trait Lexicon: Send + Sync {
992 fn lookup(&self, text: &str) -> Option<(EntityType, Confidence)>;
996
997 fn contains(&self, text: &str) -> bool {
999 self.lookup(text).is_some()
1000 }
1001
1002 fn source(&self) -> &str;
1004
1005 fn len(&self) -> usize;
1007
1008 fn is_empty(&self) -> bool {
1010 self.len() == 0
1011 }
1012}
1013
1014#[derive(Debug, Clone)]
1019pub struct HashMapLexicon {
1020 entries: std::collections::HashMap<String, (EntityType, Confidence)>,
1021 source: String,
1022}
1023
1024impl HashMapLexicon {
1025 #[must_use]
1027 pub fn new(source: impl Into<String>) -> Self {
1028 Self {
1029 entries: std::collections::HashMap::new(),
1030 source: source.into(),
1031 }
1032 }
1033
1034 pub fn insert(
1036 &mut self,
1037 text: impl Into<String>,
1038 entity_type: EntityType,
1039 confidence: impl Into<Confidence>,
1040 ) {
1041 self.entries
1042 .insert(text.into(), (entity_type, confidence.into()));
1043 }
1044
1045 pub fn from_iter<I, S, C>(source: impl Into<String>, entries: I) -> Self
1047 where
1048 I: IntoIterator<Item = (S, EntityType, C)>,
1049 S: Into<String>,
1050 C: Into<Confidence>,
1051 {
1052 let mut lexicon = Self::new(source);
1053 for (text, entity_type, conf) in entries {
1054 lexicon.insert(text, entity_type, conf);
1055 }
1056 lexicon
1057 }
1058
1059 pub fn entries(&self) -> impl Iterator<Item = (&str, &EntityType, Confidence)> {
1061 self.entries.iter().map(|(k, (t, c))| (k.as_str(), t, *c))
1062 }
1063}
1064
1065impl Lexicon for HashMapLexicon {
1066 fn lookup(&self, text: &str) -> Option<(EntityType, Confidence)> {
1067 self.entries.get(text).cloned()
1068 }
1069
1070 fn source(&self) -> &str {
1071 &self.source
1072 }
1073
1074 fn len(&self) -> usize {
1075 self.entries.len()
1076 }
1077}
1078
1079#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
1084pub struct Provenance {
1085 pub source: Cow<'static, str>,
1087 pub method: ExtractionMethod,
1089 pub pattern: Option<Cow<'static, str>>,
1091 pub raw_confidence: Option<Confidence>,
1093 #[serde(default, skip_serializing_if = "Option::is_none")]
1095 pub model_version: Option<Cow<'static, str>>,
1096 #[serde(default, skip_serializing_if = "Option::is_none")]
1098 pub timestamp: Option<String>,
1099}
1100
1101impl Provenance {
1102 #[must_use]
1104 pub fn pattern(pattern_name: &'static str) -> Self {
1105 Self {
1106 source: Cow::Borrowed("pattern"),
1107 method: ExtractionMethod::Pattern,
1108 pattern: Some(Cow::Borrowed(pattern_name)),
1109 raw_confidence: Some(Confidence::ONE), model_version: None,
1111 timestamp: None,
1112 }
1113 }
1114
1115 #[must_use]
1129 pub fn ml(model_name: impl Into<Cow<'static, str>>, confidence: impl Into<Confidence>) -> Self {
1130 Self {
1131 source: model_name.into(),
1132 method: ExtractionMethod::Neural,
1133 pattern: None,
1134 raw_confidence: Some(confidence.into()),
1135 model_version: None,
1136 timestamp: None,
1137 }
1138 }
1139
1140 #[deprecated(
1142 since = "0.2.1",
1143 note = "Use ml() instead, it now accepts owned strings"
1144 )]
1145 #[must_use]
1146 pub fn ml_owned(model_name: impl Into<String>, confidence: impl Into<Confidence>) -> Self {
1147 Self::ml(Cow::Owned(model_name.into()), confidence)
1148 }
1149
1150 #[must_use]
1152 pub fn ensemble(sources: &'static str) -> Self {
1153 Self {
1154 source: Cow::Borrowed(sources),
1155 method: ExtractionMethod::Consensus,
1156 pattern: None,
1157 raw_confidence: None,
1158 model_version: None,
1159 timestamp: None,
1160 }
1161 }
1162
1163 #[must_use]
1165 pub fn with_version(mut self, version: &'static str) -> Self {
1166 self.model_version = Some(Cow::Borrowed(version));
1167 self
1168 }
1169
1170 #[must_use]
1172 pub fn with_timestamp(mut self, timestamp: impl Into<String>) -> Self {
1173 self.timestamp = Some(timestamp.into());
1174 self
1175 }
1176}
1177
1178#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1209pub enum Span {
1210 Text {
1215 start: usize,
1217 end: usize,
1219 },
1220 BoundingBox {
1223 x: f32,
1225 y: f32,
1227 width: f32,
1229 height: f32,
1231 page: Option<u32>,
1233 },
1234 Hybrid {
1236 start: usize,
1238 end: usize,
1240 bbox: Box<Span>,
1242 },
1243}
1244
1245impl Span {
1246 #[must_use]
1248 pub const fn text(start: usize, end: usize) -> Self {
1249 Self::Text { start, end }
1250 }
1251
1252 #[must_use]
1254 pub fn bbox(x: f32, y: f32, width: f32, height: f32) -> Self {
1255 Self::BoundingBox {
1256 x,
1257 y,
1258 width,
1259 height,
1260 page: None,
1261 }
1262 }
1263
1264 #[must_use]
1266 pub fn bbox_on_page(x: f32, y: f32, width: f32, height: f32, page: u32) -> Self {
1267 Self::BoundingBox {
1268 x,
1269 y,
1270 width,
1271 height,
1272 page: Some(page),
1273 }
1274 }
1275
1276 #[must_use]
1278 pub const fn is_text(&self) -> bool {
1279 matches!(self, Self::Text { .. } | Self::Hybrid { .. })
1280 }
1281
1282 #[must_use]
1284 pub const fn is_visual(&self) -> bool {
1285 matches!(self, Self::BoundingBox { .. } | Self::Hybrid { .. })
1286 }
1287
1288 #[must_use]
1290 pub const fn text_offsets(&self) -> Option<(usize, usize)> {
1291 match self {
1292 Self::Text { start, end } => Some((*start, *end)),
1293 Self::Hybrid { start, end, .. } => Some((*start, *end)),
1294 Self::BoundingBox { .. } => None,
1295 }
1296 }
1297
1298 #[must_use]
1300 pub fn len(&self) -> usize {
1301 match self {
1302 Self::Text { start, end } => end.saturating_sub(*start),
1303 Self::Hybrid { start, end, .. } => end.saturating_sub(*start),
1304 Self::BoundingBox { .. } => 0,
1305 }
1306 }
1307
1308 #[must_use]
1310 pub fn is_empty(&self) -> bool {
1311 self.len() == 0
1312 }
1313}
1314
1315#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1356pub struct DiscontinuousSpan {
1357 segments: Vec<std::ops::Range<usize>>,
1360}
1361
1362impl DiscontinuousSpan {
1363 #[must_use]
1367 pub fn new(mut segments: Vec<std::ops::Range<usize>>) -> Self {
1368 segments.sort_by_key(|r| r.start);
1370 Self { segments }
1371 }
1372
1373 #[must_use]
1375 #[allow(clippy::single_range_in_vec_init)] pub fn contiguous(start: usize, end: usize) -> Self {
1377 Self {
1378 segments: vec![start..end],
1379 }
1380 }
1381
1382 #[must_use]
1384 pub fn num_segments(&self) -> usize {
1385 self.segments.len()
1386 }
1387
1388 #[must_use]
1390 pub fn is_discontinuous(&self) -> bool {
1391 self.segments.len() > 1
1392 }
1393
1394 #[must_use]
1396 pub fn is_contiguous(&self) -> bool {
1397 self.segments.len() <= 1
1398 }
1399
1400 #[must_use]
1402 pub fn segments(&self) -> &[std::ops::Range<usize>] {
1403 &self.segments
1404 }
1405
1406 #[must_use]
1408 pub fn bounding_range(&self) -> Option<std::ops::Range<usize>> {
1409 if self.segments.is_empty() {
1410 return None;
1411 }
1412 let start = self.segments.first()?.start;
1413 let end = self.segments.last()?.end;
1414 Some(start..end)
1415 }
1416
1417 #[must_use]
1420 pub fn total_len(&self) -> usize {
1421 self.segments.iter().map(|r| r.end - r.start).sum()
1422 }
1423
1424 #[must_use]
1426 pub fn extract_text(&self, text: &str, separator: &str) -> String {
1427 self.segments
1428 .iter()
1429 .map(|r| {
1430 let start = r.start;
1431 let len = r.end.saturating_sub(r.start);
1432 text.chars().skip(start).take(len).collect::<String>()
1433 })
1434 .collect::<Vec<_>>()
1435 .join(separator)
1436 }
1437
1438 #[must_use]
1448 pub fn contains(&self, pos: usize) -> bool {
1449 self.segments.iter().any(|r| r.contains(&pos))
1450 }
1451
1452 #[must_use]
1454 pub fn to_span(&self) -> Option<Span> {
1455 self.bounding_range().map(|r| Span::Text {
1456 start: r.start,
1457 end: r.end,
1458 })
1459 }
1460}
1461
1462impl From<std::ops::Range<usize>> for DiscontinuousSpan {
1463 fn from(range: std::ops::Range<usize>) -> Self {
1464 Self::contiguous(range.start, range.end)
1465 }
1466}
1467
1468impl Default for Span {
1469 fn default() -> Self {
1470 Self::Text { start: 0, end: 0 }
1471 }
1472}
1473
1474#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
1486pub struct HierarchicalConfidence {
1487 pub linkage: f32,
1490 pub type_score: f32,
1492 pub boundary: f32,
1495}
1496
1497impl HierarchicalConfidence {
1498 #[must_use]
1500 pub fn new(linkage: f32, type_score: f32, boundary: f32) -> Self {
1501 Self {
1502 linkage: linkage.clamp(0.0, 1.0),
1503 type_score: type_score.clamp(0.0, 1.0),
1504 boundary: boundary.clamp(0.0, 1.0),
1505 }
1506 }
1507
1508 #[must_use]
1511 pub fn from_single(confidence: f32) -> Self {
1512 let c = confidence.clamp(0.0, 1.0);
1513 Self {
1514 linkage: c,
1515 type_score: c,
1516 boundary: c,
1517 }
1518 }
1519
1520 #[must_use]
1523 pub fn combined(&self) -> f32 {
1524 (self.linkage * self.type_score * self.boundary).powf(1.0 / 3.0)
1525 }
1526
1527 #[must_use]
1529 pub fn as_f64(&self) -> f64 {
1530 self.combined() as f64
1531 }
1532
1533 #[must_use]
1535 pub fn passes_threshold(&self, linkage_min: f32, type_min: f32, boundary_min: f32) -> bool {
1536 self.linkage >= linkage_min && self.type_score >= type_min && self.boundary >= boundary_min
1537 }
1538}
1539
1540impl Default for HierarchicalConfidence {
1541 fn default() -> Self {
1542 Self {
1543 linkage: 1.0,
1544 type_score: 1.0,
1545 boundary: 1.0,
1546 }
1547 }
1548}
1549
1550impl From<f64> for HierarchicalConfidence {
1551 fn from(confidence: f64) -> Self {
1552 Self::from_single(confidence as f32)
1553 }
1554}
1555
1556impl From<f32> for HierarchicalConfidence {
1557 fn from(confidence: f32) -> Self {
1558 Self::from_single(confidence)
1559 }
1560}
1561
1562#[derive(Debug, Clone)]
1584pub struct RaggedBatch {
1585 pub token_ids: Vec<u32>,
1588 pub cumulative_offsets: Vec<u32>,
1592 pub max_seq_len: usize,
1594}
1595
1596impl RaggedBatch {
1597 pub fn from_sequences(sequences: &[Vec<u32>]) -> Self {
1599 let total_tokens: usize = sequences.iter().map(|s| s.len()).sum();
1600 let mut token_ids = Vec::with_capacity(total_tokens);
1601 let mut cumulative_offsets = Vec::with_capacity(sequences.len() + 1);
1602 let mut max_seq_len = 0;
1603
1604 cumulative_offsets.push(0);
1605 for seq in sequences {
1606 token_ids.extend_from_slice(seq);
1607 let len = token_ids.len();
1611 if len > u32::MAX as usize {
1612 log::warn!(
1615 "Token count {} exceeds u32::MAX, truncating to {}",
1616 len,
1617 u32::MAX
1618 );
1619 cumulative_offsets.push(u32::MAX);
1620 } else {
1621 cumulative_offsets.push(len as u32);
1622 }
1623 max_seq_len = max_seq_len.max(seq.len());
1624 }
1625
1626 Self {
1627 token_ids,
1628 cumulative_offsets,
1629 max_seq_len,
1630 }
1631 }
1632
1633 #[must_use]
1635 pub fn batch_size(&self) -> usize {
1636 self.cumulative_offsets.len().saturating_sub(1)
1637 }
1638
1639 #[must_use]
1641 pub fn total_tokens(&self) -> usize {
1642 self.token_ids.len()
1643 }
1644
1645 #[must_use]
1647 pub fn doc_range(&self, doc_idx: usize) -> Option<std::ops::Range<usize>> {
1648 if doc_idx + 1 < self.cumulative_offsets.len() {
1649 let start = self.cumulative_offsets[doc_idx] as usize;
1650 let end = self.cumulative_offsets[doc_idx + 1] as usize;
1651 Some(start..end)
1652 } else {
1653 None
1654 }
1655 }
1656
1657 #[must_use]
1659 pub fn doc_tokens(&self, doc_idx: usize) -> Option<&[u32]> {
1660 self.doc_range(doc_idx).map(|r| &self.token_ids[r])
1661 }
1662
1663 #[must_use]
1665 pub fn padding_savings(&self) -> f64 {
1666 let padded_size = self.batch_size() * self.max_seq_len;
1667 if padded_size == 0 {
1668 return 0.0;
1669 }
1670 1.0 - (self.total_tokens() as f64 / padded_size as f64)
1671 }
1672}
1673
1674#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1683pub struct SpanCandidate {
1684 pub doc_idx: u32,
1686 pub start: u32,
1688 pub end: u32,
1690}
1691
1692impl SpanCandidate {
1693 #[must_use]
1695 pub const fn new(doc_idx: u32, start: u32, end: u32) -> Self {
1696 Self {
1697 doc_idx,
1698 start,
1699 end,
1700 }
1701 }
1702
1703 #[must_use]
1705 pub const fn width(&self) -> u32 {
1706 self.end.saturating_sub(self.start)
1707 }
1708}
1709
1710pub fn generate_span_candidates(batch: &RaggedBatch, max_width: usize) -> Vec<SpanCandidate> {
1715 let mut candidates = Vec::new();
1716
1717 for doc_idx in 0..batch.batch_size() {
1718 if let Some(range) = batch.doc_range(doc_idx) {
1719 let doc_len = range.len();
1720 for start in 0..doc_len {
1722 let max_end = (start + max_width).min(doc_len);
1723 for end in (start + 1)..=max_end {
1724 candidates.push(SpanCandidate::new(doc_idx as u32, start as u32, end as u32));
1725 }
1726 }
1727 }
1728 }
1729
1730 candidates
1731}
1732
1733pub fn generate_filtered_candidates(
1737 batch: &RaggedBatch,
1738 max_width: usize,
1739 linkage_mask: &[f32],
1740 threshold: f32,
1741) -> Vec<SpanCandidate> {
1742 let mut candidates = Vec::new();
1743 let mut mask_idx = 0;
1744
1745 for doc_idx in 0..batch.batch_size() {
1746 if let Some(range) = batch.doc_range(doc_idx) {
1747 let doc_len = range.len();
1748 for start in 0..doc_len {
1749 let max_end = (start + max_width).min(doc_len);
1750 for end in (start + 1)..=max_end {
1751 if mask_idx < linkage_mask.len() && linkage_mask[mask_idx] >= threshold {
1753 candidates.push(SpanCandidate::new(
1754 doc_idx as u32,
1755 start as u32,
1756 end as u32,
1757 ));
1758 }
1759 mask_idx += 1;
1760 }
1761 }
1762 }
1763 }
1764
1765 candidates
1766}
1767
1768#[derive(Debug, Clone, Serialize, Deserialize)]
1819pub struct Entity {
1820 pub text: String,
1822 pub entity_type: EntityType,
1824 pub start: usize,
1829 pub end: usize,
1834 pub confidence: Confidence,
1839 #[serde(default, skip_serializing_if = "Option::is_none")]
1841 pub normalized: Option<String>,
1842 #[serde(default, skip_serializing_if = "Option::is_none")]
1844 pub provenance: Option<Provenance>,
1845 #[serde(default, skip_serializing_if = "Option::is_none")]
1848 pub kb_id: Option<String>,
1849 #[serde(default, skip_serializing_if = "Option::is_none")]
1853 pub canonical_id: Option<super::types::CanonicalId>,
1854 #[serde(default, skip_serializing_if = "Option::is_none")]
1857 pub hierarchical_confidence: Option<HierarchicalConfidence>,
1858 #[serde(default, skip_serializing_if = "Option::is_none")]
1861 pub visual_span: Option<Span>,
1862 #[serde(default, skip_serializing_if = "Option::is_none")]
1866 pub discontinuous_span: Option<DiscontinuousSpan>,
1867 #[serde(default, skip_serializing_if = "Option::is_none")]
1889 pub valid_from: Option<chrono::DateTime<chrono::Utc>>,
1890 #[serde(default, skip_serializing_if = "Option::is_none")]
1895 pub valid_until: Option<chrono::DateTime<chrono::Utc>>,
1896 #[serde(default, skip_serializing_if = "Option::is_none")]
1917 pub viewport: Option<EntityViewport>,
1918 #[serde(default, skip_serializing_if = "Option::is_none")]
1924 pub phi_features: Option<PhiFeatures>,
1925 #[serde(default, skip_serializing_if = "Option::is_none")]
1931 pub mention_type: Option<MentionType>,
1932}
1933
1934impl Entity {
1935 #[must_use]
1946 pub fn new(
1947 text: impl Into<String>,
1948 entity_type: EntityType,
1949 start: usize,
1950 end: usize,
1951 confidence: impl Into<Confidence>,
1952 ) -> Self {
1953 Self {
1954 text: text.into(),
1955 entity_type,
1956 start,
1957 end,
1958 confidence: confidence.into(),
1959 normalized: None,
1960 provenance: None,
1961 kb_id: None,
1962 canonical_id: None,
1963 hierarchical_confidence: None,
1964 visual_span: None,
1965 discontinuous_span: None,
1966 valid_from: None,
1967 valid_until: None,
1968 viewport: None,
1969 phi_features: None,
1970 mention_type: None,
1971 }
1972 }
1973
1974 #[must_use]
1976 pub fn with_provenance(
1977 text: impl Into<String>,
1978 entity_type: EntityType,
1979 start: usize,
1980 end: usize,
1981 confidence: impl Into<Confidence>,
1982 provenance: Provenance,
1983 ) -> Self {
1984 Self {
1985 text: text.into(),
1986 entity_type,
1987 start,
1988 end,
1989 confidence: confidence.into(),
1990 normalized: None,
1991 provenance: Some(provenance),
1992 kb_id: None,
1993 canonical_id: None,
1994 hierarchical_confidence: None,
1995 visual_span: None,
1996 discontinuous_span: None,
1997 valid_from: None,
1998 valid_until: None,
1999 viewport: None,
2000 phi_features: None,
2001 mention_type: None,
2002 }
2003 }
2004
2005 #[must_use]
2007 pub fn with_hierarchical_confidence(
2008 text: impl Into<String>,
2009 entity_type: EntityType,
2010 start: usize,
2011 end: usize,
2012 confidence: HierarchicalConfidence,
2013 ) -> Self {
2014 Self {
2015 text: text.into(),
2016 entity_type,
2017 start,
2018 end,
2019 confidence: Confidence::new(confidence.as_f64()),
2020 normalized: None,
2021 provenance: None,
2022 kb_id: None,
2023 canonical_id: None,
2024 hierarchical_confidence: Some(confidence),
2025 visual_span: None,
2026 discontinuous_span: None,
2027 valid_from: None,
2028 valid_until: None,
2029 viewport: None,
2030 phi_features: None,
2031 mention_type: None,
2032 }
2033 }
2034
2035 #[must_use]
2037 pub fn from_visual(
2038 text: impl Into<String>,
2039 entity_type: EntityType,
2040 bbox: Span,
2041 confidence: impl Into<Confidence>,
2042 ) -> Self {
2043 Self {
2044 text: text.into(),
2045 entity_type,
2046 start: 0,
2047 end: 0,
2048 confidence: confidence.into(),
2049 normalized: None,
2050 provenance: None,
2051 kb_id: None,
2052 canonical_id: None,
2053 hierarchical_confidence: None,
2054 visual_span: Some(bbox),
2055 discontinuous_span: None,
2056 valid_from: None,
2057 valid_until: None,
2058 viewport: None,
2059 phi_features: None,
2060 mention_type: None,
2061 }
2062 }
2063
2064 #[must_use]
2066 pub fn with_type(
2067 text: impl Into<String>,
2068 entity_type: EntityType,
2069 start: usize,
2070 end: usize,
2071 ) -> Self {
2072 Self::new(text, entity_type, start, end, 1.0)
2073 }
2074
2075 pub fn link_to_kb(&mut self, kb_id: impl Into<String>) {
2085 self.kb_id = Some(kb_id.into());
2086 }
2087
2088 pub fn set_canonical(&mut self, canonical_id: impl Into<super::types::CanonicalId>) {
2092 self.canonical_id = Some(canonical_id.into());
2093 }
2094
2095 #[must_use]
2105 pub fn with_canonical_id(mut self, canonical_id: impl Into<super::types::CanonicalId>) -> Self {
2106 self.canonical_id = Some(canonical_id.into());
2107 self
2108 }
2109
2110 #[must_use]
2112 pub fn is_linked(&self) -> bool {
2113 self.kb_id.is_some()
2114 }
2115
2116 #[must_use]
2118 pub fn has_coreference(&self) -> bool {
2119 self.canonical_id.is_some()
2120 }
2121
2122 #[must_use]
2128 pub fn is_discontinuous(&self) -> bool {
2129 self.discontinuous_span
2130 .as_ref()
2131 .map(|s| s.is_discontinuous())
2132 .unwrap_or(false)
2133 }
2134
2135 #[must_use]
2139 pub fn discontinuous_segments(&self) -> Option<Vec<std::ops::Range<usize>>> {
2140 self.discontinuous_span
2141 .as_ref()
2142 .filter(|s| s.is_discontinuous())
2143 .map(|s| s.segments().to_vec())
2144 }
2145
2146 pub fn set_discontinuous_span(&mut self, span: DiscontinuousSpan) {
2150 if let Some(bounding) = span.bounding_range() {
2152 self.start = bounding.start;
2153 self.end = bounding.end;
2154 }
2155 self.discontinuous_span = Some(span);
2156 }
2157
2158 #[must_use]
2166 pub fn total_len(&self) -> usize {
2167 if let Some(ref span) = self.discontinuous_span {
2168 span.segments().iter().map(|r| r.end - r.start).sum()
2169 } else {
2170 self.end.saturating_sub(self.start)
2171 }
2172 }
2173
2174 pub fn set_normalized(&mut self, normalized: impl Into<String>) {
2186 self.normalized = Some(normalized.into());
2187 }
2188
2189 #[must_use]
2191 pub fn normalized_or_text(&self) -> &str {
2192 self.normalized.as_deref().unwrap_or(&self.text)
2193 }
2194
2195 #[must_use]
2197 pub fn method(&self) -> ExtractionMethod {
2198 self.provenance
2199 .as_ref()
2200 .map_or(ExtractionMethod::Unknown, |p| p.method)
2201 }
2202
2203 #[must_use]
2205 pub fn source(&self) -> Option<&str> {
2206 self.provenance.as_ref().map(|p| p.source.as_ref())
2207 }
2208
2209 #[must_use]
2211 pub fn category(&self) -> EntityCategory {
2212 self.entity_type.category()
2213 }
2214
2215 #[must_use]
2217 pub fn is_structured(&self) -> bool {
2218 self.entity_type.pattern_detectable()
2219 }
2220
2221 #[must_use]
2223 pub fn is_named(&self) -> bool {
2224 self.entity_type.requires_ml()
2225 }
2226
2227 #[must_use]
2229 pub fn overlaps(&self, other: &Entity) -> bool {
2230 !(self.end <= other.start || other.end <= self.start)
2231 }
2232
2233 #[must_use]
2235 pub fn overlap_ratio(&self, other: &Entity) -> f64 {
2236 let intersection_start = self.start.max(other.start);
2237 let intersection_end = self.end.min(other.end);
2238
2239 if intersection_start >= intersection_end {
2240 return 0.0;
2241 }
2242
2243 let intersection = (intersection_end - intersection_start) as f64;
2244 let union = ((self.end - self.start) + (other.end - other.start)
2245 - (intersection_end - intersection_start)) as f64;
2246
2247 if union == 0.0 {
2248 return 1.0;
2249 }
2250
2251 intersection / union
2252 }
2253
2254 pub fn set_hierarchical_confidence(&mut self, confidence: HierarchicalConfidence) {
2256 self.confidence = Confidence::new(confidence.as_f64());
2257 self.hierarchical_confidence = Some(confidence);
2258 }
2259
2260 #[must_use]
2262 pub fn linkage_confidence(&self) -> f32 {
2263 self.hierarchical_confidence
2264 .map_or(f32::from(self.confidence), |h| h.linkage)
2265 }
2266
2267 #[must_use]
2269 pub fn type_confidence(&self) -> f32 {
2270 self.hierarchical_confidence
2271 .map_or(f32::from(self.confidence), |h| h.type_score)
2272 }
2273
2274 #[must_use]
2276 pub fn boundary_confidence(&self) -> f32 {
2277 self.hierarchical_confidence
2278 .map_or(f32::from(self.confidence), |h| h.boundary)
2279 }
2280
2281 #[must_use]
2283 pub fn is_visual(&self) -> bool {
2284 self.visual_span.is_some()
2285 }
2286
2287 #[must_use]
2289 pub const fn text_span(&self) -> (usize, usize) {
2290 (self.start, self.end)
2291 }
2292
2293 #[must_use]
2295 pub const fn span_len(&self) -> usize {
2296 self.end.saturating_sub(self.start)
2297 }
2298
2299 pub fn set_visual_span(&mut self, span: Span) {
2324 self.visual_span = Some(span);
2325 }
2326
2327 #[must_use]
2347 pub fn extract_text(&self, source_text: &str) -> String {
2348 let char_count = source_text.chars().count();
2352 self.extract_text_with_len(source_text, char_count)
2353 }
2354
2355 #[must_use]
2367 pub fn extract_text_with_len(&self, source_text: &str, text_char_count: usize) -> String {
2368 if self.start >= text_char_count || self.end > text_char_count || self.start >= self.end {
2369 return String::new();
2370 }
2371 source_text
2372 .chars()
2373 .skip(self.start)
2374 .take(self.end - self.start)
2375 .collect()
2376 }
2377
2378 pub fn set_valid_from(&mut self, dt: chrono::DateTime<chrono::Utc>) {
2394 self.valid_from = Some(dt);
2395 }
2396
2397 pub fn set_valid_until(&mut self, dt: chrono::DateTime<chrono::Utc>) {
2399 self.valid_until = Some(dt);
2400 }
2401
2402 pub fn set_temporal_range(
2404 &mut self,
2405 from: chrono::DateTime<chrono::Utc>,
2406 until: chrono::DateTime<chrono::Utc>,
2407 ) {
2408 self.valid_from = Some(from);
2409 self.valid_until = Some(until);
2410 }
2411
2412 #[must_use]
2414 pub fn is_temporal(&self) -> bool {
2415 self.valid_from.is_some() || self.valid_until.is_some()
2416 }
2417
2418 #[must_use]
2440 pub fn valid_at(&self, timestamp: &chrono::DateTime<chrono::Utc>) -> bool {
2441 match (&self.valid_from, &self.valid_until) {
2442 (None, None) => true, (Some(from), None) => timestamp >= from, (None, Some(until)) => timestamp <= until, (Some(from), Some(until)) => timestamp >= from && timestamp <= until,
2446 }
2447 }
2448
2449 #[must_use]
2451 pub fn is_currently_valid(&self) -> bool {
2452 self.valid_at(&chrono::Utc::now())
2453 }
2454
2455 pub fn set_viewport(&mut self, viewport: EntityViewport) {
2470 self.viewport = Some(viewport);
2471 }
2472
2473 #[must_use]
2475 pub fn has_viewport(&self) -> bool {
2476 self.viewport.is_some()
2477 }
2478
2479 #[must_use]
2481 pub fn viewport_or_default(&self) -> EntityViewport {
2482 self.viewport.clone().unwrap_or_default()
2483 }
2484
2485 #[must_use]
2491 pub fn matches_viewport(&self, query_viewport: &EntityViewport) -> bool {
2492 match &self.viewport {
2493 None => true, Some(v) => v == query_viewport,
2495 }
2496 }
2497
2498 #[must_use]
2500 pub fn builder(text: impl Into<String>, entity_type: EntityType) -> EntityBuilder {
2501 EntityBuilder::new(text, entity_type)
2502 }
2503
2504 #[must_use]
2537 pub fn validate(&self, source_text: &str) -> Vec<ValidationIssue> {
2538 let char_count = source_text.chars().count();
2540 self.validate_with_len(source_text, char_count)
2541 }
2542
2543 #[must_use]
2555 pub fn validate_with_len(
2556 &self,
2557 source_text: &str,
2558 text_char_count: usize,
2559 ) -> Vec<ValidationIssue> {
2560 let mut issues = Vec::new();
2561
2562 if self.start >= self.end {
2564 issues.push(ValidationIssue::InvalidSpan {
2565 start: self.start,
2566 end: self.end,
2567 reason: "start must be less than end".to_string(),
2568 });
2569 }
2570
2571 if self.end > text_char_count {
2572 issues.push(ValidationIssue::SpanOutOfBounds {
2573 end: self.end,
2574 text_len: text_char_count,
2575 });
2576 }
2577
2578 if self.start < self.end && self.end <= text_char_count {
2580 let actual = self.extract_text_with_len(source_text, text_char_count);
2581 if actual != self.text {
2582 issues.push(ValidationIssue::TextMismatch {
2583 expected: self.text.clone(),
2584 actual,
2585 start: self.start,
2586 end: self.end,
2587 });
2588 }
2589 }
2590
2591 if let EntityType::Custom { ref name, .. } = self.entity_type {
2595 if name.is_empty() {
2596 issues.push(ValidationIssue::InvalidType {
2597 reason: "Custom entity type has empty name".to_string(),
2598 });
2599 }
2600 }
2601
2602 if let Some(ref disc_span) = self.discontinuous_span {
2604 for (i, seg) in disc_span.segments().iter().enumerate() {
2605 if seg.start >= seg.end {
2606 issues.push(ValidationIssue::InvalidSpan {
2607 start: seg.start,
2608 end: seg.end,
2609 reason: format!("discontinuous segment {} is invalid", i),
2610 });
2611 }
2612 if seg.end > text_char_count {
2613 issues.push(ValidationIssue::SpanOutOfBounds {
2614 end: seg.end,
2615 text_len: text_char_count,
2616 });
2617 }
2618 }
2619 }
2620
2621 issues
2622 }
2623
2624 #[must_use]
2628 pub fn is_valid(&self, source_text: &str) -> bool {
2629 self.validate(source_text).is_empty()
2630 }
2631
2632 #[must_use]
2652 pub fn validate_batch(
2653 entities: &[Entity],
2654 source_text: &str,
2655 ) -> std::collections::HashMap<usize, Vec<ValidationIssue>> {
2656 entities
2657 .iter()
2658 .enumerate()
2659 .filter_map(|(idx, entity)| {
2660 let issues = entity.validate(source_text);
2661 if issues.is_empty() {
2662 None
2663 } else {
2664 Some((idx, issues))
2665 }
2666 })
2667 .collect()
2668 }
2669}
2670
2671#[derive(Debug, Clone, PartialEq)]
2673pub enum ValidationIssue {
2674 InvalidSpan {
2676 start: usize,
2678 end: usize,
2680 reason: String,
2682 },
2683 SpanOutOfBounds {
2685 end: usize,
2687 text_len: usize,
2689 },
2690 TextMismatch {
2692 expected: String,
2694 actual: String,
2696 start: usize,
2698 end: usize,
2700 },
2701 InvalidConfidence {
2703 value: f64,
2705 },
2706 InvalidType {
2708 reason: String,
2710 },
2711}
2712
2713impl std::fmt::Display for ValidationIssue {
2714 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2715 match self {
2716 ValidationIssue::InvalidSpan { start, end, reason } => {
2717 write!(f, "Invalid span [{}, {}): {}", start, end, reason)
2718 }
2719 ValidationIssue::SpanOutOfBounds { end, text_len } => {
2720 write!(f, "Span end {} exceeds text length {}", end, text_len)
2721 }
2722 ValidationIssue::TextMismatch {
2723 expected,
2724 actual,
2725 start,
2726 end,
2727 } => {
2728 write!(
2729 f,
2730 "Text mismatch at [{}, {}): expected '{}', got '{}'",
2731 start, end, expected, actual
2732 )
2733 }
2734 ValidationIssue::InvalidConfidence { value } => {
2735 write!(f, "Confidence {} outside [0.0, 1.0]", value)
2736 }
2737 ValidationIssue::InvalidType { reason } => {
2738 write!(f, "Invalid entity type: {}", reason)
2739 }
2740 }
2741 }
2742}
2743
2744#[derive(Debug, Clone)]
2759pub struct EntityBuilder {
2760 text: String,
2761 entity_type: EntityType,
2762 start: usize,
2763 end: usize,
2764 confidence: Confidence,
2765 normalized: Option<String>,
2766 provenance: Option<Provenance>,
2767 kb_id: Option<String>,
2768 canonical_id: Option<super::types::CanonicalId>,
2769 hierarchical_confidence: Option<HierarchicalConfidence>,
2770 visual_span: Option<Span>,
2771 discontinuous_span: Option<DiscontinuousSpan>,
2772 valid_from: Option<chrono::DateTime<chrono::Utc>>,
2773 valid_until: Option<chrono::DateTime<chrono::Utc>>,
2774 viewport: Option<EntityViewport>,
2775 phi_features: Option<PhiFeatures>,
2776 mention_type: Option<MentionType>,
2777}
2778
2779impl EntityBuilder {
2780 #[must_use]
2782 pub fn new(text: impl Into<String>, entity_type: EntityType) -> Self {
2783 let text = text.into();
2784 let end = text.chars().count();
2785 Self {
2786 text,
2787 entity_type,
2788 start: 0,
2789 end,
2790 confidence: Confidence::ONE,
2791 normalized: None,
2792 provenance: None,
2793 kb_id: None,
2794 canonical_id: None,
2795 hierarchical_confidence: None,
2796 visual_span: None,
2797 discontinuous_span: None,
2798 valid_from: None,
2799 valid_until: None,
2800 viewport: None,
2801 phi_features: None,
2802 mention_type: None,
2803 }
2804 }
2805
2806 #[must_use]
2808 pub const fn span(mut self, start: usize, end: usize) -> Self {
2809 self.start = start;
2810 self.end = end;
2811 self
2812 }
2813
2814 #[must_use]
2816 pub fn confidence(mut self, confidence: impl Into<Confidence>) -> Self {
2817 self.confidence = confidence.into();
2818 self
2819 }
2820
2821 #[must_use]
2823 pub fn hierarchical_confidence(mut self, confidence: HierarchicalConfidence) -> Self {
2824 self.confidence = Confidence::new(confidence.as_f64());
2825 self.hierarchical_confidence = Some(confidence);
2826 self
2827 }
2828
2829 #[must_use]
2831 pub fn normalized(mut self, normalized: impl Into<String>) -> Self {
2832 self.normalized = Some(normalized.into());
2833 self
2834 }
2835
2836 #[must_use]
2838 pub fn provenance(mut self, provenance: Provenance) -> Self {
2839 self.provenance = Some(provenance);
2840 self
2841 }
2842
2843 #[must_use]
2845 pub fn kb_id(mut self, kb_id: impl Into<String>) -> Self {
2846 self.kb_id = Some(kb_id.into());
2847 self
2848 }
2849
2850 #[must_use]
2852 pub const fn canonical_id(mut self, canonical_id: u64) -> Self {
2853 self.canonical_id = Some(super::types::CanonicalId::new(canonical_id));
2854 self
2855 }
2856
2857 #[must_use]
2859 pub fn visual_span(mut self, span: Span) -> Self {
2860 self.visual_span = Some(span);
2861 self
2862 }
2863
2864 #[must_use]
2868 pub fn discontinuous_span(mut self, span: DiscontinuousSpan) -> Self {
2869 if let Some(bounding) = span.bounding_range() {
2871 self.start = bounding.start;
2872 self.end = bounding.end;
2873 }
2874 self.discontinuous_span = Some(span);
2875 self
2876 }
2877
2878 #[must_use]
2892 pub fn valid_from(mut self, dt: chrono::DateTime<chrono::Utc>) -> Self {
2893 self.valid_from = Some(dt);
2894 self
2895 }
2896
2897 #[must_use]
2899 pub fn valid_until(mut self, dt: chrono::DateTime<chrono::Utc>) -> Self {
2900 self.valid_until = Some(dt);
2901 self
2902 }
2903
2904 #[must_use]
2906 pub fn temporal_range(
2907 mut self,
2908 from: chrono::DateTime<chrono::Utc>,
2909 until: chrono::DateTime<chrono::Utc>,
2910 ) -> Self {
2911 self.valid_from = Some(from);
2912 self.valid_until = Some(until);
2913 self
2914 }
2915
2916 #[must_use]
2929 pub fn viewport(mut self, viewport: EntityViewport) -> Self {
2930 self.viewport = Some(viewport);
2931 self
2932 }
2933
2934 #[must_use]
2936 pub fn phi_features(mut self, phi_features: PhiFeatures) -> Self {
2937 self.phi_features = Some(phi_features);
2938 self
2939 }
2940
2941 #[must_use]
2943 pub fn mention_type(mut self, mention_type: MentionType) -> Self {
2944 self.mention_type = Some(mention_type);
2945 self
2946 }
2947
2948 #[must_use]
2950 pub fn build(self) -> Entity {
2951 Entity {
2952 text: self.text,
2953 entity_type: self.entity_type,
2954 start: self.start,
2955 end: self.end,
2956 confidence: self.confidence,
2957 normalized: self.normalized,
2958 provenance: self.provenance,
2959 kb_id: self.kb_id,
2960 canonical_id: self.canonical_id,
2961 hierarchical_confidence: self.hierarchical_confidence,
2962 visual_span: self.visual_span,
2963 discontinuous_span: self.discontinuous_span,
2964 valid_from: self.valid_from,
2965 valid_until: self.valid_until,
2966 viewport: self.viewport,
2967 phi_features: self.phi_features,
2968 mention_type: self.mention_type,
2969 }
2970 }
2971}
2972
2973#[derive(Debug, Clone, Serialize, Deserialize)]
2999pub struct Relation {
3000 pub head: Entity,
3002 pub tail: Entity,
3004 pub relation_type: String,
3006 pub trigger_span: Option<(usize, usize)>,
3009 pub confidence: Confidence,
3011}
3012
3013impl Relation {
3014 #[must_use]
3016 pub fn new(
3017 head: Entity,
3018 tail: Entity,
3019 relation_type: impl Into<String>,
3020 confidence: impl Into<Confidence>,
3021 ) -> Self {
3022 Self {
3023 head,
3024 tail,
3025 relation_type: relation_type.into(),
3026 trigger_span: None,
3027 confidence: confidence.into(),
3028 }
3029 }
3030
3031 #[must_use]
3033 pub fn with_trigger(
3034 head: Entity,
3035 tail: Entity,
3036 relation_type: impl Into<String>,
3037 trigger_start: usize,
3038 trigger_end: usize,
3039 confidence: impl Into<Confidence>,
3040 ) -> Self {
3041 Self {
3042 head,
3043 tail,
3044 relation_type: relation_type.into(),
3045 trigger_span: Some((trigger_start, trigger_end)),
3046 confidence: confidence.into(),
3047 }
3048 }
3049
3050 #[must_use]
3052 pub fn as_triple(&self) -> String {
3053 format!(
3054 "({}, {}, {})",
3055 self.head.text, self.relation_type, self.tail.text
3056 )
3057 }
3058
3059 #[must_use]
3062 pub fn span_distance(&self) -> usize {
3063 if self.head.end <= self.tail.start {
3064 self.tail.start.saturating_sub(self.head.end)
3065 } else if self.tail.end <= self.head.start {
3066 self.head.start.saturating_sub(self.tail.end)
3067 } else {
3068 0 }
3070 }
3071}
3072
3073#[cfg(test)]
3074mod tests {
3075 #![allow(clippy::unwrap_used)] use super::*;
3077
3078 #[test]
3079 fn test_entity_type_roundtrip() {
3080 let types = [
3081 EntityType::Person,
3082 EntityType::Organization,
3083 EntityType::Location,
3084 EntityType::Date,
3085 EntityType::Money,
3086 EntityType::Percent,
3087 ];
3088
3089 for t in types {
3090 let label = t.as_label();
3091 let parsed = EntityType::from_label(label);
3092 assert_eq!(t, parsed);
3093 }
3094 }
3095
3096 #[test]
3097 fn test_entity_overlap() {
3098 let e1 = Entity::new("John", EntityType::Person, 0, 4, 0.9);
3099 let e2 = Entity::new("Smith", EntityType::Person, 5, 10, 0.9);
3100 let e3 = Entity::new("John Smith", EntityType::Person, 0, 10, 0.9);
3101
3102 assert!(!e1.overlaps(&e2)); assert!(e1.overlaps(&e3)); assert!(e3.overlaps(&e2)); }
3106
3107 #[test]
3108 fn test_confidence_clamping() {
3109 let e1 = Entity::new("test", EntityType::Person, 0, 4, 1.5);
3110 assert!((e1.confidence - 1.0).abs() < f64::EPSILON);
3111
3112 let e2 = Entity::new("test", EntityType::Person, 0, 4, -0.5);
3113 assert!(e2.confidence.abs() < f64::EPSILON);
3114 }
3115
3116 #[test]
3117 fn test_entity_categories() {
3118 assert_eq!(EntityType::Person.category(), EntityCategory::Agent);
3120 assert_eq!(
3121 EntityType::Organization.category(),
3122 EntityCategory::Organization
3123 );
3124 assert_eq!(EntityType::Location.category(), EntityCategory::Place);
3125 assert!(EntityType::Person.requires_ml());
3126 assert!(!EntityType::Person.pattern_detectable());
3127
3128 assert_eq!(EntityType::Date.category(), EntityCategory::Temporal);
3130 assert_eq!(EntityType::Time.category(), EntityCategory::Temporal);
3131 assert!(EntityType::Date.pattern_detectable());
3132 assert!(!EntityType::Date.requires_ml());
3133
3134 assert_eq!(EntityType::Money.category(), EntityCategory::Numeric);
3136 assert_eq!(EntityType::Percent.category(), EntityCategory::Numeric);
3137 assert!(EntityType::Money.pattern_detectable());
3138
3139 assert_eq!(EntityType::Email.category(), EntityCategory::Contact);
3141 assert_eq!(EntityType::Url.category(), EntityCategory::Contact);
3142 assert_eq!(EntityType::Phone.category(), EntityCategory::Contact);
3143 assert!(EntityType::Email.pattern_detectable());
3144 }
3145
3146 #[test]
3147 fn test_new_types_roundtrip() {
3148 let types = [
3149 EntityType::Time,
3150 EntityType::Email,
3151 EntityType::Url,
3152 EntityType::Phone,
3153 EntityType::Quantity,
3154 EntityType::Cardinal,
3155 EntityType::Ordinal,
3156 ];
3157
3158 for t in types {
3159 let label = t.as_label();
3160 let parsed = EntityType::from_label(label);
3161 assert_eq!(t, parsed, "Roundtrip failed for {}", label);
3162 }
3163 }
3164
3165 #[test]
3166 fn test_custom_entity_type() {
3167 let disease = EntityType::custom("DISEASE", EntityCategory::Agent);
3168 assert_eq!(disease.as_label(), "DISEASE");
3169 assert!(disease.requires_ml());
3170
3171 let product_id = EntityType::custom("PRODUCT_ID", EntityCategory::Misc);
3172 assert_eq!(product_id.as_label(), "PRODUCT_ID");
3173 assert!(!product_id.requires_ml());
3174 assert!(!product_id.pattern_detectable());
3175 }
3176
3177 #[test]
3178 fn test_entity_normalization() {
3179 let mut e = Entity::new("Jan 15", EntityType::Date, 0, 6, 0.95);
3180 assert!(e.normalized.is_none());
3181 assert_eq!(e.normalized_or_text(), "Jan 15");
3182
3183 e.set_normalized("2024-01-15");
3184 assert_eq!(e.normalized.as_deref(), Some("2024-01-15"));
3185 assert_eq!(e.normalized_or_text(), "2024-01-15");
3186 }
3187
3188 #[test]
3189 fn test_entity_helpers() {
3190 let named = Entity::new("John", EntityType::Person, 0, 4, 0.9);
3191 assert!(named.is_named());
3192 assert!(!named.is_structured());
3193 assert_eq!(named.category(), EntityCategory::Agent);
3194
3195 let structured = Entity::new("$100", EntityType::Money, 0, 4, 0.95);
3196 assert!(!structured.is_named());
3197 assert!(structured.is_structured());
3198 assert_eq!(structured.category(), EntityCategory::Numeric);
3199 }
3200
3201 #[test]
3202 fn test_knowledge_linking() {
3203 let mut entity = Entity::new("Marie Curie", EntityType::Person, 0, 11, 0.95);
3204 assert!(!entity.is_linked());
3205 assert!(!entity.has_coreference());
3206
3207 entity.link_to_kb("Q7186"); assert!(entity.is_linked());
3209 assert_eq!(entity.kb_id.as_deref(), Some("Q7186"));
3210
3211 entity.set_canonical(42);
3212 assert!(entity.has_coreference());
3213 assert_eq!(
3214 entity.canonical_id,
3215 Some(crate::core::types::CanonicalId::new(42))
3216 );
3217 }
3218
3219 #[test]
3220 fn test_relation_creation() {
3221 let head = Entity::new("Marie Curie", EntityType::Person, 0, 11, 0.95);
3222 let tail = Entity::new("Sorbonne", EntityType::Organization, 24, 32, 0.90);
3223
3224 let relation = Relation::new(head.clone(), tail.clone(), "WORKED_AT", 0.85);
3225 assert_eq!(relation.relation_type, "WORKED_AT");
3226 assert_eq!(relation.as_triple(), "(Marie Curie, WORKED_AT, Sorbonne)");
3227 assert!(relation.trigger_span.is_none());
3228
3229 let relation2 = Relation::with_trigger(head, tail, "EMPLOYMENT", 13, 19, 0.85);
3231 assert_eq!(relation2.trigger_span, Some((13, 19)));
3232 }
3233
3234 #[test]
3235 fn test_relation_span_distance() {
3236 let head = Entity::new("Marie Curie", EntityType::Person, 0, 11, 0.95);
3238 let tail = Entity::new("Sorbonne", EntityType::Organization, 24, 32, 0.90);
3239 let relation = Relation::new(head, tail, "WORKED_AT", 0.85);
3240 assert_eq!(relation.span_distance(), 13);
3241 }
3242
3243 #[test]
3244 fn test_relation_category() {
3245 let rel_type = EntityType::custom("CEO_OF", EntityCategory::Relation);
3247 assert_eq!(rel_type.category(), EntityCategory::Relation);
3248 assert!(rel_type.category().is_relation());
3249 assert!(rel_type.requires_ml()); }
3251
3252 #[test]
3257 fn test_span_text() {
3258 let span = Span::text(10, 20);
3259 assert!(span.is_text());
3260 assert!(!span.is_visual());
3261 assert_eq!(span.text_offsets(), Some((10, 20)));
3262 assert_eq!(span.len(), 10);
3263 assert!(!span.is_empty());
3264 }
3265
3266 #[test]
3267 fn test_span_bbox() {
3268 let span = Span::bbox(0.1, 0.2, 0.3, 0.4);
3269 assert!(!span.is_text());
3270 assert!(span.is_visual());
3271 assert_eq!(span.text_offsets(), None);
3272 assert_eq!(span.len(), 0); }
3274
3275 #[test]
3276 fn test_span_bbox_with_page() {
3277 let span = Span::bbox_on_page(0.1, 0.2, 0.3, 0.4, 5);
3278 if let Span::BoundingBox { page, .. } = span {
3279 assert_eq!(page, Some(5));
3280 } else {
3281 panic!("Expected BoundingBox");
3282 }
3283 }
3284
3285 #[test]
3286 fn test_span_hybrid() {
3287 let bbox = Span::bbox(0.1, 0.2, 0.3, 0.4);
3288 let hybrid = Span::Hybrid {
3289 start: 10,
3290 end: 20,
3291 bbox: Box::new(bbox),
3292 };
3293 assert!(hybrid.is_text());
3294 assert!(hybrid.is_visual());
3295 assert_eq!(hybrid.text_offsets(), Some((10, 20)));
3296 assert_eq!(hybrid.len(), 10);
3297 }
3298
3299 #[test]
3304 fn test_hierarchical_confidence_new() {
3305 let hc = HierarchicalConfidence::new(0.9, 0.8, 0.7);
3306 assert!((hc.linkage - 0.9).abs() < f32::EPSILON);
3307 assert!((hc.type_score - 0.8).abs() < f32::EPSILON);
3308 assert!((hc.boundary - 0.7).abs() < f32::EPSILON);
3309 }
3310
3311 #[test]
3312 fn test_hierarchical_confidence_clamping() {
3313 let hc = HierarchicalConfidence::new(1.5, -0.5, 0.5);
3314 assert!((hc.linkage - 1.0).abs() < f32::EPSILON);
3315 assert!(hc.type_score.abs() < f32::EPSILON);
3316 assert!((hc.boundary - 0.5).abs() < f32::EPSILON);
3317 }
3318
3319 #[test]
3320 fn test_hierarchical_confidence_from_single() {
3321 let hc = HierarchicalConfidence::from_single(0.8);
3322 assert!((hc.linkage - 0.8).abs() < f32::EPSILON);
3323 assert!((hc.type_score - 0.8).abs() < f32::EPSILON);
3324 assert!((hc.boundary - 0.8).abs() < f32::EPSILON);
3325 }
3326
3327 #[test]
3328 fn test_hierarchical_confidence_combined() {
3329 let hc = HierarchicalConfidence::new(1.0, 1.0, 1.0);
3330 assert!((hc.combined() - 1.0).abs() < f32::EPSILON);
3331
3332 let hc2 = HierarchicalConfidence::new(0.8, 0.8, 0.8);
3333 assert!((hc2.combined() - 0.8).abs() < f32::EPSILON);
3334
3335 let hc3 = HierarchicalConfidence::new(0.5, 0.5, 0.5);
3337 assert!((hc3.combined() - 0.5).abs() < 0.001);
3338 }
3339
3340 #[test]
3341 fn test_hierarchical_confidence_threshold() {
3342 let hc = HierarchicalConfidence::new(0.9, 0.8, 0.7);
3343 assert!(hc.passes_threshold(0.5, 0.5, 0.5));
3344 assert!(hc.passes_threshold(0.9, 0.8, 0.7));
3345 assert!(!hc.passes_threshold(0.95, 0.8, 0.7)); assert!(!hc.passes_threshold(0.9, 0.85, 0.7)); }
3348
3349 #[test]
3350 fn test_hierarchical_confidence_from_f64() {
3351 let hc: HierarchicalConfidence = 0.85_f64.into();
3352 assert!((hc.linkage - 0.85).abs() < 0.001);
3353 }
3354
3355 #[test]
3360 fn test_ragged_batch_from_sequences() {
3361 let seqs = vec![vec![1, 2, 3], vec![4, 5], vec![6, 7, 8, 9]];
3362 let batch = RaggedBatch::from_sequences(&seqs);
3363
3364 assert_eq!(batch.batch_size(), 3);
3365 assert_eq!(batch.total_tokens(), 9);
3366 assert_eq!(batch.max_seq_len, 4);
3367 assert_eq!(batch.cumulative_offsets, vec![0, 3, 5, 9]);
3368 }
3369
3370 #[test]
3371 fn test_ragged_batch_doc_range() {
3372 let seqs = vec![vec![1, 2, 3], vec![4, 5]];
3373 let batch = RaggedBatch::from_sequences(&seqs);
3374
3375 assert_eq!(batch.doc_range(0), Some(0..3));
3376 assert_eq!(batch.doc_range(1), Some(3..5));
3377 assert_eq!(batch.doc_range(2), None);
3378 }
3379
3380 #[test]
3381 fn test_ragged_batch_doc_tokens() {
3382 let seqs = vec![vec![1, 2, 3], vec![4, 5]];
3383 let batch = RaggedBatch::from_sequences(&seqs);
3384
3385 assert_eq!(batch.doc_tokens(0), Some(&[1, 2, 3][..]));
3386 assert_eq!(batch.doc_tokens(1), Some(&[4, 5][..]));
3387 }
3388
3389 #[test]
3390 fn test_ragged_batch_padding_savings() {
3391 let seqs = vec![vec![1, 2, 3], vec![4, 5], vec![6, 7, 8, 9]];
3395 let batch = RaggedBatch::from_sequences(&seqs);
3396 let savings = batch.padding_savings();
3397 assert!((savings - 0.25).abs() < 0.001);
3398 }
3399
3400 #[test]
3405 fn test_span_candidate() {
3406 let sc = SpanCandidate::new(0, 5, 10);
3407 assert_eq!(sc.doc_idx, 0);
3408 assert_eq!(sc.start, 5);
3409 assert_eq!(sc.end, 10);
3410 assert_eq!(sc.width(), 5);
3411 }
3412
3413 #[test]
3414 fn test_generate_span_candidates() {
3415 let seqs = vec![vec![1, 2, 3]]; let batch = RaggedBatch::from_sequences(&seqs);
3417 let candidates = generate_span_candidates(&batch, 2);
3418
3419 assert_eq!(candidates.len(), 5);
3422
3423 for c in &candidates {
3425 assert_eq!(c.doc_idx, 0);
3426 assert!(c.end as usize <= 3);
3427 assert!(c.width() as usize <= 2);
3428 }
3429 }
3430
3431 #[test]
3432 fn test_generate_filtered_candidates() {
3433 let seqs = vec![vec![1, 2, 3]];
3434 let batch = RaggedBatch::from_sequences(&seqs);
3435
3436 let mask = vec![0.9, 0.9, 0.1, 0.1, 0.1];
3439 let candidates = generate_filtered_candidates(&batch, 2, &mask, 0.5);
3440
3441 assert_eq!(candidates.len(), 2);
3442 }
3443
3444 #[test]
3449 fn test_entity_builder_basic() {
3450 let entity = Entity::builder("John", EntityType::Person)
3451 .span(0, 4)
3452 .confidence(0.95)
3453 .build();
3454
3455 assert_eq!(entity.text, "John");
3456 assert_eq!(entity.entity_type, EntityType::Person);
3457 assert_eq!(entity.start, 0);
3458 assert_eq!(entity.end, 4);
3459 assert!((entity.confidence - 0.95).abs() < f64::EPSILON);
3460 }
3461
3462 #[test]
3463 fn test_entity_builder_full() {
3464 let entity = Entity::builder("Marie Curie", EntityType::Person)
3465 .span(0, 11)
3466 .confidence(0.95)
3467 .kb_id("Q7186")
3468 .canonical_id(42)
3469 .normalized("Marie Salomea Skłodowska Curie")
3470 .provenance(Provenance::ml("bert", 0.95))
3471 .build();
3472
3473 assert_eq!(entity.text, "Marie Curie");
3474 assert_eq!(entity.kb_id.as_deref(), Some("Q7186"));
3475 assert_eq!(
3476 entity.canonical_id,
3477 Some(crate::core::types::CanonicalId::new(42))
3478 );
3479 assert_eq!(
3480 entity.normalized.as_deref(),
3481 Some("Marie Salomea Skłodowska Curie")
3482 );
3483 assert!(entity.provenance.is_some());
3484 }
3485
3486 #[test]
3487 fn test_entity_builder_hierarchical() {
3488 let hc = HierarchicalConfidence::new(0.9, 0.8, 0.7);
3489 let entity = Entity::builder("test", EntityType::Person)
3490 .span(0, 4)
3491 .hierarchical_confidence(hc)
3492 .build();
3493
3494 assert!(entity.hierarchical_confidence.is_some());
3495 assert!((entity.linkage_confidence() - 0.9).abs() < 0.001);
3496 assert!((entity.type_confidence() - 0.8).abs() < 0.001);
3497 assert!((entity.boundary_confidence() - 0.7).abs() < 0.001);
3498 }
3499
3500 #[test]
3501 fn test_entity_builder_visual() {
3502 let bbox = Span::bbox(0.1, 0.2, 0.3, 0.4);
3503 let entity = Entity::builder("receipt item", EntityType::Money)
3504 .visual_span(bbox)
3505 .confidence(0.9)
3506 .build();
3507
3508 assert!(entity.is_visual());
3509 assert!(entity.visual_span.is_some());
3510 }
3511
3512 #[test]
3517 fn test_entity_hierarchical_confidence_helpers() {
3518 let mut entity = Entity::new("test", EntityType::Person, 0, 4, 0.8);
3519
3520 assert!((entity.linkage_confidence() - 0.8).abs() < 0.001);
3522 assert!((entity.type_confidence() - 0.8).abs() < 0.001);
3523 assert!((entity.boundary_confidence() - 0.8).abs() < 0.001);
3524
3525 entity.set_hierarchical_confidence(HierarchicalConfidence::new(0.95, 0.85, 0.75));
3527 assert!((entity.linkage_confidence() - 0.95).abs() < 0.001);
3528 assert!((entity.type_confidence() - 0.85).abs() < 0.001);
3529 assert!((entity.boundary_confidence() - 0.75).abs() < 0.001);
3530 }
3531
3532 #[test]
3533 fn test_entity_from_visual() {
3534 let entity = Entity::from_visual(
3535 "receipt total",
3536 EntityType::Money,
3537 Span::bbox(0.5, 0.8, 0.2, 0.05),
3538 0.92,
3539 );
3540
3541 assert!(entity.is_visual());
3542 assert_eq!(entity.start, 0);
3543 assert_eq!(entity.end, 0);
3544 assert!((entity.confidence - 0.92).abs() < f64::EPSILON);
3545 }
3546
3547 #[test]
3548 fn test_entity_span_helpers() {
3549 let entity = Entity::new("test", EntityType::Person, 10, 20, 0.9);
3550 assert_eq!(entity.text_span(), (10, 20));
3551 assert_eq!(entity.span_len(), 10);
3552 }
3553
3554 #[test]
3559 fn test_provenance_pattern() {
3560 let prov = Provenance::pattern("EMAIL");
3561 assert_eq!(prov.method, ExtractionMethod::Pattern);
3562 assert_eq!(prov.pattern.as_deref(), Some("EMAIL"));
3563 assert_eq!(prov.raw_confidence, Some(Confidence::new(1.0))); }
3565
3566 #[test]
3567 fn test_provenance_ml() {
3568 let prov = Provenance::ml("bert-ner", 0.87);
3569 assert_eq!(prov.method, ExtractionMethod::Neural);
3570 assert_eq!(prov.source.as_ref(), "bert-ner");
3571 assert_eq!(prov.raw_confidence, Some(Confidence::new(0.87)));
3572 }
3573
3574 #[test]
3575 fn test_provenance_with_version() {
3576 let prov = Provenance::ml("gliner", 0.92).with_version("v2.1.0");
3577
3578 assert_eq!(prov.model_version.as_deref(), Some("v2.1.0"));
3579 assert_eq!(prov.source.as_ref(), "gliner");
3580 }
3581
3582 #[test]
3583 fn test_provenance_with_timestamp() {
3584 let prov = Provenance::pattern("DATE").with_timestamp("2024-01-15T10:30:00Z");
3585
3586 assert_eq!(prov.timestamp.as_deref(), Some("2024-01-15T10:30:00Z"));
3587 }
3588
3589 #[test]
3590 fn test_provenance_builder_chain() {
3591 let prov = Provenance::ml("modernbert-ner", 0.95)
3592 .with_version("v1.0.0")
3593 .with_timestamp("2024-11-27T12:00:00Z");
3594
3595 assert_eq!(prov.method, ExtractionMethod::Neural);
3596 assert_eq!(prov.source.as_ref(), "modernbert-ner");
3597 assert_eq!(prov.raw_confidence, Some(Confidence::new(0.95)));
3598 assert_eq!(prov.model_version.as_deref(), Some("v1.0.0"));
3599 assert_eq!(prov.timestamp.as_deref(), Some("2024-11-27T12:00:00Z"));
3600 }
3601
3602 #[test]
3603 fn test_provenance_serialization() {
3604 let prov = Provenance::ml("test", 0.9)
3605 .with_version("v1.0")
3606 .with_timestamp("2024-01-01");
3607
3608 let json = serde_json::to_string(&prov).unwrap();
3609 assert!(json.contains("model_version"));
3610 assert!(json.contains("v1.0"));
3611
3612 let restored: Provenance = serde_json::from_str(&json).unwrap();
3613 assert_eq!(restored.model_version.as_deref(), Some("v1.0"));
3614 assert_eq!(restored.timestamp.as_deref(), Some("2024-01-01"));
3615 }
3616}
3617
3618#[cfg(test)]
3619mod proptests {
3620 #![allow(clippy::unwrap_used)] use super::*;
3622 use proptest::prelude::*;
3623
3624 proptest! {
3625 #[test]
3626 fn confidence_always_clamped(conf in -10.0f64..10.0) {
3627 let e = Entity::new("test", EntityType::Person, 0, 4, conf);
3628 prop_assert!(e.confidence >= 0.0);
3629 prop_assert!(e.confidence <= 1.0);
3630 }
3631
3632 #[test]
3633 fn entity_type_roundtrip(label in "[A-Z]{3,10}") {
3634 let et = EntityType::from_label(&label);
3635 let back = EntityType::from_label(et.as_label());
3636 let is_custom = matches!(back, EntityType::Custom { .. });
3638 prop_assert!(is_custom || back == et);
3639 }
3640
3641 #[test]
3642 fn overlap_is_symmetric(
3643 s1 in 0usize..100,
3644 len1 in 1usize..50,
3645 s2 in 0usize..100,
3646 len2 in 1usize..50,
3647 ) {
3648 let e1 = Entity::new("a", EntityType::Person, s1, s1 + len1, 1.0);
3649 let e2 = Entity::new("b", EntityType::Person, s2, s2 + len2, 1.0);
3650 prop_assert_eq!(e1.overlaps(&e2), e2.overlaps(&e1));
3651 }
3652
3653 #[test]
3654 fn overlap_ratio_bounded(
3655 s1 in 0usize..100,
3656 len1 in 1usize..50,
3657 s2 in 0usize..100,
3658 len2 in 1usize..50,
3659 ) {
3660 let e1 = Entity::new("a", EntityType::Person, s1, s1 + len1, 1.0);
3661 let e2 = Entity::new("b", EntityType::Person, s2, s2 + len2, 1.0);
3662 let ratio = e1.overlap_ratio(&e2);
3663 prop_assert!(ratio >= 0.0);
3664 prop_assert!(ratio <= 1.0);
3665 }
3666
3667 #[test]
3668 fn self_overlap_ratio_is_one(s in 0usize..100, len in 1usize..50) {
3669 let e = Entity::new("test", EntityType::Person, s, s + len, 1.0);
3670 let ratio = e.overlap_ratio(&e);
3671 prop_assert!((ratio - 1.0).abs() < 1e-10);
3672 }
3673
3674 #[test]
3675 fn hierarchical_confidence_always_clamped(
3676 linkage in -2.0f32..2.0,
3677 type_score in -2.0f32..2.0,
3678 boundary in -2.0f32..2.0,
3679 ) {
3680 let hc = HierarchicalConfidence::new(linkage, type_score, boundary);
3681 prop_assert!(hc.linkage >= 0.0 && hc.linkage <= 1.0);
3682 prop_assert!(hc.type_score >= 0.0 && hc.type_score <= 1.0);
3683 prop_assert!(hc.boundary >= 0.0 && hc.boundary <= 1.0);
3684 prop_assert!(hc.combined() >= 0.0 && hc.combined() <= 1.0);
3685 }
3686
3687 #[test]
3688 fn span_candidate_width_consistent(
3689 doc in 0u32..10,
3690 start in 0u32..100,
3691 end in 1u32..100,
3692 ) {
3693 let actual_end = start.max(end);
3694 let sc = SpanCandidate::new(doc, start, actual_end);
3695 prop_assert_eq!(sc.width(), actual_end.saturating_sub(start));
3696 }
3697
3698 #[test]
3699 fn ragged_batch_preserves_tokens(
3700 seq_lens in proptest::collection::vec(1usize..10, 1..5),
3701 ) {
3702 let mut counter = 0u32;
3704 let seqs: Vec<Vec<u32>> = seq_lens.iter().map(|&len| {
3705 let seq: Vec<u32> = (counter..counter + len as u32).collect();
3706 counter += len as u32;
3707 seq
3708 }).collect();
3709
3710 let batch = RaggedBatch::from_sequences(&seqs);
3711
3712 prop_assert_eq!(batch.batch_size(), seqs.len());
3714 prop_assert_eq!(batch.total_tokens(), seq_lens.iter().sum::<usize>());
3715
3716 for (i, seq) in seqs.iter().enumerate() {
3718 let doc_tokens = batch.doc_tokens(i).unwrap();
3719 prop_assert_eq!(doc_tokens, seq.as_slice());
3720 }
3721 }
3722
3723 #[test]
3724 fn span_text_offsets_consistent(start in 0usize..100, len in 0usize..50) {
3725 let end = start + len;
3726 let span = Span::text(start, end);
3727 let (s, e) = span.text_offsets().unwrap();
3728 prop_assert_eq!(s, start);
3729 prop_assert_eq!(e, end);
3730 prop_assert_eq!(span.len(), len);
3731 }
3732
3733 #[test]
3739 fn entity_span_validity(
3740 start in 0usize..10000,
3741 len in 1usize..500,
3742 conf in 0.0f64..=1.0,
3743 ) {
3744 let end = start + len;
3745 let text_content: String = "x".repeat(end);
3747 let entity_text: String = text_content.chars().skip(start).take(len).collect();
3748 let e = Entity::new(&entity_text, EntityType::Person, start, end, conf);
3749 let issues = e.validate(&text_content);
3750 for issue in &issues {
3752 match issue {
3753 ValidationIssue::InvalidSpan { .. } => {
3754 prop_assert!(false, "start < end should never produce InvalidSpan");
3755 }
3756 ValidationIssue::SpanOutOfBounds { .. } => {
3757 prop_assert!(false, "span within text should never produce SpanOutOfBounds");
3758 }
3759 _ => {} }
3761 }
3762 }
3763
3764 #[test]
3766 fn entity_type_label_roundtrip_standard(
3767 idx in 0usize..13,
3768 ) {
3769 let standard_types = [
3770 EntityType::Person,
3771 EntityType::Organization,
3772 EntityType::Location,
3773 EntityType::Date,
3774 EntityType::Time,
3775 EntityType::Money,
3776 EntityType::Percent,
3777 EntityType::Quantity,
3778 EntityType::Cardinal,
3779 EntityType::Ordinal,
3780 EntityType::Email,
3781 EntityType::Url,
3782 EntityType::Phone,
3783 ];
3784 let et = &standard_types[idx];
3785 let label = et.as_label();
3786 let roundtripped = EntityType::from_label(label);
3787 prop_assert_eq!(&roundtripped, et,
3788 "from_label(as_label()) must roundtrip for {:?} (label={:?})", et, label);
3789 }
3790
3791 #[test]
3793 fn span_containment_property(
3794 a_start in 0usize..5000,
3795 a_len in 1usize..5000,
3796 b_offset in 0usize..5000,
3797 b_len in 1usize..5000,
3798 ) {
3799 let a_end = a_start + a_len;
3800 let b_start = a_start + (b_offset % a_len); let b_end_candidate = b_start + b_len;
3802
3803 if b_start >= a_start && b_end_candidate <= a_end {
3805 prop_assert!(a_start <= b_start);
3807 prop_assert!(a_end >= b_end_candidate);
3808
3809 let ea = Entity::new("a", EntityType::Person, a_start, a_end, 1.0);
3811 let eb = Entity::new("b", EntityType::Person, b_start, b_end_candidate, 1.0);
3812 prop_assert!(ea.overlaps(&eb),
3813 "containing span must overlap contained span");
3814 }
3815 }
3816
3817 #[test]
3819 fn entity_serde_roundtrip(
3820 start in 0usize..10000,
3821 len in 1usize..500,
3822 conf in 0.0f64..=1.0,
3823 type_idx in 0usize..5,
3824 ) {
3825 let end = start + len;
3826 let types = [
3827 EntityType::Person,
3828 EntityType::Organization,
3829 EntityType::Location,
3830 EntityType::Date,
3831 EntityType::Email,
3832 ];
3833 let et = types[type_idx].clone();
3834 let text = format!("entity_{}", start);
3835 let e = Entity::new(&text, et, start, end, conf);
3836
3837 let json = serde_json::to_string(&e).unwrap();
3838 let e2: Entity = serde_json::from_str(&json).unwrap();
3839
3840 prop_assert_eq!(&e.text, &e2.text);
3841 prop_assert_eq!(&e.entity_type, &e2.entity_type);
3842 prop_assert_eq!(e.start, e2.start);
3843 prop_assert_eq!(e.end, e2.end);
3844 prop_assert!((e.confidence - e2.confidence).abs() < 1e-10,
3846 "confidence roundtrip: {} vs {}", e.confidence, e2.confidence);
3847 prop_assert_eq!(&e.normalized, &e2.normalized);
3848 prop_assert_eq!(&e.kb_id, &e2.kb_id);
3849 }
3850
3851 #[test]
3853 fn discontinuous_span_total_length(
3854 segments in proptest::collection::vec(
3855 (0usize..5000, 1usize..500),
3856 1..6
3857 ),
3858 ) {
3859 let ranges: Vec<std::ops::Range<usize>> = segments.iter()
3860 .map(|&(start, len)| start..start + len)
3861 .collect();
3862 let expected_sum: usize = ranges.iter().map(|r| r.end - r.start).sum();
3863 let span = DiscontinuousSpan::new(ranges);
3864 prop_assert_eq!(span.total_len(), expected_sum,
3865 "total_len must equal sum of segment lengths");
3866 }
3867 }
3868
3869 #[test]
3874 fn test_entity_viewport_as_str() {
3875 assert_eq!(EntityViewport::Business.as_str(), "business");
3876 assert_eq!(EntityViewport::Legal.as_str(), "legal");
3877 assert_eq!(EntityViewport::Technical.as_str(), "technical");
3878 assert_eq!(EntityViewport::Academic.as_str(), "academic");
3879 assert_eq!(EntityViewport::Personal.as_str(), "personal");
3880 assert_eq!(EntityViewport::Political.as_str(), "political");
3881 assert_eq!(EntityViewport::Media.as_str(), "media");
3882 assert_eq!(EntityViewport::Historical.as_str(), "historical");
3883 assert_eq!(EntityViewport::General.as_str(), "general");
3884 assert_eq!(
3885 EntityViewport::Custom("custom".to_string()).as_str(),
3886 "custom"
3887 );
3888 }
3889
3890 #[test]
3891 fn test_entity_viewport_is_professional() {
3892 assert!(EntityViewport::Business.is_professional());
3893 assert!(EntityViewport::Legal.is_professional());
3894 assert!(EntityViewport::Technical.is_professional());
3895 assert!(EntityViewport::Academic.is_professional());
3896 assert!(EntityViewport::Political.is_professional());
3897
3898 assert!(!EntityViewport::Personal.is_professional());
3899 assert!(!EntityViewport::Media.is_professional());
3900 assert!(!EntityViewport::Historical.is_professional());
3901 assert!(!EntityViewport::General.is_professional());
3902 assert!(!EntityViewport::Custom("test".to_string()).is_professional());
3903 }
3904
3905 #[test]
3906 fn test_entity_viewport_from_str() {
3907 assert_eq!(
3908 "business".parse::<EntityViewport>().unwrap(),
3909 EntityViewport::Business
3910 );
3911 assert_eq!(
3912 "financial".parse::<EntityViewport>().unwrap(),
3913 EntityViewport::Business
3914 );
3915 assert_eq!(
3916 "corporate".parse::<EntityViewport>().unwrap(),
3917 EntityViewport::Business
3918 );
3919
3920 assert_eq!(
3921 "legal".parse::<EntityViewport>().unwrap(),
3922 EntityViewport::Legal
3923 );
3924 assert_eq!(
3925 "law".parse::<EntityViewport>().unwrap(),
3926 EntityViewport::Legal
3927 );
3928
3929 assert_eq!(
3930 "technical".parse::<EntityViewport>().unwrap(),
3931 EntityViewport::Technical
3932 );
3933 assert_eq!(
3934 "engineering".parse::<EntityViewport>().unwrap(),
3935 EntityViewport::Technical
3936 );
3937
3938 assert_eq!(
3939 "academic".parse::<EntityViewport>().unwrap(),
3940 EntityViewport::Academic
3941 );
3942 assert_eq!(
3943 "research".parse::<EntityViewport>().unwrap(),
3944 EntityViewport::Academic
3945 );
3946
3947 assert_eq!(
3948 "personal".parse::<EntityViewport>().unwrap(),
3949 EntityViewport::Personal
3950 );
3951 assert_eq!(
3952 "biographical".parse::<EntityViewport>().unwrap(),
3953 EntityViewport::Personal
3954 );
3955
3956 assert_eq!(
3957 "political".parse::<EntityViewport>().unwrap(),
3958 EntityViewport::Political
3959 );
3960 assert_eq!(
3961 "policy".parse::<EntityViewport>().unwrap(),
3962 EntityViewport::Political
3963 );
3964
3965 assert_eq!(
3966 "media".parse::<EntityViewport>().unwrap(),
3967 EntityViewport::Media
3968 );
3969 assert_eq!(
3970 "press".parse::<EntityViewport>().unwrap(),
3971 EntityViewport::Media
3972 );
3973
3974 assert_eq!(
3975 "historical".parse::<EntityViewport>().unwrap(),
3976 EntityViewport::Historical
3977 );
3978 assert_eq!(
3979 "history".parse::<EntityViewport>().unwrap(),
3980 EntityViewport::Historical
3981 );
3982
3983 assert_eq!(
3984 "general".parse::<EntityViewport>().unwrap(),
3985 EntityViewport::General
3986 );
3987 assert_eq!(
3988 "generic".parse::<EntityViewport>().unwrap(),
3989 EntityViewport::General
3990 );
3991 assert_eq!(
3992 "".parse::<EntityViewport>().unwrap(),
3993 EntityViewport::General
3994 );
3995
3996 assert_eq!(
3998 "custom_viewport".parse::<EntityViewport>().unwrap(),
3999 EntityViewport::Custom("custom_viewport".to_string())
4000 );
4001 }
4002
4003 #[test]
4004 fn test_entity_viewport_from_str_case_insensitive() {
4005 assert_eq!(
4006 "BUSINESS".parse::<EntityViewport>().unwrap(),
4007 EntityViewport::Business
4008 );
4009 assert_eq!(
4010 "Business".parse::<EntityViewport>().unwrap(),
4011 EntityViewport::Business
4012 );
4013 assert_eq!(
4014 "BuSiNeSs".parse::<EntityViewport>().unwrap(),
4015 EntityViewport::Business
4016 );
4017 }
4018
4019 #[test]
4020 fn test_entity_viewport_display() {
4021 assert_eq!(format!("{}", EntityViewport::Business), "business");
4022 assert_eq!(format!("{}", EntityViewport::Academic), "academic");
4023 assert_eq!(
4024 format!("{}", EntityViewport::Custom("test".to_string())),
4025 "test"
4026 );
4027 }
4028
4029 #[test]
4030 fn test_entity_viewport_methods() {
4031 let mut entity = Entity::new("Marie Curie", EntityType::Person, 0, 11, 0.9);
4032
4033 assert!(!entity.has_viewport());
4035 assert_eq!(entity.viewport_or_default(), EntityViewport::General);
4036 assert!(entity.matches_viewport(&EntityViewport::Academic)); entity.set_viewport(EntityViewport::Academic);
4040 assert!(entity.has_viewport());
4041 assert_eq!(entity.viewport_or_default(), EntityViewport::Academic);
4042 assert!(entity.matches_viewport(&EntityViewport::Academic));
4043 assert!(!entity.matches_viewport(&EntityViewport::Business));
4044 }
4045
4046 #[test]
4047 fn test_entity_builder_with_viewport() {
4048 let entity = Entity::builder("Marie Curie", EntityType::Person)
4049 .span(0, 11)
4050 .viewport(EntityViewport::Academic)
4051 .build();
4052
4053 assert_eq!(entity.viewport, Some(EntityViewport::Academic));
4054 assert!(entity.has_viewport());
4055 }
4056
4057 #[test]
4062 fn test_entity_category_requires_ml() {
4063 assert!(EntityCategory::Agent.requires_ml());
4064 assert!(EntityCategory::Organization.requires_ml());
4065 assert!(EntityCategory::Place.requires_ml());
4066 assert!(EntityCategory::Creative.requires_ml());
4067 assert!(EntityCategory::Relation.requires_ml());
4068
4069 assert!(!EntityCategory::Temporal.requires_ml());
4070 assert!(!EntityCategory::Numeric.requires_ml());
4071 assert!(!EntityCategory::Contact.requires_ml());
4072 assert!(!EntityCategory::Misc.requires_ml());
4073 }
4074
4075 #[test]
4076 fn test_entity_category_pattern_detectable() {
4077 assert!(EntityCategory::Temporal.pattern_detectable());
4078 assert!(EntityCategory::Numeric.pattern_detectable());
4079 assert!(EntityCategory::Contact.pattern_detectable());
4080
4081 assert!(!EntityCategory::Agent.pattern_detectable());
4082 assert!(!EntityCategory::Organization.pattern_detectable());
4083 assert!(!EntityCategory::Place.pattern_detectable());
4084 assert!(!EntityCategory::Creative.pattern_detectable());
4085 assert!(!EntityCategory::Relation.pattern_detectable());
4086 assert!(!EntityCategory::Misc.pattern_detectable());
4087 }
4088
4089 #[test]
4090 fn test_entity_category_is_relation() {
4091 assert!(EntityCategory::Relation.is_relation());
4092
4093 assert!(!EntityCategory::Agent.is_relation());
4094 assert!(!EntityCategory::Organization.is_relation());
4095 assert!(!EntityCategory::Place.is_relation());
4096 assert!(!EntityCategory::Temporal.is_relation());
4097 assert!(!EntityCategory::Numeric.is_relation());
4098 assert!(!EntityCategory::Contact.is_relation());
4099 assert!(!EntityCategory::Creative.is_relation());
4100 assert!(!EntityCategory::Misc.is_relation());
4101 }
4102
4103 #[test]
4104 fn test_entity_category_as_str() {
4105 assert_eq!(EntityCategory::Agent.as_str(), "agent");
4106 assert_eq!(EntityCategory::Organization.as_str(), "organization");
4107 assert_eq!(EntityCategory::Place.as_str(), "place");
4108 assert_eq!(EntityCategory::Creative.as_str(), "creative");
4109 assert_eq!(EntityCategory::Temporal.as_str(), "temporal");
4110 assert_eq!(EntityCategory::Numeric.as_str(), "numeric");
4111 assert_eq!(EntityCategory::Contact.as_str(), "contact");
4112 assert_eq!(EntityCategory::Relation.as_str(), "relation");
4113 assert_eq!(EntityCategory::Misc.as_str(), "misc");
4114 }
4115
4116 #[test]
4117 fn test_entity_category_display() {
4118 assert_eq!(format!("{}", EntityCategory::Agent), "agent");
4119 assert_eq!(format!("{}", EntityCategory::Temporal), "temporal");
4120 assert_eq!(format!("{}", EntityCategory::Relation), "relation");
4121 }
4122
4123 #[test]
4128 fn test_entity_type_serializes_to_flat_string() {
4129 assert_eq!(
4130 serde_json::to_string(&EntityType::Person).unwrap(),
4131 r#""PER""#
4132 );
4133 assert_eq!(
4134 serde_json::to_string(&EntityType::Organization).unwrap(),
4135 r#""ORG""#
4136 );
4137 assert_eq!(
4138 serde_json::to_string(&EntityType::Location).unwrap(),
4139 r#""LOC""#
4140 );
4141 assert_eq!(
4142 serde_json::to_string(&EntityType::Date).unwrap(),
4143 r#""DATE""#
4144 );
4145 assert_eq!(
4146 serde_json::to_string(&EntityType::Money).unwrap(),
4147 r#""MONEY""#
4148 );
4149 }
4150
4151 #[test]
4152 fn test_custom_entity_type_serializes_flat() {
4153 let misc = EntityType::custom("MISC", EntityCategory::Misc);
4154 assert_eq!(serde_json::to_string(&misc).unwrap(), r#""MISC""#);
4155
4156 let disease = EntityType::custom("DISEASE", EntityCategory::Agent);
4157 assert_eq!(serde_json::to_string(&disease).unwrap(), r#""DISEASE""#);
4158 }
4159
4160 #[test]
4161 fn test_entity_type_deserializes_from_flat_string() {
4162 let per: EntityType = serde_json::from_str(r#""PER""#).unwrap();
4163 assert_eq!(per, EntityType::Person);
4164
4165 let org: EntityType = serde_json::from_str(r#""ORG""#).unwrap();
4166 assert_eq!(org, EntityType::Organization);
4167
4168 let misc: EntityType = serde_json::from_str(r#""MISC""#).unwrap();
4169 assert_eq!(misc, EntityType::custom("MISC", EntityCategory::Misc));
4170 }
4171
4172 #[test]
4173 fn test_entity_type_deserializes_backward_compat_custom() {
4174 let json = r#"{"Custom":{"name":"MISC","category":"Misc"}}"#;
4176 let et: EntityType = serde_json::from_str(json).unwrap();
4177 assert_eq!(et, EntityType::custom("MISC", EntityCategory::Misc));
4178 }
4179
4180 #[test]
4181 fn test_entity_type_deserializes_backward_compat_other() {
4182 let json = r#"{"Other":"foo"}"#;
4184 let et: EntityType = serde_json::from_str(json).unwrap();
4185 assert_eq!(et, EntityType::custom("foo", EntityCategory::Misc));
4186 }
4187
4188 #[test]
4189 fn test_entity_type_serde_roundtrip() {
4190 let types = vec![
4191 EntityType::Person,
4192 EntityType::Organization,
4193 EntityType::Location,
4194 EntityType::Date,
4195 EntityType::Time,
4196 EntityType::Money,
4197 EntityType::Percent,
4198 EntityType::Quantity,
4199 EntityType::Cardinal,
4200 EntityType::Ordinal,
4201 EntityType::Email,
4202 EntityType::Url,
4203 EntityType::Phone,
4204 EntityType::custom("MISC", EntityCategory::Misc),
4205 EntityType::custom("DISEASE", EntityCategory::Agent),
4206 ];
4207
4208 for t in &types {
4209 let json = serde_json::to_string(t).unwrap();
4210 let back: EntityType = serde_json::from_str(&json).unwrap();
4211 assert_eq!(
4214 t.as_label(),
4215 back.as_label(),
4216 "roundtrip failed for {:?}",
4217 t
4218 );
4219 }
4220 }
4221}