use std::{future::Future, net::SocketAddr, sync::Arc};
use axum::Router;
use crate::{
core::{CoreError, Service, ServiceFuture, ShutdownToken},
rest::{RestConfig, RestLayerStack},
};
#[derive(Debug, Clone)]
pub struct RestServer {
config: RestConfig,
router: Router,
}
impl RestServer {
pub fn new(config: RestConfig, router: Router) -> Self {
Self { config, router }
}
pub fn raw_router(&self) -> Router {
self.router.clone()
}
pub fn into_router(self) -> Router {
RestLayerStack::new(self.config).layer(self.router)
}
pub async fn serve_with_shutdown<F>(self, addr: SocketAddr, shutdown: F) -> std::io::Result<()>
where
F: Future<Output = ()> + Send + 'static,
{
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, self.into_router())
.with_graceful_shutdown(shutdown)
.await
}
}
pub struct RestService {
name: String,
addr: SocketAddr,
server: std::sync::Mutex<Option<RestServer>>,
}
impl RestService {
pub fn new(name: impl Into<String>, addr: SocketAddr, server: RestServer) -> Self {
Self {
name: name.into(),
addr,
server: std::sync::Mutex::new(Some(server)),
}
}
pub fn addr(&self) -> SocketAddr {
self.addr
}
}
impl Service for RestService {
fn name(&self) -> &str {
&self.name
}
fn start(&self, shutdown: ShutdownToken) -> ServiceFuture<'_> {
Box::pin(async move {
let server = self
.server
.lock()
.expect("rest service mutex")
.take()
.ok_or_else(|| {
CoreError::Service(format!("service {} already started", self.name))
})?;
server
.serve_with_shutdown(self.addr, async move {
shutdown.cancelled().await;
})
.await
.map_err(|error| {
CoreError::Service(format!("REST service {} failed: {error}", self.name))
})
})
}
}
impl Service for Arc<RestService> {
fn name(&self) -> &str {
self.as_ref().name()
}
fn start(&self, shutdown: ShutdownToken) -> ServiceFuture<'_> {
self.as_ref().start(shutdown)
}
fn stop(&self) -> ServiceFuture<'_> {
self.as_ref().stop()
}
}
#[cfg(test)]
mod tests {
use super::RestServer;
use crate::rest::{ApiResponse, RestConfig};
use axum::{Router, routing::get};
use tower::ServiceExt;
#[tokio::test]
async fn server_builds_router() {
let router = Router::new().route("/ready", get(|| async { ApiResponse::success("ok") }));
let service = RestServer::new(RestConfig::default(), router).into_router();
let response = service
.oneshot(
axum::http::Request::builder()
.uri("/ready")
.body(axum::body::Body::empty())
.expect("request"),
)
.await
.expect("response");
assert_eq!(response.status(), axum::http::StatusCode::OK);
}
}