1use super::entity::{
90 DiscontinuousSpan, Entity, EntityType, HierarchicalConfidence, Provenance, Span,
91};
92use serde::{Deserialize, Serialize};
93use std::collections::HashMap;
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
121pub enum Modality {
122 Iconic,
125 #[default]
128 Symbolic,
129 Hybrid,
132}
133
134impl Modality {
135 #[must_use]
137 pub const fn supports_linguistic_features(&self) -> bool {
138 matches!(self, Self::Symbolic | Self::Hybrid)
139 }
140
141 #[must_use]
143 pub const fn supports_geometric_features(&self) -> bool {
144 matches!(self, Self::Iconic | Self::Hybrid)
145 }
146}
147
148#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
183pub enum Location {
184 Text {
186 start: usize,
188 end: usize,
190 },
191 BoundingBox {
193 x: f32,
195 y: f32,
197 width: f32,
199 height: f32,
201 page: Option<u32>,
203 },
204 Temporal {
206 start_sec: f64,
208 end_sec: f64,
210 frame: Option<u64>,
212 },
213 Cuboid {
215 center: [f32; 3],
217 dimensions: [f32; 3],
219 rotation: [f32; 4],
221 },
222 Genomic {
224 contig: String,
226 start: u64,
228 end: u64,
230 strand: Option<char>,
232 },
233 Discontinuous {
235 segments: Vec<(usize, usize)>,
237 },
238 TextWithBbox {
240 start: usize,
242 end: usize,
244 bbox: Box<Location>,
246 },
247}
248
249impl Location {
250 #[must_use]
252 pub const fn text(start: usize, end: usize) -> Self {
253 Self::Text { start, end }
254 }
255
256 #[must_use]
258 pub fn bbox(x: f32, y: f32, width: f32, height: f32) -> Self {
259 Self::BoundingBox {
260 x,
261 y,
262 width,
263 height,
264 page: None,
265 }
266 }
267
268 #[must_use]
270 pub const fn modality(&self) -> Modality {
271 match self {
272 Self::Text { .. } | Self::Genomic { .. } | Self::Discontinuous { .. } => {
273 Modality::Symbolic
274 }
275 Self::BoundingBox { .. } | Self::Cuboid { .. } => Modality::Iconic,
276 Self::Temporal { .. } => Modality::Iconic, Self::TextWithBbox { .. } => Modality::Hybrid,
278 }
279 }
280
281 #[must_use]
283 pub fn text_offsets(&self) -> Option<(usize, usize)> {
284 match self {
285 Self::Text { start, end } => Some((*start, *end)),
286 Self::TextWithBbox { start, end, .. } => Some((*start, *end)),
287 Self::Discontinuous { segments } => {
288 let start = segments.iter().map(|(s, _)| *s).min()?;
289 let end = segments.iter().map(|(_, e)| *e).max()?;
290 Some((start, end))
291 }
292 _ => None,
293 }
294 }
295
296 #[must_use]
298 pub fn overlaps(&self, other: &Self) -> bool {
299 match (self, other) {
300 (Self::Text { start: s1, end: e1 }, Self::Text { start: s2, end: e2 }) => {
301 s1 < e2 && s2 < e1
302 }
303 (
304 Self::BoundingBox {
305 x: x1,
306 y: y1,
307 width: w1,
308 height: h1,
309 page: p1,
310 },
311 Self::BoundingBox {
312 x: x2,
313 y: y2,
314 width: w2,
315 height: h2,
316 page: p2,
317 },
318 ) => {
319 if p1 != p2 {
321 return false;
322 }
323 x1 < &(x2 + w2) && &(x1 + w1) > x2 && y1 < &(y2 + h2) && &(y1 + h1) > y2
325 }
326 _ => false, }
328 }
329
330 #[must_use]
334 pub fn iou(&self, other: &Self) -> Option<f64> {
335 match (self, other) {
336 (Self::Text { start: s1, end: e1 }, Self::Text { start: s2, end: e2 }) => {
337 let intersection_start = (*s1).max(*s2);
338 let intersection_end = (*e1).min(*e2);
339 if intersection_start >= intersection_end {
340 return Some(0.0);
341 }
342 let intersection = (intersection_end - intersection_start) as f64;
343 let union = ((*e1).max(*e2) - (*s1).min(*s2)) as f64;
344 if union == 0.0 {
345 Some(0.0)
346 } else {
347 Some(intersection / union)
348 }
349 }
350 (
351 Self::BoundingBox {
352 x: x1,
353 y: y1,
354 width: w1,
355 height: h1,
356 page: p1,
357 },
358 Self::BoundingBox {
359 x: x2,
360 y: y2,
361 width: w2,
362 height: h2,
363 page: p2,
364 },
365 ) => {
366 if p1 != p2 {
367 return Some(0.0);
368 }
369 let x_overlap = (x1 + w1).min(x2 + w2) - x1.max(*x2);
370 let y_overlap = (y1 + h1).min(y2 + h2) - y1.max(*y2);
371 if x_overlap <= 0.0 || y_overlap <= 0.0 {
372 return Some(0.0);
373 }
374 let intersection = (x_overlap * y_overlap) as f64;
375 let area1 = (*w1 * *h1) as f64;
376 let area2 = (*w2 * *h2) as f64;
377 let union = area1 + area2 - intersection;
378 if union == 0.0 {
379 Some(0.0)
380 } else {
381 Some(intersection / union)
382 }
383 }
384 _ => None,
385 }
386 }
387}
388
389impl Default for Location {
390 fn default() -> Self {
391 Self::Text { start: 0, end: 0 }
392 }
393}
394
395impl From<&Span> for Location {
396 fn from(span: &Span) -> Self {
397 match span {
398 Span::Text { start, end } => Self::Text {
399 start: *start,
400 end: *end,
401 },
402 Span::BoundingBox {
403 x,
404 y,
405 width,
406 height,
407 page,
408 } => Self::BoundingBox {
409 x: *x,
410 y: *y,
411 width: *width,
412 height: *height,
413 page: *page,
414 },
415 Span::Hybrid { start, end, bbox } => Self::TextWithBbox {
416 start: *start,
417 end: *end,
418 bbox: Box::new(Location::from(bbox.as_ref())),
419 },
420 }
421 }
422}
423
424impl From<Span> for Location {
425 fn from(span: Span) -> Self {
426 Self::from(&span)
427 }
428}
429
430impl Location {
439 #[must_use]
444 pub fn to_span(&self) -> Option<Span> {
445 match self {
446 Self::Text { start, end } => Some(Span::Text {
447 start: *start,
448 end: *end,
449 }),
450 Self::BoundingBox {
451 x,
452 y,
453 width,
454 height,
455 page,
456 } => Some(Span::BoundingBox {
457 x: *x,
458 y: *y,
459 width: *width,
460 height: *height,
461 page: *page,
462 }),
463 Self::TextWithBbox { start, end, bbox } => {
464 let inner_span = bbox.to_span()?;
465 Some(Span::Hybrid {
466 start: *start,
467 end: *end,
468 bbox: Box::new(inner_span),
469 })
470 }
471 Self::Temporal { .. }
473 | Self::Cuboid { .. }
474 | Self::Genomic { .. }
475 | Self::Discontinuous { .. } => None,
476 }
477 }
478}
479
480pub use super::types::SignalId;
486
487#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
517pub struct Signal<L = Location> {
518 pub id: SignalId,
520 pub location: L,
522 pub surface: String,
524 pub label: super::types::TypeLabel,
528 pub confidence: f32,
530 pub hierarchical: Option<HierarchicalConfidence>,
532 pub provenance: Option<Provenance>,
534 pub modality: Modality,
536 pub normalized: Option<String>,
538 pub negated: bool,
540 pub quantifier: Option<Quantifier>,
542}
543
544#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
549pub enum Quantifier {
550 Universal,
552 Existential,
554 None,
556 Definite,
558 Bare,
560}
561
562impl<L> Signal<L> {
563 #[must_use]
573 pub fn new(
574 id: impl Into<SignalId>,
575 location: L,
576 surface: impl Into<String>,
577 label: impl Into<super::types::TypeLabel>,
578 confidence: f32,
579 ) -> Self {
580 Self {
581 id: id.into(),
582 location,
583 surface: surface.into(),
584 label: label.into(),
585 confidence: confidence.clamp(0.0, 1.0),
586 hierarchical: None,
587 provenance: None,
588 modality: Modality::default(),
589 normalized: None,
590 negated: false,
591 quantifier: None,
592 }
593 }
594
595 #[must_use]
597 pub fn label(&self) -> &str {
598 self.label.as_str()
599 }
600
601 #[must_use]
603 pub fn type_label(&self) -> super::types::TypeLabel {
604 self.label.clone()
605 }
606
607 #[must_use]
609 pub fn surface(&self) -> &str {
610 &self.surface
611 }
612
613 #[must_use]
615 pub fn is_confident(&self, threshold: f32) -> bool {
616 self.confidence >= threshold
617 }
618
619 #[must_use]
621 pub fn with_modality(mut self, modality: Modality) -> Self {
622 self.modality = modality;
623 self
624 }
625
626 #[must_use]
628 pub fn negated(mut self) -> Self {
629 self.negated = true;
630 self
631 }
632
633 #[must_use]
635 pub fn with_quantifier(mut self, q: Quantifier) -> Self {
636 self.quantifier = Some(q);
637 self
638 }
639
640 #[must_use]
642 pub fn with_provenance(mut self, p: Provenance) -> Self {
643 self.provenance = Some(p);
644 self
645 }
646}
647
648impl Signal<Location> {
649 #[must_use]
651 pub fn text_offsets(&self) -> Option<(usize, usize)> {
652 self.location.text_offsets()
653 }
654
655 #[must_use]
672 pub fn validate_against(&self, source_text: &str) -> Option<SignalValidationError> {
673 let (start, end) = self.location.text_offsets()?;
674
675 let char_count = source_text.chars().count();
676
677 if end > char_count {
679 return Some(SignalValidationError::OutOfBounds {
680 signal_id: self.id,
681 end,
682 text_len: char_count,
683 });
684 }
685
686 if start >= end {
687 return Some(SignalValidationError::InvalidSpan {
688 signal_id: self.id,
689 start,
690 end,
691 });
692 }
693
694 let actual: String = source_text.chars().skip(start).take(end - start).collect();
696
697 if actual != self.surface {
698 return Some(SignalValidationError::TextMismatch {
699 signal_id: self.id,
700 expected: self.surface.clone(),
701 actual,
702 start,
703 end,
704 });
705 }
706
707 None
708 }
709
710 #[must_use]
712 pub fn is_valid(&self, source_text: &str) -> bool {
713 self.validate_against(source_text).is_none()
714 }
715
716 #[must_use]
731 pub fn from_text(
732 source: &str,
733 surface: &str,
734 label: impl Into<super::types::TypeLabel>,
735 confidence: f32,
736 ) -> Option<Self> {
737 Self::from_text_nth(source, surface, label, confidence, 0)
738 }
739
740 #[must_use]
742 pub fn from_text_nth(
743 source: &str,
744 surface: &str,
745 label: impl Into<super::types::TypeLabel>,
746 confidence: f32,
747 occurrence: usize,
748 ) -> Option<Self> {
749 for (count, (byte_idx, _)) in source.match_indices(surface).enumerate() {
751 if count == occurrence {
752 let start = source[..byte_idx].chars().count();
754 let end = start + surface.chars().count();
755
756 return Some(Self::new(
757 SignalId::ZERO,
758 Location::text(start, end),
759 surface,
760 label,
761 confidence,
762 ));
763 }
764 }
765
766 None
767 }
768}
769
770#[derive(Debug, Clone, PartialEq)]
772pub enum SignalValidationError {
773 OutOfBounds {
775 signal_id: SignalId,
777 end: usize,
779 text_len: usize,
781 },
782 InvalidSpan {
784 signal_id: SignalId,
786 start: usize,
788 end: usize,
790 },
791 TextMismatch {
793 signal_id: SignalId,
795 expected: String,
797 actual: String,
799 start: usize,
801 end: usize,
803 },
804}
805
806impl std::fmt::Display for SignalValidationError {
807 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
808 match self {
809 Self::OutOfBounds {
810 signal_id,
811 end,
812 text_len,
813 } => {
814 write!(
815 f,
816 "S{}: end offset {} exceeds text length {}",
817 signal_id, end, text_len
818 )
819 }
820 Self::InvalidSpan {
821 signal_id,
822 start,
823 end,
824 } => {
825 write!(f, "S{}: invalid span [{}, {})", signal_id, start, end)
826 }
827 Self::TextMismatch {
828 signal_id,
829 expected,
830 actual,
831 start,
832 end,
833 } => {
834 write!(
835 f,
836 "S{}: text mismatch at [{}, {}): expected '{}', found '{}'",
837 signal_id, start, end, expected, actual
838 )
839 }
840 }
841 }
842}
843
844impl std::error::Error for SignalValidationError {}
845
846pub use super::types::TrackId;
852
853#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
855pub struct SignalRef {
856 pub signal_id: SignalId,
858 pub position: u32,
860}
861
862#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
868pub struct TrackRef {
869 pub doc_id: String,
871 pub track_id: TrackId,
873}
874
875#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
895pub struct Track {
896 pub id: TrackId,
898 pub signals: Vec<SignalRef>,
900 pub entity_type: Option<super::types::TypeLabel>,
904 pub canonical_surface: String,
906 pub identity_id: Option<IdentityId>,
908 pub cluster_confidence: f32,
910 pub embedding: Option<Vec<f32>>,
913}
914
915impl Track {
916 #[must_use]
918 pub fn new(id: impl Into<TrackId>, canonical_surface: impl Into<String>) -> Self {
919 Self {
920 id: id.into(),
921 signals: Vec::new(),
922 entity_type: None,
923 canonical_surface: canonical_surface.into(),
924 identity_id: None,
925 cluster_confidence: 1.0,
926 embedding: None,
927 }
928 }
929
930 pub fn add_signal(&mut self, signal_id: impl Into<SignalId>, position: u32) {
932 let signal_id = signal_id.into();
933 self.signals.push(SignalRef {
934 signal_id,
935 position,
936 });
937 }
938
939 #[must_use]
941 pub fn len(&self) -> usize {
942 self.signals.len()
943 }
944
945 #[must_use]
947 pub fn is_empty(&self) -> bool {
948 self.signals.is_empty()
949 }
950
951 #[must_use]
953 pub fn is_singleton(&self) -> bool {
954 self.signals.len() == 1
955 }
956
957 #[must_use]
959 pub const fn id(&self) -> TrackId {
960 self.id
961 }
962
963 #[must_use]
965 pub fn signals(&self) -> &[SignalRef] {
966 &self.signals
967 }
968
969 #[must_use]
971 pub fn canonical_surface(&self) -> &str {
972 &self.canonical_surface
973 }
974
975 #[must_use]
977 pub const fn identity_id(&self) -> Option<IdentityId> {
978 self.identity_id
979 }
980
981 #[must_use]
983 pub const fn cluster_confidence(&self) -> f32 {
984 self.cluster_confidence
985 }
986
987 pub fn set_cluster_confidence(&mut self, confidence: f32) {
989 self.cluster_confidence = confidence.clamp(0.0, 1.0);
990 }
991
992 pub fn set_identity_id(&mut self, identity_id: IdentityId) {
994 self.identity_id = Some(identity_id);
995 }
996
997 pub fn clear_identity_id(&mut self) {
999 self.identity_id = None;
1000 }
1001
1002 #[must_use]
1004 pub fn with_identity(mut self, identity_id: IdentityId) -> Self {
1005 self.identity_id = Some(identity_id);
1006 self
1007 }
1008
1009 #[must_use]
1013 pub fn with_type(mut self, entity_type: impl Into<String>) -> Self {
1014 let s = entity_type.into();
1015 self.entity_type = Some(super::types::TypeLabel::from(s.as_str()));
1016 self
1017 }
1018
1019 #[must_use]
1033 pub fn with_type_label(mut self, label: super::types::TypeLabel) -> Self {
1034 self.entity_type = Some(label);
1035 self
1036 }
1037
1038 #[must_use]
1043 pub fn type_label(&self) -> Option<super::types::TypeLabel> {
1044 self.entity_type.clone()
1045 }
1046
1047 #[must_use]
1049 pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
1050 self.embedding = Some(embedding);
1051 self
1052 }
1053
1054 pub fn compute_spread(&self, doc: &GroundedDocument) -> Option<usize> {
1058 if self.signals.is_empty() {
1059 return Some(0);
1060 }
1061
1062 let positions: Vec<usize> = self
1063 .signals
1064 .iter()
1065 .filter_map(|sr| {
1066 doc.signals
1067 .iter()
1068 .find(|s| s.id == sr.signal_id)
1069 .and_then(|s| s.location.text_offsets())
1070 .map(|(start, _)| start)
1071 })
1072 .collect();
1073
1074 if positions.is_empty() {
1075 return None;
1076 }
1077
1078 let min_pos = *positions.iter().min().expect("positions non-empty");
1079 let max_pos = *positions.iter().max().expect("positions non-empty");
1080 Some(max_pos.saturating_sub(min_pos))
1081 }
1082
1083 pub fn collect_variations(&self, doc: &GroundedDocument) -> Vec<String> {
1087 let mut variations: std::collections::HashSet<String> = std::collections::HashSet::new();
1088
1089 for sr in &self.signals {
1090 if let Some(signal) = doc.signals.iter().find(|s| s.id == sr.signal_id) {
1091 variations.insert(signal.surface.clone());
1092 }
1093 }
1094
1095 variations.into_iter().collect()
1096 }
1097
1098 pub fn confidence_stats(&self, doc: &GroundedDocument) -> Option<(f32, f32, f32)> {
1102 let confidences: Vec<f32> = self
1103 .signals
1104 .iter()
1105 .filter_map(|sr| {
1106 doc.signals
1107 .iter()
1108 .find(|s| s.id == sr.signal_id)
1109 .map(|s| s.confidence)
1110 })
1111 .collect();
1112
1113 if confidences.is_empty() {
1114 return None;
1115 }
1116
1117 let min = confidences.iter().cloned().fold(f32::INFINITY, f32::min);
1118 let max = confidences
1119 .iter()
1120 .cloned()
1121 .fold(f32::NEG_INFINITY, f32::max);
1122 let mean = confidences.iter().sum::<f32>() / confidences.len() as f32;
1123
1124 Some((min, max, mean))
1125 }
1126
1127 pub fn compute_stats(&self, doc: &GroundedDocument, text_len: usize) -> TrackStats {
1131 let chain_length = self.signals.len();
1132 let spread = self.compute_spread(doc).unwrap_or(0);
1133 let variations = self.collect_variations(doc);
1134 let (min_conf, max_conf, mean_conf) = self.confidence_stats(doc).unwrap_or((0.0, 0.0, 0.0));
1135
1136 let positions: Vec<usize> = self
1138 .signals
1139 .iter()
1140 .filter_map(|sr| {
1141 doc.signals
1142 .iter()
1143 .find(|s| s.id == sr.signal_id)
1144 .and_then(|s| s.location.text_offsets())
1145 .map(|(start, _)| start)
1146 })
1147 .collect();
1148
1149 let first_position = positions.iter().min().copied().unwrap_or(0);
1150 let last_position = positions.iter().max().copied().unwrap_or(0);
1151 let relative_spread = if text_len > 0 {
1152 spread as f64 / text_len as f64
1153 } else {
1154 0.0
1155 };
1156
1157 TrackStats {
1158 chain_length,
1159 variation_count: variations.len(),
1160 variations,
1161 spread,
1162 relative_spread,
1163 first_position,
1164 last_position,
1165 min_confidence: min_conf,
1166 max_confidence: max_conf,
1167 mean_confidence: mean_conf,
1168 has_embedding: self.embedding.is_some(),
1169 }
1170 }
1171}
1172
1173#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1175pub struct TrackStats {
1176 pub chain_length: usize,
1178 pub variation_count: usize,
1180 pub variations: Vec<String>,
1182 pub spread: usize,
1184 pub relative_spread: f64,
1186 pub first_position: usize,
1188 pub last_position: usize,
1190 pub min_confidence: f32,
1192 pub max_confidence: f32,
1194 pub mean_confidence: f32,
1196 pub has_embedding: bool,
1198}
1199
1200pub use super::types::IdentityId;
1206
1207#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
1212pub enum IdentitySource {
1213 CrossDocCoref {
1216 track_refs: Vec<TrackRef>,
1218 },
1219 KnowledgeBase {
1222 kb_name: String,
1224 kb_id: String,
1226 },
1227 Hybrid {
1230 track_refs: Vec<TrackRef>,
1232 kb_name: String,
1234 kb_id: String,
1236 },
1237}
1238
1239#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1261pub struct Identity {
1262 pub id: IdentityId,
1264 pub canonical_name: String,
1266 pub entity_type: Option<super::types::TypeLabel>,
1270 pub kb_id: Option<String>,
1272 pub kb_name: Option<String>,
1274 pub description: Option<String>,
1276 pub embedding: Option<Vec<f32>>,
1279 pub aliases: Vec<String>,
1281 pub confidence: f32,
1283 #[serde(default, skip_serializing_if = "Option::is_none")]
1285 pub source: Option<IdentitySource>,
1286}
1287
1288impl Identity {
1289 #[must_use]
1291 pub fn new(id: impl Into<IdentityId>, canonical_name: impl Into<String>) -> Self {
1292 Self {
1293 id: id.into(),
1294 canonical_name: canonical_name.into(),
1295 entity_type: None,
1296 kb_id: None,
1297 kb_name: None,
1298 description: None,
1299 embedding: None,
1300 aliases: Vec::new(),
1301 confidence: 1.0,
1302 source: None,
1303 }
1304 }
1305
1306 #[must_use]
1308 pub fn from_kb(
1309 id: impl Into<IdentityId>,
1310 canonical_name: impl Into<String>,
1311 kb_name: impl Into<String>,
1312 kb_id: impl Into<String>,
1313 ) -> Self {
1314 let kb_name_str = kb_name.into();
1315 let kb_id_str = kb_id.into();
1316 Self {
1317 id: id.into(),
1318 canonical_name: canonical_name.into(),
1319 entity_type: None,
1320 kb_id: Some(kb_id_str.clone()),
1321 kb_name: Some(kb_name_str.clone()),
1322 description: None,
1323 embedding: None,
1324 aliases: Vec::new(),
1325 confidence: 1.0,
1326 source: Some(IdentitySource::KnowledgeBase {
1327 kb_name: kb_name_str,
1328 kb_id: kb_id_str,
1329 }),
1330 }
1331 }
1332
1333 pub fn add_alias(&mut self, alias: impl Into<String>) {
1335 self.aliases.push(alias.into());
1336 }
1337
1338 #[must_use]
1340 pub const fn id(&self) -> IdentityId {
1341 self.id
1342 }
1343
1344 #[must_use]
1346 pub fn canonical_name(&self) -> &str {
1347 &self.canonical_name
1348 }
1349
1350 #[must_use]
1352 pub fn kb_id(&self) -> Option<&str> {
1353 self.kb_id.as_deref()
1354 }
1355
1356 #[must_use]
1358 pub fn kb_name(&self) -> Option<&str> {
1359 self.kb_name.as_deref()
1360 }
1361
1362 #[must_use]
1364 pub fn aliases(&self) -> &[String] {
1365 &self.aliases
1366 }
1367
1368 #[must_use]
1370 pub const fn confidence(&self) -> f32 {
1371 self.confidence
1372 }
1373
1374 pub fn set_confidence(&mut self, confidence: f32) {
1376 self.confidence = confidence.clamp(0.0, 1.0);
1377 }
1378
1379 #[must_use]
1381 pub fn source(&self) -> Option<&IdentitySource> {
1382 self.source.as_ref()
1383 }
1384
1385 #[must_use]
1387 pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
1388 self.embedding = Some(embedding);
1389 self
1390 }
1391
1392 #[must_use]
1396 pub fn with_type(mut self, entity_type: impl Into<String>) -> Self {
1397 let s = entity_type.into();
1398 self.entity_type = Some(super::types::TypeLabel::from(s.as_str()));
1399 self
1400 }
1401
1402 #[must_use]
1407 pub fn with_type_label(mut self, label: super::types::TypeLabel) -> Self {
1408 self.entity_type = Some(label);
1409 self
1410 }
1411
1412 #[must_use]
1417 pub fn type_label(&self) -> Option<super::types::TypeLabel> {
1418 self.entity_type.clone()
1419 }
1420
1421 #[must_use]
1423 pub fn with_description(mut self, description: impl Into<String>) -> Self {
1424 self.description = Some(description.into());
1425 self
1426 }
1427
1428 }
1430
1431#[derive(Debug, Clone, Serialize, Deserialize)]
1494pub struct GroundedDocument {
1495 pub id: String,
1497 pub text: String,
1499 pub signals: Vec<Signal<Location>>,
1501 pub tracks: HashMap<TrackId, Track>,
1503 pub identities: HashMap<IdentityId, Identity>,
1505 signal_to_track: HashMap<SignalId, TrackId>,
1507 track_to_identity: HashMap<TrackId, IdentityId>,
1509 next_signal_id: SignalId,
1511 next_track_id: TrackId,
1513 next_identity_id: IdentityId,
1515}
1516
1517impl GroundedDocument {
1518 #[must_use]
1520 pub fn new(id: impl Into<String>, text: impl Into<String>) -> Self {
1521 Self {
1522 id: id.into(),
1523 text: text.into(),
1524 signals: Vec::new(),
1525 tracks: HashMap::new(),
1526 identities: HashMap::new(),
1527 signal_to_track: HashMap::new(),
1528 track_to_identity: HashMap::new(),
1529 next_signal_id: SignalId::ZERO,
1530 next_track_id: TrackId::ZERO,
1531 next_identity_id: IdentityId::ZERO,
1532 }
1533 }
1534
1535 pub fn add_signal(&mut self, mut signal: Signal<Location>) -> SignalId {
1541 let id = self.next_signal_id;
1542 signal.id = id;
1543 self.signals.push(signal);
1544 self.next_signal_id += 1;
1545 id
1546 }
1547
1548 #[must_use]
1550 pub fn get_signal(&self, id: impl Into<SignalId>) -> Option<&Signal<Location>> {
1551 let id = id.into();
1552 self.signals.iter().find(|s| s.id == id)
1553 }
1554
1555 pub fn signals(&self) -> &[Signal<Location>] {
1557 &self.signals
1558 }
1559
1560 pub fn add_track(&mut self, mut track: Track) -> TrackId {
1566 let id = self.next_track_id;
1567 track.id = id;
1568
1569 for signal_ref in &track.signals {
1571 self.signal_to_track.insert(signal_ref.signal_id, id);
1572 }
1573
1574 self.tracks.insert(id, track);
1575 self.next_track_id += 1;
1576 id
1577 }
1578
1579 #[must_use]
1581 pub fn get_track(&self, id: impl Into<TrackId>) -> Option<&Track> {
1582 self.tracks.get(&id.into())
1583 }
1584
1585 #[must_use]
1587 pub fn get_track_mut(&mut self, id: impl Into<TrackId>) -> Option<&mut Track> {
1588 self.tracks.get_mut(&id.into())
1589 }
1590
1591 pub fn add_signal_to_track(
1596 &mut self,
1597 signal_id: impl Into<SignalId>,
1598 track_id: impl Into<TrackId>,
1599 position: u32,
1600 ) -> bool {
1601 let signal_id = signal_id.into();
1602 let track_id = track_id.into();
1603 if let Some(track) = self.tracks.get_mut(&track_id) {
1604 track.add_signal(signal_id, position);
1605 self.signal_to_track.insert(signal_id, track_id);
1606 true
1607 } else {
1608 false
1609 }
1610 }
1611
1612 #[must_use]
1614 pub fn track_for_signal(&self, signal_id: SignalId) -> Option<&Track> {
1615 let track_id = self.signal_to_track.get(&signal_id)?;
1616 self.tracks.get(track_id)
1617 }
1618
1619 pub fn tracks(&self) -> impl Iterator<Item = &Track> {
1621 self.tracks.values()
1622 }
1623
1624 pub fn add_identity(&mut self, mut identity: Identity) -> IdentityId {
1630 let id = self.next_identity_id;
1631 identity.id = id;
1632 self.identities.insert(id, identity);
1633 self.next_identity_id += 1;
1634 id
1635 }
1636
1637 pub fn link_track_to_identity(
1639 &mut self,
1640 track_id: impl Into<TrackId>,
1641 identity_id: impl Into<IdentityId>,
1642 ) {
1643 let track_id = track_id.into();
1644 let identity_id = identity_id.into();
1645 if let Some(track) = self.tracks.get_mut(&track_id) {
1646 track.identity_id = Some(identity_id);
1647 self.track_to_identity.insert(track_id, identity_id);
1648 }
1649 }
1650
1651 #[must_use]
1653 pub fn get_identity(&self, id: IdentityId) -> Option<&Identity> {
1654 self.identities.get(&id)
1655 }
1656
1657 #[must_use]
1659 pub fn identity_for_track(&self, track_id: TrackId) -> Option<&Identity> {
1660 let identity_id = self.track_to_identity.get(&track_id)?;
1661 self.identities.get(identity_id)
1662 }
1663
1664 #[must_use]
1666 pub fn identity_for_signal(&self, signal_id: SignalId) -> Option<&Identity> {
1667 let track_id = self.signal_to_track.get(&signal_id)?;
1668 self.identity_for_track(*track_id)
1669 }
1670
1671 pub fn identities(&self) -> impl Iterator<Item = &Identity> {
1673 self.identities.values()
1674 }
1675
1676 #[must_use]
1681 pub fn track_ref(&self, track_id: TrackId) -> Option<TrackRef> {
1682 if self.tracks.contains_key(&track_id) {
1684 Some(TrackRef {
1685 doc_id: self.id.clone(),
1686 track_id,
1687 })
1688 } else {
1689 None
1690 }
1691 }
1692
1693 #[must_use]
1699 pub fn to_entities(&self) -> Vec<Entity> {
1700 self.signals
1701 .iter()
1702 .map(|signal| {
1703 let (start, end) = signal.location.text_offsets().unwrap_or((0, 0));
1704 let track = self.track_for_signal(signal.id);
1705 let identity = track.and_then(|t| self.identity_for_track(t.id));
1706
1707 Entity {
1708 text: signal.surface.clone(),
1709 entity_type: EntityType::from_label(signal.label.as_str()),
1710 start,
1711 end,
1712 confidence: signal.confidence as f64,
1713 normalized: signal.normalized.clone(),
1714 provenance: signal.provenance.clone(),
1715 kb_id: identity.and_then(|i| i.kb_id.clone()),
1716 canonical_id: track.map(|t| super::types::CanonicalId::new(t.id.get())),
1717 hierarchical_confidence: signal.hierarchical,
1718 visual_span: match &signal.location {
1719 Location::BoundingBox {
1720 x,
1721 y,
1722 width,
1723 height,
1724 page,
1725 } => Some(Span::BoundingBox {
1726 x: *x,
1727 y: *y,
1728 width: *width,
1729 height: *height,
1730 page: *page,
1731 }),
1732 Location::TextWithBbox { bbox, .. } => {
1733 if let Location::BoundingBox {
1734 x,
1735 y,
1736 width,
1737 height,
1738 page,
1739 } = bbox.as_ref()
1740 {
1741 Some(Span::BoundingBox {
1742 x: *x,
1743 y: *y,
1744 width: *width,
1745 height: *height,
1746 page: *page,
1747 })
1748 } else {
1749 None
1750 }
1751 }
1752 _ => None,
1753 },
1754 discontinuous_span: match &signal.location {
1755 Location::Discontinuous { segments } => Some(DiscontinuousSpan::new(
1756 segments.iter().map(|(s, e)| (*s)..(*e)).collect(),
1757 )),
1758 _ => None,
1759 },
1760 valid_from: None,
1761 valid_until: None,
1762 viewport: None,
1763 }
1764 })
1765 .collect()
1766 }
1767
1768 #[must_use]
1770 pub fn from_entities(
1771 id: impl Into<String>,
1772 text: impl Into<String>,
1773 entities: &[Entity],
1774 ) -> Self {
1775 let mut doc = Self::new(id, text);
1776
1777 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1783 enum TrackKey {
1784 Canonical(super::types::CanonicalId),
1785 Singleton(usize),
1786 }
1787
1788 let mut tracks_map: HashMap<TrackKey, Vec<SignalId>> = HashMap::new();
1789 let mut signal_to_entity_idx: HashMap<SignalId, usize> = HashMap::new();
1790
1791 for (idx, entity) in entities.iter().enumerate() {
1792 let location = if let Some(disc) = &entity.discontinuous_span {
1793 Location::Discontinuous {
1794 segments: disc.segments().iter().map(|r| (r.start, r.end)).collect(),
1795 }
1796 } else if let Some(visual) = &entity.visual_span {
1797 Location::from(visual)
1798 } else {
1799 Location::text(entity.start, entity.end)
1800 };
1801
1802 let mut signal = Signal::new(
1803 SignalId::new(idx as u64),
1804 location,
1805 &entity.text,
1806 entity.entity_type.as_label(),
1807 entity.confidence as f32,
1808 );
1809 signal.normalized = entity.normalized.clone();
1810 signal.provenance = entity.provenance.clone();
1811 signal.hierarchical = entity.hierarchical_confidence;
1812
1813 let signal_id = doc.add_signal(signal);
1814 signal_to_entity_idx.insert(signal_id, idx);
1815
1816 let key = match entity.canonical_id {
1817 Some(cid) => TrackKey::Canonical(cid),
1818 None => TrackKey::Singleton(idx),
1819 };
1820 tracks_map.entry(key).or_default().push(signal_id);
1821 }
1822
1823 for (_key, signal_ids) in tracks_map {
1825 if let Some(first_signal) = signal_ids.first().and_then(|id| doc.get_signal(*id)) {
1826 let mut track = Track::new(doc.next_track_id, &first_signal.surface);
1827 track.entity_type =
1828 Some(super::types::TypeLabel::from(first_signal.label.as_str()));
1829
1830 for (pos, &signal_id) in signal_ids.iter().enumerate() {
1831 track.add_signal(signal_id, pos as u32);
1832 }
1833
1834 let kb_id = signal_ids.iter().find_map(|sid| {
1837 let ent_idx = signal_to_entity_idx.get(sid).copied()?;
1838 entities.get(ent_idx)?.kb_id.clone()
1839 });
1840 if let Some(kb_id) = kb_id {
1841 let identity = Identity::from_kb(
1842 doc.next_identity_id,
1843 &track.canonical_surface,
1844 "unknown",
1845 kb_id,
1846 );
1847 let identity_id = doc.add_identity(identity);
1848 track = track.with_identity(identity_id);
1849 }
1850
1851 doc.add_track(track);
1852 }
1853 }
1854
1855 doc
1856 }
1857
1858 #[must_use]
1860 pub fn signals_with_label(&self, label: &str) -> Vec<&Signal<Location>> {
1861 let want = super::types::TypeLabel::from(label);
1862 self.signals.iter().filter(|s| s.label == want).collect()
1863 }
1864
1865 #[must_use]
1867 pub fn confident_signals(&self, threshold: f32) -> Vec<&Signal<Location>> {
1868 self.signals
1869 .iter()
1870 .filter(|s| s.confidence >= threshold)
1871 .collect()
1872 }
1873
1874 pub fn linked_tracks(&self) -> impl Iterator<Item = &Track> {
1876 self.tracks.values().filter(|t| t.identity_id.is_some())
1877 }
1878
1879 pub fn unlinked_tracks(&self) -> impl Iterator<Item = &Track> {
1881 self.tracks.values().filter(|t| t.identity_id.is_none())
1882 }
1883
1884 #[must_use]
1886 pub fn untracked_signal_count(&self) -> usize {
1887 self.signals
1888 .iter()
1889 .filter(|s| !self.signal_to_track.contains_key(&s.id))
1890 .count()
1891 }
1892
1893 #[must_use]
1895 pub fn untracked_signals(&self) -> Vec<&Signal<Location>> {
1896 self.signals
1897 .iter()
1898 .filter(|s| !self.signal_to_track.contains_key(&s.id))
1899 .collect()
1900 }
1901
1902 #[must_use]
1908 pub fn signals_by_modality(&self, modality: Modality) -> Vec<&Signal<Location>> {
1909 self.signals
1910 .iter()
1911 .filter(|s| s.modality == modality)
1912 .collect()
1913 }
1914
1915 #[must_use]
1917 pub fn text_signals(&self) -> Vec<&Signal<Location>> {
1918 self.signals_by_modality(Modality::Symbolic)
1919 }
1920
1921 #[must_use]
1923 pub fn visual_signals(&self) -> Vec<&Signal<Location>> {
1924 self.signals_by_modality(Modality::Iconic)
1925 }
1926
1927 #[must_use]
1929 pub fn overlapping_signals(&self, location: &Location) -> Vec<&Signal<Location>> {
1930 self.signals
1931 .iter()
1932 .filter(|s| s.location.overlaps(location))
1933 .collect()
1934 }
1935
1936 #[must_use]
1938 pub fn signals_in_range(&self, start: usize, end: usize) -> Vec<&Signal<Location>> {
1939 self.signals
1940 .iter()
1941 .filter(|s| {
1942 if let Some((s_start, s_end)) = s.location.text_offsets() {
1943 s_start >= start && s_end <= end
1944 } else {
1945 false
1946 }
1947 })
1948 .collect()
1949 }
1950
1951 #[must_use]
1953 pub fn negated_signals(&self) -> Vec<&Signal<Location>> {
1954 self.signals.iter().filter(|s| s.negated).collect()
1955 }
1956
1957 #[must_use]
1959 pub fn quantified_signals(&self, quantifier: Quantifier) -> Vec<&Signal<Location>> {
1960 self.signals
1961 .iter()
1962 .filter(|s| s.quantifier == Some(quantifier))
1963 .collect()
1964 }
1965
1966 #[must_use]
1988 pub fn validate(&self) -> Vec<SignalValidationError> {
1989 self.signals
1990 .iter()
1991 .filter_map(|s| s.validate_against(&self.text))
1992 .collect()
1993 }
1994
1995 #[must_use]
2019 pub fn validate_invariants(&self) -> Vec<String> {
2020 let mut errors = Vec::new();
2021
2022 let mut seen_ids = std::collections::HashSet::new();
2024 for signal in &self.signals {
2025 if !seen_ids.insert(signal.id) {
2026 errors.push(format!("Duplicate signal ID: {}", signal.id));
2027 }
2028 }
2029
2030 let signal_ids: std::collections::HashSet<_> = self.signals.iter().map(|s| s.id).collect();
2032
2033 for (track_id, track) in &self.tracks {
2035 for signal_ref in &track.signals {
2036 if !signal_ids.contains(&signal_ref.signal_id) {
2037 errors.push(format!(
2038 "Track {} references non-existent signal {}",
2039 track_id, signal_ref.signal_id
2040 ));
2041 }
2042 }
2043 }
2044
2045 for (signal_id, track_id) in &self.signal_to_track {
2047 if let Some(track) = self.tracks.get(track_id) {
2049 if !track.signals.iter().any(|r| r.signal_id == *signal_id) {
2051 errors.push(format!(
2052 "signal_to_track[{}] = {} but track doesn't contain signal",
2053 signal_id, track_id
2054 ));
2055 }
2056 } else {
2057 errors.push(format!(
2058 "signal_to_track[{}] = {} but track doesn't exist",
2059 signal_id, track_id
2060 ));
2061 }
2062 }
2063
2064 for (track_id, identity_id) in &self.track_to_identity {
2066 if let Some(track) = self.tracks.get(track_id) {
2068 if track.identity_id != Some(*identity_id) {
2069 errors.push(format!(
2070 "track_to_identity[{}] = {} but track.identity_id = {:?}",
2071 track_id, identity_id, track.identity_id
2072 ));
2073 }
2074 } else {
2075 errors.push(format!(
2076 "track_to_identity[{}] = {} but track doesn't exist",
2077 track_id, identity_id
2078 ));
2079 }
2080
2081 if !self.identities.contains_key(identity_id) {
2083 errors.push(format!(
2084 "track_to_identity[{}] = {} but identity doesn't exist",
2085 track_id, identity_id
2086 ));
2087 }
2088 }
2089
2090 for (track_id, track) in &self.tracks {
2092 if let Some(identity_id) = track.identity_id {
2093 if !self.identities.contains_key(&identity_id) {
2094 errors.push(format!(
2095 "Track {} references non-existent identity {}",
2096 track_id, identity_id
2097 ));
2098 }
2099 }
2100 }
2101
2102 errors
2103 }
2104
2105 #[must_use]
2107 pub fn invariants_hold(&self) -> bool {
2108 self.validate_invariants().is_empty()
2109 }
2110
2111 #[must_use]
2113 pub fn is_valid(&self) -> bool {
2114 self.signals.iter().all(|s| s.is_valid(&self.text))
2115 }
2116
2117 pub fn add_signal_validated(
2121 &mut self,
2122 signal: Signal<Location>,
2123 ) -> Result<SignalId, SignalValidationError> {
2124 if let Some(err) = signal.validate_against(&self.text) {
2125 return Err(err);
2126 }
2127 Ok(self.add_signal(signal))
2128 }
2129
2130 pub fn add_signal_from_text(
2144 &mut self,
2145 surface: &str,
2146 label: impl Into<super::types::TypeLabel>,
2147 confidence: f32,
2148 ) -> Option<SignalId> {
2149 let signal = Signal::from_text(&self.text, surface, label, confidence)?;
2150 Some(self.add_signal(signal))
2151 }
2152
2153 pub fn add_signal_from_text_nth(
2155 &mut self,
2156 surface: &str,
2157 label: impl Into<super::types::TypeLabel>,
2158 confidence: f32,
2159 occurrence: usize,
2160 ) -> Option<SignalId> {
2161 let signal = Signal::from_text_nth(&self.text, surface, label, confidence, occurrence)?;
2162 Some(self.add_signal(signal))
2163 }
2164
2165 #[must_use]
2171 pub fn stats(&self) -> DocumentStats {
2172 let signal_count = self.signals.len();
2173 let track_count = self.tracks.len();
2174 let identity_count = self.identities.len();
2175
2176 let linked_track_count = self
2177 .tracks
2178 .values()
2179 .filter(|t| t.identity_id.is_some())
2180 .count();
2181 let untracked_count = self.untracked_signal_count();
2182
2183 let avg_track_size = if track_count > 0 {
2184 self.tracks.values().map(|t| t.len()).sum::<usize>() as f32 / track_count as f32
2185 } else {
2186 0.0
2187 };
2188
2189 let singleton_count = self.tracks.values().filter(|t| t.is_singleton()).count();
2190
2191 let avg_confidence = if signal_count > 0 {
2192 self.signals.iter().map(|s| s.confidence).sum::<f32>() / signal_count as f32
2193 } else {
2194 0.0
2195 };
2196
2197 let negated_count = self.signals.iter().filter(|s| s.negated).count();
2198
2199 let symbolic_count = self
2201 .signals
2202 .iter()
2203 .filter(|s| s.modality == Modality::Symbolic)
2204 .count();
2205 let iconic_count = self
2206 .signals
2207 .iter()
2208 .filter(|s| s.modality == Modality::Iconic)
2209 .count();
2210 let hybrid_count = self
2211 .signals
2212 .iter()
2213 .filter(|s| s.modality == Modality::Hybrid)
2214 .count();
2215
2216 DocumentStats {
2217 signal_count,
2218 track_count,
2219 identity_count,
2220 linked_track_count,
2221 untracked_count,
2222 avg_track_size,
2223 singleton_count,
2224 avg_confidence,
2225 negated_count,
2226 symbolic_count,
2227 iconic_count,
2228 hybrid_count,
2229 }
2230 }
2231
2232 pub fn add_signals(
2240 &mut self,
2241 signals: impl IntoIterator<Item = Signal<Location>>,
2242 ) -> Vec<SignalId> {
2243 signals.into_iter().map(|s| self.add_signal(s)).collect()
2244 }
2245
2246 pub fn create_track_from_signals(
2250 &mut self,
2251 canonical: impl Into<String>,
2252 signal_ids: &[SignalId],
2253 ) -> Option<TrackId> {
2254 if signal_ids.is_empty() {
2255 return None;
2256 }
2257
2258 let mut track = Track::new(TrackId::ZERO, canonical);
2259 for (pos, &id) in signal_ids.iter().enumerate() {
2260 track.add_signal(id, pos as u32);
2261 }
2262 Some(self.add_track(track))
2263 }
2264
2265 pub fn merge_tracks(&mut self, track_ids: &[TrackId]) -> Option<TrackId> {
2270 if track_ids.is_empty() {
2271 return None;
2272 }
2273
2274 let mut all_signals: Vec<SignalRef> = Vec::new();
2276 let mut canonical = String::new();
2277 let mut entity_type = None;
2278
2279 for &track_id in track_ids {
2280 if let Some(track) = self.tracks.get(&track_id) {
2281 if canonical.is_empty() {
2282 canonical = track.canonical_surface.clone();
2283 entity_type = track.entity_type.clone();
2284 }
2285 all_signals.extend(track.signals.iter().cloned());
2286 }
2287 }
2288
2289 if all_signals.is_empty() {
2290 return None;
2291 }
2292
2293 all_signals.sort_by_key(|s| s.position);
2295
2296 for &track_id in track_ids {
2298 self.tracks.remove(&track_id);
2299 }
2300
2301 let mut new_track = Track::new(TrackId::ZERO, canonical);
2303 new_track.entity_type = entity_type;
2304 for (pos, signal_ref) in all_signals.iter().enumerate() {
2305 new_track.add_signal(signal_ref.signal_id, pos as u32);
2306 }
2307
2308 Some(self.add_track(new_track))
2309 }
2310
2311 #[must_use]
2313 pub fn find_overlapping_signal_pairs(&self) -> Vec<(SignalId, SignalId)> {
2314 let mut pairs = Vec::new();
2315 let signals: Vec<_> = self.signals.iter().collect();
2316
2317 for i in 0..signals.len() {
2318 for j in (i + 1)..signals.len() {
2319 if signals[i].location.overlaps(&signals[j].location) {
2320 pairs.push((signals[i].id, signals[j].id));
2321 }
2322 }
2323 }
2324
2325 pairs
2326 }
2327}
2328
2329#[derive(Debug, Clone, Copy, Default)]
2331pub struct DocumentStats {
2332 pub signal_count: usize,
2334 pub track_count: usize,
2336 pub identity_count: usize,
2338 pub linked_track_count: usize,
2340 pub untracked_count: usize,
2342 pub avg_track_size: f32,
2344 pub singleton_count: usize,
2346 pub avg_confidence: f32,
2348 pub negated_count: usize,
2350 pub symbolic_count: usize,
2352 pub iconic_count: usize,
2354 pub hybrid_count: usize,
2356}
2357
2358impl std::fmt::Display for DocumentStats {
2359 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2360 writeln!(f, "Document Statistics:")?;
2361 writeln!(
2362 f,
2363 " Signals: {} (avg confidence: {:.2})",
2364 self.signal_count, self.avg_confidence
2365 )?;
2366 writeln!(
2367 f,
2368 " Tracks: {} (avg size: {:.1}, singletons: {})",
2369 self.track_count, self.avg_track_size, self.singleton_count
2370 )?;
2371 writeln!(
2372 f,
2373 " Identities: {} ({} tracks linked)",
2374 self.identity_count, self.linked_track_count
2375 )?;
2376 writeln!(f, " Untracked signals: {}", self.untracked_count)?;
2377 writeln!(
2378 f,
2379 " Modalities: {} symbolic, {} iconic, {} hybrid",
2380 self.symbolic_count, self.iconic_count, self.hybrid_count
2381 )?;
2382 if self.negated_count > 0 {
2383 writeln!(f, " Negated: {}", self.negated_count)?;
2384 }
2385 Ok(())
2386 }
2387}
2388
2389#[derive(Debug, Clone)]
2399struct IntervalNode {
2400 signal_id: SignalId,
2402 start: usize,
2404 end: usize,
2406 max_end: usize,
2408 left: Option<Box<IntervalNode>>,
2410 right: Option<Box<IntervalNode>>,
2412}
2413
2414impl IntervalNode {
2415 fn new(signal_id: SignalId, start: usize, end: usize) -> Self {
2416 Self {
2417 signal_id,
2418 start,
2419 end,
2420 max_end: end,
2421 left: None,
2422 right: None,
2423 }
2424 }
2425
2426 fn insert(&mut self, signal_id: SignalId, start: usize, end: usize) {
2427 self.max_end = self.max_end.max(end);
2428
2429 if start < self.start {
2430 if let Some(ref mut left) = self.left {
2431 left.insert(signal_id, start, end);
2432 } else {
2433 self.left = Some(Box::new(IntervalNode::new(signal_id, start, end)));
2434 }
2435 } else if let Some(ref mut right) = self.right {
2436 right.insert(signal_id, start, end);
2437 } else {
2438 self.right = Some(Box::new(IntervalNode::new(signal_id, start, end)));
2439 }
2440 }
2441
2442 fn query_overlap(&self, query_start: usize, query_end: usize, results: &mut Vec<SignalId>) {
2443 if self.start < query_end && query_start < self.end {
2445 results.push(self.signal_id);
2446 }
2447
2448 if let Some(ref left) = self.left {
2450 if left.max_end > query_start {
2451 left.query_overlap(query_start, query_end, results);
2452 }
2453 }
2454
2455 if let Some(ref right) = self.right {
2457 if self.start < query_end {
2458 right.query_overlap(query_start, query_end, results);
2459 }
2460 }
2461 }
2462
2463 fn query_containing(&self, query_start: usize, query_end: usize, results: &mut Vec<SignalId>) {
2464 if self.start <= query_start && self.end >= query_end {
2466 results.push(self.signal_id);
2467 }
2468
2469 if let Some(ref left) = self.left {
2471 if left.max_end >= query_end {
2472 left.query_containing(query_start, query_end, results);
2473 }
2474 }
2475
2476 if let Some(ref right) = self.right {
2478 if self.start <= query_start {
2479 right.query_containing(query_start, query_end, results);
2480 }
2481 }
2482 }
2483
2484 fn query_contained_in(
2485 &self,
2486 range_start: usize,
2487 range_end: usize,
2488 results: &mut Vec<SignalId>,
2489 ) {
2490 if self.start >= range_start && self.end <= range_end {
2492 results.push(self.signal_id);
2493 }
2494
2495 if let Some(ref left) = self.left {
2497 left.query_contained_in(range_start, range_end, results);
2498 }
2499
2500 if let Some(ref right) = self.right {
2502 if self.start < range_end {
2503 right.query_contained_in(range_start, range_end, results);
2504 }
2505 }
2506 }
2507}
2508
2509#[derive(Debug, Clone, Default)]
2525pub struct TextSpatialIndex {
2526 root: Option<IntervalNode>,
2527 size: usize,
2528}
2529
2530impl TextSpatialIndex {
2531 #[must_use]
2533 pub fn new() -> Self {
2534 Self::default()
2535 }
2536
2537 #[must_use]
2539 pub fn from_signals(signals: &[Signal<Location>]) -> Self {
2540 let mut index = Self::new();
2541 for signal in signals {
2542 if let Some((start, end)) = signal.location.text_offsets() {
2543 index.insert(signal.id, start, end);
2544 }
2545 }
2546 index
2547 }
2548
2549 pub fn insert(&mut self, signal_id: SignalId, start: usize, end: usize) {
2551 if let Some(ref mut root) = self.root {
2552 root.insert(signal_id, start, end);
2553 } else {
2554 self.root = Some(IntervalNode::new(signal_id, start, end));
2555 }
2556 self.size += 1;
2557 }
2558
2559 #[must_use]
2561 pub fn query_overlap(&self, start: usize, end: usize) -> Vec<SignalId> {
2562 let mut results = Vec::new();
2563 if let Some(ref root) = self.root {
2564 root.query_overlap(start, end, &mut results);
2565 }
2566 results
2567 }
2568
2569 #[must_use]
2571 pub fn query_containing(&self, start: usize, end: usize) -> Vec<SignalId> {
2572 let mut results = Vec::new();
2573 if let Some(ref root) = self.root {
2574 root.query_containing(start, end, &mut results);
2575 }
2576 results
2577 }
2578
2579 #[must_use]
2581 pub fn query_contained_in(&self, start: usize, end: usize) -> Vec<SignalId> {
2582 let mut results = Vec::new();
2583 if let Some(ref root) = self.root {
2584 root.query_contained_in(start, end, &mut results);
2585 }
2586 results
2587 }
2588
2589 #[must_use]
2591 pub fn len(&self) -> usize {
2592 self.size
2593 }
2594
2595 #[must_use]
2597 pub fn is_empty(&self) -> bool {
2598 self.size == 0
2599 }
2600}
2601
2602impl GroundedDocument {
2603 #[must_use]
2622 pub fn build_text_index(&self) -> TextSpatialIndex {
2623 TextSpatialIndex::from_signals(&self.signals)
2624 }
2625
2626 #[must_use]
2631 pub fn query_signals_in_range_indexed(
2632 &self,
2633 start: usize,
2634 end: usize,
2635 ) -> Vec<&Signal<Location>> {
2636 let index = self.build_text_index();
2637 let ids = index.query_contained_in(start, end);
2638 ids.iter().filter_map(|&id| self.get_signal(id)).collect()
2639 }
2640
2641 #[must_use]
2643 pub fn query_overlapping_signals_indexed(
2644 &self,
2645 start: usize,
2646 end: usize,
2647 ) -> Vec<&Signal<Location>> {
2648 let index = self.build_text_index();
2649 let ids = index.query_overlap(start, end);
2650 ids.iter().filter_map(|&id| self.get_signal(id)).collect()
2651 }
2652
2653 #[must_use]
2666 pub fn to_coref_document(&self) -> super::coref::CorefDocument {
2667 use super::coref::{CorefChain, CorefDocument, Mention};
2668 use std::collections::HashMap;
2669
2670 let signal_by_id: HashMap<SignalId, &Signal<Location>> =
2672 self.signals.iter().map(|s| (s.id, s)).collect();
2673
2674 let mut chains: Vec<CorefChain> = Vec::new();
2675
2676 for track in self.tracks.values() {
2677 let mut mentions: Vec<Mention> = Vec::new();
2678
2679 for sref in &track.signals {
2680 let Some(signal) = signal_by_id.get(&sref.signal_id) else {
2681 continue;
2682 };
2683
2684 let Some((start, end)) = signal.location.text_offsets() else {
2685 continue;
2686 };
2687
2688 let mut m = Mention::new(signal.surface.clone(), start, end);
2689 m.entity_type = Some(signal.label.to_string());
2690 mentions.push(m);
2691 }
2692
2693 if mentions.is_empty() {
2694 continue;
2695 }
2696
2697 let mut chain = CorefChain::new(mentions);
2698 chain.entity_type = track.entity_type.as_ref().map(|t| t.to_string());
2699 chains.push(chain);
2700 }
2701
2702 chains.sort_by_key(|c| c.mentions.first().map(|m| m.start).unwrap_or(usize::MAX));
2704
2705 CorefDocument::with_id(&self.text, &self.id, chains)
2706 }
2707}
2708
2709pub fn render_document_html(doc: &GroundedDocument) -> String {
2717 let mut html = String::new();
2718 let stats = doc.stats();
2719
2720 html.push_str(r#"<!DOCTYPE html>
2721<html>
2722<head>
2723<meta charset="UTF-8">
2724<meta name="color-scheme" content="dark light">
2725<title>grounded::GroundedDocument</title>
2726<style>
2727:root{
2728 /* Allow UA widgets (inputs/scrollbars) to match the theme */
2729 color-scheme: light dark;
2730 /* Dark (default) */
2731 --bg:#0a0a0a;
2732 --panel-bg:#0d0d0d;
2733 --text:#b0b0b0;
2734 --text-strong:#fff;
2735 --muted:#666;
2736 --border:#222;
2737 --border-strong:#333;
2738 --hover:#111;
2739 --input-bg:#080808;
2740 --active:#fff;
2741 --track-strong:rgba(255,255,255,0.35);
2742 --track-soft:rgba(255,255,255,0.18);
2743 /* Entity colors (dark) */
2744 --per-bg:#1a1a2e; --per-br:#4a4a8a; --per-tx:#8888cc;
2745 --org-bg:#1a2e1a; --org-br:#4a8a4a; --org-tx:#88cc88;
2746 --loc-bg:#2e2e1a; --loc-br:#8a8a4a; --loc-tx:#cccc88;
2747 --mis-bg:#1a1a1a; --mis-br:#4a4a4a; --mis-tx:#999;
2748 --dat-bg:#2e1a1a; --dat-br:#8a4a4a; --dat-tx:#cc8888;
2749 --badge-y-bg:#1a2e1a; --badge-y-tx:#4a8a4a; --badge-y-br:#2a4a2a;
2750 --badge-n-bg:#2e2e1a; --badge-n-tx:#8a8a4a; --badge-n-br:#4a4a2a;
2751}
2752@media (prefers-color-scheme: light){
2753 :root{
2754 --bg:#ffffff;
2755 --panel-bg:#f7f7f7;
2756 --text:#222;
2757 --text-strong:#000;
2758 --muted:#555;
2759 --border:#d6d6d6;
2760 --border-strong:#c6c6c6;
2761 --hover:#f0f0f0;
2762 --input-bg:#ffffff;
2763 --active:#000;
2764 --track-strong:rgba(0,0,0,0.25);
2765 --track-soft:rgba(0,0,0,0.12);
2766 /* Entity colors (light) */
2767 --per-bg:#e9e9ff; --per-br:#6c6cff; --per-tx:#2b2b7a;
2768 --org-bg:#e9f7e9; --org-br:#2f8a2f; --org-tx:#1f5a1f;
2769 --loc-bg:#fff7db; --loc-br:#8a7a2f; --loc-tx:#5a4d12;
2770 --mis-bg:#f2f2f2; --mis-br:#8a8a8a; --mis-tx:#333;
2771 --dat-bg:#ffe9e9; --dat-br:#8a2f2f; --dat-tx:#5a1f1f;
2772 --badge-y-bg:#e9f7e9; --badge-y-tx:#1f5a1f; --badge-y-br:#9ad19a;
2773 --badge-n-bg:#fff7db; --badge-n-tx:#5a4d12; --badge-n-br:#e2d39a;
2774 }
2775}
2776html[data-theme='dark']{
2777 --bg:#0a0a0a; --panel-bg:#0d0d0d; --text:#b0b0b0; --text-strong:#fff;
2778 --muted:#666; --border:#222; --border-strong:#333; --hover:#111;
2779 --input-bg:#080808; --active:#fff;
2780 --track-strong:rgba(255,255,255,0.35); --track-soft:rgba(255,255,255,0.18);
2781 --per-bg:#1a1a2e; --per-br:#4a4a8a; --per-tx:#8888cc;
2782 --org-bg:#1a2e1a; --org-br:#4a8a4a; --org-tx:#88cc88;
2783 --loc-bg:#2e2e1a; --loc-br:#8a8a4a; --loc-tx:#cccc88;
2784 --mis-bg:#1a1a1a; --mis-br:#4a4a4a; --mis-tx:#999;
2785 --dat-bg:#2e1a1a; --dat-br:#8a4a4a; --dat-tx:#cc8888;
2786 --badge-y-bg:#1a2e1a; --badge-y-tx:#4a8a4a; --badge-y-br:#2a4a2a;
2787 --badge-n-bg:#2e2e1a; --badge-n-tx:#8a8a4a; --badge-n-br:#4a4a2a;
2788}
2789html[data-theme='light']{
2790 --bg:#ffffff; --panel-bg:#f7f7f7; --text:#222; --text-strong:#000;
2791 --muted:#555; --border:#d6d6d6; --border-strong:#c6c6c6; --hover:#f0f0f0;
2792 --input-bg:#ffffff; --active:#000;
2793 --track-strong:rgba(0,0,0,0.25); --track-soft:rgba(0,0,0,0.12);
2794 --per-bg:#e9e9ff; --per-br:#6c6cff; --per-tx:#2b2b7a;
2795 --org-bg:#e9f7e9; --org-br:#2f8a2f; --org-tx:#1f5a1f;
2796 --loc-bg:#fff7db; --loc-br:#8a7a2f; --loc-tx:#5a4d12;
2797 --mis-bg:#f2f2f2; --mis-br:#8a8a8a; --mis-tx:#333;
2798 --dat-bg:#ffe9e9; --dat-br:#8a2f2f; --dat-tx:#5a1f1f;
2799 --badge-y-bg:#e9f7e9; --badge-y-tx:#1f5a1f; --badge-y-br:#9ad19a;
2800 --badge-n-bg:#fff7db; --badge-n-tx:#5a4d12; --badge-n-br:#e2d39a;
2801}
2802
2803*{box-sizing:border-box;margin:0;padding:0}
2804body{font:12px/1.4 monospace;background:var(--bg);color:var(--text);padding:8px}
2805h1,h2,h3{color:var(--text-strong);font-weight:normal;border-bottom:1px solid var(--border-strong);padding:4px 0;margin:16px 0 8px}
2806h1{font-size:14px}h2{font-size:12px}h3{font-size:11px;color:var(--muted)}
2807 a{color:inherit}
2808 a:hover{text-decoration:underline}
2809table{width:100%;border-collapse:collapse;font-size:11px;margin:4px 0}
2810th,td{padding:4px 8px;text-align:left;border:1px solid var(--border)}
2811th{background:var(--hover);color:var(--muted);font-weight:normal;text-transform:uppercase;font-size:10px}
2812tr:hover{background:var(--hover)}
2813.grid{display:grid;grid-template-columns:repeat(auto-fit,minmax(300px,1fr));gap:8px}
2814.panel{border:1px solid var(--border);background:var(--panel-bg);padding:8px}
2815.panel-h{display:flex;align-items:center;gap:8px}
2816.toggle{cursor:pointer;user-select:none;color:var(--muted);border:1px solid var(--border);background:var(--bg);padding:2px 6px;font-size:10px}
2817.panel-collapsed table,.panel-collapsed .panel-body{display:none}
2818.toolbar{display:flex;gap:8px;align-items:center;margin:8px 0 0}
2819.toolbar input{width:100%;max-width:520px;background:var(--input-bg);border:1px solid var(--border);color:var(--text);padding:6px 8px;font:12px monospace}
2820.muted{color:var(--muted)}
2821.panel-body{white-space:pre-wrap;word-break:break-word}
2822.text-box{background:var(--input-bg);border:1px solid var(--border);padding:8px;white-space:pre-wrap;word-break:break-word;line-height:1.6}
2823.e{padding:1px 2px;border-bottom:1px solid}
2824.seg{cursor:pointer}
2825.e-per{background:var(--per-bg);border-color:var(--per-br);color:var(--per-tx)}
2826.e-org{background:var(--org-bg);border-color:var(--org-br);color:var(--org-tx)}
2827.e-loc{background:var(--loc-bg);border-color:var(--loc-br);color:var(--loc-tx)}
2828.e-misc{background:var(--mis-bg);border-color:var(--mis-br);color:var(--mis-tx)}
2829.e-date{background:var(--dat-bg);border-color:var(--dat-br);color:var(--dat-tx)}
2830.e-track{box-shadow:inset 0 0 0 1px var(--track-strong)}
2831.e-track-hover{box-shadow:inset 0 0 0 1px var(--track-soft)}
2832.e-active{outline:2px solid var(--active);outline-offset:1px}
2833.conf{color:var(--muted);font-size:10px}
2834.badge{display:inline-block;padding:1px 4px;font-size:9px;text-transform:uppercase}
2835.badge-y{background:var(--badge-y-bg);color:var(--badge-y-tx);border:1px solid var(--badge-y-br)}
2836.badge-n{background:var(--badge-n-bg);color:var(--badge-n-tx);border:1px solid var(--badge-n-br)}
2837.stats{display:flex;gap:16px;padding:8px 0;border-bottom:1px solid var(--border);margin-bottom:8px}
2838.stat{text-align:center}.stat-v{font-size:18px;color:var(--text-strong)}.stat-l{font-size:9px;color:var(--muted);text-transform:uppercase}
2839.id{color:var(--muted);font-size:9px}
2840.kb{color:var(--muted)}
2841.arrow{color:var(--muted)}
2842</style>
2843</head>
2844<body>
2845"#);
2846
2847 html.push_str(&format!(
2849 r#"<div class="panel-h" style="justify-content:space-between"><h1>doc_id="{}" len={}</h1><span class="toggle" id="theme-toggle" title="toggle theme (auto → dark → light)">theme: auto</span></div>"#,
2850 html_escape(&doc.id),
2851 doc.text.len()
2852 ));
2853
2854 html.push_str(r#"<div class="stats">"#);
2855 html.push_str(&format!(
2856 r#"<div class="stat"><div class="stat-v">{}</div><div class="stat-l">signals</div></div>"#,
2857 stats.signal_count
2858 ));
2859 html.push_str(&format!(
2860 r#"<div class="stat"><div class="stat-v">{}</div><div class="stat-l">tracks</div></div>"#,
2861 stats.track_count
2862 ));
2863 html.push_str(&format!(r#"<div class="stat"><div class="stat-v">{}</div><div class="stat-l">identities</div></div>"#, stats.identity_count));
2864 html.push_str(&format!(r#"<div class="stat"><div class="stat-v">{:.2}</div><div class="stat-l">avg_conf</div></div>"#, stats.avg_confidence));
2865 html.push_str(&format!(
2866 r#"<div class="stat"><div class="stat-v">{}</div><div class="stat-l">linked</div></div>"#,
2867 stats.linked_track_count
2868 ));
2869 html.push_str(&format!(r#"<div class="stat"><div class="stat-v">{}</div><div class="stat-l">untracked</div></div>"#, stats.untracked_count));
2870 if stats.iconic_count > 0 || stats.hybrid_count > 0 {
2871 html.push_str(&format!(r#"<div class="stat"><div class="stat-v">{}/{}/{}</div><div class="stat-l">sym/ico/hyb</div></div>"#,
2872 stats.symbolic_count, stats.iconic_count, stats.hybrid_count));
2873 }
2874 html.push_str(r#"</div>"#);
2875
2876 html.push_str(r#"<h2>text</h2>"#);
2878 html.push_str(r#"<div class="text-box">"#);
2879 html.push_str(&annotate_text_html(
2880 &doc.text,
2881 doc.signals(),
2882 &doc.signal_to_track,
2883 ));
2884 html.push_str(r#"</div>"#);
2885
2886 html.push_str(
2888 r#"<h2>selection</h2><div class="panel" id="selection-panel" role="region" aria-label="selection"><div class="panel-h"><h3>selection</h3><span class="muted" id="selection-hint" role="status" aria-live="polite">click a mention / row to see coref track details</span></div><pre class="panel-body" id="selection-body" role="textbox" aria-readonly="true" aria-label="selection details">—</pre></div>"#,
2889 );
2890
2891 html.push_str(r#"<div class="grid">"#);
2893
2894 html.push_str(r#"<div class="panel" id="panel-signals"><div class="panel-h"><h3>signals (level 1)</h3><span class="toggle" data-toggle="panel-signals">toggle</span></div><div class="toolbar"><input id="signal-filter" type="text" placeholder="filter signals: id / label / surface (e.g. 'PER', 'S12', 'Paris')" /><span class="muted" id="signal-filter-count"></span></div><table id="signals-table">"#);
2896 html.push_str(r#"<tr><th>id</th><th>span</th><th>surface</th><th>label</th><th>conf</th><th>track</th></tr>"#);
2897 for signal in doc.signals() {
2898 let (span, start_opt, end_opt) = if let Some((s, e)) = signal.location.text_offsets() {
2899 (format!("[{},{})", s, e), Some(s), Some(e))
2900 } else {
2901 ("bbox".to_string(), None, None)
2902 };
2903 let track_id_num = doc.signal_to_track.get(&signal.id).copied();
2904 let track_id = track_id_num
2905 .map(|t| format!("T{}", t))
2906 .unwrap_or_else(|| "-".to_string());
2907 let track_attr = track_id_num
2908 .map(|t| format!(r#" data-track="{}""#, t))
2909 .unwrap_or_default();
2910 let offs_attr = match (start_opt, end_opt) {
2911 (Some(s), Some(e)) => format!(r#" data-start="{}" data-end="{}""#, s, e),
2912 _ => String::new(),
2913 };
2914 let neg = if signal.negated { " NEG" } else { "" };
2915 html.push_str(&format!(
2916 r#"<tr data-sid="S{sid}" data-label="{label}" data-surface="{surface}"{track_attr}{offs_attr} data-conf="{conf:.2}"><td class="id"><a href='#S{sid}'>S{sid}</a></td><td>{span}</td><td>{surface}</td><td>{label}{neg}</td><td class="conf">{conf:.2}</td><td class="id">{track}</td></tr>"#,
2917 sid = signal.id,
2918 span = span,
2919 surface = html_escape(&signal.surface),
2920 label = html_escape(signal.label.as_str()),
2921 neg = neg,
2922 conf = signal.confidence,
2923 track = track_id,
2924 track_attr = track_attr,
2925 offs_attr = offs_attr
2926 ));
2927 }
2928 html.push_str(r#"</table></div>"#);
2929
2930 html.push_str(r#"<div class="panel" id="panel-tracks"><div class="panel-h"><h3>tracks (level 2)</h3><span class="toggle" data-toggle="panel-tracks">toggle</span></div><table id="tracks-table">"#);
2932 html.push_str(r#"<tr><th>id</th><th>canonical</th><th>type</th><th>|S|</th><th>signals</th><th>identity</th></tr>"#);
2933 for track in doc.tracks() {
2934 let entity_type = track
2935 .entity_type
2936 .as_ref()
2937 .map(|t| t.as_str())
2938 .unwrap_or("-");
2939 let signals: Vec<String> = track
2940 .signals
2941 .iter()
2942 .map(|s| format!("S{}", s.signal_id))
2943 .collect();
2944 let identity = doc
2945 .identity_for_track(track.id)
2946 .map(|i| format!("I{}", i.id))
2947 .unwrap_or_else(|| "-".to_string());
2948 let linked_badge = if track.identity_id.is_some() {
2949 r#"<span class="badge badge-y">y</span>"#
2950 } else {
2951 r#"<span class="badge badge-n">n</span>"#
2952 };
2953 html.push_str(&format!(
2954 r#"<tr data-tid="{tid}"><td class="id">T{tid}</td><td>{canonical_surface}</td><td>{etype}</td><td>{n}</td><td class="id">{sigs}</td><td class="id">{ident} {badge}</td></tr>"#,
2955 tid = track.id,
2956 canonical_surface = html_escape(&track.canonical_surface),
2957 etype = html_escape(entity_type),
2958 n = track.len(),
2959 sigs = html_escape(&signals.join(" ")),
2960 ident = identity,
2961 badge = linked_badge
2962 ));
2963 }
2964 html.push_str(r#"</table></div>"#);
2965
2966 html.push_str(r#"<div class="panel" id="panel-identities"><div class="panel-h"><h3>identities (level 3)</h3><span class="toggle" data-toggle="panel-identities">toggle</span></div><table>"#);
2968 html.push_str(r#"<tr><th>id</th><th>name</th><th>type</th><th>kb</th><th>kb_id</th><th>aliases</th></tr>"#);
2969 for identity in doc.identities() {
2970 let kb = identity.kb_name.as_deref().unwrap_or("-");
2971 let kb_id = identity.kb_id.as_deref().unwrap_or("-");
2972 let entity_type = identity
2973 .entity_type
2974 .as_ref()
2975 .map(|t| t.as_str())
2976 .unwrap_or("-");
2977 let aliases = if identity.aliases.is_empty() {
2978 "-".to_string()
2979 } else {
2980 identity.aliases.join(", ")
2981 };
2982 html.push_str(&format!(
2983 r#"<tr><td class="id">I{}</td><td>{}</td><td>{}</td><td class="kb">{}</td><td class="kb">{}</td><td>{}</td></tr>"#,
2984 identity.id, html_escape(&identity.canonical_name), entity_type, kb, kb_id, html_escape(&aliases)
2985 ));
2986 }
2987 html.push_str(r#"</table></div>"#);
2988
2989 html.push_str(r#"</div>"#); html.push_str(r#"<h2>hierarchy trace</h2><div class="panel"><table>"#);
2993 html.push_str(r#"<tr><th>signal</th><th></th><th>track</th><th></th><th>identity</th><th>kb_id</th></tr>"#);
2994 for signal in doc.signals() {
2995 let track = doc.track_for_signal(signal.id);
2996 let identity = doc.identity_for_signal(signal.id);
2997
2998 let track_str = track
2999 .map(|t| format!("T{} \"{}\"", t.id, html_escape(&t.canonical_surface)))
3000 .unwrap_or_else(|| "-".to_string());
3001 let identity_str = identity
3002 .map(|i| format!("I{} \"{}\"", i.id, html_escape(&i.canonical_name)))
3003 .unwrap_or_else(|| "-".to_string());
3004 let kb_str = identity
3005 .and_then(|i| i.kb_id.as_ref())
3006 .map(|s| s.as_str())
3007 .unwrap_or("-");
3008
3009 html.push_str(&format!(
3010 r#"<tr><td>S{} "{}"</td><td class="arrow">→</td><td>{}</td><td class="arrow">→</td><td>{}</td><td class="kb">{}</td></tr>"#,
3011 signal.id, html_escape(&signal.surface), track_str, identity_str, kb_str
3012 ));
3013 }
3014 html.push_str(r#"</table></div>"#);
3015
3016 html.push_str(r#"<script>
3019(() => {
3020 // Index signal metadata from the signals table, and map signal/track → text elements.
3021 const signalMeta = new Map();
3022 document.querySelectorAll('#signals-table tr[data-sid]').forEach((row) => {
3023 const sid = row.getAttribute('data-sid');
3024 if (!sid) return;
3025 signalMeta.set(sid, {
3026 sid,
3027 label: row.getAttribute('data-label') || '',
3028 surface: row.getAttribute('data-surface') || '',
3029 conf: row.getAttribute('data-conf') || '',
3030 start: row.getAttribute('data-start'),
3031 end: row.getAttribute('data-end'),
3032 track: row.getAttribute('data-track'),
3033 });
3034 });
3035
3036 const signalEls = new Map();
3037 const addSignalEl = (sid, el) => {
3038 if (!sid || !el) return;
3039 const arr = signalEls.get(sid) || [];
3040 arr.push(el);
3041 signalEls.set(sid, arr);
3042 };
3043 // Old-style inline spans (non-overlapping renderer).
3044 document.querySelectorAll('span.e[data-sid]').forEach((el) => {
3045 addSignalEl(el.getAttribute('data-sid'), el);
3046 });
3047 // Segmented spans (overlap/discontinuous-safe renderer).
3048 document.querySelectorAll('span.seg[data-sids]').forEach((el) => {
3049 const raw = (el.getAttribute('data-sids') || '').trim();
3050 if (!raw) return;
3051 raw.split(/\s+/).filter(Boolean).forEach((sid) => addSignalEl(sid, el));
3052 });
3053
3054 const trackEls = new Map();
3055 for (const [sid, els] of signalEls.entries()) {
3056 const meta = signalMeta.get(sid);
3057 const tid = meta ? meta.track : null;
3058 if (!tid) continue;
3059 const arr = trackEls.get(tid) || [];
3060 els.forEach((el) => arr.push(el));
3061 trackEls.set(tid, arr);
3062 }
3063
3064 const selectionBody = document.getElementById('selection-body');
3065 const selectionHint = document.getElementById('selection-hint');
3066 const defaultHint = selectionHint ? (selectionHint.textContent || '') : '';
3067 const setSelection = (text) => {
3068 if (!selectionBody) return;
3069 selectionBody.textContent = text;
3070 };
3071 const setHint = (text) => {
3072 if (!selectionHint) return;
3073 selectionHint.textContent = text || defaultHint;
3074 };
3075
3076 // Theme toggle: auto (prefers-color-scheme) → dark → light.
3077 const themeBtn = document.getElementById('theme-toggle');
3078 const themeKey = 'anno-theme';
3079 const applyTheme = (theme) => {
3080 const t = theme || 'auto';
3081 if (t === 'auto') {
3082 delete document.documentElement.dataset.theme;
3083 } else {
3084 document.documentElement.dataset.theme = t;
3085 }
3086 if (themeBtn) themeBtn.textContent = `theme: ${t}`;
3087 };
3088 const readTheme = () => {
3089 try { return localStorage.getItem(themeKey) || 'auto'; } catch (_) { return 'auto'; }
3090 };
3091 const writeTheme = (t) => {
3092 try { localStorage.setItem(themeKey, t); } catch (_) { /* ignore */ }
3093 };
3094 applyTheme(readTheme());
3095 if (themeBtn) {
3096 themeBtn.addEventListener('click', () => {
3097 const cur = readTheme();
3098 const next = cur === 'auto' ? 'dark' : (cur === 'dark' ? 'light' : 'auto');
3099 writeTheme(next);
3100 applyTheme(next);
3101 });
3102 }
3103
3104 let activeSignalEls = [];
3105 let activeSignalRow = null;
3106 const clearActive = () => {
3107 if (activeSignalEls && activeSignalEls.length) {
3108 activeSignalEls.forEach((el) => el.classList.remove('e-active'));
3109 }
3110 if (activeSignalRow) activeSignalRow.classList.remove('e-active');
3111 activeSignalEls = [];
3112 activeSignalRow = null;
3113 };
3114
3115 let activeTrack = null;
3116 let hoverTrack = null;
3117
3118 const removeTrackClass = (tid, cls) => {
3119 if (!tid) return;
3120 const els = trackEls.get(tid);
3121 if (!els) return;
3122 els.forEach((el) => el.classList.remove(cls));
3123 };
3124
3125 const addTrackClass = (tid, cls) => {
3126 if (!tid) return;
3127 const els = trackEls.get(tid);
3128 if (!els) return;
3129 els.forEach((el) => el.classList.add(cls));
3130 };
3131
3132 const trackSize = (tid) => {
3133 const els = tid ? trackEls.get(tid) : null;
3134 return els ? els.length : 0;
3135 };
3136
3137 const getTrackSelectionText = (tid) => {
3138 if (!tid) return 'track: - (untracked)';
3139 const row = document.querySelector(`#tracks-table tr[data-tid='${tid}']`);
3140 if (!row) return `track T${tid}`;
3141 const cells = row.querySelectorAll('td');
3142 const canonical = (cells[1]?.textContent || '').trim();
3143 const etype = (cells[2]?.textContent || '').trim();
3144 const count = (cells[3]?.textContent || '').trim();
3145 const sigs = (cells[4]?.textContent || '').trim();
3146 const lines = [];
3147 lines.push(`track T${tid} canonical="${canonical}" type="${etype}" mentions=${count}`);
3148 if (sigs) lines.push(`track signals: ${sigs}`);
3149 return lines.join('\n');
3150 };
3151
3152 const renderTrackSelection = (tid) => setSelection(getTrackSelectionText(tid));
3153
3154 const renderSignalSelectionBySid = (sid) => {
3155 const meta = signalMeta.get(sid);
3156 const label = meta ? (meta.label || '') : '';
3157 const conf = meta ? (meta.conf || '') : '';
3158 const start = meta ? meta.start : null;
3159 const end = meta ? meta.end : null;
3160 const tid = meta ? meta.track : null;
3161 const lines = [];
3162 if (start !== null && end !== null) {
3163 lines.push(`signal ${sid} label=${label} conf=${conf} span=[${start},${end})`);
3164 } else {
3165 lines.push(`signal ${sid} label=${label} conf=${conf}`);
3166 }
3167 if (meta && meta.surface) lines.push(`surface: ${meta.surface}`);
3168 lines.push('');
3169 lines.push(getTrackSelectionText(tid));
3170 setSelection(lines.join('\n'));
3171 };
3172
3173 const setActiveTrack = (tid) => {
3174 const next = tid || null;
3175 if (activeTrack === next) return;
3176 removeTrackClass(activeTrack, 'e-track');
3177 activeTrack = next;
3178 if (activeTrack) addTrackClass(activeTrack, 'e-track');
3179 if (hoverTrack && activeTrack && hoverTrack === activeTrack) {
3180 removeTrackClass(hoverTrack, 'e-track-hover');
3181 }
3182 };
3183
3184 const setHoverTrack = (tid) => {
3185 const next = tid || null;
3186 if (hoverTrack === next) return;
3187 removeTrackClass(hoverTrack, 'e-track-hover');
3188 hoverTrack = next;
3189 if (!hoverTrack) {
3190 setHint('');
3191 return;
3192 }
3193 if (activeTrack && hoverTrack === activeTrack) {
3194 setHint(`selected track T${hoverTrack} (${trackSize(hoverTrack)} mentions)`);
3195 return;
3196 }
3197 addTrackClass(hoverTrack, 'e-track-hover');
3198 setHint(`hover track T${hoverTrack} (${trackSize(hoverTrack)} mentions)`);
3199 };
3200
3201 const emitToParentSpan = (start, end) => {
3202 try {
3203 if (!window.parent || window.parent === window) return;
3204 if (start === null || end === null) return;
3205 window.parent.postMessage({ type: 'anno:activate-span', start: Number(start), end: Number(end) }, '*');
3206 } catch (_) {
3207 // ignore: best-effort bridge for iframe containers
3208 }
3209 };
3210
3211 const activateBySpan = (start, end, emit) => {
3212 if (start === null || end === null || start === undefined || end === undefined) return;
3213 // Prefer an exact signal span if present; otherwise fall back to the table row metadata.
3214 const el = document.querySelector(`span.e[data-sid][data-start='${start}'][data-end='${end}']`);
3215 if (el) {
3216 const sid = el.getAttribute('data-sid');
3217 if (sid) activateSignal(sid, emit);
3218 return;
3219 }
3220 const row = document.querySelector(`#signals-table tr[data-start='${start}'][data-end='${end}']`);
3221 if (!row) return;
3222 const sid = row.getAttribute('data-sid');
3223 if (!sid) return;
3224 activateSignal(sid, emit);
3225 };
3226
3227 const activateSignal = (sid, emit) => {
3228 clearActive();
3229 const els = signalEls.get(sid) || [];
3230 if (!els.length) return;
3231 els.forEach((el) => el.classList.add('e-active'));
3232 activeSignalEls = els;
3233 const row = document.querySelector(`#signals-table tr[data-sid='${sid}']`);
3234 if (row) {
3235 row.classList.add('e-active');
3236 activeSignalRow = row;
3237 }
3238 const primaryEl = els[0];
3239 primaryEl.scrollIntoView({ block: 'center', behavior: 'smooth' });
3240 const meta = signalMeta.get(sid);
3241 const tid = meta ? meta.track : primaryEl.getAttribute('data-track');
3242 setActiveTrack(tid);
3243 renderSignalSelectionBySid(sid);
3244 if (emit && meta && meta.start !== null && meta.end !== null) {
3245 emitToParentSpan(meta.start, meta.end);
3246 }
3247 };
3248
3249 // Table click
3250 const signalsTable = document.getElementById('signals-table');
3251 if (signalsTable) {
3252 signalsTable.addEventListener('click', (ev) => {
3253 const a = ev.target && ev.target.closest ? ev.target.closest("a[href^='#S']") : null;
3254 const row = ev.target && ev.target.closest ? ev.target.closest('tr[data-sid]') : null;
3255 const sid = (a && a.getAttribute('href') ? a.getAttribute('href').slice(1) : null) || (row ? row.getAttribute('data-sid') : null);
3256 if (!sid) return;
3257 ev.preventDefault();
3258 activateSignal(sid, true);
3259 history.replaceState(null, '', '#' + sid);
3260 });
3261
3262 // Hover a signals row → preview track highlight
3263 signalsTable.addEventListener('mouseover', (ev) => {
3264 const row = ev.target && ev.target.closest ? ev.target.closest('tr[data-sid]') : null;
3265 if (!row) return;
3266 const tid = row.getAttribute('data-track');
3267 setHoverTrack(tid);
3268 });
3269 signalsTable.addEventListener('mouseout', (ev) => {
3270 const to = ev.relatedTarget;
3271 if (to && signalsTable.contains(to)) return;
3272 setHoverTrack(null);
3273 });
3274 }
3275
3276 // Clicking an inline entity should also toggle active highlight.
3277 const pickPrimarySid = (el) => {
3278 if (!el) return null;
3279 const p = el.getAttribute('data-primary');
3280 if (p) return p;
3281 const raw = (el.getAttribute('data-sids') || '').trim();
3282 if (!raw) return null;
3283 const sids = raw.split(/\s+/).filter(Boolean);
3284 if (!sids.length) return null;
3285 // Prefer the shortest mention span from metadata.
3286 let best = sids[0];
3287 let bestLen = null;
3288 for (const sid of sids) {
3289 const meta = signalMeta.get(sid);
3290 const s = meta && meta.start !== null ? Number(meta.start) : null;
3291 const e = meta && meta.end !== null ? Number(meta.end) : null;
3292 const len = (s !== null && e !== null) ? (e - s) : null;
3293 if (len === null) continue;
3294 if (bestLen === null || len < bestLen) {
3295 best = sid;
3296 bestLen = len;
3297 }
3298 }
3299 return best;
3300 };
3301
3302 document.addEventListener('click', (ev) => {
3303 const span = ev.target && ev.target.closest ? ev.target.closest('span.e[data-sid]') : null;
3304 if (span) {
3305 activateSignal(span.getAttribute('data-sid'), true);
3306 return;
3307 }
3308 const seg = ev.target && ev.target.closest ? ev.target.closest('span.seg[data-sids]') : null;
3309 if (!seg) return;
3310 activateSignal(pickPrimarySid(seg), true);
3311 });
3312
3313 // Hover an inline entity → preview highlight its track
3314 document.addEventListener('mouseover', (ev) => {
3315 const span = ev.target && ev.target.closest ? ev.target.closest('span.e[data-sid]') : null;
3316 if (span) {
3317 setHoverTrack(span.getAttribute('data-track'));
3318 return;
3319 }
3320 const seg = ev.target && ev.target.closest ? ev.target.closest('span.seg[data-sids]') : null;
3321 if (!seg) return;
3322 const sid = pickPrimarySid(seg);
3323 const meta = sid ? signalMeta.get(sid) : null;
3324 setHoverTrack(meta ? meta.track : null);
3325 });
3326 document.addEventListener('mouseout', (ev) => {
3327 const span = ev.target && ev.target.closest ? ev.target.closest('span.e[data-sid]') : null;
3328 const seg = ev.target && ev.target.closest ? ev.target.closest('span.seg[data-sids]') : null;
3329 if (!span && !seg) return;
3330 const to = ev.relatedTarget;
3331 if (to && to.closest && (to.closest('span.e[data-sid]') || to.closest('span.seg[data-sids]'))) return;
3332 setHoverTrack(null);
3333 });
3334
3335 // Clicking a track row → select track (highlight + details)
3336 const tracksTable = document.getElementById('tracks-table');
3337 if (tracksTable) {
3338 tracksTable.addEventListener('click', (ev) => {
3339 const row = ev.target && ev.target.closest ? ev.target.closest('tr[data-tid]') : null;
3340 if (!row) return;
3341 const tid = row.getAttribute('data-tid');
3342 setActiveTrack(tid);
3343 renderTrackSelection(tid);
3344 });
3345 tracksTable.addEventListener('mouseover', (ev) => {
3346 const row = ev.target && ev.target.closest ? ev.target.closest('tr[data-tid]') : null;
3347 if (!row) return;
3348 setHoverTrack(row.getAttribute('data-tid'));
3349 });
3350 tracksTable.addEventListener('mouseout', (ev) => {
3351 const to = ev.relatedTarget;
3352 if (to && tracksTable.contains(to)) return;
3353 setHoverTrack(null);
3354 });
3355 }
3356
3357 // Filter
3358 const input = document.getElementById('signal-filter');
3359 const countEl = document.getElementById('signal-filter-count');
3360 if (input && signalsTable) {
3361 const update = () => {
3362 const q = (input.value || '').trim().toLowerCase();
3363 let shown = 0;
3364 const rows = signalsTable.querySelectorAll('tr[data-sid]');
3365 rows.forEach(row => {
3366 const sid = (row.getAttribute('data-sid') || '').toLowerCase();
3367 const label = (row.getAttribute('data-label') || '').toLowerCase();
3368 const surface = (row.getAttribute('data-surface') || '').toLowerCase();
3369 const ok = !q || sid.includes(q) || label.includes(q) || surface.includes(q);
3370 row.style.display = ok ? '' : 'none';
3371 if (ok) shown += 1;
3372 });
3373 if (countEl) countEl.textContent = shown + ' shown';
3374 };
3375 input.addEventListener('input', update);
3376 update();
3377 }
3378
3379 // Panel toggles
3380 document.querySelectorAll('[data-toggle]').forEach(btn => {
3381 btn.addEventListener('click', () => {
3382 const id = btn.getAttribute('data-toggle');
3383 const panel = id ? document.getElementById(id) : null;
3384 if (!panel) return;
3385 panel.classList.toggle('panel-collapsed');
3386 });
3387 });
3388
3389 // If URL hash is #S123, focus it.
3390 const hash = (location.hash || '').slice(1);
3391 if (hash && hash.startsWith('S')) activateSignal(hash, false);
3392
3393 // Optional: allow parent pages (e.g., dataset explorers) to sync selection across iframes.
3394 window.addEventListener('message', (ev) => {
3395 const data = ev && ev.data ? ev.data : null;
3396 if (!data || data.type !== 'anno:activate-span') return;
3397 if (typeof data.start !== 'number' || typeof data.end !== 'number') return;
3398 activateBySpan(data.start, data.end, false);
3399 });
3400})();
3401</script>"#);
3402
3403 html.push_str(r#"</body></html>"#);
3404 html
3405}
3406
3407fn html_escape(s: &str) -> String {
3408 s.replace('&', "&")
3409 .replace('<', "<")
3410 .replace('>', ">")
3411 .replace('"', """)
3412}
3413
3414fn annotate_text_html(
3415 text: &str,
3416 signals: &[Signal<Location>],
3417 signal_to_track: &std::collections::HashMap<SignalId, TrackId>,
3418) -> String {
3419 let char_count = text.chars().count();
3420 if char_count == 0 {
3421 return String::new();
3422 }
3423
3424 #[derive(Debug, Clone)]
3425 struct SigMeta {
3426 sid: String,
3427 label: String,
3428 conf: f32,
3429 track_id: Option<TrackId>,
3430 covered_len: usize,
3431 }
3432
3433 #[derive(Debug, Clone)]
3434 struct Event {
3435 pos: usize,
3436 meta_idx: usize,
3437 delta: i32, }
3439
3440 let mut metas: Vec<SigMeta> = Vec::new();
3442 let mut events: Vec<Event> = Vec::new();
3443 let mut boundaries: Vec<usize> = vec![0, char_count];
3444
3445 for s in signals {
3446 let raw_segments: Vec<(usize, usize)> = match &s.location {
3447 Location::Text { start, end } => vec![(*start, *end)],
3448 Location::TextWithBbox { start, end, .. } => vec![(*start, *end)],
3449 Location::Discontinuous { segments } => segments.clone(),
3450 _ => Vec::new(),
3451 };
3452 if raw_segments.is_empty() {
3453 continue;
3454 }
3455
3456 let mut cleaned: Vec<(usize, usize)> = Vec::new();
3457 let mut covered_len = 0usize;
3458 for (start, end) in raw_segments {
3459 let start = start.min(char_count);
3460 let end = end.min(char_count);
3461 if start >= end {
3462 continue;
3463 }
3464 covered_len = covered_len.saturating_add(end - start);
3465 cleaned.push((start, end));
3466 }
3467 if cleaned.is_empty() {
3468 continue;
3469 }
3470
3471 let meta_idx = metas.len();
3472 let track_id = signal_to_track.get(&s.id).copied();
3473 metas.push(SigMeta {
3474 sid: format!("S{}", s.id),
3475 label: s.label.to_string(),
3476 conf: s.confidence,
3477 track_id,
3478 covered_len,
3479 });
3480
3481 for (start, end) in cleaned {
3482 boundaries.push(start);
3483 boundaries.push(end);
3484 events.push(Event {
3485 pos: start,
3486 meta_idx,
3487 delta: 1,
3488 });
3489 events.push(Event {
3490 pos: end,
3491 meta_idx,
3492 delta: -1,
3493 });
3494 }
3495 }
3496
3497 if metas.is_empty() {
3498 return html_escape(text);
3499 }
3500
3501 boundaries.sort_unstable();
3502 boundaries.dedup();
3503 events.sort_by(|a, b| a.pos.cmp(&b.pos).then_with(|| a.delta.cmp(&b.delta)));
3504
3505 let mut active_counts: Vec<u32> = vec![0; metas.len()];
3506 let mut active: Vec<usize> = Vec::new();
3507 let mut ev_idx = 0usize;
3508
3509 let mut result = String::new();
3510
3511 for bi in 0..boundaries.len().saturating_sub(1) {
3512 let pos = boundaries[bi];
3513 while ev_idx < events.len() && events[ev_idx].pos == pos {
3515 let e = &events[ev_idx];
3516 let idx = e.meta_idx;
3517 if e.delta < 0 {
3518 if active_counts[idx] > 0 {
3519 active_counts[idx] -= 1;
3520 if active_counts[idx] == 0 {
3521 active.retain(|&x| x != idx);
3522 }
3523 }
3524 } else {
3525 active_counts[idx] += 1;
3526 if active_counts[idx] == 1 {
3527 active.push(idx);
3528 }
3529 }
3530 ev_idx += 1;
3531 }
3532
3533 let next = boundaries[bi + 1];
3534 if next <= pos {
3535 continue;
3536 }
3537
3538 let seg_text: String = text.chars().skip(pos).take(next - pos).collect();
3539 if active.is_empty() {
3540 result.push_str(&html_escape(&seg_text));
3541 continue;
3542 }
3543
3544 let primary_idx = active
3546 .iter()
3547 .copied()
3548 .min_by(|a, b| {
3549 metas[*a]
3550 .covered_len
3551 .cmp(&metas[*b].covered_len)
3552 .then_with(|| {
3553 metas[*b]
3554 .conf
3555 .partial_cmp(&metas[*a].conf)
3556 .unwrap_or(std::cmp::Ordering::Equal)
3557 })
3558 })
3559 .unwrap_or(active[0]);
3560 let primary = &metas[primary_idx];
3561
3562 let class = match primary.label.to_uppercase().as_str() {
3563 "PER" | "PERSON" => "e-per",
3564 "ORG" | "ORGANIZATION" | "COMPANY" => "e-org",
3565 "LOC" | "LOCATION" | "GPE" => "e-loc",
3566 "DATE" | "TIME" => "e-date",
3567 _ => "e-misc",
3568 };
3569
3570 let mut sids: Vec<&str> = active.iter().map(|i| metas[*i].sid.as_str()).collect();
3571 sids.sort_unstable();
3572 let data_sids = sids.join(" ");
3573
3574 let mut title = format!(
3575 "sids=[{}] primary={} [{}..{})",
3576 data_sids, primary.sid, pos, next
3577 );
3578 if let Some(t) = primary.track_id {
3579 title.push_str(&format!(" track=T{}", t));
3580 }
3581
3582 result.push_str(&format!(
3583 r#"<span class="e seg {class}" data-sids="{sids}" data-start="{start}" data-end="{end}" data-primary="{primary}" title="{title}">{text}</span>"#,
3584 class = class,
3585 sids = html_escape(&data_sids),
3586 start = pos,
3587 end = next,
3588 primary = html_escape(&primary.sid),
3589 title = html_escape(&title),
3590 text = html_escape(&seg_text),
3591 ));
3592 }
3593
3594 result
3595}
3596
3597#[derive(Debug, Clone)]
3603pub struct EvalComparison {
3604 pub text: String,
3606 pub gold: Vec<Signal<Location>>,
3608 pub predicted: Vec<Signal<Location>>,
3610 pub matches: Vec<EvalMatch>,
3612}
3613
3614#[derive(Debug, Clone)]
3616pub enum EvalMatch {
3617 Correct {
3619 gold_id: SignalId,
3621 pred_id: SignalId,
3623 },
3624 TypeMismatch {
3626 gold_id: SignalId,
3628 pred_id: SignalId,
3630 gold_label: String,
3632 pred_label: String,
3634 },
3635 BoundaryError {
3637 gold_id: SignalId,
3639 pred_id: SignalId,
3641 iou: f64,
3643 },
3644 Spurious {
3646 pred_id: SignalId,
3648 },
3649 Missed {
3651 gold_id: SignalId,
3653 },
3654}
3655
3656impl EvalComparison {
3657 #[must_use]
3677 pub fn compare(
3678 text: &str,
3679 gold: Vec<Signal<Location>>,
3680 predicted: Vec<Signal<Location>>,
3681 ) -> Self {
3682 let mut matches = Vec::new();
3683 let mut gold_matched = vec![false; gold.len()];
3684 let mut pred_matched = vec![false; predicted.len()];
3685
3686 for (pi, pred) in predicted.iter().enumerate() {
3688 let pred_offsets = match pred.location.text_offsets() {
3689 Some(o) => o,
3690 None => continue,
3691 };
3692
3693 for (gi, g) in gold.iter().enumerate() {
3694 if gold_matched[gi] {
3695 continue;
3696 }
3697 let gold_offsets = match g.location.text_offsets() {
3698 Some(o) => o,
3699 None => continue,
3700 };
3701
3702 if pred_offsets == gold_offsets {
3704 if pred.label == g.label {
3705 matches.push(EvalMatch::Correct {
3706 gold_id: g.id,
3707 pred_id: pred.id,
3708 });
3709 } else {
3710 matches.push(EvalMatch::TypeMismatch {
3711 gold_id: g.id,
3712 pred_id: pred.id,
3713 gold_label: g.label.to_string(),
3714 pred_label: pred.label.to_string(),
3715 });
3716 }
3717 gold_matched[gi] = true;
3718 pred_matched[pi] = true;
3719 break;
3720 }
3721 }
3722 }
3723
3724 for (pi, pred) in predicted.iter().enumerate() {
3726 if pred_matched[pi] {
3727 continue;
3728 }
3729 let pred_offsets = match pred.location.text_offsets() {
3730 Some(o) => o,
3731 None => continue,
3732 };
3733
3734 for (gi, g) in gold.iter().enumerate() {
3735 if gold_matched[gi] {
3736 continue;
3737 }
3738 let gold_offsets = match g.location.text_offsets() {
3739 Some(o) => o,
3740 None => continue,
3741 };
3742
3743 if pred_offsets.0 < gold_offsets.1 && pred_offsets.1 > gold_offsets.0 {
3745 let iou = pred.location.iou(&g.location).unwrap_or(0.0);
3746 matches.push(EvalMatch::BoundaryError {
3747 gold_id: g.id,
3748 pred_id: pred.id,
3749 iou,
3750 });
3751 gold_matched[gi] = true;
3752 pred_matched[pi] = true;
3753 break;
3754 }
3755 }
3756 }
3757
3758 for (pi, pred) in predicted.iter().enumerate() {
3760 if !pred_matched[pi] {
3761 matches.push(EvalMatch::Spurious { pred_id: pred.id });
3762 }
3763 }
3764
3765 for (gi, g) in gold.iter().enumerate() {
3767 if !gold_matched[gi] {
3768 matches.push(EvalMatch::Missed { gold_id: g.id });
3769 }
3770 }
3771
3772 Self {
3773 text: text.to_string(),
3774 gold,
3775 predicted,
3776 matches,
3777 }
3778 }
3779
3780 #[must_use]
3782 pub fn correct_count(&self) -> usize {
3783 self.matches
3784 .iter()
3785 .filter(|m| matches!(m, EvalMatch::Correct { .. }))
3786 .count()
3787 }
3788
3789 #[must_use]
3791 pub fn error_count(&self) -> usize {
3792 self.matches.len() - self.correct_count()
3793 }
3794
3795 #[must_use]
3797 pub fn precision(&self) -> f64 {
3798 if self.predicted.is_empty() {
3799 0.0
3800 } else {
3801 self.correct_count() as f64 / self.predicted.len() as f64
3802 }
3803 }
3804
3805 #[must_use]
3807 pub fn recall(&self) -> f64 {
3808 if self.gold.is_empty() {
3809 0.0
3810 } else {
3811 self.correct_count() as f64 / self.gold.len() as f64
3812 }
3813 }
3814
3815 #[must_use]
3817 pub fn f1(&self) -> f64 {
3818 let p = self.precision();
3819 let r = self.recall();
3820 if p + r > 0.0 {
3821 2.0 * p * r / (p + r)
3822 } else {
3823 0.0
3824 }
3825 }
3826}
3827
3828pub fn render_eval_html(cmp: &EvalComparison) -> String {
3832 render_eval_html_with_title(cmp, "eval comparison")
3833}
3834
3835#[must_use]
3839pub fn render_eval_html_with_title(cmp: &EvalComparison, title: &str) -> String {
3840 let mut html = String::new();
3841 let title = html_escape(title);
3842
3843 html.push_str(
3844 r#"<!DOCTYPE html>
3845<html>
3846<head>
3847<meta charset="UTF-8">
3848<meta name="color-scheme" content="dark light">
3849"#,
3850 );
3851 html.push_str(&format!("<title>{}</title>", title));
3852 html.push_str(r#"
3853:root{
3854 color-scheme: light dark;
3855 --bg:#0a0a0a;
3856 --panel-bg:#0d0d0d;
3857 --text:#b0b0b0;
3858 --text-strong:#fff;
3859 --muted:#666;
3860 --border:#222;
3861 --border-strong:#333;
3862 --hover:#111;
3863 --input-bg:#080808;
3864 --active:#ddd;
3865 /* Eval entity colors (dark) */
3866 --gold-bg:#1a2e1a; --gold-br:#4a8a4a; --gold-tx:#88cc88;
3867 --pred-bg:#1a1a2e; --pred-br:#4a4a8a; --pred-tx:#8888cc;
3868 /* Match row borders */
3869 --m-ok:#4a8a4a;
3870 --m-type:#8a8a4a;
3871 --m-bound:#4a8a8a;
3872 --m-fp:#8a4a4a;
3873 --m-fn:#8a4a8a;
3874}
3875@media (prefers-color-scheme: light){
3876 :root{
3877 --bg:#ffffff;
3878 --panel-bg:#f7f7f7;
3879 --text:#222;
3880 --text-strong:#000;
3881 --muted:#555;
3882 --border:#d6d6d6;
3883 --border-strong:#c6c6c6;
3884 --hover:#f0f0f0;
3885 --input-bg:#ffffff;
3886 --active:#000;
3887 --gold-bg:#e9f7e9; --gold-br:#2f8a2f; --gold-tx:#1f5a1f;
3888 --pred-bg:#e9e9ff; --pred-br:#6c6cff; --pred-tx:#2b2b7a;
3889 --m-ok:#2f8a2f;
3890 --m-type:#8a7a2f;
3891 --m-bound:#2f7a8a;
3892 --m-fp:#8a2f2f;
3893 --m-fn:#6a2f8a;
3894 }
3895}
3896html[data-theme='dark']{
3897 --bg:#0a0a0a; --panel-bg:#0d0d0d; --text:#b0b0b0; --text-strong:#fff;
3898 --muted:#666; --border:#222; --border-strong:#333; --hover:#111; --input-bg:#080808; --active:#ddd;
3899 --gold-bg:#1a2e1a; --gold-br:#4a8a4a; --gold-tx:#88cc88;
3900 --pred-bg:#1a1a2e; --pred-br:#4a4a8a; --pred-tx:#8888cc;
3901 --m-ok:#4a8a4a; --m-type:#8a8a4a; --m-bound:#4a8a8a; --m-fp:#8a4a4a; --m-fn:#8a4a8a;
3902}
3903html[data-theme='light']{
3904 --bg:#ffffff; --panel-bg:#f7f7f7; --text:#222; --text-strong:#000;
3905 --muted:#555; --border:#d6d6d6; --border-strong:#c6c6c6; --hover:#f0f0f0; --input-bg:#ffffff; --active:#000;
3906 --gold-bg:#e9f7e9; --gold-br:#2f8a2f; --gold-tx:#1f5a1f;
3907 --pred-bg:#e9e9ff; --pred-br:#6c6cff; --pred-tx:#2b2b7a;
3908 --m-ok:#2f8a2f; --m-type:#8a7a2f; --m-bound:#2f7a8a; --m-fp:#8a2f2f; --m-fn:#6a2f8a;
3909}
3910
3911<style>
3912*{box-sizing:border-box;margin:0;padding:0}
3913body{font:12px/1.4 monospace;background:var(--bg);color:var(--text);padding:8px}
3914h1,h2{color:var(--text-strong);font-weight:normal;border-bottom:1px solid var(--border-strong);padding:4px 0;margin:16px 0 8px}
3915h1{font-size:14px}h2{font-size:12px}
3916table{width:100%;border-collapse:collapse;font-size:11px;margin:4px 0}
3917th,td{padding:4px 8px;text-align:left;border:1px solid var(--border)}
3918th{background:var(--hover);color:var(--muted);font-weight:normal;text-transform:uppercase;font-size:10px}
3919tr:hover{background:var(--hover)}
3920.grid{display:grid;grid-template-columns:1fr 1fr;gap:8px}
3921.panel{border:1px solid var(--border);background:var(--panel-bg);padding:8px}
3922.text-box{background:var(--input-bg);border:1px solid var(--border);padding:8px;white-space:pre-wrap;word-break:break-word;line-height:1.6}
3923.stats{display:flex;gap:24px;padding:8px 0;border-bottom:1px solid var(--border);margin-bottom:8px}
3924.stat{text-align:center}.stat-v{font-size:18px;color:var(--text-strong)}.stat-l{font-size:9px;color:var(--muted);text-transform:uppercase}
3925/* Entities */
3926.e{padding:1px 2px;border-bottom:2px solid}
3927.seg{cursor:pointer}
3928.e-gold{background:var(--gold-bg);border-color:var(--gold-br);color:var(--gold-tx)}
3929.e-pred{background:var(--pred-bg);border-color:var(--pred-br);color:var(--pred-tx)}
3930.e-active{outline:1px solid var(--active);outline-offset:1px}
3931/* Match types */
3932.correct{background:#1a2e1a;border-color:#4a8a4a}
3933.type-err{background:#2e2e1a;border-color:#8a8a4a}
3934.boundary{background:#1a2e2e;border-color:#4a8a8a}
3935.spurious{background:#2e1a1a;border-color:#8a4a4a}
3936.missed{background:#2e1a2e;border-color:#8a4a8a}
3937.match-row.correct{border-left:3px solid var(--m-ok)}
3938.match-row.type-err{border-left:3px solid var(--m-type)}
3939.match-row.boundary{border-left:3px solid var(--m-bound)}
3940.match-row.spurious{border-left:3px solid var(--m-fp)}
3941.match-row.missed{border-left:3px solid var(--m-fn)}
3942.match-row.active{outline:1px solid var(--muted)}
3943.sel{color:var(--muted);margin:6px 0 12px}
3944.metric{font-size:14px;color:var(--muted)}.metric b{color:var(--text-strong)}
3945</style>
3946</head>
3947<body>
3948"#);
3949
3950 html.push_str(&format!(
3952 "<div class=\"panel-h\" style=\"justify-content:space-between\"><h1>{}</h1><span class=\"toggle\" id=\"theme-toggle\" title=\"toggle theme (auto → dark → light)\">theme: auto</span></div>",
3953 title
3954 ));
3955
3956 html.push_str("<div class=\"stats\">");
3958 html.push_str(&format!(
3959 "<div class=\"stat\"><div class=\"stat-v\">{}</div><div class=\"stat-l\">gold</div></div>",
3960 cmp.gold.len()
3961 ));
3962 html.push_str(&format!(
3963 "<div class=\"stat\"><div class=\"stat-v\">{}</div><div class=\"stat-l\">predicted</div></div>",
3964 cmp.predicted.len()
3965 ));
3966 html.push_str(&format!(
3967 "<div class=\"stat\"><div class=\"stat-v\">{}</div><div class=\"stat-l\">correct</div></div>",
3968 cmp.correct_count()
3969 ));
3970 html.push_str(&format!(
3971 "<div class=\"stat\"><div class=\"stat-v\">{}</div><div class=\"stat-l\">errors</div></div>",
3972 cmp.error_count()
3973 ));
3974 html.push_str(&format!(
3975 "<div class=\"metric\">P=<b>{:.1}%</b> R=<b>{:.1}%</b> F1=<b>{:.1}%</b></div>",
3976 cmp.precision() * 100.0,
3977 cmp.recall() * 100.0,
3978 cmp.f1() * 100.0
3979 ));
3980 html.push_str("</div>");
3981
3982 html.push_str("<div id=\"selection\" class=\"sel\">click a match row to select spans</div>");
3984
3985 html.push_str("<div class=\"grid\">");
3987
3988 html.push_str("<div class=\"panel\"><h2>gold (ground truth)</h2><div class=\"text-box\">");
3990 let gold_spans: Vec<EvalHtmlSpan> = cmp
3991 .gold
3992 .iter()
3993 .map(|s| {
3994 let (start, end) = s.location.text_offsets().unwrap_or((0, 0));
3995 EvalHtmlSpan {
3996 start,
3997 end,
3998 label: s.label.to_string(),
3999 class: "e-gold",
4000 id: format!("G{}", s.id),
4001 }
4002 })
4003 .collect();
4004 html.push_str(&annotate_text_spans(&cmp.text, &gold_spans));
4005 html.push_str("</div></div>");
4006
4007 html.push_str("<div class=\"panel\"><h2>predicted</h2><div class=\"text-box\">");
4009 let pred_spans: Vec<EvalHtmlSpan> = cmp
4010 .predicted
4011 .iter()
4012 .map(|s| {
4013 let (start, end) = s.location.text_offsets().unwrap_or((0, 0));
4014 EvalHtmlSpan {
4015 start,
4016 end,
4017 label: s.label.to_string(),
4018 class: "e-pred",
4019 id: format!("P{}", s.id),
4020 }
4021 })
4022 .collect();
4023 html.push_str(&annotate_text_spans(&cmp.text, &pred_spans));
4024 html.push_str("</div></div>");
4025
4026 html.push_str("</div>");
4027
4028 html.push_str("<h2>matches</h2><table>");
4030 html.push_str("<tr><th>type</th><th>gold</th><th>predicted</th><th>notes</th></tr>");
4031
4032 for (mi, m) in cmp.matches.iter().enumerate() {
4033 let (class, mtype, gold_text, pred_text, notes, gid, pid) = match m {
4034 EvalMatch::Correct { gold_id, pred_id } => {
4035 let g = cmp.gold.iter().find(|s| s.id == *gold_id);
4036 let p = cmp.predicted.iter().find(|s| s.id == *pred_id);
4037 (
4038 "correct",
4039 "✓",
4040 g.map(|s| format!("[{}] {}", s.label, s.surface()))
4041 .unwrap_or_default(),
4042 p.map(|s| format!("[{}] {}", s.label, s.surface()))
4043 .unwrap_or_default(),
4044 String::new(),
4045 Some(format!("G{}", gold_id)),
4046 Some(format!("P{}", pred_id)),
4047 )
4048 }
4049 EvalMatch::TypeMismatch {
4050 gold_id,
4051 pred_id,
4052 gold_label,
4053 pred_label,
4054 } => {
4055 let g = cmp.gold.iter().find(|s| s.id == *gold_id);
4056 let p = cmp.predicted.iter().find(|s| s.id == *pred_id);
4057 (
4058 "type-err",
4059 "type",
4060 g.map(|s| format!("[{}] {}", s.label, s.surface()))
4061 .unwrap_or_default(),
4062 p.map(|s| format!("[{}] {}", s.label, s.surface()))
4063 .unwrap_or_default(),
4064 format!("{} → {}", gold_label, pred_label),
4065 Some(format!("G{}", gold_id)),
4066 Some(format!("P{}", pred_id)),
4067 )
4068 }
4069 EvalMatch::BoundaryError {
4070 gold_id,
4071 pred_id,
4072 iou,
4073 } => {
4074 let g = cmp.gold.iter().find(|s| s.id == *gold_id);
4075 let p = cmp.predicted.iter().find(|s| s.id == *pred_id);
4076 (
4077 "boundary",
4078 "bound",
4079 g.map(|s| format!("[{}] \"{}\"", s.label, s.surface()))
4080 .unwrap_or_default(),
4081 p.map(|s| format!("[{}] \"{}\"", s.label, s.surface()))
4082 .unwrap_or_default(),
4083 format!("IoU={:.2}", iou),
4084 Some(format!("G{}", gold_id)),
4085 Some(format!("P{}", pred_id)),
4086 )
4087 }
4088 EvalMatch::Spurious { pred_id } => {
4089 let p = cmp.predicted.iter().find(|s| s.id == *pred_id);
4090 (
4091 "spurious",
4092 "FP",
4093 String::new(),
4094 p.map(|s| format!("[{}] {}", s.label, s.surface()))
4095 .unwrap_or_default(),
4096 "false positive".to_string(),
4097 None,
4098 Some(format!("P{}", pred_id)),
4099 )
4100 }
4101 EvalMatch::Missed { gold_id } => {
4102 let g = cmp.gold.iter().find(|s| s.id == *gold_id);
4103 (
4104 "missed",
4105 "FN",
4106 g.map(|s| format!("[{}] {}", s.label, s.surface()))
4107 .unwrap_or_default(),
4108 String::new(),
4109 "false negative".to_string(),
4110 Some(format!("G{}", gold_id)),
4111 None,
4112 )
4113 }
4114 };
4115
4116 let mut data_attrs = String::new();
4117 if let Some(gid) = gid.as_deref() {
4118 data_attrs.push_str(&format!(" data-gid=\"{}\"", html_escape(gid)));
4119 }
4120 if let Some(pid) = pid.as_deref() {
4121 data_attrs.push_str(&format!(" data-pid=\"{}\"", html_escape(pid)));
4122 }
4123
4124 html.push_str(&format!(
4125 "<tr id=\"M{mid}\" class=\"match-row {class}\"{attrs}><td><a class=\"match-link\" href=\"#M{mid}\">{mtype}</a></td><td>{gold}</td><td>{pred}</td><td>{notes}</td></tr>",
4126 mid = mi,
4127 class = class,
4128 attrs = data_attrs,
4129 mtype = html_escape(mtype),
4130 gold = html_escape(&gold_text),
4131 pred = html_escape(&pred_text),
4132 notes = html_escape(¬es)
4133 ));
4134 }
4135 html.push_str("</table>");
4136
4137 html.push_str(
4138 r#"<script>
4139(() => {
4140 // Theme toggle: auto (prefers-color-scheme) → dark → light.
4141 const themeBtn = document.getElementById('theme-toggle');
4142 const themeKey = 'anno-theme';
4143 const applyTheme = (theme) => {
4144 const t = theme || 'auto';
4145 if (t === 'auto') {
4146 delete document.documentElement.dataset.theme;
4147 } else {
4148 document.documentElement.dataset.theme = t;
4149 }
4150 if (themeBtn) themeBtn.textContent = `theme: ${t}`;
4151 };
4152 const readTheme = () => {
4153 try { return localStorage.getItem(themeKey) || 'auto'; } catch (_) { return 'auto'; }
4154 };
4155 const writeTheme = (t) => {
4156 try { localStorage.setItem(themeKey, t); } catch (_) { /* ignore */ }
4157 };
4158 applyTheme(readTheme());
4159 if (themeBtn) {
4160 themeBtn.addEventListener('click', () => {
4161 const cur = readTheme();
4162 const next = cur === 'auto' ? 'dark' : (cur === 'dark' ? 'light' : 'auto');
4163 writeTheme(next);
4164 applyTheme(next);
4165 });
4166 }
4167
4168 function clearActive() {
4169 document.querySelectorAll(".e-active").forEach((el) => el.classList.remove("e-active"));
4170 document.querySelectorAll("tr.match-row.active").forEach((el) => el.classList.remove("active"));
4171 }
4172
4173 function findSpanEls(eid) {
4174 if (!eid) return [];
4175 // New segmented renderer: one span can be split across multiple elements.
4176 const els = Array.from(document.querySelectorAll(`span.e[data-eids~='${eid}']`));
4177 if (els.length) return els;
4178 // Back-compat: older HTML used a single element id.
4179 const single = document.getElementById(eid);
4180 return single ? [single] : [];
4181 }
4182
4183 function activate(gid, pid, row) {
4184 clearActive();
4185 const gEls = findSpanEls(gid);
4186 const pEls = findSpanEls(pid);
4187 const sel = document.getElementById("selection");
4188 gEls.forEach((el) => el.classList.add("e-active"));
4189 pEls.forEach((el) => el.classList.add("e-active"));
4190 if (row) row.classList.add("active");
4191 if (sel) {
4192 const parts = [];
4193 if (gEls.length) {
4194 const lbl = gEls[0].dataset && gEls[0].dataset.label ? ` [${gEls[0].dataset.label}]` : "";
4195 parts.push(`gold ${gid}${lbl}`);
4196 }
4197 if (pEls.length) {
4198 const lbl = pEls[0].dataset && pEls[0].dataset.label ? ` [${pEls[0].dataset.label}]` : "";
4199 parts.push(`pred ${pid}${lbl}`);
4200 }
4201 sel.textContent = parts.length ? parts.join(" | ") : "no selection";
4202 }
4203 if (row && row.id) {
4204 // Keep deep links stable without triggering navigation jump.
4205 // NOTE: single quotes avoid the Rust raw-string delimiter issue with quote+hash.
4206 history.replaceState(null, "", '#' + row.id);
4207 }
4208 const target = gEls[0] || pEls[0];
4209 if (target) target.scrollIntoView({ behavior: "smooth", block: "center" });
4210 }
4211
4212 document.querySelectorAll("tr.match-row[data-gid], tr.match-row[data-pid]").forEach((tr) => {
4213 tr.addEventListener("click", () => activate(tr.dataset.gid, tr.dataset.pid, tr));
4214 });
4215
4216 document.querySelectorAll("a.match-link").forEach((a) => {
4217 a.addEventListener("click", (ev) => {
4218 ev.preventDefault();
4219 const tr = a.closest("tr.match-row");
4220 if (!tr) return;
4221 activate(tr.dataset.gid, tr.dataset.pid, tr);
4222 });
4223 });
4224
4225 // Auto-select a match row if the URL has a deep link (e.g. #M12).
4226 const hash = (location.hash || "").slice(1);
4227 if (hash && hash.startsWith("M")) {
4228 const tr = document.getElementById(hash);
4229 if (tr && tr.classList && tr.classList.contains("match-row")) {
4230 activate(tr.dataset.gid, tr.dataset.pid, tr);
4231 }
4232 }
4233})();
4234</script>"#,
4235 );
4236
4237 html.push_str("</body></html>");
4238 html
4239}
4240
4241#[derive(Debug, Clone)]
4243struct EvalHtmlSpan {
4244 start: usize,
4245 end: usize,
4246 label: String,
4247 class: &'static str,
4248 id: String,
4249}
4250
4251fn annotate_text_spans(text: &str, spans: &[EvalHtmlSpan]) -> String {
4252 let char_count = text.chars().count();
4253 if char_count == 0 || spans.is_empty() {
4254 return html_escape(text);
4255 }
4256
4257 #[derive(Debug, Clone)]
4258 struct Meta {
4259 id: String,
4260 label: String,
4261 class: &'static str,
4262 len: usize,
4263 }
4264 #[derive(Debug, Clone)]
4265 struct Event {
4266 pos: usize,
4267 meta_idx: usize,
4268 delta: i32,
4269 }
4270
4271 let mut metas: Vec<Meta> = Vec::with_capacity(spans.len());
4272 let mut events: Vec<Event> = Vec::new();
4273 let mut boundaries: Vec<usize> = vec![0, char_count];
4274
4275 for s in spans {
4276 let start = s.start.min(char_count);
4277 let end = s.end.min(char_count);
4278 if start >= end {
4279 continue;
4280 }
4281 let meta_idx = metas.len();
4282 metas.push(Meta {
4283 id: s.id.clone(),
4284 label: s.label.to_string(),
4285 class: s.class,
4286 len: end - start,
4287 });
4288 boundaries.push(start);
4289 boundaries.push(end);
4290 events.push(Event {
4291 pos: start,
4292 meta_idx,
4293 delta: 1,
4294 });
4295 events.push(Event {
4296 pos: end,
4297 meta_idx,
4298 delta: -1,
4299 });
4300 }
4301
4302 if metas.is_empty() {
4303 return html_escape(text);
4304 }
4305
4306 boundaries.sort_unstable();
4307 boundaries.dedup();
4308 events.sort_by(|a, b| a.pos.cmp(&b.pos).then_with(|| a.delta.cmp(&b.delta)));
4309
4310 let mut active_counts: Vec<u32> = vec![0; metas.len()];
4311 let mut active: Vec<usize> = Vec::new();
4312 let mut ev_idx = 0usize;
4313 let mut result = String::new();
4314
4315 for bi in 0..boundaries.len().saturating_sub(1) {
4316 let pos = boundaries[bi];
4317 while ev_idx < events.len() && events[ev_idx].pos == pos {
4318 let e = &events[ev_idx];
4319 let idx = e.meta_idx;
4320 if e.delta < 0 {
4321 if active_counts[idx] > 0 {
4322 active_counts[idx] -= 1;
4323 if active_counts[idx] == 0 {
4324 active.retain(|&x| x != idx);
4325 }
4326 }
4327 } else {
4328 active_counts[idx] += 1;
4329 if active_counts[idx] == 1 {
4330 active.push(idx);
4331 }
4332 }
4333 ev_idx += 1;
4334 }
4335
4336 let next = boundaries[bi + 1];
4337 if next <= pos {
4338 continue;
4339 }
4340
4341 let seg_text: String = text.chars().skip(pos).take(next - pos).collect();
4342 if active.is_empty() {
4343 result.push_str(&html_escape(&seg_text));
4344 continue;
4345 }
4346
4347 let primary_idx = active
4348 .iter()
4349 .copied()
4350 .min_by_key(|i| metas[*i].len)
4351 .unwrap_or(active[0]);
4352 let primary = &metas[primary_idx];
4353 let mut eids: Vec<&str> = active.iter().map(|i| metas[*i].id.as_str()).collect();
4354 eids.sort_unstable();
4355 let data_eids = eids.join(" ");
4356
4357 let title = format!(
4358 "eids=[{}] primary={} [{}..{})",
4359 data_eids, primary.id, pos, next
4360 );
4361 result.push_str(&format!(
4362 "<span class=\"e seg {class}\" data-eids=\"{eids}\" data-label=\"{label}\" data-start=\"{start}\" data-end=\"{end}\" title=\"{title}\">{text}</span>",
4363 class = primary.class,
4364 eids = html_escape(&data_eids),
4365 label = html_escape(&primary.label),
4366 start = pos,
4367 end = next,
4368 title = html_escape(&title),
4369 text = html_escape(&seg_text)
4370 ));
4371 }
4372
4373 result
4374}
4375
4376#[derive(Debug, Clone, Default)]
4382pub struct ProcessOptions {
4383 pub labels: Vec<String>,
4385 pub threshold: f32,
4387}
4388
4389#[derive(Debug)]
4391pub struct ProcessResult {
4392 pub document: GroundedDocument,
4394 pub valid: bool,
4396 pub errors: Vec<SignalValidationError>,
4398}
4399
4400impl ProcessResult {
4401 #[must_use]
4403 pub fn to_html(&self) -> String {
4404 render_document_html(&self.document)
4405 }
4406}
4407
4408#[allow(dead_code)]
4428#[doc(hidden)]
4429pub fn process_text(
4430 _text: &str,
4431 _model: Option<&dyn std::any::Any>,
4432) -> super::Result<ProcessResult> {
4433 unimplemented!("Use anno::process_text instead - this stub documents the API only")
4434}
4435
4436#[derive(Debug, Clone)]
4454pub struct Corpus {
4455 documents: std::collections::HashMap<String, GroundedDocument>,
4456 identities: std::collections::HashMap<IdentityId, Identity>,
4457 next_identity_id: IdentityId,
4458}
4459
4460impl Corpus {
4461 #[must_use]
4463 pub fn new() -> Self {
4464 Self {
4465 documents: std::collections::HashMap::new(),
4466 identities: std::collections::HashMap::new(),
4467 next_identity_id: IdentityId::ZERO,
4468 }
4469 }
4470
4471 #[must_use]
4473 pub fn identities(&self) -> &std::collections::HashMap<IdentityId, Identity> {
4474 &self.identities
4475 }
4476
4477 #[must_use]
4479 pub fn get_identity(&self, id: IdentityId) -> Option<&Identity> {
4480 self.identities.get(&id)
4481 }
4482
4483 pub fn add_identity(&mut self, mut identity: Identity) -> IdentityId {
4488 let id = self.next_identity_id;
4489 identity.id = id;
4490 self.identities.insert(id, identity);
4491 self.next_identity_id += 1;
4492 id
4493 }
4494
4495 #[must_use]
4499 pub fn next_identity_id(&self) -> IdentityId {
4500 self.next_identity_id
4501 }
4502
4503 pub fn documents(&self) -> impl Iterator<Item = &GroundedDocument> {
4507 self.documents.values()
4508 }
4509
4510 #[must_use]
4514 pub fn get_document(&self, doc_id: &str) -> Option<&GroundedDocument> {
4515 self.documents.get(doc_id)
4516 }
4517
4518 pub fn get_document_mut(&mut self, doc_id: &str) -> Option<&mut GroundedDocument> {
4522 self.documents.get_mut(doc_id)
4523 }
4524
4525 pub fn add_document(&mut self, document: GroundedDocument) -> String {
4530 let doc_id = document.id.clone();
4531 self.documents.insert(doc_id.clone(), document);
4532 doc_id
4533 }
4534
4535 pub fn link_track_to_kb(
4557 &mut self,
4558 track_ref: &TrackRef,
4559 kb_name: impl Into<String>,
4560 kb_id: impl Into<String>,
4561 canonical_name: impl Into<String>,
4562 ) -> super::Result<IdentityId> {
4563 use super::error::Error;
4564
4565 let doc = self.documents.get_mut(&track_ref.doc_id).ok_or_else(|| {
4566 Error::track_ref(format!(
4567 "Document '{}' not found in corpus",
4568 track_ref.doc_id
4569 ))
4570 })?;
4571 let track = doc.get_track(track_ref.track_id).ok_or_else(|| {
4572 Error::track_ref(format!(
4573 "Track {} not found in document '{}'",
4574 track_ref.track_id, track_ref.doc_id
4575 ))
4576 })?;
4577
4578 let kb_name_str = kb_name.into();
4579 let kb_id_str = kb_id.into();
4580 let canonical_name_str = canonical_name.into();
4581
4582 let identity_id = if let Some(existing_id) = track.identity_id {
4584 if let Some(identity) = self.identities.get_mut(&existing_id) {
4586 identity.kb_id = Some(kb_id_str.clone());
4587 identity.kb_name = Some(kb_name_str.clone());
4588 identity.canonical_name = canonical_name_str.clone();
4589
4590 identity.source = Some(match identity.source.take() {
4592 Some(IdentitySource::CrossDocCoref { track_refs }) => IdentitySource::Hybrid {
4593 track_refs,
4594 kb_name: kb_name_str.clone(),
4595 kb_id: kb_id_str.clone(),
4596 },
4597 _ => IdentitySource::KnowledgeBase {
4598 kb_name: kb_name_str.clone(),
4599 kb_id: kb_id_str.clone(),
4600 },
4601 });
4602
4603 existing_id
4604 } else {
4605 let new_id = self.next_identity_id;
4613 self.next_identity_id += 1;
4614
4615 let identity = Identity {
4616 id: new_id,
4617 canonical_name: canonical_name_str,
4618 entity_type: track.entity_type.clone(),
4619 kb_id: Some(kb_id_str.clone()),
4620 kb_name: Some(kb_name_str.clone()),
4621 description: None,
4622 embedding: track.embedding.clone(),
4623 aliases: Vec::new(),
4624 confidence: track.cluster_confidence,
4625 source: Some(IdentitySource::KnowledgeBase {
4626 kb_name: kb_name_str,
4627 kb_id: kb_id_str,
4628 }),
4629 };
4630
4631 self.identities.insert(new_id, identity);
4632 doc.link_track_to_identity(track_ref.track_id, new_id);
4635 new_id
4636 }
4637 } else {
4638 let new_id = self.next_identity_id;
4640 self.next_identity_id += 1;
4641
4642 let identity = Identity {
4643 id: new_id,
4644 canonical_name: canonical_name_str,
4645 entity_type: track.entity_type.clone(),
4646 kb_id: Some(kb_id_str.clone()),
4647 kb_name: Some(kb_name_str.clone()),
4648 description: None,
4649 embedding: track.embedding.clone(),
4650 aliases: Vec::new(),
4651 confidence: track.cluster_confidence,
4652 source: Some(IdentitySource::KnowledgeBase {
4653 kb_name: kb_name_str,
4654 kb_id: kb_id_str,
4655 }),
4656 };
4657
4658 self.identities.insert(new_id, identity);
4659 doc.link_track_to_identity(track_ref.track_id, new_id);
4660 new_id
4661 };
4662
4663 Ok(identity_id)
4664 }
4665}
4666
4667impl Default for Corpus {
4668 fn default() -> Self {
4669 Self::new()
4670 }
4671}
4672
4673#[cfg(test)]
4674mod tests {
4675 #![allow(clippy::unwrap_used)] use super::*;
4677
4678 #[test]
4679 fn test_render_eval_html_has_interactive_hooks_and_is_unicode_safe() {
4680 let text = "習近平在北京會見了普京。";
4682
4683 let gold: Vec<Signal<Location>> = vec![
4684 Signal::new(SignalId::new(0), Location::text(0, 3), "習近平", "PER", 1.0),
4685 Signal::new(SignalId::new(1), Location::text(4, 6), "北京", "LOC", 1.0),
4686 ];
4687
4688 let predicted: Vec<Signal<Location>> = vec![
4690 Signal::new(SignalId::new(0), Location::text(0, 3), "習近平", "PER", 0.9),
4691 Signal::new(SignalId::new(1), Location::text(4, 6), "北京", "PER", 0.7),
4692 ];
4693
4694 let cmp = EvalComparison::compare(text, gold, predicted);
4695 let html = render_eval_html_with_title(&cmp, "test");
4696
4697 assert!(html.contains("id=\"selection\""));
4699
4700 assert!(html.contains("data-eids=\"G0\""));
4702 assert!(html.contains("data-eids=\"P0\""));
4703
4704 assert!(html.contains("class=\"match-link\""));
4706 assert!(html.contains("href=\"#M0\""));
4707 assert!(html.contains("data-gid=\"G0\""));
4708 assert!(html.contains("data-pid=\"P0\""));
4709
4710 assert!(html.contains("北京"));
4712 }
4713
4714 fn find_char_span(text: &str, needle: &str) -> Option<(usize, usize)> {
4715 let hay: Vec<char> = text.chars().collect();
4716 let pat: Vec<char> = needle.chars().collect();
4717 if pat.is_empty() || hay.len() < pat.len() {
4718 return None;
4719 }
4720 for i in 0..=(hay.len() - pat.len()) {
4721 if hay[i..(i + pat.len())] == pat[..] {
4722 return Some((i, i + pat.len()));
4723 }
4724 }
4725 None
4726 }
4727
4728 #[test]
4729 fn test_annotate_text_html_supports_overlaps_discontinuous_and_unicode() {
4730 let text = "Marie Curie met Cher in Paris. 習近平在北京會見了普京。 \
4732التقى محمد بن سلمان في الرياض. Путин встретился с Си Цзиньпином в Москве. \
4733प्रधान मंत्री शर्मा दिल्ली में मिले। severe pain ... in abdomen.";
4734
4735 let (m0s, m0e) = find_char_span(text, "Marie Curie").unwrap();
4737 let (m1s, m1e) = find_char_span(text, "Curie").unwrap();
4738
4739 let pain = find_char_span(text, "pain").unwrap();
4741 let abdomen = find_char_span(text, "abdomen").unwrap();
4742
4743 let signals: Vec<Signal<Location>> = vec![
4744 Signal::new(
4745 SignalId::new(0),
4746 Location::text(m0s, m0e),
4747 "Marie Curie",
4748 "PER",
4749 0.9,
4750 ),
4751 Signal::new(
4752 SignalId::new(1),
4753 Location::text(m1s, m1e),
4754 "Curie",
4755 "PER",
4756 0.8,
4757 ),
4758 Signal::new(
4759 SignalId::new(2),
4760 Location::Discontinuous {
4761 segments: vec![pain, abdomen],
4762 },
4763 "pain … abdomen",
4764 "SYMPTOM",
4765 0.7,
4766 ),
4767 ];
4768
4769 let html = annotate_text_html(text, &signals, &std::collections::HashMap::new());
4770
4771 assert!(html.contains("data-sids=\"S0 S1\"") || html.contains("data-sids=\"S1 S0\""));
4773
4774 assert!(html.contains("data-sids=\"S2\""));
4776
4777 assert!(html.contains("北京"));
4779 assert!(html.contains("Москве"));
4780 assert!(html.contains("शर्मा"));
4781 assert!(html.contains("محمد"));
4782 }
4783
4784 #[test]
4785 fn test_location_text_iou() {
4786 let l1 = Location::text(0, 10);
4787 let l2 = Location::text(5, 15);
4788 let iou = l1.iou(&l2).unwrap();
4789 assert!((iou - 0.333).abs() < 0.01);
4793 }
4794
4795 #[test]
4796 fn test_location_bbox_iou() {
4797 let b1 = Location::bbox(0.0, 0.0, 0.5, 0.5);
4798 let b2 = Location::bbox(0.25, 0.25, 0.5, 0.5);
4799 let iou = b1.iou(&b2).unwrap();
4800 assert!((iou - 0.143).abs() < 0.01);
4804 }
4805
4806 #[test]
4807 fn test_location_different_types_no_iou() {
4808 let text = Location::text(0, 10);
4809 let bbox = Location::bbox(0.0, 0.0, 0.5, 0.5);
4810 assert!(text.iou(&bbox).is_none());
4811 }
4812
4813 #[test]
4814 fn test_signal_creation() {
4815 let signal: Signal<Location> =
4816 Signal::new(0, Location::text(0, 11), "Marie Curie", "Person", 0.95);
4817 assert_eq!(signal.surface, "Marie Curie");
4818 assert_eq!(signal.label, "Person".into());
4819 assert!((signal.confidence - 0.95).abs() < 0.001);
4820 assert!(!signal.negated);
4821 }
4822
4823 #[test]
4824 fn test_signal_with_linguistic_features() {
4825 let signal: Signal<Location> =
4826 Signal::new(0, Location::text(0, 10), "not a doctor", "Occupation", 0.8)
4827 .negated()
4828 .with_quantifier(Quantifier::Existential)
4829 .with_modality(Modality::Symbolic);
4830
4831 assert!(signal.negated);
4832 assert_eq!(signal.quantifier, Some(Quantifier::Existential));
4833 assert!(signal.modality.supports_linguistic_features());
4834 }
4835
4836 #[test]
4837 fn test_track_formation() {
4838 let mut track = Track::new(0, "Marie Curie");
4839 track.add_signal(0, 0);
4840 track.add_signal(1, 1);
4841 track.add_signal(2, 2);
4842
4843 assert_eq!(track.len(), 3);
4844 assert!(!track.is_singleton());
4845 assert!(!track.is_empty());
4846 }
4847
4848 #[test]
4849 fn test_identity_creation() {
4850 let identity = Identity::from_kb(0, "Marie Curie", "wikidata", "Q7186")
4851 .with_type("Person")
4852 .with_embedding(vec![0.1, 0.2, 0.3]);
4853
4854 assert_eq!(identity.canonical_name, "Marie Curie");
4855 assert_eq!(identity.kb_id, Some("Q7186".to_string()));
4856 assert_eq!(identity.kb_name, Some("wikidata".to_string()));
4857 assert!(identity.embedding.is_some());
4858 }
4859
4860 #[test]
4861 fn test_grounded_document_hierarchy() {
4862 let mut doc = GroundedDocument::new(
4863 "doc1",
4864 "Marie Curie won the Nobel Prize. She was a physicist.",
4865 );
4866
4867 let s1 = doc.add_signal(Signal::new(
4869 0,
4870 Location::text(0, 12),
4871 "Marie Curie",
4872 "Person",
4873 0.95,
4874 ));
4875 let s2 = doc.add_signal(Signal::new(
4876 1,
4877 Location::text(38, 41),
4878 "She",
4879 "Person",
4880 0.88,
4881 ));
4882 let s3 = doc.add_signal(Signal::new(
4883 2,
4884 Location::text(17, 29),
4885 "Nobel Prize",
4886 "Award",
4887 0.92,
4888 ));
4889
4890 let mut track1 = Track::new(0, "Marie Curie");
4892 track1.add_signal(s1, 0);
4893 track1.add_signal(s2, 1);
4894 let track1_id = doc.add_track(track1);
4895
4896 let mut track2 = Track::new(1, "Nobel Prize");
4897 track2.add_signal(s3, 0);
4898 doc.add_track(track2);
4899
4900 let identity = Identity::from_kb(0, "Marie Curie", "wikidata", "Q7186");
4902 let identity_id = doc.add_identity(identity);
4903 doc.link_track_to_identity(track1_id, identity_id);
4904
4905 assert_eq!(doc.signals().len(), 3);
4907 assert_eq!(doc.tracks().count(), 2);
4908 assert_eq!(doc.identities().count(), 1);
4909
4910 let track = doc.track_for_signal(s1).unwrap();
4912 assert_eq!(track.canonical_surface, "Marie Curie");
4913 assert_eq!(track.len(), 2);
4914
4915 let identity = doc.identity_for_track(track1_id).unwrap();
4917 assert_eq!(identity.kb_id, Some("Q7186".to_string()));
4918
4919 let identity = doc.identity_for_signal(s1).unwrap();
4921 assert_eq!(identity.canonical_name, "Marie Curie");
4922 }
4923
4924 #[test]
4925 fn test_modality_features() {
4926 assert!(Modality::Symbolic.supports_linguistic_features());
4927 assert!(!Modality::Symbolic.supports_geometric_features());
4928
4929 assert!(!Modality::Iconic.supports_linguistic_features());
4930 assert!(Modality::Iconic.supports_geometric_features());
4931
4932 assert!(Modality::Hybrid.supports_linguistic_features());
4933 assert!(Modality::Hybrid.supports_geometric_features());
4934 }
4935
4936 #[test]
4937 fn test_location_from_span() {
4938 let span = Span::Text { start: 0, end: 10 };
4939 let location = Location::from(&span);
4940 assert_eq!(location.text_offsets(), Some((0, 10)));
4941
4942 let span = Span::BoundingBox {
4943 x: 0.1,
4944 y: 0.2,
4945 width: 0.3,
4946 height: 0.4,
4947 page: Some(1),
4948 };
4949 let location = Location::from(&span);
4950 assert!(matches!(location, Location::BoundingBox { .. }));
4951 }
4952
4953 #[test]
4954 fn test_entity_roundtrip() {
4955 use super::EntityType;
4956
4957 let entities = vec![
4958 Entity::new("Marie Curie", EntityType::Person, 0, 12, 0.95),
4959 Entity::new(
4960 "Nobel Prize",
4961 EntityType::Other("Award".to_string()),
4962 17,
4963 29,
4964 0.92,
4965 ),
4966 ];
4967
4968 let doc =
4969 GroundedDocument::from_entities("doc1", "Marie Curie won the Nobel Prize.", &entities);
4970 let converted = doc.to_entities();
4971
4972 assert_eq!(converted.len(), 2);
4973 assert_eq!(converted[0].text, "Marie Curie");
4974 assert_eq!(converted[1].text, "Nobel Prize");
4975 }
4976
4977 #[test]
4978 fn test_signal_confidence_threshold() {
4979 let signal: Signal<Location> = Signal::new(0, Location::text(0, 10), "test", "Type", 0.75);
4980 assert!(signal.is_confident(0.5));
4981 assert!(signal.is_confident(0.75));
4982 assert!(!signal.is_confident(0.8));
4983 }
4984
4985 #[test]
4986 fn test_document_filtering() {
4987 let mut doc = GroundedDocument::new("doc1", "Test text");
4988
4989 doc.add_signal(Signal::new(0, Location::text(0, 4), "high", "Person", 0.95));
4991 doc.add_signal(Signal::new(1, Location::text(5, 8), "low", "Person", 0.3));
4992 doc.add_signal(Signal::new(
4993 2,
4994 Location::text(9, 12),
4995 "org",
4996 "Organization",
4997 0.8,
4998 ));
4999
5000 let confident = doc.confident_signals(0.5);
5002 assert_eq!(confident.len(), 2);
5003
5004 let persons = doc.signals_with_label("Person");
5006 assert_eq!(persons.len(), 2);
5007
5008 let orgs = doc.signals_with_label("Organization");
5009 assert_eq!(orgs.len(), 1);
5010 }
5011
5012 #[test]
5013 fn test_untracked_signals() {
5014 let mut doc = GroundedDocument::new("doc1", "Test");
5015
5016 let s1 = doc.add_signal(Signal::new(0, Location::text(0, 4), "a", "T", 0.9));
5017 let s2 = doc.add_signal(Signal::new(1, Location::text(5, 8), "b", "T", 0.9));
5018 let _s3 = doc.add_signal(Signal::new(2, Location::text(9, 12), "c", "T", 0.9));
5019
5020 let mut track = Track::new(0, "a");
5022 track.add_signal(s1, 0);
5023 track.add_signal(s2, 1);
5024 doc.add_track(track);
5025
5026 assert_eq!(doc.untracked_signal_count(), 1);
5028 let untracked = doc.untracked_signals();
5029 assert_eq!(untracked.len(), 1);
5030 assert_eq!(untracked[0].surface, "c");
5031 }
5032
5033 #[test]
5034 fn test_linked_unlinked_tracks() {
5035 let mut doc = GroundedDocument::new("doc1", "Test");
5036
5037 let s1 = doc.add_signal(Signal::new(0, Location::text(0, 4), "a", "T", 0.9));
5038 let s2 = doc.add_signal(Signal::new(1, Location::text(5, 8), "b", "T", 0.9));
5039
5040 let mut track1 = Track::new(0, "a");
5041 track1.add_signal(s1, 0);
5042 let track1_id = doc.add_track(track1);
5043
5044 let mut track2 = Track::new(1, "b");
5045 track2.add_signal(s2, 0);
5046 doc.add_track(track2);
5047
5048 let identity = Identity::new(0, "Entity A");
5050 let identity_id = doc.add_identity(identity);
5051 doc.link_track_to_identity(track1_id, identity_id);
5052
5053 assert_eq!(doc.linked_tracks().count(), 1);
5054 assert_eq!(doc.unlinked_tracks().count(), 1);
5055 }
5056
5057 #[test]
5058 fn test_location_overlaps() {
5059 let l1 = Location::text(0, 10);
5060 let l2 = Location::text(5, 15);
5061 let l3 = Location::text(15, 20);
5062
5063 assert!(l1.overlaps(&l2));
5064 assert!(!l1.overlaps(&l3));
5065 assert!(!l2.overlaps(&l3)); let b1 = Location::bbox(0.0, 0.0, 0.5, 0.5);
5069 let b2 = Location::bbox(0.4, 0.4, 0.5, 0.5);
5070 let b3 = Location::bbox(0.6, 0.6, 0.2, 0.2);
5071
5072 assert!(b1.overlaps(&b2));
5073 assert!(!b1.overlaps(&b3));
5074 }
5075
5076 #[test]
5077 fn test_iou_edge_cases() {
5078 let l1 = Location::text(0, 5);
5080 let l2 = Location::text(10, 15);
5081 assert_eq!(l1.iou(&l2), Some(0.0));
5082
5083 let l3 = Location::text(0, 10);
5085 let l4 = Location::text(0, 10);
5086 assert_eq!(l3.iou(&l4), Some(1.0));
5087
5088 let l5 = Location::text(0, 20);
5090 let l6 = Location::text(5, 15);
5091 let iou = l5.iou(&l6).unwrap();
5092 assert!((iou - 0.5).abs() < 0.001);
5094 }
5095
5096 #[test]
5100 fn test_document_stats() {
5101 let mut doc = GroundedDocument::new("doc1", "Test document with entities.");
5102
5103 let s1 = doc.add_signal(Signal::new(0, Location::text(0, 4), "Test", "Type", 0.9));
5105 let mut negated = Signal::new(0, Location::text(5, 13), "document", "Type", 0.8);
5106 negated.negated = true;
5107 let s2 = doc.add_signal(negated);
5108 let _s3 = doc.add_signal(Signal::new(
5109 0,
5110 Location::text(19, 27),
5111 "entities",
5112 "Type",
5113 0.7,
5114 ));
5115
5116 let mut track = Track::new(0, "Test");
5118 track.add_signal(s1, 0);
5119 track.add_signal(s2, 1);
5120 doc.add_track(track);
5121
5122 let identity = Identity::new(0, "Test Entity");
5124 let identity_id = doc.add_identity(identity);
5125 doc.link_track_to_identity(0, identity_id);
5126
5127 let stats = doc.stats();
5128
5129 assert_eq!(stats.signal_count, 3);
5130 assert_eq!(stats.track_count, 1);
5131 assert_eq!(stats.identity_count, 1);
5132 assert_eq!(stats.linked_track_count, 1);
5133 assert_eq!(stats.untracked_count, 1); assert_eq!(stats.negated_count, 1);
5135 assert!((stats.avg_confidence - 0.8).abs() < 0.01); assert!((stats.avg_track_size - 2.0).abs() < 0.01);
5137 }
5138
5139 #[test]
5140 fn test_batch_operations() {
5141 let mut doc = GroundedDocument::new("doc1", "Test document.");
5142
5143 let signals = vec![
5145 Signal::new(0, Location::text(0, 4), "Test", "Type", 0.9),
5146 Signal::new(0, Location::text(5, 13), "document", "Type", 0.8),
5147 ];
5148 let ids = doc.add_signals(signals);
5149
5150 assert_eq!(ids.len(), 2);
5151 assert_eq!(doc.signals().len(), 2);
5152
5153 let track_id = doc.create_track_from_signals("Test", &ids);
5155 assert!(track_id.is_some());
5156
5157 let track = doc.get_track(track_id.unwrap()).unwrap();
5158 assert_eq!(track.len(), 2);
5159 assert_eq!(track.canonical_surface, "Test");
5160 }
5161
5162 #[test]
5163 fn test_merge_tracks() {
5164 let mut doc = GroundedDocument::new("doc1", "John Smith works at Acme. He is great.");
5165
5166 let s1 = doc.add_signal(Signal::new(
5168 0,
5169 Location::text(0, 10),
5170 "John Smith",
5171 "Person",
5172 0.9,
5173 ));
5174 let s2 = doc.add_signal(Signal::new(0, Location::text(26, 28), "He", "Person", 0.8));
5175
5176 let mut track1 = Track::new(0, "John Smith");
5178 track1.add_signal(s1, 0);
5179 let track1_id = doc.add_track(track1);
5180
5181 let mut track2 = Track::new(0, "He");
5182 track2.add_signal(s2, 0);
5183 let track2_id = doc.add_track(track2);
5184
5185 assert_eq!(doc.tracks().count(), 2);
5186
5187 let merged_id = doc.merge_tracks(&[track1_id, track2_id]);
5189 assert!(merged_id.is_some());
5190
5191 assert_eq!(doc.tracks().count(), 1);
5193 let merged = doc.get_track(merged_id.unwrap()).unwrap();
5194 assert_eq!(merged.len(), 2);
5195 assert_eq!(merged.canonical_surface, "John Smith"); }
5197
5198 #[test]
5199 fn test_find_overlapping_pairs() {
5200 let mut doc = GroundedDocument::new("doc1", "New York City is great.");
5201
5202 doc.add_signal(Signal::new(
5204 0,
5205 Location::text(0, 13),
5206 "New York City",
5207 "Location",
5208 0.9,
5209 ));
5210 doc.add_signal(Signal::new(
5211 0,
5212 Location::text(0, 8),
5213 "New York",
5214 "Location",
5215 0.85,
5216 ));
5217 doc.add_signal(Signal::new(0, Location::text(17, 22), "great", "Adj", 0.7)); let pairs = doc.find_overlapping_signal_pairs();
5220
5221 assert_eq!(pairs.len(), 1);
5223 }
5224
5225 #[test]
5226 fn test_signals_in_range() {
5227 let mut doc = GroundedDocument::new("doc1", "John went to Paris and Berlin last year.");
5228
5229 doc.add_signal(Signal::new(0, Location::text(0, 4), "John", "Person", 0.9));
5230 doc.add_signal(Signal::new(
5231 0,
5232 Location::text(13, 18),
5233 "Paris",
5234 "Location",
5235 0.9,
5236 ));
5237 doc.add_signal(Signal::new(
5238 0,
5239 Location::text(23, 29),
5240 "Berlin",
5241 "Location",
5242 0.9,
5243 ));
5244 doc.add_signal(Signal::new(
5245 0,
5246 Location::text(30, 39),
5247 "last year",
5248 "Date",
5249 0.8,
5250 ));
5251
5252 let in_range = doc.signals_in_range(10, 30);
5254 assert_eq!(in_range.len(), 2); let surfaces: Vec<_> = in_range.iter().map(|s| &s.surface).collect();
5257 assert!(surfaces.contains(&&"Paris".to_string()));
5258 assert!(surfaces.contains(&&"Berlin".to_string()));
5259 }
5260
5261 #[test]
5262 fn test_modality_filtering() {
5263 let mut doc = GroundedDocument::new("doc1", "Test");
5264
5265 let mut text_signal = Signal::new(0, Location::text(0, 4), "Test", "Type", 0.9);
5267 text_signal.modality = Modality::Symbolic;
5268 doc.add_signal(text_signal);
5269
5270 let mut visual_signal =
5272 Signal::new(0, Location::bbox(0.0, 0.0, 0.5, 0.5), "Box", "Type", 0.8);
5273 visual_signal.modality = Modality::Iconic;
5274 doc.add_signal(visual_signal);
5275
5276 assert_eq!(doc.text_signals().len(), 1);
5277 assert_eq!(doc.visual_signals().len(), 1);
5278 assert_eq!(doc.signals_by_modality(Modality::Hybrid).len(), 0);
5279 }
5280
5281 #[test]
5282 fn test_quantifier_variants() {
5283 let quantifiers = [
5285 Quantifier::Universal,
5286 Quantifier::Existential,
5287 Quantifier::None,
5288 Quantifier::Definite,
5289 Quantifier::Bare,
5290 ];
5291
5292 for q in quantifiers {
5293 let signal: Signal<Location> =
5294 Signal::new(0, Location::text(0, 5), "test", "Type", 0.9).with_quantifier(q);
5295
5296 assert_eq!(signal.quantifier, Some(q));
5297 }
5298 }
5299
5300 #[test]
5301 fn test_location_modality_derivation() {
5302 assert_eq!(Location::text(0, 10).modality(), Modality::Symbolic);
5303 assert_eq!(
5304 Location::bbox(0.0, 0.0, 0.5, 0.5).modality(),
5305 Modality::Iconic
5306 );
5307
5308 let temporal = Location::Temporal {
5309 start_sec: 0.0,
5310 end_sec: 5.0,
5311 frame: None,
5312 };
5313 assert_eq!(temporal.modality(), Modality::Iconic);
5314
5315 let genomic = Location::Genomic {
5316 contig: "chr1".into(),
5317 start: 0,
5318 end: 1000,
5319 strand: Some('+'),
5320 };
5321 assert_eq!(genomic.modality(), Modality::Symbolic);
5322
5323 let hybrid = Location::TextWithBbox {
5324 start: 0,
5325 end: 10,
5326 bbox: Box::new(Location::bbox(0.0, 0.0, 0.5, 0.5)),
5327 };
5328 assert_eq!(hybrid.modality(), Modality::Hybrid);
5329 }
5330
5331 }
5334
5335#[cfg(test)]
5343mod proptests {
5344 #![allow(clippy::unwrap_used)] use super::*;
5346 use proptest::prelude::*;
5347
5348 fn confidence_strategy() -> impl Strategy<Value = f32> {
5354 0.0f32..=1.0
5355 }
5356
5357 fn label_strategy() -> impl Strategy<Value = String> {
5359 prop_oneof![
5360 Just("Person".to_string()),
5361 Just("Organization".to_string()),
5362 Just("Location".to_string()),
5363 Just("Date".to_string()),
5364 "[A-Z][a-z]{2,10}".prop_map(|s| s),
5365 ]
5366 }
5367
5368 fn surface_strategy() -> impl Strategy<Value = String> {
5370 "[A-Za-z ]{1,50}".prop_map(|s| s.trim().to_string())
5371 }
5372
5373 proptest! {
5378 #[test]
5380 fn iou_symmetric(
5381 start1 in 0usize..1000,
5382 len1 in 1usize..500,
5383 start2 in 0usize..1000,
5384 len2 in 1usize..500,
5385 ) {
5386 let a = Location::text(start1, start1 + len1);
5387 let b = Location::text(start2, start2 + len2);
5388
5389 let iou_ab = a.iou(&b);
5390 let iou_ba = b.iou(&a);
5391
5392 prop_assert_eq!(iou_ab, iou_ba, "IoU must be symmetric");
5393 }
5394
5395 #[test]
5397 fn iou_bounded(
5398 start1 in 0usize..1000,
5399 len1 in 1usize..500,
5400 start2 in 0usize..1000,
5401 len2 in 1usize..500,
5402 ) {
5403 let a = Location::text(start1, start1 + len1);
5404 let b = Location::text(start2, start2 + len2);
5405
5406 if let Some(iou) = a.iou(&b) {
5407 prop_assert!(iou >= 0.0, "IoU must be non-negative: got {}", iou);
5408 prop_assert!(iou <= 1.0, "IoU must be at most 1: got {}", iou);
5409 }
5410 }
5411
5412 #[test]
5414 fn iou_self_identity(start in 0usize..1000, len in 1usize..500) {
5415 let loc = Location::text(start, start + len);
5416 let iou = loc.iou(&loc).unwrap();
5417 prop_assert!(
5418 (iou - 1.0).abs() < 1e-6,
5419 "Self-IoU must be 1.0, got {}",
5420 iou
5421 );
5422 }
5423
5424 #[test]
5426 fn iou_non_overlapping_zero(
5427 start1 in 0usize..500,
5428 len1 in 1usize..100,
5429 ) {
5430 let end1 = start1 + len1;
5431 let start2 = end1 + 100; let len2 = 50;
5433
5434 let a = Location::text(start1, end1);
5435 let b = Location::text(start2, start2 + len2);
5436
5437 let iou = a.iou(&b).expect("bbox iou should be defined");
5438 prop_assert!(
5439 iou.abs() < 1e-6,
5440 "Non-overlapping IoU must be 0, got {}",
5441 iou
5442 );
5443 }
5444
5445 #[test]
5447 fn bbox_iou_symmetric_bounded(
5448 x1 in 0.0f32..0.8,
5449 y1 in 0.0f32..0.8,
5450 w1 in 0.05f32..0.2,
5451 h1 in 0.05f32..0.2,
5452 x2 in 0.0f32..0.8,
5453 y2 in 0.0f32..0.8,
5454 w2 in 0.05f32..0.2,
5455 h2 in 0.05f32..0.2,
5456 ) {
5457 let a = Location::bbox(x1, y1, w1, h1);
5458 let b = Location::bbox(x2, y2, w2, h2);
5459
5460 let iou_ab = a.iou(&b);
5461 let iou_ba = b.iou(&a);
5462
5463 prop_assert_eq!(iou_ab, iou_ba, "BBox IoU must be symmetric");
5465
5466 if let Some(iou) = iou_ab {
5468 prop_assert!(
5469 (0.0..=1.0).contains(&iou),
5470 "BBox IoU out of bounds: {}",
5471 iou
5472 );
5473 }
5474 }
5475 }
5476
5477 proptest! {
5482 #[test]
5484 fn signal_confidence_clamped(raw_conf in -10.0f32..10.0) {
5485 let signal: Signal<Location> = Signal::new(
5486 0,
5487 Location::text(0, 10),
5488 "test",
5489 "Type",
5490 raw_conf,
5491 );
5492
5493 prop_assert!(signal.confidence >= 0.0, "Confidence below 0: {}", signal.confidence);
5494 prop_assert!(signal.confidence <= 1.0, "Confidence above 1: {}", signal.confidence);
5495 }
5496
5497 #[test]
5499 fn signal_preserves_data(
5500 surface in surface_strategy(),
5501 label in label_strategy(),
5502 conf in confidence_strategy(),
5503 start in 0usize..1000,
5504 len in 1usize..100,
5505 ) {
5506 let signal: Signal<Location> = Signal::new(
5507 0,
5508 Location::text(start, start + len),
5509 &surface,
5510 label.as_str(),
5511 conf,
5512 );
5513
5514 prop_assert_eq!(&signal.surface, &surface);
5515 let want = crate::TypeLabel::from(label.as_str());
5516 prop_assert_eq!(signal.label, want);
5517 }
5518
5519 #[test]
5523 fn signal_negation_stable(conf in confidence_strategy()) {
5524 let signal: Signal<Location> = Signal::new(
5525 0,
5526 Location::text(0, 10),
5527 "test",
5528 "Type",
5529 conf,
5530 )
5531 .negated();
5532
5533 prop_assert!(signal.negated, "Signal should be negated after .negated()");
5534 }
5535
5536 #[test]
5538 fn symbolic_supports_linguistic(
5539 start in 0usize..1000,
5540 len in 1usize..100,
5541 ) {
5542 let loc = Location::text(start, start + len);
5543 prop_assert!(
5544 loc.modality().supports_linguistic_features(),
5545 "Text locations must support linguistic features"
5546 );
5547 }
5548
5549 #[test]
5551 fn iconic_supports_geometric(
5552 x in 0.0f32..0.9,
5553 y in 0.0f32..0.9,
5554 w in 0.01f32..0.5,
5555 h in 0.01f32..0.5,
5556 ) {
5557 let loc = Location::bbox(x, y, w, h);
5558 prop_assert!(
5559 loc.modality().supports_geometric_features(),
5560 "BBox locations must support geometric features"
5561 );
5562 }
5563 }
5564
5565 proptest! {
5570 #[test]
5572 fn track_length_monotonic(signal_count in 1usize..20) {
5573 let mut track = Track::new(0, "test");
5574
5575 for i in 0..signal_count {
5576 track.add_signal(i, i as u32);
5577 prop_assert_eq!(
5578 track.len(),
5579 i + 1,
5580 "Track length should be {} after adding {} signals",
5581 i + 1,
5582 i + 1
5583 );
5584 }
5585 }
5586
5587 #[test]
5589 fn track_not_empty_after_add(canonical in surface_strategy()) {
5590 let mut track = Track::new(0, &canonical);
5591 prop_assert!(track.is_empty(), "New track should be empty");
5592
5593 track.add_signal(0, 0);
5594 prop_assert!(!track.is_empty(), "Track should not be empty after add");
5595 }
5596
5597 #[test]
5599 fn track_positions_stored(signal_count in 1usize..10) {
5600 let mut track = Track::new(0, "test");
5601
5602 for i in 0..signal_count {
5603 track.add_signal(i, i as u32);
5604 }
5605
5606 for (idx, signal_ref) in track.signals.iter().enumerate() {
5607 prop_assert_eq!(
5608 signal_ref.position as usize,
5609 idx,
5610 "Signal position mismatch at index {}",
5611 idx
5612 );
5613 }
5614 }
5615 }
5616
5617 proptest! {
5622 #[test]
5624 fn document_signal_ids_monotonic(signal_count in 1usize..20) {
5625 let mut doc = GroundedDocument::new("test", "test text");
5626
5627 let mut prev_id: Option<SignalId> = None;
5628 for i in 0..signal_count {
5629 let id = doc.add_signal(Signal::new(
5630 999, Location::text(i * 10, i * 10 + 5),
5632 format!("entity_{}", i),
5633 "Type",
5634 0.9,
5635 ));
5636
5637 if let Some(prev) = prev_id {
5638 prop_assert!(id > prev, "Signal IDs should be monotonically increasing");
5639 }
5640 prev_id = Some(id);
5641 }
5642 }
5643
5644 #[test]
5646 fn document_track_membership_consistent(signal_count in 1usize..5) {
5647 let mut doc = GroundedDocument::new("test", "test text");
5648
5649 let mut signal_ids = Vec::new();
5651 for i in 0..signal_count {
5652 let id = doc.add_signal(Signal::new(
5653 0,
5654 Location::text(i * 10, i * 10 + 5),
5655 format!("entity_{}", i),
5656 "Type",
5657 0.9,
5658 ));
5659 signal_ids.push(id);
5660 }
5661
5662 let mut track = Track::new(0, "canonical");
5664 for (pos, &id) in signal_ids.iter().enumerate() {
5665 track.add_signal(id, pos as u32);
5666 }
5667 let track_id = doc.add_track(track);
5668
5669 for &signal_id in &signal_ids {
5671 let found_track = doc.track_for_signal(signal_id);
5672 prop_assert!(found_track.is_some(), "Signal should be in a track");
5673 prop_assert_eq!(
5674 found_track.unwrap().id,
5675 track_id,
5676 "Signal should be in the correct track"
5677 );
5678 }
5679 }
5680
5681 #[test]
5683 fn document_identity_transitivity(signal_count in 1usize..3) {
5684 let mut doc = GroundedDocument::new("test", "test text");
5685
5686 let mut signal_ids = Vec::new();
5688 for i in 0..signal_count {
5689 let id = doc.add_signal(Signal::new(
5690 0,
5691 Location::text(i * 10, i * 10 + 5),
5692 format!("entity_{}", i),
5693 "Type",
5694 0.9,
5695 ));
5696 signal_ids.push(id);
5697 }
5698
5699 let mut track = Track::new(0, "canonical");
5701 for (pos, &id) in signal_ids.iter().enumerate() {
5702 track.add_signal(id, pos as u32);
5703 }
5704 let track_id = doc.add_track(track);
5705
5706 let identity = Identity::from_kb(0, "Entity", "wikidata", "Q123");
5707 let identity_id = doc.add_identity(identity);
5708 doc.link_track_to_identity(track_id, identity_id);
5709
5710 for &signal_id in &signal_ids {
5712 let identity = doc.identity_for_signal(signal_id);
5713 prop_assert!(identity.is_some(), "Should find identity through signal");
5714 prop_assert_eq!(
5715 identity.unwrap().id,
5716 identity_id,
5717 "Should find correct identity"
5718 );
5719 }
5720 }
5721
5722 #[test]
5724 fn document_untracked_signals(total in 2usize..10, tracked in 0usize..10) {
5725 let tracked = tracked.min(total - 1); let mut doc = GroundedDocument::new("test", "test text");
5727
5728 let mut signal_ids = Vec::new();
5730 for i in 0..total {
5731 let id = doc.add_signal(Signal::new(
5732 0,
5733 Location::text(i * 10, i * 10 + 5),
5734 format!("entity_{}", i),
5735 "Type",
5736 0.9,
5737 ));
5738 signal_ids.push(id);
5739 }
5740
5741 let mut track = Track::new(0, "canonical");
5743 for (pos, &id) in signal_ids.iter().take(tracked).enumerate() {
5744 track.add_signal(id, pos as u32);
5745 }
5746 if tracked > 0 {
5747 doc.add_track(track);
5748 }
5749
5750 prop_assert_eq!(
5752 doc.untracked_signal_count(),
5753 total - tracked,
5754 "Wrong untracked count"
5755 );
5756 }
5757 }
5758
5759 proptest! {
5764 #[test]
5766 fn entity_roundtrip_preserves_text(
5767 text in surface_strategy(),
5768 start in 0usize..1000,
5769 len in 1usize..100,
5770 conf in 0.0f64..=1.0,
5771 ) {
5772 use super::EntityType;
5773
5774 let end = start + len;
5775 let entity = super::Entity::new(&text, EntityType::Person, start, end, conf);
5776
5777 let doc = GroundedDocument::from_entities("test", "x".repeat(end + 10), &[entity]);
5778 let converted = doc.to_entities();
5779
5780 prop_assert_eq!(converted.len(), 1, "Should have exactly one entity");
5781 prop_assert_eq!(&converted[0].text, &text, "Text should be preserved");
5782 prop_assert_eq!(converted[0].start, start, "Start should be preserved");
5783 prop_assert_eq!(converted[0].end, end, "End should be preserved");
5784 }
5785
5786 }
5789
5790 proptest! {
5795 #[test]
5797 fn modality_feature_consistency(_dummy in 0..1) {
5798 prop_assert!(Modality::Iconic.supports_geometric_features());
5800 prop_assert!(!Modality::Iconic.supports_linguistic_features());
5801
5802 prop_assert!(Modality::Symbolic.supports_linguistic_features());
5804 prop_assert!(!Modality::Symbolic.supports_geometric_features());
5805
5806 prop_assert!(Modality::Hybrid.supports_linguistic_features());
5808 prop_assert!(Modality::Hybrid.supports_geometric_features());
5809 }
5810 }
5811
5812 proptest! {
5817 #[test]
5819 fn overlap_symmetric(
5820 start1 in 0usize..1000,
5821 len1 in 1usize..100,
5822 start2 in 0usize..1000,
5823 len2 in 1usize..100,
5824 ) {
5825 let a = Location::text(start1, start1 + len1);
5826 let b = Location::text(start2, start2 + len2);
5827
5828 prop_assert_eq!(
5829 a.overlaps(&b),
5830 b.overlaps(&a),
5831 "Overlap must be symmetric"
5832 );
5833 }
5834
5835 #[test]
5837 fn overlap_reflexive(start in 0usize..1000, len in 1usize..100) {
5838 let loc = Location::text(start, start + len);
5839 prop_assert!(loc.overlaps(&loc), "Location must overlap with itself");
5840 }
5841
5842 #[test]
5844 fn iou_implies_overlap(
5845 start1 in 0usize..500,
5846 len1 in 1usize..100,
5847 start2 in 0usize..500,
5848 len2 in 1usize..100,
5849 ) {
5850 let a = Location::text(start1, start1 + len1);
5851 let b = Location::text(start2, start2 + len2);
5852
5853 if let Some(iou) = a.iou(&b) {
5854 if iou > 0.0 {
5855 prop_assert!(
5856 a.overlaps(&b),
5857 "IoU > 0 should imply overlap"
5858 );
5859 }
5860 }
5861 }
5862 }
5863
5864 proptest! {
5869 #[test]
5871 fn stats_signal_count_accurate(signal_count in 0usize..20) {
5872 let mut doc = GroundedDocument::new("test", "test");
5873 for i in 0..signal_count {
5874 doc.add_signal(Signal::new(
5875 0,
5876 Location::text(i * 10, i * 10 + 5),
5877 "entity",
5878 "Type",
5879 0.9,
5880 ));
5881 }
5882
5883 let stats = doc.stats();
5884 prop_assert_eq!(stats.signal_count, signal_count);
5885 }
5886
5887 #[test]
5889 fn stats_track_count_accurate(track_count in 0usize..10) {
5890 let mut doc = GroundedDocument::new("test", "test");
5891 for i in 0..track_count {
5892 let id = doc.add_signal(Signal::new(
5893 0,
5894 Location::text(i * 10, i * 10 + 5),
5895 "entity",
5896 "Type",
5897 0.9,
5898 ));
5899 let mut track = Track::new(0, format!("track_{}", i));
5900 track.add_signal(id, 0);
5901 doc.add_track(track);
5902 }
5903
5904 let stats = doc.stats();
5905 prop_assert_eq!(stats.track_count, track_count);
5906 }
5907
5908 #[test]
5910 fn stats_avg_confidence_bounded(
5911 confidences in proptest::collection::vec(0.0f32..=1.0, 1..10)
5912 ) {
5913 let mut doc = GroundedDocument::new("test", "test");
5914 for (i, conf) in confidences.iter().enumerate() {
5915 doc.add_signal(Signal::new(
5916 0,
5917 Location::text(i * 10, i * 10 + 5),
5918 "entity",
5919 "Type",
5920 *conf,
5921 ));
5922 }
5923
5924 let stats = doc.stats();
5925 prop_assert!(stats.avg_confidence >= 0.0);
5926 prop_assert!(stats.avg_confidence <= 1.0);
5927 }
5928 }
5929
5930 proptest! {
5935 #[test]
5937 fn batch_add_returns_all_ids(count in 1usize..10) {
5938 let mut doc = GroundedDocument::new("test", "test");
5939 let signals: Vec<Signal<Location>> = (0..count)
5940 .map(|i| Signal::new(0, Location::text(i * 10, i * 10 + 5), "e", "T", 0.9))
5941 .collect();
5942
5943 let ids = doc.add_signals(signals);
5944 prop_assert_eq!(ids.len(), count);
5945 prop_assert_eq!(doc.signals().len(), count);
5946 }
5947
5948 #[test]
5950 fn create_track_valid(signal_count in 1usize..5) {
5951 let mut doc = GroundedDocument::new("test", "test");
5952 let mut signal_ids = Vec::new();
5953 for i in 0..signal_count {
5954 let id = doc.add_signal(Signal::new(
5955 0,
5956 Location::text(i * 10, i * 10 + 5),
5957 "entity",
5958 "Type",
5959 0.9,
5960 ));
5961 signal_ids.push(id);
5962 }
5963
5964 let track_id = doc.create_track_from_signals("canonical", &signal_ids);
5965 prop_assert!(track_id.is_some());
5966
5967 let track = doc.get_track(track_id.unwrap());
5968 prop_assert!(track.is_some());
5969 prop_assert_eq!(track.unwrap().len(), signal_count);
5970 }
5971
5972 #[test]
5974 fn create_track_empty_returns_none(_dummy in 0..1) {
5975 let mut doc = GroundedDocument::new("test", "test");
5976 let track_id = doc.create_track_from_signals("canonical", &[]);
5977 prop_assert!(track_id.is_none());
5978 }
5979 }
5980
5981 proptest! {
5986 #[test]
5988 fn signals_in_range_within_bounds(
5989 range_start in 0usize..100,
5990 range_len in 10usize..50,
5991 ) {
5992 let range_end = range_start + range_len;
5993 let mut doc = GroundedDocument::new("test", "x".repeat(200));
5994
5995 doc.add_signal(Signal::new(0, Location::text(range_start + 2, range_start + 5), "inside", "T", 0.9));
5997 doc.add_signal(Signal::new(0, Location::text(0, 5), "before", "T", 0.9));
5998 doc.add_signal(Signal::new(0, Location::text(190, 195), "after", "T", 0.9));
5999
6000 let in_range = doc.signals_in_range(range_start, range_end);
6001
6002 for signal in &in_range {
6003 if let Some((start, end)) = signal.location.text_offsets() {
6004 prop_assert!(start >= range_start, "Signal start {} < range start {}", start, range_start);
6005 prop_assert!(end <= range_end, "Signal end {} > range end {}", end, range_end);
6006 }
6007 }
6008 }
6009
6010 #[test]
6012 fn overlapping_signals_symmetric(
6013 start1 in 10usize..50,
6014 len1 in 5usize..20,
6015 start2 in 10usize..50,
6016 len2 in 5usize..20,
6017 ) {
6018 let mut doc = GroundedDocument::new("test", "x".repeat(100));
6019
6020 let loc1 = Location::text(start1, start1 + len1);
6021 let loc2 = Location::text(start2, start2 + len2);
6022
6023 doc.add_signal(Signal::new(0, loc1.clone(), "A", "T", 0.9));
6024 doc.add_signal(Signal::new(0, loc2.clone(), "B", "T", 0.9));
6025
6026 let overlaps_loc1 = doc.overlapping_signals(&loc1);
6027 let overlaps_loc2 = doc.overlapping_signals(&loc2);
6028
6029 if loc1.overlaps(&loc2) {
6031 prop_assert!(overlaps_loc1.len() >= 2, "Should find both when overlapping");
6032 prop_assert!(overlaps_loc2.len() >= 2, "Should find both when overlapping");
6033 }
6034 }
6035 }
6036
6037 proptest! {
6042 #[test]
6044 fn modality_counts_sum_to_total(
6045 symbolic_count in 0usize..5,
6046 iconic_count in 0usize..5,
6047 ) {
6048 let mut doc = GroundedDocument::new("test", "test");
6049
6050 for i in 0..symbolic_count {
6052 let mut signal = Signal::new(
6053 0,
6054 Location::text(i * 10, i * 10 + 5),
6055 "entity",
6056 "Type",
6057 0.9,
6058 );
6059 signal.modality = Modality::Symbolic;
6060 doc.add_signal(signal);
6061 }
6062
6063 for i in 0..iconic_count {
6065 let mut signal = Signal::new(
6066 0,
6067 Location::bbox(i as f32 * 0.1, 0.0, 0.05, 0.05),
6068 "entity",
6069 "Type",
6070 0.9,
6071 );
6072 signal.modality = Modality::Iconic;
6073 doc.add_signal(signal);
6074 }
6075
6076 let stats = doc.stats();
6077 prop_assert_eq!(
6078 stats.symbolic_count + stats.iconic_count + stats.hybrid_count,
6079 stats.signal_count,
6080 "Modality counts should sum to total"
6081 );
6082 }
6083 }
6084
6085 proptest! {
6090 #[test]
6092 fn from_text_always_valid(
6093 text in "[a-zA-Z ]{20,100}",
6094 surface_start in 0usize..15,
6095 surface_len in 1usize..8,
6096 ) {
6097 let text_char_len = text.chars().count();
6098 let surface_end = (surface_start + surface_len).min(text_char_len);
6099 let surface_start = surface_start.min(surface_end.saturating_sub(1));
6100
6101 if surface_start < surface_end && surface_end <= text_char_len {
6102 let surface: String = text.chars()
6103 .skip(surface_start)
6104 .take(surface_end - surface_start)
6105 .collect();
6106
6107 if !surface.is_empty() {
6108 if let Some(signal) = Signal::<Location>::from_text(&text, &surface, "Test", 0.9) {
6110 prop_assert!(
6112 signal.validate_against(&text).is_none(),
6113 "Signal created via from_text must be valid"
6114 );
6115 }
6116 }
6117 }
6118 }
6119
6120 #[test]
6122 fn validated_add_rejects_invalid(
6123 text in "[a-z]{10,50}",
6124 wrong_surface in "[A-Z]{3,10}",
6125 ) {
6126 let mut doc = GroundedDocument::new("test", &text);
6127
6128 let signal = Signal::new(
6130 0,
6131 Location::text(0, wrong_surface.chars().count().min(text.chars().count())),
6132 wrong_surface.clone(),
6133 "Test",
6134 0.9,
6135 );
6136
6137 let expected: String = text.chars().take(wrong_surface.chars().count()).collect();
6140 if expected != wrong_surface {
6141 let result = doc.add_signal_validated(signal);
6142 prop_assert!(result.is_err(), "Should reject signal with mismatched surface");
6143 }
6144 }
6145
6146 #[test]
6148 fn round_trip_signal_from_text(
6149 prefix in "[a-z]{5,20}",
6150 entity in "[A-Z][a-z]{3,10}",
6151 suffix in "[a-z]{5,20}",
6152 ) {
6153 let text = format!("{} {} {}", prefix, entity, suffix);
6154 let mut doc = GroundedDocument::new("test", &text);
6155
6156 let id = doc.add_signal_from_text(&entity, "Entity", 0.9);
6157 prop_assert!(id.is_some(), "Should find entity in text");
6158
6159 let signal = doc.signals().iter().find(|s| s.id == id.unwrap());
6160 prop_assert!(signal.is_some(), "Should retrieve added signal");
6161
6162 let signal = signal.unwrap();
6163 prop_assert_eq!(signal.surface(), entity.as_str(), "Surface should match");
6164
6165 prop_assert!(
6167 doc.is_valid(),
6168 "Document should be valid after from_text add"
6169 );
6170 }
6171
6172 #[test]
6174 fn nth_occurrence_finds_correct(
6175 entity in "[A-Z][a-z]{2,5}",
6176 sep in " [a-z]+ ",
6177 ) {
6178 let text = format!("{}{}{}{}{}", entity, sep, entity, sep, entity);
6180 let mut doc = GroundedDocument::new("test", &text);
6181
6182 for n in 0..3 {
6184 let id = doc.add_signal_from_text_nth(&entity, "Entity", 0.9, n);
6185 prop_assert!(id.is_some(), "Should find occurrence {}", n);
6186 }
6187
6188 let id = doc.add_signal_from_text_nth(&entity, "Entity", 0.9, 3);
6190 prop_assert!(id.is_none(), "Should NOT find 4th occurrence");
6191
6192 prop_assert!(doc.is_valid(), "All signals should be valid");
6194
6195 let offsets: Vec<_> = doc.signals()
6197 .iter()
6198 .filter_map(|s| s.text_offsets())
6199 .collect();
6200 let unique: std::collections::HashSet<_> = offsets.iter().collect();
6201 prop_assert_eq!(offsets.len(), unique.len(), "Each occurrence should have distinct offset");
6202 }
6203 }
6204
6205 #[test]
6210 fn test_track_stats_basic() {
6211 let text = "John met Mary. He said hello. John left.";
6212 let mut doc = GroundedDocument::new("test", text);
6213 let text_len = text.chars().count();
6214
6215 let s1 = doc.add_signal(Signal::new(0, Location::text(0, 4), "John", "Person", 0.95));
6217 let s2 = doc.add_signal(Signal::new(
6218 0,
6219 Location::text(30, 34),
6220 "John",
6221 "Person",
6222 0.90,
6223 ));
6224
6225 let track_id = doc.add_track(Track::new(0, "John".to_string()));
6227 doc.add_signal_to_track(s1, track_id, 0);
6228 doc.add_signal_to_track(s2, track_id, 1);
6229
6230 let track = doc.get_track(track_id).unwrap();
6232 let stats = track.compute_stats(&doc, text_len);
6233
6234 assert_eq!(stats.chain_length, 2, "Two mentions");
6235 assert_eq!(stats.variation_count, 1, "One unique surface form");
6236 assert!(stats.spread > 0, "Spread should be positive");
6237 assert!(stats.relative_spread > 0.0 && stats.relative_spread < 1.0);
6238 assert!((stats.min_confidence - 0.90).abs() < 0.01);
6239 assert!((stats.max_confidence - 0.95).abs() < 0.01);
6240 assert!((stats.mean_confidence - 0.925).abs() < 0.01);
6241 }
6242
6243 #[test]
6244 fn test_track_stats_singleton() {
6245 let text = "Paris is beautiful.";
6246 let mut doc = GroundedDocument::new("test", text);
6247 let text_len = text.chars().count();
6248
6249 let s1 = doc.add_signal(Signal::new(
6250 0,
6251 Location::text(0, 5),
6252 "Paris",
6253 "Location",
6254 0.88,
6255 ));
6256 let track_id = doc.add_track(Track::new(0, "Paris".to_string()));
6257 doc.add_signal_to_track(s1, track_id, 0);
6258
6259 let track = doc.get_track(track_id).unwrap();
6260 let stats = track.compute_stats(&doc, text_len);
6261
6262 assert_eq!(stats.chain_length, 1);
6263 assert_eq!(stats.spread, 0, "Singleton has zero spread");
6264 assert_eq!(stats.first_position, stats.last_position);
6265 assert!((stats.min_confidence - stats.max_confidence).abs() < 0.001);
6266 }
6267}