entrenar/yaml_mode/manifest/
data.rs1use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8use super::shorthand::deserialize_human_usize_opt;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct DataConfig {
13 #[serde(default, skip_serializing_if = "Option::is_none")]
15 pub source: Option<String>,
16
17 #[serde(default, skip_serializing_if = "Option::is_none")]
19 pub format: Option<String>,
20
21 #[serde(default, skip_serializing_if = "Option::is_none")]
23 pub split: Option<DataSplit>,
24
25 #[serde(default, skip_serializing_if = "Option::is_none")]
27 pub train: Option<String>,
28
29 #[serde(default, skip_serializing_if = "Option::is_none")]
31 pub val: Option<String>,
32
33 #[serde(default, skip_serializing_if = "Option::is_none")]
35 pub test: Option<String>,
36
37 #[serde(default, skip_serializing_if = "Option::is_none")]
39 pub preprocessing: Option<Vec<PreprocessingStep>>,
40
41 #[serde(default, skip_serializing_if = "Option::is_none")]
43 pub augmentation: Option<Vec<HashMap<String, serde_json::Value>>>,
44
45 #[serde(default, skip_serializing_if = "Option::is_none")]
47 pub loader: Option<DataLoader>,
48
49 #[serde(default, skip_serializing_if = "Option::is_none")]
52 pub tokenizer: Option<String>,
53
54 #[serde(
56 default,
57 skip_serializing_if = "Option::is_none",
58 deserialize_with = "deserialize_human_usize_opt"
59 )]
60 pub seq_len: Option<usize>,
61
62 #[serde(default, skip_serializing_if = "Option::is_none")]
64 pub input_column: Option<String>,
65
66 #[serde(default, skip_serializing_if = "Option::is_none")]
68 pub output_column: Option<String>,
69
70 #[serde(
72 default,
73 skip_serializing_if = "Option::is_none",
74 deserialize_with = "deserialize_human_usize_opt"
75 )]
76 pub max_length: Option<usize>,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct DataSplit {
82 pub train: f64,
84
85 #[serde(default, skip_serializing_if = "Option::is_none")]
87 pub val: Option<f64>,
88
89 #[serde(default, skip_serializing_if = "Option::is_none")]
91 pub test: Option<f64>,
92
93 #[serde(default, skip_serializing_if = "Option::is_none")]
95 pub stratify: Option<String>,
96
97 #[serde(default, skip_serializing_if = "Option::is_none")]
99 pub seed: Option<u64>,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104#[serde(untagged)]
105pub enum PreprocessingStep {
106 Normalize { normalize: NormalizeConfig },
108 Encode { encode: EncodeConfig },
110 Drop { drop: DropConfig },
112 FillNa { fillna: FillNaConfig },
114 Tokenize { tokenize: TokenizeConfig },
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct NormalizeConfig {
121 pub columns: Vec<String>,
122 pub method: String,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct EncodeConfig {
128 pub columns: Vec<String>,
129 pub method: String,
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct DropConfig {
135 pub columns: Vec<String>,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct FillNaConfig {
141 pub strategy: String,
142 #[serde(default, skip_serializing_if = "Option::is_none")]
143 pub value: Option<serde_json::Value>,
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct TokenizeConfig {
149 pub tokenizer: String,
150 #[serde(default, skip_serializing_if = "Option::is_none")]
151 pub max_length: Option<usize>,
152 #[serde(default, skip_serializing_if = "Option::is_none")]
153 pub padding: Option<String>,
154 #[serde(default, skip_serializing_if = "Option::is_none")]
155 pub truncation: Option<bool>,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct DataLoader {
161 pub batch_size: usize,
163
164 pub shuffle: bool,
166
167 #[serde(default, skip_serializing_if = "Option::is_none")]
169 pub num_workers: Option<usize>,
170
171 #[serde(default, skip_serializing_if = "Option::is_none")]
173 pub pin_memory: Option<bool>,
174
175 #[serde(default, skip_serializing_if = "Option::is_none")]
177 pub drop_last: Option<bool>,
178
179 #[serde(default, skip_serializing_if = "Option::is_none")]
181 pub prefetch_factor: Option<usize>,
182}