use crate::config::paths;
use crate::config::{Config, LogLevel, StreamingMode};
use crate::error::{ProxyError, Result};
use std::path::Path;
pub struct ConfigValidator<'a> {
config: &'a Config,
errors: Vec<String>,
warnings: Vec<String>,
}
impl<'a> ConfigValidator<'a> {
pub fn new(config: &'a Config) -> Self {
Self { config, errors: Vec::new(), warnings: Vec::new() }
}
pub fn validate(mut self) -> Result<()> {
self.validate_server_config();
self.validate_auth_config();
self.validate_streaming_config();
self.validate_security_requirements();
for warning in &self.warnings {
tracing::warn!("Configuration warning: {}", warning);
}
if !self.errors.is_empty() {
let error_msg = format!(
"Configuration validation failed with {} error(s):\n\n{}\n\
\n\
Please fix these issues and try again.\n\
Run 'modelmux config init' for interactive configuration setup.",
self.errors.len(),
self.errors
.iter()
.enumerate()
.map(|(i, e)| format!("{}. {}", i + 1, e))
.collect::<Vec<_>>()
.join("\n")
);
return Err(ProxyError::Config(error_msg));
}
tracing::info!("Configuration validation passed");
if !self.warnings.is_empty() {
tracing::info!("Configuration has {} warning(s) but is valid", self.warnings.len());
}
Ok(())
}
fn validate_server_config(&mut self) {
let server = &self.config.server;
if server.port == 0 {
self.add_error(format!(
"Invalid server port {}: must be between 1 and 65535",
server.port
));
}
if server.port < 1024 {
self.add_warning(format!(
"Server port {} requires root/administrator privileges",
server.port
));
}
match server.port {
80 | 443 => {
self.add_warning(format!(
"Port {} is commonly used by web servers and may conflict",
server.port
));
}
22 => {
self.add_warning("Port 22 is used by SSH and may conflict".to_string());
}
25 | 587 | 465 => {
self.add_warning(format!(
"Port {} is used by mail servers and may conflict",
server.port
));
}
_ => {}
}
if server.max_retry_attempts > 10 {
self.add_warning(format!(
"High retry count ({}): may cause long delays on failures",
server.max_retry_attempts
));
}
tracing::debug!("Server config validation completed");
}
fn validate_auth_config(&mut self) {
let auth = &self.config.auth;
let has_file = auth.service_account_file.is_some();
let has_json = auth.service_account_json.is_some();
if !has_file && !has_json {
self.add_error(
"No service account configuration found. Please set either:\n\
- auth.service_account_file = \"/path/to/service-account.json\"\n\
- auth.service_account_json = \"{ ... }\" (inline JSON)"
.to_string(),
);
return; }
if let Some(ref file_path) = auth.service_account_file {
self.validate_service_account_file(file_path);
}
if let Some(ref json_str) = auth.service_account_json {
self.validate_service_account_json(json_str);
}
if has_file && has_json {
self.add_warning(
"Both service_account_file and service_account_json are specified. \
service_account_json will take precedence."
.to_string(),
);
}
tracing::debug!("Auth config validation completed");
}
fn validate_service_account_file(&mut self, file_path: &str) {
let expanded_path = match paths::expand_path(file_path) {
Ok(path) => path,
Err(e) => {
self.add_error(format!(
"Failed to expand service account file path '{}': {}",
file_path, e
));
return;
}
};
if !expanded_path.exists() {
self.add_error(format!(
"Service account file not found: '{}'\n\
\n\
To fix this:\n\
1. Download your Google Cloud service account key JSON\n\
2. Save it to the specified path\n\
3. Ensure the file is readable\n\
\n\
Example:\n\
mkdir -p ~/.config/modelmux\n\
cp /path/to/downloaded-key.json ~/.config/modelmux/service-account.json\n\
chmod 600 ~/.config/modelmux/service-account.json",
expanded_path.display()
));
return;
}
if !expanded_path.is_file() {
self.add_error(format!(
"Service account path exists but is not a regular file: '{}'",
expanded_path.display()
));
return;
}
self.validate_file_permissions(&expanded_path);
match std::fs::read_to_string(&expanded_path) {
Ok(contents) => {
self.validate_service_account_json(&contents);
}
Err(e) => {
self.add_error(format!(
"Cannot read service account file '{}': {}\n\
Please check file permissions.",
expanded_path.display(),
e
));
}
}
}
fn validate_service_account_json(&mut self, json_str: &str) {
let service_account: serde_json::Value = match serde_json::from_str(json_str) {
Ok(value) => value,
Err(e) => {
self.add_error(format!(
"Invalid service account JSON: {}\n\
Please ensure the JSON is properly formatted.",
e
));
return;
}
};
let required_fields = [
"type",
"project_id",
"private_key_id",
"private_key",
"client_email",
"client_id",
"auth_uri",
"token_uri",
];
for field in &required_fields {
if !service_account.get(field).and_then(|v| v.as_str()).map_or(false, |s| !s.is_empty())
{
self.add_error(format!(
"Service account JSON missing or empty required field: '{}'",
field
));
}
}
if let Some(account_type) = service_account.get("type").and_then(|v| v.as_str()) {
if account_type != "service_account" {
self.add_error(format!(
"Invalid service account type: '{}'. Expected 'service_account'",
account_type
));
}
}
if let Some(client_email) = service_account.get("client_email").and_then(|v| v.as_str()) {
if !client_email.contains('@') || !client_email.contains("gserviceaccount.com") {
self.add_warning(format!(
"Service account email '{}' doesn't look like a Google service account email",
client_email
));
}
}
if let Some(private_key) = service_account.get("private_key").and_then(|v| v.as_str()) {
if !private_key.starts_with("-----BEGIN PRIVATE KEY-----") {
self.add_error("Private key doesn't appear to be in valid PEM format".to_string());
}
}
}
fn validate_file_permissions(&mut self, path: &Path) {
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
if let Ok(metadata) = std::fs::metadata(path) {
let permissions = metadata.permissions();
let mode = permissions.mode();
if mode & 0o044 != 0 {
self.add_warning(format!(
"Service account file '{}' is readable by group/others (permissions: {:o}). \
Consider restricting permissions: chmod 600 '{}'",
path.display(), mode & 0o777, path.display()
));
}
if mode & 0o022 != 0 {
self.add_warning(format!(
"Service account file '{}' is writable by group/others (permissions: {:o}). \
Consider restricting permissions: chmod 600 '{}'",
path.display(), mode & 0o777, path.display()
));
}
}
}
#[cfg(not(unix))]
{
if let Err(e) = std::fs::File::open(path) {
self.add_error(format!(
"Cannot open service account file '{}': {}",
path.display(),
e
));
}
}
}
fn validate_streaming_config(&mut self) {
let streaming = &self.config.streaming;
if streaming.buffer_size == 0 {
self.add_error("Streaming buffer size cannot be zero".to_string());
} else if streaming.buffer_size < 1024 {
self.add_warning(format!(
"Small streaming buffer size ({} bytes) may impact performance",
streaming.buffer_size
));
} else if streaming.buffer_size > 10 * 1024 * 1024 {
self.add_warning(format!(
"Large streaming buffer size ({} bytes) may consume excessive memory",
streaming.buffer_size
));
}
if streaming.chunk_timeout_ms == 0 {
self.add_error("Streaming chunk timeout cannot be zero".to_string());
} else if streaming.chunk_timeout_ms < 100 {
self.add_warning(format!(
"Very short chunk timeout ({}ms) may cause premature timeouts",
streaming.chunk_timeout_ms
));
} else if streaming.chunk_timeout_ms > 60000 {
self.add_warning(format!(
"Long chunk timeout ({}ms) may cause poor user experience",
streaming.chunk_timeout_ms
));
}
match streaming.mode {
StreamingMode::Never => {
if streaming.buffer_size > 1024 * 1024 {
self.add_warning(
"Large buffer size not needed when streaming is disabled".to_string(),
);
}
}
StreamingMode::Buffered => {
if streaming.buffer_size < 4096 {
self.add_warning(
"Small buffer size may reduce effectiveness of buffered streaming"
.to_string(),
);
}
}
_ => {} }
tracing::debug!("Streaming config validation completed");
}
fn validate_security_requirements(&mut self) {
if self.config.server.log_level == LogLevel::Trace {
self.add_warning(
"Trace log level enabled: may log sensitive information in production".to_string(),
);
}
if !self.config.server.enable_retries {
self.add_warning(
"Retries are disabled: may impact reliability in production".to_string(),
);
}
tracing::debug!("Security validation completed");
}
fn add_error(&mut self, error: String) {
tracing::debug!("Validation error: {}", error);
self.errors.push(error);
}
fn add_warning(&mut self, warning: String) {
tracing::debug!("Validation warning: {}", warning);
self.warnings.push(warning);
}
}
#[allow(dead_code)]
pub fn validate_field<T, F>(value: &T, field_name: &str, validator: F) -> Result<()>
where
F: FnOnce(&T) -> Result<()>,
{
validator(value).map_err(|e| ProxyError::Config(format!("Invalid {}: {}", field_name, e)))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{AuthConfig, Config, ServerConfig, StreamingConfig, default_auth_strategy};
use std::fs;
use tempfile::TempDir;
fn create_test_config() -> Config {
Config {
server: ServerConfig {
port: 3000,
log_level: LogLevel::Info,
enable_retries: true,
max_retry_attempts: 3,
},
auth: AuthConfig {
service_account_file: None,
service_account_json: Some(r#"{"type":"service_account","project_id":"test","private_key_id":"123","private_key":"-----BEGIN PRIVATE KEY-----\ntest\n-----END PRIVATE KEY-----","client_email":"test@test.gserviceaccount.com","client_id":"123","auth_uri":"https://accounts.google.com/o/oauth2/auth","token_uri":"https://oauth2.googleapis.com/token"}"#.to_string()),
strategy: default_auth_strategy(),
},
streaming: StreamingConfig {
mode: StreamingMode::Auto,
buffer_size: 65536,
chunk_timeout_ms: 5000,
},
vertex: None,
llm_provider: None, }
}
#[test]
fn test_valid_config_passes_validation() {
let config = create_test_config();
let result = ConfigValidator::new(&config).validate();
assert!(result.is_ok(), "Valid config should pass validation");
}
#[test]
fn test_invalid_port_fails_validation() {
let mut config = create_test_config();
config.server.port = 0;
let result = ConfigValidator::new(&config).validate();
assert!(result.is_err());
let error_msg = format!("{}", result.unwrap_err());
assert!(error_msg.contains("Invalid server port 0"));
}
#[test]
fn test_missing_auth_fails_validation() {
let mut config = create_test_config();
config.auth.service_account_file = None;
config.auth.service_account_json = None;
let result = ConfigValidator::new(&config).validate();
assert!(result.is_err());
let error_msg = format!("{}", result.unwrap_err());
assert!(error_msg.contains("No service account configuration"));
}
#[test]
fn test_invalid_json_fails_validation() {
let mut config = create_test_config();
config.auth.service_account_json = Some("invalid json".to_string());
let result = ConfigValidator::new(&config).validate();
assert!(result.is_err());
let error_msg = format!("{}", result.unwrap_err());
assert!(error_msg.contains("Invalid service account JSON"));
}
#[test]
fn test_service_account_file_validation() {
let temp_dir = TempDir::new().unwrap();
let service_account_file = temp_dir.path().join("service-account.json");
let valid_json = r#"{"type":"service_account","project_id":"test","private_key_id":"123","private_key":"-----BEGIN PRIVATE KEY-----\ntest\n-----END PRIVATE KEY-----","client_email":"test@test.gserviceaccount.com","client_id":"123","auth_uri":"https://accounts.google.com/o/oauth2/auth","token_uri":"https://oauth2.googleapis.com/token"}"#;
fs::write(&service_account_file, valid_json).unwrap();
let mut config = create_test_config();
config.auth.service_account_file = Some(service_account_file.to_string_lossy().to_string());
config.auth.service_account_json = None;
let result = ConfigValidator::new(&config).validate();
assert!(result.is_ok(), "Valid service account file should pass validation");
}
#[test]
fn test_zero_buffer_size_fails_validation() {
let mut config = create_test_config();
config.streaming.buffer_size = 0;
let result = ConfigValidator::new(&config).validate();
assert!(result.is_err());
let error_msg = format!("{}", result.unwrap_err());
assert!(error_msg.contains("buffer size cannot be zero"));
}
#[test]
fn test_privileged_port_warning() {
let mut config = create_test_config();
config.server.port = 80;
let result = ConfigValidator::new(&config).validate();
assert!(result.is_ok(), "Config with privileged port should still be valid");
}
#[test]
fn test_validate_field_utility() {
let port = 8080u16;
let result = validate_field(&port, "port", |p| {
if *p == 0 { Err(ProxyError::Config("cannot be zero".to_string())) } else { Ok(()) }
});
assert!(result.is_ok());
let bad_port = 0u16;
let result = validate_field(&bad_port, "port", |p| {
if *p == 0 { Err(ProxyError::Config("cannot be zero".to_string())) } else { Ok(()) }
});
assert!(result.is_err());
}
}