1use serde::{Deserialize, Serialize};
39use std::borrow::Cow;
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
51#[non_exhaustive]
52pub enum EntityCategory {
53 Agent,
56 Organization,
59 Place,
62 Creative,
65 Temporal,
68 Numeric,
71 Contact,
74 Relation,
78 Misc,
80}
81
82impl EntityCategory {
83 #[must_use]
85 pub const fn requires_ml(&self) -> bool {
86 matches!(
87 self,
88 EntityCategory::Agent
89 | EntityCategory::Organization
90 | EntityCategory::Place
91 | EntityCategory::Creative
92 | EntityCategory::Relation
93 )
94 }
95
96 #[must_use]
98 pub const fn pattern_detectable(&self) -> bool {
99 matches!(
100 self,
101 EntityCategory::Temporal | EntityCategory::Numeric | EntityCategory::Contact
102 )
103 }
104
105 #[must_use]
107 pub const fn is_relation(&self) -> bool {
108 matches!(self, EntityCategory::Relation)
109 }
110
111 #[must_use]
113 pub const fn as_str(&self) -> &'static str {
114 match self {
115 EntityCategory::Agent => "agent",
116 EntityCategory::Organization => "organization",
117 EntityCategory::Place => "place",
118 EntityCategory::Creative => "creative",
119 EntityCategory::Temporal => "temporal",
120 EntityCategory::Numeric => "numeric",
121 EntityCategory::Contact => "contact",
122 EntityCategory::Relation => "relation",
123 EntityCategory::Misc => "misc",
124 }
125 }
126}
127
128impl std::fmt::Display for EntityCategory {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 write!(f, "{}", self.as_str())
131 }
132}
133
134#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
171#[non_exhaustive]
172pub enum EntityViewport {
173 Business,
175 Legal,
177 Technical,
179 Academic,
181 Personal,
183 Political,
185 Media,
187 Historical,
189 #[default]
191 General,
192 Custom(String),
194}
195
196impl EntityViewport {
197 #[must_use]
199 pub fn as_str(&self) -> &str {
200 match self {
201 EntityViewport::Business => "business",
202 EntityViewport::Legal => "legal",
203 EntityViewport::Technical => "technical",
204 EntityViewport::Academic => "academic",
205 EntityViewport::Personal => "personal",
206 EntityViewport::Political => "political",
207 EntityViewport::Media => "media",
208 EntityViewport::Historical => "historical",
209 EntityViewport::General => "general",
210 EntityViewport::Custom(s) => s,
211 }
212 }
213
214 #[must_use]
216 pub const fn is_professional(&self) -> bool {
217 matches!(
218 self,
219 EntityViewport::Business
220 | EntityViewport::Legal
221 | EntityViewport::Technical
222 | EntityViewport::Academic
223 | EntityViewport::Political
224 )
225 }
226}
227
228impl std::str::FromStr for EntityViewport {
229 type Err = std::convert::Infallible;
230
231 fn from_str(s: &str) -> Result<Self, Self::Err> {
232 Ok(match s.to_lowercase().as_str() {
233 "business" | "financial" | "corporate" => EntityViewport::Business,
234 "legal" | "law" | "compliance" => EntityViewport::Legal,
235 "technical" | "engineering" | "tech" => EntityViewport::Technical,
236 "academic" | "research" | "scholarly" => EntityViewport::Academic,
237 "personal" | "biographical" | "private" => EntityViewport::Personal,
238 "political" | "policy" | "government" => EntityViewport::Political,
239 "media" | "press" | "pr" | "public_relations" => EntityViewport::Media,
240 "historical" | "history" | "past" => EntityViewport::Historical,
241 "general" | "generic" | "" => EntityViewport::General,
242 other => EntityViewport::Custom(other.to_string()),
243 })
244 }
245}
246
247impl std::fmt::Display for EntityViewport {
248 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249 write!(f, "{}", self.as_str())
250 }
251}
252
253#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
278#[non_exhaustive]
279pub enum EntityType {
280 Person,
283 Organization,
285 Location,
287
288 Date,
291 Time,
293
294 Money,
297 Percent,
299 Quantity,
301 Cardinal,
303 Ordinal,
305
306 Email,
309 Url,
311 Phone,
313
314 Custom {
317 name: String,
319 category: EntityCategory,
321 },
322
323 #[serde(rename = "Other")]
325 Other(String),
326}
327
328impl EntityType {
329 #[must_use]
331 pub fn category(&self) -> EntityCategory {
332 match self {
333 EntityType::Person => EntityCategory::Agent,
335 EntityType::Organization => EntityCategory::Organization,
337 EntityType::Location => EntityCategory::Place,
339 EntityType::Date | EntityType::Time => EntityCategory::Temporal,
341 EntityType::Money
343 | EntityType::Percent
344 | EntityType::Quantity
345 | EntityType::Cardinal
346 | EntityType::Ordinal => EntityCategory::Numeric,
347 EntityType::Email | EntityType::Url | EntityType::Phone => EntityCategory::Contact,
349 EntityType::Custom { category, .. } => *category,
351 EntityType::Other(_) => EntityCategory::Misc,
353 }
354 }
355
356 #[must_use]
358 pub fn requires_ml(&self) -> bool {
359 self.category().requires_ml()
360 }
361
362 #[must_use]
364 pub fn pattern_detectable(&self) -> bool {
365 self.category().pattern_detectable()
366 }
367
368 #[must_use]
377 pub fn as_label(&self) -> &str {
378 match self {
379 EntityType::Person => "PER",
380 EntityType::Organization => "ORG",
381 EntityType::Location => "LOC",
382 EntityType::Date => "DATE",
383 EntityType::Time => "TIME",
384 EntityType::Money => "MONEY",
385 EntityType::Percent => "PERCENT",
386 EntityType::Quantity => "QUANTITY",
387 EntityType::Cardinal => "CARDINAL",
388 EntityType::Ordinal => "ORDINAL",
389 EntityType::Email => "EMAIL",
390 EntityType::Url => "URL",
391 EntityType::Phone => "PHONE",
392 EntityType::Custom { name, .. } => name.as_str(),
393 EntityType::Other(s) => s.as_str(),
394 }
395 }
396
397 #[must_use]
409 pub fn from_label(label: &str) -> Self {
410 let label = label
412 .strip_prefix("B-")
413 .or_else(|| label.strip_prefix("I-"))
414 .or_else(|| label.strip_prefix("E-"))
415 .or_else(|| label.strip_prefix("S-"))
416 .unwrap_or(label);
417
418 match label.to_uppercase().as_str() {
419 "PER" | "PERSON" => EntityType::Person,
421 "ORG" | "ORGANIZATION" | "COMPANY" | "CORPORATION" => EntityType::Organization,
422 "LOC" | "LOCATION" | "GPE" | "GEO-LOC" => EntityType::Location,
423 "FACILITY" | "FAC" | "BUILDING" => {
425 EntityType::custom("BUILDING", EntityCategory::Place)
426 }
427 "PRODUCT" | "PROD" => EntityType::custom("PRODUCT", EntityCategory::Misc),
428 "EVENT" => EntityType::custom("EVENT", EntityCategory::Creative),
429 "CREATIVE-WORK" | "WORK_OF_ART" | "ART" => {
430 EntityType::custom("CREATIVE_WORK", EntityCategory::Creative)
431 }
432 "GROUP" | "NORP" => EntityType::custom("GROUP", EntityCategory::Agent),
433 "DATE" => EntityType::Date,
435 "TIME" => EntityType::Time,
436 "MONEY" | "CURRENCY" => EntityType::Money,
438 "PERCENT" | "PERCENTAGE" => EntityType::Percent,
439 "QUANTITY" => EntityType::Quantity,
440 "CARDINAL" => EntityType::Cardinal,
441 "ORDINAL" => EntityType::Ordinal,
442 "EMAIL" => EntityType::Email,
444 "URL" | "URI" => EntityType::Url,
445 "PHONE" | "TELEPHONE" => EntityType::Phone,
446 "MISC" | "MISCELLANEOUS" | "OTHER" => EntityType::Other("MISC".to_string()),
448 "DISEASE" | "DISORDER" => EntityType::custom("DISEASE", EntityCategory::Misc),
450 "CHEMICAL" | "DRUG" => EntityType::custom("CHEMICAL", EntityCategory::Misc),
451 "GENE" => EntityType::custom("GENE", EntityCategory::Misc),
452 "PROTEIN" => EntityType::custom("PROTEIN", EntityCategory::Misc),
453 other => EntityType::Other(other.to_string()),
455 }
456 }
457
458 #[must_use]
472 pub fn custom(name: impl Into<String>, category: EntityCategory) -> Self {
473 EntityType::Custom {
474 name: name.into(),
475 category,
476 }
477 }
478}
479
480impl std::fmt::Display for EntityType {
481 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
482 write!(f, "{}", self.as_label())
483 }
484}
485
486impl std::str::FromStr for EntityType {
487 type Err = std::convert::Infallible;
488
489 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
491 Ok(Self::from_label(s))
492 }
493}
494
495#[derive(Debug, Clone, Default)]
535pub struct TypeMapper {
536 mappings: std::collections::HashMap<String, EntityType>,
537}
538
539impl TypeMapper {
540 #[must_use]
542 pub fn new() -> Self {
543 Self::default()
544 }
545
546 #[must_use]
548 pub fn mit_movie() -> Self {
549 let mut mapper = Self::new();
550 mapper.add("ACTOR", EntityType::Person);
552 mapper.add("DIRECTOR", EntityType::Person);
553 mapper.add("CHARACTER", EntityType::Person);
554 mapper.add(
555 "TITLE",
556 EntityType::custom("WORK_OF_ART", EntityCategory::Creative),
557 );
558 mapper.add("GENRE", EntityType::custom("GENRE", EntityCategory::Misc));
559 mapper.add("YEAR", EntityType::Date);
560 mapper.add("RATING", EntityType::custom("RATING", EntityCategory::Misc));
561 mapper.add("PLOT", EntityType::custom("PLOT", EntityCategory::Misc));
562 mapper
563 }
564
565 #[must_use]
567 pub fn mit_restaurant() -> Self {
568 let mut mapper = Self::new();
569 mapper.add("RESTAURANT_NAME", EntityType::Organization);
570 mapper.add("LOCATION", EntityType::Location);
571 mapper.add(
572 "CUISINE",
573 EntityType::custom("CUISINE", EntityCategory::Misc),
574 );
575 mapper.add("DISH", EntityType::custom("DISH", EntityCategory::Misc));
576 mapper.add("PRICE", EntityType::Money);
577 mapper.add(
578 "AMENITY",
579 EntityType::custom("AMENITY", EntityCategory::Misc),
580 );
581 mapper.add("HOURS", EntityType::Time);
582 mapper
583 }
584
585 #[must_use]
587 pub fn biomedical() -> Self {
588 let mut mapper = Self::new();
589 mapper.add(
590 "DISEASE",
591 EntityType::custom("DISEASE", EntityCategory::Agent),
592 );
593 mapper.add(
594 "CHEMICAL",
595 EntityType::custom("CHEMICAL", EntityCategory::Misc),
596 );
597 mapper.add("DRUG", EntityType::custom("DRUG", EntityCategory::Misc));
598 mapper.add("GENE", EntityType::custom("GENE", EntityCategory::Misc));
599 mapper.add(
600 "PROTEIN",
601 EntityType::custom("PROTEIN", EntityCategory::Misc),
602 );
603 mapper.add("DNA", EntityType::custom("DNA", EntityCategory::Misc));
605 mapper.add("RNA", EntityType::custom("RNA", EntityCategory::Misc));
606 mapper.add(
607 "cell_line",
608 EntityType::custom("CELL_LINE", EntityCategory::Misc),
609 );
610 mapper.add(
611 "cell_type",
612 EntityType::custom("CELL_TYPE", EntityCategory::Misc),
613 );
614 mapper
615 }
616
617 #[must_use]
619 pub fn social_media() -> Self {
620 let mut mapper = Self::new();
621 mapper.add("person", EntityType::Person);
623 mapper.add("corporation", EntityType::Organization);
624 mapper.add("location", EntityType::Location);
625 mapper.add("group", EntityType::Organization);
626 mapper.add(
627 "product",
628 EntityType::custom("PRODUCT", EntityCategory::Misc),
629 );
630 mapper.add(
631 "creative_work",
632 EntityType::custom("WORK_OF_ART", EntityCategory::Creative),
633 );
634 mapper.add("event", EntityType::custom("EVENT", EntityCategory::Misc));
635 mapper
636 }
637
638 #[must_use]
640 pub fn manufacturing() -> Self {
641 let mut mapper = Self::new();
642 mapper.add("MATE", EntityType::custom("MATERIAL", EntityCategory::Misc));
644 mapper.add("MANP", EntityType::custom("PROCESS", EntityCategory::Misc));
645 mapper.add("MACEQ", EntityType::custom("MACHINE", EntityCategory::Misc));
646 mapper.add(
647 "APPL",
648 EntityType::custom("APPLICATION", EntityCategory::Misc),
649 );
650 mapper.add("FEAT", EntityType::custom("FEATURE", EntityCategory::Misc));
651 mapper.add(
652 "PARA",
653 EntityType::custom("PARAMETER", EntityCategory::Misc),
654 );
655 mapper.add("PRO", EntityType::custom("PROPERTY", EntityCategory::Misc));
656 mapper.add(
657 "CHAR",
658 EntityType::custom("CHARACTERISTIC", EntityCategory::Misc),
659 );
660 mapper.add(
661 "ENAT",
662 EntityType::custom("ENABLING_TECHNOLOGY", EntityCategory::Misc),
663 );
664 mapper.add(
665 "CONPRI",
666 EntityType::custom("CONCEPT_PRINCIPLE", EntityCategory::Misc),
667 );
668 mapper.add(
669 "BIOP",
670 EntityType::custom("BIO_PROCESS", EntityCategory::Misc),
671 );
672 mapper.add(
673 "MANS",
674 EntityType::custom("MAN_STANDARD", EntityCategory::Misc),
675 );
676 mapper
677 }
678
679 pub fn add(&mut self, source: impl Into<String>, target: EntityType) {
681 self.mappings.insert(source.into().to_uppercase(), target);
682 }
683
684 #[must_use]
686 pub fn map(&self, label: &str) -> Option<&EntityType> {
687 self.mappings.get(&label.to_uppercase())
688 }
689
690 #[must_use]
694 pub fn normalize(&self, label: &str) -> EntityType {
695 self.map(label)
696 .cloned()
697 .unwrap_or_else(|| EntityType::from_label(label))
698 }
699
700 #[must_use]
702 pub fn contains(&self, label: &str) -> bool {
703 self.mappings.contains_key(&label.to_uppercase())
704 }
705
706 pub fn labels(&self) -> impl Iterator<Item = &String> {
708 self.mappings.keys()
709 }
710}
711
712#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
728#[non_exhaustive]
729pub enum ExtractionMethod {
730 Pattern,
733
734 #[default]
737 Neural,
738
739 #[deprecated(since = "0.2.0", note = "Use Neural or GatedEnsemble instead")]
743 Lexicon,
744
745 SoftLexicon,
749
750 GatedEnsemble,
754
755 Consensus,
757
758 Heuristic,
761
762 Unknown,
764
765 #[deprecated(since = "0.2.0", note = "Use Heuristic or Pattern instead")]
767 Rule,
768
769 #[deprecated(since = "0.2.0", note = "Use Neural instead")]
771 ML,
772
773 #[deprecated(since = "0.2.0", note = "Use Consensus instead")]
775 Ensemble,
776}
777
778impl ExtractionMethod {
779 #[must_use]
806 pub const fn is_calibrated(&self) -> bool {
807 #[allow(deprecated)]
808 match self {
809 ExtractionMethod::Neural => true,
810 ExtractionMethod::GatedEnsemble => true,
811 ExtractionMethod::SoftLexicon => true,
812 ExtractionMethod::ML => true, ExtractionMethod::Pattern => false,
815 ExtractionMethod::Lexicon => false,
816 ExtractionMethod::Consensus => false,
817 ExtractionMethod::Heuristic => false,
818 ExtractionMethod::Unknown => false,
819 ExtractionMethod::Rule => false,
820 ExtractionMethod::Ensemble => false,
821 }
822 }
823
824 #[must_use]
831 pub const fn confidence_interpretation(&self) -> &'static str {
832 #[allow(deprecated)]
833 match self {
834 ExtractionMethod::Neural | ExtractionMethod::ML => "probability",
835 ExtractionMethod::GatedEnsemble | ExtractionMethod::SoftLexicon => "probability",
836 ExtractionMethod::Pattern | ExtractionMethod::Lexicon => "binary",
837 ExtractionMethod::Heuristic | ExtractionMethod::Rule => "heuristic_score",
838 ExtractionMethod::Consensus | ExtractionMethod::Ensemble => "agreement_ratio",
839 ExtractionMethod::Unknown => "unknown",
840 }
841 }
842}
843
844impl std::fmt::Display for ExtractionMethod {
845 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
846 #[allow(deprecated)]
847 match self {
848 ExtractionMethod::Pattern => write!(f, "pattern"),
849 ExtractionMethod::Neural => write!(f, "neural"),
850 ExtractionMethod::Lexicon => write!(f, "lexicon"),
851 ExtractionMethod::SoftLexicon => write!(f, "soft_lexicon"),
852 ExtractionMethod::GatedEnsemble => write!(f, "gated_ensemble"),
853 ExtractionMethod::Consensus => write!(f, "consensus"),
854 ExtractionMethod::Heuristic => write!(f, "heuristic"),
855 ExtractionMethod::Unknown => write!(f, "unknown"),
856 ExtractionMethod::Rule => write!(f, "heuristic"), ExtractionMethod::ML => write!(f, "neural"), ExtractionMethod::Ensemble => write!(f, "consensus"), }
860 }
861}
862
863pub trait Lexicon: Send + Sync {
917 fn lookup(&self, text: &str) -> Option<(EntityType, f64)>;
921
922 fn contains(&self, text: &str) -> bool {
924 self.lookup(text).is_some()
925 }
926
927 fn source(&self) -> &str;
929
930 fn len(&self) -> usize;
932
933 fn is_empty(&self) -> bool {
935 self.len() == 0
936 }
937}
938
939#[derive(Debug, Clone)]
944pub struct HashMapLexicon {
945 entries: std::collections::HashMap<String, (EntityType, f64)>,
946 source: String,
947}
948
949impl HashMapLexicon {
950 #[must_use]
952 pub fn new(source: impl Into<String>) -> Self {
953 Self {
954 entries: std::collections::HashMap::new(),
955 source: source.into(),
956 }
957 }
958
959 pub fn insert(&mut self, text: impl Into<String>, entity_type: EntityType, confidence: f64) {
961 self.entries.insert(text.into(), (entity_type, confidence));
962 }
963
964 pub fn from_iter<I, S>(source: impl Into<String>, entries: I) -> Self
966 where
967 I: IntoIterator<Item = (S, EntityType, f64)>,
968 S: Into<String>,
969 {
970 let mut lexicon = Self::new(source);
971 for (text, entity_type, confidence) in entries {
972 lexicon.insert(text, entity_type, confidence);
973 }
974 lexicon
975 }
976
977 pub fn entries(&self) -> impl Iterator<Item = (&str, &EntityType, f64)> {
979 self.entries.iter().map(|(k, (t, c))| (k.as_str(), t, *c))
980 }
981}
982
983impl Lexicon for HashMapLexicon {
984 fn lookup(&self, text: &str) -> Option<(EntityType, f64)> {
985 self.entries.get(text).cloned()
986 }
987
988 fn source(&self) -> &str {
989 &self.source
990 }
991
992 fn len(&self) -> usize {
993 self.entries.len()
994 }
995}
996
997#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
1002pub struct Provenance {
1003 pub source: Cow<'static, str>,
1005 pub method: ExtractionMethod,
1007 pub pattern: Option<Cow<'static, str>>,
1009 pub raw_confidence: Option<f64>,
1011 #[serde(default, skip_serializing_if = "Option::is_none")]
1013 pub model_version: Option<Cow<'static, str>>,
1014 #[serde(default, skip_serializing_if = "Option::is_none")]
1016 pub timestamp: Option<String>,
1017}
1018
1019impl Provenance {
1020 #[must_use]
1022 pub fn pattern(pattern_name: &'static str) -> Self {
1023 Self {
1024 source: Cow::Borrowed("pattern"),
1025 method: ExtractionMethod::Pattern,
1026 pattern: Some(Cow::Borrowed(pattern_name)),
1027 raw_confidence: Some(1.0), model_version: None,
1029 timestamp: None,
1030 }
1031 }
1032
1033 #[must_use]
1047 pub fn ml(model_name: impl Into<Cow<'static, str>>, confidence: f64) -> Self {
1048 Self {
1049 source: model_name.into(),
1050 method: ExtractionMethod::Neural,
1051 pattern: None,
1052 raw_confidence: Some(confidence),
1053 model_version: None,
1054 timestamp: None,
1055 }
1056 }
1057
1058 #[deprecated(
1060 since = "0.2.1",
1061 note = "Use ml() instead, it now accepts owned strings"
1062 )]
1063 #[must_use]
1064 pub fn ml_owned(model_name: impl Into<String>, confidence: f64) -> Self {
1065 Self::ml(Cow::Owned(model_name.into()), confidence)
1066 }
1067
1068 #[must_use]
1070 pub fn ensemble(sources: &'static str) -> Self {
1071 Self {
1072 source: Cow::Borrowed(sources),
1073 method: ExtractionMethod::Consensus,
1074 pattern: None,
1075 raw_confidence: None,
1076 model_version: None,
1077 timestamp: None,
1078 }
1079 }
1080
1081 #[must_use]
1083 pub fn with_version(mut self, version: &'static str) -> Self {
1084 self.model_version = Some(Cow::Borrowed(version));
1085 self
1086 }
1087
1088 #[must_use]
1090 pub fn with_timestamp(mut self, timestamp: impl Into<String>) -> Self {
1091 self.timestamp = Some(timestamp.into());
1092 self
1093 }
1094}
1095
1096#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1127pub enum Span {
1128 Text {
1133 start: usize,
1135 end: usize,
1137 },
1138 BoundingBox {
1141 x: f32,
1143 y: f32,
1145 width: f32,
1147 height: f32,
1149 page: Option<u32>,
1151 },
1152 Hybrid {
1154 start: usize,
1156 end: usize,
1158 bbox: Box<Span>,
1160 },
1161}
1162
1163impl Span {
1164 #[must_use]
1166 pub const fn text(start: usize, end: usize) -> Self {
1167 Self::Text { start, end }
1168 }
1169
1170 #[must_use]
1172 pub fn bbox(x: f32, y: f32, width: f32, height: f32) -> Self {
1173 Self::BoundingBox {
1174 x,
1175 y,
1176 width,
1177 height,
1178 page: None,
1179 }
1180 }
1181
1182 #[must_use]
1184 pub fn bbox_on_page(x: f32, y: f32, width: f32, height: f32, page: u32) -> Self {
1185 Self::BoundingBox {
1186 x,
1187 y,
1188 width,
1189 height,
1190 page: Some(page),
1191 }
1192 }
1193
1194 #[must_use]
1196 pub const fn is_text(&self) -> bool {
1197 matches!(self, Self::Text { .. } | Self::Hybrid { .. })
1198 }
1199
1200 #[must_use]
1202 pub const fn is_visual(&self) -> bool {
1203 matches!(self, Self::BoundingBox { .. } | Self::Hybrid { .. })
1204 }
1205
1206 #[must_use]
1208 pub const fn text_offsets(&self) -> Option<(usize, usize)> {
1209 match self {
1210 Self::Text { start, end } => Some((*start, *end)),
1211 Self::Hybrid { start, end, .. } => Some((*start, *end)),
1212 Self::BoundingBox { .. } => None,
1213 }
1214 }
1215
1216 #[must_use]
1218 pub fn len(&self) -> usize {
1219 match self {
1220 Self::Text { start, end } => end.saturating_sub(*start),
1221 Self::Hybrid { start, end, .. } => end.saturating_sub(*start),
1222 Self::BoundingBox { .. } => 0,
1223 }
1224 }
1225
1226 #[must_use]
1228 pub fn is_empty(&self) -> bool {
1229 self.len() == 0
1230 }
1231}
1232
1233#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1274pub struct DiscontinuousSpan {
1275 segments: Vec<std::ops::Range<usize>>,
1278}
1279
1280impl DiscontinuousSpan {
1281 #[must_use]
1285 pub fn new(mut segments: Vec<std::ops::Range<usize>>) -> Self {
1286 segments.sort_by_key(|r| r.start);
1288 Self { segments }
1289 }
1290
1291 #[must_use]
1293 #[allow(clippy::single_range_in_vec_init)] pub fn contiguous(start: usize, end: usize) -> Self {
1295 Self {
1296 segments: vec![start..end],
1297 }
1298 }
1299
1300 #[must_use]
1302 pub fn num_segments(&self) -> usize {
1303 self.segments.len()
1304 }
1305
1306 #[must_use]
1308 pub fn is_discontinuous(&self) -> bool {
1309 self.segments.len() > 1
1310 }
1311
1312 #[must_use]
1314 pub fn is_contiguous(&self) -> bool {
1315 self.segments.len() <= 1
1316 }
1317
1318 #[must_use]
1320 pub fn segments(&self) -> &[std::ops::Range<usize>] {
1321 &self.segments
1322 }
1323
1324 #[must_use]
1326 pub fn bounding_range(&self) -> Option<std::ops::Range<usize>> {
1327 if self.segments.is_empty() {
1328 return None;
1329 }
1330 let start = self.segments.first()?.start;
1331 let end = self.segments.last()?.end;
1332 Some(start..end)
1333 }
1334
1335 #[must_use]
1338 pub fn total_len(&self) -> usize {
1339 self.segments.iter().map(|r| r.end - r.start).sum()
1340 }
1341
1342 #[must_use]
1344 pub fn extract_text(&self, text: &str, separator: &str) -> String {
1345 self.segments
1346 .iter()
1347 .map(|r| {
1348 let start = r.start;
1349 let len = r.end.saturating_sub(r.start);
1350 text.chars().skip(start).take(len).collect::<String>()
1351 })
1352 .collect::<Vec<_>>()
1353 .join(separator)
1354 }
1355
1356 #[must_use]
1366 pub fn contains(&self, pos: usize) -> bool {
1367 self.segments.iter().any(|r| r.contains(&pos))
1368 }
1369
1370 #[must_use]
1372 pub fn to_span(&self) -> Option<Span> {
1373 self.bounding_range().map(|r| Span::Text {
1374 start: r.start,
1375 end: r.end,
1376 })
1377 }
1378}
1379
1380impl From<std::ops::Range<usize>> for DiscontinuousSpan {
1381 fn from(range: std::ops::Range<usize>) -> Self {
1382 Self::contiguous(range.start, range.end)
1383 }
1384}
1385
1386impl Default for Span {
1387 fn default() -> Self {
1388 Self::Text { start: 0, end: 0 }
1389 }
1390}
1391
1392#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
1404pub struct HierarchicalConfidence {
1405 pub linkage: f32,
1408 pub type_score: f32,
1410 pub boundary: f32,
1413}
1414
1415impl HierarchicalConfidence {
1416 #[must_use]
1418 pub fn new(linkage: f32, type_score: f32, boundary: f32) -> Self {
1419 Self {
1420 linkage: linkage.clamp(0.0, 1.0),
1421 type_score: type_score.clamp(0.0, 1.0),
1422 boundary: boundary.clamp(0.0, 1.0),
1423 }
1424 }
1425
1426 #[must_use]
1429 pub fn from_single(confidence: f32) -> Self {
1430 let c = confidence.clamp(0.0, 1.0);
1431 Self {
1432 linkage: c,
1433 type_score: c,
1434 boundary: c,
1435 }
1436 }
1437
1438 #[must_use]
1441 pub fn combined(&self) -> f32 {
1442 (self.linkage * self.type_score * self.boundary).powf(1.0 / 3.0)
1443 }
1444
1445 #[must_use]
1447 pub fn as_f64(&self) -> f64 {
1448 self.combined() as f64
1449 }
1450
1451 #[must_use]
1453 pub fn passes_threshold(&self, linkage_min: f32, type_min: f32, boundary_min: f32) -> bool {
1454 self.linkage >= linkage_min && self.type_score >= type_min && self.boundary >= boundary_min
1455 }
1456}
1457
1458impl Default for HierarchicalConfidence {
1459 fn default() -> Self {
1460 Self {
1461 linkage: 1.0,
1462 type_score: 1.0,
1463 boundary: 1.0,
1464 }
1465 }
1466}
1467
1468impl From<f64> for HierarchicalConfidence {
1469 fn from(confidence: f64) -> Self {
1470 Self::from_single(confidence as f32)
1471 }
1472}
1473
1474impl From<f32> for HierarchicalConfidence {
1475 fn from(confidence: f32) -> Self {
1476 Self::from_single(confidence)
1477 }
1478}
1479
1480#[derive(Debug, Clone)]
1502pub struct RaggedBatch {
1503 pub token_ids: Vec<u32>,
1506 pub cumulative_offsets: Vec<u32>,
1510 pub max_seq_len: usize,
1512}
1513
1514impl RaggedBatch {
1515 pub fn from_sequences(sequences: &[Vec<u32>]) -> Self {
1517 let total_tokens: usize = sequences.iter().map(|s| s.len()).sum();
1518 let mut token_ids = Vec::with_capacity(total_tokens);
1519 let mut cumulative_offsets = Vec::with_capacity(sequences.len() + 1);
1520 let mut max_seq_len = 0;
1521
1522 cumulative_offsets.push(0);
1523 for seq in sequences {
1524 token_ids.extend_from_slice(seq);
1525 let len = token_ids.len();
1529 if len > u32::MAX as usize {
1530 log::warn!(
1533 "Token count {} exceeds u32::MAX, truncating to {}",
1534 len,
1535 u32::MAX
1536 );
1537 cumulative_offsets.push(u32::MAX);
1538 } else {
1539 cumulative_offsets.push(len as u32);
1540 }
1541 max_seq_len = max_seq_len.max(seq.len());
1542 }
1543
1544 Self {
1545 token_ids,
1546 cumulative_offsets,
1547 max_seq_len,
1548 }
1549 }
1550
1551 #[must_use]
1553 pub fn batch_size(&self) -> usize {
1554 self.cumulative_offsets.len().saturating_sub(1)
1555 }
1556
1557 #[must_use]
1559 pub fn total_tokens(&self) -> usize {
1560 self.token_ids.len()
1561 }
1562
1563 #[must_use]
1565 pub fn doc_range(&self, doc_idx: usize) -> Option<std::ops::Range<usize>> {
1566 if doc_idx + 1 < self.cumulative_offsets.len() {
1567 let start = self.cumulative_offsets[doc_idx] as usize;
1568 let end = self.cumulative_offsets[doc_idx + 1] as usize;
1569 Some(start..end)
1570 } else {
1571 None
1572 }
1573 }
1574
1575 #[must_use]
1577 pub fn doc_tokens(&self, doc_idx: usize) -> Option<&[u32]> {
1578 self.doc_range(doc_idx).map(|r| &self.token_ids[r])
1579 }
1580
1581 #[must_use]
1583 pub fn padding_savings(&self) -> f64 {
1584 let padded_size = self.batch_size() * self.max_seq_len;
1585 if padded_size == 0 {
1586 return 0.0;
1587 }
1588 1.0 - (self.total_tokens() as f64 / padded_size as f64)
1589 }
1590}
1591
1592#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1601pub struct SpanCandidate {
1602 pub doc_idx: u32,
1604 pub start: u32,
1606 pub end: u32,
1608}
1609
1610impl SpanCandidate {
1611 #[must_use]
1613 pub const fn new(doc_idx: u32, start: u32, end: u32) -> Self {
1614 Self {
1615 doc_idx,
1616 start,
1617 end,
1618 }
1619 }
1620
1621 #[must_use]
1623 pub const fn width(&self) -> u32 {
1624 self.end.saturating_sub(self.start)
1625 }
1626}
1627
1628pub fn generate_span_candidates(batch: &RaggedBatch, max_width: usize) -> Vec<SpanCandidate> {
1633 let mut candidates = Vec::new();
1634
1635 for doc_idx in 0..batch.batch_size() {
1636 if let Some(range) = batch.doc_range(doc_idx) {
1637 let doc_len = range.len();
1638 for start in 0..doc_len {
1640 let max_end = (start + max_width).min(doc_len);
1641 for end in (start + 1)..=max_end {
1642 candidates.push(SpanCandidate::new(doc_idx as u32, start as u32, end as u32));
1643 }
1644 }
1645 }
1646 }
1647
1648 candidates
1649}
1650
1651pub fn generate_filtered_candidates(
1655 batch: &RaggedBatch,
1656 max_width: usize,
1657 linkage_mask: &[f32],
1658 threshold: f32,
1659) -> Vec<SpanCandidate> {
1660 let mut candidates = Vec::new();
1661 let mut mask_idx = 0;
1662
1663 for doc_idx in 0..batch.batch_size() {
1664 if let Some(range) = batch.doc_range(doc_idx) {
1665 let doc_len = range.len();
1666 for start in 0..doc_len {
1667 let max_end = (start + max_width).min(doc_len);
1668 for end in (start + 1)..=max_end {
1669 if mask_idx < linkage_mask.len() && linkage_mask[mask_idx] >= threshold {
1671 candidates.push(SpanCandidate::new(
1672 doc_idx as u32,
1673 start as u32,
1674 end as u32,
1675 ));
1676 }
1677 mask_idx += 1;
1678 }
1679 }
1680 }
1681 }
1682
1683 candidates
1684}
1685
1686#[derive(Debug, Clone, Serialize, Deserialize)]
1737pub struct Entity {
1738 pub text: String,
1740 pub entity_type: EntityType,
1742 pub start: usize,
1747 pub end: usize,
1752 pub confidence: f64,
1754 #[serde(default, skip_serializing_if = "Option::is_none")]
1756 pub normalized: Option<String>,
1757 #[serde(default, skip_serializing_if = "Option::is_none")]
1759 pub provenance: Option<Provenance>,
1760 #[serde(default, skip_serializing_if = "Option::is_none")]
1763 pub kb_id: Option<String>,
1764 #[serde(default, skip_serializing_if = "Option::is_none")]
1768 pub canonical_id: Option<super::types::CanonicalId>,
1769 #[serde(default, skip_serializing_if = "Option::is_none")]
1772 pub hierarchical_confidence: Option<HierarchicalConfidence>,
1773 #[serde(default, skip_serializing_if = "Option::is_none")]
1776 pub visual_span: Option<Span>,
1777 #[serde(default, skip_serializing_if = "Option::is_none")]
1781 pub discontinuous_span: Option<DiscontinuousSpan>,
1782 #[serde(default, skip_serializing_if = "Option::is_none")]
1804 pub valid_from: Option<chrono::DateTime<chrono::Utc>>,
1805 #[serde(default, skip_serializing_if = "Option::is_none")]
1810 pub valid_until: Option<chrono::DateTime<chrono::Utc>>,
1811 #[serde(default, skip_serializing_if = "Option::is_none")]
1832 pub viewport: Option<EntityViewport>,
1833}
1834
1835impl Entity {
1836 #[must_use]
1847 pub fn new(
1848 text: impl Into<String>,
1849 entity_type: EntityType,
1850 start: usize,
1851 end: usize,
1852 confidence: f64,
1853 ) -> Self {
1854 Self {
1855 text: text.into(),
1856 entity_type,
1857 start,
1858 end,
1859 confidence: confidence.clamp(0.0, 1.0),
1860 normalized: None,
1861 provenance: None,
1862 kb_id: None,
1863 canonical_id: None,
1864 hierarchical_confidence: None,
1865 visual_span: None,
1866 discontinuous_span: None,
1867 valid_from: None,
1868 valid_until: None,
1869 viewport: None,
1870 }
1871 }
1872
1873 #[must_use]
1875 pub fn with_provenance(
1876 text: impl Into<String>,
1877 entity_type: EntityType,
1878 start: usize,
1879 end: usize,
1880 confidence: f64,
1881 provenance: Provenance,
1882 ) -> Self {
1883 Self {
1884 text: text.into(),
1885 entity_type,
1886 start,
1887 end,
1888 confidence: confidence.clamp(0.0, 1.0),
1889 normalized: None,
1890 provenance: Some(provenance),
1891 kb_id: None,
1892 canonical_id: None,
1893 hierarchical_confidence: None,
1894 visual_span: None,
1895 discontinuous_span: None,
1896 valid_from: None,
1897 valid_until: None,
1898 viewport: None,
1899 }
1900 }
1901
1902 #[must_use]
1904 pub fn with_hierarchical_confidence(
1905 text: impl Into<String>,
1906 entity_type: EntityType,
1907 start: usize,
1908 end: usize,
1909 confidence: HierarchicalConfidence,
1910 ) -> Self {
1911 Self {
1912 text: text.into(),
1913 entity_type,
1914 start,
1915 end,
1916 confidence: confidence.as_f64(),
1917 normalized: None,
1918 provenance: None,
1919 kb_id: None,
1920 canonical_id: None,
1921 hierarchical_confidence: Some(confidence),
1922 visual_span: None,
1923 discontinuous_span: None,
1924 valid_from: None,
1925 valid_until: None,
1926 viewport: None,
1927 }
1928 }
1929
1930 #[must_use]
1932 pub fn from_visual(
1933 text: impl Into<String>,
1934 entity_type: EntityType,
1935 bbox: Span,
1936 confidence: f64,
1937 ) -> Self {
1938 Self {
1939 text: text.into(),
1940 entity_type,
1941 start: 0,
1942 end: 0,
1943 confidence: confidence.clamp(0.0, 1.0),
1944 normalized: None,
1945 provenance: None,
1946 kb_id: None,
1947 canonical_id: None,
1948 hierarchical_confidence: None,
1949 visual_span: Some(bbox),
1950 discontinuous_span: None,
1951 valid_from: None,
1952 valid_until: None,
1953 viewport: None,
1954 }
1955 }
1956
1957 #[must_use]
1959 pub fn with_type(
1960 text: impl Into<String>,
1961 entity_type: EntityType,
1962 start: usize,
1963 end: usize,
1964 ) -> Self {
1965 Self::new(text, entity_type, start, end, 1.0)
1966 }
1967
1968 pub fn link_to_kb(&mut self, kb_id: impl Into<String>) {
1978 self.kb_id = Some(kb_id.into());
1979 }
1980
1981 pub fn set_canonical(&mut self, canonical_id: impl Into<super::types::CanonicalId>) {
1985 self.canonical_id = Some(canonical_id.into());
1986 }
1987
1988 #[must_use]
1998 pub fn with_canonical_id(mut self, canonical_id: impl Into<super::types::CanonicalId>) -> Self {
1999 self.canonical_id = Some(canonical_id.into());
2000 self
2001 }
2002
2003 #[must_use]
2005 pub fn is_linked(&self) -> bool {
2006 self.kb_id.is_some()
2007 }
2008
2009 #[must_use]
2011 pub fn has_coreference(&self) -> bool {
2012 self.canonical_id.is_some()
2013 }
2014
2015 #[must_use]
2021 pub fn is_discontinuous(&self) -> bool {
2022 self.discontinuous_span
2023 .as_ref()
2024 .map(|s| s.is_discontinuous())
2025 .unwrap_or(false)
2026 }
2027
2028 #[must_use]
2032 pub fn discontinuous_segments(&self) -> Option<Vec<std::ops::Range<usize>>> {
2033 self.discontinuous_span
2034 .as_ref()
2035 .filter(|s| s.is_discontinuous())
2036 .map(|s| s.segments().to_vec())
2037 }
2038
2039 pub fn set_discontinuous_span(&mut self, span: DiscontinuousSpan) {
2043 if let Some(bounding) = span.bounding_range() {
2045 self.start = bounding.start;
2046 self.end = bounding.end;
2047 }
2048 self.discontinuous_span = Some(span);
2049 }
2050
2051 #[must_use]
2059 pub fn total_len(&self) -> usize {
2060 if let Some(ref span) = self.discontinuous_span {
2061 span.segments().iter().map(|r| r.end - r.start).sum()
2062 } else {
2063 self.end.saturating_sub(self.start)
2064 }
2065 }
2066
2067 pub fn set_normalized(&mut self, normalized: impl Into<String>) {
2079 self.normalized = Some(normalized.into());
2080 }
2081
2082 #[must_use]
2084 pub fn normalized_or_text(&self) -> &str {
2085 self.normalized.as_deref().unwrap_or(&self.text)
2086 }
2087
2088 #[must_use]
2090 pub fn method(&self) -> ExtractionMethod {
2091 self.provenance
2092 .as_ref()
2093 .map_or(ExtractionMethod::Unknown, |p| p.method)
2094 }
2095
2096 #[must_use]
2098 pub fn source(&self) -> Option<&str> {
2099 self.provenance.as_ref().map(|p| p.source.as_ref())
2100 }
2101
2102 #[must_use]
2104 pub fn category(&self) -> EntityCategory {
2105 self.entity_type.category()
2106 }
2107
2108 #[must_use]
2110 pub fn is_structured(&self) -> bool {
2111 self.entity_type.pattern_detectable()
2112 }
2113
2114 #[must_use]
2116 pub fn is_named(&self) -> bool {
2117 self.entity_type.requires_ml()
2118 }
2119
2120 #[must_use]
2122 pub fn overlaps(&self, other: &Entity) -> bool {
2123 !(self.end <= other.start || other.end <= self.start)
2124 }
2125
2126 #[must_use]
2128 pub fn overlap_ratio(&self, other: &Entity) -> f64 {
2129 let intersection_start = self.start.max(other.start);
2130 let intersection_end = self.end.min(other.end);
2131
2132 if intersection_start >= intersection_end {
2133 return 0.0;
2134 }
2135
2136 let intersection = (intersection_end - intersection_start) as f64;
2137 let union = ((self.end - self.start) + (other.end - other.start)
2138 - (intersection_end - intersection_start)) as f64;
2139
2140 if union == 0.0 {
2141 return 1.0;
2142 }
2143
2144 intersection / union
2145 }
2146
2147 pub fn set_hierarchical_confidence(&mut self, confidence: HierarchicalConfidence) {
2149 self.confidence = confidence.as_f64();
2150 self.hierarchical_confidence = Some(confidence);
2151 }
2152
2153 #[must_use]
2155 pub fn linkage_confidence(&self) -> f32 {
2156 self.hierarchical_confidence
2157 .map_or(self.confidence as f32, |h| h.linkage)
2158 }
2159
2160 #[must_use]
2162 pub fn type_confidence(&self) -> f32 {
2163 self.hierarchical_confidence
2164 .map_or(self.confidence as f32, |h| h.type_score)
2165 }
2166
2167 #[must_use]
2169 pub fn boundary_confidence(&self) -> f32 {
2170 self.hierarchical_confidence
2171 .map_or(self.confidence as f32, |h| h.boundary)
2172 }
2173
2174 #[must_use]
2176 pub fn is_visual(&self) -> bool {
2177 self.visual_span.is_some()
2178 }
2179
2180 #[must_use]
2182 pub const fn text_span(&self) -> (usize, usize) {
2183 (self.start, self.end)
2184 }
2185
2186 #[must_use]
2188 pub const fn span_len(&self) -> usize {
2189 self.end.saturating_sub(self.start)
2190 }
2191
2192 #[allow(dead_code)]
2216 #[doc(hidden)]
2217 pub fn to_text_span(&self, _source_text: &str) -> serde_json::Value {
2218 unimplemented!("Use anno::offset utilities directly - see method docs")
2219 }
2220
2221 pub fn set_visual_span(&mut self, span: Span) {
2223 self.visual_span = Some(span);
2224 }
2225
2226 #[must_use]
2246 pub fn extract_text(&self, source_text: &str) -> String {
2247 let char_count = source_text.chars().count();
2251 self.extract_text_with_len(source_text, char_count)
2252 }
2253
2254 #[must_use]
2266 pub fn extract_text_with_len(&self, source_text: &str, text_char_count: usize) -> String {
2267 if self.start >= text_char_count || self.end > text_char_count || self.start >= self.end {
2268 return String::new();
2269 }
2270 source_text
2271 .chars()
2272 .skip(self.start)
2273 .take(self.end - self.start)
2274 .collect()
2275 }
2276
2277 pub fn set_valid_from(&mut self, dt: chrono::DateTime<chrono::Utc>) {
2293 self.valid_from = Some(dt);
2294 }
2295
2296 pub fn set_valid_until(&mut self, dt: chrono::DateTime<chrono::Utc>) {
2298 self.valid_until = Some(dt);
2299 }
2300
2301 pub fn set_temporal_range(
2303 &mut self,
2304 from: chrono::DateTime<chrono::Utc>,
2305 until: chrono::DateTime<chrono::Utc>,
2306 ) {
2307 self.valid_from = Some(from);
2308 self.valid_until = Some(until);
2309 }
2310
2311 #[must_use]
2313 pub fn is_temporal(&self) -> bool {
2314 self.valid_from.is_some() || self.valid_until.is_some()
2315 }
2316
2317 #[must_use]
2339 pub fn valid_at(&self, timestamp: &chrono::DateTime<chrono::Utc>) -> bool {
2340 match (&self.valid_from, &self.valid_until) {
2341 (None, None) => true, (Some(from), None) => timestamp >= from, (None, Some(until)) => timestamp <= until, (Some(from), Some(until)) => timestamp >= from && timestamp <= until,
2345 }
2346 }
2347
2348 #[must_use]
2350 pub fn is_currently_valid(&self) -> bool {
2351 self.valid_at(&chrono::Utc::now())
2352 }
2353
2354 pub fn set_viewport(&mut self, viewport: EntityViewport) {
2369 self.viewport = Some(viewport);
2370 }
2371
2372 #[must_use]
2374 pub fn has_viewport(&self) -> bool {
2375 self.viewport.is_some()
2376 }
2377
2378 #[must_use]
2380 pub fn viewport_or_default(&self) -> EntityViewport {
2381 self.viewport.clone().unwrap_or_default()
2382 }
2383
2384 #[must_use]
2390 pub fn matches_viewport(&self, query_viewport: &EntityViewport) -> bool {
2391 match &self.viewport {
2392 None => true, Some(v) => v == query_viewport,
2394 }
2395 }
2396
2397 #[must_use]
2399 pub fn builder(text: impl Into<String>, entity_type: EntityType) -> EntityBuilder {
2400 EntityBuilder::new(text, entity_type)
2401 }
2402
2403 #[must_use]
2436 pub fn validate(&self, source_text: &str) -> Vec<ValidationIssue> {
2437 let char_count = source_text.chars().count();
2439 self.validate_with_len(source_text, char_count)
2440 }
2441
2442 #[must_use]
2454 pub fn validate_with_len(
2455 &self,
2456 source_text: &str,
2457 text_char_count: usize,
2458 ) -> Vec<ValidationIssue> {
2459 let mut issues = Vec::new();
2460
2461 if self.start >= self.end {
2463 issues.push(ValidationIssue::InvalidSpan {
2464 start: self.start,
2465 end: self.end,
2466 reason: "start must be less than end".to_string(),
2467 });
2468 }
2469
2470 if self.end > text_char_count {
2471 issues.push(ValidationIssue::SpanOutOfBounds {
2472 end: self.end,
2473 text_len: text_char_count,
2474 });
2475 }
2476
2477 if self.start < self.end && self.end <= text_char_count {
2479 let actual = self.extract_text_with_len(source_text, text_char_count);
2480 if actual != self.text {
2481 issues.push(ValidationIssue::TextMismatch {
2482 expected: self.text.clone(),
2483 actual,
2484 start: self.start,
2485 end: self.end,
2486 });
2487 }
2488 }
2489
2490 if !(0.0..=1.0).contains(&self.confidence) {
2492 issues.push(ValidationIssue::InvalidConfidence {
2493 value: self.confidence,
2494 });
2495 }
2496
2497 if let EntityType::Custom { ref name, .. } = self.entity_type {
2499 if name.is_empty() {
2500 issues.push(ValidationIssue::InvalidType {
2501 reason: "Custom entity type has empty name".to_string(),
2502 });
2503 }
2504 }
2505
2506 if let Some(ref disc_span) = self.discontinuous_span {
2508 for (i, seg) in disc_span.segments().iter().enumerate() {
2509 if seg.start >= seg.end {
2510 issues.push(ValidationIssue::InvalidSpan {
2511 start: seg.start,
2512 end: seg.end,
2513 reason: format!("discontinuous segment {} is invalid", i),
2514 });
2515 }
2516 if seg.end > text_char_count {
2517 issues.push(ValidationIssue::SpanOutOfBounds {
2518 end: seg.end,
2519 text_len: text_char_count,
2520 });
2521 }
2522 }
2523 }
2524
2525 issues
2526 }
2527
2528 #[must_use]
2532 pub fn is_valid(&self, source_text: &str) -> bool {
2533 self.validate(source_text).is_empty()
2534 }
2535
2536 #[must_use]
2556 pub fn validate_batch(
2557 entities: &[Entity],
2558 source_text: &str,
2559 ) -> std::collections::HashMap<usize, Vec<ValidationIssue>> {
2560 entities
2561 .iter()
2562 .enumerate()
2563 .filter_map(|(idx, entity)| {
2564 let issues = entity.validate(source_text);
2565 if issues.is_empty() {
2566 None
2567 } else {
2568 Some((idx, issues))
2569 }
2570 })
2571 .collect()
2572 }
2573}
2574
2575#[derive(Debug, Clone, PartialEq)]
2577pub enum ValidationIssue {
2578 InvalidSpan {
2580 start: usize,
2582 end: usize,
2584 reason: String,
2586 },
2587 SpanOutOfBounds {
2589 end: usize,
2591 text_len: usize,
2593 },
2594 TextMismatch {
2596 expected: String,
2598 actual: String,
2600 start: usize,
2602 end: usize,
2604 },
2605 InvalidConfidence {
2607 value: f64,
2609 },
2610 InvalidType {
2612 reason: String,
2614 },
2615}
2616
2617impl std::fmt::Display for ValidationIssue {
2618 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2619 match self {
2620 ValidationIssue::InvalidSpan { start, end, reason } => {
2621 write!(f, "Invalid span [{}, {}): {}", start, end, reason)
2622 }
2623 ValidationIssue::SpanOutOfBounds { end, text_len } => {
2624 write!(f, "Span end {} exceeds text length {}", end, text_len)
2625 }
2626 ValidationIssue::TextMismatch {
2627 expected,
2628 actual,
2629 start,
2630 end,
2631 } => {
2632 write!(
2633 f,
2634 "Text mismatch at [{}, {}): expected '{}', got '{}'",
2635 start, end, expected, actual
2636 )
2637 }
2638 ValidationIssue::InvalidConfidence { value } => {
2639 write!(f, "Confidence {} outside [0.0, 1.0]", value)
2640 }
2641 ValidationIssue::InvalidType { reason } => {
2642 write!(f, "Invalid entity type: {}", reason)
2643 }
2644 }
2645 }
2646}
2647
2648#[derive(Debug, Clone)]
2663pub struct EntityBuilder {
2664 text: String,
2665 entity_type: EntityType,
2666 start: usize,
2667 end: usize,
2668 confidence: f64,
2669 normalized: Option<String>,
2670 provenance: Option<Provenance>,
2671 kb_id: Option<String>,
2672 canonical_id: Option<super::types::CanonicalId>,
2673 hierarchical_confidence: Option<HierarchicalConfidence>,
2674 visual_span: Option<Span>,
2675 discontinuous_span: Option<DiscontinuousSpan>,
2676 valid_from: Option<chrono::DateTime<chrono::Utc>>,
2677 valid_until: Option<chrono::DateTime<chrono::Utc>>,
2678 viewport: Option<EntityViewport>,
2679}
2680
2681impl EntityBuilder {
2682 #[must_use]
2684 pub fn new(text: impl Into<String>, entity_type: EntityType) -> Self {
2685 Self {
2686 text: text.into(),
2687 entity_type,
2688 start: 0,
2689 end: 0,
2690 confidence: 1.0,
2691 normalized: None,
2692 provenance: None,
2693 kb_id: None,
2694 canonical_id: None,
2695 hierarchical_confidence: None,
2696 visual_span: None,
2697 discontinuous_span: None,
2698 valid_from: None,
2699 valid_until: None,
2700 viewport: None,
2701 }
2702 }
2703
2704 #[must_use]
2706 pub const fn span(mut self, start: usize, end: usize) -> Self {
2707 self.start = start;
2708 self.end = end;
2709 self
2710 }
2711
2712 #[must_use]
2714 pub fn confidence(mut self, confidence: f64) -> Self {
2715 self.confidence = confidence.clamp(0.0, 1.0);
2716 self
2717 }
2718
2719 #[must_use]
2721 pub fn hierarchical_confidence(mut self, confidence: HierarchicalConfidence) -> Self {
2722 self.confidence = confidence.as_f64();
2723 self.hierarchical_confidence = Some(confidence);
2724 self
2725 }
2726
2727 #[must_use]
2729 pub fn normalized(mut self, normalized: impl Into<String>) -> Self {
2730 self.normalized = Some(normalized.into());
2731 self
2732 }
2733
2734 #[must_use]
2736 pub fn provenance(mut self, provenance: Provenance) -> Self {
2737 self.provenance = Some(provenance);
2738 self
2739 }
2740
2741 #[must_use]
2743 pub fn kb_id(mut self, kb_id: impl Into<String>) -> Self {
2744 self.kb_id = Some(kb_id.into());
2745 self
2746 }
2747
2748 #[must_use]
2750 pub const fn canonical_id(mut self, canonical_id: u64) -> Self {
2751 self.canonical_id = Some(super::types::CanonicalId::new(canonical_id));
2752 self
2753 }
2754
2755 #[must_use]
2757 pub fn visual_span(mut self, span: Span) -> Self {
2758 self.visual_span = Some(span);
2759 self
2760 }
2761
2762 #[must_use]
2766 pub fn discontinuous_span(mut self, span: DiscontinuousSpan) -> Self {
2767 if let Some(bounding) = span.bounding_range() {
2769 self.start = bounding.start;
2770 self.end = bounding.end;
2771 }
2772 self.discontinuous_span = Some(span);
2773 self
2774 }
2775
2776 #[must_use]
2790 pub fn valid_from(mut self, dt: chrono::DateTime<chrono::Utc>) -> Self {
2791 self.valid_from = Some(dt);
2792 self
2793 }
2794
2795 #[must_use]
2797 pub fn valid_until(mut self, dt: chrono::DateTime<chrono::Utc>) -> Self {
2798 self.valid_until = Some(dt);
2799 self
2800 }
2801
2802 #[must_use]
2804 pub fn temporal_range(
2805 mut self,
2806 from: chrono::DateTime<chrono::Utc>,
2807 until: chrono::DateTime<chrono::Utc>,
2808 ) -> Self {
2809 self.valid_from = Some(from);
2810 self.valid_until = Some(until);
2811 self
2812 }
2813
2814 #[must_use]
2827 pub fn viewport(mut self, viewport: EntityViewport) -> Self {
2828 self.viewport = Some(viewport);
2829 self
2830 }
2831
2832 #[must_use]
2834 pub fn build(self) -> Entity {
2835 Entity {
2836 text: self.text,
2837 entity_type: self.entity_type,
2838 start: self.start,
2839 end: self.end,
2840 confidence: self.confidence,
2841 normalized: self.normalized,
2842 provenance: self.provenance,
2843 kb_id: self.kb_id,
2844 canonical_id: self.canonical_id,
2845 hierarchical_confidence: self.hierarchical_confidence,
2846 visual_span: self.visual_span,
2847 discontinuous_span: self.discontinuous_span,
2848 valid_from: self.valid_from,
2849 valid_until: self.valid_until,
2850 viewport: self.viewport,
2851 }
2852 }
2853}
2854
2855#[derive(Debug, Clone, Serialize, Deserialize)]
2881pub struct Relation {
2882 pub head: Entity,
2884 pub tail: Entity,
2886 pub relation_type: String,
2888 pub trigger_span: Option<(usize, usize)>,
2891 pub confidence: f64,
2893}
2894
2895impl Relation {
2896 #[must_use]
2898 pub fn new(
2899 head: Entity,
2900 tail: Entity,
2901 relation_type: impl Into<String>,
2902 confidence: f64,
2903 ) -> Self {
2904 Self {
2905 head,
2906 tail,
2907 relation_type: relation_type.into(),
2908 trigger_span: None,
2909 confidence: confidence.clamp(0.0, 1.0),
2910 }
2911 }
2912
2913 #[must_use]
2915 pub fn with_trigger(
2916 head: Entity,
2917 tail: Entity,
2918 relation_type: impl Into<String>,
2919 trigger_start: usize,
2920 trigger_end: usize,
2921 confidence: f64,
2922 ) -> Self {
2923 Self {
2924 head,
2925 tail,
2926 relation_type: relation_type.into(),
2927 trigger_span: Some((trigger_start, trigger_end)),
2928 confidence: confidence.clamp(0.0, 1.0),
2929 }
2930 }
2931
2932 #[must_use]
2934 pub fn as_triple(&self) -> String {
2935 format!(
2936 "({}, {}, {})",
2937 self.head.text, self.relation_type, self.tail.text
2938 )
2939 }
2940
2941 #[must_use]
2944 pub fn span_distance(&self) -> usize {
2945 if self.head.end <= self.tail.start {
2946 self.tail.start.saturating_sub(self.head.end)
2947 } else if self.tail.end <= self.head.start {
2948 self.head.start.saturating_sub(self.tail.end)
2949 } else {
2950 0 }
2952 }
2953}
2954
2955#[cfg(test)]
2956mod tests {
2957 #![allow(clippy::unwrap_used)] use super::*;
2959
2960 #[test]
2961 fn test_entity_type_roundtrip() {
2962 let types = [
2963 EntityType::Person,
2964 EntityType::Organization,
2965 EntityType::Location,
2966 EntityType::Date,
2967 EntityType::Money,
2968 EntityType::Percent,
2969 ];
2970
2971 for t in types {
2972 let label = t.as_label();
2973 let parsed = EntityType::from_label(label);
2974 assert_eq!(t, parsed);
2975 }
2976 }
2977
2978 #[test]
2979 fn test_entity_overlap() {
2980 let e1 = Entity::new("John", EntityType::Person, 0, 4, 0.9);
2981 let e2 = Entity::new("Smith", EntityType::Person, 5, 10, 0.9);
2982 let e3 = Entity::new("John Smith", EntityType::Person, 0, 10, 0.9);
2983
2984 assert!(!e1.overlaps(&e2)); assert!(e1.overlaps(&e3)); assert!(e3.overlaps(&e2)); }
2988
2989 #[test]
2990 fn test_confidence_clamping() {
2991 let e1 = Entity::new("test", EntityType::Person, 0, 4, 1.5);
2992 assert!((e1.confidence - 1.0).abs() < f64::EPSILON);
2993
2994 let e2 = Entity::new("test", EntityType::Person, 0, 4, -0.5);
2995 assert!(e2.confidence.abs() < f64::EPSILON);
2996 }
2997
2998 #[test]
2999 fn test_entity_categories() {
3000 assert_eq!(EntityType::Person.category(), EntityCategory::Agent);
3002 assert_eq!(
3003 EntityType::Organization.category(),
3004 EntityCategory::Organization
3005 );
3006 assert_eq!(EntityType::Location.category(), EntityCategory::Place);
3007 assert!(EntityType::Person.requires_ml());
3008 assert!(!EntityType::Person.pattern_detectable());
3009
3010 assert_eq!(EntityType::Date.category(), EntityCategory::Temporal);
3012 assert_eq!(EntityType::Time.category(), EntityCategory::Temporal);
3013 assert!(EntityType::Date.pattern_detectable());
3014 assert!(!EntityType::Date.requires_ml());
3015
3016 assert_eq!(EntityType::Money.category(), EntityCategory::Numeric);
3018 assert_eq!(EntityType::Percent.category(), EntityCategory::Numeric);
3019 assert!(EntityType::Money.pattern_detectable());
3020
3021 assert_eq!(EntityType::Email.category(), EntityCategory::Contact);
3023 assert_eq!(EntityType::Url.category(), EntityCategory::Contact);
3024 assert_eq!(EntityType::Phone.category(), EntityCategory::Contact);
3025 assert!(EntityType::Email.pattern_detectable());
3026 }
3027
3028 #[test]
3029 fn test_new_types_roundtrip() {
3030 let types = [
3031 EntityType::Time,
3032 EntityType::Email,
3033 EntityType::Url,
3034 EntityType::Phone,
3035 EntityType::Quantity,
3036 EntityType::Cardinal,
3037 EntityType::Ordinal,
3038 ];
3039
3040 for t in types {
3041 let label = t.as_label();
3042 let parsed = EntityType::from_label(label);
3043 assert_eq!(t, parsed, "Roundtrip failed for {}", label);
3044 }
3045 }
3046
3047 #[test]
3048 fn test_custom_entity_type() {
3049 let disease = EntityType::custom("DISEASE", EntityCategory::Agent);
3050 assert_eq!(disease.as_label(), "DISEASE");
3051 assert!(disease.requires_ml());
3052
3053 let product_id = EntityType::custom("PRODUCT_ID", EntityCategory::Misc);
3054 assert_eq!(product_id.as_label(), "PRODUCT_ID");
3055 assert!(!product_id.requires_ml());
3056 assert!(!product_id.pattern_detectable());
3057 }
3058
3059 #[test]
3060 fn test_entity_normalization() {
3061 let mut e = Entity::new("Jan 15", EntityType::Date, 0, 6, 0.95);
3062 assert!(e.normalized.is_none());
3063 assert_eq!(e.normalized_or_text(), "Jan 15");
3064
3065 e.set_normalized("2024-01-15");
3066 assert_eq!(e.normalized.as_deref(), Some("2024-01-15"));
3067 assert_eq!(e.normalized_or_text(), "2024-01-15");
3068 }
3069
3070 #[test]
3071 fn test_entity_helpers() {
3072 let named = Entity::new("John", EntityType::Person, 0, 4, 0.9);
3073 assert!(named.is_named());
3074 assert!(!named.is_structured());
3075 assert_eq!(named.category(), EntityCategory::Agent);
3076
3077 let structured = Entity::new("$100", EntityType::Money, 0, 4, 0.95);
3078 assert!(!structured.is_named());
3079 assert!(structured.is_structured());
3080 assert_eq!(structured.category(), EntityCategory::Numeric);
3081 }
3082
3083 #[test]
3084 fn test_knowledge_linking() {
3085 let mut entity = Entity::new("Marie Curie", EntityType::Person, 0, 11, 0.95);
3086 assert!(!entity.is_linked());
3087 assert!(!entity.has_coreference());
3088
3089 entity.link_to_kb("Q7186"); assert!(entity.is_linked());
3091 assert_eq!(entity.kb_id.as_deref(), Some("Q7186"));
3092
3093 entity.set_canonical(42);
3094 assert!(entity.has_coreference());
3095 assert_eq!(
3096 entity.canonical_id,
3097 Some(crate::core::types::CanonicalId::new(42))
3098 );
3099 }
3100
3101 #[test]
3102 fn test_relation_creation() {
3103 let head = Entity::new("Marie Curie", EntityType::Person, 0, 11, 0.95);
3104 let tail = Entity::new("Sorbonne", EntityType::Organization, 24, 32, 0.90);
3105
3106 let relation = Relation::new(head.clone(), tail.clone(), "WORKED_AT", 0.85);
3107 assert_eq!(relation.relation_type, "WORKED_AT");
3108 assert_eq!(relation.as_triple(), "(Marie Curie, WORKED_AT, Sorbonne)");
3109 assert!(relation.trigger_span.is_none());
3110
3111 let relation2 = Relation::with_trigger(head, tail, "EMPLOYMENT", 13, 19, 0.85);
3113 assert_eq!(relation2.trigger_span, Some((13, 19)));
3114 }
3115
3116 #[test]
3117 fn test_relation_span_distance() {
3118 let head = Entity::new("Marie Curie", EntityType::Person, 0, 11, 0.95);
3120 let tail = Entity::new("Sorbonne", EntityType::Organization, 24, 32, 0.90);
3121 let relation = Relation::new(head, tail, "WORKED_AT", 0.85);
3122 assert_eq!(relation.span_distance(), 13);
3123 }
3124
3125 #[test]
3126 fn test_relation_category() {
3127 let rel_type = EntityType::custom("CEO_OF", EntityCategory::Relation);
3129 assert_eq!(rel_type.category(), EntityCategory::Relation);
3130 assert!(rel_type.category().is_relation());
3131 assert!(rel_type.requires_ml()); }
3133
3134 #[test]
3139 fn test_span_text() {
3140 let span = Span::text(10, 20);
3141 assert!(span.is_text());
3142 assert!(!span.is_visual());
3143 assert_eq!(span.text_offsets(), Some((10, 20)));
3144 assert_eq!(span.len(), 10);
3145 assert!(!span.is_empty());
3146 }
3147
3148 #[test]
3149 fn test_span_bbox() {
3150 let span = Span::bbox(0.1, 0.2, 0.3, 0.4);
3151 assert!(!span.is_text());
3152 assert!(span.is_visual());
3153 assert_eq!(span.text_offsets(), None);
3154 assert_eq!(span.len(), 0); }
3156
3157 #[test]
3158 fn test_span_bbox_with_page() {
3159 let span = Span::bbox_on_page(0.1, 0.2, 0.3, 0.4, 5);
3160 if let Span::BoundingBox { page, .. } = span {
3161 assert_eq!(page, Some(5));
3162 } else {
3163 panic!("Expected BoundingBox");
3164 }
3165 }
3166
3167 #[test]
3168 fn test_span_hybrid() {
3169 let bbox = Span::bbox(0.1, 0.2, 0.3, 0.4);
3170 let hybrid = Span::Hybrid {
3171 start: 10,
3172 end: 20,
3173 bbox: Box::new(bbox),
3174 };
3175 assert!(hybrid.is_text());
3176 assert!(hybrid.is_visual());
3177 assert_eq!(hybrid.text_offsets(), Some((10, 20)));
3178 assert_eq!(hybrid.len(), 10);
3179 }
3180
3181 #[test]
3186 fn test_hierarchical_confidence_new() {
3187 let hc = HierarchicalConfidence::new(0.9, 0.8, 0.7);
3188 assert!((hc.linkage - 0.9).abs() < f32::EPSILON);
3189 assert!((hc.type_score - 0.8).abs() < f32::EPSILON);
3190 assert!((hc.boundary - 0.7).abs() < f32::EPSILON);
3191 }
3192
3193 #[test]
3194 fn test_hierarchical_confidence_clamping() {
3195 let hc = HierarchicalConfidence::new(1.5, -0.5, 0.5);
3196 assert!((hc.linkage - 1.0).abs() < f32::EPSILON);
3197 assert!(hc.type_score.abs() < f32::EPSILON);
3198 assert!((hc.boundary - 0.5).abs() < f32::EPSILON);
3199 }
3200
3201 #[test]
3202 fn test_hierarchical_confidence_from_single() {
3203 let hc = HierarchicalConfidence::from_single(0.8);
3204 assert!((hc.linkage - 0.8).abs() < f32::EPSILON);
3205 assert!((hc.type_score - 0.8).abs() < f32::EPSILON);
3206 assert!((hc.boundary - 0.8).abs() < f32::EPSILON);
3207 }
3208
3209 #[test]
3210 fn test_hierarchical_confidence_combined() {
3211 let hc = HierarchicalConfidence::new(1.0, 1.0, 1.0);
3212 assert!((hc.combined() - 1.0).abs() < f32::EPSILON);
3213
3214 let hc2 = HierarchicalConfidence::new(0.8, 0.8, 0.8);
3215 assert!((hc2.combined() - 0.8).abs() < f32::EPSILON);
3216
3217 let hc3 = HierarchicalConfidence::new(0.5, 0.5, 0.5);
3219 assert!((hc3.combined() - 0.5).abs() < 0.001);
3220 }
3221
3222 #[test]
3223 fn test_hierarchical_confidence_threshold() {
3224 let hc = HierarchicalConfidence::new(0.9, 0.8, 0.7);
3225 assert!(hc.passes_threshold(0.5, 0.5, 0.5));
3226 assert!(hc.passes_threshold(0.9, 0.8, 0.7));
3227 assert!(!hc.passes_threshold(0.95, 0.8, 0.7)); assert!(!hc.passes_threshold(0.9, 0.85, 0.7)); }
3230
3231 #[test]
3232 fn test_hierarchical_confidence_from_f64() {
3233 let hc: HierarchicalConfidence = 0.85_f64.into();
3234 assert!((hc.linkage - 0.85).abs() < 0.001);
3235 }
3236
3237 #[test]
3242 fn test_ragged_batch_from_sequences() {
3243 let seqs = vec![vec![1, 2, 3], vec![4, 5], vec![6, 7, 8, 9]];
3244 let batch = RaggedBatch::from_sequences(&seqs);
3245
3246 assert_eq!(batch.batch_size(), 3);
3247 assert_eq!(batch.total_tokens(), 9);
3248 assert_eq!(batch.max_seq_len, 4);
3249 assert_eq!(batch.cumulative_offsets, vec![0, 3, 5, 9]);
3250 }
3251
3252 #[test]
3253 fn test_ragged_batch_doc_range() {
3254 let seqs = vec![vec![1, 2, 3], vec![4, 5]];
3255 let batch = RaggedBatch::from_sequences(&seqs);
3256
3257 assert_eq!(batch.doc_range(0), Some(0..3));
3258 assert_eq!(batch.doc_range(1), Some(3..5));
3259 assert_eq!(batch.doc_range(2), None);
3260 }
3261
3262 #[test]
3263 fn test_ragged_batch_doc_tokens() {
3264 let seqs = vec![vec![1, 2, 3], vec![4, 5]];
3265 let batch = RaggedBatch::from_sequences(&seqs);
3266
3267 assert_eq!(batch.doc_tokens(0), Some(&[1, 2, 3][..]));
3268 assert_eq!(batch.doc_tokens(1), Some(&[4, 5][..]));
3269 }
3270
3271 #[test]
3272 fn test_ragged_batch_padding_savings() {
3273 let seqs = vec![vec![1, 2, 3], vec![4, 5], vec![6, 7, 8, 9]];
3277 let batch = RaggedBatch::from_sequences(&seqs);
3278 let savings = batch.padding_savings();
3279 assert!((savings - 0.25).abs() < 0.001);
3280 }
3281
3282 #[test]
3287 fn test_span_candidate() {
3288 let sc = SpanCandidate::new(0, 5, 10);
3289 assert_eq!(sc.doc_idx, 0);
3290 assert_eq!(sc.start, 5);
3291 assert_eq!(sc.end, 10);
3292 assert_eq!(sc.width(), 5);
3293 }
3294
3295 #[test]
3296 fn test_generate_span_candidates() {
3297 let seqs = vec![vec![1, 2, 3]]; let batch = RaggedBatch::from_sequences(&seqs);
3299 let candidates = generate_span_candidates(&batch, 2);
3300
3301 assert_eq!(candidates.len(), 5);
3304
3305 for c in &candidates {
3307 assert_eq!(c.doc_idx, 0);
3308 assert!(c.end as usize <= 3);
3309 assert!(c.width() as usize <= 2);
3310 }
3311 }
3312
3313 #[test]
3314 fn test_generate_filtered_candidates() {
3315 let seqs = vec![vec![1, 2, 3]];
3316 let batch = RaggedBatch::from_sequences(&seqs);
3317
3318 let mask = vec![0.9, 0.9, 0.1, 0.1, 0.1];
3321 let candidates = generate_filtered_candidates(&batch, 2, &mask, 0.5);
3322
3323 assert_eq!(candidates.len(), 2);
3324 }
3325
3326 #[test]
3331 fn test_entity_builder_basic() {
3332 let entity = Entity::builder("John", EntityType::Person)
3333 .span(0, 4)
3334 .confidence(0.95)
3335 .build();
3336
3337 assert_eq!(entity.text, "John");
3338 assert_eq!(entity.entity_type, EntityType::Person);
3339 assert_eq!(entity.start, 0);
3340 assert_eq!(entity.end, 4);
3341 assert!((entity.confidence - 0.95).abs() < f64::EPSILON);
3342 }
3343
3344 #[test]
3345 fn test_entity_builder_full() {
3346 let entity = Entity::builder("Marie Curie", EntityType::Person)
3347 .span(0, 11)
3348 .confidence(0.95)
3349 .kb_id("Q7186")
3350 .canonical_id(42)
3351 .normalized("Marie Salomea Skłodowska Curie")
3352 .provenance(Provenance::ml("bert", 0.95))
3353 .build();
3354
3355 assert_eq!(entity.text, "Marie Curie");
3356 assert_eq!(entity.kb_id.as_deref(), Some("Q7186"));
3357 assert_eq!(
3358 entity.canonical_id,
3359 Some(crate::core::types::CanonicalId::new(42))
3360 );
3361 assert_eq!(
3362 entity.normalized.as_deref(),
3363 Some("Marie Salomea Skłodowska Curie")
3364 );
3365 assert!(entity.provenance.is_some());
3366 }
3367
3368 #[test]
3369 fn test_entity_builder_hierarchical() {
3370 let hc = HierarchicalConfidence::new(0.9, 0.8, 0.7);
3371 let entity = Entity::builder("test", EntityType::Person)
3372 .span(0, 4)
3373 .hierarchical_confidence(hc)
3374 .build();
3375
3376 assert!(entity.hierarchical_confidence.is_some());
3377 assert!((entity.linkage_confidence() - 0.9).abs() < 0.001);
3378 assert!((entity.type_confidence() - 0.8).abs() < 0.001);
3379 assert!((entity.boundary_confidence() - 0.7).abs() < 0.001);
3380 }
3381
3382 #[test]
3383 fn test_entity_builder_visual() {
3384 let bbox = Span::bbox(0.1, 0.2, 0.3, 0.4);
3385 let entity = Entity::builder("receipt item", EntityType::Money)
3386 .visual_span(bbox)
3387 .confidence(0.9)
3388 .build();
3389
3390 assert!(entity.is_visual());
3391 assert!(entity.visual_span.is_some());
3392 }
3393
3394 #[test]
3399 fn test_entity_hierarchical_confidence_helpers() {
3400 let mut entity = Entity::new("test", EntityType::Person, 0, 4, 0.8);
3401
3402 assert!((entity.linkage_confidence() - 0.8).abs() < 0.001);
3404 assert!((entity.type_confidence() - 0.8).abs() < 0.001);
3405 assert!((entity.boundary_confidence() - 0.8).abs() < 0.001);
3406
3407 entity.set_hierarchical_confidence(HierarchicalConfidence::new(0.95, 0.85, 0.75));
3409 assert!((entity.linkage_confidence() - 0.95).abs() < 0.001);
3410 assert!((entity.type_confidence() - 0.85).abs() < 0.001);
3411 assert!((entity.boundary_confidence() - 0.75).abs() < 0.001);
3412 }
3413
3414 #[test]
3415 fn test_entity_from_visual() {
3416 let entity = Entity::from_visual(
3417 "receipt total",
3418 EntityType::Money,
3419 Span::bbox(0.5, 0.8, 0.2, 0.05),
3420 0.92,
3421 );
3422
3423 assert!(entity.is_visual());
3424 assert_eq!(entity.start, 0);
3425 assert_eq!(entity.end, 0);
3426 assert!((entity.confidence - 0.92).abs() < f64::EPSILON);
3427 }
3428
3429 #[test]
3430 fn test_entity_span_helpers() {
3431 let entity = Entity::new("test", EntityType::Person, 10, 20, 0.9);
3432 assert_eq!(entity.text_span(), (10, 20));
3433 assert_eq!(entity.span_len(), 10);
3434 }
3435
3436 #[test]
3441 fn test_provenance_pattern() {
3442 let prov = Provenance::pattern("EMAIL");
3443 assert_eq!(prov.method, ExtractionMethod::Pattern);
3444 assert_eq!(prov.pattern.as_deref(), Some("EMAIL"));
3445 assert_eq!(prov.raw_confidence, Some(1.0)); }
3447
3448 #[test]
3449 fn test_provenance_ml() {
3450 let prov = Provenance::ml("bert-ner", 0.87);
3451 assert_eq!(prov.method, ExtractionMethod::Neural);
3452 assert_eq!(prov.source.as_ref(), "bert-ner");
3453 assert_eq!(prov.raw_confidence, Some(0.87));
3454 }
3455
3456 #[test]
3457 fn test_provenance_with_version() {
3458 let prov = Provenance::ml("gliner", 0.92).with_version("v2.1.0");
3459
3460 assert_eq!(prov.model_version.as_deref(), Some("v2.1.0"));
3461 assert_eq!(prov.source.as_ref(), "gliner");
3462 }
3463
3464 #[test]
3465 fn test_provenance_with_timestamp() {
3466 let prov = Provenance::pattern("DATE").with_timestamp("2024-01-15T10:30:00Z");
3467
3468 assert_eq!(prov.timestamp.as_deref(), Some("2024-01-15T10:30:00Z"));
3469 }
3470
3471 #[test]
3472 fn test_provenance_builder_chain() {
3473 let prov = Provenance::ml("modernbert-ner", 0.95)
3474 .with_version("v1.0.0")
3475 .with_timestamp("2024-11-27T12:00:00Z");
3476
3477 assert_eq!(prov.method, ExtractionMethod::Neural);
3478 assert_eq!(prov.source.as_ref(), "modernbert-ner");
3479 assert_eq!(prov.raw_confidence, Some(0.95));
3480 assert_eq!(prov.model_version.as_deref(), Some("v1.0.0"));
3481 assert_eq!(prov.timestamp.as_deref(), Some("2024-11-27T12:00:00Z"));
3482 }
3483
3484 #[test]
3485 fn test_provenance_serialization() {
3486 let prov = Provenance::ml("test", 0.9)
3487 .with_version("v1.0")
3488 .with_timestamp("2024-01-01");
3489
3490 let json = serde_json::to_string(&prov).unwrap();
3491 assert!(json.contains("model_version"));
3492 assert!(json.contains("v1.0"));
3493
3494 let restored: Provenance = serde_json::from_str(&json).unwrap();
3495 assert_eq!(restored.model_version.as_deref(), Some("v1.0"));
3496 assert_eq!(restored.timestamp.as_deref(), Some("2024-01-01"));
3497 }
3498}
3499
3500#[cfg(test)]
3501mod proptests {
3502 #![allow(clippy::unwrap_used)] use super::*;
3504 use proptest::prelude::*;
3505
3506 proptest! {
3507 #[test]
3508 fn confidence_always_clamped(conf in -10.0f64..10.0) {
3509 let e = Entity::new("test", EntityType::Person, 0, 4, conf);
3510 prop_assert!(e.confidence >= 0.0);
3511 prop_assert!(e.confidence <= 1.0);
3512 }
3513
3514 #[test]
3515 fn entity_type_roundtrip(label in "[A-Z]{3,10}") {
3516 let et = EntityType::from_label(&label);
3517 let back = EntityType::from_label(et.as_label());
3518 prop_assert!(matches!(back, EntityType::Other(_)) || back == et);
3520 }
3521
3522 #[test]
3523 fn overlap_is_symmetric(
3524 s1 in 0usize..100,
3525 len1 in 1usize..50,
3526 s2 in 0usize..100,
3527 len2 in 1usize..50,
3528 ) {
3529 let e1 = Entity::new("a", EntityType::Person, s1, s1 + len1, 1.0);
3530 let e2 = Entity::new("b", EntityType::Person, s2, s2 + len2, 1.0);
3531 prop_assert_eq!(e1.overlaps(&e2), e2.overlaps(&e1));
3532 }
3533
3534 #[test]
3535 fn overlap_ratio_bounded(
3536 s1 in 0usize..100,
3537 len1 in 1usize..50,
3538 s2 in 0usize..100,
3539 len2 in 1usize..50,
3540 ) {
3541 let e1 = Entity::new("a", EntityType::Person, s1, s1 + len1, 1.0);
3542 let e2 = Entity::new("b", EntityType::Person, s2, s2 + len2, 1.0);
3543 let ratio = e1.overlap_ratio(&e2);
3544 prop_assert!(ratio >= 0.0);
3545 prop_assert!(ratio <= 1.0);
3546 }
3547
3548 #[test]
3549 fn self_overlap_ratio_is_one(s in 0usize..100, len in 1usize..50) {
3550 let e = Entity::new("test", EntityType::Person, s, s + len, 1.0);
3551 let ratio = e.overlap_ratio(&e);
3552 prop_assert!((ratio - 1.0).abs() < 1e-10);
3553 }
3554
3555 #[test]
3556 fn hierarchical_confidence_always_clamped(
3557 linkage in -2.0f32..2.0,
3558 type_score in -2.0f32..2.0,
3559 boundary in -2.0f32..2.0,
3560 ) {
3561 let hc = HierarchicalConfidence::new(linkage, type_score, boundary);
3562 prop_assert!(hc.linkage >= 0.0 && hc.linkage <= 1.0);
3563 prop_assert!(hc.type_score >= 0.0 && hc.type_score <= 1.0);
3564 prop_assert!(hc.boundary >= 0.0 && hc.boundary <= 1.0);
3565 prop_assert!(hc.combined() >= 0.0 && hc.combined() <= 1.0);
3566 }
3567
3568 #[test]
3569 fn span_candidate_width_consistent(
3570 doc in 0u32..10,
3571 start in 0u32..100,
3572 end in 1u32..100,
3573 ) {
3574 let actual_end = start.max(end);
3575 let sc = SpanCandidate::new(doc, start, actual_end);
3576 prop_assert_eq!(sc.width(), actual_end.saturating_sub(start));
3577 }
3578
3579 #[test]
3580 fn ragged_batch_preserves_tokens(
3581 seq_lens in proptest::collection::vec(1usize..10, 1..5),
3582 ) {
3583 let mut counter = 0u32;
3585 let seqs: Vec<Vec<u32>> = seq_lens.iter().map(|&len| {
3586 let seq: Vec<u32> = (counter..counter + len as u32).collect();
3587 counter += len as u32;
3588 seq
3589 }).collect();
3590
3591 let batch = RaggedBatch::from_sequences(&seqs);
3592
3593 prop_assert_eq!(batch.batch_size(), seqs.len());
3595 prop_assert_eq!(batch.total_tokens(), seq_lens.iter().sum::<usize>());
3596
3597 for (i, seq) in seqs.iter().enumerate() {
3599 let doc_tokens = batch.doc_tokens(i).unwrap();
3600 prop_assert_eq!(doc_tokens, seq.as_slice());
3601 }
3602 }
3603
3604 #[test]
3605 fn span_text_offsets_consistent(start in 0usize..100, len in 0usize..50) {
3606 let end = start + len;
3607 let span = Span::text(start, end);
3608 let (s, e) = span.text_offsets().unwrap();
3609 prop_assert_eq!(s, start);
3610 prop_assert_eq!(e, end);
3611 prop_assert_eq!(span.len(), len);
3612 }
3613
3614 #[test]
3620 fn entity_span_validity(
3621 start in 0usize..10000,
3622 len in 1usize..500,
3623 conf in 0.0f64..=1.0,
3624 ) {
3625 let end = start + len;
3626 let text_content: String = "x".repeat(end);
3628 let entity_text: String = text_content.chars().skip(start).take(len).collect();
3629 let e = Entity::new(&entity_text, EntityType::Person, start, end, conf);
3630 let issues = e.validate(&text_content);
3631 for issue in &issues {
3633 match issue {
3634 ValidationIssue::InvalidSpan { .. } => {
3635 prop_assert!(false, "start < end should never produce InvalidSpan");
3636 }
3637 ValidationIssue::SpanOutOfBounds { .. } => {
3638 prop_assert!(false, "span within text should never produce SpanOutOfBounds");
3639 }
3640 _ => {} }
3642 }
3643 }
3644
3645 #[test]
3647 fn entity_type_label_roundtrip_standard(
3648 idx in 0usize..13,
3649 ) {
3650 let standard_types = [
3651 EntityType::Person,
3652 EntityType::Organization,
3653 EntityType::Location,
3654 EntityType::Date,
3655 EntityType::Time,
3656 EntityType::Money,
3657 EntityType::Percent,
3658 EntityType::Quantity,
3659 EntityType::Cardinal,
3660 EntityType::Ordinal,
3661 EntityType::Email,
3662 EntityType::Url,
3663 EntityType::Phone,
3664 ];
3665 let et = &standard_types[idx];
3666 let label = et.as_label();
3667 let roundtripped = EntityType::from_label(label);
3668 prop_assert_eq!(&roundtripped, et,
3669 "from_label(as_label()) must roundtrip for {:?} (label={:?})", et, label);
3670 }
3671
3672 #[test]
3674 fn span_containment_property(
3675 a_start in 0usize..5000,
3676 a_len in 1usize..5000,
3677 b_offset in 0usize..5000,
3678 b_len in 1usize..5000,
3679 ) {
3680 let a_end = a_start + a_len;
3681 let b_start = a_start + (b_offset % a_len); let b_end_candidate = b_start + b_len;
3683
3684 if b_start >= a_start && b_end_candidate <= a_end {
3686 prop_assert!(a_start <= b_start);
3688 prop_assert!(a_end >= b_end_candidate);
3689
3690 let ea = Entity::new("a", EntityType::Person, a_start, a_end, 1.0);
3692 let eb = Entity::new("b", EntityType::Person, b_start, b_end_candidate, 1.0);
3693 prop_assert!(ea.overlaps(&eb),
3694 "containing span must overlap contained span");
3695 }
3696 }
3697
3698 #[test]
3700 fn entity_serde_roundtrip(
3701 start in 0usize..10000,
3702 len in 1usize..500,
3703 conf in 0.0f64..=1.0,
3704 type_idx in 0usize..5,
3705 ) {
3706 let end = start + len;
3707 let types = [
3708 EntityType::Person,
3709 EntityType::Organization,
3710 EntityType::Location,
3711 EntityType::Date,
3712 EntityType::Email,
3713 ];
3714 let et = types[type_idx].clone();
3715 let text = format!("entity_{}", start);
3716 let e = Entity::new(&text, et, start, end, conf);
3717
3718 let json = serde_json::to_string(&e).unwrap();
3719 let e2: Entity = serde_json::from_str(&json).unwrap();
3720
3721 prop_assert_eq!(&e.text, &e2.text);
3722 prop_assert_eq!(&e.entity_type, &e2.entity_type);
3723 prop_assert_eq!(e.start, e2.start);
3724 prop_assert_eq!(e.end, e2.end);
3725 prop_assert!((e.confidence - e2.confidence).abs() < 1e-10,
3727 "confidence roundtrip: {} vs {}", e.confidence, e2.confidence);
3728 prop_assert_eq!(&e.normalized, &e2.normalized);
3729 prop_assert_eq!(&e.kb_id, &e2.kb_id);
3730 }
3731
3732 #[test]
3734 fn discontinuous_span_total_length(
3735 segments in proptest::collection::vec(
3736 (0usize..5000, 1usize..500),
3737 1..6
3738 ),
3739 ) {
3740 let ranges: Vec<std::ops::Range<usize>> = segments.iter()
3741 .map(|&(start, len)| start..start + len)
3742 .collect();
3743 let expected_sum: usize = ranges.iter().map(|r| r.end - r.start).sum();
3744 let span = DiscontinuousSpan::new(ranges);
3745 prop_assert_eq!(span.total_len(), expected_sum,
3746 "total_len must equal sum of segment lengths");
3747 }
3748 }
3749
3750 #[test]
3755 fn test_entity_viewport_as_str() {
3756 assert_eq!(EntityViewport::Business.as_str(), "business");
3757 assert_eq!(EntityViewport::Legal.as_str(), "legal");
3758 assert_eq!(EntityViewport::Technical.as_str(), "technical");
3759 assert_eq!(EntityViewport::Academic.as_str(), "academic");
3760 assert_eq!(EntityViewport::Personal.as_str(), "personal");
3761 assert_eq!(EntityViewport::Political.as_str(), "political");
3762 assert_eq!(EntityViewport::Media.as_str(), "media");
3763 assert_eq!(EntityViewport::Historical.as_str(), "historical");
3764 assert_eq!(EntityViewport::General.as_str(), "general");
3765 assert_eq!(
3766 EntityViewport::Custom("custom".to_string()).as_str(),
3767 "custom"
3768 );
3769 }
3770
3771 #[test]
3772 fn test_entity_viewport_is_professional() {
3773 assert!(EntityViewport::Business.is_professional());
3774 assert!(EntityViewport::Legal.is_professional());
3775 assert!(EntityViewport::Technical.is_professional());
3776 assert!(EntityViewport::Academic.is_professional());
3777 assert!(EntityViewport::Political.is_professional());
3778
3779 assert!(!EntityViewport::Personal.is_professional());
3780 assert!(!EntityViewport::Media.is_professional());
3781 assert!(!EntityViewport::Historical.is_professional());
3782 assert!(!EntityViewport::General.is_professional());
3783 assert!(!EntityViewport::Custom("test".to_string()).is_professional());
3784 }
3785
3786 #[test]
3787 fn test_entity_viewport_from_str() {
3788 assert_eq!(
3789 "business".parse::<EntityViewport>().unwrap(),
3790 EntityViewport::Business
3791 );
3792 assert_eq!(
3793 "financial".parse::<EntityViewport>().unwrap(),
3794 EntityViewport::Business
3795 );
3796 assert_eq!(
3797 "corporate".parse::<EntityViewport>().unwrap(),
3798 EntityViewport::Business
3799 );
3800
3801 assert_eq!(
3802 "legal".parse::<EntityViewport>().unwrap(),
3803 EntityViewport::Legal
3804 );
3805 assert_eq!(
3806 "law".parse::<EntityViewport>().unwrap(),
3807 EntityViewport::Legal
3808 );
3809
3810 assert_eq!(
3811 "technical".parse::<EntityViewport>().unwrap(),
3812 EntityViewport::Technical
3813 );
3814 assert_eq!(
3815 "engineering".parse::<EntityViewport>().unwrap(),
3816 EntityViewport::Technical
3817 );
3818
3819 assert_eq!(
3820 "academic".parse::<EntityViewport>().unwrap(),
3821 EntityViewport::Academic
3822 );
3823 assert_eq!(
3824 "research".parse::<EntityViewport>().unwrap(),
3825 EntityViewport::Academic
3826 );
3827
3828 assert_eq!(
3829 "personal".parse::<EntityViewport>().unwrap(),
3830 EntityViewport::Personal
3831 );
3832 assert_eq!(
3833 "biographical".parse::<EntityViewport>().unwrap(),
3834 EntityViewport::Personal
3835 );
3836
3837 assert_eq!(
3838 "political".parse::<EntityViewport>().unwrap(),
3839 EntityViewport::Political
3840 );
3841 assert_eq!(
3842 "policy".parse::<EntityViewport>().unwrap(),
3843 EntityViewport::Political
3844 );
3845
3846 assert_eq!(
3847 "media".parse::<EntityViewport>().unwrap(),
3848 EntityViewport::Media
3849 );
3850 assert_eq!(
3851 "press".parse::<EntityViewport>().unwrap(),
3852 EntityViewport::Media
3853 );
3854
3855 assert_eq!(
3856 "historical".parse::<EntityViewport>().unwrap(),
3857 EntityViewport::Historical
3858 );
3859 assert_eq!(
3860 "history".parse::<EntityViewport>().unwrap(),
3861 EntityViewport::Historical
3862 );
3863
3864 assert_eq!(
3865 "general".parse::<EntityViewport>().unwrap(),
3866 EntityViewport::General
3867 );
3868 assert_eq!(
3869 "generic".parse::<EntityViewport>().unwrap(),
3870 EntityViewport::General
3871 );
3872 assert_eq!(
3873 "".parse::<EntityViewport>().unwrap(),
3874 EntityViewport::General
3875 );
3876
3877 assert_eq!(
3879 "custom_viewport".parse::<EntityViewport>().unwrap(),
3880 EntityViewport::Custom("custom_viewport".to_string())
3881 );
3882 }
3883
3884 #[test]
3885 fn test_entity_viewport_from_str_case_insensitive() {
3886 assert_eq!(
3887 "BUSINESS".parse::<EntityViewport>().unwrap(),
3888 EntityViewport::Business
3889 );
3890 assert_eq!(
3891 "Business".parse::<EntityViewport>().unwrap(),
3892 EntityViewport::Business
3893 );
3894 assert_eq!(
3895 "BuSiNeSs".parse::<EntityViewport>().unwrap(),
3896 EntityViewport::Business
3897 );
3898 }
3899
3900 #[test]
3901 fn test_entity_viewport_display() {
3902 assert_eq!(format!("{}", EntityViewport::Business), "business");
3903 assert_eq!(format!("{}", EntityViewport::Academic), "academic");
3904 assert_eq!(
3905 format!("{}", EntityViewport::Custom("test".to_string())),
3906 "test"
3907 );
3908 }
3909
3910 #[test]
3911 fn test_entity_viewport_methods() {
3912 let mut entity = Entity::new("Marie Curie", EntityType::Person, 0, 11, 0.9);
3913
3914 assert!(!entity.has_viewport());
3916 assert_eq!(entity.viewport_or_default(), EntityViewport::General);
3917 assert!(entity.matches_viewport(&EntityViewport::Academic)); entity.set_viewport(EntityViewport::Academic);
3921 assert!(entity.has_viewport());
3922 assert_eq!(entity.viewport_or_default(), EntityViewport::Academic);
3923 assert!(entity.matches_viewport(&EntityViewport::Academic));
3924 assert!(!entity.matches_viewport(&EntityViewport::Business));
3925 }
3926
3927 #[test]
3928 fn test_entity_builder_with_viewport() {
3929 let entity = Entity::builder("Marie Curie", EntityType::Person)
3930 .span(0, 11)
3931 .viewport(EntityViewport::Academic)
3932 .build();
3933
3934 assert_eq!(entity.viewport, Some(EntityViewport::Academic));
3935 assert!(entity.has_viewport());
3936 }
3937
3938 #[test]
3943 fn test_entity_category_requires_ml() {
3944 assert!(EntityCategory::Agent.requires_ml());
3945 assert!(EntityCategory::Organization.requires_ml());
3946 assert!(EntityCategory::Place.requires_ml());
3947 assert!(EntityCategory::Creative.requires_ml());
3948 assert!(EntityCategory::Relation.requires_ml());
3949
3950 assert!(!EntityCategory::Temporal.requires_ml());
3951 assert!(!EntityCategory::Numeric.requires_ml());
3952 assert!(!EntityCategory::Contact.requires_ml());
3953 assert!(!EntityCategory::Misc.requires_ml());
3954 }
3955
3956 #[test]
3957 fn test_entity_category_pattern_detectable() {
3958 assert!(EntityCategory::Temporal.pattern_detectable());
3959 assert!(EntityCategory::Numeric.pattern_detectable());
3960 assert!(EntityCategory::Contact.pattern_detectable());
3961
3962 assert!(!EntityCategory::Agent.pattern_detectable());
3963 assert!(!EntityCategory::Organization.pattern_detectable());
3964 assert!(!EntityCategory::Place.pattern_detectable());
3965 assert!(!EntityCategory::Creative.pattern_detectable());
3966 assert!(!EntityCategory::Relation.pattern_detectable());
3967 assert!(!EntityCategory::Misc.pattern_detectable());
3968 }
3969
3970 #[test]
3971 fn test_entity_category_is_relation() {
3972 assert!(EntityCategory::Relation.is_relation());
3973
3974 assert!(!EntityCategory::Agent.is_relation());
3975 assert!(!EntityCategory::Organization.is_relation());
3976 assert!(!EntityCategory::Place.is_relation());
3977 assert!(!EntityCategory::Temporal.is_relation());
3978 assert!(!EntityCategory::Numeric.is_relation());
3979 assert!(!EntityCategory::Contact.is_relation());
3980 assert!(!EntityCategory::Creative.is_relation());
3981 assert!(!EntityCategory::Misc.is_relation());
3982 }
3983
3984 #[test]
3985 fn test_entity_category_as_str() {
3986 assert_eq!(EntityCategory::Agent.as_str(), "agent");
3987 assert_eq!(EntityCategory::Organization.as_str(), "organization");
3988 assert_eq!(EntityCategory::Place.as_str(), "place");
3989 assert_eq!(EntityCategory::Creative.as_str(), "creative");
3990 assert_eq!(EntityCategory::Temporal.as_str(), "temporal");
3991 assert_eq!(EntityCategory::Numeric.as_str(), "numeric");
3992 assert_eq!(EntityCategory::Contact.as_str(), "contact");
3993 assert_eq!(EntityCategory::Relation.as_str(), "relation");
3994 assert_eq!(EntityCategory::Misc.as_str(), "misc");
3995 }
3996
3997 #[test]
3998 fn test_entity_category_display() {
3999 assert_eq!(format!("{}", EntityCategory::Agent), "agent");
4000 assert_eq!(format!("{}", EntityCategory::Temporal), "temporal");
4001 assert_eq!(format!("{}", EntityCategory::Relation), "relation");
4002 }
4003}