1use super::entity::{
90 DiscontinuousSpan, Entity, EntityType, HierarchicalConfidence, Provenance, Span,
91};
92use serde::{Deserialize, Serialize};
93use std::collections::HashMap;
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
121pub enum Modality {
122 Iconic,
125 #[default]
128 Symbolic,
129 Hybrid,
132}
133
134impl Modality {
135 #[must_use]
137 pub const fn supports_linguistic_features(&self) -> bool {
138 matches!(self, Self::Symbolic | Self::Hybrid)
139 }
140
141 #[must_use]
143 pub const fn supports_geometric_features(&self) -> bool {
144 matches!(self, Self::Iconic | Self::Hybrid)
145 }
146}
147
148#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
183pub enum Location {
184 Text {
186 start: usize,
188 end: usize,
190 },
191 BoundingBox {
193 x: f32,
195 y: f32,
197 width: f32,
199 height: f32,
201 page: Option<u32>,
203 },
204 Temporal {
206 start_sec: f64,
208 end_sec: f64,
210 frame: Option<u64>,
212 },
213 Cuboid {
215 center: [f32; 3],
217 dimensions: [f32; 3],
219 rotation: [f32; 4],
221 },
222 Genomic {
224 contig: String,
226 start: u64,
228 end: u64,
230 strand: Option<char>,
232 },
233 Discontinuous {
235 segments: Vec<(usize, usize)>,
237 },
238 TextWithBbox {
240 start: usize,
242 end: usize,
244 bbox: Box<Location>,
246 },
247}
248
249impl Location {
250 #[must_use]
252 pub const fn text(start: usize, end: usize) -> Self {
253 Self::Text { start, end }
254 }
255
256 #[must_use]
258 pub fn bbox(x: f32, y: f32, width: f32, height: f32) -> Self {
259 Self::BoundingBox {
260 x,
261 y,
262 width,
263 height,
264 page: None,
265 }
266 }
267
268 #[must_use]
270 pub const fn modality(&self) -> Modality {
271 match self {
272 Self::Text { .. } | Self::Genomic { .. } | Self::Discontinuous { .. } => {
273 Modality::Symbolic
274 }
275 Self::BoundingBox { .. } | Self::Cuboid { .. } => Modality::Iconic,
276 Self::Temporal { .. } => Modality::Iconic, Self::TextWithBbox { .. } => Modality::Hybrid,
278 }
279 }
280
281 #[must_use]
283 pub fn text_offsets(&self) -> Option<(usize, usize)> {
284 match self {
285 Self::Text { start, end } => Some((*start, *end)),
286 Self::TextWithBbox { start, end, .. } => Some((*start, *end)),
287 Self::Discontinuous { segments } => {
288 let start = segments.iter().map(|(s, _)| *s).min()?;
289 let end = segments.iter().map(|(_, e)| *e).max()?;
290 Some((start, end))
291 }
292 _ => None,
293 }
294 }
295
296 #[must_use]
298 pub fn overlaps(&self, other: &Self) -> bool {
299 match (self, other) {
300 (Self::Text { start: s1, end: e1 }, Self::Text { start: s2, end: e2 }) => {
301 s1 < e2 && s2 < e1
302 }
303 (
304 Self::BoundingBox {
305 x: x1,
306 y: y1,
307 width: w1,
308 height: h1,
309 page: p1,
310 },
311 Self::BoundingBox {
312 x: x2,
313 y: y2,
314 width: w2,
315 height: h2,
316 page: p2,
317 },
318 ) => {
319 if p1 != p2 {
321 return false;
322 }
323 x1 < &(x2 + w2) && &(x1 + w1) > x2 && y1 < &(y2 + h2) && &(y1 + h1) > y2
325 }
326 _ => false, }
328 }
329
330 #[must_use]
334 pub fn iou(&self, other: &Self) -> Option<f64> {
335 match (self, other) {
336 (Self::Text { start: s1, end: e1 }, Self::Text { start: s2, end: e2 }) => {
337 let intersection_start = (*s1).max(*s2);
338 let intersection_end = (*e1).min(*e2);
339 if intersection_start >= intersection_end {
340 return Some(0.0);
341 }
342 let intersection = (intersection_end - intersection_start) as f64;
343 let union = ((*e1).max(*e2) - (*s1).min(*s2)) as f64;
344 if union == 0.0 {
345 Some(0.0)
346 } else {
347 Some(intersection / union)
348 }
349 }
350 (
351 Self::BoundingBox {
352 x: x1,
353 y: y1,
354 width: w1,
355 height: h1,
356 page: p1,
357 },
358 Self::BoundingBox {
359 x: x2,
360 y: y2,
361 width: w2,
362 height: h2,
363 page: p2,
364 },
365 ) => {
366 if p1 != p2 {
367 return Some(0.0);
368 }
369 let x_overlap = (x1 + w1).min(x2 + w2) - x1.max(*x2);
370 let y_overlap = (y1 + h1).min(y2 + h2) - y1.max(*y2);
371 if x_overlap <= 0.0 || y_overlap <= 0.0 {
372 return Some(0.0);
373 }
374 let intersection = (x_overlap * y_overlap) as f64;
375 let area1 = (*w1 * *h1) as f64;
376 let area2 = (*w2 * *h2) as f64;
377 let union = area1 + area2 - intersection;
378 if union == 0.0 {
379 Some(0.0)
380 } else {
381 Some(intersection / union)
382 }
383 }
384 _ => None,
385 }
386 }
387}
388
389impl Default for Location {
390 fn default() -> Self {
391 Self::Text { start: 0, end: 0 }
392 }
393}
394
395impl From<&Span> for Location {
396 fn from(span: &Span) -> Self {
397 match span {
398 Span::Text { start, end } => Self::Text {
399 start: *start,
400 end: *end,
401 },
402 Span::BoundingBox {
403 x,
404 y,
405 width,
406 height,
407 page,
408 } => Self::BoundingBox {
409 x: *x,
410 y: *y,
411 width: *width,
412 height: *height,
413 page: *page,
414 },
415 Span::Hybrid { start, end, bbox } => Self::TextWithBbox {
416 start: *start,
417 end: *end,
418 bbox: Box::new(Location::from(bbox.as_ref())),
419 },
420 }
421 }
422}
423
424impl From<Span> for Location {
425 fn from(span: Span) -> Self {
426 Self::from(&span)
427 }
428}
429
430impl Location {
439 #[must_use]
444 pub fn to_span(&self) -> Option<Span> {
445 match self {
446 Self::Text { start, end } => Some(Span::Text {
447 start: *start,
448 end: *end,
449 }),
450 Self::BoundingBox {
451 x,
452 y,
453 width,
454 height,
455 page,
456 } => Some(Span::BoundingBox {
457 x: *x,
458 y: *y,
459 width: *width,
460 height: *height,
461 page: *page,
462 }),
463 Self::TextWithBbox { start, end, bbox } => {
464 let inner_span = bbox.to_span()?;
465 Some(Span::Hybrid {
466 start: *start,
467 end: *end,
468 bbox: Box::new(inner_span),
469 })
470 }
471 Self::Temporal { .. }
473 | Self::Cuboid { .. }
474 | Self::Genomic { .. }
475 | Self::Discontinuous { .. } => None,
476 }
477 }
478}
479
480pub use super::types::SignalId;
486
487#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
517pub struct Signal<L = Location> {
518 pub id: SignalId,
520 pub location: L,
522 pub surface: String,
524 pub label: super::types::TypeLabel,
528 pub confidence: f32,
530 pub hierarchical: Option<HierarchicalConfidence>,
532 pub provenance: Option<Provenance>,
534 pub modality: Modality,
536 pub normalized: Option<String>,
538 pub negated: bool,
540 pub quantifier: Option<Quantifier>,
542}
543
544#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
549pub enum Quantifier {
550 Universal,
552 Existential,
554 None,
556 Definite,
558 Bare,
560}
561
562impl<L> Signal<L> {
563 #[must_use]
573 pub fn new(
574 id: impl Into<SignalId>,
575 location: L,
576 surface: impl Into<String>,
577 label: impl Into<super::types::TypeLabel>,
578 confidence: f32,
579 ) -> Self {
580 Self {
581 id: id.into(),
582 location,
583 surface: surface.into(),
584 label: label.into(),
585 confidence: confidence.clamp(0.0, 1.0),
586 hierarchical: None,
587 provenance: None,
588 modality: Modality::default(),
589 normalized: None,
590 negated: false,
591 quantifier: None,
592 }
593 }
594
595 #[must_use]
597 pub fn label(&self) -> &str {
598 self.label.as_str()
599 }
600
601 #[must_use]
603 pub fn type_label(&self) -> super::types::TypeLabel {
604 self.label.clone()
605 }
606
607 #[must_use]
609 pub fn surface(&self) -> &str {
610 &self.surface
611 }
612
613 #[must_use]
615 pub fn is_confident(&self, threshold: f32) -> bool {
616 self.confidence >= threshold
617 }
618
619 #[must_use]
621 pub fn with_modality(mut self, modality: Modality) -> Self {
622 self.modality = modality;
623 self
624 }
625
626 #[must_use]
628 pub fn negated(mut self) -> Self {
629 self.negated = true;
630 self
631 }
632
633 #[must_use]
635 pub fn with_quantifier(mut self, q: Quantifier) -> Self {
636 self.quantifier = Some(q);
637 self
638 }
639
640 #[must_use]
642 pub fn with_provenance(mut self, p: Provenance) -> Self {
643 self.provenance = Some(p);
644 self
645 }
646}
647
648impl Signal<Location> {
649 #[must_use]
651 pub fn text_offsets(&self) -> Option<(usize, usize)> {
652 self.location.text_offsets()
653 }
654
655 #[must_use]
672 pub fn validate_against(&self, source_text: &str) -> Option<SignalValidationError> {
673 let (start, end) = self.location.text_offsets()?;
674
675 let char_count = source_text.chars().count();
676
677 if end > char_count {
679 return Some(SignalValidationError::OutOfBounds {
680 signal_id: self.id,
681 end,
682 text_len: char_count,
683 });
684 }
685
686 if start >= end {
687 return Some(SignalValidationError::InvalidSpan {
688 signal_id: self.id,
689 start,
690 end,
691 });
692 }
693
694 let actual: String = source_text.chars().skip(start).take(end - start).collect();
696
697 if actual != self.surface {
698 return Some(SignalValidationError::TextMismatch {
699 signal_id: self.id,
700 expected: self.surface.clone(),
701 actual,
702 start,
703 end,
704 });
705 }
706
707 None
708 }
709
710 #[must_use]
712 pub fn is_valid(&self, source_text: &str) -> bool {
713 self.validate_against(source_text).is_none()
714 }
715
716 #[must_use]
731 pub fn from_text(
732 source: &str,
733 surface: &str,
734 label: impl Into<super::types::TypeLabel>,
735 confidence: f32,
736 ) -> Option<Self> {
737 Self::from_text_nth(source, surface, label, confidence, 0)
738 }
739
740 #[must_use]
742 pub fn from_text_nth(
743 source: &str,
744 surface: &str,
745 label: impl Into<super::types::TypeLabel>,
746 confidence: f32,
747 occurrence: usize,
748 ) -> Option<Self> {
749 for (count, (byte_idx, _)) in source.match_indices(surface).enumerate() {
751 if count == occurrence {
752 let start = source[..byte_idx].chars().count();
754 let end = start + surface.chars().count();
755
756 return Some(Self::new(
757 SignalId::ZERO,
758 Location::text(start, end),
759 surface,
760 label,
761 confidence,
762 ));
763 }
764 }
765
766 None
767 }
768}
769
770#[derive(Debug, Clone, PartialEq)]
772pub enum SignalValidationError {
773 OutOfBounds {
775 signal_id: SignalId,
777 end: usize,
779 text_len: usize,
781 },
782 InvalidSpan {
784 signal_id: SignalId,
786 start: usize,
788 end: usize,
790 },
791 TextMismatch {
793 signal_id: SignalId,
795 expected: String,
797 actual: String,
799 start: usize,
801 end: usize,
803 },
804}
805
806impl std::fmt::Display for SignalValidationError {
807 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
808 match self {
809 Self::OutOfBounds {
810 signal_id,
811 end,
812 text_len,
813 } => {
814 write!(
815 f,
816 "S{}: end offset {} exceeds text length {}",
817 signal_id, end, text_len
818 )
819 }
820 Self::InvalidSpan {
821 signal_id,
822 start,
823 end,
824 } => {
825 write!(f, "S{}: invalid span [{}, {})", signal_id, start, end)
826 }
827 Self::TextMismatch {
828 signal_id,
829 expected,
830 actual,
831 start,
832 end,
833 } => {
834 write!(
835 f,
836 "S{}: text mismatch at [{}, {}): expected '{}', found '{}'",
837 signal_id, start, end, expected, actual
838 )
839 }
840 }
841 }
842}
843
844impl std::error::Error for SignalValidationError {}
845
846pub use super::types::TrackId;
852
853#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
855pub struct SignalRef {
856 pub signal_id: SignalId,
858 pub position: u32,
860}
861
862#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
868pub struct TrackRef {
869 pub doc_id: String,
871 pub track_id: TrackId,
873}
874
875#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
895pub struct Track {
896 pub id: TrackId,
898 pub signals: Vec<SignalRef>,
900 pub entity_type: Option<super::types::TypeLabel>,
904 pub canonical_surface: String,
906 pub identity_id: Option<IdentityId>,
908 pub cluster_confidence: f32,
910 pub embedding: Option<Vec<f32>>,
913}
914
915impl Track {
916 #[must_use]
918 pub fn new(id: impl Into<TrackId>, canonical_surface: impl Into<String>) -> Self {
919 Self {
920 id: id.into(),
921 signals: Vec::new(),
922 entity_type: None,
923 canonical_surface: canonical_surface.into(),
924 identity_id: None,
925 cluster_confidence: 1.0,
926 embedding: None,
927 }
928 }
929
930 pub fn add_signal(&mut self, signal_id: impl Into<SignalId>, position: u32) {
932 let signal_id = signal_id.into();
933 self.signals.push(SignalRef {
934 signal_id,
935 position,
936 });
937 }
938
939 #[must_use]
941 pub fn len(&self) -> usize {
942 self.signals.len()
943 }
944
945 #[must_use]
947 pub fn is_empty(&self) -> bool {
948 self.signals.is_empty()
949 }
950
951 #[must_use]
953 pub fn is_singleton(&self) -> bool {
954 self.signals.len() == 1
955 }
956
957 #[must_use]
959 pub const fn id(&self) -> TrackId {
960 self.id
961 }
962
963 #[must_use]
965 pub fn signals(&self) -> &[SignalRef] {
966 &self.signals
967 }
968
969 #[must_use]
971 pub fn canonical_surface(&self) -> &str {
972 &self.canonical_surface
973 }
974
975 #[must_use]
977 pub const fn identity_id(&self) -> Option<IdentityId> {
978 self.identity_id
979 }
980
981 #[must_use]
983 pub const fn cluster_confidence(&self) -> f32 {
984 self.cluster_confidence
985 }
986
987 pub fn set_cluster_confidence(&mut self, confidence: f32) {
989 self.cluster_confidence = confidence.clamp(0.0, 1.0);
990 }
991
992 pub fn set_identity_id(&mut self, identity_id: IdentityId) {
994 self.identity_id = Some(identity_id);
995 }
996
997 pub fn clear_identity_id(&mut self) {
999 self.identity_id = None;
1000 }
1001
1002 #[must_use]
1004 pub fn with_identity(mut self, identity_id: IdentityId) -> Self {
1005 self.identity_id = Some(identity_id);
1006 self
1007 }
1008
1009 #[must_use]
1013 pub fn with_type(mut self, entity_type: impl Into<String>) -> Self {
1014 let s = entity_type.into();
1015 self.entity_type = Some(super::types::TypeLabel::from(s.as_str()));
1016 self
1017 }
1018
1019 #[must_use]
1035 pub fn with_type_label(mut self, label: super::types::TypeLabel) -> Self {
1036 self.entity_type = Some(label);
1037 self
1038 }
1039
1040 #[must_use]
1045 pub fn type_label(&self) -> Option<super::types::TypeLabel> {
1046 self.entity_type.clone()
1047 }
1048
1049 #[must_use]
1051 pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
1052 self.embedding = Some(embedding);
1053 self
1054 }
1055
1056 pub fn compute_spread(&self, doc: &GroundedDocument) -> Option<usize> {
1060 if self.signals.is_empty() {
1061 return Some(0);
1062 }
1063
1064 let positions: Vec<usize> = self
1065 .signals
1066 .iter()
1067 .filter_map(|sr| {
1068 doc.signals
1069 .iter()
1070 .find(|s| s.id == sr.signal_id)
1071 .and_then(|s| s.location.text_offsets())
1072 .map(|(start, _)| start)
1073 })
1074 .collect();
1075
1076 if positions.is_empty() {
1077 return None;
1078 }
1079
1080 let min_pos = *positions.iter().min().expect("positions non-empty");
1081 let max_pos = *positions.iter().max().expect("positions non-empty");
1082 Some(max_pos.saturating_sub(min_pos))
1083 }
1084
1085 pub fn collect_variations(&self, doc: &GroundedDocument) -> Vec<String> {
1089 let mut variations: std::collections::HashSet<String> = std::collections::HashSet::new();
1090
1091 for sr in &self.signals {
1092 if let Some(signal) = doc.signals.iter().find(|s| s.id == sr.signal_id) {
1093 variations.insert(signal.surface.clone());
1094 }
1095 }
1096
1097 variations.into_iter().collect()
1098 }
1099
1100 pub fn confidence_stats(&self, doc: &GroundedDocument) -> Option<(f32, f32, f32)> {
1104 let confidences: Vec<f32> = self
1105 .signals
1106 .iter()
1107 .filter_map(|sr| {
1108 doc.signals
1109 .iter()
1110 .find(|s| s.id == sr.signal_id)
1111 .map(|s| s.confidence)
1112 })
1113 .collect();
1114
1115 if confidences.is_empty() {
1116 return None;
1117 }
1118
1119 let min = confidences.iter().cloned().fold(f32::INFINITY, f32::min);
1120 let max = confidences
1121 .iter()
1122 .cloned()
1123 .fold(f32::NEG_INFINITY, f32::max);
1124 let mean = confidences.iter().sum::<f32>() / confidences.len() as f32;
1125
1126 Some((min, max, mean))
1127 }
1128
1129 pub fn compute_stats(&self, doc: &GroundedDocument, text_len: usize) -> TrackStats {
1133 let chain_length = self.signals.len();
1134 let spread = self.compute_spread(doc).unwrap_or(0);
1135 let variations = self.collect_variations(doc);
1136 let (min_conf, max_conf, mean_conf) = self.confidence_stats(doc).unwrap_or((0.0, 0.0, 0.0));
1137
1138 let positions: Vec<usize> = self
1140 .signals
1141 .iter()
1142 .filter_map(|sr| {
1143 doc.signals
1144 .iter()
1145 .find(|s| s.id == sr.signal_id)
1146 .and_then(|s| s.location.text_offsets())
1147 .map(|(start, _)| start)
1148 })
1149 .collect();
1150
1151 let first_position = positions.iter().min().copied().unwrap_or(0);
1152 let last_position = positions.iter().max().copied().unwrap_or(0);
1153 let relative_spread = if text_len > 0 {
1154 spread as f64 / text_len as f64
1155 } else {
1156 0.0
1157 };
1158
1159 TrackStats {
1160 chain_length,
1161 variation_count: variations.len(),
1162 variations,
1163 spread,
1164 relative_spread,
1165 first_position,
1166 last_position,
1167 min_confidence: min_conf,
1168 max_confidence: max_conf,
1169 mean_confidence: mean_conf,
1170 has_embedding: self.embedding.is_some(),
1171 }
1172 }
1173}
1174
1175#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1177pub struct TrackStats {
1178 pub chain_length: usize,
1180 pub variation_count: usize,
1182 pub variations: Vec<String>,
1184 pub spread: usize,
1186 pub relative_spread: f64,
1188 pub first_position: usize,
1190 pub last_position: usize,
1192 pub min_confidence: f32,
1194 pub max_confidence: f32,
1196 pub mean_confidence: f32,
1198 pub has_embedding: bool,
1200}
1201
1202pub use super::types::IdentityId;
1208
1209#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
1214pub enum IdentitySource {
1215 CrossDocCoref {
1218 track_refs: Vec<TrackRef>,
1220 },
1221 KnowledgeBase {
1224 kb_name: String,
1226 kb_id: String,
1228 },
1229 Hybrid {
1232 track_refs: Vec<TrackRef>,
1234 kb_name: String,
1236 kb_id: String,
1238 },
1239}
1240
1241#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1263pub struct Identity {
1264 pub id: IdentityId,
1266 pub canonical_name: String,
1268 pub entity_type: Option<super::types::TypeLabel>,
1272 pub kb_id: Option<String>,
1274 pub kb_name: Option<String>,
1276 pub description: Option<String>,
1278 pub embedding: Option<Vec<f32>>,
1281 pub aliases: Vec<String>,
1283 pub confidence: f32,
1285 #[serde(default, skip_serializing_if = "Option::is_none")]
1287 pub source: Option<IdentitySource>,
1288}
1289
1290impl Identity {
1291 #[must_use]
1293 pub fn new(id: impl Into<IdentityId>, canonical_name: impl Into<String>) -> Self {
1294 Self {
1295 id: id.into(),
1296 canonical_name: canonical_name.into(),
1297 entity_type: None,
1298 kb_id: None,
1299 kb_name: None,
1300 description: None,
1301 embedding: None,
1302 aliases: Vec::new(),
1303 confidence: 1.0,
1304 source: None,
1305 }
1306 }
1307
1308 #[must_use]
1310 pub fn from_kb(
1311 id: impl Into<IdentityId>,
1312 canonical_name: impl Into<String>,
1313 kb_name: impl Into<String>,
1314 kb_id: impl Into<String>,
1315 ) -> Self {
1316 let kb_name_str = kb_name.into();
1317 let kb_id_str = kb_id.into();
1318 Self {
1319 id: id.into(),
1320 canonical_name: canonical_name.into(),
1321 entity_type: None,
1322 kb_id: Some(kb_id_str.clone()),
1323 kb_name: Some(kb_name_str.clone()),
1324 description: None,
1325 embedding: None,
1326 aliases: Vec::new(),
1327 confidence: 1.0,
1328 source: Some(IdentitySource::KnowledgeBase {
1329 kb_name: kb_name_str,
1330 kb_id: kb_id_str,
1331 }),
1332 }
1333 }
1334
1335 pub fn add_alias(&mut self, alias: impl Into<String>) {
1337 self.aliases.push(alias.into());
1338 }
1339
1340 #[must_use]
1342 pub const fn id(&self) -> IdentityId {
1343 self.id
1344 }
1345
1346 #[must_use]
1348 pub fn canonical_name(&self) -> &str {
1349 &self.canonical_name
1350 }
1351
1352 #[must_use]
1354 pub fn kb_id(&self) -> Option<&str> {
1355 self.kb_id.as_deref()
1356 }
1357
1358 #[must_use]
1360 pub fn kb_name(&self) -> Option<&str> {
1361 self.kb_name.as_deref()
1362 }
1363
1364 #[must_use]
1366 pub fn aliases(&self) -> &[String] {
1367 &self.aliases
1368 }
1369
1370 #[must_use]
1372 pub const fn confidence(&self) -> f32 {
1373 self.confidence
1374 }
1375
1376 pub fn set_confidence(&mut self, confidence: f32) {
1378 self.confidence = confidence.clamp(0.0, 1.0);
1379 }
1380
1381 #[must_use]
1383 pub fn source(&self) -> Option<&IdentitySource> {
1384 self.source.as_ref()
1385 }
1386
1387 #[must_use]
1389 pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
1390 self.embedding = Some(embedding);
1391 self
1392 }
1393
1394 #[must_use]
1398 pub fn with_type(mut self, entity_type: impl Into<String>) -> Self {
1399 let s = entity_type.into();
1400 self.entity_type = Some(super::types::TypeLabel::from(s.as_str()));
1401 self
1402 }
1403
1404 #[must_use]
1409 pub fn with_type_label(mut self, label: super::types::TypeLabel) -> Self {
1410 self.entity_type = Some(label);
1411 self
1412 }
1413
1414 #[must_use]
1419 pub fn type_label(&self) -> Option<super::types::TypeLabel> {
1420 self.entity_type.clone()
1421 }
1422
1423 #[must_use]
1425 pub fn with_description(mut self, description: impl Into<String>) -> Self {
1426 self.description = Some(description.into());
1427 self
1428 }
1429
1430 }
1432
1433#[derive(Debug, Clone, Serialize, Deserialize)]
1496pub struct GroundedDocument {
1497 pub id: String,
1499 pub text: String,
1501 pub signals: Vec<Signal<Location>>,
1503 pub tracks: HashMap<TrackId, Track>,
1505 pub identities: HashMap<IdentityId, Identity>,
1507 signal_to_track: HashMap<SignalId, TrackId>,
1509 track_to_identity: HashMap<TrackId, IdentityId>,
1511 next_signal_id: SignalId,
1513 next_track_id: TrackId,
1515 next_identity_id: IdentityId,
1517}
1518
1519impl GroundedDocument {
1520 #[must_use]
1522 pub fn new(id: impl Into<String>, text: impl Into<String>) -> Self {
1523 Self {
1524 id: id.into(),
1525 text: text.into(),
1526 signals: Vec::new(),
1527 tracks: HashMap::new(),
1528 identities: HashMap::new(),
1529 signal_to_track: HashMap::new(),
1530 track_to_identity: HashMap::new(),
1531 next_signal_id: SignalId::ZERO,
1532 next_track_id: TrackId::ZERO,
1533 next_identity_id: IdentityId::ZERO,
1534 }
1535 }
1536
1537 pub fn add_signal(&mut self, mut signal: Signal<Location>) -> SignalId {
1543 let id = self.next_signal_id;
1544 signal.id = id;
1545 self.signals.push(signal);
1546 self.next_signal_id += 1;
1547 id
1548 }
1549
1550 #[must_use]
1552 pub fn get_signal(&self, id: impl Into<SignalId>) -> Option<&Signal<Location>> {
1553 let id = id.into();
1554 self.signals.iter().find(|s| s.id == id)
1555 }
1556
1557 pub fn signals(&self) -> &[Signal<Location>] {
1559 &self.signals
1560 }
1561
1562 pub fn add_track(&mut self, mut track: Track) -> TrackId {
1568 let id = self.next_track_id;
1569 track.id = id;
1570
1571 for signal_ref in &track.signals {
1573 self.signal_to_track.insert(signal_ref.signal_id, id);
1574 }
1575
1576 self.tracks.insert(id, track);
1577 self.next_track_id += 1;
1578 id
1579 }
1580
1581 #[must_use]
1583 pub fn get_track(&self, id: impl Into<TrackId>) -> Option<&Track> {
1584 self.tracks.get(&id.into())
1585 }
1586
1587 #[must_use]
1589 pub fn get_track_mut(&mut self, id: impl Into<TrackId>) -> Option<&mut Track> {
1590 self.tracks.get_mut(&id.into())
1591 }
1592
1593 pub fn add_signal_to_track(
1598 &mut self,
1599 signal_id: impl Into<SignalId>,
1600 track_id: impl Into<TrackId>,
1601 position: u32,
1602 ) -> bool {
1603 let signal_id = signal_id.into();
1604 let track_id = track_id.into();
1605 if let Some(track) = self.tracks.get_mut(&track_id) {
1606 track.add_signal(signal_id, position);
1607 self.signal_to_track.insert(signal_id, track_id);
1608 true
1609 } else {
1610 false
1611 }
1612 }
1613
1614 #[must_use]
1616 pub fn track_for_signal(&self, signal_id: SignalId) -> Option<&Track> {
1617 let track_id = self.signal_to_track.get(&signal_id)?;
1618 self.tracks.get(track_id)
1619 }
1620
1621 pub fn tracks(&self) -> impl Iterator<Item = &Track> {
1623 self.tracks.values()
1624 }
1625
1626 pub fn add_identity(&mut self, mut identity: Identity) -> IdentityId {
1632 let id = self.next_identity_id;
1633 identity.id = id;
1634 self.identities.insert(id, identity);
1635 self.next_identity_id += 1;
1636 id
1637 }
1638
1639 pub fn link_track_to_identity(
1641 &mut self,
1642 track_id: impl Into<TrackId>,
1643 identity_id: impl Into<IdentityId>,
1644 ) {
1645 let track_id = track_id.into();
1646 let identity_id = identity_id.into();
1647 if let Some(track) = self.tracks.get_mut(&track_id) {
1648 track.identity_id = Some(identity_id);
1649 self.track_to_identity.insert(track_id, identity_id);
1650 }
1651 }
1652
1653 #[must_use]
1655 pub fn get_identity(&self, id: IdentityId) -> Option<&Identity> {
1656 self.identities.get(&id)
1657 }
1658
1659 #[must_use]
1661 pub fn identity_for_track(&self, track_id: TrackId) -> Option<&Identity> {
1662 let identity_id = self.track_to_identity.get(&track_id)?;
1663 self.identities.get(identity_id)
1664 }
1665
1666 #[must_use]
1668 pub fn identity_for_signal(&self, signal_id: SignalId) -> Option<&Identity> {
1669 let track_id = self.signal_to_track.get(&signal_id)?;
1670 self.identity_for_track(*track_id)
1671 }
1672
1673 pub fn identities(&self) -> impl Iterator<Item = &Identity> {
1675 self.identities.values()
1676 }
1677
1678 #[must_use]
1683 pub fn track_ref(&self, track_id: TrackId) -> Option<TrackRef> {
1684 if self.tracks.contains_key(&track_id) {
1686 Some(TrackRef {
1687 doc_id: self.id.clone(),
1688 track_id,
1689 })
1690 } else {
1691 None
1692 }
1693 }
1694
1695 #[must_use]
1701 pub fn to_entities(&self) -> Vec<Entity> {
1702 self.signals
1703 .iter()
1704 .map(|signal| {
1705 let (start, end) = signal.location.text_offsets().unwrap_or((0, 0));
1706 let track = self.track_for_signal(signal.id);
1707 let identity = track.and_then(|t| self.identity_for_track(t.id));
1708
1709 Entity {
1710 text: signal.surface.clone(),
1711 entity_type: EntityType::from_label(signal.label.as_str()),
1712 start,
1713 end,
1714 confidence: signal.confidence as f64,
1715 normalized: signal.normalized.clone(),
1716 provenance: signal.provenance.clone(),
1717 kb_id: identity.and_then(|i| i.kb_id.clone()),
1718 canonical_id: track.map(|t| super::types::CanonicalId::new(t.id.get())),
1719 hierarchical_confidence: signal.hierarchical,
1720 visual_span: match &signal.location {
1721 Location::BoundingBox {
1722 x,
1723 y,
1724 width,
1725 height,
1726 page,
1727 } => Some(Span::BoundingBox {
1728 x: *x,
1729 y: *y,
1730 width: *width,
1731 height: *height,
1732 page: *page,
1733 }),
1734 Location::TextWithBbox { bbox, .. } => {
1735 if let Location::BoundingBox {
1736 x,
1737 y,
1738 width,
1739 height,
1740 page,
1741 } = bbox.as_ref()
1742 {
1743 Some(Span::BoundingBox {
1744 x: *x,
1745 y: *y,
1746 width: *width,
1747 height: *height,
1748 page: *page,
1749 })
1750 } else {
1751 None
1752 }
1753 }
1754 _ => None,
1755 },
1756 discontinuous_span: match &signal.location {
1757 Location::Discontinuous { segments } => Some(DiscontinuousSpan::new(
1758 segments.iter().map(|(s, e)| (*s)..(*e)).collect(),
1759 )),
1760 _ => None,
1761 },
1762 valid_from: None,
1763 valid_until: None,
1764 viewport: None,
1765 }
1766 })
1767 .collect()
1768 }
1769
1770 #[must_use]
1772 pub fn from_entities(
1773 id: impl Into<String>,
1774 text: impl Into<String>,
1775 entities: &[Entity],
1776 ) -> Self {
1777 let mut doc = Self::new(id, text);
1778
1779 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1785 enum TrackKey {
1786 Canonical(super::types::CanonicalId),
1787 Singleton(usize),
1788 }
1789
1790 let mut tracks_map: HashMap<TrackKey, Vec<SignalId>> = HashMap::new();
1791 let mut signal_to_entity_idx: HashMap<SignalId, usize> = HashMap::new();
1792
1793 for (idx, entity) in entities.iter().enumerate() {
1794 let location = if let Some(disc) = &entity.discontinuous_span {
1795 Location::Discontinuous {
1796 segments: disc.segments().iter().map(|r| (r.start, r.end)).collect(),
1797 }
1798 } else if let Some(visual) = &entity.visual_span {
1799 Location::from(visual)
1800 } else {
1801 Location::text(entity.start, entity.end)
1802 };
1803
1804 let mut signal = Signal::new(
1805 SignalId::new(idx as u64),
1806 location,
1807 &entity.text,
1808 entity.entity_type.as_label(),
1809 entity.confidence as f32,
1810 );
1811 signal.normalized = entity.normalized.clone();
1812 signal.provenance = entity.provenance.clone();
1813 signal.hierarchical = entity.hierarchical_confidence;
1814
1815 let signal_id = doc.add_signal(signal);
1816 signal_to_entity_idx.insert(signal_id, idx);
1817
1818 let key = match entity.canonical_id {
1819 Some(cid) => TrackKey::Canonical(cid),
1820 None => TrackKey::Singleton(idx),
1821 };
1822 tracks_map.entry(key).or_default().push(signal_id);
1823 }
1824
1825 for (_key, signal_ids) in tracks_map {
1827 if let Some(first_signal) = signal_ids.first().and_then(|id| doc.get_signal(*id)) {
1828 let mut track = Track::new(doc.next_track_id, &first_signal.surface);
1829 track.entity_type =
1830 Some(super::types::TypeLabel::from(first_signal.label.as_str()));
1831
1832 for (pos, &signal_id) in signal_ids.iter().enumerate() {
1833 track.add_signal(signal_id, pos as u32);
1834 }
1835
1836 let kb_id = signal_ids.iter().find_map(|sid| {
1839 let ent_idx = signal_to_entity_idx.get(sid).copied()?;
1840 entities.get(ent_idx)?.kb_id.clone()
1841 });
1842 if let Some(kb_id) = kb_id {
1843 let identity = Identity::from_kb(
1844 doc.next_identity_id,
1845 &track.canonical_surface,
1846 "unknown",
1847 kb_id,
1848 );
1849 let identity_id = doc.add_identity(identity);
1850 track = track.with_identity(identity_id);
1851 }
1852
1853 doc.add_track(track);
1854 }
1855 }
1856
1857 doc
1858 }
1859
1860 #[must_use]
1862 pub fn signals_with_label(&self, label: &str) -> Vec<&Signal<Location>> {
1863 let want = super::types::TypeLabel::from(label);
1864 self.signals.iter().filter(|s| s.label == want).collect()
1865 }
1866
1867 #[must_use]
1869 pub fn confident_signals(&self, threshold: f32) -> Vec<&Signal<Location>> {
1870 self.signals
1871 .iter()
1872 .filter(|s| s.confidence >= threshold)
1873 .collect()
1874 }
1875
1876 pub fn linked_tracks(&self) -> impl Iterator<Item = &Track> {
1878 self.tracks.values().filter(|t| t.identity_id.is_some())
1879 }
1880
1881 pub fn unlinked_tracks(&self) -> impl Iterator<Item = &Track> {
1883 self.tracks.values().filter(|t| t.identity_id.is_none())
1884 }
1885
1886 #[must_use]
1888 pub fn untracked_signal_count(&self) -> usize {
1889 self.signals
1890 .iter()
1891 .filter(|s| !self.signal_to_track.contains_key(&s.id))
1892 .count()
1893 }
1894
1895 #[must_use]
1897 pub fn untracked_signals(&self) -> Vec<&Signal<Location>> {
1898 self.signals
1899 .iter()
1900 .filter(|s| !self.signal_to_track.contains_key(&s.id))
1901 .collect()
1902 }
1903
1904 #[must_use]
1910 pub fn signals_by_modality(&self, modality: Modality) -> Vec<&Signal<Location>> {
1911 self.signals
1912 .iter()
1913 .filter(|s| s.modality == modality)
1914 .collect()
1915 }
1916
1917 #[must_use]
1919 pub fn text_signals(&self) -> Vec<&Signal<Location>> {
1920 self.signals_by_modality(Modality::Symbolic)
1921 }
1922
1923 #[must_use]
1925 pub fn visual_signals(&self) -> Vec<&Signal<Location>> {
1926 self.signals_by_modality(Modality::Iconic)
1927 }
1928
1929 #[must_use]
1931 pub fn overlapping_signals(&self, location: &Location) -> Vec<&Signal<Location>> {
1932 self.signals
1933 .iter()
1934 .filter(|s| s.location.overlaps(location))
1935 .collect()
1936 }
1937
1938 #[must_use]
1940 pub fn signals_in_range(&self, start: usize, end: usize) -> Vec<&Signal<Location>> {
1941 self.signals
1942 .iter()
1943 .filter(|s| {
1944 if let Some((s_start, s_end)) = s.location.text_offsets() {
1945 s_start >= start && s_end <= end
1946 } else {
1947 false
1948 }
1949 })
1950 .collect()
1951 }
1952
1953 #[must_use]
1955 pub fn negated_signals(&self) -> Vec<&Signal<Location>> {
1956 self.signals.iter().filter(|s| s.negated).collect()
1957 }
1958
1959 #[must_use]
1961 pub fn quantified_signals(&self, quantifier: Quantifier) -> Vec<&Signal<Location>> {
1962 self.signals
1963 .iter()
1964 .filter(|s| s.quantifier == Some(quantifier))
1965 .collect()
1966 }
1967
1968 #[must_use]
1990 pub fn validate(&self) -> Vec<SignalValidationError> {
1991 self.signals
1992 .iter()
1993 .filter_map(|s| s.validate_against(&self.text))
1994 .collect()
1995 }
1996
1997 #[must_use]
2021 pub fn validate_invariants(&self) -> Vec<String> {
2022 let mut errors = Vec::new();
2023
2024 let mut seen_ids = std::collections::HashSet::new();
2026 for signal in &self.signals {
2027 if !seen_ids.insert(signal.id) {
2028 errors.push(format!("Duplicate signal ID: {}", signal.id));
2029 }
2030 }
2031
2032 let signal_ids: std::collections::HashSet<_> = self.signals.iter().map(|s| s.id).collect();
2034
2035 for (track_id, track) in &self.tracks {
2037 for signal_ref in &track.signals {
2038 if !signal_ids.contains(&signal_ref.signal_id) {
2039 errors.push(format!(
2040 "Track {} references non-existent signal {}",
2041 track_id, signal_ref.signal_id
2042 ));
2043 }
2044 }
2045 }
2046
2047 for (signal_id, track_id) in &self.signal_to_track {
2049 if let Some(track) = self.tracks.get(track_id) {
2051 if !track.signals.iter().any(|r| r.signal_id == *signal_id) {
2053 errors.push(format!(
2054 "signal_to_track[{}] = {} but track doesn't contain signal",
2055 signal_id, track_id
2056 ));
2057 }
2058 } else {
2059 errors.push(format!(
2060 "signal_to_track[{}] = {} but track doesn't exist",
2061 signal_id, track_id
2062 ));
2063 }
2064 }
2065
2066 for (track_id, identity_id) in &self.track_to_identity {
2068 if let Some(track) = self.tracks.get(track_id) {
2070 if track.identity_id != Some(*identity_id) {
2071 errors.push(format!(
2072 "track_to_identity[{}] = {} but track.identity_id = {:?}",
2073 track_id, identity_id, track.identity_id
2074 ));
2075 }
2076 } else {
2077 errors.push(format!(
2078 "track_to_identity[{}] = {} but track doesn't exist",
2079 track_id, identity_id
2080 ));
2081 }
2082
2083 if !self.identities.contains_key(identity_id) {
2085 errors.push(format!(
2086 "track_to_identity[{}] = {} but identity doesn't exist",
2087 track_id, identity_id
2088 ));
2089 }
2090 }
2091
2092 for (track_id, track) in &self.tracks {
2094 if let Some(identity_id) = track.identity_id {
2095 if !self.identities.contains_key(&identity_id) {
2096 errors.push(format!(
2097 "Track {} references non-existent identity {}",
2098 track_id, identity_id
2099 ));
2100 }
2101 }
2102 }
2103
2104 errors
2105 }
2106
2107 #[must_use]
2109 pub fn invariants_hold(&self) -> bool {
2110 self.validate_invariants().is_empty()
2111 }
2112
2113 #[must_use]
2115 pub fn is_valid(&self) -> bool {
2116 self.signals.iter().all(|s| s.is_valid(&self.text))
2117 }
2118
2119 pub fn add_signal_validated(
2123 &mut self,
2124 signal: Signal<Location>,
2125 ) -> Result<SignalId, SignalValidationError> {
2126 if let Some(err) = signal.validate_against(&self.text) {
2127 return Err(err);
2128 }
2129 Ok(self.add_signal(signal))
2130 }
2131
2132 pub fn add_signal_from_text(
2146 &mut self,
2147 surface: &str,
2148 label: impl Into<super::types::TypeLabel>,
2149 confidence: f32,
2150 ) -> Option<SignalId> {
2151 let signal = Signal::from_text(&self.text, surface, label, confidence)?;
2152 Some(self.add_signal(signal))
2153 }
2154
2155 pub fn add_signal_from_text_nth(
2157 &mut self,
2158 surface: &str,
2159 label: impl Into<super::types::TypeLabel>,
2160 confidence: f32,
2161 occurrence: usize,
2162 ) -> Option<SignalId> {
2163 let signal = Signal::from_text_nth(&self.text, surface, label, confidence, occurrence)?;
2164 Some(self.add_signal(signal))
2165 }
2166
2167 #[must_use]
2173 pub fn stats(&self) -> DocumentStats {
2174 let signal_count = self.signals.len();
2175 let track_count = self.tracks.len();
2176 let identity_count = self.identities.len();
2177
2178 let linked_track_count = self
2179 .tracks
2180 .values()
2181 .filter(|t| t.identity_id.is_some())
2182 .count();
2183 let untracked_count = self.untracked_signal_count();
2184
2185 let avg_track_size = if track_count > 0 {
2186 self.tracks.values().map(|t| t.len()).sum::<usize>() as f32 / track_count as f32
2187 } else {
2188 0.0
2189 };
2190
2191 let singleton_count = self.tracks.values().filter(|t| t.is_singleton()).count();
2192
2193 let avg_confidence = if signal_count > 0 {
2194 self.signals.iter().map(|s| s.confidence).sum::<f32>() / signal_count as f32
2195 } else {
2196 0.0
2197 };
2198
2199 let negated_count = self.signals.iter().filter(|s| s.negated).count();
2200
2201 let symbolic_count = self
2203 .signals
2204 .iter()
2205 .filter(|s| s.modality == Modality::Symbolic)
2206 .count();
2207 let iconic_count = self
2208 .signals
2209 .iter()
2210 .filter(|s| s.modality == Modality::Iconic)
2211 .count();
2212 let hybrid_count = self
2213 .signals
2214 .iter()
2215 .filter(|s| s.modality == Modality::Hybrid)
2216 .count();
2217
2218 DocumentStats {
2219 signal_count,
2220 track_count,
2221 identity_count,
2222 linked_track_count,
2223 untracked_count,
2224 avg_track_size,
2225 singleton_count,
2226 avg_confidence,
2227 negated_count,
2228 symbolic_count,
2229 iconic_count,
2230 hybrid_count,
2231 }
2232 }
2233
2234 pub fn add_signals(
2242 &mut self,
2243 signals: impl IntoIterator<Item = Signal<Location>>,
2244 ) -> Vec<SignalId> {
2245 signals.into_iter().map(|s| self.add_signal(s)).collect()
2246 }
2247
2248 pub fn create_track_from_signals(
2252 &mut self,
2253 canonical: impl Into<String>,
2254 signal_ids: &[SignalId],
2255 ) -> Option<TrackId> {
2256 if signal_ids.is_empty() {
2257 return None;
2258 }
2259
2260 let mut track = Track::new(TrackId::ZERO, canonical);
2261 for (pos, &id) in signal_ids.iter().enumerate() {
2262 track.add_signal(id, pos as u32);
2263 }
2264 Some(self.add_track(track))
2265 }
2266
2267 pub fn merge_tracks(&mut self, track_ids: &[TrackId]) -> Option<TrackId> {
2272 if track_ids.is_empty() {
2273 return None;
2274 }
2275
2276 let mut all_signals: Vec<SignalRef> = Vec::new();
2278 let mut canonical = String::new();
2279 let mut entity_type = None;
2280
2281 for &track_id in track_ids {
2282 if let Some(track) = self.tracks.get(&track_id) {
2283 if canonical.is_empty() {
2284 canonical = track.canonical_surface.clone();
2285 entity_type = track.entity_type.clone();
2286 }
2287 all_signals.extend(track.signals.iter().cloned());
2288 }
2289 }
2290
2291 if all_signals.is_empty() {
2292 return None;
2293 }
2294
2295 all_signals.sort_by_key(|s| s.position);
2297
2298 for &track_id in track_ids {
2300 self.tracks.remove(&track_id);
2301 }
2302
2303 let mut new_track = Track::new(TrackId::ZERO, canonical);
2305 new_track.entity_type = entity_type;
2306 for (pos, signal_ref) in all_signals.iter().enumerate() {
2307 new_track.add_signal(signal_ref.signal_id, pos as u32);
2308 }
2309
2310 Some(self.add_track(new_track))
2311 }
2312
2313 #[must_use]
2315 pub fn find_overlapping_signal_pairs(&self) -> Vec<(SignalId, SignalId)> {
2316 let mut pairs = Vec::new();
2317 let signals: Vec<_> = self.signals.iter().collect();
2318
2319 for i in 0..signals.len() {
2320 for j in (i + 1)..signals.len() {
2321 if signals[i].location.overlaps(&signals[j].location) {
2322 pairs.push((signals[i].id, signals[j].id));
2323 }
2324 }
2325 }
2326
2327 pairs
2328 }
2329}
2330
2331#[derive(Debug, Clone, Copy, Default)]
2333pub struct DocumentStats {
2334 pub signal_count: usize,
2336 pub track_count: usize,
2338 pub identity_count: usize,
2340 pub linked_track_count: usize,
2342 pub untracked_count: usize,
2344 pub avg_track_size: f32,
2346 pub singleton_count: usize,
2348 pub avg_confidence: f32,
2350 pub negated_count: usize,
2352 pub symbolic_count: usize,
2354 pub iconic_count: usize,
2356 pub hybrid_count: usize,
2358}
2359
2360impl std::fmt::Display for DocumentStats {
2361 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2362 writeln!(f, "Document Statistics:")?;
2363 writeln!(
2364 f,
2365 " Signals: {} (avg confidence: {:.2})",
2366 self.signal_count, self.avg_confidence
2367 )?;
2368 writeln!(
2369 f,
2370 " Tracks: {} (avg size: {:.1}, singletons: {})",
2371 self.track_count, self.avg_track_size, self.singleton_count
2372 )?;
2373 writeln!(
2374 f,
2375 " Identities: {} ({} tracks linked)",
2376 self.identity_count, self.linked_track_count
2377 )?;
2378 writeln!(f, " Untracked signals: {}", self.untracked_count)?;
2379 writeln!(
2380 f,
2381 " Modalities: {} symbolic, {} iconic, {} hybrid",
2382 self.symbolic_count, self.iconic_count, self.hybrid_count
2383 )?;
2384 if self.negated_count > 0 {
2385 writeln!(f, " Negated: {}", self.negated_count)?;
2386 }
2387 Ok(())
2388 }
2389}
2390
2391#[derive(Debug, Clone)]
2401struct IntervalNode {
2402 signal_id: SignalId,
2404 start: usize,
2406 end: usize,
2408 max_end: usize,
2410 left: Option<Box<IntervalNode>>,
2412 right: Option<Box<IntervalNode>>,
2414}
2415
2416impl IntervalNode {
2417 fn new(signal_id: SignalId, start: usize, end: usize) -> Self {
2418 Self {
2419 signal_id,
2420 start,
2421 end,
2422 max_end: end,
2423 left: None,
2424 right: None,
2425 }
2426 }
2427
2428 fn insert(&mut self, signal_id: SignalId, start: usize, end: usize) {
2429 self.max_end = self.max_end.max(end);
2430
2431 if start < self.start {
2432 if let Some(ref mut left) = self.left {
2433 left.insert(signal_id, start, end);
2434 } else {
2435 self.left = Some(Box::new(IntervalNode::new(signal_id, start, end)));
2436 }
2437 } else if let Some(ref mut right) = self.right {
2438 right.insert(signal_id, start, end);
2439 } else {
2440 self.right = Some(Box::new(IntervalNode::new(signal_id, start, end)));
2441 }
2442 }
2443
2444 fn query_overlap(&self, query_start: usize, query_end: usize, results: &mut Vec<SignalId>) {
2445 if self.start < query_end && query_start < self.end {
2447 results.push(self.signal_id);
2448 }
2449
2450 if let Some(ref left) = self.left {
2452 if left.max_end > query_start {
2453 left.query_overlap(query_start, query_end, results);
2454 }
2455 }
2456
2457 if let Some(ref right) = self.right {
2459 if self.start < query_end {
2460 right.query_overlap(query_start, query_end, results);
2461 }
2462 }
2463 }
2464
2465 fn query_containing(&self, query_start: usize, query_end: usize, results: &mut Vec<SignalId>) {
2466 if self.start <= query_start && self.end >= query_end {
2468 results.push(self.signal_id);
2469 }
2470
2471 if let Some(ref left) = self.left {
2473 if left.max_end >= query_end {
2474 left.query_containing(query_start, query_end, results);
2475 }
2476 }
2477
2478 if let Some(ref right) = self.right {
2480 if self.start <= query_start {
2481 right.query_containing(query_start, query_end, results);
2482 }
2483 }
2484 }
2485
2486 fn query_contained_in(
2487 &self,
2488 range_start: usize,
2489 range_end: usize,
2490 results: &mut Vec<SignalId>,
2491 ) {
2492 if self.start >= range_start && self.end <= range_end {
2494 results.push(self.signal_id);
2495 }
2496
2497 if let Some(ref left) = self.left {
2499 left.query_contained_in(range_start, range_end, results);
2500 }
2501
2502 if let Some(ref right) = self.right {
2504 if self.start < range_end {
2505 right.query_contained_in(range_start, range_end, results);
2506 }
2507 }
2508 }
2509}
2510
2511#[derive(Debug, Clone, Default)]
2527pub struct TextSpatialIndex {
2528 root: Option<IntervalNode>,
2529 size: usize,
2530}
2531
2532impl TextSpatialIndex {
2533 #[must_use]
2535 pub fn new() -> Self {
2536 Self::default()
2537 }
2538
2539 #[must_use]
2541 pub fn from_signals(signals: &[Signal<Location>]) -> Self {
2542 let mut index = Self::new();
2543 for signal in signals {
2544 if let Some((start, end)) = signal.location.text_offsets() {
2545 index.insert(signal.id, start, end);
2546 }
2547 }
2548 index
2549 }
2550
2551 pub fn insert(&mut self, signal_id: SignalId, start: usize, end: usize) {
2553 if let Some(ref mut root) = self.root {
2554 root.insert(signal_id, start, end);
2555 } else {
2556 self.root = Some(IntervalNode::new(signal_id, start, end));
2557 }
2558 self.size += 1;
2559 }
2560
2561 #[must_use]
2563 pub fn query_overlap(&self, start: usize, end: usize) -> Vec<SignalId> {
2564 let mut results = Vec::new();
2565 if let Some(ref root) = self.root {
2566 root.query_overlap(start, end, &mut results);
2567 }
2568 results
2569 }
2570
2571 #[must_use]
2573 pub fn query_containing(&self, start: usize, end: usize) -> Vec<SignalId> {
2574 let mut results = Vec::new();
2575 if let Some(ref root) = self.root {
2576 root.query_containing(start, end, &mut results);
2577 }
2578 results
2579 }
2580
2581 #[must_use]
2583 pub fn query_contained_in(&self, start: usize, end: usize) -> Vec<SignalId> {
2584 let mut results = Vec::new();
2585 if let Some(ref root) = self.root {
2586 root.query_contained_in(start, end, &mut results);
2587 }
2588 results
2589 }
2590
2591 #[must_use]
2593 pub fn len(&self) -> usize {
2594 self.size
2595 }
2596
2597 #[must_use]
2599 pub fn is_empty(&self) -> bool {
2600 self.size == 0
2601 }
2602}
2603
2604impl GroundedDocument {
2605 #[must_use]
2624 pub fn build_text_index(&self) -> TextSpatialIndex {
2625 TextSpatialIndex::from_signals(&self.signals)
2626 }
2627
2628 #[must_use]
2633 pub fn query_signals_in_range_indexed(
2634 &self,
2635 start: usize,
2636 end: usize,
2637 ) -> Vec<&Signal<Location>> {
2638 let index = self.build_text_index();
2639 let ids = index.query_contained_in(start, end);
2640 ids.iter().filter_map(|&id| self.get_signal(id)).collect()
2641 }
2642
2643 #[must_use]
2645 pub fn query_overlapping_signals_indexed(
2646 &self,
2647 start: usize,
2648 end: usize,
2649 ) -> Vec<&Signal<Location>> {
2650 let index = self.build_text_index();
2651 let ids = index.query_overlap(start, end);
2652 ids.iter().filter_map(|&id| self.get_signal(id)).collect()
2653 }
2654
2655 #[must_use]
2668 pub fn to_coref_document(&self) -> super::coref::CorefDocument {
2669 use super::coref::{CorefChain, CorefDocument, Mention};
2670 use std::collections::HashMap;
2671
2672 let signal_by_id: HashMap<SignalId, &Signal<Location>> =
2674 self.signals.iter().map(|s| (s.id, s)).collect();
2675
2676 let mut chains: Vec<CorefChain> = Vec::new();
2677
2678 for track in self.tracks.values() {
2679 let mut mentions: Vec<Mention> = Vec::new();
2680
2681 for sref in &track.signals {
2682 let Some(signal) = signal_by_id.get(&sref.signal_id) else {
2683 continue;
2684 };
2685
2686 let Some((start, end)) = signal.location.text_offsets() else {
2687 continue;
2688 };
2689
2690 let mut m = Mention::new(signal.surface.clone(), start, end);
2691 m.entity_type = Some(signal.label.to_string());
2692 mentions.push(m);
2693 }
2694
2695 if mentions.is_empty() {
2696 continue;
2697 }
2698
2699 let mut chain = CorefChain::new(mentions);
2700 chain.entity_type = track.entity_type.as_ref().map(|t| t.to_string());
2701 chains.push(chain);
2702 }
2703
2704 chains.sort_by_key(|c| c.mentions.first().map(|m| m.start).unwrap_or(usize::MAX));
2706
2707 CorefDocument::with_id(&self.text, &self.id, chains)
2708 }
2709}
2710
2711pub fn render_document_html(doc: &GroundedDocument) -> String {
2719 let mut html = String::new();
2720 let stats = doc.stats();
2721
2722 html.push_str(r#"<!DOCTYPE html>
2723<html>
2724<head>
2725<meta charset="UTF-8">
2726<meta name="color-scheme" content="dark light">
2727<title>grounded::GroundedDocument</title>
2728<style>
2729:root{
2730 /* Allow UA widgets (inputs/scrollbars) to match the theme */
2731 color-scheme: light dark;
2732 /* Dark (default) */
2733 --bg:#0a0a0a;
2734 --panel-bg:#0d0d0d;
2735 --text:#b0b0b0;
2736 --text-strong:#fff;
2737 --muted:#666;
2738 --border:#222;
2739 --border-strong:#333;
2740 --hover:#111;
2741 --input-bg:#080808;
2742 --active:#fff;
2743 --track-strong:rgba(255,255,255,0.35);
2744 --track-soft:rgba(255,255,255,0.18);
2745 /* Entity colors (dark) */
2746 --per-bg:#1a1a2e; --per-br:#4a4a8a; --per-tx:#8888cc;
2747 --org-bg:#1a2e1a; --org-br:#4a8a4a; --org-tx:#88cc88;
2748 --loc-bg:#2e2e1a; --loc-br:#8a8a4a; --loc-tx:#cccc88;
2749 --mis-bg:#1a1a1a; --mis-br:#4a4a4a; --mis-tx:#999;
2750 --dat-bg:#2e1a1a; --dat-br:#8a4a4a; --dat-tx:#cc8888;
2751 --badge-y-bg:#1a2e1a; --badge-y-tx:#4a8a4a; --badge-y-br:#2a4a2a;
2752 --badge-n-bg:#2e2e1a; --badge-n-tx:#8a8a4a; --badge-n-br:#4a4a2a;
2753}
2754@media (prefers-color-scheme: light){
2755 :root{
2756 --bg:#ffffff;
2757 --panel-bg:#f7f7f7;
2758 --text:#222;
2759 --text-strong:#000;
2760 --muted:#555;
2761 --border:#d6d6d6;
2762 --border-strong:#c6c6c6;
2763 --hover:#f0f0f0;
2764 --input-bg:#ffffff;
2765 --active:#000;
2766 --track-strong:rgba(0,0,0,0.25);
2767 --track-soft:rgba(0,0,0,0.12);
2768 /* Entity colors (light) */
2769 --per-bg:#e9e9ff; --per-br:#6c6cff; --per-tx:#2b2b7a;
2770 --org-bg:#e9f7e9; --org-br:#2f8a2f; --org-tx:#1f5a1f;
2771 --loc-bg:#fff7db; --loc-br:#8a7a2f; --loc-tx:#5a4d12;
2772 --mis-bg:#f2f2f2; --mis-br:#8a8a8a; --mis-tx:#333;
2773 --dat-bg:#ffe9e9; --dat-br:#8a2f2f; --dat-tx:#5a1f1f;
2774 --badge-y-bg:#e9f7e9; --badge-y-tx:#1f5a1f; --badge-y-br:#9ad19a;
2775 --badge-n-bg:#fff7db; --badge-n-tx:#5a4d12; --badge-n-br:#e2d39a;
2776 }
2777}
2778html[data-theme='dark']{
2779 --bg:#0a0a0a; --panel-bg:#0d0d0d; --text:#b0b0b0; --text-strong:#fff;
2780 --muted:#666; --border:#222; --border-strong:#333; --hover:#111;
2781 --input-bg:#080808; --active:#fff;
2782 --track-strong:rgba(255,255,255,0.35); --track-soft:rgba(255,255,255,0.18);
2783 --per-bg:#1a1a2e; --per-br:#4a4a8a; --per-tx:#8888cc;
2784 --org-bg:#1a2e1a; --org-br:#4a8a4a; --org-tx:#88cc88;
2785 --loc-bg:#2e2e1a; --loc-br:#8a8a4a; --loc-tx:#cccc88;
2786 --mis-bg:#1a1a1a; --mis-br:#4a4a4a; --mis-tx:#999;
2787 --dat-bg:#2e1a1a; --dat-br:#8a4a4a; --dat-tx:#cc8888;
2788 --badge-y-bg:#1a2e1a; --badge-y-tx:#4a8a4a; --badge-y-br:#2a4a2a;
2789 --badge-n-bg:#2e2e1a; --badge-n-tx:#8a8a4a; --badge-n-br:#4a4a2a;
2790}
2791html[data-theme='light']{
2792 --bg:#ffffff; --panel-bg:#f7f7f7; --text:#222; --text-strong:#000;
2793 --muted:#555; --border:#d6d6d6; --border-strong:#c6c6c6; --hover:#f0f0f0;
2794 --input-bg:#ffffff; --active:#000;
2795 --track-strong:rgba(0,0,0,0.25); --track-soft:rgba(0,0,0,0.12);
2796 --per-bg:#e9e9ff; --per-br:#6c6cff; --per-tx:#2b2b7a;
2797 --org-bg:#e9f7e9; --org-br:#2f8a2f; --org-tx:#1f5a1f;
2798 --loc-bg:#fff7db; --loc-br:#8a7a2f; --loc-tx:#5a4d12;
2799 --mis-bg:#f2f2f2; --mis-br:#8a8a8a; --mis-tx:#333;
2800 --dat-bg:#ffe9e9; --dat-br:#8a2f2f; --dat-tx:#5a1f1f;
2801 --badge-y-bg:#e9f7e9; --badge-y-tx:#1f5a1f; --badge-y-br:#9ad19a;
2802 --badge-n-bg:#fff7db; --badge-n-tx:#5a4d12; --badge-n-br:#e2d39a;
2803}
2804
2805*{box-sizing:border-box;margin:0;padding:0}
2806body{font:12px/1.4 monospace;background:var(--bg);color:var(--text);padding:8px}
2807h1,h2,h3{color:var(--text-strong);font-weight:normal;border-bottom:1px solid var(--border-strong);padding:4px 0;margin:16px 0 8px}
2808h1{font-size:14px}h2{font-size:12px}h3{font-size:11px;color:var(--muted)}
2809 a{color:inherit}
2810 a:hover{text-decoration:underline}
2811table{width:100%;border-collapse:collapse;font-size:11px;margin:4px 0}
2812th,td{padding:4px 8px;text-align:left;border:1px solid var(--border)}
2813th{background:var(--hover);color:var(--muted);font-weight:normal;text-transform:uppercase;font-size:10px}
2814tr:hover{background:var(--hover)}
2815.grid{display:grid;grid-template-columns:repeat(auto-fit,minmax(300px,1fr));gap:8px}
2816.panel{border:1px solid var(--border);background:var(--panel-bg);padding:8px}
2817.panel-h{display:flex;align-items:center;gap:8px}
2818.toggle{cursor:pointer;user-select:none;color:var(--muted);border:1px solid var(--border);background:var(--bg);padding:2px 6px;font-size:10px}
2819.panel-collapsed table,.panel-collapsed .panel-body{display:none}
2820.toolbar{display:flex;gap:8px;align-items:center;margin:8px 0 0}
2821.toolbar input{width:100%;max-width:520px;background:var(--input-bg);border:1px solid var(--border);color:var(--text);padding:6px 8px;font:12px monospace}
2822.muted{color:var(--muted)}
2823.panel-body{white-space:pre-wrap;word-break:break-word}
2824.text-box{background:var(--input-bg);border:1px solid var(--border);padding:8px;white-space:pre-wrap;word-break:break-word;line-height:1.6}
2825.e{padding:1px 2px;border-bottom:1px solid}
2826.seg{cursor:pointer}
2827.e-per{background:var(--per-bg);border-color:var(--per-br);color:var(--per-tx)}
2828.e-org{background:var(--org-bg);border-color:var(--org-br);color:var(--org-tx)}
2829.e-loc{background:var(--loc-bg);border-color:var(--loc-br);color:var(--loc-tx)}
2830.e-misc{background:var(--mis-bg);border-color:var(--mis-br);color:var(--mis-tx)}
2831.e-date{background:var(--dat-bg);border-color:var(--dat-br);color:var(--dat-tx)}
2832.e-track{box-shadow:inset 0 0 0 1px var(--track-strong)}
2833.e-track-hover{box-shadow:inset 0 0 0 1px var(--track-soft)}
2834.e-active{outline:2px solid var(--active);outline-offset:1px}
2835.conf{color:var(--muted);font-size:10px}
2836.badge{display:inline-block;padding:1px 4px;font-size:9px;text-transform:uppercase}
2837.badge-y{background:var(--badge-y-bg);color:var(--badge-y-tx);border:1px solid var(--badge-y-br)}
2838.badge-n{background:var(--badge-n-bg);color:var(--badge-n-tx);border:1px solid var(--badge-n-br)}
2839.stats{display:flex;gap:16px;padding:8px 0;border-bottom:1px solid var(--border);margin-bottom:8px}
2840.stat{text-align:center}.stat-v{font-size:18px;color:var(--text-strong)}.stat-l{font-size:9px;color:var(--muted);text-transform:uppercase}
2841.id{color:var(--muted);font-size:9px}
2842.kb{color:var(--muted)}
2843.arrow{color:var(--muted)}
2844</style>
2845</head>
2846<body>
2847"#);
2848
2849 html.push_str(&format!(
2851 r#"<div class="panel-h" style="justify-content:space-between"><h1>doc_id="{}" len={}</h1><span class="toggle" id="theme-toggle" title="toggle theme (auto → dark → light)">theme: auto</span></div>"#,
2852 html_escape(&doc.id),
2853 doc.text.len()
2854 ));
2855
2856 html.push_str(r#"<div class="stats">"#);
2857 html.push_str(&format!(
2858 r#"<div class="stat"><div class="stat-v">{}</div><div class="stat-l">signals</div></div>"#,
2859 stats.signal_count
2860 ));
2861 html.push_str(&format!(
2862 r#"<div class="stat"><div class="stat-v">{}</div><div class="stat-l">tracks</div></div>"#,
2863 stats.track_count
2864 ));
2865 html.push_str(&format!(r#"<div class="stat"><div class="stat-v">{}</div><div class="stat-l">identities</div></div>"#, stats.identity_count));
2866 html.push_str(&format!(r#"<div class="stat"><div class="stat-v">{:.2}</div><div class="stat-l">avg_conf</div></div>"#, stats.avg_confidence));
2867 html.push_str(&format!(
2868 r#"<div class="stat"><div class="stat-v">{}</div><div class="stat-l">linked</div></div>"#,
2869 stats.linked_track_count
2870 ));
2871 html.push_str(&format!(r#"<div class="stat"><div class="stat-v">{}</div><div class="stat-l">untracked</div></div>"#, stats.untracked_count));
2872 if stats.iconic_count > 0 || stats.hybrid_count > 0 {
2873 html.push_str(&format!(r#"<div class="stat"><div class="stat-v">{}/{}/{}</div><div class="stat-l">sym/ico/hyb</div></div>"#,
2874 stats.symbolic_count, stats.iconic_count, stats.hybrid_count));
2875 }
2876 html.push_str(r#"</div>"#);
2877
2878 html.push_str(r#"<h2>text</h2>"#);
2880 html.push_str(r#"<div class="text-box">"#);
2881 html.push_str(&annotate_text_html(
2882 &doc.text,
2883 doc.signals(),
2884 &doc.signal_to_track,
2885 ));
2886 html.push_str(r#"</div>"#);
2887
2888 html.push_str(
2890 r#"<h2>selection</h2><div class="panel" id="selection-panel" role="region" aria-label="selection"><div class="panel-h"><h3>selection</h3><span class="muted" id="selection-hint" role="status" aria-live="polite">click a mention / row to see coref track details</span></div><pre class="panel-body" id="selection-body" role="textbox" aria-readonly="true" aria-label="selection details">—</pre></div>"#,
2891 );
2892
2893 html.push_str(r#"<div class="grid">"#);
2895
2896 html.push_str(r#"<div class="panel" id="panel-signals"><div class="panel-h"><h3>signals (level 1)</h3><span class="toggle" data-toggle="panel-signals">toggle</span></div><div class="toolbar"><input id="signal-filter" type="text" placeholder="filter signals: id / label / surface (e.g. 'PER', 'S12', 'Paris')" /><span class="muted" id="signal-filter-count"></span></div><table id="signals-table">"#);
2898 html.push_str(r#"<tr><th>id</th><th>span</th><th>surface</th><th>label</th><th>conf</th><th>track</th></tr>"#);
2899 for signal in doc.signals() {
2900 let (span, start_opt, end_opt) = if let Some((s, e)) = signal.location.text_offsets() {
2901 (format!("[{},{})", s, e), Some(s), Some(e))
2902 } else {
2903 ("bbox".to_string(), None, None)
2904 };
2905 let track_id_num = doc.signal_to_track.get(&signal.id).copied();
2906 let track_id = track_id_num
2907 .map(|t| format!("T{}", t))
2908 .unwrap_or_else(|| "-".to_string());
2909 let track_attr = track_id_num
2910 .map(|t| format!(r#" data-track="{}""#, t))
2911 .unwrap_or_default();
2912 let offs_attr = match (start_opt, end_opt) {
2913 (Some(s), Some(e)) => format!(r#" data-start="{}" data-end="{}""#, s, e),
2914 _ => String::new(),
2915 };
2916 let neg = if signal.negated { " NEG" } else { "" };
2917 html.push_str(&format!(
2918 r#"<tr data-sid="S{sid}" data-label="{label}" data-surface="{surface}"{track_attr}{offs_attr} data-conf="{conf:.2}"><td class="id"><a href='#S{sid}'>S{sid}</a></td><td>{span}</td><td>{surface}</td><td>{label}{neg}</td><td class="conf">{conf:.2}</td><td class="id">{track}</td></tr>"#,
2919 sid = signal.id,
2920 span = span,
2921 surface = html_escape(&signal.surface),
2922 label = html_escape(signal.label.as_str()),
2923 neg = neg,
2924 conf = signal.confidence,
2925 track = track_id,
2926 track_attr = track_attr,
2927 offs_attr = offs_attr
2928 ));
2929 }
2930 html.push_str(r#"</table></div>"#);
2931
2932 html.push_str(r#"<div class="panel" id="panel-tracks"><div class="panel-h"><h3>tracks (level 2)</h3><span class="toggle" data-toggle="panel-tracks">toggle</span></div><table id="tracks-table">"#);
2934 html.push_str(r#"<tr><th>id</th><th>canonical</th><th>type</th><th>|S|</th><th>signals</th><th>identity</th></tr>"#);
2935 for track in doc.tracks() {
2936 let entity_type = track
2937 .entity_type
2938 .as_ref()
2939 .map(|t| t.as_str())
2940 .unwrap_or("-");
2941 let signals: Vec<String> = track
2942 .signals
2943 .iter()
2944 .map(|s| format!("S{}", s.signal_id))
2945 .collect();
2946 let identity = doc
2947 .identity_for_track(track.id)
2948 .map(|i| format!("I{}", i.id))
2949 .unwrap_or_else(|| "-".to_string());
2950 let linked_badge = if track.identity_id.is_some() {
2951 r#"<span class="badge badge-y">y</span>"#
2952 } else {
2953 r#"<span class="badge badge-n">n</span>"#
2954 };
2955 html.push_str(&format!(
2956 r#"<tr data-tid="{tid}"><td class="id">T{tid}</td><td>{canonical_surface}</td><td>{etype}</td><td>{n}</td><td class="id">{sigs}</td><td class="id">{ident} {badge}</td></tr>"#,
2957 tid = track.id,
2958 canonical_surface = html_escape(&track.canonical_surface),
2959 etype = html_escape(entity_type),
2960 n = track.len(),
2961 sigs = html_escape(&signals.join(" ")),
2962 ident = identity,
2963 badge = linked_badge
2964 ));
2965 }
2966 html.push_str(r#"</table></div>"#);
2967
2968 html.push_str(r#"<div class="panel" id="panel-identities"><div class="panel-h"><h3>identities (level 3)</h3><span class="toggle" data-toggle="panel-identities">toggle</span></div><table>"#);
2970 html.push_str(r#"<tr><th>id</th><th>name</th><th>type</th><th>kb</th><th>kb_id</th><th>aliases</th></tr>"#);
2971 for identity in doc.identities() {
2972 let kb = identity.kb_name.as_deref().unwrap_or("-");
2973 let kb_id = identity.kb_id.as_deref().unwrap_or("-");
2974 let entity_type = identity
2975 .entity_type
2976 .as_ref()
2977 .map(|t| t.as_str())
2978 .unwrap_or("-");
2979 let aliases = if identity.aliases.is_empty() {
2980 "-".to_string()
2981 } else {
2982 identity.aliases.join(", ")
2983 };
2984 html.push_str(&format!(
2985 r#"<tr><td class="id">I{}</td><td>{}</td><td>{}</td><td class="kb">{}</td><td class="kb">{}</td><td>{}</td></tr>"#,
2986 identity.id, html_escape(&identity.canonical_name), entity_type, kb, kb_id, html_escape(&aliases)
2987 ));
2988 }
2989 html.push_str(r#"</table></div>"#);
2990
2991 html.push_str(r#"</div>"#); html.push_str(r#"<h2>hierarchy trace</h2><div class="panel"><table>"#);
2995 html.push_str(r#"<tr><th>signal</th><th></th><th>track</th><th></th><th>identity</th><th>kb_id</th></tr>"#);
2996 for signal in doc.signals() {
2997 let track = doc.track_for_signal(signal.id);
2998 let identity = doc.identity_for_signal(signal.id);
2999
3000 let track_str = track
3001 .map(|t| format!("T{} \"{}\"", t.id, html_escape(&t.canonical_surface)))
3002 .unwrap_or_else(|| "-".to_string());
3003 let identity_str = identity
3004 .map(|i| format!("I{} \"{}\"", i.id, html_escape(&i.canonical_name)))
3005 .unwrap_or_else(|| "-".to_string());
3006 let kb_str = identity
3007 .and_then(|i| i.kb_id.as_ref())
3008 .map(|s| s.as_str())
3009 .unwrap_or("-");
3010
3011 html.push_str(&format!(
3012 r#"<tr><td>S{} "{}"</td><td class="arrow">→</td><td>{}</td><td class="arrow">→</td><td>{}</td><td class="kb">{}</td></tr>"#,
3013 signal.id, html_escape(&signal.surface), track_str, identity_str, kb_str
3014 ));
3015 }
3016 html.push_str(r#"</table></div>"#);
3017
3018 html.push_str(r#"<script>
3021(() => {
3022 // Index signal metadata from the signals table, and map signal/track → text elements.
3023 const signalMeta = new Map();
3024 document.querySelectorAll('#signals-table tr[data-sid]').forEach((row) => {
3025 const sid = row.getAttribute('data-sid');
3026 if (!sid) return;
3027 signalMeta.set(sid, {
3028 sid,
3029 label: row.getAttribute('data-label') || '',
3030 surface: row.getAttribute('data-surface') || '',
3031 conf: row.getAttribute('data-conf') || '',
3032 start: row.getAttribute('data-start'),
3033 end: row.getAttribute('data-end'),
3034 track: row.getAttribute('data-track'),
3035 });
3036 });
3037
3038 const signalEls = new Map();
3039 const addSignalEl = (sid, el) => {
3040 if (!sid || !el) return;
3041 const arr = signalEls.get(sid) || [];
3042 arr.push(el);
3043 signalEls.set(sid, arr);
3044 };
3045 // Old-style inline spans (non-overlapping renderer).
3046 document.querySelectorAll('span.e[data-sid]').forEach((el) => {
3047 addSignalEl(el.getAttribute('data-sid'), el);
3048 });
3049 // Segmented spans (overlap/discontinuous-safe renderer).
3050 document.querySelectorAll('span.seg[data-sids]').forEach((el) => {
3051 const raw = (el.getAttribute('data-sids') || '').trim();
3052 if (!raw) return;
3053 raw.split(/\s+/).filter(Boolean).forEach((sid) => addSignalEl(sid, el));
3054 });
3055
3056 const trackEls = new Map();
3057 for (const [sid, els] of signalEls.entries()) {
3058 const meta = signalMeta.get(sid);
3059 const tid = meta ? meta.track : null;
3060 if (!tid) continue;
3061 const arr = trackEls.get(tid) || [];
3062 els.forEach((el) => arr.push(el));
3063 trackEls.set(tid, arr);
3064 }
3065
3066 const selectionBody = document.getElementById('selection-body');
3067 const selectionHint = document.getElementById('selection-hint');
3068 const defaultHint = selectionHint ? (selectionHint.textContent || '') : '';
3069 const setSelection = (text) => {
3070 if (!selectionBody) return;
3071 selectionBody.textContent = text;
3072 };
3073 const setHint = (text) => {
3074 if (!selectionHint) return;
3075 selectionHint.textContent = text || defaultHint;
3076 };
3077
3078 // Theme toggle: auto (prefers-color-scheme) → dark → light.
3079 const themeBtn = document.getElementById('theme-toggle');
3080 const themeKey = 'anno-theme';
3081 const applyTheme = (theme) => {
3082 const t = theme || 'auto';
3083 if (t === 'auto') {
3084 delete document.documentElement.dataset.theme;
3085 } else {
3086 document.documentElement.dataset.theme = t;
3087 }
3088 if (themeBtn) themeBtn.textContent = `theme: ${t}`;
3089 };
3090 const readTheme = () => {
3091 try { return localStorage.getItem(themeKey) || 'auto'; } catch (_) { return 'auto'; }
3092 };
3093 const writeTheme = (t) => {
3094 try { localStorage.setItem(themeKey, t); } catch (_) { /* ignore */ }
3095 };
3096 applyTheme(readTheme());
3097 if (themeBtn) {
3098 themeBtn.addEventListener('click', () => {
3099 const cur = readTheme();
3100 const next = cur === 'auto' ? 'dark' : (cur === 'dark' ? 'light' : 'auto');
3101 writeTheme(next);
3102 applyTheme(next);
3103 });
3104 }
3105
3106 let activeSignalEls = [];
3107 let activeSignalRow = null;
3108 const clearActive = () => {
3109 if (activeSignalEls && activeSignalEls.length) {
3110 activeSignalEls.forEach((el) => el.classList.remove('e-active'));
3111 }
3112 if (activeSignalRow) activeSignalRow.classList.remove('e-active');
3113 activeSignalEls = [];
3114 activeSignalRow = null;
3115 };
3116
3117 let activeTrack = null;
3118 let hoverTrack = null;
3119
3120 const removeTrackClass = (tid, cls) => {
3121 if (!tid) return;
3122 const els = trackEls.get(tid);
3123 if (!els) return;
3124 els.forEach((el) => el.classList.remove(cls));
3125 };
3126
3127 const addTrackClass = (tid, cls) => {
3128 if (!tid) return;
3129 const els = trackEls.get(tid);
3130 if (!els) return;
3131 els.forEach((el) => el.classList.add(cls));
3132 };
3133
3134 const trackSize = (tid) => {
3135 const els = tid ? trackEls.get(tid) : null;
3136 return els ? els.length : 0;
3137 };
3138
3139 const getTrackSelectionText = (tid) => {
3140 if (!tid) return 'track: - (untracked)';
3141 const row = document.querySelector(`#tracks-table tr[data-tid='${tid}']`);
3142 if (!row) return `track T${tid}`;
3143 const cells = row.querySelectorAll('td');
3144 const canonical = (cells[1]?.textContent || '').trim();
3145 const etype = (cells[2]?.textContent || '').trim();
3146 const count = (cells[3]?.textContent || '').trim();
3147 const sigs = (cells[4]?.textContent || '').trim();
3148 const lines = [];
3149 lines.push(`track T${tid} canonical="${canonical}" type="${etype}" mentions=${count}`);
3150 if (sigs) lines.push(`track signals: ${sigs}`);
3151 return lines.join('\n');
3152 };
3153
3154 const renderTrackSelection = (tid) => setSelection(getTrackSelectionText(tid));
3155
3156 const renderSignalSelectionBySid = (sid) => {
3157 const meta = signalMeta.get(sid);
3158 const label = meta ? (meta.label || '') : '';
3159 const conf = meta ? (meta.conf || '') : '';
3160 const start = meta ? meta.start : null;
3161 const end = meta ? meta.end : null;
3162 const tid = meta ? meta.track : null;
3163 const lines = [];
3164 if (start !== null && end !== null) {
3165 lines.push(`signal ${sid} label=${label} conf=${conf} span=[${start},${end})`);
3166 } else {
3167 lines.push(`signal ${sid} label=${label} conf=${conf}`);
3168 }
3169 if (meta && meta.surface) lines.push(`surface: ${meta.surface}`);
3170 lines.push('');
3171 lines.push(getTrackSelectionText(tid));
3172 setSelection(lines.join('\n'));
3173 };
3174
3175 const setActiveTrack = (tid) => {
3176 const next = tid || null;
3177 if (activeTrack === next) return;
3178 removeTrackClass(activeTrack, 'e-track');
3179 activeTrack = next;
3180 if (activeTrack) addTrackClass(activeTrack, 'e-track');
3181 if (hoverTrack && activeTrack && hoverTrack === activeTrack) {
3182 removeTrackClass(hoverTrack, 'e-track-hover');
3183 }
3184 };
3185
3186 const setHoverTrack = (tid) => {
3187 const next = tid || null;
3188 if (hoverTrack === next) return;
3189 removeTrackClass(hoverTrack, 'e-track-hover');
3190 hoverTrack = next;
3191 if (!hoverTrack) {
3192 setHint('');
3193 return;
3194 }
3195 if (activeTrack && hoverTrack === activeTrack) {
3196 setHint(`selected track T${hoverTrack} (${trackSize(hoverTrack)} mentions)`);
3197 return;
3198 }
3199 addTrackClass(hoverTrack, 'e-track-hover');
3200 setHint(`hover track T${hoverTrack} (${trackSize(hoverTrack)} mentions)`);
3201 };
3202
3203 const emitToParentSpan = (start, end) => {
3204 try {
3205 if (!window.parent || window.parent === window) return;
3206 if (start === null || end === null) return;
3207 window.parent.postMessage({ type: 'anno:activate-span', start: Number(start), end: Number(end) }, '*');
3208 } catch (_) {
3209 // ignore: best-effort bridge for iframe containers
3210 }
3211 };
3212
3213 const activateBySpan = (start, end, emit) => {
3214 if (start === null || end === null || start === undefined || end === undefined) return;
3215 // Prefer an exact signal span if present; otherwise fall back to the table row metadata.
3216 const el = document.querySelector(`span.e[data-sid][data-start='${start}'][data-end='${end}']`);
3217 if (el) {
3218 const sid = el.getAttribute('data-sid');
3219 if (sid) activateSignal(sid, emit);
3220 return;
3221 }
3222 const row = document.querySelector(`#signals-table tr[data-start='${start}'][data-end='${end}']`);
3223 if (!row) return;
3224 const sid = row.getAttribute('data-sid');
3225 if (!sid) return;
3226 activateSignal(sid, emit);
3227 };
3228
3229 const activateSignal = (sid, emit) => {
3230 clearActive();
3231 const els = signalEls.get(sid) || [];
3232 if (!els.length) return;
3233 els.forEach((el) => el.classList.add('e-active'));
3234 activeSignalEls = els;
3235 const row = document.querySelector(`#signals-table tr[data-sid='${sid}']`);
3236 if (row) {
3237 row.classList.add('e-active');
3238 activeSignalRow = row;
3239 }
3240 const primaryEl = els[0];
3241 primaryEl.scrollIntoView({ block: 'center', behavior: 'smooth' });
3242 const meta = signalMeta.get(sid);
3243 const tid = meta ? meta.track : primaryEl.getAttribute('data-track');
3244 setActiveTrack(tid);
3245 renderSignalSelectionBySid(sid);
3246 if (emit && meta && meta.start !== null && meta.end !== null) {
3247 emitToParentSpan(meta.start, meta.end);
3248 }
3249 };
3250
3251 // Table click
3252 const signalsTable = document.getElementById('signals-table');
3253 if (signalsTable) {
3254 signalsTable.addEventListener('click', (ev) => {
3255 const a = ev.target && ev.target.closest ? ev.target.closest("a[href^='#S']") : null;
3256 const row = ev.target && ev.target.closest ? ev.target.closest('tr[data-sid]') : null;
3257 const sid = (a && a.getAttribute('href') ? a.getAttribute('href').slice(1) : null) || (row ? row.getAttribute('data-sid') : null);
3258 if (!sid) return;
3259 ev.preventDefault();
3260 activateSignal(sid, true);
3261 history.replaceState(null, '', '#' + sid);
3262 });
3263
3264 // Hover a signals row → preview track highlight
3265 signalsTable.addEventListener('mouseover', (ev) => {
3266 const row = ev.target && ev.target.closest ? ev.target.closest('tr[data-sid]') : null;
3267 if (!row) return;
3268 const tid = row.getAttribute('data-track');
3269 setHoverTrack(tid);
3270 });
3271 signalsTable.addEventListener('mouseout', (ev) => {
3272 const to = ev.relatedTarget;
3273 if (to && signalsTable.contains(to)) return;
3274 setHoverTrack(null);
3275 });
3276 }
3277
3278 // Clicking an inline entity should also toggle active highlight.
3279 const pickPrimarySid = (el) => {
3280 if (!el) return null;
3281 const p = el.getAttribute('data-primary');
3282 if (p) return p;
3283 const raw = (el.getAttribute('data-sids') || '').trim();
3284 if (!raw) return null;
3285 const sids = raw.split(/\s+/).filter(Boolean);
3286 if (!sids.length) return null;
3287 // Prefer the shortest mention span from metadata.
3288 let best = sids[0];
3289 let bestLen = null;
3290 for (const sid of sids) {
3291 const meta = signalMeta.get(sid);
3292 const s = meta && meta.start !== null ? Number(meta.start) : null;
3293 const e = meta && meta.end !== null ? Number(meta.end) : null;
3294 const len = (s !== null && e !== null) ? (e - s) : null;
3295 if (len === null) continue;
3296 if (bestLen === null || len < bestLen) {
3297 best = sid;
3298 bestLen = len;
3299 }
3300 }
3301 return best;
3302 };
3303
3304 document.addEventListener('click', (ev) => {
3305 const span = ev.target && ev.target.closest ? ev.target.closest('span.e[data-sid]') : null;
3306 if (span) {
3307 activateSignal(span.getAttribute('data-sid'), true);
3308 return;
3309 }
3310 const seg = ev.target && ev.target.closest ? ev.target.closest('span.seg[data-sids]') : null;
3311 if (!seg) return;
3312 activateSignal(pickPrimarySid(seg), true);
3313 });
3314
3315 // Hover an inline entity → preview highlight its track
3316 document.addEventListener('mouseover', (ev) => {
3317 const span = ev.target && ev.target.closest ? ev.target.closest('span.e[data-sid]') : null;
3318 if (span) {
3319 setHoverTrack(span.getAttribute('data-track'));
3320 return;
3321 }
3322 const seg = ev.target && ev.target.closest ? ev.target.closest('span.seg[data-sids]') : null;
3323 if (!seg) return;
3324 const sid = pickPrimarySid(seg);
3325 const meta = sid ? signalMeta.get(sid) : null;
3326 setHoverTrack(meta ? meta.track : null);
3327 });
3328 document.addEventListener('mouseout', (ev) => {
3329 const span = ev.target && ev.target.closest ? ev.target.closest('span.e[data-sid]') : null;
3330 const seg = ev.target && ev.target.closest ? ev.target.closest('span.seg[data-sids]') : null;
3331 if (!span && !seg) return;
3332 const to = ev.relatedTarget;
3333 if (to && to.closest && (to.closest('span.e[data-sid]') || to.closest('span.seg[data-sids]'))) return;
3334 setHoverTrack(null);
3335 });
3336
3337 // Clicking a track row → select track (highlight + details)
3338 const tracksTable = document.getElementById('tracks-table');
3339 if (tracksTable) {
3340 tracksTable.addEventListener('click', (ev) => {
3341 const row = ev.target && ev.target.closest ? ev.target.closest('tr[data-tid]') : null;
3342 if (!row) return;
3343 const tid = row.getAttribute('data-tid');
3344 setActiveTrack(tid);
3345 renderTrackSelection(tid);
3346 });
3347 tracksTable.addEventListener('mouseover', (ev) => {
3348 const row = ev.target && ev.target.closest ? ev.target.closest('tr[data-tid]') : null;
3349 if (!row) return;
3350 setHoverTrack(row.getAttribute('data-tid'));
3351 });
3352 tracksTable.addEventListener('mouseout', (ev) => {
3353 const to = ev.relatedTarget;
3354 if (to && tracksTable.contains(to)) return;
3355 setHoverTrack(null);
3356 });
3357 }
3358
3359 // Filter
3360 const input = document.getElementById('signal-filter');
3361 const countEl = document.getElementById('signal-filter-count');
3362 if (input && signalsTable) {
3363 const update = () => {
3364 const q = (input.value || '').trim().toLowerCase();
3365 let shown = 0;
3366 const rows = signalsTable.querySelectorAll('tr[data-sid]');
3367 rows.forEach(row => {
3368 const sid = (row.getAttribute('data-sid') || '').toLowerCase();
3369 const label = (row.getAttribute('data-label') || '').toLowerCase();
3370 const surface = (row.getAttribute('data-surface') || '').toLowerCase();
3371 const ok = !q || sid.includes(q) || label.includes(q) || surface.includes(q);
3372 row.style.display = ok ? '' : 'none';
3373 if (ok) shown += 1;
3374 });
3375 if (countEl) countEl.textContent = shown + ' shown';
3376 };
3377 input.addEventListener('input', update);
3378 update();
3379 }
3380
3381 // Panel toggles
3382 document.querySelectorAll('[data-toggle]').forEach(btn => {
3383 btn.addEventListener('click', () => {
3384 const id = btn.getAttribute('data-toggle');
3385 const panel = id ? document.getElementById(id) : null;
3386 if (!panel) return;
3387 panel.classList.toggle('panel-collapsed');
3388 });
3389 });
3390
3391 // If URL hash is #S123, focus it.
3392 const hash = (location.hash || '').slice(1);
3393 if (hash && hash.startsWith('S')) activateSignal(hash, false);
3394
3395 // Optional: allow parent pages (e.g., dataset explorers) to sync selection across iframes.
3396 window.addEventListener('message', (ev) => {
3397 const data = ev && ev.data ? ev.data : null;
3398 if (!data || data.type !== 'anno:activate-span') return;
3399 if (typeof data.start !== 'number' || typeof data.end !== 'number') return;
3400 activateBySpan(data.start, data.end, false);
3401 });
3402})();
3403</script>"#);
3404
3405 html.push_str(r#"</body></html>"#);
3406 html
3407}
3408
3409fn html_escape(s: &str) -> String {
3410 s.replace('&', "&")
3411 .replace('<', "<")
3412 .replace('>', ">")
3413 .replace('"', """)
3414}
3415
3416fn annotate_text_html(
3417 text: &str,
3418 signals: &[Signal<Location>],
3419 signal_to_track: &std::collections::HashMap<SignalId, TrackId>,
3420) -> String {
3421 let char_count = text.chars().count();
3422 if char_count == 0 {
3423 return String::new();
3424 }
3425
3426 #[derive(Debug, Clone)]
3427 struct SigMeta {
3428 sid: String,
3429 label: String,
3430 conf: f32,
3431 track_id: Option<TrackId>,
3432 covered_len: usize,
3433 }
3434
3435 #[derive(Debug, Clone)]
3436 struct Event {
3437 pos: usize,
3438 meta_idx: usize,
3439 delta: i32, }
3441
3442 let mut metas: Vec<SigMeta> = Vec::new();
3444 let mut events: Vec<Event> = Vec::new();
3445 let mut boundaries: Vec<usize> = vec![0, char_count];
3446
3447 for s in signals {
3448 let raw_segments: Vec<(usize, usize)> = match &s.location {
3449 Location::Text { start, end } => vec![(*start, *end)],
3450 Location::TextWithBbox { start, end, .. } => vec![(*start, *end)],
3451 Location::Discontinuous { segments } => segments.clone(),
3452 _ => Vec::new(),
3453 };
3454 if raw_segments.is_empty() {
3455 continue;
3456 }
3457
3458 let mut cleaned: Vec<(usize, usize)> = Vec::new();
3459 let mut covered_len = 0usize;
3460 for (start, end) in raw_segments {
3461 let start = start.min(char_count);
3462 let end = end.min(char_count);
3463 if start >= end {
3464 continue;
3465 }
3466 covered_len = covered_len.saturating_add(end - start);
3467 cleaned.push((start, end));
3468 }
3469 if cleaned.is_empty() {
3470 continue;
3471 }
3472
3473 let meta_idx = metas.len();
3474 let track_id = signal_to_track.get(&s.id).copied();
3475 metas.push(SigMeta {
3476 sid: format!("S{}", s.id),
3477 label: s.label.to_string(),
3478 conf: s.confidence,
3479 track_id,
3480 covered_len,
3481 });
3482
3483 for (start, end) in cleaned {
3484 boundaries.push(start);
3485 boundaries.push(end);
3486 events.push(Event {
3487 pos: start,
3488 meta_idx,
3489 delta: 1,
3490 });
3491 events.push(Event {
3492 pos: end,
3493 meta_idx,
3494 delta: -1,
3495 });
3496 }
3497 }
3498
3499 if metas.is_empty() {
3500 return html_escape(text);
3501 }
3502
3503 boundaries.sort_unstable();
3504 boundaries.dedup();
3505 events.sort_by(|a, b| a.pos.cmp(&b.pos).then_with(|| a.delta.cmp(&b.delta)));
3506
3507 let mut active_counts: Vec<u32> = vec![0; metas.len()];
3508 let mut active: Vec<usize> = Vec::new();
3509 let mut ev_idx = 0usize;
3510
3511 let mut result = String::new();
3512
3513 for bi in 0..boundaries.len().saturating_sub(1) {
3514 let pos = boundaries[bi];
3515 while ev_idx < events.len() && events[ev_idx].pos == pos {
3517 let e = &events[ev_idx];
3518 let idx = e.meta_idx;
3519 if e.delta < 0 {
3520 if active_counts[idx] > 0 {
3521 active_counts[idx] -= 1;
3522 if active_counts[idx] == 0 {
3523 active.retain(|&x| x != idx);
3524 }
3525 }
3526 } else {
3527 active_counts[idx] += 1;
3528 if active_counts[idx] == 1 {
3529 active.push(idx);
3530 }
3531 }
3532 ev_idx += 1;
3533 }
3534
3535 let next = boundaries[bi + 1];
3536 if next <= pos {
3537 continue;
3538 }
3539
3540 let seg_text: String = text.chars().skip(pos).take(next - pos).collect();
3541 if active.is_empty() {
3542 result.push_str(&html_escape(&seg_text));
3543 continue;
3544 }
3545
3546 let primary_idx = active
3548 .iter()
3549 .copied()
3550 .min_by(|a, b| {
3551 metas[*a]
3552 .covered_len
3553 .cmp(&metas[*b].covered_len)
3554 .then_with(|| {
3555 metas[*b]
3556 .conf
3557 .partial_cmp(&metas[*a].conf)
3558 .unwrap_or(std::cmp::Ordering::Equal)
3559 })
3560 })
3561 .unwrap_or(active[0]);
3562 let primary = &metas[primary_idx];
3563
3564 let class = match primary.label.to_uppercase().as_str() {
3565 "PER" | "PERSON" => "e-per",
3566 "ORG" | "ORGANIZATION" | "COMPANY" => "e-org",
3567 "LOC" | "LOCATION" | "GPE" => "e-loc",
3568 "DATE" | "TIME" => "e-date",
3569 _ => "e-misc",
3570 };
3571
3572 let mut sids: Vec<&str> = active.iter().map(|i| metas[*i].sid.as_str()).collect();
3573 sids.sort_unstable();
3574 let data_sids = sids.join(" ");
3575
3576 let mut title = format!(
3577 "sids=[{}] primary={} [{}..{})",
3578 data_sids, primary.sid, pos, next
3579 );
3580 if let Some(t) = primary.track_id {
3581 title.push_str(&format!(" track=T{}", t));
3582 }
3583
3584 result.push_str(&format!(
3585 r#"<span class="e seg {class}" data-sids="{sids}" data-start="{start}" data-end="{end}" data-primary="{primary}" title="{title}">{text}</span>"#,
3586 class = class,
3587 sids = html_escape(&data_sids),
3588 start = pos,
3589 end = next,
3590 primary = html_escape(&primary.sid),
3591 title = html_escape(&title),
3592 text = html_escape(&seg_text),
3593 ));
3594 }
3595
3596 result
3597}
3598
3599#[derive(Debug, Clone)]
3605pub struct EvalComparison {
3606 pub text: String,
3608 pub gold: Vec<Signal<Location>>,
3610 pub predicted: Vec<Signal<Location>>,
3612 pub matches: Vec<EvalMatch>,
3614}
3615
3616#[derive(Debug, Clone)]
3618pub enum EvalMatch {
3619 Correct {
3621 gold_id: SignalId,
3623 pred_id: SignalId,
3625 },
3626 TypeMismatch {
3628 gold_id: SignalId,
3630 pred_id: SignalId,
3632 gold_label: String,
3634 pred_label: String,
3636 },
3637 BoundaryError {
3639 gold_id: SignalId,
3641 pred_id: SignalId,
3643 iou: f64,
3645 },
3646 Spurious {
3648 pred_id: SignalId,
3650 },
3651 Missed {
3653 gold_id: SignalId,
3655 },
3656}
3657
3658impl EvalComparison {
3659 #[must_use]
3678 pub fn compare(
3679 text: &str,
3680 gold: Vec<Signal<Location>>,
3681 predicted: Vec<Signal<Location>>,
3682 ) -> Self {
3683 let mut matches = Vec::new();
3684 let mut gold_matched = vec![false; gold.len()];
3685 let mut pred_matched = vec![false; predicted.len()];
3686
3687 for (pi, pred) in predicted.iter().enumerate() {
3689 let pred_offsets = match pred.location.text_offsets() {
3690 Some(o) => o,
3691 None => continue,
3692 };
3693
3694 for (gi, g) in gold.iter().enumerate() {
3695 if gold_matched[gi] {
3696 continue;
3697 }
3698 let gold_offsets = match g.location.text_offsets() {
3699 Some(o) => o,
3700 None => continue,
3701 };
3702
3703 if pred_offsets == gold_offsets {
3705 if pred.label == g.label {
3706 matches.push(EvalMatch::Correct {
3707 gold_id: g.id,
3708 pred_id: pred.id,
3709 });
3710 } else {
3711 matches.push(EvalMatch::TypeMismatch {
3712 gold_id: g.id,
3713 pred_id: pred.id,
3714 gold_label: g.label.to_string(),
3715 pred_label: pred.label.to_string(),
3716 });
3717 }
3718 gold_matched[gi] = true;
3719 pred_matched[pi] = true;
3720 break;
3721 }
3722 }
3723 }
3724
3725 for (pi, pred) in predicted.iter().enumerate() {
3727 if pred_matched[pi] {
3728 continue;
3729 }
3730 let pred_offsets = match pred.location.text_offsets() {
3731 Some(o) => o,
3732 None => continue,
3733 };
3734
3735 for (gi, g) in gold.iter().enumerate() {
3736 if gold_matched[gi] {
3737 continue;
3738 }
3739 let gold_offsets = match g.location.text_offsets() {
3740 Some(o) => o,
3741 None => continue,
3742 };
3743
3744 if pred_offsets.0 < gold_offsets.1 && pred_offsets.1 > gold_offsets.0 {
3746 let iou = pred.location.iou(&g.location).unwrap_or(0.0);
3747 matches.push(EvalMatch::BoundaryError {
3748 gold_id: g.id,
3749 pred_id: pred.id,
3750 iou,
3751 });
3752 gold_matched[gi] = true;
3753 pred_matched[pi] = true;
3754 break;
3755 }
3756 }
3757 }
3758
3759 for (pi, pred) in predicted.iter().enumerate() {
3761 if !pred_matched[pi] {
3762 matches.push(EvalMatch::Spurious { pred_id: pred.id });
3763 }
3764 }
3765
3766 for (gi, g) in gold.iter().enumerate() {
3768 if !gold_matched[gi] {
3769 matches.push(EvalMatch::Missed { gold_id: g.id });
3770 }
3771 }
3772
3773 Self {
3774 text: text.to_string(),
3775 gold,
3776 predicted,
3777 matches,
3778 }
3779 }
3780
3781 #[must_use]
3783 pub fn correct_count(&self) -> usize {
3784 self.matches
3785 .iter()
3786 .filter(|m| matches!(m, EvalMatch::Correct { .. }))
3787 .count()
3788 }
3789
3790 #[must_use]
3792 pub fn error_count(&self) -> usize {
3793 self.matches.len() - self.correct_count()
3794 }
3795
3796 #[must_use]
3798 pub fn precision(&self) -> f64 {
3799 if self.predicted.is_empty() {
3800 0.0
3801 } else {
3802 self.correct_count() as f64 / self.predicted.len() as f64
3803 }
3804 }
3805
3806 #[must_use]
3808 pub fn recall(&self) -> f64 {
3809 if self.gold.is_empty() {
3810 0.0
3811 } else {
3812 self.correct_count() as f64 / self.gold.len() as f64
3813 }
3814 }
3815
3816 #[must_use]
3818 pub fn f1(&self) -> f64 {
3819 let p = self.precision();
3820 let r = self.recall();
3821 if p + r > 0.0 {
3822 2.0 * p * r / (p + r)
3823 } else {
3824 0.0
3825 }
3826 }
3827}
3828
3829pub fn render_eval_html(cmp: &EvalComparison) -> String {
3833 render_eval_html_with_title(cmp, "eval comparison")
3834}
3835
3836#[must_use]
3840pub fn render_eval_html_with_title(cmp: &EvalComparison, title: &str) -> String {
3841 let mut html = String::new();
3842 let title = html_escape(title);
3843
3844 html.push_str(
3845 r#"<!DOCTYPE html>
3846<html>
3847<head>
3848<meta charset="UTF-8">
3849<meta name="color-scheme" content="dark light">
3850"#,
3851 );
3852 html.push_str(&format!("<title>{}</title>", title));
3853 html.push_str(r#"
3854:root{
3855 color-scheme: light dark;
3856 --bg:#0a0a0a;
3857 --panel-bg:#0d0d0d;
3858 --text:#b0b0b0;
3859 --text-strong:#fff;
3860 --muted:#666;
3861 --border:#222;
3862 --border-strong:#333;
3863 --hover:#111;
3864 --input-bg:#080808;
3865 --active:#ddd;
3866 /* Eval entity colors (dark) */
3867 --gold-bg:#1a2e1a; --gold-br:#4a8a4a; --gold-tx:#88cc88;
3868 --pred-bg:#1a1a2e; --pred-br:#4a4a8a; --pred-tx:#8888cc;
3869 /* Match row borders */
3870 --m-ok:#4a8a4a;
3871 --m-type:#8a8a4a;
3872 --m-bound:#4a8a8a;
3873 --m-fp:#8a4a4a;
3874 --m-fn:#8a4a8a;
3875}
3876@media (prefers-color-scheme: light){
3877 :root{
3878 --bg:#ffffff;
3879 --panel-bg:#f7f7f7;
3880 --text:#222;
3881 --text-strong:#000;
3882 --muted:#555;
3883 --border:#d6d6d6;
3884 --border-strong:#c6c6c6;
3885 --hover:#f0f0f0;
3886 --input-bg:#ffffff;
3887 --active:#000;
3888 --gold-bg:#e9f7e9; --gold-br:#2f8a2f; --gold-tx:#1f5a1f;
3889 --pred-bg:#e9e9ff; --pred-br:#6c6cff; --pred-tx:#2b2b7a;
3890 --m-ok:#2f8a2f;
3891 --m-type:#8a7a2f;
3892 --m-bound:#2f7a8a;
3893 --m-fp:#8a2f2f;
3894 --m-fn:#6a2f8a;
3895 }
3896}
3897html[data-theme='dark']{
3898 --bg:#0a0a0a; --panel-bg:#0d0d0d; --text:#b0b0b0; --text-strong:#fff;
3899 --muted:#666; --border:#222; --border-strong:#333; --hover:#111; --input-bg:#080808; --active:#ddd;
3900 --gold-bg:#1a2e1a; --gold-br:#4a8a4a; --gold-tx:#88cc88;
3901 --pred-bg:#1a1a2e; --pred-br:#4a4a8a; --pred-tx:#8888cc;
3902 --m-ok:#4a8a4a; --m-type:#8a8a4a; --m-bound:#4a8a8a; --m-fp:#8a4a4a; --m-fn:#8a4a8a;
3903}
3904html[data-theme='light']{
3905 --bg:#ffffff; --panel-bg:#f7f7f7; --text:#222; --text-strong:#000;
3906 --muted:#555; --border:#d6d6d6; --border-strong:#c6c6c6; --hover:#f0f0f0; --input-bg:#ffffff; --active:#000;
3907 --gold-bg:#e9f7e9; --gold-br:#2f8a2f; --gold-tx:#1f5a1f;
3908 --pred-bg:#e9e9ff; --pred-br:#6c6cff; --pred-tx:#2b2b7a;
3909 --m-ok:#2f8a2f; --m-type:#8a7a2f; --m-bound:#2f7a8a; --m-fp:#8a2f2f; --m-fn:#6a2f8a;
3910}
3911
3912<style>
3913*{box-sizing:border-box;margin:0;padding:0}
3914body{font:12px/1.4 monospace;background:var(--bg);color:var(--text);padding:8px}
3915h1,h2{color:var(--text-strong);font-weight:normal;border-bottom:1px solid var(--border-strong);padding:4px 0;margin:16px 0 8px}
3916h1{font-size:14px}h2{font-size:12px}
3917table{width:100%;border-collapse:collapse;font-size:11px;margin:4px 0}
3918th,td{padding:4px 8px;text-align:left;border:1px solid var(--border)}
3919th{background:var(--hover);color:var(--muted);font-weight:normal;text-transform:uppercase;font-size:10px}
3920tr:hover{background:var(--hover)}
3921.grid{display:grid;grid-template-columns:1fr 1fr;gap:8px}
3922.panel{border:1px solid var(--border);background:var(--panel-bg);padding:8px}
3923.text-box{background:var(--input-bg);border:1px solid var(--border);padding:8px;white-space:pre-wrap;word-break:break-word;line-height:1.6}
3924.stats{display:flex;gap:24px;padding:8px 0;border-bottom:1px solid var(--border);margin-bottom:8px}
3925.stat{text-align:center}.stat-v{font-size:18px;color:var(--text-strong)}.stat-l{font-size:9px;color:var(--muted);text-transform:uppercase}
3926/* Entities */
3927.e{padding:1px 2px;border-bottom:2px solid}
3928.seg{cursor:pointer}
3929.e-gold{background:var(--gold-bg);border-color:var(--gold-br);color:var(--gold-tx)}
3930.e-pred{background:var(--pred-bg);border-color:var(--pred-br);color:var(--pred-tx)}
3931.e-active{outline:1px solid var(--active);outline-offset:1px}
3932/* Match types */
3933.correct{background:#1a2e1a;border-color:#4a8a4a}
3934.type-err{background:#2e2e1a;border-color:#8a8a4a}
3935.boundary{background:#1a2e2e;border-color:#4a8a8a}
3936.spurious{background:#2e1a1a;border-color:#8a4a4a}
3937.missed{background:#2e1a2e;border-color:#8a4a8a}
3938.match-row.correct{border-left:3px solid var(--m-ok)}
3939.match-row.type-err{border-left:3px solid var(--m-type)}
3940.match-row.boundary{border-left:3px solid var(--m-bound)}
3941.match-row.spurious{border-left:3px solid var(--m-fp)}
3942.match-row.missed{border-left:3px solid var(--m-fn)}
3943.match-row.active{outline:1px solid var(--muted)}
3944.sel{color:var(--muted);margin:6px 0 12px}
3945.metric{font-size:14px;color:var(--muted)}.metric b{color:var(--text-strong)}
3946</style>
3947</head>
3948<body>
3949"#);
3950
3951 html.push_str(&format!(
3953 "<div class=\"panel-h\" style=\"justify-content:space-between\"><h1>{}</h1><span class=\"toggle\" id=\"theme-toggle\" title=\"toggle theme (auto → dark → light)\">theme: auto</span></div>",
3954 title
3955 ));
3956
3957 html.push_str("<div class=\"stats\">");
3959 html.push_str(&format!(
3960 "<div class=\"stat\"><div class=\"stat-v\">{}</div><div class=\"stat-l\">gold</div></div>",
3961 cmp.gold.len()
3962 ));
3963 html.push_str(&format!(
3964 "<div class=\"stat\"><div class=\"stat-v\">{}</div><div class=\"stat-l\">predicted</div></div>",
3965 cmp.predicted.len()
3966 ));
3967 html.push_str(&format!(
3968 "<div class=\"stat\"><div class=\"stat-v\">{}</div><div class=\"stat-l\">correct</div></div>",
3969 cmp.correct_count()
3970 ));
3971 html.push_str(&format!(
3972 "<div class=\"stat\"><div class=\"stat-v\">{}</div><div class=\"stat-l\">errors</div></div>",
3973 cmp.error_count()
3974 ));
3975 html.push_str(&format!(
3976 "<div class=\"metric\">P=<b>{:.1}%</b> R=<b>{:.1}%</b> F1=<b>{:.1}%</b></div>",
3977 cmp.precision() * 100.0,
3978 cmp.recall() * 100.0,
3979 cmp.f1() * 100.0
3980 ));
3981 html.push_str("</div>");
3982
3983 html.push_str("<div id=\"selection\" class=\"sel\">click a match row to select spans</div>");
3985
3986 html.push_str("<div class=\"grid\">");
3988
3989 html.push_str("<div class=\"panel\"><h2>gold (ground truth)</h2><div class=\"text-box\">");
3991 let gold_spans: Vec<EvalHtmlSpan> = cmp
3992 .gold
3993 .iter()
3994 .map(|s| {
3995 let (start, end) = s.location.text_offsets().unwrap_or((0, 0));
3996 EvalHtmlSpan {
3997 start,
3998 end,
3999 label: s.label.to_string(),
4000 class: "e-gold",
4001 id: format!("G{}", s.id),
4002 }
4003 })
4004 .collect();
4005 html.push_str(&annotate_text_spans(&cmp.text, &gold_spans));
4006 html.push_str("</div></div>");
4007
4008 html.push_str("<div class=\"panel\"><h2>predicted</h2><div class=\"text-box\">");
4010 let pred_spans: Vec<EvalHtmlSpan> = cmp
4011 .predicted
4012 .iter()
4013 .map(|s| {
4014 let (start, end) = s.location.text_offsets().unwrap_or((0, 0));
4015 EvalHtmlSpan {
4016 start,
4017 end,
4018 label: s.label.to_string(),
4019 class: "e-pred",
4020 id: format!("P{}", s.id),
4021 }
4022 })
4023 .collect();
4024 html.push_str(&annotate_text_spans(&cmp.text, &pred_spans));
4025 html.push_str("</div></div>");
4026
4027 html.push_str("</div>");
4028
4029 html.push_str("<h2>matches</h2><table>");
4031 html.push_str("<tr><th>type</th><th>gold</th><th>predicted</th><th>notes</th></tr>");
4032
4033 for (mi, m) in cmp.matches.iter().enumerate() {
4034 let (class, mtype, gold_text, pred_text, notes, gid, pid) = match m {
4035 EvalMatch::Correct { gold_id, pred_id } => {
4036 let g = cmp.gold.iter().find(|s| s.id == *gold_id);
4037 let p = cmp.predicted.iter().find(|s| s.id == *pred_id);
4038 (
4039 "correct",
4040 "✓",
4041 g.map(|s| format!("[{}] {}", s.label, s.surface()))
4042 .unwrap_or_default(),
4043 p.map(|s| format!("[{}] {}", s.label, s.surface()))
4044 .unwrap_or_default(),
4045 String::new(),
4046 Some(format!("G{}", gold_id)),
4047 Some(format!("P{}", pred_id)),
4048 )
4049 }
4050 EvalMatch::TypeMismatch {
4051 gold_id,
4052 pred_id,
4053 gold_label,
4054 pred_label,
4055 } => {
4056 let g = cmp.gold.iter().find(|s| s.id == *gold_id);
4057 let p = cmp.predicted.iter().find(|s| s.id == *pred_id);
4058 (
4059 "type-err",
4060 "type",
4061 g.map(|s| format!("[{}] {}", s.label, s.surface()))
4062 .unwrap_or_default(),
4063 p.map(|s| format!("[{}] {}", s.label, s.surface()))
4064 .unwrap_or_default(),
4065 format!("{} → {}", gold_label, pred_label),
4066 Some(format!("G{}", gold_id)),
4067 Some(format!("P{}", pred_id)),
4068 )
4069 }
4070 EvalMatch::BoundaryError {
4071 gold_id,
4072 pred_id,
4073 iou,
4074 } => {
4075 let g = cmp.gold.iter().find(|s| s.id == *gold_id);
4076 let p = cmp.predicted.iter().find(|s| s.id == *pred_id);
4077 (
4078 "boundary",
4079 "bound",
4080 g.map(|s| format!("[{}] \"{}\"", s.label, s.surface()))
4081 .unwrap_or_default(),
4082 p.map(|s| format!("[{}] \"{}\"", s.label, s.surface()))
4083 .unwrap_or_default(),
4084 format!("IoU={:.2}", iou),
4085 Some(format!("G{}", gold_id)),
4086 Some(format!("P{}", pred_id)),
4087 )
4088 }
4089 EvalMatch::Spurious { pred_id } => {
4090 let p = cmp.predicted.iter().find(|s| s.id == *pred_id);
4091 (
4092 "spurious",
4093 "FP",
4094 String::new(),
4095 p.map(|s| format!("[{}] {}", s.label, s.surface()))
4096 .unwrap_or_default(),
4097 "false positive".to_string(),
4098 None,
4099 Some(format!("P{}", pred_id)),
4100 )
4101 }
4102 EvalMatch::Missed { gold_id } => {
4103 let g = cmp.gold.iter().find(|s| s.id == *gold_id);
4104 (
4105 "missed",
4106 "FN",
4107 g.map(|s| format!("[{}] {}", s.label, s.surface()))
4108 .unwrap_or_default(),
4109 String::new(),
4110 "false negative".to_string(),
4111 Some(format!("G{}", gold_id)),
4112 None,
4113 )
4114 }
4115 };
4116
4117 let mut data_attrs = String::new();
4118 if let Some(gid) = gid.as_deref() {
4119 data_attrs.push_str(&format!(" data-gid=\"{}\"", html_escape(gid)));
4120 }
4121 if let Some(pid) = pid.as_deref() {
4122 data_attrs.push_str(&format!(" data-pid=\"{}\"", html_escape(pid)));
4123 }
4124
4125 html.push_str(&format!(
4126 "<tr id=\"M{mid}\" class=\"match-row {class}\"{attrs}><td><a class=\"match-link\" href=\"#M{mid}\">{mtype}</a></td><td>{gold}</td><td>{pred}</td><td>{notes}</td></tr>",
4127 mid = mi,
4128 class = class,
4129 attrs = data_attrs,
4130 mtype = html_escape(mtype),
4131 gold = html_escape(&gold_text),
4132 pred = html_escape(&pred_text),
4133 notes = html_escape(¬es)
4134 ));
4135 }
4136 html.push_str("</table>");
4137
4138 html.push_str(
4139 r#"<script>
4140(() => {
4141 // Theme toggle: auto (prefers-color-scheme) → dark → light.
4142 const themeBtn = document.getElementById('theme-toggle');
4143 const themeKey = 'anno-theme';
4144 const applyTheme = (theme) => {
4145 const t = theme || 'auto';
4146 if (t === 'auto') {
4147 delete document.documentElement.dataset.theme;
4148 } else {
4149 document.documentElement.dataset.theme = t;
4150 }
4151 if (themeBtn) themeBtn.textContent = `theme: ${t}`;
4152 };
4153 const readTheme = () => {
4154 try { return localStorage.getItem(themeKey) || 'auto'; } catch (_) { return 'auto'; }
4155 };
4156 const writeTheme = (t) => {
4157 try { localStorage.setItem(themeKey, t); } catch (_) { /* ignore */ }
4158 };
4159 applyTheme(readTheme());
4160 if (themeBtn) {
4161 themeBtn.addEventListener('click', () => {
4162 const cur = readTheme();
4163 const next = cur === 'auto' ? 'dark' : (cur === 'dark' ? 'light' : 'auto');
4164 writeTheme(next);
4165 applyTheme(next);
4166 });
4167 }
4168
4169 function clearActive() {
4170 document.querySelectorAll(".e-active").forEach((el) => el.classList.remove("e-active"));
4171 document.querySelectorAll("tr.match-row.active").forEach((el) => el.classList.remove("active"));
4172 }
4173
4174 function findSpanEls(eid) {
4175 if (!eid) return [];
4176 // New segmented renderer: one span can be split across multiple elements.
4177 const els = Array.from(document.querySelectorAll(`span.e[data-eids~='${eid}']`));
4178 if (els.length) return els;
4179 // Back-compat: older HTML used a single element id.
4180 const single = document.getElementById(eid);
4181 return single ? [single] : [];
4182 }
4183
4184 function activate(gid, pid, row) {
4185 clearActive();
4186 const gEls = findSpanEls(gid);
4187 const pEls = findSpanEls(pid);
4188 const sel = document.getElementById("selection");
4189 gEls.forEach((el) => el.classList.add("e-active"));
4190 pEls.forEach((el) => el.classList.add("e-active"));
4191 if (row) row.classList.add("active");
4192 if (sel) {
4193 const parts = [];
4194 if (gEls.length) {
4195 const lbl = gEls[0].dataset && gEls[0].dataset.label ? ` [${gEls[0].dataset.label}]` : "";
4196 parts.push(`gold ${gid}${lbl}`);
4197 }
4198 if (pEls.length) {
4199 const lbl = pEls[0].dataset && pEls[0].dataset.label ? ` [${pEls[0].dataset.label}]` : "";
4200 parts.push(`pred ${pid}${lbl}`);
4201 }
4202 sel.textContent = parts.length ? parts.join(" | ") : "no selection";
4203 }
4204 if (row && row.id) {
4205 // Keep deep links stable without triggering navigation jump.
4206 // NOTE: single quotes avoid the Rust raw-string delimiter issue with quote+hash.
4207 history.replaceState(null, "", '#' + row.id);
4208 }
4209 const target = gEls[0] || pEls[0];
4210 if (target) target.scrollIntoView({ behavior: "smooth", block: "center" });
4211 }
4212
4213 document.querySelectorAll("tr.match-row[data-gid], tr.match-row[data-pid]").forEach((tr) => {
4214 tr.addEventListener("click", () => activate(tr.dataset.gid, tr.dataset.pid, tr));
4215 });
4216
4217 document.querySelectorAll("a.match-link").forEach((a) => {
4218 a.addEventListener("click", (ev) => {
4219 ev.preventDefault();
4220 const tr = a.closest("tr.match-row");
4221 if (!tr) return;
4222 activate(tr.dataset.gid, tr.dataset.pid, tr);
4223 });
4224 });
4225
4226 // Auto-select a match row if the URL has a deep link (e.g. #M12).
4227 const hash = (location.hash || "").slice(1);
4228 if (hash && hash.startsWith("M")) {
4229 const tr = document.getElementById(hash);
4230 if (tr && tr.classList && tr.classList.contains("match-row")) {
4231 activate(tr.dataset.gid, tr.dataset.pid, tr);
4232 }
4233 }
4234})();
4235</script>"#,
4236 );
4237
4238 html.push_str("</body></html>");
4239 html
4240}
4241
4242#[derive(Debug, Clone)]
4244struct EvalHtmlSpan {
4245 start: usize,
4246 end: usize,
4247 label: String,
4248 class: &'static str,
4249 id: String,
4250}
4251
4252fn annotate_text_spans(text: &str, spans: &[EvalHtmlSpan]) -> String {
4253 let char_count = text.chars().count();
4254 if char_count == 0 || spans.is_empty() {
4255 return html_escape(text);
4256 }
4257
4258 #[derive(Debug, Clone)]
4259 struct Meta {
4260 id: String,
4261 label: String,
4262 class: &'static str,
4263 len: usize,
4264 }
4265 #[derive(Debug, Clone)]
4266 struct Event {
4267 pos: usize,
4268 meta_idx: usize,
4269 delta: i32,
4270 }
4271
4272 let mut metas: Vec<Meta> = Vec::with_capacity(spans.len());
4273 let mut events: Vec<Event> = Vec::new();
4274 let mut boundaries: Vec<usize> = vec![0, char_count];
4275
4276 for s in spans {
4277 let start = s.start.min(char_count);
4278 let end = s.end.min(char_count);
4279 if start >= end {
4280 continue;
4281 }
4282 let meta_idx = metas.len();
4283 metas.push(Meta {
4284 id: s.id.clone(),
4285 label: s.label.to_string(),
4286 class: s.class,
4287 len: end - start,
4288 });
4289 boundaries.push(start);
4290 boundaries.push(end);
4291 events.push(Event {
4292 pos: start,
4293 meta_idx,
4294 delta: 1,
4295 });
4296 events.push(Event {
4297 pos: end,
4298 meta_idx,
4299 delta: -1,
4300 });
4301 }
4302
4303 if metas.is_empty() {
4304 return html_escape(text);
4305 }
4306
4307 boundaries.sort_unstable();
4308 boundaries.dedup();
4309 events.sort_by(|a, b| a.pos.cmp(&b.pos).then_with(|| a.delta.cmp(&b.delta)));
4310
4311 let mut active_counts: Vec<u32> = vec![0; metas.len()];
4312 let mut active: Vec<usize> = Vec::new();
4313 let mut ev_idx = 0usize;
4314 let mut result = String::new();
4315
4316 for bi in 0..boundaries.len().saturating_sub(1) {
4317 let pos = boundaries[bi];
4318 while ev_idx < events.len() && events[ev_idx].pos == pos {
4319 let e = &events[ev_idx];
4320 let idx = e.meta_idx;
4321 if e.delta < 0 {
4322 if active_counts[idx] > 0 {
4323 active_counts[idx] -= 1;
4324 if active_counts[idx] == 0 {
4325 active.retain(|&x| x != idx);
4326 }
4327 }
4328 } else {
4329 active_counts[idx] += 1;
4330 if active_counts[idx] == 1 {
4331 active.push(idx);
4332 }
4333 }
4334 ev_idx += 1;
4335 }
4336
4337 let next = boundaries[bi + 1];
4338 if next <= pos {
4339 continue;
4340 }
4341
4342 let seg_text: String = text.chars().skip(pos).take(next - pos).collect();
4343 if active.is_empty() {
4344 result.push_str(&html_escape(&seg_text));
4345 continue;
4346 }
4347
4348 let primary_idx = active
4349 .iter()
4350 .copied()
4351 .min_by_key(|i| metas[*i].len)
4352 .unwrap_or(active[0]);
4353 let primary = &metas[primary_idx];
4354 let mut eids: Vec<&str> = active.iter().map(|i| metas[*i].id.as_str()).collect();
4355 eids.sort_unstable();
4356 let data_eids = eids.join(" ");
4357
4358 let title = format!(
4359 "eids=[{}] primary={} [{}..{})",
4360 data_eids, primary.id, pos, next
4361 );
4362 result.push_str(&format!(
4363 "<span class=\"e seg {class}\" data-eids=\"{eids}\" data-label=\"{label}\" data-start=\"{start}\" data-end=\"{end}\" title=\"{title}\">{text}</span>",
4364 class = primary.class,
4365 eids = html_escape(&data_eids),
4366 label = html_escape(&primary.label),
4367 start = pos,
4368 end = next,
4369 title = html_escape(&title),
4370 text = html_escape(&seg_text)
4371 ));
4372 }
4373
4374 result
4375}
4376
4377#[derive(Debug, Clone, Default)]
4383pub struct ProcessOptions {
4384 pub labels: Vec<String>,
4386 pub threshold: f32,
4388}
4389
4390#[derive(Debug)]
4392pub struct ProcessResult {
4393 pub document: GroundedDocument,
4395 pub valid: bool,
4397 pub errors: Vec<SignalValidationError>,
4399}
4400
4401impl ProcessResult {
4402 #[must_use]
4404 pub fn to_html(&self) -> String {
4405 render_document_html(&self.document)
4406 }
4407}
4408
4409#[allow(dead_code)]
4429#[doc(hidden)]
4430pub fn process_text(
4431 _text: &str,
4432 _model: Option<&dyn std::any::Any>,
4433) -> super::Result<ProcessResult> {
4434 unimplemented!("Use anno::process_text instead - this stub documents the API only")
4435}
4436
4437#[derive(Debug, Clone)]
4455pub struct Corpus {
4456 documents: std::collections::HashMap<String, GroundedDocument>,
4457 identities: std::collections::HashMap<IdentityId, Identity>,
4458 next_identity_id: IdentityId,
4459}
4460
4461impl Corpus {
4462 #[must_use]
4464 pub fn new() -> Self {
4465 Self {
4466 documents: std::collections::HashMap::new(),
4467 identities: std::collections::HashMap::new(),
4468 next_identity_id: IdentityId::ZERO,
4469 }
4470 }
4471
4472 #[must_use]
4474 pub fn identities(&self) -> &std::collections::HashMap<IdentityId, Identity> {
4475 &self.identities
4476 }
4477
4478 #[must_use]
4480 pub fn get_identity(&self, id: IdentityId) -> Option<&Identity> {
4481 self.identities.get(&id)
4482 }
4483
4484 pub fn add_identity(&mut self, mut identity: Identity) -> IdentityId {
4489 let id = self.next_identity_id;
4490 identity.id = id;
4491 self.identities.insert(id, identity);
4492 self.next_identity_id += 1;
4493 id
4494 }
4495
4496 #[must_use]
4500 pub fn next_identity_id(&self) -> IdentityId {
4501 self.next_identity_id
4502 }
4503
4504 pub fn documents(&self) -> impl Iterator<Item = &GroundedDocument> {
4508 self.documents.values()
4509 }
4510
4511 #[must_use]
4515 pub fn get_document(&self, doc_id: &str) -> Option<&GroundedDocument> {
4516 self.documents.get(doc_id)
4517 }
4518
4519 pub fn get_document_mut(&mut self, doc_id: &str) -> Option<&mut GroundedDocument> {
4523 self.documents.get_mut(doc_id)
4524 }
4525
4526 pub fn add_document(&mut self, document: GroundedDocument) -> String {
4531 let doc_id = document.id.clone();
4532 self.documents.insert(doc_id.clone(), document);
4533 doc_id
4534 }
4535
4536 pub fn link_track_to_kb(
4558 &mut self,
4559 track_ref: &TrackRef,
4560 kb_name: impl Into<String>,
4561 kb_id: impl Into<String>,
4562 canonical_name: impl Into<String>,
4563 ) -> super::Result<IdentityId> {
4564 use super::error::Error;
4565
4566 let doc = self.documents.get_mut(&track_ref.doc_id).ok_or_else(|| {
4567 Error::track_ref(format!(
4568 "Document '{}' not found in corpus",
4569 track_ref.doc_id
4570 ))
4571 })?;
4572 let track = doc.get_track(track_ref.track_id).ok_or_else(|| {
4573 Error::track_ref(format!(
4574 "Track {} not found in document '{}'",
4575 track_ref.track_id, track_ref.doc_id
4576 ))
4577 })?;
4578
4579 let kb_name_str = kb_name.into();
4580 let kb_id_str = kb_id.into();
4581 let canonical_name_str = canonical_name.into();
4582
4583 let identity_id = if let Some(existing_id) = track.identity_id {
4585 if let Some(identity) = self.identities.get_mut(&existing_id) {
4587 identity.kb_id = Some(kb_id_str.clone());
4588 identity.kb_name = Some(kb_name_str.clone());
4589 identity.canonical_name = canonical_name_str.clone();
4590
4591 identity.source = Some(match identity.source.take() {
4593 Some(IdentitySource::CrossDocCoref { track_refs }) => IdentitySource::Hybrid {
4594 track_refs,
4595 kb_name: kb_name_str.clone(),
4596 kb_id: kb_id_str.clone(),
4597 },
4598 _ => IdentitySource::KnowledgeBase {
4599 kb_name: kb_name_str.clone(),
4600 kb_id: kb_id_str.clone(),
4601 },
4602 });
4603
4604 existing_id
4605 } else {
4606 let new_id = self.next_identity_id;
4614 self.next_identity_id += 1;
4615
4616 let identity = Identity {
4617 id: new_id,
4618 canonical_name: canonical_name_str,
4619 entity_type: track.entity_type.clone(),
4620 kb_id: Some(kb_id_str.clone()),
4621 kb_name: Some(kb_name_str.clone()),
4622 description: None,
4623 embedding: track.embedding.clone(),
4624 aliases: Vec::new(),
4625 confidence: track.cluster_confidence,
4626 source: Some(IdentitySource::KnowledgeBase {
4627 kb_name: kb_name_str,
4628 kb_id: kb_id_str,
4629 }),
4630 };
4631
4632 self.identities.insert(new_id, identity);
4633 doc.link_track_to_identity(track_ref.track_id, new_id);
4636 new_id
4637 }
4638 } else {
4639 let new_id = self.next_identity_id;
4641 self.next_identity_id += 1;
4642
4643 let identity = Identity {
4644 id: new_id,
4645 canonical_name: canonical_name_str,
4646 entity_type: track.entity_type.clone(),
4647 kb_id: Some(kb_id_str.clone()),
4648 kb_name: Some(kb_name_str.clone()),
4649 description: None,
4650 embedding: track.embedding.clone(),
4651 aliases: Vec::new(),
4652 confidence: track.cluster_confidence,
4653 source: Some(IdentitySource::KnowledgeBase {
4654 kb_name: kb_name_str,
4655 kb_id: kb_id_str,
4656 }),
4657 };
4658
4659 self.identities.insert(new_id, identity);
4660 doc.link_track_to_identity(track_ref.track_id, new_id);
4661 new_id
4662 };
4663
4664 Ok(identity_id)
4665 }
4666}
4667
4668impl Default for Corpus {
4669 fn default() -> Self {
4670 Self::new()
4671 }
4672}
4673
4674#[cfg(test)]
4675mod tests {
4676 #![allow(clippy::unwrap_used)] use super::*;
4678
4679 #[test]
4680 fn test_render_eval_html_has_interactive_hooks_and_is_unicode_safe() {
4681 let text = "習近平在北京會見了普京。";
4683
4684 let gold: Vec<Signal<Location>> = vec![
4685 Signal::new(SignalId::new(0), Location::text(0, 3), "習近平", "PER", 1.0),
4686 Signal::new(SignalId::new(1), Location::text(4, 6), "北京", "LOC", 1.0),
4687 ];
4688
4689 let predicted: Vec<Signal<Location>> = vec![
4691 Signal::new(SignalId::new(0), Location::text(0, 3), "習近平", "PER", 0.9),
4692 Signal::new(SignalId::new(1), Location::text(4, 6), "北京", "PER", 0.7),
4693 ];
4694
4695 let cmp = EvalComparison::compare(text, gold, predicted);
4696 let html = render_eval_html_with_title(&cmp, "test");
4697
4698 assert!(html.contains("id=\"selection\""));
4700
4701 assert!(html.contains("data-eids=\"G0\""));
4703 assert!(html.contains("data-eids=\"P0\""));
4704
4705 assert!(html.contains("class=\"match-link\""));
4707 assert!(html.contains("href=\"#M0\""));
4708 assert!(html.contains("data-gid=\"G0\""));
4709 assert!(html.contains("data-pid=\"P0\""));
4710
4711 assert!(html.contains("北京"));
4713 }
4714
4715 fn find_char_span(text: &str, needle: &str) -> Option<(usize, usize)> {
4716 let hay: Vec<char> = text.chars().collect();
4717 let pat: Vec<char> = needle.chars().collect();
4718 if pat.is_empty() || hay.len() < pat.len() {
4719 return None;
4720 }
4721 for i in 0..=(hay.len() - pat.len()) {
4722 if hay[i..(i + pat.len())] == pat[..] {
4723 return Some((i, i + pat.len()));
4724 }
4725 }
4726 None
4727 }
4728
4729 #[test]
4730 fn test_annotate_text_html_supports_overlaps_discontinuous_and_unicode() {
4731 let text = "Marie Curie met Cher in Paris. 習近平在北京會見了普京。 \
4733التقى محمد بن سلمان في الرياض. Путин встретился с Си Цзиньпином в Москве. \
4734प्रधान मंत्री शर्मा दिल्ली में मिले। severe pain ... in abdomen.";
4735
4736 let (m0s, m0e) = find_char_span(text, "Marie Curie").unwrap();
4738 let (m1s, m1e) = find_char_span(text, "Curie").unwrap();
4739
4740 let pain = find_char_span(text, "pain").unwrap();
4742 let abdomen = find_char_span(text, "abdomen").unwrap();
4743
4744 let signals: Vec<Signal<Location>> = vec![
4745 Signal::new(
4746 SignalId::new(0),
4747 Location::text(m0s, m0e),
4748 "Marie Curie",
4749 "PER",
4750 0.9,
4751 ),
4752 Signal::new(
4753 SignalId::new(1),
4754 Location::text(m1s, m1e),
4755 "Curie",
4756 "PER",
4757 0.8,
4758 ),
4759 Signal::new(
4760 SignalId::new(2),
4761 Location::Discontinuous {
4762 segments: vec![pain, abdomen],
4763 },
4764 "pain … abdomen",
4765 "SYMPTOM",
4766 0.7,
4767 ),
4768 ];
4769
4770 let html = annotate_text_html(text, &signals, &std::collections::HashMap::new());
4771
4772 assert!(html.contains("data-sids=\"S0 S1\"") || html.contains("data-sids=\"S1 S0\""));
4774
4775 assert!(html.contains("data-sids=\"S2\""));
4777
4778 assert!(html.contains("北京"));
4780 assert!(html.contains("Москве"));
4781 assert!(html.contains("शर्मा"));
4782 assert!(html.contains("محمد"));
4783 }
4784
4785 #[test]
4786 fn test_location_text_iou() {
4787 let l1 = Location::text(0, 10);
4788 let l2 = Location::text(5, 15);
4789 let iou = l1.iou(&l2).unwrap();
4790 assert!((iou - 0.333).abs() < 0.01);
4794 }
4795
4796 #[test]
4797 fn test_location_bbox_iou() {
4798 let b1 = Location::bbox(0.0, 0.0, 0.5, 0.5);
4799 let b2 = Location::bbox(0.25, 0.25, 0.5, 0.5);
4800 let iou = b1.iou(&b2).unwrap();
4801 assert!((iou - 0.143).abs() < 0.01);
4805 }
4806
4807 #[test]
4808 fn test_location_different_types_no_iou() {
4809 let text = Location::text(0, 10);
4810 let bbox = Location::bbox(0.0, 0.0, 0.5, 0.5);
4811 assert!(text.iou(&bbox).is_none());
4812 }
4813
4814 #[test]
4815 fn test_signal_creation() {
4816 let signal: Signal<Location> =
4817 Signal::new(0, Location::text(0, 11), "Marie Curie", "Person", 0.95);
4818 assert_eq!(signal.surface, "Marie Curie");
4819 assert_eq!(signal.label, "Person".into());
4820 assert!((signal.confidence - 0.95).abs() < 0.001);
4821 assert!(!signal.negated);
4822 }
4823
4824 #[test]
4825 fn test_signal_with_linguistic_features() {
4826 let signal: Signal<Location> =
4827 Signal::new(0, Location::text(0, 10), "not a doctor", "Occupation", 0.8)
4828 .negated()
4829 .with_quantifier(Quantifier::Existential)
4830 .with_modality(Modality::Symbolic);
4831
4832 assert!(signal.negated);
4833 assert_eq!(signal.quantifier, Some(Quantifier::Existential));
4834 assert!(signal.modality.supports_linguistic_features());
4835 }
4836
4837 #[test]
4838 fn test_track_formation() {
4839 let mut track = Track::new(0, "Marie Curie");
4840 track.add_signal(0, 0);
4841 track.add_signal(1, 1);
4842 track.add_signal(2, 2);
4843
4844 assert_eq!(track.len(), 3);
4845 assert!(!track.is_singleton());
4846 assert!(!track.is_empty());
4847 }
4848
4849 #[test]
4850 fn test_identity_creation() {
4851 let identity = Identity::from_kb(0, "Marie Curie", "wikidata", "Q7186")
4852 .with_type("Person")
4853 .with_embedding(vec![0.1, 0.2, 0.3]);
4854
4855 assert_eq!(identity.canonical_name, "Marie Curie");
4856 assert_eq!(identity.kb_id, Some("Q7186".to_string()));
4857 assert_eq!(identity.kb_name, Some("wikidata".to_string()));
4858 assert!(identity.embedding.is_some());
4859 }
4860
4861 #[test]
4862 fn test_grounded_document_hierarchy() {
4863 let mut doc = GroundedDocument::new(
4864 "doc1",
4865 "Marie Curie won the Nobel Prize. She was a physicist.",
4866 );
4867
4868 let s1 = doc.add_signal(Signal::new(
4870 0,
4871 Location::text(0, 12),
4872 "Marie Curie",
4873 "Person",
4874 0.95,
4875 ));
4876 let s2 = doc.add_signal(Signal::new(
4877 1,
4878 Location::text(38, 41),
4879 "She",
4880 "Person",
4881 0.88,
4882 ));
4883 let s3 = doc.add_signal(Signal::new(
4884 2,
4885 Location::text(17, 29),
4886 "Nobel Prize",
4887 "Award",
4888 0.92,
4889 ));
4890
4891 let mut track1 = Track::new(0, "Marie Curie");
4893 track1.add_signal(s1, 0);
4894 track1.add_signal(s2, 1);
4895 let track1_id = doc.add_track(track1);
4896
4897 let mut track2 = Track::new(1, "Nobel Prize");
4898 track2.add_signal(s3, 0);
4899 doc.add_track(track2);
4900
4901 let identity = Identity::from_kb(0, "Marie Curie", "wikidata", "Q7186");
4903 let identity_id = doc.add_identity(identity);
4904 doc.link_track_to_identity(track1_id, identity_id);
4905
4906 assert_eq!(doc.signals().len(), 3);
4908 assert_eq!(doc.tracks().count(), 2);
4909 assert_eq!(doc.identities().count(), 1);
4910
4911 let track = doc.track_for_signal(s1).unwrap();
4913 assert_eq!(track.canonical_surface, "Marie Curie");
4914 assert_eq!(track.len(), 2);
4915
4916 let identity = doc.identity_for_track(track1_id).unwrap();
4918 assert_eq!(identity.kb_id, Some("Q7186".to_string()));
4919
4920 let identity = doc.identity_for_signal(s1).unwrap();
4922 assert_eq!(identity.canonical_name, "Marie Curie");
4923 }
4924
4925 #[test]
4926 fn test_modality_features() {
4927 assert!(Modality::Symbolic.supports_linguistic_features());
4928 assert!(!Modality::Symbolic.supports_geometric_features());
4929
4930 assert!(!Modality::Iconic.supports_linguistic_features());
4931 assert!(Modality::Iconic.supports_geometric_features());
4932
4933 assert!(Modality::Hybrid.supports_linguistic_features());
4934 assert!(Modality::Hybrid.supports_geometric_features());
4935 }
4936
4937 #[test]
4938 fn test_location_from_span() {
4939 let span = Span::Text { start: 0, end: 10 };
4940 let location = Location::from(&span);
4941 assert_eq!(location.text_offsets(), Some((0, 10)));
4942
4943 let span = Span::BoundingBox {
4944 x: 0.1,
4945 y: 0.2,
4946 width: 0.3,
4947 height: 0.4,
4948 page: Some(1),
4949 };
4950 let location = Location::from(&span);
4951 assert!(matches!(location, Location::BoundingBox { .. }));
4952 }
4953
4954 #[test]
4955 fn test_entity_roundtrip() {
4956 use super::EntityType;
4957
4958 let entities = vec![
4959 Entity::new("Marie Curie", EntityType::Person, 0, 12, 0.95),
4960 Entity::new(
4961 "Nobel Prize",
4962 EntityType::Other("Award".to_string()),
4963 17,
4964 29,
4965 0.92,
4966 ),
4967 ];
4968
4969 let doc =
4970 GroundedDocument::from_entities("doc1", "Marie Curie won the Nobel Prize.", &entities);
4971 let converted = doc.to_entities();
4972
4973 assert_eq!(converted.len(), 2);
4974 assert_eq!(converted[0].text, "Marie Curie");
4975 assert_eq!(converted[1].text, "Nobel Prize");
4976 }
4977
4978 #[test]
4979 fn test_signal_confidence_threshold() {
4980 let signal: Signal<Location> = Signal::new(0, Location::text(0, 10), "test", "Type", 0.75);
4981 assert!(signal.is_confident(0.5));
4982 assert!(signal.is_confident(0.75));
4983 assert!(!signal.is_confident(0.8));
4984 }
4985
4986 #[test]
4987 fn test_document_filtering() {
4988 let mut doc = GroundedDocument::new("doc1", "Test text");
4989
4990 doc.add_signal(Signal::new(0, Location::text(0, 4), "high", "Person", 0.95));
4992 doc.add_signal(Signal::new(1, Location::text(5, 8), "low", "Person", 0.3));
4993 doc.add_signal(Signal::new(
4994 2,
4995 Location::text(9, 12),
4996 "org",
4997 "Organization",
4998 0.8,
4999 ));
5000
5001 let confident = doc.confident_signals(0.5);
5003 assert_eq!(confident.len(), 2);
5004
5005 let persons = doc.signals_with_label("Person");
5007 assert_eq!(persons.len(), 2);
5008
5009 let orgs = doc.signals_with_label("Organization");
5010 assert_eq!(orgs.len(), 1);
5011 }
5012
5013 #[test]
5014 fn test_untracked_signals() {
5015 let mut doc = GroundedDocument::new("doc1", "Test");
5016
5017 let s1 = doc.add_signal(Signal::new(0, Location::text(0, 4), "a", "T", 0.9));
5018 let s2 = doc.add_signal(Signal::new(1, Location::text(5, 8), "b", "T", 0.9));
5019 let _s3 = doc.add_signal(Signal::new(2, Location::text(9, 12), "c", "T", 0.9));
5020
5021 let mut track = Track::new(0, "a");
5023 track.add_signal(s1, 0);
5024 track.add_signal(s2, 1);
5025 doc.add_track(track);
5026
5027 assert_eq!(doc.untracked_signal_count(), 1);
5029 let untracked = doc.untracked_signals();
5030 assert_eq!(untracked.len(), 1);
5031 assert_eq!(untracked[0].surface, "c");
5032 }
5033
5034 #[test]
5035 fn test_linked_unlinked_tracks() {
5036 let mut doc = GroundedDocument::new("doc1", "Test");
5037
5038 let s1 = doc.add_signal(Signal::new(0, Location::text(0, 4), "a", "T", 0.9));
5039 let s2 = doc.add_signal(Signal::new(1, Location::text(5, 8), "b", "T", 0.9));
5040
5041 let mut track1 = Track::new(0, "a");
5042 track1.add_signal(s1, 0);
5043 let track1_id = doc.add_track(track1);
5044
5045 let mut track2 = Track::new(1, "b");
5046 track2.add_signal(s2, 0);
5047 doc.add_track(track2);
5048
5049 let identity = Identity::new(0, "Entity A");
5051 let identity_id = doc.add_identity(identity);
5052 doc.link_track_to_identity(track1_id, identity_id);
5053
5054 assert_eq!(doc.linked_tracks().count(), 1);
5055 assert_eq!(doc.unlinked_tracks().count(), 1);
5056 }
5057
5058 #[test]
5059 fn test_location_overlaps() {
5060 let l1 = Location::text(0, 10);
5061 let l2 = Location::text(5, 15);
5062 let l3 = Location::text(15, 20);
5063
5064 assert!(l1.overlaps(&l2));
5065 assert!(!l1.overlaps(&l3));
5066 assert!(!l2.overlaps(&l3)); let b1 = Location::bbox(0.0, 0.0, 0.5, 0.5);
5070 let b2 = Location::bbox(0.4, 0.4, 0.5, 0.5);
5071 let b3 = Location::bbox(0.6, 0.6, 0.2, 0.2);
5072
5073 assert!(b1.overlaps(&b2));
5074 assert!(!b1.overlaps(&b3));
5075 }
5076
5077 #[test]
5078 fn test_iou_edge_cases() {
5079 let l1 = Location::text(0, 5);
5081 let l2 = Location::text(10, 15);
5082 assert_eq!(l1.iou(&l2), Some(0.0));
5083
5084 let l3 = Location::text(0, 10);
5086 let l4 = Location::text(0, 10);
5087 assert_eq!(l3.iou(&l4), Some(1.0));
5088
5089 let l5 = Location::text(0, 20);
5091 let l6 = Location::text(5, 15);
5092 let iou = l5.iou(&l6).unwrap();
5093 assert!((iou - 0.5).abs() < 0.001);
5095 }
5096
5097 #[test]
5101 fn test_document_stats() {
5102 let mut doc = GroundedDocument::new("doc1", "Test document with entities.");
5103
5104 let s1 = doc.add_signal(Signal::new(0, Location::text(0, 4), "Test", "Type", 0.9));
5106 let mut negated = Signal::new(0, Location::text(5, 13), "document", "Type", 0.8);
5107 negated.negated = true;
5108 let s2 = doc.add_signal(negated);
5109 let _s3 = doc.add_signal(Signal::new(
5110 0,
5111 Location::text(19, 27),
5112 "entities",
5113 "Type",
5114 0.7,
5115 ));
5116
5117 let mut track = Track::new(0, "Test");
5119 track.add_signal(s1, 0);
5120 track.add_signal(s2, 1);
5121 doc.add_track(track);
5122
5123 let identity = Identity::new(0, "Test Entity");
5125 let identity_id = doc.add_identity(identity);
5126 doc.link_track_to_identity(0, identity_id);
5127
5128 let stats = doc.stats();
5129
5130 assert_eq!(stats.signal_count, 3);
5131 assert_eq!(stats.track_count, 1);
5132 assert_eq!(stats.identity_count, 1);
5133 assert_eq!(stats.linked_track_count, 1);
5134 assert_eq!(stats.untracked_count, 1); assert_eq!(stats.negated_count, 1);
5136 assert!((stats.avg_confidence - 0.8).abs() < 0.01); assert!((stats.avg_track_size - 2.0).abs() < 0.01);
5138 }
5139
5140 #[test]
5141 fn test_batch_operations() {
5142 let mut doc = GroundedDocument::new("doc1", "Test document.");
5143
5144 let signals = vec![
5146 Signal::new(0, Location::text(0, 4), "Test", "Type", 0.9),
5147 Signal::new(0, Location::text(5, 13), "document", "Type", 0.8),
5148 ];
5149 let ids = doc.add_signals(signals);
5150
5151 assert_eq!(ids.len(), 2);
5152 assert_eq!(doc.signals().len(), 2);
5153
5154 let track_id = doc.create_track_from_signals("Test", &ids);
5156 assert!(track_id.is_some());
5157
5158 let track = doc.get_track(track_id.unwrap()).unwrap();
5159 assert_eq!(track.len(), 2);
5160 assert_eq!(track.canonical_surface, "Test");
5161 }
5162
5163 #[test]
5164 fn test_merge_tracks() {
5165 let mut doc = GroundedDocument::new("doc1", "John Smith works at Acme. He is great.");
5166
5167 let s1 = doc.add_signal(Signal::new(
5169 0,
5170 Location::text(0, 10),
5171 "John Smith",
5172 "Person",
5173 0.9,
5174 ));
5175 let s2 = doc.add_signal(Signal::new(0, Location::text(26, 28), "He", "Person", 0.8));
5176
5177 let mut track1 = Track::new(0, "John Smith");
5179 track1.add_signal(s1, 0);
5180 let track1_id = doc.add_track(track1);
5181
5182 let mut track2 = Track::new(0, "He");
5183 track2.add_signal(s2, 0);
5184 let track2_id = doc.add_track(track2);
5185
5186 assert_eq!(doc.tracks().count(), 2);
5187
5188 let merged_id = doc.merge_tracks(&[track1_id, track2_id]);
5190 assert!(merged_id.is_some());
5191
5192 assert_eq!(doc.tracks().count(), 1);
5194 let merged = doc.get_track(merged_id.unwrap()).unwrap();
5195 assert_eq!(merged.len(), 2);
5196 assert_eq!(merged.canonical_surface, "John Smith"); }
5198
5199 #[test]
5200 fn test_find_overlapping_pairs() {
5201 let mut doc = GroundedDocument::new("doc1", "New York City is great.");
5202
5203 doc.add_signal(Signal::new(
5205 0,
5206 Location::text(0, 13),
5207 "New York City",
5208 "Location",
5209 0.9,
5210 ));
5211 doc.add_signal(Signal::new(
5212 0,
5213 Location::text(0, 8),
5214 "New York",
5215 "Location",
5216 0.85,
5217 ));
5218 doc.add_signal(Signal::new(0, Location::text(17, 22), "great", "Adj", 0.7)); let pairs = doc.find_overlapping_signal_pairs();
5221
5222 assert_eq!(pairs.len(), 1);
5224 }
5225
5226 #[test]
5227 fn test_signals_in_range() {
5228 let mut doc = GroundedDocument::new("doc1", "John went to Paris and Berlin last year.");
5229
5230 doc.add_signal(Signal::new(0, Location::text(0, 4), "John", "Person", 0.9));
5231 doc.add_signal(Signal::new(
5232 0,
5233 Location::text(13, 18),
5234 "Paris",
5235 "Location",
5236 0.9,
5237 ));
5238 doc.add_signal(Signal::new(
5239 0,
5240 Location::text(23, 29),
5241 "Berlin",
5242 "Location",
5243 0.9,
5244 ));
5245 doc.add_signal(Signal::new(
5246 0,
5247 Location::text(30, 39),
5248 "last year",
5249 "Date",
5250 0.8,
5251 ));
5252
5253 let in_range = doc.signals_in_range(10, 30);
5255 assert_eq!(in_range.len(), 2); let surfaces: Vec<_> = in_range.iter().map(|s| &s.surface).collect();
5258 assert!(surfaces.contains(&&"Paris".to_string()));
5259 assert!(surfaces.contains(&&"Berlin".to_string()));
5260 }
5261
5262 #[test]
5263 fn test_modality_filtering() {
5264 let mut doc = GroundedDocument::new("doc1", "Test");
5265
5266 let mut text_signal = Signal::new(0, Location::text(0, 4), "Test", "Type", 0.9);
5268 text_signal.modality = Modality::Symbolic;
5269 doc.add_signal(text_signal);
5270
5271 let mut visual_signal =
5273 Signal::new(0, Location::bbox(0.0, 0.0, 0.5, 0.5), "Box", "Type", 0.8);
5274 visual_signal.modality = Modality::Iconic;
5275 doc.add_signal(visual_signal);
5276
5277 assert_eq!(doc.text_signals().len(), 1);
5278 assert_eq!(doc.visual_signals().len(), 1);
5279 assert_eq!(doc.signals_by_modality(Modality::Hybrid).len(), 0);
5280 }
5281
5282 #[test]
5283 fn test_quantifier_variants() {
5284 let quantifiers = [
5286 Quantifier::Universal,
5287 Quantifier::Existential,
5288 Quantifier::None,
5289 Quantifier::Definite,
5290 Quantifier::Bare,
5291 ];
5292
5293 for q in quantifiers {
5294 let signal: Signal<Location> =
5295 Signal::new(0, Location::text(0, 5), "test", "Type", 0.9).with_quantifier(q);
5296
5297 assert_eq!(signal.quantifier, Some(q));
5298 }
5299 }
5300
5301 #[test]
5302 fn test_location_modality_derivation() {
5303 assert_eq!(Location::text(0, 10).modality(), Modality::Symbolic);
5304 assert_eq!(
5305 Location::bbox(0.0, 0.0, 0.5, 0.5).modality(),
5306 Modality::Iconic
5307 );
5308
5309 let temporal = Location::Temporal {
5310 start_sec: 0.0,
5311 end_sec: 5.0,
5312 frame: None,
5313 };
5314 assert_eq!(temporal.modality(), Modality::Iconic);
5315
5316 let genomic = Location::Genomic {
5317 contig: "chr1".into(),
5318 start: 0,
5319 end: 1000,
5320 strand: Some('+'),
5321 };
5322 assert_eq!(genomic.modality(), Modality::Symbolic);
5323
5324 let hybrid = Location::TextWithBbox {
5325 start: 0,
5326 end: 10,
5327 bbox: Box::new(Location::bbox(0.0, 0.0, 0.5, 0.5)),
5328 };
5329 assert_eq!(hybrid.modality(), Modality::Hybrid);
5330 }
5331
5332 }
5335
5336#[cfg(test)]
5344mod proptests {
5345 #![allow(clippy::unwrap_used)] use super::*;
5347 use proptest::prelude::*;
5348
5349 fn confidence_strategy() -> impl Strategy<Value = f32> {
5355 0.0f32..=1.0
5356 }
5357
5358 fn label_strategy() -> impl Strategy<Value = String> {
5360 prop_oneof![
5361 Just("Person".to_string()),
5362 Just("Organization".to_string()),
5363 Just("Location".to_string()),
5364 Just("Date".to_string()),
5365 "[A-Z][a-z]{2,10}".prop_map(|s| s),
5366 ]
5367 }
5368
5369 fn surface_strategy() -> impl Strategy<Value = String> {
5371 "[A-Za-z ]{1,50}".prop_map(|s| s.trim().to_string())
5372 }
5373
5374 proptest! {
5379 #[test]
5381 fn iou_symmetric(
5382 start1 in 0usize..1000,
5383 len1 in 1usize..500,
5384 start2 in 0usize..1000,
5385 len2 in 1usize..500,
5386 ) {
5387 let a = Location::text(start1, start1 + len1);
5388 let b = Location::text(start2, start2 + len2);
5389
5390 let iou_ab = a.iou(&b);
5391 let iou_ba = b.iou(&a);
5392
5393 prop_assert_eq!(iou_ab, iou_ba, "IoU must be symmetric");
5394 }
5395
5396 #[test]
5398 fn iou_bounded(
5399 start1 in 0usize..1000,
5400 len1 in 1usize..500,
5401 start2 in 0usize..1000,
5402 len2 in 1usize..500,
5403 ) {
5404 let a = Location::text(start1, start1 + len1);
5405 let b = Location::text(start2, start2 + len2);
5406
5407 if let Some(iou) = a.iou(&b) {
5408 prop_assert!(iou >= 0.0, "IoU must be non-negative: got {}", iou);
5409 prop_assert!(iou <= 1.0, "IoU must be at most 1: got {}", iou);
5410 }
5411 }
5412
5413 #[test]
5415 fn iou_self_identity(start in 0usize..1000, len in 1usize..500) {
5416 let loc = Location::text(start, start + len);
5417 let iou = loc.iou(&loc).unwrap();
5418 prop_assert!(
5419 (iou - 1.0).abs() < 1e-6,
5420 "Self-IoU must be 1.0, got {}",
5421 iou
5422 );
5423 }
5424
5425 #[test]
5427 fn iou_non_overlapping_zero(
5428 start1 in 0usize..500,
5429 len1 in 1usize..100,
5430 ) {
5431 let end1 = start1 + len1;
5432 let start2 = end1 + 100; let len2 = 50;
5434
5435 let a = Location::text(start1, end1);
5436 let b = Location::text(start2, start2 + len2);
5437
5438 let iou = a.iou(&b).expect("bbox iou should be defined");
5439 prop_assert!(
5440 iou.abs() < 1e-6,
5441 "Non-overlapping IoU must be 0, got {}",
5442 iou
5443 );
5444 }
5445
5446 #[test]
5448 fn bbox_iou_symmetric_bounded(
5449 x1 in 0.0f32..0.8,
5450 y1 in 0.0f32..0.8,
5451 w1 in 0.05f32..0.2,
5452 h1 in 0.05f32..0.2,
5453 x2 in 0.0f32..0.8,
5454 y2 in 0.0f32..0.8,
5455 w2 in 0.05f32..0.2,
5456 h2 in 0.05f32..0.2,
5457 ) {
5458 let a = Location::bbox(x1, y1, w1, h1);
5459 let b = Location::bbox(x2, y2, w2, h2);
5460
5461 let iou_ab = a.iou(&b);
5462 let iou_ba = b.iou(&a);
5463
5464 prop_assert_eq!(iou_ab, iou_ba, "BBox IoU must be symmetric");
5466
5467 if let Some(iou) = iou_ab {
5469 prop_assert!(
5470 (0.0..=1.0).contains(&iou),
5471 "BBox IoU out of bounds: {}",
5472 iou
5473 );
5474 }
5475 }
5476 }
5477
5478 proptest! {
5483 #[test]
5485 fn signal_confidence_clamped(raw_conf in -10.0f32..10.0) {
5486 let signal: Signal<Location> = Signal::new(
5487 0,
5488 Location::text(0, 10),
5489 "test",
5490 "Type",
5491 raw_conf,
5492 );
5493
5494 prop_assert!(signal.confidence >= 0.0, "Confidence below 0: {}", signal.confidence);
5495 prop_assert!(signal.confidence <= 1.0, "Confidence above 1: {}", signal.confidence);
5496 }
5497
5498 #[test]
5500 fn signal_preserves_data(
5501 surface in surface_strategy(),
5502 label in label_strategy(),
5503 conf in confidence_strategy(),
5504 start in 0usize..1000,
5505 len in 1usize..100,
5506 ) {
5507 let signal: Signal<Location> = Signal::new(
5508 0,
5509 Location::text(start, start + len),
5510 &surface,
5511 label.as_str(),
5512 conf,
5513 );
5514
5515 prop_assert_eq!(&signal.surface, &surface);
5516 let want = crate::TypeLabel::from(label.as_str());
5517 prop_assert_eq!(signal.label, want);
5518 }
5519
5520 #[test]
5524 fn signal_negation_stable(conf in confidence_strategy()) {
5525 let signal: Signal<Location> = Signal::new(
5526 0,
5527 Location::text(0, 10),
5528 "test",
5529 "Type",
5530 conf,
5531 )
5532 .negated();
5533
5534 prop_assert!(signal.negated, "Signal should be negated after .negated()");
5535 }
5536
5537 #[test]
5539 fn symbolic_supports_linguistic(
5540 start in 0usize..1000,
5541 len in 1usize..100,
5542 ) {
5543 let loc = Location::text(start, start + len);
5544 prop_assert!(
5545 loc.modality().supports_linguistic_features(),
5546 "Text locations must support linguistic features"
5547 );
5548 }
5549
5550 #[test]
5552 fn iconic_supports_geometric(
5553 x in 0.0f32..0.9,
5554 y in 0.0f32..0.9,
5555 w in 0.01f32..0.5,
5556 h in 0.01f32..0.5,
5557 ) {
5558 let loc = Location::bbox(x, y, w, h);
5559 prop_assert!(
5560 loc.modality().supports_geometric_features(),
5561 "BBox locations must support geometric features"
5562 );
5563 }
5564 }
5565
5566 proptest! {
5571 #[test]
5573 fn track_length_monotonic(signal_count in 1usize..20) {
5574 let mut track = Track::new(0, "test");
5575
5576 for i in 0..signal_count {
5577 track.add_signal(i, i as u32);
5578 prop_assert_eq!(
5579 track.len(),
5580 i + 1,
5581 "Track length should be {} after adding {} signals",
5582 i + 1,
5583 i + 1
5584 );
5585 }
5586 }
5587
5588 #[test]
5590 fn track_not_empty_after_add(canonical in surface_strategy()) {
5591 let mut track = Track::new(0, &canonical);
5592 prop_assert!(track.is_empty(), "New track should be empty");
5593
5594 track.add_signal(0, 0);
5595 prop_assert!(!track.is_empty(), "Track should not be empty after add");
5596 }
5597
5598 #[test]
5600 fn track_positions_stored(signal_count in 1usize..10) {
5601 let mut track = Track::new(0, "test");
5602
5603 for i in 0..signal_count {
5604 track.add_signal(i, i as u32);
5605 }
5606
5607 for (idx, signal_ref) in track.signals.iter().enumerate() {
5608 prop_assert_eq!(
5609 signal_ref.position as usize,
5610 idx,
5611 "Signal position mismatch at index {}",
5612 idx
5613 );
5614 }
5615 }
5616 }
5617
5618 proptest! {
5623 #[test]
5625 fn document_signal_ids_monotonic(signal_count in 1usize..20) {
5626 let mut doc = GroundedDocument::new("test", "test text");
5627
5628 let mut prev_id: Option<SignalId> = None;
5629 for i in 0..signal_count {
5630 let id = doc.add_signal(Signal::new(
5631 999, Location::text(i * 10, i * 10 + 5),
5633 format!("entity_{}", i),
5634 "Type",
5635 0.9,
5636 ));
5637
5638 if let Some(prev) = prev_id {
5639 prop_assert!(id > prev, "Signal IDs should be monotonically increasing");
5640 }
5641 prev_id = Some(id);
5642 }
5643 }
5644
5645 #[test]
5647 fn document_track_membership_consistent(signal_count in 1usize..5) {
5648 let mut doc = GroundedDocument::new("test", "test text");
5649
5650 let mut signal_ids = Vec::new();
5652 for i in 0..signal_count {
5653 let id = doc.add_signal(Signal::new(
5654 0,
5655 Location::text(i * 10, i * 10 + 5),
5656 format!("entity_{}", i),
5657 "Type",
5658 0.9,
5659 ));
5660 signal_ids.push(id);
5661 }
5662
5663 let mut track = Track::new(0, "canonical");
5665 for (pos, &id) in signal_ids.iter().enumerate() {
5666 track.add_signal(id, pos as u32);
5667 }
5668 let track_id = doc.add_track(track);
5669
5670 for &signal_id in &signal_ids {
5672 let found_track = doc.track_for_signal(signal_id);
5673 prop_assert!(found_track.is_some(), "Signal should be in a track");
5674 prop_assert_eq!(
5675 found_track.unwrap().id,
5676 track_id,
5677 "Signal should be in the correct track"
5678 );
5679 }
5680 }
5681
5682 #[test]
5684 fn document_identity_transitivity(signal_count in 1usize..3) {
5685 let mut doc = GroundedDocument::new("test", "test text");
5686
5687 let mut signal_ids = Vec::new();
5689 for i in 0..signal_count {
5690 let id = doc.add_signal(Signal::new(
5691 0,
5692 Location::text(i * 10, i * 10 + 5),
5693 format!("entity_{}", i),
5694 "Type",
5695 0.9,
5696 ));
5697 signal_ids.push(id);
5698 }
5699
5700 let mut track = Track::new(0, "canonical");
5702 for (pos, &id) in signal_ids.iter().enumerate() {
5703 track.add_signal(id, pos as u32);
5704 }
5705 let track_id = doc.add_track(track);
5706
5707 let identity = Identity::from_kb(0, "Entity", "wikidata", "Q123");
5708 let identity_id = doc.add_identity(identity);
5709 doc.link_track_to_identity(track_id, identity_id);
5710
5711 for &signal_id in &signal_ids {
5713 let identity = doc.identity_for_signal(signal_id);
5714 prop_assert!(identity.is_some(), "Should find identity through signal");
5715 prop_assert_eq!(
5716 identity.unwrap().id,
5717 identity_id,
5718 "Should find correct identity"
5719 );
5720 }
5721 }
5722
5723 #[test]
5725 fn document_untracked_signals(total in 2usize..10, tracked in 0usize..10) {
5726 let tracked = tracked.min(total - 1); let mut doc = GroundedDocument::new("test", "test text");
5728
5729 let mut signal_ids = Vec::new();
5731 for i in 0..total {
5732 let id = doc.add_signal(Signal::new(
5733 0,
5734 Location::text(i * 10, i * 10 + 5),
5735 format!("entity_{}", i),
5736 "Type",
5737 0.9,
5738 ));
5739 signal_ids.push(id);
5740 }
5741
5742 let mut track = Track::new(0, "canonical");
5744 for (pos, &id) in signal_ids.iter().take(tracked).enumerate() {
5745 track.add_signal(id, pos as u32);
5746 }
5747 if tracked > 0 {
5748 doc.add_track(track);
5749 }
5750
5751 prop_assert_eq!(
5753 doc.untracked_signal_count(),
5754 total - tracked,
5755 "Wrong untracked count"
5756 );
5757 }
5758 }
5759
5760 proptest! {
5765 #[test]
5767 fn entity_roundtrip_preserves_text(
5768 text in surface_strategy(),
5769 start in 0usize..1000,
5770 len in 1usize..100,
5771 conf in 0.0f64..=1.0,
5772 ) {
5773 use super::EntityType;
5774
5775 let end = start + len;
5776 let entity = super::Entity::new(&text, EntityType::Person, start, end, conf);
5777
5778 let doc = GroundedDocument::from_entities("test", "x".repeat(end + 10), &[entity]);
5779 let converted = doc.to_entities();
5780
5781 prop_assert_eq!(converted.len(), 1, "Should have exactly one entity");
5782 prop_assert_eq!(&converted[0].text, &text, "Text should be preserved");
5783 prop_assert_eq!(converted[0].start, start, "Start should be preserved");
5784 prop_assert_eq!(converted[0].end, end, "End should be preserved");
5785 }
5786
5787 }
5790
5791 proptest! {
5796 #[test]
5798 fn modality_feature_consistency(_dummy in 0..1) {
5799 prop_assert!(Modality::Iconic.supports_geometric_features());
5801 prop_assert!(!Modality::Iconic.supports_linguistic_features());
5802
5803 prop_assert!(Modality::Symbolic.supports_linguistic_features());
5805 prop_assert!(!Modality::Symbolic.supports_geometric_features());
5806
5807 prop_assert!(Modality::Hybrid.supports_linguistic_features());
5809 prop_assert!(Modality::Hybrid.supports_geometric_features());
5810 }
5811 }
5812
5813 proptest! {
5818 #[test]
5820 fn overlap_symmetric(
5821 start1 in 0usize..1000,
5822 len1 in 1usize..100,
5823 start2 in 0usize..1000,
5824 len2 in 1usize..100,
5825 ) {
5826 let a = Location::text(start1, start1 + len1);
5827 let b = Location::text(start2, start2 + len2);
5828
5829 prop_assert_eq!(
5830 a.overlaps(&b),
5831 b.overlaps(&a),
5832 "Overlap must be symmetric"
5833 );
5834 }
5835
5836 #[test]
5838 fn overlap_reflexive(start in 0usize..1000, len in 1usize..100) {
5839 let loc = Location::text(start, start + len);
5840 prop_assert!(loc.overlaps(&loc), "Location must overlap with itself");
5841 }
5842
5843 #[test]
5845 fn iou_implies_overlap(
5846 start1 in 0usize..500,
5847 len1 in 1usize..100,
5848 start2 in 0usize..500,
5849 len2 in 1usize..100,
5850 ) {
5851 let a = Location::text(start1, start1 + len1);
5852 let b = Location::text(start2, start2 + len2);
5853
5854 if let Some(iou) = a.iou(&b) {
5855 if iou > 0.0 {
5856 prop_assert!(
5857 a.overlaps(&b),
5858 "IoU > 0 should imply overlap"
5859 );
5860 }
5861 }
5862 }
5863 }
5864
5865 proptest! {
5870 #[test]
5872 fn stats_signal_count_accurate(signal_count in 0usize..20) {
5873 let mut doc = GroundedDocument::new("test", "test");
5874 for i in 0..signal_count {
5875 doc.add_signal(Signal::new(
5876 0,
5877 Location::text(i * 10, i * 10 + 5),
5878 "entity",
5879 "Type",
5880 0.9,
5881 ));
5882 }
5883
5884 let stats = doc.stats();
5885 prop_assert_eq!(stats.signal_count, signal_count);
5886 }
5887
5888 #[test]
5890 fn stats_track_count_accurate(track_count in 0usize..10) {
5891 let mut doc = GroundedDocument::new("test", "test");
5892 for i in 0..track_count {
5893 let id = doc.add_signal(Signal::new(
5894 0,
5895 Location::text(i * 10, i * 10 + 5),
5896 "entity",
5897 "Type",
5898 0.9,
5899 ));
5900 let mut track = Track::new(0, format!("track_{}", i));
5901 track.add_signal(id, 0);
5902 doc.add_track(track);
5903 }
5904
5905 let stats = doc.stats();
5906 prop_assert_eq!(stats.track_count, track_count);
5907 }
5908
5909 #[test]
5911 fn stats_avg_confidence_bounded(
5912 confidences in proptest::collection::vec(0.0f32..=1.0, 1..10)
5913 ) {
5914 let mut doc = GroundedDocument::new("test", "test");
5915 for (i, conf) in confidences.iter().enumerate() {
5916 doc.add_signal(Signal::new(
5917 0,
5918 Location::text(i * 10, i * 10 + 5),
5919 "entity",
5920 "Type",
5921 *conf,
5922 ));
5923 }
5924
5925 let stats = doc.stats();
5926 prop_assert!(stats.avg_confidence >= 0.0);
5927 prop_assert!(stats.avg_confidence <= 1.0);
5928 }
5929 }
5930
5931 proptest! {
5936 #[test]
5938 fn batch_add_returns_all_ids(count in 1usize..10) {
5939 let mut doc = GroundedDocument::new("test", "test");
5940 let signals: Vec<Signal<Location>> = (0..count)
5941 .map(|i| Signal::new(0, Location::text(i * 10, i * 10 + 5), "e", "T", 0.9))
5942 .collect();
5943
5944 let ids = doc.add_signals(signals);
5945 prop_assert_eq!(ids.len(), count);
5946 prop_assert_eq!(doc.signals().len(), count);
5947 }
5948
5949 #[test]
5951 fn create_track_valid(signal_count in 1usize..5) {
5952 let mut doc = GroundedDocument::new("test", "test");
5953 let mut signal_ids = Vec::new();
5954 for i in 0..signal_count {
5955 let id = doc.add_signal(Signal::new(
5956 0,
5957 Location::text(i * 10, i * 10 + 5),
5958 "entity",
5959 "Type",
5960 0.9,
5961 ));
5962 signal_ids.push(id);
5963 }
5964
5965 let track_id = doc.create_track_from_signals("canonical", &signal_ids);
5966 prop_assert!(track_id.is_some());
5967
5968 let track = doc.get_track(track_id.unwrap());
5969 prop_assert!(track.is_some());
5970 prop_assert_eq!(track.unwrap().len(), signal_count);
5971 }
5972
5973 #[test]
5975 fn create_track_empty_returns_none(_dummy in 0..1) {
5976 let mut doc = GroundedDocument::new("test", "test");
5977 let track_id = doc.create_track_from_signals("canonical", &[]);
5978 prop_assert!(track_id.is_none());
5979 }
5980 }
5981
5982 proptest! {
5987 #[test]
5989 fn signals_in_range_within_bounds(
5990 range_start in 0usize..100,
5991 range_len in 10usize..50,
5992 ) {
5993 let range_end = range_start + range_len;
5994 let mut doc = GroundedDocument::new("test", "x".repeat(200));
5995
5996 doc.add_signal(Signal::new(0, Location::text(range_start + 2, range_start + 5), "inside", "T", 0.9));
5998 doc.add_signal(Signal::new(0, Location::text(0, 5), "before", "T", 0.9));
5999 doc.add_signal(Signal::new(0, Location::text(190, 195), "after", "T", 0.9));
6000
6001 let in_range = doc.signals_in_range(range_start, range_end);
6002
6003 for signal in &in_range {
6004 if let Some((start, end)) = signal.location.text_offsets() {
6005 prop_assert!(start >= range_start, "Signal start {} < range start {}", start, range_start);
6006 prop_assert!(end <= range_end, "Signal end {} > range end {}", end, range_end);
6007 }
6008 }
6009 }
6010
6011 #[test]
6013 fn overlapping_signals_symmetric(
6014 start1 in 10usize..50,
6015 len1 in 5usize..20,
6016 start2 in 10usize..50,
6017 len2 in 5usize..20,
6018 ) {
6019 let mut doc = GroundedDocument::new("test", "x".repeat(100));
6020
6021 let loc1 = Location::text(start1, start1 + len1);
6022 let loc2 = Location::text(start2, start2 + len2);
6023
6024 doc.add_signal(Signal::new(0, loc1.clone(), "A", "T", 0.9));
6025 doc.add_signal(Signal::new(0, loc2.clone(), "B", "T", 0.9));
6026
6027 let overlaps_loc1 = doc.overlapping_signals(&loc1);
6028 let overlaps_loc2 = doc.overlapping_signals(&loc2);
6029
6030 if loc1.overlaps(&loc2) {
6032 prop_assert!(overlaps_loc1.len() >= 2, "Should find both when overlapping");
6033 prop_assert!(overlaps_loc2.len() >= 2, "Should find both when overlapping");
6034 }
6035 }
6036 }
6037
6038 proptest! {
6043 #[test]
6045 fn modality_counts_sum_to_total(
6046 symbolic_count in 0usize..5,
6047 iconic_count in 0usize..5,
6048 ) {
6049 let mut doc = GroundedDocument::new("test", "test");
6050
6051 for i in 0..symbolic_count {
6053 let mut signal = Signal::new(
6054 0,
6055 Location::text(i * 10, i * 10 + 5),
6056 "entity",
6057 "Type",
6058 0.9,
6059 );
6060 signal.modality = Modality::Symbolic;
6061 doc.add_signal(signal);
6062 }
6063
6064 for i in 0..iconic_count {
6066 let mut signal = Signal::new(
6067 0,
6068 Location::bbox(i as f32 * 0.1, 0.0, 0.05, 0.05),
6069 "entity",
6070 "Type",
6071 0.9,
6072 );
6073 signal.modality = Modality::Iconic;
6074 doc.add_signal(signal);
6075 }
6076
6077 let stats = doc.stats();
6078 prop_assert_eq!(
6079 stats.symbolic_count + stats.iconic_count + stats.hybrid_count,
6080 stats.signal_count,
6081 "Modality counts should sum to total"
6082 );
6083 }
6084 }
6085
6086 proptest! {
6091 #[test]
6093 fn from_text_always_valid(
6094 text in "[a-zA-Z ]{20,100}",
6095 surface_start in 0usize..15,
6096 surface_len in 1usize..8,
6097 ) {
6098 let text_char_len = text.chars().count();
6099 let surface_end = (surface_start + surface_len).min(text_char_len);
6100 let surface_start = surface_start.min(surface_end.saturating_sub(1));
6101
6102 if surface_start < surface_end && surface_end <= text_char_len {
6103 let surface: String = text.chars()
6104 .skip(surface_start)
6105 .take(surface_end - surface_start)
6106 .collect();
6107
6108 if !surface.is_empty() {
6109 if let Some(signal) = Signal::<Location>::from_text(&text, &surface, "Test", 0.9) {
6111 prop_assert!(
6113 signal.validate_against(&text).is_none(),
6114 "Signal created via from_text must be valid"
6115 );
6116 }
6117 }
6118 }
6119 }
6120
6121 #[test]
6123 fn validated_add_rejects_invalid(
6124 text in "[a-z]{10,50}",
6125 wrong_surface in "[A-Z]{3,10}",
6126 ) {
6127 let mut doc = GroundedDocument::new("test", &text);
6128
6129 let signal = Signal::new(
6131 0,
6132 Location::text(0, wrong_surface.chars().count().min(text.chars().count())),
6133 wrong_surface.clone(),
6134 "Test",
6135 0.9,
6136 );
6137
6138 let expected: String = text.chars().take(wrong_surface.chars().count()).collect();
6141 if expected != wrong_surface {
6142 let result = doc.add_signal_validated(signal);
6143 prop_assert!(result.is_err(), "Should reject signal with mismatched surface");
6144 }
6145 }
6146
6147 #[test]
6149 fn round_trip_signal_from_text(
6150 prefix in "[a-z]{5,20}",
6151 entity in "[A-Z][a-z]{3,10}",
6152 suffix in "[a-z]{5,20}",
6153 ) {
6154 let text = format!("{} {} {}", prefix, entity, suffix);
6155 let mut doc = GroundedDocument::new("test", &text);
6156
6157 let id = doc.add_signal_from_text(&entity, "Entity", 0.9);
6158 prop_assert!(id.is_some(), "Should find entity in text");
6159
6160 let signal = doc.signals().iter().find(|s| s.id == id.unwrap());
6161 prop_assert!(signal.is_some(), "Should retrieve added signal");
6162
6163 let signal = signal.unwrap();
6164 prop_assert_eq!(signal.surface(), entity.as_str(), "Surface should match");
6165
6166 prop_assert!(
6168 doc.is_valid(),
6169 "Document should be valid after from_text add"
6170 );
6171 }
6172
6173 #[test]
6175 fn nth_occurrence_finds_correct(
6176 entity in "[A-Z][a-z]{2,5}",
6177 sep in " [a-z]+ ",
6178 ) {
6179 let text = format!("{}{}{}{}{}", entity, sep, entity, sep, entity);
6181 let mut doc = GroundedDocument::new("test", &text);
6182
6183 for n in 0..3 {
6185 let id = doc.add_signal_from_text_nth(&entity, "Entity", 0.9, n);
6186 prop_assert!(id.is_some(), "Should find occurrence {}", n);
6187 }
6188
6189 let id = doc.add_signal_from_text_nth(&entity, "Entity", 0.9, 3);
6191 prop_assert!(id.is_none(), "Should NOT find 4th occurrence");
6192
6193 prop_assert!(doc.is_valid(), "All signals should be valid");
6195
6196 let offsets: Vec<_> = doc.signals()
6198 .iter()
6199 .filter_map(|s| s.text_offsets())
6200 .collect();
6201 let unique: std::collections::HashSet<_> = offsets.iter().collect();
6202 prop_assert_eq!(offsets.len(), unique.len(), "Each occurrence should have distinct offset");
6203 }
6204 }
6205
6206 #[test]
6211 fn test_track_stats_basic() {
6212 let text = "John met Mary. He said hello. John left.";
6213 let mut doc = GroundedDocument::new("test", text);
6214 let text_len = text.chars().count();
6215
6216 let s1 = doc.add_signal(Signal::new(0, Location::text(0, 4), "John", "Person", 0.95));
6218 let s2 = doc.add_signal(Signal::new(
6219 0,
6220 Location::text(30, 34),
6221 "John",
6222 "Person",
6223 0.90,
6224 ));
6225
6226 let track_id = doc.add_track(Track::new(0, "John".to_string()));
6228 doc.add_signal_to_track(s1, track_id, 0);
6229 doc.add_signal_to_track(s2, track_id, 1);
6230
6231 let track = doc.get_track(track_id).unwrap();
6233 let stats = track.compute_stats(&doc, text_len);
6234
6235 assert_eq!(stats.chain_length, 2, "Two mentions");
6236 assert_eq!(stats.variation_count, 1, "One unique surface form");
6237 assert!(stats.spread > 0, "Spread should be positive");
6238 assert!(stats.relative_spread > 0.0 && stats.relative_spread < 1.0);
6239 assert!((stats.min_confidence - 0.90).abs() < 0.01);
6240 assert!((stats.max_confidence - 0.95).abs() < 0.01);
6241 assert!((stats.mean_confidence - 0.925).abs() < 0.01);
6242 }
6243
6244 #[test]
6245 fn test_track_stats_singleton() {
6246 let text = "Paris is beautiful.";
6247 let mut doc = GroundedDocument::new("test", text);
6248 let text_len = text.chars().count();
6249
6250 let s1 = doc.add_signal(Signal::new(
6251 0,
6252 Location::text(0, 5),
6253 "Paris",
6254 "Location",
6255 0.88,
6256 ));
6257 let track_id = doc.add_track(Track::new(0, "Paris".to_string()));
6258 doc.add_signal_to_track(s1, track_id, 0);
6259
6260 let track = doc.get_track(track_id).unwrap();
6261 let stats = track.compute_stats(&doc, text_len);
6262
6263 assert_eq!(stats.chain_length, 1);
6264 assert_eq!(stats.spread, 0, "Singleton has zero spread");
6265 assert_eq!(stats.first_position, stats.last_position);
6266 assert!((stats.min_confidence - stats.max_confidence).abs() < 0.001);
6267 }
6268}