1use serde::{Deserialize, Serialize};
116use std::fmt;
117use std::hash::Hash;
118
119#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
128#[non_exhaustive]
129pub enum Task {
130 NER,
132 IntraDocCoref,
134 InterDocCoref,
136 NED,
138 RelationExtraction,
140 EventExtraction,
142 DiscontinuousNER,
144 VisualNER,
146 TemporalNER,
148 AspectExtraction,
150 SlotFilling,
152 POS,
154 DependencyParsing,
156}
157
158impl Task {
159 #[must_use]
161 pub const fn produces_entities(&self) -> bool {
162 matches!(
163 self,
164 Self::NER
165 | Self::DiscontinuousNER
166 | Self::VisualNER
167 | Self::TemporalNER
168 | Self::AspectExtraction
169 | Self::SlotFilling
170 )
171 }
172
173 #[must_use]
175 pub const fn involves_coreference(&self) -> bool {
176 matches!(self, Self::IntraDocCoref | Self::InterDocCoref)
177 }
178
179 #[must_use]
181 pub const fn involves_kb_linking(&self) -> bool {
182 matches!(self, Self::NED)
183 }
184
185 #[must_use]
187 pub const fn involves_relations(&self) -> bool {
188 matches!(self, Self::RelationExtraction | Self::EventExtraction)
189 }
190}
191
192impl fmt::Display for Task {
193 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
194 match self {
195 Self::NER => write!(f, "NER"),
196 Self::IntraDocCoref => write!(f, "Intra-Doc Coreference"),
197 Self::InterDocCoref => write!(f, "Inter-Doc Coreference"),
198 Self::NED => write!(f, "Named Entity Disambiguation"),
199 Self::RelationExtraction => write!(f, "Relation Extraction"),
200 Self::EventExtraction => write!(f, "Event Extraction"),
201 Self::DiscontinuousNER => write!(f, "Discontinuous NER"),
202 Self::VisualNER => write!(f, "Visual NER"),
203 Self::TemporalNER => write!(f, "Temporal NER"),
204 Self::AspectExtraction => write!(f, "Aspect Extraction"),
205 Self::SlotFilling => write!(f, "Slot Filling"),
206 Self::POS => write!(f, "POS Tagging"),
207 Self::DependencyParsing => write!(f, "Dependency Parsing"),
208 }
209 }
210}
211
212impl std::str::FromStr for Task {
213 type Err = String;
214
215 fn from_str(s: &str) -> Result<Self, Self::Err> {
228 match s.to_lowercase().as_str() {
229 "ner" | "named_entity_recognition" | "sequence_labeling" => Ok(Self::NER),
230 "coref" | "coreference" | "intra_doc_coref" | "intradoccoref" => {
231 Ok(Self::IntraDocCoref)
232 }
233 "cdcr" | "inter_doc_coref" | "interdoccoref" | "cross_doc_coref" => {
234 Ok(Self::InterDocCoref)
235 }
236 "ned" | "el" | "entity_linking" | "disambiguation" => Ok(Self::NED),
237 "re" | "relation_extraction" | "relations" => Ok(Self::RelationExtraction),
238 "event" | "event_extraction" | "events" => Ok(Self::EventExtraction),
239 "discontinuous" | "discontinuous_ner" | "nested" | "nested_ner" => {
240 Ok(Self::DiscontinuousNER)
241 }
242 "visual" | "visual_ner" | "document_ner" => Ok(Self::VisualNER),
243 "temporal" | "temporal_ner" | "timex" => Ok(Self::TemporalNER),
244 "aspect" | "aspect_extraction" | "absa" => Ok(Self::AspectExtraction),
245 "slot" | "slot_filling" | "intent" => Ok(Self::SlotFilling),
246 "pos" | "pos_tagging" | "part_of_speech" => Ok(Self::POS),
247 "dep" | "dependency" | "dependency_parsing" => Ok(Self::DependencyParsing),
248 _ => Err(format!(
249 "Unknown task: '{}'. Valid: ner, coref, ned, re, event, ...",
250 s
251 )),
252 }
253 }
254}
255
256impl Task {
257 pub const ALL: &'static [Task] = &[
259 Task::NER,
260 Task::IntraDocCoref,
261 Task::InterDocCoref,
262 Task::NED,
263 Task::RelationExtraction,
264 Task::EventExtraction,
265 Task::DiscontinuousNER,
266 Task::VisualNER,
267 Task::TemporalNER,
268 Task::AspectExtraction,
269 Task::SlotFilling,
270 Task::POS,
271 Task::DependencyParsing,
272 ];
273
274 #[must_use]
276 pub const fn code(&self) -> &'static str {
277 match self {
278 Self::NER => "ner",
279 Self::IntraDocCoref => "coref",
280 Self::InterDocCoref => "cdcr",
281 Self::NED => "el",
282 Self::RelationExtraction => "re",
283 Self::EventExtraction => "event",
284 Self::DiscontinuousNER => "discontinuous",
285 Self::VisualNER => "visual",
286 Self::TemporalNER => "temporal",
287 Self::AspectExtraction => "aspect",
288 Self::SlotFilling => "slot",
289 Self::POS => "pos",
290 Self::DependencyParsing => "dep",
291 }
292 }
293}
294
295#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
304#[non_exhaustive]
305pub enum ParserHint {
306 #[default]
308 CoNLL,
309 CoNLLU,
311 JSON,
313 JSONL,
315 HuggingFaceAPI,
317 BRAT,
319 XML,
321 ACE,
323 OntoNotes,
325 Custom,
327}
328
329impl ParserHint {
330 #[must_use]
332 pub const fn typical_extensions(&self) -> &'static [&'static str] {
333 match self {
334 Self::CoNLL => &["conll", "txt", "bio"],
335 Self::CoNLLU => &["conllu"],
336 Self::JSON => &["json"],
337 Self::JSONL => &["jsonl", "ndjson"],
338 Self::HuggingFaceAPI => &["json"],
339 Self::BRAT => &["ann", "txt"],
340 Self::XML | Self::ACE => &["xml", "sgml"],
341 Self::OntoNotes => &["onf", "name"],
342 Self::Custom => &[],
343 }
344 }
345}
346
347#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
356#[non_exhaustive]
357pub enum License {
358 CCBY,
360 CCBYSA,
362 CCBYNC,
364 CCBYNCSA,
366 CC0,
368 MIT,
370 Apache2,
372 GPL,
374 LDC,
376 ResearchOnly,
378 Proprietary,
380 #[default]
382 Unknown,
383 Other(String),
385}
386
387impl License {
388 #[must_use]
390 pub fn allows_commercial(&self) -> bool {
391 matches!(
392 self,
393 Self::CCBY | Self::CCBYSA | Self::CC0 | Self::MIT | Self::Apache2
394 )
395 }
396
397 #[must_use]
399 pub fn allows_redistribution(&self) -> bool {
400 !matches!(self, Self::LDC | Self::Proprietary | Self::ResearchOnly)
401 }
402}
403
404impl fmt::Display for License {
405 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
406 match self {
407 Self::CCBY => write!(f, "CC BY 4.0"),
408 Self::CCBYSA => write!(f, "CC BY-SA 4.0"),
409 Self::CCBYNC => write!(f, "CC BY-NC 4.0"),
410 Self::CCBYNCSA => write!(f, "CC BY-NC-SA 4.0"),
411 Self::CC0 => write!(f, "CC0 (Public Domain)"),
412 Self::MIT => write!(f, "MIT"),
413 Self::Apache2 => write!(f, "Apache 2.0"),
414 Self::GPL => write!(f, "GPL"),
415 Self::LDC => write!(f, "LDC"),
416 Self::ResearchOnly => write!(f, "Research Only"),
417 Self::Proprietary => write!(f, "Proprietary"),
418 Self::Unknown => write!(f, "Unknown"),
419 Self::Other(s) => write!(f, "{s}"),
420 }
421 }
422}
423
424#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
432#[non_exhaustive]
433pub enum Domain {
434 News,
436 Biomedical,
438 Scientific,
440 Legal,
442 Financial,
444 SocialMedia,
446 Wikipedia,
448 Literary,
450 Historical,
452 Dialogue,
454 Technical,
456 Web,
458 Cybersecurity,
460 Music,
462 #[default]
464 Mixed,
465 Other(String),
467}
468
469impl fmt::Display for Domain {
470 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
471 match self {
472 Self::News => write!(f, "News"),
473 Self::Biomedical => write!(f, "Biomedical"),
474 Self::Scientific => write!(f, "Scientific"),
475 Self::Legal => write!(f, "Legal"),
476 Self::Financial => write!(f, "Financial"),
477 Self::SocialMedia => write!(f, "Social Media"),
478 Self::Wikipedia => write!(f, "Wikipedia"),
479 Self::Literary => write!(f, "Literary"),
480 Self::Historical => write!(f, "Historical"),
481 Self::Dialogue => write!(f, "Dialogue"),
482 Self::Technical => write!(f, "Technical"),
483 Self::Web => write!(f, "Web"),
484 Self::Cybersecurity => write!(f, "Cybersecurity"),
485 Self::Music => write!(f, "Music"),
486 Self::Mixed => write!(f, "Mixed"),
487 Self::Other(s) => write!(f, "{s}"),
488 }
489 }
490}
491
492#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
501pub struct TemporalCoverage {
502 pub start_year: Option<i32>,
504 pub end_year: Option<i32>,
506 pub has_temporal_annotations: bool,
508 pub has_diachronic_entities: bool,
510}
511
512#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
518pub struct DatasetStats {
519 pub doc_count: Option<usize>,
521 pub mention_count: Option<usize>,
523 pub entity_count: Option<usize>,
525 pub token_count: Option<usize>,
527 pub split_sizes: Option<SplitSizes>,
529}
530
531#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
533pub struct SplitSizes {
534 pub train: usize,
536 pub dev: usize,
538 pub test: usize,
540}
541
542pub trait DatasetSpec: Send + Sync {
576 fn name(&self) -> &str;
582
583 fn id(&self) -> &str;
585
586 fn task(&self) -> Task;
588
589 fn languages(&self) -> &[&str];
593
594 fn entity_types(&self) -> &[&str];
599
600 fn parser_hint(&self) -> ParserHint;
602
603 fn license(&self) -> License;
605
606 fn description(&self) -> Option<&str> {
612 None
613 }
614
615 fn domain(&self) -> Domain {
617 Domain::Mixed
618 }
619
620 fn download_url(&self) -> Option<&str> {
622 None
623 }
624
625 fn citation(&self) -> Option<&str> {
627 None
628 }
629
630 fn doi(&self) -> Option<&str> {
632 None
633 }
634
635 fn local_path(&self) -> Option<&std::path::Path> {
637 None
638 }
639
640 fn stats(&self) -> DatasetStats {
642 DatasetStats::default()
643 }
644
645 fn temporal_coverage(&self) -> TemporalCoverage {
647 TemporalCoverage::default()
648 }
649
650 fn secondary_tasks(&self) -> &[Task] {
652 &[]
653 }
654
655 fn is_constructed_language(&self) -> bool {
657 false
658 }
659
660 fn is_historical(&self) -> bool {
662 false
663 }
664
665 fn requires_auth(&self) -> bool {
667 false
668 }
669
670 fn version(&self) -> Option<&str> {
672 None
673 }
674
675 fn notes(&self) -> Option<&str> {
677 None
678 }
679
680 fn languages_vec(&self) -> Vec<String> {
688 self.languages().iter().map(|s| (*s).to_string()).collect()
689 }
690
691 fn entity_types_vec(&self) -> Vec<String> {
695 self.entity_types()
696 .iter()
697 .map(|s| (*s).to_string())
698 .collect()
699 }
700
701 fn is_public(&self) -> bool {
707 self.license().allows_redistribution() && !self.requires_auth()
708 }
709
710 fn supports_task(&self, task: Task) -> bool {
712 self.task() == task || self.secondary_tasks().contains(&task)
713 }
714
715 fn supports_language(&self, lang: &str) -> bool {
717 let langs = self.languages_vec();
718 langs.iter().any(|l| l == "multilingual" || l == lang)
719 }
720
721 fn has_entity_type(&self, entity_type: &str) -> bool {
723 self.entity_types_vec()
724 .iter()
725 .any(|t| t.eq_ignore_ascii_case(entity_type))
726 }
727}
728
729#[derive(Debug, Clone)]
754pub struct CustomDataset {
755 id: String,
756 name: String,
757 task: Task,
758 languages: Vec<String>,
759 entity_types: Vec<String>,
760 parser_hint: ParserHint,
761 license: License,
762 description: Option<String>,
763 domain: Domain,
764 download_url: Option<String>,
765 local_path: Option<std::path::PathBuf>,
766 stats: DatasetStats,
767 temporal_coverage: TemporalCoverage,
768 secondary_tasks: Vec<Task>,
769 is_constructed: bool,
770 is_historical: bool,
771 requires_auth: bool,
772 version: Option<String>,
773 notes: Option<String>,
774 citation: Option<String>,
775}
776
777impl CustomDataset {
778 #[must_use]
780 pub fn new(id: impl Into<String>, task: Task) -> Self {
781 let id = id.into();
782 Self {
783 name: id.clone(),
784 id,
785 task,
786 languages: vec!["en".to_string()],
787 entity_types: vec![],
788 parser_hint: ParserHint::CoNLL,
789 license: License::Unknown,
790 description: None,
791 domain: Domain::Mixed,
792 download_url: None,
793 local_path: None,
794 stats: DatasetStats::default(),
795 temporal_coverage: TemporalCoverage::default(),
796 secondary_tasks: vec![],
797 is_constructed: false,
798 is_historical: false,
799 requires_auth: false,
800 version: None,
801 notes: None,
802 citation: None,
803 }
804 }
805
806 #[must_use]
808 pub fn with_name(mut self, name: impl Into<String>) -> Self {
809 self.name = name.into();
810 self
811 }
812
813 #[must_use]
815 pub fn with_languages(mut self, langs: &[&str]) -> Self {
816 self.languages = langs.iter().map(|s| (*s).to_string()).collect();
817 self
818 }
819
820 #[must_use]
822 pub fn with_entity_types(mut self, types: &[&str]) -> Self {
823 self.entity_types = types.iter().map(|s| (*s).to_string()).collect();
824 self
825 }
826
827 #[must_use]
829 pub fn with_parser(mut self, parser: ParserHint) -> Self {
830 self.parser_hint = parser;
831 self
832 }
833
834 #[must_use]
836 pub fn with_license(mut self, license: License) -> Self {
837 self.license = license;
838 self
839 }
840
841 #[must_use]
843 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
844 self.description = Some(desc.into());
845 self
846 }
847
848 #[must_use]
850 pub fn with_domain(mut self, domain: Domain) -> Self {
851 self.domain = domain;
852 self
853 }
854
855 #[must_use]
857 pub fn with_url(mut self, url: impl Into<String>) -> Self {
858 self.download_url = Some(url.into());
859 self
860 }
861
862 #[must_use]
864 pub fn with_path(mut self, path: std::path::PathBuf) -> Self {
865 self.local_path = Some(path);
866 self
867 }
868
869 #[must_use]
871 pub fn with_stats(mut self, stats: DatasetStats) -> Self {
872 self.stats = stats;
873 self
874 }
875
876 #[must_use]
878 pub fn with_temporal_coverage(mut self, coverage: TemporalCoverage) -> Self {
879 self.temporal_coverage = coverage;
880 self
881 }
882
883 #[must_use]
885 pub fn with_secondary_tasks(mut self, tasks: Vec<Task>) -> Self {
886 self.secondary_tasks = tasks;
887 self
888 }
889
890 #[must_use]
892 pub fn constructed(mut self) -> Self {
893 self.is_constructed = true;
894 self
895 }
896
897 #[must_use]
899 pub fn historical(mut self) -> Self {
900 self.is_historical = true;
901 self
902 }
903
904 #[must_use]
906 pub fn requires_authentication(mut self) -> Self {
907 self.requires_auth = true;
908 self
909 }
910
911 #[must_use]
913 pub fn with_version(mut self, version: impl Into<String>) -> Self {
914 self.version = Some(version.into());
915 self
916 }
917
918 #[must_use]
920 pub fn languages_owned(&self) -> &[String] {
921 &self.languages
922 }
923
924 #[must_use]
926 pub fn entity_types_owned(&self) -> &[String] {
927 &self.entity_types
928 }
929
930 #[must_use]
932 pub fn with_notes(mut self, notes: impl Into<String>) -> Self {
933 self.notes = Some(notes.into());
934 self
935 }
936
937 #[must_use]
939 pub fn with_citation(mut self, citation: impl Into<String>) -> Self {
940 self.citation = Some(citation.into());
941 self
942 }
943}
944
945impl DatasetSpec for CustomDataset {
946 fn name(&self) -> &str {
947 &self.name
948 }
949
950 fn id(&self) -> &str {
951 &self.id
952 }
953
954 fn task(&self) -> Task {
955 self.task
956 }
957
958 fn languages(&self) -> &[&str] {
959 static EMPTY: &[&str] = &[];
962 EMPTY
963 }
964
965 fn entity_types(&self) -> &[&str] {
966 static EMPTY: &[&str] = &[];
969 EMPTY
970 }
971
972 fn parser_hint(&self) -> ParserHint {
973 self.parser_hint
974 }
975
976 fn license(&self) -> License {
977 self.license.clone()
978 }
979
980 fn description(&self) -> Option<&str> {
981 self.description.as_deref()
982 }
983
984 fn domain(&self) -> Domain {
985 self.domain.clone()
986 }
987
988 fn download_url(&self) -> Option<&str> {
989 self.download_url.as_deref()
990 }
991
992 fn local_path(&self) -> Option<&std::path::Path> {
993 self.local_path.as_deref()
994 }
995
996 fn stats(&self) -> DatasetStats {
997 self.stats.clone()
998 }
999
1000 fn temporal_coverage(&self) -> TemporalCoverage {
1001 self.temporal_coverage.clone()
1002 }
1003
1004 fn secondary_tasks(&self) -> &[Task] {
1005 &self.secondary_tasks
1006 }
1007
1008 fn is_constructed_language(&self) -> bool {
1009 self.is_constructed
1010 }
1011
1012 fn is_historical(&self) -> bool {
1013 self.is_historical
1014 }
1015
1016 fn requires_auth(&self) -> bool {
1017 self.requires_auth
1018 }
1019
1020 fn version(&self) -> Option<&str> {
1021 self.version.as_deref()
1022 }
1023
1024 fn notes(&self) -> Option<&str> {
1025 self.notes.as_deref()
1026 }
1027
1028 fn citation(&self) -> Option<&str> {
1029 self.citation.as_deref()
1030 }
1031
1032 fn languages_vec(&self) -> Vec<String> {
1034 self.languages.clone()
1035 }
1036
1037 fn entity_types_vec(&self) -> Vec<String> {
1038 self.entity_types.clone()
1039 }
1040}
1041
1042#[derive(Default)]
1051pub struct DatasetRegistry {
1052 datasets: std::collections::HashMap<String, Box<dyn DatasetSpec>>,
1053}
1054
1055impl DatasetRegistry {
1056 #[must_use]
1058 pub fn new() -> Self {
1059 Self::default()
1060 }
1061
1062 pub fn register(
1066 &mut self,
1067 dataset: impl DatasetSpec + 'static,
1068 ) -> Option<Box<dyn DatasetSpec>> {
1069 let id = dataset.id().to_string();
1070 self.datasets.insert(id, Box::new(dataset))
1071 }
1072
1073 #[must_use]
1075 pub fn get(&self, id: &str) -> Option<&dyn DatasetSpec> {
1076 self.datasets.get(id).map(|b| &**b)
1077 }
1078
1079 pub fn unregister(&mut self, id: &str) -> Option<Box<dyn DatasetSpec>> {
1081 self.datasets.remove(id)
1082 }
1083
1084 #[must_use]
1086 pub fn list_ids(&self) -> Vec<&str> {
1087 self.datasets.keys().map(|s| s.as_str()).collect()
1088 }
1089
1090 #[must_use]
1092 pub fn len(&self) -> usize {
1093 self.datasets.len()
1094 }
1095
1096 #[must_use]
1098 pub fn is_empty(&self) -> bool {
1099 self.datasets.is_empty()
1100 }
1101
1102 pub fn iter(&self) -> impl Iterator<Item = (&str, &dyn DatasetSpec)> {
1104 self.datasets.iter().map(|(k, v)| (k.as_str(), &**v))
1105 }
1106
1107 pub fn by_task(&self, task: Task) -> impl Iterator<Item = &dyn DatasetSpec> {
1109 self.datasets
1110 .values()
1111 .filter(move |d| d.supports_task(task))
1112 .map(|b| &**b)
1113 }
1114
1115 pub fn by_language<'a>(&'a self, lang: &'a str) -> impl Iterator<Item = &'a dyn DatasetSpec> {
1117 self.datasets
1118 .values()
1119 .filter(move |d| d.supports_language(lang))
1120 .map(|b| &**b)
1121 }
1122
1123 pub fn by_domain(&self, domain: Domain) -> impl Iterator<Item = &dyn DatasetSpec> {
1125 self.datasets
1126 .values()
1127 .filter(move |d| d.domain() == domain)
1128 .map(|b| &**b)
1129 }
1130
1131 pub fn public_only(&self) -> impl Iterator<Item = &dyn DatasetSpec> {
1133 self.datasets
1134 .values()
1135 .filter(|d| d.is_public())
1136 .map(|b| &**b)
1137 }
1138
1139 pub fn historical(&self) -> impl Iterator<Item = &dyn DatasetSpec> {
1141 self.datasets
1142 .values()
1143 .filter(|d| d.is_historical())
1144 .map(|b| &**b)
1145 }
1146
1147 pub fn with_entity_type<'a>(
1149 &'a self,
1150 entity_type: &'a str,
1151 ) -> impl Iterator<Item = &'a dyn DatasetSpec> {
1152 self.datasets
1153 .values()
1154 .filter(move |d| d.has_entity_type(entity_type))
1155 .map(|b| &**b)
1156 }
1157
1158 #[must_use]
1160 pub fn summary(&self) -> RegistrySummary {
1161 let mut tasks = std::collections::HashMap::new();
1162 let mut domains = std::collections::HashMap::new();
1163 let mut languages = std::collections::HashSet::new();
1164
1165 for ds in self.datasets.values() {
1166 *tasks.entry(ds.task()).or_insert(0) += 1;
1167 *domains.entry(ds.domain()).or_insert(0) += 1;
1168 for lang in ds.languages_vec() {
1169 languages.insert(lang);
1170 }
1171 }
1172
1173 RegistrySummary {
1174 total: self.datasets.len(),
1175 by_task: tasks,
1176 by_domain: domains,
1177 languages: languages.into_iter().collect(),
1178 }
1179 }
1180}
1181
1182#[derive(Debug, Clone)]
1184pub struct RegistrySummary {
1185 pub total: usize,
1187 pub by_task: std::collections::HashMap<Task, usize>,
1189 pub by_domain: std::collections::HashMap<Domain, usize>,
1191 pub languages: Vec<String>,
1193}
1194
1195impl fmt::Debug for DatasetRegistry {
1196 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1197 f.debug_struct("DatasetRegistry")
1198 .field("count", &self.datasets.len())
1199 .field("ids", &self.list_ids())
1200 .finish()
1201 }
1202}
1203
1204#[cfg(test)]
1209mod tests {
1210 use super::*;
1211
1212 #[test]
1213 fn test_custom_dataset_creation() {
1214 let dataset = CustomDataset::new("test_ner", Task::NER)
1215 .with_name("Test NER Dataset")
1216 .with_languages(&["en", "de"])
1217 .with_entity_types(&["PER", "LOC", "ORG"])
1218 .with_license(License::MIT)
1219 .with_domain(Domain::News);
1220
1221 assert_eq!(dataset.id(), "test_ner");
1222 assert_eq!(dataset.name(), "Test NER Dataset");
1223 assert_eq!(dataset.task(), Task::NER);
1224 assert!(dataset.languages_owned().contains(&"en".to_string()));
1226 assert!(dataset.languages_owned().contains(&"de".to_string()));
1227 assert!(!dataset.languages_owned().contains(&"fr".to_string()));
1228 assert!(dataset
1229 .entity_types_owned()
1230 .iter()
1231 .any(|t| t.eq_ignore_ascii_case("PER")));
1232 assert!(dataset
1233 .entity_types_owned()
1234 .iter()
1235 .any(|t| t.eq_ignore_ascii_case("per"))); assert!(dataset.is_public());
1237 }
1238
1239 #[test]
1240 fn test_registry() {
1241 let mut registry = DatasetRegistry::new();
1242
1243 let dataset1 = CustomDataset::new("ds1", Task::NER)
1244 .with_name("Dataset 1")
1245 .with_languages(&["en"]);
1246
1247 let dataset2 = CustomDataset::new("ds2", Task::IntraDocCoref)
1248 .with_name("Dataset 2")
1249 .with_languages(&["de"]);
1250
1251 registry.register(dataset1);
1252 registry.register(dataset2);
1253
1254 assert_eq!(registry.len(), 2);
1255 assert!(registry.get("ds1").is_some());
1256 assert!(registry.get("ds2").is_some());
1257 assert!(registry.get("ds3").is_none());
1258
1259 let ner_datasets: Vec<_> = registry.by_task(Task::NER).collect();
1260 assert_eq!(ner_datasets.len(), 1);
1261 assert_eq!(ner_datasets[0].id(), "ds1");
1262 }
1263
1264 #[test]
1265 fn test_task_properties() {
1266 assert!(Task::NER.produces_entities());
1267 assert!(!Task::IntraDocCoref.produces_entities());
1268 assert!(Task::IntraDocCoref.involves_coreference());
1269 assert!(Task::InterDocCoref.involves_coreference());
1270 assert!(!Task::NER.involves_coreference());
1271 assert!(Task::NED.involves_kb_linking());
1272 assert!(Task::RelationExtraction.involves_relations());
1273 }
1274
1275 #[test]
1276 fn test_license_properties() {
1277 assert!(License::MIT.allows_commercial());
1278 assert!(License::MIT.allows_redistribution());
1279 assert!(!License::LDC.allows_redistribution());
1280 assert!(!License::ResearchOnly.allows_commercial());
1281 }
1282
1283 #[test]
1284 fn test_parser_extensions() {
1285 assert!(ParserHint::CoNLL.typical_extensions().contains(&"conll"));
1286 assert!(ParserHint::JSONL.typical_extensions().contains(&"jsonl"));
1287 }
1288
1289 #[test]
1290 fn test_task_from_str() {
1291 assert_eq!("ner".parse::<Task>().expect("task parse"), Task::NER);
1293 assert_eq!("NER".parse::<Task>().expect("task parse"), Task::NER);
1294 assert_eq!(
1295 "coref".parse::<Task>().expect("task parse"),
1296 Task::IntraDocCoref
1297 );
1298 assert_eq!(
1299 "cdcr".parse::<Task>().expect("task parse"),
1300 Task::InterDocCoref
1301 );
1302 assert_eq!("el".parse::<Task>().expect("task parse"), Task::NED);
1303 assert_eq!(
1304 "entity_linking".parse::<Task>().expect("task parse"),
1305 Task::NED
1306 );
1307 assert_eq!(
1308 "re".parse::<Task>().expect("task parse"),
1309 Task::RelationExtraction
1310 );
1311
1312 assert!("invalid_task".parse::<Task>().is_err());
1314 }
1315
1316 #[test]
1317 fn test_task_code() {
1318 assert_eq!(Task::NER.code(), "ner");
1319 assert_eq!(Task::IntraDocCoref.code(), "coref");
1320 assert_eq!(Task::NED.code(), "el");
1321 assert_eq!(Task::RelationExtraction.code(), "re");
1322 }
1323
1324 #[test]
1325 fn test_task_all_variants() {
1326 assert!(Task::ALL.contains(&Task::NER));
1328 assert!(Task::ALL.contains(&Task::IntraDocCoref));
1329 assert!(Task::ALL.contains(&Task::NED));
1330 assert_eq!(Task::ALL.len(), 13); }
1332
1333 #[test]
1334 fn test_registry_filtering() {
1335 let mut registry = DatasetRegistry::new();
1336
1337 registry.register(
1339 CustomDataset::new("biomedical_ner", Task::NER)
1340 .with_languages(&["en"])
1341 .with_domain(Domain::Biomedical)
1342 .with_entity_types(&["DISEASE", "DRUG"]),
1343 );
1344 registry.register(
1345 CustomDataset::new("news_coref", Task::IntraDocCoref)
1346 .with_languages(&["en", "de"])
1347 .with_domain(Domain::News),
1348 );
1349 registry.register(
1350 CustomDataset::new("sanskrit_edl", Task::NED)
1351 .with_languages(&["sa"])
1352 .with_domain(Domain::Literary)
1353 .historical(),
1354 );
1355
1356 let bio: Vec<_> = registry.by_domain(Domain::Biomedical).collect();
1358 assert_eq!(bio.len(), 1);
1359 assert_eq!(bio[0].id(), "biomedical_ner");
1360
1361 let german: Vec<_> = registry.by_language("de").collect();
1363 assert_eq!(german.len(), 1);
1364 assert_eq!(german[0].id(), "news_coref");
1365
1366 let historical: Vec<_> = registry.historical().collect();
1368 assert_eq!(historical.len(), 1);
1369 assert_eq!(historical[0].id(), "sanskrit_edl");
1370
1371 let disease: Vec<_> = registry.with_entity_type("DISEASE").collect();
1373 assert_eq!(disease.len(), 1);
1374 }
1375
1376 #[test]
1377 fn test_registry_summary() {
1378 let mut registry = DatasetRegistry::new();
1379 registry.register(CustomDataset::new("a", Task::NER).with_languages(&["en"]));
1380 registry.register(CustomDataset::new("b", Task::NER).with_languages(&["de"]));
1381 registry.register(CustomDataset::new("c", Task::IntraDocCoref).with_languages(&["en"]));
1382
1383 let summary = registry.summary();
1384 assert_eq!(summary.total, 3);
1385 assert_eq!(summary.by_task.get(&Task::NER), Some(&2));
1386 assert_eq!(summary.by_task.get(&Task::IntraDocCoref), Some(&1));
1387 assert!(summary.languages.contains(&"en".to_string()));
1388 assert!(summary.languages.contains(&"de".to_string()));
1389 }
1390
1391 #[test]
1392 fn test_historical_custom_dataset_smoke() {
1393 let ds = CustomDataset::new("historical_edl", Task::NED)
1395 .with_name("Historical EDL (example)")
1396 .with_languages(&["sa"])
1397 .with_entity_types(&["Person", "Location"])
1398 .with_parser(ParserHint::CoNLLU)
1399 .with_license(License::CCBY)
1400 .with_domain(Domain::Literary)
1401 .with_secondary_tasks(vec![Task::IntraDocCoref, Task::NER])
1402 .with_stats(DatasetStats {
1403 doc_count: Some(10),
1404 mention_count: Some(100),
1405 ..Default::default()
1406 })
1407 .with_citation("Example citation")
1408 .historical();
1409
1410 assert_eq!(ds.task(), Task::NED);
1411 assert!(ds.supports_language("sa"));
1412 assert!(ds.is_historical());
1413 assert!(ds.is_public());
1414 }
1415
1416 #[test]
1417 fn test_domain_display() {
1418 assert_eq!(format!("{}", Domain::Biomedical), "Biomedical");
1419 assert_eq!(format!("{}", Domain::Literary), "Literary");
1420 assert_eq!(format!("{}", Domain::Other("custom".into())), "custom");
1421 }
1422
1423 #[test]
1424 fn test_license_display() {
1425 assert_eq!(format!("{}", License::CCBY), "CC BY 4.0");
1426 assert_eq!(format!("{}", License::MIT), "MIT");
1427 assert_eq!(format!("{}", License::LDC), "LDC");
1428 }
1429
1430 #[test]
1431 fn test_temporal_coverage() {
1432 let cov = TemporalCoverage {
1433 start_year: Some(2010),
1434 end_year: Some(2020),
1435 has_temporal_annotations: true,
1436 has_diachronic_entities: false,
1437 };
1438
1439 assert_eq!(cov.start_year, Some(2010));
1440 assert!(cov.has_temporal_annotations);
1441 }
1442
1443 #[test]
1444 fn test_split_sizes() {
1445 let splits = SplitSizes {
1446 train: 1000,
1447 dev: 100,
1448 test: 200,
1449 };
1450
1451 assert_eq!(splits.train + splits.dev + splits.test, 1300);
1452 }
1453}