use clap::{Parser, ValueEnum};
pub struct TrainingArgs {
pub learning_rate: f64,
pub load: Option<String>,
pub save: Option<String>,
pub epochs: usize,
}
#[derive(ValueEnum, Clone)]
pub enum WhichModel {
Linear,
Mlp,
}
#[derive(ValueEnum, Clone)]
pub enum WhichOptim {
Adadelta,
Adagrad,
Adamax,
Sgd,
NAdam,
RAdam,
Rms,
Adam,
}
#[derive(Parser)]
pub struct Args {
#[clap(value_enum, default_value_t = WhichModel::Linear)]
pub model: WhichModel,
#[arg(long, value_enum, default_value_t = WhichOptim::Adadelta)]
pub optim: WhichOptim,
#[arg(long)]
pub learning_rate: Option<f64>,
#[arg(long, default_value_t = 200)]
pub epochs: usize,
#[arg(long)]
pub save: Option<String>,
#[arg(long)]
pub load: Option<String>,
#[arg(long)]
pub local_mnist: Option<String>,
}