1pub mod proxy_protocol;
2
3use std::{
4 fmt::Display,
5 net::SocketAddr,
6 path::PathBuf,
7 sync::{
8 Arc,
9 atomic::{AtomicU32, AtomicU64, Ordering},
10 },
11 time::{Duration, Instant},
12};
13
14use anyhow::{Context, anyhow};
15use async_trait::async_trait;
16use axum::{Router, extract::Request};
17use http::Response;
18use hyper::body::Incoming;
19use hyper_util::{
20 rt::{TokioExecutor, TokioIo, TokioTimer},
21 server::conn::auto::Builder,
22};
23use ic_bn_lib_common::{
24 traits::Run,
25 types::{
26 http::{
27 ALPN_ACME, Addr, ConnInfo, Error, ListenerOpts, Metrics, ProxyProtocolMode,
28 ServerOptions, TlsInfo,
29 },
30 tls::TlsOptions,
31 },
32};
33use prometheus::{
34 Registry,
35 core::{AtomicI64, GenericGauge},
36};
37use proxy_protocol::{ProxyHeader, ProxyProtocolStream};
38use rustls::sign::SingleCertAndKey;
39use scopeguard::defer;
40use tokio::{
41 io::AsyncWriteExt,
42 pin, select,
43 sync::mpsc::channel,
44 time::{sleep, timeout},
45};
46use tokio_io_timeout::TimeoutStream;
47use tokio_util::{sync::CancellationToken, task::TaskTracker};
48use tower_service::Service;
49use tracing::{debug, info, warn};
50use uuid::Uuid;
51
52use super::body::NotifyingBody;
53use crate::{
54 network::{AsyncCounter, AsyncReadWrite, listener::Listener, tls_handshake},
55 tls::{pem_convert_to_rustls, prepare_server_config},
56};
57
58const YEAR: Duration = Duration::from_secs(86400 * 365);
59
60#[derive(Clone)]
61enum RequestState {
62 Start,
63 End,
64}
65
66struct Conn {
67 addr: Addr,
68 remote_addr: Addr,
69 router: Router,
70 builder: Builder<TokioExecutor>,
71 token_graceful: CancellationToken,
72 token_forceful: CancellationToken,
73 options: ServerOptions,
74 metrics: Metrics,
75 requests: AtomicU32,
76 rustls_cfg: Option<Arc<rustls::ServerConfig>>,
77}
78
79impl Display for Conn {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 write!(f, "[{}] <- [{}]", self.addr, self.remote_addr)
82 }
83}
84
85impl Conn {
86 async fn handle(&self, stream: Box<dyn AsyncReadWrite>) -> Result<(), Error> {
87 let accepted_at = Instant::now();
88
89 debug!("{self}: got a new connection");
90
91 let addr = self.addr.to_string();
93 let labels = &mut [
94 addr.as_str(), self.remote_addr.family(), "no", "no", "no", "no", ];
101
102 let (stream, stats) = AsyncCounter::new(stream);
104
105 let (stream, proxy_hdr): (Box<dyn AsyncReadWrite>, Option<ProxyHeader>) =
107 if self.options.proxy_protocol_mode != ProxyProtocolMode::Off {
108 let (stream, hdr) = ProxyProtocolStream::accept(stream)
109 .await
110 .context("unable to accept Proxy Protocol")?;
111
112 if self.options.proxy_protocol_mode == ProxyProtocolMode::Forced && hdr.is_none() {
113 return Err(Error::NoProxyProtocolDetected);
114 }
115
116 (Box::new(stream), hdr)
117 } else {
118 (Box::new(stream), None)
119 };
120
121 let (local_addr, remote_addr) = proxy_hdr
123 .map(|x| (Addr::Tcp(x.dst), Addr::Tcp(x.src)))
124 .unwrap_or_else(|| (self.addr.clone(), self.remote_addr.clone()));
125
126 let conn_info = Arc::new(ConnInfo {
127 id: Uuid::now_v7(),
128 accepted_at,
129 remote_addr,
130 local_addr,
131 traffic: stats.clone(),
132 req_count: AtomicU64::new(0),
133 close: self.token_forceful.clone(),
134 });
135
136 let (stream, tls_info): (Box<dyn AsyncReadWrite>, _) = if let Some(rustls_cfg) =
138 &self.rustls_cfg
139 {
140 debug!("{}: performing TLS handshake", self);
141
142 let (mut stream_tls, tls_info) = timeout(
143 self.options.tls_handshake_timeout,
144 tls_handshake(rustls_cfg.clone(), stream),
145 )
146 .await
147 .context("TLS handshake timed out")?
148 .context("TLS handshake failed")?;
149
150 debug!(
151 "{}: handshake finished in {}ms (SNI: {:?}, proto: {:?}, cipher: {:?}, ALPN: {:?})",
152 self,
153 tls_info.handshake_dur.as_millis(),
154 tls_info.sni,
155 tls_info.protocol,
156 tls_info.cipher,
157 tls_info.alpn,
158 );
159
160 if tls_info
162 .alpn
163 .as_ref()
164 .is_some_and(|x| x.as_bytes() == ALPN_ACME)
165 {
166 debug!("{self}: ACME ALPN - closing connection");
167
168 timeout(Duration::from_secs(5), stream_tls.shutdown())
169 .await
170 .context("socket shutdown timed out")?
171 .context("socket shutdown failed")?;
172
173 return Ok(());
174 }
175
176 (Box::new(stream_tls), Some(Arc::new(tls_info)))
177 } else {
178 (Box::new(stream), None)
179 };
180
181 if let Some(v) = &tls_info {
183 labels[2] = v.protocol.as_str().unwrap();
184 labels[3] = v.cipher.as_str().unwrap();
185
186 self.metrics
187 .conn_tls_handshake_duration
188 .with_label_values(&labels[0..4])
189 .observe(v.handshake_dur.as_secs_f64());
190 }
191
192 self.metrics
193 .conns_open
194 .with_label_values(&labels[0..4])
195 .inc();
196
197 let requests_inflight = self
198 .metrics
199 .requests_inflight
200 .with_label_values(&labels[0..4]);
201
202 let result = self
204 .handle_inner(stream, conn_info.clone(), tls_info, requests_inflight)
205 .await;
206
207 let (sent, rcvd) = (stats.sent(), stats.rcvd());
209 let dur = accepted_at.elapsed().as_secs_f64();
210 let reqs = conn_info.req_count.load(Ordering::SeqCst);
211
212 if self.token_forceful.is_cancelled() {
214 labels[4] = "yes";
215 }
216 if self.token_graceful.is_cancelled() {
218 labels[5] = "yes";
219 }
220
221 self.metrics.conns.with_label_values(labels).inc();
222 self.metrics
223 .conns_open
224 .with_label_values(&labels[0..4])
225 .dec();
226 self.metrics.requests.with_label_values(labels).inc_by(reqs);
227 self.metrics
228 .bytes_rcvd
229 .with_label_values(labels)
230 .inc_by(rcvd);
231 self.metrics
232 .bytes_sent
233 .with_label_values(labels)
234 .inc_by(sent);
235 self.metrics
236 .conn_duration
237 .with_label_values(labels)
238 .observe(dur);
239 self.metrics
240 .requests_per_conn
241 .with_label_values(labels)
242 .observe(reqs as f64);
243
244 debug!(
245 "{self}: connection closed (rcvd: {rcvd}, sent: {sent}, reqs: {reqs}, duration: {dur}, graceful: {}, forced close: {})",
246 self.token_graceful.is_cancelled(),
247 self.token_forceful.is_cancelled(),
248 );
249
250 result
251 }
252
253 async fn handle_inner(
254 &self,
255 stream: Box<dyn AsyncReadWrite>,
256 conn_info: Arc<ConnInfo>,
257 tls_info: Option<Arc<TlsInfo>>,
258 requests_inflight: GenericGauge<AtomicI64>,
259 ) -> Result<(), Error> {
260 let mut idle_timer = Box::pin(sleep(self.options.idle_timeout.unwrap_or(10 * YEAR)));
263
264 let (state_tx, mut state_rx) = channel(65536);
267
268 let mut stream = TimeoutStream::new(stream);
270 stream.set_read_timeout(self.options.read_timeout);
271 stream.set_write_timeout(self.options.write_timeout);
272
273 let stream = TokioIo::new(stream);
275
276 let max_requests_per_conn = self.options.max_requests_per_conn;
278 let service = hyper::service::service_fn(move |mut request: Request<Incoming>| {
279 let _ = state_tx.try_send(RequestState::Start);
281
282 request.extensions_mut().insert(conn_info.clone());
284 if let Some(v) = &tls_info {
285 request.extensions_mut().insert(v.clone());
286 }
287
288 let mut router = self.router.clone();
290 let token = self.token_graceful.clone();
291 let conn_info = conn_info.clone();
292 let state_tx = state_tx.clone();
293 let requests_inflight = requests_inflight.clone();
294
295 async move {
297 requests_inflight.inc();
299
300 defer! {
303 requests_inflight.dec();
304 }
305
306 let result = router.call(request).await.map(|x| {
308 let (parts, body) = x.into_parts();
310 let body = NotifyingBody::new(body, state_tx, RequestState::End);
311 Response::from_parts(parts, body)
312 });
313
314 if let Some(v) = max_requests_per_conn {
316 let req_count = conn_info.req_count.fetch_add(1, Ordering::SeqCst);
317 if req_count + 1 >= v {
318 token.cancel();
319 }
320 }
321
322 result
323 }
324 });
325
326 let conn = self
328 .builder
329 .serve_connection_with_upgrades(Box::pin(stream), service);
330
331 pin!(conn);
333
334 loop {
335 select! {
336 biased; () = self.token_forceful.cancelled() => {
340 break;
341 }
342
343 () = self.token_graceful.cancelled() => {
345 conn.as_mut().graceful_shutdown();
348
349 let _ = timeout(self.options.grace_period, conn.as_mut()).await;
353 break;
354 },
355
356 Some(v) = state_rx.recv() => {
358 match v {
359 RequestState::Start => {
360 let reqs = self.requests.fetch_add(1, Ordering::SeqCst) + 1;
361 debug!("{self}: request started");
362
363 if self.options.idle_timeout.is_some() {
366 debug!("{self}: stopping idle timer (now: {reqs})");
367 idle_timer.as_mut().reset(tokio::time::Instant::now() + 10 * YEAR);
368 }
369 },
370
371 RequestState::End => {
372 let reqs = self.requests.fetch_sub(1, Ordering::SeqCst) - 1;
373 debug!("{self}: request finished (now: {reqs})");
374
375 if let Some(v) = self.options.idle_timeout && reqs == 0 {
377 debug!("{self}: no outstanding requests, starting timer");
379 idle_timer.as_mut().reset(tokio::time::Instant::now() + v);
380 }
381 }
382 }
383 },
384
385 () = idle_timer.as_mut(), if self.options.idle_timeout.is_some() => {
387 debug!("{self}: Idle timeout triggered, closing");
388
389 conn.as_mut().graceful_shutdown();
391 let _ = timeout(Duration::from_secs(5), conn.as_mut()).await;
393 break;
394 },
395
396 v = conn.as_mut() => {
398 if let Err(e) = v {
399 return Err(anyhow!("unable to serve connection: {e:#}").into());
400 }
401
402 break;
403 },
404 }
405 }
406
407 Ok(())
408 }
409}
410
411pub struct ServerBuilder {
413 addr: Option<Addr>,
414 router: Router,
415 registry: Registry,
416 metrics: Option<Metrics>,
417 options: ServerOptions,
418 rustls_cfg: Option<rustls::ServerConfig>,
419}
420
421impl ServerBuilder {
422 pub fn new(router: Router) -> Self {
424 Self {
425 addr: None,
426 router,
427 registry: Registry::new(),
428 metrics: None,
429 options: ServerOptions::default(),
430 rustls_cfg: None,
431 }
432 }
433
434 pub fn listen_tcp(mut self, socket: SocketAddr) -> Self {
436 self.addr = Some(Addr::Tcp(socket));
437 self
438 }
439
440 pub fn listen_unix(mut self, path: PathBuf) -> Self {
442 self.addr = Some(Addr::Unix(path));
443 self
444 }
445
446 pub fn with_metrics_registry(mut self, registry: &Registry) -> Self {
448 self.registry = registry.clone();
449 self
450 }
451
452 pub fn with_metrics(mut self, metrics: Metrics) -> Self {
455 self.metrics = Some(metrics);
456 self
457 }
458
459 pub fn with_rustls_config(mut self, rustls_cfg: rustls::ServerConfig) -> Self {
461 self.rustls_cfg = Some(rustls_cfg);
462 self
463 }
464
465 pub const fn with_options(mut self, options: ServerOptions) -> Self {
467 self.options = options;
468 self
469 }
470
471 pub fn with_rustls_single_cert(mut self, cert: PathBuf, key: PathBuf) -> Result<Self, Error> {
474 let cert = std::fs::read(cert).context("unable to read cert")?;
475 let key = std::fs::read(key).context("unable to read key")?;
476 let cert = pem_convert_to_rustls(&key, &cert).context("unable to parse cert+key pair")?;
477 let resolver = SingleCertAndKey::from(cert);
478 let tls_opts = TlsOptions::default();
479 let rustls_cfg = prepare_server_config(tls_opts, Arc::new(resolver), &self.registry);
480
481 self.rustls_cfg = Some(rustls_cfg);
482 Ok(self)
483 }
484
485 pub fn build(self) -> Result<Server, Error> {
487 let Some(addr) = self.addr else {
488 return Err(Error::Generic(anyhow!("Listening address not specified")));
489 };
490
491 let metrics = self.metrics.unwrap_or_else(|| Metrics::new(&self.registry));
492
493 Ok(Server::new(
494 addr,
495 self.router,
496 self.options,
497 metrics,
498 self.rustls_cfg,
499 ))
500 }
501}
502
503pub struct Server {
505 addr: Addr,
506 router: Router,
507 tracker: TaskTracker,
508 options: ServerOptions,
509 metrics: Metrics,
510 builder: Builder<TokioExecutor>,
511 rustls_cfg: Option<Arc<rustls::ServerConfig>>,
512}
513
514impl Display for Server {
515 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
516 write!(f, "[{}]", self.addr)
517 }
518}
519
520impl Server {
521 pub fn new(
523 addr: Addr,
524 router: Router,
525 options: ServerOptions,
526 metrics: Metrics,
527 rustls_cfg: Option<rustls::ServerConfig>,
528 ) -> Self {
529 let mut builder = Builder::new(TokioExecutor::new());
532 builder
533 .http1()
534 .timer(TokioTimer::new()) .header_read_timeout(Some(options.http1_header_read_timeout))
536 .keep_alive(true)
537 .http2()
538 .adaptive_window(true)
539 .max_concurrent_streams(Some(options.http2_max_streams))
540 .timer(TokioTimer::new()) .keep_alive_interval(options.http2_keepalive_interval)
542 .keep_alive_timeout(options.http2_keepalive_timeout)
543 .enable_connect_protocol(); Self {
546 addr,
547 router,
548 options,
549 metrics,
550 tracker: TaskTracker::new(),
551 builder,
552 rustls_cfg: rustls_cfg.map(Arc::new),
553 }
554 }
555
556 pub async fn serve(&self, token: CancellationToken) -> Result<(), Error> {
558 let opts = ListenerOpts {
559 backlog: self.options.backlog,
560 mss: self.options.tcp_mss,
561 keepalive: (&self.options).into(),
562 };
563
564 let listener =
565 Listener::new(self.addr.clone(), opts).context("unable to create listener")?;
566 self.serve_with_listener(listener, token).await
567 }
568
569 fn spawn_connection(
570 &self,
571 stream: Box<dyn AsyncReadWrite>,
572 remote_addr: Addr,
573 token: CancellationToken,
574 ) {
575 let conn = Conn {
579 addr: self.addr.clone(),
580 remote_addr: remote_addr.clone(),
581 router: self.router.clone(),
582 builder: self.builder.clone(),
583 token_graceful: token,
584 token_forceful: CancellationToken::new(),
585 options: self.options,
586 metrics: self.metrics.clone(), requests: AtomicU32::new(0),
588 rustls_cfg: self.rustls_cfg.clone(),
589 };
590
591 self.tracker.spawn(async move {
593 if let Err(e) = conn.handle(stream).await {
594 info!(
595 "[{}] <- [{remote_addr}]: failed to handle connection: {e:#}",
596 conn.addr
597 );
598 }
599
600 debug!("[{}] <- [{remote_addr}]: connection finished", conn.addr);
601 });
602 }
603
604 pub async fn serve_with_listener(
606 &self,
607 listener: Listener,
608 token: CancellationToken,
609 ) -> Result<(), Error> {
610 warn!("{self}: running (TLS: {})", self.rustls_cfg.is_some());
611
612 loop {
613 select! {
614 biased; () = token.cancelled() => {
617 drop(listener);
619
620 warn!("{self}: shutting down, waiting for the active connections to close for {}s", self.options.grace_period.as_secs());
621 self.tracker.close();
622
623 select! {
624 () = sleep(self.options.grace_period + Duration::from_secs(5)) => {
625 warn!("{self}: connections didn't close in time, shutting down anyway");
626 },
627 () = self.tracker.wait() => {},
628 }
629
630 warn!("{self}: shut down");
631
632 if let Addr::Unix(v) = &self.addr {
634 let _ = std::fs::remove_file(v);
635 }
636
637 return Ok(());
638 },
639
640 v = listener.accept() => {
642 let (stream, remote_addr) = match v {
643 Ok(v) => v,
644 Err(e) => {
645 warn!("{self}: unable to accept connection: {e:#}");
646 sleep(Duration::from_millis(10)).await;
649 continue;
650 }
651 };
652
653 self.spawn_connection(stream, remote_addr, token.child_token());
654 }
655 }
656 }
657 }
658}
659
660#[async_trait]
661impl Run for Server {
662 async fn run(&self, token: CancellationToken) -> Result<(), anyhow::Error> {
663 self.serve(token).await?;
664 Ok(())
665 }
666}
667
668#[cfg(test)]
669mod test {
670 use http::StatusCode;
671
672 use crate::network::listener::listen_tcp;
673
674 use super::*;
675
676 #[tokio::test]
677 async fn test_server() {
678 let opts = ServerOptions::default();
679 let listener = listen_tcp(
680 "127.0.0.1:0".parse().unwrap(),
681 ListenerOpts {
682 backlog: 128,
683 mss: None,
684 keepalive: (&opts).into(),
685 },
686 )
687 .unwrap();
688
689 let addr = listener.local_addr().unwrap();
690
691 let server = Server::new(
692 Addr::Tcp(addr),
693 Router::new(),
694 opts,
695 Metrics::new(&Registry::new()),
696 None,
697 );
698
699 tokio::spawn(async move {
700 server
701 .serve_with_listener(listener.into(), CancellationToken::new())
702 .await
703 .unwrap();
704 });
705
706 for _ in 0..10 {
707 let Ok(result) = reqwest::get(format!("http://{addr}")).await else {
708 tokio::time::sleep(Duration::from_millis(10)).await;
709 continue;
710 };
711
712 assert_eq!(result.status(), StatusCode::NOT_FOUND);
713 break;
714 }
715 }
716}