1use super::manifest::TrainingManifest;
7use thiserror::Error;
8
9pub type ValidationResult<T> = Result<T, ManifestError>;
11
12#[derive(Debug, Error)]
14pub enum ManifestError {
15 #[error("Unsupported entrenar version: {0}. Supported versions: 1.0")]
16 UnsupportedVersion(String),
17
18 #[error("Empty required field: {0}")]
19 EmptyRequiredField(String),
20
21 #[error("Invalid range for {field}: {value} (expected {constraint})")]
22 InvalidRange { field: String, value: String, constraint: String },
23
24 #[error("Mutually exclusive fields specified: {field1} and {field2}")]
25 MutuallyExclusive { field1: String, field2: String },
26
27 #[error("Invalid split ratios: sum is {sum} (expected 1.0)")]
28 InvalidSplitRatios { sum: f64 },
29
30 #[error("Invalid quantization bits: {bits}. Valid values: 2, 4, 8")]
31 InvalidQuantBits { bits: u8 },
32
33 #[error("Dependency error: {0}")]
34 DependencyError(String),
35
36 #[error("Invalid optimizer: {0}")]
37 InvalidOptimizer(String),
38
39 #[error("Invalid scheduler: {0}")]
40 InvalidScheduler(String),
41}
42
43const SUPPORTED_VERSIONS: &[&str] = &["1.0"];
45
46const VALID_OPTIMIZERS: &[&str] = &["sgd", "adam", "adamw", "rmsprop", "adagrad", "lamb"];
48
49const VALID_SCHEDULERS: &[&str] =
51 &["step", "cosine", "cosine_annealing", "linear", "exponential", "plateau", "one_cycle"];
52
53const VALID_QUANT_BITS: &[u8] = &[2, 4, 8];
55
56pub fn validate_manifest(manifest: &TrainingManifest) -> ValidationResult<()> {
66 validate_version(&manifest.entrenar)?;
68
69 validate_required_fields(manifest)?;
71
72 if let Some(ref optim) = manifest.optimizer {
74 validate_optimizer(optim)?;
75 }
76
77 if let Some(ref sched) = manifest.scheduler {
79 validate_scheduler(sched)?;
80 }
81
82 if let Some(ref training) = manifest.training {
84 validate_training(training)?;
85 }
86
87 if let Some(ref data) = manifest.data {
89 validate_data(data)?;
90 }
91
92 if let Some(ref lora) = manifest.lora {
94 validate_lora(lora)?;
95 }
96
97 if let Some(ref quant) = manifest.quantize {
99 validate_quantize(quant)?;
100 }
101
102 Ok(())
103}
104
105fn validate_version(version: &str) -> ValidationResult<()> {
107 if !SUPPORTED_VERSIONS.contains(&version) {
108 return Err(ManifestError::UnsupportedVersion(version.to_string()));
109 }
110 Ok(())
111}
112
113fn validate_required_fields(manifest: &TrainingManifest) -> ValidationResult<()> {
115 if manifest.name.is_empty() {
116 return Err(ManifestError::EmptyRequiredField("name".to_string()));
117 }
118
119 if manifest.version.is_empty() {
120 return Err(ManifestError::EmptyRequiredField("version".to_string()));
121 }
122
123 Ok(())
124}
125
126fn validate_positive_f64(value: f64, field: &str, constraint: &str) -> ValidationResult<()> {
132 if value <= 0.0 {
133 return Err(ManifestError::InvalidRange {
134 field: field.to_string(),
135 value: value.to_string(),
136 constraint: constraint.to_string(),
137 });
138 }
139 Ok(())
140}
141
142fn validate_nonzero_usize(value: Option<usize>, field: &str) -> ValidationResult<()> {
144 if let Some(v) = value {
145 if v == 0 {
146 return Err(ManifestError::InvalidRange {
147 field: field.to_string(),
148 value: v.to_string(),
149 constraint: ">= 1".to_string(),
150 });
151 }
152 }
153 Ok(())
154}
155
156fn validate_nonneg_f64(value: Option<f64>, field: &str) -> ValidationResult<()> {
158 if let Some(v) = value {
159 if v < 0.0 {
160 return Err(ManifestError::InvalidRange {
161 field: field.to_string(),
162 value: v.to_string(),
163 constraint: ">= 0".to_string(),
164 });
165 }
166 }
167 Ok(())
168}
169
170fn validate_dropout_range(value: Option<f64>, field: &str) -> ValidationResult<()> {
172 if let Some(v) = value {
173 if !(0.0..1.0).contains(&v) {
174 return Err(ManifestError::InvalidRange {
175 field: field.to_string(),
176 value: v.to_string(),
177 constraint: "in [0, 1)".to_string(),
178 });
179 }
180 }
181 Ok(())
182}
183
184fn validate_quant_bits(bits: Option<u8>) -> ValidationResult<()> {
186 if let Some(b) = bits {
187 if !VALID_QUANT_BITS.contains(&b) {
188 return Err(ManifestError::InvalidQuantBits { bits: b });
189 }
190 }
191 Ok(())
192}
193
194fn validate_optimizer(optim: &super::manifest::OptimizerConfig) -> ValidationResult<()> {
200 validate_optimizer_name(&optim.name)?;
201 validate_positive_f64(optim.lr, "optimizer.lr", "> 0")?;
202 validate_nonneg_f64(optim.weight_decay, "optimizer.weight_decay")?;
203 validate_optimizer_betas(optim.betas.as_deref())?;
204 Ok(())
205}
206
207fn validate_optimizer_name(name: &str) -> ValidationResult<()> {
209 let name_lower = name.to_lowercase();
210 if !VALID_OPTIMIZERS.contains(&name_lower.as_str()) {
211 return Err(ManifestError::InvalidOptimizer(format!(
212 "Unknown optimizer '{name}'. Valid options: {VALID_OPTIMIZERS:?}",
213 )));
214 }
215 Ok(())
216}
217
218fn validate_optimizer_betas(betas: Option<&[f64]>) -> ValidationResult<()> {
220 let Some(betas) = betas else {
221 return Ok(());
222 };
223 for (i, beta) in betas.iter().enumerate() {
224 if *beta <= 0.0 || *beta >= 1.0 {
225 return Err(ManifestError::InvalidRange {
226 field: format!("optimizer.betas[{i}]"),
227 value: beta.to_string(),
228 constraint: "in (0, 1)".to_string(),
229 });
230 }
231 }
232 Ok(())
233}
234
235fn validate_scheduler(sched: &super::manifest::SchedulerConfig) -> ValidationResult<()> {
241 let name_lower = sched.name.to_lowercase();
242 if !VALID_SCHEDULERS.contains(&name_lower.as_str()) {
243 return Err(ManifestError::InvalidScheduler(format!(
244 "Unknown scheduler '{}'. Valid options: {:?}",
245 sched.name, VALID_SCHEDULERS
246 )));
247 }
248
249 Ok(())
250}
251
252fn validate_training(training: &super::manifest::TrainingConfig) -> ValidationResult<()> {
258 validate_duration_exclusivity(training)?;
259 validate_nonzero_usize(training.epochs, "training.epochs")?;
260 validate_gradient_config(training.gradient.as_ref())?;
261 Ok(())
262}
263
264fn validate_duration_exclusivity(
266 training: &super::manifest::TrainingConfig,
267) -> ValidationResult<()> {
268 let has_epochs = training.epochs.is_some();
269 let has_max_steps = training.max_steps.is_some();
270 let has_duration = training.duration.is_some();
271
272 if let Some((f1, f2)) = first_duration_conflict(has_epochs, has_max_steps, has_duration) {
273 return Err(ManifestError::MutuallyExclusive {
274 field1: f1.to_string(),
275 field2: f2.to_string(),
276 });
277 }
278 Ok(())
279}
280
281fn first_duration_conflict(
283 has_epochs: bool,
284 has_max_steps: bool,
285 has_duration: bool,
286) -> Option<(&'static str, &'static str)> {
287 if has_epochs && has_max_steps {
288 return Some(("training.epochs", "training.max_steps"));
289 }
290 if has_epochs && has_duration {
291 return Some(("training.epochs", "training.duration"));
292 }
293 if has_max_steps && has_duration {
294 return Some(("training.max_steps", "training.duration"));
295 }
296 None
297}
298
299fn validate_gradient_config(
301 gradient: Option<&super::manifest::GradientConfig>,
302) -> ValidationResult<()> {
303 let Some(grad) = gradient else {
304 return Ok(());
305 };
306 validate_nonzero_usize(grad.accumulation_steps, "training.gradient.accumulation_steps")
307}
308
309fn validate_data(data: &super::manifest::DataConfig) -> ValidationResult<()> {
315 validate_loader_batch_size(data.loader.as_ref())?;
316 validate_split_ratios(data.split.as_ref())
317}
318
319fn validate_loader_batch_size(
321 loader: Option<&super::manifest::DataLoader>,
322) -> ValidationResult<()> {
323 let Some(loader) = loader else {
324 return Ok(());
325 };
326 if loader.batch_size == 0 {
327 return Err(ManifestError::InvalidRange {
328 field: "data.loader.batch_size".to_string(),
329 value: "0".to_string(),
330 constraint: ">= 1".to_string(),
331 });
332 }
333 Ok(())
334}
335
336fn validate_split_ratios(split: Option<&super::manifest::DataSplit>) -> ValidationResult<()> {
338 let Some(split) = split else {
339 return Ok(());
340 };
341
342 let sum = split.train + split.val.unwrap_or(0.0) + split.test.unwrap_or(0.0);
343
344 if (sum - 1.0).abs() > 0.001 {
346 return Err(ManifestError::InvalidSplitRatios { sum });
347 }
348
349 if split.train < 0.0 || split.train > 1.0 {
351 return Err(ManifestError::InvalidRange {
352 field: "data.split.train".to_string(),
353 value: split.train.to_string(),
354 constraint: "in [0, 1]".to_string(),
355 });
356 }
357 Ok(())
358}
359
360fn validate_lora(lora: &super::manifest::LoraConfig) -> ValidationResult<()> {
366 if !lora.enabled {
368 return Ok(());
369 }
370
371 validate_lora_target_modules(lora)?;
372 validate_lora_rank(lora.rank)?;
373 validate_positive_f64(lora.alpha, "lora.alpha", "> 0")?;
374 validate_dropout_range(lora.dropout, "lora.dropout")?;
375 validate_quant_bits(lora.quantize_bits)
376}
377
378fn validate_lora_target_modules(lora: &super::manifest::LoraConfig) -> ValidationResult<()> {
380 if lora.target_modules.is_empty() && lora.target_modules_pattern.is_none() {
381 return Err(ManifestError::EmptyRequiredField("lora.target_modules".to_string()));
382 }
383 Ok(())
384}
385
386fn validate_lora_rank(rank: usize) -> ValidationResult<()> {
388 if rank == 0 {
389 return Err(ManifestError::InvalidRange {
390 field: "lora.rank".to_string(),
391 value: "0".to_string(),
392 constraint: ">= 1".to_string(),
393 });
394 }
395 Ok(())
396}
397
398fn validate_quantize(quant: &super::manifest::QuantizeConfig) -> ValidationResult<()> {
404 if !quant.enabled {
406 return Ok(());
407 }
408
409 if !VALID_QUANT_BITS.contains(&quant.bits) {
411 return Err(ManifestError::InvalidQuantBits { bits: quant.bits });
412 }
413
414 Ok(())
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420
421 #[test]
422 fn test_validate_version() {
423 assert!(validate_version("1.0").is_ok());
424 assert!(validate_version("2.0").is_err());
425 }
426
427 #[test]
428 fn test_valid_optimizers() {
429 for opt in VALID_OPTIMIZERS {
430 let optim = super::super::manifest::OptimizerConfig {
431 name: opt.to_string(),
432 lr: 0.001,
433 weight_decay: None,
434 betas: None,
435 eps: None,
436 amsgrad: None,
437 momentum: None,
438 nesterov: None,
439 dampening: None,
440 alpha: None,
441 centered: None,
442 param_groups: None,
443 };
444 assert!(validate_optimizer(&optim).is_ok(), "Optimizer {opt} should be valid");
445 }
446 }
447
448 #[test]
449 fn test_valid_quant_bits() {
450 for bits in VALID_QUANT_BITS {
451 let quant = super::super::manifest::QuantizeConfig {
452 enabled: true,
453 bits: *bits,
454 scheme: None,
455 granularity: None,
456 group_size: None,
457 qat: None,
458 calibration: None,
459 exclude: None,
460 };
461 assert!(validate_quantize(&quant).is_ok(), "Quant bits {bits} should be valid");
462 }
463 }
464}