1use super::confidence::Confidence;
39use super::types::MentionType;
40use serde::{Deserialize, Serialize};
41use std::borrow::Cow;
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
53#[non_exhaustive]
54pub enum EntityCategory {
55 Agent,
58 Organization,
61 Place,
64 Creative,
67 Temporal,
70 Numeric,
73 Contact,
76 Relation,
80 Misc,
82}
83
84impl EntityCategory {
85 #[must_use]
87 pub const fn requires_ml(&self) -> bool {
88 matches!(
89 self,
90 EntityCategory::Agent
91 | EntityCategory::Organization
92 | EntityCategory::Place
93 | EntityCategory::Creative
94 | EntityCategory::Relation
95 )
96 }
97
98 #[must_use]
100 pub const fn pattern_detectable(&self) -> bool {
101 matches!(
102 self,
103 EntityCategory::Temporal | EntityCategory::Numeric | EntityCategory::Contact
104 )
105 }
106
107 #[must_use]
109 pub const fn is_relation(&self) -> bool {
110 matches!(self, EntityCategory::Relation)
111 }
112
113 #[must_use]
115 pub const fn as_str(&self) -> &'static str {
116 match self {
117 EntityCategory::Agent => "agent",
118 EntityCategory::Organization => "organization",
119 EntityCategory::Place => "place",
120 EntityCategory::Creative => "creative",
121 EntityCategory::Temporal => "temporal",
122 EntityCategory::Numeric => "numeric",
123 EntityCategory::Contact => "contact",
124 EntityCategory::Relation => "relation",
125 EntityCategory::Misc => "misc",
126 }
127 }
128}
129
130impl std::fmt::Display for EntityCategory {
131 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132 write!(f, "{}", self.as_str())
133 }
134}
135
136#[derive(Debug, Clone, PartialEq, Eq, Hash)]
161#[non_exhaustive]
162pub enum EntityType {
163 Person,
166 Organization,
168 Location,
170
171 Date,
174 Time,
176
177 Money,
180 Percent,
182 Quantity,
184 Cardinal,
186 Ordinal,
188
189 Email,
192 Url,
194 Phone,
196
197 Custom {
200 name: String,
202 category: EntityCategory,
204 },
205}
206
207impl EntityType {
208 #[must_use]
210 pub fn category(&self) -> EntityCategory {
211 match self {
212 EntityType::Person => EntityCategory::Agent,
214 EntityType::Organization => EntityCategory::Organization,
216 EntityType::Location => EntityCategory::Place,
218 EntityType::Date | EntityType::Time => EntityCategory::Temporal,
220 EntityType::Money
222 | EntityType::Percent
223 | EntityType::Quantity
224 | EntityType::Cardinal
225 | EntityType::Ordinal => EntityCategory::Numeric,
226 EntityType::Email | EntityType::Url | EntityType::Phone => EntityCategory::Contact,
228 EntityType::Custom { category, .. } => *category,
230 }
231 }
232
233 #[must_use]
235 pub fn requires_ml(&self) -> bool {
236 self.category().requires_ml()
237 }
238
239 #[must_use]
241 pub fn pattern_detectable(&self) -> bool {
242 self.category().pattern_detectable()
243 }
244
245 #[must_use]
254 pub fn as_label(&self) -> &str {
255 match self {
256 EntityType::Person => "PER",
257 EntityType::Organization => "ORG",
258 EntityType::Location => "LOC",
259 EntityType::Date => "DATE",
260 EntityType::Time => "TIME",
261 EntityType::Money => "MONEY",
262 EntityType::Percent => "PERCENT",
263 EntityType::Quantity => "QUANTITY",
264 EntityType::Cardinal => "CARDINAL",
265 EntityType::Ordinal => "ORDINAL",
266 EntityType::Email => "EMAIL",
267 EntityType::Url => "URL",
268 EntityType::Phone => "PHONE",
269 EntityType::Custom { name, .. } => name.as_str(),
270 }
271 }
272
273 #[must_use]
285 pub fn from_label(label: &str) -> Self {
286 let label = label
288 .strip_prefix("B-")
289 .or_else(|| label.strip_prefix("I-"))
290 .or_else(|| label.strip_prefix("E-"))
291 .or_else(|| label.strip_prefix("S-"))
292 .unwrap_or(label);
293
294 match label.to_uppercase().as_str() {
295 "PER" | "PERSON" => EntityType::Person,
297 "ORG" | "ORGANIZATION" | "COMPANY" | "CORPORATION" => EntityType::Organization,
298 "LOC" | "LOCATION" | "GPE" | "GEO-LOC" => EntityType::Location,
299 "FACILITY" | "FAC" | "BUILDING" => {
301 EntityType::custom("BUILDING", EntityCategory::Place)
302 }
303 "PRODUCT" | "PROD" => EntityType::custom("PRODUCT", EntityCategory::Misc),
304 "EVENT" => EntityType::custom("EVENT", EntityCategory::Creative),
305 "CREATIVE-WORK" | "WORK_OF_ART" | "ART" => {
306 EntityType::custom("CREATIVE_WORK", EntityCategory::Creative)
307 }
308 "GROUP" | "NORP" => EntityType::custom("GROUP", EntityCategory::Agent),
309 "DATE" => EntityType::Date,
311 "TIME" => EntityType::Time,
312 "MONEY" | "CURRENCY" => EntityType::Money,
314 "PERCENT" | "PERCENTAGE" => EntityType::Percent,
315 "QUANTITY" => EntityType::Quantity,
316 "CARDINAL" => EntityType::Cardinal,
317 "ORDINAL" => EntityType::Ordinal,
318 "EMAIL" => EntityType::Email,
320 "URL" | "URI" => EntityType::Url,
321 "PHONE" | "TELEPHONE" => EntityType::Phone,
322 "MISC" | "MISCELLANEOUS" | "OTHER" => EntityType::custom("MISC", EntityCategory::Misc),
324 "DISEASE" | "DISORDER" => EntityType::custom("DISEASE", EntityCategory::Misc),
326 "CHEMICAL" | "DRUG" => EntityType::custom("CHEMICAL", EntityCategory::Misc),
327 "GENE" => EntityType::custom("GENE", EntityCategory::Misc),
328 "PROTEIN" => EntityType::custom("PROTEIN", EntityCategory::Misc),
329 other => EntityType::custom(other, EntityCategory::Misc),
331 }
332 }
333
334 #[must_use]
348 pub fn custom(name: impl Into<String>, category: EntityCategory) -> Self {
349 EntityType::Custom {
350 name: name.into(),
351 category,
352 }
353 }
354}
355
356impl std::fmt::Display for EntityType {
357 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
358 write!(f, "{}", self.as_label())
359 }
360}
361
362impl std::str::FromStr for EntityType {
363 type Err = std::convert::Infallible;
364
365 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
367 Ok(Self::from_label(s))
368 }
369}
370
371impl Serialize for EntityType {
376 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
377 serializer.serialize_str(self.as_label())
378 }
379}
380
381impl<'de> Deserialize<'de> for EntityType {
382 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
383 struct EntityTypeVisitor;
384
385 impl<'de> serde::de::Visitor<'de> for EntityTypeVisitor {
386 type Value = EntityType;
387
388 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
389 f.write_str("a string label or a tagged enum object")
390 }
391
392 fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<EntityType, E> {
394 Ok(EntityType::from_label(v))
395 }
396
397 fn visit_map<A: serde::de::MapAccess<'de>>(
400 self,
401 mut map: A,
402 ) -> Result<EntityType, A::Error> {
403 let key: String = map
404 .next_key()?
405 .ok_or_else(|| serde::de::Error::custom("empty object"))?;
406 match key.as_str() {
407 "Custom" => {
408 #[derive(Deserialize)]
409 struct CustomFields {
410 name: String,
411 category: EntityCategory,
412 }
413 let fields: CustomFields = map.next_value()?;
414 Ok(EntityType::Custom {
415 name: fields.name,
416 category: fields.category,
417 })
418 }
419 "Other" => {
420 let val: String = map.next_value()?;
422 Ok(EntityType::custom(val, EntityCategory::Misc))
423 }
424 variant => {
426 let _: serde::de::IgnoredAny = map.next_value()?;
428 Ok(EntityType::from_label(variant))
429 }
430 }
431 }
432 }
433
434 deserializer.deserialize_any(EntityTypeVisitor)
435 }
436}
437
438#[derive(Debug, Clone, Default)]
478pub struct TypeMapper {
479 mappings: std::collections::HashMap<String, EntityType>,
480}
481
482impl TypeMapper {
483 #[must_use]
485 pub fn new() -> Self {
486 Self::default()
487 }
488
489 #[must_use]
491 pub fn mit_movie() -> Self {
492 let mut mapper = Self::new();
493 mapper.add("ACTOR", EntityType::Person);
495 mapper.add("DIRECTOR", EntityType::Person);
496 mapper.add("CHARACTER", EntityType::Person);
497 mapper.add(
498 "TITLE",
499 EntityType::custom("WORK_OF_ART", EntityCategory::Creative),
500 );
501 mapper.add("GENRE", EntityType::custom("GENRE", EntityCategory::Misc));
502 mapper.add("YEAR", EntityType::Date);
503 mapper.add("RATING", EntityType::custom("RATING", EntityCategory::Misc));
504 mapper.add("PLOT", EntityType::custom("PLOT", EntityCategory::Misc));
505 mapper
506 }
507
508 #[must_use]
510 pub fn mit_restaurant() -> Self {
511 let mut mapper = Self::new();
512 mapper.add("RESTAURANT_NAME", EntityType::Organization);
513 mapper.add("LOCATION", EntityType::Location);
514 mapper.add(
515 "CUISINE",
516 EntityType::custom("CUISINE", EntityCategory::Misc),
517 );
518 mapper.add("DISH", EntityType::custom("DISH", EntityCategory::Misc));
519 mapper.add("PRICE", EntityType::Money);
520 mapper.add(
521 "AMENITY",
522 EntityType::custom("AMENITY", EntityCategory::Misc),
523 );
524 mapper.add("HOURS", EntityType::Time);
525 mapper
526 }
527
528 #[must_use]
530 pub fn biomedical() -> Self {
531 let mut mapper = Self::new();
532 mapper.add(
533 "DISEASE",
534 EntityType::custom("DISEASE", EntityCategory::Agent),
535 );
536 mapper.add(
537 "CHEMICAL",
538 EntityType::custom("CHEMICAL", EntityCategory::Misc),
539 );
540 mapper.add("DRUG", EntityType::custom("DRUG", EntityCategory::Misc));
541 mapper.add("GENE", EntityType::custom("GENE", EntityCategory::Misc));
542 mapper.add(
543 "PROTEIN",
544 EntityType::custom("PROTEIN", EntityCategory::Misc),
545 );
546 mapper.add("DNA", EntityType::custom("DNA", EntityCategory::Misc));
548 mapper.add("RNA", EntityType::custom("RNA", EntityCategory::Misc));
549 mapper.add(
550 "cell_line",
551 EntityType::custom("CELL_LINE", EntityCategory::Misc),
552 );
553 mapper.add(
554 "cell_type",
555 EntityType::custom("CELL_TYPE", EntityCategory::Misc),
556 );
557 mapper
558 }
559
560 #[must_use]
562 pub fn social_media() -> Self {
563 let mut mapper = Self::new();
564 mapper.add("person", EntityType::Person);
566 mapper.add("corporation", EntityType::Organization);
567 mapper.add("location", EntityType::Location);
568 mapper.add("group", EntityType::Organization);
569 mapper.add(
570 "product",
571 EntityType::custom("PRODUCT", EntityCategory::Misc),
572 );
573 mapper.add(
574 "creative_work",
575 EntityType::custom("WORK_OF_ART", EntityCategory::Creative),
576 );
577 mapper.add("event", EntityType::custom("EVENT", EntityCategory::Misc));
578 mapper
579 }
580
581 #[must_use]
583 pub fn manufacturing() -> Self {
584 let mut mapper = Self::new();
585 mapper.add("MATE", EntityType::custom("MATERIAL", EntityCategory::Misc));
587 mapper.add("MANP", EntityType::custom("PROCESS", EntityCategory::Misc));
588 mapper.add("MACEQ", EntityType::custom("MACHINE", EntityCategory::Misc));
589 mapper.add(
590 "APPL",
591 EntityType::custom("APPLICATION", EntityCategory::Misc),
592 );
593 mapper.add("FEAT", EntityType::custom("FEATURE", EntityCategory::Misc));
594 mapper.add(
595 "PARA",
596 EntityType::custom("PARAMETER", EntityCategory::Misc),
597 );
598 mapper.add("PRO", EntityType::custom("PROPERTY", EntityCategory::Misc));
599 mapper.add(
600 "CHAR",
601 EntityType::custom("CHARACTERISTIC", EntityCategory::Misc),
602 );
603 mapper.add(
604 "ENAT",
605 EntityType::custom("ENABLING_TECHNOLOGY", EntityCategory::Misc),
606 );
607 mapper.add(
608 "CONPRI",
609 EntityType::custom("CONCEPT_PRINCIPLE", EntityCategory::Misc),
610 );
611 mapper.add(
612 "BIOP",
613 EntityType::custom("BIO_PROCESS", EntityCategory::Misc),
614 );
615 mapper.add(
616 "MANS",
617 EntityType::custom("MAN_STANDARD", EntityCategory::Misc),
618 );
619 mapper
620 }
621
622 pub fn add(&mut self, source: impl Into<String>, target: EntityType) {
624 self.mappings.insert(source.into().to_uppercase(), target);
625 }
626
627 #[must_use]
629 pub fn map(&self, label: &str) -> Option<&EntityType> {
630 self.mappings.get(&label.to_uppercase())
631 }
632
633 #[must_use]
637 pub fn normalize(&self, label: &str) -> EntityType {
638 self.map(label)
639 .cloned()
640 .unwrap_or_else(|| EntityType::from_label(label))
641 }
642
643 #[must_use]
645 pub fn contains(&self, label: &str) -> bool {
646 self.mappings.contains_key(&label.to_uppercase())
647 }
648
649 pub fn labels(&self) -> impl Iterator<Item = &String> {
651 self.mappings.keys()
652 }
653}
654
655#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
668#[non_exhaustive]
669pub enum ExtractionMethod {
670 Pattern,
673
674 #[default]
677 Neural,
678
679 Consensus,
681
682 Heuristic,
685
686 Unknown,
688}
689
690impl ExtractionMethod {
691 #[must_use]
715 pub const fn is_calibrated(&self) -> bool {
716 match self {
717 ExtractionMethod::Neural => true,
718 ExtractionMethod::Pattern => false,
720 ExtractionMethod::Consensus => false,
721 ExtractionMethod::Heuristic => false,
722 ExtractionMethod::Unknown => false,
723 }
724 }
725
726 #[must_use]
733 pub const fn confidence_interpretation(&self) -> &'static str {
734 match self {
735 ExtractionMethod::Neural => "probability",
736 ExtractionMethod::Pattern => "binary",
737 ExtractionMethod::Heuristic => "heuristic_score",
738 ExtractionMethod::Consensus => "agreement_ratio",
739 ExtractionMethod::Unknown => "unknown",
740 }
741 }
742}
743
744impl std::fmt::Display for ExtractionMethod {
745 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
746 match self {
747 ExtractionMethod::Pattern => write!(f, "pattern"),
748 ExtractionMethod::Neural => write!(f, "neural"),
749 ExtractionMethod::Consensus => write!(f, "consensus"),
750 ExtractionMethod::Heuristic => write!(f, "heuristic"),
751 ExtractionMethod::Unknown => write!(f, "unknown"),
752 }
753 }
754}
755
756pub trait Lexicon: Send + Sync {
810 fn lookup(&self, text: &str) -> Option<(EntityType, Confidence)>;
814
815 fn contains(&self, text: &str) -> bool {
817 self.lookup(text).is_some()
818 }
819
820 fn source(&self) -> &str;
822
823 fn len(&self) -> usize;
825
826 fn is_empty(&self) -> bool {
828 self.len() == 0
829 }
830}
831
832#[derive(Debug, Clone)]
837pub struct HashMapLexicon {
838 entries: std::collections::HashMap<String, (EntityType, Confidence)>,
839 source: String,
840}
841
842impl HashMapLexicon {
843 #[must_use]
845 pub fn new(source: impl Into<String>) -> Self {
846 Self {
847 entries: std::collections::HashMap::new(),
848 source: source.into(),
849 }
850 }
851
852 pub fn insert(
854 &mut self,
855 text: impl Into<String>,
856 entity_type: EntityType,
857 confidence: impl Into<Confidence>,
858 ) {
859 self.entries
860 .insert(text.into(), (entity_type, confidence.into()));
861 }
862
863 pub fn from_iter<I, S, C>(source: impl Into<String>, entries: I) -> Self
865 where
866 I: IntoIterator<Item = (S, EntityType, C)>,
867 S: Into<String>,
868 C: Into<Confidence>,
869 {
870 let mut lexicon = Self::new(source);
871 for (text, entity_type, conf) in entries {
872 lexicon.insert(text, entity_type, conf);
873 }
874 lexicon
875 }
876
877 pub fn entries(&self) -> impl Iterator<Item = (&str, &EntityType, Confidence)> {
879 self.entries.iter().map(|(k, (t, c))| (k.as_str(), t, *c))
880 }
881}
882
883impl Lexicon for HashMapLexicon {
884 fn lookup(&self, text: &str) -> Option<(EntityType, Confidence)> {
885 self.entries.get(text).cloned()
886 }
887
888 fn source(&self) -> &str {
889 &self.source
890 }
891
892 fn len(&self) -> usize {
893 self.entries.len()
894 }
895}
896
897#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
902pub struct Provenance {
903 pub source: Cow<'static, str>,
905 pub method: ExtractionMethod,
907 pub pattern: Option<Cow<'static, str>>,
909 pub raw_confidence: Option<Confidence>,
911 #[serde(default, skip_serializing_if = "Option::is_none")]
913 pub model_version: Option<Cow<'static, str>>,
914 #[serde(default, skip_serializing_if = "Option::is_none")]
916 pub timestamp: Option<String>,
917}
918
919impl Provenance {
920 #[must_use]
922 pub fn pattern(pattern_name: &'static str) -> Self {
923 Self {
924 source: Cow::Borrowed("pattern"),
925 method: ExtractionMethod::Pattern,
926 pattern: Some(Cow::Borrowed(pattern_name)),
927 raw_confidence: Some(Confidence::ONE), model_version: None,
929 timestamp: None,
930 }
931 }
932
933 #[must_use]
947 pub fn ml(model_name: impl Into<Cow<'static, str>>, confidence: impl Into<Confidence>) -> Self {
948 Self {
949 source: model_name.into(),
950 method: ExtractionMethod::Neural,
951 pattern: None,
952 raw_confidence: Some(confidence.into()),
953 model_version: None,
954 timestamp: None,
955 }
956 }
957
958 #[must_use]
960 pub fn ensemble(sources: &'static str) -> Self {
961 Self {
962 source: Cow::Borrowed(sources),
963 method: ExtractionMethod::Consensus,
964 pattern: None,
965 raw_confidence: None,
966 model_version: None,
967 timestamp: None,
968 }
969 }
970
971 #[must_use]
973 pub fn with_version(mut self, version: &'static str) -> Self {
974 self.model_version = Some(Cow::Borrowed(version));
975 self
976 }
977
978 #[must_use]
980 pub fn with_timestamp(mut self, timestamp: impl Into<String>) -> Self {
981 self.timestamp = Some(timestamp.into());
982 self
983 }
984}
985
986#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1017pub enum Span {
1018 Text {
1023 start: usize,
1025 end: usize,
1027 },
1028 BoundingBox {
1031 x: f32,
1033 y: f32,
1035 width: f32,
1037 height: f32,
1039 page: Option<u32>,
1041 },
1042 Hybrid {
1044 start: usize,
1046 end: usize,
1048 bbox: Box<Span>,
1050 },
1051}
1052
1053impl Span {
1054 #[must_use]
1056 pub const fn text(start: usize, end: usize) -> Self {
1057 Self::Text { start, end }
1058 }
1059
1060 #[must_use]
1062 pub fn bbox(x: f32, y: f32, width: f32, height: f32) -> Self {
1063 Self::BoundingBox {
1064 x,
1065 y,
1066 width,
1067 height,
1068 page: None,
1069 }
1070 }
1071
1072 #[must_use]
1074 pub fn bbox_on_page(x: f32, y: f32, width: f32, height: f32, page: u32) -> Self {
1075 Self::BoundingBox {
1076 x,
1077 y,
1078 width,
1079 height,
1080 page: Some(page),
1081 }
1082 }
1083
1084 #[must_use]
1086 pub const fn is_text(&self) -> bool {
1087 matches!(self, Self::Text { .. } | Self::Hybrid { .. })
1088 }
1089
1090 #[must_use]
1092 pub const fn is_visual(&self) -> bool {
1093 matches!(self, Self::BoundingBox { .. } | Self::Hybrid { .. })
1094 }
1095
1096 #[must_use]
1098 pub const fn text_offsets(&self) -> Option<(usize, usize)> {
1099 match self {
1100 Self::Text { start, end } => Some((*start, *end)),
1101 Self::Hybrid { start, end, .. } => Some((*start, *end)),
1102 Self::BoundingBox { .. } => None,
1103 }
1104 }
1105
1106 #[must_use]
1108 pub fn len(&self) -> usize {
1109 match self {
1110 Self::Text { start, end } => end.saturating_sub(*start),
1111 Self::Hybrid { start, end, .. } => end.saturating_sub(*start),
1112 Self::BoundingBox { .. } => 0,
1113 }
1114 }
1115
1116 #[must_use]
1118 pub fn is_empty(&self) -> bool {
1119 self.len() == 0
1120 }
1121}
1122
1123#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1164pub struct DiscontinuousSpan {
1165 segments: Vec<std::ops::Range<usize>>,
1168}
1169
1170impl DiscontinuousSpan {
1171 #[must_use]
1176 pub fn new(mut segments: Vec<std::ops::Range<usize>>) -> Self {
1177 segments.retain(|r| r.start < r.end);
1179 segments.sort_by_key(|r| r.start);
1181 let mut merged: Vec<std::ops::Range<usize>> = Vec::with_capacity(segments.len());
1183 for seg in segments {
1184 if let Some(last) = merged.last_mut() {
1185 if seg.start <= last.end {
1186 last.end = last.end.max(seg.end);
1188 continue;
1189 }
1190 }
1191 merged.push(seg);
1192 }
1193 Self { segments: merged }
1194 }
1195
1196 #[must_use]
1198 #[allow(clippy::single_range_in_vec_init)] pub fn contiguous(start: usize, end: usize) -> Self {
1200 Self {
1201 segments: vec![start..end],
1202 }
1203 }
1204
1205 #[must_use]
1207 pub fn num_segments(&self) -> usize {
1208 self.segments.len()
1209 }
1210
1211 #[must_use]
1213 pub fn is_discontinuous(&self) -> bool {
1214 self.segments.len() > 1
1215 }
1216
1217 #[must_use]
1219 pub fn is_contiguous(&self) -> bool {
1220 self.segments.len() <= 1
1221 }
1222
1223 #[must_use]
1225 pub fn segments(&self) -> &[std::ops::Range<usize>] {
1226 &self.segments
1227 }
1228
1229 #[must_use]
1231 pub fn bounding_range(&self) -> Option<std::ops::Range<usize>> {
1232 if self.segments.is_empty() {
1233 return None;
1234 }
1235 let start = self.segments.first()?.start;
1236 let end = self.segments.last()?.end;
1237 Some(start..end)
1238 }
1239
1240 #[must_use]
1243 pub fn total_len(&self) -> usize {
1244 self.segments.iter().map(|r| r.end - r.start).sum()
1245 }
1246
1247 #[must_use]
1249 pub fn extract_text(&self, text: &str, separator: &str) -> String {
1250 self.segments
1251 .iter()
1252 .map(|r| {
1253 let start = r.start;
1254 let len = r.end.saturating_sub(r.start);
1255 text.chars().skip(start).take(len).collect::<String>()
1256 })
1257 .collect::<Vec<_>>()
1258 .join(separator)
1259 }
1260
1261 #[must_use]
1271 pub fn contains(&self, pos: usize) -> bool {
1272 self.segments.iter().any(|r| r.contains(&pos))
1273 }
1274
1275 #[must_use]
1277 pub fn to_span(&self) -> Option<Span> {
1278 self.bounding_range().map(|r| Span::Text {
1279 start: r.start,
1280 end: r.end,
1281 })
1282 }
1283}
1284
1285impl From<std::ops::Range<usize>> for DiscontinuousSpan {
1286 fn from(range: std::ops::Range<usize>) -> Self {
1287 Self::contiguous(range.start, range.end)
1288 }
1289}
1290
1291impl Default for Span {
1292 fn default() -> Self {
1293 Self::Text { start: 0, end: 0 }
1294 }
1295}
1296
1297#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
1309pub struct HierarchicalConfidence {
1310 pub linkage: Confidence,
1313 pub type_score: Confidence,
1315 pub boundary: Confidence,
1318}
1319
1320impl HierarchicalConfidence {
1321 #[must_use]
1326 pub fn new(
1327 linkage: impl Into<Confidence>,
1328 type_score: impl Into<Confidence>,
1329 boundary: impl Into<Confidence>,
1330 ) -> Self {
1331 Self {
1332 linkage: linkage.into(),
1333 type_score: type_score.into(),
1334 boundary: boundary.into(),
1335 }
1336 }
1337
1338 #[must_use]
1341 pub fn from_single(confidence: impl Into<Confidence>) -> Self {
1342 let c = confidence.into();
1343 Self {
1344 linkage: c,
1345 type_score: c,
1346 boundary: c,
1347 }
1348 }
1349
1350 #[must_use]
1353 pub fn combined(&self) -> Confidence {
1354 let product = self.linkage.value() * self.type_score.value() * self.boundary.value();
1355 Confidence::new(product.powf(1.0 / 3.0))
1356 }
1357
1358 #[must_use]
1360 pub fn as_f64(&self) -> f64 {
1361 self.combined().value()
1362 }
1363
1364 #[must_use]
1366 pub fn passes_threshold(&self, linkage_min: f64, type_min: f64, boundary_min: f64) -> bool {
1367 self.linkage >= linkage_min && self.type_score >= type_min && self.boundary >= boundary_min
1368 }
1369}
1370
1371impl Default for HierarchicalConfidence {
1372 fn default() -> Self {
1373 Self {
1374 linkage: Confidence::ONE,
1375 type_score: Confidence::ONE,
1376 boundary: Confidence::ONE,
1377 }
1378 }
1379}
1380
1381impl From<f64> for HierarchicalConfidence {
1382 fn from(confidence: f64) -> Self {
1383 Self::from_single(confidence)
1384 }
1385}
1386
1387impl From<f32> for HierarchicalConfidence {
1388 fn from(confidence: f32) -> Self {
1389 Self::from_single(confidence)
1390 }
1391}
1392
1393impl From<Confidence> for HierarchicalConfidence {
1394 fn from(confidence: Confidence) -> Self {
1395 Self::from_single(confidence)
1396 }
1397}
1398
1399#[derive(Debug, Clone)]
1421pub struct RaggedBatch {
1422 pub token_ids: Vec<u32>,
1425 pub cumulative_offsets: Vec<u32>,
1429 pub max_seq_len: usize,
1431}
1432
1433impl RaggedBatch {
1434 pub fn from_sequences(sequences: &[Vec<u32>]) -> Self {
1436 let total_tokens: usize = sequences.iter().map(|s| s.len()).sum();
1437 let mut token_ids = Vec::with_capacity(total_tokens);
1438 let mut cumulative_offsets = Vec::with_capacity(sequences.len() + 1);
1439 let mut max_seq_len = 0;
1440
1441 cumulative_offsets.push(0);
1442 for seq in sequences {
1443 token_ids.extend_from_slice(seq);
1444 let len = token_ids.len();
1448 if len > u32::MAX as usize {
1449 log::warn!(
1452 "Token count {} exceeds u32::MAX, truncating to {}",
1453 len,
1454 u32::MAX
1455 );
1456 cumulative_offsets.push(u32::MAX);
1457 } else {
1458 cumulative_offsets.push(len as u32);
1459 }
1460 max_seq_len = max_seq_len.max(seq.len());
1461 }
1462
1463 Self {
1464 token_ids,
1465 cumulative_offsets,
1466 max_seq_len,
1467 }
1468 }
1469
1470 #[must_use]
1472 pub fn batch_size(&self) -> usize {
1473 self.cumulative_offsets.len().saturating_sub(1)
1474 }
1475
1476 #[must_use]
1478 pub fn total_tokens(&self) -> usize {
1479 self.token_ids.len()
1480 }
1481
1482 #[must_use]
1484 pub fn doc_range(&self, doc_idx: usize) -> Option<std::ops::Range<usize>> {
1485 if doc_idx + 1 < self.cumulative_offsets.len() {
1486 let start = self.cumulative_offsets[doc_idx] as usize;
1487 let end = self.cumulative_offsets[doc_idx + 1] as usize;
1488 Some(start..end)
1489 } else {
1490 None
1491 }
1492 }
1493
1494 #[must_use]
1496 pub fn doc_tokens(&self, doc_idx: usize) -> Option<&[u32]> {
1497 self.doc_range(doc_idx).map(|r| &self.token_ids[r])
1498 }
1499
1500 #[must_use]
1502 pub fn padding_savings(&self) -> f64 {
1503 let padded_size = self.batch_size() * self.max_seq_len;
1504 if padded_size == 0 {
1505 return 0.0;
1506 }
1507 1.0 - (self.total_tokens() as f64 / padded_size as f64)
1508 }
1509}
1510
1511#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1520pub struct SpanCandidate {
1521 pub doc_idx: u32,
1523 pub start: u32,
1525 pub end: u32,
1527}
1528
1529impl SpanCandidate {
1530 #[must_use]
1532 pub const fn new(doc_idx: u32, start: u32, end: u32) -> Self {
1533 Self {
1534 doc_idx,
1535 start,
1536 end,
1537 }
1538 }
1539
1540 #[must_use]
1542 pub const fn width(&self) -> u32 {
1543 self.end.saturating_sub(self.start)
1544 }
1545}
1546
1547pub fn generate_span_candidates(batch: &RaggedBatch, max_width: usize) -> Vec<SpanCandidate> {
1552 let mut candidates = Vec::new();
1553
1554 for doc_idx in 0..batch.batch_size() {
1555 if let Some(range) = batch.doc_range(doc_idx) {
1556 let doc_len = range.len();
1557 for start in 0..doc_len {
1559 let max_end = (start + max_width).min(doc_len);
1560 for end in (start + 1)..=max_end {
1561 candidates.push(SpanCandidate::new(doc_idx as u32, start as u32, end as u32));
1562 }
1563 }
1564 }
1565 }
1566
1567 candidates
1568}
1569
1570pub fn generate_filtered_candidates(
1574 batch: &RaggedBatch,
1575 max_width: usize,
1576 linkage_mask: &[f32],
1577 threshold: f32,
1578) -> Vec<SpanCandidate> {
1579 let mut candidates = Vec::new();
1580 let mut mask_idx = 0;
1581
1582 for doc_idx in 0..batch.batch_size() {
1583 if let Some(range) = batch.doc_range(doc_idx) {
1584 let doc_len = range.len();
1585 for start in 0..doc_len {
1586 let max_end = (start + max_width).min(doc_len);
1587 for end in (start + 1)..=max_end {
1588 if mask_idx < linkage_mask.len() && linkage_mask[mask_idx] >= threshold {
1590 candidates.push(SpanCandidate::new(
1591 doc_idx as u32,
1592 start as u32,
1593 end as u32,
1594 ));
1595 }
1596 mask_idx += 1;
1597 }
1598 }
1599 }
1600 }
1601
1602 candidates
1603}
1604
1605#[derive(Debug, Clone, Serialize)]
1654pub struct Entity {
1655 pub text: String,
1657 pub entity_type: EntityType,
1659 start: usize,
1666 end: usize,
1673 pub confidence: Confidence,
1678 #[serde(default, skip_serializing_if = "Option::is_none")]
1680 pub normalized: Option<String>,
1681 #[serde(default, skip_serializing_if = "Option::is_none")]
1683 pub provenance: Option<Provenance>,
1684 #[serde(default, skip_serializing_if = "Option::is_none")]
1687 pub kb_id: Option<String>,
1688 #[serde(default, skip_serializing_if = "Option::is_none")]
1692 pub canonical_id: Option<super::types::CanonicalId>,
1693 #[serde(default, skip_serializing_if = "Option::is_none")]
1696 pub hierarchical_confidence: Option<HierarchicalConfidence>,
1697 #[serde(default, skip_serializing_if = "Option::is_none")]
1700 pub visual_span: Option<Span>,
1701 #[serde(default, skip_serializing_if = "Option::is_none")]
1705 pub discontinuous_span: Option<DiscontinuousSpan>,
1706 #[serde(default, skip_serializing_if = "Option::is_none")]
1712 pub mention_type: Option<MentionType>,
1713}
1714
1715impl<'de> Deserialize<'de> for Entity {
1716 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
1717 #[derive(Deserialize)]
1721 struct EntityHelper {
1722 text: String,
1723 entity_type: EntityType,
1724 start: usize,
1725 end: usize,
1726 confidence: Confidence,
1727 #[serde(default)]
1728 normalized: Option<String>,
1729 #[serde(default)]
1730 provenance: Option<Provenance>,
1731 #[serde(default)]
1732 kb_id: Option<String>,
1733 #[serde(default)]
1734 canonical_id: Option<super::types::CanonicalId>,
1735 #[serde(default)]
1736 hierarchical_confidence: Option<HierarchicalConfidence>,
1737 #[serde(default)]
1738 visual_span: Option<Span>,
1739 #[serde(default)]
1740 discontinuous_span: Option<DiscontinuousSpan>,
1741 #[serde(default)]
1742 mention_type: Option<MentionType>,
1743 }
1744
1745 let h = EntityHelper::deserialize(deserializer)?;
1746 let mut entity = Entity::new(h.text, h.entity_type, h.start, h.end, h.confidence);
1747 entity.normalized = h.normalized;
1748 entity.provenance = h.provenance;
1749 entity.kb_id = h.kb_id;
1750 entity.canonical_id = h.canonical_id;
1751 entity.hierarchical_confidence = h.hierarchical_confidence;
1752 entity.visual_span = h.visual_span;
1753 entity.discontinuous_span = h.discontinuous_span;
1754 entity.mention_type = h.mention_type;
1755 Ok(entity)
1756 }
1757}
1758
1759impl Entity {
1760 #[must_use]
1771 pub fn new(
1772 text: impl Into<String>,
1773 entity_type: EntityType,
1774 start: usize,
1775 end: usize,
1776 confidence: impl Into<Confidence>,
1777 ) -> Self {
1778 let (start, end) = if start > end {
1780 (end, start)
1781 } else {
1782 (start, end)
1783 };
1784 Self {
1785 text: text.into(),
1786 entity_type,
1787 start,
1788 end,
1789 confidence: confidence.into(),
1790 normalized: None,
1791 provenance: None,
1792 kb_id: None,
1793 canonical_id: None,
1794 hierarchical_confidence: None,
1795 visual_span: None,
1796 discontinuous_span: None,
1797 mention_type: None,
1798 }
1799 }
1800
1801 #[inline]
1803 #[must_use]
1804 pub fn start(&self) -> usize {
1805 self.start
1806 }
1807
1808 #[inline]
1810 #[must_use]
1811 pub fn end(&self) -> usize {
1812 self.end
1813 }
1814
1815 #[inline]
1817 pub fn set_start(&mut self, start: usize) {
1818 self.start = start;
1819 }
1820
1821 #[inline]
1823 pub fn set_end(&mut self, end: usize) {
1824 self.end = end;
1825 }
1826
1827 #[must_use]
1829 pub fn with_provenance(
1830 text: impl Into<String>,
1831 entity_type: EntityType,
1832 start: usize,
1833 end: usize,
1834 confidence: impl Into<Confidence>,
1835 provenance: Provenance,
1836 ) -> Self {
1837 let (start, end) = if start > end {
1838 (end, start)
1839 } else {
1840 (start, end)
1841 };
1842 Self {
1843 text: text.into(),
1844 entity_type,
1845 start,
1846 end,
1847 confidence: confidence.into(),
1848 normalized: None,
1849 provenance: Some(provenance),
1850 kb_id: None,
1851 canonical_id: None,
1852 hierarchical_confidence: None,
1853 visual_span: None,
1854 discontinuous_span: None,
1855 mention_type: None,
1856 }
1857 }
1858
1859 #[must_use]
1861 pub fn with_hierarchical_confidence(
1862 text: impl Into<String>,
1863 entity_type: EntityType,
1864 start: usize,
1865 end: usize,
1866 confidence: HierarchicalConfidence,
1867 ) -> Self {
1868 let (start, end) = if start > end {
1869 (end, start)
1870 } else {
1871 (start, end)
1872 };
1873 Self {
1874 text: text.into(),
1875 entity_type,
1876 start,
1877 end,
1878 confidence: Confidence::new(confidence.as_f64()),
1879 normalized: None,
1880 provenance: None,
1881 kb_id: None,
1882 canonical_id: None,
1883 hierarchical_confidence: Some(confidence),
1884 visual_span: None,
1885 discontinuous_span: None,
1886 mention_type: None,
1887 }
1888 }
1889
1890 #[must_use]
1892 pub fn from_visual(
1893 text: impl Into<String>,
1894 entity_type: EntityType,
1895 bbox: Span,
1896 confidence: impl Into<Confidence>,
1897 ) -> Self {
1898 Self {
1899 text: text.into(),
1900 entity_type,
1901 start: 0,
1902 end: 0,
1903 confidence: confidence.into(),
1904 normalized: None,
1905 provenance: None,
1906 kb_id: None,
1907 canonical_id: None,
1908 hierarchical_confidence: None,
1909 visual_span: Some(bbox),
1910 discontinuous_span: None,
1911 mention_type: None,
1912 }
1913 }
1914
1915 #[must_use]
1917 pub fn with_type(
1918 text: impl Into<String>,
1919 entity_type: EntityType,
1920 start: usize,
1921 end: usize,
1922 ) -> Self {
1923 Self::new(text, entity_type, start, end, 1.0)
1924 }
1925
1926 pub fn link_to_kb(&mut self, kb_id: impl Into<String>) {
1936 self.kb_id = Some(kb_id.into());
1937 }
1938
1939 pub fn set_canonical(&mut self, canonical_id: impl Into<super::types::CanonicalId>) {
1943 self.canonical_id = Some(canonical_id.into());
1944 }
1945
1946 #[must_use]
1956 pub fn with_canonical_id(mut self, canonical_id: impl Into<super::types::CanonicalId>) -> Self {
1957 self.canonical_id = Some(canonical_id.into());
1958 self
1959 }
1960
1961 #[must_use]
1963 pub fn is_linked(&self) -> bool {
1964 self.kb_id.is_some()
1965 }
1966
1967 #[must_use]
1969 pub fn has_coreference(&self) -> bool {
1970 self.canonical_id.is_some()
1971 }
1972
1973 #[must_use]
1979 pub fn is_discontinuous(&self) -> bool {
1980 self.discontinuous_span
1981 .as_ref()
1982 .map(|s| s.is_discontinuous())
1983 .unwrap_or(false)
1984 }
1985
1986 #[must_use]
1990 pub fn discontinuous_segments(&self) -> Option<Vec<std::ops::Range<usize>>> {
1991 self.discontinuous_span
1992 .as_ref()
1993 .filter(|s| s.is_discontinuous())
1994 .map(|s| s.segments().to_vec())
1995 }
1996
1997 pub fn set_discontinuous_span(&mut self, span: DiscontinuousSpan) {
2001 if let Some(bounding) = span.bounding_range() {
2003 self.start = bounding.start;
2004 self.end = bounding.end;
2005 }
2006 self.discontinuous_span = Some(span);
2007 }
2008
2009 #[must_use]
2017 pub fn total_len(&self) -> usize {
2018 if let Some(ref span) = self.discontinuous_span {
2019 span.segments().iter().map(|r| r.end - r.start).sum()
2020 } else {
2021 self.end.saturating_sub(self.start)
2022 }
2023 }
2024
2025 pub fn set_normalized(&mut self, normalized: impl Into<String>) {
2037 self.normalized = Some(normalized.into());
2038 }
2039
2040 #[must_use]
2042 pub fn normalized_or_text(&self) -> &str {
2043 self.normalized.as_deref().unwrap_or(&self.text)
2044 }
2045
2046 #[must_use]
2048 pub fn method(&self) -> ExtractionMethod {
2049 self.provenance
2050 .as_ref()
2051 .map_or(ExtractionMethod::Unknown, |p| p.method)
2052 }
2053
2054 #[must_use]
2056 pub fn source(&self) -> Option<&str> {
2057 self.provenance.as_ref().map(|p| p.source.as_ref())
2058 }
2059
2060 #[must_use]
2062 pub fn category(&self) -> EntityCategory {
2063 self.entity_type.category()
2064 }
2065
2066 #[must_use]
2068 pub fn is_structured(&self) -> bool {
2069 self.entity_type.pattern_detectable()
2070 }
2071
2072 #[must_use]
2074 pub fn is_named(&self) -> bool {
2075 self.entity_type.requires_ml()
2076 }
2077
2078 #[must_use]
2080 pub fn overlaps(&self, other: &Entity) -> bool {
2081 !(self.end <= other.start || other.end <= self.start)
2082 }
2083
2084 #[must_use]
2086 pub fn overlap_ratio(&self, other: &Entity) -> f64 {
2087 let intersection_start = self.start.max(other.start);
2088 let intersection_end = self.end.min(other.end);
2089
2090 if intersection_start >= intersection_end {
2091 return 0.0;
2092 }
2093
2094 let intersection = (intersection_end - intersection_start) as f64;
2095 let union = ((self.end - self.start) + (other.end - other.start)
2096 - (intersection_end - intersection_start)) as f64;
2097
2098 if union == 0.0 {
2099 return 1.0;
2100 }
2101
2102 intersection / union
2103 }
2104
2105 pub fn set_hierarchical_confidence(&mut self, confidence: HierarchicalConfidence) {
2107 self.confidence = Confidence::new(confidence.as_f64());
2108 self.hierarchical_confidence = Some(confidence);
2109 }
2110
2111 #[must_use]
2113 pub fn linkage_confidence(&self) -> Confidence {
2114 self.hierarchical_confidence
2115 .map_or(self.confidence, |h| h.linkage)
2116 }
2117
2118 #[must_use]
2120 pub fn type_confidence(&self) -> Confidence {
2121 self.hierarchical_confidence
2122 .map_or(self.confidence, |h| h.type_score)
2123 }
2124
2125 #[must_use]
2127 pub fn boundary_confidence(&self) -> Confidence {
2128 self.hierarchical_confidence
2129 .map_or(self.confidence, |h| h.boundary)
2130 }
2131
2132 #[must_use]
2134 pub fn is_visual(&self) -> bool {
2135 self.visual_span.is_some()
2136 }
2137
2138 #[must_use]
2140 pub const fn text_span(&self) -> (usize, usize) {
2141 (self.start, self.end)
2142 }
2143
2144 #[must_use]
2146 pub const fn span_len(&self) -> usize {
2147 self.end.saturating_sub(self.start)
2148 }
2149
2150 pub fn set_visual_span(&mut self, span: Span) {
2175 self.visual_span = Some(span);
2176 }
2177
2178 #[must_use]
2198 pub fn extract_text(&self, source_text: &str) -> String {
2199 let char_count = source_text.chars().count();
2203 self.extract_text_with_len(source_text, char_count)
2204 }
2205
2206 #[must_use]
2218 pub fn extract_text_with_len(&self, source_text: &str, text_char_count: usize) -> String {
2219 if self.start >= text_char_count || self.end > text_char_count || self.start >= self.end {
2220 return String::new();
2221 }
2222 source_text
2223 .chars()
2224 .skip(self.start)
2225 .take(self.end - self.start)
2226 .collect()
2227 }
2228
2229 #[must_use]
2231 pub fn builder(text: impl Into<String>, entity_type: EntityType) -> EntityBuilder {
2232 EntityBuilder::new(text, entity_type)
2233 }
2234
2235 #[must_use]
2268 pub fn validate(&self, source_text: &str) -> Vec<ValidationIssue> {
2269 let char_count = source_text.chars().count();
2271 self.validate_with_len(source_text, char_count)
2272 }
2273
2274 #[must_use]
2286 pub fn validate_with_len(
2287 &self,
2288 source_text: &str,
2289 text_char_count: usize,
2290 ) -> Vec<ValidationIssue> {
2291 let mut issues = Vec::new();
2292
2293 if self.start >= self.end {
2295 issues.push(ValidationIssue::InvalidSpan {
2296 start: self.start,
2297 end: self.end,
2298 reason: "start must be less than end".to_string(),
2299 });
2300 }
2301
2302 if self.end > text_char_count {
2303 issues.push(ValidationIssue::SpanOutOfBounds {
2304 end: self.end,
2305 text_len: text_char_count,
2306 });
2307 }
2308
2309 if self.start < self.end && self.end <= text_char_count {
2311 let actual = self.extract_text_with_len(source_text, text_char_count);
2312 if actual != self.text {
2313 issues.push(ValidationIssue::TextMismatch {
2314 expected: self.text.clone(),
2315 actual,
2316 start: self.start,
2317 end: self.end,
2318 });
2319 }
2320 }
2321
2322 if let EntityType::Custom { ref name, .. } = self.entity_type {
2326 if name.is_empty() {
2327 issues.push(ValidationIssue::InvalidType {
2328 reason: "Custom entity type has empty name".to_string(),
2329 });
2330 }
2331 }
2332
2333 if let Some(ref disc_span) = self.discontinuous_span {
2335 for (i, seg) in disc_span.segments().iter().enumerate() {
2336 if seg.start >= seg.end {
2337 issues.push(ValidationIssue::InvalidSpan {
2338 start: seg.start,
2339 end: seg.end,
2340 reason: format!("discontinuous segment {} is invalid", i),
2341 });
2342 }
2343 if seg.end > text_char_count {
2344 issues.push(ValidationIssue::SpanOutOfBounds {
2345 end: seg.end,
2346 text_len: text_char_count,
2347 });
2348 }
2349 }
2350 }
2351
2352 issues
2353 }
2354
2355 #[must_use]
2359 pub fn is_valid(&self, source_text: &str) -> bool {
2360 self.validate(source_text).is_empty()
2361 }
2362
2363 #[must_use]
2383 pub fn validate_batch(
2384 entities: &[Entity],
2385 source_text: &str,
2386 ) -> std::collections::HashMap<usize, Vec<ValidationIssue>> {
2387 entities
2388 .iter()
2389 .enumerate()
2390 .filter_map(|(idx, entity)| {
2391 let issues = entity.validate(source_text);
2392 if issues.is_empty() {
2393 None
2394 } else {
2395 Some((idx, issues))
2396 }
2397 })
2398 .collect()
2399 }
2400}
2401
2402#[derive(Debug, Clone, PartialEq)]
2404pub enum ValidationIssue {
2405 InvalidSpan {
2407 start: usize,
2409 end: usize,
2411 reason: String,
2413 },
2414 SpanOutOfBounds {
2416 end: usize,
2418 text_len: usize,
2420 },
2421 TextMismatch {
2423 expected: String,
2425 actual: String,
2427 start: usize,
2429 end: usize,
2431 },
2432 InvalidConfidence {
2434 value: f64,
2436 },
2437 InvalidType {
2439 reason: String,
2441 },
2442}
2443
2444impl std::fmt::Display for ValidationIssue {
2445 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2446 match self {
2447 ValidationIssue::InvalidSpan { start, end, reason } => {
2448 write!(f, "Invalid span [{}, {}): {}", start, end, reason)
2449 }
2450 ValidationIssue::SpanOutOfBounds { end, text_len } => {
2451 write!(f, "Span end {} exceeds text length {}", end, text_len)
2452 }
2453 ValidationIssue::TextMismatch {
2454 expected,
2455 actual,
2456 start,
2457 end,
2458 } => {
2459 write!(
2460 f,
2461 "Text mismatch at [{}, {}): expected '{}', got '{}'",
2462 start, end, expected, actual
2463 )
2464 }
2465 ValidationIssue::InvalidConfidence { value } => {
2466 write!(f, "Confidence {} outside [0.0, 1.0]", value)
2467 }
2468 ValidationIssue::InvalidType { reason } => {
2469 write!(f, "Invalid entity type: {}", reason)
2470 }
2471 }
2472 }
2473}
2474
2475#[derive(Debug, Clone)]
2490pub struct EntityBuilder {
2491 text: String,
2492 entity_type: EntityType,
2493 start: usize,
2494 end: usize,
2495 confidence: Confidence,
2496 normalized: Option<String>,
2497 provenance: Option<Provenance>,
2498 kb_id: Option<String>,
2499 canonical_id: Option<super::types::CanonicalId>,
2500 hierarchical_confidence: Option<HierarchicalConfidence>,
2501 visual_span: Option<Span>,
2502 discontinuous_span: Option<DiscontinuousSpan>,
2503 mention_type: Option<MentionType>,
2504}
2505
2506impl EntityBuilder {
2507 #[must_use]
2509 pub fn new(text: impl Into<String>, entity_type: EntityType) -> Self {
2510 let text = text.into();
2511 let end = text.chars().count();
2512 Self {
2513 text,
2514 entity_type,
2515 start: 0,
2516 end,
2517 confidence: Confidence::ONE,
2518 normalized: None,
2519 provenance: None,
2520 kb_id: None,
2521 canonical_id: None,
2522 hierarchical_confidence: None,
2523 visual_span: None,
2524 discontinuous_span: None,
2525 mention_type: None,
2526 }
2527 }
2528
2529 #[must_use]
2531 pub const fn span(mut self, start: usize, end: usize) -> Self {
2532 self.start = start;
2533 self.end = end;
2534 self
2535 }
2536
2537 #[must_use]
2539 pub fn confidence(mut self, confidence: impl Into<Confidence>) -> Self {
2540 self.confidence = confidence.into();
2541 self
2542 }
2543
2544 #[must_use]
2546 pub fn hierarchical_confidence(mut self, confidence: HierarchicalConfidence) -> Self {
2547 self.confidence = Confidence::new(confidence.as_f64());
2548 self.hierarchical_confidence = Some(confidence);
2549 self
2550 }
2551
2552 #[must_use]
2554 pub fn normalized(mut self, normalized: impl Into<String>) -> Self {
2555 self.normalized = Some(normalized.into());
2556 self
2557 }
2558
2559 #[must_use]
2561 pub fn provenance(mut self, provenance: Provenance) -> Self {
2562 self.provenance = Some(provenance);
2563 self
2564 }
2565
2566 #[must_use]
2568 pub fn kb_id(mut self, kb_id: impl Into<String>) -> Self {
2569 self.kb_id = Some(kb_id.into());
2570 self
2571 }
2572
2573 #[must_use]
2575 pub const fn canonical_id(mut self, canonical_id: u64) -> Self {
2576 self.canonical_id = Some(super::types::CanonicalId::new(canonical_id));
2577 self
2578 }
2579
2580 #[must_use]
2582 pub fn visual_span(mut self, span: Span) -> Self {
2583 self.visual_span = Some(span);
2584 self
2585 }
2586
2587 #[must_use]
2591 pub fn discontinuous_span(mut self, span: DiscontinuousSpan) -> Self {
2592 if let Some(bounding) = span.bounding_range() {
2594 self.start = bounding.start;
2595 self.end = bounding.end;
2596 }
2597 self.discontinuous_span = Some(span);
2598 self
2599 }
2600
2601 #[must_use]
2603 pub fn mention_type(mut self, mention_type: MentionType) -> Self {
2604 self.mention_type = Some(mention_type);
2605 self
2606 }
2607
2608 #[must_use]
2610 pub fn build(self) -> Entity {
2611 Entity {
2612 text: self.text,
2613 entity_type: self.entity_type,
2614 start: self.start,
2615 end: self.end,
2616 confidence: self.confidence,
2617 normalized: self.normalized,
2618 provenance: self.provenance,
2619 kb_id: self.kb_id,
2620 canonical_id: self.canonical_id,
2621 hierarchical_confidence: self.hierarchical_confidence,
2622 visual_span: self.visual_span,
2623 discontinuous_span: self.discontinuous_span,
2624 mention_type: self.mention_type,
2625 }
2626 }
2627}
2628
2629#[derive(Debug, Clone, Serialize, Deserialize)]
2655pub struct Relation {
2656 pub head: Entity,
2658 pub tail: Entity,
2660 pub relation_type: String,
2662 pub trigger_span: Option<(usize, usize)>,
2665 pub confidence: Confidence,
2667}
2668
2669impl Relation {
2670 #[must_use]
2672 pub fn new(
2673 head: Entity,
2674 tail: Entity,
2675 relation_type: impl Into<String>,
2676 confidence: impl Into<Confidence>,
2677 ) -> Self {
2678 Self {
2679 head,
2680 tail,
2681 relation_type: relation_type.into(),
2682 trigger_span: None,
2683 confidence: confidence.into(),
2684 }
2685 }
2686
2687 #[must_use]
2689 pub fn with_trigger(
2690 head: Entity,
2691 tail: Entity,
2692 relation_type: impl Into<String>,
2693 trigger_start: usize,
2694 trigger_end: usize,
2695 confidence: impl Into<Confidence>,
2696 ) -> Self {
2697 Self {
2698 head,
2699 tail,
2700 relation_type: relation_type.into(),
2701 trigger_span: Some((trigger_start, trigger_end)),
2702 confidence: confidence.into(),
2703 }
2704 }
2705
2706 #[must_use]
2708 pub fn as_triple(&self) -> String {
2709 format!(
2710 "({}, {}, {})",
2711 self.head.text, self.relation_type, self.tail.text
2712 )
2713 }
2714
2715 #[must_use]
2718 pub fn span_distance(&self) -> usize {
2719 if self.head.end <= self.tail.start {
2720 self.tail.start.saturating_sub(self.head.end)
2721 } else if self.tail.end <= self.head.start {
2722 self.head.start.saturating_sub(self.tail.end)
2723 } else {
2724 0 }
2726 }
2727}
2728
2729#[cfg(test)]
2730mod tests {
2731 #![allow(clippy::unwrap_used)] use super::*;
2733
2734 #[test]
2735 fn entity_new_swaps_inverted_span() {
2736 let e = Entity::new("test", EntityType::Person, 10, 5, 0.9);
2737 assert_eq!(e.start(), 5);
2738 assert_eq!(e.end(), 10);
2739 }
2740
2741 #[test]
2742 fn entity_deserialize_swaps_inverted_span() {
2743 let json = r#"{"text":"test","entity_type":"PER","start":10,"end":5,"confidence":0.9}"#;
2744 let e: Entity = serde_json::from_str(json).unwrap();
2745 assert_eq!(e.start(), 5);
2746 assert_eq!(e.end(), 10);
2747 }
2748
2749 #[test]
2750 fn entity_serde_round_trip() {
2751 let original = Entity::new("Berlin", EntityType::Location, 10, 16, 0.95);
2752 let json = serde_json::to_string(&original).unwrap();
2753 let restored: Entity = serde_json::from_str(&json).unwrap();
2754 assert_eq!(restored.text, original.text);
2755 assert_eq!(restored.entity_type, original.entity_type);
2756 assert_eq!(restored.start(), original.start());
2757 assert_eq!(restored.end(), original.end());
2758 assert!((restored.confidence.value() - original.confidence.value()).abs() < f64::EPSILON);
2759 }
2760
2761 #[test]
2762 fn test_entity_type_roundtrip() {
2763 let types = [
2764 EntityType::Person,
2765 EntityType::Organization,
2766 EntityType::Location,
2767 EntityType::Date,
2768 EntityType::Money,
2769 EntityType::Percent,
2770 ];
2771
2772 for t in types {
2773 let label = t.as_label();
2774 let parsed = EntityType::from_label(label);
2775 assert_eq!(t, parsed);
2776 }
2777 }
2778
2779 #[test]
2780 fn test_entity_overlap() {
2781 let e1 = Entity::new("John", EntityType::Person, 0, 4, 0.9);
2782 let e2 = Entity::new("Smith", EntityType::Person, 5, 10, 0.9);
2783 let e3 = Entity::new("John Smith", EntityType::Person, 0, 10, 0.9);
2784
2785 assert!(!e1.overlaps(&e2)); assert!(e1.overlaps(&e3)); assert!(e3.overlaps(&e2)); }
2789
2790 #[test]
2791 fn test_confidence_clamping() {
2792 let e1 = Entity::new("test", EntityType::Person, 0, 4, 1.5);
2793 assert!((e1.confidence - 1.0).abs() < f64::EPSILON);
2794
2795 let e2 = Entity::new("test", EntityType::Person, 0, 4, -0.5);
2796 assert!(e2.confidence.abs() < f64::EPSILON);
2797 }
2798
2799 #[test]
2800 fn test_entity_categories() {
2801 assert_eq!(EntityType::Person.category(), EntityCategory::Agent);
2803 assert_eq!(
2804 EntityType::Organization.category(),
2805 EntityCategory::Organization
2806 );
2807 assert_eq!(EntityType::Location.category(), EntityCategory::Place);
2808 assert!(EntityType::Person.requires_ml());
2809 assert!(!EntityType::Person.pattern_detectable());
2810
2811 assert_eq!(EntityType::Date.category(), EntityCategory::Temporal);
2813 assert_eq!(EntityType::Time.category(), EntityCategory::Temporal);
2814 assert!(EntityType::Date.pattern_detectable());
2815 assert!(!EntityType::Date.requires_ml());
2816
2817 assert_eq!(EntityType::Money.category(), EntityCategory::Numeric);
2819 assert_eq!(EntityType::Percent.category(), EntityCategory::Numeric);
2820 assert!(EntityType::Money.pattern_detectable());
2821
2822 assert_eq!(EntityType::Email.category(), EntityCategory::Contact);
2824 assert_eq!(EntityType::Url.category(), EntityCategory::Contact);
2825 assert_eq!(EntityType::Phone.category(), EntityCategory::Contact);
2826 assert!(EntityType::Email.pattern_detectable());
2827 }
2828
2829 #[test]
2830 fn test_new_types_roundtrip() {
2831 let types = [
2832 EntityType::Time,
2833 EntityType::Email,
2834 EntityType::Url,
2835 EntityType::Phone,
2836 EntityType::Quantity,
2837 EntityType::Cardinal,
2838 EntityType::Ordinal,
2839 ];
2840
2841 for t in types {
2842 let label = t.as_label();
2843 let parsed = EntityType::from_label(label);
2844 assert_eq!(t, parsed, "Roundtrip failed for {}", label);
2845 }
2846 }
2847
2848 #[test]
2849 fn test_custom_entity_type() {
2850 let disease = EntityType::custom("DISEASE", EntityCategory::Agent);
2851 assert_eq!(disease.as_label(), "DISEASE");
2852 assert!(disease.requires_ml());
2853
2854 let product_id = EntityType::custom("PRODUCT_ID", EntityCategory::Misc);
2855 assert_eq!(product_id.as_label(), "PRODUCT_ID");
2856 assert!(!product_id.requires_ml());
2857 assert!(!product_id.pattern_detectable());
2858 }
2859
2860 #[test]
2861 fn test_entity_normalization() {
2862 let mut e = Entity::new("Jan 15", EntityType::Date, 0, 6, 0.95);
2863 assert!(e.normalized.is_none());
2864 assert_eq!(e.normalized_or_text(), "Jan 15");
2865
2866 e.set_normalized("2024-01-15");
2867 assert_eq!(e.normalized.as_deref(), Some("2024-01-15"));
2868 assert_eq!(e.normalized_or_text(), "2024-01-15");
2869 }
2870
2871 #[test]
2872 fn test_entity_helpers() {
2873 let named = Entity::new("John", EntityType::Person, 0, 4, 0.9);
2874 assert!(named.is_named());
2875 assert!(!named.is_structured());
2876 assert_eq!(named.category(), EntityCategory::Agent);
2877
2878 let structured = Entity::new("$100", EntityType::Money, 0, 4, 0.95);
2879 assert!(!structured.is_named());
2880 assert!(structured.is_structured());
2881 assert_eq!(structured.category(), EntityCategory::Numeric);
2882 }
2883
2884 #[test]
2885 fn test_knowledge_linking() {
2886 let mut entity = Entity::new("Marie Curie", EntityType::Person, 0, 11, 0.95);
2887 assert!(!entity.is_linked());
2888 assert!(!entity.has_coreference());
2889
2890 entity.link_to_kb("Q7186"); assert!(entity.is_linked());
2892 assert_eq!(entity.kb_id.as_deref(), Some("Q7186"));
2893
2894 entity.set_canonical(42);
2895 assert!(entity.has_coreference());
2896 assert_eq!(
2897 entity.canonical_id,
2898 Some(crate::core::types::CanonicalId::new(42))
2899 );
2900 }
2901
2902 #[test]
2903 fn test_relation_creation() {
2904 let head = Entity::new("Marie Curie", EntityType::Person, 0, 11, 0.95);
2905 let tail = Entity::new("Sorbonne", EntityType::Organization, 24, 32, 0.90);
2906
2907 let relation = Relation::new(head.clone(), tail.clone(), "WORKED_AT", 0.85);
2908 assert_eq!(relation.relation_type, "WORKED_AT");
2909 assert_eq!(relation.as_triple(), "(Marie Curie, WORKED_AT, Sorbonne)");
2910 assert!(relation.trigger_span.is_none());
2911
2912 let relation2 = Relation::with_trigger(head, tail, "EMPLOYMENT", 13, 19, 0.85);
2914 assert_eq!(relation2.trigger_span, Some((13, 19)));
2915 }
2916
2917 #[test]
2918 fn test_relation_span_distance() {
2919 let head = Entity::new("Marie Curie", EntityType::Person, 0, 11, 0.95);
2921 let tail = Entity::new("Sorbonne", EntityType::Organization, 24, 32, 0.90);
2922 let relation = Relation::new(head, tail, "WORKED_AT", 0.85);
2923 assert_eq!(relation.span_distance(), 13);
2924 }
2925
2926 #[test]
2927 fn test_relation_category() {
2928 let rel_type = EntityType::custom("CEO_OF", EntityCategory::Relation);
2930 assert_eq!(rel_type.category(), EntityCategory::Relation);
2931 assert!(rel_type.category().is_relation());
2932 assert!(rel_type.requires_ml()); }
2934
2935 #[test]
2940 fn test_span_text() {
2941 let span = Span::text(10, 20);
2942 assert!(span.is_text());
2943 assert!(!span.is_visual());
2944 assert_eq!(span.text_offsets(), Some((10, 20)));
2945 assert_eq!(span.len(), 10);
2946 assert!(!span.is_empty());
2947 }
2948
2949 #[test]
2950 fn test_span_bbox() {
2951 let span = Span::bbox(0.1, 0.2, 0.3, 0.4);
2952 assert!(!span.is_text());
2953 assert!(span.is_visual());
2954 assert_eq!(span.text_offsets(), None);
2955 assert_eq!(span.len(), 0); }
2957
2958 #[test]
2959 fn test_span_bbox_with_page() {
2960 let span = Span::bbox_on_page(0.1, 0.2, 0.3, 0.4, 5);
2961 if let Span::BoundingBox { page, .. } = span {
2962 assert_eq!(page, Some(5));
2963 } else {
2964 panic!("Expected BoundingBox");
2965 }
2966 }
2967
2968 #[test]
2969 fn test_span_hybrid() {
2970 let bbox = Span::bbox(0.1, 0.2, 0.3, 0.4);
2971 let hybrid = Span::Hybrid {
2972 start: 10,
2973 end: 20,
2974 bbox: Box::new(bbox),
2975 };
2976 assert!(hybrid.is_text());
2977 assert!(hybrid.is_visual());
2978 assert_eq!(hybrid.text_offsets(), Some((10, 20)));
2979 assert_eq!(hybrid.len(), 10);
2980 }
2981
2982 #[test]
2987 fn test_hierarchical_confidence_new() {
2988 let hc = HierarchicalConfidence::new(0.9, 0.8, 0.7);
2989 assert!((hc.linkage - 0.9).abs() < f64::EPSILON);
2990 assert!((hc.type_score - 0.8).abs() < f64::EPSILON);
2991 assert!((hc.boundary - 0.7).abs() < f64::EPSILON);
2992 }
2993
2994 #[test]
2995 fn test_hierarchical_confidence_clamping() {
2996 let hc = HierarchicalConfidence::new(1.5, -0.5, 0.5);
2997 assert_eq!(hc.linkage, 1.0);
2998 assert_eq!(hc.type_score, 0.0);
2999 assert_eq!(hc.boundary, 0.5);
3000 }
3001
3002 #[test]
3003 fn test_hierarchical_confidence_from_single() {
3004 let hc = HierarchicalConfidence::from_single(0.8);
3005 assert!((hc.linkage - 0.8).abs() < f64::EPSILON);
3006 assert!((hc.type_score - 0.8).abs() < f64::EPSILON);
3007 assert!((hc.boundary - 0.8).abs() < f64::EPSILON);
3008 }
3009
3010 #[test]
3011 fn test_hierarchical_confidence_combined() {
3012 let hc = HierarchicalConfidence::new(1.0, 1.0, 1.0);
3013 assert!((hc.combined() - 1.0).abs() < f64::EPSILON);
3014
3015 let hc2 = HierarchicalConfidence::new(0.8, 0.8, 0.8);
3016 assert!((hc2.combined() - 0.8).abs() < 0.001);
3017
3018 let hc3 = HierarchicalConfidence::new(0.5, 0.5, 0.5);
3020 assert!((hc3.combined() - 0.5).abs() < 0.001);
3021 }
3022
3023 #[test]
3024 fn test_hierarchical_confidence_threshold() {
3025 let hc = HierarchicalConfidence::new(0.9, 0.8, 0.7);
3026 assert!(hc.passes_threshold(0.5, 0.5, 0.5));
3027 assert!(hc.passes_threshold(0.9, 0.8, 0.7));
3028 assert!(!hc.passes_threshold(0.95, 0.8, 0.7)); assert!(!hc.passes_threshold(0.9, 0.85, 0.7)); }
3031
3032 #[test]
3033 fn test_hierarchical_confidence_from_f64() {
3034 let hc: HierarchicalConfidence = 0.85_f64.into();
3035 assert!((hc.linkage - 0.85).abs() < 0.001);
3036 }
3037
3038 #[test]
3043 fn test_ragged_batch_from_sequences() {
3044 let seqs = vec![vec![1, 2, 3], vec![4, 5], vec![6, 7, 8, 9]];
3045 let batch = RaggedBatch::from_sequences(&seqs);
3046
3047 assert_eq!(batch.batch_size(), 3);
3048 assert_eq!(batch.total_tokens(), 9);
3049 assert_eq!(batch.max_seq_len, 4);
3050 assert_eq!(batch.cumulative_offsets, vec![0, 3, 5, 9]);
3051 }
3052
3053 #[test]
3054 fn test_ragged_batch_doc_range() {
3055 let seqs = vec![vec![1, 2, 3], vec![4, 5]];
3056 let batch = RaggedBatch::from_sequences(&seqs);
3057
3058 assert_eq!(batch.doc_range(0), Some(0..3));
3059 assert_eq!(batch.doc_range(1), Some(3..5));
3060 assert_eq!(batch.doc_range(2), None);
3061 }
3062
3063 #[test]
3064 fn test_ragged_batch_doc_tokens() {
3065 let seqs = vec![vec![1, 2, 3], vec![4, 5]];
3066 let batch = RaggedBatch::from_sequences(&seqs);
3067
3068 assert_eq!(batch.doc_tokens(0), Some(&[1, 2, 3][..]));
3069 assert_eq!(batch.doc_tokens(1), Some(&[4, 5][..]));
3070 }
3071
3072 #[test]
3073 fn test_ragged_batch_padding_savings() {
3074 let seqs = vec![vec![1, 2, 3], vec![4, 5], vec![6, 7, 8, 9]];
3078 let batch = RaggedBatch::from_sequences(&seqs);
3079 let savings = batch.padding_savings();
3080 assert!((savings - 0.25).abs() < 0.001);
3081 }
3082
3083 #[test]
3088 fn test_span_candidate() {
3089 let sc = SpanCandidate::new(0, 5, 10);
3090 assert_eq!(sc.doc_idx, 0);
3091 assert_eq!(sc.start, 5);
3092 assert_eq!(sc.end, 10);
3093 assert_eq!(sc.width(), 5);
3094 }
3095
3096 #[test]
3097 fn test_generate_span_candidates() {
3098 let seqs = vec![vec![1, 2, 3]]; let batch = RaggedBatch::from_sequences(&seqs);
3100 let candidates = generate_span_candidates(&batch, 2);
3101
3102 assert_eq!(candidates.len(), 5);
3105
3106 for c in &candidates {
3108 assert_eq!(c.doc_idx, 0);
3109 assert!(c.end as usize <= 3);
3110 assert!(c.width() as usize <= 2);
3111 }
3112 }
3113
3114 #[test]
3115 fn test_generate_filtered_candidates() {
3116 let seqs = vec![vec![1, 2, 3]];
3117 let batch = RaggedBatch::from_sequences(&seqs);
3118
3119 let mask = vec![0.9, 0.9, 0.1, 0.1, 0.1];
3122 let candidates = generate_filtered_candidates(&batch, 2, &mask, 0.5);
3123
3124 assert_eq!(candidates.len(), 2);
3125 }
3126
3127 #[test]
3132 fn test_entity_builder_basic() {
3133 let entity = Entity::builder("John", EntityType::Person)
3134 .span(0, 4)
3135 .confidence(0.95)
3136 .build();
3137
3138 assert_eq!(entity.text, "John");
3139 assert_eq!(entity.entity_type, EntityType::Person);
3140 assert_eq!(entity.start(), 0);
3141 assert_eq!(entity.end(), 4);
3142 assert!((entity.confidence - 0.95).abs() < f64::EPSILON);
3143 }
3144
3145 #[test]
3146 fn test_entity_builder_full() {
3147 let entity = Entity::builder("Marie Curie", EntityType::Person)
3148 .span(0, 11)
3149 .confidence(0.95)
3150 .kb_id("Q7186")
3151 .canonical_id(42)
3152 .normalized("Marie Salomea Skłodowska Curie")
3153 .provenance(Provenance::ml("bert", 0.95))
3154 .build();
3155
3156 assert_eq!(entity.text, "Marie Curie");
3157 assert_eq!(entity.kb_id.as_deref(), Some("Q7186"));
3158 assert_eq!(
3159 entity.canonical_id,
3160 Some(crate::core::types::CanonicalId::new(42))
3161 );
3162 assert_eq!(
3163 entity.normalized.as_deref(),
3164 Some("Marie Salomea Skłodowska Curie")
3165 );
3166 assert!(entity.provenance.is_some());
3167 }
3168
3169 #[test]
3170 fn test_entity_builder_hierarchical() {
3171 let hc = HierarchicalConfidence::new(0.9, 0.8, 0.7);
3172 let entity = Entity::builder("test", EntityType::Person)
3173 .span(0, 4)
3174 .hierarchical_confidence(hc)
3175 .build();
3176
3177 assert!(entity.hierarchical_confidence.is_some());
3178 assert!((entity.linkage_confidence() - 0.9).abs() < 0.001);
3179 assert!((entity.type_confidence() - 0.8).abs() < 0.001);
3180 assert!((entity.boundary_confidence() - 0.7).abs() < 0.001);
3181 }
3182
3183 #[test]
3184 fn test_entity_builder_visual() {
3185 let bbox = Span::bbox(0.1, 0.2, 0.3, 0.4);
3186 let entity = Entity::builder("receipt item", EntityType::Money)
3187 .visual_span(bbox)
3188 .confidence(0.9)
3189 .build();
3190
3191 assert!(entity.is_visual());
3192 assert!(entity.visual_span.is_some());
3193 }
3194
3195 #[test]
3200 fn test_entity_hierarchical_confidence_helpers() {
3201 let mut entity = Entity::new("test", EntityType::Person, 0, 4, 0.8);
3202
3203 assert!((entity.linkage_confidence() - 0.8).abs() < 0.001);
3205 assert!((entity.type_confidence() - 0.8).abs() < 0.001);
3206 assert!((entity.boundary_confidence() - 0.8).abs() < 0.001);
3207
3208 entity.set_hierarchical_confidence(HierarchicalConfidence::new(0.95, 0.85, 0.75));
3210 assert!((entity.linkage_confidence() - 0.95).abs() < 0.001);
3211 assert!((entity.type_confidence() - 0.85).abs() < 0.001);
3212 assert!((entity.boundary_confidence() - 0.75).abs() < 0.001);
3213 }
3214
3215 #[test]
3216 fn test_entity_from_visual() {
3217 let entity = Entity::from_visual(
3218 "receipt total",
3219 EntityType::Money,
3220 Span::bbox(0.5, 0.8, 0.2, 0.05),
3221 0.92,
3222 );
3223
3224 assert!(entity.is_visual());
3225 assert_eq!(entity.start(), 0);
3226 assert_eq!(entity.end(), 0);
3227 assert!((entity.confidence - 0.92).abs() < f64::EPSILON);
3228 }
3229
3230 #[test]
3231 fn test_entity_span_helpers() {
3232 let entity = Entity::new("test", EntityType::Person, 10, 20, 0.9);
3233 assert_eq!(entity.text_span(), (10, 20));
3234 assert_eq!(entity.span_len(), 10);
3235 }
3236
3237 #[test]
3242 fn test_provenance_pattern() {
3243 let prov = Provenance::pattern("EMAIL");
3244 assert_eq!(prov.method, ExtractionMethod::Pattern);
3245 assert_eq!(prov.pattern.as_deref(), Some("EMAIL"));
3246 assert_eq!(prov.raw_confidence, Some(Confidence::new(1.0))); }
3248
3249 #[test]
3250 fn test_provenance_ml() {
3251 let prov = Provenance::ml("bert-ner", 0.87);
3252 assert_eq!(prov.method, ExtractionMethod::Neural);
3253 assert_eq!(prov.source.as_ref(), "bert-ner");
3254 assert_eq!(prov.raw_confidence, Some(Confidence::new(0.87)));
3255 }
3256
3257 #[test]
3258 fn test_provenance_with_version() {
3259 let prov = Provenance::ml("gliner", 0.92).with_version("v2.1.0");
3260
3261 assert_eq!(prov.model_version.as_deref(), Some("v2.1.0"));
3262 assert_eq!(prov.source.as_ref(), "gliner");
3263 }
3264
3265 #[test]
3266 fn test_provenance_with_timestamp() {
3267 let prov = Provenance::pattern("DATE").with_timestamp("2024-01-15T10:30:00Z");
3268
3269 assert_eq!(prov.timestamp.as_deref(), Some("2024-01-15T10:30:00Z"));
3270 }
3271
3272 #[test]
3273 fn test_provenance_builder_chain() {
3274 let prov = Provenance::ml("modernbert-ner", 0.95)
3275 .with_version("v1.0.0")
3276 .with_timestamp("2024-11-27T12:00:00Z");
3277
3278 assert_eq!(prov.method, ExtractionMethod::Neural);
3279 assert_eq!(prov.source.as_ref(), "modernbert-ner");
3280 assert_eq!(prov.raw_confidence, Some(Confidence::new(0.95)));
3281 assert_eq!(prov.model_version.as_deref(), Some("v1.0.0"));
3282 assert_eq!(prov.timestamp.as_deref(), Some("2024-11-27T12:00:00Z"));
3283 }
3284
3285 #[test]
3286 fn test_provenance_serialization() {
3287 let prov = Provenance::ml("test", 0.9)
3288 .with_version("v1.0")
3289 .with_timestamp("2024-01-01");
3290
3291 let json = serde_json::to_string(&prov).unwrap();
3292 assert!(json.contains("model_version"));
3293 assert!(json.contains("v1.0"));
3294
3295 let restored: Provenance = serde_json::from_str(&json).unwrap();
3296 assert_eq!(restored.model_version.as_deref(), Some("v1.0"));
3297 assert_eq!(restored.timestamp.as_deref(), Some("2024-01-01"));
3298 }
3299
3300 #[test]
3301 fn entity_serde_roundtrip_no_temporal_fields() {
3302 let entity = Entity::new("Berlin", EntityType::Location, 0, 6, 0.95);
3303 let json = serde_json::to_string(&entity).unwrap();
3304 assert!(!json.contains("valid_from"));
3306 assert!(!json.contains("valid_until"));
3307 assert!(!json.contains("phi_features"));
3308 let recovered: Entity = serde_json::from_str(&json).unwrap();
3310 assert_eq!(recovered.text, "Berlin");
3311 assert_eq!(recovered.start(), 0);
3312 assert_eq!(recovered.end(), 6);
3313 }
3314
3315 #[test]
3316 fn entity_deserialize_ignores_unknown_fields() {
3317 let json = r#"{"text":"Berlin","entity_type":"LOC","start":0,"end":6,"confidence":0.95,"valid_from":null,"phi_features":null}"#;
3318 let entity: Entity = serde_json::from_str(json).unwrap();
3319 assert_eq!(entity.text, "Berlin");
3320 }
3321}
3322
3323#[cfg(test)]
3324mod proptests {
3325 #![allow(clippy::unwrap_used)] use super::*;
3327 use proptest::prelude::*;
3328
3329 proptest! {
3330 #[test]
3331 fn confidence_always_clamped(conf in -10.0f64..10.0) {
3332 let e = Entity::new("test", EntityType::Person, 0, 4, conf);
3333 prop_assert!(e.confidence >= 0.0);
3334 prop_assert!(e.confidence <= 1.0);
3335 }
3336
3337 #[test]
3338 fn entity_type_roundtrip(label in "[A-Z]{3,10}") {
3339 let et = EntityType::from_label(&label);
3340 let back = EntityType::from_label(et.as_label());
3341 let is_custom = matches!(back, EntityType::Custom { .. });
3343 prop_assert!(is_custom || back == et);
3344 }
3345
3346 #[test]
3347 fn overlap_is_symmetric(
3348 s1 in 0usize..100,
3349 len1 in 1usize..50,
3350 s2 in 0usize..100,
3351 len2 in 1usize..50,
3352 ) {
3353 let e1 = Entity::new("a", EntityType::Person, s1, s1 + len1, 1.0);
3354 let e2 = Entity::new("b", EntityType::Person, s2, s2 + len2, 1.0);
3355 prop_assert_eq!(e1.overlaps(&e2), e2.overlaps(&e1));
3356 }
3357
3358 #[test]
3359 fn overlap_ratio_bounded(
3360 s1 in 0usize..100,
3361 len1 in 1usize..50,
3362 s2 in 0usize..100,
3363 len2 in 1usize..50,
3364 ) {
3365 let e1 = Entity::new("a", EntityType::Person, s1, s1 + len1, 1.0);
3366 let e2 = Entity::new("b", EntityType::Person, s2, s2 + len2, 1.0);
3367 let ratio = e1.overlap_ratio(&e2);
3368 prop_assert!(ratio >= 0.0);
3369 prop_assert!(ratio <= 1.0);
3370 }
3371
3372 #[test]
3373 fn self_overlap_ratio_is_one(s in 0usize..100, len in 1usize..50) {
3374 let e = Entity::new("test", EntityType::Person, s, s + len, 1.0);
3375 let ratio = e.overlap_ratio(&e);
3376 prop_assert!((ratio - 1.0).abs() < 1e-10);
3377 }
3378
3379 #[test]
3380 fn hierarchical_confidence_always_clamped(
3381 linkage in -2.0f32..2.0,
3382 type_score in -2.0f32..2.0,
3383 boundary in -2.0f32..2.0,
3384 ) {
3385 let hc = HierarchicalConfidence::new(linkage, type_score, boundary);
3386 prop_assert!(hc.linkage >= 0.0 && hc.linkage <= 1.0);
3387 prop_assert!(hc.type_score >= 0.0 && hc.type_score <= 1.0);
3388 prop_assert!(hc.boundary >= 0.0 && hc.boundary <= 1.0);
3389 prop_assert!(hc.combined() >= 0.0 && hc.combined() <= 1.0);
3390 }
3391
3392 #[test]
3393 fn span_candidate_width_consistent(
3394 doc in 0u32..10,
3395 start in 0u32..100,
3396 end in 1u32..100,
3397 ) {
3398 let actual_end = start.max(end);
3399 let sc = SpanCandidate::new(doc, start, actual_end);
3400 prop_assert_eq!(sc.width(), actual_end.saturating_sub(start));
3401 }
3402
3403 #[test]
3404 fn ragged_batch_preserves_tokens(
3405 seq_lens in proptest::collection::vec(1usize..10, 1..5),
3406 ) {
3407 let mut counter = 0u32;
3409 let seqs: Vec<Vec<u32>> = seq_lens.iter().map(|&len| {
3410 let seq: Vec<u32> = (counter..counter + len as u32).collect();
3411 counter += len as u32;
3412 seq
3413 }).collect();
3414
3415 let batch = RaggedBatch::from_sequences(&seqs);
3416
3417 prop_assert_eq!(batch.batch_size(), seqs.len());
3419 prop_assert_eq!(batch.total_tokens(), seq_lens.iter().sum::<usize>());
3420
3421 for (i, seq) in seqs.iter().enumerate() {
3423 let doc_tokens = batch.doc_tokens(i).unwrap();
3424 prop_assert_eq!(doc_tokens, seq.as_slice());
3425 }
3426 }
3427
3428 #[test]
3429 fn span_text_offsets_consistent(start in 0usize..100, len in 0usize..50) {
3430 let end = start + len;
3431 let span = Span::text(start, end);
3432 let (s, e) = span.text_offsets().unwrap();
3433 prop_assert_eq!(s, start);
3434 prop_assert_eq!(e, end);
3435 prop_assert_eq!(span.len(), len);
3436 }
3437
3438 #[test]
3444 fn entity_span_validity(
3445 start in 0usize..10000,
3446 len in 1usize..500,
3447 conf in 0.0f64..=1.0,
3448 ) {
3449 let end = start + len;
3450 let text_content: String = "x".repeat(end);
3452 let entity_text: String = text_content.chars().skip(start).take(len).collect();
3453 let e = Entity::new(&entity_text, EntityType::Person, start, end, conf);
3454 let issues = e.validate(&text_content);
3455 for issue in &issues {
3457 match issue {
3458 ValidationIssue::InvalidSpan { .. } => {
3459 prop_assert!(false, "start < end should never produce InvalidSpan");
3460 }
3461 ValidationIssue::SpanOutOfBounds { .. } => {
3462 prop_assert!(false, "span within text should never produce SpanOutOfBounds");
3463 }
3464 _ => {} }
3466 }
3467 }
3468
3469 #[test]
3471 fn entity_type_label_roundtrip_standard(
3472 idx in 0usize..13,
3473 ) {
3474 let standard_types = [
3475 EntityType::Person,
3476 EntityType::Organization,
3477 EntityType::Location,
3478 EntityType::Date,
3479 EntityType::Time,
3480 EntityType::Money,
3481 EntityType::Percent,
3482 EntityType::Quantity,
3483 EntityType::Cardinal,
3484 EntityType::Ordinal,
3485 EntityType::Email,
3486 EntityType::Url,
3487 EntityType::Phone,
3488 ];
3489 let et = &standard_types[idx];
3490 let label = et.as_label();
3491 let roundtripped = EntityType::from_label(label);
3492 prop_assert_eq!(&roundtripped, et,
3493 "from_label(as_label()) must roundtrip for {:?} (label={:?})", et, label);
3494 }
3495
3496 #[test]
3498 fn span_containment_property(
3499 a_start in 0usize..5000,
3500 a_len in 1usize..5000,
3501 b_offset in 0usize..5000,
3502 b_len in 1usize..5000,
3503 ) {
3504 let a_end = a_start + a_len;
3505 let b_start = a_start + (b_offset % a_len); let b_end_candidate = b_start + b_len;
3507
3508 if b_start >= a_start && b_end_candidate <= a_end {
3510 prop_assert!(a_start <= b_start);
3512 prop_assert!(a_end >= b_end_candidate);
3513
3514 let ea = Entity::new("a", EntityType::Person, a_start, a_end, 1.0);
3516 let eb = Entity::new("b", EntityType::Person, b_start, b_end_candidate, 1.0);
3517 prop_assert!(ea.overlaps(&eb),
3518 "containing span must overlap contained span");
3519 }
3520 }
3521
3522 #[test]
3524 fn entity_serde_roundtrip(
3525 start in 0usize..10000,
3526 len in 1usize..500,
3527 conf in 0.0f64..=1.0,
3528 type_idx in 0usize..5,
3529 ) {
3530 let end = start + len;
3531 let types = [
3532 EntityType::Person,
3533 EntityType::Organization,
3534 EntityType::Location,
3535 EntityType::Date,
3536 EntityType::Email,
3537 ];
3538 let et = types[type_idx].clone();
3539 let text = format!("entity_{}", start);
3540 let e = Entity::new(&text, et, start, end, conf);
3541
3542 let json = serde_json::to_string(&e).unwrap();
3543 let e2: Entity = serde_json::from_str(&json).unwrap();
3544
3545 prop_assert_eq!(&e.text, &e2.text);
3546 prop_assert_eq!(&e.entity_type, &e2.entity_type);
3547 prop_assert_eq!(e.start(), e2.start());
3548 prop_assert_eq!(e.end(), e2.end());
3549 prop_assert!((e.confidence - e2.confidence).abs() < 1e-10,
3551 "confidence roundtrip: {} vs {}", e.confidence, e2.confidence);
3552 prop_assert_eq!(&e.normalized, &e2.normalized);
3553 prop_assert_eq!(&e.kb_id, &e2.kb_id);
3554 }
3555
3556 #[test]
3559 fn discontinuous_span_total_length(
3560 segments in proptest::collection::vec(
3561 (0usize..5000, 1usize..500),
3562 1..6
3563 ),
3564 ) {
3565 let ranges: Vec<std::ops::Range<usize>> = segments.iter()
3566 .map(|&(start, len)| start..start + len)
3567 .collect();
3568 let span = DiscontinuousSpan::new(ranges);
3569 let expected_sum: usize = span.segments().iter().map(|r| r.end - r.start).sum();
3571 prop_assert_eq!(span.total_len(), expected_sum,
3572 "total_len must equal sum of merged segment lengths");
3573 for w in span.segments().windows(2) {
3575 prop_assert!(w[0].end <= w[1].start,
3576 "segments must not overlap: {:?} vs {:?}", w[0], w[1]);
3577 }
3578 }
3579 }
3580
3581 #[test]
3586 fn test_entity_category_requires_ml() {
3587 assert!(EntityCategory::Agent.requires_ml());
3588 assert!(EntityCategory::Organization.requires_ml());
3589 assert!(EntityCategory::Place.requires_ml());
3590 assert!(EntityCategory::Creative.requires_ml());
3591 assert!(EntityCategory::Relation.requires_ml());
3592
3593 assert!(!EntityCategory::Temporal.requires_ml());
3594 assert!(!EntityCategory::Numeric.requires_ml());
3595 assert!(!EntityCategory::Contact.requires_ml());
3596 assert!(!EntityCategory::Misc.requires_ml());
3597 }
3598
3599 #[test]
3600 fn test_entity_category_pattern_detectable() {
3601 assert!(EntityCategory::Temporal.pattern_detectable());
3602 assert!(EntityCategory::Numeric.pattern_detectable());
3603 assert!(EntityCategory::Contact.pattern_detectable());
3604
3605 assert!(!EntityCategory::Agent.pattern_detectable());
3606 assert!(!EntityCategory::Organization.pattern_detectable());
3607 assert!(!EntityCategory::Place.pattern_detectable());
3608 assert!(!EntityCategory::Creative.pattern_detectable());
3609 assert!(!EntityCategory::Relation.pattern_detectable());
3610 assert!(!EntityCategory::Misc.pattern_detectable());
3611 }
3612
3613 #[test]
3614 fn test_entity_category_is_relation() {
3615 assert!(EntityCategory::Relation.is_relation());
3616
3617 assert!(!EntityCategory::Agent.is_relation());
3618 assert!(!EntityCategory::Organization.is_relation());
3619 assert!(!EntityCategory::Place.is_relation());
3620 assert!(!EntityCategory::Temporal.is_relation());
3621 assert!(!EntityCategory::Numeric.is_relation());
3622 assert!(!EntityCategory::Contact.is_relation());
3623 assert!(!EntityCategory::Creative.is_relation());
3624 assert!(!EntityCategory::Misc.is_relation());
3625 }
3626
3627 #[test]
3628 fn test_entity_category_as_str() {
3629 assert_eq!(EntityCategory::Agent.as_str(), "agent");
3630 assert_eq!(EntityCategory::Organization.as_str(), "organization");
3631 assert_eq!(EntityCategory::Place.as_str(), "place");
3632 assert_eq!(EntityCategory::Creative.as_str(), "creative");
3633 assert_eq!(EntityCategory::Temporal.as_str(), "temporal");
3634 assert_eq!(EntityCategory::Numeric.as_str(), "numeric");
3635 assert_eq!(EntityCategory::Contact.as_str(), "contact");
3636 assert_eq!(EntityCategory::Relation.as_str(), "relation");
3637 assert_eq!(EntityCategory::Misc.as_str(), "misc");
3638 }
3639
3640 #[test]
3641 fn test_entity_category_display() {
3642 assert_eq!(format!("{}", EntityCategory::Agent), "agent");
3643 assert_eq!(format!("{}", EntityCategory::Temporal), "temporal");
3644 assert_eq!(format!("{}", EntityCategory::Relation), "relation");
3645 }
3646
3647 #[test]
3652 fn test_entity_type_serializes_to_flat_string() {
3653 assert_eq!(
3654 serde_json::to_string(&EntityType::Person).unwrap(),
3655 r#""PER""#
3656 );
3657 assert_eq!(
3658 serde_json::to_string(&EntityType::Organization).unwrap(),
3659 r#""ORG""#
3660 );
3661 assert_eq!(
3662 serde_json::to_string(&EntityType::Location).unwrap(),
3663 r#""LOC""#
3664 );
3665 assert_eq!(
3666 serde_json::to_string(&EntityType::Date).unwrap(),
3667 r#""DATE""#
3668 );
3669 assert_eq!(
3670 serde_json::to_string(&EntityType::Money).unwrap(),
3671 r#""MONEY""#
3672 );
3673 }
3674
3675 #[test]
3676 fn test_custom_entity_type_serializes_flat() {
3677 let misc = EntityType::custom("MISC", EntityCategory::Misc);
3678 assert_eq!(serde_json::to_string(&misc).unwrap(), r#""MISC""#);
3679
3680 let disease = EntityType::custom("DISEASE", EntityCategory::Agent);
3681 assert_eq!(serde_json::to_string(&disease).unwrap(), r#""DISEASE""#);
3682 }
3683
3684 #[test]
3685 fn test_entity_type_deserializes_from_flat_string() {
3686 let per: EntityType = serde_json::from_str(r#""PER""#).unwrap();
3687 assert_eq!(per, EntityType::Person);
3688
3689 let org: EntityType = serde_json::from_str(r#""ORG""#).unwrap();
3690 assert_eq!(org, EntityType::Organization);
3691
3692 let misc: EntityType = serde_json::from_str(r#""MISC""#).unwrap();
3693 assert_eq!(misc, EntityType::custom("MISC", EntityCategory::Misc));
3694 }
3695
3696 #[test]
3697 fn test_entity_type_deserializes_backward_compat_custom() {
3698 let json = r#"{"Custom":{"name":"MISC","category":"Misc"}}"#;
3700 let et: EntityType = serde_json::from_str(json).unwrap();
3701 assert_eq!(et, EntityType::custom("MISC", EntityCategory::Misc));
3702 }
3703
3704 #[test]
3705 fn test_entity_type_deserializes_backward_compat_other() {
3706 let json = r#"{"Other":"foo"}"#;
3708 let et: EntityType = serde_json::from_str(json).unwrap();
3709 assert_eq!(et, EntityType::custom("foo", EntityCategory::Misc));
3710 }
3711
3712 #[test]
3713 fn test_entity_type_serde_roundtrip() {
3714 let types = vec![
3715 EntityType::Person,
3716 EntityType::Organization,
3717 EntityType::Location,
3718 EntityType::Date,
3719 EntityType::Time,
3720 EntityType::Money,
3721 EntityType::Percent,
3722 EntityType::Quantity,
3723 EntityType::Cardinal,
3724 EntityType::Ordinal,
3725 EntityType::Email,
3726 EntityType::Url,
3727 EntityType::Phone,
3728 EntityType::custom("MISC", EntityCategory::Misc),
3729 EntityType::custom("DISEASE", EntityCategory::Agent),
3730 ];
3731
3732 for t in &types {
3733 let json = serde_json::to_string(t).unwrap();
3734 let back: EntityType = serde_json::from_str(&json).unwrap();
3735 assert_eq!(
3738 t.as_label(),
3739 back.as_label(),
3740 "roundtrip failed for {:?}",
3741 t
3742 );
3743 }
3744 }
3745}