1use std::{convert::Infallible, fmt::Display, net::SocketAddr, sync::Arc, time::Duration};
2
3pub mod init_log;
4pub mod util;
5type DynError = Box<dyn std::error::Error + Send + Sync>;
6use crate::util::{
7 io::{self, create_dual_stack_listener},
8 tls::{tls_config, TlsAcceptor},
9};
10use anyhow::anyhow;
11use axum::{
12 extract::Request,
13 response::{IntoResponse, Response},
14 Router,
15};
16
17use hyper::{body::Incoming, StatusCode};
18use hyper_util::rt::TokioExecutor;
19use log::{info, warn};
20use tokio::{pin, sync::broadcast, time};
21use tokio_rustls::rustls::ServerConfig;
22use tower::{Service, ServiceExt};
23
24const REFRESH_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
25const GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
26
27pub struct Server<I: ReqInterceptor = DummyInterceptor> {
28 pub port: u16,
29 pub tls_param: Option<TlsParam>,
30 router: Router,
31 pub interceptor: Option<I>,
32 pub idle_timeout: Duration,
33}
34
35#[derive(Debug, Clone)]
36pub struct TlsParam {
37 pub tls: bool,
38 pub cert: String,
39 pub key: String,
40}
41
42pub enum InterceptResult {
43 Return(Response),
44 Continue(Request<Incoming>),
45 Error(AppError),
46}
47
48pub trait ReqInterceptor {
49 fn intercept(&self, req: Request<Incoming>, ip: SocketAddr) -> impl std::future::Future<Output = InterceptResult> + Send;
50}
51
52#[derive(Clone)]
53pub struct DummyInterceptor;
54
55impl ReqInterceptor for DummyInterceptor {
56 async fn intercept(&self, req: Request<Incoming>, _ip: SocketAddr) -> InterceptResult {
57 InterceptResult::Continue(req)
58 }
59}
60
61pub type DefaultServer = Server<DummyInterceptor>;
62
63pub fn new_server(port: u16, router: Router) -> Server {
64 Server {
65 port,
66 tls_param: None, router,
68 interceptor: None,
69 idle_timeout: Duration::from_secs(120),
70 }
71}
72
73impl<I> Server<I>
74where
75 I: ReqInterceptor + Clone + Send + Sync + 'static,
76{
77 pub fn with_interceptor<R>(self: Server<I>, interceptor: R) -> Server<R>
78 where
79 R: ReqInterceptor + Clone + Send + Sync + 'static,
80 {
81 Server::<R> {
82 port: self.port,
83 tls_param: self.tls_param,
84 router: self.router,
85 interceptor: Some(interceptor),
86 idle_timeout: self.idle_timeout, }
88 }
89 pub fn with_tls_param(mut self, tls_param: Option<TlsParam>) -> Self {
90 self.tls_param = tls_param;
92 self
93 }
94
95 pub fn with_timeout(mut self, timeout: Duration) -> Self {
96 self.idle_timeout = timeout;
97 self
98 }
99
100 pub async fn run(&self) -> Result<(), DynError> {
101 let use_tls = match self.tls_param.clone() {
102 Some(config) => config.tls,
103 None => false,
104 };
105 log::info!("listening on port {}, use_tls: {}", self.port, use_tls);
106 let server: hyper_util::server::conn::auto::Builder<TokioExecutor> = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
107 let graceful: hyper_util::server::graceful::GracefulShutdown = hyper_util::server::graceful::GracefulShutdown::new();
108 match use_tls {
109 #[allow(clippy::expect_used)]
110 true => {
111 serve_tls(
112 &self.router,
113 server,
114 graceful,
115 self.port,
116 self.tls_param.as_ref().expect("should be some"),
117 self.interceptor.clone(),
118 self.idle_timeout,
119 )
120 .await?
121 }
122 false => serve_plantext(&self.router, server, graceful, self.port, self.interceptor.clone(), self.idle_timeout).await?,
123 }
124 Ok(())
125 }
126}
127
128async fn handle<I>(
129 request: Request<Incoming>, client_socket_addr: SocketAddr, app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>>,
130 interceptor: Option<I>,
131) -> std::result::Result<Response, std::convert::Infallible>
132where
133 I: ReqInterceptor + Clone + Send + Sync + 'static,
134{
135 if let Some(interceptor) = interceptor {
136 match interceptor.intercept(request, client_socket_addr).await {
137 InterceptResult::Continue(req) => app.oneshot(req).await,
138 InterceptResult::Return(res) => Ok(res),
139 InterceptResult::Error(err) => {
140 let res = err.into_response();
141 Ok(res)
142 }
143 }
144 } else {
145 app.oneshot(request).await
146 }
147}
148
149async fn handle_connection<C, I>(
150 conn: C, client_socket_addr: std::net::SocketAddr, app: Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>,
151 interceptor: Option<I>, graceful: &hyper_util::server::graceful::GracefulShutdown, timeout: Duration,
152) where
153 C: tokio::io::AsyncRead + tokio::io::AsyncWrite + 'static + Send + Sync,
154 I: ReqInterceptor + Clone + Send + Sync + 'static,
155{
156 let timeout_io = Box::pin(io::TimeoutIO::new(conn, timeout));
157 use hyper::Request;
158 use hyper_util::rt::TokioIo;
159 let stream = TokioIo::new(timeout_io);
160 let mut app = app.into_make_service_with_connect_info::<SocketAddr>();
161 let app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>> = unwrap_infallible(app.call(client_socket_addr).await);
162 let hyper_service = hyper::service::service_fn(move |request: Request<hyper::body::Incoming>| {
164 handle(request, client_socket_addr, app.clone(), interceptor.clone())
165 });
166
167 let conn = server.serve_connection_with_upgrades(stream, hyper_service);
168 let conn = graceful.watch(conn.into_owned());
169
170 tokio::spawn(async move {
171 if let Err(err) = conn.await {
172 info!("connection error: {}", err);
173 }
174 log::debug!("connection dropped: {}", client_socket_addr);
175 });
176}
177
178async fn serve_plantext<I>(
179 app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
180 port: u16, interceptor: Option<I>, timeout: Duration,
181) -> Result<(), DynError>
182where
183 I: ReqInterceptor + Clone + Send + Sync + 'static,
184{
185 let listener = create_dual_stack_listener(port).await?;
186 let signal = handle_signal();
187 pin!(signal);
188 loop {
189 tokio::select! {
190 _ = signal.as_mut() => {
191 info!("start graceful shutdown!");
192 drop(listener);
193 break;
194 }
195 conn = listener.accept() => {
196 match conn {
197 Ok((conn, client_socket_addr)) => {
198 handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
199 Err(e) => {
200 warn!("accept error:{}", e);
201 }
202 }
203 }
204 }
205 }
206 tokio::select! {
207 _ = graceful.shutdown() => {
208 info!("Gracefully shutdown!");
209 },
210 _ = tokio::time::sleep(GRACEFUL_SHUTDOWN_TIMEOUT) => {
211 info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting...");
212 }
213 }
214 Ok(())
215}
216
217async fn serve_tls<I>(
218 app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
219 port: u16, tls_param: &TlsParam, interceptor: Option<I>, timeout: Duration,
220) -> Result<(), DynError>
221where
222 I: ReqInterceptor + Clone + Send + Sync + 'static,
223{
224 let (tx, _rx) = broadcast::channel::<Arc<ServerConfig>>(10);
225 let tx_clone = tx.clone();
226 let tls_param_clone = tls_param.clone();
227 tokio::spawn(async move {
228 info!("update tls config every {:?}", REFRESH_INTERVAL);
229 loop {
230 time::sleep(REFRESH_INTERVAL).await;
231 if let Ok(new_acceptor) = tls_config(&tls_param_clone.key, &tls_param_clone.cert) {
232 info!("update tls config");
233 if let Err(e) = tx.send(new_acceptor) {
234 warn!("send tls config error:{}", e);
235 }
236 }
237 }
238 });
239 let mut rx = tx_clone.subscribe();
240 let mut acceptor: TlsAcceptor = TlsAcceptor::new(tls_config(&tls_param.key, &tls_param.cert)?, create_dual_stack_listener(port).await?);
241 let signal = handle_signal();
242 pin!(signal);
243 loop {
244 tokio::select! {
245 _ = signal.as_mut() => {
246 info!("start graceful shutdown!");
247 drop(acceptor);
248 break;
249 }
250 message = rx.recv() => {
251 #[allow(clippy::expect_used)]
252 let new_config = message.expect("Channel should not be closed");
253 acceptor.replace_config(new_config);
255 info!("replaced tls config");
256 }
257 conn = acceptor.accept() => {
258 match conn {
259 Ok((conn, client_socket_addr)) => {
260 handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
261 Err(e) => {
262 warn!("accept error:{}", e);
263 }
264 }
265 }
266 }
267 }
268 tokio::select! {
269 _ = graceful.shutdown() => {
270 info!("Gracefully shutdown!");
271 },
272 _ = tokio::time::sleep(GRACEFUL_SHUTDOWN_TIMEOUT) => {
273 info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting...");
274 }
275 }
276 Ok(())
277}
278
279#[cfg(unix)]
280async fn handle_signal() -> Result<(), DynError> {
281 use log::info;
282 use tokio::signal::unix::{signal, SignalKind};
283 let mut terminate_signal = signal(SignalKind::terminate())?;
284 tokio::select! {
285 _ = terminate_signal.recv() => {
286 info!("receive terminate signal, shutdowning");
287 },
288 _ = tokio::signal::ctrl_c() => {
289 info!("ctrl_c => shutdowning");
290 },
291 };
292 Ok(())
293}
294
295#[cfg(windows)]
296async fn handle_signal() -> Result<(), DynError> {
297 let _ = tokio::signal::ctrl_c().await;
298 info!("ctrl_c => shutdowning");
299 Ok(())
300}
301
302#[derive(Debug)]
304pub struct AppError(anyhow::Error);
305
306impl Display for AppError {
307 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308 self.0.fmt(f)
309 }
310}
311
312impl IntoResponse for AppError {
314 fn into_response(self) -> Response {
315 let err = self.0;
316 tracing::error!(%err, "error");
319 (StatusCode::INTERNAL_SERVER_ERROR, format!("axum-bootstrap error: {}", &err)).into_response()
320 }
321}
322
323impl<E> From<E> for AppError
326where
327 E: Into<anyhow::Error>,
328{
329 fn from(err: E) -> Self {
330 Self(err.into())
331 }
332}
333
334impl AppError {
335 pub fn new<T: std::error::Error + Send + Sync + 'static>(err: T) -> Self {
336 Self(anyhow!(err))
337 }
338}
339
340fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
341 match result {
342 Ok(value) => value,
343 Err(err) => match err {},
344 }
345}