use crate::eval::GoldEntity;
use anno::EntityType;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[non_exhaustive]
pub enum Domain {
#[default]
News,
SocialMedia,
Biomedical,
Financial,
Legal,
Scientific,
Conversational,
Technical,
Historical,
Sports,
Entertainment,
Politics,
Ecommerce,
Academic,
Email,
Weather,
Travel,
Food,
RealEstate,
Cybersecurity,
Multilingual,
}
impl Domain {
pub fn all() -> &'static [Domain] {
&[
Domain::News,
Domain::SocialMedia,
Domain::Biomedical,
Domain::Financial,
Domain::Legal,
Domain::Scientific,
Domain::Conversational,
Domain::Technical,
Domain::Historical,
Domain::Sports,
Domain::Entertainment,
Domain::Politics,
Domain::Ecommerce,
Domain::Academic,
Domain::Email,
Domain::Weather,
Domain::Travel,
Domain::Food,
Domain::RealEstate,
Domain::Cybersecurity,
Domain::Multilingual,
]
}
pub fn name(&self) -> &'static str {
match self {
Domain::News => "News",
Domain::SocialMedia => "Social Media",
Domain::Biomedical => "Biomedical",
Domain::Financial => "Financial",
Domain::Legal => "Legal",
Domain::Scientific => "Scientific",
Domain::Conversational => "Conversational",
Domain::Technical => "Technical",
Domain::Historical => "Historical",
Domain::Sports => "Sports",
Domain::Entertainment => "Entertainment",
Domain::Politics => "Politics",
Domain::Ecommerce => "E-commerce",
Domain::Academic => "Academic",
Domain::Email => "Email",
Domain::Weather => "Weather",
Domain::Travel => "Travel",
Domain::Food => "Food",
Domain::RealEstate => "Real Estate",
Domain::Cybersecurity => "Cybersecurity",
Domain::Multilingual => "Multilingual",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub enum Difficulty {
#[default]
Easy,
Medium,
Hard,
Adversarial,
}
impl Difficulty {
pub fn all() -> &'static [Difficulty] {
&[
Difficulty::Easy,
Difficulty::Medium,
Difficulty::Hard,
Difficulty::Adversarial,
]
}
pub fn is_challenging(&self) -> bool {
matches!(self, Difficulty::Hard | Difficulty::Adversarial)
}
pub fn score(&self) -> u8 {
match self {
Difficulty::Easy => 0,
Difficulty::Medium => 1,
Difficulty::Hard => 2,
Difficulty::Adversarial => 3,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnnotatedExample {
pub text: String,
pub entities: Vec<GoldEntity>,
pub domain: Domain,
pub difficulty: Difficulty,
}
impl AnnotatedExample {
pub fn new(text: impl Into<String>, entities: Vec<GoldEntity>) -> Self {
Self {
text: text.into(),
entities,
domain: Domain::default(),
difficulty: Difficulty::default(),
}
}
#[must_use]
pub fn simple(text: impl Into<String>, entities: Vec<(&str, &str)>) -> Self {
Self::from_tuples(text, entities)
}
pub fn with_metadata(
text: impl Into<String>,
entities: Vec<GoldEntity>,
domain: Domain,
difficulty: Difficulty,
) -> Self {
Self {
text: text.into(),
entities,
domain,
difficulty,
}
}
pub fn from_tuples(text: impl Into<String>, entities: Vec<(&str, &str)>) -> Self {
let text = text.into();
let gold_entities = entities
.into_iter()
.map(|(entity_text, entity_type_str)| {
let start = text.find(entity_text).unwrap_or_else(|| {
panic!("Entity '{}' not found in text '{}'", entity_text, text)
});
let entity_type = EntityType::from_label(entity_type_str);
GoldEntity::new(entity_text, entity_type, start)
})
.collect();
Self {
text,
entities: gold_entities,
domain: Domain::default(),
difficulty: Difficulty::default(),
}
}
pub fn with_domain(mut self, domain: Domain) -> Self {
self.domain = domain;
self
}
pub fn with_difficulty(mut self, difficulty: Difficulty) -> Self {
self.difficulty = difficulty;
self
}
pub fn is_negative(&self) -> bool {
self.entities.is_empty()
}
pub fn entity_count(&self) -> usize {
self.entities.len()
}
pub fn entity_types(&self) -> Vec<&EntityType> {
let mut types: Vec<_> = self.entities.iter().map(|e| &e.entity_type).collect();
types.sort_by_key(|t| format!("{:?}", t));
types.dedup();
types
}
pub fn as_test_case(&self) -> (&str, &[GoldEntity]) {
(&self.text, &self.entities)
}
pub fn into_test_case(self) -> (String, Vec<GoldEntity>) {
(self.text, self.entities)
}
}
pub(crate) mod helpers {
use super::*;
pub fn entity(text: &str, entity_type: EntityType, start: usize) -> GoldEntity {
GoldEntity::new(text, entity_type, start)
}
pub fn disease(text: &str, start: usize) -> GoldEntity {
GoldEntity::new(
text,
EntityType::Custom {
name: "DISEASE".to_string(),
category: anno::EntityCategory::Misc,
},
start,
)
}
pub fn drug(text: &str, start: usize) -> GoldEntity {
GoldEntity::new(
text,
EntityType::Custom {
name: "DRUG".to_string(),
category: anno::EntityCategory::Misc,
},
start,
)
}
pub fn gene(text: &str, start: usize) -> GoldEntity {
GoldEntity::new(
text,
EntityType::Custom {
name: "GENE".to_string(),
category: anno::EntityCategory::Misc,
},
start,
)
}
pub fn chemical(text: &str, start: usize) -> GoldEntity {
GoldEntity::new(
text,
EntityType::Custom {
name: "CHEMICAL".to_string(),
category: anno::EntityCategory::Misc,
},
start,
)
}
pub fn entity_email(text: &str, start: usize) -> GoldEntity {
GoldEntity::new(
text,
EntityType::Custom {
name: "EMAIL".to_string(),
category: anno::EntityCategory::Contact,
},
start,
)
}
pub fn entity_url(text: &str, start: usize) -> GoldEntity {
GoldEntity::new(
text,
EntityType::Custom {
name: "URL".to_string(),
category: anno::EntityCategory::Misc,
},
start,
)
}
pub fn entity_phone(text: &str, start: usize) -> GoldEntity {
GoldEntity::new(
text,
EntityType::Custom {
name: "PHONE".to_string(),
category: anno::EntityCategory::Contact,
},
start,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_domain_all() {
let domains = Domain::all();
assert!(domains.len() >= 20);
assert!(domains.contains(&Domain::News));
assert!(domains.contains(&Domain::Biomedical));
}
#[test]
fn test_difficulty_ordering() {
assert!(Difficulty::Easy.score() < Difficulty::Medium.score());
assert!(Difficulty::Medium.score() < Difficulty::Hard.score());
assert!(Difficulty::Hard.score() < Difficulty::Adversarial.score());
}
#[test]
fn test_annotated_example_from_tuples() {
let example = AnnotatedExample::from_tuples(
"John works at Google in NYC.",
vec![("John", "PER"), ("Google", "ORG"), ("NYC", "LOC")],
);
assert_eq!(example.entities.len(), 3);
assert_eq!(example.entities[0].text, "John");
assert_eq!(example.entities[0].start, 0);
assert_eq!(example.entities[1].text, "Google");
assert_eq!(example.entities[1].start, 14);
}
#[test]
fn test_annotated_example_builder() {
let example = AnnotatedExample::new("Test text", vec![])
.with_domain(Domain::Biomedical)
.with_difficulty(Difficulty::Hard);
assert_eq!(example.domain, Domain::Biomedical);
assert_eq!(example.difficulty, Difficulty::Hard);
}
#[test]
fn test_is_negative() {
let positive = AnnotatedExample::from_tuples("John is here", vec![("John", "PER")]);
let negative = AnnotatedExample::new("No entities here", vec![]);
assert!(!positive.is_negative());
assert!(negative.is_negative());
}
}