use crate::error::{ConfigError, HalldyllError, Result};
use std::collections::HashSet;
use tracing::debug;
use super::spec::{DeployConfig, PodConfig, StateBackend, VolumeConfig};
#[derive(Debug, Default)]
pub struct ConfigValidator {
known_gpu_types: HashSet<String>,
}
const KNOWN_GPU_TYPES: &[&str] = &[
"NVIDIA A40",
"NVIDIA A100 80GB PCIe",
"NVIDIA A100-SXM4-80GB",
"NVIDIA GeForce RTX 3070",
"NVIDIA GeForce RTX 3080",
"NVIDIA GeForce RTX 3080 Ti",
"NVIDIA GeForce RTX 3090",
"NVIDIA GeForce RTX 3090 Ti",
"NVIDIA GeForce RTX 4070 Ti",
"NVIDIA GeForce RTX 4080",
"NVIDIA GeForce RTX 4090",
"NVIDIA H100 80GB HBM3",
"NVIDIA H100 PCIe",
"NVIDIA L4",
"NVIDIA L40",
"NVIDIA L40S",
"NVIDIA RTX 4000 Ada Generation",
"NVIDIA RTX 5000 Ada Generation",
"NVIDIA RTX 6000 Ada Generation",
"NVIDIA RTX A4000",
"NVIDIA RTX A4500",
"NVIDIA RTX A5000",
"NVIDIA RTX A6000",
];
#[derive(Debug, Default)]
pub struct ValidationResult {
pub errors: Vec<ValidationError>,
pub warnings: Vec<String>,
}
#[derive(Debug)]
pub struct ValidationError {
pub field: String,
pub message: String,
}
impl ConfigValidator {
#[must_use]
pub fn new() -> Self {
Self {
known_gpu_types: KNOWN_GPU_TYPES.iter().map(|s| (*s).to_string()).collect(),
}
}
pub fn add_gpu_type(&mut self, gpu_type: impl Into<String>) {
self.known_gpu_types.insert(gpu_type.into());
}
pub fn validate(&self, config: &DeployConfig) -> Result<ValidationResult> {
let mut result = ValidationResult::default();
Self::validate_project(&config.project, &mut result);
Self::validate_state(&config.state, &mut result);
self.validate_pods(&config.pods, &mut result);
Self::validate_guardrails(config, &mut result);
if result.errors.is_empty() {
debug!("Configuration validation passed");
Ok(result)
} else {
let first_error = &result.errors[0];
Err(HalldyllError::Config(ConfigError::ValidationError {
message: first_error.message.clone(),
field: Some(first_error.field.clone()),
}))
}
}
fn validate_project(
project: &super::spec::ProjectConfig,
result: &mut ValidationResult,
) {
if project.name.is_empty() {
result.errors.push(ValidationError {
field: String::from("project.name"),
message: String::from("Project name cannot be empty"),
});
} else if !is_valid_name(&project.name) {
result.errors.push(ValidationError {
field: String::from("project.name"),
message: format!(
"Project name '{}' is invalid. Must be lowercase alphanumeric with hyphens.",
project.name
),
});
}
if project.environment.is_empty() {
result.errors.push(ValidationError {
field: String::from("project.environment"),
message: String::from("Environment cannot be empty"),
});
}
}
fn validate_state(state: &super::spec::StateConfig, result: &mut ValidationResult) {
match state.backend {
StateBackend::S3 => {
if state.bucket.is_none() || state.bucket.as_ref().is_some_and(String::is_empty) {
result.errors.push(ValidationError {
field: String::from("state.bucket"),
message: String::from("S3 bucket name is required when using S3 backend"),
});
}
}
StateBackend::Local => {
}
}
}
fn validate_pods(&self, pods: &[PodConfig], result: &mut ValidationResult) {
if pods.is_empty() {
result.warnings.push(String::from("No pods defined in configuration"));
return;
}
let mut seen_names = HashSet::new();
let mut all_ports: HashSet<u16> = HashSet::new();
for (i, pod) in pods.iter().enumerate() {
let prefix = format!("pods[{i}]");
if seen_names.contains(&pod.name) {
result.errors.push(ValidationError {
field: format!("{prefix}.name"),
message: format!("Duplicate pod name: {}", pod.name),
});
} else {
seen_names.insert(&pod.name);
}
if !is_valid_name(&pod.name) {
result.errors.push(ValidationError {
field: format!("{prefix}.name"),
message: format!(
"Pod name '{}' is invalid. Must be lowercase alphanumeric with hyphens.",
pod.name
),
});
}
self.validate_gpu(&pod.gpu, &prefix, result);
Self::validate_ports(&pod.ports, &prefix, &mut all_ports, result);
Self::validate_volumes(&pod.volumes, &prefix, result);
Self::validate_runtime(&pod.runtime, &prefix, result);
Self::validate_models(&pod.models, &prefix, result);
}
}
fn validate_gpu(
&self,
gpu: &super::spec::GpuConfig,
prefix: &str,
result: &mut ValidationResult,
) {
if gpu.count == 0 {
result.errors.push(ValidationError {
field: format!("{prefix}.gpu.count"),
message: String::from("GPU count must be at least 1"),
});
}
if gpu.count > 8 {
result.warnings.push(format!(
"{prefix}.gpu.count: Requesting {count} GPUs is unusual",
count = gpu.count
));
}
if !self.known_gpu_types.contains(&gpu.gpu_type) {
result.warnings.push(format!(
"{prefix}.gpu.type: Unknown GPU type '{}'. This may fail if not available.",
gpu.gpu_type
));
}
for (i, fallback) in gpu.fallback.iter().enumerate() {
if !self.known_gpu_types.contains(fallback) {
result.warnings.push(format!(
"{prefix}.gpu.fallback[{i}]: Unknown fallback GPU type '{fallback}'",
));
}
}
}
fn validate_ports(
ports: &[super::spec::PortConfig],
prefix: &str,
all_ports: &mut HashSet<u16>,
result: &mut ValidationResult,
) {
let mut pod_ports = HashSet::new();
for (i, port) in ports.iter().enumerate() {
if pod_ports.contains(&port.port) {
result.errors.push(ValidationError {
field: format!("{prefix}.ports[{i}]"),
message: format!("Duplicate port {} in pod", port.port),
});
} else {
pod_ports.insert(port.port);
}
if port.port < 1024 && port.port != 22 && port.port != 80 && port.port != 443 {
result.warnings.push(format!(
"{prefix}.ports[{i}]: Port {} is in the reserved range (<1024)",
port.port
));
}
}
all_ports.extend(pod_ports);
}
fn validate_volumes(
volumes: &[VolumeConfig],
prefix: &str,
result: &mut ValidationResult,
) {
let mut seen_names = HashSet::new();
let mut seen_mounts = HashSet::new();
for (i, volume) in volumes.iter().enumerate() {
if seen_names.contains(&volume.name) {
result.errors.push(ValidationError {
field: format!("{prefix}.volumes[{i}].name"),
message: format!("Duplicate volume name: {}", volume.name),
});
} else {
seen_names.insert(&volume.name);
}
if seen_mounts.contains(&volume.mount) {
result.errors.push(ValidationError {
field: format!("{prefix}.volumes[{i}].mount"),
message: format!("Duplicate mount path: {}", volume.mount),
});
} else {
seen_mounts.insert(&volume.mount);
}
if !volume.mount.starts_with('/') {
result.errors.push(ValidationError {
field: format!("{prefix}.volumes[{i}].mount"),
message: format!("Mount path must be absolute: {}", volume.mount),
});
}
}
}
fn validate_runtime(
runtime: &super::spec::RuntimeConfig,
prefix: &str,
result: &mut ValidationResult,
) {
if runtime.image.is_empty() {
result.errors.push(ValidationError {
field: format!("{prefix}.runtime.image"),
message: String::from("Container image cannot be empty"),
});
}
if runtime.image.ends_with(":latest") {
result.warnings.push(format!(
"{prefix}.runtime.image: Using ':latest' tag is not recommended for production"
));
}
}
fn validate_models(
models: &[super::spec::ModelConfig],
prefix: &str,
result: &mut ValidationResult,
) {
let mut seen_ids = HashSet::new();
for (i, model) in models.iter().enumerate() {
if seen_ids.contains(&model.id) {
result.errors.push(ValidationError {
field: format!("{prefix}.models[{i}].id"),
message: format!("Duplicate model ID: {}", model.id),
});
} else {
seen_ids.insert(&model.id);
}
if model.provider == super::spec::ModelProvider::Huggingface
&& model.repo.is_none()
{
result.errors.push(ValidationError {
field: format!("{prefix}.models[{i}].repo"),
message: format!(
"Model '{}' uses huggingface provider but no repo specified",
model.id
),
});
}
if model.provider == super::spec::ModelProvider::Bundle
&& model.components.as_ref().is_none_or(Vec::is_empty)
{
result.errors.push(ValidationError {
field: format!("{prefix}.models[{i}].components"),
message: format!(
"Model '{}' uses bundle provider but no components specified",
model.id
),
});
}
}
}
fn validate_guardrails(config: &DeployConfig, result: &mut ValidationResult) {
if let Some(guardrails) = &config.guardrails {
if let Some(cost) = guardrails.max_hourly_cost
&& cost <= 0.0 {
result.errors.push(ValidationError {
field: String::from("guardrails.max_hourly_cost"),
message: String::from("Maximum hourly cost must be positive"),
});
}
if let Some(max_gpus) = guardrails.max_gpus {
let total_gpus = config.total_gpus();
if total_gpus > max_gpus {
result.errors.push(ValidationError {
field: String::from("guardrails.max_gpus"),
message: format!(
"Configuration requires {total_gpus} GPUs but max_gpus is {max_gpus}"
),
});
}
}
if let Some(ttl) = guardrails.ttl_hours
&& ttl == 0 {
result.errors.push(ValidationError {
field: String::from("guardrails.ttl_hours"),
message: String::from("TTL must be at least 1 hour"),
});
}
}
}
}
fn is_valid_name(name: &str) -> bool {
if name.is_empty() {
return false;
}
let mut chars = name.chars();
if let Some(first) = chars.next()
&& !first.is_ascii_lowercase() {
return false;
}
for c in chars {
if !c.is_ascii_lowercase() && !c.is_ascii_digit() && c != '-' {
return false;
}
}
if name.ends_with('-') {
return false;
}
if name.contains("--") {
return false;
}
true
}
impl ValidationResult {
#[must_use]
pub const fn is_valid(&self) -> bool {
self.errors.is_empty()
}
#[must_use]
pub const fn error_count(&self) -> usize {
self.errors.len()
}
#[must_use]
pub const fn warning_count(&self) -> usize {
self.warnings.len()
}
}
impl std::fmt::Display for ValidationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}: {}", self.field, self.message)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_valid_name() {
assert!(is_valid_name("pod-text"));
assert!(is_valid_name("my-pod-123"));
assert!(is_valid_name("a"));
assert!(is_valid_name("test"));
}
#[test]
fn test_invalid_name() {
assert!(!is_valid_name(""));
assert!(!is_valid_name("Pod-Text")); assert!(!is_valid_name("123-pod")); assert!(!is_valid_name("pod_text")); assert!(!is_valid_name("pod-")); assert!(!is_valid_name("pod--text")); }
}