use serde::{Deserialize, Serialize};
use std::fmt;
use std::hash::Hash;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[non_exhaustive]
pub enum Task {
NER,
IntraDocCoref,
InterDocCoref,
NED,
RelationExtraction,
EventExtraction,
DiscontinuousNER,
VisualNER,
TemporalNER,
AspectExtraction,
SlotFilling,
POS,
DependencyParsing,
}
impl Task {
#[must_use]
pub const fn produces_entities(&self) -> bool {
matches!(
self,
Self::NER
| Self::DiscontinuousNER
| Self::VisualNER
| Self::TemporalNER
| Self::AspectExtraction
| Self::SlotFilling
)
}
#[must_use]
pub const fn involves_coreference(&self) -> bool {
matches!(self, Self::IntraDocCoref | Self::InterDocCoref)
}
#[must_use]
pub const fn involves_kb_linking(&self) -> bool {
matches!(self, Self::NED)
}
#[must_use]
pub const fn involves_relations(&self) -> bool {
matches!(self, Self::RelationExtraction | Self::EventExtraction)
}
}
impl fmt::Display for Task {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::NER => write!(f, "NER"),
Self::IntraDocCoref => write!(f, "Intra-Doc Coreference"),
Self::InterDocCoref => write!(f, "Inter-Doc Coreference"),
Self::NED => write!(f, "Named Entity Disambiguation"),
Self::RelationExtraction => write!(f, "Relation Extraction"),
Self::EventExtraction => write!(f, "Event Extraction"),
Self::DiscontinuousNER => write!(f, "Discontinuous NER"),
Self::VisualNER => write!(f, "Visual NER"),
Self::TemporalNER => write!(f, "Temporal NER"),
Self::AspectExtraction => write!(f, "Aspect Extraction"),
Self::SlotFilling => write!(f, "Slot Filling"),
Self::POS => write!(f, "POS Tagging"),
Self::DependencyParsing => write!(f, "Dependency Parsing"),
}
}
}
impl std::str::FromStr for Task {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"ner" | "named_entity_recognition" | "sequence_labeling" => Ok(Self::NER),
"coref" | "coreference" | "intra_doc_coref" | "intradoccoref" => {
Ok(Self::IntraDocCoref)
}
"cdcr" | "inter_doc_coref" | "interdoccoref" | "cross_doc_coref" => {
Ok(Self::InterDocCoref)
}
"ned" | "el" | "entity_linking" | "disambiguation" => Ok(Self::NED),
"re" | "relation_extraction" | "relations" => Ok(Self::RelationExtraction),
"event" | "event_extraction" | "events" => Ok(Self::EventExtraction),
"discontinuous" | "discontinuous_ner" | "nested" | "nested_ner" => {
Ok(Self::DiscontinuousNER)
}
"visual" | "visual_ner" | "document_ner" => Ok(Self::VisualNER),
"temporal" | "temporal_ner" | "timex" => Ok(Self::TemporalNER),
"aspect" | "aspect_extraction" | "absa" => Ok(Self::AspectExtraction),
"slot" | "slot_filling" | "intent" => Ok(Self::SlotFilling),
"pos" | "pos_tagging" | "part_of_speech" => Ok(Self::POS),
"dep" | "dependency" | "dependency_parsing" => Ok(Self::DependencyParsing),
_ => Err(format!(
"Unknown task: '{}'. Valid: ner, coref, ned, re, event, ...",
s
)),
}
}
}
impl Task {
pub const ALL: &'static [Task] = &[
Task::NER,
Task::IntraDocCoref,
Task::InterDocCoref,
Task::NED,
Task::RelationExtraction,
Task::EventExtraction,
Task::DiscontinuousNER,
Task::VisualNER,
Task::TemporalNER,
Task::AspectExtraction,
Task::SlotFilling,
Task::POS,
Task::DependencyParsing,
];
#[must_use]
pub const fn code(&self) -> &'static str {
match self {
Self::NER => "ner",
Self::IntraDocCoref => "coref",
Self::InterDocCoref => "cdcr",
Self::NED => "el",
Self::RelationExtraction => "re",
Self::EventExtraction => "event",
Self::DiscontinuousNER => "discontinuous",
Self::VisualNER => "visual",
Self::TemporalNER => "temporal",
Self::AspectExtraction => "aspect",
Self::SlotFilling => "slot",
Self::POS => "pos",
Self::DependencyParsing => "dep",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[non_exhaustive]
pub enum ParserHint {
#[default]
CoNLL,
CoNLLU,
JSON,
JSONL,
HuggingFaceAPI,
BRAT,
XML,
ACE,
OntoNotes,
Custom,
}
impl ParserHint {
#[must_use]
pub const fn typical_extensions(&self) -> &'static [&'static str] {
match self {
Self::CoNLL => &["conll", "txt", "bio"],
Self::CoNLLU => &["conllu"],
Self::JSON => &["json"],
Self::JSONL => &["jsonl", "ndjson"],
Self::HuggingFaceAPI => &["json"],
Self::BRAT => &["ann", "txt"],
Self::XML | Self::ACE => &["xml", "sgml"],
Self::OntoNotes => &["onf", "name"],
Self::Custom => &[],
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[non_exhaustive]
pub enum License {
CCBY,
CCBYSA,
CCBYNC,
CCBYNCSA,
CC0,
MIT,
Apache2,
GPL,
LDC,
ResearchOnly,
Proprietary,
#[default]
Unknown,
Other(String),
}
impl License {
#[must_use]
pub fn allows_commercial(&self) -> bool {
matches!(
self,
Self::CCBY | Self::CCBYSA | Self::CC0 | Self::MIT | Self::Apache2
)
}
#[must_use]
pub fn allows_redistribution(&self) -> bool {
!matches!(self, Self::LDC | Self::Proprietary | Self::ResearchOnly)
}
}
impl fmt::Display for License {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::CCBY => write!(f, "CC BY 4.0"),
Self::CCBYSA => write!(f, "CC BY-SA 4.0"),
Self::CCBYNC => write!(f, "CC BY-NC 4.0"),
Self::CCBYNCSA => write!(f, "CC BY-NC-SA 4.0"),
Self::CC0 => write!(f, "CC0 (Public Domain)"),
Self::MIT => write!(f, "MIT"),
Self::Apache2 => write!(f, "Apache 2.0"),
Self::GPL => write!(f, "GPL"),
Self::LDC => write!(f, "LDC"),
Self::ResearchOnly => write!(f, "Research Only"),
Self::Proprietary => write!(f, "Proprietary"),
Self::Unknown => write!(f, "Unknown"),
Self::Other(s) => write!(f, "{s}"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[non_exhaustive]
pub enum Domain {
News,
Biomedical,
Scientific,
Legal,
Financial,
SocialMedia,
Wikipedia,
Literary,
Historical,
Dialogue,
Technical,
Web,
Cybersecurity,
Music,
#[default]
Mixed,
Other(String),
}
impl fmt::Display for Domain {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::News => write!(f, "News"),
Self::Biomedical => write!(f, "Biomedical"),
Self::Scientific => write!(f, "Scientific"),
Self::Legal => write!(f, "Legal"),
Self::Financial => write!(f, "Financial"),
Self::SocialMedia => write!(f, "Social Media"),
Self::Wikipedia => write!(f, "Wikipedia"),
Self::Literary => write!(f, "Literary"),
Self::Historical => write!(f, "Historical"),
Self::Dialogue => write!(f, "Dialogue"),
Self::Technical => write!(f, "Technical"),
Self::Web => write!(f, "Web"),
Self::Cybersecurity => write!(f, "Cybersecurity"),
Self::Music => write!(f, "Music"),
Self::Mixed => write!(f, "Mixed"),
Self::Other(s) => write!(f, "{s}"),
}
}
}
#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
pub struct TemporalCoverage {
pub start_year: Option<i32>,
pub end_year: Option<i32>,
pub has_temporal_annotations: bool,
pub has_diachronic_entities: bool,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
pub struct DatasetStats {
pub doc_count: Option<usize>,
pub mention_count: Option<usize>,
pub entity_count: Option<usize>,
pub token_count: Option<usize>,
pub split_sizes: Option<SplitSizes>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct SplitSizes {
pub train: usize,
pub dev: usize,
pub test: usize,
}
pub trait DatasetSpec: Send + Sync {
fn name(&self) -> &str;
fn id(&self) -> &str;
fn task(&self) -> Task;
fn languages(&self) -> &[&str];
fn entity_types(&self) -> &[&str];
fn parser_hint(&self) -> ParserHint;
fn license(&self) -> License;
fn description(&self) -> Option<&str> {
None
}
fn domain(&self) -> Domain {
Domain::Mixed
}
fn download_url(&self) -> Option<&str> {
None
}
fn citation(&self) -> Option<&str> {
None
}
fn doi(&self) -> Option<&str> {
None
}
fn local_path(&self) -> Option<&std::path::Path> {
None
}
fn stats(&self) -> DatasetStats {
DatasetStats::default()
}
fn temporal_coverage(&self) -> TemporalCoverage {
TemporalCoverage::default()
}
fn secondary_tasks(&self) -> &[Task] {
&[]
}
fn is_constructed_language(&self) -> bool {
false
}
fn is_historical(&self) -> bool {
false
}
fn requires_auth(&self) -> bool {
false
}
fn version(&self) -> Option<&str> {
None
}
fn notes(&self) -> Option<&str> {
None
}
fn languages_vec(&self) -> Vec<String> {
self.languages().iter().map(|s| (*s).to_string()).collect()
}
fn entity_types_vec(&self) -> Vec<String> {
self.entity_types()
.iter()
.map(|s| (*s).to_string())
.collect()
}
fn is_public(&self) -> bool {
self.license().allows_redistribution() && !self.requires_auth()
}
fn supports_task(&self, task: Task) -> bool {
self.task() == task || self.secondary_tasks().contains(&task)
}
fn supports_language(&self, lang: &str) -> bool {
let langs = self.languages_vec();
langs.iter().any(|l| l == "multilingual" || l == lang)
}
fn has_entity_type(&self, entity_type: &str) -> bool {
self.entity_types_vec()
.iter()
.any(|t| t.eq_ignore_ascii_case(entity_type))
}
}
#[derive(Debug, Clone)]
pub struct CustomDataset {
id: String,
name: String,
task: Task,
languages: Vec<String>,
entity_types: Vec<String>,
parser_hint: ParserHint,
license: License,
description: Option<String>,
domain: Domain,
download_url: Option<String>,
local_path: Option<std::path::PathBuf>,
stats: DatasetStats,
temporal_coverage: TemporalCoverage,
secondary_tasks: Vec<Task>,
is_constructed: bool,
is_historical: bool,
requires_auth: bool,
version: Option<String>,
notes: Option<String>,
citation: Option<String>,
}
impl CustomDataset {
#[must_use]
pub fn new(id: impl Into<String>, task: Task) -> Self {
let id = id.into();
Self {
name: id.clone(),
id,
task,
languages: vec!["en".to_string()],
entity_types: vec![],
parser_hint: ParserHint::CoNLL,
license: License::Unknown,
description: None,
domain: Domain::Mixed,
download_url: None,
local_path: None,
stats: DatasetStats::default(),
temporal_coverage: TemporalCoverage::default(),
secondary_tasks: vec![],
is_constructed: false,
is_historical: false,
requires_auth: false,
version: None,
notes: None,
citation: None,
}
}
#[must_use]
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = name.into();
self
}
#[must_use]
pub fn with_languages(mut self, langs: &[&str]) -> Self {
self.languages = langs.iter().map(|s| (*s).to_string()).collect();
self
}
#[must_use]
pub fn with_entity_types(mut self, types: &[&str]) -> Self {
self.entity_types = types.iter().map(|s| (*s).to_string()).collect();
self
}
#[must_use]
pub fn with_parser(mut self, parser: ParserHint) -> Self {
self.parser_hint = parser;
self
}
#[must_use]
pub fn with_license(mut self, license: License) -> Self {
self.license = license;
self
}
#[must_use]
pub fn with_description(mut self, desc: impl Into<String>) -> Self {
self.description = Some(desc.into());
self
}
#[must_use]
pub fn with_domain(mut self, domain: Domain) -> Self {
self.domain = domain;
self
}
#[must_use]
pub fn with_url(mut self, url: impl Into<String>) -> Self {
self.download_url = Some(url.into());
self
}
#[must_use]
pub fn with_path(mut self, path: std::path::PathBuf) -> Self {
self.local_path = Some(path);
self
}
#[must_use]
pub fn with_stats(mut self, stats: DatasetStats) -> Self {
self.stats = stats;
self
}
#[must_use]
pub fn with_temporal_coverage(mut self, coverage: TemporalCoverage) -> Self {
self.temporal_coverage = coverage;
self
}
#[must_use]
pub fn with_secondary_tasks(mut self, tasks: Vec<Task>) -> Self {
self.secondary_tasks = tasks;
self
}
#[must_use]
pub fn constructed(mut self) -> Self {
self.is_constructed = true;
self
}
#[must_use]
pub fn historical(mut self) -> Self {
self.is_historical = true;
self
}
#[must_use]
pub fn requires_authentication(mut self) -> Self {
self.requires_auth = true;
self
}
#[must_use]
pub fn with_version(mut self, version: impl Into<String>) -> Self {
self.version = Some(version.into());
self
}
#[must_use]
pub fn languages_owned(&self) -> &[String] {
&self.languages
}
#[must_use]
pub fn entity_types_owned(&self) -> &[String] {
&self.entity_types
}
#[must_use]
pub fn with_notes(mut self, notes: impl Into<String>) -> Self {
self.notes = Some(notes.into());
self
}
#[must_use]
pub fn with_citation(mut self, citation: impl Into<String>) -> Self {
self.citation = Some(citation.into());
self
}
}
impl DatasetSpec for CustomDataset {
fn name(&self) -> &str {
&self.name
}
fn id(&self) -> &str {
&self.id
}
fn task(&self) -> Task {
self.task
}
fn languages(&self) -> &[&str] {
static EMPTY: &[&str] = &[];
EMPTY
}
fn entity_types(&self) -> &[&str] {
static EMPTY: &[&str] = &[];
EMPTY
}
fn parser_hint(&self) -> ParserHint {
self.parser_hint
}
fn license(&self) -> License {
self.license.clone()
}
fn description(&self) -> Option<&str> {
self.description.as_deref()
}
fn domain(&self) -> Domain {
self.domain.clone()
}
fn download_url(&self) -> Option<&str> {
self.download_url.as_deref()
}
fn local_path(&self) -> Option<&std::path::Path> {
self.local_path.as_deref()
}
fn stats(&self) -> DatasetStats {
self.stats.clone()
}
fn temporal_coverage(&self) -> TemporalCoverage {
self.temporal_coverage.clone()
}
fn secondary_tasks(&self) -> &[Task] {
&self.secondary_tasks
}
fn is_constructed_language(&self) -> bool {
self.is_constructed
}
fn is_historical(&self) -> bool {
self.is_historical
}
fn requires_auth(&self) -> bool {
self.requires_auth
}
fn version(&self) -> Option<&str> {
self.version.as_deref()
}
fn notes(&self) -> Option<&str> {
self.notes.as_deref()
}
fn citation(&self) -> Option<&str> {
self.citation.as_deref()
}
fn languages_vec(&self) -> Vec<String> {
self.languages.clone()
}
fn entity_types_vec(&self) -> Vec<String> {
self.entity_types.clone()
}
}
#[derive(Default)]
pub struct DatasetRegistry {
datasets: std::collections::HashMap<String, Box<dyn DatasetSpec>>,
}
impl DatasetRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn register(
&mut self,
dataset: impl DatasetSpec + 'static,
) -> Option<Box<dyn DatasetSpec>> {
let id = dataset.id().to_string();
self.datasets.insert(id, Box::new(dataset))
}
#[must_use]
pub fn get(&self, id: &str) -> Option<&dyn DatasetSpec> {
self.datasets.get(id).map(|b| &**b)
}
pub fn unregister(&mut self, id: &str) -> Option<Box<dyn DatasetSpec>> {
self.datasets.remove(id)
}
#[must_use]
pub fn list_ids(&self) -> Vec<&str> {
self.datasets.keys().map(|s| s.as_str()).collect()
}
#[must_use]
pub fn len(&self) -> usize {
self.datasets.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.datasets.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = (&str, &dyn DatasetSpec)> {
self.datasets.iter().map(|(k, v)| (k.as_str(), &**v))
}
pub fn by_task(&self, task: Task) -> impl Iterator<Item = &dyn DatasetSpec> {
self.datasets
.values()
.filter(move |d| d.supports_task(task))
.map(|b| &**b)
}
pub fn by_language<'a>(&'a self, lang: &'a str) -> impl Iterator<Item = &'a dyn DatasetSpec> {
self.datasets
.values()
.filter(move |d| d.supports_language(lang))
.map(|b| &**b)
}
pub fn by_domain(&self, domain: Domain) -> impl Iterator<Item = &dyn DatasetSpec> {
self.datasets
.values()
.filter(move |d| d.domain() == domain)
.map(|b| &**b)
}
pub fn public_only(&self) -> impl Iterator<Item = &dyn DatasetSpec> {
self.datasets
.values()
.filter(|d| d.is_public())
.map(|b| &**b)
}
pub fn historical(&self) -> impl Iterator<Item = &dyn DatasetSpec> {
self.datasets
.values()
.filter(|d| d.is_historical())
.map(|b| &**b)
}
pub fn with_entity_type<'a>(
&'a self,
entity_type: &'a str,
) -> impl Iterator<Item = &'a dyn DatasetSpec> {
self.datasets
.values()
.filter(move |d| d.has_entity_type(entity_type))
.map(|b| &**b)
}
#[must_use]
pub fn summary(&self) -> RegistrySummary {
let mut tasks = std::collections::HashMap::new();
let mut domains = std::collections::HashMap::new();
let mut languages = std::collections::HashSet::new();
for ds in self.datasets.values() {
*tasks.entry(ds.task()).or_insert(0) += 1;
*domains.entry(ds.domain()).or_insert(0) += 1;
for lang in ds.languages_vec() {
languages.insert(lang);
}
}
RegistrySummary {
total: self.datasets.len(),
by_task: tasks,
by_domain: domains,
languages: languages.into_iter().collect(),
}
}
}
#[derive(Debug, Clone)]
pub struct RegistrySummary {
pub total: usize,
pub by_task: std::collections::HashMap<Task, usize>,
pub by_domain: std::collections::HashMap<Domain, usize>,
pub languages: Vec<String>,
}
impl fmt::Debug for DatasetRegistry {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DatasetRegistry")
.field("count", &self.datasets.len())
.field("ids", &self.list_ids())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_custom_dataset_creation() {
let dataset = CustomDataset::new("test_ner", Task::NER)
.with_name("Test NER Dataset")
.with_languages(&["en", "de"])
.with_entity_types(&["PER", "LOC", "ORG"])
.with_license(License::MIT)
.with_domain(Domain::News);
assert_eq!(dataset.id(), "test_ner");
assert_eq!(dataset.name(), "Test NER Dataset");
assert_eq!(dataset.task(), Task::NER);
assert!(dataset.languages_owned().contains(&"en".to_string()));
assert!(dataset.languages_owned().contains(&"de".to_string()));
assert!(!dataset.languages_owned().contains(&"fr".to_string()));
assert!(dataset
.entity_types_owned()
.iter()
.any(|t| t.eq_ignore_ascii_case("PER")));
assert!(dataset
.entity_types_owned()
.iter()
.any(|t| t.eq_ignore_ascii_case("per"))); assert!(dataset.is_public());
}
#[test]
fn test_registry() {
let mut registry = DatasetRegistry::new();
let dataset1 = CustomDataset::new("ds1", Task::NER)
.with_name("Dataset 1")
.with_languages(&["en"]);
let dataset2 = CustomDataset::new("ds2", Task::IntraDocCoref)
.with_name("Dataset 2")
.with_languages(&["de"]);
registry.register(dataset1);
registry.register(dataset2);
assert_eq!(registry.len(), 2);
assert!(registry.get("ds1").is_some());
assert!(registry.get("ds2").is_some());
assert!(registry.get("ds3").is_none());
let ner_datasets: Vec<_> = registry.by_task(Task::NER).collect();
assert_eq!(ner_datasets.len(), 1);
assert_eq!(ner_datasets[0].id(), "ds1");
}
#[test]
fn test_task_properties() {
assert!(Task::NER.produces_entities());
assert!(!Task::IntraDocCoref.produces_entities());
assert!(Task::IntraDocCoref.involves_coreference());
assert!(Task::InterDocCoref.involves_coreference());
assert!(!Task::NER.involves_coreference());
assert!(Task::NED.involves_kb_linking());
assert!(Task::RelationExtraction.involves_relations());
}
#[test]
fn test_license_properties() {
assert!(License::MIT.allows_commercial());
assert!(License::MIT.allows_redistribution());
assert!(!License::LDC.allows_redistribution());
assert!(!License::ResearchOnly.allows_commercial());
}
#[test]
fn test_parser_extensions() {
assert!(ParserHint::CoNLL.typical_extensions().contains(&"conll"));
assert!(ParserHint::JSONL.typical_extensions().contains(&"jsonl"));
}
#[test]
fn test_task_from_str() {
assert_eq!("ner".parse::<Task>().expect("task parse"), Task::NER);
assert_eq!("NER".parse::<Task>().expect("task parse"), Task::NER);
assert_eq!(
"coref".parse::<Task>().expect("task parse"),
Task::IntraDocCoref
);
assert_eq!(
"cdcr".parse::<Task>().expect("task parse"),
Task::InterDocCoref
);
assert_eq!("el".parse::<Task>().expect("task parse"), Task::NED);
assert_eq!(
"entity_linking".parse::<Task>().expect("task parse"),
Task::NED
);
assert_eq!(
"re".parse::<Task>().expect("task parse"),
Task::RelationExtraction
);
assert!("invalid_task".parse::<Task>().is_err());
}
#[test]
fn test_task_code() {
assert_eq!(Task::NER.code(), "ner");
assert_eq!(Task::IntraDocCoref.code(), "coref");
assert_eq!(Task::NED.code(), "el");
assert_eq!(Task::RelationExtraction.code(), "re");
}
#[test]
fn test_task_all_variants() {
assert!(Task::ALL.contains(&Task::NER));
assert!(Task::ALL.contains(&Task::IntraDocCoref));
assert!(Task::ALL.contains(&Task::NED));
assert_eq!(Task::ALL.len(), 13); }
#[test]
fn test_registry_filtering() {
let mut registry = DatasetRegistry::new();
registry.register(
CustomDataset::new("biomedical_ner", Task::NER)
.with_languages(&["en"])
.with_domain(Domain::Biomedical)
.with_entity_types(&["DISEASE", "DRUG"]),
);
registry.register(
CustomDataset::new("news_coref", Task::IntraDocCoref)
.with_languages(&["en", "de"])
.with_domain(Domain::News),
);
registry.register(
CustomDataset::new("sanskrit_edl", Task::NED)
.with_languages(&["sa"])
.with_domain(Domain::Literary)
.historical(),
);
let bio: Vec<_> = registry.by_domain(Domain::Biomedical).collect();
assert_eq!(bio.len(), 1);
assert_eq!(bio[0].id(), "biomedical_ner");
let german: Vec<_> = registry.by_language("de").collect();
assert_eq!(german.len(), 1);
assert_eq!(german[0].id(), "news_coref");
let historical: Vec<_> = registry.historical().collect();
assert_eq!(historical.len(), 1);
assert_eq!(historical[0].id(), "sanskrit_edl");
let disease: Vec<_> = registry.with_entity_type("DISEASE").collect();
assert_eq!(disease.len(), 1);
}
#[test]
fn test_registry_summary() {
let mut registry = DatasetRegistry::new();
registry.register(CustomDataset::new("a", Task::NER).with_languages(&["en"]));
registry.register(CustomDataset::new("b", Task::NER).with_languages(&["de"]));
registry.register(CustomDataset::new("c", Task::IntraDocCoref).with_languages(&["en"]));
let summary = registry.summary();
assert_eq!(summary.total, 3);
assert_eq!(summary.by_task.get(&Task::NER), Some(&2));
assert_eq!(summary.by_task.get(&Task::IntraDocCoref), Some(&1));
assert!(summary.languages.contains(&"en".to_string()));
assert!(summary.languages.contains(&"de".to_string()));
}
#[test]
fn test_historical_custom_dataset_smoke() {
let ds = CustomDataset::new("historical_edl", Task::NED)
.with_name("Historical EDL (example)")
.with_languages(&["sa"])
.with_entity_types(&["Person", "Location"])
.with_parser(ParserHint::CoNLLU)
.with_license(License::CCBY)
.with_domain(Domain::Literary)
.with_secondary_tasks(vec![Task::IntraDocCoref, Task::NER])
.with_stats(DatasetStats {
doc_count: Some(10),
mention_count: Some(100),
..Default::default()
})
.with_citation("Example citation")
.historical();
assert_eq!(ds.task(), Task::NED);
assert!(ds.supports_language("sa"));
assert!(ds.is_historical());
assert!(ds.is_public());
}
#[test]
fn test_domain_display() {
assert_eq!(format!("{}", Domain::Biomedical), "Biomedical");
assert_eq!(format!("{}", Domain::Literary), "Literary");
assert_eq!(format!("{}", Domain::Other("custom".into())), "custom");
}
#[test]
fn test_license_display() {
assert_eq!(format!("{}", License::CCBY), "CC BY 4.0");
assert_eq!(format!("{}", License::MIT), "MIT");
assert_eq!(format!("{}", License::LDC), "LDC");
}
#[test]
fn test_temporal_coverage() {
let cov = TemporalCoverage {
start_year: Some(2010),
end_year: Some(2020),
has_temporal_annotations: true,
has_diachronic_entities: false,
};
assert_eq!(cov.start_year, Some(2010));
assert!(cov.has_temporal_annotations);
}
#[test]
fn test_split_sizes() {
let splits = SplitSizes {
train: 1000,
dev: 100,
test: 200,
};
assert_eq!(splits.train + splits.dev + splits.test, 1300);
}
}