entrenar/config/cli/
init.rs1use clap::Parser;
4use std::path::PathBuf;
5
6#[derive(Parser, Debug, Clone, PartialEq)]
8pub struct InitArgs {
9 #[arg(short, long, default_value = "minimal")]
11 pub template: InitTemplate,
12
13 #[arg(short, long)]
15 pub output: Option<PathBuf>,
16
17 #[arg(long, default_value = "my-experiment")]
19 pub name: String,
20
21 #[arg(long)]
23 pub model: Option<String>,
24
25 #[arg(long)]
27 pub base: Option<String>,
28
29 #[arg(long)]
31 pub method: Option<TrainingMethod>,
32
33 #[arg(long)]
35 pub data: Option<String>,
36}
37
38#[derive(Debug, Clone, Copy, PartialEq)]
40pub enum TrainingMethod {
41 Full,
43 Lora,
45 Qlora,
47}
48
49impl std::str::FromStr for TrainingMethod {
50 type Err = String;
51
52 fn from_str(s: &str) -> Result<Self, Self::Err> {
53 match s.to_lowercase().as_str() {
54 "full" => Ok(TrainingMethod::Full),
55 "lora" => Ok(TrainingMethod::Lora),
56 "qlora" => Ok(TrainingMethod::Qlora),
57 _ => Err(format!("Unknown method: {s}. Valid methods: full, lora, qlora")),
58 }
59 }
60}
61
62impl std::fmt::Display for TrainingMethod {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 match self {
65 TrainingMethod::Full => write!(f, "full"),
66 TrainingMethod::Lora => write!(f, "lora"),
67 TrainingMethod::Qlora => write!(f, "qlora"),
68 }
69 }
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Default)]
74pub enum InitTemplate {
75 #[default]
77 Minimal,
78 Lora,
80 Qlora,
82 Full,
84}
85
86impl std::str::FromStr for InitTemplate {
87 type Err = String;
88
89 fn from_str(s: &str) -> Result<Self, Self::Err> {
90 match s.to_lowercase().as_str() {
91 "minimal" | "min" => Ok(InitTemplate::Minimal),
92 "lora" => Ok(InitTemplate::Lora),
93 "qlora" => Ok(InitTemplate::Qlora),
94 "full" | "complete" => Ok(InitTemplate::Full),
95 _ => Err(format!("Unknown template: {s}. Valid templates: minimal, lora, qlora, full")),
96 }
97 }
98}
99
100impl std::fmt::Display for InitTemplate {
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 match self {
103 InitTemplate::Minimal => write!(f, "minimal"),
104 InitTemplate::Lora => write!(f, "lora"),
105 InitTemplate::Qlora => write!(f, "qlora"),
106 InitTemplate::Full => write!(f, "full"),
107 }
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114
115 #[test]
116 fn test_init_template_from_str() {
117 assert_eq!(
118 "minimal".parse::<InitTemplate>().expect("parsing should succeed"),
119 InitTemplate::Minimal
120 );
121 assert_eq!(
122 "min".parse::<InitTemplate>().expect("parsing should succeed"),
123 InitTemplate::Minimal
124 );
125 assert_eq!(
126 "lora".parse::<InitTemplate>().expect("parsing should succeed"),
127 InitTemplate::Lora
128 );
129 assert_eq!(
130 "qlora".parse::<InitTemplate>().expect("parsing should succeed"),
131 InitTemplate::Qlora
132 );
133 assert_eq!(
134 "full".parse::<InitTemplate>().expect("parsing should succeed"),
135 InitTemplate::Full
136 );
137 assert_eq!(
138 "complete".parse::<InitTemplate>().expect("parsing should succeed"),
139 InitTemplate::Full
140 );
141 assert!("invalid".parse::<InitTemplate>().is_err());
142 }
143
144 #[test]
145 fn test_init_template_display() {
146 assert_eq!(format!("{}", InitTemplate::Minimal), "minimal");
147 assert_eq!(format!("{}", InitTemplate::Lora), "lora");
148 assert_eq!(format!("{}", InitTemplate::Qlora), "qlora");
149 assert_eq!(format!("{}", InitTemplate::Full), "full");
150 }
151
152 #[test]
153 fn test_init_template_default() {
154 assert_eq!(InitTemplate::default(), InitTemplate::Minimal);
155 }
156
157 #[test]
158 fn test_training_method_from_str() {
159 assert_eq!(
160 "full".parse::<TrainingMethod>().expect("parsing should succeed"),
161 TrainingMethod::Full
162 );
163 assert_eq!(
164 "lora".parse::<TrainingMethod>().expect("parsing should succeed"),
165 TrainingMethod::Lora
166 );
167 assert_eq!(
168 "qlora".parse::<TrainingMethod>().expect("parsing should succeed"),
169 TrainingMethod::Qlora
170 );
171 assert_eq!(
172 "LORA".parse::<TrainingMethod>().expect("parsing should succeed"),
173 TrainingMethod::Lora
174 );
175 assert!("invalid".parse::<TrainingMethod>().is_err());
176 }
177
178 #[test]
179 fn test_training_method_display() {
180 assert_eq!(format!("{}", TrainingMethod::Full), "full");
181 assert_eq!(format!("{}", TrainingMethod::Lora), "lora");
182 assert_eq!(format!("{}", TrainingMethod::Qlora), "qlora");
183 }
184}