1use super::lsh::{LSHConfig, MinHashLSH};
103use crate::core as anno_core;
104use std::collections::HashMap;
105
106#[derive(Debug, Clone)]
108pub struct StreamingConfig {
109 pub add_threshold: f32,
111 pub merge_threshold: f32,
113 pub max_clusters: usize,
115 pub use_lsh: bool,
117 pub lsh_config: LSHConfig,
119 pub require_type_match: bool,
121}
122
123impl Default for StreamingConfig {
124 fn default() -> Self {
125 Self {
126 add_threshold: 0.6,
127 merge_threshold: 0.7,
128 max_clusters: 10_000,
129 use_lsh: true,
130 lsh_config: LSHConfig::default(),
131 require_type_match: true,
132 }
133 }
134}
135
136impl StreamingConfig {
137 pub fn high_recall() -> Self {
139 Self {
140 add_threshold: 0.4,
141 merge_threshold: 0.5,
142 use_lsh: true,
143 lsh_config: LSHConfig::high_recall(),
144 ..Default::default()
145 }
146 }
147
148 pub fn high_precision() -> Self {
150 Self {
151 add_threshold: 0.7,
152 merge_threshold: 0.8,
153 use_lsh: true,
154 lsh_config: LSHConfig::high_precision(),
155 ..Default::default()
156 }
157 }
158
159 }
162
163#[derive(Debug, Clone)]
165pub struct EntityMention {
166 pub doc_id: String,
168 pub canonical_surface: String,
170 pub entity_type: Option<anno_core::TypeLabel>,
172 pub embedding: Option<Vec<f32>>,
174 pub track_id: Option<anno_core::TrackId>,
176 pub timestamp: Option<chrono::DateTime<chrono::Utc>>,
178 pub valid_from: Option<chrono::DateTime<chrono::Utc>>,
180 pub valid_until: Option<chrono::DateTime<chrono::Utc>>,
182}
183
184impl EntityMention {
185 pub fn new(doc_id: impl Into<String>, surface: impl Into<String>) -> Self {
187 Self {
188 doc_id: doc_id.into(),
189 canonical_surface: surface.into(),
190 entity_type: None,
191 embedding: None,
192 track_id: None,
193 timestamp: None,
194 valid_from: None,
195 valid_until: None,
196 }
197 }
198
199 pub fn with_type(mut self, entity_type: impl Into<String>) -> Self {
201 let s = entity_type.into();
202 self.entity_type = Some(anno_core::TypeLabel::from(s.as_str()));
203 self
204 }
205
206 pub fn with_timestamp(mut self, ts: chrono::DateTime<chrono::Utc>) -> Self {
210 self.timestamp = Some(ts);
211 self
212 }
213
214 pub fn with_temporal_bounds(
219 mut self,
220 from: Option<chrono::DateTime<chrono::Utc>>,
221 until: Option<chrono::DateTime<chrono::Utc>>,
222 ) -> Self {
223 self.valid_from = from;
224 self.valid_until = until;
225 self
226 }
227
228 pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
230 self.embedding = Some(embedding);
231 self
232 }
233
234 pub fn with_track_id(mut self, track_id: anno_core::TrackId) -> Self {
236 self.track_id = Some(track_id);
237 self
238 }
239}
240
241#[derive(Debug, Clone)]
243pub struct EntityCluster {
244 pub id: anno_core::IdentityId,
246 pub canonical_name: String,
248 pub entity_type: Option<anno_core::TypeLabel>,
250 pub mentions: Vec<EntityMention>,
252 pub centroid: Option<Vec<f32>>,
254 pub confidence: f32,
256}
257
258impl EntityCluster {
259 fn from_mention(id: anno_core::IdentityId, mention: EntityMention) -> Self {
261 let canonical_name = mention.canonical_surface.clone();
262 let entity_type = mention.entity_type.clone();
263 let centroid = mention.embedding.clone();
264
265 Self {
266 id,
267 canonical_name,
268 entity_type,
269 mentions: vec![mention],
270 centroid,
271 confidence: 1.0,
272 }
273 }
274
275 fn add_mention(&mut self, mention: EntityMention) {
277 if let (Some(existing), Some(new)) = (&mut self.centroid, &mention.embedding) {
279 let n = self.mentions.len() as f32;
280 for (i, v) in new.iter().enumerate() {
281 if i < existing.len() {
282 existing[i] = (existing[i] * n + v) / (n + 1.0);
284 }
285 }
286 } else if self.centroid.is_none() && mention.embedding.is_some() {
287 self.centroid = mention.embedding.clone();
288 }
289
290 self.mentions.push(mention);
291 }
292
293 fn merge(&mut self, other: EntityCluster) {
295 if let (Some(c1), Some(c2)) = (&mut self.centroid, &other.centroid) {
297 let n1 = self.mentions.len() as f32;
298 let n2 = other.mentions.len() as f32;
299 for (i, v2) in c2.iter().enumerate() {
300 if i < c1.len() {
301 c1[i] = (c1[i] * n1 + v2 * n2) / (n1 + n2);
302 }
303 }
304 }
305
306 self.mentions.extend(other.mentions);
308
309 self.confidence = (self.confidence + other.confidence) / 2.0;
311 }
312
313 pub fn document_ids(&self) -> Vec<&str> {
315 let mut doc_ids: Vec<&str> = self.mentions.iter().map(|m| m.doc_id.as_str()).collect();
316 doc_ids.sort();
317 doc_ids.dedup();
318 doc_ids
319 }
320
321 pub fn has_temporal_bounds(&self) -> bool {
323 self.mentions
324 .iter()
325 .any(|m| m.valid_from.is_some() || m.valid_until.is_some())
326 }
327
328 pub fn temporal_bounds(
336 &self,
337 ) -> (
338 Option<chrono::DateTime<chrono::Utc>>,
339 Option<chrono::DateTime<chrono::Utc>>,
340 ) {
341 let valid_from = self.mentions.iter().filter_map(|m| m.valid_from).min();
342
343 let valid_until = self.mentions.iter().filter_map(|m| m.valid_until).max();
344
345 (valid_from, valid_until)
346 }
347
348 pub fn observation_times(&self) -> Vec<chrono::DateTime<chrono::Utc>> {
352 let mut times: Vec<_> = self.mentions.iter().filter_map(|m| m.timestamp).collect();
353 times.sort();
354 times.dedup();
355 times
356 }
357
358 pub fn observation_span(
362 &self,
363 ) -> Option<(chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)> {
364 let times = self.observation_times();
365 if times.is_empty() {
366 None
367 } else {
368 Some((
369 times[0],
370 *times
371 .last()
372 .expect("times should not be empty after empty check"),
373 ))
374 }
375 }
376}
377
378#[derive(Debug)]
380pub struct StreamingResolver {
381 config: StreamingConfig,
382 clusters: HashMap<anno_core::IdentityId, EntityCluster>,
384 lsh: Option<MinHashLSH>,
386 lsh_to_cluster: HashMap<usize, anno_core::IdentityId>,
388 next_id: anno_core::IdentityId,
390 mention_count: usize,
392}
393
394impl StreamingResolver {
395 pub fn new(config: StreamingConfig) -> Self {
397 let lsh = if config.use_lsh {
398 Some(MinHashLSH::new(config.lsh_config.clone()))
399 } else {
400 None
401 };
402
403 Self {
404 config,
405 clusters: HashMap::new(),
406 lsh,
407 lsh_to_cluster: HashMap::new(),
408 next_id: anno_core::IdentityId::ZERO,
409 mention_count: 0,
410 }
411 }
412
413 pub fn add_mention(&mut self, mention: EntityMention) -> anno_core::IdentityId {
418 self.mention_count += 1;
419
420 let best_cluster = self.find_best_cluster(&mention);
422
423 let cluster_id = if let Some((cluster_id, similarity)) = best_cluster {
424 if similarity >= self.config.add_threshold {
425 if let Some(cluster) = self.clusters.get_mut(&cluster_id) {
427 cluster.add_mention(mention);
428 }
429 cluster_id
430 } else {
431 self.create_cluster(mention)
433 }
434 } else {
435 self.create_cluster(mention)
437 };
438
439 if self.clusters.len() > self.config.max_clusters {
441 self.merge_clusters();
442 }
443
444 cluster_id
445 }
446
447 pub fn add_entity(
449 &mut self,
450 doc_id: impl Into<String>,
451 surface: impl Into<String>,
452 entity_type: Option<String>,
453 ) -> anno_core::IdentityId {
454 let mut mention = EntityMention::new(doc_id, surface);
455 if let Some(et) = entity_type {
456 mention = mention.with_type(et);
457 }
458 self.add_mention(mention)
459 }
460
461 pub fn clusters(&self) -> Vec<&EntityCluster> {
463 self.clusters.values().collect()
464 }
465
466 pub fn get_cluster(&self, id: anno_core::IdentityId) -> Option<&EntityCluster> {
468 self.clusters.get(&id)
469 }
470
471 pub fn num_clusters(&self) -> usize {
473 self.clusters.len()
474 }
475
476 pub fn num_mentions(&self) -> usize {
478 self.mention_count
479 }
480
481 pub fn merge_clusters(&mut self) {
483 use anno_core::IdentityId;
484
485 let cluster_ids: Vec<IdentityId> = self.clusters.keys().copied().collect();
487 let mut to_merge: Vec<(IdentityId, IdentityId)> = Vec::new();
488
489 for i in 0..cluster_ids.len() {
490 for j in (i + 1)..cluster_ids.len() {
491 let id_a = cluster_ids[i];
492 let id_b = cluster_ids[j];
493
494 if let (Some(cluster_a), Some(cluster_b)) =
495 (self.clusters.get(&id_a), self.clusters.get(&id_b))
496 {
497 if self.config.require_type_match
499 && cluster_a.entity_type != cluster_b.entity_type
500 {
501 continue;
502 }
503
504 let similarity = self.cluster_similarity(cluster_a, cluster_b);
505 if similarity >= self.config.merge_threshold {
506 to_merge.push((id_a, id_b));
507 }
508 }
509 }
510 }
511
512 let mut merged_into: HashMap<IdentityId, IdentityId> = HashMap::new();
514
515 fn find_root(
516 merged_into: &mut HashMap<IdentityId, IdentityId>,
517 id: IdentityId,
518 ) -> IdentityId {
519 if let Some(&parent) = merged_into.get(&id) {
520 if parent != id {
521 let root = find_root(merged_into, parent);
522 merged_into.insert(id, root);
523 return root;
524 }
525 }
526 id
527 }
528
529 for (a, b) in to_merge {
530 let root_a = find_root(&mut merged_into, a);
531 let root_b = find_root(&mut merged_into, b);
532 if root_a != root_b {
533 merged_into.insert(root_b, root_a);
534 }
535 }
536
537 let to_remove: Vec<IdentityId> = merged_into
539 .iter()
540 .filter(|(k, v)| *k != *v)
541 .map(|(k, _)| *k)
542 .collect();
543
544 for id in to_remove {
545 if let Some(cluster) = self.clusters.remove(&id) {
546 let root = find_root(&mut merged_into, id);
547 if let Some(target) = self.clusters.get_mut(&root) {
548 target.merge(cluster);
549 }
550 }
551 }
552 }
553
554 fn find_best_cluster(&self, mention: &EntityMention) -> Option<(anno_core::IdentityId, f32)> {
560 if let Some(lsh) = &self.lsh {
561 let candidates = lsh.query(&mention.canonical_surface);
563
564 let mut best: Option<(anno_core::IdentityId, f32)> = None;
565 for idx in candidates {
566 if let Some(&cluster_id) = self.lsh_to_cluster.get(&idx) {
567 if let Some(cluster) = self.clusters.get(&cluster_id) {
568 if self.config.require_type_match
570 && mention.entity_type.is_some()
571 && cluster.entity_type != mention.entity_type
572 {
573 continue;
574 }
575
576 let sim = self.mention_cluster_similarity(mention, cluster);
577 let should_update = match best {
578 None => true,
579 Some((_, s)) => sim > s,
580 };
581 if should_update {
582 best = Some((cluster_id, sim));
583 }
584 }
585 }
586 }
587 best
588 } else {
589 let mut best: Option<(anno_core::IdentityId, f32)> = None;
591
592 for (&cluster_id, cluster) in &self.clusters {
593 if self.config.require_type_match
595 && mention.entity_type.is_some()
596 && cluster.entity_type != mention.entity_type
597 {
598 continue;
599 }
600
601 let sim = self.mention_cluster_similarity(mention, cluster);
602 let should_update = match best {
603 None => true,
604 Some((_, s)) => sim > s,
605 };
606 if should_update {
607 best = Some((cluster_id, sim));
608 }
609 }
610 best
611 }
612 }
613
614 fn create_cluster(&mut self, mention: EntityMention) -> anno_core::IdentityId {
616 let id = self.next_id;
617 self.next_id += 1;
618
619 if let Some(lsh) = &mut self.lsh {
621 let lsh_idx = lsh.len();
622 lsh.insert_text(id.get().to_string(), &mention.canonical_surface);
623 self.lsh_to_cluster.insert(lsh_idx, id);
624 }
625
626 let cluster = EntityCluster::from_mention(id, mention);
627 self.clusters.insert(id, cluster);
628 id
629 }
630
631 fn mention_cluster_similarity(&self, mention: &EntityMention, cluster: &EntityCluster) -> f32 {
633 if let (Some(emb), Some(centroid)) = (&mention.embedding, &cluster.centroid) {
634 return cosine_similarity(emb, centroid);
635 }
636 trigram_similarity(&mention.canonical_surface, &cluster.canonical_name)
637 }
638
639 fn cluster_similarity(&self, cluster_a: &EntityCluster, cluster_b: &EntityCluster) -> f32 {
641 if let (Some(c1), Some(c2)) = (&cluster_a.centroid, &cluster_b.centroid) {
642 return cosine_similarity(c1, c2);
643 }
644 trigram_similarity(&cluster_a.canonical_name, &cluster_b.canonical_name)
645 }
646}
647
648impl Default for StreamingResolver {
649 fn default() -> Self {
650 Self::new(StreamingConfig::default())
651 }
652}
653
654pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
660 if a.len() != b.len() || a.is_empty() {
661 return 0.0;
662 }
663
664 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
665 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
666 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
667
668 if norm_a == 0.0 || norm_b == 0.0 {
669 0.0
670 } else {
671 dot / (norm_a * norm_b)
672 }
673}
674
675pub fn trigram_similarity(a: &str, b: &str) -> f32 {
698 textprep::similarity::trigram_jaccard(a, b) as f32
700}
701
702#[doc(hidden)]
704#[deprecated(since = "0.3.0", note = "Use trigram_similarity instead")]
705pub fn string_similarity(a: &str, b: &str) -> f32 {
706 trigram_similarity(a, b)
707}
708
709impl EntityMention {
717 #[must_use]
723 pub fn from_track(doc_id: impl Into<String>, track: &crate::Track) -> Self {
724 Self {
725 doc_id: doc_id.into(),
726 canonical_surface: track.canonical_surface.clone(),
727 entity_type: track.entity_type.clone(),
728 embedding: track.embedding.clone(),
729 track_id: Some(track.id),
730 timestamp: None,
734 valid_from: None,
735 valid_until: None,
736 }
737 }
738}
739
740impl EntityCluster {
741 #[must_use]
747 pub fn to_identity(&self) -> anno_core::Identity {
748 let track_refs: Vec<anno_core::TrackRef> = self
750 .mentions
751 .iter()
752 .filter_map(|m| {
753 m.track_id.map(|tid| anno_core::TrackRef {
754 doc_id: m.doc_id.clone(),
755 track_id: tid,
756 })
757 })
758 .collect();
759
760 let source = if track_refs.is_empty() {
761 None
762 } else {
763 Some(anno_core::IdentitySource::CrossDocCoref { track_refs })
764 };
765
766 let valid_from = self.mentions.iter().filter_map(|m| m.valid_from).min();
768 let valid_until = self.mentions.iter().filter_map(|m| m.valid_until).max();
769
770 let _ = (valid_from, valid_until);
773
774 anno_core::Identity {
775 id: self.id,
776 canonical_name: self.canonical_name.clone(),
777 entity_type: self.entity_type.clone(),
778 kb_id: None,
779 kb_name: None,
780 description: None,
781 embedding: self.centroid.clone(),
782 aliases: self
783 .mentions
784 .iter()
785 .map(|m| m.canonical_surface.clone())
786 .filter(|s| s != &self.canonical_name)
787 .collect::<std::collections::HashSet<_>>()
788 .into_iter()
789 .collect(),
790 confidence: self.confidence,
791 source,
792 }
793 }
794}
795
796impl StreamingResolver {
797 #[must_use]
802 pub fn to_identities(&self) -> Vec<anno_core::Identity> {
803 self.clusters()
804 .into_iter()
805 .map(|c| c.to_identity())
806 .collect()
807 }
808
809 pub fn add_track(
814 &mut self,
815 doc_id: impl Into<String>,
816 track: &anno_core::Track,
817 ) -> anno_core::IdentityId {
818 let mention = EntityMention::from_track(doc_id, track);
819 self.add_mention(mention)
820 }
821}
822
823#[cfg(test)]
828mod tests {
829 use super::*;
830
831 #[test]
832 fn test_basic_streaming() {
833 let mut resolver = StreamingResolver::new(StreamingConfig::default());
834
835 resolver.add_entity("doc1", "Barack Obama", Some("Person".to_string()));
836 resolver.add_entity("doc2", "obama", Some("Person".to_string()));
837 resolver.add_entity("doc3", "Donald Trump", Some("Person".to_string()));
838
839 assert!(resolver.num_clusters() <= 3);
841 assert_eq!(resolver.num_mentions(), 3);
842 }
843
844 #[test]
845 fn test_type_filtering() {
846 let config = StreamingConfig {
847 require_type_match: true,
848 ..Default::default()
849 };
850 let mut resolver = StreamingResolver::new(config);
851
852 resolver.add_entity("doc1", "Apple", Some("Organization".to_string()));
853 resolver.add_entity("doc2", "Apple", Some("Food".to_string()));
854
855 assert_eq!(resolver.num_clusters(), 2);
857 }
858
859 #[test]
860 fn test_cluster_merging() {
861 let config = StreamingConfig {
862 max_clusters: 2,
863 merge_threshold: 0.3, ..Default::default()
865 };
866 let mut resolver = StreamingResolver::new(config);
867
868 resolver.add_entity("doc1", "New York City", None);
869 resolver.add_entity("doc2", "NYC", None);
870 resolver.add_entity("doc3", "New York", None);
871 resolver.add_entity("doc4", "Los Angeles", None);
872 resolver.add_entity("doc5", "LA", None);
873
874 assert!(resolver.num_clusters() <= 5);
876 }
877
878 #[test]
879 fn test_cosine_similarity() {
880 let a = vec![1.0, 0.0, 0.0];
881 let b = vec![1.0, 0.0, 0.0];
882 let c = vec![0.0, 1.0, 0.0];
883
884 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
885 assert!((cosine_similarity(&a, &c) - 0.0).abs() < 0.001);
886 }
887
888 #[test]
889 fn test_trigram_similarity() {
890 assert!(trigram_similarity("Barack Obama", "barack obama") > 0.9);
891 assert!(trigram_similarity("Obama", "Trump") < 0.3);
892 }
893
894 #[test]
895 fn test_document_ids() {
896 let mut resolver = StreamingResolver::new(StreamingConfig::default());
897
898 resolver.add_entity("doc1", "Barack Obama", None);
899 resolver.add_entity("doc2", "obama", None);
900
901 let clusters = resolver.clusters();
902 for cluster in clusters {
903 if cluster.mentions.len() > 1 {
904 let doc_ids = cluster.document_ids();
905 assert!(!doc_ids.is_empty());
906 }
907 }
908 }
909
910 #[test]
911 fn test_entity_mention_from_track() {
912 let track = anno_core::Track::new(42, "Barack Obama").with_type("Person".to_string());
913
914 let mention = EntityMention::from_track("doc1", &track);
915
916 assert_eq!(mention.doc_id, "doc1");
917 assert_eq!(mention.canonical_surface, "Barack Obama");
918 assert_eq!(
919 mention.entity_type,
920 Some(anno_core::TypeLabel::from("Person"))
921 );
922 assert_eq!(mention.track_id, Some(anno_core::TrackId::new(42)));
923 }
924
925 #[test]
926 fn test_cluster_to_identity() {
927 let mut resolver = StreamingResolver::new(StreamingConfig::default());
928
929 resolver.add_entity("doc1", "Barack Obama", Some("Person".to_string()));
931 resolver.add_entity("doc2", "obama", Some("Person".to_string()));
932
933 let identities = resolver.to_identities();
934
935 assert!(!identities.is_empty());
937
938 for identity in &identities {
939 assert!(!identity.canonical_name.is_empty());
941 assert!((0.0..=1.0).contains(&identity.confidence));
943 }
944 }
945
946 #[test]
947 fn test_add_track() {
948 let mut resolver = StreamingResolver::new(StreamingConfig::default());
949
950 let track1 = anno_core::Track::new(1, "Jensen Huang").with_type("Person".to_string());
951 let track2 = anno_core::Track::new(2, "Nvidia").with_type("Organization".to_string());
952
953 resolver.add_track("doc1", &track1);
954 resolver.add_track("doc1", &track2);
955
956 assert_eq!(resolver.num_mentions(), 2);
957 assert!(resolver.num_clusters() >= 1);
959 }
960
961 #[test]
962 fn test_streaming_basic_similarity_smoke() {
963 let mut resolver = StreamingResolver::new(StreamingConfig::default());
964
965 resolver.add_entity("doc1", "Barack Obama", Some("Person".to_string()));
966 resolver.add_entity("doc2", "obama", Some("Person".to_string()));
967 resolver.add_entity("doc3", "Donald Trump", Some("Person".to_string()));
968
969 assert!(resolver.num_clusters() <= 3);
970 }
971
972 }
974
975#[cfg(test)]
980mod proptests {
981 use super::*;
982 use proptest::prelude::*;
983
984 proptest! {
985 #![proptest_config(ProptestConfig::with_cases(50))]
986
987 #[test]
989 fn streaming_mention_conservation(
990 entities in proptest::collection::vec("[A-Za-z ]{3,20}", 1..20)
991 ) {
992 let mut resolver = StreamingResolver::new(StreamingConfig::default());
993
994 for (i, entity) in entities.iter().enumerate() {
995 resolver.add_entity(format!("doc{}", i), entity, None);
996 }
997
998 let cluster_mentions: usize = resolver.clusters()
999 .iter()
1000 .map(|c| c.mentions.len())
1001 .sum();
1002
1003 prop_assert_eq!(resolver.num_mentions(), cluster_mentions,
1004 "Mention count mismatch: {} != {}",
1005 resolver.num_mentions(), cluster_mentions);
1006 }
1007
1008 #[test]
1010 fn streaming_cluster_bounded(
1011 entities in proptest::collection::vec("[A-Za-z]{3,15}", 1..30)
1012 ) {
1013 let mut resolver = StreamingResolver::new(StreamingConfig::default());
1014
1015 for (i, entity) in entities.iter().enumerate() {
1016 resolver.add_entity(format!("doc{}", i), entity, None);
1017 }
1018
1019 prop_assert!(resolver.num_clusters() <= resolver.num_mentions(),
1020 "More clusters ({}) than mentions ({})",
1021 resolver.num_clusters(), resolver.num_mentions());
1022 }
1023
1024 #[test]
1026 fn streaming_identical_cluster(name in "[A-Za-z]{5,15}", count in 2usize..10) {
1027 let mut resolver = StreamingResolver::new(StreamingConfig::default());
1028
1029 for i in 0..count {
1030 resolver.add_entity(format!("doc{}", i), &name, None);
1031 }
1032
1033 prop_assert_eq!(resolver.num_clusters(), 1,
1035 "Identical entities should form one cluster, got {}",
1036 resolver.num_clusters());
1037
1038 let cluster = resolver.clusters().into_iter().next().expect("should have at least one cluster");
1039 prop_assert_eq!(cluster.mentions.len(), count,
1040 "Cluster should have {} mentions, got {}",
1041 count, cluster.mentions.len());
1042 }
1043
1044 #[test]
1046 fn streaming_type_separation(name in "[A-Za-z]{5,15}") {
1047 let config = StreamingConfig {
1048 require_type_match: true,
1049 ..Default::default()
1050 };
1051 let mut resolver = StreamingResolver::new(config);
1052
1053 resolver.add_entity("doc1", &name, Some("Person".to_string()));
1054 resolver.add_entity("doc2", &name, Some("Organization".to_string()));
1055
1056 prop_assert_eq!(resolver.num_clusters(), 2,
1057 "Different types should not cluster");
1058 }
1059
1060 #[test]
1062 fn streaming_confidence_bounded(
1063 entities in proptest::collection::vec("[A-Za-z ]{3,20}", 1..15)
1064 ) {
1065 let mut resolver = StreamingResolver::new(StreamingConfig::default());
1066
1067 for (i, entity) in entities.iter().enumerate() {
1068 resolver.add_entity(format!("doc{}", i), entity, None);
1069 }
1070
1071 for cluster in resolver.clusters() {
1072 prop_assert!((0.0..=1.0).contains(&cluster.confidence),
1073 "Confidence {} out of bounds", cluster.confidence);
1074 }
1075 }
1076
1077 #[test]
1079 fn trigram_sim_symmetric(a in "[A-Za-z ]{3,20}", b in "[A-Za-z ]{3,20}") {
1080 let sim_ab = trigram_similarity(&a, &b);
1081 let sim_ba = trigram_similarity(&b, &a);
1082 prop_assert!((sim_ab - sim_ba).abs() < 0.001,
1083 "Trigram similarity not symmetric: {} vs {}", sim_ab, sim_ba);
1084 }
1085
1086 #[test]
1088 fn cosine_sim_bounded(
1089 dim in 10usize..100,
1090 seed in any::<u64>()
1091 ) {
1092 let mut rng = seed;
1093 let a: Vec<f32> = (0..dim).map(|_| {
1094 rng = rng.wrapping_mul(1103515245).wrapping_add(12345);
1095 (rng % 1000) as f32 / 1000.0
1096 }).collect();
1097 let b: Vec<f32> = (0..dim).map(|_| {
1098 rng = rng.wrapping_mul(1103515245).wrapping_add(12345);
1099 (rng % 1000) as f32 / 1000.0
1100 }).collect();
1101
1102 let sim = cosine_similarity(&a, &b);
1103 prop_assert!((-0.001..=1.001).contains(&sim),
1104 "Cosine similarity {} out of bounds", sim);
1105 }
1106 }
1107}