Skip to main content

entrenar/yaml_mode/
validation.rs

1//! Manifest Validation (Poka-yoke)
2//!
3//! Schema validation catches errors at parse time, not runtime.
4//! Implements the Toyota Way's poka-yoke principle of defect prevention at source.
5
6use super::manifest::TrainingManifest;
7use thiserror::Error;
8
9/// Validation result type
10pub type ValidationResult<T> = Result<T, ManifestError>;
11
12/// Manifest validation errors
13#[derive(Debug, Error)]
14pub enum ManifestError {
15    #[error("Unsupported entrenar version: {0}. Supported versions: 1.0")]
16    UnsupportedVersion(String),
17
18    #[error("Empty required field: {0}")]
19    EmptyRequiredField(String),
20
21    #[error("Invalid range for {field}: {value} (expected {constraint})")]
22    InvalidRange { field: String, value: String, constraint: String },
23
24    #[error("Mutually exclusive fields specified: {field1} and {field2}")]
25    MutuallyExclusive { field1: String, field2: String },
26
27    #[error("Invalid split ratios: sum is {sum} (expected 1.0)")]
28    InvalidSplitRatios { sum: f64 },
29
30    #[error("Invalid quantization bits: {bits}. Valid values: 2, 4, 8")]
31    InvalidQuantBits { bits: u8 },
32
33    #[error("Dependency error: {0}")]
34    DependencyError(String),
35
36    #[error("Invalid optimizer: {0}")]
37    InvalidOptimizer(String),
38
39    #[error("Invalid scheduler: {0}")]
40    InvalidScheduler(String),
41}
42
43/// Supported entrenar specification versions
44const SUPPORTED_VERSIONS: &[&str] = &["1.0"];
45
46/// Valid optimizer names
47const VALID_OPTIMIZERS: &[&str] = &["sgd", "adam", "adamw", "rmsprop", "adagrad", "lamb"];
48
49/// Valid scheduler names
50const VALID_SCHEDULERS: &[&str] =
51    &["step", "cosine", "cosine_annealing", "linear", "exponential", "plateau", "one_cycle"];
52
53/// Valid quantization bit widths
54const VALID_QUANT_BITS: &[u8] = &[2, 4, 8];
55
56/// Validate a training manifest
57///
58/// Performs comprehensive validation including:
59/// 1. Version compatibility
60/// 2. Required fields presence
61/// 3. Type constraints
62/// 4. Range constraints
63/// 5. Mutual exclusivity
64/// 6. Cross-field dependencies
65pub fn validate_manifest(manifest: &TrainingManifest) -> ValidationResult<()> {
66    // 1. Version validation
67    validate_version(&manifest.entrenar)?;
68
69    // 2. Required field validation
70    validate_required_fields(manifest)?;
71
72    // 3. Optimizer validation
73    if let Some(ref optim) = manifest.optimizer {
74        validate_optimizer(optim)?;
75    }
76
77    // 4. Scheduler validation
78    if let Some(ref sched) = manifest.scheduler {
79        validate_scheduler(sched)?;
80    }
81
82    // 5. Training config validation
83    if let Some(ref training) = manifest.training {
84        validate_training(training)?;
85    }
86
87    // 6. Data config validation
88    if let Some(ref data) = manifest.data {
89        validate_data(data)?;
90    }
91
92    // 7. LoRA validation
93    if let Some(ref lora) = manifest.lora {
94        validate_lora(lora)?;
95    }
96
97    // 8. Quantization validation
98    if let Some(ref quant) = manifest.quantize {
99        validate_quantize(quant)?;
100    }
101
102    Ok(())
103}
104
105/// Validate specification version
106fn validate_version(version: &str) -> ValidationResult<()> {
107    if !SUPPORTED_VERSIONS.contains(&version) {
108        return Err(ManifestError::UnsupportedVersion(version.to_string()));
109    }
110    Ok(())
111}
112
113/// Validate required fields
114fn validate_required_fields(manifest: &TrainingManifest) -> ValidationResult<()> {
115    if manifest.name.is_empty() {
116        return Err(ManifestError::EmptyRequiredField("name".to_string()));
117    }
118
119    if manifest.version.is_empty() {
120        return Err(ManifestError::EmptyRequiredField("version".to_string()));
121    }
122
123    Ok(())
124}
125
126// ---------------------------------------------------------------------------
127// Shared range-check helpers (reduce nesting in callers)
128// ---------------------------------------------------------------------------
129
130/// Validate that a required f64 is strictly positive
131fn validate_positive_f64(value: f64, field: &str, constraint: &str) -> ValidationResult<()> {
132    if value <= 0.0 {
133        return Err(ManifestError::InvalidRange {
134            field: field.to_string(),
135            value: value.to_string(),
136            constraint: constraint.to_string(),
137        });
138    }
139    Ok(())
140}
141
142/// Validate that an optional usize, if present, is non-zero (>= 1)
143fn validate_nonzero_usize(value: Option<usize>, field: &str) -> ValidationResult<()> {
144    if let Some(v) = value {
145        if v == 0 {
146            return Err(ManifestError::InvalidRange {
147                field: field.to_string(),
148                value: v.to_string(),
149                constraint: ">= 1".to_string(),
150            });
151        }
152    }
153    Ok(())
154}
155
156/// Validate that an optional f64, if present, is non-negative (>= 0)
157fn validate_nonneg_f64(value: Option<f64>, field: &str) -> ValidationResult<()> {
158    if let Some(v) = value {
159        if v < 0.0 {
160            return Err(ManifestError::InvalidRange {
161                field: field.to_string(),
162                value: v.to_string(),
163                constraint: ">= 0".to_string(),
164            });
165        }
166    }
167    Ok(())
168}
169
170/// Validate that an optional f64, if present, lies within the half-open range [0, 1)
171fn validate_dropout_range(value: Option<f64>, field: &str) -> ValidationResult<()> {
172    if let Some(v) = value {
173        if !(0.0..1.0).contains(&v) {
174            return Err(ManifestError::InvalidRange {
175                field: field.to_string(),
176                value: v.to_string(),
177                constraint: "in [0, 1)".to_string(),
178            });
179        }
180    }
181    Ok(())
182}
183
184/// Validate that an optional u8, if present, is a valid quantization bit width
185fn validate_quant_bits(bits: Option<u8>) -> ValidationResult<()> {
186    if let Some(b) = bits {
187        if !VALID_QUANT_BITS.contains(&b) {
188            return Err(ManifestError::InvalidQuantBits { bits: b });
189        }
190    }
191    Ok(())
192}
193
194// ---------------------------------------------------------------------------
195// Optimizer validation
196// ---------------------------------------------------------------------------
197
198/// Validate optimizer configuration
199fn validate_optimizer(optim: &super::manifest::OptimizerConfig) -> ValidationResult<()> {
200    validate_optimizer_name(&optim.name)?;
201    validate_positive_f64(optim.lr, "optimizer.lr", "> 0")?;
202    validate_nonneg_f64(optim.weight_decay, "optimizer.weight_decay")?;
203    validate_optimizer_betas(optim.betas.as_deref())?;
204    Ok(())
205}
206
207/// Validate optimizer name against the allow-list
208fn validate_optimizer_name(name: &str) -> ValidationResult<()> {
209    let name_lower = name.to_lowercase();
210    if !VALID_OPTIMIZERS.contains(&name_lower.as_str()) {
211        return Err(ManifestError::InvalidOptimizer(format!(
212            "Unknown optimizer '{name}'. Valid options: {VALID_OPTIMIZERS:?}",
213        )));
214    }
215    Ok(())
216}
217
218/// Validate that each beta value is in the open interval (0, 1)
219fn validate_optimizer_betas(betas: Option<&[f64]>) -> ValidationResult<()> {
220    let Some(betas) = betas else {
221        return Ok(());
222    };
223    for (i, beta) in betas.iter().enumerate() {
224        if *beta <= 0.0 || *beta >= 1.0 {
225            return Err(ManifestError::InvalidRange {
226                field: format!("optimizer.betas[{i}]"),
227                value: beta.to_string(),
228                constraint: "in (0, 1)".to_string(),
229            });
230        }
231    }
232    Ok(())
233}
234
235// ---------------------------------------------------------------------------
236// Scheduler validation
237// ---------------------------------------------------------------------------
238
239/// Validate scheduler configuration
240fn validate_scheduler(sched: &super::manifest::SchedulerConfig) -> ValidationResult<()> {
241    let name_lower = sched.name.to_lowercase();
242    if !VALID_SCHEDULERS.contains(&name_lower.as_str()) {
243        return Err(ManifestError::InvalidScheduler(format!(
244            "Unknown scheduler '{}'. Valid options: {:?}",
245            sched.name, VALID_SCHEDULERS
246        )));
247    }
248
249    Ok(())
250}
251
252// ---------------------------------------------------------------------------
253// Training validation
254// ---------------------------------------------------------------------------
255
256/// Validate training configuration
257fn validate_training(training: &super::manifest::TrainingConfig) -> ValidationResult<()> {
258    validate_duration_exclusivity(training)?;
259    validate_nonzero_usize(training.epochs, "training.epochs")?;
260    validate_gradient_config(training.gradient.as_ref())?;
261    Ok(())
262}
263
264/// Ensure at most one of epochs / max_steps / duration is specified
265fn validate_duration_exclusivity(
266    training: &super::manifest::TrainingConfig,
267) -> ValidationResult<()> {
268    let has_epochs = training.epochs.is_some();
269    let has_max_steps = training.max_steps.is_some();
270    let has_duration = training.duration.is_some();
271
272    if let Some((f1, f2)) = first_duration_conflict(has_epochs, has_max_steps, has_duration) {
273        return Err(ManifestError::MutuallyExclusive {
274            field1: f1.to_string(),
275            field2: f2.to_string(),
276        });
277    }
278    Ok(())
279}
280
281/// Return the first pair of conflicting duration fields, if any
282fn first_duration_conflict(
283    has_epochs: bool,
284    has_max_steps: bool,
285    has_duration: bool,
286) -> Option<(&'static str, &'static str)> {
287    if has_epochs && has_max_steps {
288        return Some(("training.epochs", "training.max_steps"));
289    }
290    if has_epochs && has_duration {
291        return Some(("training.epochs", "training.duration"));
292    }
293    if has_max_steps && has_duration {
294        return Some(("training.max_steps", "training.duration"));
295    }
296    None
297}
298
299/// Validate gradient accumulation steps if present
300fn validate_gradient_config(
301    gradient: Option<&super::manifest::GradientConfig>,
302) -> ValidationResult<()> {
303    let Some(grad) = gradient else {
304        return Ok(());
305    };
306    validate_nonzero_usize(grad.accumulation_steps, "training.gradient.accumulation_steps")
307}
308
309// ---------------------------------------------------------------------------
310// Data validation
311// ---------------------------------------------------------------------------
312
313/// Validate data configuration
314fn validate_data(data: &super::manifest::DataConfig) -> ValidationResult<()> {
315    validate_loader_batch_size(data.loader.as_ref())?;
316    validate_split_ratios(data.split.as_ref())
317}
318
319/// Validate that loader batch_size > 0
320fn validate_loader_batch_size(
321    loader: Option<&super::manifest::DataLoader>,
322) -> ValidationResult<()> {
323    let Some(loader) = loader else {
324        return Ok(());
325    };
326    if loader.batch_size == 0 {
327        return Err(ManifestError::InvalidRange {
328            field: "data.loader.batch_size".to_string(),
329            value: "0".to_string(),
330            constraint: ">= 1".to_string(),
331        });
332    }
333    Ok(())
334}
335
336/// Validate data split ratios sum to 1.0 and train ratio is in [0, 1]
337fn validate_split_ratios(split: Option<&super::manifest::DataSplit>) -> ValidationResult<()> {
338    let Some(split) = split else {
339        return Ok(());
340    };
341
342    let sum = split.train + split.val.unwrap_or(0.0) + split.test.unwrap_or(0.0);
343
344    // Allow small tolerance for floating point
345    if (sum - 1.0).abs() > 0.001 {
346        return Err(ManifestError::InvalidSplitRatios { sum });
347    }
348
349    // Validate individual ratios in [0, 1]
350    if split.train < 0.0 || split.train > 1.0 {
351        return Err(ManifestError::InvalidRange {
352            field: "data.split.train".to_string(),
353            value: split.train.to_string(),
354            constraint: "in [0, 1]".to_string(),
355        });
356    }
357    Ok(())
358}
359
360// ---------------------------------------------------------------------------
361// LoRA validation
362// ---------------------------------------------------------------------------
363
364/// Validate LoRA configuration
365fn validate_lora(lora: &super::manifest::LoraConfig) -> ValidationResult<()> {
366    // Only validate if enabled
367    if !lora.enabled {
368        return Ok(());
369    }
370
371    validate_lora_target_modules(lora)?;
372    validate_lora_rank(lora.rank)?;
373    validate_positive_f64(lora.alpha, "lora.alpha", "> 0")?;
374    validate_dropout_range(lora.dropout, "lora.dropout")?;
375    validate_quant_bits(lora.quantize_bits)
376}
377
378/// Validate that at least one of target_modules or target_modules_pattern is provided
379fn validate_lora_target_modules(lora: &super::manifest::LoraConfig) -> ValidationResult<()> {
380    if lora.target_modules.is_empty() && lora.target_modules_pattern.is_none() {
381        return Err(ManifestError::EmptyRequiredField("lora.target_modules".to_string()));
382    }
383    Ok(())
384}
385
386/// Validate LoRA rank is at least 1
387fn validate_lora_rank(rank: usize) -> ValidationResult<()> {
388    if rank == 0 {
389        return Err(ManifestError::InvalidRange {
390            field: "lora.rank".to_string(),
391            value: "0".to_string(),
392            constraint: ">= 1".to_string(),
393        });
394    }
395    Ok(())
396}
397
398// ---------------------------------------------------------------------------
399// Quantization validation
400// ---------------------------------------------------------------------------
401
402/// Validate quantization configuration
403fn validate_quantize(quant: &super::manifest::QuantizeConfig) -> ValidationResult<()> {
404    // Only validate if enabled
405    if !quant.enabled {
406        return Ok(());
407    }
408
409    // Validate bits
410    if !VALID_QUANT_BITS.contains(&quant.bits) {
411        return Err(ManifestError::InvalidQuantBits { bits: quant.bits });
412    }
413
414    Ok(())
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420
421    #[test]
422    fn test_validate_version() {
423        assert!(validate_version("1.0").is_ok());
424        assert!(validate_version("2.0").is_err());
425    }
426
427    #[test]
428    fn test_valid_optimizers() {
429        for opt in VALID_OPTIMIZERS {
430            let optim = super::super::manifest::OptimizerConfig {
431                name: opt.to_string(),
432                lr: 0.001,
433                weight_decay: None,
434                betas: None,
435                eps: None,
436                amsgrad: None,
437                momentum: None,
438                nesterov: None,
439                dampening: None,
440                alpha: None,
441                centered: None,
442                param_groups: None,
443            };
444            assert!(validate_optimizer(&optim).is_ok(), "Optimizer {opt} should be valid");
445        }
446    }
447
448    #[test]
449    fn test_valid_quant_bits() {
450        for bits in VALID_QUANT_BITS {
451            let quant = super::super::manifest::QuantizeConfig {
452                enabled: true,
453                bits: *bits,
454                scheme: None,
455                granularity: None,
456                group_size: None,
457                qat: None,
458                calibration: None,
459                exclude: None,
460            };
461            assert!(validate_quantize(&quant).is_ok(), "Quant bits {bits} should be valid");
462        }
463    }
464}