use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use super::shorthand::deserialize_human_usize_opt;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DataConfig {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub source: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub format: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub split: Option<DataSplit>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub train: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub val: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub test: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub preprocessing: Option<Vec<PreprocessingStep>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub augmentation: Option<Vec<HashMap<String, serde_json::Value>>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub loader: Option<DataLoader>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tokenizer: Option<String>,
#[serde(
default,
skip_serializing_if = "Option::is_none",
deserialize_with = "deserialize_human_usize_opt"
)]
pub seq_len: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub input_column: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub output_column: Option<String>,
#[serde(
default,
skip_serializing_if = "Option::is_none",
deserialize_with = "deserialize_human_usize_opt"
)]
pub max_length: Option<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DataSplit {
pub train: f64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub val: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub test: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stratify: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub seed: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum PreprocessingStep {
Normalize { normalize: NormalizeConfig },
Encode { encode: EncodeConfig },
Drop { drop: DropConfig },
FillNa { fillna: FillNaConfig },
Tokenize { tokenize: TokenizeConfig },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NormalizeConfig {
pub columns: Vec<String>,
pub method: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncodeConfig {
pub columns: Vec<String>,
pub method: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DropConfig {
pub columns: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FillNaConfig {
pub strategy: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub value: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenizeConfig {
pub tokenizer: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_length: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub padding: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub truncation: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DataLoader {
pub batch_size: usize,
pub shuffle: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub num_workers: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub pin_memory: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub drop_last: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prefetch_factor: Option<usize>,
}