Skip to main content

entrenar/config/cli/
init.rs

1//! Init command types
2
3use clap::Parser;
4use std::path::PathBuf;
5
6/// Arguments for the init command
7#[derive(Parser, Debug, Clone, PartialEq)]
8pub struct InitArgs {
9    /// Template to use for initialization
10    #[arg(short, long, default_value = "minimal")]
11    pub template: InitTemplate,
12
13    /// Output path (stdout if not specified)
14    #[arg(short, long)]
15    pub output: Option<PathBuf>,
16
17    /// Experiment name
18    #[arg(long, default_value = "my-experiment")]
19    pub name: String,
20
21    /// Model source path or URI
22    #[arg(long)]
23    pub model: Option<String>,
24
25    /// Base model HF repo ID (e.g., Qwen/Qwen2.5-Coder-0.5B)
26    #[arg(long)]
27    pub base: Option<String>,
28
29    /// Training method (overrides --template)
30    #[arg(long)]
31    pub method: Option<TrainingMethod>,
32
33    /// Data source path or URI
34    #[arg(long)]
35    pub data: Option<String>,
36}
37
38/// Training method for --method flag
39#[derive(Debug, Clone, Copy, PartialEq)]
40pub enum TrainingMethod {
41    /// Full fine-tuning
42    Full,
43    /// LoRA fine-tuning
44    Lora,
45    /// QLoRA (quantized LoRA) fine-tuning
46    Qlora,
47}
48
49impl std::str::FromStr for TrainingMethod {
50    type Err = String;
51
52    fn from_str(s: &str) -> Result<Self, Self::Err> {
53        match s.to_lowercase().as_str() {
54            "full" => Ok(TrainingMethod::Full),
55            "lora" => Ok(TrainingMethod::Lora),
56            "qlora" => Ok(TrainingMethod::Qlora),
57            _ => Err(format!("Unknown method: {s}. Valid methods: full, lora, qlora")),
58        }
59    }
60}
61
62impl std::fmt::Display for TrainingMethod {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        match self {
65            TrainingMethod::Full => write!(f, "full"),
66            TrainingMethod::Lora => write!(f, "lora"),
67            TrainingMethod::Qlora => write!(f, "qlora"),
68        }
69    }
70}
71
72/// Init template type
73#[derive(Debug, Clone, Copy, PartialEq, Default)]
74pub enum InitTemplate {
75    /// Minimal manifest with required fields only
76    #[default]
77    Minimal,
78    /// LoRA fine-tuning template
79    Lora,
80    /// QLoRA fine-tuning template
81    Qlora,
82    /// Full template with all sections
83    Full,
84}
85
86impl std::str::FromStr for InitTemplate {
87    type Err = String;
88
89    fn from_str(s: &str) -> Result<Self, Self::Err> {
90        match s.to_lowercase().as_str() {
91            "minimal" | "min" => Ok(InitTemplate::Minimal),
92            "lora" => Ok(InitTemplate::Lora),
93            "qlora" => Ok(InitTemplate::Qlora),
94            "full" | "complete" => Ok(InitTemplate::Full),
95            _ => Err(format!("Unknown template: {s}. Valid templates: minimal, lora, qlora, full")),
96        }
97    }
98}
99
100impl std::fmt::Display for InitTemplate {
101    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102        match self {
103            InitTemplate::Minimal => write!(f, "minimal"),
104            InitTemplate::Lora => write!(f, "lora"),
105            InitTemplate::Qlora => write!(f, "qlora"),
106            InitTemplate::Full => write!(f, "full"),
107        }
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114
115    #[test]
116    fn test_init_template_from_str() {
117        assert_eq!(
118            "minimal".parse::<InitTemplate>().expect("parsing should succeed"),
119            InitTemplate::Minimal
120        );
121        assert_eq!(
122            "min".parse::<InitTemplate>().expect("parsing should succeed"),
123            InitTemplate::Minimal
124        );
125        assert_eq!(
126            "lora".parse::<InitTemplate>().expect("parsing should succeed"),
127            InitTemplate::Lora
128        );
129        assert_eq!(
130            "qlora".parse::<InitTemplate>().expect("parsing should succeed"),
131            InitTemplate::Qlora
132        );
133        assert_eq!(
134            "full".parse::<InitTemplate>().expect("parsing should succeed"),
135            InitTemplate::Full
136        );
137        assert_eq!(
138            "complete".parse::<InitTemplate>().expect("parsing should succeed"),
139            InitTemplate::Full
140        );
141        assert!("invalid".parse::<InitTemplate>().is_err());
142    }
143
144    #[test]
145    fn test_init_template_display() {
146        assert_eq!(format!("{}", InitTemplate::Minimal), "minimal");
147        assert_eq!(format!("{}", InitTemplate::Lora), "lora");
148        assert_eq!(format!("{}", InitTemplate::Qlora), "qlora");
149        assert_eq!(format!("{}", InitTemplate::Full), "full");
150    }
151
152    #[test]
153    fn test_init_template_default() {
154        assert_eq!(InitTemplate::default(), InitTemplate::Minimal);
155    }
156
157    #[test]
158    fn test_training_method_from_str() {
159        assert_eq!(
160            "full".parse::<TrainingMethod>().expect("parsing should succeed"),
161            TrainingMethod::Full
162        );
163        assert_eq!(
164            "lora".parse::<TrainingMethod>().expect("parsing should succeed"),
165            TrainingMethod::Lora
166        );
167        assert_eq!(
168            "qlora".parse::<TrainingMethod>().expect("parsing should succeed"),
169            TrainingMethod::Qlora
170        );
171        assert_eq!(
172            "LORA".parse::<TrainingMethod>().expect("parsing should succeed"),
173            TrainingMethod::Lora
174        );
175        assert!("invalid".parse::<TrainingMethod>().is_err());
176    }
177
178    #[test]
179    fn test_training_method_display() {
180        assert_eq!(format!("{}", TrainingMethod::Full), "full");
181        assert_eq!(format!("{}", TrainingMethod::Lora), "lora");
182        assert_eq!(format!("{}", TrainingMethod::Qlora), "qlora");
183    }
184}