ic_bn_lib/http/server/
mod.rs

1pub mod proxy_protocol;
2
3use std::{
4    fmt::Display,
5    io,
6    net::SocketAddr,
7    os::unix::fs::PermissionsExt,
8    path::PathBuf,
9    sync::{
10        Arc,
11        atomic::{AtomicU32, AtomicU64, Ordering},
12    },
13    time::{Duration, Instant},
14};
15
16use anyhow::{Context, anyhow};
17use async_trait::async_trait;
18use axum::{Router, extract::Request};
19use http::Response;
20use hyper::body::Incoming;
21use hyper_util::{
22    rt::{TokioExecutor, TokioIo, TokioTimer},
23    server::conn::auto::Builder,
24};
25use ic_bn_lib_common::{
26    traits::Run,
27    types::{
28        http::{
29            ALPN_ACME, Addr, ConnInfo, Error, ListenerOpts, Metrics, ProxyProtocolMode,
30            ServerOptions, TlsInfo,
31        },
32        tls::TlsOptions,
33    },
34};
35use prometheus::{
36    Registry,
37    core::{AtomicI64, GenericGauge},
38};
39use proxy_protocol::{ProxyHeader, ProxyProtocolStream};
40use rustls::sign::SingleCertAndKey;
41use scopeguard::defer;
42use socket2::{Domain, Socket, Type};
43use tokio::{
44    io::AsyncWriteExt,
45    net::{TcpListener, UnixListener, UnixSocket},
46    pin, select,
47    sync::mpsc::channel,
48    time::{sleep, timeout},
49};
50use tokio_io_timeout::TimeoutStream;
51use tokio_rustls::TlsAcceptor;
52use tokio_util::{sync::CancellationToken, task::TaskTracker};
53use tower_service::Service;
54use tracing::{debug, info, warn};
55use uuid::Uuid;
56
57use super::{AsyncCounter, AsyncReadWrite, body::NotifyingBody};
58use crate::tls::{pem_convert_to_rustls, prepare_server_config};
59
60const YEAR: Duration = Duration::from_secs(86400 * 365);
61
62/// Connection listener
63pub enum Listener {
64    Tcp(TcpListener),
65    Unix(UnixListener),
66}
67
68impl Listener {
69    /// Create a new Listener
70    pub fn new(addr: Addr, opts: ListenerOpts) -> Result<Self, Error> {
71        Ok(match addr {
72            Addr::Tcp(v) => Self::Tcp(listen_tcp(v, opts)?),
73            Addr::Unix(v) => Self::Unix(listen_unix(v, opts)?),
74        })
75    }
76
77    /// Accept the connection
78    async fn accept(&self) -> Result<(Box<dyn AsyncReadWrite>, Addr), io::Error> {
79        Ok(match self {
80            Self::Tcp(v) => {
81                let x = v.accept().await?;
82                (Box::new(x.0), Addr::Tcp(x.1))
83            }
84            Self::Unix(v) => {
85                let x = v.accept().await?;
86                (
87                    Box::new(x.0),
88                    Addr::Unix(x.1.as_pathname().map(|x| x.into()).unwrap_or_default()),
89                )
90            }
91        })
92    }
93
94    pub fn local_addr(&self) -> Option<SocketAddr> {
95        match &self {
96            Self::Tcp(v) => v.local_addr().ok(),
97            Self::Unix(_) => None,
98        }
99    }
100}
101
102impl From<TcpListener> for Listener {
103    /// Creates a Listener from TcpListener
104    fn from(v: TcpListener) -> Self {
105        Self::Tcp(v)
106    }
107}
108
109impl From<UnixListener> for Listener {
110    /// Creates a Listener from UnixListener
111    fn from(v: UnixListener) -> Self {
112        Self::Unix(v)
113    }
114}
115
116#[derive(Clone)]
117enum RequestState {
118    Start,
119    End,
120}
121
122async fn tls_handshake(
123    rustls_cfg: Arc<rustls::ServerConfig>,
124    stream: impl AsyncReadWrite,
125) -> Result<(impl AsyncReadWrite, TlsInfo), Error> {
126    let tls_acceptor = TlsAcceptor::from(rustls_cfg);
127
128    // Perform the TLS handshake
129    let start = Instant::now();
130    let stream = tls_acceptor
131        .accept(stream)
132        .await
133        .context("TLS accept failed")?;
134    let duration = start.elapsed();
135
136    let conn = stream.get_ref().1;
137    let mut tls_info = TlsInfo::try_from(conn)?;
138    tls_info.handshake_dur = duration;
139
140    Ok((stream, tls_info))
141}
142
143struct Conn {
144    addr: Addr,
145    remote_addr: Addr,
146    router: Router,
147    builder: Builder<TokioExecutor>,
148    token_graceful: CancellationToken,
149    token_forceful: CancellationToken,
150    options: ServerOptions,
151    metrics: Metrics,
152    requests: AtomicU32,
153    rustls_cfg: Option<Arc<rustls::ServerConfig>>,
154}
155
156impl Display for Conn {
157    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158        write!(f, "[{}] <- [{}]", self.addr, self.remote_addr)
159    }
160}
161
162impl Conn {
163    async fn handle(&self, stream: Box<dyn AsyncReadWrite>) -> Result<(), Error> {
164        let accepted_at = Instant::now();
165
166        debug!("{self}: got a new connection");
167
168        // Prepare metric labels
169        let addr = self.addr.to_string();
170        let labels = &mut [
171            addr.as_str(),             // Listening addr
172            self.remote_addr.family(), // Remote client address family
173            "no",                      // TLS version
174            "no",                      // TLS ciphersuite
175            "no",                      // Force-closed
176            "no",                      // Recycled
177        ];
178
179        // Wrap with traffic counter
180        let (stream, stats) = AsyncCounter::new(stream);
181
182        // Read & parse Proxy Protocol v2 header if configured
183        let (stream, proxy_hdr): (Box<dyn AsyncReadWrite>, Option<ProxyHeader>) =
184            if self.options.proxy_protocol_mode != ProxyProtocolMode::Off {
185                let (stream, hdr) = ProxyProtocolStream::accept(stream)
186                    .await
187                    .context("unable to accept Proxy Protocol")?;
188
189                if self.options.proxy_protocol_mode == ProxyProtocolMode::Forced && hdr.is_none() {
190                    return Err(Error::NoProxyProtocolDetected);
191                }
192
193                (Box::new(stream), hdr)
194            } else {
195                (Box::new(stream), None)
196            };
197
198        // Use IPs from Proxy Protocol if available
199        let (local_addr, remote_addr) = proxy_hdr
200            .map(|x| (Addr::Tcp(x.dst), Addr::Tcp(x.src)))
201            .unwrap_or_else(|| (self.addr.clone(), self.remote_addr.clone()));
202
203        let conn_info = Arc::new(ConnInfo {
204            id: Uuid::now_v7(),
205            accepted_at,
206            remote_addr,
207            local_addr,
208            traffic: stats.clone(),
209            req_count: AtomicU64::new(0),
210            close: self.token_forceful.clone(),
211        });
212
213        // Perform TLS handshake if we're in TLS mode
214        let (stream, tls_info): (Box<dyn AsyncReadWrite>, _) = if let Some(rustls_cfg) =
215            &self.rustls_cfg
216        {
217            debug!("{}: performing TLS handshake", self);
218
219            let (mut stream_tls, tls_info) = timeout(
220                self.options.tls_handshake_timeout,
221                tls_handshake(rustls_cfg.clone(), stream),
222            )
223            .await
224            .context("TLS handshake timed out")?
225            .context("TLS handshake failed")?;
226
227            debug!(
228                "{}: handshake finished in {}ms (SNI: {:?}, proto: {:?}, cipher: {:?}, ALPN: {:?})",
229                self,
230                tls_info.handshake_dur.as_millis(),
231                tls_info.sni,
232                tls_info.protocol,
233                tls_info.cipher,
234                tls_info.alpn,
235            );
236
237            // Close the connection if agreed ALPN is ACME - the handshake is enough for the challenge
238            if tls_info
239                .alpn
240                .as_ref()
241                .is_some_and(|x| x.as_bytes() == ALPN_ACME)
242            {
243                debug!("{self}: ACME ALPN - closing connection");
244
245                timeout(Duration::from_secs(5), stream_tls.shutdown())
246                    .await
247                    .context("socket shutdown timed out")?
248                    .context("socket shutdown failed")?;
249
250                return Ok(());
251            }
252
253            (Box::new(stream_tls), Some(Arc::new(tls_info)))
254        } else {
255            (Box::new(stream), None)
256        };
257
258        // Record TLS metrics
259        if let Some(v) = &tls_info {
260            labels[2] = v.protocol.as_str().unwrap();
261            labels[3] = v.cipher.as_str().unwrap();
262
263            self.metrics
264                .conn_tls_handshake_duration
265                .with_label_values(&labels[0..4])
266                .observe(v.handshake_dur.as_secs_f64());
267        }
268
269        self.metrics
270            .conns_open
271            .with_label_values(&labels[0..4])
272            .inc();
273
274        let requests_inflight = self
275            .metrics
276            .requests_inflight
277            .with_label_values(&labels[0..4]);
278
279        // Handle the connection
280        let result = self
281            .handle_inner(stream, conn_info.clone(), tls_info, requests_inflight)
282            .await;
283
284        // Record connection metrics
285        let (sent, rcvd) = (stats.sent(), stats.rcvd());
286        let dur = accepted_at.elapsed().as_secs_f64();
287        let reqs = conn_info.req_count.load(Ordering::SeqCst);
288
289        // force-closed
290        if self.token_forceful.is_cancelled() {
291            labels[4] = "yes";
292        }
293        // recycled
294        if self.token_graceful.is_cancelled() {
295            labels[5] = "yes";
296        }
297
298        self.metrics.conns.with_label_values(labels).inc();
299        self.metrics
300            .conns_open
301            .with_label_values(&labels[0..4])
302            .dec();
303        self.metrics.requests.with_label_values(labels).inc_by(reqs);
304        self.metrics
305            .bytes_rcvd
306            .with_label_values(labels)
307            .inc_by(rcvd);
308        self.metrics
309            .bytes_sent
310            .with_label_values(labels)
311            .inc_by(sent);
312        self.metrics
313            .conn_duration
314            .with_label_values(labels)
315            .observe(dur);
316        self.metrics
317            .requests_per_conn
318            .with_label_values(labels)
319            .observe(reqs as f64);
320
321        debug!(
322            "{self}: connection closed (rcvd: {rcvd}, sent: {sent}, reqs: {reqs}, duration: {dur}, graceful: {}, forced close: {})",
323            self.token_graceful.is_cancelled(),
324            self.token_forceful.is_cancelled(),
325        );
326
327        result
328    }
329
330    async fn handle_inner(
331        &self,
332        stream: Box<dyn AsyncReadWrite>,
333        conn_info: Arc<ConnInfo>,
334        tls_info: Option<Arc<TlsInfo>>,
335        requests_inflight: GenericGauge<AtomicI64>,
336    ) -> Result<(), Error> {
337        // Create a timer for idle connection tracking.
338        // Falls back to 10 years if idle timer is not set (for simplicity)
339        let mut idle_timer = Box::pin(sleep(self.options.idle_timeout.unwrap_or(10 * YEAR)));
340
341        // Create channel to notify about request start/stop.
342        // Use bounded but big enough so that it's larger than our concurrency.
343        let (state_tx, mut state_rx) = channel(65536);
344
345        // Apply timeouts on read/write calls
346        let mut stream = TimeoutStream::new(stream);
347        stream.set_read_timeout(self.options.read_timeout);
348        stream.set_write_timeout(self.options.write_timeout);
349
350        // Convert stream from Tokio to Hyper
351        let stream = TokioIo::new(stream);
352
353        // Convert router to Hyper service
354        let max_requests_per_conn = self.options.max_requests_per_conn;
355        let service = hyper::service::service_fn(move |mut request: Request<Incoming>| {
356            // Notify that we have started processing the request
357            let _ = state_tx.try_send(RequestState::Start);
358
359            // Inject connection information
360            request.extensions_mut().insert(conn_info.clone());
361            if let Some(v) = &tls_info {
362                request.extensions_mut().insert(v.clone());
363            }
364
365            // Clone the stuff needed in the async block below
366            let mut router = self.router.clone();
367            let token = self.token_graceful.clone();
368            let conn_info = conn_info.clone();
369            let state_tx = state_tx.clone();
370            let requests_inflight = requests_inflight.clone();
371
372            // Return the future
373            async move {
374                // Increase the global inflight requests counter
375                requests_inflight.inc();
376
377                // Since the future can be cancelled we need defer to decrease the counter in any case
378                // to avoid leaking the inflight requests
379                defer! {
380                    requests_inflight.dec();
381                }
382
383                // Execute the request
384                let result = router.call(request).await.map(|x| {
385                    // Wrap the response body into a notifying one
386                    let (parts, body) = x.into_parts();
387                    let body = NotifyingBody::new(body, state_tx, RequestState::End);
388                    Response::from_parts(parts, body)
389                });
390
391                // Check if we need to gracefully shutdown this connection
392                if let Some(v) = max_requests_per_conn {
393                    let req_count = conn_info.req_count.fetch_add(1, Ordering::SeqCst);
394                    if req_count + 1 >= v {
395                        token.cancel();
396                    }
397                }
398
399                result
400            }
401        });
402
403        // Serve the connection
404        let conn = self
405            .builder
406            .serve_connection_with_upgrades(Box::pin(stream), service);
407
408        // Using mutable future reference requires pinning
409        pin!(conn);
410
411        loop {
412            select! {
413                biased; // Poll top-down
414
415                // Immediately close the connection if was requested
416                () = self.token_forceful.cancelled() => {
417                    break;
418                }
419
420                // Start graceful shutdown of the connection
421                () = self.token_graceful.cancelled() => {
422                    // For H2: sends GOAWAY frames to the client
423                    // For H1: disables keepalives
424                    conn.as_mut().graceful_shutdown();
425
426                    // Wait for the grace period to finish or connection to complete.
427                    // Connection must still be polled for the shutdown to proceed.
428                    // We don't really care for the result.
429                    let _ = timeout(self.options.grace_period, conn.as_mut()).await;
430                    break;
431                },
432
433                // Get request state change notifications
434                Some(v) = state_rx.recv() => {
435                    match v {
436                        RequestState::Start => {
437                            let reqs = self.requests.fetch_add(1, Ordering::SeqCst) + 1;
438                            debug!("{self}: request started");
439
440                            // Effectively disable the timer by setting it to 10 years into the future.
441                            // TODO improve?
442                            if self.options.idle_timeout.is_some() {
443                                debug!("{self}: stopping idle timer (now: {reqs})");
444                                idle_timer.as_mut().reset(tokio::time::Instant::now() + 10 * YEAR);
445                            }
446                        },
447
448                        RequestState::End => {
449                            let reqs = self.requests.fetch_sub(1, Ordering::SeqCst) - 1;
450                            debug!("{self}: request finished (now: {reqs})");
451
452                            // Check if the number of outstanding requests is now zero
453                            if let Some(v) = self.options.idle_timeout && reqs == 0 {
454                                // Enable the idle timer
455                                debug!("{self}: no outstanding requests, starting timer");
456                                idle_timer.as_mut().reset(tokio::time::Instant::now() + v);
457                            }
458                        }
459                    }
460                },
461
462                // See if the idle timeout has kicked in
463                () = idle_timer.as_mut(), if self.options.idle_timeout.is_some() => {
464                    debug!("{self}: Idle timeout triggered, closing");
465
466                    // Signal that we're closing
467                    conn.as_mut().graceful_shutdown();
468                    // Give the client some time to shut down
469                    let _ = timeout(Duration::from_secs(5), conn.as_mut()).await;
470                    break;
471                },
472
473                // Drive the connection by polling it
474                v = conn.as_mut() => {
475                    if let Err(e) = v {
476                        return Err(anyhow!("unable to serve connection: {e:#}").into());
477                    }
478
479                    break;
480                },
481            }
482        }
483
484        Ok(())
485    }
486}
487
488/// Builder for a `Server`
489pub struct ServerBuilder {
490    addr: Option<Addr>,
491    router: Router,
492    registry: Registry,
493    metrics: Option<Metrics>,
494    options: ServerOptions,
495    rustls_cfg: Option<rustls::ServerConfig>,
496}
497
498impl ServerBuilder {
499    /// Creates a builder with a given router & defaults
500    pub fn new(router: Router) -> Self {
501        Self {
502            addr: None,
503            router,
504            registry: Registry::new(),
505            metrics: None,
506            options: ServerOptions::default(),
507            rustls_cfg: None,
508        }
509    }
510
511    /// Listens on the given TCP socket
512    pub fn listen_tcp(mut self, socket: SocketAddr) -> Self {
513        self.addr = Some(Addr::Tcp(socket));
514        self
515    }
516
517    /// Listens on the given Unix socket
518    pub fn listen_unix(mut self, path: PathBuf) -> Self {
519        self.addr = Some(Addr::Unix(path));
520        self
521    }
522
523    /// Sets up metrics with provided Registry
524    pub fn with_metrics_registry(mut self, registry: &Registry) -> Self {
525        self.registry = registry.clone();
526        self
527    }
528
529    /// Sets up metrics with provided Metrics.
530    /// Overrides `with_metrics_registry()`.
531    pub fn with_metrics(mut self, metrics: Metrics) -> Self {
532        self.metrics = Some(metrics);
533        self
534    }
535
536    /// Sets up TLS with provided ServerConfig
537    pub fn with_rustls_config(mut self, rustls_cfg: rustls::ServerConfig) -> Self {
538        self.rustls_cfg = Some(rustls_cfg);
539        self
540    }
541
542    /// Sets up with provided Options
543    pub const fn with_options(mut self, options: ServerOptions) -> Self {
544        self.options = options;
545        self
546    }
547
548    /// Sets up TLS with a single certificate.
549    /// If metrics are needed - provide registry using `with_metrics_registry` before calling this method.
550    pub fn with_rustls_single_cert(mut self, cert: PathBuf, key: PathBuf) -> Result<Self, Error> {
551        let cert = std::fs::read(cert).context("unable to read cert")?;
552        let key = std::fs::read(key).context("unable to read key")?;
553        let cert = pem_convert_to_rustls(&key, &cert).context("unable to parse cert+key pair")?;
554        let resolver = SingleCertAndKey::from(cert);
555        let tls_opts = TlsOptions::default();
556        let rustls_cfg = prepare_server_config(tls_opts, Arc::new(resolver), &self.registry);
557
558        self.rustls_cfg = Some(rustls_cfg);
559        Ok(self)
560    }
561
562    /// Build the Server
563    pub fn build(self) -> Result<Server, Error> {
564        let Some(addr) = self.addr else {
565            return Err(Error::Generic(anyhow!("Listening address not specified")));
566        };
567
568        let metrics = self.metrics.unwrap_or_else(|| Metrics::new(&self.registry));
569
570        Ok(Server::new(
571            addr,
572            self.router,
573            self.options,
574            metrics,
575            self.rustls_cfg,
576        ))
577    }
578}
579
580/// Listens for new connections with an optional TLS and serves provided Router
581pub struct Server {
582    addr: Addr,
583    router: Router,
584    tracker: TaskTracker,
585    options: ServerOptions,
586    metrics: Metrics,
587    builder: Builder<TokioExecutor>,
588    rustls_cfg: Option<Arc<rustls::ServerConfig>>,
589}
590
591impl Display for Server {
592    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
593        write!(f, "[{}]", self.addr)
594    }
595}
596
597impl Server {
598    /// Create a new `Server`
599    pub fn new(
600        addr: Addr,
601        router: Router,
602        options: ServerOptions,
603        metrics: Metrics,
604        rustls_cfg: Option<rustls::ServerConfig>,
605    ) -> Self {
606        // Prepare Hyper connection builder
607        // It automatically figures out whether to do HTTP1 or HTTP2
608        let mut builder = Builder::new(TokioExecutor::new());
609        builder
610            .http1()
611            .timer(TokioTimer::new()) // Needed for the keepalives below
612            .header_read_timeout(Some(options.http1_header_read_timeout))
613            .keep_alive(true)
614            .http2()
615            .adaptive_window(true)
616            .max_concurrent_streams(Some(options.http2_max_streams))
617            .timer(TokioTimer::new()) // Needed for the keepalives below
618            .keep_alive_interval(options.http2_keepalive_interval)
619            .keep_alive_timeout(options.http2_keepalive_timeout)
620            .enable_connect_protocol(); // Needed for Websockets
621
622        Self {
623            addr,
624            router,
625            options,
626            metrics,
627            tracker: TaskTracker::new(),
628            builder,
629            rustls_cfg: rustls_cfg.map(Arc::new),
630        }
631    }
632
633    /// Start serving with given cancellation token
634    pub async fn serve(&self, token: CancellationToken) -> Result<(), Error> {
635        let opts = ListenerOpts {
636            backlog: self.options.backlog,
637            mss: self.options.tcp_mss,
638            keepalive: (&self.options).into(),
639        };
640
641        let listener = Listener::new(self.addr.clone(), opts)?;
642        self.serve_with_listener(listener, token).await
643    }
644
645    fn spawn_connection(
646        &self,
647        stream: Box<dyn AsyncReadWrite>,
648        remote_addr: Addr,
649        token: CancellationToken,
650    ) {
651        // Create a new connection
652        // Router & TlsAcceptor are both Arc<> inside so it's cheap to clone
653        // Builder is a bit more complex, but cloning is better than to create it again
654        let conn = Conn {
655            addr: self.addr.clone(),
656            remote_addr: remote_addr.clone(),
657            router: self.router.clone(),
658            builder: self.builder.clone(),
659            token_graceful: token,
660            token_forceful: CancellationToken::new(),
661            options: self.options,
662            metrics: self.metrics.clone(), // All metrics have Arc inside
663            requests: AtomicU32::new(0),
664            rustls_cfg: self.rustls_cfg.clone(),
665        };
666
667        // Spawn a task to handle connection & track it
668        self.tracker.spawn(async move {
669            if let Err(e) = conn.handle(stream).await {
670                info!(
671                    "[{}] <- [{remote_addr}]: failed to handle connection: {e:#}",
672                    conn.addr
673                );
674            }
675
676            debug!("[{}] <- [{remote_addr}]: connection finished", conn.addr);
677        });
678    }
679
680    /// Start serving with a given listener & cancellation token
681    pub async fn serve_with_listener(
682        &self,
683        listener: Listener,
684        token: CancellationToken,
685    ) -> Result<(), Error> {
686        warn!("{self}: running (TLS: {})", self.rustls_cfg.is_some());
687
688        loop {
689            select! {
690                biased; // Poll top-down
691
692                () = token.cancelled() => {
693                    // Stop accepting new connections
694                    drop(listener);
695
696                    warn!("{self}: shutting down, waiting for the active connections to close for {}s", self.options.grace_period.as_secs());
697                    self.tracker.close();
698
699                    select! {
700                        () = sleep(self.options.grace_period + Duration::from_secs(5)) => {
701                            warn!("{self}: connections didn't close in time, shutting down anyway");
702                        },
703                        () = self.tracker.wait() => {},
704                    }
705
706                    warn!("{self}: shut down");
707
708                    // Remove the socket
709                    if let Addr::Unix(v) = &self.addr {
710                        let _ = std::fs::remove_file(v);
711                    }
712
713                    return Ok(());
714                },
715
716                // Try to accept the connection
717                v = listener.accept() => {
718                    let (stream, remote_addr) = match v {
719                        Ok(v) => v,
720                        Err(e) => {
721                            warn!("{self}: unable to accept connection: {e:#}");
722                            // Wait few ms just in case that there's an overflown backlog
723                            // so that we don't run into infinite error loop
724                            sleep(Duration::from_millis(10)).await;
725                            continue;
726                        }
727                    };
728
729                    self.spawn_connection(stream, remote_addr, token.child_token());
730                }
731            }
732        }
733    }
734}
735
736/// Creates a TCP listener with given opts
737pub fn listen_tcp(addr: SocketAddr, opts: ListenerOpts) -> Result<TcpListener, Error> {
738    let domain = if addr.is_ipv4() {
739        Domain::IPV4
740    } else {
741        Domain::IPV6
742    };
743
744    let socket = Socket::new(domain, Type::STREAM, None).context("unable to create socket")?;
745    socket
746        .set_tcp_nodelay(true)
747        .context("unable to set TCP_NODELAY")?;
748
749    if let Some(v) = opts.mss {
750        socket.set_tcp_mss(v).context("unable to set TCP MSS")?;
751    }
752
753    socket
754        .set_reuse_address(true)
755        .context("unable to set SO_REUSEADDR")?;
756    socket
757        .set_tcp_keepalive(&opts.keepalive)
758        .context("unable to set keepalive on the socket")?;
759    socket
760        .set_nonblocking(true)
761        .context("unable to set socket into non-blocking mode")?;
762
763    socket.bind(&addr.into()).context("unable to bind socket")?;
764    socket
765        .listen(opts.backlog as i32)
766        .context("unable to listen on the socket")?;
767
768    let listener = TcpListener::from_std(socket.into())
769        .context("unable to convert socket from the standard one")?;
770
771    Ok(listener)
772}
773
774/// Creates a Unix Socket listener with given opts
775pub fn listen_unix(path: PathBuf, opts: ListenerOpts) -> Result<UnixListener, Error> {
776    let socket = UnixSocket::new_stream().context("unable to open UNIX socket")?;
777
778    if path.exists() {
779        std::fs::remove_file(&path).context("unable to remove UNIX socket")?;
780    }
781
782    socket.bind(&path).context("unable to bind socket")?;
783
784    let socket = socket
785        .listen(opts.backlog)
786        .context("unable to listen socket")?;
787
788    std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o666))
789        .context("unable to set permissions on socket")?;
790
791    Ok(socket)
792}
793
794#[async_trait]
795impl Run for Server {
796    async fn run(&self, token: CancellationToken) -> Result<(), anyhow::Error> {
797        self.serve(token).await?;
798        Ok(())
799    }
800}
801
802#[cfg(test)]
803mod test {
804    use http::StatusCode;
805
806    use super::*;
807
808    #[tokio::test]
809    async fn test_server() {
810        let opts = ServerOptions::default();
811        let listener = listen_tcp(
812            "127.0.0.1:0".parse().unwrap(),
813            ListenerOpts {
814                backlog: 128,
815                mss: None,
816                keepalive: (&opts).into(),
817            },
818        )
819        .unwrap();
820
821        let addr = listener.local_addr().unwrap();
822
823        let server = Server::new(
824            Addr::Tcp(addr),
825            Router::new(),
826            opts,
827            Metrics::new(&Registry::new()),
828            None,
829        );
830
831        tokio::spawn(async move {
832            server
833                .serve_with_listener(listener.into(), CancellationToken::new())
834                .await
835                .unwrap();
836        });
837
838        for _ in 0..10 {
839            let Ok(result) = reqwest::get(format!("http://{addr}")).await else {
840                tokio::time::sleep(Duration::from_millis(10)).await;
841                continue;
842            };
843
844            assert_eq!(result.status(), StatusCode::NOT_FOUND);
845            break;
846        }
847    }
848}