Skip to main content

entrenar/yaml_mode/
bridge.rs

1//! Bridge converter: TrainingManifest → TrainSpec
2//!
3//! Maps the declarative YAML Mode manifest format to the legacy TrainSpec
4//! so that `entrenar train manifest.yaml` works with the existing pipeline.
5
6use crate::config::{
7    DataConfig as SpecDataConfig, LoRASpec, ModelMode, ModelRef, OptimSpec, QuantSpec, TrainSpec,
8    TrainingMode, TrainingParams,
9};
10use crate::yaml_mode::TrainingManifest;
11use std::collections::HashMap;
12use std::path::PathBuf;
13use thiserror::Error;
14
15/// Result of converting a TrainingManifest to TrainSpec
16#[derive(Debug)]
17pub struct BridgeResult {
18    /// The converted TrainSpec
19    pub spec: TrainSpec,
20    /// Warnings about unsupported/ignored manifest fields
21    pub warnings: Vec<String>,
22}
23
24/// Errors that can occur during manifest-to-spec conversion
25#[derive(Debug, Error)]
26pub enum BridgeError {
27    #[error("Missing required field: {0}")]
28    MissingRequired(String),
29
30    #[error("Invalid value for {field}: {reason}")]
31    InvalidValue { field: String, reason: String },
32}
33
34/// Convert a TrainingManifest into a TrainSpec
35///
36/// Maps declarative manifest fields to the legacy spec format.
37/// Returns warnings for manifest features not supported by TrainSpec.
38pub fn manifest_to_spec(manifest: &TrainingManifest) -> Result<BridgeResult, BridgeError> {
39    let mut warnings = Vec::new();
40
41    let model = convert_model(manifest, &mut warnings)?;
42    let data = convert_data(manifest, &mut warnings)?;
43    let optimizer = convert_optimizer(manifest)?;
44    let training = convert_training(manifest, model.mode, &mut warnings);
45    let lora = convert_lora(manifest, &mut warnings);
46    let quantize = convert_quantize(manifest);
47
48    // Merge config is not in manifest format — skip
49    if manifest.monitoring.is_some() {
50        warnings.push("monitoring config is not supported in legacy TrainSpec".into());
51    }
52    if manifest.callbacks.is_some() {
53        warnings.push("callbacks config is not supported in legacy TrainSpec".into());
54    }
55    if manifest.distillation.is_some() {
56        warnings.push("distillation config is not supported in legacy TrainSpec".into());
57    }
58
59    let spec =
60        TrainSpec { model, data, optimizer, training, lora, quantize, merge: None, publish: None };
61
62    Ok(BridgeResult { spec, warnings })
63}
64
65/// Convert model config from manifest to spec
66fn convert_model(
67    manifest: &TrainingManifest,
68    warnings: &mut Vec<String>,
69) -> Result<ModelRef, BridgeError> {
70    use crate::config::ArchitectureOverrides;
71
72    let model_cfg =
73        manifest.model.as_ref().ok_or_else(|| BridgeError::MissingRequired("model".into()))?;
74
75    // Determine model mode from architecture
76    let mode = if let Some(ref arch) = model_cfg.architecture {
77        if arch.arch_type == "transformer" {
78            ModelMode::Transformer
79        } else {
80            ModelMode::Tabular
81        }
82    } else {
83        ModelMode::Tabular
84    };
85
86    // Use LoRA target_modules as model layers
87    let layers = manifest.lora.as_ref().map(|l| l.target_modules.clone()).unwrap_or_default();
88
89    if model_cfg.freeze.is_some() {
90        warnings.push("model.freeze is not supported in legacy TrainSpec".into());
91    }
92    if model_cfg.device.is_some() {
93        warnings.push("model.device is not supported in legacy TrainSpec".into());
94    }
95
96    // Convert architecture params to overrides
97    let architecture = model_cfg
98        .architecture
99        .as_ref()
100        .map(|arch| ArchitectureOverrides {
101            hidden_size: arch.hidden_size,
102            num_hidden_layers: arch.num_layers,
103            num_attention_heads: arch.num_heads,
104            num_kv_heads: arch.num_kv_heads,
105            intermediate_size: arch.intermediate_size,
106            vocab_size: arch.vocab_size,
107            max_position_embeddings: arch.max_seq_length,
108            rms_norm_eps: arch.rms_norm_eps,
109            rope_theta: arch.rope_theta,
110            use_bias: arch.use_bias,
111            head_dim: arch.head_dim,
112        })
113        .filter(|o| !o.is_empty());
114
115    Ok(ModelRef {
116        path: PathBuf::from(&model_cfg.source),
117        layers,
118        mode,
119        config: None,
120        architecture,
121    })
122}
123
124/// Resolve the training data path from manifest data config.
125///
126/// Prefers explicit `train` field; falls back to `source`.
127fn resolve_train_path(
128    data_cfg: &crate::yaml_mode::manifest::data::DataConfig,
129) -> Result<PathBuf, BridgeError> {
130    data_cfg
131        .train
132        .as_deref()
133        .or(data_cfg.source.as_deref())
134        .map(PathBuf::from)
135        .ok_or_else(|| BridgeError::MissingRequired("data.source or data.train".into()))
136}
137
138/// Extract tokenizer path and max_length from the first Tokenize preprocessing step.
139///
140/// Returns `(Option<PathBuf>, Option<usize>)` for tokenizer and max_length respectively.
141fn extract_preprocessing_tokenizer(
142    steps: &[crate::yaml_mode::manifest::data::PreprocessingStep],
143) -> (Option<PathBuf>, Option<usize>) {
144    for step in steps {
145        if let crate::yaml_mode::manifest::data::PreprocessingStep::Tokenize { tokenize } = step {
146            return (Some(PathBuf::from(&tokenize.tokenizer)), tokenize.max_length);
147        }
148    }
149    (None, None)
150}
151
152/// Convert data config from manifest to spec
153fn convert_data(
154    manifest: &TrainingManifest,
155    warnings: &mut Vec<String>,
156) -> Result<SpecDataConfig, BridgeError> {
157    let data_cfg =
158        manifest.data.as_ref().ok_or_else(|| BridgeError::MissingRequired("data".into()))?;
159
160    let train = resolve_train_path(data_cfg)?;
161    let val = data_cfg.val.as_ref().map(PathBuf::from);
162    let batch_size = data_cfg.loader.as_ref().map_or(8, |l| l.batch_size);
163
164    // Bridge LLM data fields directly from manifest
165    let mut tokenizer = data_cfg.tokenizer.as_ref().map(PathBuf::from);
166    let seq_len = data_cfg.seq_len;
167    let input_column = data_cfg.input_column.clone();
168    let output_column = data_cfg.output_column.clone();
169    let mut max_length = data_cfg.max_length;
170
171    // Fallback: extract tokenizer/max_length from preprocessing Tokenize step
172    if let Some(ref steps) = data_cfg.preprocessing {
173        let (fallback_tok, fallback_len) = extract_preprocessing_tokenizer(steps);
174        if tokenizer.is_none() {
175            tokenizer = fallback_tok;
176        }
177        if max_length.is_none() {
178            max_length = fallback_len;
179        }
180    }
181
182    if data_cfg.augmentation.is_some() {
183        warnings.push("data.augmentation is not supported in legacy TrainSpec".into());
184    }
185
186    Ok(SpecDataConfig {
187        train,
188        val,
189        batch_size,
190        auto_infer_types: true,
191        seq_len,
192        tokenizer,
193        input_column,
194        output_column,
195        max_length,
196    })
197}
198
199/// Convert optimizer config from manifest to spec
200fn convert_optimizer(manifest: &TrainingManifest) -> Result<OptimSpec, BridgeError> {
201    let optim_cfg = manifest
202        .optimizer
203        .as_ref()
204        .ok_or_else(|| BridgeError::MissingRequired("optimizer".into()))?;
205
206    let name = optim_cfg.name.to_lowercase();
207
208    // f64 -> f32 conversion for learning rate
209    let lr = optim_cfg.lr as f32;
210
211    // Pack optional parameters into HashMap
212    let mut params: HashMap<String, serde_json::Value> = HashMap::new();
213
214    if let Some(ref betas) = optim_cfg.betas {
215        if betas.len() >= 2 {
216            params.insert("beta1".into(), serde_json::json!(betas[0]));
217            params.insert("beta2".into(), serde_json::json!(betas[1]));
218        }
219    }
220    if let Some(eps) = optim_cfg.eps {
221        params.insert("eps".into(), serde_json::json!(eps));
222    }
223    if let Some(wd) = optim_cfg.weight_decay {
224        params.insert("weight_decay".into(), serde_json::json!(wd));
225    }
226    if let Some(momentum) = optim_cfg.momentum {
227        params.insert("momentum".into(), serde_json::json!(momentum));
228    }
229
230    Ok(OptimSpec { name, lr, params })
231}
232
233/// Collect scheduler-specific parameters into a HashMap.
234///
235/// Inserts each optional scheduler field that is `Some` into the map.
236/// Returns `None` if no parameters were set.
237fn collect_scheduler_params(
238    s: &crate::yaml_mode::manifest::scheduler::SchedulerConfig,
239) -> Option<HashMap<String, serde_json::Value>> {
240    /// Helper: insert a JSON-serializable value into the map if `Some`.
241    fn insert_opt<V: serde::Serialize>(
242        params: &mut HashMap<String, serde_json::Value>,
243        key: &str,
244        value: Option<&V>,
245    ) {
246        if let Some(v) = value {
247            params.insert(key.into(), serde_json::json!(v));
248        }
249    }
250
251    let mut params = HashMap::new();
252    insert_opt(&mut params, "t_max", s.t_max.as_ref());
253    insert_opt(&mut params, "eta_min", s.eta_min.as_ref());
254    insert_opt(&mut params, "step_size", s.step_size.as_ref());
255    insert_opt(&mut params, "gamma", s.gamma.as_ref());
256    insert_opt(&mut params, "mode", s.mode.as_ref());
257    insert_opt(&mut params, "factor", s.factor.as_ref());
258    insert_opt(&mut params, "patience", s.patience.as_ref());
259    insert_opt(&mut params, "threshold", s.threshold.as_ref());
260    insert_opt(&mut params, "max_lr", s.max_lr.as_ref());
261    insert_opt(&mut params, "pct_start", s.pct_start.as_ref());
262    insert_opt(&mut params, "anneal_strategy", s.anneal_strategy.as_ref());
263    insert_opt(&mut params, "div_factor", s.div_factor.as_ref());
264    insert_opt(&mut params, "final_div_factor", s.final_div_factor.as_ref());
265
266    if params.is_empty() {
267        None
268    } else {
269        Some(params)
270    }
271}
272
273/// Emit warnings for unsupported training sub-fields.
274fn warn_unsupported_training_fields(
275    training_cfg: Option<&crate::yaml_mode::manifest::training::TrainingConfig>,
276    warnings: &mut Vec<String>,
277) {
278    if training_cfg.and_then(|t| t.early_stopping.as_ref()).is_some() {
279        warnings.push("training.early_stopping is not supported in legacy TrainSpec".into());
280    }
281    if training_cfg.and_then(|t| t.distributed.as_ref()).is_some() {
282        warnings.push("training.distributed is not supported in legacy TrainSpec".into());
283    }
284}
285
286/// Convert training config from manifest to spec
287fn convert_training(
288    manifest: &TrainingManifest,
289    model_mode: ModelMode,
290    warnings: &mut Vec<String>,
291) -> TrainingParams {
292    let training_cfg = manifest.training.as_ref();
293    let scheduler_cfg = manifest.scheduler.as_ref();
294    let output_cfg = manifest.output.as_ref();
295
296    let epochs = training_cfg.and_then(|t| t.epochs).unwrap_or(10);
297
298    let grad_clip =
299        training_cfg.and_then(|t| t.gradient.as_ref()).and_then(|g| g.clip_norm).map(|v| v as f32);
300
301    let gradient_accumulation =
302        training_cfg.and_then(|t| t.gradient.as_ref()).and_then(|g| g.accumulation_steps);
303
304    let mixed_precision = training_cfg.and_then(|t| t.mixed_precision.as_ref()).and_then(|mp| {
305        if mp.enabled {
306            mp.dtype.clone()
307        } else {
308            None
309        }
310    });
311
312    let save_interval =
313        training_cfg.and_then(|t| t.checkpoint.as_ref()).and_then(|c| c.save_every).unwrap_or(1);
314
315    let lr_scheduler = scheduler_cfg.map(|s| s.name.to_lowercase());
316
317    let warmup_steps =
318        scheduler_cfg.and_then(|s| s.warmup.as_ref()).and_then(|w| w.steps).unwrap_or(0);
319
320    let scheduler_params = scheduler_cfg.and_then(collect_scheduler_params);
321
322    let output_dir =
323        output_cfg.map_or_else(|| PathBuf::from("./checkpoints"), |o| PathBuf::from(&o.dir));
324
325    let mode = if model_mode == ModelMode::Transformer {
326        TrainingMode::CausalLm
327    } else {
328        TrainingMode::default()
329    };
330
331    let seed = manifest.seed;
332
333    warn_unsupported_training_fields(training_cfg, warnings);
334
335    let deterministic = training_cfg.and_then(|t| t.deterministic).unwrap_or(false);
336
337    TrainingParams {
338        epochs,
339        grad_clip,
340        lr_scheduler,
341        warmup_steps,
342        save_interval,
343        output_dir,
344        mode,
345        gradient_accumulation,
346        checkpoints: None,
347        mixed_precision,
348        max_steps: training_cfg.and_then(|t| t.max_steps),
349        scheduler_params,
350        seed,
351        max_checkpoints: 5,
352        shuffle: true,
353        curriculum: training_cfg.and_then(|t| t.curriculum.clone()),
354        profile_interval: 0,
355        deterministic,
356        eval_interval: 0,
357        patience: 0,
358        distributed: None,
359    }
360}
361
362/// Convert LoRA config from manifest to spec
363fn convert_lora(manifest: &TrainingManifest, _warnings: &mut Vec<String>) -> Option<LoRASpec> {
364    let lora_cfg = manifest.lora.as_ref()?;
365
366    if !lora_cfg.enabled {
367        return None;
368    }
369
370    Some(LoRASpec {
371        rank: lora_cfg.rank,
372        alpha: lora_cfg.alpha as f32,
373        target_modules: lora_cfg.target_modules.clone(),
374        dropout: lora_cfg.dropout.map_or(0.0, |d| d as f32),
375        lora_plus_ratio: 1.0,
376        double_quantize: false,
377        quantize_base: lora_cfg.quantize_base.unwrap_or(false),
378    })
379}
380
381/// Convert quantization config from manifest to spec
382fn convert_quantize(manifest: &TrainingManifest) -> Option<QuantSpec> {
383    let quant_cfg = manifest.quantize.as_ref()?;
384
385    if !quant_cfg.enabled {
386        return None;
387    }
388
389    let symmetric = quant_cfg.scheme.as_deref().is_none_or(|s| s == "symmetric");
390
391    let per_channel = quant_cfg.granularity.as_deref().is_none_or(|g| g == "per_channel");
392
393    Some(QuantSpec { bits: quant_cfg.bits, symmetric, per_channel })
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399    use crate::yaml_mode::manifest::core::TrainingManifest;
400    use crate::yaml_mode::manifest::data::{DataConfig, DataLoader};
401    use crate::yaml_mode::manifest::lora::LoraConfig;
402    use crate::yaml_mode::manifest::model::{ArchitectureConfig, ModelConfig};
403    use crate::yaml_mode::manifest::optimizer::OptimizerConfig;
404    use crate::yaml_mode::manifest::output::OutputConfig;
405    use crate::yaml_mode::manifest::quantize::QuantizeConfig;
406    use crate::yaml_mode::manifest::scheduler::{SchedulerConfig, WarmupConfig};
407    use crate::yaml_mode::manifest::training::{
408        CheckpointConfig, GradientConfig, MixedPrecisionConfig, TrainingConfig,
409    };
410
411    /// Create a minimal valid manifest for testing
412    fn minimal_manifest() -> TrainingManifest {
413        TrainingManifest {
414            entrenar: "1.0".into(),
415            name: "test-experiment".into(),
416            version: "1.0.0".into(),
417            description: None,
418            seed: None,
419            data: Some(DataConfig {
420                source: Some("./data/train.parquet".into()),
421                format: None,
422                split: None,
423                train: None,
424                val: None,
425                test: None,
426                preprocessing: None,
427                augmentation: None,
428                loader: None,
429                tokenizer: None,
430                seq_len: None,
431                input_column: None,
432                output_column: None,
433                max_length: None,
434            }),
435            model: Some(ModelConfig {
436                source: "./models/base.safetensors".into(),
437                format: None,
438                architecture: None,
439                freeze: None,
440                device: None,
441                dtype: None,
442            }),
443            optimizer: Some(OptimizerConfig {
444                name: "adam".into(),
445                lr: 1e-4,
446                weight_decay: None,
447                betas: None,
448                eps: None,
449                amsgrad: None,
450                momentum: None,
451                nesterov: None,
452                dampening: None,
453                alpha: None,
454                centered: None,
455                param_groups: None,
456            }),
457            scheduler: None,
458            training: None,
459            lora: None,
460            quantize: None,
461            monitoring: None,
462            callbacks: None,
463            output: None,
464            publish: None,
465            citl: None,
466            rag: None,
467            graph: None,
468            distillation: None,
469            inspect: None,
470            privacy: None,
471            audit: None,
472            session: None,
473            stress: None,
474            benchmark: None,
475            debug: None,
476            signing: None,
477            verification: None,
478            lockfile: None,
479            strict: None,
480            strict_validation: None,
481            require_peer_review: None,
482        }
483    }
484
485    #[test]
486    fn test_minimal_manifest_converts() {
487        let manifest = minimal_manifest();
488        let result = manifest_to_spec(&manifest).expect("operation should succeed");
489        assert_eq!(result.spec.model.path, PathBuf::from("./models/base.safetensors"));
490        assert_eq!(result.spec.model.mode, ModelMode::Tabular);
491        assert_eq!(result.spec.data.train, PathBuf::from("./data/train.parquet"));
492        assert_eq!(result.spec.optimizer.name, "adam");
493        assert!((result.spec.optimizer.lr - 1e-4).abs() < 1e-6);
494        assert_eq!(result.spec.training.epochs, 10);
495        assert!(result.warnings.is_empty());
496    }
497
498    #[test]
499    fn test_missing_model_errors() {
500        let mut manifest = minimal_manifest();
501        manifest.model = None;
502        let err = manifest_to_spec(&manifest).unwrap_err();
503        assert!(matches!(err, BridgeError::MissingRequired(_)));
504    }
505
506    #[test]
507    fn test_missing_data_errors() {
508        let mut manifest = minimal_manifest();
509        manifest.data = None;
510        let err = manifest_to_spec(&manifest).unwrap_err();
511        assert!(matches!(err, BridgeError::MissingRequired(_)));
512    }
513
514    #[test]
515    fn test_missing_optimizer_errors() {
516        let mut manifest = minimal_manifest();
517        manifest.optimizer = None;
518        let err = manifest_to_spec(&manifest).unwrap_err();
519        assert!(matches!(err, BridgeError::MissingRequired(_)));
520    }
521
522    #[test]
523    fn test_missing_data_source_errors() {
524        let mut manifest = minimal_manifest();
525        manifest.data = Some(DataConfig {
526            source: None,
527            format: None,
528            split: None,
529            train: None,
530            val: None,
531            test: None,
532            preprocessing: None,
533            augmentation: None,
534            loader: None,
535            tokenizer: None,
536            seq_len: None,
537            input_column: None,
538            output_column: None,
539            max_length: None,
540        });
541        let err = manifest_to_spec(&manifest).unwrap_err();
542        assert!(matches!(err, BridgeError::MissingRequired(_)));
543    }
544
545    #[test]
546    fn test_explicit_train_path_preferred_over_source() {
547        let mut manifest = minimal_manifest();
548        manifest.data.as_mut().expect("operation should succeed").source =
549            Some("./source.parquet".into());
550        manifest.data.as_mut().expect("operation should succeed").train =
551            Some("./explicit_train.parquet".into());
552        let result = manifest_to_spec(&manifest).expect("operation should succeed");
553        assert_eq!(result.spec.data.train, PathBuf::from("./explicit_train.parquet"));
554    }
555
556    #[test]
557    fn test_val_path_converted() {
558        let mut manifest = minimal_manifest();
559        manifest.data.as_mut().expect("operation should succeed").val =
560            Some("./val.parquet".into());
561        let result = manifest_to_spec(&manifest).expect("operation should succeed");
562        assert_eq!(result.spec.data.val, Some(PathBuf::from("./val.parquet")));
563    }
564
565    #[test]
566    fn test_batch_size_from_loader() {
567        let mut manifest = minimal_manifest();
568        manifest.data.as_mut().expect("load should succeed").loader = Some(DataLoader {
569            batch_size: 32,
570            shuffle: true,
571            num_workers: None,
572            pin_memory: None,
573            drop_last: None,
574            prefetch_factor: None,
575        });
576        let result = manifest_to_spec(&manifest).expect("operation should succeed");
577        assert_eq!(result.spec.data.batch_size, 32);
578    }
579
580    #[test]
581    fn test_batch_size_default_without_loader() {
582        let manifest = minimal_manifest();
583        let result = manifest_to_spec(&manifest).expect("operation should succeed");
584        assert_eq!(result.spec.data.batch_size, 8);
585    }
586
587    #[test]
588    fn test_transformer_mode_from_architecture() {
589        let mut manifest = minimal_manifest();
590        manifest.model.as_mut().expect("config should be valid").architecture =
591            Some(ArchitectureConfig {
592                arch_type: "transformer".into(),
593                hidden_size: None,
594                num_layers: None,
595                num_heads: None,
596                num_kv_heads: None,
597                intermediate_size: None,
598                vocab_size: None,
599                max_seq_length: None,
600                rms_norm_eps: None,
601                rope_theta: None,
602                use_bias: None,
603                head_dim: None,
604                layers: None,
605            });
606        let result = manifest_to_spec(&manifest).expect("operation should succeed");
607        assert_eq!(result.spec.model.mode, ModelMode::Transformer);
608    }
609
610    #[test]
611    fn test_optimizer_params_converted() {
612        let mut manifest = minimal_manifest();
613        let opt = manifest.optimizer.as_mut().expect("operation should succeed");
614        opt.name = "adamw".into();
615        opt.lr = 3e-4;
616        opt.betas = Some(vec![0.9, 0.999]);
617        opt.eps = Some(1e-8);
618        opt.weight_decay = Some(0.01);
619
620        let result = manifest_to_spec(&manifest).expect("operation should succeed");
621        assert_eq!(result.spec.optimizer.name, "adamw");
622        assert!((result.spec.optimizer.lr - 3e-4).abs() < 1e-6);
623        assert_eq!(result.spec.optimizer.params["beta1"], serde_json::json!(0.9));
624        assert_eq!(result.spec.optimizer.params["beta2"], serde_json::json!(0.999));
625        assert!(result.spec.optimizer.params.contains_key("eps"));
626        assert!(result.spec.optimizer.params.contains_key("weight_decay"));
627    }
628
629    #[test]
630    fn test_training_config_converted() {
631        let mut manifest = minimal_manifest();
632        manifest.training = Some(TrainingConfig {
633            epochs: Some(5),
634            max_steps: None,
635            duration: None,
636            gradient: Some(GradientConfig {
637                accumulation_steps: Some(4),
638                clip_norm: Some(1.0),
639                clip_value: None,
640            }),
641            mixed_precision: Some(MixedPrecisionConfig {
642                enabled: true,
643                dtype: Some("bf16".into()),
644                loss_scale: None,
645            }),
646            distributed: None,
647            checkpoint: Some(CheckpointConfig {
648                save_every: Some(2),
649                keep_last: None,
650                save_best: None,
651                metric: None,
652                mode: None,
653            }),
654            early_stopping: None,
655            validation: None,
656            deterministic: None,
657            benchmark: None,
658            curriculum: None,
659        });
660
661        let result = manifest_to_spec(&manifest).expect("operation should succeed");
662        assert_eq!(result.spec.training.epochs, 5);
663        assert_eq!(result.spec.training.grad_clip, Some(1.0));
664        assert_eq!(result.spec.training.gradient_accumulation, Some(4));
665        assert_eq!(result.spec.training.mixed_precision, Some("bf16".into()));
666        assert_eq!(result.spec.training.save_interval, 2);
667    }
668
669    #[test]
670    fn test_scheduler_converted() {
671        let mut manifest = minimal_manifest();
672        manifest.scheduler = Some(SchedulerConfig {
673            name: "cosine".into(),
674            warmup: Some(WarmupConfig { steps: Some(100), ratio: None, start_lr: None }),
675            t_max: None,
676            eta_min: None,
677            step_size: None,
678            gamma: None,
679            mode: None,
680            factor: None,
681            patience: None,
682            threshold: None,
683            max_lr: None,
684            pct_start: None,
685            anneal_strategy: None,
686            div_factor: None,
687            final_div_factor: None,
688        });
689
690        let result = manifest_to_spec(&manifest).expect("operation should succeed");
691        assert_eq!(result.spec.training.lr_scheduler, Some("cosine".into()));
692        assert_eq!(result.spec.training.warmup_steps, 100);
693    }
694
695    #[test]
696    fn test_output_dir_converted() {
697        let mut manifest = minimal_manifest();
698        manifest.output = Some(OutputConfig {
699            dir: "./outputs/my-model".into(),
700            model: None,
701            metrics: None,
702            report: None,
703            registry: None,
704        });
705
706        let result = manifest_to_spec(&manifest).expect("operation should succeed");
707        assert_eq!(result.spec.training.output_dir, PathBuf::from("./outputs/my-model"));
708    }
709
710    #[test]
711    fn test_lora_converted() {
712        let mut manifest = minimal_manifest();
713        manifest.lora = Some(LoraConfig {
714            enabled: true,
715            rank: 64,
716            alpha: 16.0,
717            dropout: Some(0.1),
718            target_modules: vec!["q_proj".into(), "v_proj".into()],
719            target_modules_pattern: None,
720            bias: None,
721            init_weights: None,
722            quantize_base: None,
723            quantize_bits: None,
724            double_quantize: None,
725            quant_type: None,
726        });
727
728        let result = manifest_to_spec(&manifest).expect("operation should succeed");
729        let lora = result.spec.lora.expect("operation should succeed");
730        assert_eq!(lora.rank, 64);
731        assert!((lora.alpha - 16.0).abs() < 1e-6);
732        assert!((lora.dropout - 0.1).abs() < 1e-6);
733        assert_eq!(lora.target_modules, vec!["q_proj", "v_proj"]);
734        // Also check that target_modules get mapped to model layers
735        assert_eq!(result.spec.model.layers, vec!["q_proj", "v_proj"]);
736    }
737
738    #[test]
739    fn test_lora_disabled_not_converted() {
740        let mut manifest = minimal_manifest();
741        manifest.lora = Some(LoraConfig {
742            enabled: false,
743            rank: 64,
744            alpha: 16.0,
745            dropout: None,
746            target_modules: vec!["q_proj".into()],
747            target_modules_pattern: None,
748            bias: None,
749            init_weights: None,
750            quantize_base: None,
751            quantize_bits: None,
752            double_quantize: None,
753            quant_type: None,
754        });
755
756        let result = manifest_to_spec(&manifest).expect("operation should succeed");
757        assert!(result.spec.lora.is_none());
758    }
759
760    #[test]
761    fn test_quantize_converted() {
762        let mut manifest = minimal_manifest();
763        manifest.quantize = Some(QuantizeConfig {
764            enabled: true,
765            bits: 4,
766            scheme: Some("symmetric".into()),
767            granularity: Some("per_channel".into()),
768            group_size: None,
769            qat: None,
770            calibration: None,
771            exclude: None,
772        });
773
774        let result = manifest_to_spec(&manifest).expect("operation should succeed");
775        let quant = result.spec.quantize.expect("operation should succeed");
776        assert_eq!(quant.bits, 4);
777        assert!(quant.symmetric);
778        assert!(quant.per_channel);
779    }
780
781    #[test]
782    fn test_quantize_disabled_not_converted() {
783        let mut manifest = minimal_manifest();
784        manifest.quantize = Some(QuantizeConfig {
785            enabled: false,
786            bits: 4,
787            scheme: None,
788            granularity: None,
789            group_size: None,
790            qat: None,
791            calibration: None,
792            exclude: None,
793        });
794
795        let result = manifest_to_spec(&manifest).expect("operation should succeed");
796        assert!(result.spec.quantize.is_none());
797    }
798
799    #[test]
800    fn test_quantize_asymmetric() {
801        let mut manifest = minimal_manifest();
802        manifest.quantize = Some(QuantizeConfig {
803            enabled: true,
804            bits: 8,
805            scheme: Some("asymmetric".into()),
806            granularity: Some("per_tensor".into()),
807            group_size: None,
808            qat: None,
809            calibration: None,
810            exclude: None,
811        });
812
813        let result = manifest_to_spec(&manifest).expect("operation should succeed");
814        let quant = result.spec.quantize.expect("operation should succeed");
815        assert!(!quant.symmetric);
816        assert!(!quant.per_channel);
817    }
818
819    #[test]
820    fn test_unsupported_fields_produce_warnings() {
821        let mut manifest = minimal_manifest();
822        manifest.monitoring = Some(crate::yaml_mode::MonitoringConfig {
823            terminal: None,
824            tracking: None,
825            system: None,
826            alerts: None,
827            drift_detection: None,
828        });
829
830        let result = manifest_to_spec(&manifest).expect("operation should succeed");
831        assert!(!result.warnings.is_empty());
832        assert!(result.warnings.iter().any(|w| w.contains("monitoring")));
833    }
834
835    #[test]
836    fn test_training_defaults_without_config() {
837        let manifest = minimal_manifest();
838        let result = manifest_to_spec(&manifest).expect("operation should succeed");
839        assert_eq!(result.spec.training.epochs, 10);
840        assert!(result.spec.training.grad_clip.is_none());
841        assert!(result.spec.training.lr_scheduler.is_none());
842        assert_eq!(result.spec.training.warmup_steps, 0);
843        assert_eq!(result.spec.training.save_interval, 1);
844        assert_eq!(result.spec.training.output_dir, PathBuf::from("./checkpoints"));
845    }
846
847    #[test]
848    fn test_bridge_error_display() {
849        let e = BridgeError::MissingRequired("model".into());
850        assert!(e.to_string().contains("model"));
851
852        let e = BridgeError::InvalidValue { field: "lr".into(), reason: "must be positive".into() };
853        assert!(e.to_string().contains("lr"));
854        assert!(e.to_string().contains("must be positive"));
855    }
856
857    #[test]
858    fn test_full_manifest_roundtrip_from_yaml() {
859        let yaml = r#"
860entrenar: "1.0"
861name: "full-test"
862version: "1.0.0"
863seed: 42
864
865model:
866  source: "./models/llama.safetensors"
867  architecture:
868    type: transformer
869    hidden_size: 768
870
871data:
872  source: "./data/train.jsonl"
873  val: "./data/val.jsonl"
874  tokenizer: "./tokenizer.json"
875  seq_len: 2048
876  input_column: text
877  output_column: target
878  max_length: 512
879  loader:
880    batch_size: 16
881    shuffle: true
882
883optimizer:
884  name: adamw
885  lr: 0.0003
886  betas: [0.9, 0.95]
887  weight_decay: 0.1
888
889scheduler:
890  name: cosine
891  T_max: 1000
892  eta_min: 0.000001
893  warmup:
894    steps: 200
895
896training:
897  epochs: 3
898  gradient:
899    clip_norm: 1.0
900    accumulation_steps: 8
901  mixed_precision:
902    enabled: true
903    dtype: bf16
904  checkpoint:
905    save_every: 1
906
907lora:
908  enabled: true
909  rank: 32
910  alpha: 64.0
911  dropout: 0.05
912  target_modules: [q_proj, k_proj, v_proj, o_proj]
913
914quantize:
915  enabled: true
916  bits: 4
917  scheme: symmetric
918  granularity: per_channel
919
920output:
921  dir: "./outputs/full-test"
922"#;
923
924        let manifest: TrainingManifest =
925            serde_yaml::from_str(yaml).expect("operation should succeed");
926        let result = manifest_to_spec(&manifest).expect("operation should succeed");
927
928        assert_eq!(result.spec.model.mode, ModelMode::Transformer);
929        assert_eq!(result.spec.model.path, PathBuf::from("./models/llama.safetensors"));
930        assert_eq!(result.spec.model.layers, vec!["q_proj", "k_proj", "v_proj", "o_proj"]);
931        assert_eq!(result.spec.data.train, PathBuf::from("./data/train.jsonl"));
932        assert_eq!(result.spec.data.val, Some(PathBuf::from("./data/val.jsonl")));
933        assert_eq!(result.spec.data.batch_size, 16);
934        // LLM data fields
935        assert_eq!(result.spec.data.tokenizer, Some(PathBuf::from("./tokenizer.json")));
936        assert_eq!(result.spec.data.seq_len, Some(2048));
937        assert_eq!(result.spec.data.input_column, Some("text".into()));
938        assert_eq!(result.spec.data.output_column, Some("target".into()));
939        assert_eq!(result.spec.data.max_length, Some(512));
940        // Optimizer
941        assert_eq!(result.spec.optimizer.name, "adamw");
942        assert!((result.spec.optimizer.lr - 0.0003).abs() < 1e-6);
943        // Training
944        assert_eq!(result.spec.training.epochs, 3);
945        assert_eq!(result.spec.training.grad_clip, Some(1.0));
946        assert_eq!(result.spec.training.gradient_accumulation, Some(8));
947        assert_eq!(result.spec.training.mixed_precision, Some("bf16".into()));
948        assert_eq!(result.spec.training.lr_scheduler, Some("cosine".into()));
949        assert_eq!(result.spec.training.warmup_steps, 200);
950        assert_eq!(result.spec.training.save_interval, 1);
951        assert_eq!(result.spec.training.output_dir, PathBuf::from("./outputs/full-test"));
952        // Training mode: transformer → CausalLm
953        assert_eq!(result.spec.training.mode, TrainingMode::CausalLm);
954        // Seed
955        assert_eq!(result.spec.training.seed, Some(42));
956        // Scheduler params
957        let sched_params =
958            result.spec.training.scheduler_params.as_ref().expect("operation should succeed");
959        assert_eq!(sched_params["t_max"], serde_json::json!(1000));
960        assert_eq!(sched_params["eta_min"], serde_json::json!(0.000001));
961
962        let lora = result.spec.lora.expect("operation should succeed");
963        assert_eq!(lora.rank, 32);
964
965        let quant = result.spec.quantize.expect("operation should succeed");
966        assert_eq!(quant.bits, 4);
967        assert!(quant.symmetric);
968        assert!(quant.per_channel);
969    }
970
971    // === Phase 2 Tests ===
972
973    #[test]
974    fn test_transformer_gets_causal_lm_mode() {
975        let mut manifest = minimal_manifest();
976        manifest.model.as_mut().expect("config should be valid").architecture =
977            Some(ArchitectureConfig {
978                arch_type: "transformer".into(),
979                hidden_size: None,
980                num_layers: None,
981                num_heads: None,
982                num_kv_heads: None,
983                intermediate_size: None,
984                vocab_size: None,
985                max_seq_length: None,
986                rms_norm_eps: None,
987                rope_theta: None,
988                use_bias: None,
989                head_dim: None,
990                layers: None,
991            });
992        let result = manifest_to_spec(&manifest).expect("operation should succeed");
993        assert_eq!(result.spec.training.mode, TrainingMode::CausalLm);
994    }
995
996    #[test]
997    fn test_tabular_gets_regression_mode() {
998        let manifest = minimal_manifest();
999        let result = manifest_to_spec(&manifest).expect("operation should succeed");
1000        assert_eq!(result.spec.model.mode, ModelMode::Tabular);
1001        assert_eq!(result.spec.training.mode, TrainingMode::Regression);
1002    }
1003
1004    #[test]
1005    fn test_data_llm_fields_converted() {
1006        let mut manifest = minimal_manifest();
1007        let data = manifest.data.as_mut().expect("operation should succeed");
1008        data.tokenizer = Some("./tokenizer.json".into());
1009        data.seq_len = Some(2048);
1010        data.input_column = Some("text".into());
1011        data.output_column = Some("label".into());
1012        data.max_length = Some(512);
1013
1014        let result = manifest_to_spec(&manifest).expect("operation should succeed");
1015        assert_eq!(result.spec.data.tokenizer, Some(PathBuf::from("./tokenizer.json")));
1016        assert_eq!(result.spec.data.seq_len, Some(2048));
1017        assert_eq!(result.spec.data.input_column, Some("text".into()));
1018        assert_eq!(result.spec.data.output_column, Some("label".into()));
1019        assert_eq!(result.spec.data.max_length, Some(512));
1020    }
1021
1022    #[test]
1023    fn test_data_tokenizer_from_preprocessing_fallback() {
1024        use crate::yaml_mode::manifest::data::{PreprocessingStep, TokenizeConfig};
1025
1026        let mut manifest = minimal_manifest();
1027        let data = manifest.data.as_mut().expect("operation should succeed");
1028        // No top-level tokenizer, but set preprocessing
1029        data.preprocessing = Some(vec![PreprocessingStep::Tokenize {
1030            tokenize: TokenizeConfig {
1031                tokenizer: "./fallback-tokenizer.json".into(),
1032                max_length: Some(256),
1033                padding: None,
1034                truncation: None,
1035            },
1036        }]);
1037
1038        let result = manifest_to_spec(&manifest).expect("operation should succeed");
1039        assert_eq!(result.spec.data.tokenizer, Some(PathBuf::from("./fallback-tokenizer.json")));
1040        assert_eq!(result.spec.data.max_length, Some(256));
1041    }
1042
1043    #[test]
1044    fn test_data_toplevel_tokenizer_takes_precedence() {
1045        use crate::yaml_mode::manifest::data::{PreprocessingStep, TokenizeConfig};
1046
1047        let mut manifest = minimal_manifest();
1048        let data = manifest.data.as_mut().expect("operation should succeed");
1049        data.tokenizer = Some("./primary.json".into());
1050        data.max_length = Some(1024);
1051        data.preprocessing = Some(vec![PreprocessingStep::Tokenize {
1052            tokenize: TokenizeConfig {
1053                tokenizer: "./fallback.json".into(),
1054                max_length: Some(256),
1055                padding: None,
1056                truncation: None,
1057            },
1058        }]);
1059
1060        let result = manifest_to_spec(&manifest).expect("operation should succeed");
1061        // Top-level takes precedence over preprocessing fallback
1062        assert_eq!(result.spec.data.tokenizer, Some(PathBuf::from("./primary.json")));
1063        assert_eq!(result.spec.data.max_length, Some(1024));
1064    }
1065
1066    #[test]
1067    fn test_scheduler_params_cosine() {
1068        let mut manifest = minimal_manifest();
1069        manifest.scheduler = Some(SchedulerConfig {
1070            name: "cosine".into(),
1071            warmup: None,
1072            t_max: Some(500),
1073            eta_min: Some(1e-6),
1074            step_size: None,
1075            gamma: None,
1076            mode: None,
1077            factor: None,
1078            patience: None,
1079            threshold: None,
1080            max_lr: None,
1081            pct_start: None,
1082            anneal_strategy: None,
1083            div_factor: None,
1084            final_div_factor: None,
1085        });
1086
1087        let result = manifest_to_spec(&manifest).expect("operation should succeed");
1088        let params = result.spec.training.scheduler_params.expect("operation should succeed");
1089        assert_eq!(params["t_max"], serde_json::json!(500));
1090        assert_eq!(params["eta_min"], serde_json::json!(1e-6));
1091        assert_eq!(params.len(), 2);
1092    }
1093
1094    #[test]
1095    fn test_scheduler_params_step() {
1096        let mut manifest = minimal_manifest();
1097        manifest.scheduler = Some(SchedulerConfig {
1098            name: "step".into(),
1099            warmup: None,
1100            t_max: None,
1101            eta_min: None,
1102            step_size: Some(30),
1103            gamma: Some(0.1),
1104            mode: None,
1105            factor: None,
1106            patience: None,
1107            threshold: None,
1108            max_lr: None,
1109            pct_start: None,
1110            anneal_strategy: None,
1111            div_factor: None,
1112            final_div_factor: None,
1113        });
1114
1115        let result = manifest_to_spec(&manifest).expect("operation should succeed");
1116        let params = result.spec.training.scheduler_params.expect("operation should succeed");
1117        assert_eq!(params["step_size"], serde_json::json!(30));
1118        assert_eq!(params["gamma"], serde_json::json!(0.1));
1119        assert_eq!(params.len(), 2);
1120    }
1121
1122    #[test]
1123    fn test_scheduler_params_plateau() {
1124        let mut manifest = minimal_manifest();
1125        manifest.scheduler = Some(SchedulerConfig {
1126            name: "plateau".into(),
1127            warmup: None,
1128            t_max: None,
1129            eta_min: None,
1130            step_size: None,
1131            gamma: None,
1132            mode: Some("min".into()),
1133            factor: Some(0.1),
1134            patience: Some(10),
1135            threshold: Some(1e-4),
1136            max_lr: None,
1137            pct_start: None,
1138            anneal_strategy: None,
1139            div_factor: None,
1140            final_div_factor: None,
1141        });
1142
1143        let result = manifest_to_spec(&manifest).expect("operation should succeed");
1144        let params = result.spec.training.scheduler_params.expect("operation should succeed");
1145        assert_eq!(params["mode"], serde_json::json!("min"));
1146        assert_eq!(params["factor"], serde_json::json!(0.1));
1147        assert_eq!(params["patience"], serde_json::json!(10));
1148        assert_eq!(params["threshold"], serde_json::json!(1e-4));
1149        assert_eq!(params.len(), 4);
1150    }
1151
1152    #[test]
1153    fn test_scheduler_params_one_cycle() {
1154        let mut manifest = minimal_manifest();
1155        manifest.scheduler = Some(SchedulerConfig {
1156            name: "one_cycle".into(),
1157            warmup: None,
1158            t_max: None,
1159            eta_min: None,
1160            step_size: None,
1161            gamma: None,
1162            mode: None,
1163            factor: None,
1164            patience: None,
1165            threshold: None,
1166            max_lr: Some(0.01),
1167            pct_start: Some(0.3),
1168            anneal_strategy: Some("cos".into()),
1169            div_factor: Some(25.0),
1170            final_div_factor: Some(1e4),
1171        });
1172
1173        let result = manifest_to_spec(&manifest).expect("operation should succeed");
1174        let params = result.spec.training.scheduler_params.expect("operation should succeed");
1175        assert_eq!(params["max_lr"], serde_json::json!(0.01));
1176        assert_eq!(params["pct_start"], serde_json::json!(0.3));
1177        assert_eq!(params["anneal_strategy"], serde_json::json!("cos"));
1178        assert_eq!(params["div_factor"], serde_json::json!(25.0));
1179        assert_eq!(params["final_div_factor"], serde_json::json!(1e4));
1180        assert_eq!(params.len(), 5);
1181    }
1182
1183    #[test]
1184    fn test_scheduler_no_params_yields_none() {
1185        let mut manifest = minimal_manifest();
1186        manifest.scheduler = Some(SchedulerConfig {
1187            name: "cosine".into(),
1188            warmup: Some(WarmupConfig { steps: Some(100), ratio: None, start_lr: None }),
1189            t_max: None,
1190            eta_min: None,
1191            step_size: None,
1192            gamma: None,
1193            mode: None,
1194            factor: None,
1195            patience: None,
1196            threshold: None,
1197            max_lr: None,
1198            pct_start: None,
1199            anneal_strategy: None,
1200            div_factor: None,
1201            final_div_factor: None,
1202        });
1203
1204        let result = manifest_to_spec(&manifest).expect("operation should succeed");
1205        // Only warmup set, no scheduler-specific params → None
1206        assert!(result.spec.training.scheduler_params.is_none());
1207    }
1208
1209    #[test]
1210    fn test_seed_passed_through() {
1211        let mut manifest = minimal_manifest();
1212        manifest.seed = Some(12345);
1213        let result = manifest_to_spec(&manifest).expect("operation should succeed");
1214        assert_eq!(result.spec.training.seed, Some(12345));
1215    }
1216
1217    #[test]
1218    fn test_seed_none_when_not_set() {
1219        let manifest = minimal_manifest();
1220        let result = manifest_to_spec(&manifest).expect("operation should succeed");
1221        assert!(result.spec.training.seed.is_none());
1222    }
1223
1224    #[test]
1225    fn test_architecture_overrides_converted() {
1226        let mut manifest = minimal_manifest();
1227        manifest.model.as_mut().expect("config should be valid").architecture =
1228            Some(ArchitectureConfig {
1229                arch_type: "transformer".into(),
1230                hidden_size: Some(1024),
1231                num_layers: Some(16),
1232                num_heads: Some(16),
1233                num_kv_heads: Some(4),
1234                intermediate_size: Some(4096),
1235                vocab_size: Some(50000),
1236                max_seq_length: Some(2048),
1237                rms_norm_eps: Some(1e-5),
1238                rope_theta: Some(500_000.0),
1239                use_bias: Some(true),
1240                head_dim: None,
1241                layers: None,
1242            });
1243        let result = manifest_to_spec(&manifest).expect("operation should succeed");
1244        let arch = result.spec.model.architecture.expect("architecture overrides should be set");
1245        assert_eq!(arch.hidden_size, Some(1024));
1246        assert_eq!(arch.num_hidden_layers, Some(16));
1247        assert_eq!(arch.num_attention_heads, Some(16));
1248        assert_eq!(arch.num_kv_heads, Some(4));
1249        assert_eq!(arch.intermediate_size, Some(4096));
1250        assert_eq!(arch.vocab_size, Some(50000));
1251        assert_eq!(arch.max_position_embeddings, Some(2048));
1252        assert_eq!(arch.rms_norm_eps, Some(1e-5));
1253        assert_eq!(arch.rope_theta, Some(500_000.0));
1254        assert_eq!(arch.use_bias, Some(true));
1255    }
1256
1257    #[test]
1258    fn test_architecture_overrides_none_when_no_params() {
1259        let mut manifest = minimal_manifest();
1260        manifest.model.as_mut().expect("config should be valid").architecture =
1261            Some(ArchitectureConfig {
1262                arch_type: "transformer".into(),
1263                hidden_size: None,
1264                num_layers: None,
1265                num_heads: None,
1266                num_kv_heads: None,
1267                intermediate_size: None,
1268                vocab_size: None,
1269                max_seq_length: None,
1270                rms_norm_eps: None,
1271                rope_theta: None,
1272                use_bias: None,
1273                head_dim: None,
1274                layers: None,
1275            });
1276        let result = manifest_to_spec(&manifest).expect("operation should succeed");
1277        // All None → overrides filtered out
1278        assert!(result.spec.model.architecture.is_none());
1279    }
1280
1281    #[test]
1282    fn test_architecture_overrides_partial() {
1283        let mut manifest = minimal_manifest();
1284        manifest.model.as_mut().expect("config should be valid").architecture =
1285            Some(ArchitectureConfig {
1286                arch_type: "transformer".into(),
1287                hidden_size: Some(768),
1288                num_layers: None,
1289                num_heads: None,
1290                num_kv_heads: None,
1291                intermediate_size: None,
1292                vocab_size: None,
1293                max_seq_length: None,
1294                rms_norm_eps: None,
1295                rope_theta: None,
1296                use_bias: None,
1297                head_dim: None,
1298                layers: None,
1299            });
1300        let result = manifest_to_spec(&manifest).expect("operation should succeed");
1301        let arch = result.spec.model.architecture.expect("architecture overrides should be set");
1302        assert_eq!(arch.hidden_size, Some(768));
1303        assert!(arch.num_attention_heads.is_none());
1304    }
1305
1306    #[test]
1307    fn test_architecture_overrides_from_yaml() {
1308        let yaml = r#"
1309entrenar: "1.0"
1310name: "arch-test"
1311version: "1.0.0"
1312
1313model:
1314  source: "./models/custom.safetensors"
1315  architecture:
1316    type: transformer
1317    hidden_size: 1024
1318    num_layers: 16
1319    num_heads: 16
1320    num_kv_heads: 4
1321    intermediate_size: 4096
1322    vocab_size: 50000
1323    max_seq_length: 2048
1324    rms_norm_eps: 0.00001
1325    rope_theta: 500000.0
1326    use_bias: true
1327
1328data:
1329  source: "./data/train.parquet"
1330
1331optimizer:
1332  name: adamw
1333  lr: 0.0003
1334"#;
1335        let manifest: TrainingManifest = serde_yaml::from_str(yaml).expect("YAML should parse");
1336        let result = manifest_to_spec(&manifest).expect("bridge should succeed");
1337        assert_eq!(result.spec.model.mode, ModelMode::Transformer);
1338        let arch = result.spec.model.architecture.expect("overrides should be set");
1339        assert_eq!(arch.hidden_size, Some(1024));
1340        assert_eq!(arch.num_hidden_layers, Some(16));
1341        assert_eq!(arch.num_attention_heads, Some(16));
1342        assert_eq!(arch.num_kv_heads, Some(4));
1343        assert_eq!(arch.intermediate_size, Some(4096));
1344        assert_eq!(arch.vocab_size, Some(50000));
1345        assert_eq!(arch.max_position_embeddings, Some(2048));
1346    }
1347
1348    #[test]
1349    fn test_deterministic_passed_through() {
1350        let mut manifest = minimal_manifest();
1351        manifest.training = Some(TrainingConfig {
1352            epochs: Some(5),
1353            max_steps: None,
1354            duration: None,
1355            gradient: None,
1356            mixed_precision: None,
1357            distributed: None,
1358            checkpoint: None,
1359            early_stopping: None,
1360            validation: None,
1361            deterministic: Some(true),
1362            benchmark: None,
1363            curriculum: None,
1364        });
1365        let result = manifest_to_spec(&manifest).expect("operation should succeed");
1366        assert!(result.spec.training.deterministic, "deterministic should be true");
1367    }
1368
1369    #[test]
1370    fn test_deterministic_defaults_false_when_not_set() {
1371        let manifest = minimal_manifest();
1372        let result = manifest_to_spec(&manifest).expect("operation should succeed");
1373        assert!(!result.spec.training.deterministic, "deterministic should default to false");
1374    }
1375}