1use serde::{Deserialize, Serialize};
39use std::borrow::Cow;
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
51#[non_exhaustive]
52pub enum EntityCategory {
53 Agent,
56 Organization,
59 Place,
62 Creative,
65 Temporal,
68 Numeric,
71 Contact,
74 Relation,
78 Misc,
80}
81
82impl EntityCategory {
83 #[must_use]
85 pub const fn requires_ml(&self) -> bool {
86 matches!(
87 self,
88 EntityCategory::Agent
89 | EntityCategory::Organization
90 | EntityCategory::Place
91 | EntityCategory::Creative
92 | EntityCategory::Relation
93 )
94 }
95
96 #[must_use]
98 pub const fn pattern_detectable(&self) -> bool {
99 matches!(
100 self,
101 EntityCategory::Temporal | EntityCategory::Numeric | EntityCategory::Contact
102 )
103 }
104
105 #[must_use]
107 pub const fn is_relation(&self) -> bool {
108 matches!(self, EntityCategory::Relation)
109 }
110
111 #[must_use]
113 pub const fn as_str(&self) -> &'static str {
114 match self {
115 EntityCategory::Agent => "agent",
116 EntityCategory::Organization => "organization",
117 EntityCategory::Place => "place",
118 EntityCategory::Creative => "creative",
119 EntityCategory::Temporal => "temporal",
120 EntityCategory::Numeric => "numeric",
121 EntityCategory::Contact => "contact",
122 EntityCategory::Relation => "relation",
123 EntityCategory::Misc => "misc",
124 }
125 }
126}
127
128impl std::fmt::Display for EntityCategory {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 write!(f, "{}", self.as_str())
131 }
132}
133
134#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
171#[non_exhaustive]
172pub enum EntityViewport {
173 Business,
175 Legal,
177 Technical,
179 Academic,
181 Personal,
183 Political,
185 Media,
187 Historical,
189 #[default]
191 General,
192 Custom(String),
194}
195
196impl EntityViewport {
197 #[must_use]
199 pub fn as_str(&self) -> &str {
200 match self {
201 EntityViewport::Business => "business",
202 EntityViewport::Legal => "legal",
203 EntityViewport::Technical => "technical",
204 EntityViewport::Academic => "academic",
205 EntityViewport::Personal => "personal",
206 EntityViewport::Political => "political",
207 EntityViewport::Media => "media",
208 EntityViewport::Historical => "historical",
209 EntityViewport::General => "general",
210 EntityViewport::Custom(s) => s,
211 }
212 }
213
214 #[must_use]
216 pub const fn is_professional(&self) -> bool {
217 matches!(
218 self,
219 EntityViewport::Business
220 | EntityViewport::Legal
221 | EntityViewport::Technical
222 | EntityViewport::Academic
223 | EntityViewport::Political
224 )
225 }
226}
227
228impl std::str::FromStr for EntityViewport {
229 type Err = std::convert::Infallible;
230
231 fn from_str(s: &str) -> Result<Self, Self::Err> {
232 Ok(match s.to_lowercase().as_str() {
233 "business" | "financial" | "corporate" => EntityViewport::Business,
234 "legal" | "law" | "compliance" => EntityViewport::Legal,
235 "technical" | "engineering" | "tech" => EntityViewport::Technical,
236 "academic" | "research" | "scholarly" => EntityViewport::Academic,
237 "personal" | "biographical" | "private" => EntityViewport::Personal,
238 "political" | "policy" | "government" => EntityViewport::Political,
239 "media" | "press" | "pr" | "public_relations" => EntityViewport::Media,
240 "historical" | "history" | "past" => EntityViewport::Historical,
241 "general" | "generic" | "" => EntityViewport::General,
242 other => EntityViewport::Custom(other.to_string()),
243 })
244 }
245}
246
247impl std::fmt::Display for EntityViewport {
248 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249 write!(f, "{}", self.as_str())
250 }
251}
252
253#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
278#[non_exhaustive]
279pub enum EntityType {
280 Person,
283 Organization,
285 Location,
287
288 Date,
291 Time,
293
294 Money,
297 Percent,
299 Quantity,
301 Cardinal,
303 Ordinal,
305
306 Email,
309 Url,
311 Phone,
313
314 Custom {
317 name: String,
319 category: EntityCategory,
321 },
322
323 #[serde(rename = "Other")]
325 Other(String),
326}
327
328impl EntityType {
329 #[must_use]
331 pub fn category(&self) -> EntityCategory {
332 match self {
333 EntityType::Person => EntityCategory::Agent,
335 EntityType::Organization => EntityCategory::Organization,
337 EntityType::Location => EntityCategory::Place,
339 EntityType::Date | EntityType::Time => EntityCategory::Temporal,
341 EntityType::Money
343 | EntityType::Percent
344 | EntityType::Quantity
345 | EntityType::Cardinal
346 | EntityType::Ordinal => EntityCategory::Numeric,
347 EntityType::Email | EntityType::Url | EntityType::Phone => EntityCategory::Contact,
349 EntityType::Custom { category, .. } => *category,
351 EntityType::Other(_) => EntityCategory::Misc,
353 }
354 }
355
356 #[must_use]
358 pub fn requires_ml(&self) -> bool {
359 self.category().requires_ml()
360 }
361
362 #[must_use]
364 pub fn pattern_detectable(&self) -> bool {
365 self.category().pattern_detectable()
366 }
367
368 #[must_use]
370 pub fn as_label(&self) -> &str {
371 match self {
372 EntityType::Person => "PER",
373 EntityType::Organization => "ORG",
374 EntityType::Location => "LOC",
375 EntityType::Date => "DATE",
376 EntityType::Time => "TIME",
377 EntityType::Money => "MONEY",
378 EntityType::Percent => "PERCENT",
379 EntityType::Quantity => "QUANTITY",
380 EntityType::Cardinal => "CARDINAL",
381 EntityType::Ordinal => "ORDINAL",
382 EntityType::Email => "EMAIL",
383 EntityType::Url => "URL",
384 EntityType::Phone => "PHONE",
385 EntityType::Custom { name, .. } => name.as_str(),
386 EntityType::Other(s) => s.as_str(),
387 }
388 }
389
390 #[must_use]
394 pub fn from_label(label: &str) -> Self {
395 let label = label
397 .strip_prefix("B-")
398 .or_else(|| label.strip_prefix("I-"))
399 .or_else(|| label.strip_prefix("E-"))
400 .or_else(|| label.strip_prefix("S-"))
401 .unwrap_or(label);
402
403 match label.to_uppercase().as_str() {
404 "PER" | "PERSON" => EntityType::Person,
406 "ORG" | "ORGANIZATION" | "COMPANY" | "CORPORATION" => EntityType::Organization,
407 "LOC" | "LOCATION" | "GPE" | "GEO-LOC" => EntityType::Location,
408 "FACILITY" | "FAC" | "BUILDING" => {
410 EntityType::custom("BUILDING", EntityCategory::Place)
411 }
412 "PRODUCT" | "PROD" => EntityType::custom("PRODUCT", EntityCategory::Misc),
413 "EVENT" => EntityType::custom("EVENT", EntityCategory::Creative),
414 "CREATIVE-WORK" | "WORK_OF_ART" | "ART" => {
415 EntityType::custom("CREATIVE_WORK", EntityCategory::Creative)
416 }
417 "GROUP" | "NORP" => EntityType::custom("GROUP", EntityCategory::Agent),
418 "DATE" => EntityType::Date,
420 "TIME" => EntityType::Time,
421 "MONEY" | "CURRENCY" => EntityType::Money,
423 "PERCENT" | "PERCENTAGE" => EntityType::Percent,
424 "QUANTITY" => EntityType::Quantity,
425 "CARDINAL" => EntityType::Cardinal,
426 "ORDINAL" => EntityType::Ordinal,
427 "EMAIL" => EntityType::Email,
429 "URL" | "URI" => EntityType::Url,
430 "PHONE" | "TELEPHONE" => EntityType::Phone,
431 "MISC" | "MISCELLANEOUS" | "OTHER" => EntityType::Other("MISC".to_string()),
433 "DISEASE" | "DISORDER" => EntityType::custom("DISEASE", EntityCategory::Misc),
435 "CHEMICAL" | "DRUG" => EntityType::custom("CHEMICAL", EntityCategory::Misc),
436 "GENE" => EntityType::custom("GENE", EntityCategory::Misc),
437 "PROTEIN" => EntityType::custom("PROTEIN", EntityCategory::Misc),
438 other => EntityType::Other(other.to_string()),
440 }
441 }
442
443 #[must_use]
458 pub fn custom(name: impl Into<String>, category: EntityCategory) -> Self {
459 EntityType::Custom {
460 name: name.into(),
461 category,
462 }
463 }
464}
465
466impl std::fmt::Display for EntityType {
467 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
468 write!(f, "{}", self.as_label())
469 }
470}
471
472impl std::str::FromStr for EntityType {
473 type Err = std::convert::Infallible;
474
475 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
477 Ok(Self::from_label(s))
478 }
479}
480
481#[derive(Debug, Clone, Default)]
521pub struct TypeMapper {
522 mappings: std::collections::HashMap<String, EntityType>,
523}
524
525impl TypeMapper {
526 #[must_use]
528 pub fn new() -> Self {
529 Self::default()
530 }
531
532 #[must_use]
534 pub fn mit_movie() -> Self {
535 let mut mapper = Self::new();
536 mapper.add("ACTOR", EntityType::Person);
538 mapper.add("DIRECTOR", EntityType::Person);
539 mapper.add("CHARACTER", EntityType::Person);
540 mapper.add(
541 "TITLE",
542 EntityType::custom("WORK_OF_ART", EntityCategory::Creative),
543 );
544 mapper.add("GENRE", EntityType::custom("GENRE", EntityCategory::Misc));
545 mapper.add("YEAR", EntityType::Date);
546 mapper.add("RATING", EntityType::custom("RATING", EntityCategory::Misc));
547 mapper.add("PLOT", EntityType::custom("PLOT", EntityCategory::Misc));
548 mapper
549 }
550
551 #[must_use]
553 pub fn mit_restaurant() -> Self {
554 let mut mapper = Self::new();
555 mapper.add("RESTAURANT_NAME", EntityType::Organization);
556 mapper.add("LOCATION", EntityType::Location);
557 mapper.add(
558 "CUISINE",
559 EntityType::custom("CUISINE", EntityCategory::Misc),
560 );
561 mapper.add("DISH", EntityType::custom("DISH", EntityCategory::Misc));
562 mapper.add("PRICE", EntityType::Money);
563 mapper.add(
564 "AMENITY",
565 EntityType::custom("AMENITY", EntityCategory::Misc),
566 );
567 mapper.add("HOURS", EntityType::Time);
568 mapper
569 }
570
571 #[must_use]
573 pub fn biomedical() -> Self {
574 let mut mapper = Self::new();
575 mapper.add(
576 "DISEASE",
577 EntityType::custom("DISEASE", EntityCategory::Agent),
578 );
579 mapper.add(
580 "CHEMICAL",
581 EntityType::custom("CHEMICAL", EntityCategory::Misc),
582 );
583 mapper.add("DRUG", EntityType::custom("DRUG", EntityCategory::Misc));
584 mapper.add("GENE", EntityType::custom("GENE", EntityCategory::Misc));
585 mapper.add(
586 "PROTEIN",
587 EntityType::custom("PROTEIN", EntityCategory::Misc),
588 );
589 mapper.add("DNA", EntityType::custom("DNA", EntityCategory::Misc));
591 mapper.add("RNA", EntityType::custom("RNA", EntityCategory::Misc));
592 mapper.add(
593 "cell_line",
594 EntityType::custom("CELL_LINE", EntityCategory::Misc),
595 );
596 mapper.add(
597 "cell_type",
598 EntityType::custom("CELL_TYPE", EntityCategory::Misc),
599 );
600 mapper
601 }
602
603 #[must_use]
605 pub fn social_media() -> Self {
606 let mut mapper = Self::new();
607 mapper.add("person", EntityType::Person);
609 mapper.add("corporation", EntityType::Organization);
610 mapper.add("location", EntityType::Location);
611 mapper.add("group", EntityType::Organization);
612 mapper.add(
613 "product",
614 EntityType::custom("PRODUCT", EntityCategory::Misc),
615 );
616 mapper.add(
617 "creative_work",
618 EntityType::custom("WORK_OF_ART", EntityCategory::Creative),
619 );
620 mapper.add("event", EntityType::custom("EVENT", EntityCategory::Misc));
621 mapper
622 }
623
624 #[must_use]
626 pub fn manufacturing() -> Self {
627 let mut mapper = Self::new();
628 mapper.add("MATE", EntityType::custom("MATERIAL", EntityCategory::Misc));
630 mapper.add("MANP", EntityType::custom("PROCESS", EntityCategory::Misc));
631 mapper.add("MACEQ", EntityType::custom("MACHINE", EntityCategory::Misc));
632 mapper.add(
633 "APPL",
634 EntityType::custom("APPLICATION", EntityCategory::Misc),
635 );
636 mapper.add("FEAT", EntityType::custom("FEATURE", EntityCategory::Misc));
637 mapper.add(
638 "PARA",
639 EntityType::custom("PARAMETER", EntityCategory::Misc),
640 );
641 mapper.add("PRO", EntityType::custom("PROPERTY", EntityCategory::Misc));
642 mapper.add(
643 "CHAR",
644 EntityType::custom("CHARACTERISTIC", EntityCategory::Misc),
645 );
646 mapper.add(
647 "ENAT",
648 EntityType::custom("ENABLING_TECHNOLOGY", EntityCategory::Misc),
649 );
650 mapper.add(
651 "CONPRI",
652 EntityType::custom("CONCEPT_PRINCIPLE", EntityCategory::Misc),
653 );
654 mapper.add(
655 "BIOP",
656 EntityType::custom("BIO_PROCESS", EntityCategory::Misc),
657 );
658 mapper.add(
659 "MANS",
660 EntityType::custom("MAN_STANDARD", EntityCategory::Misc),
661 );
662 mapper
663 }
664
665 pub fn add(&mut self, source: impl Into<String>, target: EntityType) {
667 self.mappings.insert(source.into().to_uppercase(), target);
668 }
669
670 #[must_use]
672 pub fn map(&self, label: &str) -> Option<&EntityType> {
673 self.mappings.get(&label.to_uppercase())
674 }
675
676 #[must_use]
680 pub fn normalize(&self, label: &str) -> EntityType {
681 self.map(label)
682 .cloned()
683 .unwrap_or_else(|| EntityType::from_label(label))
684 }
685
686 #[must_use]
688 pub fn contains(&self, label: &str) -> bool {
689 self.mappings.contains_key(&label.to_uppercase())
690 }
691
692 pub fn labels(&self) -> impl Iterator<Item = &String> {
694 self.mappings.keys()
695 }
696}
697
698#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
714#[non_exhaustive]
715pub enum ExtractionMethod {
716 Pattern,
719
720 #[default]
723 Neural,
724
725 #[deprecated(since = "0.2.0", note = "Use Neural or GatedEnsemble instead")]
729 Lexicon,
730
731 SoftLexicon,
735
736 GatedEnsemble,
740
741 Consensus,
743
744 Heuristic,
747
748 Unknown,
750
751 #[deprecated(since = "0.2.0", note = "Use Heuristic or Pattern instead")]
753 Rule,
754
755 #[deprecated(since = "0.2.0", note = "Use Neural instead")]
757 ML,
758
759 #[deprecated(since = "0.2.0", note = "Use Consensus instead")]
761 Ensemble,
762}
763
764impl ExtractionMethod {
765 #[must_use]
792 pub const fn is_calibrated(&self) -> bool {
793 #[allow(deprecated)]
794 match self {
795 ExtractionMethod::Neural => true,
796 ExtractionMethod::GatedEnsemble => true,
797 ExtractionMethod::SoftLexicon => true,
798 ExtractionMethod::ML => true, ExtractionMethod::Pattern => false,
801 ExtractionMethod::Lexicon => false,
802 ExtractionMethod::Consensus => false,
803 ExtractionMethod::Heuristic => false,
804 ExtractionMethod::Unknown => false,
805 ExtractionMethod::Rule => false,
806 ExtractionMethod::Ensemble => false,
807 }
808 }
809
810 #[must_use]
817 pub const fn confidence_interpretation(&self) -> &'static str {
818 #[allow(deprecated)]
819 match self {
820 ExtractionMethod::Neural | ExtractionMethod::ML => "probability",
821 ExtractionMethod::GatedEnsemble | ExtractionMethod::SoftLexicon => "probability",
822 ExtractionMethod::Pattern | ExtractionMethod::Lexicon => "binary",
823 ExtractionMethod::Heuristic | ExtractionMethod::Rule => "heuristic_score",
824 ExtractionMethod::Consensus | ExtractionMethod::Ensemble => "agreement_ratio",
825 ExtractionMethod::Unknown => "unknown",
826 }
827 }
828}
829
830impl std::fmt::Display for ExtractionMethod {
831 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
832 #[allow(deprecated)]
833 match self {
834 ExtractionMethod::Pattern => write!(f, "pattern"),
835 ExtractionMethod::Neural => write!(f, "neural"),
836 ExtractionMethod::Lexicon => write!(f, "lexicon"),
837 ExtractionMethod::SoftLexicon => write!(f, "soft_lexicon"),
838 ExtractionMethod::GatedEnsemble => write!(f, "gated_ensemble"),
839 ExtractionMethod::Consensus => write!(f, "consensus"),
840 ExtractionMethod::Heuristic => write!(f, "heuristic"),
841 ExtractionMethod::Unknown => write!(f, "unknown"),
842 ExtractionMethod::Rule => write!(f, "heuristic"), ExtractionMethod::ML => write!(f, "neural"), ExtractionMethod::Ensemble => write!(f, "consensus"), }
846 }
847}
848
849pub trait Lexicon: Send + Sync {
903 fn lookup(&self, text: &str) -> Option<(EntityType, f64)>;
907
908 fn contains(&self, text: &str) -> bool {
910 self.lookup(text).is_some()
911 }
912
913 fn source(&self) -> &str;
915
916 fn len(&self) -> usize;
918
919 fn is_empty(&self) -> bool {
921 self.len() == 0
922 }
923}
924
925#[derive(Debug, Clone)]
930pub struct HashMapLexicon {
931 entries: std::collections::HashMap<String, (EntityType, f64)>,
932 source: String,
933}
934
935impl HashMapLexicon {
936 #[must_use]
938 pub fn new(source: impl Into<String>) -> Self {
939 Self {
940 entries: std::collections::HashMap::new(),
941 source: source.into(),
942 }
943 }
944
945 pub fn insert(&mut self, text: impl Into<String>, entity_type: EntityType, confidence: f64) {
947 self.entries.insert(text.into(), (entity_type, confidence));
948 }
949
950 pub fn from_iter<I, S>(source: impl Into<String>, entries: I) -> Self
952 where
953 I: IntoIterator<Item = (S, EntityType, f64)>,
954 S: Into<String>,
955 {
956 let mut lexicon = Self::new(source);
957 for (text, entity_type, confidence) in entries {
958 lexicon.insert(text, entity_type, confidence);
959 }
960 lexicon
961 }
962
963 pub fn entries(&self) -> impl Iterator<Item = (&str, &EntityType, f64)> {
965 self.entries.iter().map(|(k, (t, c))| (k.as_str(), t, *c))
966 }
967}
968
969impl Lexicon for HashMapLexicon {
970 fn lookup(&self, text: &str) -> Option<(EntityType, f64)> {
971 self.entries.get(text).cloned()
972 }
973
974 fn source(&self) -> &str {
975 &self.source
976 }
977
978 fn len(&self) -> usize {
979 self.entries.len()
980 }
981}
982
983#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
988pub struct Provenance {
989 pub source: Cow<'static, str>,
991 pub method: ExtractionMethod,
993 pub pattern: Option<Cow<'static, str>>,
995 pub raw_confidence: Option<f64>,
997 #[serde(default, skip_serializing_if = "Option::is_none")]
999 pub model_version: Option<Cow<'static, str>>,
1000 #[serde(default, skip_serializing_if = "Option::is_none")]
1002 pub timestamp: Option<String>,
1003}
1004
1005impl Provenance {
1006 #[must_use]
1008 pub fn pattern(pattern_name: &'static str) -> Self {
1009 Self {
1010 source: Cow::Borrowed("pattern"),
1011 method: ExtractionMethod::Pattern,
1012 pattern: Some(Cow::Borrowed(pattern_name)),
1013 raw_confidence: Some(1.0), model_version: None,
1015 timestamp: None,
1016 }
1017 }
1018
1019 #[must_use]
1030 pub fn ml(model_name: impl Into<Cow<'static, str>>, confidence: f64) -> Self {
1031 Self {
1032 source: model_name.into(),
1033 method: ExtractionMethod::Neural,
1034 pattern: None,
1035 raw_confidence: Some(confidence),
1036 model_version: None,
1037 timestamp: None,
1038 }
1039 }
1040
1041 #[deprecated(
1043 since = "0.2.1",
1044 note = "Use ml() instead, it now accepts owned strings"
1045 )]
1046 #[must_use]
1047 pub fn ml_owned(model_name: impl Into<String>, confidence: f64) -> Self {
1048 Self::ml(Cow::Owned(model_name.into()), confidence)
1049 }
1050
1051 #[must_use]
1053 pub fn ensemble(sources: &'static str) -> Self {
1054 Self {
1055 source: Cow::Borrowed(sources),
1056 method: ExtractionMethod::Consensus,
1057 pattern: None,
1058 raw_confidence: None,
1059 model_version: None,
1060 timestamp: None,
1061 }
1062 }
1063
1064 #[must_use]
1066 pub fn with_version(mut self, version: &'static str) -> Self {
1067 self.model_version = Some(Cow::Borrowed(version));
1068 self
1069 }
1070
1071 #[must_use]
1073 pub fn with_timestamp(mut self, timestamp: impl Into<String>) -> Self {
1074 self.timestamp = Some(timestamp.into());
1075 self
1076 }
1077}
1078
1079#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1110pub enum Span {
1111 Text {
1116 start: usize,
1118 end: usize,
1120 },
1121 BoundingBox {
1124 x: f32,
1126 y: f32,
1128 width: f32,
1130 height: f32,
1132 page: Option<u32>,
1134 },
1135 Hybrid {
1137 start: usize,
1139 end: usize,
1141 bbox: Box<Span>,
1143 },
1144}
1145
1146impl Span {
1147 #[must_use]
1149 pub const fn text(start: usize, end: usize) -> Self {
1150 Self::Text { start, end }
1151 }
1152
1153 #[must_use]
1155 pub fn bbox(x: f32, y: f32, width: f32, height: f32) -> Self {
1156 Self::BoundingBox {
1157 x,
1158 y,
1159 width,
1160 height,
1161 page: None,
1162 }
1163 }
1164
1165 #[must_use]
1167 pub fn bbox_on_page(x: f32, y: f32, width: f32, height: f32, page: u32) -> Self {
1168 Self::BoundingBox {
1169 x,
1170 y,
1171 width,
1172 height,
1173 page: Some(page),
1174 }
1175 }
1176
1177 #[must_use]
1179 pub const fn is_text(&self) -> bool {
1180 matches!(self, Self::Text { .. } | Self::Hybrid { .. })
1181 }
1182
1183 #[must_use]
1185 pub const fn is_visual(&self) -> bool {
1186 matches!(self, Self::BoundingBox { .. } | Self::Hybrid { .. })
1187 }
1188
1189 #[must_use]
1191 pub const fn text_offsets(&self) -> Option<(usize, usize)> {
1192 match self {
1193 Self::Text { start, end } => Some((*start, *end)),
1194 Self::Hybrid { start, end, .. } => Some((*start, *end)),
1195 Self::BoundingBox { .. } => None,
1196 }
1197 }
1198
1199 #[must_use]
1201 pub fn len(&self) -> usize {
1202 match self {
1203 Self::Text { start, end } => end.saturating_sub(*start),
1204 Self::Hybrid { start, end, .. } => end.saturating_sub(*start),
1205 Self::BoundingBox { .. } => 0,
1206 }
1207 }
1208
1209 #[must_use]
1211 pub fn is_empty(&self) -> bool {
1212 self.len() == 0
1213 }
1214}
1215
1216#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1257pub struct DiscontinuousSpan {
1258 segments: Vec<std::ops::Range<usize>>,
1261}
1262
1263impl DiscontinuousSpan {
1264 #[must_use]
1268 pub fn new(mut segments: Vec<std::ops::Range<usize>>) -> Self {
1269 segments.sort_by_key(|r| r.start);
1271 Self { segments }
1272 }
1273
1274 #[must_use]
1276 #[allow(clippy::single_range_in_vec_init)] pub fn contiguous(start: usize, end: usize) -> Self {
1278 Self {
1279 segments: vec![start..end],
1280 }
1281 }
1282
1283 #[must_use]
1285 pub fn num_segments(&self) -> usize {
1286 self.segments.len()
1287 }
1288
1289 #[must_use]
1291 pub fn is_discontinuous(&self) -> bool {
1292 self.segments.len() > 1
1293 }
1294
1295 #[must_use]
1297 pub fn is_contiguous(&self) -> bool {
1298 self.segments.len() <= 1
1299 }
1300
1301 #[must_use]
1303 pub fn segments(&self) -> &[std::ops::Range<usize>] {
1304 &self.segments
1305 }
1306
1307 #[must_use]
1309 pub fn bounding_range(&self) -> Option<std::ops::Range<usize>> {
1310 if self.segments.is_empty() {
1311 return None;
1312 }
1313 let start = self.segments.first()?.start;
1314 let end = self.segments.last()?.end;
1315 Some(start..end)
1316 }
1317
1318 #[must_use]
1321 pub fn total_len(&self) -> usize {
1322 self.segments.iter().map(|r| r.end - r.start).sum()
1323 }
1324
1325 #[must_use]
1327 pub fn extract_text(&self, text: &str, separator: &str) -> String {
1328 self.segments
1329 .iter()
1330 .map(|r| {
1331 let start = r.start;
1332 let len = r.end.saturating_sub(r.start);
1333 text.chars().skip(start).take(len).collect::<String>()
1334 })
1335 .collect::<Vec<_>>()
1336 .join(separator)
1337 }
1338
1339 #[must_use]
1349 pub fn contains(&self, pos: usize) -> bool {
1350 self.segments.iter().any(|r| r.contains(&pos))
1351 }
1352
1353 #[must_use]
1355 pub fn to_span(&self) -> Option<Span> {
1356 self.bounding_range().map(|r| Span::Text {
1357 start: r.start,
1358 end: r.end,
1359 })
1360 }
1361}
1362
1363impl From<std::ops::Range<usize>> for DiscontinuousSpan {
1364 fn from(range: std::ops::Range<usize>) -> Self {
1365 Self::contiguous(range.start, range.end)
1366 }
1367}
1368
1369impl Default for Span {
1370 fn default() -> Self {
1371 Self::Text { start: 0, end: 0 }
1372 }
1373}
1374
1375#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
1387pub struct HierarchicalConfidence {
1388 pub linkage: f32,
1391 pub type_score: f32,
1393 pub boundary: f32,
1396}
1397
1398impl HierarchicalConfidence {
1399 #[must_use]
1401 pub fn new(linkage: f32, type_score: f32, boundary: f32) -> Self {
1402 Self {
1403 linkage: linkage.clamp(0.0, 1.0),
1404 type_score: type_score.clamp(0.0, 1.0),
1405 boundary: boundary.clamp(0.0, 1.0),
1406 }
1407 }
1408
1409 #[must_use]
1412 pub fn from_single(confidence: f32) -> Self {
1413 let c = confidence.clamp(0.0, 1.0);
1414 Self {
1415 linkage: c,
1416 type_score: c,
1417 boundary: c,
1418 }
1419 }
1420
1421 #[must_use]
1424 pub fn combined(&self) -> f32 {
1425 (self.linkage * self.type_score * self.boundary).powf(1.0 / 3.0)
1426 }
1427
1428 #[must_use]
1430 pub fn as_f64(&self) -> f64 {
1431 self.combined() as f64
1432 }
1433
1434 #[must_use]
1436 pub fn passes_threshold(&self, linkage_min: f32, type_min: f32, boundary_min: f32) -> bool {
1437 self.linkage >= linkage_min && self.type_score >= type_min && self.boundary >= boundary_min
1438 }
1439}
1440
1441impl Default for HierarchicalConfidence {
1442 fn default() -> Self {
1443 Self {
1444 linkage: 1.0,
1445 type_score: 1.0,
1446 boundary: 1.0,
1447 }
1448 }
1449}
1450
1451impl From<f64> for HierarchicalConfidence {
1452 fn from(confidence: f64) -> Self {
1453 Self::from_single(confidence as f32)
1454 }
1455}
1456
1457impl From<f32> for HierarchicalConfidence {
1458 fn from(confidence: f32) -> Self {
1459 Self::from_single(confidence)
1460 }
1461}
1462
1463#[derive(Debug, Clone)]
1485pub struct RaggedBatch {
1486 pub token_ids: Vec<u32>,
1489 pub cumulative_offsets: Vec<u32>,
1493 pub max_seq_len: usize,
1495}
1496
1497impl RaggedBatch {
1498 pub fn from_sequences(sequences: &[Vec<u32>]) -> Self {
1500 let total_tokens: usize = sequences.iter().map(|s| s.len()).sum();
1501 let mut token_ids = Vec::with_capacity(total_tokens);
1502 let mut cumulative_offsets = Vec::with_capacity(sequences.len() + 1);
1503 let mut max_seq_len = 0;
1504
1505 cumulative_offsets.push(0);
1506 for seq in sequences {
1507 token_ids.extend_from_slice(seq);
1508 let len = token_ids.len();
1512 if len > u32::MAX as usize {
1513 log::warn!(
1516 "Token count {} exceeds u32::MAX, truncating to {}",
1517 len,
1518 u32::MAX
1519 );
1520 cumulative_offsets.push(u32::MAX);
1521 } else {
1522 cumulative_offsets.push(len as u32);
1523 }
1524 max_seq_len = max_seq_len.max(seq.len());
1525 }
1526
1527 Self {
1528 token_ids,
1529 cumulative_offsets,
1530 max_seq_len,
1531 }
1532 }
1533
1534 #[must_use]
1536 pub fn batch_size(&self) -> usize {
1537 self.cumulative_offsets.len().saturating_sub(1)
1538 }
1539
1540 #[must_use]
1542 pub fn total_tokens(&self) -> usize {
1543 self.token_ids.len()
1544 }
1545
1546 #[must_use]
1548 pub fn doc_range(&self, doc_idx: usize) -> Option<std::ops::Range<usize>> {
1549 if doc_idx + 1 < self.cumulative_offsets.len() {
1550 let start = self.cumulative_offsets[doc_idx] as usize;
1551 let end = self.cumulative_offsets[doc_idx + 1] as usize;
1552 Some(start..end)
1553 } else {
1554 None
1555 }
1556 }
1557
1558 #[must_use]
1560 pub fn doc_tokens(&self, doc_idx: usize) -> Option<&[u32]> {
1561 self.doc_range(doc_idx).map(|r| &self.token_ids[r])
1562 }
1563
1564 #[must_use]
1566 pub fn padding_savings(&self) -> f64 {
1567 let padded_size = self.batch_size() * self.max_seq_len;
1568 if padded_size == 0 {
1569 return 0.0;
1570 }
1571 1.0 - (self.total_tokens() as f64 / padded_size as f64)
1572 }
1573}
1574
1575#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1584pub struct SpanCandidate {
1585 pub doc_idx: u32,
1587 pub start: u32,
1589 pub end: u32,
1591}
1592
1593impl SpanCandidate {
1594 #[must_use]
1596 pub const fn new(doc_idx: u32, start: u32, end: u32) -> Self {
1597 Self {
1598 doc_idx,
1599 start,
1600 end,
1601 }
1602 }
1603
1604 #[must_use]
1606 pub const fn width(&self) -> u32 {
1607 self.end.saturating_sub(self.start)
1608 }
1609}
1610
1611pub fn generate_span_candidates(batch: &RaggedBatch, max_width: usize) -> Vec<SpanCandidate> {
1616 let mut candidates = Vec::new();
1617
1618 for doc_idx in 0..batch.batch_size() {
1619 if let Some(range) = batch.doc_range(doc_idx) {
1620 let doc_len = range.len();
1621 for start in 0..doc_len {
1623 let max_end = (start + max_width).min(doc_len);
1624 for end in (start + 1)..=max_end {
1625 candidates.push(SpanCandidate::new(doc_idx as u32, start as u32, end as u32));
1626 }
1627 }
1628 }
1629 }
1630
1631 candidates
1632}
1633
1634pub fn generate_filtered_candidates(
1638 batch: &RaggedBatch,
1639 max_width: usize,
1640 linkage_mask: &[f32],
1641 threshold: f32,
1642) -> Vec<SpanCandidate> {
1643 let mut candidates = Vec::new();
1644 let mut mask_idx = 0;
1645
1646 for doc_idx in 0..batch.batch_size() {
1647 if let Some(range) = batch.doc_range(doc_idx) {
1648 let doc_len = range.len();
1649 for start in 0..doc_len {
1650 let max_end = (start + max_width).min(doc_len);
1651 for end in (start + 1)..=max_end {
1652 if mask_idx < linkage_mask.len() && linkage_mask[mask_idx] >= threshold {
1654 candidates.push(SpanCandidate::new(
1655 doc_idx as u32,
1656 start as u32,
1657 end as u32,
1658 ));
1659 }
1660 mask_idx += 1;
1661 }
1662 }
1663 }
1664 }
1665
1666 candidates
1667}
1668
1669#[derive(Debug, Clone, Serialize, Deserialize)]
1720pub struct Entity {
1721 pub text: String,
1723 pub entity_type: EntityType,
1725 pub start: usize,
1730 pub end: usize,
1735 pub confidence: f64,
1737 #[serde(default, skip_serializing_if = "Option::is_none")]
1739 pub normalized: Option<String>,
1740 #[serde(default, skip_serializing_if = "Option::is_none")]
1742 pub provenance: Option<Provenance>,
1743 #[serde(default, skip_serializing_if = "Option::is_none")]
1746 pub kb_id: Option<String>,
1747 #[serde(default, skip_serializing_if = "Option::is_none")]
1751 pub canonical_id: Option<super::types::CanonicalId>,
1752 #[serde(default, skip_serializing_if = "Option::is_none")]
1755 pub hierarchical_confidence: Option<HierarchicalConfidence>,
1756 #[serde(default, skip_serializing_if = "Option::is_none")]
1759 pub visual_span: Option<Span>,
1760 #[serde(default, skip_serializing_if = "Option::is_none")]
1764 pub discontinuous_span: Option<DiscontinuousSpan>,
1765 #[serde(default, skip_serializing_if = "Option::is_none")]
1787 pub valid_from: Option<chrono::DateTime<chrono::Utc>>,
1788 #[serde(default, skip_serializing_if = "Option::is_none")]
1793 pub valid_until: Option<chrono::DateTime<chrono::Utc>>,
1794 #[serde(default, skip_serializing_if = "Option::is_none")]
1815 pub viewport: Option<EntityViewport>,
1816}
1817
1818impl Entity {
1819 #[must_use]
1821 pub fn new(
1822 text: impl Into<String>,
1823 entity_type: EntityType,
1824 start: usize,
1825 end: usize,
1826 confidence: f64,
1827 ) -> Self {
1828 Self {
1829 text: text.into(),
1830 entity_type,
1831 start,
1832 end,
1833 confidence: confidence.clamp(0.0, 1.0),
1834 normalized: None,
1835 provenance: None,
1836 kb_id: None,
1837 canonical_id: None,
1838 hierarchical_confidence: None,
1839 visual_span: None,
1840 discontinuous_span: None,
1841 valid_from: None,
1842 valid_until: None,
1843 viewport: None,
1844 }
1845 }
1846
1847 #[must_use]
1849 pub fn with_provenance(
1850 text: impl Into<String>,
1851 entity_type: EntityType,
1852 start: usize,
1853 end: usize,
1854 confidence: f64,
1855 provenance: Provenance,
1856 ) -> Self {
1857 Self {
1858 text: text.into(),
1859 entity_type,
1860 start,
1861 end,
1862 confidence: confidence.clamp(0.0, 1.0),
1863 normalized: None,
1864 provenance: Some(provenance),
1865 kb_id: None,
1866 canonical_id: None,
1867 hierarchical_confidence: None,
1868 visual_span: None,
1869 discontinuous_span: None,
1870 valid_from: None,
1871 valid_until: None,
1872 viewport: None,
1873 }
1874 }
1875
1876 #[must_use]
1878 pub fn with_hierarchical_confidence(
1879 text: impl Into<String>,
1880 entity_type: EntityType,
1881 start: usize,
1882 end: usize,
1883 confidence: HierarchicalConfidence,
1884 ) -> Self {
1885 Self {
1886 text: text.into(),
1887 entity_type,
1888 start,
1889 end,
1890 confidence: confidence.as_f64(),
1891 normalized: None,
1892 provenance: None,
1893 kb_id: None,
1894 canonical_id: None,
1895 hierarchical_confidence: Some(confidence),
1896 visual_span: None,
1897 discontinuous_span: None,
1898 valid_from: None,
1899 valid_until: None,
1900 viewport: None,
1901 }
1902 }
1903
1904 #[must_use]
1906 pub fn from_visual(
1907 text: impl Into<String>,
1908 entity_type: EntityType,
1909 bbox: Span,
1910 confidence: f64,
1911 ) -> Self {
1912 Self {
1913 text: text.into(),
1914 entity_type,
1915 start: 0,
1916 end: 0,
1917 confidence: confidence.clamp(0.0, 1.0),
1918 normalized: None,
1919 provenance: None,
1920 kb_id: None,
1921 canonical_id: None,
1922 hierarchical_confidence: None,
1923 visual_span: Some(bbox),
1924 discontinuous_span: None,
1925 valid_from: None,
1926 valid_until: None,
1927 viewport: None,
1928 }
1929 }
1930
1931 #[must_use]
1933 pub fn with_type(
1934 text: impl Into<String>,
1935 entity_type: EntityType,
1936 start: usize,
1937 end: usize,
1938 ) -> Self {
1939 Self::new(text, entity_type, start, end, 1.0)
1940 }
1941
1942 pub fn link_to_kb(&mut self, kb_id: impl Into<String>) {
1951 self.kb_id = Some(kb_id.into());
1952 }
1953
1954 pub fn set_canonical(&mut self, canonical_id: impl Into<super::types::CanonicalId>) {
1958 self.canonical_id = Some(canonical_id.into());
1959 }
1960
1961 #[must_use]
1971 pub fn with_canonical_id(mut self, canonical_id: impl Into<super::types::CanonicalId>) -> Self {
1972 self.canonical_id = Some(canonical_id.into());
1973 self
1974 }
1975
1976 #[must_use]
1978 pub fn is_linked(&self) -> bool {
1979 self.kb_id.is_some()
1980 }
1981
1982 #[must_use]
1984 pub fn has_coreference(&self) -> bool {
1985 self.canonical_id.is_some()
1986 }
1987
1988 #[must_use]
1994 pub fn is_discontinuous(&self) -> bool {
1995 self.discontinuous_span
1996 .as_ref()
1997 .map(|s| s.is_discontinuous())
1998 .unwrap_or(false)
1999 }
2000
2001 #[must_use]
2005 pub fn discontinuous_segments(&self) -> Option<Vec<std::ops::Range<usize>>> {
2006 self.discontinuous_span
2007 .as_ref()
2008 .filter(|s| s.is_discontinuous())
2009 .map(|s| s.segments().to_vec())
2010 }
2011
2012 pub fn set_discontinuous_span(&mut self, span: DiscontinuousSpan) {
2016 if let Some(bounding) = span.bounding_range() {
2018 self.start = bounding.start;
2019 self.end = bounding.end;
2020 }
2021 self.discontinuous_span = Some(span);
2022 }
2023
2024 #[must_use]
2032 pub fn total_len(&self) -> usize {
2033 if let Some(ref span) = self.discontinuous_span {
2034 span.segments().iter().map(|r| r.end - r.start).sum()
2035 } else {
2036 self.end.saturating_sub(self.start)
2037 }
2038 }
2039
2040 pub fn set_normalized(&mut self, normalized: impl Into<String>) {
2052 self.normalized = Some(normalized.into());
2053 }
2054
2055 #[must_use]
2057 pub fn normalized_or_text(&self) -> &str {
2058 self.normalized.as_deref().unwrap_or(&self.text)
2059 }
2060
2061 #[must_use]
2063 pub fn method(&self) -> ExtractionMethod {
2064 self.provenance
2065 .as_ref()
2066 .map_or(ExtractionMethod::Unknown, |p| p.method)
2067 }
2068
2069 #[must_use]
2071 pub fn source(&self) -> Option<&str> {
2072 self.provenance.as_ref().map(|p| p.source.as_ref())
2073 }
2074
2075 #[must_use]
2077 pub fn category(&self) -> EntityCategory {
2078 self.entity_type.category()
2079 }
2080
2081 #[must_use]
2083 pub fn is_structured(&self) -> bool {
2084 self.entity_type.pattern_detectable()
2085 }
2086
2087 #[must_use]
2089 pub fn is_named(&self) -> bool {
2090 self.entity_type.requires_ml()
2091 }
2092
2093 #[must_use]
2095 pub fn overlaps(&self, other: &Entity) -> bool {
2096 !(self.end <= other.start || other.end <= self.start)
2097 }
2098
2099 #[must_use]
2101 pub fn overlap_ratio(&self, other: &Entity) -> f64 {
2102 let intersection_start = self.start.max(other.start);
2103 let intersection_end = self.end.min(other.end);
2104
2105 if intersection_start >= intersection_end {
2106 return 0.0;
2107 }
2108
2109 let intersection = (intersection_end - intersection_start) as f64;
2110 let union = ((self.end - self.start) + (other.end - other.start)
2111 - (intersection_end - intersection_start)) as f64;
2112
2113 if union == 0.0 {
2114 return 1.0;
2115 }
2116
2117 intersection / union
2118 }
2119
2120 pub fn set_hierarchical_confidence(&mut self, confidence: HierarchicalConfidence) {
2122 self.confidence = confidence.as_f64();
2123 self.hierarchical_confidence = Some(confidence);
2124 }
2125
2126 #[must_use]
2128 pub fn linkage_confidence(&self) -> f32 {
2129 self.hierarchical_confidence
2130 .map_or(self.confidence as f32, |h| h.linkage)
2131 }
2132
2133 #[must_use]
2135 pub fn type_confidence(&self) -> f32 {
2136 self.hierarchical_confidence
2137 .map_or(self.confidence as f32, |h| h.type_score)
2138 }
2139
2140 #[must_use]
2142 pub fn boundary_confidence(&self) -> f32 {
2143 self.hierarchical_confidence
2144 .map_or(self.confidence as f32, |h| h.boundary)
2145 }
2146
2147 #[must_use]
2149 pub fn is_visual(&self) -> bool {
2150 self.visual_span.is_some()
2151 }
2152
2153 #[must_use]
2155 pub const fn text_span(&self) -> (usize, usize) {
2156 (self.start, self.end)
2157 }
2158
2159 #[must_use]
2161 pub const fn span_len(&self) -> usize {
2162 self.end.saturating_sub(self.start)
2163 }
2164
2165 #[allow(dead_code)]
2189 #[doc(hidden)]
2190 pub fn to_text_span(&self, _source_text: &str) -> serde_json::Value {
2191 unimplemented!("Use anno::offset utilities directly - see method docs")
2192 }
2193
2194 pub fn set_visual_span(&mut self, span: Span) {
2196 self.visual_span = Some(span);
2197 }
2198
2199 #[must_use]
2219 pub fn extract_text(&self, source_text: &str) -> String {
2220 let char_count = source_text.chars().count();
2224 self.extract_text_with_len(source_text, char_count)
2225 }
2226
2227 #[must_use]
2239 pub fn extract_text_with_len(&self, source_text: &str, text_char_count: usize) -> String {
2240 if self.start >= text_char_count || self.end > text_char_count || self.start >= self.end {
2241 return String::new();
2242 }
2243 source_text
2244 .chars()
2245 .skip(self.start)
2246 .take(self.end - self.start)
2247 .collect()
2248 }
2249
2250 pub fn set_valid_from(&mut self, dt: chrono::DateTime<chrono::Utc>) {
2266 self.valid_from = Some(dt);
2267 }
2268
2269 pub fn set_valid_until(&mut self, dt: chrono::DateTime<chrono::Utc>) {
2271 self.valid_until = Some(dt);
2272 }
2273
2274 pub fn set_temporal_range(
2276 &mut self,
2277 from: chrono::DateTime<chrono::Utc>,
2278 until: chrono::DateTime<chrono::Utc>,
2279 ) {
2280 self.valid_from = Some(from);
2281 self.valid_until = Some(until);
2282 }
2283
2284 #[must_use]
2286 pub fn is_temporal(&self) -> bool {
2287 self.valid_from.is_some() || self.valid_until.is_some()
2288 }
2289
2290 #[must_use]
2312 pub fn valid_at(&self, timestamp: &chrono::DateTime<chrono::Utc>) -> bool {
2313 match (&self.valid_from, &self.valid_until) {
2314 (None, None) => true, (Some(from), None) => timestamp >= from, (None, Some(until)) => timestamp <= until, (Some(from), Some(until)) => timestamp >= from && timestamp <= until,
2318 }
2319 }
2320
2321 #[must_use]
2323 pub fn is_currently_valid(&self) -> bool {
2324 self.valid_at(&chrono::Utc::now())
2325 }
2326
2327 pub fn set_viewport(&mut self, viewport: EntityViewport) {
2342 self.viewport = Some(viewport);
2343 }
2344
2345 #[must_use]
2347 pub fn has_viewport(&self) -> bool {
2348 self.viewport.is_some()
2349 }
2350
2351 #[must_use]
2353 pub fn viewport_or_default(&self) -> EntityViewport {
2354 self.viewport.clone().unwrap_or_default()
2355 }
2356
2357 #[must_use]
2363 pub fn matches_viewport(&self, query_viewport: &EntityViewport) -> bool {
2364 match &self.viewport {
2365 None => true, Some(v) => v == query_viewport,
2367 }
2368 }
2369
2370 #[must_use]
2372 pub fn builder(text: impl Into<String>, entity_type: EntityType) -> EntityBuilder {
2373 EntityBuilder::new(text, entity_type)
2374 }
2375
2376 #[must_use]
2409 pub fn validate(&self, source_text: &str) -> Vec<ValidationIssue> {
2410 let char_count = source_text.chars().count();
2412 self.validate_with_len(source_text, char_count)
2413 }
2414
2415 #[must_use]
2427 pub fn validate_with_len(
2428 &self,
2429 source_text: &str,
2430 text_char_count: usize,
2431 ) -> Vec<ValidationIssue> {
2432 let mut issues = Vec::new();
2433
2434 if self.start >= self.end {
2436 issues.push(ValidationIssue::InvalidSpan {
2437 start: self.start,
2438 end: self.end,
2439 reason: "start must be less than end".to_string(),
2440 });
2441 }
2442
2443 if self.end > text_char_count {
2444 issues.push(ValidationIssue::SpanOutOfBounds {
2445 end: self.end,
2446 text_len: text_char_count,
2447 });
2448 }
2449
2450 if self.start < self.end && self.end <= text_char_count {
2452 let actual = self.extract_text_with_len(source_text, text_char_count);
2453 if actual != self.text {
2454 issues.push(ValidationIssue::TextMismatch {
2455 expected: self.text.clone(),
2456 actual,
2457 start: self.start,
2458 end: self.end,
2459 });
2460 }
2461 }
2462
2463 if !(0.0..=1.0).contains(&self.confidence) {
2465 issues.push(ValidationIssue::InvalidConfidence {
2466 value: self.confidence,
2467 });
2468 }
2469
2470 if let EntityType::Custom { ref name, .. } = self.entity_type {
2472 if name.is_empty() {
2473 issues.push(ValidationIssue::InvalidType {
2474 reason: "Custom entity type has empty name".to_string(),
2475 });
2476 }
2477 }
2478
2479 if let Some(ref disc_span) = self.discontinuous_span {
2481 for (i, seg) in disc_span.segments().iter().enumerate() {
2482 if seg.start >= seg.end {
2483 issues.push(ValidationIssue::InvalidSpan {
2484 start: seg.start,
2485 end: seg.end,
2486 reason: format!("discontinuous segment {} is invalid", i),
2487 });
2488 }
2489 if seg.end > text_char_count {
2490 issues.push(ValidationIssue::SpanOutOfBounds {
2491 end: seg.end,
2492 text_len: text_char_count,
2493 });
2494 }
2495 }
2496 }
2497
2498 issues
2499 }
2500
2501 #[must_use]
2505 pub fn is_valid(&self, source_text: &str) -> bool {
2506 self.validate(source_text).is_empty()
2507 }
2508
2509 #[must_use]
2529 pub fn validate_batch(
2530 entities: &[Entity],
2531 source_text: &str,
2532 ) -> std::collections::HashMap<usize, Vec<ValidationIssue>> {
2533 entities
2534 .iter()
2535 .enumerate()
2536 .filter_map(|(idx, entity)| {
2537 let issues = entity.validate(source_text);
2538 if issues.is_empty() {
2539 None
2540 } else {
2541 Some((idx, issues))
2542 }
2543 })
2544 .collect()
2545 }
2546}
2547
2548#[derive(Debug, Clone, PartialEq)]
2550pub enum ValidationIssue {
2551 InvalidSpan {
2553 start: usize,
2555 end: usize,
2557 reason: String,
2559 },
2560 SpanOutOfBounds {
2562 end: usize,
2564 text_len: usize,
2566 },
2567 TextMismatch {
2569 expected: String,
2571 actual: String,
2573 start: usize,
2575 end: usize,
2577 },
2578 InvalidConfidence {
2580 value: f64,
2582 },
2583 InvalidType {
2585 reason: String,
2587 },
2588}
2589
2590impl std::fmt::Display for ValidationIssue {
2591 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2592 match self {
2593 ValidationIssue::InvalidSpan { start, end, reason } => {
2594 write!(f, "Invalid span [{}, {}): {}", start, end, reason)
2595 }
2596 ValidationIssue::SpanOutOfBounds { end, text_len } => {
2597 write!(f, "Span end {} exceeds text length {}", end, text_len)
2598 }
2599 ValidationIssue::TextMismatch {
2600 expected,
2601 actual,
2602 start,
2603 end,
2604 } => {
2605 write!(
2606 f,
2607 "Text mismatch at [{}, {}): expected '{}', got '{}'",
2608 start, end, expected, actual
2609 )
2610 }
2611 ValidationIssue::InvalidConfidence { value } => {
2612 write!(f, "Confidence {} outside [0.0, 1.0]", value)
2613 }
2614 ValidationIssue::InvalidType { reason } => {
2615 write!(f, "Invalid entity type: {}", reason)
2616 }
2617 }
2618 }
2619}
2620
2621#[derive(Debug, Clone)]
2636pub struct EntityBuilder {
2637 text: String,
2638 entity_type: EntityType,
2639 start: usize,
2640 end: usize,
2641 confidence: f64,
2642 normalized: Option<String>,
2643 provenance: Option<Provenance>,
2644 kb_id: Option<String>,
2645 canonical_id: Option<super::types::CanonicalId>,
2646 hierarchical_confidence: Option<HierarchicalConfidence>,
2647 visual_span: Option<Span>,
2648 discontinuous_span: Option<DiscontinuousSpan>,
2649 valid_from: Option<chrono::DateTime<chrono::Utc>>,
2650 valid_until: Option<chrono::DateTime<chrono::Utc>>,
2651 viewport: Option<EntityViewport>,
2652}
2653
2654impl EntityBuilder {
2655 #[must_use]
2657 pub fn new(text: impl Into<String>, entity_type: EntityType) -> Self {
2658 Self {
2659 text: text.into(),
2660 entity_type,
2661 start: 0,
2662 end: 0,
2663 confidence: 1.0,
2664 normalized: None,
2665 provenance: None,
2666 kb_id: None,
2667 canonical_id: None,
2668 hierarchical_confidence: None,
2669 visual_span: None,
2670 discontinuous_span: None,
2671 valid_from: None,
2672 valid_until: None,
2673 viewport: None,
2674 }
2675 }
2676
2677 #[must_use]
2679 pub const fn span(mut self, start: usize, end: usize) -> Self {
2680 self.start = start;
2681 self.end = end;
2682 self
2683 }
2684
2685 #[must_use]
2687 pub fn confidence(mut self, confidence: f64) -> Self {
2688 self.confidence = confidence.clamp(0.0, 1.0);
2689 self
2690 }
2691
2692 #[must_use]
2694 pub fn hierarchical_confidence(mut self, confidence: HierarchicalConfidence) -> Self {
2695 self.confidence = confidence.as_f64();
2696 self.hierarchical_confidence = Some(confidence);
2697 self
2698 }
2699
2700 #[must_use]
2702 pub fn normalized(mut self, normalized: impl Into<String>) -> Self {
2703 self.normalized = Some(normalized.into());
2704 self
2705 }
2706
2707 #[must_use]
2709 pub fn provenance(mut self, provenance: Provenance) -> Self {
2710 self.provenance = Some(provenance);
2711 self
2712 }
2713
2714 #[must_use]
2716 pub fn kb_id(mut self, kb_id: impl Into<String>) -> Self {
2717 self.kb_id = Some(kb_id.into());
2718 self
2719 }
2720
2721 #[must_use]
2723 pub const fn canonical_id(mut self, canonical_id: u64) -> Self {
2724 self.canonical_id = Some(super::types::CanonicalId::new(canonical_id));
2725 self
2726 }
2727
2728 #[must_use]
2730 pub fn visual_span(mut self, span: Span) -> Self {
2731 self.visual_span = Some(span);
2732 self
2733 }
2734
2735 #[must_use]
2739 pub fn discontinuous_span(mut self, span: DiscontinuousSpan) -> Self {
2740 if let Some(bounding) = span.bounding_range() {
2742 self.start = bounding.start;
2743 self.end = bounding.end;
2744 }
2745 self.discontinuous_span = Some(span);
2746 self
2747 }
2748
2749 #[must_use]
2763 pub fn valid_from(mut self, dt: chrono::DateTime<chrono::Utc>) -> Self {
2764 self.valid_from = Some(dt);
2765 self
2766 }
2767
2768 #[must_use]
2770 pub fn valid_until(mut self, dt: chrono::DateTime<chrono::Utc>) -> Self {
2771 self.valid_until = Some(dt);
2772 self
2773 }
2774
2775 #[must_use]
2777 pub fn temporal_range(
2778 mut self,
2779 from: chrono::DateTime<chrono::Utc>,
2780 until: chrono::DateTime<chrono::Utc>,
2781 ) -> Self {
2782 self.valid_from = Some(from);
2783 self.valid_until = Some(until);
2784 self
2785 }
2786
2787 #[must_use]
2800 pub fn viewport(mut self, viewport: EntityViewport) -> Self {
2801 self.viewport = Some(viewport);
2802 self
2803 }
2804
2805 #[must_use]
2807 pub fn build(self) -> Entity {
2808 Entity {
2809 text: self.text,
2810 entity_type: self.entity_type,
2811 start: self.start,
2812 end: self.end,
2813 confidence: self.confidence,
2814 normalized: self.normalized,
2815 provenance: self.provenance,
2816 kb_id: self.kb_id,
2817 canonical_id: self.canonical_id,
2818 hierarchical_confidence: self.hierarchical_confidence,
2819 visual_span: self.visual_span,
2820 discontinuous_span: self.discontinuous_span,
2821 valid_from: self.valid_from,
2822 valid_until: self.valid_until,
2823 viewport: self.viewport,
2824 }
2825 }
2826}
2827
2828#[derive(Debug, Clone, Serialize, Deserialize)]
2854pub struct Relation {
2855 pub head: Entity,
2857 pub tail: Entity,
2859 pub relation_type: String,
2861 pub trigger_span: Option<(usize, usize)>,
2864 pub confidence: f64,
2866}
2867
2868impl Relation {
2869 #[must_use]
2871 pub fn new(
2872 head: Entity,
2873 tail: Entity,
2874 relation_type: impl Into<String>,
2875 confidence: f64,
2876 ) -> Self {
2877 Self {
2878 head,
2879 tail,
2880 relation_type: relation_type.into(),
2881 trigger_span: None,
2882 confidence: confidence.clamp(0.0, 1.0),
2883 }
2884 }
2885
2886 #[must_use]
2888 pub fn with_trigger(
2889 head: Entity,
2890 tail: Entity,
2891 relation_type: impl Into<String>,
2892 trigger_start: usize,
2893 trigger_end: usize,
2894 confidence: f64,
2895 ) -> Self {
2896 Self {
2897 head,
2898 tail,
2899 relation_type: relation_type.into(),
2900 trigger_span: Some((trigger_start, trigger_end)),
2901 confidence: confidence.clamp(0.0, 1.0),
2902 }
2903 }
2904
2905 #[must_use]
2907 pub fn as_triple(&self) -> String {
2908 format!(
2909 "({}, {}, {})",
2910 self.head.text, self.relation_type, self.tail.text
2911 )
2912 }
2913
2914 #[must_use]
2917 pub fn span_distance(&self) -> usize {
2918 if self.head.end <= self.tail.start {
2919 self.tail.start.saturating_sub(self.head.end)
2920 } else if self.tail.end <= self.head.start {
2921 self.head.start.saturating_sub(self.tail.end)
2922 } else {
2923 0 }
2925 }
2926}
2927
2928#[cfg(test)]
2929mod tests {
2930 #![allow(clippy::unwrap_used)] use super::*;
2932
2933 #[test]
2934 fn test_entity_type_roundtrip() {
2935 let types = [
2936 EntityType::Person,
2937 EntityType::Organization,
2938 EntityType::Location,
2939 EntityType::Date,
2940 EntityType::Money,
2941 EntityType::Percent,
2942 ];
2943
2944 for t in types {
2945 let label = t.as_label();
2946 let parsed = EntityType::from_label(label);
2947 assert_eq!(t, parsed);
2948 }
2949 }
2950
2951 #[test]
2952 fn test_entity_overlap() {
2953 let e1 = Entity::new("John", EntityType::Person, 0, 4, 0.9);
2954 let e2 = Entity::new("Smith", EntityType::Person, 5, 10, 0.9);
2955 let e3 = Entity::new("John Smith", EntityType::Person, 0, 10, 0.9);
2956
2957 assert!(!e1.overlaps(&e2)); assert!(e1.overlaps(&e3)); assert!(e3.overlaps(&e2)); }
2961
2962 #[test]
2963 fn test_confidence_clamping() {
2964 let e1 = Entity::new("test", EntityType::Person, 0, 4, 1.5);
2965 assert!((e1.confidence - 1.0).abs() < f64::EPSILON);
2966
2967 let e2 = Entity::new("test", EntityType::Person, 0, 4, -0.5);
2968 assert!(e2.confidence.abs() < f64::EPSILON);
2969 }
2970
2971 #[test]
2972 fn test_entity_categories() {
2973 assert_eq!(EntityType::Person.category(), EntityCategory::Agent);
2975 assert_eq!(
2976 EntityType::Organization.category(),
2977 EntityCategory::Organization
2978 );
2979 assert_eq!(EntityType::Location.category(), EntityCategory::Place);
2980 assert!(EntityType::Person.requires_ml());
2981 assert!(!EntityType::Person.pattern_detectable());
2982
2983 assert_eq!(EntityType::Date.category(), EntityCategory::Temporal);
2985 assert_eq!(EntityType::Time.category(), EntityCategory::Temporal);
2986 assert!(EntityType::Date.pattern_detectable());
2987 assert!(!EntityType::Date.requires_ml());
2988
2989 assert_eq!(EntityType::Money.category(), EntityCategory::Numeric);
2991 assert_eq!(EntityType::Percent.category(), EntityCategory::Numeric);
2992 assert!(EntityType::Money.pattern_detectable());
2993
2994 assert_eq!(EntityType::Email.category(), EntityCategory::Contact);
2996 assert_eq!(EntityType::Url.category(), EntityCategory::Contact);
2997 assert_eq!(EntityType::Phone.category(), EntityCategory::Contact);
2998 assert!(EntityType::Email.pattern_detectable());
2999 }
3000
3001 #[test]
3002 fn test_new_types_roundtrip() {
3003 let types = [
3004 EntityType::Time,
3005 EntityType::Email,
3006 EntityType::Url,
3007 EntityType::Phone,
3008 EntityType::Quantity,
3009 EntityType::Cardinal,
3010 EntityType::Ordinal,
3011 ];
3012
3013 for t in types {
3014 let label = t.as_label();
3015 let parsed = EntityType::from_label(label);
3016 assert_eq!(t, parsed, "Roundtrip failed for {}", label);
3017 }
3018 }
3019
3020 #[test]
3021 fn test_custom_entity_type() {
3022 let disease = EntityType::custom("DISEASE", EntityCategory::Agent);
3023 assert_eq!(disease.as_label(), "DISEASE");
3024 assert!(disease.requires_ml());
3025
3026 let product_id = EntityType::custom("PRODUCT_ID", EntityCategory::Misc);
3027 assert_eq!(product_id.as_label(), "PRODUCT_ID");
3028 assert!(!product_id.requires_ml());
3029 assert!(!product_id.pattern_detectable());
3030 }
3031
3032 #[test]
3033 fn test_entity_normalization() {
3034 let mut e = Entity::new("Jan 15", EntityType::Date, 0, 6, 0.95);
3035 assert!(e.normalized.is_none());
3036 assert_eq!(e.normalized_or_text(), "Jan 15");
3037
3038 e.set_normalized("2024-01-15");
3039 assert_eq!(e.normalized.as_deref(), Some("2024-01-15"));
3040 assert_eq!(e.normalized_or_text(), "2024-01-15");
3041 }
3042
3043 #[test]
3044 fn test_entity_helpers() {
3045 let named = Entity::new("John", EntityType::Person, 0, 4, 0.9);
3046 assert!(named.is_named());
3047 assert!(!named.is_structured());
3048 assert_eq!(named.category(), EntityCategory::Agent);
3049
3050 let structured = Entity::new("$100", EntityType::Money, 0, 4, 0.95);
3051 assert!(!structured.is_named());
3052 assert!(structured.is_structured());
3053 assert_eq!(structured.category(), EntityCategory::Numeric);
3054 }
3055
3056 #[test]
3057 fn test_knowledge_linking() {
3058 let mut entity = Entity::new("Marie Curie", EntityType::Person, 0, 11, 0.95);
3059 assert!(!entity.is_linked());
3060 assert!(!entity.has_coreference());
3061
3062 entity.link_to_kb("Q7186"); assert!(entity.is_linked());
3064 assert_eq!(entity.kb_id.as_deref(), Some("Q7186"));
3065
3066 entity.set_canonical(42);
3067 assert!(entity.has_coreference());
3068 assert_eq!(
3069 entity.canonical_id,
3070 Some(crate::core::types::CanonicalId::new(42))
3071 );
3072 }
3073
3074 #[test]
3075 fn test_relation_creation() {
3076 let head = Entity::new("Marie Curie", EntityType::Person, 0, 11, 0.95);
3077 let tail = Entity::new("Sorbonne", EntityType::Organization, 24, 32, 0.90);
3078
3079 let relation = Relation::new(head.clone(), tail.clone(), "WORKED_AT", 0.85);
3080 assert_eq!(relation.relation_type, "WORKED_AT");
3081 assert_eq!(relation.as_triple(), "(Marie Curie, WORKED_AT, Sorbonne)");
3082 assert!(relation.trigger_span.is_none());
3083
3084 let relation2 = Relation::with_trigger(head, tail, "EMPLOYMENT", 13, 19, 0.85);
3086 assert_eq!(relation2.trigger_span, Some((13, 19)));
3087 }
3088
3089 #[test]
3090 fn test_relation_span_distance() {
3091 let head = Entity::new("Marie Curie", EntityType::Person, 0, 11, 0.95);
3093 let tail = Entity::new("Sorbonne", EntityType::Organization, 24, 32, 0.90);
3094 let relation = Relation::new(head, tail, "WORKED_AT", 0.85);
3095 assert_eq!(relation.span_distance(), 13);
3096 }
3097
3098 #[test]
3099 fn test_relation_category() {
3100 let rel_type = EntityType::custom("CEO_OF", EntityCategory::Relation);
3102 assert_eq!(rel_type.category(), EntityCategory::Relation);
3103 assert!(rel_type.category().is_relation());
3104 assert!(rel_type.requires_ml()); }
3106
3107 #[test]
3112 fn test_span_text() {
3113 let span = Span::text(10, 20);
3114 assert!(span.is_text());
3115 assert!(!span.is_visual());
3116 assert_eq!(span.text_offsets(), Some((10, 20)));
3117 assert_eq!(span.len(), 10);
3118 assert!(!span.is_empty());
3119 }
3120
3121 #[test]
3122 fn test_span_bbox() {
3123 let span = Span::bbox(0.1, 0.2, 0.3, 0.4);
3124 assert!(!span.is_text());
3125 assert!(span.is_visual());
3126 assert_eq!(span.text_offsets(), None);
3127 assert_eq!(span.len(), 0); }
3129
3130 #[test]
3131 fn test_span_bbox_with_page() {
3132 let span = Span::bbox_on_page(0.1, 0.2, 0.3, 0.4, 5);
3133 if let Span::BoundingBox { page, .. } = span {
3134 assert_eq!(page, Some(5));
3135 } else {
3136 panic!("Expected BoundingBox");
3137 }
3138 }
3139
3140 #[test]
3141 fn test_span_hybrid() {
3142 let bbox = Span::bbox(0.1, 0.2, 0.3, 0.4);
3143 let hybrid = Span::Hybrid {
3144 start: 10,
3145 end: 20,
3146 bbox: Box::new(bbox),
3147 };
3148 assert!(hybrid.is_text());
3149 assert!(hybrid.is_visual());
3150 assert_eq!(hybrid.text_offsets(), Some((10, 20)));
3151 assert_eq!(hybrid.len(), 10);
3152 }
3153
3154 #[test]
3159 fn test_hierarchical_confidence_new() {
3160 let hc = HierarchicalConfidence::new(0.9, 0.8, 0.7);
3161 assert!((hc.linkage - 0.9).abs() < f32::EPSILON);
3162 assert!((hc.type_score - 0.8).abs() < f32::EPSILON);
3163 assert!((hc.boundary - 0.7).abs() < f32::EPSILON);
3164 }
3165
3166 #[test]
3167 fn test_hierarchical_confidence_clamping() {
3168 let hc = HierarchicalConfidence::new(1.5, -0.5, 0.5);
3169 assert!((hc.linkage - 1.0).abs() < f32::EPSILON);
3170 assert!(hc.type_score.abs() < f32::EPSILON);
3171 assert!((hc.boundary - 0.5).abs() < f32::EPSILON);
3172 }
3173
3174 #[test]
3175 fn test_hierarchical_confidence_from_single() {
3176 let hc = HierarchicalConfidence::from_single(0.8);
3177 assert!((hc.linkage - 0.8).abs() < f32::EPSILON);
3178 assert!((hc.type_score - 0.8).abs() < f32::EPSILON);
3179 assert!((hc.boundary - 0.8).abs() < f32::EPSILON);
3180 }
3181
3182 #[test]
3183 fn test_hierarchical_confidence_combined() {
3184 let hc = HierarchicalConfidence::new(1.0, 1.0, 1.0);
3185 assert!((hc.combined() - 1.0).abs() < f32::EPSILON);
3186
3187 let hc2 = HierarchicalConfidence::new(0.8, 0.8, 0.8);
3188 assert!((hc2.combined() - 0.8).abs() < f32::EPSILON);
3189
3190 let hc3 = HierarchicalConfidence::new(0.5, 0.5, 0.5);
3192 assert!((hc3.combined() - 0.5).abs() < 0.001);
3193 }
3194
3195 #[test]
3196 fn test_hierarchical_confidence_threshold() {
3197 let hc = HierarchicalConfidence::new(0.9, 0.8, 0.7);
3198 assert!(hc.passes_threshold(0.5, 0.5, 0.5));
3199 assert!(hc.passes_threshold(0.9, 0.8, 0.7));
3200 assert!(!hc.passes_threshold(0.95, 0.8, 0.7)); assert!(!hc.passes_threshold(0.9, 0.85, 0.7)); }
3203
3204 #[test]
3205 fn test_hierarchical_confidence_from_f64() {
3206 let hc: HierarchicalConfidence = 0.85_f64.into();
3207 assert!((hc.linkage - 0.85).abs() < 0.001);
3208 }
3209
3210 #[test]
3215 fn test_ragged_batch_from_sequences() {
3216 let seqs = vec![vec![1, 2, 3], vec![4, 5], vec![6, 7, 8, 9]];
3217 let batch = RaggedBatch::from_sequences(&seqs);
3218
3219 assert_eq!(batch.batch_size(), 3);
3220 assert_eq!(batch.total_tokens(), 9);
3221 assert_eq!(batch.max_seq_len, 4);
3222 assert_eq!(batch.cumulative_offsets, vec![0, 3, 5, 9]);
3223 }
3224
3225 #[test]
3226 fn test_ragged_batch_doc_range() {
3227 let seqs = vec![vec![1, 2, 3], vec![4, 5]];
3228 let batch = RaggedBatch::from_sequences(&seqs);
3229
3230 assert_eq!(batch.doc_range(0), Some(0..3));
3231 assert_eq!(batch.doc_range(1), Some(3..5));
3232 assert_eq!(batch.doc_range(2), None);
3233 }
3234
3235 #[test]
3236 fn test_ragged_batch_doc_tokens() {
3237 let seqs = vec![vec![1, 2, 3], vec![4, 5]];
3238 let batch = RaggedBatch::from_sequences(&seqs);
3239
3240 assert_eq!(batch.doc_tokens(0), Some(&[1, 2, 3][..]));
3241 assert_eq!(batch.doc_tokens(1), Some(&[4, 5][..]));
3242 }
3243
3244 #[test]
3245 fn test_ragged_batch_padding_savings() {
3246 let seqs = vec![vec![1, 2, 3], vec![4, 5], vec![6, 7, 8, 9]];
3250 let batch = RaggedBatch::from_sequences(&seqs);
3251 let savings = batch.padding_savings();
3252 assert!((savings - 0.25).abs() < 0.001);
3253 }
3254
3255 #[test]
3260 fn test_span_candidate() {
3261 let sc = SpanCandidate::new(0, 5, 10);
3262 assert_eq!(sc.doc_idx, 0);
3263 assert_eq!(sc.start, 5);
3264 assert_eq!(sc.end, 10);
3265 assert_eq!(sc.width(), 5);
3266 }
3267
3268 #[test]
3269 fn test_generate_span_candidates() {
3270 let seqs = vec![vec![1, 2, 3]]; let batch = RaggedBatch::from_sequences(&seqs);
3272 let candidates = generate_span_candidates(&batch, 2);
3273
3274 assert_eq!(candidates.len(), 5);
3277
3278 for c in &candidates {
3280 assert_eq!(c.doc_idx, 0);
3281 assert!(c.end as usize <= 3);
3282 assert!(c.width() as usize <= 2);
3283 }
3284 }
3285
3286 #[test]
3287 fn test_generate_filtered_candidates() {
3288 let seqs = vec![vec![1, 2, 3]];
3289 let batch = RaggedBatch::from_sequences(&seqs);
3290
3291 let mask = vec![0.9, 0.9, 0.1, 0.1, 0.1];
3294 let candidates = generate_filtered_candidates(&batch, 2, &mask, 0.5);
3295
3296 assert_eq!(candidates.len(), 2);
3297 }
3298
3299 #[test]
3304 fn test_entity_builder_basic() {
3305 let entity = Entity::builder("John", EntityType::Person)
3306 .span(0, 4)
3307 .confidence(0.95)
3308 .build();
3309
3310 assert_eq!(entity.text, "John");
3311 assert_eq!(entity.entity_type, EntityType::Person);
3312 assert_eq!(entity.start, 0);
3313 assert_eq!(entity.end, 4);
3314 assert!((entity.confidence - 0.95).abs() < f64::EPSILON);
3315 }
3316
3317 #[test]
3318 fn test_entity_builder_full() {
3319 let entity = Entity::builder("Marie Curie", EntityType::Person)
3320 .span(0, 11)
3321 .confidence(0.95)
3322 .kb_id("Q7186")
3323 .canonical_id(42)
3324 .normalized("Marie Salomea Skłodowska Curie")
3325 .provenance(Provenance::ml("bert", 0.95))
3326 .build();
3327
3328 assert_eq!(entity.text, "Marie Curie");
3329 assert_eq!(entity.kb_id.as_deref(), Some("Q7186"));
3330 assert_eq!(
3331 entity.canonical_id,
3332 Some(crate::core::types::CanonicalId::new(42))
3333 );
3334 assert_eq!(
3335 entity.normalized.as_deref(),
3336 Some("Marie Salomea Skłodowska Curie")
3337 );
3338 assert!(entity.provenance.is_some());
3339 }
3340
3341 #[test]
3342 fn test_entity_builder_hierarchical() {
3343 let hc = HierarchicalConfidence::new(0.9, 0.8, 0.7);
3344 let entity = Entity::builder("test", EntityType::Person)
3345 .span(0, 4)
3346 .hierarchical_confidence(hc)
3347 .build();
3348
3349 assert!(entity.hierarchical_confidence.is_some());
3350 assert!((entity.linkage_confidence() - 0.9).abs() < 0.001);
3351 assert!((entity.type_confidence() - 0.8).abs() < 0.001);
3352 assert!((entity.boundary_confidence() - 0.7).abs() < 0.001);
3353 }
3354
3355 #[test]
3356 fn test_entity_builder_visual() {
3357 let bbox = Span::bbox(0.1, 0.2, 0.3, 0.4);
3358 let entity = Entity::builder("receipt item", EntityType::Money)
3359 .visual_span(bbox)
3360 .confidence(0.9)
3361 .build();
3362
3363 assert!(entity.is_visual());
3364 assert!(entity.visual_span.is_some());
3365 }
3366
3367 #[test]
3372 fn test_entity_hierarchical_confidence_helpers() {
3373 let mut entity = Entity::new("test", EntityType::Person, 0, 4, 0.8);
3374
3375 assert!((entity.linkage_confidence() - 0.8).abs() < 0.001);
3377 assert!((entity.type_confidence() - 0.8).abs() < 0.001);
3378 assert!((entity.boundary_confidence() - 0.8).abs() < 0.001);
3379
3380 entity.set_hierarchical_confidence(HierarchicalConfidence::new(0.95, 0.85, 0.75));
3382 assert!((entity.linkage_confidence() - 0.95).abs() < 0.001);
3383 assert!((entity.type_confidence() - 0.85).abs() < 0.001);
3384 assert!((entity.boundary_confidence() - 0.75).abs() < 0.001);
3385 }
3386
3387 #[test]
3388 fn test_entity_from_visual() {
3389 let entity = Entity::from_visual(
3390 "receipt total",
3391 EntityType::Money,
3392 Span::bbox(0.5, 0.8, 0.2, 0.05),
3393 0.92,
3394 );
3395
3396 assert!(entity.is_visual());
3397 assert_eq!(entity.start, 0);
3398 assert_eq!(entity.end, 0);
3399 assert!((entity.confidence - 0.92).abs() < f64::EPSILON);
3400 }
3401
3402 #[test]
3403 fn test_entity_span_helpers() {
3404 let entity = Entity::new("test", EntityType::Person, 10, 20, 0.9);
3405 assert_eq!(entity.text_span(), (10, 20));
3406 assert_eq!(entity.span_len(), 10);
3407 }
3408
3409 #[test]
3414 fn test_provenance_pattern() {
3415 let prov = Provenance::pattern("EMAIL");
3416 assert_eq!(prov.method, ExtractionMethod::Pattern);
3417 assert_eq!(prov.pattern.as_deref(), Some("EMAIL"));
3418 assert_eq!(prov.raw_confidence, Some(1.0)); }
3420
3421 #[test]
3422 fn test_provenance_ml() {
3423 let prov = Provenance::ml("bert-ner", 0.87);
3424 assert_eq!(prov.method, ExtractionMethod::Neural);
3425 assert_eq!(prov.source.as_ref(), "bert-ner");
3426 assert_eq!(prov.raw_confidence, Some(0.87));
3427 }
3428
3429 #[test]
3430 fn test_provenance_with_version() {
3431 let prov = Provenance::ml("gliner", 0.92).with_version("v2.1.0");
3432
3433 assert_eq!(prov.model_version.as_deref(), Some("v2.1.0"));
3434 assert_eq!(prov.source.as_ref(), "gliner");
3435 }
3436
3437 #[test]
3438 fn test_provenance_with_timestamp() {
3439 let prov = Provenance::pattern("DATE").with_timestamp("2024-01-15T10:30:00Z");
3440
3441 assert_eq!(prov.timestamp.as_deref(), Some("2024-01-15T10:30:00Z"));
3442 }
3443
3444 #[test]
3445 fn test_provenance_builder_chain() {
3446 let prov = Provenance::ml("modernbert-ner", 0.95)
3447 .with_version("v1.0.0")
3448 .with_timestamp("2024-11-27T12:00:00Z");
3449
3450 assert_eq!(prov.method, ExtractionMethod::Neural);
3451 assert_eq!(prov.source.as_ref(), "modernbert-ner");
3452 assert_eq!(prov.raw_confidence, Some(0.95));
3453 assert_eq!(prov.model_version.as_deref(), Some("v1.0.0"));
3454 assert_eq!(prov.timestamp.as_deref(), Some("2024-11-27T12:00:00Z"));
3455 }
3456
3457 #[test]
3458 fn test_provenance_serialization() {
3459 let prov = Provenance::ml("test", 0.9)
3460 .with_version("v1.0")
3461 .with_timestamp("2024-01-01");
3462
3463 let json = serde_json::to_string(&prov).unwrap();
3464 assert!(json.contains("model_version"));
3465 assert!(json.contains("v1.0"));
3466
3467 let restored: Provenance = serde_json::from_str(&json).unwrap();
3468 assert_eq!(restored.model_version.as_deref(), Some("v1.0"));
3469 assert_eq!(restored.timestamp.as_deref(), Some("2024-01-01"));
3470 }
3471}
3472
3473#[cfg(test)]
3474mod proptests {
3475 #![allow(clippy::unwrap_used)] use super::*;
3477 use proptest::prelude::*;
3478
3479 proptest! {
3480 #[test]
3481 fn confidence_always_clamped(conf in -10.0f64..10.0) {
3482 let e = Entity::new("test", EntityType::Person, 0, 4, conf);
3483 prop_assert!(e.confidence >= 0.0);
3484 prop_assert!(e.confidence <= 1.0);
3485 }
3486
3487 #[test]
3488 fn entity_type_roundtrip(label in "[A-Z]{3,10}") {
3489 let et = EntityType::from_label(&label);
3490 let back = EntityType::from_label(et.as_label());
3491 prop_assert!(matches!(back, EntityType::Other(_)) || back == et);
3493 }
3494
3495 #[test]
3496 fn overlap_is_symmetric(
3497 s1 in 0usize..100,
3498 len1 in 1usize..50,
3499 s2 in 0usize..100,
3500 len2 in 1usize..50,
3501 ) {
3502 let e1 = Entity::new("a", EntityType::Person, s1, s1 + len1, 1.0);
3503 let e2 = Entity::new("b", EntityType::Person, s2, s2 + len2, 1.0);
3504 prop_assert_eq!(e1.overlaps(&e2), e2.overlaps(&e1));
3505 }
3506
3507 #[test]
3508 fn overlap_ratio_bounded(
3509 s1 in 0usize..100,
3510 len1 in 1usize..50,
3511 s2 in 0usize..100,
3512 len2 in 1usize..50,
3513 ) {
3514 let e1 = Entity::new("a", EntityType::Person, s1, s1 + len1, 1.0);
3515 let e2 = Entity::new("b", EntityType::Person, s2, s2 + len2, 1.0);
3516 let ratio = e1.overlap_ratio(&e2);
3517 prop_assert!(ratio >= 0.0);
3518 prop_assert!(ratio <= 1.0);
3519 }
3520
3521 #[test]
3522 fn self_overlap_ratio_is_one(s in 0usize..100, len in 1usize..50) {
3523 let e = Entity::new("test", EntityType::Person, s, s + len, 1.0);
3524 let ratio = e.overlap_ratio(&e);
3525 prop_assert!((ratio - 1.0).abs() < 1e-10);
3526 }
3527
3528 #[test]
3529 fn hierarchical_confidence_always_clamped(
3530 linkage in -2.0f32..2.0,
3531 type_score in -2.0f32..2.0,
3532 boundary in -2.0f32..2.0,
3533 ) {
3534 let hc = HierarchicalConfidence::new(linkage, type_score, boundary);
3535 prop_assert!(hc.linkage >= 0.0 && hc.linkage <= 1.0);
3536 prop_assert!(hc.type_score >= 0.0 && hc.type_score <= 1.0);
3537 prop_assert!(hc.boundary >= 0.0 && hc.boundary <= 1.0);
3538 prop_assert!(hc.combined() >= 0.0 && hc.combined() <= 1.0);
3539 }
3540
3541 #[test]
3542 fn span_candidate_width_consistent(
3543 doc in 0u32..10,
3544 start in 0u32..100,
3545 end in 1u32..100,
3546 ) {
3547 let actual_end = start.max(end);
3548 let sc = SpanCandidate::new(doc, start, actual_end);
3549 prop_assert_eq!(sc.width(), actual_end.saturating_sub(start));
3550 }
3551
3552 #[test]
3553 fn ragged_batch_preserves_tokens(
3554 seq_lens in proptest::collection::vec(1usize..10, 1..5),
3555 ) {
3556 let mut counter = 0u32;
3558 let seqs: Vec<Vec<u32>> = seq_lens.iter().map(|&len| {
3559 let seq: Vec<u32> = (counter..counter + len as u32).collect();
3560 counter += len as u32;
3561 seq
3562 }).collect();
3563
3564 let batch = RaggedBatch::from_sequences(&seqs);
3565
3566 prop_assert_eq!(batch.batch_size(), seqs.len());
3568 prop_assert_eq!(batch.total_tokens(), seq_lens.iter().sum::<usize>());
3569
3570 for (i, seq) in seqs.iter().enumerate() {
3572 let doc_tokens = batch.doc_tokens(i).unwrap();
3573 prop_assert_eq!(doc_tokens, seq.as_slice());
3574 }
3575 }
3576
3577 #[test]
3578 fn span_text_offsets_consistent(start in 0usize..100, len in 0usize..50) {
3579 let end = start + len;
3580 let span = Span::text(start, end);
3581 let (s, e) = span.text_offsets().unwrap();
3582 prop_assert_eq!(s, start);
3583 prop_assert_eq!(e, end);
3584 prop_assert_eq!(span.len(), len);
3585 }
3586 }
3587
3588 #[test]
3593 fn test_entity_viewport_as_str() {
3594 assert_eq!(EntityViewport::Business.as_str(), "business");
3595 assert_eq!(EntityViewport::Legal.as_str(), "legal");
3596 assert_eq!(EntityViewport::Technical.as_str(), "technical");
3597 assert_eq!(EntityViewport::Academic.as_str(), "academic");
3598 assert_eq!(EntityViewport::Personal.as_str(), "personal");
3599 assert_eq!(EntityViewport::Political.as_str(), "political");
3600 assert_eq!(EntityViewport::Media.as_str(), "media");
3601 assert_eq!(EntityViewport::Historical.as_str(), "historical");
3602 assert_eq!(EntityViewport::General.as_str(), "general");
3603 assert_eq!(
3604 EntityViewport::Custom("custom".to_string()).as_str(),
3605 "custom"
3606 );
3607 }
3608
3609 #[test]
3610 fn test_entity_viewport_is_professional() {
3611 assert!(EntityViewport::Business.is_professional());
3612 assert!(EntityViewport::Legal.is_professional());
3613 assert!(EntityViewport::Technical.is_professional());
3614 assert!(EntityViewport::Academic.is_professional());
3615 assert!(EntityViewport::Political.is_professional());
3616
3617 assert!(!EntityViewport::Personal.is_professional());
3618 assert!(!EntityViewport::Media.is_professional());
3619 assert!(!EntityViewport::Historical.is_professional());
3620 assert!(!EntityViewport::General.is_professional());
3621 assert!(!EntityViewport::Custom("test".to_string()).is_professional());
3622 }
3623
3624 #[test]
3625 fn test_entity_viewport_from_str() {
3626 assert_eq!(
3627 "business".parse::<EntityViewport>().unwrap(),
3628 EntityViewport::Business
3629 );
3630 assert_eq!(
3631 "financial".parse::<EntityViewport>().unwrap(),
3632 EntityViewport::Business
3633 );
3634 assert_eq!(
3635 "corporate".parse::<EntityViewport>().unwrap(),
3636 EntityViewport::Business
3637 );
3638
3639 assert_eq!(
3640 "legal".parse::<EntityViewport>().unwrap(),
3641 EntityViewport::Legal
3642 );
3643 assert_eq!(
3644 "law".parse::<EntityViewport>().unwrap(),
3645 EntityViewport::Legal
3646 );
3647
3648 assert_eq!(
3649 "technical".parse::<EntityViewport>().unwrap(),
3650 EntityViewport::Technical
3651 );
3652 assert_eq!(
3653 "engineering".parse::<EntityViewport>().unwrap(),
3654 EntityViewport::Technical
3655 );
3656
3657 assert_eq!(
3658 "academic".parse::<EntityViewport>().unwrap(),
3659 EntityViewport::Academic
3660 );
3661 assert_eq!(
3662 "research".parse::<EntityViewport>().unwrap(),
3663 EntityViewport::Academic
3664 );
3665
3666 assert_eq!(
3667 "personal".parse::<EntityViewport>().unwrap(),
3668 EntityViewport::Personal
3669 );
3670 assert_eq!(
3671 "biographical".parse::<EntityViewport>().unwrap(),
3672 EntityViewport::Personal
3673 );
3674
3675 assert_eq!(
3676 "political".parse::<EntityViewport>().unwrap(),
3677 EntityViewport::Political
3678 );
3679 assert_eq!(
3680 "policy".parse::<EntityViewport>().unwrap(),
3681 EntityViewport::Political
3682 );
3683
3684 assert_eq!(
3685 "media".parse::<EntityViewport>().unwrap(),
3686 EntityViewport::Media
3687 );
3688 assert_eq!(
3689 "press".parse::<EntityViewport>().unwrap(),
3690 EntityViewport::Media
3691 );
3692
3693 assert_eq!(
3694 "historical".parse::<EntityViewport>().unwrap(),
3695 EntityViewport::Historical
3696 );
3697 assert_eq!(
3698 "history".parse::<EntityViewport>().unwrap(),
3699 EntityViewport::Historical
3700 );
3701
3702 assert_eq!(
3703 "general".parse::<EntityViewport>().unwrap(),
3704 EntityViewport::General
3705 );
3706 assert_eq!(
3707 "generic".parse::<EntityViewport>().unwrap(),
3708 EntityViewport::General
3709 );
3710 assert_eq!(
3711 "".parse::<EntityViewport>().unwrap(),
3712 EntityViewport::General
3713 );
3714
3715 assert_eq!(
3717 "custom_viewport".parse::<EntityViewport>().unwrap(),
3718 EntityViewport::Custom("custom_viewport".to_string())
3719 );
3720 }
3721
3722 #[test]
3723 fn test_entity_viewport_from_str_case_insensitive() {
3724 assert_eq!(
3725 "BUSINESS".parse::<EntityViewport>().unwrap(),
3726 EntityViewport::Business
3727 );
3728 assert_eq!(
3729 "Business".parse::<EntityViewport>().unwrap(),
3730 EntityViewport::Business
3731 );
3732 assert_eq!(
3733 "BuSiNeSs".parse::<EntityViewport>().unwrap(),
3734 EntityViewport::Business
3735 );
3736 }
3737
3738 #[test]
3739 fn test_entity_viewport_display() {
3740 assert_eq!(format!("{}", EntityViewport::Business), "business");
3741 assert_eq!(format!("{}", EntityViewport::Academic), "academic");
3742 assert_eq!(
3743 format!("{}", EntityViewport::Custom("test".to_string())),
3744 "test"
3745 );
3746 }
3747
3748 #[test]
3749 fn test_entity_viewport_methods() {
3750 let mut entity = Entity::new("Marie Curie", EntityType::Person, 0, 11, 0.9);
3751
3752 assert!(!entity.has_viewport());
3754 assert_eq!(entity.viewport_or_default(), EntityViewport::General);
3755 assert!(entity.matches_viewport(&EntityViewport::Academic)); entity.set_viewport(EntityViewport::Academic);
3759 assert!(entity.has_viewport());
3760 assert_eq!(entity.viewport_or_default(), EntityViewport::Academic);
3761 assert!(entity.matches_viewport(&EntityViewport::Academic));
3762 assert!(!entity.matches_viewport(&EntityViewport::Business));
3763 }
3764
3765 #[test]
3766 fn test_entity_builder_with_viewport() {
3767 let entity = Entity::builder("Marie Curie", EntityType::Person)
3768 .span(0, 11)
3769 .viewport(EntityViewport::Academic)
3770 .build();
3771
3772 assert_eq!(entity.viewport, Some(EntityViewport::Academic));
3773 assert!(entity.has_viewport());
3774 }
3775
3776 #[test]
3781 fn test_entity_category_requires_ml() {
3782 assert!(EntityCategory::Agent.requires_ml());
3783 assert!(EntityCategory::Organization.requires_ml());
3784 assert!(EntityCategory::Place.requires_ml());
3785 assert!(EntityCategory::Creative.requires_ml());
3786 assert!(EntityCategory::Relation.requires_ml());
3787
3788 assert!(!EntityCategory::Temporal.requires_ml());
3789 assert!(!EntityCategory::Numeric.requires_ml());
3790 assert!(!EntityCategory::Contact.requires_ml());
3791 assert!(!EntityCategory::Misc.requires_ml());
3792 }
3793
3794 #[test]
3795 fn test_entity_category_pattern_detectable() {
3796 assert!(EntityCategory::Temporal.pattern_detectable());
3797 assert!(EntityCategory::Numeric.pattern_detectable());
3798 assert!(EntityCategory::Contact.pattern_detectable());
3799
3800 assert!(!EntityCategory::Agent.pattern_detectable());
3801 assert!(!EntityCategory::Organization.pattern_detectable());
3802 assert!(!EntityCategory::Place.pattern_detectable());
3803 assert!(!EntityCategory::Creative.pattern_detectable());
3804 assert!(!EntityCategory::Relation.pattern_detectable());
3805 assert!(!EntityCategory::Misc.pattern_detectable());
3806 }
3807
3808 #[test]
3809 fn test_entity_category_is_relation() {
3810 assert!(EntityCategory::Relation.is_relation());
3811
3812 assert!(!EntityCategory::Agent.is_relation());
3813 assert!(!EntityCategory::Organization.is_relation());
3814 assert!(!EntityCategory::Place.is_relation());
3815 assert!(!EntityCategory::Temporal.is_relation());
3816 assert!(!EntityCategory::Numeric.is_relation());
3817 assert!(!EntityCategory::Contact.is_relation());
3818 assert!(!EntityCategory::Creative.is_relation());
3819 assert!(!EntityCategory::Misc.is_relation());
3820 }
3821
3822 #[test]
3823 fn test_entity_category_as_str() {
3824 assert_eq!(EntityCategory::Agent.as_str(), "agent");
3825 assert_eq!(EntityCategory::Organization.as_str(), "organization");
3826 assert_eq!(EntityCategory::Place.as_str(), "place");
3827 assert_eq!(EntityCategory::Creative.as_str(), "creative");
3828 assert_eq!(EntityCategory::Temporal.as_str(), "temporal");
3829 assert_eq!(EntityCategory::Numeric.as_str(), "numeric");
3830 assert_eq!(EntityCategory::Contact.as_str(), "contact");
3831 assert_eq!(EntityCategory::Relation.as_str(), "relation");
3832 assert_eq!(EntityCategory::Misc.as_str(), "misc");
3833 }
3834
3835 #[test]
3836 fn test_entity_category_display() {
3837 assert_eq!(format!("{}", EntityCategory::Agent), "agent");
3838 assert_eq!(format!("{}", EntityCategory::Temporal), "temporal");
3839 assert_eq!(format!("{}", EntityCategory::Relation), "relation");
3840 }
3841}