Skip to main content

iroh_relay/server/
http_server.rs

1//! Low-level HTTP server components for embedding the relay service.
2//!
3//! This module provides [`RelayService`] which can be used to embed relay functionality
4//! into an existing HTTP server. It handles individual connections and provides
5//! the core relay protocol implementation.
6//!
7//! For a complete relay server implementation, see the parent [`server`](super) module.
8
9use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration};
10
11use bytes::Bytes;
12use derive_more::Debug;
13use http::{
14    header::{CONNECTION, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION},
15    response::Builder as ResponseBuilder,
16};
17use hyper::{
18    HeaderMap, Method, Request, Response, StatusCode,
19    body::Incoming,
20    header::{HeaderValue, SEC_WEBSOCKET_ACCEPT, UPGRADE},
21    service::Service,
22    upgrade::Upgraded,
23};
24use n0_error::{e, ensure, stack_error};
25use n0_future::MaybeFuture;
26use tokio::{
27    net::{TcpListener, TcpStream},
28    sync::Notify,
29};
30use tokio_rustls_acme::AcmeAcceptor;
31use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle};
32use tracing::{Instrument, debug, error, info, info_span, trace, warn, warn_span};
33
34use super::{AccessConfig, SpawnError, clients::Clients, streams::InvalidBucketConfig};
35use crate::{
36    KeyCache,
37    defaults::{DEFAULT_KEY_CACHE_CAPACITY, timeouts::SERVER_WRITE_TIMEOUT},
38    http::{
39        CLIENT_AUTH_HEADER, ProtocolVersion, RELAY_PATH, SUPPORTED_WEBSOCKET_VERSION,
40        WEBSOCKET_UPGRADE_PROTOCOL,
41    },
42    protos::{
43        handshake,
44        relay::{MAX_FRAME_SIZE, PER_CLIENT_SEND_QUEUE_DEPTH},
45        streams::WsBytesFramed,
46    },
47    server::{
48        ClientRateLimit,
49        client::Config,
50        metrics::Metrics,
51        streams::{MaybeTlsStream, RateLimited, RelayedStream},
52    },
53};
54
55type BytesBody = http_body_util::Full<hyper::body::Bytes>;
56type HyperError = Box<dyn std::error::Error + Send + Sync>;
57type HyperResult<T> = std::result::Result<T, HyperError>;
58type HyperHandler = Box<
59    dyn Fn(Request<Incoming>, ResponseBuilder) -> HyperResult<Response<BytesBody>>
60        + Send
61        + Sync
62        + 'static,
63>;
64
65/// WebSocket GUID needed for accepting websocket connections, see RFC 6455 (<https://www.rfc-editor.org/rfc/rfc6455>) section 1.3
66const SEC_WEBSOCKET_ACCEPT_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
67
68/// Timeout for a connection to finish the TLS and WebSocket upgrade handshakes.
69///
70/// The connection is aborted if the connection does not complete the TLS handshake
71/// and establishes relay protocol WebSocket stream within this timeout.
72const ESTABLISH_TIMEOUT: Duration = Duration::from_secs(30);
73
74/// Derives the accept key for WebSocket handshake according to RFC 6455.
75/// Takes the client's Sec-WebSocket-Key value and returns the calculated accept key.
76fn derive_accept_key(client_key: &HeaderValue) -> String {
77    use sha1::Digest;
78
79    let mut sha1 = sha1::Sha1::new();
80    sha1.update(client_key.as_bytes());
81    sha1.update(SEC_WEBSOCKET_ACCEPT_GUID);
82    data_encoding::BASE64.encode(&sha1.finalize())
83}
84
85/// Creates a new [`BytesBody`] with given content.
86fn body_full(content: impl Into<hyper::body::Bytes>) -> BytesBody {
87    http_body_util::Full::new(content.into())
88}
89
90#[allow(clippy::result_large_err)]
91fn downcast_upgrade(upgraded: Upgraded) -> Result<(MaybeTlsStream, Bytes), ConnectionHandlerError> {
92    match upgraded.downcast::<hyper_util::rt::TokioIo<MaybeTlsStream>>() {
93        Ok(parts) => Ok((parts.io.into_inner(), parts.read_buf)),
94        Err(_) => Err(e!(ConnectionHandlerError::DowncastUpgrade)),
95    }
96}
97
98/// The Relay HTTP server.
99///
100/// A running HTTP server serving the relay endpoint and optionally a number of additional
101/// HTTP services added with [`ServerBuilder::request_handler`].  If configured using
102/// [`ServerBuilder::tls_config`] the server will handle TLS as well.
103///
104/// Created using [`ServerBuilder::spawn`].
105#[derive(Debug)]
106pub(super) struct Server {
107    addr: SocketAddr,
108    http_server_task: AbortOnDropHandle<()>,
109    cancel_server_loop: CancellationToken,
110}
111
112impl Server {
113    /// Returns a handle for this server.
114    ///
115    /// The server runs in the background as several async tasks.  This allows controlling
116    /// the server, in particular it allows gracefully shutting down the server.
117    pub(super) fn handle(&self) -> ServerHandle {
118        ServerHandle {
119            cancel_token: self.cancel_server_loop.clone(),
120        }
121    }
122
123    /// Closes the underlying relay server and the HTTP(S) server tasks.
124    pub(super) fn shutdown(&self) {
125        self.cancel_server_loop.cancel();
126    }
127
128    /// Returns the [`AbortOnDropHandle`] for the supervisor task managing the server.
129    ///
130    /// This is the root of all the tasks for the server.  Aborting it will abort all the
131    /// other tasks for the server.  Awaiting it will complete when all the server tasks are
132    /// completed.
133    pub(super) fn task_handle(&mut self) -> &mut AbortOnDropHandle<()> {
134        &mut self.http_server_task
135    }
136
137    /// Returns the local address of this server.
138    pub(super) fn addr(&self) -> SocketAddr {
139        self.addr
140    }
141}
142
143/// A handle for the [`Server`].
144///
145/// This does not allow access to the task but can communicate with it.
146#[derive(Debug, Clone)]
147pub(super) struct ServerHandle {
148    cancel_token: CancellationToken,
149}
150
151impl ServerHandle {
152    /// Gracefully shut down the server.
153    pub(super) fn shutdown(&self) {
154        self.cancel_token.cancel()
155    }
156}
157
158/// Configuration to use for the TLS connection
159///
160/// This struct wraps a rustls server configuration and TLS acceptor for use with
161/// [`RelayService::handle_connection`].
162///
163/// # Example
164///
165/// ```
166/// use std::sync::Arc;
167///
168/// use iroh_relay::server::http_server::TlsConfig;
169/// use rustls::ServerConfig;
170/// use webpki_types::{CertificateDer, PrivateKeyDer};
171///
172/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
173/// // Set ring as the process-level default crypto provider
174/// rustls::crypto::ring::default_provider()
175///     .install_default()
176///     .ok();
177/// // Generate a self-signed certificate for testing
178/// let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])?;
179/// let cert_der = cert.cert.der().to_vec();
180/// let private_key_der = cert.signing_key.serialize_der();
181///
182/// // Create rustls types
183/// let cert_chain = vec![CertificateDer::from(cert_der)];
184/// let private_key = PrivateKeyDer::try_from(private_key_der)?;
185///
186/// // Create a rustls ServerConfig
187/// let server_config = Arc::new(
188///     ServerConfig::builder()
189///         .with_no_client_auth()
190///         .with_single_cert(cert_chain, private_key)?,
191/// );
192///
193/// // Create TlsConfig for use with RelayService
194/// let tls_config = TlsConfig::new(server_config);
195/// # Ok(())
196/// # }
197/// ```
198#[derive(Debug, Clone)]
199pub struct TlsConfig {
200    /// The server config
201    pub(super) config: Arc<rustls::ServerConfig>,
202    /// The kind
203    pub(super) acceptor: TlsAcceptor,
204}
205
206impl TlsConfig {
207    /// Creates a new `TlsConfig` from a rustls `ServerConfig`.
208    ///
209    /// This creates a manual TLS acceptor using the provided server configuration.
210    /// The acceptor will handle TLS handshakes for incoming connections.
211    ///
212    /// # Example
213    ///
214    /// ```
215    /// use std::sync::Arc;
216    ///
217    /// use iroh_relay::server::http_server::TlsConfig;
218    /// use rustls::ServerConfig;
219    /// use webpki_types::{CertificateDer, PrivateKeyDer};
220    ///
221    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
222    /// // Set ring as the process-level default crypto provider
223    /// rustls::crypto::ring::default_provider()
224    ///     .install_default()
225    ///     .ok();
226    /// // Generate a self-signed certificate for testing
227    /// let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])?;
228    /// let cert_der = cert.cert.der().to_vec();
229    /// let private_key_der = cert.signing_key.serialize_der();
230    ///
231    /// // Create rustls types
232    /// let cert_chain = vec![CertificateDer::from(cert_der)];
233    /// let private_key = PrivateKeyDer::try_from(private_key_der)?;
234    ///
235    /// let server_config = Arc::new(
236    ///     ServerConfig::builder()
237    ///         .with_no_client_auth()
238    ///         .with_single_cert(cert_chain, private_key)?,
239    /// );
240    ///
241    /// let tls_config = TlsConfig::new(server_config);
242    /// # Ok(())
243    /// # }
244    /// ```
245    pub fn new(config: Arc<rustls::ServerConfig>) -> Self {
246        let acceptor = tokio_rustls::TlsAcceptor::from(config.clone());
247        Self {
248            config,
249            acceptor: TlsAcceptor::Manual(acceptor),
250        }
251    }
252}
253
254/// Errors when attempting to upgrade and
255#[allow(missing_docs)]
256#[stack_error(derive, add_meta)]
257#[non_exhaustive]
258pub enum ServeConnectionError {
259    #[error("TLS[acme] handshake")]
260    TlsHandshake {
261        #[error(std_err)]
262        source: std::io::Error,
263    },
264    #[error("TLS[acme] serve connection")]
265    ServeConnection {
266        #[error(std_err)]
267        source: hyper::Error,
268    },
269    #[error("TLS[manual] accept")]
270    ManualAccept {
271        #[error(std_err)]
272        source: std::io::Error,
273    },
274    #[error("TLS[acme] accept")]
275    LetsEncryptAccept {
276        #[error(std_err)]
277        source: std::io::Error,
278    },
279    #[error("HTTPS connection")]
280    Https {
281        #[error(std_err)]
282        source: hyper::Error,
283    },
284    #[error("HTTP connection")]
285    Http {
286        #[error(std_err)]
287        source: hyper::Error,
288    },
289    #[error("Connection did not reach established state within timeout")]
290    EstablishTimeout,
291}
292
293/// Server accept errors.
294#[allow(missing_docs)]
295#[stack_error(derive, add_meta, from_sources)]
296#[non_exhaustive]
297pub enum AcceptError {
298    #[error(transparent)]
299    Handshake { source: handshake::Error },
300    #[error("rate limiting misconfigured")]
301    RateLimitingMisconfigured { source: InvalidBucketConfig },
302}
303
304/// Server connection errors, includes errors that can happen on `accept`.
305#[allow(missing_docs)]
306#[stack_error(derive, add_meta, from_sources)]
307#[non_exhaustive]
308pub enum ConnectionHandlerError {
309    #[error(transparent)]
310    Accept { source: AcceptError },
311    #[error("Could not downcast the upgraded connection to MaybeTlsStream")]
312    DowncastUpgrade {},
313    #[error("Cannot deal with buffered data yet: {buf:?}")]
314    BufferNotEmpty { buf: Bytes },
315}
316
317/// Builder for the Relay HTTP Server.
318///
319/// Defaults to handling relay requests on the "/relay" (and "/derp" for backwards compatibility) endpoint.
320/// Other HTTP endpoints can be added using [`ServerBuilder::request_handler`].
321#[derive(derive_more::Debug)]
322pub(super) struct ServerBuilder {
323    /// The ip + port combination for this server.
324    addr: SocketAddr,
325    /// Optional tls configuration/TlsAcceptor combination.
326    ///
327    /// When `None`, the server will serve HTTP, otherwise it will serve HTTPS.
328    tls_config: Option<TlsConfig>,
329    /// A map of request handlers to routes.
330    ///
331    /// Used when certain routes in your server should be made available at the same port as
332    /// the relay server, and so must be handled along side requests to the relay endpoint.
333    handlers: Handlers,
334    /// Headers to use for HTTP responses.
335    headers: HeaderMap,
336    /// Rate-limiting configuration for an individual client connection.
337    ///
338    /// Rate-limiting is enforced on received traffic from individual clients.  This
339    /// configuration applies to a single client connection.
340    client_rx_ratelimit: Option<ClientRateLimit>,
341    /// The capacity of the key cache.
342    key_cache_capacity: usize,
343    /// Access config for endpoints.
344    access: AccessConfig,
345    metrics: Option<Arc<Metrics>>,
346    establish_timeout: Duration,
347}
348
349impl ServerBuilder {
350    /// Creates a new [ServerBuilder].
351    pub(super) fn new(addr: SocketAddr) -> Self {
352        Self {
353            addr,
354            tls_config: None,
355            handlers: Default::default(),
356            headers: HeaderMap::new(),
357            client_rx_ratelimit: None,
358            key_cache_capacity: DEFAULT_KEY_CACHE_CAPACITY,
359            access: AccessConfig::Everyone,
360            metrics: None,
361            establish_timeout: ESTABLISH_TIMEOUT,
362        }
363    }
364
365    /// Sets the metrics collector.
366    pub(super) fn metrics(mut self, metrics: Arc<Metrics>) -> Self {
367        self.metrics = Some(metrics);
368        self
369    }
370
371    /// Set the access configuration.
372    pub(super) fn access(mut self, access: AccessConfig) -> Self {
373        self.access = access;
374        self
375    }
376
377    /// Serves all requests content using TLS.
378    pub(super) fn tls_config(mut self, config: Option<TlsConfig>) -> Self {
379        self.tls_config = config;
380        self
381    }
382
383    /// Sets the timeout after which connections are aborted if they don't become fully established.
384    ///
385    /// The timeout is started immediately after a TCP connection comes in, and cleared once
386    /// the connection has finished the TLS handshake and fully processed the WebSocket request
387    /// to initiate the relay protocol. If the timeout expires before being cleared, the
388    /// connection is aborted.
389    ///
390    /// Defaults to 30s.
391    #[cfg(test)]
392    pub(super) fn establish_timeout(mut self, timeout: Duration) -> Self {
393        self.establish_timeout = timeout;
394        self
395    }
396
397    /// Sets the per-client rate-limit configuration for incoming data.
398    ///
399    /// On each client connection the incoming data is rate-limited.  By default
400    /// no rate limit is enforced.
401    pub(super) fn client_rx_ratelimit(mut self, config: ClientRateLimit) -> Self {
402        self.client_rx_ratelimit = Some(config);
403        self
404    }
405
406    /// Adds a custom handler for a specific Method & URI.
407    pub(super) fn request_handler(
408        mut self,
409        method: Method,
410        uri_path: &'static str,
411        handler: HyperHandler,
412    ) -> Self {
413        self.handlers.insert((method, uri_path), handler);
414        self
415    }
416
417    /// Adds HTTP headers to responses.
418    pub(super) fn headers(mut self, headers: HeaderMap) -> Self {
419        for (k, v) in headers.iter() {
420            self.headers.insert(k.clone(), v.clone());
421        }
422        self
423    }
424
425    /// Set the capacity of the cache for public keys.
426    pub fn key_cache_capacity(mut self, capacity: usize) -> Self {
427        self.key_cache_capacity = capacity;
428        self
429    }
430
431    /// Builds and spawns an HTTP(S) Relay Server.
432    pub(super) async fn spawn(self) -> Result<Server, SpawnError> {
433        let cancel_token = CancellationToken::new();
434
435        let service = RelayService::new(
436            self.handlers,
437            self.headers,
438            self.client_rx_ratelimit,
439            KeyCache::new(self.key_cache_capacity),
440            self.access,
441            self.metrics.unwrap_or_default(),
442        );
443
444        let addr = self.addr;
445        let tls_config = self.tls_config;
446
447        // Bind a TCP listener on `addr` and handles content using HTTPS.
448
449        let listener = TcpListener::bind(&addr)
450            .await
451            .map_err(|err| e!(super::SpawnError::BindTcpListener { addr }, err))?;
452
453        let addr = listener
454            .local_addr()
455            .map_err(|err| e!(super::SpawnError::NoLocalAddr, err))?;
456        let http_str = tls_config.as_ref().map_or("HTTP/WS", |_| "HTTPS/WSS");
457        info!("[{http_str}] relay: serving on {addr}");
458
459        let cancel = cancel_token.clone();
460        let task = tokio::task::spawn(
461            async move {
462                // create a join set to track all our connection tasks
463                let mut set = tokio::task::JoinSet::new();
464                loop {
465                    tokio::select! {
466                        biased;
467                        _ = cancel.cancelled() => {
468                            break;
469                        }
470                        Some(res) = set.join_next() => {
471                            if let Err(err) = res
472                                && err.is_panic()
473                            {
474                                panic!("task panicked: {err:#?}");
475                            }
476                        }
477                        res = listener.accept() => match res {
478                            Ok((stream, peer_addr)) => {
479                                debug!("connection opened from {peer_addr}");
480                                let tls_config = tls_config.clone();
481                                let service = service.clone();
482                                // spawn a task to handle the connection
483                                set.spawn(async move {
484                                    service
485                                        .handle_connection(stream, tls_config, self.establish_timeout)
486                                        .await
487                                }.instrument(info_span!("conn", peer = %peer_addr)));
488                            }
489                            Err(err) => {
490                                error!("failed to accept connection: {err}");
491                            }
492                        }
493                    }
494                }
495                service.shutdown().await;
496                set.shutdown().await;
497                debug!("server has been shutdown.");
498            }
499            .instrument(info_span!("relay-http-serve")),
500        );
501
502        Ok(Server {
503            addr,
504            http_server_task: AbortOnDropHandle::new(task),
505            cancel_server_loop: cancel_token,
506        })
507    }
508}
509
510/// The hyper Service that serves the actual relay endpoints.
511///
512/// This service can be used standalone or embedded into an existing HTTP server.
513#[derive(Clone, Debug)]
514pub struct RelayService(Arc<Inner>);
515
516#[derive(Debug)]
517struct Inner {
518    handlers: Handlers,
519    headers: HeaderMap,
520    clients: Clients,
521    write_timeout: Duration,
522    rate_limit: Option<ClientRateLimit>,
523    key_cache: KeyCache,
524    access: AccessConfig,
525    metrics: Arc<Metrics>,
526}
527
528#[stack_error(derive, add_meta)]
529enum RelayUpgradeReqError {
530    #[error("missing header: {header}")]
531    MissingHeader { header: http::HeaderName },
532    #[error("invalid header value for {header}: {details}")]
533    InvalidHeader {
534        header: http::HeaderName,
535        details: String,
536    },
537    #[error(
538        "invalid header value for {SEC_WEBSOCKET_VERSION}: unsupported websocket version, only supporting {SUPPORTED_WEBSOCKET_VERSION}"
539    )]
540    UnsupportedWebsocketVersion,
541    #[error(
542        "invalid header value for {SEC_WEBSOCKET_PROTOCOL}: unsupported relay version: we support {we_support} but you only provide {you_support}"
543    )]
544    UnsupportedRelayVersion {
545        we_support: String,
546        you_support: String,
547    },
548}
549
550impl RelayServiceWithNotify {
551    fn build_response(&self) -> http::response::Builder {
552        let mut res = Response::builder();
553        for (key, value) in self.service.0.headers.iter() {
554            res = res.header(key, value);
555        }
556        res
557    }
558
559    /// Upgrades the HTTP connection to the relay protocol, runs relay client.
560    fn handle_relay_ws_upgrade(
561        &self,
562        mut req: Request<Incoming>,
563    ) -> Result<Response<BytesBody>, RelayUpgradeReqError> {
564        fn expect_header(
565            req: &Request<Incoming>,
566            header: http::HeaderName,
567        ) -> Result<&HeaderValue, RelayUpgradeReqError> {
568            req.headers()
569                .get(&header)
570                .ok_or_else(|| e!(RelayUpgradeReqError::MissingHeader { header }))
571        }
572
573        let upgrade_header = expect_header(&req, UPGRADE)?;
574        ensure!(
575            upgrade_header == HeaderValue::from_static(WEBSOCKET_UPGRADE_PROTOCOL),
576            RelayUpgradeReqError::InvalidHeader {
577                header: UPGRADE,
578                details: format!("value must be {WEBSOCKET_UPGRADE_PROTOCOL}")
579            }
580        );
581
582        let key = expect_header(&req, SEC_WEBSOCKET_KEY)?.clone();
583        let version = expect_header(&req, SEC_WEBSOCKET_VERSION)?.clone();
584
585        ensure!(
586            version.as_bytes() == SUPPORTED_WEBSOCKET_VERSION.as_bytes(),
587            RelayUpgradeReqError::UnsupportedWebsocketVersion
588        );
589
590        let subprotocols = expect_header(&req, SEC_WEBSOCKET_PROTOCOL)?
591            .to_str()
592            .ok()
593            .ok_or_else(|| {
594                e!(RelayUpgradeReqError::InvalidHeader {
595                    header: SEC_WEBSOCKET_PROTOCOL,
596                    details: "header value is not ascii".to_string()
597                })
598            })?;
599        let protocol_version = subprotocols
600            .split(",")
601            .map(|s| s.trim())
602            .filter_map(ProtocolVersion::match_from_str)
603            .max()
604            .ok_or_else(|| {
605                e!(RelayUpgradeReqError::UnsupportedRelayVersion {
606                    we_support: ProtocolVersion::all_joined(),
607                    you_support: subprotocols.to_string()
608                })
609            })?;
610
611        let client_auth_header = req.headers().get(CLIENT_AUTH_HEADER).cloned();
612
613        // Setup a future that will eventually receive the upgraded
614        // connection and talk a new protocol, and spawn the future
615        // into the runtime.
616        //
617        // Note: This can't possibly be fulfilled until the 101 response
618        // is returned below, so it's better to spawn this future instead
619        // waiting for it to complete to then return a response.
620        tokio::task::spawn({
621            let this = self.clone();
622            async move {
623                match hyper::upgrade::on(&mut req).await {
624                    Ok(upgraded) => {
625                        if let Err(err) = this
626                            .service
627                            .0
628                            .relay_connection_handler(
629                                upgraded,
630                                client_auth_header,
631                                protocol_version,
632                            )
633                            .await
634                        {
635                            warn!("error accepting upgraded connection: {err:#}",);
636                        } else {
637                            // We have passed the connection to the relay protocol handler,
638                            // thus we trigger the on_establish notification so that timeouts
639                            // on the upper layer will be cleared.
640                            this.on_establish.notify_waiters();
641                            debug!("upgraded connection completed");
642                        };
643                    }
644                    Err(err) => warn!("upgrade error: {err:#}"),
645                }
646            }
647            .instrument(warn_span!("handler"))
648        });
649
650        // Now return a 101 Response saying we agree to the upgrade to the
651        // websocket upgrade protocol
652        Ok(self
653            .build_response()
654            .status(StatusCode::SWITCHING_PROTOCOLS)
655            .header(
656                UPGRADE,
657                HeaderValue::from_static(WEBSOCKET_UPGRADE_PROTOCOL),
658            )
659            .header(SEC_WEBSOCKET_ACCEPT, derive_accept_key(&key))
660            .header(SEC_WEBSOCKET_PROTOCOL, protocol_version.to_header_value())
661            .header(CONNECTION, "upgrade")
662            .body(body_full("switching to websocket protocol"))
663            .expect("valid body"))
664    }
665}
666
667/// Combines [`RelayService`] with a notification token.
668///
669/// This struct implements [`Service`].
670///
671/// The notification token is triggered once the relay connection is fully established.
672#[derive(Debug, Clone)]
673pub struct RelayServiceWithNotify {
674    service: RelayService,
675    on_establish: Arc<Notify>,
676}
677
678impl RelayServiceWithNotify {
679    /// Creates a new service wrapper for a connection.
680    ///
681    /// The `on_establish` notification is triggered once the connection is passed to the
682    /// relay protocol, i.e. after a WebSocket request on /relay is received and established.
683    pub fn new(service: RelayService, on_establish: Arc<Notify>) -> Self {
684        Self {
685            service,
686            on_establish,
687        }
688    }
689}
690
691impl Service<Request<Incoming>> for RelayServiceWithNotify {
692    type Response = Response<BytesBody>;
693    type Error = HyperError;
694    type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
695
696    fn call(&self, req: Request<Incoming>) -> Self::Future {
697        // Create a client if the request hits the relay endpoint.
698        if matches!(
699            (req.method(), req.uri().path()),
700            (&hyper::Method::GET, RELAY_PATH)
701        ) {
702            let response = match self.handle_relay_ws_upgrade(req) {
703                Ok(response) => Ok(response),
704                // It's convention to send back the version(s) we *do* support
705                Err(e @ RelayUpgradeReqError::UnsupportedWebsocketVersion { .. }) => self
706                    .build_response()
707                    .status(StatusCode::BAD_REQUEST)
708                    .header(SEC_WEBSOCKET_VERSION, SUPPORTED_WEBSOCKET_VERSION)
709                    .body(body_full(e.to_string())),
710                Err(e) => self
711                    .build_response()
712                    .status(StatusCode::BAD_REQUEST)
713                    .body(body_full(e.to_string())),
714            }
715            .map_err(Into::into);
716            return std::future::ready(response);
717        }
718        // Otherwise handle the relay connection as normal.
719
720        // Check all other possible endpoints.
721        let uri = req.uri().clone();
722        if let Some(handler) = self
723            .service
724            .0
725            .handlers
726            .get(&(req.method().clone(), uri.path()))
727        {
728            let response = handler(req, self.service.0.default_response());
729            return std::future::ready(response);
730        }
731
732        // Otherwise return 404
733        let response = self
734            .service
735            .0
736            .not_found_fn(req, self.service.0.default_response());
737        std::future::ready(response)
738    }
739}
740
741impl Inner {
742    fn default_response(&self) -> ResponseBuilder {
743        let mut response = Response::builder();
744        for (key, value) in self.headers.iter() {
745            response = response.header(key.clone(), value.clone());
746        }
747        response
748    }
749
750    fn not_found_fn(
751        &self,
752        _req: Request<Incoming>,
753        mut res: ResponseBuilder,
754    ) -> HyperResult<Response<BytesBody>> {
755        for (k, v) in self.headers.iter() {
756            res = res.header(k.clone(), v.clone());
757        }
758        let body = body_full("Not Found");
759        let r = res.status(StatusCode::NOT_FOUND).body(body)?;
760        HyperResult::Ok(r)
761    }
762
763    /// The server HTTP handler to do HTTP upgrades.
764    ///
765    /// This handler runs while doing the connection upgrade handshake.  Once the connection
766    /// is upgraded it sends the stream to the relay server which takes it over.  After
767    /// having sent off the connection this handler returns.
768    async fn relay_connection_handler(
769        &self,
770        upgraded: Upgraded,
771        client_auth_header: Option<HeaderValue>,
772        protocol_version: ProtocolVersion,
773    ) -> Result<(), ConnectionHandlerError> {
774        debug!("relay_connection upgraded");
775        let (io, read_buf) = downcast_upgrade(upgraded)?;
776        if !read_buf.is_empty() {
777            return Err(e!(ConnectionHandlerError::BufferNotEmpty { buf: read_buf }));
778        }
779
780        self.accept(io, client_auth_header, protocol_version)
781            .await?;
782        Ok(())
783    }
784
785    /// Adds a new connection to the server and serves it.
786    ///
787    /// Will error if it takes too long (10 sec) to write or read to the connection, if there is
788    /// some read or write error to the connection,  if the server is meant to verify clients,
789    /// and is unable to verify this one, or if there is some issue communicating with the server.
790    ///
791    /// The provided [`AsyncRead`] and [`AsyncWrite`] must be already connected to the connection.
792    ///
793    /// [`AsyncRead`]: tokio::io::AsyncRead
794    /// [`AsyncWrite`]: tokio::io::AsyncWrite
795    async fn accept(
796        &self,
797        io: MaybeTlsStream,
798        client_auth_header: Option<HeaderValue>,
799        protocol_version: ProtocolVersion,
800    ) -> Result<(), AcceptError> {
801        trace!("accept: start");
802
803        // Set the socket to NO_DELAY.
804        io.disable_nagle();
805
806        let io = RateLimited::from_cfg(self.rate_limit, io, self.metrics.clone())
807            .map_err(|err| e!(AcceptError::RateLimitingMisconfigured, err))?;
808
809        // Create a server builder with default config
810        let websocket = tokio_websockets::ServerBuilder::new()
811            .limits(tokio_websockets::Limits::default().max_payload_len(Some(MAX_FRAME_SIZE)))
812            // Serve will create a WebSocketStream on an already upgraded connection
813            .serve(io);
814
815        let mut io = WsBytesFramed { io: websocket };
816
817        let authentication = handshake::serverside(&mut io, client_auth_header).await?;
818
819        trace!(?authentication.mechanism, "accept: verified authentication");
820
821        let is_authorized = self.access.is_allowed(authentication.client_key).await;
822        let client_key = authentication.authorize_if(is_authorized, &mut io).await?;
823
824        trace!("accept: verified authorization");
825
826        let io = RelayedStream {
827            inner: io,
828            key_cache: self.key_cache.clone(),
829        };
830
831        trace!("accept: build client conn");
832        let client_conn_builder = Config {
833            endpoint_id: client_key,
834            stream: io,
835            write_timeout: self.write_timeout,
836            channel_capacity: PER_CLIENT_SEND_QUEUE_DEPTH,
837            protocol_version,
838        };
839        trace!("accept: create client");
840        let endpoint_id = client_conn_builder.endpoint_id;
841        trace!(endpoint_id = %endpoint_id.fmt_short(), "create client");
842
843        // build and register client, starting up read & write loops for the client
844        // connection
845        self.clients
846            .register(client_conn_builder, self.metrics.clone());
847        Ok(())
848    }
849}
850
851/// TLS Certificate Authority acceptor.
852#[derive(Clone, derive_more::Debug)]
853pub(super) enum TlsAcceptor {
854    /// Uses Let's Encrypt as the Certificate Authority. This is used in production.
855    LetsEncrypt(#[debug("tokio_rustls_acme::AcmeAcceptor")] AcmeAcceptor),
856    /// Manually added tls acceptor. Generally used for tests or for when we've passed in
857    /// a certificate via a file.
858    Manual(#[debug("tokio_rustls::TlsAcceptor")] tokio_rustls::TlsAcceptor),
859}
860
861impl RelayService {
862    /// Creates a new RelayService.
863    ///
864    /// This allows embedding the relay service into an existing HTTP server.
865    pub fn new(
866        handlers: Handlers,
867        headers: HeaderMap,
868        rate_limit: Option<ClientRateLimit>,
869        key_cache: KeyCache,
870        access: AccessConfig,
871        metrics: Arc<Metrics>,
872    ) -> Self {
873        Self(Arc::new(Inner {
874            handlers,
875            headers,
876            clients: Clients::default(),
877            write_timeout: SERVER_WRITE_TIMEOUT,
878            rate_limit,
879            key_cache,
880            access,
881            metrics,
882        }))
883    }
884
885    /// Shuts down the relay service, disconnecting all clients.
886    pub async fn shutdown(&self) {
887        self.0.clients.shutdown().await;
888    }
889
890    /// Handle the incoming connection.
891    ///
892    /// If a `tls_config` is given, will serve the connection using HTTPS, otherwise HTTP.
893    ///
894    /// If the connection did not fully upgrade to a relay WebSocket connection after
895    /// `establish_timeout`, the connection is aborted.
896    ///
897    /// # Example
898    ///
899    /// ```no_run
900    /// # use std::{sync::Arc, time::Duration};
901    /// # use tokio::net::TcpStream;
902    /// # use http::HeaderMap;
903    /// # use iroh_relay::server::http_server::{Handlers, RelayService, TlsConfig};
904    /// # use iroh_relay::{KeyCache, server::{AccessConfig, Metrics}};
905    /// # use webpki_types::{CertificateDer, PrivateKeyDer};
906    /// # async fn example(stream: TcpStream) -> Result<(), Box<dyn std::error::Error>> {
907    /// // Create a relay service
908    /// let handlers = Handlers::default();
909    /// let headers = HeaderMap::new();
910    /// let key_cache = KeyCache::new(1024);
911    /// let metrics = Arc::new(Metrics::default());
912    /// let relay_service = RelayService::new(
913    ///     handlers,
914    ///     headers,
915    ///     None, // No rate limiting
916    ///     key_cache,
917    ///     AccessConfig::Everyone,
918    ///     metrics,
919    /// );
920    ///
921    /// // Generate a self-signed certificate for HTTPS
922    /// let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])?;
923    /// let cert_der = cert.cert.der().to_vec();
924    /// let private_key_der = cert.signing_key.serialize_der();
925    /// let cert_chain = vec![CertificateDer::from(cert_der)];
926    /// let private_key = PrivateKeyDer::try_from(private_key_der)?;
927    ///
928    /// // Serve with HTTPS
929    /// let server_config = Arc::new(
930    ///     rustls::ServerConfig::builder()
931    ///         .with_no_client_auth()
932    ///         .with_single_cert(cert_chain, private_key)?,
933    /// );
934    /// let tls_config = TlsConfig::new(server_config);
935    /// relay_service
936    ///     .clone()
937    ///     .handle_connection(stream, Some(tls_config), Duration::from_secs(30))
938    ///     .await;
939    ///
940    /// // Or serve with plain HTTP
941    /// # let stream = TcpStream::connect("127.0.0.1:0").await?;
942    /// relay_service
943    ///     .handle_connection(stream, None, Duration::from_secs(30))
944    ///     .await;
945    /// # Ok(())
946    /// # }
947    /// ```
948    pub async fn handle_connection(
949        self,
950        stream: TcpStream,
951        tls_config: Option<TlsConfig>,
952        establish_timeout: Duration,
953    ) {
954        let metrics = self.0.metrics.clone();
955        metrics.http_connections.inc();
956        // We create a notification token to be triggered once the connection is fully established
957        // and passed to the relay server.
958        let on_establish = Arc::new(Notify::new());
959        let service = RelayServiceWithNotify::new(self, on_establish.clone());
960
961        // This is the main connection future, driving the connection to completion.
962        let serve_fut = async move {
963            match tls_config {
964                Some(tls_config) => {
965                    debug!("HTTPS: serve connection");
966                    service.tls_serve_connection(stream, tls_config).await
967                }
968                None => {
969                    debug!("HTTP: serve connection");
970                    let stream = MaybeTlsStream::Plain(stream);
971                    service.serve_connection(stream).await
972                }
973            }
974        };
975
976        // We set a timeout for the connection to limit lingering connections during establishment.
977        // The timeout is cleared once the connection has completed the TLS and WebSocket
978        // handshakes and has been passed over to the relay protocol handler.
979        // If the timeout expires before that happens, the connection is aborted.
980        let res = clearable_timeout(establish_timeout, on_establish, serve_fut)
981            .await
982            .map_err(|_elapsed| e!(ServeConnectionError::EstablishTimeout))
983            .flatten();
984
985        metrics.http_connections_closed.inc();
986
987        if let Err(error) = res {
988            match error {
989                ServeConnectionError::ManualAccept { source, .. }
990                | ServeConnectionError::LetsEncryptAccept { source, .. }
991                    if source.kind() == std::io::ErrorKind::UnexpectedEof =>
992                {
993                    debug!(reason=?source, "peer disconnected");
994                }
995                // From hyper: <https://github.com/hyperium/hyper/commit/271bba16672ff54a44e043c5cc1ae6b9345bb172>
996                // `hyper::Error::IncompleteMessage` is hyper's equivalent of UnexpectedEof
997                ServeConnectionError::Https { source, .. }
998                | ServeConnectionError::Http { source, .. }
999                    if source.is_incomplete_message() =>
1000                {
1001                    debug!(reason=?source, "peer disconnected");
1002                }
1003                _ => {
1004                    metrics.http_connections_errored.inc();
1005                    error!(?error, "failed to handle connection");
1006                }
1007            }
1008        }
1009    }
1010}
1011
1012impl RelayServiceWithNotify {
1013    /// Serves a TLS connection.
1014    async fn tls_serve_connection(
1015        self,
1016        stream: TcpStream,
1017        tls_config: TlsConfig,
1018    ) -> Result<(), ServeConnectionError> {
1019        let TlsConfig { acceptor, config } = tls_config;
1020        let stream = match acceptor {
1021            TlsAcceptor::LetsEncrypt(a) => {
1022                match a
1023                    .accept(stream)
1024                    .await
1025                    .map_err(|err| e!(ServeConnectionError::LetsEncryptAccept, err))?
1026                {
1027                    None => {
1028                        info!("TLS[acme]: received TLS-ALPN-01 validation request");
1029                        return Ok(());
1030                    }
1031                    Some(start_handshake) => {
1032                        debug!("TLS[acme]: start handshake");
1033                        let tls_stream = start_handshake
1034                            .into_stream(config)
1035                            .await
1036                            .map_err(|err| e!(ServeConnectionError::TlsHandshake, err))?;
1037                        MaybeTlsStream::Tls(tls_stream)
1038                    }
1039                }
1040            }
1041            TlsAcceptor::Manual(a) => {
1042                debug!("TLS[manual]: accept");
1043                let tls_stream = a
1044                    .accept(stream)
1045                    .await
1046                    .map_err(|err| e!(ServeConnectionError::ManualAccept, err))?;
1047                MaybeTlsStream::Tls(tls_stream)
1048            }
1049        };
1050        self.serve_connection(stream).await
1051    }
1052
1053    /// Wrapper for the actual http connection (with upgrades)
1054    async fn serve_connection<I>(self, io: I) -> Result<(), ServeConnectionError>
1055    where
1056        I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + Sync + 'static,
1057    {
1058        hyper::server::conn::http1::Builder::new()
1059            .serve_connection(hyper_util::rt::TokioIo::new(io), self)
1060            .with_upgrades()
1061            .await
1062            .map_err(|err| e!(ServeConnectionError::ServeConnection, err))
1063    }
1064}
1065
1066/// A collection of HTTP request handlers for custom endpoints.
1067#[derive(Default)]
1068pub struct Handlers(HashMap<(Method, &'static str), HyperHandler>);
1069
1070impl std::fmt::Debug for Handlers {
1071    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1072        let s = self.0.keys().fold(String::new(), |curr, next| {
1073            let (method, uri) = next;
1074            format!("{curr}\n({method},{uri}): Box<Fn(ResponseBuilder) -> Result<Response<Body>> + Send + Sync + 'static>")
1075        });
1076        write!(f, "HashMap<{s}>")
1077    }
1078}
1079
1080impl std::ops::Deref for Handlers {
1081    type Target = HashMap<(Method, &'static str), HyperHandler>;
1082
1083    fn deref(&self) -> &Self::Target {
1084        &self.0
1085    }
1086}
1087
1088impl std::ops::DerefMut for Handlers {
1089    fn deref_mut(&mut self) -> &mut Self::Target {
1090        &mut self.0
1091    }
1092}
1093
1094/// Requires a future to complete before the specified duration elapses, unless the timeout is cleared.
1095///
1096/// If the future completes before the duration has elapsed, then the completed value is returned.
1097/// Otherwise, an error is returned and the future is canceled.
1098///
1099/// If `clear_timeout` is triggered, the timeout is cleared and the future is always run to completion.
1100async fn clearable_timeout<F: Future>(
1101    timeout: Duration,
1102    clear_timeout: Arc<Notify>,
1103    fut: F,
1104) -> Result<F::Output, Elapsed> {
1105    tokio::pin!(fut);
1106    let timeout = MaybeFuture::Some(tokio::time::sleep(timeout));
1107    tokio::pin!(timeout);
1108    loop {
1109        tokio::select! {
1110            biased;
1111            res = &mut fut => {
1112                return Ok(res);
1113            }
1114            _ = clear_timeout.notified() => {
1115                timeout.as_mut().set_none();
1116            },
1117            _ = &mut timeout => {
1118                return Err(Elapsed);
1119            }
1120        }
1121    }
1122}
1123
1124#[stack_error(derive)]
1125#[error("Timeout elapsed")]
1126struct Elapsed;
1127
1128#[cfg(test)]
1129mod tests {
1130    use std::sync::Arc;
1131
1132    use iroh_base::{PublicKey, SecretKey};
1133    use n0_error::{Result, StdResultExt, bail_any};
1134    use n0_future::{SinkExt, StreamExt};
1135    use n0_tracing_test::traced_test;
1136    use rand::{RngExt, SeedableRng};
1137    use reqwest::Url;
1138    use tokio::io::{AsyncReadExt, AsyncWriteExt};
1139    use tracing::info;
1140
1141    use super::*;
1142    use crate::{
1143        client::{Client, ClientBuilder, ConnectError, conn::Conn},
1144        dns::DnsResolver,
1145        protos::relay::{ClientToRelayMsg, Datagrams, RelayToClientMsg},
1146        tls::{CaRootsConfig, default_provider},
1147    };
1148
1149    pub(crate) fn make_tls_config() -> TlsConfig {
1150        let subject_alt_names = vec!["localhost".to_string()];
1151
1152        let cert = rcgen::generate_simple_self_signed(subject_alt_names).unwrap();
1153        let rustls_certificate = cert.cert.der().clone();
1154        let rustls_key =
1155            rustls::pki_types::PrivatePkcs8KeyDer::from(cert.signing_key.serialize_der());
1156        let config = rustls::ServerConfig::builder_with_provider(Arc::new(
1157            rustls::crypto::ring::default_provider(),
1158        ))
1159        .with_safe_default_protocol_versions()
1160        .expect("protocols supported by ring")
1161        .with_no_client_auth()
1162        .with_single_cert(vec![(rustls_certificate)], rustls_key.into())
1163        .expect("cert is right");
1164
1165        TlsConfig::new(Arc::new(config))
1166    }
1167
1168    #[tokio::test]
1169    #[traced_test]
1170    async fn test_http_clients_and_server() -> Result {
1171        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
1172
1173        let a_key = SecretKey::from_bytes(&rng.random());
1174        let b_key = SecretKey::from_bytes(&rng.random());
1175
1176        // start server
1177        let server = ServerBuilder::new("127.0.0.1:0".parse().unwrap())
1178            .spawn()
1179            .await?;
1180
1181        let addr = server.addr();
1182
1183        // get dial info
1184        let port = addr.port();
1185        let addr = if let std::net::IpAddr::V4(ipv4_addr) = addr.ip() {
1186            ipv4_addr
1187        } else {
1188            bail_any!("cannot get ipv4 addr from socket addr {addr:?}");
1189        };
1190
1191        info!("addr: {addr}:{port}");
1192        let relay_addr: Url = format!("http://{addr}:{port}").parse().unwrap();
1193
1194        // create clients
1195        let (a_key, mut client_a) = create_test_client(a_key, relay_addr.clone()).await?;
1196        info!("created client {a_key:?}");
1197        let (b_key, mut client_b) = create_test_client(b_key, relay_addr).await?;
1198        info!("created client {b_key:?}");
1199
1200        info!("ping a");
1201        client_a.send(ClientToRelayMsg::Ping([1u8; 8])).await?;
1202        let pong = client_a.next().await.expect("eos")?;
1203        assert!(matches!(pong, RelayToClientMsg::Pong { .. }));
1204
1205        info!("ping b");
1206        client_b.send(ClientToRelayMsg::Ping([2u8; 8])).await?;
1207        let pong = client_b.next().await.expect("eos")?;
1208        assert!(matches!(pong, RelayToClientMsg::Pong { .. }));
1209
1210        info!("sending message from a to b");
1211        let msg = Datagrams::from(b"hi there, client b!");
1212        client_a
1213            .send(ClientToRelayMsg::Datagrams {
1214                dst_endpoint_id: b_key,
1215                datagrams: msg.clone(),
1216            })
1217            .await?;
1218        info!("waiting for message from a on b");
1219        let (got_key, got_msg) =
1220            process_msg(client_b.next().await).expect("expected message from client_a");
1221        assert_eq!(a_key, got_key);
1222        assert_eq!(msg, got_msg);
1223
1224        info!("sending message from b to a");
1225        let msg = Datagrams::from(b"right back at ya, client b!");
1226        client_b
1227            .send(ClientToRelayMsg::Datagrams {
1228                dst_endpoint_id: a_key,
1229                datagrams: msg.clone(),
1230            })
1231            .await?;
1232        info!("waiting for message b on a");
1233        let (got_key, got_msg) =
1234            process_msg(client_a.next().await).expect("expected message from client_b");
1235        assert_eq!(b_key, got_key);
1236        assert_eq!(msg, got_msg);
1237
1238        // Close before shutting down, otherwise we'll try to send close frames on broken pipes
1239        client_a.close().await?;
1240        client_b.close().await?;
1241        server.shutdown();
1242
1243        Ok(())
1244    }
1245
1246    async fn create_test_client(
1247        key: SecretKey,
1248        server_url: Url,
1249    ) -> Result<(PublicKey, Client), ConnectError> {
1250        let public_key = key.public();
1251        let client = ClientBuilder::new(server_url, key, DnsResolver::new()).tls_client_config(
1252            CaRootsConfig::insecure_skip_verify()
1253                .client_config(default_provider())
1254                .expect("infallible"),
1255        );
1256        let client = client.connect().await?;
1257
1258        Ok((public_key, client))
1259    }
1260
1261    fn process_msg(
1262        msg: Option<Result<RelayToClientMsg, crate::client::RecvError>>,
1263    ) -> Option<(PublicKey, Datagrams)> {
1264        match msg {
1265            Some(Err(e)) => {
1266                info!("client `recv` error {e}");
1267                None
1268            }
1269            Some(Ok(msg)) => {
1270                info!("got message on: {msg:?}");
1271                if let RelayToClientMsg::Datagrams {
1272                    remote_endpoint_id: source,
1273                    datagrams,
1274                } = msg
1275                {
1276                    Some((source, datagrams))
1277                } else {
1278                    None
1279                }
1280            }
1281            None => {
1282                info!("client end of stream");
1283                None
1284            }
1285        }
1286    }
1287
1288    #[tokio::test]
1289    #[traced_test]
1290    async fn test_subprotocol_negotiation_picks_latest() -> Result {
1291        let server = ServerBuilder::new("127.0.0.1:0".parse().unwrap())
1292            .spawn()
1293            .await?;
1294        let addr = server.addr();
1295
1296        for offered in [
1297            "iroh-relay-v2,iroh-relay-v1",
1298            "iroh-relay-v1,iroh-relay-v2",
1299            "baz, iroh-relay-v1, iroh-relay-v2, boo",
1300            "foo, iroh-relay-v2, bar",
1301        ] {
1302            let ws_uri = format!("ws://{addr}{RELAY_PATH}");
1303            let (_stream, response) = tokio_websockets::ClientBuilder::new()
1304                .uri(&ws_uri)
1305                .expect("valid websocket URI")
1306                .add_header(
1307                    SEC_WEBSOCKET_PROTOCOL,
1308                    HeaderValue::from_str(offered).expect("valid subprotocol header value"),
1309                )
1310                .expect("header accepted by websocket client")
1311                .connect()
1312                .await
1313                .expect("websocket upgrade succeeds");
1314            let negotiated = response
1315                .headers()
1316                .get(SEC_WEBSOCKET_PROTOCOL)
1317                .expect("Sec-WebSocket-Protocol response header is present");
1318            assert_eq!(negotiated, "iroh-relay-v2", "offered={offered}");
1319        }
1320
1321        server.shutdown();
1322        Ok(())
1323    }
1324
1325    #[tokio::test]
1326    #[traced_test]
1327    async fn test_https_clients_and_server() -> Result {
1328        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
1329
1330        let a_key = SecretKey::from_bytes(&rng.random());
1331        let b_key = SecretKey::from_bytes(&rng.random());
1332
1333        // create tls_config
1334        let tls_config = make_tls_config();
1335
1336        // start server
1337        let mut server = ServerBuilder::new("127.0.0.1:0".parse().unwrap())
1338            .tls_config(Some(tls_config))
1339            .spawn()
1340            .await?;
1341
1342        let addr = server.addr();
1343
1344        // get dial info
1345        let port = addr.port();
1346        let addr = if let std::net::IpAddr::V4(ipv4_addr) = addr.ip() {
1347            ipv4_addr
1348        } else {
1349            bail_any!("cannot get ipv4 addr from socket addr {addr:?}");
1350        };
1351
1352        info!("Relay listening on: {addr}:{port}");
1353
1354        let url: Url = format!("https://localhost:{port}").parse().unwrap();
1355
1356        // create clients
1357        let (a_key, mut client_a) = create_test_client(a_key, url.clone()).await?;
1358        info!("created client {a_key:?}");
1359        let (b_key, mut client_b) = create_test_client(b_key, url).await?;
1360        info!("created client {b_key:?}");
1361
1362        info!("ping a");
1363        client_a.send(ClientToRelayMsg::Ping([1u8; 8])).await?;
1364        let pong = client_a.next().await.expect("eos")?;
1365        assert!(matches!(pong, RelayToClientMsg::Pong { .. }));
1366
1367        info!("ping b");
1368        client_b.send(ClientToRelayMsg::Ping([2u8; 8])).await?;
1369        let pong = client_b.next().await.expect("eos")?;
1370        assert!(matches!(pong, RelayToClientMsg::Pong { .. }));
1371
1372        info!("sending message from a to b");
1373        let msg = Datagrams::from(b"hi there, client b!");
1374        client_a
1375            .send(ClientToRelayMsg::Datagrams {
1376                dst_endpoint_id: b_key,
1377                datagrams: msg.clone(),
1378            })
1379            .await?;
1380        info!("waiting for message from a on b");
1381        let (got_key, got_msg) =
1382            process_msg(client_b.next().await).expect("expected message from client_a");
1383        assert_eq!(a_key, got_key);
1384        assert_eq!(msg, got_msg);
1385
1386        info!("sending message from b to a");
1387        let msg = Datagrams::from(b"right back at ya, client b!");
1388        client_b
1389            .send(ClientToRelayMsg::Datagrams {
1390                dst_endpoint_id: a_key,
1391                datagrams: msg.clone(),
1392            })
1393            .await?;
1394        info!("waiting for message b on a");
1395        let (got_key, got_msg) =
1396            process_msg(client_a.next().await).expect("expected message from client_b");
1397        assert_eq!(b_key, got_key);
1398        assert_eq!(msg, got_msg);
1399
1400        // Close before shutting down, otherwise we'll try to send close frames on broken pipes
1401        client_a.close().await?;
1402        client_b.close().await?;
1403        server.shutdown();
1404        server.task_handle().await.std_context("join")?;
1405
1406        Ok(())
1407    }
1408
1409    async fn make_test_client(client: tokio::io::DuplexStream, key: &SecretKey) -> Result<Conn> {
1410        let client = crate::client::streams::MaybeTlsStream::Test(client);
1411        let client = tokio_websockets::ClientBuilder::new().take_over(client);
1412        let client = Conn::new(client, KeyCache::test(), key, Default::default()).await?;
1413        Ok(client)
1414    }
1415
1416    #[tokio::test]
1417    #[traced_test]
1418    async fn test_server_basic() -> Result {
1419        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
1420
1421        info!("Create the server.");
1422        let metrics = Arc::new(Metrics::default());
1423        let service = RelayService::new(
1424            Default::default(),
1425            Default::default(),
1426            None,
1427            KeyCache::test(),
1428            AccessConfig::Everyone,
1429            metrics.clone(),
1430        );
1431
1432        info!("Create client A and connect it to the server.");
1433        let key_a = SecretKey::from_bytes(&rng.random());
1434        let public_key_a = key_a.public();
1435        let (client_a, rw_a) = tokio::io::duplex(10);
1436        let s = service.clone();
1437        let handler_task = tokio::spawn(async move {
1438            s.0.accept(MaybeTlsStream::Test(rw_a), None, Default::default())
1439                .await
1440        });
1441        let mut client_a = make_test_client(client_a, &key_a).await?;
1442        handler_task.await.std_context("join")??;
1443
1444        info!("Create client B and connect it to the server.");
1445        let key_b = SecretKey::from_bytes(&rng.random());
1446        let public_key_b = key_b.public();
1447        let (client_b, rw_b) = tokio::io::duplex(10);
1448        let s = service.clone();
1449        let handler_task = tokio::spawn(async move {
1450            s.0.accept(MaybeTlsStream::Test(rw_b), None, Default::default())
1451                .await
1452        });
1453        let mut client_b = make_test_client(client_b, &key_b).await?;
1454        handler_task.await.std_context("join")??;
1455
1456        info!("Send message from A to B.");
1457        let msg = Datagrams::from(b"hello client b!!");
1458        client_a
1459            .send(ClientToRelayMsg::Datagrams {
1460                dst_endpoint_id: public_key_b,
1461                datagrams: msg.clone(),
1462            })
1463            .await?;
1464        match client_b.next().await.unwrap()? {
1465            RelayToClientMsg::Datagrams {
1466                remote_endpoint_id,
1467                datagrams,
1468            } => {
1469                assert_eq!(public_key_a, remote_endpoint_id);
1470                assert_eq!(msg, datagrams);
1471            }
1472            msg => {
1473                bail_any!("expected ReceivedDatagrams msg, got {msg:?}");
1474            }
1475        }
1476
1477        info!("Send message from B to A.");
1478        let msg = Datagrams::from(b"nice to meet you client a!!");
1479        client_b
1480            .send(ClientToRelayMsg::Datagrams {
1481                dst_endpoint_id: public_key_a,
1482                datagrams: msg.clone(),
1483            })
1484            .await?;
1485        match client_a.next().await.unwrap()? {
1486            RelayToClientMsg::Datagrams {
1487                remote_endpoint_id,
1488                datagrams,
1489            } => {
1490                assert_eq!(public_key_b, remote_endpoint_id);
1491                assert_eq!(msg, datagrams);
1492            }
1493            msg => {
1494                bail_any!("expected ReceivedDatagrams msg, got {msg:?}");
1495            }
1496        }
1497
1498        info!("Close the server and clients");
1499        service.shutdown().await;
1500        tokio::time::sleep(Duration::from_secs(1)).await;
1501
1502        info!("Fail to send message from A to B.");
1503        let res = client_a
1504            .send(ClientToRelayMsg::Datagrams {
1505                dst_endpoint_id: public_key_b,
1506                datagrams: Datagrams::from(b"try to send"),
1507            })
1508            .await;
1509        assert!(res.is_err());
1510        assert!(client_b.next().await.is_none());
1511
1512        drop(client_a);
1513        drop(client_b);
1514
1515        service.shutdown().await;
1516
1517        assert_eq!(metrics.accepts.get(), metrics.disconnects.get());
1518
1519        Ok(())
1520    }
1521
1522    #[tokio::test]
1523    async fn test_server_replace_client() -> Result {
1524        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
1525
1526        info!("Create the server.");
1527        let service = RelayService::new(
1528            Default::default(),
1529            Default::default(),
1530            None,
1531            KeyCache::test(),
1532            AccessConfig::Everyone,
1533            Default::default(),
1534        );
1535
1536        info!("Create client A and connect it to the server.");
1537        let key_a = SecretKey::from_bytes(&rng.random());
1538        let public_key_a = key_a.public();
1539        let (client_a, rw_a) = tokio::io::duplex(10);
1540        let s = service.clone();
1541        let handler_task = tokio::spawn(async move {
1542            s.0.accept(MaybeTlsStream::Test(rw_a), None, Default::default())
1543                .await
1544        });
1545        let mut client_a = make_test_client(client_a, &key_a).await?;
1546        handler_task.await.std_context("join")??;
1547
1548        info!("Create client B and connect it to the server.");
1549        let key_b = SecretKey::from_bytes(&rng.random());
1550        let public_key_b = key_b.public();
1551        let (client_b, rw_b) = tokio::io::duplex(10);
1552        let s = service.clone();
1553        let handler_task = tokio::spawn(async move {
1554            s.0.accept(MaybeTlsStream::Test(rw_b), None, Default::default())
1555                .await
1556        });
1557        let mut client_b = make_test_client(client_b, &key_b).await?;
1558        handler_task.await.std_context("join")??;
1559
1560        info!("Send message from A to B.");
1561        let msg = Datagrams::from(b"hello client b!!");
1562        client_a
1563            .send(ClientToRelayMsg::Datagrams {
1564                dst_endpoint_id: public_key_b,
1565                datagrams: msg.clone(),
1566            })
1567            .await?;
1568        match client_b.next().await.expect("eos")? {
1569            RelayToClientMsg::Datagrams {
1570                remote_endpoint_id,
1571                datagrams,
1572            } => {
1573                assert_eq!(public_key_a, remote_endpoint_id);
1574                assert_eq!(msg, datagrams);
1575            }
1576            msg => {
1577                bail_any!("expected ReceivedDatagrams msg, got {msg:?}");
1578            }
1579        }
1580
1581        info!("Send message from B to A.");
1582        let msg = Datagrams::from(b"nice to meet you client a!!");
1583        client_b
1584            .send(ClientToRelayMsg::Datagrams {
1585                dst_endpoint_id: public_key_a,
1586                datagrams: msg.clone(),
1587            })
1588            .await?;
1589        match client_a.next().await.expect("eos")? {
1590            RelayToClientMsg::Datagrams {
1591                remote_endpoint_id,
1592                datagrams,
1593            } => {
1594                assert_eq!(public_key_b, remote_endpoint_id);
1595                assert_eq!(msg, datagrams);
1596            }
1597            msg => {
1598                bail_any!("expected ReceivedDatagrams msg, got {msg:?}");
1599            }
1600        }
1601
1602        info!("Create client B and connect it to the server");
1603        let (new_client_b, new_rw_b) = tokio::io::duplex(10);
1604        let s = service.clone();
1605        let handler_task = tokio::spawn(async move {
1606            s.0.accept(MaybeTlsStream::Test(new_rw_b), None, Default::default())
1607                .await
1608        });
1609        let mut new_client_b = make_test_client(new_client_b, &key_b).await?;
1610        handler_task.await.std_context("join")??;
1611
1612        // assert!(client_b.recv().await.is_err());
1613
1614        info!("Send message from A to B.");
1615        let msg = Datagrams::from(b"are you still there, b?!");
1616        client_a
1617            .send(ClientToRelayMsg::Datagrams {
1618                dst_endpoint_id: public_key_b,
1619                datagrams: msg.clone(),
1620            })
1621            .await?;
1622        match new_client_b.next().await.expect("eos")? {
1623            RelayToClientMsg::Datagrams {
1624                remote_endpoint_id,
1625                datagrams,
1626            } => {
1627                assert_eq!(public_key_a, remote_endpoint_id);
1628                assert_eq!(msg, datagrams);
1629            }
1630            msg => {
1631                bail_any!("expected ReceivedDatagrams msg, got {msg:?}");
1632            }
1633        }
1634
1635        info!("Send message from B to A.");
1636        let msg = Datagrams::from(b"just had a spot of trouble but I'm back now,a!!");
1637        new_client_b
1638            .send(ClientToRelayMsg::Datagrams {
1639                dst_endpoint_id: public_key_a,
1640                datagrams: msg.clone(),
1641            })
1642            .await?;
1643        match client_a.next().await.expect("eos")? {
1644            RelayToClientMsg::Datagrams {
1645                remote_endpoint_id,
1646                datagrams,
1647            } => {
1648                assert_eq!(public_key_b, remote_endpoint_id);
1649                assert_eq!(msg, datagrams);
1650            }
1651            msg => {
1652                bail_any!("expected ReceivedDatagrams msg, got {msg:?}");
1653            }
1654        }
1655
1656        info!("Close the server and clients");
1657        service.shutdown().await;
1658
1659        info!("Sending message from A to B fails");
1660        let res = client_a
1661            .send(ClientToRelayMsg::Datagrams {
1662                dst_endpoint_id: public_key_b,
1663                datagrams: Datagrams::from(b"try to send"),
1664            })
1665            .await;
1666        assert!(res.is_err());
1667        assert!(new_client_b.next().await.is_none());
1668        Ok(())
1669    }
1670
1671    #[tokio::test]
1672    #[traced_test]
1673    async fn test_establish_timeout() -> Result {
1674        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42u64);
1675
1676        // Start server with a very short establish timeout.
1677        let server = ServerBuilder::new("127.0.0.1:0".parse().unwrap())
1678            .establish_timeout(Duration::from_millis(500))
1679            .spawn()
1680            .await?;
1681
1682        let addr = server.addr();
1683        let port = addr.port();
1684        let addr = if let std::net::IpAddr::V4(ipv4_addr) = addr.ip() {
1685            ipv4_addr
1686        } else {
1687            bail_any!("cannot get ipv4 addr from socket addr {addr:?}");
1688        };
1689        let relay_url: Url = format!("http://{addr}:{port}").parse().unwrap();
1690
1691        // 1. A lingering connection that never upgrades should be aborted by the timeout.
1692        info!("opening lingering TCP connection (no upgrade)");
1693        let mut lingering = TcpStream::connect(format!("{addr}:{port}")).await?;
1694        // Write a partial HTTP request but never complete the upgrade.
1695        lingering
1696            .write_all(b"GET / HTTP/1.1\r\nHost: localhost\r\n")
1697            .await?;
1698        // Wait for the server to abort this connection.
1699        let mut buf = [0u8; 1];
1700        let deadline = tokio::time::Instant::now() + Duration::from_secs(5);
1701        let read = tokio::time::timeout_at(deadline, lingering.read(&mut buf)).await;
1702        // The server should close the connection, resulting in a read of 0 bytes or an error.
1703        match read {
1704            Ok(Ok(0)) => info!("lingering connection closed by server (EOF)"),
1705            Ok(Err(e)) => info!("lingering connection closed by server (error: {e})"),
1706            other => bail_any!("expected lingering connection to be closed, got {other:?}"),
1707        }
1708
1709        // 2. A properly established client should NOT be aborted by the timeout.
1710        info!("connecting a proper relay client");
1711        let key = SecretKey::from_bytes(&rng.random());
1712        let (_, mut client) = create_test_client(key, relay_url).await?;
1713
1714        // Wait longer than the establish timeout to prove the connection survives.
1715        tokio::time::sleep(Duration::from_millis(1000)).await;
1716
1717        // Ping should still work.
1718        client.send(ClientToRelayMsg::Ping([7u8; 8])).await?;
1719        let pong = client.next().await.expect("expected pong")?;
1720        assert!(matches!(pong, RelayToClientMsg::Pong { .. }));
1721        info!("established connection survived past the timeout");
1722
1723        client.close().await?;
1724        server.shutdown();
1725        Ok(())
1726    }
1727}