pub(crate) mod registry;
pub mod encoder;
pub use encoder::{EncoderOutput, TextEncoder};
pub mod traits;
pub use traits::{
DiscontinuousEntity, DiscontinuousNER, ExtractionWithRelations, RelationExtractor,
RelationTriple, ZeroShotNER,
};
pub(crate) mod span;
pub(crate) use span::{HandshakingCell, HandshakingMatrix};
pub mod coref;
pub use coref::{resolve_coreferences, CoreferenceCluster, CoreferenceConfig};
pub mod relation_extraction;
pub use relation_extraction::{
extract_relation_triples, extract_relation_triples_simple, extract_relations,
RelationExtractionConfig,
};
#[cfg(test)]
mod tests {
use super::coref::{resolve_coreferences, CoreferenceConfig};
use super::registry::{SemanticRegistry, SemanticRegistryBuilder};
use super::*;
use crate::{Confidence, Entity, EntityType};
#[test]
fn test_semantic_registry_builder() {
let registry = SemanticRegistry::builder()
.add_entity("person", "A human being")
.add_entity("organization", "A company or group")
.add_relation("WORKS_FOR", "Employment relationship")
.build_zero(768);
assert_eq!(registry.len(), 3);
assert_eq!(registry.entity_labels().count(), 2);
assert_eq!(registry.relation_labels().count(), 1);
}
#[test]
fn test_standard_ner_registry() {
let registry = SemanticRegistry::standard_ner(768);
assert!(registry.len() >= 5);
assert!(registry.label_index.contains_key("person"));
assert!(registry.label_index.contains_key("organization"));
}
#[test]
fn test_coreference_string_match() {
let entities = vec![
Entity::new("Marie Curie", EntityType::Person, 0, 11, 0.95),
Entity::new("Curie", EntityType::Person, 50, 55, 0.90),
];
let embeddings = vec![0.0f32; 2 * 768]; let clusters =
resolve_coreferences(&entities, &embeddings, 768, &CoreferenceConfig::default());
assert_eq!(clusters.len(), 1);
assert_eq!(clusters[0].members.len(), 2);
assert_eq!(clusters[0].canonical_name, "Marie Curie");
}
#[test]
fn test_handshaking_matrix() {
let scores = vec![
0.9, 0.1, 0.2, 0.8, 0.1, 0.1, 0.0, 0.0, 0.7, 0.2, 0.3, 0.6, 0.0, 0.0, 0.0, 0.0, 0.1, 0.1, ];
let matrix = HandshakingMatrix::from_dense(&scores, 3, 2, 0.5);
assert!(matrix.cells.len() >= 4);
}
#[test]
fn test_relation_extraction() {
let entities = vec![
Entity::new("Steve Jobs", EntityType::Person, 0, 10, 0.95),
Entity::new("Apple", EntityType::Organization, 20, 25, 0.90),
];
let text = "Steve Jobs founded Apple Inc in 1976";
let registry = SemanticRegistry::builder()
.add_relation("FOUNDED", "Founded an organization")
.build_zero(768);
let config = RelationExtractionConfig::default();
let relations = extract_relations(&entities, text, ®istry, &config);
assert!(!relations.is_empty());
assert_eq!(relations[0].relation_type, "FOUNDED");
}
#[test]
fn test_relation_extraction_uses_character_offsets_with_unicode_prefix() {
let text = "👋 Steve Jobs founded Apple Inc.";
let steve_start = text.find("Steve Jobs").expect("substring present");
let conv = crate::offset::SpanConverter::new(text);
let steve_start_char = conv.byte_to_char(steve_start);
let steve_end_char = steve_start_char + "Steve Jobs".chars().count();
let apple_start = text.find("Apple").expect("substring present");
let apple_start_char = conv.byte_to_char(apple_start);
let apple_end_char = apple_start_char + "Apple".chars().count();
let entities = vec![
Entity::new(
"Steve Jobs",
EntityType::Person,
steve_start_char,
steve_end_char,
0.95,
),
Entity::new(
"Apple",
EntityType::Organization,
apple_start_char,
apple_end_char,
0.90,
),
];
let registry = SemanticRegistry::builder()
.add_relation("FOUNDED", "Founded an organization")
.build_zero(768);
let config = RelationExtractionConfig::default();
let relations = extract_relations(&entities, text, ®istry, &config);
assert!(
!relations.is_empty(),
"Expected FOUNDED relation to be detected"
);
assert_eq!(relations[0].relation_type, "FOUNDED");
let trigger = relations[0]
.trigger_span
.expect("expected trigger_span to be present");
let trigger_text: String = text
.chars()
.skip(trigger.0)
.take(trigger.1.saturating_sub(trigger.0))
.collect();
assert_eq!(trigger_text.to_ascii_lowercase(), "founded");
}
#[test]
fn test_coreference_empty_input() {
let clusters = resolve_coreferences(&[], &[], 768, &CoreferenceConfig::default());
assert!(clusters.is_empty());
}
#[test]
fn test_coreference_single_entity_no_cluster() {
let entities = vec![Entity::new("Alice", EntityType::Person, 0, 5, 0.9)];
let embeddings = vec![0.0f32; 768];
let clusters =
resolve_coreferences(&entities, &embeddings, 768, &CoreferenceConfig::default());
assert!(clusters.is_empty());
}
#[test]
fn test_coreference_type_mismatch_prevents_linking() {
let entities = vec![
Entity::new("Apple", EntityType::Organization, 0, 5, 0.9),
Entity::new("Apple", EntityType::Location, 20, 25, 0.9),
];
let embeddings = vec![0.0f32; 2 * 768];
let clusters =
resolve_coreferences(&entities, &embeddings, 768, &CoreferenceConfig::default());
assert!(
clusters.is_empty(),
"Different entity types should not cluster even with same text"
);
}
#[test]
fn test_coreference_distance_filtering() {
let entities = vec![
Entity::new("Bob", EntityType::Person, 0, 3, 0.9),
Entity::new("Bob", EntityType::Person, 1000, 1003, 0.9),
];
let embeddings = vec![0.0f32; 2 * 768];
let config = CoreferenceConfig {
max_distance: Some(10), use_string_match: false, similarity_threshold: Confidence::new(0.85),
};
let clusters = resolve_coreferences(&entities, &embeddings, 768, &config);
assert!(
clusters.is_empty(),
"Entities beyond max_distance should not cluster"
);
}
#[test]
fn test_coreference_string_match_substring() {
let entities = vec![
Entity::new("Dr. Smith", EntityType::Person, 0, 9, 0.9),
Entity::new("Smith", EntityType::Person, 30, 35, 0.9),
];
let embeddings = vec![0.0f32; 2 * 768];
let clusters =
resolve_coreferences(&entities, &embeddings, 768, &CoreferenceConfig::default());
assert_eq!(clusters.len(), 1);
assert_eq!(clusters[0].canonical_name, "Dr. Smith");
}
#[test]
fn test_coreference_transitive_closure() {
let entities = vec![
Entity::new("Robert Johnson", EntityType::Person, 0, 14, 0.9),
Entity::new("Johnson", EntityType::Person, 30, 37, 0.9),
Entity::new("Mr. Johnson", EntityType::Person, 60, 71, 0.9),
];
let embeddings = vec![0.0f32; 3 * 768];
let clusters =
resolve_coreferences(&entities, &embeddings, 768, &CoreferenceConfig::default());
assert_eq!(clusters.len(), 1);
assert_eq!(clusters[0].members.len(), 3);
assert_eq!(clusters[0].canonical_name, "Robert Johnson");
}
#[test]
fn test_handshaking_matrix_empty_scores() {
let matrix = HandshakingMatrix::from_dense(&[], 0, 0, 0.5);
assert!(matrix.cells.is_empty());
}
#[test]
fn test_handshaking_matrix_all_below_threshold() {
let scores = vec![0.1, 0.2, 0.3, 0.4]; let matrix = HandshakingMatrix::from_dense(&scores, 2, 1, 0.5);
assert!(
matrix.cells.is_empty(),
"All scores below threshold should yield no cells"
);
}
#[test]
fn test_handshaking_matrix_decode_entities() {
let registry = SemanticRegistry::builder()
.add_entity("person", "A human being")
.build_zero(768);
let mut scores = vec![0.0f32; 3 * 3];
scores[1] = 0.9;
let matrix = HandshakingMatrix::from_dense(&scores, 3, 1, 0.5);
assert_eq!(matrix.cells.len(), 1);
let entities = matrix.decode_entities(®istry);
assert_eq!(entities.len(), 1);
let (span, label, score) = &entities[0];
assert_eq!(label.slug, "person");
assert!((score - 0.9).abs() < 0.001);
assert_eq!(span.start, 1); assert_eq!(span.end, 1); }
#[test]
fn test_handshaking_matrix_non_maximum_suppression() {
let registry = SemanticRegistry::builder()
.add_entity("person", "A human being")
.build_zero(768);
let mut scores = vec![0.0f32; 4 * 4];
scores[0] = 0.9;
scores[5] = 0.7;
let matrix = HandshakingMatrix::from_dense(&scores, 4, 1, 0.5);
assert_eq!(matrix.cells.len(), 2);
let entities = matrix.decode_entities(®istry);
assert_eq!(
entities.len(),
2,
"Non-overlapping adjacent spans should both survive NMS"
);
}
#[test]
fn test_registry_get_embedding() {
let registry = SemanticRegistry::builder()
.add_entity("person", "A human being")
.add_entity("org", "An organization")
.build_zero(4);
let emb = registry.get_embedding("person");
assert!(emb.is_some());
assert_eq!(emb.unwrap().len(), 4);
let missing = registry.get_embedding("nonexistent");
assert!(missing.is_none());
}
#[test]
fn test_registry_empty() {
let registry = SemanticRegistryBuilder::new().build_zero(768);
assert!(registry.is_empty());
assert_eq!(registry.len(), 0);
assert_eq!(registry.entity_labels().count(), 0);
assert_eq!(registry.relation_labels().count(), 0);
}
#[test]
fn test_discontinuous_entity_contiguous() {
let entity = DiscontinuousEntity {
spans: vec![(0, 5)],
text: "hello".to_string(),
entity_type: "person".to_string(),
confidence: Confidence::new(0.9),
};
assert!(entity.is_contiguous());
let converted = entity.to_entity().expect("should convert single-span");
assert_eq!(converted.text, "hello");
assert_eq!(converted.start(), 0);
assert_eq!(converted.end(), 5);
}
#[test]
fn test_discontinuous_entity_non_contiguous() {
let entity = DiscontinuousEntity {
spans: vec![(0, 3), (10, 15)],
text: "New airports".to_string(),
entity_type: "location".to_string(),
confidence: Confidence::new(0.8),
};
assert!(!entity.is_contiguous());
assert!(entity.to_entity().is_none());
}
#[test]
fn test_extraction_with_relations_into_anno_relations() {
let extraction = ExtractionWithRelations {
entities: vec![
Entity::new("Alice", EntityType::Person, 0, 5, 0.9),
Entity::new("Acme", EntityType::Organization, 20, 24, 0.8),
],
relations: vec![RelationTriple {
head_idx: 0,
tail_idx: 1,
relation_type: "WORKS_FOR".to_string(),
confidence: Confidence::new(0.85),
}],
};
let (entities, relations) = extraction.into_anno_relations();
assert_eq!(entities.len(), 2);
assert_eq!(relations.len(), 1);
assert_eq!(relations[0].relation_type, "WORKS_FOR");
assert_eq!(relations[0].head.text, "Alice");
assert_eq!(relations[0].tail.text, "Acme");
}
#[test]
fn test_extraction_with_relations_out_of_bounds_dropped() {
let extraction = ExtractionWithRelations {
entities: vec![Entity::new("Alice", EntityType::Person, 0, 5, 0.9)],
relations: vec![RelationTriple {
head_idx: 0,
tail_idx: 99, relation_type: "WORKS_FOR".to_string(),
confidence: Confidence::new(0.85),
}],
};
let (_, relations) = extraction.into_anno_relations();
assert!(
relations.is_empty(),
"Out-of-bounds relation should be silently dropped"
);
}
#[test]
fn test_relation_extraction_empty_entities() {
let registry = SemanticRegistry::builder()
.add_relation("FOUNDED", "Founded an organization")
.build_zero(768);
let config = RelationExtractionConfig::default();
let relations = extract_relations(&[], "some text", ®istry, &config);
assert!(relations.is_empty());
}
#[test]
fn test_relation_extraction_no_relation_labels() {
let entities = vec![
Entity::new("Alice", EntityType::Person, 0, 5, 0.9),
Entity::new("Acme", EntityType::Organization, 20, 24, 0.8),
];
let registry = SemanticRegistry::builder()
.add_entity("person", "A human being")
.build_zero(768);
let config = RelationExtractionConfig::default();
let text = "Alice works at Acme Corp";
let relations = extract_relations(&entities, text, ®istry, &config);
assert!(
relations.is_empty(),
"No relation labels in registry -> no relations extracted"
);
}
#[test]
fn test_relation_extraction_distance_penalty() {
let registry = SemanticRegistry::builder()
.add_relation("FOUNDED", "Founded an organization")
.build_zero(768);
let text_close = "Jobs founded Apple in 1976";
let entities_close = vec![
Entity::new("Jobs", EntityType::Person, 0, 4, 0.9),
Entity::new("Apple", EntityType::Organization, 13, 18, 0.9),
];
let config = RelationExtractionConfig::default();
let rels_close = extract_relations(&entities_close, text_close, ®istry, &config);
assert!(!rels_close.is_empty());
assert!(rels_close[0].confidence > 0.5);
}
#[test]
fn test_extract_relation_triples_overlapping_spans_skipped() {
let registry = SemanticRegistry::builder()
.add_relation("PART_OF", "Part of")
.build_zero(768);
let text = "New York City is a great city";
let entities = vec![
Entity::new("New York City", EntityType::Location, 0, 13, 0.9),
Entity::new("York", EntityType::Location, 4, 8, 0.8),
];
let config = RelationExtractionConfig::default();
let triples = extract_relation_triples(&entities, text, ®istry, &config);
assert!(
triples.is_empty(),
"Overlapping spans should be skipped in extract_relation_triples"
);
}
}