Skip to main content

entrenar/yaml_mode/manifest/
data.rs

1//! Data Configuration
2//!
3//! Contains all data-related configuration types for training manifests.
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8use super::shorthand::deserialize_human_usize_opt;
9
10/// Dataset configuration
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct DataConfig {
13    /// Data source URI (pacha://, hf://, s3://, or local path)
14    #[serde(default, skip_serializing_if = "Option::is_none")]
15    pub source: Option<String>,
16
17    /// Explicit format (auto-detected if omitted)
18    #[serde(default, skip_serializing_if = "Option::is_none")]
19    pub format: Option<String>,
20
21    /// Data split configuration
22    #[serde(default, skip_serializing_if = "Option::is_none")]
23    pub split: Option<DataSplit>,
24
25    /// Explicit training data path
26    #[serde(default, skip_serializing_if = "Option::is_none")]
27    pub train: Option<String>,
28
29    /// Explicit validation data path
30    #[serde(default, skip_serializing_if = "Option::is_none")]
31    pub val: Option<String>,
32
33    /// Explicit test data path
34    #[serde(default, skip_serializing_if = "Option::is_none")]
35    pub test: Option<String>,
36
37    /// Preprocessing pipeline
38    #[serde(default, skip_serializing_if = "Option::is_none")]
39    pub preprocessing: Option<Vec<PreprocessingStep>>,
40
41    /// Data augmentation pipeline
42    #[serde(default, skip_serializing_if = "Option::is_none")]
43    pub augmentation: Option<Vec<HashMap<String, serde_json::Value>>>,
44
45    /// DataLoader settings
46    #[serde(default, skip_serializing_if = "Option::is_none")]
47    pub loader: Option<DataLoader>,
48
49    // === LLM data fields (mirrors TrainSpec DataConfig) ===
50    /// Path to tokenizer.json (for transformer/LLM training)
51    #[serde(default, skip_serializing_if = "Option::is_none")]
52    pub tokenizer: Option<String>,
53
54    /// Sequence length (for transformers). Accepts shorthand: `"2K"` = 2048.
55    #[serde(
56        default,
57        skip_serializing_if = "Option::is_none",
58        deserialize_with = "deserialize_human_usize_opt"
59    )]
60    pub seq_len: Option<usize>,
61
62    /// Input text column name (for transformer mode)
63    #[serde(default, skip_serializing_if = "Option::is_none")]
64    pub input_column: Option<String>,
65
66    /// Output/target text column name (for transformer mode)
67    #[serde(default, skip_serializing_if = "Option::is_none")]
68    pub output_column: Option<String>,
69
70    /// Maximum tokenization length. Accepts shorthand: `"512"`, `"1K"`.
71    #[serde(
72        default,
73        skip_serializing_if = "Option::is_none",
74        deserialize_with = "deserialize_human_usize_opt"
75    )]
76    pub max_length: Option<usize>,
77}
78
79/// Data split ratios
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct DataSplit {
82    /// Training set ratio (0.0-1.0)
83    pub train: f64,
84
85    /// Validation set ratio (optional)
86    #[serde(default, skip_serializing_if = "Option::is_none")]
87    pub val: Option<f64>,
88
89    /// Test set ratio (optional)
90    #[serde(default, skip_serializing_if = "Option::is_none")]
91    pub test: Option<f64>,
92
93    /// Column name for stratified sampling
94    #[serde(default, skip_serializing_if = "Option::is_none")]
95    pub stratify: Option<String>,
96
97    /// Split seed (inherits global if omitted)
98    #[serde(default, skip_serializing_if = "Option::is_none")]
99    pub seed: Option<u64>,
100}
101
102/// Preprocessing step (normalize, encode, drop, fillna, tokenize)
103#[derive(Debug, Clone, Serialize, Deserialize)]
104#[serde(untagged)]
105pub enum PreprocessingStep {
106    /// Normalization step
107    Normalize { normalize: NormalizeConfig },
108    /// Encoding step
109    Encode { encode: EncodeConfig },
110    /// Drop columns step
111    Drop { drop: DropConfig },
112    /// Fill NA step
113    FillNa { fillna: FillNaConfig },
114    /// Tokenization step
115    Tokenize { tokenize: TokenizeConfig },
116}
117
118/// Normalization configuration
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct NormalizeConfig {
121    pub columns: Vec<String>,
122    pub method: String,
123}
124
125/// Encoding configuration
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct EncodeConfig {
128    pub columns: Vec<String>,
129    pub method: String,
130}
131
132/// Drop columns configuration
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct DropConfig {
135    pub columns: Vec<String>,
136}
137
138/// Fill NA configuration
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct FillNaConfig {
141    pub strategy: String,
142    #[serde(default, skip_serializing_if = "Option::is_none")]
143    pub value: Option<serde_json::Value>,
144}
145
146/// Tokenization configuration
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct TokenizeConfig {
149    pub tokenizer: String,
150    #[serde(default, skip_serializing_if = "Option::is_none")]
151    pub max_length: Option<usize>,
152    #[serde(default, skip_serializing_if = "Option::is_none")]
153    pub padding: Option<String>,
154    #[serde(default, skip_serializing_if = "Option::is_none")]
155    pub truncation: Option<bool>,
156}
157
158/// DataLoader settings
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct DataLoader {
161    /// Batch size
162    pub batch_size: usize,
163
164    /// Shuffle data each epoch
165    pub shuffle: bool,
166
167    /// Number of worker processes
168    #[serde(default, skip_serializing_if = "Option::is_none")]
169    pub num_workers: Option<usize>,
170
171    /// Pin memory for GPU transfer
172    #[serde(default, skip_serializing_if = "Option::is_none")]
173    pub pin_memory: Option<bool>,
174
175    /// Drop incomplete last batch
176    #[serde(default, skip_serializing_if = "Option::is_none")]
177    pub drop_last: Option<bool>,
178
179    /// Prefetch factor
180    #[serde(default, skip_serializing_if = "Option::is_none")]
181    pub prefetch_factor: Option<usize>,
182}