1use std::{convert::Infallible, fmt::Display, net::SocketAddr, sync::Arc, time::Duration};
2
3pub mod init_log;
4pub mod util;
5type DynError = Box<dyn std::error::Error + Send + Sync>;
6use crate::util::{
7 io::{self, create_dual_stack_listener},
8 tls::{tls_config, TlsAcceptor},
9};
10use anyhow::anyhow;
11use axum::{
12 extract::Request,
13 response::{IntoResponse, Response},
14 Router,
15};
16
17use hyper::{body::Incoming, StatusCode};
18use hyper_util::rt::TokioExecutor;
19use log::{info, warn};
20use tokio::{pin, sync::broadcast, time};
21use tokio_rustls::rustls::ServerConfig;
22use tower::{Service, ServiceExt};
23
24const REFRESH_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
25const GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
26
27pub struct Server<I: ReqInterceptor = DummyInterceptor> {
28 pub port: u16,
29 pub tls_param: Option<TlsParam>,
30 router: Router,
31 pub interceptor: Option<I>,
32 pub idle_timeout: Duration,
33}
34
35#[derive(Debug, Clone)]
36pub struct TlsParam {
37 pub tls: bool,
38 pub cert: String,
39 pub key: String,
40}
41
42pub enum InterceptResult<T: IntoResponse> {
43 Return(Response),
44 Drop,
45 Continue(Request<Incoming>),
46 Error(T),
47}
48
49pub trait ReqInterceptor: Send {
50 type Error: IntoResponse + Send + Sync + 'static;
51 fn intercept(&self, req: Request<Incoming>, ip: SocketAddr) -> impl std::future::Future<Output = InterceptResult<Self::Error>> + Send;
52}
53
54#[derive(Clone)]
55pub struct DummyInterceptor;
56
57impl ReqInterceptor for DummyInterceptor {
58 type Error = AppError;
59
60 async fn intercept(&self, req: Request<Incoming>, _ip: SocketAddr) -> InterceptResult<Self::Error> {
61 InterceptResult::Continue(req)
62 }
63}
64
65pub type DefaultServer = Server<DummyInterceptor>;
66
67pub fn new_server(port: u16, router: Router) -> Server {
68 Server {
69 port,
70 tls_param: None, router,
72 interceptor: None,
73 idle_timeout: Duration::from_secs(120),
74 }
75}
76
77impl<I> Server<I>
78where
79 I: ReqInterceptor + Clone + Send + Sync + 'static,
80{
81 pub fn with_interceptor<R>(self: Server<I>, interceptor: R) -> Server<R>
82 where
83 R: ReqInterceptor + Clone + Send + Sync + 'static,
84 {
85 Server::<R> {
86 port: self.port,
87 tls_param: self.tls_param,
88 router: self.router,
89 interceptor: Some(interceptor),
90 idle_timeout: self.idle_timeout, }
92 }
93 pub fn with_tls_param(mut self, tls_param: Option<TlsParam>) -> Self {
94 self.tls_param = tls_param;
96 self
97 }
98
99 pub fn with_timeout(mut self, timeout: Duration) -> Self {
100 self.idle_timeout = timeout;
101 self
102 }
103
104 pub async fn run(&self) -> Result<(), DynError> {
105 let use_tls = match self.tls_param.clone() {
106 Some(config) => config.tls,
107 None => false,
108 };
109 log::info!("listening on port {}, use_tls: {}", self.port, use_tls);
110 let server: hyper_util::server::conn::auto::Builder<TokioExecutor> = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
111 let graceful: hyper_util::server::graceful::GracefulShutdown = hyper_util::server::graceful::GracefulShutdown::new();
112 match use_tls {
113 #[allow(clippy::expect_used)]
114 true => {
115 serve_tls(
116 &self.router,
117 server,
118 graceful,
119 self.port,
120 self.tls_param.as_ref().expect("should be some"),
121 self.interceptor.clone(),
122 self.idle_timeout,
123 )
124 .await?
125 }
126 false => serve_plantext(&self.router, server, graceful, self.port, self.interceptor.clone(), self.idle_timeout).await?,
127 }
128 Ok(())
129 }
130}
131
132async fn handle<I>(
133 request: Request<Incoming>, client_socket_addr: SocketAddr, app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>>,
134 interceptor: Option<I>,
135) -> std::result::Result<Response, std::io::Error>
136where
137 I: ReqInterceptor + Clone + Send + Sync + 'static,
138{
139 if let Some(interceptor) = interceptor {
140 match interceptor.intercept(request, client_socket_addr).await {
141 InterceptResult::Return(res) => Ok(res),
142 InterceptResult::Drop => Err(std::io::Error::new(std::io::ErrorKind::Other, "Request dropped by interceptor")),
143 InterceptResult::Continue(req) => app
144 .oneshot(req)
145 .await
146 .map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err)),
147 InterceptResult::Error(err) => {
148 let res = err.into_response();
149 Ok(res)
150 }
151 }
152 } else {
153 app.oneshot(request)
154 .await
155 .map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err))
156 }
157}
158
159async fn handle_connection<C, I>(
160 conn: C, client_socket_addr: std::net::SocketAddr, app: Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>,
161 interceptor: Option<I>, graceful: &hyper_util::server::graceful::GracefulShutdown, timeout: Duration,
162) where
163 C: tokio::io::AsyncRead + tokio::io::AsyncWrite + 'static + Send + Sync,
164 I: ReqInterceptor + Clone + Send + Sync + 'static,
165{
166 let timeout_io = Box::pin(io::TimeoutIO::new(conn, timeout));
167 use hyper::Request;
168 use hyper_util::rt::TokioIo;
169 let stream = TokioIo::new(timeout_io);
170 let mut app = app.into_make_service_with_connect_info::<SocketAddr>();
171 let app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>> = unwrap_infallible(app.call(client_socket_addr).await);
172 let hyper_service = hyper::service::service_fn(move |request: Request<hyper::body::Incoming>| {
174 handle(request, client_socket_addr, app.clone(), interceptor.clone())
175 });
176
177 let conn = server.serve_connection_with_upgrades(stream, hyper_service);
178 let conn = graceful.watch(conn.into_owned());
179
180 tokio::spawn(async move {
181 if let Err(err) = conn.await {
182 info!("connection error: {}", err);
183 }
184 log::debug!("connection dropped: {}", client_socket_addr);
185 });
186}
187
188async fn serve_plantext<I>(
189 app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
190 port: u16, interceptor: Option<I>, timeout: Duration,
191) -> Result<(), DynError>
192where
193 I: ReqInterceptor + Clone + Send + Sync + 'static,
194{
195 let listener = create_dual_stack_listener(port).await?;
196 let signal = handle_signal();
197 pin!(signal);
198 loop {
199 tokio::select! {
200 _ = signal.as_mut() => {
201 info!("start graceful shutdown!");
202 drop(listener);
203 break;
204 }
205 conn = listener.accept() => {
206 match conn {
207 Ok((conn, client_socket_addr)) => {
208 handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
209 Err(e) => {
210 warn!("accept error:{}", e);
211 }
212 }
213 }
214 }
215 }
216 tokio::select! {
217 _ = graceful.shutdown() => {
218 info!("Gracefully shutdown!");
219 },
220 _ = tokio::time::sleep(GRACEFUL_SHUTDOWN_TIMEOUT) => {
221 info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting...");
222 }
223 }
224 Ok(())
225}
226
227async fn serve_tls<I>(
228 app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
229 port: u16, tls_param: &TlsParam, interceptor: Option<I>, timeout: Duration,
230) -> Result<(), DynError>
231where
232 I: ReqInterceptor + Clone + Send + Sync + 'static,
233{
234 let (tx, _rx) = broadcast::channel::<Arc<ServerConfig>>(10);
235 let tx_clone = tx.clone();
236 let tls_param_clone = tls_param.clone();
237 tokio::spawn(async move {
238 info!("update tls config every {:?}", REFRESH_INTERVAL);
239 loop {
240 time::sleep(REFRESH_INTERVAL).await;
241 if let Ok(new_acceptor) = tls_config(&tls_param_clone.key, &tls_param_clone.cert) {
242 info!("update tls config");
243 if let Err(e) = tx.send(new_acceptor) {
244 warn!("send tls config error:{}", e);
245 }
246 }
247 }
248 });
249 let mut rx = tx_clone.subscribe();
250 let mut acceptor: TlsAcceptor = TlsAcceptor::new(tls_config(&tls_param.key, &tls_param.cert)?, create_dual_stack_listener(port).await?);
251 let signal = handle_signal();
252 pin!(signal);
253 loop {
254 tokio::select! {
255 _ = signal.as_mut() => {
256 info!("start graceful shutdown!");
257 drop(acceptor);
258 break;
259 }
260 message = rx.recv() => {
261 #[allow(clippy::expect_used)]
262 let new_config = message.expect("Channel should not be closed");
263 acceptor.replace_config(new_config);
265 info!("replaced tls config");
266 }
267 conn = acceptor.accept() => {
268 match conn {
269 Ok((conn, client_socket_addr)) => {
270 handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
271 Err(e) => {
272 warn!("accept error:{}", e);
273 }
274 }
275 }
276 }
277 }
278 tokio::select! {
279 _ = graceful.shutdown() => {
280 info!("Gracefully shutdown!");
281 },
282 _ = tokio::time::sleep(GRACEFUL_SHUTDOWN_TIMEOUT) => {
283 info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting...");
284 }
285 }
286 Ok(())
287}
288
289#[cfg(unix)]
290async fn handle_signal() -> Result<(), DynError> {
291 use log::info;
292 use tokio::signal::unix::{signal, SignalKind};
293 let mut terminate_signal = signal(SignalKind::terminate())?;
294 tokio::select! {
295 _ = terminate_signal.recv() => {
296 info!("receive terminate signal, shutdowning");
297 },
298 _ = tokio::signal::ctrl_c() => {
299 info!("ctrl_c => shutdowning");
300 },
301 };
302 Ok(())
303}
304
305#[cfg(windows)]
306async fn handle_signal() -> Result<(), DynError> {
307 let _ = tokio::signal::ctrl_c().await;
308 info!("ctrl_c => shutdowning");
309 Ok(())
310}
311
312#[derive(Debug)]
314pub struct AppError(anyhow::Error);
315
316impl Display for AppError {
317 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
318 self.0.fmt(f)
319 }
320}
321
322impl IntoResponse for AppError {
324 fn into_response(self) -> Response {
325 let err = self.0;
326 tracing::error!(%err, "error");
329 (StatusCode::INTERNAL_SERVER_ERROR, format!("axum-bootstrap error: {}", &err)).into_response()
330 }
331}
332
333impl<E> From<E> for AppError
336where
337 E: Into<anyhow::Error>,
338{
339 fn from(err: E) -> Self {
340 Self(err.into())
341 }
342}
343
344impl AppError {
345 pub fn new<T: std::error::Error + Send + Sync + 'static>(err: T) -> Self {
346 Self(anyhow!(err))
347 }
348}
349
350fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
351 match result {
352 Ok(value) => value,
353 Err(err) => match err {},
354 }
355}