use clap::{Parser, Subcommand};
use std::path::PathBuf;
use super::extended::{
AuditArgs, BenchArgs, CompletionArgs, ExperimentsArgs, FinetuneArgs, InspectArgs, MonitorArgs,
PublishArgs,
};
use super::init::InitArgs;
use super::quant_merge::{MergeArgs, QuantizeArgs};
use super::research::ResearchArgs;
use super::types::OutputFormat;
#[derive(Parser, Debug, Clone, PartialEq)]
#[command(name = "entrenar")]
#[command(author = "PAIML")]
#[command(version)]
#[command(
about = "Training & Optimization Library with autograd, LoRA, quantization, and model merging"
)]
pub struct Cli {
#[command(subcommand)]
pub command: Command,
#[arg(short, long, global = true)]
pub verbose: bool,
#[arg(short, long, global = true)]
pub quiet: bool,
}
#[derive(Subcommand, Debug, Clone, PartialEq)]
pub enum Command {
Train(TrainArgs),
Validate(ValidateArgs),
Info(InfoArgs),
Init(InitArgs),
Quantize(QuantizeArgs),
Merge(MergeArgs),
Research(ResearchArgs),
Completion(CompletionArgs),
Bench(BenchArgs),
Inspect(InspectArgs),
Audit(AuditArgs),
Monitor(MonitorArgs),
Publish(PublishArgs),
Finetune(FinetuneArgs),
Experiments(ExperimentsArgs),
}
#[derive(Parser, Debug, Clone, PartialEq)]
pub struct TrainArgs {
#[arg(value_name = "CONFIG")]
pub config: PathBuf,
#[arg(short, long)]
pub output_dir: Option<PathBuf>,
#[arg(short, long)]
pub resume: Option<PathBuf>,
#[arg(short, long)]
pub epochs: Option<usize>,
#[arg(short, long)]
pub batch_size: Option<usize>,
#[arg(short, long)]
pub lr: Option<f32>,
#[arg(long)]
pub dry_run: bool,
#[arg(long)]
pub save_every: Option<usize>,
#[arg(long)]
pub log_every: Option<usize>,
#[arg(long)]
pub seed: Option<u64>,
}
#[derive(Parser, Debug, Clone, PartialEq)]
pub struct ValidateArgs {
#[arg(value_name = "CONFIG")]
pub config: PathBuf,
#[arg(short, long)]
pub detailed: bool,
}
#[derive(Parser, Debug, Clone, PartialEq)]
pub struct InfoArgs {
#[arg(value_name = "CONFIG")]
pub config: PathBuf,
#[arg(short, long, default_value = "text")]
pub format: OutputFormat,
}
pub fn parse_args<I, T>(args: I) -> Result<Cli, clap::Error>
where
I: IntoIterator<Item = T>,
T: Into<std::ffi::OsString> + Clone,
{
Cli::try_parse_from(args)
}
pub fn apply_overrides(spec: &mut crate::config::TrainSpec, args: &TrainArgs) {
if let Some(output_dir) = &args.output_dir {
spec.training.output_dir = output_dir.clone();
}
if let Some(epochs) = args.epochs {
spec.training.epochs = epochs;
}
if let Some(batch_size) = args.batch_size {
spec.data.batch_size = batch_size;
}
if let Some(lr) = args.lr {
spec.optimizer.lr = lr;
}
if let Some(save_every) = args.save_every {
spec.training.save_interval = save_every;
}
}