Skip to main content

entrenar/config/validate/
validator.rs

1//! Configuration validation logic
2//!
3//! Validates training specifications for correctness before execution.
4
5use super::error::ValidationError;
6use crate::config::schema::TrainSpec;
7
8/// Validate a training specification
9///
10/// Checks:
11/// - File paths exist
12/// - Numeric values are in valid ranges
13/// - Enums match allowed values
14pub fn validate_config(spec: &TrainSpec) -> Result<(), ValidationError> {
15    validate_model_path(spec)?;
16    validate_data_paths(spec)?;
17    validate_batch_size(spec)?;
18    validate_learning_rate(spec)?;
19    validate_optimizer(spec)?;
20    validate_epochs(spec)?;
21    validate_training_params(spec)?;
22    validate_lora(spec)?;
23    validate_quantization(spec)?;
24    validate_merge(spec)?;
25    validate_publish(spec)?;
26    Ok(())
27}
28
29/// Validate model path exists (or is a valid HuggingFace repo ID)
30#[cfg(not(test))]
31fn validate_model_path(spec: &TrainSpec) -> Result<(), ValidationError> {
32    // Accept HF repo IDs — they'll be resolved at training time
33    if spec.model.is_hf_repo_id() {
34        return Ok(());
35    }
36    if !spec.model.path.exists() {
37        return Err(ValidationError::ModelPathNotFound(spec.model.path.display().to_string()));
38    }
39    Ok(())
40}
41
42#[cfg(test)]
43fn validate_model_path(_spec: &TrainSpec) -> Result<(), ValidationError> {
44    Ok(())
45}
46
47/// Validate data paths exist
48#[cfg(not(test))]
49fn validate_data_paths(spec: &TrainSpec) -> Result<(), ValidationError> {
50    if !spec.data.train.exists() {
51        return Err(ValidationError::TrainDataNotFound(spec.data.train.display().to_string()));
52    }
53
54    if let Some(val_path) = &spec.data.val {
55        if !val_path.exists() {
56            return Err(ValidationError::ValDataNotFound(val_path.display().to_string()));
57        }
58    }
59    Ok(())
60}
61
62#[cfg(test)]
63fn validate_data_paths(_spec: &TrainSpec) -> Result<(), ValidationError> {
64    Ok(())
65}
66
67/// Validate batch size is non-zero
68fn validate_batch_size(spec: &TrainSpec) -> Result<(), ValidationError> {
69    if spec.data.batch_size == 0 {
70        return Err(ValidationError::InvalidBatchSize(spec.data.batch_size));
71    }
72    Ok(())
73}
74
75/// Validate learning rate is positive and reasonable
76fn validate_learning_rate(spec: &TrainSpec) -> Result<(), ValidationError> {
77    if spec.optimizer.lr <= 0.0 || spec.optimizer.lr > 1.0 {
78        return Err(ValidationError::InvalidLearningRate(spec.optimizer.lr));
79    }
80    Ok(())
81}
82
83/// Validate optimizer name is supported
84fn validate_optimizer(spec: &TrainSpec) -> Result<(), ValidationError> {
85    const VALID_OPTIMIZERS: [&str; 6] = ["adam", "adamw", "sgd", "rmsprop", "adagrad", "lamb"];
86    if !VALID_OPTIMIZERS.contains(&spec.optimizer.name.as_str()) {
87        return Err(ValidationError::InvalidOptimizer(spec.optimizer.name.clone()));
88    }
89    Ok(())
90}
91
92/// Validate epochs is non-zero
93fn validate_epochs(spec: &TrainSpec) -> Result<(), ValidationError> {
94    if spec.training.epochs == 0 {
95        return Err(ValidationError::InvalidEpochs(spec.training.epochs));
96    }
97    Ok(())
98}
99
100/// Validate training parameters (grad_clip, seq_len, save_interval, lr_scheduler)
101fn validate_training_params(spec: &TrainSpec) -> Result<(), ValidationError> {
102    validate_grad_clip(spec)?;
103    validate_seq_len(spec)?;
104    validate_save_interval(spec)?;
105    validate_lr_scheduler(spec)?;
106    Ok(())
107}
108
109/// Validate gradient clipping value
110fn validate_grad_clip(spec: &TrainSpec) -> Result<(), ValidationError> {
111    if let Some(grad_clip) = spec.training.grad_clip {
112        if grad_clip <= 0.0 {
113            return Err(ValidationError::InvalidGradClip(grad_clip));
114        }
115    }
116    Ok(())
117}
118
119/// Validate sequence length if specified
120fn validate_seq_len(spec: &TrainSpec) -> Result<(), ValidationError> {
121    if let Some(seq_len) = spec.data.seq_len {
122        if seq_len == 0 {
123            return Err(ValidationError::InvalidSeqLen(seq_len));
124        }
125    }
126    Ok(())
127}
128
129/// Validate save interval
130fn validate_save_interval(spec: &TrainSpec) -> Result<(), ValidationError> {
131    if spec.training.save_interval == 0 {
132        return Err(ValidationError::InvalidSaveInterval(spec.training.save_interval));
133    }
134    Ok(())
135}
136
137/// Validate LR scheduler if specified
138fn validate_lr_scheduler(spec: &TrainSpec) -> Result<(), ValidationError> {
139    if let Some(scheduler) = &spec.training.lr_scheduler {
140        const VALID_SCHEDULERS: [&str; 7] =
141            ["cosine", "linear", "constant", "step", "exponential", "one_cycle", "plateau"];
142        if !VALID_SCHEDULERS.contains(&scheduler.as_str()) {
143            return Err(ValidationError::InvalidLRScheduler(scheduler.clone()));
144        }
145    }
146    Ok(())
147}
148
149/// Validate LoRA configuration if present
150fn validate_lora(spec: &TrainSpec) -> Result<(), ValidationError> {
151    let Some(lora) = &spec.lora else {
152        return Ok(());
153    };
154
155    validate_lora_rank(lora.rank)?;
156    validate_lora_alpha(lora.alpha)?;
157    validate_lora_dropout(lora.dropout)?;
158    validate_lora_targets(&lora.target_modules)?;
159    Ok(())
160}
161
162/// Validate LoRA rank (1-1024)
163fn validate_lora_rank(rank: usize) -> Result<(), ValidationError> {
164    if rank == 0 || rank > 1024 {
165        return Err(ValidationError::InvalidLoRARank(rank));
166    }
167    Ok(())
168}
169
170/// Validate LoRA alpha (must be positive)
171fn validate_lora_alpha(alpha: f32) -> Result<(), ValidationError> {
172    if alpha <= 0.0 {
173        return Err(ValidationError::InvalidLoRAAlpha(alpha));
174    }
175    Ok(())
176}
177
178/// Validate LoRA dropout (0.0 to <1.0)
179fn validate_lora_dropout(dropout: f32) -> Result<(), ValidationError> {
180    if !(0.0..1.0).contains(&dropout) {
181        return Err(ValidationError::InvalidLoRADropout(dropout));
182    }
183    Ok(())
184}
185
186/// Validate LoRA target modules are not empty
187fn validate_lora_targets(targets: &[String]) -> Result<(), ValidationError> {
188    if targets.is_empty() {
189        return Err(ValidationError::EmptyLoRATargets);
190    }
191    Ok(())
192}
193
194/// Validate quantization configuration if present
195fn validate_quantization(spec: &TrainSpec) -> Result<(), ValidationError> {
196    let Some(quant) = &spec.quantize else {
197        return Ok(());
198    };
199
200    if quant.bits != 4 && quant.bits != 8 {
201        return Err(ValidationError::InvalidQuantBits(quant.bits));
202    }
203    Ok(())
204}
205
206/// Validate merge configuration if present
207fn validate_merge(spec: &TrainSpec) -> Result<(), ValidationError> {
208    let Some(merge) = &spec.merge else {
209        return Ok(());
210    };
211
212    const VALID_METHODS: [&str; 3] = ["ties", "dare", "slerp"];
213    if !VALID_METHODS.contains(&merge.method.as_str()) {
214        return Err(ValidationError::InvalidMergeMethod(merge.method.clone()));
215    }
216    Ok(())
217}
218
219/// Validate publish configuration if present
220fn validate_publish(spec: &TrainSpec) -> Result<(), ValidationError> {
221    let Some(publish) = &spec.publish else {
222        return Ok(());
223    };
224
225    // Repo ID must contain exactly one `/`
226    let parts: Vec<&str> = publish.repo.split('/').collect();
227    if parts.len() != 2 || parts[0].is_empty() || parts[1].is_empty() {
228        return Err(ValidationError::InvalidPublishRepo(publish.repo.clone()));
229    }
230
231    // Format must be safetensors or gguf
232    if publish.format != "safetensors" && publish.format != "gguf" {
233        return Err(ValidationError::InvalidPublishFormat(publish.format.clone()));
234    }
235
236    Ok(())
237}