Skip to main content

entrenar/yaml_mode/
templates.rs

1//! Template Generation for YAML Mode Training
2//!
3//! Generates starter manifests for common training scenarios.
4
5/// Default `T_max` for cosine annealing scheduler (total steps before restart).
6const DEFAULT_COSINE_ANNEALING_T_MAX: usize = 10000;
7
8use super::manifest::{
9    AlertConfig, CallbackConfig, CallbackType, ChartConfig, CheckpointConfig, DataConfig,
10    DataLoader, DataSplit, EarlyStoppingConfig, GradientConfig, LoraConfig, MetricsOutputConfig,
11    MixedPrecisionConfig, ModelConfig, ModelOutputConfig, MonitoringConfig, OptimizerConfig,
12    OutputConfig, QuantizeConfig, RegistryConfig, ReportConfig, SchedulerConfig,
13    SystemMonitorConfig, TerminalMonitor, TrackingConfig, TrainingConfig, TrainingManifest,
14    WarmupConfig,
15};
16
17/// Template type for initialization
18#[derive(Debug, Clone, Copy, PartialEq)]
19pub enum Template {
20    /// Minimal manifest with required fields only
21    Minimal,
22    /// LoRA fine-tuning template
23    Lora,
24    /// QLoRA fine-tuning template
25    Qlora,
26    /// Full template with all sections
27    Full,
28}
29
30/// Generate a training manifest from a template
31pub fn generate_manifest(
32    template: Template,
33    name: &str,
34    model: Option<&str>,
35    data: Option<&str>,
36) -> TrainingManifest {
37    generate_manifest_with_hints(template, name, model, data, None, None)
38}
39
40/// Generate a training manifest with optional smart defaults
41pub fn generate_manifest_with_hints(
42    template: Template,
43    name: &str,
44    model: Option<&str>,
45    data: Option<&str>,
46    lora_rank: Option<u32>,
47    learning_rate: Option<f64>,
48) -> TrainingManifest {
49    let mut manifest = match template {
50        Template::Minimal => generate_minimal(name, model, data),
51        Template::Lora => generate_lora(name, model, data),
52        Template::Qlora => generate_qlora(name, model, data),
53        Template::Full => generate_full(name, model, data),
54    };
55
56    // Apply smart defaults if provided
57    if let Some(rank) = lora_rank {
58        if let Some(ref mut lora) = manifest.lora {
59            lora.rank = rank.min(1024) as usize;
60            // Alpha is typically 2x rank
61            lora.alpha = f64::from(rank * 2);
62        }
63    }
64    if let Some(lr) = learning_rate {
65        if let Some(ref mut optim) = manifest.optimizer {
66            optim.lr = lr;
67        }
68    }
69
70    manifest
71}
72
73/// Generate YAML string from a template
74pub fn generate_yaml(
75    template: Template,
76    name: &str,
77    model: Option<&str>,
78    data: Option<&str>,
79    lora_rank: Option<u32>,
80    learning_rate: Option<f64>,
81) -> String {
82    let manifest =
83        generate_manifest_with_hints(template, name, model, data, lora_rank, learning_rate);
84    serde_yaml::to_string(&manifest).unwrap_or_else(|_err| "# Error generating YAML".to_string())
85}
86
87fn generate_minimal(name: &str, model: Option<&str>, data: Option<&str>) -> TrainingManifest {
88    TrainingManifest {
89        entrenar: "1.0".to_string(),
90        name: name.to_string(),
91        version: "1.0.0".to_string(),
92        description: Some("Training experiment".to_string()),
93        seed: Some(42),
94        data: data.map(default_data_config),
95        model: model.map(default_model_config),
96        optimizer: Some(default_optimizer_config()),
97        scheduler: Some(default_scheduler_config()),
98        training: Some(default_training_config()),
99        lora: None,
100        quantize: None,
101        monitoring: Some(default_monitoring_config()),
102        callbacks: None,
103        output: Some(default_output_config()),
104        publish: None,
105        // Extended configurations (YAML Mode QA Epic)
106        citl: None,
107        rag: None,
108        graph: None,
109        distillation: None,
110        inspect: None,
111        privacy: None,
112        audit: None,
113        session: None,
114        stress: None,
115        benchmark: None,
116        debug: None,
117        signing: None,
118        verification: None,
119        lockfile: None,
120        strict: None,
121        strict_validation: None,
122        require_peer_review: None,
123    }
124}
125
126fn default_data_config(source: &str) -> DataConfig {
127    DataConfig {
128        source: Some(source.to_string()),
129        format: None,
130        split: Some(DataSplit {
131            train: 0.8,
132            val: Some(0.1),
133            test: Some(0.1),
134            stratify: None,
135            seed: Some(42),
136        }),
137        train: None,
138        val: None,
139        test: None,
140        preprocessing: None,
141        augmentation: None,
142        loader: Some(DataLoader {
143            batch_size: 32,
144            shuffle: true,
145            num_workers: Some(4),
146            pin_memory: Some(true),
147            drop_last: Some(false),
148            prefetch_factor: None,
149        }),
150        tokenizer: None,
151        seq_len: None,
152        input_column: None,
153        output_column: None,
154        max_length: None,
155    }
156}
157
158fn default_model_config(source: &str) -> ModelConfig {
159    ModelConfig {
160        source: source.to_string(),
161        format: None,
162        architecture: None,
163        freeze: None,
164        device: Some("auto".to_string()),
165        dtype: Some("float32".to_string()),
166    }
167}
168
169fn default_optimizer_config() -> OptimizerConfig {
170    OptimizerConfig {
171        name: "adamw".to_string(),
172        lr: 0.001,
173        weight_decay: Some(0.01),
174        betas: Some(vec![0.9, 0.999]),
175        eps: Some(1e-8),
176        amsgrad: None,
177        momentum: None,
178        nesterov: None,
179        dampening: None,
180        alpha: None,
181        centered: None,
182        param_groups: None,
183    }
184}
185
186fn default_scheduler_config() -> SchedulerConfig {
187    SchedulerConfig {
188        name: "cosine_annealing".to_string(),
189        warmup: Some(WarmupConfig { steps: Some(100), ratio: None, start_lr: Some(1e-7) }),
190        t_max: Some(DEFAULT_COSINE_ANNEALING_T_MAX),
191        eta_min: Some(1e-6),
192        step_size: None,
193        gamma: None,
194        mode: None,
195        factor: None,
196        patience: None,
197        threshold: None,
198        max_lr: None,
199        pct_start: None,
200        anneal_strategy: None,
201        div_factor: None,
202        final_div_factor: None,
203    }
204}
205
206fn default_training_config() -> TrainingConfig {
207    TrainingConfig {
208        epochs: Some(10),
209        max_steps: None,
210        duration: None,
211        gradient: Some(GradientConfig {
212            accumulation_steps: Some(1),
213            clip_norm: Some(1.0),
214            clip_value: None,
215        }),
216        mixed_precision: None,
217        distributed: None,
218        checkpoint: Some(CheckpointConfig {
219            save_every: Some(1000),
220            keep_last: Some(3),
221            save_best: Some(true),
222            metric: Some("val_loss".to_string()),
223            mode: Some("min".to_string()),
224        }),
225        early_stopping: Some(EarlyStoppingConfig {
226            enabled: true,
227            metric: Some("val_loss".to_string()),
228            patience: Some(5),
229            min_delta: Some(0.001),
230            mode: Some("min".to_string()),
231        }),
232        validation: None,
233        deterministic: None,
234        benchmark: None,
235        curriculum: None,
236    }
237}
238
239fn default_monitoring_config() -> MonitoringConfig {
240    MonitoringConfig {
241        terminal: Some(TerminalMonitor {
242            enabled: true,
243            refresh_rate: Some(100),
244            metrics: Some(vec!["loss".to_string(), "accuracy".to_string()]),
245            charts: None,
246        }),
247        tracking: None,
248        system: None,
249        alerts: None,
250        drift_detection: None,
251    }
252}
253
254fn default_output_config() -> OutputConfig {
255    OutputConfig {
256        dir: "./output/{{ name }}/{{ timestamp }}".to_string(),
257        model: Some(ModelOutputConfig {
258            format: Some("safetensors".to_string()),
259            save_optimizer: Some(true),
260            save_scheduler: Some(true),
261        }),
262        metrics: None,
263        report: Some(ReportConfig {
264            enabled: true,
265            format: Some("markdown".to_string()),
266            include_plots: Some(true),
267        }),
268        registry: None,
269    }
270}
271
272fn generate_lora(name: &str, model: Option<&str>, data: Option<&str>) -> TrainingManifest {
273    let mut manifest = generate_minimal(name, model, data);
274
275    // Add LoRA configuration
276    manifest.lora = Some(LoraConfig {
277        enabled: true,
278        rank: 16,
279        alpha: 32.0,
280        dropout: Some(0.05),
281        target_modules: vec![
282            "q_proj".to_string(),
283            "k_proj".to_string(),
284            "v_proj".to_string(),
285            "o_proj".to_string(),
286        ],
287        target_modules_pattern: None,
288        bias: Some("none".to_string()),
289        init_weights: Some("gaussian".to_string()),
290        quantize_base: None,
291        quantize_bits: None,
292        double_quantize: None,
293        quant_type: None,
294    });
295
296    // Adjust training params for LoRA
297    if let Some(ref mut training) = manifest.training {
298        training.epochs = Some(3);
299        if let Some(ref mut grad) = training.gradient {
300            grad.accumulation_steps = Some(4);
301        }
302    }
303
304    // Lower learning rate for fine-tuning
305    if let Some(ref mut optim) = manifest.optimizer {
306        optim.lr = 0.0002;
307    }
308
309    // Use float16 for model
310    if let Some(ref mut model_config) = manifest.model {
311        model_config.dtype = Some("float16".to_string());
312    }
313
314    manifest
315}
316
317fn generate_qlora(name: &str, model: Option<&str>, data: Option<&str>) -> TrainingManifest {
318    let mut manifest = generate_lora(name, model, data);
319
320    // Enable QLoRA (quantized LoRA)
321    if let Some(ref mut lora) = manifest.lora {
322        lora.quantize_base = Some(true);
323        lora.quantize_bits = Some(4);
324        lora.double_quantize = Some(true);
325        lora.quant_type = Some("nf4".to_string());
326    }
327
328    // Enable mixed precision
329    if let Some(ref mut training) = manifest.training {
330        training.mixed_precision = Some(MixedPrecisionConfig {
331            enabled: true,
332            dtype: Some("bfloat16".to_string()),
333            loss_scale: Some("dynamic".to_string()),
334        });
335        // Increase gradient accumulation for memory efficiency
336        if let Some(ref mut grad) = training.gradient {
337            grad.accumulation_steps = Some(16);
338        }
339    }
340
341    manifest
342}
343
344/// Build the full quantization config section.
345fn full_quantize_config() -> QuantizeConfig {
346    QuantizeConfig {
347        enabled: false,
348        bits: 8,
349        scheme: Some("symmetric".to_string()),
350        granularity: Some("per_channel".to_string()),
351        group_size: Some(128),
352        qat: None,
353        calibration: None,
354        exclude: Some(vec!["lm_head".to_string(), "embed_tokens".to_string()]),
355    }
356}
357
358/// Build the full monitoring config with terminal, tracking, system, and alerts.
359fn full_monitoring_config(name: &str) -> MonitoringConfig {
360    MonitoringConfig {
361        terminal: Some(TerminalMonitor {
362            enabled: true,
363            refresh_rate: Some(100),
364            metrics: Some(vec![
365                "loss".to_string(),
366                "accuracy".to_string(),
367                "learning_rate".to_string(),
368                "throughput".to_string(),
369            ]),
370            charts: Some(vec![
371                ChartConfig {
372                    chart_type: "sparkline".to_string(),
373                    metric: Some("loss".to_string()),
374                    window: Some(100),
375                    show_eta: None,
376                },
377                ChartConfig {
378                    chart_type: "progress".to_string(),
379                    metric: None,
380                    window: None,
381                    show_eta: Some(true),
382                },
383            ]),
384        }),
385        tracking: Some(TrackingConfig {
386            enabled: true,
387            backend: Some("trueno-db".to_string()),
388            project: Some(name.to_string()),
389            experiment: Some("{{ name }}-{{ timestamp }}".to_string()),
390            tags: None,
391        }),
392        system: Some(SystemMonitorConfig {
393            enabled: true,
394            interval: Some(1000),
395            metrics: Some(vec![
396                "cpu_percent".to_string(),
397                "memory_mb".to_string(),
398                "gpu_utilization".to_string(),
399                "gpu_memory_mb".to_string(),
400            ]),
401        }),
402        alerts: Some(vec![
403            AlertConfig {
404                condition: "loss > 10".to_string(),
405                action: "warn".to_string(),
406                message: "Loss explosion detected".to_string(),
407            },
408            AlertConfig {
409                condition: "gpu_memory > 0.95".to_string(),
410                action: "halt".to_string(),
411                message: "GPU OOM imminent".to_string(),
412            },
413        ]),
414        drift_detection: None,
415    }
416}
417
418/// Build the full callbacks list (checkpoint, LR monitor, gradient monitor).
419fn full_callbacks_config() -> Vec<CallbackConfig> {
420    vec![
421        CallbackConfig {
422            callback_type: CallbackType::Checkpoint,
423            trigger: "epoch_end".to_string(),
424            interval: None,
425            config: None,
426            script: None,
427        },
428        CallbackConfig {
429            callback_type: CallbackType::LrMonitor,
430            trigger: "step".to_string(),
431            interval: None,
432            config: None,
433            script: None,
434        },
435        CallbackConfig {
436            callback_type: CallbackType::GradientMonitor,
437            trigger: "step".to_string(),
438            interval: Some(100),
439            config: None,
440            script: None,
441        },
442    ]
443}
444
445/// Build the full output config with model, metrics, report, and registry.
446fn full_output_config() -> OutputConfig {
447    OutputConfig {
448        dir: "./experiments/{{ name }}/{{ timestamp }}".to_string(),
449        model: Some(ModelOutputConfig {
450            format: Some("safetensors".to_string()),
451            save_optimizer: Some(true),
452            save_scheduler: Some(true),
453        }),
454        metrics: Some(MetricsOutputConfig {
455            format: Some("parquet".to_string()),
456            include: Some(vec![
457                "train_loss".to_string(),
458                "val_loss".to_string(),
459                "accuracy".to_string(),
460                "learning_rate".to_string(),
461            ]),
462        }),
463        report: Some(ReportConfig {
464            enabled: true,
465            format: Some("markdown".to_string()),
466            include_plots: Some(true),
467        }),
468        registry: Some(RegistryConfig {
469            enabled: false,
470            target: Some("pacha://models/{{ name }}:{{ version }}".to_string()),
471            include_config: Some(true),
472            include_metrics: Some(true),
473        }),
474    }
475}
476
477fn generate_full(name: &str, model: Option<&str>, data: Option<&str>) -> TrainingManifest {
478    let mut manifest = generate_qlora(name, model, data);
479
480    manifest.quantize = Some(full_quantize_config());
481    manifest.monitoring = Some(full_monitoring_config(name));
482    manifest.callbacks = Some(full_callbacks_config());
483    manifest.output = Some(full_output_config());
484
485    manifest
486}
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491
492    #[test]
493    fn test_generate_minimal() {
494        let manifest = generate_manifest(
495            Template::Minimal,
496            "test-exp",
497            Some("model.safetensors"),
498            Some("./data"),
499        );
500        assert_eq!(manifest.entrenar, "1.0");
501        assert_eq!(manifest.name, "test-exp");
502        assert!(manifest.lora.is_none());
503        assert!(manifest.model.is_some());
504    }
505
506    #[test]
507    fn test_generate_lora() {
508        let manifest =
509            generate_manifest(Template::Lora, "lora-exp", Some("hf://llama"), Some("hf://data"));
510        assert!(manifest.lora.is_some());
511        let lora = manifest.lora.expect("operation should succeed");
512        assert!(lora.enabled);
513        assert_eq!(lora.rank, 16);
514        assert!(lora.quantize_base.is_none());
515    }
516
517    #[test]
518    fn test_generate_qlora() {
519        let manifest = generate_manifest(Template::Qlora, "qlora-exp", None, None);
520        assert!(manifest.lora.is_some());
521        let lora = manifest.lora.expect("operation should succeed");
522        assert!(lora.quantize_base.expect("operation should succeed"));
523        assert_eq!(lora.quantize_bits, Some(4));
524        assert!(manifest
525            .training
526            .as_ref()
527            .expect("operation should succeed")
528            .mixed_precision
529            .is_some());
530    }
531
532    #[test]
533    fn test_generate_full() {
534        let manifest = generate_manifest(Template::Full, "full-exp", None, None);
535        assert!(manifest.lora.is_some());
536        assert!(manifest.quantize.is_some());
537        assert!(manifest.monitoring.is_some());
538        assert!(manifest.callbacks.is_some());
539        assert!(manifest.output.is_some());
540
541        let monitoring = manifest.monitoring.expect("operation should succeed");
542        assert!(monitoring.tracking.is_some());
543        assert!(monitoring.system.is_some());
544        assert!(monitoring.alerts.is_some());
545    }
546
547    #[test]
548    fn test_generate_yaml_output() {
549        let yaml = generate_yaml(Template::Minimal, "yaml-test", None, None, None, None);
550        assert!(yaml.contains("entrenar: '1.0'") || yaml.contains("entrenar: \"1.0\""));
551        assert!(yaml.contains("yaml-test"));
552    }
553
554    #[test]
555    fn test_manifest_validates() {
556        use super::super::validation::validate_manifest;
557
558        // All templates should produce valid manifests
559        for template in [Template::Minimal, Template::Lora, Template::Qlora, Template::Full] {
560            let manifest = generate_manifest(template, "test", None, None);
561            let result = validate_manifest(&manifest);
562            assert!(result.is_ok(), "Template {template:?} produced invalid manifest: {result:?}");
563        }
564    }
565
566    #[test]
567    fn test_smart_defaults_lora_rank() {
568        let manifest = generate_manifest_with_hints(
569            Template::Lora,
570            "smart-test",
571            Some("Qwen/Qwen2.5-Coder-0.5B"),
572            None,
573            Some(32),   // small model rank
574            Some(3e-4), // small model lr
575        );
576        let lora = manifest.lora.expect("operation should succeed");
577        assert_eq!(lora.rank, 32);
578        assert!((lora.alpha - 64.0).abs() < 0.01); // alpha = 2 * rank
579        assert!((manifest.optimizer.expect("operation should succeed").lr - 3e-4).abs() < 1e-10);
580    }
581
582    #[test]
583    fn test_smart_defaults_large_model() {
584        let manifest = generate_manifest_with_hints(
585            Template::Qlora,
586            "large-test",
587            Some("meta-llama/Llama-3-13B"),
588            None,
589            Some(128),
590            Some(1e-4),
591        );
592        let lora = manifest.lora.expect("operation should succeed");
593        assert_eq!(lora.rank, 128);
594        assert!((lora.alpha - 256.0).abs() < 0.01);
595    }
596
597    #[test]
598    fn test_smart_defaults_no_hints() {
599        // Without hints, defaults should be unchanged
600        let manifest =
601            generate_manifest_with_hints(Template::Lora, "no-hints", None, None, None, None);
602        let lora = manifest.lora.expect("operation should succeed");
603        assert_eq!(lora.rank, 16); // original default
604        assert!((lora.alpha - 32.0).abs() < 0.01);
605    }
606
607    #[test]
608    fn test_minimal_has_no_publish() {
609        let manifest = generate_manifest(Template::Minimal, "test", None, None);
610        assert!(manifest.publish.is_none());
611    }
612
613    #[test]
614    fn test_publish_config_yaml_roundtrip() {
615        use super::super::manifest::PublishConfig;
616
617        let yaml = r#"
618            repo: "myuser/my-model"
619            private: false
620            model_card: true
621            merge_adapters: true
622            format: safetensors
623        "#;
624        let config: PublishConfig = serde_yaml::from_str(yaml).expect("config should be valid");
625        assert_eq!(config.repo, "myuser/my-model");
626        assert!(!config.private);
627        assert!(config.model_card);
628        assert!(config.merge_adapters);
629        assert_eq!(config.format, "safetensors");
630    }
631
632    #[test]
633    fn test_publish_config_defaults() {
634        use super::super::manifest::PublishConfig;
635
636        let yaml = r#"repo: "org/name""#;
637        let config: PublishConfig = serde_yaml::from_str(yaml).expect("config should be valid");
638        assert!(!config.private);
639        assert!(config.model_card); // default true
640        assert!(!config.merge_adapters);
641        assert_eq!(config.format, "safetensors"); // default
642    }
643
644    #[test]
645    fn test_manifest_with_publish_section() {
646        let yaml = r#"
647entrenar: "1.0"
648name: test
649version: "1.0.0"
650publish:
651  repo: myuser/my-model
652  merge_adapters: true
653"#;
654        let manifest: TrainingManifest =
655            serde_yaml::from_str(yaml).expect("operation should succeed");
656        let publish = manifest.publish.expect("operation should succeed");
657        assert_eq!(publish.repo, "myuser/my-model");
658        assert!(publish.merge_adapters);
659        assert!(publish.model_card); // default
660    }
661}