1use std::{convert::Infallible, net::SocketAddr, sync::Arc, time::Duration};
2
3pub mod error;
4pub mod init_log;
5pub mod util;
6type DynError = Box<dyn std::error::Error + Send + Sync>;
7use crate::util::{
8 io::{self, create_dual_stack_listener},
9 tls::{tls_config, TlsAcceptor},
10};
11
12use axum::{
13 extract::Request,
14 response::{IntoResponse, Response},
15 Router,
16};
17
18use hyper::body::Incoming;
19use hyper_util::rt::TokioExecutor;
20use log::{info, warn};
21use tokio::{sync::broadcast, time};
22use tokio_rustls::rustls::ServerConfig;
23use tower::{Service, ServiceExt};
24use util::format::SocketAddrFormat;
25
26const REFRESH_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
27const GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
28
29pub struct Server<I: ReqInterceptor = DummyInterceptor> {
30 pub port: u16,
31 pub tls_param: Option<TlsParam>,
32 router: Router,
33 pub interceptor: Option<I>,
34 pub idle_timeout: Duration,
35 shutdown_rx: broadcast::Receiver<()>,
36}
37
38#[derive(Debug, Clone)]
39pub struct TlsParam {
40 pub tls: bool,
41 pub cert: String,
42 pub key: String,
43}
44
45pub enum InterceptResult<T: IntoResponse> {
46 Return(Response),
47 Drop,
48 Continue(Request<Incoming>),
49 Error(T),
50}
51
52pub trait ReqInterceptor: Send {
53 type Error: IntoResponse + Send + Sync + 'static;
54 fn intercept(&self, req: Request<Incoming>, ip: SocketAddr) -> impl std::future::Future<Output = InterceptResult<Self::Error>> + Send;
55}
56
57#[derive(Clone)]
58pub struct DummyInterceptor;
59
60impl ReqInterceptor for DummyInterceptor {
61 type Error = error::AppError;
62
63 async fn intercept(&self, req: Request<Incoming>, _ip: SocketAddr) -> InterceptResult<Self::Error> {
64 InterceptResult::Continue(req)
65 }
66}
67
68pub type DefaultServer = Server<DummyInterceptor>;
69
70pub fn new_server(port: u16, router: Router, shutdown_rx: broadcast::Receiver<()>) -> Server {
71 let server = Server {
72 port,
73 tls_param: None, router,
75 interceptor: None,
76 idle_timeout: Duration::from_secs(120),
77 shutdown_rx,
78 };
79 server
80}
81
82impl<I> Server<I>
83where
84 I: ReqInterceptor + Clone + Send + Sync + 'static,
85{
86 pub fn with_interceptor<R>(self: Server<I>, interceptor: R) -> Server<R>
87 where
88 R: ReqInterceptor + Clone + Send + Sync + 'static,
89 {
90 Server::<R> {
91 port: self.port,
92 tls_param: self.tls_param,
93 router: self.router,
94 interceptor: Some(interceptor),
95 idle_timeout: self.idle_timeout, shutdown_rx: self.shutdown_rx,
97 }
98 }
99 pub fn with_tls_param(mut self, tls_param: Option<TlsParam>) -> Self {
100 self.tls_param = tls_param;
102 self
103 }
104
105 pub fn with_timeout(mut self, timeout: Duration) -> Self {
106 self.idle_timeout = timeout;
107 self
108 }
109
110 pub async fn run(mut self) -> Result<(), std::io::Error> {
111 let use_tls = match self.tls_param.clone() {
112 Some(config) => config.tls,
113 None => false,
114 };
115 log::info!("listening on port {}, use_tls: {}", self.port, use_tls);
116 let server: hyper_util::server::conn::auto::Builder<TokioExecutor> = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
117 let graceful: hyper_util::server::graceful::GracefulShutdown = hyper_util::server::graceful::GracefulShutdown::new();
118 match use_tls {
119 #[allow(clippy::expect_used)]
120 true => {
121 serve_tls(
122 &self.router,
123 server,
124 graceful,
125 self.port,
126 self.tls_param.as_ref().expect("should be some"),
127 self.interceptor.clone(),
128 self.idle_timeout,
129 &mut self.shutdown_rx,
130 )
131 .await?
132 }
133 false => {
134 serve_plantext(&self.router, server, graceful, self.port, self.interceptor.clone(), self.idle_timeout, &mut self.shutdown_rx).await?
135 }
136 }
137 Ok(())
138 }
139}
140
141async fn handle<I>(
142 request: Request<Incoming>, client_socket_addr: SocketAddr, app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>>,
143 interceptor: Option<I>,
144) -> std::result::Result<Response, std::io::Error>
145where
146 I: ReqInterceptor + Clone + Send + Sync + 'static,
147{
148 if let Some(interceptor) = interceptor {
149 match interceptor.intercept(request, client_socket_addr).await {
150 InterceptResult::Return(res) => Ok(res),
151 InterceptResult::Drop => Err(std::io::Error::other("Request dropped by interceptor")),
152 InterceptResult::Continue(req) => app
153 .oneshot(req)
154 .await
155 .map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err)),
156 InterceptResult::Error(err) => {
157 let res = err.into_response();
158 Ok(res)
159 }
160 }
161 } else {
162 app.oneshot(request)
163 .await
164 .map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err))
165 }
166}
167
168async fn handle_connection<C, I>(
169 conn: C, client_socket_addr: std::net::SocketAddr, app: Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>,
170 interceptor: Option<I>, graceful: &hyper_util::server::graceful::GracefulShutdown, timeout: Duration,
171) where
172 C: tokio::io::AsyncRead + tokio::io::AsyncWrite + 'static + Send + Sync,
173 I: ReqInterceptor + Clone + Send + Sync + 'static,
174{
175 let timeout_io = Box::pin(io::TimeoutIO::new(conn, timeout));
176 use hyper::Request;
177 use hyper_util::rt::TokioIo;
178 let stream = TokioIo::new(timeout_io);
179 let mut app = app.into_make_service_with_connect_info::<SocketAddr>();
180 let app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>> = unwrap_infallible(app.call(client_socket_addr).await);
181 let hyper_service = hyper::service::service_fn(move |request: Request<hyper::body::Incoming>| {
183 handle(request, client_socket_addr, app.clone(), interceptor.clone())
184 });
185
186 let conn = server.serve_connection_with_upgrades(stream, hyper_service);
187 let conn = graceful.watch(conn.into_owned());
188
189 tokio::spawn(async move {
190 if let Err(err) = conn.await {
191 handle_hyper_error(client_socket_addr, err);
192 }
193 log::debug!("connection dropped: {client_socket_addr}");
194 });
195}
196
197fn handle_hyper_error(client_socket_addr: SocketAddr, http_err: DynError) {
198 use std::error::Error;
199 match http_err.downcast_ref::<hyper::Error>() {
200 Some(hyper_err) => {
201 let level = if hyper_err.is_user() { log::Level::Warn } else { log::Level::Debug };
202 let source = hyper_err.source().unwrap_or(hyper_err);
203 log::log!(
204 level,
205 "[hyper {}]: {:?} from {}",
206 if hyper_err.is_user() { "user" } else { "system" },
207 source,
208 SocketAddrFormat(&client_socket_addr)
209 );
210 }
211 None => match http_err.downcast_ref::<std::io::Error>() {
212 Some(io_err) => {
213 warn!("[hyper io]: [{}] {} from {}", io_err.kind(), io_err, SocketAddrFormat(&client_socket_addr));
214 }
215 None => {
216 warn!("[hyper]: {} from {}", http_err, SocketAddrFormat(&client_socket_addr));
217 }
218 },
219 }
220}
221
222async fn serve_plantext<I>(
223 app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
224 port: u16, interceptor: Option<I>, timeout: Duration, shutdown_rx: &mut broadcast::Receiver<()>,
225) -> Result<(), std::io::Error>
226where
227 I: ReqInterceptor + Clone + Send + Sync + 'static,
228{
229 let listener = create_dual_stack_listener(port).await?;
230 loop {
231 tokio::select! {
232 _ = shutdown_rx.recv() => {
233 info!("start graceful shutdown!");
234 drop(listener);
235 break;
236 }
237 conn = listener.accept() => {
238 match conn {
239 Ok((conn, client_socket_addr)) => {
240 handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
241 Err(e) => {
242 warn!("accept error:{e}");
243 }
244 }
245 }
246 }
247 }
248 tokio::select! {
249 _ = graceful.shutdown() => {
250 info!("Gracefully shutdown!");
251 },
252 _ = tokio::time::sleep(GRACEFUL_SHUTDOWN_TIMEOUT) => {
253 info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting...");
254 }
255 }
256 Ok(())
257}
258
259#[allow(clippy::too_many_arguments)]
260async fn serve_tls<I>(
261 app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
262 port: u16, tls_param: &TlsParam, interceptor: Option<I>, timeout: Duration, shutdown_rx: &mut broadcast::Receiver<()>,
263) -> Result<(), std::io::Error>
264where
265 I: ReqInterceptor + Clone + Send + Sync + 'static,
266{
267 let (tx, _rx) = broadcast::channel::<Arc<ServerConfig>>(10);
268 let tx_clone = tx.clone();
269 let tls_param_clone = tls_param.clone();
270 tokio::spawn(async move {
271 info!("update tls config every {REFRESH_INTERVAL:?}");
272 loop {
273 time::sleep(REFRESH_INTERVAL).await;
274 if let Ok(new_acceptor) = tls_config(&tls_param_clone.key, &tls_param_clone.cert) {
275 info!("update tls config");
276 if let Err(e) = tx.send(new_acceptor) {
277 warn!("send tls config error:{e}");
278 }
279 }
280 }
281 });
282 let mut rx = tx_clone.subscribe();
283 let mut acceptor: TlsAcceptor = TlsAcceptor::new(tls_config(&tls_param.key, &tls_param.cert)?, create_dual_stack_listener(port).await?);
284 loop {
285 tokio::select! {
286 _ = shutdown_rx.recv() => {
287 info!("start graceful shutdown!");
288 drop(acceptor);
289 break;
290 }
291 message = rx.recv() => {
292 #[allow(clippy::expect_used)]
293 let new_config = message.expect("Channel should not be closed");
294 acceptor.replace_config(new_config);
296 info!("replaced tls config");
297 }
298 conn = acceptor.accept() => {
299 match conn {
300 Ok((conn, client_socket_addr)) => {
301 handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
302 Err(e) => {
303 warn!("accept error:{e}");
304 }
305 }
306 }
307 }
308 }
309 tokio::select! {
310 _ = graceful.shutdown() => {
311 info!("Gracefully shutdown!");
312 },
313 _ = tokio::time::sleep(GRACEFUL_SHUTDOWN_TIMEOUT) => {
314 info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting...");
315 }
316 }
317 Ok(())
318}
319
320#[cfg(unix)]
321pub async fn wait_signal() -> Result<(), DynError> {
322 use log::info;
323 use tokio::signal::unix::{signal, SignalKind};
324 let mut terminate_signal = signal(SignalKind::terminate())?;
325 tokio::select! {
326 _ = terminate_signal.recv() => {
327 info!("receive terminate signal");
328 },
329 _ = tokio::signal::ctrl_c() => {
330 info!("receive ctrl_c signal");
331 },
332 };
333 Ok(())
334}
335
336#[cfg(windows)]
337pub async fn wait_signal() -> Result<(), DynError> {
338 let _ = tokio::signal::ctrl_c().await;
339 info!("receive ctrl_c signal");
340 Ok(())
341}
342
343fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
344 match result {
345 Ok(value) => value,
346 Err(err) => match err {},
347 }
348}