1use std::{future::Future, net::SocketAddr, sync::Arc};
2
3use axum::Router;
4
5use crate::{
6 core::{CoreError, Service, ServiceFuture, ShutdownToken},
7 rest::{RestConfig, RestLayerStack},
8};
9
10#[derive(Debug, Clone)]
12pub struct RestServer {
13 config: RestConfig,
14 router: Router,
15}
16
17impl RestServer {
18 pub fn new(config: RestConfig, router: Router) -> Self {
20 Self { config, router }
21 }
22
23 pub fn raw_router(&self) -> Router {
25 self.router.clone()
26 }
27
28 pub fn into_router(self) -> Router {
30 RestLayerStack::new(self.config).layer(self.router)
31 }
32
33 pub async fn serve_with_shutdown<F>(self, addr: SocketAddr, shutdown: F) -> std::io::Result<()>
35 where
36 F: Future<Output = ()> + Send + 'static,
37 {
38 let listener = tokio::net::TcpListener::bind(addr).await?;
39 axum::serve(listener, self.into_router())
40 .with_graceful_shutdown(shutdown)
41 .await
42 }
43}
44
45pub struct RestService {
47 name: String,
48 addr: SocketAddr,
49 server: std::sync::Mutex<Option<RestServer>>,
50}
51
52impl RestService {
53 pub fn new(name: impl Into<String>, addr: SocketAddr, server: RestServer) -> Self {
55 Self {
56 name: name.into(),
57 addr,
58 server: std::sync::Mutex::new(Some(server)),
59 }
60 }
61
62 pub fn addr(&self) -> SocketAddr {
64 self.addr
65 }
66}
67
68impl Service for RestService {
69 fn name(&self) -> &str {
70 &self.name
71 }
72
73 fn start(&self, shutdown: ShutdownToken) -> ServiceFuture<'_> {
74 Box::pin(async move {
75 let server = self
76 .server
77 .lock()
78 .expect("rest service mutex")
79 .take()
80 .ok_or_else(|| {
81 CoreError::Service(format!("service {} already started", self.name))
82 })?;
83 server
84 .serve_with_shutdown(self.addr, async move {
85 shutdown.cancelled().await;
86 })
87 .await
88 .map_err(|error| {
89 CoreError::Service(format!("REST service {} failed: {error}", self.name))
90 })
91 })
92 }
93}
94
95impl Service for Arc<RestService> {
96 fn name(&self) -> &str {
97 self.as_ref().name()
98 }
99
100 fn start(&self, shutdown: ShutdownToken) -> ServiceFuture<'_> {
101 self.as_ref().start(shutdown)
102 }
103
104 fn stop(&self) -> ServiceFuture<'_> {
105 self.as_ref().stop()
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::RestServer;
112 use crate::rest::{ApiResponse, RestConfig};
113 use axum::{Router, routing::get};
114 use tower::ServiceExt;
115
116 #[tokio::test]
117 async fn server_builds_router() {
118 let router = Router::new().route("/ready", get(|| async { ApiResponse::success("ok") }));
119
120 let service = RestServer::new(RestConfig::default(), router).into_router();
121 let response = service
122 .oneshot(
123 axum::http::Request::builder()
124 .uri("/ready")
125 .body(axum::body::Body::empty())
126 .expect("request"),
127 )
128 .await
129 .expect("response");
130
131 assert_eq!(response.status(), axum::http::StatusCode::OK);
132 }
133}