entrenar/config/validate/
validator.rs1use super::error::ValidationError;
6use crate::config::schema::TrainSpec;
7
8pub fn validate_config(spec: &TrainSpec) -> Result<(), ValidationError> {
15 validate_model_path(spec)?;
16 validate_data_paths(spec)?;
17 validate_batch_size(spec)?;
18 validate_learning_rate(spec)?;
19 validate_optimizer(spec)?;
20 validate_epochs(spec)?;
21 validate_training_params(spec)?;
22 validate_lora(spec)?;
23 validate_quantization(spec)?;
24 validate_merge(spec)?;
25 validate_publish(spec)?;
26 Ok(())
27}
28
29#[cfg(not(test))]
31fn validate_model_path(spec: &TrainSpec) -> Result<(), ValidationError> {
32 if spec.model.is_hf_repo_id() {
34 return Ok(());
35 }
36 if !spec.model.path.exists() {
37 return Err(ValidationError::ModelPathNotFound(spec.model.path.display().to_string()));
38 }
39 Ok(())
40}
41
42#[cfg(test)]
43fn validate_model_path(_spec: &TrainSpec) -> Result<(), ValidationError> {
44 Ok(())
45}
46
47#[cfg(not(test))]
49fn validate_data_paths(spec: &TrainSpec) -> Result<(), ValidationError> {
50 if !spec.data.train.exists() {
51 return Err(ValidationError::TrainDataNotFound(spec.data.train.display().to_string()));
52 }
53
54 if let Some(val_path) = &spec.data.val {
55 if !val_path.exists() {
56 return Err(ValidationError::ValDataNotFound(val_path.display().to_string()));
57 }
58 }
59 Ok(())
60}
61
62#[cfg(test)]
63fn validate_data_paths(_spec: &TrainSpec) -> Result<(), ValidationError> {
64 Ok(())
65}
66
67fn validate_batch_size(spec: &TrainSpec) -> Result<(), ValidationError> {
69 if spec.data.batch_size == 0 {
70 return Err(ValidationError::InvalidBatchSize(spec.data.batch_size));
71 }
72 Ok(())
73}
74
75fn validate_learning_rate(spec: &TrainSpec) -> Result<(), ValidationError> {
77 if spec.optimizer.lr <= 0.0 || spec.optimizer.lr > 1.0 {
78 return Err(ValidationError::InvalidLearningRate(spec.optimizer.lr));
79 }
80 Ok(())
81}
82
83fn validate_optimizer(spec: &TrainSpec) -> Result<(), ValidationError> {
85 const VALID_OPTIMIZERS: [&str; 6] = ["adam", "adamw", "sgd", "rmsprop", "adagrad", "lamb"];
86 if !VALID_OPTIMIZERS.contains(&spec.optimizer.name.as_str()) {
87 return Err(ValidationError::InvalidOptimizer(spec.optimizer.name.clone()));
88 }
89 Ok(())
90}
91
92fn validate_epochs(spec: &TrainSpec) -> Result<(), ValidationError> {
94 if spec.training.epochs == 0 {
95 return Err(ValidationError::InvalidEpochs(spec.training.epochs));
96 }
97 Ok(())
98}
99
100fn validate_training_params(spec: &TrainSpec) -> Result<(), ValidationError> {
102 validate_grad_clip(spec)?;
103 validate_seq_len(spec)?;
104 validate_save_interval(spec)?;
105 validate_lr_scheduler(spec)?;
106 Ok(())
107}
108
109fn validate_grad_clip(spec: &TrainSpec) -> Result<(), ValidationError> {
111 if let Some(grad_clip) = spec.training.grad_clip {
112 if grad_clip <= 0.0 {
113 return Err(ValidationError::InvalidGradClip(grad_clip));
114 }
115 }
116 Ok(())
117}
118
119fn validate_seq_len(spec: &TrainSpec) -> Result<(), ValidationError> {
121 if let Some(seq_len) = spec.data.seq_len {
122 if seq_len == 0 {
123 return Err(ValidationError::InvalidSeqLen(seq_len));
124 }
125 }
126 Ok(())
127}
128
129fn validate_save_interval(spec: &TrainSpec) -> Result<(), ValidationError> {
131 if spec.training.save_interval == 0 {
132 return Err(ValidationError::InvalidSaveInterval(spec.training.save_interval));
133 }
134 Ok(())
135}
136
137fn validate_lr_scheduler(spec: &TrainSpec) -> Result<(), ValidationError> {
139 if let Some(scheduler) = &spec.training.lr_scheduler {
140 const VALID_SCHEDULERS: [&str; 7] =
141 ["cosine", "linear", "constant", "step", "exponential", "one_cycle", "plateau"];
142 if !VALID_SCHEDULERS.contains(&scheduler.as_str()) {
143 return Err(ValidationError::InvalidLRScheduler(scheduler.clone()));
144 }
145 }
146 Ok(())
147}
148
149fn validate_lora(spec: &TrainSpec) -> Result<(), ValidationError> {
151 let Some(lora) = &spec.lora else {
152 return Ok(());
153 };
154
155 validate_lora_rank(lora.rank)?;
156 validate_lora_alpha(lora.alpha)?;
157 validate_lora_dropout(lora.dropout)?;
158 validate_lora_targets(&lora.target_modules)?;
159 Ok(())
160}
161
162fn validate_lora_rank(rank: usize) -> Result<(), ValidationError> {
164 if rank == 0 || rank > 1024 {
165 return Err(ValidationError::InvalidLoRARank(rank));
166 }
167 Ok(())
168}
169
170fn validate_lora_alpha(alpha: f32) -> Result<(), ValidationError> {
172 if alpha <= 0.0 {
173 return Err(ValidationError::InvalidLoRAAlpha(alpha));
174 }
175 Ok(())
176}
177
178fn validate_lora_dropout(dropout: f32) -> Result<(), ValidationError> {
180 if !(0.0..1.0).contains(&dropout) {
181 return Err(ValidationError::InvalidLoRADropout(dropout));
182 }
183 Ok(())
184}
185
186fn validate_lora_targets(targets: &[String]) -> Result<(), ValidationError> {
188 if targets.is_empty() {
189 return Err(ValidationError::EmptyLoRATargets);
190 }
191 Ok(())
192}
193
194fn validate_quantization(spec: &TrainSpec) -> Result<(), ValidationError> {
196 let Some(quant) = &spec.quantize else {
197 return Ok(());
198 };
199
200 if quant.bits != 4 && quant.bits != 8 {
201 return Err(ValidationError::InvalidQuantBits(quant.bits));
202 }
203 Ok(())
204}
205
206fn validate_merge(spec: &TrainSpec) -> Result<(), ValidationError> {
208 let Some(merge) = &spec.merge else {
209 return Ok(());
210 };
211
212 const VALID_METHODS: [&str; 3] = ["ties", "dare", "slerp"];
213 if !VALID_METHODS.contains(&merge.method.as_str()) {
214 return Err(ValidationError::InvalidMergeMethod(merge.method.clone()));
215 }
216 Ok(())
217}
218
219fn validate_publish(spec: &TrainSpec) -> Result<(), ValidationError> {
221 let Some(publish) = &spec.publish else {
222 return Ok(());
223 };
224
225 let parts: Vec<&str> = publish.repo.split('/').collect();
227 if parts.len() != 2 || parts[0].is_empty() || parts[1].is_empty() {
228 return Err(ValidationError::InvalidPublishRepo(publish.repo.clone()));
229 }
230
231 if publish.format != "safetensors" && publish.format != "gguf" {
233 return Err(ValidationError::InvalidPublishFormat(publish.format.clone()));
234 }
235
236 Ok(())
237}