1use std::{convert::Infallible, net::SocketAddr, sync::Arc, time::Duration};
2
3pub mod error;
4pub mod init_log;
5pub mod util;
6type DynError = Box<dyn std::error::Error + Send + Sync>;
7use crate::util::{
8 io::{self, create_dual_stack_listener},
9 tls::{tls_config, TlsAcceptor},
10};
11
12use axum::{
13 extract::Request,
14 response::{IntoResponse, Response},
15 Router,
16};
17
18use hyper::body::Incoming;
19use hyper_util::rt::TokioExecutor;
20use log::{info, warn};
21use tokio::{
22 sync::broadcast::{self, error::RecvError},
23 time,
24};
25use tokio_rustls::rustls::ServerConfig;
26use tower::{Service, ServiceExt};
27use util::format::SocketAddrFormat;
28
29const REFRESH_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
30const GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
31
32pub struct Server<I: ReqInterceptor = DummyInterceptor> {
33 pub port: u16,
34 pub tls_param: Option<TlsParam>,
35 router: Router,
36 pub interceptor: Option<I>,
37 pub idle_timeout: Duration,
38 shutdown_rx: broadcast::Receiver<()>,
39}
40
41#[derive(Debug, Clone)]
42pub struct TlsParam {
43 pub tls: bool,
44 pub cert: String,
45 pub key: String,
46}
47
48pub enum InterceptResult<T: IntoResponse> {
49 Return(Response),
50 Drop,
51 Continue(Request<Incoming>),
52 Error(T),
53}
54
55pub trait ReqInterceptor: Send {
56 type Error: IntoResponse + Send + Sync + 'static;
57 fn intercept(&self, req: Request<Incoming>, ip: SocketAddr) -> impl std::future::Future<Output = InterceptResult<Self::Error>> + Send;
58}
59
60#[derive(Clone)]
61pub struct DummyInterceptor;
62
63impl ReqInterceptor for DummyInterceptor {
64 type Error = error::AppError;
65
66 async fn intercept(&self, req: Request<Incoming>, _ip: SocketAddr) -> InterceptResult<Self::Error> {
67 InterceptResult::Continue(req)
68 }
69}
70
71pub type DefaultServer = Server<DummyInterceptor>;
72
73pub fn new_server(port: u16, router: Router, shutdown_rx: broadcast::Receiver<()>) -> Server {
74 Server {
75 port,
76 tls_param: None, router,
78 interceptor: None,
79 idle_timeout: Duration::from_secs(120),
80 shutdown_rx,
81 }
82}
83
84impl<I> Server<I>
85where
86 I: ReqInterceptor + Clone + Send + Sync + 'static,
87{
88 pub fn with_interceptor<R>(self: Server<I>, interceptor: R) -> Server<R>
89 where
90 R: ReqInterceptor + Clone + Send + Sync + 'static,
91 {
92 Server::<R> {
93 port: self.port,
94 tls_param: self.tls_param,
95 router: self.router,
96 interceptor: Some(interceptor),
97 idle_timeout: self.idle_timeout, shutdown_rx: self.shutdown_rx,
99 }
100 }
101 pub fn with_tls_param(mut self, tls_param: Option<TlsParam>) -> Self {
102 self.tls_param = tls_param;
104 self
105 }
106
107 pub fn with_timeout(mut self, timeout: Duration) -> Self {
108 self.idle_timeout = timeout;
109 self
110 }
111
112 pub async fn run(mut self) -> Result<(), std::io::Error> {
113 let use_tls = match self.tls_param.clone() {
114 Some(config) => config.tls,
115 None => false,
116 };
117 log::info!("listening on port {}, use_tls: {}", self.port, use_tls);
118 let server: hyper_util::server::conn::auto::Builder<TokioExecutor> = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
119 let graceful: hyper_util::server::graceful::GracefulShutdown = hyper_util::server::graceful::GracefulShutdown::new();
120 match use_tls {
121 #[allow(clippy::expect_used)]
122 true => {
123 serve_tls(
124 &self.router,
125 server,
126 graceful,
127 self.port,
128 self.tls_param.as_ref().expect("should be some"),
129 self.interceptor.clone(),
130 self.idle_timeout,
131 &mut self.shutdown_rx,
132 )
133 .await?
134 }
135 false => {
136 serve_plantext(&self.router, server, graceful, self.port, self.interceptor.clone(), self.idle_timeout, &mut self.shutdown_rx).await?
137 }
138 }
139 Ok(())
140 }
141}
142
143async fn handle<I>(
144 request: Request<Incoming>, client_socket_addr: SocketAddr, app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>>,
145 interceptor: Option<I>,
146) -> std::result::Result<Response, std::io::Error>
147where
148 I: ReqInterceptor + Clone + Send + Sync + 'static,
149{
150 if let Some(interceptor) = interceptor {
151 match interceptor.intercept(request, client_socket_addr).await {
152 InterceptResult::Return(res) => Ok(res),
153 InterceptResult::Drop => Err(std::io::Error::other("Request dropped by interceptor")),
154 InterceptResult::Continue(req) => app
155 .oneshot(req)
156 .await
157 .map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err)),
158 InterceptResult::Error(err) => {
159 let res = err.into_response();
160 Ok(res)
161 }
162 }
163 } else {
164 app.oneshot(request)
165 .await
166 .map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err))
167 }
168}
169
170async fn handle_connection<C, I>(
171 conn: C, client_socket_addr: std::net::SocketAddr, app: Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>,
172 interceptor: Option<I>, graceful: &hyper_util::server::graceful::GracefulShutdown, timeout: Duration,
173) where
174 C: tokio::io::AsyncRead + tokio::io::AsyncWrite + 'static + Send + Sync,
175 I: ReqInterceptor + Clone + Send + Sync + 'static,
176{
177 let timeout_io = Box::pin(io::TimeoutIO::new(conn, timeout));
178 use hyper::Request;
179 use hyper_util::rt::TokioIo;
180 let stream = TokioIo::new(timeout_io);
181 let mut app = app.into_make_service_with_connect_info::<SocketAddr>();
182 let app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>> = unwrap_infallible(app.call(client_socket_addr).await);
183 let hyper_service = hyper::service::service_fn(move |request: Request<hyper::body::Incoming>| {
185 handle(request, client_socket_addr, app.clone(), interceptor.clone())
186 });
187
188 let conn = server.serve_connection_with_upgrades(stream, hyper_service);
189 let conn = graceful.watch(conn.into_owned());
190
191 tokio::spawn(async move {
192 if let Err(err) = conn.await {
193 handle_hyper_error(client_socket_addr, err);
194 }
195 log::debug!("connection dropped: {client_socket_addr}");
196 });
197}
198
199fn handle_hyper_error(client_socket_addr: SocketAddr, http_err: DynError) {
200 use std::error::Error;
201 match http_err.downcast_ref::<hyper::Error>() {
202 Some(hyper_err) => {
203 let level = if hyper_err.is_user() { log::Level::Warn } else { log::Level::Debug };
204 let source = hyper_err.source().unwrap_or(hyper_err);
205 log::log!(
206 level,
207 "[hyper {}]: {:?} from {}",
208 if hyper_err.is_user() { "user" } else { "system" },
209 source,
210 SocketAddrFormat(&client_socket_addr)
211 );
212 }
213 None => match http_err.downcast_ref::<std::io::Error>() {
214 Some(io_err) => {
215 warn!("[hyper io]: [{}] {} from {}", io_err.kind(), io_err, SocketAddrFormat(&client_socket_addr));
216 }
217 None => {
218 warn!("[hyper]: {} from {}", http_err, SocketAddrFormat(&client_socket_addr));
219 }
220 },
221 }
222}
223
224async fn serve_plantext<I>(
225 app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
226 port: u16, interceptor: Option<I>, timeout: Duration, shutdown_rx: &mut broadcast::Receiver<()>,
227) -> Result<(), std::io::Error>
228where
229 I: ReqInterceptor + Clone + Send + Sync + 'static,
230{
231 let listener = create_dual_stack_listener(port).await?;
232 loop {
233 tokio::select! {
234 _ = shutdown_rx.recv() => {
235 info!("start graceful shutdown!");
236 drop(listener);
237 break;
238 }
239 conn = listener.accept() => {
240 match conn {
241 Ok((conn, client_socket_addr)) => {
242 handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
243 Err(e) => {
244 warn!("accept error:{e}");
245 }
246 }
247 }
248 }
249 }
250 match tokio::time::timeout(GRACEFUL_SHUTDOWN_TIMEOUT, graceful.shutdown()).await {
251 Ok(_) => info!("Gracefully shutdown!"),
252 Err(_) => info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting..."),
253 }
254 Ok(())
255}
256
257#[allow(clippy::too_many_arguments)]
258async fn serve_tls<I>(
259 app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
260 port: u16, tls_param: &TlsParam, interceptor: Option<I>, timeout: Duration, shutdown_rx: &mut broadcast::Receiver<()>,
261) -> Result<(), std::io::Error>
262where
263 I: ReqInterceptor + Clone + Send + Sync + 'static,
264{
265 let (tx, mut rx) = broadcast::channel::<Arc<ServerConfig>>(1);
266 let tls_param_clone = tls_param.clone();
267 tokio::spawn(async move {
268 info!("update tls config every {REFRESH_INTERVAL:?}");
269 loop {
270 time::sleep(REFRESH_INTERVAL).await;
271 if let Ok(new_acceptor) = tls_config(&tls_param_clone.key, &tls_param_clone.cert) {
272 info!("update tls config");
273 if let Err(e) = tx.send(new_acceptor) {
274 warn!("send tls config error:{e}");
275 }
276 }
277 }
278 });
279 let mut acceptor: TlsAcceptor = TlsAcceptor::new(tls_config(&tls_param.key, &tls_param.cert)?, create_dual_stack_listener(port).await?);
280 loop {
281 tokio::select! {
282 _ = shutdown_rx.recv() => {
283 info!("start graceful shutdown!");
284 drop(acceptor);
285 break;
286 }
287 message = rx.recv() => {
288 match message {
289 Ok(new_config) => {
290 acceptor.replace_config(new_config);
291 info!("replaced tls config");
292 },
293 Err(e) => {
294 match e {
295 RecvError::Closed => {
296 warn!("this channel should not be closed!");
297 break;
298 },
299 RecvError::Lagged(n) => {
300 warn!("lagged {n} messages, this may cause tls config not updated in time");
301 }
302 }
303 }
304 }
305 }
306 conn = acceptor.accept() => {
307 match conn {
308 Ok((conn, client_socket_addr)) => {
309 handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
310 Err(e) => {
311 warn!("accept error:{e}");
312 }
313 }
314 }
315 }
316 }
317 match tokio::time::timeout(GRACEFUL_SHUTDOWN_TIMEOUT, graceful.shutdown()).await {
318 Ok(_) => info!("Gracefully shutdown!"),
319 Err(_) => info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting..."),
320 }
321 Ok(())
322}
323
324#[cfg(unix)]
325pub async fn wait_signal() -> Result<(), DynError> {
326 use log::info;
327 use tokio::signal::unix::{signal, SignalKind};
328 let mut terminate_signal = signal(SignalKind::terminate())?;
329 tokio::select! {
330 _ = terminate_signal.recv() => {
331 info!("receive terminate signal");
332 },
333 _ = tokio::signal::ctrl_c() => {
334 info!("receive ctrl_c signal");
335 },
336 };
337 Ok(())
338}
339
340#[cfg(windows)]
341pub async fn wait_signal() -> Result<(), DynError> {
342 let _ = tokio::signal::ctrl_c().await;
343 info!("receive ctrl_c signal");
344 Ok(())
345}
346
347fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
348 match result {
349 Ok(value) => value,
350 Err(err) => match err {},
351 }
352}