Skip to main content

entrenar/cli/commands/
init.rs

1//! Init command implementation
2//!
3//! Supports smart initialization with:
4//! - `--base` for HF model IDs (auto-detects model size for LoRA rank)
5//! - `--method` for training method selection (lora, qlora, full)
6//! - Data format auto-detection (JSONL, Parquet, text, CSV)
7
8use crate::cli::logging::log;
9use crate::cli::LogLevel;
10use crate::config::{InitArgs, InitTemplate, TrainingMethod};
11use crate::yaml_mode::{generate_yaml, Template};
12
13/// Estimated model size category for LoRA rank suggestion
14#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum ModelSize {
16    /// < 1B parameters
17    Small,
18    /// 1B - 7B parameters
19    Medium,
20    /// 7B - 30B parameters
21    Large,
22    /// > 30B parameters
23    XLarge,
24}
25
26/// Detected data format
27#[derive(Debug, Clone, Copy, PartialEq)]
28pub enum DataFormat {
29    Jsonl,
30    Parquet,
31    Csv,
32    Text,
33    Unknown,
34}
35
36/// Estimate model size from HF model name patterns
37///
38/// Parses common naming conventions like:
39/// - "Qwen/Qwen2.5-Coder-0.5B" -> Small
40/// - "meta-llama/Llama-3-7B" -> Medium
41/// - "meta-llama/Llama-3-13B" -> Large
42/// - "meta-llama/Llama-3-70B" -> XLarge
43pub fn estimate_model_size(model_name: &str) -> ModelSize {
44    // Extract the part after the last "/" (the model name itself)
45    let name = model_name.rsplit('/').next().unwrap_or(model_name);
46    let name_upper = name.to_uppercase();
47
48    // Split by '-' and '_' only (not '.') to preserve decimal numbers like "0.5B"
49    for segment in name_upper.split(['-', '_']) {
50        if let Some(stripped) = segment.strip_suffix('B') {
51            if let Ok(size) = stripped.parse::<f64>() {
52                return categorize_param_count(size);
53            }
54        }
55    }
56
57    // Default to medium if we can't determine
58    ModelSize::Medium
59}
60
61fn categorize_param_count(billions: f64) -> ModelSize {
62    if billions < 1.0 {
63        ModelSize::Small
64    } else if billions <= 7.0 {
65        ModelSize::Medium
66    } else if billions <= 30.0 {
67        ModelSize::Large
68    } else {
69        ModelSize::XLarge
70    }
71}
72
73/// Suggest LoRA rank based on model size
74pub fn suggest_lora_rank(size: ModelSize) -> u32 {
75    match size {
76        ModelSize::Small => 32,
77        ModelSize::Medium => 64,
78        ModelSize::Large => 128,
79        ModelSize::XLarge => 256,
80    }
81}
82
83/// Suggest learning rate based on model size
84pub fn suggest_learning_rate(size: ModelSize) -> f64 {
85    match size {
86        ModelSize::Small => 3e-4,
87        ModelSize::Medium => 2e-4,
88        ModelSize::Large => 1e-4,
89        ModelSize::XLarge => 5e-5,
90    }
91}
92
93/// Detect data format from a path (file or directory)
94pub fn detect_data_format(path: &str) -> DataFormat {
95    let path = std::path::Path::new(path);
96
97    // If it's a file, check the extension
98    if path.is_file() {
99        return format_from_extension(path);
100    }
101
102    // If it's a directory, scan for common data files
103    if path.is_dir() {
104        if let Ok(entries) = std::fs::read_dir(path) {
105            for entry in entries.flatten() {
106                let p = entry.path();
107                let fmt = format_from_extension(&p);
108                if fmt != DataFormat::Unknown {
109                    return fmt;
110                }
111            }
112        }
113    }
114
115    // Try extension-based detection for paths that don't exist yet
116    format_from_extension(path)
117}
118
119fn format_from_extension(path: &std::path::Path) -> DataFormat {
120    match path.extension().and_then(|e| e.to_str()) {
121        Some("jsonl" | "jsonlines") => DataFormat::Jsonl,
122        Some("parquet" | "pq") => DataFormat::Parquet,
123        Some("csv" | "tsv") => DataFormat::Csv,
124        Some("txt" | "text") => DataFormat::Text,
125        _ => DataFormat::Unknown,
126    }
127}
128
129impl std::fmt::Display for DataFormat {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        match self {
132            DataFormat::Jsonl => write!(f, "jsonl"),
133            DataFormat::Parquet => write!(f, "parquet"),
134            DataFormat::Csv => write!(f, "csv"),
135            DataFormat::Text => write!(f, "text"),
136            DataFormat::Unknown => write!(f, "unknown"),
137        }
138    }
139}
140
141impl std::fmt::Display for ModelSize {
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        match self {
144            ModelSize::Small => write!(f, "small (<1B)"),
145            ModelSize::Medium => write!(f, "medium (1-7B)"),
146            ModelSize::Large => write!(f, "large (7-30B)"),
147            ModelSize::XLarge => write!(f, "xlarge (>30B)"),
148        }
149    }
150}
151
152pub fn run_init(args: InitArgs, level: LogLevel) -> Result<(), String> {
153    // Resolve model source: --base takes precedence over --model
154    let model_source = args.base.as_deref().or(args.model.as_deref());
155
156    // Resolve template: --method overrides --template
157    let template = if let Some(method) = &args.method {
158        match method {
159            TrainingMethod::Full => Template::Full,
160            TrainingMethod::Lora => Template::Lora,
161            TrainingMethod::Qlora => Template::Qlora,
162        }
163    } else {
164        match args.template {
165            InitTemplate::Minimal => Template::Minimal,
166            InitTemplate::Lora => Template::Lora,
167            InitTemplate::Qlora => Template::Qlora,
168            InitTemplate::Full => Template::Full,
169        }
170    };
171
172    // Detect model size if a base model is specified
173    let model_size = model_source.map(estimate_model_size);
174    let lora_rank = model_size.map(suggest_lora_rank);
175    let lr = model_size.map(suggest_learning_rate);
176
177    // Detect data format
178    let data_format = args.data.as_deref().map(detect_data_format);
179
180    // Log detected settings
181    if let Some(source) = model_source {
182        log(level, LogLevel::Normal, &format!("Model: {source}"));
183    }
184    if let Some(size) = model_size {
185        log(
186            level,
187            LogLevel::Normal,
188            &format!("Detected size: {size}, suggested LoRA rank: {}", lora_rank.unwrap_or(64)),
189        );
190    }
191    if let Some(fmt) = data_format {
192        if fmt != DataFormat::Unknown {
193            log(level, LogLevel::Normal, &format!("Data format: {fmt}"));
194        }
195    }
196
197    log(level, LogLevel::Normal, &format!("Generating {template:?} template for: {}", args.name));
198
199    // Generate YAML manifest with smart defaults
200    let yaml =
201        generate_yaml(template, &args.name, model_source, args.data.as_deref(), lora_rank, lr);
202
203    // Output to file or stdout
204    if let Some(output_path) = &args.output {
205        std::fs::write(output_path, &yaml).map_err(|e| format!("Failed to write file: {e}"))?;
206        log(level, LogLevel::Normal, &format!("Manifest saved to: {}", output_path.display()));
207    } else {
208        println!("{yaml}");
209    }
210
211    Ok(())
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    // === Model Size Detection ===
219
220    #[test]
221    fn test_estimate_model_size_small() {
222        assert_eq!(estimate_model_size("Qwen/Qwen2.5-Coder-0.5B"), ModelSize::Small);
223        assert_eq!(estimate_model_size("microsoft/phi-2-0.3B"), ModelSize::Small);
224    }
225
226    #[test]
227    fn test_estimate_model_size_medium() {
228        assert_eq!(estimate_model_size("meta-llama/Llama-3-7B"), ModelSize::Medium);
229        assert_eq!(estimate_model_size("mistralai/Mistral-1.5B-Instruct"), ModelSize::Medium);
230    }
231
232    #[test]
233    fn test_estimate_model_size_large() {
234        assert_eq!(estimate_model_size("meta-llama/Llama-3-13B"), ModelSize::Large);
235    }
236
237    #[test]
238    fn test_estimate_model_size_xlarge() {
239        assert_eq!(estimate_model_size("meta-llama/Llama-3-70B"), ModelSize::XLarge);
240    }
241
242    #[test]
243    fn test_estimate_model_size_unknown_defaults_medium() {
244        assert_eq!(estimate_model_size("some-org/some-model"), ModelSize::Medium);
245    }
246
247    #[test]
248    fn test_suggest_lora_rank() {
249        assert_eq!(suggest_lora_rank(ModelSize::Small), 32);
250        assert_eq!(suggest_lora_rank(ModelSize::Medium), 64);
251        assert_eq!(suggest_lora_rank(ModelSize::Large), 128);
252        assert_eq!(suggest_lora_rank(ModelSize::XLarge), 256);
253    }
254
255    #[test]
256    fn test_suggest_learning_rate() {
257        assert!((suggest_learning_rate(ModelSize::Small) - 3e-4).abs() < 1e-10);
258        assert!((suggest_learning_rate(ModelSize::Medium) - 2e-4).abs() < 1e-10);
259        assert!((suggest_learning_rate(ModelSize::Large) - 1e-4).abs() < 1e-10);
260        assert!((suggest_learning_rate(ModelSize::XLarge) - 5e-5).abs() < 1e-10);
261    }
262
263    // === Data Format Detection ===
264
265    #[test]
266    fn test_detect_data_format_jsonl() {
267        assert_eq!(detect_data_format("train.jsonl"), DataFormat::Jsonl);
268        assert_eq!(detect_data_format("data/train.jsonlines"), DataFormat::Jsonl);
269    }
270
271    #[test]
272    fn test_detect_data_format_parquet() {
273        assert_eq!(detect_data_format("data.parquet"), DataFormat::Parquet);
274        assert_eq!(detect_data_format("data.pq"), DataFormat::Parquet);
275    }
276
277    #[test]
278    fn test_detect_data_format_csv() {
279        assert_eq!(detect_data_format("train.csv"), DataFormat::Csv);
280        assert_eq!(detect_data_format("train.tsv"), DataFormat::Csv);
281    }
282
283    #[test]
284    fn test_detect_data_format_text() {
285        assert_eq!(detect_data_format("corpus.txt"), DataFormat::Text);
286    }
287
288    #[test]
289    fn test_detect_data_format_unknown() {
290        assert_eq!(detect_data_format("data.bin"), DataFormat::Unknown);
291        assert_eq!(detect_data_format("./my-data/"), DataFormat::Unknown);
292    }
293
294    #[test]
295    fn test_data_format_display() {
296        assert_eq!(format!("{}", DataFormat::Jsonl), "jsonl");
297        assert_eq!(format!("{}", DataFormat::Parquet), "parquet");
298        assert_eq!(format!("{}", DataFormat::Csv), "csv");
299        assert_eq!(format!("{}", DataFormat::Text), "text");
300        assert_eq!(format!("{}", DataFormat::Unknown), "unknown");
301    }
302
303    #[test]
304    fn test_model_size_display() {
305        assert_eq!(format!("{}", ModelSize::Small), "small (<1B)");
306        assert_eq!(format!("{}", ModelSize::Medium), "medium (1-7B)");
307        assert_eq!(format!("{}", ModelSize::Large), "large (7-30B)");
308        assert_eq!(format!("{}", ModelSize::XLarge), "xlarge (>30B)");
309    }
310
311    // === Integration Tests ===
312
313    #[test]
314    fn test_run_init_with_base_flag() {
315        let args = InitArgs {
316            name: "test_project".to_string(),
317            output: None,
318            template: InitTemplate::Minimal,
319            model: None,
320            base: Some("Qwen/Qwen2.5-Coder-0.5B".to_string()),
321            method: Some(TrainingMethod::Qlora),
322            data: None,
323        };
324
325        let result = run_init(args, LogLevel::Quiet);
326        assert!(result.is_ok());
327    }
328
329    #[test]
330    fn test_run_init_method_overrides_template() {
331        let args = InitArgs {
332            name: "test_project".to_string(),
333            output: None,
334            template: InitTemplate::Minimal, // should be overridden by method
335            model: None,
336            base: Some("meta-llama/Llama-3-7B".to_string()),
337            method: Some(TrainingMethod::Lora),
338            data: Some("train.jsonl".to_string()),
339        };
340
341        let result = run_init(args, LogLevel::Quiet);
342        assert!(result.is_ok());
343    }
344
345    #[test]
346    fn test_run_init_base_overrides_model() {
347        // --base should take precedence over --model
348        let args = InitArgs {
349            name: "test".to_string(),
350            output: None,
351            template: InitTemplate::Lora,
352            model: Some("local-model.safetensors".to_string()),
353            base: Some("Qwen/Qwen2.5-Coder-0.5B".to_string()),
354            method: None,
355            data: None,
356        };
357
358        let result = run_init(args, LogLevel::Quiet);
359        assert!(result.is_ok());
360    }
361
362    #[test]
363    fn test_categorize_param_count() {
364        assert_eq!(categorize_param_count(0.5), ModelSize::Small);
365        assert_eq!(categorize_param_count(1.0), ModelSize::Medium);
366        assert_eq!(categorize_param_count(7.0), ModelSize::Medium);
367        assert_eq!(categorize_param_count(13.0), ModelSize::Large);
368        assert_eq!(categorize_param_count(30.0), ModelSize::Large);
369        assert_eq!(categorize_param_count(70.0), ModelSize::XLarge);
370    }
371}