use crate::{CorefChain, Entity, Model, Relation, Result, StackedNER};
#[derive(Debug, Clone)]
pub struct AnnotatedDoc {
pub text: String,
pub entities: Vec<Entity>,
pub relations: Vec<Relation>,
pub coref_chains: Vec<CorefChain>,
}
impl AnnotatedDoc {
#[must_use]
pub fn new(
text: impl Into<String>,
entities: Vec<Entity>,
relations: Vec<Relation>,
coref_chains: Vec<CorefChain>,
) -> Self {
Self {
text: text.into(),
entities,
relations,
coref_chains,
}
}
#[must_use]
pub fn entity_texts(&self) -> Vec<&str> {
self.entities.iter().map(|e| e.text.as_str()).collect()
}
}
pub fn annotate(text: &str) -> Result<AnnotatedDoc> {
let model = StackedNER::default();
let entities = model.extract_entities(text, None)?;
Ok(AnnotatedDoc::new(text, entities, Vec::new(), Vec::new()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn annotate_returns_entities() {
let doc = annotate("Marie Curie won the Nobel Prize.").unwrap();
assert!(!doc.entities.is_empty(), "should find at least one entity");
assert!(doc.relations.is_empty());
assert!(doc.coref_chains.is_empty());
}
#[test]
fn annotate_empty_text() {
let doc = annotate("").unwrap();
assert!(doc.entities.is_empty());
assert!(doc.relations.is_empty());
assert!(doc.coref_chains.is_empty());
}
#[test]
fn entity_texts_matches_entities() {
let doc = annotate("Lynn Conway worked at IBM.").unwrap();
let texts = doc.entity_texts();
assert_eq!(texts.len(), doc.entities.len());
for (text, entity) in texts.iter().zip(&doc.entities) {
assert_eq!(*text, entity.text.as_str());
}
}
#[test]
fn new_preserves_all_fields() {
let entities = vec![Entity::new("Alice", crate::EntityType::Person, 0, 5, 0.9)];
let relations = vec![];
let chains = vec![];
let doc = AnnotatedDoc::new("Alice went home.", entities.clone(), relations, chains);
assert_eq!(doc.text, "Alice went home.");
assert_eq!(doc.entities.len(), 1);
assert_eq!(doc.entities[0].text, "Alice");
}
#[test]
fn entity_texts_empty_doc() {
let doc = AnnotatedDoc::new("nothing here", vec![], vec![], vec![]);
assert!(doc.entity_texts().is_empty());
}
}