1use serde::{Deserialize, Deserializer, Serialize};
6use std::collections::HashMap;
7use std::path::PathBuf;
8
9fn deserialize_bool_lenient<'de, D>(deserializer: D) -> Result<bool, D::Error>
12where
13 D: Deserializer<'de>,
14{
15 #[derive(Deserialize)]
16 #[serde(untagged)]
17 enum BoolOrString {
18 Bool(bool),
19 Str(String),
20 }
21
22 match BoolOrString::deserialize(deserializer)? {
23 BoolOrString::Bool(b) => Ok(b),
24 BoolOrString::Str(s) => match s.to_lowercase().as_str() {
25 "true" => Ok(true),
26 "false" => Ok(false),
27 other => {
28 Err(serde::de::Error::custom(format!("expected 'true' or 'false', got '{other}'")))
29 }
30 },
31 }
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
38#[serde(rename_all = "lowercase")]
39pub enum ModelMode {
40 #[default]
42 Tabular,
43 Transformer,
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
51#[serde(rename_all = "snake_case")]
52pub enum TrainingMode {
53 #[default]
55 Regression,
56 CausalLm,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct TrainSpec {
63 pub model: ModelRef,
65
66 pub data: DataConfig,
68
69 pub optimizer: OptimSpec,
71
72 #[serde(default, skip_serializing_if = "Option::is_none")]
74 pub lora: Option<LoRASpec>,
75
76 #[serde(default, skip_serializing_if = "Option::is_none")]
78 pub quantize: Option<QuantSpec>,
79
80 #[serde(default, skip_serializing_if = "Option::is_none")]
82 pub merge: Option<MergeSpec>,
83
84 #[serde(default)]
86 pub training: TrainingParams,
87
88 #[serde(default, skip_serializing_if = "Option::is_none")]
90 pub publish: Option<PublishSpec>,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct PublishSpec {
96 pub repo: String,
98
99 #[serde(default)]
101 pub private: bool,
102
103 #[serde(default = "default_true")]
105 pub model_card: bool,
106
107 #[serde(default)]
109 pub merge_adapters: bool,
110
111 #[serde(default = "default_safetensors")]
113 pub format: String,
114}
115
116fn default_safetensors() -> String {
117 "safetensors".to_string()
118}
119
120#[derive(Debug, Clone, Default, Serialize, Deserialize)]
126pub struct ArchitectureOverrides {
127 #[serde(default, skip_serializing_if = "Option::is_none")]
129 pub hidden_size: Option<usize>,
130 #[serde(default, skip_serializing_if = "Option::is_none", alias = "num_layers")]
132 pub num_hidden_layers: Option<usize>,
133 #[serde(default, skip_serializing_if = "Option::is_none", alias = "num_heads")]
135 pub num_attention_heads: Option<usize>,
136 #[serde(default, skip_serializing_if = "Option::is_none", alias = "num_key_value_heads")]
138 pub num_kv_heads: Option<usize>,
139 #[serde(default, skip_serializing_if = "Option::is_none")]
141 pub intermediate_size: Option<usize>,
142 #[serde(default, skip_serializing_if = "Option::is_none")]
144 pub vocab_size: Option<usize>,
145 #[serde(default, skip_serializing_if = "Option::is_none", alias = "max_seq_length")]
147 pub max_position_embeddings: Option<usize>,
148 #[serde(default, skip_serializing_if = "Option::is_none")]
150 pub rms_norm_eps: Option<f32>,
151 #[serde(default, skip_serializing_if = "Option::is_none")]
153 pub rope_theta: Option<f32>,
154 #[serde(default, skip_serializing_if = "Option::is_none")]
156 pub use_bias: Option<bool>,
157 #[serde(default, skip_serializing_if = "Option::is_none")]
159 pub head_dim: Option<usize>,
160}
161
162impl ArchitectureOverrides {
163 pub fn is_empty(&self) -> bool {
165 self.hidden_size.is_none()
166 && self.num_hidden_layers.is_none()
167 && self.num_attention_heads.is_none()
168 && self.num_kv_heads.is_none()
169 && self.intermediate_size.is_none()
170 && self.vocab_size.is_none()
171 && self.max_position_embeddings.is_none()
172 && self.rms_norm_eps.is_none()
173 && self.rope_theta.is_none()
174 && self.use_bias.is_none()
175 && self.head_dim.is_none()
176 }
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize, Default)]
181pub struct ModelRef {
182 #[serde(default)]
185 pub path: PathBuf,
186
187 #[serde(default)]
189 pub layers: Vec<String>,
190
191 #[serde(default)]
194 pub mode: ModelMode,
195
196 #[serde(default, skip_serializing_if = "Option::is_none")]
199 pub config: Option<String>,
200
201 #[serde(default, skip_serializing_if = "Option::is_none")]
204 pub architecture: Option<ArchitectureOverrides>,
205}
206
207impl ModelRef {
208 pub fn is_hf_repo_id(&self) -> bool {
212 let s = self.path.to_string_lossy();
213 is_hf_repo_id(&s)
214 }
215}
216
217pub fn is_hf_repo_id(s: &str) -> bool {
225 if s.starts_with('.') || s.starts_with('/') {
227 return false;
228 }
229
230 let parts: Vec<&str> = s.split('/').collect();
231 if parts.len() != 2 {
232 return false;
233 }
234
235 let (org, name) = (parts[0], parts[1]);
236
237 if org.is_empty() || name.is_empty() {
239 return false;
240 }
241
242 let file_extensions = [
244 ".safetensors",
245 ".gguf",
246 ".bin",
247 ".pt",
248 ".pth",
249 ".onnx",
250 ".json",
251 ".yaml",
252 ".yml",
253 ".toml",
254 ".txt",
255 ];
256 let name_lower = name.to_lowercase();
257 !file_extensions.iter().any(|ext| name_lower.ends_with(ext))
258}
259
260#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct DataConfig {
263 #[serde(default)]
265 pub train: PathBuf,
266
267 #[serde(default, skip_serializing_if = "Option::is_none")]
269 pub val: Option<PathBuf>,
270
271 #[serde(default = "default_batch_size")]
273 pub batch_size: usize,
274
275 #[serde(default = "default_true", deserialize_with = "deserialize_bool_lenient")]
277 pub auto_infer_types: bool,
278
279 #[serde(default, skip_serializing_if = "Option::is_none")]
281 pub seq_len: Option<usize>,
282
283 #[serde(default, skip_serializing_if = "Option::is_none")]
286 pub tokenizer: Option<PathBuf>,
287
288 #[serde(default, skip_serializing_if = "Option::is_none")]
290 pub input_column: Option<String>,
291
292 #[serde(default, skip_serializing_if = "Option::is_none")]
294 pub output_column: Option<String>,
295
296 #[serde(default, skip_serializing_if = "Option::is_none")]
298 pub max_length: Option<usize>,
299}
300
301impl Default for DataConfig {
302 fn default() -> Self {
303 Self {
304 train: PathBuf::new(),
305 val: None,
306 batch_size: 8,
307 auto_infer_types: true,
308 seq_len: None,
309 tokenizer: None,
310 input_column: None,
311 output_column: None,
312 max_length: None,
313 }
314 }
315}
316
317fn default_batch_size() -> usize {
318 8
319}
320
321#[derive(Debug, Clone, Serialize, Deserialize)]
323pub struct OptimSpec {
324 pub name: String,
326
327 pub lr: f32,
329
330 #[serde(flatten)]
332 pub params: HashMap<String, serde_json::Value>,
333}
334
335#[derive(Debug, Clone, Serialize, Deserialize)]
337pub struct LoRASpec {
338 pub rank: usize,
340
341 pub alpha: f32,
343
344 pub target_modules: Vec<String>,
346
347 #[serde(default)]
349 pub dropout: f32,
350
351 #[serde(default = "default_lora_plus_ratio")]
354 pub lora_plus_ratio: f32,
355
356 #[serde(default)]
360 pub double_quantize: bool,
361
362 #[serde(default)]
368 pub quantize_base: bool,
369}
370
371fn default_lora_plus_ratio() -> f32 {
372 1.0
373}
374
375#[derive(Debug, Clone, Serialize, Deserialize)]
377pub struct QuantSpec {
378 pub bits: u8,
380
381 #[serde(default = "default_true", deserialize_with = "deserialize_bool_lenient")]
383 pub symmetric: bool,
384
385 #[serde(default = "default_true", deserialize_with = "deserialize_bool_lenient")]
387 pub per_channel: bool,
388}
389
390#[derive(Debug, Clone, Serialize, Deserialize)]
392pub struct MergeSpec {
393 pub method: String,
395
396 #[serde(flatten)]
398 pub params: HashMap<String, serde_json::Value>,
399}
400
401#[derive(Debug, Clone, Serialize, Deserialize)]
403#[serde(default)]
404pub struct TrainingParams {
405 pub epochs: usize,
407
408 #[serde(skip_serializing_if = "Option::is_none")]
410 pub grad_clip: Option<f32>,
411
412 #[serde(skip_serializing_if = "Option::is_none")]
414 pub lr_scheduler: Option<String>,
415
416 pub warmup_steps: usize,
418
419 pub save_interval: usize,
421
422 pub output_dir: PathBuf,
424
425 pub mode: TrainingMode,
429
430 #[serde(skip_serializing_if = "Option::is_none")]
432 pub gradient_accumulation: Option<usize>,
433
434 #[serde(skip_serializing_if = "Option::is_none")]
436 pub checkpoints: Option<usize>,
437
438 #[serde(skip_serializing_if = "Option::is_none")]
440 pub mixed_precision: Option<String>,
441
442 #[serde(skip_serializing_if = "Option::is_none")]
444 pub scheduler_params: Option<HashMap<String, serde_json::Value>>,
445
446 #[serde(skip_serializing_if = "Option::is_none")]
448 pub max_steps: Option<usize>,
449
450 #[serde(skip_serializing_if = "Option::is_none")]
452 pub seed: Option<u64>,
453
454 #[serde(default = "default_max_checkpoints")]
457 pub max_checkpoints: usize,
458
459 #[serde(default = "default_true")]
462 pub shuffle: bool,
463
464 #[serde(default, skip_serializing_if = "Option::is_none")]
468 pub curriculum: Option<Vec<CurriculumStage>>,
469
470 #[serde(default)]
473 pub profile_interval: usize,
474
475 #[serde(default)]
480 pub deterministic: bool,
481
482 #[serde(default)]
486 pub eval_interval: usize,
487
488 #[serde(default)]
490 pub patience: usize,
491
492 #[serde(default, skip_serializing_if = "Option::is_none")]
496 pub distributed: Option<DistributedSpec>,
497}
498
499#[derive(Debug, Clone, Serialize, Deserialize)]
501pub struct CurriculumStage {
502 pub data: PathBuf,
504 #[serde(default, skip_serializing_if = "Option::is_none")]
506 pub until_step: Option<usize>,
507}
508
509#[derive(Debug, Clone, Serialize, Deserialize)]
525pub struct DistributedSpec {
526 pub world_size: usize,
528
529 #[serde(default = "default_backend")]
531 pub backend: String,
532
533 #[serde(default = "default_role")]
535 pub role: String,
536
537 #[serde(default = "default_coordinator_addr")]
539 pub coordinator_addr: String,
540
541 #[serde(default)]
543 pub rank: usize,
544
545 #[serde(default)]
547 pub local_rank: usize,
548}
549
550fn default_backend() -> String {
551 "auto".to_string()
552}
553
554fn default_role() -> String {
555 "coordinator".to_string()
556}
557
558fn default_coordinator_addr() -> String {
559 "0.0.0.0:9000".to_string()
560}
561
562impl Default for TrainingParams {
563 fn default() -> Self {
564 Self {
565 epochs: 10,
566 grad_clip: None,
567 lr_scheduler: None,
568 warmup_steps: 0,
569 save_interval: 1,
570 output_dir: PathBuf::from("./checkpoints"),
571 mode: TrainingMode::default(),
572 gradient_accumulation: None,
573 checkpoints: None,
574 mixed_precision: None,
575 scheduler_params: None,
576 max_steps: None,
577 seed: None,
578 max_checkpoints: 5,
579 shuffle: true,
580 curriculum: None,
581 profile_interval: 0,
582 deterministic: false,
583 eval_interval: 0,
584 patience: 0,
585 distributed: None,
586 }
587 }
588}
589
590fn default_max_checkpoints() -> usize {
591 5
592}
593
594fn default_true() -> bool {
595 true
596}
597
598#[cfg(test)]
599mod tests {
600 use super::*;
601
602 #[test]
603 fn test_deserialize_minimal_config() {
604 let yaml = r"
605model:
606 path: model.gguf
607 layers: []
608
609data:
610 train: train.parquet
611 batch_size: 8
612
613optimizer:
614 name: adam
615 lr: 0.001
616";
617
618 let spec: TrainSpec = serde_yaml::from_str(yaml).expect("operation should succeed");
619 assert_eq!(spec.model.path, PathBuf::from("model.gguf"));
620 assert_eq!(spec.data.batch_size, 8);
621 assert_eq!(spec.optimizer.name, "adam");
622 assert_eq!(spec.optimizer.lr, 0.001);
623 }
624
625 #[test]
626 fn test_deserialize_full_config() {
627 let yaml = r"
628model:
629 path: llama-7b.gguf
630 layers: [q_proj, k_proj, v_proj, o_proj]
631
632data:
633 train: train.parquet
634 val: val.parquet
635 batch_size: 32
636 auto_infer_types: true
637 seq_len: 2048
638
639optimizer:
640 name: adamw
641 lr: 0.0001
642 beta1: 0.9
643 beta2: 0.999
644 weight_decay: 0.01
645
646lora:
647 rank: 64
648 alpha: 16
649 target_modules: [q_proj, v_proj]
650 dropout: 0.1
651
652quantize:
653 bits: 4
654 symmetric: true
655 per_channel: true
656
657training:
658 epochs: 3
659 grad_clip: 1.0
660 lr_scheduler: cosine
661 warmup_steps: 100
662 save_interval: 1
663 output_dir: ./outputs
664";
665
666 let spec: TrainSpec = serde_yaml::from_str(yaml).expect("operation should succeed");
667 assert_eq!(spec.model.layers.len(), 4);
668 assert!(spec.lora.is_some());
669 assert_eq!(spec.lora.as_ref().expect("operation should succeed").rank, 64);
670 assert!(spec.quantize.is_some());
671 assert_eq!(spec.quantize.as_ref().expect("operation should succeed").bits, 4);
672 assert_eq!(spec.training.epochs, 3);
673 }
674
675 #[test]
676 fn test_default_training_params() {
677 let params = TrainingParams::default();
678 assert_eq!(params.epochs, 10);
679 assert_eq!(params.save_interval, 1);
680 assert!(params.grad_clip.is_none());
681 }
682
683 #[test]
686 fn test_model_mode_default_is_tabular() {
687 let mode = ModelMode::default();
688 assert_eq!(mode, ModelMode::Tabular);
689 }
690
691 #[test]
692 fn test_training_mode_default_is_regression() {
693 let mode = TrainingMode::default();
694 assert_eq!(mode, TrainingMode::Regression);
695 }
696
697 #[test]
698 fn test_model_mode_serde_roundtrip() {
699 let yaml = "tabular";
701 let mode: ModelMode = serde_yaml::from_str(yaml).expect("operation should succeed");
702 assert_eq!(mode, ModelMode::Tabular);
703
704 let yaml = "transformer";
706 let mode: ModelMode = serde_yaml::from_str(yaml).expect("operation should succeed");
707 assert_eq!(mode, ModelMode::Transformer);
708 }
709
710 #[test]
711 fn test_training_mode_serde_roundtrip() {
712 let yaml = "regression";
714 let mode: TrainingMode = serde_yaml::from_str(yaml).expect("operation should succeed");
715 assert_eq!(mode, TrainingMode::Regression);
716
717 let yaml = "causal_lm";
719 let mode: TrainingMode = serde_yaml::from_str(yaml).expect("operation should succeed");
720 assert_eq!(mode, TrainingMode::CausalLm);
721 }
722
723 #[test]
724 fn test_deserialize_transformer_config() {
725 let yaml = r"
726model:
727 path: qwen2.5-coder-1.5b.safetensors
728 mode: transformer
729 config: qwen2_1_5b
730 layers: [q_proj, v_proj]
731
732data:
733 train: corpus/train.parquet
734 batch_size: 4
735 tokenizer: tokenizer.json
736 input_column: input
737 output_column: output
738 max_length: 512
739
740optimizer:
741 name: adamw
742 lr: 0.0001
743
744training:
745 epochs: 3
746 mode: causal_lm
747 gradient_accumulation: 4
748 checkpoints: 6
749 mixed_precision: bf16
750";
751
752 let spec: TrainSpec = serde_yaml::from_str(yaml).expect("operation should succeed");
753
754 assert_eq!(spec.model.mode, ModelMode::Transformer);
756 assert_eq!(spec.model.config, Some("qwen2_1_5b".to_string()));
757
758 assert_eq!(spec.data.tokenizer, Some(PathBuf::from("tokenizer.json")));
760 assert_eq!(spec.data.input_column, Some("input".to_string()));
761 assert_eq!(spec.data.output_column, Some("output".to_string()));
762 assert_eq!(spec.data.max_length, Some(512));
763
764 assert_eq!(spec.training.mode, TrainingMode::CausalLm);
766 assert_eq!(spec.training.gradient_accumulation, Some(4));
767 assert_eq!(spec.training.checkpoints, Some(6));
768 assert_eq!(spec.training.mixed_precision, Some("bf16".to_string()));
769 }
770
771 #[test]
772 fn test_backward_compatible_minimal_config() {
773 let yaml = r"
775model:
776 path: model.gguf
777
778data:
779 train: data.parquet
780 batch_size: 8
781
782optimizer:
783 name: adam
784 lr: 0.001
785";
786
787 let spec: TrainSpec = serde_yaml::from_str(yaml).expect("operation should succeed");
788 assert_eq!(spec.model.mode, ModelMode::Tabular);
789 assert_eq!(spec.training.mode, TrainingMode::Regression);
790 assert!(spec.data.tokenizer.is_none());
791 }
792
793 #[test]
794 fn test_training_params_new_fields_default() {
795 let params = TrainingParams::default();
796 assert_eq!(params.mode, TrainingMode::Regression);
797 assert!(params.gradient_accumulation.is_none());
798 assert!(params.checkpoints.is_none());
799 assert!(params.mixed_precision.is_none());
800 assert!(params.scheduler_params.is_none());
801 assert!(params.seed.is_none());
802 assert!(params.distributed.is_none());
803 }
804
805 #[test]
806 fn test_deserialize_distributed_config() {
807 let yaml = r"
808model:
809 path: model.safetensors
810 mode: transformer
811
812data:
813 train: data/train/
814 batch_size: 4
815
816optimizer:
817 name: adamw
818 lr: 0.0003
819
820training:
821 epochs: 5
822 mode: causal_lm
823 distributed:
824 world_size: 2
825 backend: cuda
826 role: coordinator
827 coordinator_addr: '0.0.0.0:9000'
828";
829 let spec: TrainSpec = serde_yaml::from_str(yaml).expect("operation should succeed");
830 let dist = spec.training.distributed.expect("distributed config present");
831 assert_eq!(dist.world_size, 2);
832 assert_eq!(dist.backend, "cuda");
833 assert_eq!(dist.role, "coordinator");
834 assert_eq!(dist.coordinator_addr, "0.0.0.0:9000");
835 assert_eq!(dist.rank, 0);
836 assert_eq!(dist.local_rank, 0);
837 }
838
839 #[test]
840 fn test_distributed_config_defaults() {
841 let yaml = r"
842model:
843 path: model.safetensors
844
845data:
846 train: data.parquet
847 batch_size: 8
848
849optimizer:
850 name: adamw
851 lr: 0.001
852
853training:
854 distributed:
855 world_size: 4
856";
857 let spec: TrainSpec = serde_yaml::from_str(yaml).expect("operation should succeed");
858 let dist = spec.training.distributed.expect("distributed config present");
859 assert_eq!(dist.world_size, 4);
860 assert_eq!(dist.backend, "auto");
861 assert_eq!(dist.role, "coordinator");
862 assert_eq!(dist.coordinator_addr, "0.0.0.0:9000");
863 }
864
865 #[test]
866 fn test_backward_compatible_no_distributed() {
867 let yaml = r"
868model:
869 path: model.gguf
870
871data:
872 train: data.parquet
873 batch_size: 8
874
875optimizer:
876 name: adam
877 lr: 0.001
878";
879 let spec: TrainSpec = serde_yaml::from_str(yaml).expect("operation should succeed");
880 assert!(spec.training.distributed.is_none());
881 }
882
883 #[test]
884 fn test_deserialize_scheduler_params_and_seed() {
885 let yaml = r"
886model:
887 path: model.gguf
888
889data:
890 train: data.parquet
891 batch_size: 8
892
893optimizer:
894 name: adam
895 lr: 0.001
896
897training:
898 epochs: 5
899 seed: 42
900 lr_scheduler: cosine
901 scheduler_params:
902 t_max: 1000
903 eta_min: 0.000001
904";
905
906 let spec: TrainSpec = serde_yaml::from_str(yaml).expect("operation should succeed");
907 assert_eq!(spec.training.seed, Some(42));
908 let params = spec.training.scheduler_params.expect("operation should succeed");
909 assert_eq!(params["t_max"], serde_json::json!(1000));
910 assert_eq!(params["eta_min"], serde_json::json!(0.000001));
911 }
912
913 #[test]
916 fn test_cb950_quoted_booleans_deserialize() {
917 let yaml = r#"
918model:
919 path: model.gguf
920 layers: []
921
922data:
923 train: train.parquet
924 batch_size: 8
925 auto_infer_types: "true"
926
927optimizer:
928 name: adam
929 lr: 0.001
930
931quantize:
932 bits: 4
933 symmetric: "true"
934 per_channel: "false"
935"#;
936
937 let spec: TrainSpec = serde_yaml::from_str(yaml).expect("operation should succeed");
938 assert!(spec.data.auto_infer_types);
939 let quant = spec.quantize.expect("operation should succeed");
940 assert!(quant.symmetric);
941 assert!(!quant.per_channel);
942 }
943
944 #[test]
947 fn test_is_hf_repo_id_valid() {
948 assert!(is_hf_repo_id("Qwen/Qwen2.5-Coder-0.5B"));
949 assert!(is_hf_repo_id("meta-llama/Llama-2-7b"));
950 assert!(is_hf_repo_id("google/gemma-2b"));
951 assert!(is_hf_repo_id("myuser/my-model"));
952 }
953
954 #[test]
955 fn test_is_hf_repo_id_local_paths() {
956 assert!(!is_hf_repo_id("model.gguf"));
957 assert!(!is_hf_repo_id("./models/model.safetensors"));
958 assert!(!is_hf_repo_id("/absolute/path/model.bin"));
959 assert!(!is_hf_repo_id("relative/path/model.gguf"));
960 }
961
962 #[test]
963 fn test_is_hf_repo_id_edge_cases() {
964 assert!(!is_hf_repo_id(""));
965 assert!(!is_hf_repo_id("/"));
966 assert!(!is_hf_repo_id("single-part"));
967 assert!(!is_hf_repo_id("too/many/parts"));
968 assert!(!is_hf_repo_id(".hidden/path"));
969 assert!(!is_hf_repo_id("/org/name"));
970 assert!(!is_hf_repo_id("org/"));
971 assert!(!is_hf_repo_id("/name"));
972 }
973
974 #[test]
975 fn test_is_hf_repo_id_with_extension_rejected() {
976 assert!(!is_hf_repo_id("org/model.safetensors"));
978 assert!(!is_hf_repo_id("user/model.gguf"));
979 }
980
981 #[test]
982 fn test_model_ref_is_hf_repo_id() {
983 let model =
984 ModelRef { path: PathBuf::from("Qwen/Qwen2.5-Coder-0.5B"), ..Default::default() };
985 assert!(model.is_hf_repo_id());
986
987 let model = ModelRef { path: PathBuf::from("model.gguf"), ..Default::default() };
988 assert!(!model.is_hf_repo_id());
989 }
990
991 #[test]
992 fn test_deserialize_hf_repo_id_as_model_path() {
993 let yaml = r"
994model:
995 path: Qwen/Qwen2.5-Coder-0.5B
996 mode: transformer
997
998data:
999 train: data.parquet
1000 batch_size: 8
1001
1002optimizer:
1003 name: adamw
1004 lr: 0.0001
1005";
1006 let spec: TrainSpec = serde_yaml::from_str(yaml).expect("operation should succeed");
1007 assert!(spec.model.is_hf_repo_id());
1008 assert_eq!(spec.model.path, PathBuf::from("Qwen/Qwen2.5-Coder-0.5B"));
1009 }
1010
1011 #[test]
1014 fn test_deserialize_with_publish_section() {
1015 let yaml = r"
1016model:
1017 path: model.gguf
1018
1019data:
1020 train: data.parquet
1021 batch_size: 8
1022
1023optimizer:
1024 name: adamw
1025 lr: 0.0001
1026
1027publish:
1028 repo: myuser/my-model
1029 private: false
1030 model_card: true
1031 merge_adapters: true
1032 format: safetensors
1033";
1034 let spec: TrainSpec = serde_yaml::from_str(yaml).expect("operation should succeed");
1035 let publish = spec.publish.expect("operation should succeed");
1036 assert_eq!(publish.repo, "myuser/my-model");
1037 assert!(!publish.private);
1038 assert!(publish.model_card);
1039 assert!(publish.merge_adapters);
1040 assert_eq!(publish.format, "safetensors");
1041 }
1042
1043 #[test]
1044 fn test_deserialize_without_publish_section() {
1045 let yaml = r"
1046model:
1047 path: model.gguf
1048
1049data:
1050 train: data.parquet
1051 batch_size: 8
1052
1053optimizer:
1054 name: adam
1055 lr: 0.001
1056";
1057 let spec: TrainSpec = serde_yaml::from_str(yaml).expect("operation should succeed");
1058 assert!(spec.publish.is_none());
1059 }
1060
1061 #[test]
1062 fn test_publish_spec_defaults() {
1063 let yaml = r"
1064model:
1065 path: model.gguf
1066
1067data:
1068 train: data.parquet
1069 batch_size: 8
1070
1071optimizer:
1072 name: adam
1073 lr: 0.001
1074
1075publish:
1076 repo: org/model
1077";
1078 let spec: TrainSpec = serde_yaml::from_str(yaml).expect("operation should succeed");
1079 let publish = spec.publish.expect("operation should succeed");
1080 assert_eq!(publish.repo, "org/model");
1081 assert!(!publish.private);
1082 assert!(publish.model_card);
1083 assert!(!publish.merge_adapters);
1084 assert_eq!(publish.format, "safetensors");
1085 }
1086
1087 #[test]
1088 fn test_deterministic_config_yaml() {
1089 let yaml = r"
1090model:
1091 path: test-model
1092 type: transformer
1093
1094data:
1095 train: data.parquet
1096 batch_size: 8
1097
1098optimizer:
1099 name: adamw
1100 lr: 0.001
1101
1102training:
1103 epochs: 10
1104 deterministic: true
1105 seed: 12345
1106";
1107 let spec: TrainSpec = serde_yaml::from_str(yaml).expect("parse YAML");
1108 assert!(spec.training.deterministic, "deterministic should be true from YAML");
1109 assert_eq!(spec.training.seed, Some(12345));
1110 }
1111
1112 #[test]
1113 fn test_deterministic_defaults_to_false() {
1114 let yaml = r"
1115model:
1116 path: test-model
1117
1118data:
1119 train: data.parquet
1120 batch_size: 8
1121
1122optimizer:
1123 name: adamw
1124 lr: 0.001
1125";
1126 let spec: TrainSpec = serde_yaml::from_str(yaml).expect("parse YAML");
1127 assert!(!spec.training.deterministic, "deterministic should default to false");
1128 }
1129
1130 #[test]
1131 fn test_ent_263_lora_quantize_base_yaml() {
1132 let yaml = r"
1133model:
1134 path: model.safetensors
1135
1136data:
1137 train: data.parquet
1138 batch_size: 4
1139
1140optimizer:
1141 name: adamw
1142 lr: 0.0001
1143
1144lora:
1145 rank: 16
1146 alpha: 32.0
1147 target_modules: [q_proj, v_proj]
1148 quantize_base: true
1149 double_quantize: true
1150
1151training:
1152 epochs: 1
1153";
1154 let spec: TrainSpec = serde_yaml::from_str(yaml).expect("parse YAML");
1155 let lora = spec.lora.expect("lora should be present");
1156 assert!(lora.quantize_base, "quantize_base should be true");
1157 assert!(lora.double_quantize, "double_quantize should be true");
1158 assert_eq!(lora.rank, 16);
1159 }
1160
1161 #[test]
1162 fn test_ent_263_lora_quantize_base_default_false() {
1163 let yaml = r"
1164model:
1165 path: model.safetensors
1166
1167data:
1168 train: data.parquet
1169 batch_size: 4
1170
1171optimizer:
1172 name: adamw
1173 lr: 0.0001
1174
1175lora:
1176 rank: 8
1177 alpha: 16.0
1178 target_modules: [q_proj]
1179";
1180 let spec: TrainSpec = serde_yaml::from_str(yaml).expect("parse YAML");
1181 let lora = spec.lora.expect("lora should be present");
1182 assert!(!lora.quantize_base, "quantize_base should default to false");
1183 }
1184}