Skip to main content

anno_core/core/
dataset.rs

1//! Dataset specifications and metadata for NER evaluation.
2//!
3//! ## Why This Module Exists
4//!
5//! NER evaluation requires knowing what datasets exist, what entity types they
6//! annotate, what format they're in, and what license governs their use. Without
7//! a structured catalog:
8//!
9//! - Users reinvent the wheel finding/downloading the same datasets
10//! - Entity type mappings differ across implementations (PER vs PERSON vs person)
11//! - License compliance becomes guesswork
12//! - Comparing results across papers requires manual dataset lookup
13//!
14//! This module provides a trait-based abstraction for dataset metadata, plus a
15//! runtime registry for discovering and filtering datasets by task, language,
16//! domain, or license.
17//!
18//! ## Architecture
19//!
20//! ```text
21//! ┌─────────────────────────────────────────────────────────────┐
22//! │                      DatasetSpec trait                      │
23//! │  name(), id(), task(), languages(), entity_types(), ...     │
24//! └─────────────────────────┬───────────────────────────────────┘
25//!                           │ implemented by
26//!           ┌───────────────┴───────────────┐
27//!           │                               │
28//!   ┌───────▼───────┐             ┌─────────▼─────────┐
29//!   │ CustomDataset │             │   Built-in IDs    │
30//!   │ (runtime)     │             │ (compile-time)    │
31//!   └───────────────┘             └───────────────────┘
32//! ```
33//!
34//! ## Example
35//!
36//! ```rust,ignore
37//! use anno::eval::{DatasetId, load_dataset};
38//!
39//! let conll = load_dataset(DatasetId::CoNLL2003, "test")?;
40//! println!("Entity types: {:?}", DatasetId::CoNLL2003.entity_types());
41//! ```
42//!
43//! Built-in datasets include CoNLL2003, OntoNotes5, WikiANN (176 languages),
44//! BC5CDR (biomedical), AnnoCTR (cybersecurity), WNUT17 (social media), and more.
45//!
46//! # Custom Dataset
47//!
48//! ```rust
49//! use anno_core::core::dataset::{DatasetSpec, DatasetStats, Domain, License, ParserHint, Task};
50//!
51//! /// Example NER dataset spec (replace with your own dataset details).
52//! struct ExampleNER;
53//!
54//! impl DatasetSpec for ExampleNER {
55//!     fn name(&self) -> &str { "Example NER" }
56//!     fn id(&self) -> &str { "example_ner_v1" }
57//!     fn task(&self) -> Task { Task::NER }
58//!     fn languages(&self) -> &[&str] { &["en"] }
59//!     fn entity_types(&self) -> &[&str] { &["PER", "ORG", "LOC"] }
60//!     fn parser_hint(&self) -> ParserHint { ParserHint::CoNLL }
61//!     fn license(&self) -> License { License::CCBY }
62//!     fn domain(&self) -> Domain { Domain::News }
63//!     fn stats(&self) -> DatasetStats {
64//!         DatasetStats { doc_count: Some(1_000), ..Default::default() }
65//!     }
66//! }
67//! ```
68//!
69//! # Parser Hints
70//!
71//! Datasets come in various formats. [`ParserHint`] guides the loader:
72//!
73//! | Format | Description | Example Datasets |
74//! |--------|-------------|------------------|
75//! | `CoNLL` | Tab-separated BIO tags | CoNLL2003, WNUT17 |
76//! | `JSON` | Structured JSON objects | LitBank, MultiNERD |
77//! | `JSONL` | JSON Lines (one per doc) | WikiANN, Universal NER |
78//! | `CoNLLU` | Universal Dependencies format | UD treebanks |
79//! | `BRAT` | Standoff annotation format | custom annotations |
80//!
81//! # Licensing
82//!
83//! The [`License`] enum tracks data usage rights:
84//!
85//! - **Research**: Academic use only (e.g., LDC corpora)
86//! - **CC-BY-4.0**: Attribution required, commercial OK
87//! - **Apache-2.0**: Permissive, patent grant
88//! - **Proprietary**: Internal/commercial datasets
89//!
90//! # Domain Coverage
91//!
92//! ```rust
93//! use anno_core::core::dataset::Domain;
94//!
95//! // Built-in domains
96//! let domains = [
97//!     Domain::News,           // CoNLL, OntoNotes
98//!     Domain::Biomedical,     // BC5CDR, NCBI-Disease
99//!     Domain::SocialMedia,    // WNUT17, Twitter
100//!     Domain::Scientific,     // SciERC, WIESP
101//!     Domain::Legal,          // E-NER SEC
102//!     Domain::Cybersecurity,  // AnnoCTR
103//!     Domain::Music,          // Distant Listening Corpus
104//!     Domain::Literary,       // LitBank, etc.
105//!     Domain::Historical,     // HIPE-2022, medieval corpora
106//! ];
107//! ```
108//!
109//! # Classical & Historical Languages
110//!
111//! For historical-language datasets, prefer ISO 639-3 codes and explicit provenance fields.
112//! Keep examples here generic; put dataset-specific details (URLs, stats, citations) in the
113//! dataset registry where they can be reviewed for correctness.
114
115use serde::{Deserialize, Serialize};
116use std::fmt;
117use std::hash::Hash;
118
119// ============================================================================
120// Task Enumeration
121// ============================================================================
122
123/// The primary NLP task a dataset is designed for.
124///
125/// A dataset may support multiple tasks (e.g., NER + Entity Linking),
126/// but has one primary task that determines its structure.
127#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
128#[non_exhaustive]
129pub enum Task {
130    /// Named Entity Recognition (sequence labeling)
131    NER,
132    /// Intra-document coreference resolution
133    IntraDocCoref,
134    /// Inter-document (cross-document) coreference resolution
135    InterDocCoref,
136    /// Named Entity Disambiguation / Entity Linking to KB
137    NED,
138    /// Relation Extraction between entities
139    RelationExtraction,
140    /// Event extraction and argument role labeling
141    EventExtraction,
142    /// Discontinuous/nested NER (e.g., CADEC, ShARe)
143    DiscontinuousNER,
144    /// Visual document NER (forms, receipts, etc.)
145    VisualNER,
146    /// Temporal NER (diachronic entities, time expressions)
147    TemporalNER,
148    /// Sentiment/opinion target extraction
149    AspectExtraction,
150    /// Slot filling for dialogue systems
151    SlotFilling,
152    /// Part-of-speech tagging (often bundled with NER)
153    POS,
154    /// Dependency parsing
155    DependencyParsing,
156}
157
158impl Task {
159    /// Returns true if this task produces entity spans.
160    #[must_use]
161    pub const fn produces_entities(&self) -> bool {
162        matches!(
163            self,
164            Self::NER
165                | Self::DiscontinuousNER
166                | Self::VisualNER
167                | Self::TemporalNER
168                | Self::AspectExtraction
169                | Self::SlotFilling
170        )
171    }
172
173    /// Returns true if this task involves coreference chains.
174    #[must_use]
175    pub const fn involves_coreference(&self) -> bool {
176        matches!(self, Self::IntraDocCoref | Self::InterDocCoref)
177    }
178
179    /// Returns true if this task links to external knowledge bases.
180    #[must_use]
181    pub const fn involves_kb_linking(&self) -> bool {
182        matches!(self, Self::NED)
183    }
184
185    /// Returns true if this task extracts relations between entities.
186    #[must_use]
187    pub const fn involves_relations(&self) -> bool {
188        matches!(self, Self::RelationExtraction | Self::EventExtraction)
189    }
190}
191
192impl fmt::Display for Task {
193    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
194        match self {
195            Self::NER => write!(f, "NER"),
196            Self::IntraDocCoref => write!(f, "Intra-Doc Coreference"),
197            Self::InterDocCoref => write!(f, "Inter-Doc Coreference"),
198            Self::NED => write!(f, "Named Entity Disambiguation"),
199            Self::RelationExtraction => write!(f, "Relation Extraction"),
200            Self::EventExtraction => write!(f, "Event Extraction"),
201            Self::DiscontinuousNER => write!(f, "Discontinuous NER"),
202            Self::VisualNER => write!(f, "Visual NER"),
203            Self::TemporalNER => write!(f, "Temporal NER"),
204            Self::AspectExtraction => write!(f, "Aspect Extraction"),
205            Self::SlotFilling => write!(f, "Slot Filling"),
206            Self::POS => write!(f, "POS Tagging"),
207            Self::DependencyParsing => write!(f, "Dependency Parsing"),
208        }
209    }
210}
211
212impl std::str::FromStr for Task {
213    type Err = String;
214
215    /// Parse task from string (case-insensitive, supports common aliases).
216    ///
217    /// # Examples
218    ///
219    /// ```rust
220    /// use anno_core::core::dataset::Task;
221    ///
222    /// assert_eq!("ner".parse::<Task>().unwrap(), Task::NER);
223    /// assert_eq!("coref".parse::<Task>().unwrap(), Task::IntraDocCoref);
224    /// assert_eq!("entity_linking".parse::<Task>().unwrap(), Task::NED);
225    /// assert_eq!("RE".parse::<Task>().unwrap(), Task::RelationExtraction);
226    /// ```
227    fn from_str(s: &str) -> Result<Self, Self::Err> {
228        match s.to_lowercase().as_str() {
229            "ner" | "named_entity_recognition" | "sequence_labeling" => Ok(Self::NER),
230            "coref" | "coreference" | "intra_doc_coref" | "intradoccoref" => {
231                Ok(Self::IntraDocCoref)
232            }
233            "cdcr" | "inter_doc_coref" | "interdoccoref" | "cross_doc_coref" => {
234                Ok(Self::InterDocCoref)
235            }
236            "ned" | "el" | "entity_linking" | "disambiguation" => Ok(Self::NED),
237            "re" | "relation_extraction" | "relations" => Ok(Self::RelationExtraction),
238            "event" | "event_extraction" | "events" => Ok(Self::EventExtraction),
239            "discontinuous" | "discontinuous_ner" | "nested" | "nested_ner" => {
240                Ok(Self::DiscontinuousNER)
241            }
242            "visual" | "visual_ner" | "document_ner" => Ok(Self::VisualNER),
243            "temporal" | "temporal_ner" | "timex" => Ok(Self::TemporalNER),
244            "aspect" | "aspect_extraction" | "absa" => Ok(Self::AspectExtraction),
245            "slot" | "slot_filling" | "intent" => Ok(Self::SlotFilling),
246            "pos" | "pos_tagging" | "part_of_speech" => Ok(Self::POS),
247            "dep" | "dependency" | "dependency_parsing" => Ok(Self::DependencyParsing),
248            _ => Err(format!(
249                "Unknown task: '{}'. Valid: ner, coref, ned, re, event, ...",
250                s
251            )),
252        }
253    }
254}
255
256impl Task {
257    /// All task variants for iteration.
258    pub const ALL: &'static [Task] = &[
259        Task::NER,
260        Task::IntraDocCoref,
261        Task::InterDocCoref,
262        Task::NED,
263        Task::RelationExtraction,
264        Task::EventExtraction,
265        Task::DiscontinuousNER,
266        Task::VisualNER,
267        Task::TemporalNER,
268        Task::AspectExtraction,
269        Task::SlotFilling,
270        Task::POS,
271        Task::DependencyParsing,
272    ];
273
274    /// Short code for this task (lowercase, no spaces).
275    #[must_use]
276    pub const fn code(&self) -> &'static str {
277        match self {
278            Self::NER => "ner",
279            Self::IntraDocCoref => "coref",
280            Self::InterDocCoref => "cdcr",
281            Self::NED => "el",
282            Self::RelationExtraction => "re",
283            Self::EventExtraction => "event",
284            Self::DiscontinuousNER => "discontinuous",
285            Self::VisualNER => "visual",
286            Self::TemporalNER => "temporal",
287            Self::AspectExtraction => "aspect",
288            Self::SlotFilling => "slot",
289            Self::POS => "pos",
290            Self::DependencyParsing => "dep",
291        }
292    }
293}
294
295// ============================================================================
296// Parser Hints
297// ============================================================================
298
299/// Hint for how to parse this dataset's format.
300///
301/// Used by the loader to select the appropriate parser without
302/// requiring format auto-detection.
303#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
304#[non_exhaustive]
305pub enum ParserHint {
306    /// CoNLL-style column format (BIO/IOB2 tags)
307    #[default]
308    CoNLL,
309    /// CoNLL-U format (Universal Dependencies)
310    CoNLLU,
311    /// JSON with tokens and labels arrays
312    JSON,
313    /// JSON Lines (one JSON object per line)
314    JSONL,
315    /// HuggingFace datasets API format
316    HuggingFaceAPI,
317    /// BRAT standoff annotation format
318    BRAT,
319    /// XML-based format (TEI, etc.)
320    XML,
321    /// ACE/ERE XML format
322    ACE,
323    /// OntoNotes-style format
324    OntoNotes,
325    /// Custom format requiring manual parsing
326    Custom,
327}
328
329impl ParserHint {
330    /// File extensions typically associated with this format.
331    #[must_use]
332    pub const fn typical_extensions(&self) -> &'static [&'static str] {
333        match self {
334            Self::CoNLL => &["conll", "txt", "bio"],
335            Self::CoNLLU => &["conllu"],
336            Self::JSON => &["json"],
337            Self::JSONL => &["jsonl", "ndjson"],
338            Self::HuggingFaceAPI => &["json"],
339            Self::BRAT => &["ann", "txt"],
340            Self::XML | Self::ACE => &["xml", "sgml"],
341            Self::OntoNotes => &["onf", "name"],
342            Self::Custom => &[],
343        }
344    }
345}
346
347// ============================================================================
348// License Information
349// ============================================================================
350
351/// License type for dataset usage.
352///
353/// Important for determining redistribution rights and
354/// commercial usage restrictions.
355#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
356#[non_exhaustive]
357pub enum License {
358    /// Creative Commons Attribution
359    CCBY,
360    /// Creative Commons Attribution-ShareAlike
361    CCBYSA,
362    /// Creative Commons Attribution-NonCommercial
363    CCBYNC,
364    /// Creative Commons Attribution-NonCommercial-ShareAlike
365    CCBYNCSA,
366    /// Creative Commons Zero (public domain)
367    CC0,
368    /// MIT License
369    MIT,
370    /// Apache 2.0 License
371    Apache2,
372    /// GNU General Public License
373    GPL,
374    /// Linguistic Data Consortium (requires membership)
375    LDC,
376    /// Research-only license
377    ResearchOnly,
378    /// Proprietary (restricted use; typically not redistributable)
379    Proprietary,
380    /// Unknown license
381    #[default]
382    Unknown,
383    /// Other license with description
384    Other(String),
385}
386
387impl License {
388    /// Returns true if commercial use is allowed.
389    #[must_use]
390    pub fn allows_commercial(&self) -> bool {
391        matches!(
392            self,
393            Self::CCBY | Self::CCBYSA | Self::CC0 | Self::MIT | Self::Apache2
394        )
395    }
396
397    /// Returns true if the dataset can be freely redistributed.
398    #[must_use]
399    pub fn allows_redistribution(&self) -> bool {
400        !matches!(self, Self::LDC | Self::Proprietary | Self::ResearchOnly)
401    }
402}
403
404impl fmt::Display for License {
405    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
406        match self {
407            Self::CCBY => write!(f, "CC BY 4.0"),
408            Self::CCBYSA => write!(f, "CC BY-SA 4.0"),
409            Self::CCBYNC => write!(f, "CC BY-NC 4.0"),
410            Self::CCBYNCSA => write!(f, "CC BY-NC-SA 4.0"),
411            Self::CC0 => write!(f, "CC0 (Public Domain)"),
412            Self::MIT => write!(f, "MIT"),
413            Self::Apache2 => write!(f, "Apache 2.0"),
414            Self::GPL => write!(f, "GPL"),
415            Self::LDC => write!(f, "LDC"),
416            Self::ResearchOnly => write!(f, "Research Only"),
417            Self::Proprietary => write!(f, "Proprietary"),
418            Self::Unknown => write!(f, "Unknown"),
419            Self::Other(s) => write!(f, "{s}"),
420        }
421    }
422}
423
424// ============================================================================
425// Domain Information
426// ============================================================================
427
428/// Domain/genre of the dataset's source text.
429///
430/// Useful for domain adaptation and transfer learning decisions.
431#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
432#[non_exhaustive]
433pub enum Domain {
434    /// News articles and journalism
435    News,
436    /// Biomedical and clinical text
437    Biomedical,
438    /// Scientific papers and abstracts
439    Scientific,
440    /// Legal documents and contracts
441    Legal,
442    /// Financial reports and news
443    Financial,
444    /// Social media (Twitter, Reddit, etc.)
445    SocialMedia,
446    /// Wikipedia and encyclopedic text
447    Wikipedia,
448    /// Literary fiction and novels
449    Literary,
450    /// Historical documents
451    Historical,
452    /// Conversational/dialogue text
453    Dialogue,
454    /// Technical documentation
455    Technical,
456    /// Web text (general)
457    Web,
458    /// Cybersecurity reports and threat intelligence
459    Cybersecurity,
460    /// Music-related text (lyrics, reviews, metadata)
461    Music,
462    /// Multiple domains
463    #[default]
464    Mixed,
465    /// Other specific domain
466    Other(String),
467}
468
469impl fmt::Display for Domain {
470    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
471        match self {
472            Self::News => write!(f, "News"),
473            Self::Biomedical => write!(f, "Biomedical"),
474            Self::Scientific => write!(f, "Scientific"),
475            Self::Legal => write!(f, "Legal"),
476            Self::Financial => write!(f, "Financial"),
477            Self::SocialMedia => write!(f, "Social Media"),
478            Self::Wikipedia => write!(f, "Wikipedia"),
479            Self::Literary => write!(f, "Literary"),
480            Self::Historical => write!(f, "Historical"),
481            Self::Dialogue => write!(f, "Dialogue"),
482            Self::Technical => write!(f, "Technical"),
483            Self::Web => write!(f, "Web"),
484            Self::Cybersecurity => write!(f, "Cybersecurity"),
485            Self::Music => write!(f, "Music"),
486            Self::Mixed => write!(f, "Mixed"),
487            Self::Other(s) => write!(f, "{s}"),
488        }
489    }
490}
491
492// ============================================================================
493// Temporal Coverage
494// ============================================================================
495
496/// Temporal coverage of the dataset.
497///
498/// Important for understanding potential temporal bias and
499/// for diachronic entity tracking.
500#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
501pub struct TemporalCoverage {
502    /// Earliest document date (if known)
503    pub start_year: Option<i32>,
504    /// Latest document date (if known)
505    pub end_year: Option<i32>,
506    /// Whether the dataset includes explicit temporal annotations
507    pub has_temporal_annotations: bool,
508    /// Whether entities have validity periods (diachronic)
509    pub has_diachronic_entities: bool,
510}
511
512// ============================================================================
513// Dataset Statistics
514// ============================================================================
515
516/// Statistics about a dataset's size and composition.
517#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
518pub struct DatasetStats {
519    /// Number of documents/examples
520    pub doc_count: Option<usize>,
521    /// Number of entity mentions
522    pub mention_count: Option<usize>,
523    /// Number of unique entities (after coreference)
524    pub entity_count: Option<usize>,
525    /// Number of tokens
526    pub token_count: Option<usize>,
527    /// Train/dev/test split sizes
528    pub split_sizes: Option<SplitSizes>,
529}
530
531/// Train/dev/test split sizes.
532#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
533pub struct SplitSizes {
534    /// Number of examples in training split
535    pub train: usize,
536    /// Number of examples in development/validation split
537    pub dev: usize,
538    /// Number of examples in test split
539    pub test: usize,
540}
541
542// ============================================================================
543// DatasetSpec Trait
544// ============================================================================
545
546/// Specification for a dataset that can be loaded and evaluated.
547///
548/// This trait is the foundation for both built-in datasets (via the
549/// `DatasetId` enum) and custom user-defined datasets.
550///
551/// # Implementing Custom Datasets
552///
553/// ```rust,ignore
554/// use anno_core::core::dataset::*;
555///
556/// struct MyDataset {
557///     path: PathBuf,
558/// }
559///
560/// impl DatasetSpec for MyDataset {
561///     fn name(&self) -> &str { "My Custom Dataset" }
562///     fn id(&self) -> &str { "my_custom_v1" }
563///     fn task(&self) -> Task { Task::NER }
564///     fn languages(&self) -> &[&str] { &["en"] }
565///     fn entity_types(&self) -> &[&str] { &["PER", "ORG", "LOC"] }
566///     fn parser_hint(&self) -> ParserHint { ParserHint::CoNLL }
567///     fn license(&self) -> License { License::Proprietary }
568///
569///     // Override to provide actual data path
570///     fn local_path(&self) -> Option<&std::path::Path> {
571///         Some(&self.path)
572///     }
573/// }
574/// ```
575pub trait DatasetSpec: Send + Sync {
576    // ========================================================================
577    // Required Methods
578    // ========================================================================
579
580    /// Human-readable name of the dataset.
581    fn name(&self) -> &str;
582
583    /// Unique identifier string (snake_case, no spaces).
584    fn id(&self) -> &str;
585
586    /// Primary task this dataset is designed for.
587    fn task(&self) -> Task;
588
589    /// ISO 639-1 language codes (e.g., "en", "zh", "de").
590    ///
591    /// Use `["multilingual"]` for datasets covering many languages.
592    fn languages(&self) -> &[&str];
593
594    /// Entity types annotated in this dataset.
595    ///
596    /// For NER: `["PER", "LOC", "ORG", "MISC"]`
597    /// For biomedical: `["GENE", "DISEASE", "DRUG", "SPECIES"]`
598    fn entity_types(&self) -> &[&str];
599
600    /// Parser format hint for loading.
601    fn parser_hint(&self) -> ParserHint;
602
603    /// License governing dataset usage.
604    fn license(&self) -> License;
605
606    // ========================================================================
607    // Optional Methods with Defaults
608    // ========================================================================
609
610    /// Detailed description of the dataset.
611    fn description(&self) -> Option<&str> {
612        None
613    }
614
615    /// Domain/genre of source text.
616    fn domain(&self) -> Domain {
617        Domain::Mixed
618    }
619
620    /// URL for downloading the dataset.
621    fn download_url(&self) -> Option<&str> {
622        None
623    }
624
625    /// Citation information (BibTeX or plain text).
626    fn citation(&self) -> Option<&str> {
627        None
628    }
629
630    /// DOI or other persistent identifier.
631    fn doi(&self) -> Option<&str> {
632        None
633    }
634
635    /// Local path if already downloaded.
636    fn local_path(&self) -> Option<&std::path::Path> {
637        None
638    }
639
640    /// Dataset statistics (counts, splits).
641    fn stats(&self) -> DatasetStats {
642        DatasetStats::default()
643    }
644
645    /// Temporal coverage information.
646    fn temporal_coverage(&self) -> TemporalCoverage {
647        TemporalCoverage::default()
648    }
649
650    /// Additional tasks supported beyond the primary task.
651    fn secondary_tasks(&self) -> &[Task] {
652        &[]
653    }
654
655    /// Whether this is a constructed/artificial language dataset.
656    fn is_constructed_language(&self) -> bool {
657        false
658    }
659
660    /// Whether this is a historical/ancient language dataset.
661    fn is_historical(&self) -> bool {
662        false
663    }
664
665    /// Whether this dataset requires special access (gated, auth, etc.).
666    fn requires_auth(&self) -> bool {
667        false
668    }
669
670    /// Version string (e.g., "1.0", "2024-01").
671    fn version(&self) -> Option<&str> {
672        None
673    }
674
675    /// Notes or caveats about the dataset.
676    fn notes(&self) -> Option<&str> {
677        None
678    }
679
680    // ========================================================================
681    // Owned Variants (for runtime/custom datasets)
682    // ========================================================================
683
684    /// Get languages as owned Vec (for custom datasets that don't have static data).
685    ///
686    /// Default implementation converts from `languages()`.
687    fn languages_vec(&self) -> Vec<String> {
688        self.languages().iter().map(|s| (*s).to_string()).collect()
689    }
690
691    /// Get entity types as owned Vec (for custom datasets that don't have static data).
692    ///
693    /// Default implementation converts from `entity_types()`.
694    fn entity_types_vec(&self) -> Vec<String> {
695        self.entity_types()
696            .iter()
697            .map(|s| (*s).to_string())
698            .collect()
699    }
700
701    // ========================================================================
702    // Computed Properties
703    // ========================================================================
704
705    /// Check if this dataset is publicly available.
706    fn is_public(&self) -> bool {
707        self.license().allows_redistribution() && !self.requires_auth()
708    }
709
710    /// Check if this dataset supports a specific task.
711    fn supports_task(&self, task: Task) -> bool {
712        self.task() == task || self.secondary_tasks().contains(&task)
713    }
714
715    /// Check if this dataset covers a specific language.
716    fn supports_language(&self, lang: &str) -> bool {
717        let langs = self.languages_vec();
718        langs.iter().any(|l| l == "multilingual" || l == lang)
719    }
720
721    /// Check if this dataset has a specific entity type.
722    fn has_entity_type(&self, entity_type: &str) -> bool {
723        self.entity_types_vec()
724            .iter()
725            .any(|t| t.eq_ignore_ascii_case(entity_type))
726    }
727}
728
729// ============================================================================
730// Custom Dataset Implementation
731// ============================================================================
732
733/// A custom dataset defined at runtime.
734///
735/// Use this when you need to load a dataset that isn't in the built-in
736/// `DatasetId` enum.
737///
738/// # Example
739///
740/// ```rust
741/// use anno_core::core::dataset::{CustomDataset, Task, ParserHint, License, Domain};
742/// use std::path::PathBuf;
743///
744/// let dataset = CustomDataset::new("my_ner_data", Task::NER)
745///     .with_name("My Company NER Dataset")
746///     .with_languages(&["en", "de"])
747///     .with_entity_types(&["PRODUCT", "TEAM", "PROJECT"])
748///     .with_parser(ParserHint::CoNLL)
749///     .with_license(License::Proprietary)
750///     .with_domain(Domain::Technical)
751///     .with_path(PathBuf::from("/data/my_ner.conll"));
752/// ```
753#[derive(Debug, Clone)]
754pub struct CustomDataset {
755    id: String,
756    name: String,
757    task: Task,
758    languages: Vec<String>,
759    entity_types: Vec<String>,
760    parser_hint: ParserHint,
761    license: License,
762    description: Option<String>,
763    domain: Domain,
764    download_url: Option<String>,
765    local_path: Option<std::path::PathBuf>,
766    stats: DatasetStats,
767    temporal_coverage: TemporalCoverage,
768    secondary_tasks: Vec<Task>,
769    is_constructed: bool,
770    is_historical: bool,
771    requires_auth: bool,
772    version: Option<String>,
773    notes: Option<String>,
774    citation: Option<String>,
775}
776
777impl CustomDataset {
778    /// Create a new custom dataset with minimal required fields.
779    #[must_use]
780    pub fn new(id: impl Into<String>, task: Task) -> Self {
781        let id = id.into();
782        Self {
783            name: id.clone(),
784            id,
785            task,
786            languages: vec!["en".to_string()],
787            entity_types: vec![],
788            parser_hint: ParserHint::CoNLL,
789            license: License::Unknown,
790            description: None,
791            domain: Domain::Mixed,
792            download_url: None,
793            local_path: None,
794            stats: DatasetStats::default(),
795            temporal_coverage: TemporalCoverage::default(),
796            secondary_tasks: vec![],
797            is_constructed: false,
798            is_historical: false,
799            requires_auth: false,
800            version: None,
801            notes: None,
802            citation: None,
803        }
804    }
805
806    /// Set the human-readable name.
807    #[must_use]
808    pub fn with_name(mut self, name: impl Into<String>) -> Self {
809        self.name = name.into();
810        self
811    }
812
813    /// Set the languages covered.
814    #[must_use]
815    pub fn with_languages(mut self, langs: &[&str]) -> Self {
816        self.languages = langs.iter().map(|s| (*s).to_string()).collect();
817        self
818    }
819
820    /// Set the entity types.
821    #[must_use]
822    pub fn with_entity_types(mut self, types: &[&str]) -> Self {
823        self.entity_types = types.iter().map(|s| (*s).to_string()).collect();
824        self
825    }
826
827    /// Set the parser hint.
828    #[must_use]
829    pub fn with_parser(mut self, parser: ParserHint) -> Self {
830        self.parser_hint = parser;
831        self
832    }
833
834    /// Set the license.
835    #[must_use]
836    pub fn with_license(mut self, license: License) -> Self {
837        self.license = license;
838        self
839    }
840
841    /// Set the description.
842    #[must_use]
843    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
844        self.description = Some(desc.into());
845        self
846    }
847
848    /// Set the domain.
849    #[must_use]
850    pub fn with_domain(mut self, domain: Domain) -> Self {
851        self.domain = domain;
852        self
853    }
854
855    /// Set the download URL.
856    #[must_use]
857    pub fn with_url(mut self, url: impl Into<String>) -> Self {
858        self.download_url = Some(url.into());
859        self
860    }
861
862    /// Set the local file path.
863    #[must_use]
864    pub fn with_path(mut self, path: std::path::PathBuf) -> Self {
865        self.local_path = Some(path);
866        self
867    }
868
869    /// Set dataset statistics.
870    #[must_use]
871    pub fn with_stats(mut self, stats: DatasetStats) -> Self {
872        self.stats = stats;
873        self
874    }
875
876    /// Set temporal coverage.
877    #[must_use]
878    pub fn with_temporal_coverage(mut self, coverage: TemporalCoverage) -> Self {
879        self.temporal_coverage = coverage;
880        self
881    }
882
883    /// Add secondary tasks.
884    #[must_use]
885    pub fn with_secondary_tasks(mut self, tasks: Vec<Task>) -> Self {
886        self.secondary_tasks = tasks;
887        self
888    }
889
890    /// Mark as constructed language.
891    #[must_use]
892    pub fn constructed(mut self) -> Self {
893        self.is_constructed = true;
894        self
895    }
896
897    /// Mark as historical language.
898    #[must_use]
899    pub fn historical(mut self) -> Self {
900        self.is_historical = true;
901        self
902    }
903
904    /// Mark as requiring authentication.
905    #[must_use]
906    pub fn requires_authentication(mut self) -> Self {
907        self.requires_auth = true;
908        self
909    }
910
911    /// Set version string.
912    #[must_use]
913    pub fn with_version(mut self, version: impl Into<String>) -> Self {
914        self.version = Some(version.into());
915        self
916    }
917
918    /// Get languages as owned strings (for custom datasets).
919    #[must_use]
920    pub fn languages_owned(&self) -> &[String] {
921        &self.languages
922    }
923
924    /// Get entity types as owned strings (for custom datasets).
925    #[must_use]
926    pub fn entity_types_owned(&self) -> &[String] {
927        &self.entity_types
928    }
929
930    /// Set notes.
931    #[must_use]
932    pub fn with_notes(mut self, notes: impl Into<String>) -> Self {
933        self.notes = Some(notes.into());
934        self
935    }
936
937    /// Set citation.
938    #[must_use]
939    pub fn with_citation(mut self, citation: impl Into<String>) -> Self {
940        self.citation = Some(citation.into());
941        self
942    }
943}
944
945impl DatasetSpec for CustomDataset {
946    fn name(&self) -> &str {
947        &self.name
948    }
949
950    fn id(&self) -> &str {
951        &self.id
952    }
953
954    fn task(&self) -> Task {
955        self.task
956    }
957
958    fn languages(&self) -> &[&str] {
959        // This is handled via cached_languages field
960        // For CustomDataset, we use a different pattern - see languages_owned()
961        static EMPTY: &[&str] = &[];
962        EMPTY
963    }
964
965    fn entity_types(&self) -> &[&str] {
966        // This is handled via cached_entity_types field
967        // For CustomDataset, we use a different pattern - see entity_types_owned()
968        static EMPTY: &[&str] = &[];
969        EMPTY
970    }
971
972    fn parser_hint(&self) -> ParserHint {
973        self.parser_hint
974    }
975
976    fn license(&self) -> License {
977        self.license.clone()
978    }
979
980    fn description(&self) -> Option<&str> {
981        self.description.as_deref()
982    }
983
984    fn domain(&self) -> Domain {
985        self.domain.clone()
986    }
987
988    fn download_url(&self) -> Option<&str> {
989        self.download_url.as_deref()
990    }
991
992    fn local_path(&self) -> Option<&std::path::Path> {
993        self.local_path.as_deref()
994    }
995
996    fn stats(&self) -> DatasetStats {
997        self.stats.clone()
998    }
999
1000    fn temporal_coverage(&self) -> TemporalCoverage {
1001        self.temporal_coverage.clone()
1002    }
1003
1004    fn secondary_tasks(&self) -> &[Task] {
1005        &self.secondary_tasks
1006    }
1007
1008    fn is_constructed_language(&self) -> bool {
1009        self.is_constructed
1010    }
1011
1012    fn is_historical(&self) -> bool {
1013        self.is_historical
1014    }
1015
1016    fn requires_auth(&self) -> bool {
1017        self.requires_auth
1018    }
1019
1020    fn version(&self) -> Option<&str> {
1021        self.version.as_deref()
1022    }
1023
1024    fn notes(&self) -> Option<&str> {
1025        self.notes.as_deref()
1026    }
1027
1028    fn citation(&self) -> Option<&str> {
1029        self.citation.as_deref()
1030    }
1031
1032    // Override to use owned data directly instead of converting
1033    fn languages_vec(&self) -> Vec<String> {
1034        self.languages.clone()
1035    }
1036
1037    fn entity_types_vec(&self) -> Vec<String> {
1038        self.entity_types.clone()
1039    }
1040}
1041
1042// ============================================================================
1043// Dataset Registry
1044// ============================================================================
1045
1046/// Registry for dynamically registered custom datasets.
1047///
1048/// This allows users to register their own datasets at runtime
1049/// without modifying the built-in enum.
1050#[derive(Default)]
1051pub struct DatasetRegistry {
1052    datasets: std::collections::HashMap<String, Box<dyn DatasetSpec>>,
1053}
1054
1055impl DatasetRegistry {
1056    /// Create a new empty registry.
1057    #[must_use]
1058    pub fn new() -> Self {
1059        Self::default()
1060    }
1061
1062    /// Register a custom dataset.
1063    ///
1064    /// Returns the previous dataset with this ID if one existed.
1065    pub fn register(
1066        &mut self,
1067        dataset: impl DatasetSpec + 'static,
1068    ) -> Option<Box<dyn DatasetSpec>> {
1069        let id = dataset.id().to_string();
1070        self.datasets.insert(id, Box::new(dataset))
1071    }
1072
1073    /// Get a dataset by ID.
1074    #[must_use]
1075    pub fn get(&self, id: &str) -> Option<&dyn DatasetSpec> {
1076        self.datasets.get(id).map(|b| &**b)
1077    }
1078
1079    /// Remove a dataset by ID.
1080    pub fn unregister(&mut self, id: &str) -> Option<Box<dyn DatasetSpec>> {
1081        self.datasets.remove(id)
1082    }
1083
1084    /// List all registered dataset IDs.
1085    #[must_use]
1086    pub fn list_ids(&self) -> Vec<&str> {
1087        self.datasets.keys().map(|s| s.as_str()).collect()
1088    }
1089
1090    /// Number of registered datasets.
1091    #[must_use]
1092    pub fn len(&self) -> usize {
1093        self.datasets.len()
1094    }
1095
1096    /// Check if registry is empty.
1097    #[must_use]
1098    pub fn is_empty(&self) -> bool {
1099        self.datasets.is_empty()
1100    }
1101
1102    /// Iterate over all registered datasets.
1103    pub fn iter(&self) -> impl Iterator<Item = (&str, &dyn DatasetSpec)> {
1104        self.datasets.iter().map(|(k, v)| (k.as_str(), &**v))
1105    }
1106
1107    /// Filter datasets by task.
1108    pub fn by_task(&self, task: Task) -> impl Iterator<Item = &dyn DatasetSpec> {
1109        self.datasets
1110            .values()
1111            .filter(move |d| d.supports_task(task))
1112            .map(|b| &**b)
1113    }
1114
1115    /// Filter datasets by language.
1116    pub fn by_language<'a>(&'a self, lang: &'a str) -> impl Iterator<Item = &'a dyn DatasetSpec> {
1117        self.datasets
1118            .values()
1119            .filter(move |d| d.supports_language(lang))
1120            .map(|b| &**b)
1121    }
1122
1123    /// Filter datasets by domain.
1124    pub fn by_domain(&self, domain: Domain) -> impl Iterator<Item = &dyn DatasetSpec> {
1125        self.datasets
1126            .values()
1127            .filter(move |d| d.domain() == domain)
1128            .map(|b| &**b)
1129    }
1130
1131    /// Filter datasets that are publicly available (no auth, redistributable license).
1132    pub fn public_only(&self) -> impl Iterator<Item = &dyn DatasetSpec> {
1133        self.datasets
1134            .values()
1135            .filter(|d| d.is_public())
1136            .map(|b| &**b)
1137    }
1138
1139    /// Filter historical/ancient language datasets.
1140    pub fn historical(&self) -> impl Iterator<Item = &dyn DatasetSpec> {
1141        self.datasets
1142            .values()
1143            .filter(|d| d.is_historical())
1144            .map(|b| &**b)
1145    }
1146
1147    /// Find datasets supporting a specific entity type.
1148    pub fn with_entity_type<'a>(
1149        &'a self,
1150        entity_type: &'a str,
1151    ) -> impl Iterator<Item = &'a dyn DatasetSpec> {
1152        self.datasets
1153            .values()
1154            .filter(move |d| d.has_entity_type(entity_type))
1155            .map(|b| &**b)
1156    }
1157
1158    /// Get summary statistics about registered datasets.
1159    #[must_use]
1160    pub fn summary(&self) -> RegistrySummary {
1161        let mut tasks = std::collections::HashMap::new();
1162        let mut domains = std::collections::HashMap::new();
1163        let mut languages = std::collections::HashSet::new();
1164
1165        for ds in self.datasets.values() {
1166            *tasks.entry(ds.task()).or_insert(0) += 1;
1167            *domains.entry(ds.domain()).or_insert(0) += 1;
1168            for lang in ds.languages_vec() {
1169                languages.insert(lang);
1170            }
1171        }
1172
1173        RegistrySummary {
1174            total: self.datasets.len(),
1175            by_task: tasks,
1176            by_domain: domains,
1177            languages: languages.into_iter().collect(),
1178        }
1179    }
1180}
1181
1182/// Summary statistics for a dataset registry.
1183#[derive(Debug, Clone)]
1184pub struct RegistrySummary {
1185    /// Total number of datasets.
1186    pub total: usize,
1187    /// Count by primary task.
1188    pub by_task: std::collections::HashMap<Task, usize>,
1189    /// Count by domain.
1190    pub by_domain: std::collections::HashMap<Domain, usize>,
1191    /// All languages covered.
1192    pub languages: Vec<String>,
1193}
1194
1195impl fmt::Debug for DatasetRegistry {
1196    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1197        f.debug_struct("DatasetRegistry")
1198            .field("count", &self.datasets.len())
1199            .field("ids", &self.list_ids())
1200            .finish()
1201    }
1202}
1203
1204// ============================================================================
1205// Tests
1206// ============================================================================
1207
1208#[cfg(test)]
1209mod tests {
1210    use super::*;
1211
1212    #[test]
1213    fn test_custom_dataset_creation() {
1214        let dataset = CustomDataset::new("test_ner", Task::NER)
1215            .with_name("Test NER Dataset")
1216            .with_languages(&["en", "de"])
1217            .with_entity_types(&["PER", "LOC", "ORG"])
1218            .with_license(License::MIT)
1219            .with_domain(Domain::News);
1220
1221        assert_eq!(dataset.id(), "test_ner");
1222        assert_eq!(dataset.name(), "Test NER Dataset");
1223        assert_eq!(dataset.task(), Task::NER);
1224        // Use owned getters for custom datasets
1225        assert!(dataset.languages_owned().contains(&"en".to_string()));
1226        assert!(dataset.languages_owned().contains(&"de".to_string()));
1227        assert!(!dataset.languages_owned().contains(&"fr".to_string()));
1228        assert!(dataset
1229            .entity_types_owned()
1230            .iter()
1231            .any(|t| t.eq_ignore_ascii_case("PER")));
1232        assert!(dataset
1233            .entity_types_owned()
1234            .iter()
1235            .any(|t| t.eq_ignore_ascii_case("per"))); // case insensitive
1236        assert!(dataset.is_public());
1237    }
1238
1239    #[test]
1240    fn test_registry() {
1241        let mut registry = DatasetRegistry::new();
1242
1243        let dataset1 = CustomDataset::new("ds1", Task::NER)
1244            .with_name("Dataset 1")
1245            .with_languages(&["en"]);
1246
1247        let dataset2 = CustomDataset::new("ds2", Task::IntraDocCoref)
1248            .with_name("Dataset 2")
1249            .with_languages(&["de"]);
1250
1251        registry.register(dataset1);
1252        registry.register(dataset2);
1253
1254        assert_eq!(registry.len(), 2);
1255        assert!(registry.get("ds1").is_some());
1256        assert!(registry.get("ds2").is_some());
1257        assert!(registry.get("ds3").is_none());
1258
1259        let ner_datasets: Vec<_> = registry.by_task(Task::NER).collect();
1260        assert_eq!(ner_datasets.len(), 1);
1261        assert_eq!(ner_datasets[0].id(), "ds1");
1262    }
1263
1264    #[test]
1265    fn test_task_properties() {
1266        assert!(Task::NER.produces_entities());
1267        assert!(!Task::IntraDocCoref.produces_entities());
1268        assert!(Task::IntraDocCoref.involves_coreference());
1269        assert!(Task::InterDocCoref.involves_coreference());
1270        assert!(!Task::NER.involves_coreference());
1271        assert!(Task::NED.involves_kb_linking());
1272        assert!(Task::RelationExtraction.involves_relations());
1273    }
1274
1275    #[test]
1276    fn test_license_properties() {
1277        assert!(License::MIT.allows_commercial());
1278        assert!(License::MIT.allows_redistribution());
1279        assert!(!License::LDC.allows_redistribution());
1280        assert!(!License::ResearchOnly.allows_commercial());
1281    }
1282
1283    #[test]
1284    fn test_parser_extensions() {
1285        assert!(ParserHint::CoNLL.typical_extensions().contains(&"conll"));
1286        assert!(ParserHint::JSONL.typical_extensions().contains(&"jsonl"));
1287    }
1288
1289    #[test]
1290    fn test_task_from_str() {
1291        // Basic parsing
1292        assert_eq!("ner".parse::<Task>().expect("task parse"), Task::NER);
1293        assert_eq!("NER".parse::<Task>().expect("task parse"), Task::NER);
1294        assert_eq!(
1295            "coref".parse::<Task>().expect("task parse"),
1296            Task::IntraDocCoref
1297        );
1298        assert_eq!(
1299            "cdcr".parse::<Task>().expect("task parse"),
1300            Task::InterDocCoref
1301        );
1302        assert_eq!("el".parse::<Task>().expect("task parse"), Task::NED);
1303        assert_eq!(
1304            "entity_linking".parse::<Task>().expect("task parse"),
1305            Task::NED
1306        );
1307        assert_eq!(
1308            "re".parse::<Task>().expect("task parse"),
1309            Task::RelationExtraction
1310        );
1311
1312        // Invalid task
1313        assert!("invalid_task".parse::<Task>().is_err());
1314    }
1315
1316    #[test]
1317    fn test_task_code() {
1318        assert_eq!(Task::NER.code(), "ner");
1319        assert_eq!(Task::IntraDocCoref.code(), "coref");
1320        assert_eq!(Task::NED.code(), "el");
1321        assert_eq!(Task::RelationExtraction.code(), "re");
1322    }
1323
1324    #[test]
1325    fn test_task_all_variants() {
1326        // Ensure ALL contains all variants
1327        assert!(Task::ALL.contains(&Task::NER));
1328        assert!(Task::ALL.contains(&Task::IntraDocCoref));
1329        assert!(Task::ALL.contains(&Task::NED));
1330        assert_eq!(Task::ALL.len(), 13); // Update if variants change
1331    }
1332
1333    #[test]
1334    fn test_registry_filtering() {
1335        let mut registry = DatasetRegistry::new();
1336
1337        // Add diverse datasets
1338        registry.register(
1339            CustomDataset::new("biomedical_ner", Task::NER)
1340                .with_languages(&["en"])
1341                .with_domain(Domain::Biomedical)
1342                .with_entity_types(&["DISEASE", "DRUG"]),
1343        );
1344        registry.register(
1345            CustomDataset::new("news_coref", Task::IntraDocCoref)
1346                .with_languages(&["en", "de"])
1347                .with_domain(Domain::News),
1348        );
1349        registry.register(
1350            CustomDataset::new("sanskrit_edl", Task::NED)
1351                .with_languages(&["sa"])
1352                .with_domain(Domain::Literary)
1353                .historical(),
1354        );
1355
1356        // Test by_domain
1357        let bio: Vec<_> = registry.by_domain(Domain::Biomedical).collect();
1358        assert_eq!(bio.len(), 1);
1359        assert_eq!(bio[0].id(), "biomedical_ner");
1360
1361        // Test by_language
1362        let german: Vec<_> = registry.by_language("de").collect();
1363        assert_eq!(german.len(), 1);
1364        assert_eq!(german[0].id(), "news_coref");
1365
1366        // Test historical
1367        let historical: Vec<_> = registry.historical().collect();
1368        assert_eq!(historical.len(), 1);
1369        assert_eq!(historical[0].id(), "sanskrit_edl");
1370
1371        // Test with_entity_type
1372        let disease: Vec<_> = registry.with_entity_type("DISEASE").collect();
1373        assert_eq!(disease.len(), 1);
1374    }
1375
1376    #[test]
1377    fn test_registry_summary() {
1378        let mut registry = DatasetRegistry::new();
1379        registry.register(CustomDataset::new("a", Task::NER).with_languages(&["en"]));
1380        registry.register(CustomDataset::new("b", Task::NER).with_languages(&["de"]));
1381        registry.register(CustomDataset::new("c", Task::IntraDocCoref).with_languages(&["en"]));
1382
1383        let summary = registry.summary();
1384        assert_eq!(summary.total, 3);
1385        assert_eq!(summary.by_task.get(&Task::NER), Some(&2));
1386        assert_eq!(summary.by_task.get(&Task::IntraDocCoref), Some(&1));
1387        assert!(summary.languages.contains(&"en".to_string()));
1388        assert!(summary.languages.contains(&"de".to_string()));
1389    }
1390
1391    #[test]
1392    fn test_historical_custom_dataset_smoke() {
1393        // Keep tests generic; dataset-specific examples belong in the dataset registry.
1394        let ds = CustomDataset::new("historical_edl", Task::NED)
1395            .with_name("Historical EDL (example)")
1396            .with_languages(&["sa"])
1397            .with_entity_types(&["Person", "Location"])
1398            .with_parser(ParserHint::CoNLLU)
1399            .with_license(License::CCBY)
1400            .with_domain(Domain::Literary)
1401            .with_secondary_tasks(vec![Task::IntraDocCoref, Task::NER])
1402            .with_stats(DatasetStats {
1403                doc_count: Some(10),
1404                mention_count: Some(100),
1405                ..Default::default()
1406            })
1407            .with_citation("Example citation")
1408            .historical();
1409
1410        assert_eq!(ds.task(), Task::NED);
1411        assert!(ds.supports_language("sa"));
1412        assert!(ds.is_historical());
1413        assert!(ds.is_public());
1414    }
1415
1416    #[test]
1417    fn test_domain_display() {
1418        assert_eq!(format!("{}", Domain::Biomedical), "Biomedical");
1419        assert_eq!(format!("{}", Domain::Literary), "Literary");
1420        assert_eq!(format!("{}", Domain::Other("custom".into())), "custom");
1421    }
1422
1423    #[test]
1424    fn test_license_display() {
1425        assert_eq!(format!("{}", License::CCBY), "CC BY 4.0");
1426        assert_eq!(format!("{}", License::MIT), "MIT");
1427        assert_eq!(format!("{}", License::LDC), "LDC");
1428    }
1429
1430    #[test]
1431    fn test_temporal_coverage() {
1432        let cov = TemporalCoverage {
1433            start_year: Some(2010),
1434            end_year: Some(2020),
1435            has_temporal_annotations: true,
1436            has_diachronic_entities: false,
1437        };
1438
1439        assert_eq!(cov.start_year, Some(2010));
1440        assert!(cov.has_temporal_annotations);
1441    }
1442
1443    #[test]
1444    fn test_split_sizes() {
1445        let splits = SplitSizes {
1446            train: 1000,
1447            dev: 100,
1448            test: 200,
1449        };
1450
1451        assert_eq!(splits.train + splits.dev + splits.test, 1300);
1452    }
1453}