use {crate::{error::{JetError,
Result},
middleware::{RequestIdLayer,
smart_trace_layer}},
axum::Router,
std::net::SocketAddr,
tokio::net::TcpListener,
tower_http::{compression::CompressionLayer,
cors::{Any,
CorsLayer}},
tracing::info};
#[cfg(feature = "metrics")]
use {crate::{metrics::{MetricsGuard,
metrics_router},
middleware::HttpMetricsLayer},
std::sync::Arc};
pub struct JetServer {
router: Router,
cors_layer: Option<CorsLayer>,
compression_enabled: bool,
tracing_enabled: bool,
request_id_enabled: bool,
}
impl JetServer {
pub fn new() -> Self {
Self {
router: Router::new(),
cors_layer: None,
compression_enabled: false,
tracing_enabled: false,
request_id_enabled: false,
}
}
pub fn route(mut self, path: &str, method_router: axum::routing::MethodRouter) -> Self {
self.router = self.router.route(path, method_router);
self
}
pub fn merge(mut self, router: Router) -> Self {
self.router = self.router.merge(router);
self
}
pub fn nest(mut self, path: &str, router: Router) -> Self {
self.router = self.router.nest(path, router);
self
}
pub fn with_cors(mut self) -> Self {
self.cors_layer = Some(CorsLayer::new().allow_origin(Any).allow_methods(Any).allow_headers(Any));
self
}
pub fn with_cors_config(mut self, cors: CorsLayer) -> Self {
self.cors_layer = Some(cors);
self
}
pub fn with_compression(mut self) -> Self {
self.compression_enabled = true;
self
}
pub fn with_request_id(mut self) -> Self {
self.request_id_enabled = true;
self
}
pub fn with_tracing(mut self) -> Self {
self.tracing_enabled = true;
self
}
#[cfg(feature = "metrics")]
pub fn with_metrics(self, guard: Arc<MetricsGuard>, path: &str) -> Self {
self.merge(metrics_router(guard, path)).layer(HttpMetricsLayer::new())
}
pub fn layer<L>(mut self, layer: L) -> Self
where
L: tower::Layer<axum::routing::Route> + Clone + Send + 'static,
L::Service: tower::Service<axum::extract::Request, Response = axum::response::Response, Error = std::convert::Infallible>
+ Clone
+ Send
+ 'static,
<L::Service as tower::Service<axum::extract::Request>>::Future: Send + 'static, {
self.router = self.router.layer(layer);
self
}
pub fn into_router(self) -> Router {
self.router
}
pub async fn serve(self, addr: &str) -> Result<()> {
let mut router = self.router;
if let Some(cors) = self.cors_layer {
router = router.layer(cors);
}
if self.compression_enabled {
router = router.layer(CompressionLayer::new());
}
if self.request_id_enabled {
router = router.layer(RequestIdLayer::new());
}
if self.tracing_enabled {
router = router.layer(smart_trace_layer());
}
let addr: SocketAddr = addr
.parse()
.map_err(|e| JetError::ServerBind(format!("Invalid address: {}", e)))?;
let listener = TcpListener::bind(addr)
.await
.map_err(|e| JetError::ServerBind(format!("Failed to bind: {}", e)))?;
info!("AT-Jet server listening on {}", addr);
axum::serve(listener, router)
.await
.map_err(|e| JetError::ServerBind(format!("Server error: {}", e)))?;
Ok(())
}
pub async fn serve_with_shutdown(self, addr: &str) -> Result<()> {
let mut router = self.router;
if let Some(cors) = self.cors_layer {
router = router.layer(cors);
}
if self.compression_enabled {
router = router.layer(CompressionLayer::new());
}
if self.request_id_enabled {
router = router.layer(RequestIdLayer::new());
}
if self.tracing_enabled {
router = router.layer(smart_trace_layer());
}
let addr: SocketAddr = addr
.parse()
.map_err(|e| JetError::ServerBind(format!("Invalid address: {}", e)))?;
let listener = TcpListener::bind(addr)
.await
.map_err(|e| JetError::ServerBind(format!("Failed to bind: {}", e)))?;
info!("AT-Jet server listening on {}", addr);
let http_handle = tokio::spawn(async move { axum::serve(listener, router).await });
tokio::select! {
_ = tokio::signal::ctrl_c() => {
info!("Received shutdown signal");
}
result = http_handle => {
match result {
| Ok(Ok(())) => info!("HTTP server exited"),
| Ok(Err(e)) => tracing::error!(error = %e, "HTTP server failed"),
| Err(e) => tracing::error!(error = %e, "HTTP server task panicked"),
}
}
}
Ok(())
}
#[deprecated(since = "0.8.0", note = "Use StartupBanner from at_jet::startup instead")]
pub fn print_banner(service_name: &str, version: &str, extras: &[(&str, &str)]) {
let mut banner = crate::startup::StartupBanner::new(service_name, version);
for (key, value) in extras {
banner = banner.kv(key, value);
}
banner.print();
}
}
impl Default for JetServer {
fn default() -> Self {
Self::new()
}
}