use crate::config::ServerConfig;
use axum::Router;
use std::net::SocketAddr;
#[cfg(feature = "server")]
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
pub struct ServerBootstrap<S = ()> {
state: Option<S>,
router_builder: Option<Box<dyn FnOnce(S) -> Router + Send>>,
host: String,
port: u16,
tracing_initialized: bool,
tracing_filter: Option<String>,
}
impl ServerBootstrap<()> {
pub fn new() -> Self {
Self {
state: None,
router_builder: None,
host: "0.0.0.0".to_string(),
port: 8080,
tracing_initialized: false,
tracing_filter: None,
}
}
}
impl<S> ServerBootstrap<S>
where
S: Clone + Send + 'static,
{
pub fn with_state(state: S) -> Self {
Self {
state: Some(state),
router_builder: None,
host: "0.0.0.0".to_string(),
port: 8080,
tracing_initialized: false,
tracing_filter: None,
}
}
pub fn with_host(mut self, host: impl Into<String>) -> Self {
self.host = host.into();
self
}
pub fn with_port(mut self, port: u16) -> Self {
self.port = port;
self
}
pub fn with_server_config(mut self, config: ServerConfig) -> Self {
self.host = config.host;
self.port = config.port;
self
}
pub fn with_tracing(mut self, app_name: &str) -> Self {
let filter = format!("{}=debug,tower_http=debug", app_name);
self.tracing_filter = Some(filter);
self
}
pub fn with_tracing_filter(mut self, filter: impl Into<String>) -> Self {
self.tracing_filter = Some(filter.into());
self
}
pub fn with_router<F>(mut self, builder: F) -> Self
where
F: FnOnce(S) -> Router + Send + 'static,
{
self.router_builder = Some(Box::new(builder));
self
}
pub async fn serve(mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
#[cfg(feature = "server")]
if let Some(filter) = self.tracing_filter.take()
&& !self.tracing_initialized {
tracing_subscriber::registry()
.with(
EnvFilter::try_from_default_env()
.unwrap_or_else(|_| filter.into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
self.tracing_initialized = true;
}
#[cfg(feature = "server")]
dotenvy::dotenv().ok();
let state = self.state.ok_or("State is required")?;
let router_builder = self
.router_builder
.ok_or("Router builder is required")?;
let app = router_builder(state);
let addr: SocketAddr = format!("{}:{}", self.host, self.port)
.parse()
.map_err(|e| format!("Invalid socket address: {}", e))?;
tracing::info!("Starting server on {}", addr);
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app).await?;
Ok(())
}
}
impl Default for ServerBootstrap<()> {
fn default() -> Self {
Self::new()
}
}
pub async fn quick_start(
app_name: &str,
router: Router,
port: u16,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
ServerBootstrap::with_state(())
.with_tracing(app_name)
.with_port(port)
.with_router(|_| router)
.serve()
.await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_server_bootstrap_builder() {
let bootstrap = ServerBootstrap::new()
.with_host("127.0.0.1")
.with_port(3000);
assert_eq!(bootstrap.host, "127.0.0.1");
assert_eq!(bootstrap.port, 3000);
}
#[test]
fn test_server_bootstrap_with_state() {
#[derive(Clone)]
struct TestState {
value: i32,
}
let state = TestState { value: 42 };
let bootstrap = ServerBootstrap::with_state(state);
assert!(bootstrap.state.is_some());
assert_eq!(bootstrap.state.unwrap().value, 42);
}
#[test]
fn test_server_config_integration() {
let config = ServerConfig {
host: "192.168.1.1".to_string(),
port: 9000,
cors_origins: vec![],
};
let bootstrap = ServerBootstrap::new().with_server_config(config);
assert_eq!(bootstrap.host, "192.168.1.1");
assert_eq!(bootstrap.port, 9000);
}
}