1use crate::config::{
7 DataConfig as SpecDataConfig, LoRASpec, ModelMode, ModelRef, OptimSpec, QuantSpec, TrainSpec,
8 TrainingMode, TrainingParams,
9};
10use crate::yaml_mode::TrainingManifest;
11use std::collections::HashMap;
12use std::path::PathBuf;
13use thiserror::Error;
14
15#[derive(Debug)]
17pub struct BridgeResult {
18 pub spec: TrainSpec,
20 pub warnings: Vec<String>,
22}
23
24#[derive(Debug, Error)]
26pub enum BridgeError {
27 #[error("Missing required field: {0}")]
28 MissingRequired(String),
29
30 #[error("Invalid value for {field}: {reason}")]
31 InvalidValue { field: String, reason: String },
32}
33
34pub fn manifest_to_spec(manifest: &TrainingManifest) -> Result<BridgeResult, BridgeError> {
39 let mut warnings = Vec::new();
40
41 let model = convert_model(manifest, &mut warnings)?;
42 let data = convert_data(manifest, &mut warnings)?;
43 let optimizer = convert_optimizer(manifest)?;
44 let training = convert_training(manifest, model.mode, &mut warnings);
45 let lora = convert_lora(manifest, &mut warnings);
46 let quantize = convert_quantize(manifest);
47
48 if manifest.monitoring.is_some() {
50 warnings.push("monitoring config is not supported in legacy TrainSpec".into());
51 }
52 if manifest.callbacks.is_some() {
53 warnings.push("callbacks config is not supported in legacy TrainSpec".into());
54 }
55 if manifest.distillation.is_some() {
56 warnings.push("distillation config is not supported in legacy TrainSpec".into());
57 }
58
59 let spec =
60 TrainSpec { model, data, optimizer, training, lora, quantize, merge: None, publish: None };
61
62 Ok(BridgeResult { spec, warnings })
63}
64
65fn convert_model(
67 manifest: &TrainingManifest,
68 warnings: &mut Vec<String>,
69) -> Result<ModelRef, BridgeError> {
70 use crate::config::ArchitectureOverrides;
71
72 let model_cfg =
73 manifest.model.as_ref().ok_or_else(|| BridgeError::MissingRequired("model".into()))?;
74
75 let mode = if let Some(ref arch) = model_cfg.architecture {
77 if arch.arch_type == "transformer" {
78 ModelMode::Transformer
79 } else {
80 ModelMode::Tabular
81 }
82 } else {
83 ModelMode::Tabular
84 };
85
86 let layers = manifest.lora.as_ref().map(|l| l.target_modules.clone()).unwrap_or_default();
88
89 if model_cfg.freeze.is_some() {
90 warnings.push("model.freeze is not supported in legacy TrainSpec".into());
91 }
92 if model_cfg.device.is_some() {
93 warnings.push("model.device is not supported in legacy TrainSpec".into());
94 }
95
96 let architecture = model_cfg
98 .architecture
99 .as_ref()
100 .map(|arch| ArchitectureOverrides {
101 hidden_size: arch.hidden_size,
102 num_hidden_layers: arch.num_layers,
103 num_attention_heads: arch.num_heads,
104 num_kv_heads: arch.num_kv_heads,
105 intermediate_size: arch.intermediate_size,
106 vocab_size: arch.vocab_size,
107 max_position_embeddings: arch.max_seq_length,
108 rms_norm_eps: arch.rms_norm_eps,
109 rope_theta: arch.rope_theta,
110 use_bias: arch.use_bias,
111 head_dim: arch.head_dim,
112 })
113 .filter(|o| !o.is_empty());
114
115 Ok(ModelRef {
116 path: PathBuf::from(&model_cfg.source),
117 layers,
118 mode,
119 config: None,
120 architecture,
121 })
122}
123
124fn resolve_train_path(
128 data_cfg: &crate::yaml_mode::manifest::data::DataConfig,
129) -> Result<PathBuf, BridgeError> {
130 data_cfg
131 .train
132 .as_deref()
133 .or(data_cfg.source.as_deref())
134 .map(PathBuf::from)
135 .ok_or_else(|| BridgeError::MissingRequired("data.source or data.train".into()))
136}
137
138fn extract_preprocessing_tokenizer(
142 steps: &[crate::yaml_mode::manifest::data::PreprocessingStep],
143) -> (Option<PathBuf>, Option<usize>) {
144 for step in steps {
145 if let crate::yaml_mode::manifest::data::PreprocessingStep::Tokenize { tokenize } = step {
146 return (Some(PathBuf::from(&tokenize.tokenizer)), tokenize.max_length);
147 }
148 }
149 (None, None)
150}
151
152fn convert_data(
154 manifest: &TrainingManifest,
155 warnings: &mut Vec<String>,
156) -> Result<SpecDataConfig, BridgeError> {
157 let data_cfg =
158 manifest.data.as_ref().ok_or_else(|| BridgeError::MissingRequired("data".into()))?;
159
160 let train = resolve_train_path(data_cfg)?;
161 let val = data_cfg.val.as_ref().map(PathBuf::from);
162 let batch_size = data_cfg.loader.as_ref().map_or(8, |l| l.batch_size);
163
164 let mut tokenizer = data_cfg.tokenizer.as_ref().map(PathBuf::from);
166 let seq_len = data_cfg.seq_len;
167 let input_column = data_cfg.input_column.clone();
168 let output_column = data_cfg.output_column.clone();
169 let mut max_length = data_cfg.max_length;
170
171 if let Some(ref steps) = data_cfg.preprocessing {
173 let (fallback_tok, fallback_len) = extract_preprocessing_tokenizer(steps);
174 if tokenizer.is_none() {
175 tokenizer = fallback_tok;
176 }
177 if max_length.is_none() {
178 max_length = fallback_len;
179 }
180 }
181
182 if data_cfg.augmentation.is_some() {
183 warnings.push("data.augmentation is not supported in legacy TrainSpec".into());
184 }
185
186 Ok(SpecDataConfig {
187 train,
188 val,
189 batch_size,
190 auto_infer_types: true,
191 seq_len,
192 tokenizer,
193 input_column,
194 output_column,
195 max_length,
196 })
197}
198
199fn convert_optimizer(manifest: &TrainingManifest) -> Result<OptimSpec, BridgeError> {
201 let optim_cfg = manifest
202 .optimizer
203 .as_ref()
204 .ok_or_else(|| BridgeError::MissingRequired("optimizer".into()))?;
205
206 let name = optim_cfg.name.to_lowercase();
207
208 let lr = optim_cfg.lr as f32;
210
211 let mut params: HashMap<String, serde_json::Value> = HashMap::new();
213
214 if let Some(ref betas) = optim_cfg.betas {
215 if betas.len() >= 2 {
216 params.insert("beta1".into(), serde_json::json!(betas[0]));
217 params.insert("beta2".into(), serde_json::json!(betas[1]));
218 }
219 }
220 if let Some(eps) = optim_cfg.eps {
221 params.insert("eps".into(), serde_json::json!(eps));
222 }
223 if let Some(wd) = optim_cfg.weight_decay {
224 params.insert("weight_decay".into(), serde_json::json!(wd));
225 }
226 if let Some(momentum) = optim_cfg.momentum {
227 params.insert("momentum".into(), serde_json::json!(momentum));
228 }
229
230 Ok(OptimSpec { name, lr, params })
231}
232
233fn collect_scheduler_params(
238 s: &crate::yaml_mode::manifest::scheduler::SchedulerConfig,
239) -> Option<HashMap<String, serde_json::Value>> {
240 fn insert_opt<V: serde::Serialize>(
242 params: &mut HashMap<String, serde_json::Value>,
243 key: &str,
244 value: Option<&V>,
245 ) {
246 if let Some(v) = value {
247 params.insert(key.into(), serde_json::json!(v));
248 }
249 }
250
251 let mut params = HashMap::new();
252 insert_opt(&mut params, "t_max", s.t_max.as_ref());
253 insert_opt(&mut params, "eta_min", s.eta_min.as_ref());
254 insert_opt(&mut params, "step_size", s.step_size.as_ref());
255 insert_opt(&mut params, "gamma", s.gamma.as_ref());
256 insert_opt(&mut params, "mode", s.mode.as_ref());
257 insert_opt(&mut params, "factor", s.factor.as_ref());
258 insert_opt(&mut params, "patience", s.patience.as_ref());
259 insert_opt(&mut params, "threshold", s.threshold.as_ref());
260 insert_opt(&mut params, "max_lr", s.max_lr.as_ref());
261 insert_opt(&mut params, "pct_start", s.pct_start.as_ref());
262 insert_opt(&mut params, "anneal_strategy", s.anneal_strategy.as_ref());
263 insert_opt(&mut params, "div_factor", s.div_factor.as_ref());
264 insert_opt(&mut params, "final_div_factor", s.final_div_factor.as_ref());
265
266 if params.is_empty() {
267 None
268 } else {
269 Some(params)
270 }
271}
272
273fn warn_unsupported_training_fields(
275 training_cfg: Option<&crate::yaml_mode::manifest::training::TrainingConfig>,
276 warnings: &mut Vec<String>,
277) {
278 if training_cfg.and_then(|t| t.early_stopping.as_ref()).is_some() {
279 warnings.push("training.early_stopping is not supported in legacy TrainSpec".into());
280 }
281 if training_cfg.and_then(|t| t.distributed.as_ref()).is_some() {
282 warnings.push("training.distributed is not supported in legacy TrainSpec".into());
283 }
284}
285
286fn convert_training(
288 manifest: &TrainingManifest,
289 model_mode: ModelMode,
290 warnings: &mut Vec<String>,
291) -> TrainingParams {
292 let training_cfg = manifest.training.as_ref();
293 let scheduler_cfg = manifest.scheduler.as_ref();
294 let output_cfg = manifest.output.as_ref();
295
296 let epochs = training_cfg.and_then(|t| t.epochs).unwrap_or(10);
297
298 let grad_clip =
299 training_cfg.and_then(|t| t.gradient.as_ref()).and_then(|g| g.clip_norm).map(|v| v as f32);
300
301 let gradient_accumulation =
302 training_cfg.and_then(|t| t.gradient.as_ref()).and_then(|g| g.accumulation_steps);
303
304 let mixed_precision = training_cfg.and_then(|t| t.mixed_precision.as_ref()).and_then(|mp| {
305 if mp.enabled {
306 mp.dtype.clone()
307 } else {
308 None
309 }
310 });
311
312 let save_interval =
313 training_cfg.and_then(|t| t.checkpoint.as_ref()).and_then(|c| c.save_every).unwrap_or(1);
314
315 let lr_scheduler = scheduler_cfg.map(|s| s.name.to_lowercase());
316
317 let warmup_steps =
318 scheduler_cfg.and_then(|s| s.warmup.as_ref()).and_then(|w| w.steps).unwrap_or(0);
319
320 let scheduler_params = scheduler_cfg.and_then(collect_scheduler_params);
321
322 let output_dir =
323 output_cfg.map_or_else(|| PathBuf::from("./checkpoints"), |o| PathBuf::from(&o.dir));
324
325 let mode = if model_mode == ModelMode::Transformer {
326 TrainingMode::CausalLm
327 } else {
328 TrainingMode::default()
329 };
330
331 let seed = manifest.seed;
332
333 warn_unsupported_training_fields(training_cfg, warnings);
334
335 let deterministic = training_cfg.and_then(|t| t.deterministic).unwrap_or(false);
336
337 TrainingParams {
338 epochs,
339 grad_clip,
340 lr_scheduler,
341 warmup_steps,
342 save_interval,
343 output_dir,
344 mode,
345 gradient_accumulation,
346 checkpoints: None,
347 mixed_precision,
348 max_steps: training_cfg.and_then(|t| t.max_steps),
349 scheduler_params,
350 seed,
351 max_checkpoints: 5,
352 shuffle: true,
353 curriculum: training_cfg.and_then(|t| t.curriculum.clone()),
354 profile_interval: 0,
355 deterministic,
356 eval_interval: 0,
357 patience: 0,
358 distributed: None,
359 }
360}
361
362fn convert_lora(manifest: &TrainingManifest, _warnings: &mut Vec<String>) -> Option<LoRASpec> {
364 let lora_cfg = manifest.lora.as_ref()?;
365
366 if !lora_cfg.enabled {
367 return None;
368 }
369
370 Some(LoRASpec {
371 rank: lora_cfg.rank,
372 alpha: lora_cfg.alpha as f32,
373 target_modules: lora_cfg.target_modules.clone(),
374 dropout: lora_cfg.dropout.map_or(0.0, |d| d as f32),
375 lora_plus_ratio: 1.0,
376 double_quantize: false,
377 quantize_base: lora_cfg.quantize_base.unwrap_or(false),
378 })
379}
380
381fn convert_quantize(manifest: &TrainingManifest) -> Option<QuantSpec> {
383 let quant_cfg = manifest.quantize.as_ref()?;
384
385 if !quant_cfg.enabled {
386 return None;
387 }
388
389 let symmetric = quant_cfg.scheme.as_deref().is_none_or(|s| s == "symmetric");
390
391 let per_channel = quant_cfg.granularity.as_deref().is_none_or(|g| g == "per_channel");
392
393 Some(QuantSpec { bits: quant_cfg.bits, symmetric, per_channel })
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399 use crate::yaml_mode::manifest::core::TrainingManifest;
400 use crate::yaml_mode::manifest::data::{DataConfig, DataLoader};
401 use crate::yaml_mode::manifest::lora::LoraConfig;
402 use crate::yaml_mode::manifest::model::{ArchitectureConfig, ModelConfig};
403 use crate::yaml_mode::manifest::optimizer::OptimizerConfig;
404 use crate::yaml_mode::manifest::output::OutputConfig;
405 use crate::yaml_mode::manifest::quantize::QuantizeConfig;
406 use crate::yaml_mode::manifest::scheduler::{SchedulerConfig, WarmupConfig};
407 use crate::yaml_mode::manifest::training::{
408 CheckpointConfig, GradientConfig, MixedPrecisionConfig, TrainingConfig,
409 };
410
411 fn minimal_manifest() -> TrainingManifest {
413 TrainingManifest {
414 entrenar: "1.0".into(),
415 name: "test-experiment".into(),
416 version: "1.0.0".into(),
417 description: None,
418 seed: None,
419 data: Some(DataConfig {
420 source: Some("./data/train.parquet".into()),
421 format: None,
422 split: None,
423 train: None,
424 val: None,
425 test: None,
426 preprocessing: None,
427 augmentation: None,
428 loader: None,
429 tokenizer: None,
430 seq_len: None,
431 input_column: None,
432 output_column: None,
433 max_length: None,
434 }),
435 model: Some(ModelConfig {
436 source: "./models/base.safetensors".into(),
437 format: None,
438 architecture: None,
439 freeze: None,
440 device: None,
441 dtype: None,
442 }),
443 optimizer: Some(OptimizerConfig {
444 name: "adam".into(),
445 lr: 1e-4,
446 weight_decay: None,
447 betas: None,
448 eps: None,
449 amsgrad: None,
450 momentum: None,
451 nesterov: None,
452 dampening: None,
453 alpha: None,
454 centered: None,
455 param_groups: None,
456 }),
457 scheduler: None,
458 training: None,
459 lora: None,
460 quantize: None,
461 monitoring: None,
462 callbacks: None,
463 output: None,
464 publish: None,
465 citl: None,
466 rag: None,
467 graph: None,
468 distillation: None,
469 inspect: None,
470 privacy: None,
471 audit: None,
472 session: None,
473 stress: None,
474 benchmark: None,
475 debug: None,
476 signing: None,
477 verification: None,
478 lockfile: None,
479 strict: None,
480 strict_validation: None,
481 require_peer_review: None,
482 }
483 }
484
485 #[test]
486 fn test_minimal_manifest_converts() {
487 let manifest = minimal_manifest();
488 let result = manifest_to_spec(&manifest).expect("operation should succeed");
489 assert_eq!(result.spec.model.path, PathBuf::from("./models/base.safetensors"));
490 assert_eq!(result.spec.model.mode, ModelMode::Tabular);
491 assert_eq!(result.spec.data.train, PathBuf::from("./data/train.parquet"));
492 assert_eq!(result.spec.optimizer.name, "adam");
493 assert!((result.spec.optimizer.lr - 1e-4).abs() < 1e-6);
494 assert_eq!(result.spec.training.epochs, 10);
495 assert!(result.warnings.is_empty());
496 }
497
498 #[test]
499 fn test_missing_model_errors() {
500 let mut manifest = minimal_manifest();
501 manifest.model = None;
502 let err = manifest_to_spec(&manifest).unwrap_err();
503 assert!(matches!(err, BridgeError::MissingRequired(_)));
504 }
505
506 #[test]
507 fn test_missing_data_errors() {
508 let mut manifest = minimal_manifest();
509 manifest.data = None;
510 let err = manifest_to_spec(&manifest).unwrap_err();
511 assert!(matches!(err, BridgeError::MissingRequired(_)));
512 }
513
514 #[test]
515 fn test_missing_optimizer_errors() {
516 let mut manifest = minimal_manifest();
517 manifest.optimizer = None;
518 let err = manifest_to_spec(&manifest).unwrap_err();
519 assert!(matches!(err, BridgeError::MissingRequired(_)));
520 }
521
522 #[test]
523 fn test_missing_data_source_errors() {
524 let mut manifest = minimal_manifest();
525 manifest.data = Some(DataConfig {
526 source: None,
527 format: None,
528 split: None,
529 train: None,
530 val: None,
531 test: None,
532 preprocessing: None,
533 augmentation: None,
534 loader: None,
535 tokenizer: None,
536 seq_len: None,
537 input_column: None,
538 output_column: None,
539 max_length: None,
540 });
541 let err = manifest_to_spec(&manifest).unwrap_err();
542 assert!(matches!(err, BridgeError::MissingRequired(_)));
543 }
544
545 #[test]
546 fn test_explicit_train_path_preferred_over_source() {
547 let mut manifest = minimal_manifest();
548 manifest.data.as_mut().expect("operation should succeed").source =
549 Some("./source.parquet".into());
550 manifest.data.as_mut().expect("operation should succeed").train =
551 Some("./explicit_train.parquet".into());
552 let result = manifest_to_spec(&manifest).expect("operation should succeed");
553 assert_eq!(result.spec.data.train, PathBuf::from("./explicit_train.parquet"));
554 }
555
556 #[test]
557 fn test_val_path_converted() {
558 let mut manifest = minimal_manifest();
559 manifest.data.as_mut().expect("operation should succeed").val =
560 Some("./val.parquet".into());
561 let result = manifest_to_spec(&manifest).expect("operation should succeed");
562 assert_eq!(result.spec.data.val, Some(PathBuf::from("./val.parquet")));
563 }
564
565 #[test]
566 fn test_batch_size_from_loader() {
567 let mut manifest = minimal_manifest();
568 manifest.data.as_mut().expect("load should succeed").loader = Some(DataLoader {
569 batch_size: 32,
570 shuffle: true,
571 num_workers: None,
572 pin_memory: None,
573 drop_last: None,
574 prefetch_factor: None,
575 });
576 let result = manifest_to_spec(&manifest).expect("operation should succeed");
577 assert_eq!(result.spec.data.batch_size, 32);
578 }
579
580 #[test]
581 fn test_batch_size_default_without_loader() {
582 let manifest = minimal_manifest();
583 let result = manifest_to_spec(&manifest).expect("operation should succeed");
584 assert_eq!(result.spec.data.batch_size, 8);
585 }
586
587 #[test]
588 fn test_transformer_mode_from_architecture() {
589 let mut manifest = minimal_manifest();
590 manifest.model.as_mut().expect("config should be valid").architecture =
591 Some(ArchitectureConfig {
592 arch_type: "transformer".into(),
593 hidden_size: None,
594 num_layers: None,
595 num_heads: None,
596 num_kv_heads: None,
597 intermediate_size: None,
598 vocab_size: None,
599 max_seq_length: None,
600 rms_norm_eps: None,
601 rope_theta: None,
602 use_bias: None,
603 head_dim: None,
604 layers: None,
605 });
606 let result = manifest_to_spec(&manifest).expect("operation should succeed");
607 assert_eq!(result.spec.model.mode, ModelMode::Transformer);
608 }
609
610 #[test]
611 fn test_optimizer_params_converted() {
612 let mut manifest = minimal_manifest();
613 let opt = manifest.optimizer.as_mut().expect("operation should succeed");
614 opt.name = "adamw".into();
615 opt.lr = 3e-4;
616 opt.betas = Some(vec![0.9, 0.999]);
617 opt.eps = Some(1e-8);
618 opt.weight_decay = Some(0.01);
619
620 let result = manifest_to_spec(&manifest).expect("operation should succeed");
621 assert_eq!(result.spec.optimizer.name, "adamw");
622 assert!((result.spec.optimizer.lr - 3e-4).abs() < 1e-6);
623 assert_eq!(result.spec.optimizer.params["beta1"], serde_json::json!(0.9));
624 assert_eq!(result.spec.optimizer.params["beta2"], serde_json::json!(0.999));
625 assert!(result.spec.optimizer.params.contains_key("eps"));
626 assert!(result.spec.optimizer.params.contains_key("weight_decay"));
627 }
628
629 #[test]
630 fn test_training_config_converted() {
631 let mut manifest = minimal_manifest();
632 manifest.training = Some(TrainingConfig {
633 epochs: Some(5),
634 max_steps: None,
635 duration: None,
636 gradient: Some(GradientConfig {
637 accumulation_steps: Some(4),
638 clip_norm: Some(1.0),
639 clip_value: None,
640 }),
641 mixed_precision: Some(MixedPrecisionConfig {
642 enabled: true,
643 dtype: Some("bf16".into()),
644 loss_scale: None,
645 }),
646 distributed: None,
647 checkpoint: Some(CheckpointConfig {
648 save_every: Some(2),
649 keep_last: None,
650 save_best: None,
651 metric: None,
652 mode: None,
653 }),
654 early_stopping: None,
655 validation: None,
656 deterministic: None,
657 benchmark: None,
658 curriculum: None,
659 });
660
661 let result = manifest_to_spec(&manifest).expect("operation should succeed");
662 assert_eq!(result.spec.training.epochs, 5);
663 assert_eq!(result.spec.training.grad_clip, Some(1.0));
664 assert_eq!(result.spec.training.gradient_accumulation, Some(4));
665 assert_eq!(result.spec.training.mixed_precision, Some("bf16".into()));
666 assert_eq!(result.spec.training.save_interval, 2);
667 }
668
669 #[test]
670 fn test_scheduler_converted() {
671 let mut manifest = minimal_manifest();
672 manifest.scheduler = Some(SchedulerConfig {
673 name: "cosine".into(),
674 warmup: Some(WarmupConfig { steps: Some(100), ratio: None, start_lr: None }),
675 t_max: None,
676 eta_min: None,
677 step_size: None,
678 gamma: None,
679 mode: None,
680 factor: None,
681 patience: None,
682 threshold: None,
683 max_lr: None,
684 pct_start: None,
685 anneal_strategy: None,
686 div_factor: None,
687 final_div_factor: None,
688 });
689
690 let result = manifest_to_spec(&manifest).expect("operation should succeed");
691 assert_eq!(result.spec.training.lr_scheduler, Some("cosine".into()));
692 assert_eq!(result.spec.training.warmup_steps, 100);
693 }
694
695 #[test]
696 fn test_output_dir_converted() {
697 let mut manifest = minimal_manifest();
698 manifest.output = Some(OutputConfig {
699 dir: "./outputs/my-model".into(),
700 model: None,
701 metrics: None,
702 report: None,
703 registry: None,
704 });
705
706 let result = manifest_to_spec(&manifest).expect("operation should succeed");
707 assert_eq!(result.spec.training.output_dir, PathBuf::from("./outputs/my-model"));
708 }
709
710 #[test]
711 fn test_lora_converted() {
712 let mut manifest = minimal_manifest();
713 manifest.lora = Some(LoraConfig {
714 enabled: true,
715 rank: 64,
716 alpha: 16.0,
717 dropout: Some(0.1),
718 target_modules: vec!["q_proj".into(), "v_proj".into()],
719 target_modules_pattern: None,
720 bias: None,
721 init_weights: None,
722 quantize_base: None,
723 quantize_bits: None,
724 double_quantize: None,
725 quant_type: None,
726 });
727
728 let result = manifest_to_spec(&manifest).expect("operation should succeed");
729 let lora = result.spec.lora.expect("operation should succeed");
730 assert_eq!(lora.rank, 64);
731 assert!((lora.alpha - 16.0).abs() < 1e-6);
732 assert!((lora.dropout - 0.1).abs() < 1e-6);
733 assert_eq!(lora.target_modules, vec!["q_proj", "v_proj"]);
734 assert_eq!(result.spec.model.layers, vec!["q_proj", "v_proj"]);
736 }
737
738 #[test]
739 fn test_lora_disabled_not_converted() {
740 let mut manifest = minimal_manifest();
741 manifest.lora = Some(LoraConfig {
742 enabled: false,
743 rank: 64,
744 alpha: 16.0,
745 dropout: None,
746 target_modules: vec!["q_proj".into()],
747 target_modules_pattern: None,
748 bias: None,
749 init_weights: None,
750 quantize_base: None,
751 quantize_bits: None,
752 double_quantize: None,
753 quant_type: None,
754 });
755
756 let result = manifest_to_spec(&manifest).expect("operation should succeed");
757 assert!(result.spec.lora.is_none());
758 }
759
760 #[test]
761 fn test_quantize_converted() {
762 let mut manifest = minimal_manifest();
763 manifest.quantize = Some(QuantizeConfig {
764 enabled: true,
765 bits: 4,
766 scheme: Some("symmetric".into()),
767 granularity: Some("per_channel".into()),
768 group_size: None,
769 qat: None,
770 calibration: None,
771 exclude: None,
772 });
773
774 let result = manifest_to_spec(&manifest).expect("operation should succeed");
775 let quant = result.spec.quantize.expect("operation should succeed");
776 assert_eq!(quant.bits, 4);
777 assert!(quant.symmetric);
778 assert!(quant.per_channel);
779 }
780
781 #[test]
782 fn test_quantize_disabled_not_converted() {
783 let mut manifest = minimal_manifest();
784 manifest.quantize = Some(QuantizeConfig {
785 enabled: false,
786 bits: 4,
787 scheme: None,
788 granularity: None,
789 group_size: None,
790 qat: None,
791 calibration: None,
792 exclude: None,
793 });
794
795 let result = manifest_to_spec(&manifest).expect("operation should succeed");
796 assert!(result.spec.quantize.is_none());
797 }
798
799 #[test]
800 fn test_quantize_asymmetric() {
801 let mut manifest = minimal_manifest();
802 manifest.quantize = Some(QuantizeConfig {
803 enabled: true,
804 bits: 8,
805 scheme: Some("asymmetric".into()),
806 granularity: Some("per_tensor".into()),
807 group_size: None,
808 qat: None,
809 calibration: None,
810 exclude: None,
811 });
812
813 let result = manifest_to_spec(&manifest).expect("operation should succeed");
814 let quant = result.spec.quantize.expect("operation should succeed");
815 assert!(!quant.symmetric);
816 assert!(!quant.per_channel);
817 }
818
819 #[test]
820 fn test_unsupported_fields_produce_warnings() {
821 let mut manifest = minimal_manifest();
822 manifest.monitoring = Some(crate::yaml_mode::MonitoringConfig {
823 terminal: None,
824 tracking: None,
825 system: None,
826 alerts: None,
827 drift_detection: None,
828 });
829
830 let result = manifest_to_spec(&manifest).expect("operation should succeed");
831 assert!(!result.warnings.is_empty());
832 assert!(result.warnings.iter().any(|w| w.contains("monitoring")));
833 }
834
835 #[test]
836 fn test_training_defaults_without_config() {
837 let manifest = minimal_manifest();
838 let result = manifest_to_spec(&manifest).expect("operation should succeed");
839 assert_eq!(result.spec.training.epochs, 10);
840 assert!(result.spec.training.grad_clip.is_none());
841 assert!(result.spec.training.lr_scheduler.is_none());
842 assert_eq!(result.spec.training.warmup_steps, 0);
843 assert_eq!(result.spec.training.save_interval, 1);
844 assert_eq!(result.spec.training.output_dir, PathBuf::from("./checkpoints"));
845 }
846
847 #[test]
848 fn test_bridge_error_display() {
849 let e = BridgeError::MissingRequired("model".into());
850 assert!(e.to_string().contains("model"));
851
852 let e = BridgeError::InvalidValue { field: "lr".into(), reason: "must be positive".into() };
853 assert!(e.to_string().contains("lr"));
854 assert!(e.to_string().contains("must be positive"));
855 }
856
857 #[test]
858 fn test_full_manifest_roundtrip_from_yaml() {
859 let yaml = r#"
860entrenar: "1.0"
861name: "full-test"
862version: "1.0.0"
863seed: 42
864
865model:
866 source: "./models/llama.safetensors"
867 architecture:
868 type: transformer
869 hidden_size: 768
870
871data:
872 source: "./data/train.jsonl"
873 val: "./data/val.jsonl"
874 tokenizer: "./tokenizer.json"
875 seq_len: 2048
876 input_column: text
877 output_column: target
878 max_length: 512
879 loader:
880 batch_size: 16
881 shuffle: true
882
883optimizer:
884 name: adamw
885 lr: 0.0003
886 betas: [0.9, 0.95]
887 weight_decay: 0.1
888
889scheduler:
890 name: cosine
891 T_max: 1000
892 eta_min: 0.000001
893 warmup:
894 steps: 200
895
896training:
897 epochs: 3
898 gradient:
899 clip_norm: 1.0
900 accumulation_steps: 8
901 mixed_precision:
902 enabled: true
903 dtype: bf16
904 checkpoint:
905 save_every: 1
906
907lora:
908 enabled: true
909 rank: 32
910 alpha: 64.0
911 dropout: 0.05
912 target_modules: [q_proj, k_proj, v_proj, o_proj]
913
914quantize:
915 enabled: true
916 bits: 4
917 scheme: symmetric
918 granularity: per_channel
919
920output:
921 dir: "./outputs/full-test"
922"#;
923
924 let manifest: TrainingManifest =
925 serde_yaml::from_str(yaml).expect("operation should succeed");
926 let result = manifest_to_spec(&manifest).expect("operation should succeed");
927
928 assert_eq!(result.spec.model.mode, ModelMode::Transformer);
929 assert_eq!(result.spec.model.path, PathBuf::from("./models/llama.safetensors"));
930 assert_eq!(result.spec.model.layers, vec!["q_proj", "k_proj", "v_proj", "o_proj"]);
931 assert_eq!(result.spec.data.train, PathBuf::from("./data/train.jsonl"));
932 assert_eq!(result.spec.data.val, Some(PathBuf::from("./data/val.jsonl")));
933 assert_eq!(result.spec.data.batch_size, 16);
934 assert_eq!(result.spec.data.tokenizer, Some(PathBuf::from("./tokenizer.json")));
936 assert_eq!(result.spec.data.seq_len, Some(2048));
937 assert_eq!(result.spec.data.input_column, Some("text".into()));
938 assert_eq!(result.spec.data.output_column, Some("target".into()));
939 assert_eq!(result.spec.data.max_length, Some(512));
940 assert_eq!(result.spec.optimizer.name, "adamw");
942 assert!((result.spec.optimizer.lr - 0.0003).abs() < 1e-6);
943 assert_eq!(result.spec.training.epochs, 3);
945 assert_eq!(result.spec.training.grad_clip, Some(1.0));
946 assert_eq!(result.spec.training.gradient_accumulation, Some(8));
947 assert_eq!(result.spec.training.mixed_precision, Some("bf16".into()));
948 assert_eq!(result.spec.training.lr_scheduler, Some("cosine".into()));
949 assert_eq!(result.spec.training.warmup_steps, 200);
950 assert_eq!(result.spec.training.save_interval, 1);
951 assert_eq!(result.spec.training.output_dir, PathBuf::from("./outputs/full-test"));
952 assert_eq!(result.spec.training.mode, TrainingMode::CausalLm);
954 assert_eq!(result.spec.training.seed, Some(42));
956 let sched_params =
958 result.spec.training.scheduler_params.as_ref().expect("operation should succeed");
959 assert_eq!(sched_params["t_max"], serde_json::json!(1000));
960 assert_eq!(sched_params["eta_min"], serde_json::json!(0.000001));
961
962 let lora = result.spec.lora.expect("operation should succeed");
963 assert_eq!(lora.rank, 32);
964
965 let quant = result.spec.quantize.expect("operation should succeed");
966 assert_eq!(quant.bits, 4);
967 assert!(quant.symmetric);
968 assert!(quant.per_channel);
969 }
970
971 #[test]
974 fn test_transformer_gets_causal_lm_mode() {
975 let mut manifest = minimal_manifest();
976 manifest.model.as_mut().expect("config should be valid").architecture =
977 Some(ArchitectureConfig {
978 arch_type: "transformer".into(),
979 hidden_size: None,
980 num_layers: None,
981 num_heads: None,
982 num_kv_heads: None,
983 intermediate_size: None,
984 vocab_size: None,
985 max_seq_length: None,
986 rms_norm_eps: None,
987 rope_theta: None,
988 use_bias: None,
989 head_dim: None,
990 layers: None,
991 });
992 let result = manifest_to_spec(&manifest).expect("operation should succeed");
993 assert_eq!(result.spec.training.mode, TrainingMode::CausalLm);
994 }
995
996 #[test]
997 fn test_tabular_gets_regression_mode() {
998 let manifest = minimal_manifest();
999 let result = manifest_to_spec(&manifest).expect("operation should succeed");
1000 assert_eq!(result.spec.model.mode, ModelMode::Tabular);
1001 assert_eq!(result.spec.training.mode, TrainingMode::Regression);
1002 }
1003
1004 #[test]
1005 fn test_data_llm_fields_converted() {
1006 let mut manifest = minimal_manifest();
1007 let data = manifest.data.as_mut().expect("operation should succeed");
1008 data.tokenizer = Some("./tokenizer.json".into());
1009 data.seq_len = Some(2048);
1010 data.input_column = Some("text".into());
1011 data.output_column = Some("label".into());
1012 data.max_length = Some(512);
1013
1014 let result = manifest_to_spec(&manifest).expect("operation should succeed");
1015 assert_eq!(result.spec.data.tokenizer, Some(PathBuf::from("./tokenizer.json")));
1016 assert_eq!(result.spec.data.seq_len, Some(2048));
1017 assert_eq!(result.spec.data.input_column, Some("text".into()));
1018 assert_eq!(result.spec.data.output_column, Some("label".into()));
1019 assert_eq!(result.spec.data.max_length, Some(512));
1020 }
1021
1022 #[test]
1023 fn test_data_tokenizer_from_preprocessing_fallback() {
1024 use crate::yaml_mode::manifest::data::{PreprocessingStep, TokenizeConfig};
1025
1026 let mut manifest = minimal_manifest();
1027 let data = manifest.data.as_mut().expect("operation should succeed");
1028 data.preprocessing = Some(vec![PreprocessingStep::Tokenize {
1030 tokenize: TokenizeConfig {
1031 tokenizer: "./fallback-tokenizer.json".into(),
1032 max_length: Some(256),
1033 padding: None,
1034 truncation: None,
1035 },
1036 }]);
1037
1038 let result = manifest_to_spec(&manifest).expect("operation should succeed");
1039 assert_eq!(result.spec.data.tokenizer, Some(PathBuf::from("./fallback-tokenizer.json")));
1040 assert_eq!(result.spec.data.max_length, Some(256));
1041 }
1042
1043 #[test]
1044 fn test_data_toplevel_tokenizer_takes_precedence() {
1045 use crate::yaml_mode::manifest::data::{PreprocessingStep, TokenizeConfig};
1046
1047 let mut manifest = minimal_manifest();
1048 let data = manifest.data.as_mut().expect("operation should succeed");
1049 data.tokenizer = Some("./primary.json".into());
1050 data.max_length = Some(1024);
1051 data.preprocessing = Some(vec![PreprocessingStep::Tokenize {
1052 tokenize: TokenizeConfig {
1053 tokenizer: "./fallback.json".into(),
1054 max_length: Some(256),
1055 padding: None,
1056 truncation: None,
1057 },
1058 }]);
1059
1060 let result = manifest_to_spec(&manifest).expect("operation should succeed");
1061 assert_eq!(result.spec.data.tokenizer, Some(PathBuf::from("./primary.json")));
1063 assert_eq!(result.spec.data.max_length, Some(1024));
1064 }
1065
1066 #[test]
1067 fn test_scheduler_params_cosine() {
1068 let mut manifest = minimal_manifest();
1069 manifest.scheduler = Some(SchedulerConfig {
1070 name: "cosine".into(),
1071 warmup: None,
1072 t_max: Some(500),
1073 eta_min: Some(1e-6),
1074 step_size: None,
1075 gamma: None,
1076 mode: None,
1077 factor: None,
1078 patience: None,
1079 threshold: None,
1080 max_lr: None,
1081 pct_start: None,
1082 anneal_strategy: None,
1083 div_factor: None,
1084 final_div_factor: None,
1085 });
1086
1087 let result = manifest_to_spec(&manifest).expect("operation should succeed");
1088 let params = result.spec.training.scheduler_params.expect("operation should succeed");
1089 assert_eq!(params["t_max"], serde_json::json!(500));
1090 assert_eq!(params["eta_min"], serde_json::json!(1e-6));
1091 assert_eq!(params.len(), 2);
1092 }
1093
1094 #[test]
1095 fn test_scheduler_params_step() {
1096 let mut manifest = minimal_manifest();
1097 manifest.scheduler = Some(SchedulerConfig {
1098 name: "step".into(),
1099 warmup: None,
1100 t_max: None,
1101 eta_min: None,
1102 step_size: Some(30),
1103 gamma: Some(0.1),
1104 mode: None,
1105 factor: None,
1106 patience: None,
1107 threshold: None,
1108 max_lr: None,
1109 pct_start: None,
1110 anneal_strategy: None,
1111 div_factor: None,
1112 final_div_factor: None,
1113 });
1114
1115 let result = manifest_to_spec(&manifest).expect("operation should succeed");
1116 let params = result.spec.training.scheduler_params.expect("operation should succeed");
1117 assert_eq!(params["step_size"], serde_json::json!(30));
1118 assert_eq!(params["gamma"], serde_json::json!(0.1));
1119 assert_eq!(params.len(), 2);
1120 }
1121
1122 #[test]
1123 fn test_scheduler_params_plateau() {
1124 let mut manifest = minimal_manifest();
1125 manifest.scheduler = Some(SchedulerConfig {
1126 name: "plateau".into(),
1127 warmup: None,
1128 t_max: None,
1129 eta_min: None,
1130 step_size: None,
1131 gamma: None,
1132 mode: Some("min".into()),
1133 factor: Some(0.1),
1134 patience: Some(10),
1135 threshold: Some(1e-4),
1136 max_lr: None,
1137 pct_start: None,
1138 anneal_strategy: None,
1139 div_factor: None,
1140 final_div_factor: None,
1141 });
1142
1143 let result = manifest_to_spec(&manifest).expect("operation should succeed");
1144 let params = result.spec.training.scheduler_params.expect("operation should succeed");
1145 assert_eq!(params["mode"], serde_json::json!("min"));
1146 assert_eq!(params["factor"], serde_json::json!(0.1));
1147 assert_eq!(params["patience"], serde_json::json!(10));
1148 assert_eq!(params["threshold"], serde_json::json!(1e-4));
1149 assert_eq!(params.len(), 4);
1150 }
1151
1152 #[test]
1153 fn test_scheduler_params_one_cycle() {
1154 let mut manifest = minimal_manifest();
1155 manifest.scheduler = Some(SchedulerConfig {
1156 name: "one_cycle".into(),
1157 warmup: None,
1158 t_max: None,
1159 eta_min: None,
1160 step_size: None,
1161 gamma: None,
1162 mode: None,
1163 factor: None,
1164 patience: None,
1165 threshold: None,
1166 max_lr: Some(0.01),
1167 pct_start: Some(0.3),
1168 anneal_strategy: Some("cos".into()),
1169 div_factor: Some(25.0),
1170 final_div_factor: Some(1e4),
1171 });
1172
1173 let result = manifest_to_spec(&manifest).expect("operation should succeed");
1174 let params = result.spec.training.scheduler_params.expect("operation should succeed");
1175 assert_eq!(params["max_lr"], serde_json::json!(0.01));
1176 assert_eq!(params["pct_start"], serde_json::json!(0.3));
1177 assert_eq!(params["anneal_strategy"], serde_json::json!("cos"));
1178 assert_eq!(params["div_factor"], serde_json::json!(25.0));
1179 assert_eq!(params["final_div_factor"], serde_json::json!(1e4));
1180 assert_eq!(params.len(), 5);
1181 }
1182
1183 #[test]
1184 fn test_scheduler_no_params_yields_none() {
1185 let mut manifest = minimal_manifest();
1186 manifest.scheduler = Some(SchedulerConfig {
1187 name: "cosine".into(),
1188 warmup: Some(WarmupConfig { steps: Some(100), ratio: None, start_lr: None }),
1189 t_max: None,
1190 eta_min: None,
1191 step_size: None,
1192 gamma: None,
1193 mode: None,
1194 factor: None,
1195 patience: None,
1196 threshold: None,
1197 max_lr: None,
1198 pct_start: None,
1199 anneal_strategy: None,
1200 div_factor: None,
1201 final_div_factor: None,
1202 });
1203
1204 let result = manifest_to_spec(&manifest).expect("operation should succeed");
1205 assert!(result.spec.training.scheduler_params.is_none());
1207 }
1208
1209 #[test]
1210 fn test_seed_passed_through() {
1211 let mut manifest = minimal_manifest();
1212 manifest.seed = Some(12345);
1213 let result = manifest_to_spec(&manifest).expect("operation should succeed");
1214 assert_eq!(result.spec.training.seed, Some(12345));
1215 }
1216
1217 #[test]
1218 fn test_seed_none_when_not_set() {
1219 let manifest = minimal_manifest();
1220 let result = manifest_to_spec(&manifest).expect("operation should succeed");
1221 assert!(result.spec.training.seed.is_none());
1222 }
1223
1224 #[test]
1225 fn test_architecture_overrides_converted() {
1226 let mut manifest = minimal_manifest();
1227 manifest.model.as_mut().expect("config should be valid").architecture =
1228 Some(ArchitectureConfig {
1229 arch_type: "transformer".into(),
1230 hidden_size: Some(1024),
1231 num_layers: Some(16),
1232 num_heads: Some(16),
1233 num_kv_heads: Some(4),
1234 intermediate_size: Some(4096),
1235 vocab_size: Some(50000),
1236 max_seq_length: Some(2048),
1237 rms_norm_eps: Some(1e-5),
1238 rope_theta: Some(500_000.0),
1239 use_bias: Some(true),
1240 head_dim: None,
1241 layers: None,
1242 });
1243 let result = manifest_to_spec(&manifest).expect("operation should succeed");
1244 let arch = result.spec.model.architecture.expect("architecture overrides should be set");
1245 assert_eq!(arch.hidden_size, Some(1024));
1246 assert_eq!(arch.num_hidden_layers, Some(16));
1247 assert_eq!(arch.num_attention_heads, Some(16));
1248 assert_eq!(arch.num_kv_heads, Some(4));
1249 assert_eq!(arch.intermediate_size, Some(4096));
1250 assert_eq!(arch.vocab_size, Some(50000));
1251 assert_eq!(arch.max_position_embeddings, Some(2048));
1252 assert_eq!(arch.rms_norm_eps, Some(1e-5));
1253 assert_eq!(arch.rope_theta, Some(500_000.0));
1254 assert_eq!(arch.use_bias, Some(true));
1255 }
1256
1257 #[test]
1258 fn test_architecture_overrides_none_when_no_params() {
1259 let mut manifest = minimal_manifest();
1260 manifest.model.as_mut().expect("config should be valid").architecture =
1261 Some(ArchitectureConfig {
1262 arch_type: "transformer".into(),
1263 hidden_size: None,
1264 num_layers: None,
1265 num_heads: None,
1266 num_kv_heads: None,
1267 intermediate_size: None,
1268 vocab_size: None,
1269 max_seq_length: None,
1270 rms_norm_eps: None,
1271 rope_theta: None,
1272 use_bias: None,
1273 head_dim: None,
1274 layers: None,
1275 });
1276 let result = manifest_to_spec(&manifest).expect("operation should succeed");
1277 assert!(result.spec.model.architecture.is_none());
1279 }
1280
1281 #[test]
1282 fn test_architecture_overrides_partial() {
1283 let mut manifest = minimal_manifest();
1284 manifest.model.as_mut().expect("config should be valid").architecture =
1285 Some(ArchitectureConfig {
1286 arch_type: "transformer".into(),
1287 hidden_size: Some(768),
1288 num_layers: None,
1289 num_heads: None,
1290 num_kv_heads: None,
1291 intermediate_size: None,
1292 vocab_size: None,
1293 max_seq_length: None,
1294 rms_norm_eps: None,
1295 rope_theta: None,
1296 use_bias: None,
1297 head_dim: None,
1298 layers: None,
1299 });
1300 let result = manifest_to_spec(&manifest).expect("operation should succeed");
1301 let arch = result.spec.model.architecture.expect("architecture overrides should be set");
1302 assert_eq!(arch.hidden_size, Some(768));
1303 assert!(arch.num_attention_heads.is_none());
1304 }
1305
1306 #[test]
1307 fn test_architecture_overrides_from_yaml() {
1308 let yaml = r#"
1309entrenar: "1.0"
1310name: "arch-test"
1311version: "1.0.0"
1312
1313model:
1314 source: "./models/custom.safetensors"
1315 architecture:
1316 type: transformer
1317 hidden_size: 1024
1318 num_layers: 16
1319 num_heads: 16
1320 num_kv_heads: 4
1321 intermediate_size: 4096
1322 vocab_size: 50000
1323 max_seq_length: 2048
1324 rms_norm_eps: 0.00001
1325 rope_theta: 500000.0
1326 use_bias: true
1327
1328data:
1329 source: "./data/train.parquet"
1330
1331optimizer:
1332 name: adamw
1333 lr: 0.0003
1334"#;
1335 let manifest: TrainingManifest = serde_yaml::from_str(yaml).expect("YAML should parse");
1336 let result = manifest_to_spec(&manifest).expect("bridge should succeed");
1337 assert_eq!(result.spec.model.mode, ModelMode::Transformer);
1338 let arch = result.spec.model.architecture.expect("overrides should be set");
1339 assert_eq!(arch.hidden_size, Some(1024));
1340 assert_eq!(arch.num_hidden_layers, Some(16));
1341 assert_eq!(arch.num_attention_heads, Some(16));
1342 assert_eq!(arch.num_kv_heads, Some(4));
1343 assert_eq!(arch.intermediate_size, Some(4096));
1344 assert_eq!(arch.vocab_size, Some(50000));
1345 assert_eq!(arch.max_position_embeddings, Some(2048));
1346 }
1347
1348 #[test]
1349 fn test_deterministic_passed_through() {
1350 let mut manifest = minimal_manifest();
1351 manifest.training = Some(TrainingConfig {
1352 epochs: Some(5),
1353 max_steps: None,
1354 duration: None,
1355 gradient: None,
1356 mixed_precision: None,
1357 distributed: None,
1358 checkpoint: None,
1359 early_stopping: None,
1360 validation: None,
1361 deterministic: Some(true),
1362 benchmark: None,
1363 curriculum: None,
1364 });
1365 let result = manifest_to_spec(&manifest).expect("operation should succeed");
1366 assert!(result.spec.training.deterministic, "deterministic should be true");
1367 }
1368
1369 #[test]
1370 fn test_deterministic_defaults_false_when_not_set() {
1371 let manifest = minimal_manifest();
1372 let result = manifest_to_spec(&manifest).expect("operation should succeed");
1373 assert!(!result.spec.training.deterministic, "deterministic should default to false");
1374 }
1375}