use axum::{extract::State, http::StatusCode, response::Json, routing::get, Router};
use premortem::prelude::*;
use serde::{Deserialize, Serialize};
use std::{net::SocketAddr, sync::Arc};
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ServerConfig {
host: String,
port: u16,
tls_cert: Option<String>,
tls_key: Option<String>,
max_body_size_mb: u32,
request_timeout_secs: u64,
max_connections: u32,
idle_timeout_secs: u64,
api_prefix: String,
cors_allowed_origins: Vec<String>,
rate_limit_requests: u32,
rate_limit_window_secs: u64,
log_level: String,
log_format: String,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
host: "127.0.0.1".to_string(),
port: 3000,
tls_cert: None,
tls_key: None,
max_body_size_mb: 10,
request_timeout_secs: 30,
max_connections: 1000,
idle_timeout_secs: 60,
api_prefix: "/api/v1".to_string(),
cors_allowed_origins: vec![],
rate_limit_requests: 100,
rate_limit_window_secs: 60,
log_level: "info".to_string(),
log_format: "json".to_string(),
}
}
}
impl Validate for ServerConfig {
fn validate(&self) -> ConfigValidation<()> {
use stillwater::Validation;
let mut errors = Vec::new();
if self.host.is_empty() {
errors.push(ConfigError::ValidationError {
path: "host".to_string(),
source_location: None,
value: Some(self.host.clone()),
message: "host cannot be empty".to_string(),
});
}
if self.port == 0 {
errors.push(ConfigError::ValidationError {
path: "port".to_string(),
source_location: None,
value: Some(self.port.to_string()),
message: "port must be between 1 and 65535".to_string(),
});
}
match (&self.tls_cert, &self.tls_key) {
(Some(_), None) => {
errors.push(ConfigError::CrossFieldError {
paths: vec!["tls_cert".to_string(), "tls_key".to_string()],
message: "TLS certificate provided but TLS key is missing".to_string(),
});
}
(None, Some(_)) => {
errors.push(ConfigError::CrossFieldError {
paths: vec!["tls_cert".to_string(), "tls_key".to_string()],
message: "TLS key provided but TLS certificate is missing".to_string(),
});
}
_ => {}
}
if self.max_body_size_mb == 0 {
errors.push(ConfigError::ValidationError {
path: "max_body_size_mb".to_string(),
source_location: None,
value: Some(self.max_body_size_mb.to_string()),
message: "max body size must be at least 1 MB".to_string(),
});
}
if self.max_body_size_mb > 1024 {
errors.push(ConfigError::ValidationError {
path: "max_body_size_mb".to_string(),
source_location: None,
value: Some(self.max_body_size_mb.to_string()),
message: "max body size should not exceed 1024 MB (1 GB)".to_string(),
});
}
if self.request_timeout_secs == 0 {
errors.push(ConfigError::ValidationError {
path: "request_timeout_secs".to_string(),
source_location: None,
value: Some(self.request_timeout_secs.to_string()),
message: "request timeout must be at least 1 second".to_string(),
});
}
if self.request_timeout_secs > 3600 {
errors.push(ConfigError::ValidationError {
path: "request_timeout_secs".to_string(),
source_location: None,
value: Some(self.request_timeout_secs.to_string()),
message: "request timeout should not exceed 1 hour (3600 seconds)".to_string(),
});
}
if self.max_connections == 0 {
errors.push(ConfigError::ValidationError {
path: "max_connections".to_string(),
source_location: None,
value: Some(self.max_connections.to_string()),
message: "max connections must be at least 1".to_string(),
});
}
if !self.api_prefix.starts_with('/') {
errors.push(ConfigError::ValidationError {
path: "api_prefix".to_string(),
source_location: None,
value: Some(self.api_prefix.clone()),
message: "API prefix must start with '/'".to_string(),
});
}
let valid_log_levels = ["trace", "debug", "info", "warn", "error"];
if !valid_log_levels.contains(&self.log_level.to_lowercase().as_str()) {
errors.push(ConfigError::ValidationError {
path: "log_level".to_string(),
source_location: None,
value: Some(self.log_level.clone()),
message: "log level must be one of: trace, debug, info, warn, error".to_string(),
});
}
let valid_log_formats = ["json", "pretty", "compact"];
if !valid_log_formats.contains(&self.log_format.to_lowercase().as_str()) {
errors.push(ConfigError::ValidationError {
path: "log_format".to_string(),
source_location: None,
value: Some(self.log_format.clone()),
message: "log format must be one of: json, pretty, compact".to_string(),
});
}
if self.rate_limit_requests == 0 {
errors.push(ConfigError::ValidationError {
path: "rate_limit_requests".to_string(),
source_location: None,
value: Some(self.rate_limit_requests.to_string()),
message: "rate limit requests must be at least 1".to_string(),
});
}
if self.rate_limit_window_secs == 0 {
errors.push(ConfigError::ValidationError {
path: "rate_limit_window_secs".to_string(),
source_location: None,
value: Some(self.rate_limit_window_secs.to_string()),
message: "rate limit window must be at least 1 second".to_string(),
});
}
match ConfigErrors::from_vec(errors) {
Some(errs) => Validation::Failure(errs),
None => Validation::Success(()),
}
}
}
#[derive(Clone)]
struct AppState {
config: Arc<ServerConfig>,
}
#[derive(Serialize)]
struct HealthResponse {
status: &'static str,
version: &'static str,
}
#[derive(Serialize)]
struct ConfigInfoResponse {
host: String,
port: u16,
api_prefix: String,
max_connections: u32,
tls_enabled: bool,
}
async fn health_check() -> Json<HealthResponse> {
Json(HealthResponse {
status: "healthy",
version: env!("CARGO_PKG_VERSION"),
})
}
async fn config_info(State(state): State<AppState>) -> Json<ConfigInfoResponse> {
Json(ConfigInfoResponse {
host: state.config.host.clone(),
port: state.config.port,
api_prefix: state.config.api_prefix.clone(),
max_connections: state.config.max_connections,
tls_enabled: state.config.tls_cert.is_some(),
})
}
async fn api_root(State(state): State<AppState>) -> (StatusCode, String) {
(
StatusCode::OK,
format!(
"Welcome to the API! Server running on {}:{}",
state.config.host, state.config.port
),
)
}
fn build_router(state: AppState) -> Router {
let api_prefix = state.config.api_prefix.clone();
let api_routes = Router::new()
.route("/", get(api_root))
.route("/config", get(config_info));
Router::new()
.route("/health", get(health_check))
.nest(&api_prefix, api_routes)
.with_state(state)
}
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
println!("Loading server configuration...");
println!();
let config_result = Config::<ServerConfig>::builder()
.source(Defaults::from(ServerConfig::default()))
.source(Toml::file("config.toml").optional())
.source(Env::prefix("SERVER_"))
.build();
let config = match config_result {
Ok(config) => {
println!("Configuration validated successfully!");
println!();
config
}
Err(errors) => {
eprintln!("Configuration validation failed!");
eprintln!();
eprintln!("Found {} configuration error(s):", errors.len());
eprintln!();
for (i, error) in errors.iter().enumerate() {
eprintln!(" {}. {}", i + 1, error);
}
eprintln!();
eprintln!("Please fix the configuration errors and try again.");
std::process::exit(1);
}
};
println!("Server Configuration:");
println!(" Host: {}", config.host);
println!(" Port: {}", config.port);
println!(
" TLS: {}",
if config.tls_cert.is_some() {
"enabled"
} else {
"disabled"
}
);
println!(" Max Body Size: {} MB", config.max_body_size_mb);
println!(" Request Timeout: {}s", config.request_timeout_secs);
println!(" Max Connections: {}", config.max_connections);
println!(" API Prefix: {}", config.api_prefix);
println!(" Log Level: {}", config.log_level);
println!();
let server_config = config.clone();
let state = AppState {
config: Arc::new(server_config.into_inner()),
};
let app = build_router(state);
let addr: SocketAddr = format!("{}:{}", config.host, config.port)
.parse()
.expect("Invalid socket address");
println!("Starting server on http://{}", addr);
println!();
println!("Available endpoints:");
println!(" GET /health - Health check");
println!(
" GET {api_prefix}/ - API root",
api_prefix = config.api_prefix
);
println!(
" GET {api_prefix}/config - Configuration info",
api_prefix = config.api_prefix
);
println!();
let listener = tokio::net::TcpListener::bind(addr)
.await
.expect("Failed to bind");
axum::serve(listener, app).await.expect("Server error");
}
#[cfg(test)]
mod tests {
use super::*;
use stillwater::Validation;
#[test]
fn test_valid_config() {
let config = ServerConfig::default();
let result = config.validate();
assert!(matches!(result, Validation::Success(())));
}
#[test]
fn test_invalid_port() {
let config = ServerConfig {
port: 0,
..Default::default()
};
let result = config.validate();
assert!(matches!(result, Validation::Failure(_)));
}
#[test]
fn test_tls_requires_both_cert_and_key() {
let config = ServerConfig {
tls_cert: Some("/path/to/cert.pem".to_string()),
tls_key: None,
..Default::default()
};
let result = config.validate();
assert!(matches!(result, Validation::Failure(_)));
let config = ServerConfig {
tls_cert: None,
tls_key: Some("/path/to/key.pem".to_string()),
..Default::default()
};
let result = config.validate();
assert!(matches!(result, Validation::Failure(_)));
let config = ServerConfig {
tls_cert: Some("/path/to/cert.pem".to_string()),
tls_key: Some("/path/to/key.pem".to_string()),
..Default::default()
};
let result = config.validate();
assert!(matches!(result, Validation::Success(())));
}
#[test]
fn test_api_prefix_must_start_with_slash() {
let config = ServerConfig {
api_prefix: "api/v1".to_string(), ..Default::default()
};
let result = config.validate();
assert!(matches!(result, Validation::Failure(_)));
}
#[test]
fn test_invalid_log_level() {
let config = ServerConfig {
log_level: "verbose".to_string(), ..Default::default()
};
let result = config.validate();
assert!(matches!(result, Validation::Failure(_)));
}
#[test]
fn test_multiple_validation_errors_accumulated() {
let config = ServerConfig {
port: 0,
host: "".to_string(),
api_prefix: "no-slash".to_string(),
log_level: "invalid".to_string(),
..Default::default()
};
let result = config.validate();
if let Validation::Failure(errors) = result {
assert!(
errors.len() >= 4,
"Expected at least 4 errors, got {}",
errors.len()
);
} else {
panic!("Expected validation failure");
}
}
}