1use std::{convert::Infallible, net::SocketAddr, sync::Arc, time::Duration};
2pub mod error;
3pub mod init_log;
4#[cfg(feature = "jwt")]
5pub mod jwt;
6pub mod util;
7type DynError = Box<dyn std::error::Error + Send + Sync>;
8use crate::util::{
9 io::{self, create_dual_stack_listener},
10 tls::{TlsAcceptor, tls_config},
11};
12
13use axum::{
14 Router,
15 extract::Request,
16 response::{IntoResponse, Response},
17};
18
19use hyper::body::Incoming;
20use hyper_util::rt::TokioExecutor;
21use log::{info, warn};
22use tokio::{
23 sync::broadcast::{self, Receiver, Sender, error::RecvError},
24 time,
25};
26use tokio_rustls::rustls::ServerConfig;
27use tower::{Service, ServiceExt};
28use util::format::SocketAddrFormat;
29
30const REFRESH_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
31const GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
32
33pub struct Server<I: ReqInterceptor = DummyInterceptor> {
34 pub port: u16,
35 pub tls_param: Option<TlsParam>,
36 router: Router,
37 pub interceptor: Option<I>,
38 pub idle_timeout: Duration,
39 shutdown_rx: broadcast::Receiver<()>,
40}
41
42#[derive(Debug, Clone)]
43pub struct TlsParam {
44 pub tls: bool,
45 pub cert: String,
46 pub key: String,
47}
48
49pub enum InterceptResult<T: IntoResponse> {
50 Return(Response),
51 Drop,
52 Continue(Request<Incoming>),
53 Error(T),
54}
55
56pub trait ReqInterceptor: Send {
57 type Error: IntoResponse + Send + Sync + 'static;
58 fn intercept(&self, req: Request<Incoming>, ip: SocketAddr) -> impl std::future::Future<Output = InterceptResult<Self::Error>> + Send;
59}
60
61#[derive(Clone)]
62pub struct DummyInterceptor;
63
64impl ReqInterceptor for DummyInterceptor {
65 type Error = error::AppError;
66
67 async fn intercept(&self, req: Request<Incoming>, _ip: SocketAddr) -> InterceptResult<Self::Error> {
68 InterceptResult::Continue(req)
69 }
70}
71
72pub type DefaultServer = Server<DummyInterceptor>;
73
74pub fn new_server(port: u16, router: Router, shutdown_rx: broadcast::Receiver<()>) -> Server {
75 Server {
76 port,
77 tls_param: None, router,
79 interceptor: None,
80 idle_timeout: Duration::from_secs(120),
81 shutdown_rx,
82 }
83}
84
85impl<I> Server<I>
86where
87 I: ReqInterceptor + Clone + Send + Sync + 'static,
88{
89 pub fn with_interceptor<R>(self: Server<I>, interceptor: R) -> Server<R>
90 where
91 R: ReqInterceptor + Clone + Send + Sync + 'static,
92 {
93 Server::<R> {
94 port: self.port,
95 tls_param: self.tls_param,
96 router: self.router,
97 interceptor: Some(interceptor),
98 idle_timeout: self.idle_timeout, shutdown_rx: self.shutdown_rx,
100 }
101 }
102 pub fn with_tls_param(mut self, tls_param: Option<TlsParam>) -> Self {
103 self.tls_param = tls_param;
105 self
106 }
107
108 pub fn with_timeout(mut self, timeout: Duration) -> Self {
109 self.idle_timeout = timeout;
110 self
111 }
112
113 pub async fn run(mut self) -> Result<(), std::io::Error> {
114 let use_tls = match self.tls_param.clone() {
115 Some(config) => config.tls,
116 None => false,
117 };
118 log::info!("listening on port {}, use_tls: {}", self.port, use_tls);
119 let server: hyper_util::server::conn::auto::Builder<TokioExecutor> = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
120 let graceful: hyper_util::server::graceful::GracefulShutdown = hyper_util::server::graceful::GracefulShutdown::new();
121 match use_tls {
122 #[allow(clippy::expect_used)]
123 true => {
124 serve_tls(
125 &self.router,
126 server,
127 graceful,
128 self.port,
129 self.tls_param.as_ref().expect("should be some"),
130 self.interceptor.clone(),
131 self.idle_timeout,
132 &mut self.shutdown_rx,
133 )
134 .await?
135 }
136 false => {
137 serve_plantext(&self.router, server, graceful, self.port, self.interceptor.clone(), self.idle_timeout, &mut self.shutdown_rx).await?
138 }
139 }
140 Ok(())
141 }
142}
143
144async fn handle<I>(
145 request: Request<Incoming>, client_socket_addr: SocketAddr, app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>>,
146 interceptor: Option<I>,
147) -> std::result::Result<Response, std::io::Error>
148where
149 I: ReqInterceptor + Clone + Send + Sync + 'static,
150{
151 if let Some(interceptor) = interceptor {
152 match interceptor.intercept(request, client_socket_addr).await {
153 InterceptResult::Return(res) => Ok(res),
154 InterceptResult::Drop => Err(std::io::Error::other("Request dropped by interceptor")),
155 InterceptResult::Continue(req) => app
156 .oneshot(req)
157 .await
158 .map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err)),
159 InterceptResult::Error(err) => {
160 let res = err.into_response();
161 Ok(res)
162 }
163 }
164 } else {
165 app.oneshot(request)
166 .await
167 .map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err))
168 }
169}
170
171async fn handle_connection<C, I>(
172 conn: C, client_socket_addr: std::net::SocketAddr, app: Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>,
173 interceptor: Option<I>, graceful: &hyper_util::server::graceful::GracefulShutdown, timeout: Duration,
174) where
175 C: tokio::io::AsyncRead + tokio::io::AsyncWrite + 'static + Send + Sync,
176 I: ReqInterceptor + Clone + Send + Sync + 'static,
177{
178 let timeout_io = Box::pin(io::TimeoutIO::new(conn, timeout));
179 use hyper::Request;
180 use hyper_util::rt::TokioIo;
181 let stream = TokioIo::new(timeout_io);
182 let mut app = app.into_make_service_with_connect_info::<SocketAddr>();
183 let app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>> = unwrap_infallible(app.call(client_socket_addr).await);
184 let hyper_service = hyper::service::service_fn(move |request: Request<hyper::body::Incoming>| {
186 handle(request, client_socket_addr, app.clone(), interceptor.clone())
187 });
188
189 let conn = server.serve_connection_with_upgrades(stream, hyper_service);
190 let conn = graceful.watch(conn.into_owned());
191
192 tokio::spawn(async move {
193 if let Err(err) = conn.await {
194 handle_hyper_error(client_socket_addr, err);
195 }
196 log::debug!("connection dropped: {client_socket_addr}");
197 });
198}
199
200fn handle_hyper_error(client_socket_addr: SocketAddr, http_err: DynError) {
201 use std::error::Error;
202 match http_err.downcast_ref::<hyper::Error>() {
203 Some(hyper_err) => {
204 let level = if hyper_err.is_user() { log::Level::Warn } else { log::Level::Debug };
205 let source = hyper_err.source().unwrap_or(hyper_err);
206 log::log!(
207 level,
208 "[hyper {}]: {:?} from {}",
209 if hyper_err.is_user() { "user" } else { "system" },
210 source,
211 SocketAddrFormat(&client_socket_addr)
212 );
213 }
214 None => match http_err.downcast_ref::<std::io::Error>() {
215 Some(io_err) => {
216 warn!("[hyper io]: [{}] {} from {}", io_err.kind(), io_err, SocketAddrFormat(&client_socket_addr));
217 }
218 None => {
219 warn!("[hyper]: {} from {}", http_err, SocketAddrFormat(&client_socket_addr));
220 }
221 },
222 }
223}
224
225async fn serve_plantext<I>(
226 app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
227 port: u16, interceptor: Option<I>, timeout: Duration, shutdown_rx: &mut broadcast::Receiver<()>,
228) -> Result<(), std::io::Error>
229where
230 I: ReqInterceptor + Clone + Send + Sync + 'static,
231{
232 let listener = create_dual_stack_listener(port).await?;
233 loop {
234 tokio::select! {
235 _ = shutdown_rx.recv() => {
236 info!("start graceful shutdown!");
237 drop(listener);
238 break;
239 }
240 conn = listener.accept() => {
241 match conn {
242 Ok((conn, client_socket_addr)) => {
243 handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
244 Err(e) => {
245 warn!("accept error:{e}");
246 }
247 }
248 }
249 }
250 }
251 match tokio::time::timeout(GRACEFUL_SHUTDOWN_TIMEOUT, graceful.shutdown()).await {
252 Ok(_) => info!("Gracefully shutdown!"),
253 Err(_) => info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting..."),
254 }
255 Ok(())
256}
257
258#[allow(clippy::too_many_arguments)]
259async fn serve_tls<I>(
260 app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
261 port: u16, tls_param: &TlsParam, interceptor: Option<I>, timeout: Duration, shutdown_rx: &mut broadcast::Receiver<()>,
262) -> Result<(), std::io::Error>
263where
264 I: ReqInterceptor + Clone + Send + Sync + 'static,
265{
266 let (tx, mut rx) = broadcast::channel::<Arc<ServerConfig>>(1);
267 let tls_param_clone = tls_param.clone();
268 tokio::spawn(async move {
269 info!("update tls config every {REFRESH_INTERVAL:?}");
270 loop {
271 time::sleep(REFRESH_INTERVAL).await;
272 if let Ok(new_acceptor) = tls_config(&tls_param_clone.key, &tls_param_clone.cert) {
273 info!("update tls config");
274 if let Err(e) = tx.send(new_acceptor) {
275 warn!("send tls config error:{e}");
276 }
277 }
278 }
279 });
280 let mut acceptor: TlsAcceptor = TlsAcceptor::new(tls_config(&tls_param.key, &tls_param.cert)?, create_dual_stack_listener(port).await?);
281 loop {
282 tokio::select! {
283 _ = shutdown_rx.recv() => {
284 info!("start graceful shutdown!");
285 drop(acceptor);
286 break;
287 }
288 message = rx.recv() => {
289 match message {
290 Ok(new_config) => {
291 acceptor.replace_config(new_config);
292 info!("replaced tls config");
293 },
294 Err(e) => {
295 match e {
296 RecvError::Closed => {
297 warn!("this channel should not be closed!");
298 break;
299 },
300 RecvError::Lagged(n) => {
301 warn!("lagged {n} messages, this may cause tls config not updated in time");
302 }
303 }
304 }
305 }
306 }
307 conn = acceptor.accept() => {
308 match conn {
309 Ok((conn, client_socket_addr)) => {
310 handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
311 Err(e) => {
312 warn!("accept error:{e}");
313 }
314 }
315 }
316 }
317 }
318 match tokio::time::timeout(GRACEFUL_SHUTDOWN_TIMEOUT, graceful.shutdown()).await {
319 Ok(_) => info!("Gracefully shutdown!"),
320 Err(_) => info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting..."),
321 }
322 Ok(())
323}
324
325pub fn generate_shutdown_receiver() -> Receiver<()> {
326 let (shutdown_tx, shutdown_rx) = tokio::sync::broadcast::channel::<()>(1);
327 subscribe_shutdown_sender(shutdown_tx);
328 shutdown_rx
329}
330
331pub fn subscribe_shutdown_sender(shutdown_tx: Sender<()>) {
332 tokio::spawn(async move {
333 match wait_signal().await {
334 Ok(_) => {
335 let _ = shutdown_tx.send(());
336 }
337 Err(e) => {
338 log::error!("wait_signal error: {}", e);
339 panic!("wait_signal error: {}", e);
340 }
341 }
342 });
343}
344
345#[cfg(unix)]
346pub(crate) async fn wait_signal() -> Result<(), DynError> {
347 use log::info;
348 use tokio::signal::unix::{SignalKind, signal};
349 let mut terminate_signal = signal(SignalKind::terminate())?;
350 tokio::select! {
351 _ = terminate_signal.recv() => {
352 info!("receive terminate signal");
353 },
354 _ = tokio::signal::ctrl_c() => {
355 info!("receive ctrl_c signal");
356 },
357 };
358 Ok(())
359}
360
361#[cfg(windows)]
362pub(crate) async fn wait_signal() -> Result<(), DynError> {
363 let _ = tokio::signal::ctrl_c().await;
364 info!("receive ctrl_c signal");
365 Ok(())
366}
367
368fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
369 match result {
370 Ok(value) => value,
371 Err(err) => match err {},
372 }
373}