use serde::Deserialize;
use std::path::PathBuf;
#[derive(Debug, Deserialize)]
pub struct FitConfig {
pub input: InputConfig,
pub model: ModelConfig,
pub output: OutputConfig,
#[serde(default)]
pub forecast: ForecastConfig,
}
#[derive(Debug, Deserialize)]
pub struct InputConfig {
pub path: PathBuf,
pub data_type: String,
#[serde(default = "default_selection")]
pub selection: String,
}
fn default_selection() -> String {
"latest".into()
}
#[derive(Debug, Deserialize)]
pub struct ModelConfig {
#[serde(default = "default_max_iter")]
pub max_iter: usize,
#[serde(default = "default_tolerance")]
pub tolerance: f64,
#[serde(default = "default_learning_rate")]
pub learning_rate: f64,
#[serde(default = "default_train_ratio")]
pub train_ratio: f64,
#[serde(default = "default_gap_threshold")]
pub gap_threshold_secs: f64,
}
fn default_max_iter() -> usize {
50_000
}
fn default_tolerance() -> f64 {
1e3
}
fn default_learning_rate() -> f64 {
1e-2
}
fn default_train_ratio() -> f64 {
0.8
}
fn default_gap_threshold() -> f64 {
5.0
}
#[derive(Debug, Deserialize)]
pub struct OutputConfig {
pub artifact_dir: PathBuf,
#[serde(default = "default_format", rename = "format")]
pub _format: String,
#[serde(default = "default_include_diag")]
pub include_diagnostics: bool,
}
fn default_format() -> String {
"json".into()
}
fn default_include_diag() -> bool {
true
}
#[derive(Debug, Deserialize)]
pub struct ForecastConfig {
#[serde(default = "default_mc_paths")]
pub mc_paths: usize,
#[serde(default)]
pub mc_statistic: EnsembleStatistic,
}
impl Default for ForecastConfig {
fn default() -> Self {
Self {
mc_paths: default_mc_paths(),
mc_statistic: EnsembleStatistic::default(),
}
}
}
fn default_mc_paths() -> usize {
1
}
#[derive(Default, Debug, Clone, Copy, PartialEq, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum EnsembleStatistic {
#[default]
Median,
Mean,
P25,
P75,
}