1const DEFAULT_COSINE_ANNEALING_T_MAX: usize = 10000;
7
8use super::manifest::{
9 AlertConfig, CallbackConfig, CallbackType, ChartConfig, CheckpointConfig, DataConfig,
10 DataLoader, DataSplit, EarlyStoppingConfig, GradientConfig, LoraConfig, MetricsOutputConfig,
11 MixedPrecisionConfig, ModelConfig, ModelOutputConfig, MonitoringConfig, OptimizerConfig,
12 OutputConfig, QuantizeConfig, RegistryConfig, ReportConfig, SchedulerConfig,
13 SystemMonitorConfig, TerminalMonitor, TrackingConfig, TrainingConfig, TrainingManifest,
14 WarmupConfig,
15};
16
17#[derive(Debug, Clone, Copy, PartialEq)]
19pub enum Template {
20 Minimal,
22 Lora,
24 Qlora,
26 Full,
28}
29
30pub fn generate_manifest(
32 template: Template,
33 name: &str,
34 model: Option<&str>,
35 data: Option<&str>,
36) -> TrainingManifest {
37 generate_manifest_with_hints(template, name, model, data, None, None)
38}
39
40pub fn generate_manifest_with_hints(
42 template: Template,
43 name: &str,
44 model: Option<&str>,
45 data: Option<&str>,
46 lora_rank: Option<u32>,
47 learning_rate: Option<f64>,
48) -> TrainingManifest {
49 let mut manifest = match template {
50 Template::Minimal => generate_minimal(name, model, data),
51 Template::Lora => generate_lora(name, model, data),
52 Template::Qlora => generate_qlora(name, model, data),
53 Template::Full => generate_full(name, model, data),
54 };
55
56 if let Some(rank) = lora_rank {
58 if let Some(ref mut lora) = manifest.lora {
59 lora.rank = rank.min(1024) as usize;
60 lora.alpha = f64::from(rank * 2);
62 }
63 }
64 if let Some(lr) = learning_rate {
65 if let Some(ref mut optim) = manifest.optimizer {
66 optim.lr = lr;
67 }
68 }
69
70 manifest
71}
72
73pub fn generate_yaml(
75 template: Template,
76 name: &str,
77 model: Option<&str>,
78 data: Option<&str>,
79 lora_rank: Option<u32>,
80 learning_rate: Option<f64>,
81) -> String {
82 let manifest =
83 generate_manifest_with_hints(template, name, model, data, lora_rank, learning_rate);
84 serde_yaml::to_string(&manifest).unwrap_or_else(|_err| "# Error generating YAML".to_string())
85}
86
87fn generate_minimal(name: &str, model: Option<&str>, data: Option<&str>) -> TrainingManifest {
88 TrainingManifest {
89 entrenar: "1.0".to_string(),
90 name: name.to_string(),
91 version: "1.0.0".to_string(),
92 description: Some("Training experiment".to_string()),
93 seed: Some(42),
94 data: data.map(default_data_config),
95 model: model.map(default_model_config),
96 optimizer: Some(default_optimizer_config()),
97 scheduler: Some(default_scheduler_config()),
98 training: Some(default_training_config()),
99 lora: None,
100 quantize: None,
101 monitoring: Some(default_monitoring_config()),
102 callbacks: None,
103 output: Some(default_output_config()),
104 publish: None,
105 citl: None,
107 rag: None,
108 graph: None,
109 distillation: None,
110 inspect: None,
111 privacy: None,
112 audit: None,
113 session: None,
114 stress: None,
115 benchmark: None,
116 debug: None,
117 signing: None,
118 verification: None,
119 lockfile: None,
120 strict: None,
121 strict_validation: None,
122 require_peer_review: None,
123 }
124}
125
126fn default_data_config(source: &str) -> DataConfig {
127 DataConfig {
128 source: Some(source.to_string()),
129 format: None,
130 split: Some(DataSplit {
131 train: 0.8,
132 val: Some(0.1),
133 test: Some(0.1),
134 stratify: None,
135 seed: Some(42),
136 }),
137 train: None,
138 val: None,
139 test: None,
140 preprocessing: None,
141 augmentation: None,
142 loader: Some(DataLoader {
143 batch_size: 32,
144 shuffle: true,
145 num_workers: Some(4),
146 pin_memory: Some(true),
147 drop_last: Some(false),
148 prefetch_factor: None,
149 }),
150 tokenizer: None,
151 seq_len: None,
152 input_column: None,
153 output_column: None,
154 max_length: None,
155 }
156}
157
158fn default_model_config(source: &str) -> ModelConfig {
159 ModelConfig {
160 source: source.to_string(),
161 format: None,
162 architecture: None,
163 freeze: None,
164 device: Some("auto".to_string()),
165 dtype: Some("float32".to_string()),
166 }
167}
168
169fn default_optimizer_config() -> OptimizerConfig {
170 OptimizerConfig {
171 name: "adamw".to_string(),
172 lr: 0.001,
173 weight_decay: Some(0.01),
174 betas: Some(vec![0.9, 0.999]),
175 eps: Some(1e-8),
176 amsgrad: None,
177 momentum: None,
178 nesterov: None,
179 dampening: None,
180 alpha: None,
181 centered: None,
182 param_groups: None,
183 }
184}
185
186fn default_scheduler_config() -> SchedulerConfig {
187 SchedulerConfig {
188 name: "cosine_annealing".to_string(),
189 warmup: Some(WarmupConfig { steps: Some(100), ratio: None, start_lr: Some(1e-7) }),
190 t_max: Some(DEFAULT_COSINE_ANNEALING_T_MAX),
191 eta_min: Some(1e-6),
192 step_size: None,
193 gamma: None,
194 mode: None,
195 factor: None,
196 patience: None,
197 threshold: None,
198 max_lr: None,
199 pct_start: None,
200 anneal_strategy: None,
201 div_factor: None,
202 final_div_factor: None,
203 }
204}
205
206fn default_training_config() -> TrainingConfig {
207 TrainingConfig {
208 epochs: Some(10),
209 max_steps: None,
210 duration: None,
211 gradient: Some(GradientConfig {
212 accumulation_steps: Some(1),
213 clip_norm: Some(1.0),
214 clip_value: None,
215 }),
216 mixed_precision: None,
217 distributed: None,
218 checkpoint: Some(CheckpointConfig {
219 save_every: Some(1000),
220 keep_last: Some(3),
221 save_best: Some(true),
222 metric: Some("val_loss".to_string()),
223 mode: Some("min".to_string()),
224 }),
225 early_stopping: Some(EarlyStoppingConfig {
226 enabled: true,
227 metric: Some("val_loss".to_string()),
228 patience: Some(5),
229 min_delta: Some(0.001),
230 mode: Some("min".to_string()),
231 }),
232 validation: None,
233 deterministic: None,
234 benchmark: None,
235 curriculum: None,
236 }
237}
238
239fn default_monitoring_config() -> MonitoringConfig {
240 MonitoringConfig {
241 terminal: Some(TerminalMonitor {
242 enabled: true,
243 refresh_rate: Some(100),
244 metrics: Some(vec!["loss".to_string(), "accuracy".to_string()]),
245 charts: None,
246 }),
247 tracking: None,
248 system: None,
249 alerts: None,
250 drift_detection: None,
251 }
252}
253
254fn default_output_config() -> OutputConfig {
255 OutputConfig {
256 dir: "./output/{{ name }}/{{ timestamp }}".to_string(),
257 model: Some(ModelOutputConfig {
258 format: Some("safetensors".to_string()),
259 save_optimizer: Some(true),
260 save_scheduler: Some(true),
261 }),
262 metrics: None,
263 report: Some(ReportConfig {
264 enabled: true,
265 format: Some("markdown".to_string()),
266 include_plots: Some(true),
267 }),
268 registry: None,
269 }
270}
271
272fn generate_lora(name: &str, model: Option<&str>, data: Option<&str>) -> TrainingManifest {
273 let mut manifest = generate_minimal(name, model, data);
274
275 manifest.lora = Some(LoraConfig {
277 enabled: true,
278 rank: 16,
279 alpha: 32.0,
280 dropout: Some(0.05),
281 target_modules: vec![
282 "q_proj".to_string(),
283 "k_proj".to_string(),
284 "v_proj".to_string(),
285 "o_proj".to_string(),
286 ],
287 target_modules_pattern: None,
288 bias: Some("none".to_string()),
289 init_weights: Some("gaussian".to_string()),
290 quantize_base: None,
291 quantize_bits: None,
292 double_quantize: None,
293 quant_type: None,
294 });
295
296 if let Some(ref mut training) = manifest.training {
298 training.epochs = Some(3);
299 if let Some(ref mut grad) = training.gradient {
300 grad.accumulation_steps = Some(4);
301 }
302 }
303
304 if let Some(ref mut optim) = manifest.optimizer {
306 optim.lr = 0.0002;
307 }
308
309 if let Some(ref mut model_config) = manifest.model {
311 model_config.dtype = Some("float16".to_string());
312 }
313
314 manifest
315}
316
317fn generate_qlora(name: &str, model: Option<&str>, data: Option<&str>) -> TrainingManifest {
318 let mut manifest = generate_lora(name, model, data);
319
320 if let Some(ref mut lora) = manifest.lora {
322 lora.quantize_base = Some(true);
323 lora.quantize_bits = Some(4);
324 lora.double_quantize = Some(true);
325 lora.quant_type = Some("nf4".to_string());
326 }
327
328 if let Some(ref mut training) = manifest.training {
330 training.mixed_precision = Some(MixedPrecisionConfig {
331 enabled: true,
332 dtype: Some("bfloat16".to_string()),
333 loss_scale: Some("dynamic".to_string()),
334 });
335 if let Some(ref mut grad) = training.gradient {
337 grad.accumulation_steps = Some(16);
338 }
339 }
340
341 manifest
342}
343
344fn full_quantize_config() -> QuantizeConfig {
346 QuantizeConfig {
347 enabled: false,
348 bits: 8,
349 scheme: Some("symmetric".to_string()),
350 granularity: Some("per_channel".to_string()),
351 group_size: Some(128),
352 qat: None,
353 calibration: None,
354 exclude: Some(vec!["lm_head".to_string(), "embed_tokens".to_string()]),
355 }
356}
357
358fn full_monitoring_config(name: &str) -> MonitoringConfig {
360 MonitoringConfig {
361 terminal: Some(TerminalMonitor {
362 enabled: true,
363 refresh_rate: Some(100),
364 metrics: Some(vec![
365 "loss".to_string(),
366 "accuracy".to_string(),
367 "learning_rate".to_string(),
368 "throughput".to_string(),
369 ]),
370 charts: Some(vec![
371 ChartConfig {
372 chart_type: "sparkline".to_string(),
373 metric: Some("loss".to_string()),
374 window: Some(100),
375 show_eta: None,
376 },
377 ChartConfig {
378 chart_type: "progress".to_string(),
379 metric: None,
380 window: None,
381 show_eta: Some(true),
382 },
383 ]),
384 }),
385 tracking: Some(TrackingConfig {
386 enabled: true,
387 backend: Some("trueno-db".to_string()),
388 project: Some(name.to_string()),
389 experiment: Some("{{ name }}-{{ timestamp }}".to_string()),
390 tags: None,
391 }),
392 system: Some(SystemMonitorConfig {
393 enabled: true,
394 interval: Some(1000),
395 metrics: Some(vec![
396 "cpu_percent".to_string(),
397 "memory_mb".to_string(),
398 "gpu_utilization".to_string(),
399 "gpu_memory_mb".to_string(),
400 ]),
401 }),
402 alerts: Some(vec![
403 AlertConfig {
404 condition: "loss > 10".to_string(),
405 action: "warn".to_string(),
406 message: "Loss explosion detected".to_string(),
407 },
408 AlertConfig {
409 condition: "gpu_memory > 0.95".to_string(),
410 action: "halt".to_string(),
411 message: "GPU OOM imminent".to_string(),
412 },
413 ]),
414 drift_detection: None,
415 }
416}
417
418fn full_callbacks_config() -> Vec<CallbackConfig> {
420 vec![
421 CallbackConfig {
422 callback_type: CallbackType::Checkpoint,
423 trigger: "epoch_end".to_string(),
424 interval: None,
425 config: None,
426 script: None,
427 },
428 CallbackConfig {
429 callback_type: CallbackType::LrMonitor,
430 trigger: "step".to_string(),
431 interval: None,
432 config: None,
433 script: None,
434 },
435 CallbackConfig {
436 callback_type: CallbackType::GradientMonitor,
437 trigger: "step".to_string(),
438 interval: Some(100),
439 config: None,
440 script: None,
441 },
442 ]
443}
444
445fn full_output_config() -> OutputConfig {
447 OutputConfig {
448 dir: "./experiments/{{ name }}/{{ timestamp }}".to_string(),
449 model: Some(ModelOutputConfig {
450 format: Some("safetensors".to_string()),
451 save_optimizer: Some(true),
452 save_scheduler: Some(true),
453 }),
454 metrics: Some(MetricsOutputConfig {
455 format: Some("parquet".to_string()),
456 include: Some(vec![
457 "train_loss".to_string(),
458 "val_loss".to_string(),
459 "accuracy".to_string(),
460 "learning_rate".to_string(),
461 ]),
462 }),
463 report: Some(ReportConfig {
464 enabled: true,
465 format: Some("markdown".to_string()),
466 include_plots: Some(true),
467 }),
468 registry: Some(RegistryConfig {
469 enabled: false,
470 target: Some("pacha://models/{{ name }}:{{ version }}".to_string()),
471 include_config: Some(true),
472 include_metrics: Some(true),
473 }),
474 }
475}
476
477fn generate_full(name: &str, model: Option<&str>, data: Option<&str>) -> TrainingManifest {
478 let mut manifest = generate_qlora(name, model, data);
479
480 manifest.quantize = Some(full_quantize_config());
481 manifest.monitoring = Some(full_monitoring_config(name));
482 manifest.callbacks = Some(full_callbacks_config());
483 manifest.output = Some(full_output_config());
484
485 manifest
486}
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491
492 #[test]
493 fn test_generate_minimal() {
494 let manifest = generate_manifest(
495 Template::Minimal,
496 "test-exp",
497 Some("model.safetensors"),
498 Some("./data"),
499 );
500 assert_eq!(manifest.entrenar, "1.0");
501 assert_eq!(manifest.name, "test-exp");
502 assert!(manifest.lora.is_none());
503 assert!(manifest.model.is_some());
504 }
505
506 #[test]
507 fn test_generate_lora() {
508 let manifest =
509 generate_manifest(Template::Lora, "lora-exp", Some("hf://llama"), Some("hf://data"));
510 assert!(manifest.lora.is_some());
511 let lora = manifest.lora.expect("operation should succeed");
512 assert!(lora.enabled);
513 assert_eq!(lora.rank, 16);
514 assert!(lora.quantize_base.is_none());
515 }
516
517 #[test]
518 fn test_generate_qlora() {
519 let manifest = generate_manifest(Template::Qlora, "qlora-exp", None, None);
520 assert!(manifest.lora.is_some());
521 let lora = manifest.lora.expect("operation should succeed");
522 assert!(lora.quantize_base.expect("operation should succeed"));
523 assert_eq!(lora.quantize_bits, Some(4));
524 assert!(manifest
525 .training
526 .as_ref()
527 .expect("operation should succeed")
528 .mixed_precision
529 .is_some());
530 }
531
532 #[test]
533 fn test_generate_full() {
534 let manifest = generate_manifest(Template::Full, "full-exp", None, None);
535 assert!(manifest.lora.is_some());
536 assert!(manifest.quantize.is_some());
537 assert!(manifest.monitoring.is_some());
538 assert!(manifest.callbacks.is_some());
539 assert!(manifest.output.is_some());
540
541 let monitoring = manifest.monitoring.expect("operation should succeed");
542 assert!(monitoring.tracking.is_some());
543 assert!(monitoring.system.is_some());
544 assert!(monitoring.alerts.is_some());
545 }
546
547 #[test]
548 fn test_generate_yaml_output() {
549 let yaml = generate_yaml(Template::Minimal, "yaml-test", None, None, None, None);
550 assert!(yaml.contains("entrenar: '1.0'") || yaml.contains("entrenar: \"1.0\""));
551 assert!(yaml.contains("yaml-test"));
552 }
553
554 #[test]
555 fn test_manifest_validates() {
556 use super::super::validation::validate_manifest;
557
558 for template in [Template::Minimal, Template::Lora, Template::Qlora, Template::Full] {
560 let manifest = generate_manifest(template, "test", None, None);
561 let result = validate_manifest(&manifest);
562 assert!(result.is_ok(), "Template {template:?} produced invalid manifest: {result:?}");
563 }
564 }
565
566 #[test]
567 fn test_smart_defaults_lora_rank() {
568 let manifest = generate_manifest_with_hints(
569 Template::Lora,
570 "smart-test",
571 Some("Qwen/Qwen2.5-Coder-0.5B"),
572 None,
573 Some(32), Some(3e-4), );
576 let lora = manifest.lora.expect("operation should succeed");
577 assert_eq!(lora.rank, 32);
578 assert!((lora.alpha - 64.0).abs() < 0.01); assert!((manifest.optimizer.expect("operation should succeed").lr - 3e-4).abs() < 1e-10);
580 }
581
582 #[test]
583 fn test_smart_defaults_large_model() {
584 let manifest = generate_manifest_with_hints(
585 Template::Qlora,
586 "large-test",
587 Some("meta-llama/Llama-3-13B"),
588 None,
589 Some(128),
590 Some(1e-4),
591 );
592 let lora = manifest.lora.expect("operation should succeed");
593 assert_eq!(lora.rank, 128);
594 assert!((lora.alpha - 256.0).abs() < 0.01);
595 }
596
597 #[test]
598 fn test_smart_defaults_no_hints() {
599 let manifest =
601 generate_manifest_with_hints(Template::Lora, "no-hints", None, None, None, None);
602 let lora = manifest.lora.expect("operation should succeed");
603 assert_eq!(lora.rank, 16); assert!((lora.alpha - 32.0).abs() < 0.01);
605 }
606
607 #[test]
608 fn test_minimal_has_no_publish() {
609 let manifest = generate_manifest(Template::Minimal, "test", None, None);
610 assert!(manifest.publish.is_none());
611 }
612
613 #[test]
614 fn test_publish_config_yaml_roundtrip() {
615 use super::super::manifest::PublishConfig;
616
617 let yaml = r#"
618 repo: "myuser/my-model"
619 private: false
620 model_card: true
621 merge_adapters: true
622 format: safetensors
623 "#;
624 let config: PublishConfig = serde_yaml::from_str(yaml).expect("config should be valid");
625 assert_eq!(config.repo, "myuser/my-model");
626 assert!(!config.private);
627 assert!(config.model_card);
628 assert!(config.merge_adapters);
629 assert_eq!(config.format, "safetensors");
630 }
631
632 #[test]
633 fn test_publish_config_defaults() {
634 use super::super::manifest::PublishConfig;
635
636 let yaml = r#"repo: "org/name""#;
637 let config: PublishConfig = serde_yaml::from_str(yaml).expect("config should be valid");
638 assert!(!config.private);
639 assert!(config.model_card); assert!(!config.merge_adapters);
641 assert_eq!(config.format, "safetensors"); }
643
644 #[test]
645 fn test_manifest_with_publish_section() {
646 let yaml = r#"
647entrenar: "1.0"
648name: test
649version: "1.0.0"
650publish:
651 repo: myuser/my-model
652 merge_adapters: true
653"#;
654 let manifest: TrainingManifest =
655 serde_yaml::from_str(yaml).expect("operation should succeed");
656 let publish = manifest.publish.expect("operation should succeed");
657 assert_eq!(publish.repo, "myuser/my-model");
658 assert!(publish.merge_adapters);
659 assert!(publish.model_card); }
661}