1use crate::cli::logging::log;
9use crate::cli::LogLevel;
10use crate::config::{InitArgs, InitTemplate, TrainingMethod};
11use crate::yaml_mode::{generate_yaml, Template};
12
13#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum ModelSize {
16 Small,
18 Medium,
20 Large,
22 XLarge,
24}
25
26#[derive(Debug, Clone, Copy, PartialEq)]
28pub enum DataFormat {
29 Jsonl,
30 Parquet,
31 Csv,
32 Text,
33 Unknown,
34}
35
36pub fn estimate_model_size(model_name: &str) -> ModelSize {
44 let name = model_name.rsplit('/').next().unwrap_or(model_name);
46 let name_upper = name.to_uppercase();
47
48 for segment in name_upper.split(['-', '_']) {
50 if let Some(stripped) = segment.strip_suffix('B') {
51 if let Ok(size) = stripped.parse::<f64>() {
52 return categorize_param_count(size);
53 }
54 }
55 }
56
57 ModelSize::Medium
59}
60
61fn categorize_param_count(billions: f64) -> ModelSize {
62 if billions < 1.0 {
63 ModelSize::Small
64 } else if billions <= 7.0 {
65 ModelSize::Medium
66 } else if billions <= 30.0 {
67 ModelSize::Large
68 } else {
69 ModelSize::XLarge
70 }
71}
72
73pub fn suggest_lora_rank(size: ModelSize) -> u32 {
75 match size {
76 ModelSize::Small => 32,
77 ModelSize::Medium => 64,
78 ModelSize::Large => 128,
79 ModelSize::XLarge => 256,
80 }
81}
82
83pub fn suggest_learning_rate(size: ModelSize) -> f64 {
85 match size {
86 ModelSize::Small => 3e-4,
87 ModelSize::Medium => 2e-4,
88 ModelSize::Large => 1e-4,
89 ModelSize::XLarge => 5e-5,
90 }
91}
92
93pub fn detect_data_format(path: &str) -> DataFormat {
95 let path = std::path::Path::new(path);
96
97 if path.is_file() {
99 return format_from_extension(path);
100 }
101
102 if path.is_dir() {
104 if let Ok(entries) = std::fs::read_dir(path) {
105 for entry in entries.flatten() {
106 let p = entry.path();
107 let fmt = format_from_extension(&p);
108 if fmt != DataFormat::Unknown {
109 return fmt;
110 }
111 }
112 }
113 }
114
115 format_from_extension(path)
117}
118
119fn format_from_extension(path: &std::path::Path) -> DataFormat {
120 match path.extension().and_then(|e| e.to_str()) {
121 Some("jsonl" | "jsonlines") => DataFormat::Jsonl,
122 Some("parquet" | "pq") => DataFormat::Parquet,
123 Some("csv" | "tsv") => DataFormat::Csv,
124 Some("txt" | "text") => DataFormat::Text,
125 _ => DataFormat::Unknown,
126 }
127}
128
129impl std::fmt::Display for DataFormat {
130 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131 match self {
132 DataFormat::Jsonl => write!(f, "jsonl"),
133 DataFormat::Parquet => write!(f, "parquet"),
134 DataFormat::Csv => write!(f, "csv"),
135 DataFormat::Text => write!(f, "text"),
136 DataFormat::Unknown => write!(f, "unknown"),
137 }
138 }
139}
140
141impl std::fmt::Display for ModelSize {
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 match self {
144 ModelSize::Small => write!(f, "small (<1B)"),
145 ModelSize::Medium => write!(f, "medium (1-7B)"),
146 ModelSize::Large => write!(f, "large (7-30B)"),
147 ModelSize::XLarge => write!(f, "xlarge (>30B)"),
148 }
149 }
150}
151
152pub fn run_init(args: InitArgs, level: LogLevel) -> Result<(), String> {
153 let model_source = args.base.as_deref().or(args.model.as_deref());
155
156 let template = if let Some(method) = &args.method {
158 match method {
159 TrainingMethod::Full => Template::Full,
160 TrainingMethod::Lora => Template::Lora,
161 TrainingMethod::Qlora => Template::Qlora,
162 }
163 } else {
164 match args.template {
165 InitTemplate::Minimal => Template::Minimal,
166 InitTemplate::Lora => Template::Lora,
167 InitTemplate::Qlora => Template::Qlora,
168 InitTemplate::Full => Template::Full,
169 }
170 };
171
172 let model_size = model_source.map(estimate_model_size);
174 let lora_rank = model_size.map(suggest_lora_rank);
175 let lr = model_size.map(suggest_learning_rate);
176
177 let data_format = args.data.as_deref().map(detect_data_format);
179
180 if let Some(source) = model_source {
182 log(level, LogLevel::Normal, &format!("Model: {source}"));
183 }
184 if let Some(size) = model_size {
185 log(
186 level,
187 LogLevel::Normal,
188 &format!("Detected size: {size}, suggested LoRA rank: {}", lora_rank.unwrap_or(64)),
189 );
190 }
191 if let Some(fmt) = data_format {
192 if fmt != DataFormat::Unknown {
193 log(level, LogLevel::Normal, &format!("Data format: {fmt}"));
194 }
195 }
196
197 log(level, LogLevel::Normal, &format!("Generating {template:?} template for: {}", args.name));
198
199 let yaml =
201 generate_yaml(template, &args.name, model_source, args.data.as_deref(), lora_rank, lr);
202
203 if let Some(output_path) = &args.output {
205 std::fs::write(output_path, &yaml).map_err(|e| format!("Failed to write file: {e}"))?;
206 log(level, LogLevel::Normal, &format!("Manifest saved to: {}", output_path.display()));
207 } else {
208 println!("{yaml}");
209 }
210
211 Ok(())
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217
218 #[test]
221 fn test_estimate_model_size_small() {
222 assert_eq!(estimate_model_size("Qwen/Qwen2.5-Coder-0.5B"), ModelSize::Small);
223 assert_eq!(estimate_model_size("microsoft/phi-2-0.3B"), ModelSize::Small);
224 }
225
226 #[test]
227 fn test_estimate_model_size_medium() {
228 assert_eq!(estimate_model_size("meta-llama/Llama-3-7B"), ModelSize::Medium);
229 assert_eq!(estimate_model_size("mistralai/Mistral-1.5B-Instruct"), ModelSize::Medium);
230 }
231
232 #[test]
233 fn test_estimate_model_size_large() {
234 assert_eq!(estimate_model_size("meta-llama/Llama-3-13B"), ModelSize::Large);
235 }
236
237 #[test]
238 fn test_estimate_model_size_xlarge() {
239 assert_eq!(estimate_model_size("meta-llama/Llama-3-70B"), ModelSize::XLarge);
240 }
241
242 #[test]
243 fn test_estimate_model_size_unknown_defaults_medium() {
244 assert_eq!(estimate_model_size("some-org/some-model"), ModelSize::Medium);
245 }
246
247 #[test]
248 fn test_suggest_lora_rank() {
249 assert_eq!(suggest_lora_rank(ModelSize::Small), 32);
250 assert_eq!(suggest_lora_rank(ModelSize::Medium), 64);
251 assert_eq!(suggest_lora_rank(ModelSize::Large), 128);
252 assert_eq!(suggest_lora_rank(ModelSize::XLarge), 256);
253 }
254
255 #[test]
256 fn test_suggest_learning_rate() {
257 assert!((suggest_learning_rate(ModelSize::Small) - 3e-4).abs() < 1e-10);
258 assert!((suggest_learning_rate(ModelSize::Medium) - 2e-4).abs() < 1e-10);
259 assert!((suggest_learning_rate(ModelSize::Large) - 1e-4).abs() < 1e-10);
260 assert!((suggest_learning_rate(ModelSize::XLarge) - 5e-5).abs() < 1e-10);
261 }
262
263 #[test]
266 fn test_detect_data_format_jsonl() {
267 assert_eq!(detect_data_format("train.jsonl"), DataFormat::Jsonl);
268 assert_eq!(detect_data_format("data/train.jsonlines"), DataFormat::Jsonl);
269 }
270
271 #[test]
272 fn test_detect_data_format_parquet() {
273 assert_eq!(detect_data_format("data.parquet"), DataFormat::Parquet);
274 assert_eq!(detect_data_format("data.pq"), DataFormat::Parquet);
275 }
276
277 #[test]
278 fn test_detect_data_format_csv() {
279 assert_eq!(detect_data_format("train.csv"), DataFormat::Csv);
280 assert_eq!(detect_data_format("train.tsv"), DataFormat::Csv);
281 }
282
283 #[test]
284 fn test_detect_data_format_text() {
285 assert_eq!(detect_data_format("corpus.txt"), DataFormat::Text);
286 }
287
288 #[test]
289 fn test_detect_data_format_unknown() {
290 assert_eq!(detect_data_format("data.bin"), DataFormat::Unknown);
291 assert_eq!(detect_data_format("./my-data/"), DataFormat::Unknown);
292 }
293
294 #[test]
295 fn test_data_format_display() {
296 assert_eq!(format!("{}", DataFormat::Jsonl), "jsonl");
297 assert_eq!(format!("{}", DataFormat::Parquet), "parquet");
298 assert_eq!(format!("{}", DataFormat::Csv), "csv");
299 assert_eq!(format!("{}", DataFormat::Text), "text");
300 assert_eq!(format!("{}", DataFormat::Unknown), "unknown");
301 }
302
303 #[test]
304 fn test_model_size_display() {
305 assert_eq!(format!("{}", ModelSize::Small), "small (<1B)");
306 assert_eq!(format!("{}", ModelSize::Medium), "medium (1-7B)");
307 assert_eq!(format!("{}", ModelSize::Large), "large (7-30B)");
308 assert_eq!(format!("{}", ModelSize::XLarge), "xlarge (>30B)");
309 }
310
311 #[test]
314 fn test_run_init_with_base_flag() {
315 let args = InitArgs {
316 name: "test_project".to_string(),
317 output: None,
318 template: InitTemplate::Minimal,
319 model: None,
320 base: Some("Qwen/Qwen2.5-Coder-0.5B".to_string()),
321 method: Some(TrainingMethod::Qlora),
322 data: None,
323 };
324
325 let result = run_init(args, LogLevel::Quiet);
326 assert!(result.is_ok());
327 }
328
329 #[test]
330 fn test_run_init_method_overrides_template() {
331 let args = InitArgs {
332 name: "test_project".to_string(),
333 output: None,
334 template: InitTemplate::Minimal, model: None,
336 base: Some("meta-llama/Llama-3-7B".to_string()),
337 method: Some(TrainingMethod::Lora),
338 data: Some("train.jsonl".to_string()),
339 };
340
341 let result = run_init(args, LogLevel::Quiet);
342 assert!(result.is_ok());
343 }
344
345 #[test]
346 fn test_run_init_base_overrides_model() {
347 let args = InitArgs {
349 name: "test".to_string(),
350 output: None,
351 template: InitTemplate::Lora,
352 model: Some("local-model.safetensors".to_string()),
353 base: Some("Qwen/Qwen2.5-Coder-0.5B".to_string()),
354 method: None,
355 data: None,
356 };
357
358 let result = run_init(args, LogLevel::Quiet);
359 assert!(result.is_ok());
360 }
361
362 #[test]
363 fn test_categorize_param_count() {
364 assert_eq!(categorize_param_count(0.5), ModelSize::Small);
365 assert_eq!(categorize_param_count(1.0), ModelSize::Medium);
366 assert_eq!(categorize_param_count(7.0), ModelSize::Medium);
367 assert_eq!(categorize_param_count(13.0), ModelSize::Large);
368 assert_eq!(categorize_param_count(30.0), ModelSize::Large);
369 assert_eq!(categorize_param_count(70.0), ModelSize::XLarge);
370 }
371}