1use serde::{Deserialize, Serialize};
116use std::fmt;
117use std::hash::Hash;
118
119#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
128#[non_exhaustive]
129pub enum Task {
130 NER,
132 IntraDocCoref,
134 InterDocCoref,
136 NED,
138 RelationExtraction,
140 EventExtraction,
142 DiscontinuousNER,
144 VisualNER,
146 TemporalNER,
148 AspectExtraction,
150 SlotFilling,
152 POS,
154 DependencyParsing,
156}
157
158impl Task {
159 #[must_use]
161 pub const fn produces_entities(&self) -> bool {
162 matches!(
163 self,
164 Self::NER
165 | Self::DiscontinuousNER
166 | Self::VisualNER
167 | Self::TemporalNER
168 | Self::AspectExtraction
169 | Self::SlotFilling
170 )
171 }
172
173 #[must_use]
175 pub const fn involves_coreference(&self) -> bool {
176 matches!(self, Self::IntraDocCoref | Self::InterDocCoref)
177 }
178
179 #[must_use]
181 pub const fn involves_kb_linking(&self) -> bool {
182 matches!(self, Self::NED)
183 }
184
185 #[must_use]
187 pub const fn involves_relations(&self) -> bool {
188 matches!(self, Self::RelationExtraction | Self::EventExtraction)
189 }
190}
191
192impl fmt::Display for Task {
193 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
194 match self {
195 Self::NER => write!(f, "NER"),
196 Self::IntraDocCoref => write!(f, "Intra-Doc Coreference"),
197 Self::InterDocCoref => write!(f, "Inter-Doc Coreference"),
198 Self::NED => write!(f, "Named Entity Disambiguation"),
199 Self::RelationExtraction => write!(f, "Relation Extraction"),
200 Self::EventExtraction => write!(f, "Event Extraction"),
201 Self::DiscontinuousNER => write!(f, "Discontinuous NER"),
202 Self::VisualNER => write!(f, "Visual NER"),
203 Self::TemporalNER => write!(f, "Temporal NER"),
204 Self::AspectExtraction => write!(f, "Aspect Extraction"),
205 Self::SlotFilling => write!(f, "Slot Filling"),
206 Self::POS => write!(f, "POS Tagging"),
207 Self::DependencyParsing => write!(f, "Dependency Parsing"),
208 }
209 }
210}
211
212impl std::str::FromStr for Task {
213 type Err = String;
214
215 fn from_str(s: &str) -> Result<Self, Self::Err> {
228 match s.to_lowercase().as_str() {
229 "ner" | "named_entity_recognition" | "sequence_labeling" => Ok(Self::NER),
230 "coref" | "coreference" | "intra_doc_coref" | "intradoccoref" => {
231 Ok(Self::IntraDocCoref)
232 }
233 "cdcr" | "inter_doc_coref" | "interdoccoref" | "cross_doc_coref" => {
234 Ok(Self::InterDocCoref)
235 }
236 "ned" | "el" | "entity_linking" | "disambiguation" => Ok(Self::NED),
237 "re" | "relation_extraction" | "relations" => Ok(Self::RelationExtraction),
238 "event" | "event_extraction" | "events" => Ok(Self::EventExtraction),
239 "discontinuous" | "discontinuous_ner" | "nested" | "nested_ner" => {
240 Ok(Self::DiscontinuousNER)
241 }
242 "visual" | "visual_ner" | "document_ner" => Ok(Self::VisualNER),
243 "temporal" | "temporal_ner" | "timex" => Ok(Self::TemporalNER),
244 "aspect" | "aspect_extraction" | "absa" => Ok(Self::AspectExtraction),
245 "slot" | "slot_filling" | "intent" => Ok(Self::SlotFilling),
246 "pos" | "pos_tagging" | "part_of_speech" => Ok(Self::POS),
247 "dep" | "dependency" | "dependency_parsing" => Ok(Self::DependencyParsing),
248 _ => Err(format!(
249 "Unknown task: '{}'. Valid: ner, coref, ned, re, event, ...",
250 s
251 )),
252 }
253 }
254}
255
256impl Task {
257 pub const ALL: &'static [Task] = &[
259 Task::NER,
260 Task::IntraDocCoref,
261 Task::InterDocCoref,
262 Task::NED,
263 Task::RelationExtraction,
264 Task::EventExtraction,
265 Task::DiscontinuousNER,
266 Task::VisualNER,
267 Task::TemporalNER,
268 Task::AspectExtraction,
269 Task::SlotFilling,
270 Task::POS,
271 Task::DependencyParsing,
272 ];
273
274 #[must_use]
276 pub const fn code(&self) -> &'static str {
277 match self {
278 Self::NER => "ner",
279 Self::IntraDocCoref => "coref",
280 Self::InterDocCoref => "cdcr",
281 Self::NED => "el",
282 Self::RelationExtraction => "re",
283 Self::EventExtraction => "event",
284 Self::DiscontinuousNER => "discontinuous",
285 Self::VisualNER => "visual",
286 Self::TemporalNER => "temporal",
287 Self::AspectExtraction => "aspect",
288 Self::SlotFilling => "slot",
289 Self::POS => "pos",
290 Self::DependencyParsing => "dep",
291 }
292 }
293}
294
295#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
304#[non_exhaustive]
305pub enum ParserHint {
306 #[default]
308 CoNLL,
309 CoNLLU,
311 JSON,
313 JSONL,
315 HuggingFaceAPI,
317 BRAT,
319 XML,
321 ACE,
323 OntoNotes,
325 Custom,
327}
328
329impl ParserHint {
330 #[must_use]
332 pub const fn typical_extensions(&self) -> &'static [&'static str] {
333 match self {
334 Self::CoNLL => &["conll", "txt", "bio"],
335 Self::CoNLLU => &["conllu"],
336 Self::JSON => &["json"],
337 Self::JSONL => &["jsonl", "ndjson"],
338 Self::HuggingFaceAPI => &["json"],
339 Self::BRAT => &["ann", "txt"],
340 Self::XML | Self::ACE => &["xml", "sgml"],
341 Self::OntoNotes => &["onf", "name"],
342 Self::Custom => &[],
343 }
344 }
345}
346
347#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
356#[non_exhaustive]
357pub enum License {
358 CCBY,
360 CCBYSA,
362 CCBYNC,
364 CCBYNCSA,
366 CC0,
368 MIT,
370 Apache2,
372 GPL,
374 LDC,
376 ResearchOnly,
378 Proprietary,
380 #[default]
382 Unknown,
383 Other(String),
385}
386
387impl License {
388 #[must_use]
390 pub fn allows_commercial(&self) -> bool {
391 matches!(
392 self,
393 Self::CCBY | Self::CCBYSA | Self::CC0 | Self::MIT | Self::Apache2
394 )
395 }
396
397 #[must_use]
399 pub fn allows_redistribution(&self) -> bool {
400 !matches!(self, Self::LDC | Self::Proprietary | Self::ResearchOnly)
401 }
402}
403
404impl fmt::Display for License {
405 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
406 match self {
407 Self::CCBY => write!(f, "CC BY 4.0"),
408 Self::CCBYSA => write!(f, "CC BY-SA 4.0"),
409 Self::CCBYNC => write!(f, "CC BY-NC 4.0"),
410 Self::CCBYNCSA => write!(f, "CC BY-NC-SA 4.0"),
411 Self::CC0 => write!(f, "CC0 (Public Domain)"),
412 Self::MIT => write!(f, "MIT"),
413 Self::Apache2 => write!(f, "Apache 2.0"),
414 Self::GPL => write!(f, "GPL"),
415 Self::LDC => write!(f, "LDC"),
416 Self::ResearchOnly => write!(f, "Research Only"),
417 Self::Proprietary => write!(f, "Proprietary"),
418 Self::Unknown => write!(f, "Unknown"),
419 Self::Other(s) => write!(f, "{s}"),
420 }
421 }
422}
423
424#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
432#[non_exhaustive]
433pub enum Domain {
434 News,
436 Biomedical,
438 Scientific,
440 Legal,
442 Financial,
444 SocialMedia,
446 Wikipedia,
448 Literary,
450 Historical,
452 Dialogue,
454 Technical,
456 Web,
458 Cybersecurity,
460 Music,
462 #[default]
464 Mixed,
465 Other(String),
467}
468
469impl fmt::Display for Domain {
470 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
471 match self {
472 Self::News => write!(f, "News"),
473 Self::Biomedical => write!(f, "Biomedical"),
474 Self::Scientific => write!(f, "Scientific"),
475 Self::Legal => write!(f, "Legal"),
476 Self::Financial => write!(f, "Financial"),
477 Self::SocialMedia => write!(f, "Social Media"),
478 Self::Wikipedia => write!(f, "Wikipedia"),
479 Self::Literary => write!(f, "Literary"),
480 Self::Historical => write!(f, "Historical"),
481 Self::Dialogue => write!(f, "Dialogue"),
482 Self::Technical => write!(f, "Technical"),
483 Self::Web => write!(f, "Web"),
484 Self::Cybersecurity => write!(f, "Cybersecurity"),
485 Self::Music => write!(f, "Music"),
486 Self::Mixed => write!(f, "Mixed"),
487 Self::Other(s) => write!(f, "{s}"),
488 }
489 }
490}
491
492#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
501pub struct TemporalCoverage {
502 pub start_year: Option<i32>,
504 pub end_year: Option<i32>,
506 pub has_temporal_annotations: bool,
508 pub has_diachronic_entities: bool,
510}
511
512#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
518pub struct DatasetStats {
519 pub doc_count: Option<usize>,
521 pub mention_count: Option<usize>,
523 pub entity_count: Option<usize>,
525 pub token_count: Option<usize>,
527 pub split_sizes: Option<SplitSizes>,
529}
530
531#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
533pub struct SplitSizes {
534 pub train: usize,
536 pub dev: usize,
538 pub test: usize,
540}
541
542pub trait DatasetSpec: Send + Sync {
577 fn name(&self) -> &str;
583
584 fn id(&self) -> &str;
586
587 fn task(&self) -> Task;
589
590 fn languages(&self) -> &[&str];
594
595 fn entity_types(&self) -> &[&str];
600
601 fn parser_hint(&self) -> ParserHint;
603
604 fn license(&self) -> License;
606
607 fn description(&self) -> Option<&str> {
613 None
614 }
615
616 fn domain(&self) -> Domain {
618 Domain::Mixed
619 }
620
621 fn download_url(&self) -> Option<&str> {
623 None
624 }
625
626 fn citation(&self) -> Option<&str> {
628 None
629 }
630
631 fn doi(&self) -> Option<&str> {
633 None
634 }
635
636 fn local_path(&self) -> Option<&std::path::Path> {
638 None
639 }
640
641 fn stats(&self) -> DatasetStats {
643 DatasetStats::default()
644 }
645
646 fn temporal_coverage(&self) -> TemporalCoverage {
648 TemporalCoverage::default()
649 }
650
651 fn secondary_tasks(&self) -> &[Task] {
653 &[]
654 }
655
656 fn is_constructed_language(&self) -> bool {
658 false
659 }
660
661 fn is_historical(&self) -> bool {
663 false
664 }
665
666 fn requires_auth(&self) -> bool {
668 false
669 }
670
671 fn version(&self) -> Option<&str> {
673 None
674 }
675
676 fn notes(&self) -> Option<&str> {
678 None
679 }
680
681 fn languages_vec(&self) -> Vec<String> {
689 self.languages().iter().map(|s| (*s).to_string()).collect()
690 }
691
692 fn entity_types_vec(&self) -> Vec<String> {
696 self.entity_types()
697 .iter()
698 .map(|s| (*s).to_string())
699 .collect()
700 }
701
702 fn is_public(&self) -> bool {
708 self.license().allows_redistribution() && !self.requires_auth()
709 }
710
711 fn supports_task(&self, task: Task) -> bool {
713 self.task() == task || self.secondary_tasks().contains(&task)
714 }
715
716 fn supports_language(&self, lang: &str) -> bool {
718 let langs = self.languages_vec();
719 langs.iter().any(|l| l == "multilingual" || l == lang)
720 }
721
722 fn has_entity_type(&self, entity_type: &str) -> bool {
724 self.entity_types_vec()
725 .iter()
726 .any(|t| t.eq_ignore_ascii_case(entity_type))
727 }
728}
729
730#[derive(Debug, Clone)]
755pub struct CustomDataset {
756 id: String,
757 name: String,
758 task: Task,
759 languages: Vec<String>,
760 entity_types: Vec<String>,
761 parser_hint: ParserHint,
762 license: License,
763 description: Option<String>,
764 domain: Domain,
765 download_url: Option<String>,
766 local_path: Option<std::path::PathBuf>,
767 stats: DatasetStats,
768 temporal_coverage: TemporalCoverage,
769 secondary_tasks: Vec<Task>,
770 is_constructed: bool,
771 is_historical: bool,
772 requires_auth: bool,
773 version: Option<String>,
774 notes: Option<String>,
775 citation: Option<String>,
776}
777
778impl CustomDataset {
779 #[must_use]
781 pub fn new(id: impl Into<String>, task: Task) -> Self {
782 let id = id.into();
783 Self {
784 name: id.clone(),
785 id,
786 task,
787 languages: vec!["en".to_string()],
788 entity_types: vec![],
789 parser_hint: ParserHint::CoNLL,
790 license: License::Unknown,
791 description: None,
792 domain: Domain::Mixed,
793 download_url: None,
794 local_path: None,
795 stats: DatasetStats::default(),
796 temporal_coverage: TemporalCoverage::default(),
797 secondary_tasks: vec![],
798 is_constructed: false,
799 is_historical: false,
800 requires_auth: false,
801 version: None,
802 notes: None,
803 citation: None,
804 }
805 }
806
807 #[must_use]
809 pub fn with_name(mut self, name: impl Into<String>) -> Self {
810 self.name = name.into();
811 self
812 }
813
814 #[must_use]
816 pub fn with_languages(mut self, langs: &[&str]) -> Self {
817 self.languages = langs.iter().map(|s| (*s).to_string()).collect();
818 self
819 }
820
821 #[must_use]
823 pub fn with_entity_types(mut self, types: &[&str]) -> Self {
824 self.entity_types = types.iter().map(|s| (*s).to_string()).collect();
825 self
826 }
827
828 #[must_use]
830 pub fn with_parser(mut self, parser: ParserHint) -> Self {
831 self.parser_hint = parser;
832 self
833 }
834
835 #[must_use]
837 pub fn with_license(mut self, license: License) -> Self {
838 self.license = license;
839 self
840 }
841
842 #[must_use]
844 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
845 self.description = Some(desc.into());
846 self
847 }
848
849 #[must_use]
851 pub fn with_domain(mut self, domain: Domain) -> Self {
852 self.domain = domain;
853 self
854 }
855
856 #[must_use]
858 pub fn with_url(mut self, url: impl Into<String>) -> Self {
859 self.download_url = Some(url.into());
860 self
861 }
862
863 #[must_use]
865 pub fn with_path(mut self, path: std::path::PathBuf) -> Self {
866 self.local_path = Some(path);
867 self
868 }
869
870 #[must_use]
872 pub fn with_stats(mut self, stats: DatasetStats) -> Self {
873 self.stats = stats;
874 self
875 }
876
877 #[must_use]
879 pub fn with_temporal_coverage(mut self, coverage: TemporalCoverage) -> Self {
880 self.temporal_coverage = coverage;
881 self
882 }
883
884 #[must_use]
886 pub fn with_secondary_tasks(mut self, tasks: Vec<Task>) -> Self {
887 self.secondary_tasks = tasks;
888 self
889 }
890
891 #[must_use]
893 pub fn constructed(mut self) -> Self {
894 self.is_constructed = true;
895 self
896 }
897
898 #[must_use]
900 pub fn historical(mut self) -> Self {
901 self.is_historical = true;
902 self
903 }
904
905 #[must_use]
907 pub fn requires_authentication(mut self) -> Self {
908 self.requires_auth = true;
909 self
910 }
911
912 #[must_use]
914 pub fn with_version(mut self, version: impl Into<String>) -> Self {
915 self.version = Some(version.into());
916 self
917 }
918
919 #[must_use]
921 pub fn languages_owned(&self) -> &[String] {
922 &self.languages
923 }
924
925 #[must_use]
927 pub fn entity_types_owned(&self) -> &[String] {
928 &self.entity_types
929 }
930
931 #[must_use]
933 pub fn with_notes(mut self, notes: impl Into<String>) -> Self {
934 self.notes = Some(notes.into());
935 self
936 }
937
938 #[must_use]
940 pub fn with_citation(mut self, citation: impl Into<String>) -> Self {
941 self.citation = Some(citation.into());
942 self
943 }
944}
945
946impl DatasetSpec for CustomDataset {
947 fn name(&self) -> &str {
948 &self.name
949 }
950
951 fn id(&self) -> &str {
952 &self.id
953 }
954
955 fn task(&self) -> Task {
956 self.task
957 }
958
959 fn languages(&self) -> &[&str] {
960 static EMPTY: &[&str] = &[];
963 EMPTY
964 }
965
966 fn entity_types(&self) -> &[&str] {
967 static EMPTY: &[&str] = &[];
970 EMPTY
971 }
972
973 fn parser_hint(&self) -> ParserHint {
974 self.parser_hint
975 }
976
977 fn license(&self) -> License {
978 self.license.clone()
979 }
980
981 fn description(&self) -> Option<&str> {
982 self.description.as_deref()
983 }
984
985 fn domain(&self) -> Domain {
986 self.domain.clone()
987 }
988
989 fn download_url(&self) -> Option<&str> {
990 self.download_url.as_deref()
991 }
992
993 fn local_path(&self) -> Option<&std::path::Path> {
994 self.local_path.as_deref()
995 }
996
997 fn stats(&self) -> DatasetStats {
998 self.stats.clone()
999 }
1000
1001 fn temporal_coverage(&self) -> TemporalCoverage {
1002 self.temporal_coverage.clone()
1003 }
1004
1005 fn secondary_tasks(&self) -> &[Task] {
1006 &self.secondary_tasks
1007 }
1008
1009 fn is_constructed_language(&self) -> bool {
1010 self.is_constructed
1011 }
1012
1013 fn is_historical(&self) -> bool {
1014 self.is_historical
1015 }
1016
1017 fn requires_auth(&self) -> bool {
1018 self.requires_auth
1019 }
1020
1021 fn version(&self) -> Option<&str> {
1022 self.version.as_deref()
1023 }
1024
1025 fn notes(&self) -> Option<&str> {
1026 self.notes.as_deref()
1027 }
1028
1029 fn citation(&self) -> Option<&str> {
1030 self.citation.as_deref()
1031 }
1032
1033 fn languages_vec(&self) -> Vec<String> {
1035 self.languages.clone()
1036 }
1037
1038 fn entity_types_vec(&self) -> Vec<String> {
1039 self.entity_types.clone()
1040 }
1041}
1042
1043#[derive(Default)]
1052pub struct DatasetRegistry {
1053 datasets: std::collections::HashMap<String, Box<dyn DatasetSpec>>,
1054}
1055
1056impl DatasetRegistry {
1057 #[must_use]
1059 pub fn new() -> Self {
1060 Self::default()
1061 }
1062
1063 pub fn register(
1067 &mut self,
1068 dataset: impl DatasetSpec + 'static,
1069 ) -> Option<Box<dyn DatasetSpec>> {
1070 let id = dataset.id().to_string();
1071 self.datasets.insert(id, Box::new(dataset))
1072 }
1073
1074 #[must_use]
1076 pub fn get(&self, id: &str) -> Option<&dyn DatasetSpec> {
1077 self.datasets.get(id).map(|b| &**b)
1078 }
1079
1080 pub fn unregister(&mut self, id: &str) -> Option<Box<dyn DatasetSpec>> {
1082 self.datasets.remove(id)
1083 }
1084
1085 #[must_use]
1087 pub fn list_ids(&self) -> Vec<&str> {
1088 self.datasets.keys().map(|s| s.as_str()).collect()
1089 }
1090
1091 #[must_use]
1093 pub fn len(&self) -> usize {
1094 self.datasets.len()
1095 }
1096
1097 #[must_use]
1099 pub fn is_empty(&self) -> bool {
1100 self.datasets.is_empty()
1101 }
1102
1103 pub fn iter(&self) -> impl Iterator<Item = (&str, &dyn DatasetSpec)> {
1105 self.datasets.iter().map(|(k, v)| (k.as_str(), &**v))
1106 }
1107
1108 pub fn by_task(&self, task: Task) -> impl Iterator<Item = &dyn DatasetSpec> {
1110 self.datasets
1111 .values()
1112 .filter(move |d| d.supports_task(task))
1113 .map(|b| &**b)
1114 }
1115
1116 pub fn by_language<'a>(&'a self, lang: &'a str) -> impl Iterator<Item = &'a dyn DatasetSpec> {
1118 self.datasets
1119 .values()
1120 .filter(move |d| d.supports_language(lang))
1121 .map(|b| &**b)
1122 }
1123
1124 pub fn by_domain(&self, domain: Domain) -> impl Iterator<Item = &dyn DatasetSpec> {
1126 self.datasets
1127 .values()
1128 .filter(move |d| d.domain() == domain)
1129 .map(|b| &**b)
1130 }
1131
1132 pub fn public_only(&self) -> impl Iterator<Item = &dyn DatasetSpec> {
1134 self.datasets
1135 .values()
1136 .filter(|d| d.is_public())
1137 .map(|b| &**b)
1138 }
1139
1140 pub fn historical(&self) -> impl Iterator<Item = &dyn DatasetSpec> {
1142 self.datasets
1143 .values()
1144 .filter(|d| d.is_historical())
1145 .map(|b| &**b)
1146 }
1147
1148 pub fn with_entity_type<'a>(
1150 &'a self,
1151 entity_type: &'a str,
1152 ) -> impl Iterator<Item = &'a dyn DatasetSpec> {
1153 self.datasets
1154 .values()
1155 .filter(move |d| d.has_entity_type(entity_type))
1156 .map(|b| &**b)
1157 }
1158
1159 #[must_use]
1161 pub fn summary(&self) -> RegistrySummary {
1162 let mut tasks = std::collections::HashMap::new();
1163 let mut domains = std::collections::HashMap::new();
1164 let mut languages = std::collections::HashSet::new();
1165
1166 for ds in self.datasets.values() {
1167 *tasks.entry(ds.task()).or_insert(0) += 1;
1168 *domains.entry(ds.domain()).or_insert(0) += 1;
1169 for lang in ds.languages_vec() {
1170 languages.insert(lang);
1171 }
1172 }
1173
1174 RegistrySummary {
1175 total: self.datasets.len(),
1176 by_task: tasks,
1177 by_domain: domains,
1178 languages: languages.into_iter().collect(),
1179 }
1180 }
1181}
1182
1183#[derive(Debug, Clone)]
1185pub struct RegistrySummary {
1186 pub total: usize,
1188 pub by_task: std::collections::HashMap<Task, usize>,
1190 pub by_domain: std::collections::HashMap<Domain, usize>,
1192 pub languages: Vec<String>,
1194}
1195
1196impl fmt::Debug for DatasetRegistry {
1197 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1198 f.debug_struct("DatasetRegistry")
1199 .field("count", &self.datasets.len())
1200 .field("ids", &self.list_ids())
1201 .finish()
1202 }
1203}
1204
1205#[cfg(test)]
1210mod tests {
1211 use super::*;
1212
1213 #[test]
1214 fn test_custom_dataset_creation() {
1215 let dataset = CustomDataset::new("test_ner", Task::NER)
1216 .with_name("Test NER Dataset")
1217 .with_languages(&["en", "de"])
1218 .with_entity_types(&["PER", "LOC", "ORG"])
1219 .with_license(License::MIT)
1220 .with_domain(Domain::News);
1221
1222 assert_eq!(dataset.id(), "test_ner");
1223 assert_eq!(dataset.name(), "Test NER Dataset");
1224 assert_eq!(dataset.task(), Task::NER);
1225 assert!(dataset.languages_owned().contains(&"en".to_string()));
1227 assert!(dataset.languages_owned().contains(&"de".to_string()));
1228 assert!(!dataset.languages_owned().contains(&"fr".to_string()));
1229 assert!(dataset
1230 .entity_types_owned()
1231 .iter()
1232 .any(|t| t.eq_ignore_ascii_case("PER")));
1233 assert!(dataset
1234 .entity_types_owned()
1235 .iter()
1236 .any(|t| t.eq_ignore_ascii_case("per"))); assert!(dataset.is_public());
1238 }
1239
1240 #[test]
1241 fn test_registry() {
1242 let mut registry = DatasetRegistry::new();
1243
1244 let dataset1 = CustomDataset::new("ds1", Task::NER)
1245 .with_name("Dataset 1")
1246 .with_languages(&["en"]);
1247
1248 let dataset2 = CustomDataset::new("ds2", Task::IntraDocCoref)
1249 .with_name("Dataset 2")
1250 .with_languages(&["de"]);
1251
1252 registry.register(dataset1);
1253 registry.register(dataset2);
1254
1255 assert_eq!(registry.len(), 2);
1256 assert!(registry.get("ds1").is_some());
1257 assert!(registry.get("ds2").is_some());
1258 assert!(registry.get("ds3").is_none());
1259
1260 let ner_datasets: Vec<_> = registry.by_task(Task::NER).collect();
1261 assert_eq!(ner_datasets.len(), 1);
1262 assert_eq!(ner_datasets[0].id(), "ds1");
1263 }
1264
1265 #[test]
1266 fn test_task_properties() {
1267 assert!(Task::NER.produces_entities());
1268 assert!(!Task::IntraDocCoref.produces_entities());
1269 assert!(Task::IntraDocCoref.involves_coreference());
1270 assert!(Task::InterDocCoref.involves_coreference());
1271 assert!(!Task::NER.involves_coreference());
1272 assert!(Task::NED.involves_kb_linking());
1273 assert!(Task::RelationExtraction.involves_relations());
1274 }
1275
1276 #[test]
1277 fn test_license_properties() {
1278 assert!(License::MIT.allows_commercial());
1279 assert!(License::MIT.allows_redistribution());
1280 assert!(!License::LDC.allows_redistribution());
1281 assert!(!License::ResearchOnly.allows_commercial());
1282 }
1283
1284 #[test]
1285 fn test_parser_extensions() {
1286 assert!(ParserHint::CoNLL.typical_extensions().contains(&"conll"));
1287 assert!(ParserHint::JSONL.typical_extensions().contains(&"jsonl"));
1288 }
1289
1290 #[test]
1291 fn test_task_from_str() {
1292 assert_eq!("ner".parse::<Task>().expect("task parse"), Task::NER);
1294 assert_eq!("NER".parse::<Task>().expect("task parse"), Task::NER);
1295 assert_eq!(
1296 "coref".parse::<Task>().expect("task parse"),
1297 Task::IntraDocCoref
1298 );
1299 assert_eq!(
1300 "cdcr".parse::<Task>().expect("task parse"),
1301 Task::InterDocCoref
1302 );
1303 assert_eq!("el".parse::<Task>().expect("task parse"), Task::NED);
1304 assert_eq!(
1305 "entity_linking".parse::<Task>().expect("task parse"),
1306 Task::NED
1307 );
1308 assert_eq!(
1309 "re".parse::<Task>().expect("task parse"),
1310 Task::RelationExtraction
1311 );
1312
1313 assert!("invalid_task".parse::<Task>().is_err());
1315 }
1316
1317 #[test]
1318 fn test_task_code() {
1319 assert_eq!(Task::NER.code(), "ner");
1320 assert_eq!(Task::IntraDocCoref.code(), "coref");
1321 assert_eq!(Task::NED.code(), "el");
1322 assert_eq!(Task::RelationExtraction.code(), "re");
1323 }
1324
1325 #[test]
1326 fn test_task_all_variants() {
1327 assert!(Task::ALL.contains(&Task::NER));
1329 assert!(Task::ALL.contains(&Task::IntraDocCoref));
1330 assert!(Task::ALL.contains(&Task::NED));
1331 assert_eq!(Task::ALL.len(), 13); }
1333
1334 #[test]
1335 fn test_registry_filtering() {
1336 let mut registry = DatasetRegistry::new();
1337
1338 registry.register(
1340 CustomDataset::new("biomedical_ner", Task::NER)
1341 .with_languages(&["en"])
1342 .with_domain(Domain::Biomedical)
1343 .with_entity_types(&["DISEASE", "DRUG"]),
1344 );
1345 registry.register(
1346 CustomDataset::new("news_coref", Task::IntraDocCoref)
1347 .with_languages(&["en", "de"])
1348 .with_domain(Domain::News),
1349 );
1350 registry.register(
1351 CustomDataset::new("sanskrit_edl", Task::NED)
1352 .with_languages(&["sa"])
1353 .with_domain(Domain::Literary)
1354 .historical(),
1355 );
1356
1357 let bio: Vec<_> = registry.by_domain(Domain::Biomedical).collect();
1359 assert_eq!(bio.len(), 1);
1360 assert_eq!(bio[0].id(), "biomedical_ner");
1361
1362 let german: Vec<_> = registry.by_language("de").collect();
1364 assert_eq!(german.len(), 1);
1365 assert_eq!(german[0].id(), "news_coref");
1366
1367 let historical: Vec<_> = registry.historical().collect();
1369 assert_eq!(historical.len(), 1);
1370 assert_eq!(historical[0].id(), "sanskrit_edl");
1371
1372 let disease: Vec<_> = registry.with_entity_type("DISEASE").collect();
1374 assert_eq!(disease.len(), 1);
1375 }
1376
1377 #[test]
1378 fn test_registry_summary() {
1379 let mut registry = DatasetRegistry::new();
1380 registry.register(CustomDataset::new("a", Task::NER).with_languages(&["en"]));
1381 registry.register(CustomDataset::new("b", Task::NER).with_languages(&["de"]));
1382 registry.register(CustomDataset::new("c", Task::IntraDocCoref).with_languages(&["en"]));
1383
1384 let summary = registry.summary();
1385 assert_eq!(summary.total, 3);
1386 assert_eq!(summary.by_task.get(&Task::NER), Some(&2));
1387 assert_eq!(summary.by_task.get(&Task::IntraDocCoref), Some(&1));
1388 assert!(summary.languages.contains(&"en".to_string()));
1389 assert!(summary.languages.contains(&"de".to_string()));
1390 }
1391
1392 #[test]
1393 fn test_historical_custom_dataset_smoke() {
1394 let ds = CustomDataset::new("historical_edl", Task::NED)
1396 .with_name("Historical EDL (example)")
1397 .with_languages(&["sa"])
1398 .with_entity_types(&["Person", "Location"])
1399 .with_parser(ParserHint::CoNLLU)
1400 .with_license(License::CCBY)
1401 .with_domain(Domain::Literary)
1402 .with_secondary_tasks(vec![Task::IntraDocCoref, Task::NER])
1403 .with_stats(DatasetStats {
1404 doc_count: Some(10),
1405 mention_count: Some(100),
1406 ..Default::default()
1407 })
1408 .with_citation("Example citation")
1409 .historical();
1410
1411 assert_eq!(ds.task(), Task::NED);
1412 assert!(ds.supports_language("sa"));
1413 assert!(ds.is_historical());
1414 assert!(ds.is_public());
1415 }
1416
1417 #[test]
1418 fn test_domain_display() {
1419 assert_eq!(format!("{}", Domain::Biomedical), "Biomedical");
1420 assert_eq!(format!("{}", Domain::Literary), "Literary");
1421 assert_eq!(format!("{}", Domain::Other("custom".into())), "custom");
1422 }
1423
1424 #[test]
1425 fn test_license_display() {
1426 assert_eq!(format!("{}", License::CCBY), "CC BY 4.0");
1427 assert_eq!(format!("{}", License::MIT), "MIT");
1428 assert_eq!(format!("{}", License::LDC), "LDC");
1429 }
1430
1431 #[test]
1432 fn test_temporal_coverage() {
1433 let cov = TemporalCoverage {
1434 start_year: Some(2010),
1435 end_year: Some(2020),
1436 has_temporal_annotations: true,
1437 has_diachronic_entities: false,
1438 };
1439
1440 assert_eq!(cov.start_year, Some(2010));
1441 assert!(cov.has_temporal_annotations);
1442 }
1443
1444 #[test]
1445 fn test_split_sizes() {
1446 let splits = SplitSizes {
1447 train: 1000,
1448 dev: 100,
1449 test: 200,
1450 };
1451
1452 assert_eq!(splits.train + splits.dev + splits.test, 1300);
1453 }
1454}