iroh_net/relay/server/
http_server.rs

1use std::{collections::HashMap, future::Future, net::SocketAddr, pin::Pin, sync::Arc};
2
3use anyhow::{bail, ensure, Context as _, Result};
4use bytes::Bytes;
5use derive_more::Debug;
6use futures_lite::FutureExt;
7use http::{header::CONNECTION, response::Builder as ResponseBuilder};
8use hyper::{
9    body::Incoming,
10    header::{HeaderValue, UPGRADE},
11    service::Service,
12    upgrade::Upgraded,
13    HeaderMap, Method, Request, Response, StatusCode,
14};
15use tokio::net::{TcpListener, TcpStream};
16use tokio_rustls_acme::AcmeAcceptor;
17use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle};
18use tracing::{debug, debug_span, error, info, info_span, warn, Instrument};
19use tungstenite::handshake::derive_accept_key;
20
21use crate::relay::{
22    http::{Protocol, LEGACY_RELAY_PATH, RELAY_PATH, SUPPORTED_WEBSOCKET_VERSION},
23    server::{
24        actor::{ClientConnHandler, ServerActorTask},
25        streams::MaybeTlsStream,
26    },
27};
28
29type BytesBody = http_body_util::Full<hyper::body::Bytes>;
30type HyperError = Box<dyn std::error::Error + Send + Sync>;
31type HyperResult<T> = std::result::Result<T, HyperError>;
32type HyperHandler = Box<
33    dyn Fn(Request<Incoming>, ResponseBuilder) -> HyperResult<Response<BytesBody>>
34        + Send
35        + Sync
36        + 'static,
37>;
38
39/// Creates a new [`BytesBody`] with no content.
40fn body_empty() -> BytesBody {
41    http_body_util::Full::new(hyper::body::Bytes::new())
42}
43
44/// Creates a new [`BytesBody`] with given content.
45fn body_full(content: impl Into<hyper::body::Bytes>) -> BytesBody {
46    http_body_util::Full::new(content.into())
47}
48
49fn downcast_upgrade(upgraded: Upgraded) -> Result<(MaybeTlsStream, Bytes)> {
50    match upgraded.downcast::<hyper_util::rt::TokioIo<MaybeTlsStream>>() {
51        Ok(parts) => Ok((parts.io.into_inner(), parts.read_buf)),
52        Err(_) => {
53            bail!("could not downcast the upgraded connection to MaybeTlsStream")
54        }
55    }
56}
57
58/// The server HTTP handler to do HTTP upgrades.
59async fn relay_connection_handler(
60    protocol: Protocol,
61    conn_handler: &ClientConnHandler,
62    upgraded: Upgraded,
63) -> Result<()> {
64    debug!(?protocol, "relay_connection upgraded");
65    let (io, read_buf) = downcast_upgrade(upgraded)?;
66    ensure!(
67        read_buf.is_empty(),
68        "can not deal with buffered data yet: {:?}",
69        read_buf
70    );
71
72    conn_handler.accept(protocol, io).await
73}
74
75/// The Relay HTTP server.
76///
77/// A running HTTP server serving the relay endpoint and optionally a number of additional
78/// HTTP services added with [`ServerBuilder::request_handler`].  If configured using
79/// [`ServerBuilder::tls_config`] the server will handle TLS as well.
80///
81/// Created using [`ServerBuilder::spawn`].
82#[derive(Debug)]
83pub struct Server {
84    addr: SocketAddr,
85    http_server_task: AbortOnDropHandle<()>,
86    cancel_server_loop: CancellationToken,
87}
88
89impl Server {
90    /// Returns a handle for this server.
91    ///
92    /// The server runs in the background as several async tasks.  This allows controlling
93    /// the server, in particular it allows gracefully shutting down the server.
94    pub fn handle(&self) -> ServerHandle {
95        ServerHandle {
96            cancel_token: self.cancel_server_loop.clone(),
97        }
98    }
99
100    /// Closes the underlying relay server and the HTTP(S) server tasks.
101    pub fn shutdown(&self) {
102        self.cancel_server_loop.cancel();
103    }
104
105    /// Returns the [`AbortOnDropHandle`] for the supervisor task managing the server.
106    ///
107    /// This is the root of all the tasks for the server.  Aborting it will abort all the
108    /// other tasks for the server.  Awaiting it will complete when all the server tasks are
109    /// completed.
110    pub fn task_handle(&mut self) -> &mut AbortOnDropHandle<()> {
111        &mut self.http_server_task
112    }
113
114    /// Returns the local address of this server.
115    pub fn addr(&self) -> SocketAddr {
116        self.addr
117    }
118}
119
120/// A handle for the [`Server`].
121///
122/// This does not allow access to the task but can communicate with it.
123#[derive(Debug, Clone)]
124pub struct ServerHandle {
125    cancel_token: CancellationToken,
126}
127
128impl ServerHandle {
129    /// Gracefully shut down the server.
130    pub fn shutdown(&self) {
131        self.cancel_token.cancel()
132    }
133}
134
135/// Configuration to use for the TLS connection
136#[derive(Debug, Clone)]
137pub struct TlsConfig {
138    /// The server config
139    pub config: Arc<rustls::ServerConfig>,
140    /// The kind
141    pub acceptor: TlsAcceptor,
142}
143
144/// Builder for the Relay HTTP Server.
145///
146/// Defaults to handling relay requests on the "/relay" (and "/derp" for backwards compatibility) endpoint.
147/// Other HTTP endpoints can be added using [`ServerBuilder::request_handler`].
148#[derive(derive_more::Debug)]
149pub struct ServerBuilder {
150    /// The ip + port combination for this server.
151    addr: SocketAddr,
152    /// Optional tls configuration/TlsAcceptor combination.
153    ///
154    /// When `None`, the server will serve HTTP, otherwise it will serve HTTPS.
155    tls_config: Option<TlsConfig>,
156    /// A map of request handlers to routes.
157    ///
158    /// Used when certain routes in your server should be made available at the same port as
159    /// the relay server, and so must be handled along side requests to the relay endpoint.
160    handlers: Handlers,
161    /// Headers to use for HTTP responses.
162    headers: HeaderMap,
163    /// 404 not found response.
164    ///
165    /// When `None`, a default is provided.
166    #[debug("{}", not_found_fn.as_ref().map_or("None", |_| "Some(Box<Fn(ResponseBuilder) -> Result<Response<Body>> + Send + Sync + 'static>)"))]
167    not_found_fn: Option<HyperHandler>,
168}
169
170impl ServerBuilder {
171    /// Creates a new [ServerBuilder].
172    pub fn new(addr: SocketAddr) -> Self {
173        Self {
174            addr,
175            tls_config: None,
176            handlers: Default::default(),
177            headers: HeaderMap::new(),
178            not_found_fn: None,
179        }
180    }
181
182    /// Serves all requests content using TLS.
183    pub fn tls_config(mut self, config: Option<TlsConfig>) -> Self {
184        self.tls_config = config;
185        self
186    }
187
188    /// Adds a custom handler for a specific Method & URI.
189    pub fn request_handler(
190        mut self,
191        method: Method,
192        uri_path: &'static str,
193        handler: HyperHandler,
194    ) -> Self {
195        self.handlers.insert((method, uri_path), handler);
196        self
197    }
198
199    /// Sets a custom "404" handler.
200    #[allow(unused)]
201    pub fn not_found_handler(mut self, handler: HyperHandler) -> Self {
202        self.not_found_fn = Some(handler);
203        self
204    }
205
206    /// Adds HTTP headers to responses.
207    pub fn headers(mut self, headers: HeaderMap) -> Self {
208        for (k, v) in headers.iter() {
209            self.headers.insert(k.clone(), v.clone());
210        }
211        self
212    }
213
214    /// Builds and spawns an HTTP(S) Relay Server.
215    pub async fn spawn(self) -> Result<Server> {
216        let relay_server = ServerActorTask::new();
217        let relay_handler = relay_server.client_conn_handler(self.headers.clone());
218
219        let h = self.headers.clone();
220        let not_found_fn = match self.not_found_fn {
221            Some(f) => f,
222            None => Box::new(move |_req: Request<Incoming>, mut res: ResponseBuilder| {
223                for (k, v) in h.iter() {
224                    res = res.header(k.clone(), v.clone());
225                }
226                let body = body_full("Not Found");
227                let r = res.status(StatusCode::NOT_FOUND).body(body)?;
228                HyperResult::Ok(r)
229            }),
230        };
231
232        let service = RelayService::new(self.handlers, relay_handler, not_found_fn, self.headers);
233
234        let server_state = ServerState {
235            addr: self.addr,
236            tls_config: self.tls_config,
237            server: relay_server,
238            service,
239        };
240
241        // Spawns some server tasks, we only wait till all tasks are started.
242        server_state.serve().await
243    }
244}
245
246#[derive(Debug)]
247struct ServerState {
248    addr: SocketAddr,
249    tls_config: Option<TlsConfig>,
250    server: ServerActorTask,
251    service: RelayService,
252}
253
254impl ServerState {
255    // Binds a TCP listener on `addr` and handles content using HTTPS.
256    // Returns the local [`SocketAddr`] on which the server is listening.
257    async fn serve(self) -> Result<Server> {
258        let ServerState {
259            addr,
260            tls_config,
261            server,
262            service,
263        } = self;
264        let listener = TcpListener::bind(&addr)
265            .await
266            .with_context(|| format!("failed to bind server socket to {addr}"))?;
267        // we will use this cancel token to stop the infinite loop in the `listener.accept() task`
268        let cancel_server_loop = CancellationToken::new();
269        let addr = listener.local_addr()?;
270        let http_str = tls_config.as_ref().map_or("HTTP/WS", |_| "HTTPS/WSS");
271        info!("[{http_str}] relay: serving on {addr}");
272        let cancel = cancel_server_loop.clone();
273        let task = tokio::task::spawn(async move {
274            // create a join set to track all our connection tasks
275            let mut set = tokio::task::JoinSet::new();
276            loop {
277                tokio::select! {
278                    biased;
279                    _ = cancel.cancelled() => {
280                        break;
281                    }
282                    res = listener.accept() => match res {
283                        Ok((stream, peer_addr)) => {
284                            debug!("[{http_str}] relay: Connection opened from {peer_addr}");
285                            let tls_config = tls_config.clone();
286                            let service = service.clone();
287                            // spawn a task to handle the connection
288                            set.spawn(async move {
289                                if let Err(error) = service
290                                    .handle_connection(stream, tls_config)
291                                    .await
292                                {
293                                    match error.downcast_ref::<std::io::Error>() {
294                                        Some(io_error) if io_error.kind() == std::io::ErrorKind::UnexpectedEof => {
295                                            debug!(reason=?error, "[{http_str}] relay: peer disconnected");
296                                        },
297                                        _ => {
298                                            error!(?error, "[{http_str}] relay: failed to handle connection");
299                                        }
300                                    }
301                                }
302                            }.instrument(info_span!("conn", peer = %peer_addr)));
303                        }
304                        Err(err) => {
305                            error!("[{http_str}] relay: failed to accept connection: {err}");
306                        }
307                    }
308                }
309            }
310            // TODO: if the task this is running in is aborted this server is not shut
311            // down.
312            server.close().await;
313            set.shutdown().await;
314            debug!("[{http_str}] relay: server has been shutdown.");
315        }.instrument(info_span!("relay-http-serve")));
316
317        Ok(Server {
318            addr,
319            http_server_task: AbortOnDropHandle::new(task),
320            cancel_server_loop,
321        })
322    }
323}
324
325impl Service<Request<Incoming>> for ClientConnHandler {
326    type Response = Response<BytesBody>;
327    type Error = hyper::Error;
328    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
329
330    fn call(&self, mut req: Request<Incoming>) -> Self::Future {
331        // TODO: soooo much cloning. See if there is an alternative
332        let closure_conn_handler = self.clone();
333        let mut builder = Response::builder();
334        for (key, value) in self.default_headers.iter() {
335            builder = builder.header(key, value);
336        }
337
338        async move {
339            {
340                // Send a 400 to any request that doesn't have an `Upgrade` header.
341                let Some(protocol) = req.headers().get(UPGRADE).and_then(Protocol::parse_header)
342                else {
343                    return Ok(builder
344                        .status(StatusCode::BAD_REQUEST)
345                        .body(body_empty())
346                        .expect("valid body"));
347                };
348
349                let websocket_headers = if protocol == Protocol::Websocket {
350                    let Some(key) = req.headers().get("Sec-WebSocket-Key").cloned() else {
351                        warn!("missing header Sec-WebSocket-Key for websocket relay protocol");
352                        return Ok(builder
353                            .status(StatusCode::BAD_REQUEST)
354                            .body(body_empty())
355                            .expect("valid body"));
356                    };
357
358                    let Some(version) = req.headers().get("Sec-WebSocket-Version").cloned() else {
359                        warn!("missing header Sec-WebSocket-Version for websocket relay protocol");
360                        return Ok(builder
361                            .status(StatusCode::BAD_REQUEST)
362                            .body(body_empty())
363                            .expect("valid body"));
364                    };
365
366                    if version.as_bytes() != SUPPORTED_WEBSOCKET_VERSION.as_bytes() {
367                        warn!("invalid header Sec-WebSocket-Version: {:?}", version);
368                        return Ok(builder
369                            .status(StatusCode::BAD_REQUEST)
370                            // It's convention to send back the version(s) we *do* support
371                            .header("Sec-WebSocket-Version", SUPPORTED_WEBSOCKET_VERSION)
372                            .body(body_empty())
373                            .expect("valid body"));
374                    }
375
376                    Some((key, version))
377                } else {
378                    None
379                };
380
381                debug!("upgrading protocol: {:?}", protocol);
382
383                // Setup a future that will eventually receive the upgraded
384                // connection and talk a new protocol, and spawn the future
385                // into the runtime.
386                //
387                // Note: This can't possibly be fulfilled until the 101 response
388                // is returned below, so it's better to spawn this future instead
389                // waiting for it to complete to then return a response.
390                tokio::task::spawn(
391                    async move {
392                        match hyper::upgrade::on(&mut req).await {
393                            Ok(upgraded) => {
394                                if let Err(e) = relay_connection_handler(
395                                    protocol,
396                                    &closure_conn_handler,
397                                    upgraded,
398                                )
399                                .await
400                                {
401                                    warn!(
402                                        "upgrade to \"{}\": io error: {:?}",
403                                        e,
404                                        protocol.upgrade_header()
405                                    );
406                                } else {
407                                    debug!("upgrade to \"{}\" success", protocol.upgrade_header());
408                                };
409                            }
410                            Err(e) => warn!("upgrade error: {:?}", e),
411                        }
412                    }
413                    .instrument(debug_span!("handler")),
414                );
415
416                // Now return a 101 Response saying we agree to the upgrade to the
417                // HTTP_UPGRADE_PROTOCOL
418                builder = builder
419                    .status(StatusCode::SWITCHING_PROTOCOLS)
420                    .header(UPGRADE, HeaderValue::from_static(protocol.upgrade_header()));
421
422                if let Some((key, _version)) = websocket_headers {
423                    Ok(builder
424                        .header("Sec-WebSocket-Accept", &derive_accept_key(key.as_bytes()))
425                        .header(CONNECTION, "upgrade")
426                        .body(body_full("switching to websocket protocol"))
427                        .expect("valid body"))
428                } else {
429                    Ok(builder.body(body_empty()).expect("valid body"))
430                }
431            }
432        }
433        .boxed()
434    }
435}
436
437impl Service<Request<Incoming>> for RelayService {
438    type Response = Response<BytesBody>;
439    type Error = HyperError;
440    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
441
442    fn call(&self, req: Request<Incoming>) -> Self::Future {
443        // if the request hits the relay endpoint
444        // or /derp for backwards compat
445        if matches!(
446            (req.method(), req.uri().path()),
447            (&hyper::Method::GET, LEGACY_RELAY_PATH | RELAY_PATH)
448        ) {
449            let h = self.0.relay_handler.clone();
450            // otherwise handle the relay connection as normal
451            return Box::pin(async move { h.call(req).await.map_err(Into::into) });
452        }
453
454        // check all other possible endpoints
455        let uri = req.uri().clone();
456        if let Some(res) = self.0.handlers.get(&(req.method().clone(), uri.path())) {
457            let f = res(req, self.0.default_response());
458            return Box::pin(async move { f });
459        }
460        // otherwise return 404
461        let res = (self.0.not_found_fn)(req, self.0.default_response());
462        Box::pin(async move { res })
463    }
464}
465
466/// The hyper Service that servers the actual relay endpoints
467#[derive(Clone, Debug)]
468struct RelayService(Arc<Inner>);
469
470#[derive(derive_more::Debug)]
471struct Inner {
472    pub relay_handler: ClientConnHandler,
473    #[debug("Box<Fn(ResponseBuilder) -> Result<Response<BytesBody>> + Send + Sync + 'static>")]
474    pub not_found_fn: HyperHandler,
475    pub handlers: Handlers,
476    pub headers: HeaderMap,
477}
478
479impl Inner {
480    fn default_response(&self) -> ResponseBuilder {
481        let mut response = Response::builder();
482        for (key, value) in self.headers.iter() {
483            response = response.header(key.clone(), value.clone());
484        }
485        response
486    }
487}
488
489/// TLS Certificate Authority acceptor.
490#[derive(Clone, derive_more::Debug)]
491pub enum TlsAcceptor {
492    /// Uses Let's Encrypt as the Certificate Authority. This is used in production.
493    LetsEncrypt(#[debug("tokio_rustls_acme::AcmeAcceptor")] AcmeAcceptor),
494    /// Manually added tls acceptor. Generally used for tests or for when we've passed in
495    /// a certificate via a file.
496    Manual(#[debug("tokio_rustls::TlsAcceptor")] tokio_rustls::TlsAcceptor),
497}
498
499impl RelayService {
500    fn new(
501        handlers: Handlers,
502        relay_handler: ClientConnHandler,
503        not_found_fn: HyperHandler,
504        headers: HeaderMap,
505    ) -> Self {
506        Self(Arc::new(Inner {
507            relay_handler,
508            handlers,
509            not_found_fn,
510            headers,
511        }))
512    }
513
514    /// Handle the incoming connection.
515    ///
516    /// If a `tls_config` is given, will serve the connection using HTTPS.
517    async fn handle_connection(
518        self,
519        stream: TcpStream,
520        tls_config: Option<TlsConfig>,
521    ) -> Result<()> {
522        match tls_config {
523            Some(tls_config) => self.tls_serve_connection(stream, tls_config).await,
524            None => {
525                debug!("HTTP: serve connection");
526                self.serve_connection(MaybeTlsStream::Plain(stream)).await
527            }
528        }
529    }
530
531    /// Serve the tls connection
532    async fn tls_serve_connection(self, stream: TcpStream, tls_config: TlsConfig) -> Result<()> {
533        let TlsConfig { acceptor, config } = tls_config;
534        match acceptor {
535            TlsAcceptor::LetsEncrypt(a) => match a.accept(stream).await? {
536                None => {
537                    info!("TLS[acme]: received TLS-ALPN-01 validation request");
538                }
539                Some(start_handshake) => {
540                    debug!("TLS[acme]: start handshake");
541                    let tls_stream = start_handshake
542                        .into_stream(config)
543                        .await
544                        .context("TLS[acme] handshake")?;
545                    self.serve_connection(MaybeTlsStream::Tls(tls_stream))
546                        .await
547                        .context("TLS[acme] serve connection")?;
548                }
549            },
550            TlsAcceptor::Manual(a) => {
551                debug!("TLS[manual]: accept");
552                let tls_stream = a.accept(stream).await.context("TLS[manual] accept")?;
553                self.serve_connection(MaybeTlsStream::Tls(tls_stream))
554                    .await
555                    .context("TLS[manual] serve connection")?;
556            }
557        }
558        Ok(())
559    }
560
561    /// Wrapper for the actual http connection (with upgrades)
562    async fn serve_connection<I>(self, io: I) -> Result<()>
563    where
564        I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + Sync + 'static,
565    {
566        hyper::server::conn::http1::Builder::new()
567            .serve_connection(hyper_util::rt::TokioIo::new(io), self)
568            .with_upgrades()
569            .await?;
570        Ok(())
571    }
572}
573
574#[derive(Default)]
575struct Handlers(HashMap<(Method, &'static str), HyperHandler>);
576
577impl std::fmt::Debug for Handlers {
578    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
579        let s = self.0.keys().fold(String::new(), |curr, next| {
580            let (method, uri) = next;
581            format!("{curr}\n({method},{uri}): Box<Fn(ResponseBuilder) -> Result<Response<Body>> + Send + Sync + 'static>")
582        });
583        write!(f, "HashMap<{s}>")
584    }
585}
586
587impl std::ops::Deref for Handlers {
588    type Target = HashMap<(Method, &'static str), HyperHandler>;
589
590    fn deref(&self) -> &Self::Target {
591        &self.0
592    }
593}
594
595impl std::ops::DerefMut for Handlers {
596    fn deref_mut(&mut self) -> &mut Self::Target {
597        &mut self.0
598    }
599}
600
601#[cfg(test)]
602mod tests {
603    use std::sync::Arc;
604
605    use anyhow::Result;
606    use bytes::Bytes;
607    use reqwest::Url;
608    use tokio::{sync::mpsc, task::JoinHandle};
609    use tracing::{info, info_span, Instrument};
610    use tracing_subscriber::{prelude::*, EnvFilter};
611
612    use super::*;
613    use crate::{
614        key::{PublicKey, SecretKey},
615        relay::client::{conn::ReceivedMessage, Client, ClientBuilder},
616    };
617
618    pub(crate) fn make_tls_config() -> TlsConfig {
619        let subject_alt_names = vec!["localhost".to_string()];
620
621        let cert = rcgen::generate_simple_self_signed(subject_alt_names).unwrap();
622        let rustls_certificate =
623            rustls::pki_types::CertificateDer::from(cert.serialize_der().unwrap());
624        let rustls_key =
625            rustls::pki_types::PrivatePkcs8KeyDer::from(cert.get_key_pair().serialize_der());
626        let rustls_key = rustls::pki_types::PrivateKeyDer::from(rustls_key);
627        let config = rustls::ServerConfig::builder_with_provider(Arc::new(
628            rustls::crypto::ring::default_provider(),
629        ))
630        .with_safe_default_protocol_versions()
631        .expect("protocols supported by ring")
632        .with_no_client_auth()
633        .with_single_cert(vec![(rustls_certificate)], rustls_key)
634        .expect("cert is right");
635
636        let config = Arc::new(config);
637        let acceptor = tokio_rustls::TlsAcceptor::from(config.clone());
638
639        TlsConfig {
640            config,
641            acceptor: TlsAcceptor::Manual(acceptor),
642        }
643    }
644
645    #[tokio::test]
646    async fn test_http_clients_and_server() -> Result<()> {
647        let _guard = iroh_test::logging::setup();
648
649        let a_key = SecretKey::generate();
650        let b_key = SecretKey::generate();
651
652        // start server
653        let server = ServerBuilder::new("127.0.0.1:0".parse().unwrap())
654            .spawn()
655            .await?;
656
657        let addr = server.addr();
658
659        // get dial info
660        let port = addr.port();
661        let addr = {
662            if let std::net::IpAddr::V4(ipv4_addr) = addr.ip() {
663                ipv4_addr
664            } else {
665                anyhow::bail!("cannot get ipv4 addr from socket addr {addr:?}");
666            }
667        };
668        info!("addr: {addr}:{port}");
669        let relay_addr: Url = format!("http://{addr}:{port}").parse().unwrap();
670
671        // create clients
672        let (a_key, mut a_recv, client_a_task, client_a) = {
673            let span = info_span!("client-a");
674            let _guard = span.enter();
675            create_test_client(a_key, relay_addr.clone())
676        };
677        info!("created client {a_key:?}");
678        let (b_key, mut b_recv, client_b_task, client_b) = {
679            let span = info_span!("client-b");
680            let _guard = span.enter();
681            create_test_client(b_key, relay_addr)
682        };
683        info!("created client {b_key:?}");
684
685        info!("ping a");
686        client_a.ping().await?;
687
688        info!("ping b");
689        client_b.ping().await?;
690
691        info!("sending message from a to b");
692        let msg = Bytes::from_static(b"hi there, client b!");
693        client_a.send(b_key, msg.clone()).await?;
694        info!("waiting for message from a on b");
695        let (got_key, got_msg) = b_recv.recv().await.expect("expected message from client_a");
696        assert_eq!(a_key, got_key);
697        assert_eq!(msg, got_msg);
698
699        info!("sending message from b to a");
700        let msg = Bytes::from_static(b"right back at ya, client b!");
701        client_b.send(a_key, msg.clone()).await?;
702        info!("waiting for message b on a");
703        let (got_key, got_msg) = a_recv.recv().await.expect("expected message from client_b");
704        assert_eq!(b_key, got_key);
705        assert_eq!(msg, got_msg);
706
707        client_a.close().await?;
708        client_a_task.abort();
709        client_b.close().await?;
710        client_b_task.abort();
711        server.shutdown();
712
713        Ok(())
714    }
715
716    fn create_test_client(
717        key: SecretKey,
718        server_url: Url,
719    ) -> (
720        PublicKey,
721        mpsc::Receiver<(PublicKey, Bytes)>,
722        JoinHandle<()>,
723        Client,
724    ) {
725        let client = ClientBuilder::new(server_url).insecure_skip_cert_verify(true);
726        let dns_resolver = crate::dns::default_resolver();
727        let (client, mut client_reader) = client.build(key.clone(), dns_resolver.clone());
728        let public_key = key.public();
729        let (received_msg_s, received_msg_r) = tokio::sync::mpsc::channel(10);
730        let client_reader_task = tokio::spawn(
731            async move {
732                loop {
733                    info!("waiting for message on {:?}", key.public());
734                    match client_reader.recv().await {
735                        None => {
736                            info!("client received nothing");
737                            return;
738                        }
739                        Some(Err(e)) => {
740                            info!("client {:?} `recv` error {e}", key.public());
741                            return;
742                        }
743                        Some(Ok(msg)) => {
744                            info!("got message on {:?}: {msg:?}", key.public());
745                            if let ReceivedMessage::ReceivedPacket { source, data } = msg {
746                                received_msg_s
747                                    .send((source, data))
748                                    .await
749                                    .unwrap_or_else(|err| {
750                                        panic!(
751                                            "client {:?}, error sending message over channel: {:?}",
752                                            key.public(),
753                                            err
754                                        )
755                                    });
756                            }
757                        }
758                    }
759                }
760            }
761            .instrument(info_span!("test-client-reader")),
762        );
763        (public_key, received_msg_r, client_reader_task, client)
764    }
765
766    #[tokio::test]
767    async fn test_https_clients_and_server() -> Result<()> {
768        tracing_subscriber::registry()
769            .with(tracing_subscriber::fmt::layer().with_writer(std::io::stderr))
770            .with(EnvFilter::from_default_env())
771            .try_init()
772            .ok();
773
774        let a_key = SecretKey::generate();
775        let b_key = SecretKey::generate();
776
777        // create tls_config
778        let tls_config = make_tls_config();
779
780        // start server
781        let mut server = ServerBuilder::new("127.0.0.1:0".parse().unwrap())
782            .tls_config(Some(tls_config))
783            .spawn()
784            .await?;
785
786        let addr = server.addr();
787
788        // get dial info
789        let port = addr.port();
790        let addr = {
791            if let std::net::IpAddr::V4(ipv4_addr) = addr.ip() {
792                ipv4_addr
793            } else {
794                anyhow::bail!("cannot get ipv4 addr from socket addr {addr:?}");
795            }
796        };
797        info!("Relay listening on: {addr}:{port}");
798
799        let url: Url = format!("https://localhost:{port}").parse().unwrap();
800
801        // create clients
802        let (a_key, mut a_recv, client_a_task, client_a) = create_test_client(a_key, url.clone());
803        info!("created client {a_key:?}");
804        let (b_key, mut b_recv, client_b_task, client_b) = create_test_client(b_key, url);
805        info!("created client {b_key:?}");
806
807        client_a.ping().await?;
808        client_b.ping().await?;
809
810        info!("sending message from a to b");
811        let msg = Bytes::from_static(b"hi there, client b!");
812        client_a.send(b_key, msg.clone()).await?;
813        info!("waiting for message from a on b");
814        let (got_key, got_msg) = b_recv.recv().await.expect("expected message from client_a");
815        assert_eq!(a_key, got_key);
816        assert_eq!(msg, got_msg);
817
818        info!("sending message from b to a");
819        let msg = Bytes::from_static(b"right back at ya, client b!");
820        client_b.send(a_key, msg.clone()).await?;
821        info!("waiting for message b on a");
822        let (got_key, got_msg) = a_recv.recv().await.expect("expected message from client_b");
823        assert_eq!(b_key, got_key);
824        assert_eq!(msg, got_msg);
825
826        server.shutdown();
827        server.task_handle().await?;
828        client_a.close().await?;
829        client_a_task.abort();
830        client_b.close().await?;
831        client_b_task.abort();
832        Ok(())
833    }
834}