1pub mod proxy_protocol;
2
3use std::{
4 fmt::Display,
5 io,
6 net::SocketAddr,
7 os::unix::fs::PermissionsExt,
8 path::PathBuf,
9 sync::{
10 Arc,
11 atomic::{AtomicU32, AtomicU64, Ordering},
12 },
13 time::{Duration, Instant},
14};
15
16use anyhow::{Context, anyhow};
17use async_trait::async_trait;
18use axum::{Router, extract::Request};
19use http::Response;
20use hyper::body::Incoming;
21use hyper_util::{
22 rt::{TokioExecutor, TokioIo, TokioTimer},
23 server::conn::auto::Builder,
24};
25use ic_bn_lib_common::{
26 traits::Run,
27 types::{
28 http::{
29 ALPN_ACME, Addr, ConnInfo, Error, ListenerOpts, Metrics, ProxyProtocolMode,
30 ServerOptions, TlsInfo,
31 },
32 tls::TlsOptions,
33 },
34};
35use prometheus::{
36 Registry,
37 core::{AtomicI64, GenericGauge},
38};
39use proxy_protocol::{ProxyHeader, ProxyProtocolStream};
40use rustls::sign::SingleCertAndKey;
41use scopeguard::defer;
42use socket2::{Domain, Socket, Type};
43use tokio::{
44 io::AsyncWriteExt,
45 net::{TcpListener, UnixListener, UnixSocket},
46 pin, select,
47 sync::mpsc::channel,
48 time::{sleep, timeout},
49};
50use tokio_io_timeout::TimeoutStream;
51use tokio_rustls::TlsAcceptor;
52use tokio_util::{sync::CancellationToken, task::TaskTracker};
53use tower_service::Service;
54use tracing::{debug, info, warn};
55use uuid::Uuid;
56
57use super::{AsyncCounter, AsyncReadWrite, body::NotifyingBody};
58use crate::tls::{pem_convert_to_rustls, prepare_server_config};
59
60const YEAR: Duration = Duration::from_secs(86400 * 365);
61
62pub enum Listener {
64 Tcp(TcpListener),
65 Unix(UnixListener),
66}
67
68impl Listener {
69 pub fn new(addr: Addr, opts: ListenerOpts) -> Result<Self, Error> {
71 Ok(match addr {
72 Addr::Tcp(v) => Self::Tcp(listen_tcp(v, opts)?),
73 Addr::Unix(v) => Self::Unix(listen_unix(v, opts)?),
74 })
75 }
76
77 async fn accept(&self) -> Result<(Box<dyn AsyncReadWrite>, Addr), io::Error> {
79 Ok(match self {
80 Self::Tcp(v) => {
81 let x = v.accept().await?;
82 (Box::new(x.0), Addr::Tcp(x.1))
83 }
84 Self::Unix(v) => {
85 let x = v.accept().await?;
86 (
87 Box::new(x.0),
88 Addr::Unix(x.1.as_pathname().map(|x| x.into()).unwrap_or_default()),
89 )
90 }
91 })
92 }
93
94 pub fn local_addr(&self) -> Option<SocketAddr> {
95 match &self {
96 Self::Tcp(v) => v.local_addr().ok(),
97 Self::Unix(_) => None,
98 }
99 }
100}
101
102impl From<TcpListener> for Listener {
103 fn from(v: TcpListener) -> Self {
105 Self::Tcp(v)
106 }
107}
108
109impl From<UnixListener> for Listener {
110 fn from(v: UnixListener) -> Self {
112 Self::Unix(v)
113 }
114}
115
116#[derive(Clone)]
117enum RequestState {
118 Start,
119 End,
120}
121
122async fn tls_handshake(
123 rustls_cfg: Arc<rustls::ServerConfig>,
124 stream: impl AsyncReadWrite,
125) -> Result<(impl AsyncReadWrite, TlsInfo), Error> {
126 let tls_acceptor = TlsAcceptor::from(rustls_cfg);
127
128 let start = Instant::now();
130 let stream = tls_acceptor
131 .accept(stream)
132 .await
133 .context("TLS accept failed")?;
134 let duration = start.elapsed();
135
136 let conn = stream.get_ref().1;
137 let mut tls_info = TlsInfo::try_from(conn)?;
138 tls_info.handshake_dur = duration;
139
140 Ok((stream, tls_info))
141}
142
143struct Conn {
144 addr: Addr,
145 remote_addr: Addr,
146 router: Router,
147 builder: Builder<TokioExecutor>,
148 token_graceful: CancellationToken,
149 token_forceful: CancellationToken,
150 options: ServerOptions,
151 metrics: Metrics,
152 requests: AtomicU32,
153 rustls_cfg: Option<Arc<rustls::ServerConfig>>,
154}
155
156impl Display for Conn {
157 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158 write!(f, "[{}] <- [{}]", self.addr, self.remote_addr)
159 }
160}
161
162impl Conn {
163 async fn handle(&self, stream: Box<dyn AsyncReadWrite>) -> Result<(), Error> {
164 let accepted_at = Instant::now();
165
166 debug!("{self}: got a new connection");
167
168 let addr = self.addr.to_string();
170 let labels = &mut [
171 addr.as_str(), self.remote_addr.family(), "no", "no", "no", "no", ];
178
179 let (stream, stats) = AsyncCounter::new(stream);
181
182 let (stream, proxy_hdr): (Box<dyn AsyncReadWrite>, Option<ProxyHeader>) =
184 if self.options.proxy_protocol_mode != ProxyProtocolMode::Off {
185 let (stream, hdr) = ProxyProtocolStream::accept(stream)
186 .await
187 .context("unable to accept Proxy Protocol")?;
188
189 if self.options.proxy_protocol_mode == ProxyProtocolMode::Forced && hdr.is_none() {
190 return Err(Error::NoProxyProtocolDetected);
191 }
192
193 (Box::new(stream), hdr)
194 } else {
195 (Box::new(stream), None)
196 };
197
198 let (local_addr, remote_addr) = proxy_hdr
200 .map(|x| (Addr::Tcp(x.dst), Addr::Tcp(x.src)))
201 .unwrap_or_else(|| (self.addr.clone(), self.remote_addr.clone()));
202
203 let conn_info = Arc::new(ConnInfo {
204 id: Uuid::now_v7(),
205 accepted_at,
206 remote_addr,
207 local_addr,
208 traffic: stats.clone(),
209 req_count: AtomicU64::new(0),
210 close: self.token_forceful.clone(),
211 });
212
213 let (stream, tls_info): (Box<dyn AsyncReadWrite>, _) = if let Some(rustls_cfg) =
215 &self.rustls_cfg
216 {
217 debug!("{}: performing TLS handshake", self);
218
219 let (mut stream_tls, tls_info) = timeout(
220 self.options.tls_handshake_timeout,
221 tls_handshake(rustls_cfg.clone(), stream),
222 )
223 .await
224 .context("TLS handshake timed out")?
225 .context("TLS handshake failed")?;
226
227 debug!(
228 "{}: handshake finished in {}ms (SNI: {:?}, proto: {:?}, cipher: {:?}, ALPN: {:?})",
229 self,
230 tls_info.handshake_dur.as_millis(),
231 tls_info.sni,
232 tls_info.protocol,
233 tls_info.cipher,
234 tls_info.alpn,
235 );
236
237 if tls_info
239 .alpn
240 .as_ref()
241 .is_some_and(|x| x.as_bytes() == ALPN_ACME)
242 {
243 debug!("{self}: ACME ALPN - closing connection");
244
245 timeout(Duration::from_secs(5), stream_tls.shutdown())
246 .await
247 .context("socket shutdown timed out")?
248 .context("socket shutdown failed")?;
249
250 return Ok(());
251 }
252
253 (Box::new(stream_tls), Some(Arc::new(tls_info)))
254 } else {
255 (Box::new(stream), None)
256 };
257
258 if let Some(v) = &tls_info {
260 labels[2] = v.protocol.as_str().unwrap();
261 labels[3] = v.cipher.as_str().unwrap();
262
263 self.metrics
264 .conn_tls_handshake_duration
265 .with_label_values(&labels[0..4])
266 .observe(v.handshake_dur.as_secs_f64());
267 }
268
269 self.metrics
270 .conns_open
271 .with_label_values(&labels[0..4])
272 .inc();
273
274 let requests_inflight = self
275 .metrics
276 .requests_inflight
277 .with_label_values(&labels[0..4]);
278
279 let result = self
281 .handle_inner(stream, conn_info.clone(), tls_info, requests_inflight)
282 .await;
283
284 let (sent, rcvd) = (stats.sent(), stats.rcvd());
286 let dur = accepted_at.elapsed().as_secs_f64();
287 let reqs = conn_info.req_count.load(Ordering::SeqCst);
288
289 if self.token_forceful.is_cancelled() {
291 labels[4] = "yes";
292 }
293 if self.token_graceful.is_cancelled() {
295 labels[5] = "yes";
296 }
297
298 self.metrics.conns.with_label_values(labels).inc();
299 self.metrics
300 .conns_open
301 .with_label_values(&labels[0..4])
302 .dec();
303 self.metrics.requests.with_label_values(labels).inc_by(reqs);
304 self.metrics
305 .bytes_rcvd
306 .with_label_values(labels)
307 .inc_by(rcvd);
308 self.metrics
309 .bytes_sent
310 .with_label_values(labels)
311 .inc_by(sent);
312 self.metrics
313 .conn_duration
314 .with_label_values(labels)
315 .observe(dur);
316 self.metrics
317 .requests_per_conn
318 .with_label_values(labels)
319 .observe(reqs as f64);
320
321 debug!(
322 "{self}: connection closed (rcvd: {rcvd}, sent: {sent}, reqs: {reqs}, duration: {dur}, graceful: {}, forced close: {})",
323 self.token_graceful.is_cancelled(),
324 self.token_forceful.is_cancelled(),
325 );
326
327 result
328 }
329
330 async fn handle_inner(
331 &self,
332 stream: Box<dyn AsyncReadWrite>,
333 conn_info: Arc<ConnInfo>,
334 tls_info: Option<Arc<TlsInfo>>,
335 requests_inflight: GenericGauge<AtomicI64>,
336 ) -> Result<(), Error> {
337 let mut idle_timer = Box::pin(sleep(self.options.idle_timeout.unwrap_or(10 * YEAR)));
340
341 let (state_tx, mut state_rx) = channel(65536);
344
345 let mut stream = TimeoutStream::new(stream);
347 stream.set_read_timeout(self.options.read_timeout);
348 stream.set_write_timeout(self.options.write_timeout);
349
350 let stream = TokioIo::new(stream);
352
353 let max_requests_per_conn = self.options.max_requests_per_conn;
355 let service = hyper::service::service_fn(move |mut request: Request<Incoming>| {
356 let _ = state_tx.try_send(RequestState::Start);
358
359 request.extensions_mut().insert(conn_info.clone());
361 if let Some(v) = &tls_info {
362 request.extensions_mut().insert(v.clone());
363 }
364
365 let mut router = self.router.clone();
367 let token = self.token_graceful.clone();
368 let conn_info = conn_info.clone();
369 let state_tx = state_tx.clone();
370 let requests_inflight = requests_inflight.clone();
371
372 async move {
374 requests_inflight.inc();
376
377 defer! {
380 requests_inflight.dec();
381 }
382
383 let result = router.call(request).await.map(|x| {
385 let (parts, body) = x.into_parts();
387 let body = NotifyingBody::new(body, state_tx, RequestState::End);
388 Response::from_parts(parts, body)
389 });
390
391 if let Some(v) = max_requests_per_conn {
393 let req_count = conn_info.req_count.fetch_add(1, Ordering::SeqCst);
394 if req_count + 1 >= v {
395 token.cancel();
396 }
397 }
398
399 result
400 }
401 });
402
403 let conn = self
405 .builder
406 .serve_connection_with_upgrades(Box::pin(stream), service);
407
408 pin!(conn);
410
411 loop {
412 select! {
413 biased; () = self.token_forceful.cancelled() => {
417 break;
418 }
419
420 () = self.token_graceful.cancelled() => {
422 conn.as_mut().graceful_shutdown();
425
426 let _ = timeout(self.options.grace_period, conn.as_mut()).await;
430 break;
431 },
432
433 Some(v) = state_rx.recv() => {
435 match v {
436 RequestState::Start => {
437 let reqs = self.requests.fetch_add(1, Ordering::SeqCst) + 1;
438 debug!("{self}: request started");
439
440 if self.options.idle_timeout.is_some() {
443 debug!("{self}: stopping idle timer (now: {reqs})");
444 idle_timer.as_mut().reset(tokio::time::Instant::now() + 10 * YEAR);
445 }
446 },
447
448 RequestState::End => {
449 let reqs = self.requests.fetch_sub(1, Ordering::SeqCst) - 1;
450 debug!("{self}: request finished (now: {reqs})");
451
452 if let Some(v) = self.options.idle_timeout && reqs == 0 {
454 debug!("{self}: no outstanding requests, starting timer");
456 idle_timer.as_mut().reset(tokio::time::Instant::now() + v);
457 }
458 }
459 }
460 },
461
462 () = idle_timer.as_mut(), if self.options.idle_timeout.is_some() => {
464 debug!("{self}: Idle timeout triggered, closing");
465
466 conn.as_mut().graceful_shutdown();
468 let _ = timeout(Duration::from_secs(5), conn.as_mut()).await;
470 break;
471 },
472
473 v = conn.as_mut() => {
475 if let Err(e) = v {
476 return Err(anyhow!("unable to serve connection: {e:#}").into());
477 }
478
479 break;
480 },
481 }
482 }
483
484 Ok(())
485 }
486}
487
488pub struct ServerBuilder {
490 addr: Option<Addr>,
491 router: Router,
492 registry: Registry,
493 metrics: Option<Metrics>,
494 options: ServerOptions,
495 rustls_cfg: Option<rustls::ServerConfig>,
496}
497
498impl ServerBuilder {
499 pub fn new(router: Router) -> Self {
501 Self {
502 addr: None,
503 router,
504 registry: Registry::new(),
505 metrics: None,
506 options: ServerOptions::default(),
507 rustls_cfg: None,
508 }
509 }
510
511 pub fn listen_tcp(mut self, socket: SocketAddr) -> Self {
513 self.addr = Some(Addr::Tcp(socket));
514 self
515 }
516
517 pub fn listen_unix(mut self, path: PathBuf) -> Self {
519 self.addr = Some(Addr::Unix(path));
520 self
521 }
522
523 pub fn with_metrics_registry(mut self, registry: &Registry) -> Self {
525 self.registry = registry.clone();
526 self
527 }
528
529 pub fn with_metrics(mut self, metrics: Metrics) -> Self {
532 self.metrics = Some(metrics);
533 self
534 }
535
536 pub fn with_rustls_config(mut self, rustls_cfg: rustls::ServerConfig) -> Self {
538 self.rustls_cfg = Some(rustls_cfg);
539 self
540 }
541
542 pub const fn with_options(mut self, options: ServerOptions) -> Self {
544 self.options = options;
545 self
546 }
547
548 pub fn with_rustls_single_cert(mut self, cert: PathBuf, key: PathBuf) -> Result<Self, Error> {
551 let cert = std::fs::read(cert).context("unable to read cert")?;
552 let key = std::fs::read(key).context("unable to read key")?;
553 let cert = pem_convert_to_rustls(&key, &cert).context("unable to parse cert+key pair")?;
554 let resolver = SingleCertAndKey::from(cert);
555 let tls_opts = TlsOptions::default();
556 let rustls_cfg = prepare_server_config(tls_opts, Arc::new(resolver), &self.registry);
557
558 self.rustls_cfg = Some(rustls_cfg);
559 Ok(self)
560 }
561
562 pub fn build(self) -> Result<Server, Error> {
564 let Some(addr) = self.addr else {
565 return Err(Error::Generic(anyhow!("Listening address not specified")));
566 };
567
568 let metrics = self.metrics.unwrap_or_else(|| Metrics::new(&self.registry));
569
570 Ok(Server::new(
571 addr,
572 self.router,
573 self.options,
574 metrics,
575 self.rustls_cfg,
576 ))
577 }
578}
579
580pub struct Server {
582 addr: Addr,
583 router: Router,
584 tracker: TaskTracker,
585 options: ServerOptions,
586 metrics: Metrics,
587 builder: Builder<TokioExecutor>,
588 rustls_cfg: Option<Arc<rustls::ServerConfig>>,
589}
590
591impl Display for Server {
592 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
593 write!(f, "[{}]", self.addr)
594 }
595}
596
597impl Server {
598 pub fn new(
600 addr: Addr,
601 router: Router,
602 options: ServerOptions,
603 metrics: Metrics,
604 rustls_cfg: Option<rustls::ServerConfig>,
605 ) -> Self {
606 let mut builder = Builder::new(TokioExecutor::new());
609 builder
610 .http1()
611 .timer(TokioTimer::new()) .header_read_timeout(Some(options.http1_header_read_timeout))
613 .keep_alive(true)
614 .http2()
615 .adaptive_window(true)
616 .max_concurrent_streams(Some(options.http2_max_streams))
617 .timer(TokioTimer::new()) .keep_alive_interval(options.http2_keepalive_interval)
619 .keep_alive_timeout(options.http2_keepalive_timeout)
620 .enable_connect_protocol(); Self {
623 addr,
624 router,
625 options,
626 metrics,
627 tracker: TaskTracker::new(),
628 builder,
629 rustls_cfg: rustls_cfg.map(Arc::new),
630 }
631 }
632
633 pub async fn serve(&self, token: CancellationToken) -> Result<(), Error> {
635 let opts = ListenerOpts {
636 backlog: self.options.backlog,
637 mss: self.options.tcp_mss,
638 keepalive: (&self.options).into(),
639 };
640
641 let listener = Listener::new(self.addr.clone(), opts)?;
642 self.serve_with_listener(listener, token).await
643 }
644
645 fn spawn_connection(
646 &self,
647 stream: Box<dyn AsyncReadWrite>,
648 remote_addr: Addr,
649 token: CancellationToken,
650 ) {
651 let conn = Conn {
655 addr: self.addr.clone(),
656 remote_addr: remote_addr.clone(),
657 router: self.router.clone(),
658 builder: self.builder.clone(),
659 token_graceful: token,
660 token_forceful: CancellationToken::new(),
661 options: self.options,
662 metrics: self.metrics.clone(), requests: AtomicU32::new(0),
664 rustls_cfg: self.rustls_cfg.clone(),
665 };
666
667 self.tracker.spawn(async move {
669 if let Err(e) = conn.handle(stream).await {
670 info!(
671 "[{}] <- [{remote_addr}]: failed to handle connection: {e:#}",
672 conn.addr
673 );
674 }
675
676 debug!("[{}] <- [{remote_addr}]: connection finished", conn.addr);
677 });
678 }
679
680 pub async fn serve_with_listener(
682 &self,
683 listener: Listener,
684 token: CancellationToken,
685 ) -> Result<(), Error> {
686 warn!("{self}: running (TLS: {})", self.rustls_cfg.is_some());
687
688 loop {
689 select! {
690 biased; () = token.cancelled() => {
693 drop(listener);
695
696 warn!("{self}: shutting down, waiting for the active connections to close for {}s", self.options.grace_period.as_secs());
697 self.tracker.close();
698
699 select! {
700 () = sleep(self.options.grace_period + Duration::from_secs(5)) => {
701 warn!("{self}: connections didn't close in time, shutting down anyway");
702 },
703 () = self.tracker.wait() => {},
704 }
705
706 warn!("{self}: shut down");
707
708 if let Addr::Unix(v) = &self.addr {
710 let _ = std::fs::remove_file(v);
711 }
712
713 return Ok(());
714 },
715
716 v = listener.accept() => {
718 let (stream, remote_addr) = match v {
719 Ok(v) => v,
720 Err(e) => {
721 warn!("{self}: unable to accept connection: {e:#}");
722 sleep(Duration::from_millis(10)).await;
725 continue;
726 }
727 };
728
729 self.spawn_connection(stream, remote_addr, token.child_token());
730 }
731 }
732 }
733 }
734}
735
736pub fn listen_tcp(addr: SocketAddr, opts: ListenerOpts) -> Result<TcpListener, Error> {
738 let domain = if addr.is_ipv4() {
739 Domain::IPV4
740 } else {
741 Domain::IPV6
742 };
743
744 let socket = Socket::new(domain, Type::STREAM, None).context("unable to create socket")?;
745 socket
746 .set_tcp_nodelay(true)
747 .context("unable to set TCP_NODELAY")?;
748
749 if let Some(v) = opts.mss {
750 socket.set_tcp_mss(v).context("unable to set TCP MSS")?;
751 }
752
753 socket
754 .set_reuse_address(true)
755 .context("unable to set SO_REUSEADDR")?;
756 socket
757 .set_tcp_keepalive(&opts.keepalive)
758 .context("unable to set keepalive on the socket")?;
759 socket
760 .set_nonblocking(true)
761 .context("unable to set socket into non-blocking mode")?;
762
763 socket.bind(&addr.into()).context("unable to bind socket")?;
764 socket
765 .listen(opts.backlog as i32)
766 .context("unable to listen on the socket")?;
767
768 let listener = TcpListener::from_std(socket.into())
769 .context("unable to convert socket from the standard one")?;
770
771 Ok(listener)
772}
773
774pub fn listen_unix(path: PathBuf, opts: ListenerOpts) -> Result<UnixListener, Error> {
776 let socket = UnixSocket::new_stream().context("unable to open UNIX socket")?;
777
778 if path.exists() {
779 std::fs::remove_file(&path).context("unable to remove UNIX socket")?;
780 }
781
782 socket.bind(&path).context("unable to bind socket")?;
783
784 let socket = socket
785 .listen(opts.backlog)
786 .context("unable to listen socket")?;
787
788 std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o666))
789 .context("unable to set permissions on socket")?;
790
791 Ok(socket)
792}
793
794#[async_trait]
795impl Run for Server {
796 async fn run(&self, token: CancellationToken) -> Result<(), anyhow::Error> {
797 self.serve(token).await?;
798 Ok(())
799 }
800}
801
802#[cfg(test)]
803mod test {
804 use http::StatusCode;
805
806 use super::*;
807
808 #[tokio::test]
809 async fn test_server() {
810 let opts = ServerOptions::default();
811 let listener = listen_tcp(
812 "127.0.0.1:0".parse().unwrap(),
813 ListenerOpts {
814 backlog: 128,
815 mss: None,
816 keepalive: (&opts).into(),
817 },
818 )
819 .unwrap();
820
821 let addr = listener.local_addr().unwrap();
822
823 let server = Server::new(
824 Addr::Tcp(addr),
825 Router::new(),
826 opts,
827 Metrics::new(&Registry::new()),
828 None,
829 );
830
831 tokio::spawn(async move {
832 server
833 .serve_with_listener(listener.into(), CancellationToken::new())
834 .await
835 .unwrap();
836 });
837
838 for _ in 0..10 {
839 let Ok(result) = reqwest::get(format!("http://{addr}")).await else {
840 tokio::time::sleep(Duration::from_millis(10)).await;
841 continue;
842 };
843
844 assert_eq!(result.status(), StatusCode::NOT_FOUND);
845 break;
846 }
847 }
848}