Skip to main content

entrenar/config/
schema.rs

1//! YAML schema definitions for declarative training configuration
2//!
3//! ENT-114: Added `ModelMode` and `TrainingMode` for LLM training support.
4
5use serde::{Deserialize, Deserializer, Serialize};
6use std::collections::HashMap;
7use std::path::PathBuf;
8
9/// Deserialize a bool from either a YAML boolean (`true`) or a quoted string (`"true"`).
10/// This supports CB-950 compliance where all truthy values must be quoted in YAML.
11fn deserialize_bool_lenient<'de, D>(deserializer: D) -> Result<bool, D::Error>
12where
13    D: Deserializer<'de>,
14{
15    #[derive(Deserialize)]
16    #[serde(untagged)]
17    enum BoolOrString {
18        Bool(bool),
19        Str(String),
20    }
21
22    match BoolOrString::deserialize(deserializer)? {
23        BoolOrString::Bool(b) => Ok(b),
24        BoolOrString::Str(s) => match s.to_lowercase().as_str() {
25            "true" => Ok(true),
26            "false" => Ok(false),
27            other => {
28                Err(serde::de::Error::custom(format!("expected 'true' or 'false', got '{other}'")))
29            }
30        },
31    }
32}
33
34/// Model execution mode
35///
36/// Determines whether to use generic tabular ML training or transformer-based LLM training.
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
38#[serde(rename_all = "lowercase")]
39pub enum ModelMode {
40    /// Generic tabular ML (uses Trainer + MSELoss)
41    #[default]
42    Tabular,
43    /// Transformer-based LLM (uses TransformerTrainer + CausalLMLoss)
44    Transformer,
45}
46
47/// Training loss mode
48///
49/// Determines which loss function to use during training.
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
51#[serde(rename_all = "snake_case")]
52pub enum TrainingMode {
53    /// Mean squared error loss (regression)
54    #[default]
55    Regression,
56    /// Cross-entropy loss for next-token prediction (language modeling)
57    CausalLm,
58}
59
60/// Complete training specification
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct TrainSpec {
63    /// Model configuration
64    pub model: ModelRef,
65
66    /// Data configuration
67    pub data: DataConfig,
68
69    /// Optimizer configuration
70    pub optimizer: OptimSpec,
71
72    /// Optional LoRA configuration
73    #[serde(default, skip_serializing_if = "Option::is_none")]
74    pub lora: Option<LoRASpec>,
75
76    /// Optional quantization configuration
77    #[serde(default, skip_serializing_if = "Option::is_none")]
78    pub quantize: Option<QuantSpec>,
79
80    /// Optional model merging configuration
81    #[serde(default, skip_serializing_if = "Option::is_none")]
82    pub merge: Option<MergeSpec>,
83
84    /// Training hyperparameters
85    #[serde(default)]
86    pub training: TrainingParams,
87
88    /// Optional auto-publish after training completes
89    #[serde(default, skip_serializing_if = "Option::is_none")]
90    pub publish: Option<PublishSpec>,
91}
92
93/// Auto-publish configuration for uploading to HuggingFace Hub after training.
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct PublishSpec {
96    /// HuggingFace repo ID (e.g., "myuser/my-model")
97    pub repo: String,
98
99    /// Make the repository private
100    #[serde(default)]
101    pub private: bool,
102
103    /// Generate and upload a model card
104    #[serde(default = "default_true")]
105    pub model_card: bool,
106
107    /// Merge LoRA adapters before publishing
108    #[serde(default)]
109    pub merge_adapters: bool,
110
111    /// Export format (safetensors or gguf)
112    #[serde(default = "default_safetensors")]
113    pub format: String,
114}
115
116fn default_safetensors() -> String {
117    "safetensors".to_string()
118}
119
120/// Architecture override parameters from YAML manifest.
121///
122/// These override individual fields of the `TransformerConfig` resolved from
123/// `config.json` or preset defaults. Only `Some` fields apply; `None` fields
124/// are left as-is from the base config.
125#[derive(Debug, Clone, Default, Serialize, Deserialize)]
126pub struct ArchitectureOverrides {
127    /// Hidden size (embedding dimension)
128    #[serde(default, skip_serializing_if = "Option::is_none")]
129    pub hidden_size: Option<usize>,
130    /// Number of transformer layers
131    #[serde(default, skip_serializing_if = "Option::is_none", alias = "num_layers")]
132    pub num_hidden_layers: Option<usize>,
133    /// Number of attention heads
134    #[serde(default, skip_serializing_if = "Option::is_none", alias = "num_heads")]
135    pub num_attention_heads: Option<usize>,
136    /// Number of key-value heads (for grouped-query attention)
137    #[serde(default, skip_serializing_if = "Option::is_none", alias = "num_key_value_heads")]
138    pub num_kv_heads: Option<usize>,
139    /// FFN intermediate dimension
140    #[serde(default, skip_serializing_if = "Option::is_none")]
141    pub intermediate_size: Option<usize>,
142    /// Vocabulary size
143    #[serde(default, skip_serializing_if = "Option::is_none")]
144    pub vocab_size: Option<usize>,
145    /// Maximum sequence/position length
146    #[serde(default, skip_serializing_if = "Option::is_none", alias = "max_seq_length")]
147    pub max_position_embeddings: Option<usize>,
148    /// RMS normalization epsilon
149    #[serde(default, skip_serializing_if = "Option::is_none")]
150    pub rms_norm_eps: Option<f32>,
151    /// RoPE theta (rotary positional encoding base)
152    #[serde(default, skip_serializing_if = "Option::is_none")]
153    pub rope_theta: Option<f32>,
154    /// Whether to use bias in linear layers
155    #[serde(default, skip_serializing_if = "Option::is_none")]
156    pub use_bias: Option<bool>,
157    /// Per-head dimension override (for models where head_dim != hidden_size / num_heads)
158    #[serde(default, skip_serializing_if = "Option::is_none")]
159    pub head_dim: Option<usize>,
160}
161
162impl ArchitectureOverrides {
163    /// Returns true if no overrides are set.
164    pub fn is_empty(&self) -> bool {
165        self.hidden_size.is_none()
166            && self.num_hidden_layers.is_none()
167            && self.num_attention_heads.is_none()
168            && self.num_kv_heads.is_none()
169            && self.intermediate_size.is_none()
170            && self.vocab_size.is_none()
171            && self.max_position_embeddings.is_none()
172            && self.rms_norm_eps.is_none()
173            && self.rope_theta.is_none()
174            && self.use_bias.is_none()
175            && self.head_dim.is_none()
176    }
177}
178
179/// Model reference and target layers
180#[derive(Debug, Clone, Serialize, Deserialize, Default)]
181pub struct ModelRef {
182    /// Path to base model — local path (GGUF, safetensors, etc.) or HuggingFace repo ID
183    /// (e.g., "Qwen/Qwen2.5-Coder-0.5B"). HF repo IDs are auto-detected and downloaded.
184    #[serde(default)]
185    pub path: PathBuf,
186
187    /// Target layers for LoRA (if applicable)
188    #[serde(default)]
189    pub layers: Vec<String>,
190
191    /// Model execution mode (tabular or transformer)
192    /// ENT-114: Routes to TransformerTrainer when mode=transformer
193    #[serde(default)]
194    pub mode: ModelMode,
195
196    /// Transformer architecture config preset (e.g., "qwen2_1_5b", "llama2_7b")
197    /// Only used when mode=transformer
198    #[serde(default, skip_serializing_if = "Option::is_none")]
199    pub config: Option<String>,
200
201    /// Architecture parameter overrides from YAML manifest.
202    /// Applied on top of config.json / preset values.
203    #[serde(default, skip_serializing_if = "Option::is_none")]
204    pub architecture: Option<ArchitectureOverrides>,
205}
206
207impl ModelRef {
208    /// Check if the model path looks like a HuggingFace repo ID (e.g., "org/model-name").
209    ///
210    /// Detection: contains exactly one `/`, no file extension, and both parts are non-empty.
211    pub fn is_hf_repo_id(&self) -> bool {
212        let s = self.path.to_string_lossy();
213        is_hf_repo_id(&s)
214    }
215}
216
217/// Check if a string looks like a HuggingFace repo ID.
218///
219/// Returns true if the string has the format "org/name" where:
220/// - There is exactly one `/`
221/// - Both parts are non-empty
222/// - The name doesn't end with a known model file extension
223/// - The string doesn't start with `.` or `/` (not a filesystem path)
224pub fn is_hf_repo_id(s: &str) -> bool {
225    // Must not start with `.` or `/` (those are filesystem paths)
226    if s.starts_with('.') || s.starts_with('/') {
227        return false;
228    }
229
230    let parts: Vec<&str> = s.split('/').collect();
231    if parts.len() != 2 {
232        return false;
233    }
234
235    let (org, name) = (parts[0], parts[1]);
236
237    // Both parts must be non-empty
238    if org.is_empty() || name.is_empty() {
239        return false;
240    }
241
242    // Reject if name ends with a known model file extension
243    let file_extensions = [
244        ".safetensors",
245        ".gguf",
246        ".bin",
247        ".pt",
248        ".pth",
249        ".onnx",
250        ".json",
251        ".yaml",
252        ".yml",
253        ".toml",
254        ".txt",
255    ];
256    let name_lower = name.to_lowercase();
257    !file_extensions.iter().any(|ext| name_lower.ends_with(ext))
258}
259
260/// Data configuration
261#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct DataConfig {
263    /// Training data path
264    #[serde(default)]
265    pub train: PathBuf,
266
267    /// Optional validation data path
268    #[serde(default, skip_serializing_if = "Option::is_none")]
269    pub val: Option<PathBuf>,
270
271    /// Batch size
272    #[serde(default = "default_batch_size")]
273    pub batch_size: usize,
274
275    /// Auto-infer feature types from data
276    #[serde(default = "default_true", deserialize_with = "deserialize_bool_lenient")]
277    pub auto_infer_types: bool,
278
279    /// Sequence length (for transformers)
280    #[serde(default, skip_serializing_if = "Option::is_none")]
281    pub seq_len: Option<usize>,
282
283    // === ENT-114: LLM training fields ===
284    /// Path to HuggingFace tokenizer.json (for transformer mode)
285    #[serde(default, skip_serializing_if = "Option::is_none")]
286    pub tokenizer: Option<PathBuf>,
287
288    /// Input text column name (for transformer mode)
289    #[serde(default, skip_serializing_if = "Option::is_none")]
290    pub input_column: Option<String>,
291
292    /// Output/target text column name (for transformer mode)
293    #[serde(default, skip_serializing_if = "Option::is_none")]
294    pub output_column: Option<String>,
295
296    /// Maximum sequence length for tokenization
297    #[serde(default, skip_serializing_if = "Option::is_none")]
298    pub max_length: Option<usize>,
299}
300
301impl Default for DataConfig {
302    fn default() -> Self {
303        Self {
304            train: PathBuf::new(),
305            val: None,
306            batch_size: 8,
307            auto_infer_types: true,
308            seq_len: None,
309            tokenizer: None,
310            input_column: None,
311            output_column: None,
312            max_length: None,
313        }
314    }
315}
316
317fn default_batch_size() -> usize {
318    8
319}
320
321/// Optimizer specification
322#[derive(Debug, Clone, Serialize, Deserialize)]
323pub struct OptimSpec {
324    /// Optimizer name: "adam" | "adamw" | "sgd"
325    pub name: String,
326
327    /// Learning rate
328    pub lr: f32,
329
330    /// Optimizer-specific parameters (beta1, beta2, momentum, etc.)
331    #[serde(flatten)]
332    pub params: HashMap<String, serde_json::Value>,
333}
334
335/// LoRA configuration
336#[derive(Debug, Clone, Serialize, Deserialize)]
337pub struct LoRASpec {
338    /// Rank of low-rank decomposition
339    pub rank: usize,
340
341    /// Scaling factor (alpha)
342    pub alpha: f32,
343
344    /// Target modules (e.g., [q_proj, v_proj])
345    pub target_modules: Vec<String>,
346
347    /// Dropout probability
348    #[serde(default)]
349    pub dropout: f32,
350
351    /// LoRA+ ratio: LR multiplier for B matrices (ENT-LoRA-006)
352    /// Default 1.0 = standard LoRA. 16.0 = LoRA+ (Hayou et al. ICML 2024)
353    #[serde(default = "default_lora_plus_ratio")]
354    pub lora_plus_ratio: f32,
355
356    /// Double quantization for QLoRA (ENT-LoRA-008)
357    /// Quantizes FP32 absmax constants to 8-bit, saving ~0.37 bits/param
358    /// Default true when quantize_base is true
359    #[serde(default)]
360    pub double_quantize: bool,
361
362    /// Quantize frozen base weights to NF4 (4-bit) for QLoRA (ENT-263)
363    ///
364    /// When true, base model weights are quantized to NF4 and frozen.
365    /// Only LoRA adapters + norm weights are trainable in fp32.
366    /// Achieves ~8x VRAM compression on base weights.
367    #[serde(default)]
368    pub quantize_base: bool,
369}
370
371fn default_lora_plus_ratio() -> f32 {
372    1.0
373}
374
375/// Quantization configuration
376#[derive(Debug, Clone, Serialize, Deserialize)]
377pub struct QuantSpec {
378    /// Quantization bits (4 or 8)
379    pub bits: u8,
380
381    /// Symmetric quantization
382    #[serde(default = "default_true", deserialize_with = "deserialize_bool_lenient")]
383    pub symmetric: bool,
384
385    /// Per-channel quantization
386    #[serde(default = "default_true", deserialize_with = "deserialize_bool_lenient")]
387    pub per_channel: bool,
388}
389
390/// Model merging configuration
391#[derive(Debug, Clone, Serialize, Deserialize)]
392pub struct MergeSpec {
393    /// Merge method: "ties" | "dare" | "slerp"
394    pub method: String,
395
396    /// Method-specific parameters
397    #[serde(flatten)]
398    pub params: HashMap<String, serde_json::Value>,
399}
400
401/// Training hyperparameters
402#[derive(Debug, Clone, Serialize, Deserialize)]
403#[serde(default)]
404pub struct TrainingParams {
405    /// Number of epochs
406    pub epochs: usize,
407
408    /// Gradient clipping threshold
409    #[serde(skip_serializing_if = "Option::is_none")]
410    pub grad_clip: Option<f32>,
411
412    /// Learning rate scheduler
413    #[serde(skip_serializing_if = "Option::is_none")]
414    pub lr_scheduler: Option<String>,
415
416    /// Warmup steps
417    pub warmup_steps: usize,
418
419    /// Save checkpoint every N epochs
420    pub save_interval: usize,
421
422    /// Output directory for checkpoints
423    pub output_dir: PathBuf,
424
425    // === ENT-114: LLM training fields ===
426    /// Training mode (regression or causal_lm)
427    /// ENT-114: Uses CausalLMLoss when mode=causal_lm
428    pub mode: TrainingMode,
429
430    /// Gradient accumulation steps (for large batch simulation)
431    #[serde(skip_serializing_if = "Option::is_none")]
432    pub gradient_accumulation: Option<usize>,
433
434    /// Number of gradient checkpoints (for memory optimization)
435    #[serde(skip_serializing_if = "Option::is_none")]
436    pub checkpoints: Option<usize>,
437
438    /// Use mixed precision training (bf16 or fp16)
439    #[serde(skip_serializing_if = "Option::is_none")]
440    pub mixed_precision: Option<String>,
441
442    /// Scheduler-specific parameters (t_max, gamma, step_size, etc.)
443    #[serde(skip_serializing_if = "Option::is_none")]
444    pub scheduler_params: Option<HashMap<String, serde_json::Value>>,
445
446    /// Maximum training steps (overrides epochs if set)
447    #[serde(skip_serializing_if = "Option::is_none")]
448    pub max_steps: Option<usize>,
449
450    /// Global random seed for reproducibility
451    #[serde(skip_serializing_if = "Option::is_none")]
452    pub seed: Option<u64>,
453
454    // === R-009: Multiple checkpoint retention ===
455    /// Maximum number of checkpoints to keep (default 5, 0 = unlimited)
456    #[serde(default = "default_max_checkpoints")]
457    pub max_checkpoints: usize,
458
459    // === R-015: Data shuffling ===
460    /// Shuffle training batches per epoch (default true)
461    #[serde(default = "default_true")]
462    pub shuffle: bool,
463
464    // === R-023: Curriculum learning ===
465    /// Multi-stage data mixing: switch data sources at step boundaries.
466    /// Each stage specifies a data path and the step at which to transition.
467    #[serde(default, skip_serializing_if = "Option::is_none")]
468    pub curriculum: Option<Vec<CurriculumStage>>,
469
470    // === KAIZEN-047: Step profiler ===
471    /// Print per-phase timing breakdown every N steps (0 = disabled)
472    #[serde(default)]
473    pub profile_interval: usize,
474
475    // === R-084: Bitwise deterministic training ===
476    /// Enable bitwise deterministic training mode (C-DETERM-001).
477    /// Sets CUBLAS_WORKSPACE_CONFIG, cuDNN deterministic mode, disables
478    /// cuDNN benchmark. May reduce throughput but guarantees reproducibility.
479    #[serde(default)]
480    pub deterministic: bool,
481
482    // === ALB-087: Auto eval scheduling + best-model tracking ===
483    /// Steps between validation evaluations (defaults to save_interval if 0).
484    /// Decouples eval frequency from checkpoint save frequency.
485    #[serde(default)]
486    pub eval_interval: usize,
487
488    /// Number of eval intervals without improvement before early stop (0 = disabled).
489    #[serde(default)]
490    pub patience: usize,
491
492    // === Distributed training (tickets #131-#140) ===
493    /// Distributed data-parallel training configuration.
494    /// When present, enables multi-GPU / multi-node DDP pretraining.
495    #[serde(default, skip_serializing_if = "Option::is_none")]
496    pub distributed: Option<DistributedSpec>,
497}
498
499/// A curriculum learning stage: a data source active until a given step.
500#[derive(Debug, Clone, Serialize, Deserialize)]
501pub struct CurriculumStage {
502    /// Path to training data for this stage
503    pub data: PathBuf,
504    /// Step number at which to transition to next stage (None = until end)
505    #[serde(default, skip_serializing_if = "Option::is_none")]
506    pub until_step: Option<usize>,
507}
508
509/// Distributed training configuration from YAML.
510///
511/// Specifies how to coordinate multiple workers for data-parallel pretraining.
512/// When present in the `training` section, enables distributed mode.
513///
514/// # Example YAML
515///
516/// ```yaml
517/// training:
518///   distributed:
519///     world_size: 2
520///     backend: "auto"
521///     role: "coordinator"
522///     coordinator_addr: "0.0.0.0:9000"
523/// ```
524#[derive(Debug, Clone, Serialize, Deserialize)]
525pub struct DistributedSpec {
526    /// Total number of workers
527    pub world_size: usize,
528
529    /// Compute backend: "cuda", "wgpu", or "auto" (default: "auto")
530    #[serde(default = "default_backend")]
531    pub backend: String,
532
533    /// Node role: "coordinator" or "worker" (default: "coordinator")
534    #[serde(default = "default_role")]
535    pub role: String,
536
537    /// Coordinator address (bind for coordinator, connect for worker)
538    #[serde(default = "default_coordinator_addr")]
539    pub coordinator_addr: String,
540
541    /// This worker's global rank (default: 0)
542    #[serde(default)]
543    pub rank: usize,
544
545    /// This worker's local rank on its machine (default: 0)
546    #[serde(default)]
547    pub local_rank: usize,
548}
549
550fn default_backend() -> String {
551    "auto".to_string()
552}
553
554fn default_role() -> String {
555    "coordinator".to_string()
556}
557
558fn default_coordinator_addr() -> String {
559    "0.0.0.0:9000".to_string()
560}
561
562impl Default for TrainingParams {
563    fn default() -> Self {
564        Self {
565            epochs: 10,
566            grad_clip: None,
567            lr_scheduler: None,
568            warmup_steps: 0,
569            save_interval: 1,
570            output_dir: PathBuf::from("./checkpoints"),
571            mode: TrainingMode::default(),
572            gradient_accumulation: None,
573            checkpoints: None,
574            mixed_precision: None,
575            scheduler_params: None,
576            max_steps: None,
577            seed: None,
578            max_checkpoints: 5,
579            shuffle: true,
580            curriculum: None,
581            profile_interval: 0,
582            deterministic: false,
583            eval_interval: 0,
584            patience: 0,
585            distributed: None,
586        }
587    }
588}
589
590fn default_max_checkpoints() -> usize {
591    5
592}
593
594fn default_true() -> bool {
595    true
596}
597
598#[cfg(test)]
599mod tests {
600    use super::*;
601
602    #[test]
603    fn test_deserialize_minimal_config() {
604        let yaml = r"
605model:
606  path: model.gguf
607  layers: []
608
609data:
610  train: train.parquet
611  batch_size: 8
612
613optimizer:
614  name: adam
615  lr: 0.001
616";
617
618        let spec: TrainSpec = serde_yaml::from_str(yaml).expect("operation should succeed");
619        assert_eq!(spec.model.path, PathBuf::from("model.gguf"));
620        assert_eq!(spec.data.batch_size, 8);
621        assert_eq!(spec.optimizer.name, "adam");
622        assert_eq!(spec.optimizer.lr, 0.001);
623    }
624
625    #[test]
626    fn test_deserialize_full_config() {
627        let yaml = r"
628model:
629  path: llama-7b.gguf
630  layers: [q_proj, k_proj, v_proj, o_proj]
631
632data:
633  train: train.parquet
634  val: val.parquet
635  batch_size: 32
636  auto_infer_types: true
637  seq_len: 2048
638
639optimizer:
640  name: adamw
641  lr: 0.0001
642  beta1: 0.9
643  beta2: 0.999
644  weight_decay: 0.01
645
646lora:
647  rank: 64
648  alpha: 16
649  target_modules: [q_proj, v_proj]
650  dropout: 0.1
651
652quantize:
653  bits: 4
654  symmetric: true
655  per_channel: true
656
657training:
658  epochs: 3
659  grad_clip: 1.0
660  lr_scheduler: cosine
661  warmup_steps: 100
662  save_interval: 1
663  output_dir: ./outputs
664";
665
666        let spec: TrainSpec = serde_yaml::from_str(yaml).expect("operation should succeed");
667        assert_eq!(spec.model.layers.len(), 4);
668        assert!(spec.lora.is_some());
669        assert_eq!(spec.lora.as_ref().expect("operation should succeed").rank, 64);
670        assert!(spec.quantize.is_some());
671        assert_eq!(spec.quantize.as_ref().expect("operation should succeed").bits, 4);
672        assert_eq!(spec.training.epochs, 3);
673    }
674
675    #[test]
676    fn test_default_training_params() {
677        let params = TrainingParams::default();
678        assert_eq!(params.epochs, 10);
679        assert_eq!(params.save_interval, 1);
680        assert!(params.grad_clip.is_none());
681    }
682
683    // === ENT-114 Tests: LLM Training Schema ===
684
685    #[test]
686    fn test_model_mode_default_is_tabular() {
687        let mode = ModelMode::default();
688        assert_eq!(mode, ModelMode::Tabular);
689    }
690
691    #[test]
692    fn test_training_mode_default_is_regression() {
693        let mode = TrainingMode::default();
694        assert_eq!(mode, TrainingMode::Regression);
695    }
696
697    #[test]
698    fn test_model_mode_serde_roundtrip() {
699        // Tabular mode
700        let yaml = "tabular";
701        let mode: ModelMode = serde_yaml::from_str(yaml).expect("operation should succeed");
702        assert_eq!(mode, ModelMode::Tabular);
703
704        // Transformer mode
705        let yaml = "transformer";
706        let mode: ModelMode = serde_yaml::from_str(yaml).expect("operation should succeed");
707        assert_eq!(mode, ModelMode::Transformer);
708    }
709
710    #[test]
711    fn test_training_mode_serde_roundtrip() {
712        // Regression mode
713        let yaml = "regression";
714        let mode: TrainingMode = serde_yaml::from_str(yaml).expect("operation should succeed");
715        assert_eq!(mode, TrainingMode::Regression);
716
717        // CausalLM mode
718        let yaml = "causal_lm";
719        let mode: TrainingMode = serde_yaml::from_str(yaml).expect("operation should succeed");
720        assert_eq!(mode, TrainingMode::CausalLm);
721    }
722
723    #[test]
724    fn test_deserialize_transformer_config() {
725        let yaml = r"
726model:
727  path: qwen2.5-coder-1.5b.safetensors
728  mode: transformer
729  config: qwen2_1_5b
730  layers: [q_proj, v_proj]
731
732data:
733  train: corpus/train.parquet
734  batch_size: 4
735  tokenizer: tokenizer.json
736  input_column: input
737  output_column: output
738  max_length: 512
739
740optimizer:
741  name: adamw
742  lr: 0.0001
743
744training:
745  epochs: 3
746  mode: causal_lm
747  gradient_accumulation: 4
748  checkpoints: 6
749  mixed_precision: bf16
750";
751
752        let spec: TrainSpec = serde_yaml::from_str(yaml).expect("operation should succeed");
753
754        // Model assertions
755        assert_eq!(spec.model.mode, ModelMode::Transformer);
756        assert_eq!(spec.model.config, Some("qwen2_1_5b".to_string()));
757
758        // Data assertions
759        assert_eq!(spec.data.tokenizer, Some(PathBuf::from("tokenizer.json")));
760        assert_eq!(spec.data.input_column, Some("input".to_string()));
761        assert_eq!(spec.data.output_column, Some("output".to_string()));
762        assert_eq!(spec.data.max_length, Some(512));
763
764        // Training assertions
765        assert_eq!(spec.training.mode, TrainingMode::CausalLm);
766        assert_eq!(spec.training.gradient_accumulation, Some(4));
767        assert_eq!(spec.training.checkpoints, Some(6));
768        assert_eq!(spec.training.mixed_precision, Some("bf16".to_string()));
769    }
770
771    #[test]
772    fn test_backward_compatible_minimal_config() {
773        // Ensure old configs still work (defaults to tabular/regression)
774        let yaml = r"
775model:
776  path: model.gguf
777
778data:
779  train: data.parquet
780  batch_size: 8
781
782optimizer:
783  name: adam
784  lr: 0.001
785";
786
787        let spec: TrainSpec = serde_yaml::from_str(yaml).expect("operation should succeed");
788        assert_eq!(spec.model.mode, ModelMode::Tabular);
789        assert_eq!(spec.training.mode, TrainingMode::Regression);
790        assert!(spec.data.tokenizer.is_none());
791    }
792
793    #[test]
794    fn test_training_params_new_fields_default() {
795        let params = TrainingParams::default();
796        assert_eq!(params.mode, TrainingMode::Regression);
797        assert!(params.gradient_accumulation.is_none());
798        assert!(params.checkpoints.is_none());
799        assert!(params.mixed_precision.is_none());
800        assert!(params.scheduler_params.is_none());
801        assert!(params.seed.is_none());
802        assert!(params.distributed.is_none());
803    }
804
805    #[test]
806    fn test_deserialize_distributed_config() {
807        let yaml = r"
808model:
809  path: model.safetensors
810  mode: transformer
811
812data:
813  train: data/train/
814  batch_size: 4
815
816optimizer:
817  name: adamw
818  lr: 0.0003
819
820training:
821  epochs: 5
822  mode: causal_lm
823  distributed:
824    world_size: 2
825    backend: cuda
826    role: coordinator
827    coordinator_addr: '0.0.0.0:9000'
828";
829        let spec: TrainSpec = serde_yaml::from_str(yaml).expect("operation should succeed");
830        let dist = spec.training.distributed.expect("distributed config present");
831        assert_eq!(dist.world_size, 2);
832        assert_eq!(dist.backend, "cuda");
833        assert_eq!(dist.role, "coordinator");
834        assert_eq!(dist.coordinator_addr, "0.0.0.0:9000");
835        assert_eq!(dist.rank, 0);
836        assert_eq!(dist.local_rank, 0);
837    }
838
839    #[test]
840    fn test_distributed_config_defaults() {
841        let yaml = r"
842model:
843  path: model.safetensors
844
845data:
846  train: data.parquet
847  batch_size: 8
848
849optimizer:
850  name: adamw
851  lr: 0.001
852
853training:
854  distributed:
855    world_size: 4
856";
857        let spec: TrainSpec = serde_yaml::from_str(yaml).expect("operation should succeed");
858        let dist = spec.training.distributed.expect("distributed config present");
859        assert_eq!(dist.world_size, 4);
860        assert_eq!(dist.backend, "auto");
861        assert_eq!(dist.role, "coordinator");
862        assert_eq!(dist.coordinator_addr, "0.0.0.0:9000");
863    }
864
865    #[test]
866    fn test_backward_compatible_no_distributed() {
867        let yaml = r"
868model:
869  path: model.gguf
870
871data:
872  train: data.parquet
873  batch_size: 8
874
875optimizer:
876  name: adam
877  lr: 0.001
878";
879        let spec: TrainSpec = serde_yaml::from_str(yaml).expect("operation should succeed");
880        assert!(spec.training.distributed.is_none());
881    }
882
883    #[test]
884    fn test_deserialize_scheduler_params_and_seed() {
885        let yaml = r"
886model:
887  path: model.gguf
888
889data:
890  train: data.parquet
891  batch_size: 8
892
893optimizer:
894  name: adam
895  lr: 0.001
896
897training:
898  epochs: 5
899  seed: 42
900  lr_scheduler: cosine
901  scheduler_params:
902    t_max: 1000
903    eta_min: 0.000001
904";
905
906        let spec: TrainSpec = serde_yaml::from_str(yaml).expect("operation should succeed");
907        assert_eq!(spec.training.seed, Some(42));
908        let params = spec.training.scheduler_params.expect("operation should succeed");
909        assert_eq!(params["t_max"], serde_json::json!(1000));
910        assert_eq!(params["eta_min"], serde_json::json!(0.000001));
911    }
912
913    /// CB-950: Verify that quoted boolean strings ("true"/"false") deserialize correctly.
914    /// PMAT compliance requires all YAML truthy values to be quoted.
915    #[test]
916    fn test_cb950_quoted_booleans_deserialize() {
917        let yaml = r#"
918model:
919  path: model.gguf
920  layers: []
921
922data:
923  train: train.parquet
924  batch_size: 8
925  auto_infer_types: "true"
926
927optimizer:
928  name: adam
929  lr: 0.001
930
931quantize:
932  bits: 4
933  symmetric: "true"
934  per_channel: "false"
935"#;
936
937        let spec: TrainSpec = serde_yaml::from_str(yaml).expect("operation should succeed");
938        assert!(spec.data.auto_infer_types);
939        let quant = spec.quantize.expect("operation should succeed");
940        assert!(quant.symmetric);
941        assert!(!quant.per_channel);
942    }
943
944    // === HF Repo ID Detection Tests ===
945
946    #[test]
947    fn test_is_hf_repo_id_valid() {
948        assert!(is_hf_repo_id("Qwen/Qwen2.5-Coder-0.5B"));
949        assert!(is_hf_repo_id("meta-llama/Llama-2-7b"));
950        assert!(is_hf_repo_id("google/gemma-2b"));
951        assert!(is_hf_repo_id("myuser/my-model"));
952    }
953
954    #[test]
955    fn test_is_hf_repo_id_local_paths() {
956        assert!(!is_hf_repo_id("model.gguf"));
957        assert!(!is_hf_repo_id("./models/model.safetensors"));
958        assert!(!is_hf_repo_id("/absolute/path/model.bin"));
959        assert!(!is_hf_repo_id("relative/path/model.gguf"));
960    }
961
962    #[test]
963    fn test_is_hf_repo_id_edge_cases() {
964        assert!(!is_hf_repo_id(""));
965        assert!(!is_hf_repo_id("/"));
966        assert!(!is_hf_repo_id("single-part"));
967        assert!(!is_hf_repo_id("too/many/parts"));
968        assert!(!is_hf_repo_id(".hidden/path"));
969        assert!(!is_hf_repo_id("/org/name"));
970        assert!(!is_hf_repo_id("org/"));
971        assert!(!is_hf_repo_id("/name"));
972    }
973
974    #[test]
975    fn test_is_hf_repo_id_with_extension_rejected() {
976        // Files with extensions are local paths, not HF IDs
977        assert!(!is_hf_repo_id("org/model.safetensors"));
978        assert!(!is_hf_repo_id("user/model.gguf"));
979    }
980
981    #[test]
982    fn test_model_ref_is_hf_repo_id() {
983        let model =
984            ModelRef { path: PathBuf::from("Qwen/Qwen2.5-Coder-0.5B"), ..Default::default() };
985        assert!(model.is_hf_repo_id());
986
987        let model = ModelRef { path: PathBuf::from("model.gguf"), ..Default::default() };
988        assert!(!model.is_hf_repo_id());
989    }
990
991    #[test]
992    fn test_deserialize_hf_repo_id_as_model_path() {
993        let yaml = r"
994model:
995  path: Qwen/Qwen2.5-Coder-0.5B
996  mode: transformer
997
998data:
999  train: data.parquet
1000  batch_size: 8
1001
1002optimizer:
1003  name: adamw
1004  lr: 0.0001
1005";
1006        let spec: TrainSpec = serde_yaml::from_str(yaml).expect("operation should succeed");
1007        assert!(spec.model.is_hf_repo_id());
1008        assert_eq!(spec.model.path, PathBuf::from("Qwen/Qwen2.5-Coder-0.5B"));
1009    }
1010
1011    // === Publish Section Tests ===
1012
1013    #[test]
1014    fn test_deserialize_with_publish_section() {
1015        let yaml = r"
1016model:
1017  path: model.gguf
1018
1019data:
1020  train: data.parquet
1021  batch_size: 8
1022
1023optimizer:
1024  name: adamw
1025  lr: 0.0001
1026
1027publish:
1028  repo: myuser/my-model
1029  private: false
1030  model_card: true
1031  merge_adapters: true
1032  format: safetensors
1033";
1034        let spec: TrainSpec = serde_yaml::from_str(yaml).expect("operation should succeed");
1035        let publish = spec.publish.expect("operation should succeed");
1036        assert_eq!(publish.repo, "myuser/my-model");
1037        assert!(!publish.private);
1038        assert!(publish.model_card);
1039        assert!(publish.merge_adapters);
1040        assert_eq!(publish.format, "safetensors");
1041    }
1042
1043    #[test]
1044    fn test_deserialize_without_publish_section() {
1045        let yaml = r"
1046model:
1047  path: model.gguf
1048
1049data:
1050  train: data.parquet
1051  batch_size: 8
1052
1053optimizer:
1054  name: adam
1055  lr: 0.001
1056";
1057        let spec: TrainSpec = serde_yaml::from_str(yaml).expect("operation should succeed");
1058        assert!(spec.publish.is_none());
1059    }
1060
1061    #[test]
1062    fn test_publish_spec_defaults() {
1063        let yaml = r"
1064model:
1065  path: model.gguf
1066
1067data:
1068  train: data.parquet
1069  batch_size: 8
1070
1071optimizer:
1072  name: adam
1073  lr: 0.001
1074
1075publish:
1076  repo: org/model
1077";
1078        let spec: TrainSpec = serde_yaml::from_str(yaml).expect("operation should succeed");
1079        let publish = spec.publish.expect("operation should succeed");
1080        assert_eq!(publish.repo, "org/model");
1081        assert!(!publish.private);
1082        assert!(publish.model_card);
1083        assert!(!publish.merge_adapters);
1084        assert_eq!(publish.format, "safetensors");
1085    }
1086
1087    #[test]
1088    fn test_deterministic_config_yaml() {
1089        let yaml = r"
1090model:
1091  path: test-model
1092  type: transformer
1093
1094data:
1095  train: data.parquet
1096  batch_size: 8
1097
1098optimizer:
1099  name: adamw
1100  lr: 0.001
1101
1102training:
1103  epochs: 10
1104  deterministic: true
1105  seed: 12345
1106";
1107        let spec: TrainSpec = serde_yaml::from_str(yaml).expect("parse YAML");
1108        assert!(spec.training.deterministic, "deterministic should be true from YAML");
1109        assert_eq!(spec.training.seed, Some(12345));
1110    }
1111
1112    #[test]
1113    fn test_deterministic_defaults_to_false() {
1114        let yaml = r"
1115model:
1116  path: test-model
1117
1118data:
1119  train: data.parquet
1120  batch_size: 8
1121
1122optimizer:
1123  name: adamw
1124  lr: 0.001
1125";
1126        let spec: TrainSpec = serde_yaml::from_str(yaml).expect("parse YAML");
1127        assert!(!spec.training.deterministic, "deterministic should default to false");
1128    }
1129
1130    #[test]
1131    fn test_ent_263_lora_quantize_base_yaml() {
1132        let yaml = r"
1133model:
1134  path: model.safetensors
1135
1136data:
1137  train: data.parquet
1138  batch_size: 4
1139
1140optimizer:
1141  name: adamw
1142  lr: 0.0001
1143
1144lora:
1145  rank: 16
1146  alpha: 32.0
1147  target_modules: [q_proj, v_proj]
1148  quantize_base: true
1149  double_quantize: true
1150
1151training:
1152  epochs: 1
1153";
1154        let spec: TrainSpec = serde_yaml::from_str(yaml).expect("parse YAML");
1155        let lora = spec.lora.expect("lora should be present");
1156        assert!(lora.quantize_base, "quantize_base should be true");
1157        assert!(lora.double_quantize, "double_quantize should be true");
1158        assert_eq!(lora.rank, 16);
1159    }
1160
1161    #[test]
1162    fn test_ent_263_lora_quantize_base_default_false() {
1163        let yaml = r"
1164model:
1165  path: model.safetensors
1166
1167data:
1168  train: data.parquet
1169  batch_size: 4
1170
1171optimizer:
1172  name: adamw
1173  lr: 0.0001
1174
1175lora:
1176  rank: 8
1177  alpha: 16.0
1178  target_modules: [q_proj]
1179";
1180        let spec: TrainSpec = serde_yaml::from_str(yaml).expect("parse YAML");
1181        let lora = spec.lora.expect("lora should be present");
1182        assert!(!lora.quantize_base, "quantize_base should default to false");
1183    }
1184}