use std::net::SocketAddr;
use axum::Router as AxumRouter;
use tokio::net::TcpListener;
use crate::error::{ServiceError, ServiceResult};
use crate::router::ServiceRouter;
#[cfg(feature = "cache")]
use crate::middleware::cache::{CacheConfig, CacheLayer};
use super::ServerConfig;
pub struct AxumServer {
app: AxumRouter,
config: ServerConfig,
}
impl AxumServer {
pub fn new(router: ServiceRouter) -> Self {
Self::with_config(router, ServerConfig::default())
}
pub fn with_config(router: ServiceRouter, config: ServerConfig) -> Self {
let connect_router = router.into_inner();
let app = connect_router.into_axum_router();
Self { app, config }
}
pub fn with_app(mut self, app: AxumRouter) -> Self {
self.app = app;
self
}
pub fn merge(mut self, other: AxumRouter) -> Self {
self.app = self.app.merge(other);
self
}
#[cfg(feature = "cache")]
pub fn with_cache(mut self, config: CacheConfig) -> Self {
self.app = self.app.layer(CacheLayer::new(config));
self
}
pub fn config(&self) -> &ServerConfig {
&self.config
}
pub fn app(&self) -> AxumRouter {
self.app.clone()
}
pub async fn serve(self) -> ServiceResult<()> {
let listener = TcpListener::bind(self.config.addr)
.await
.map_err(|e| ServiceError::Internal(format!("bind {}: {}", self.config.addr, e)))?;
self.serve_with_listener(listener).await
}
pub async fn serve_with_listener(self, listener: TcpListener) -> ServiceResult<()> {
axum::serve(listener, self.app)
.await
.map_err(|e| ServiceError::Internal(format!("axum::serve: {}", e)))
}
pub async fn serve_with_shutdown<F>(self, shutdown: F) -> ServiceResult<()>
where
F: std::future::Future<Output = ()> + Send + 'static,
{
let listener = TcpListener::bind(self.config.addr)
.await
.map_err(|e| ServiceError::Internal(format!("bind {}: {}", self.config.addr, e)))?;
axum::serve(listener, self.app)
.with_graceful_shutdown(shutdown)
.await
.map_err(|e| ServiceError::Internal(format!("axum::serve: {}", e)))
}
}
pub async fn bind_random_port(host: &str) -> ServiceResult<(TcpListener, SocketAddr)> {
let listener = TcpListener::bind(format!("{host}:0"))
.await
.map_err(|e| ServiceError::Internal(format!("bind {host}:0: {}", e)))?;
let addr = listener
.local_addr()
.map_err(|e| ServiceError::Internal(format!("local_addr: {e}")))?;
Ok((listener, addr))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_axum_server_new() {
let server = AxumServer::new(ServiceRouter::new());
assert_eq!(server.config().name, "sunbeam-server");
}
#[tokio::test]
async fn test_bind_random_port_assigns_nonzero() {
let (_listener, addr) = bind_random_port("127.0.0.1").await.unwrap();
assert_ne!(addr.port(), 0);
}
#[tokio::test]
async fn test_serve_and_shutdown() {
let (listener, addr) = bind_random_port("127.0.0.1").await.unwrap();
let server = AxumServer::new(ServiceRouter::new());
let (tx, rx) = tokio::sync::oneshot::channel::<()>();
let handle = tokio::spawn(async move {
axum::serve(listener, server.app)
.with_graceful_shutdown(async move {
let _ = rx.await;
})
.await
.unwrap();
});
let url = format!("http://{addr}/__doesntexist");
let resp = reqwest::get(&url).await.unwrap();
assert!(resp.status().is_client_error() || resp.status().is_server_error());
let _ = tx.send(());
handle.await.unwrap();
}
}