#[allow(unused_imports)]
use crate::{Confidence, Entity, EntityType, Relation};
pub trait ZeroShotNER: Send + Sync {
fn extract_with_types(
&self,
text: &str,
entity_types: &[&str],
threshold: f32,
) -> crate::Result<Vec<Entity>>;
fn extract_with_descriptions(
&self,
text: &str,
descriptions: &[&str],
threshold: f32,
) -> crate::Result<Vec<Entity>>;
fn extract_with_described_types(
&self,
text: &str,
types_with_descriptions: &[(&str, &str)],
threshold: f32,
) -> crate::Result<Vec<Entity>> {
let labels: Vec<&str> = types_with_descriptions
.iter()
.map(|(label, _)| *label)
.collect();
self.extract_with_types(text, &labels, threshold)
}
fn default_types(&self) -> &[&'static str];
}
pub trait RelationExtractor: Send + Sync {
fn extract_with_relations(
&self,
text: &str,
entity_types: &[&str],
relation_types: &[&str],
threshold: f32,
) -> crate::Result<ExtractionWithRelations>;
fn extract_relations_default(
&self,
text: &str,
) -> crate::Result<(Vec<crate::Entity>, Vec<crate::Relation>)> {
let result =
self.extract_with_relations(text, DEFAULT_ENTITY_TYPES, DEFAULT_RELATION_TYPES, 0.5)?;
Ok(result.into_anno_relations())
}
}
#[derive(Debug, Clone, Default)]
pub struct ExtractionWithRelations {
pub entities: Vec<Entity>,
pub relations: Vec<RelationTriple>,
}
#[derive(Debug, Clone)]
pub struct RelationTriple {
pub head_idx: usize,
pub tail_idx: usize,
pub relation_type: String,
pub confidence: Confidence,
}
pub(crate) const DEFAULT_ENTITY_TYPES: &[&str] = &[
"person",
"organization",
"location",
"date",
"product",
"event",
];
pub(crate) const DEFAULT_RELATION_TYPES: &[&str] = &[
"founded",
"works_for",
"located_in",
"acquired",
"born_in",
"member_of",
"ceo_of",
"part_of",
"subsidiary_of",
];
impl ExtractionWithRelations {
#[must_use]
pub fn into_anno_relations(self) -> (Vec<Entity>, Vec<crate::Relation>) {
let relations = self
.relations
.iter()
.filter_map(|t| {
let head = self.entities.get(t.head_idx)?.clone();
let tail = self.entities.get(t.tail_idx)?.clone();
Some(crate::Relation::new(
head,
tail,
t.relation_type.clone(),
t.confidence,
))
})
.collect();
(self.entities, relations)
}
}
pub trait DiscontinuousNER: Send + Sync {
fn extract_discontinuous(
&self,
text: &str,
entity_types: &[&str],
threshold: f32,
) -> crate::Result<Vec<DiscontinuousEntity>>;
}
#[derive(Debug, Clone)]
pub struct DiscontinuousEntity {
pub spans: Vec<(usize, usize)>,
pub text: String,
pub entity_type: String,
pub confidence: Confidence,
}
impl DiscontinuousEntity {
pub fn is_contiguous(&self) -> bool {
self.spans.len() == 1
}
pub fn to_entity(&self) -> Option<Entity> {
if self.is_contiguous() {
let (start, end) = self.spans[0];
Some(Entity::new(
self.text.clone(),
EntityType::from_label(&self.entity_type),
start,
end,
self.confidence,
))
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_entity_types_not_empty() {
assert!(
!DEFAULT_ENTITY_TYPES.is_empty(),
"DEFAULT_ENTITY_TYPES must have at least one entry"
);
assert!(DEFAULT_ENTITY_TYPES.contains(&"person"));
assert!(DEFAULT_ENTITY_TYPES.contains(&"organization"));
assert!(DEFAULT_ENTITY_TYPES.contains(&"location"));
}
#[test]
fn test_default_relation_types_not_empty() {
assert!(
!DEFAULT_RELATION_TYPES.is_empty(),
"DEFAULT_RELATION_TYPES must have at least one entry"
);
assert!(DEFAULT_RELATION_TYPES.contains(&"founded"));
assert!(DEFAULT_RELATION_TYPES.contains(&"works_for"));
}
#[test]
fn test_default_types_are_lowercase() {
for ty in DEFAULT_ENTITY_TYPES {
assert_eq!(*ty, ty.to_lowercase(), "entity type should be lowercase");
}
for rel in DEFAULT_RELATION_TYPES {
assert_eq!(
*rel,
rel.to_lowercase(),
"relation type should be lowercase"
);
}
}
#[test]
fn test_default_types_no_duplicates() {
let mut seen = std::collections::HashSet::new();
for ty in DEFAULT_ENTITY_TYPES {
assert!(seen.insert(*ty), "duplicate entity type: {}", ty);
}
let mut seen = std::collections::HashSet::new();
for rel in DEFAULT_RELATION_TYPES {
assert!(seen.insert(*rel), "duplicate relation type: {}", rel);
}
}
#[derive(Debug, Default)]
struct CapturingNer {
last_labels: std::sync::Mutex<Vec<String>>,
}
impl ZeroShotNER for CapturingNer {
fn extract_with_types(
&self,
_text: &str,
entity_types: &[&str],
_threshold: f32,
) -> crate::Result<Vec<Entity>> {
let mut last = self.last_labels.lock().unwrap_or_else(|e| e.into_inner());
*last = entity_types.iter().map(|s| (*s).to_string()).collect();
Ok(Vec::new())
}
fn extract_with_descriptions(
&self,
_text: &str,
_descriptions: &[&str],
_threshold: f32,
) -> crate::Result<Vec<Entity>> {
Ok(Vec::new())
}
fn default_types(&self) -> &'static [&'static str] {
DEFAULT_ENTITY_TYPES
}
}
#[test]
fn extract_with_described_types_default_passes_only_labels() {
let ner = CapturingNer::default();
let pairs = &[
("person", "a named human individual"),
("organization", "a company or institution"),
];
let _ = ZeroShotNER::extract_with_described_types(&ner, "any text", pairs, 0.5);
let last = ner.last_labels.lock().unwrap_or_else(|e| e.into_inner());
assert_eq!(
*last,
vec!["person".to_string(), "organization".to_string()]
);
}
#[test]
fn extract_with_described_types_handles_empty_descriptions() {
let ner = CapturingNer::default();
let pairs = &[("date", ""), ("amount", "money mentioned in the text")];
let _ = ZeroShotNER::extract_with_described_types(&ner, "any text", pairs, 0.5);
let last = ner.last_labels.lock().unwrap_or_else(|e| e.into_inner());
assert_eq!(*last, vec!["date".to_string(), "amount".to_string()]);
}
#[test]
fn extract_with_described_types_empty_input_returns_no_entities() {
let ner = CapturingNer::default();
let pairs: &[(&str, &str)] = &[];
let result =
ZeroShotNER::extract_with_described_types(&ner, "any text", pairs, 0.5).unwrap();
assert!(result.is_empty());
}
#[test]
fn test_extraction_with_relations_default_is_empty() {
let extraction = ExtractionWithRelations::default();
assert!(extraction.entities.is_empty());
assert!(extraction.relations.is_empty());
}
#[test]
fn test_extraction_with_relations_into_anno_empty() {
let extraction = ExtractionWithRelations::default();
let (entities, relations) = extraction.into_anno_relations();
assert!(entities.is_empty());
assert!(relations.is_empty());
}
#[test]
fn test_extraction_with_relations_multiple_relations() {
let extraction = ExtractionWithRelations {
entities: vec![
Entity::new("Alice", EntityType::Person, 0, 5, 0.9),
Entity::new("Bob", EntityType::Person, 10, 13, 0.85),
Entity::new("Acme", EntityType::Organization, 20, 24, 0.8),
],
relations: vec![
RelationTriple {
head_idx: 0,
tail_idx: 2,
relation_type: "WORKS_FOR".to_string(),
confidence: Confidence::new(0.8),
},
RelationTriple {
head_idx: 1,
tail_idx: 2,
relation_type: "WORKS_FOR".to_string(),
confidence: Confidence::new(0.7),
},
],
};
let (entities, relations) = extraction.into_anno_relations();
assert_eq!(entities.len(), 3);
assert_eq!(relations.len(), 2);
assert_eq!(relations[0].head.text, "Alice");
assert_eq!(relations[1].head.text, "Bob");
}
#[test]
fn test_extraction_mixed_valid_and_invalid_indices() {
let extraction = ExtractionWithRelations {
entities: vec![
Entity::new("X", EntityType::Person, 0, 1, 0.9),
Entity::new("Y", EntityType::Organization, 5, 6, 0.8),
],
relations: vec![
RelationTriple {
head_idx: 0,
tail_idx: 1,
relation_type: "VALID".to_string(),
confidence: Confidence::new(0.9),
},
RelationTriple {
head_idx: 0,
tail_idx: 100,
relation_type: "INVALID_TAIL".to_string(),
confidence: Confidence::new(0.5),
},
RelationTriple {
head_idx: 50,
tail_idx: 1,
relation_type: "INVALID_HEAD".to_string(),
confidence: Confidence::new(0.5),
},
],
};
let (_, relations) = extraction.into_anno_relations();
assert_eq!(relations.len(), 1, "only the valid relation should survive");
assert_eq!(relations[0].relation_type, "VALID");
}
#[test]
fn test_relation_triple_clone() {
let triple = RelationTriple {
head_idx: 0,
tail_idx: 1,
relation_type: "FOUNDED".to_string(),
confidence: Confidence::new(0.95),
};
let cloned = triple.clone();
assert_eq!(cloned.head_idx, 0);
assert_eq!(cloned.tail_idx, 1);
assert_eq!(cloned.relation_type, "FOUNDED");
assert!((cloned.confidence.value() - 0.95).abs() < f64::EPSILON);
}
#[test]
fn test_discontinuous_entity_empty_spans() {
let entity = DiscontinuousEntity {
spans: vec![],
text: String::new(),
entity_type: "misc".to_string(),
confidence: Confidence::new(0.5),
};
assert!(!entity.is_contiguous());
assert!(entity.to_entity().is_none());
}
#[test]
fn test_discontinuous_entity_three_spans() {
let entity = DiscontinuousEntity {
spans: vec![(0, 3), (10, 15), (20, 25)],
text: "compound entity".to_string(),
entity_type: "location".to_string(),
confidence: Confidence::new(0.7),
};
assert!(!entity.is_contiguous());
assert!(entity.to_entity().is_none());
}
#[test]
fn test_discontinuous_entity_to_entity_preserves_fields() {
let entity = DiscontinuousEntity {
spans: vec![(5, 10)],
text: "Smith".to_string(),
entity_type: "person".to_string(),
confidence: Confidence::new(0.88),
};
let converted = entity.to_entity().expect("single span should convert");
assert_eq!(converted.text, "Smith");
assert_eq!(converted.start(), 5);
assert_eq!(converted.end(), 10);
assert_eq!(converted.entity_type, EntityType::Person);
assert!((converted.confidence - 0.88).abs() < 0.001);
}
}