use super::error::ValidationError;
use crate::config::schema::TrainSpec;
pub fn validate_config(spec: &TrainSpec) -> Result<(), ValidationError> {
validate_model_path(spec)?;
validate_data_paths(spec)?;
validate_batch_size(spec)?;
validate_learning_rate(spec)?;
validate_optimizer(spec)?;
validate_epochs(spec)?;
validate_training_params(spec)?;
validate_lora(spec)?;
validate_quantization(spec)?;
validate_merge(spec)?;
validate_publish(spec)?;
Ok(())
}
#[cfg(not(test))]
fn validate_model_path(spec: &TrainSpec) -> Result<(), ValidationError> {
if spec.model.is_hf_repo_id() {
return Ok(());
}
if !spec.model.path.exists() {
return Err(ValidationError::ModelPathNotFound(spec.model.path.display().to_string()));
}
Ok(())
}
#[cfg(test)]
fn validate_model_path(_spec: &TrainSpec) -> Result<(), ValidationError> {
Ok(())
}
#[cfg(not(test))]
fn validate_data_paths(spec: &TrainSpec) -> Result<(), ValidationError> {
if !spec.data.train.exists() {
return Err(ValidationError::TrainDataNotFound(spec.data.train.display().to_string()));
}
if let Some(val_path) = &spec.data.val {
if !val_path.exists() {
return Err(ValidationError::ValDataNotFound(val_path.display().to_string()));
}
}
Ok(())
}
#[cfg(test)]
fn validate_data_paths(_spec: &TrainSpec) -> Result<(), ValidationError> {
Ok(())
}
fn validate_batch_size(spec: &TrainSpec) -> Result<(), ValidationError> {
if spec.data.batch_size == 0 {
return Err(ValidationError::InvalidBatchSize(spec.data.batch_size));
}
Ok(())
}
fn validate_learning_rate(spec: &TrainSpec) -> Result<(), ValidationError> {
if spec.optimizer.lr <= 0.0 || spec.optimizer.lr > 1.0 {
return Err(ValidationError::InvalidLearningRate(spec.optimizer.lr));
}
Ok(())
}
fn validate_optimizer(spec: &TrainSpec) -> Result<(), ValidationError> {
const VALID_OPTIMIZERS: [&str; 6] = ["adam", "adamw", "sgd", "rmsprop", "adagrad", "lamb"];
if !VALID_OPTIMIZERS.contains(&spec.optimizer.name.as_str()) {
return Err(ValidationError::InvalidOptimizer(spec.optimizer.name.clone()));
}
Ok(())
}
fn validate_epochs(spec: &TrainSpec) -> Result<(), ValidationError> {
if spec.training.epochs == 0 {
return Err(ValidationError::InvalidEpochs(spec.training.epochs));
}
Ok(())
}
fn validate_training_params(spec: &TrainSpec) -> Result<(), ValidationError> {
validate_grad_clip(spec)?;
validate_seq_len(spec)?;
validate_save_interval(spec)?;
validate_lr_scheduler(spec)?;
Ok(())
}
fn validate_grad_clip(spec: &TrainSpec) -> Result<(), ValidationError> {
if let Some(grad_clip) = spec.training.grad_clip {
if grad_clip <= 0.0 {
return Err(ValidationError::InvalidGradClip(grad_clip));
}
}
Ok(())
}
fn validate_seq_len(spec: &TrainSpec) -> Result<(), ValidationError> {
if let Some(seq_len) = spec.data.seq_len {
if seq_len == 0 {
return Err(ValidationError::InvalidSeqLen(seq_len));
}
}
Ok(())
}
fn validate_save_interval(spec: &TrainSpec) -> Result<(), ValidationError> {
if spec.training.save_interval == 0 {
return Err(ValidationError::InvalidSaveInterval(spec.training.save_interval));
}
Ok(())
}
fn validate_lr_scheduler(spec: &TrainSpec) -> Result<(), ValidationError> {
if let Some(scheduler) = &spec.training.lr_scheduler {
const VALID_SCHEDULERS: [&str; 7] =
["cosine", "linear", "constant", "step", "exponential", "one_cycle", "plateau"];
if !VALID_SCHEDULERS.contains(&scheduler.as_str()) {
return Err(ValidationError::InvalidLRScheduler(scheduler.clone()));
}
}
Ok(())
}
fn validate_lora(spec: &TrainSpec) -> Result<(), ValidationError> {
let Some(lora) = &spec.lora else {
return Ok(());
};
validate_lora_rank(lora.rank)?;
validate_lora_alpha(lora.alpha)?;
validate_lora_dropout(lora.dropout)?;
validate_lora_targets(&lora.target_modules)?;
Ok(())
}
fn validate_lora_rank(rank: usize) -> Result<(), ValidationError> {
if rank == 0 || rank > 1024 {
return Err(ValidationError::InvalidLoRARank(rank));
}
Ok(())
}
fn validate_lora_alpha(alpha: f32) -> Result<(), ValidationError> {
if alpha <= 0.0 {
return Err(ValidationError::InvalidLoRAAlpha(alpha));
}
Ok(())
}
fn validate_lora_dropout(dropout: f32) -> Result<(), ValidationError> {
if !(0.0..1.0).contains(&dropout) {
return Err(ValidationError::InvalidLoRADropout(dropout));
}
Ok(())
}
fn validate_lora_targets(targets: &[String]) -> Result<(), ValidationError> {
if targets.is_empty() {
return Err(ValidationError::EmptyLoRATargets);
}
Ok(())
}
fn validate_quantization(spec: &TrainSpec) -> Result<(), ValidationError> {
let Some(quant) = &spec.quantize else {
return Ok(());
};
if quant.bits != 4 && quant.bits != 8 {
return Err(ValidationError::InvalidQuantBits(quant.bits));
}
Ok(())
}
fn validate_merge(spec: &TrainSpec) -> Result<(), ValidationError> {
let Some(merge) = &spec.merge else {
return Ok(());
};
const VALID_METHODS: [&str; 3] = ["ties", "dare", "slerp"];
if !VALID_METHODS.contains(&merge.method.as_str()) {
return Err(ValidationError::InvalidMergeMethod(merge.method.clone()));
}
Ok(())
}
fn validate_publish(spec: &TrainSpec) -> Result<(), ValidationError> {
let Some(publish) = &spec.publish else {
return Ok(());
};
let parts: Vec<&str> = publish.repo.split('/').collect();
if parts.len() != 2 || parts[0].is_empty() || parts[1].is_empty() {
return Err(ValidationError::InvalidPublishRepo(publish.repo.clone()));
}
if publish.format != "safetensors" && publish.format != "gguf" {
return Err(ValidationError::InvalidPublishFormat(publish.format.clone()));
}
Ok(())
}