holochain_websocket/
lib.rs

1#![deny(missing_docs)]
2//! Holochain websocket support library.
3//! This is currently a thin wrapper around tokio-tungstenite that
4//! provides rpc-style request/responses via u64 message ids.
5
6use bytes::Bytes;
7use holochain_serialized_bytes::prelude::*;
8use holochain_types::websocket::AllowedOrigins;
9use std::io::ErrorKind;
10pub use std::io::{Error, Result};
11use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
12use std::sync::Arc;
13use tokio::net::ToSocketAddrs;
14use tokio::select;
15use tokio_tungstenite::tungstenite::handshake::client::Request;
16use tokio_tungstenite::tungstenite::handshake::server::{Callback, ErrorResponse, Response};
17use tokio_tungstenite::tungstenite::http::{HeaderMap, HeaderValue, StatusCode};
18use tokio_tungstenite::tungstenite::protocol::Message;
19
20/// The messages actually sent over the wire by this library.
21/// If you want to implement your own server or client you
22/// will need this type or be able to serialize / deserialize it.
23#[derive(Debug, serde::Serialize, serde::Deserialize, SerializedBytes)]
24#[serde(rename_all = "snake_case", tag = "type")]
25pub enum WireMessage {
26    /// A message without a response.
27    Signal {
28        /// Actual bytes of the message serialized as [message pack](https://msgpack.org/).
29        #[serde(with = "serde_bytes")]
30        data: Vec<u8>,
31    },
32
33    /// An authentication message, sent by the client if the server requires it.
34    Authenticate {
35        /// Actual bytes of the message serialized as [message pack](https://msgpack.org/).
36        #[serde(with = "serde_bytes")]
37        data: Vec<u8>,
38    },
39
40    /// A request that requires a response.
41    Request {
42        /// The id of this request.
43        id: u64,
44        /// Actual bytes of the message serialized as [message pack](https://msgpack.org/).
45        #[serde(with = "serde_bytes")]
46        data: Vec<u8>,
47    },
48
49    /// The response to a request.
50    Response {
51        /// The id of the request that this response is for.
52        id: u64,
53        /// Actual bytes of the message serialized as [message pack](https://msgpack.org/).
54        #[serde(with = "serde_bytes")]
55        data: Option<Vec<u8>>,
56    },
57}
58
59impl WireMessage {
60    /// Deserialize a WireMessage.
61    fn try_from_bytes(b: Vec<u8>) -> WebsocketResult<Self> {
62        let b = UnsafeBytes::from(b);
63        let b = SerializedBytes::from(b);
64        let b: WireMessage = b.try_into()?;
65        Ok(b)
66    }
67
68    /// Create a new authenticate message.
69    fn authenticate<S>(s: S) -> WebsocketResult<Message>
70    where
71        S: std::fmt::Debug,
72        SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
73    {
74        let s1 = SerializedBytes::try_from(s)?;
75        let s2 = Self::Authenticate {
76            data: UnsafeBytes::from(s1).into(),
77        };
78        let s3: SerializedBytes = s2.try_into()?;
79        Ok(Message::Binary(Bytes::copy_from_slice(
80            s3.bytes().as_slice(),
81        )))
82    }
83
84    /// Create a new request message (with new unique msg id).
85    fn request<S>(s: S) -> WebsocketResult<(Message, u64)>
86    where
87        S: std::fmt::Debug,
88        SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
89    {
90        static ID: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1);
91        let id = ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
92        tracing::trace!(?s, %id, "OutRequest");
93        let s1 = SerializedBytes::try_from(s)?;
94        let s2 = Self::Request {
95            id,
96            data: UnsafeBytes::from(s1).into(),
97        };
98        let s3: SerializedBytes = s2.try_into()?;
99        Ok((
100            Message::Binary(Bytes::copy_from_slice(s3.bytes().as_slice())),
101            id,
102        ))
103    }
104
105    /// Create a new response message.
106    fn response<S>(id: u64, s: S) -> WebsocketResult<Message>
107    where
108        S: std::fmt::Debug,
109        SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
110    {
111        let s1 = SerializedBytes::try_from(s)?;
112        let s2 = Self::Response {
113            id,
114            data: Some(UnsafeBytes::from(s1).into()),
115        };
116        let s3: SerializedBytes = s2.try_into()?;
117        Ok(Message::Binary(Bytes::copy_from_slice(
118            s3.bytes().as_slice(),
119        )))
120    }
121
122    /// Create a new signal message.
123    fn signal<S>(s: S) -> WebsocketResult<Message>
124    where
125        S: std::fmt::Debug,
126        SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
127    {
128        tracing::trace!(?s, "SendSignal");
129        let s1 = SerializedBytes::try_from(s)?;
130        let s2 = Self::Signal {
131            data: UnsafeBytes::from(s1).into(),
132        };
133        let s3: SerializedBytes = s2.try_into()?;
134        Ok(Message::Binary(Bytes::copy_from_slice(
135            s3.bytes().as_slice(),
136        )))
137    }
138}
139
140/// Websocket configuration struct.
141#[derive(Clone, Debug)]
142pub struct WebsocketConfig {
143    /// Seconds after which the lib will stop tracking individual request ids.
144    /// [default = 60 seconds]
145    pub default_request_timeout: std::time::Duration,
146
147    /// Maximum total message size of a websocket message. [default = 64M]
148    pub max_message_size: usize,
149
150    /// Maximum websocket frame size. [default = 16M]
151    pub max_frame_size: usize,
152
153    /// Allowed origins access control for a [WebsocketListener].
154    /// Not used by the [WebsocketSender].
155    pub allowed_origins: Option<AllowedOrigins>,
156}
157
158impl WebsocketConfig {
159    /// The default client WebsocketConfig.
160    pub const CLIENT_DEFAULT: WebsocketConfig = WebsocketConfig {
161        default_request_timeout: std::time::Duration::from_secs(60),
162        max_message_size: 64 << 20,
163        max_frame_size: 16 << 20,
164        allowed_origins: None,
165    };
166
167    /// The default listener WebsocketConfig.
168    pub const LISTENER_DEFAULT: WebsocketConfig = WebsocketConfig {
169        default_request_timeout: std::time::Duration::from_secs(60),
170        max_message_size: 64 << 20,
171        max_frame_size: 16 << 20,
172        allowed_origins: Some(AllowedOrigins::Any),
173    };
174
175    /// Internal convert to tungstenite config.
176    pub(crate) fn as_tungstenite(
177        &self,
178    ) -> tokio_tungstenite::tungstenite::protocol::WebSocketConfig {
179        tokio_tungstenite::tungstenite::protocol::WebSocketConfig::default()
180            .max_message_size(Some(self.max_message_size))
181            .max_frame_size(Some(self.max_frame_size))
182    }
183}
184
185struct RMapInner(
186    pub  std::collections::HashMap<
187        u64,
188        tokio::sync::oneshot::Sender<WebsocketResult<SerializedBytes>>,
189    >,
190);
191
192impl Drop for RMapInner {
193    fn drop(&mut self) {
194        self.close();
195    }
196}
197
198impl RMapInner {
199    fn close(&mut self) {
200        for (_, s) in self.0.drain() {
201            let _ = s.send(Err(WebsocketError::Close("ConnectionClosed".to_string())));
202        }
203    }
204}
205
206#[derive(Clone)]
207struct RMap(Arc<std::sync::Mutex<RMapInner>>);
208
209impl Default for RMap {
210    fn default() -> Self {
211        Self(Arc::new(std::sync::Mutex::new(RMapInner(
212            std::collections::HashMap::default(),
213        ))))
214    }
215}
216
217impl RMap {
218    pub fn close(&self) {
219        if let Ok(mut lock) = self.0.lock() {
220            lock.close();
221        }
222    }
223
224    pub fn insert(
225        &self,
226        id: u64,
227        sender: tokio::sync::oneshot::Sender<WebsocketResult<SerializedBytes>>,
228    ) {
229        self.0.lock().unwrap().0.insert(id, sender);
230    }
231
232    pub fn remove(
233        &self,
234        id: u64,
235    ) -> Option<tokio::sync::oneshot::Sender<WebsocketResult<SerializedBytes>>> {
236        self.0.lock().unwrap().0.remove(&id)
237    }
238}
239
240/// An error produced when working with websockets.
241///
242/// It is intended to capture all the errors that a caller might want to handle. Other errors that
243/// are unlikely to be recoverable are mapped to [WebsocketError::Other].
244#[derive(thiserror::Error, Debug)]
245pub enum WebsocketError {
246    /// The websocket has been closed by the other side.
247    #[error("Websocket closed: {0}")]
248    Close(String),
249    /// A received messaged did not deserialize to the expected type.
250    #[error("Received a message that did not deserialize: {0}")]
251    Deserialize(#[from] SerializedBytesError),
252    /// A websocket error from the underlying tungstenite library.
253    #[error("Websocket error: {0}")]
254    Websocket(#[from] Box<tokio_tungstenite::tungstenite::Error>),
255    /// A timeout occurred.
256    #[error("Timeout")]
257    Timeout(#[from] tokio::time::error::Elapsed),
258    /// An IO error occurred.
259    #[error("IO error: {0}")]
260    Io(#[from] Error),
261    /// Some other error occurred.
262    #[error("Other error: {0}")]
263    Other(String),
264}
265
266/// A result type, with the error type [WebsocketError].
267pub type WebsocketResult<T> = std::result::Result<T, WebsocketError>;
268
269type WsStream = tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>;
270type WsSend =
271    futures::stream::SplitSink<WsStream, tokio_tungstenite::tungstenite::protocol::Message>;
272type WsSendSync = Arc<tokio::sync::Mutex<WsSend>>;
273type WsRecv = futures::stream::SplitStream<WsStream>;
274type WsRecvSync = Arc<tokio::sync::Mutex<WsRecv>>;
275
276#[derive(Clone)]
277struct WsCore {
278    pub send: WsSendSync,
279    pub recv: WsRecvSync,
280    pub rmap: RMap,
281    pub timeout: std::time::Duration,
282}
283
284#[derive(Clone)]
285struct WsCoreSync(Arc<std::sync::Mutex<Option<WsCore>>>);
286
287impl PartialEq for WsCoreSync {
288    fn eq(&self, other: &Self) -> bool {
289        Arc::ptr_eq(&self.0, &other.0)
290    }
291}
292
293impl WsCoreSync {
294    fn close(&self) {
295        if let Some(core) = self.0.lock().unwrap().take() {
296            core.rmap.close();
297            tokio::task::spawn(async move {
298                use futures::sink::SinkExt;
299                let _ = core.send.lock().await.close().await;
300            });
301        }
302    }
303
304    fn close_if_err<R>(&self, r: WebsocketResult<R>) -> WebsocketResult<R> {
305        match r {
306            Err(e @ WebsocketError::Deserialize { .. }) => {
307                // Don't close the connection on a deserialization error.
308                // That's a client issue and not a connection issue.
309                Err(e)
310            }
311            Err(err) => {
312                self.close();
313                Err(err)
314            }
315            Ok(res) => Ok(res),
316        }
317    }
318
319    pub async fn exec<F, C, R>(&self, c: C) -> WebsocketResult<R>
320    where
321        F: std::future::Future<Output = WebsocketResult<R>>,
322        C: FnOnce(WsCoreSync, WsCore) -> F,
323    {
324        let core = match self.0.lock().unwrap().as_ref() {
325            Some(core) => core.clone(),
326            None => return Err(WebsocketError::Close("No connection".to_string())),
327        };
328        self.close_if_err(c(self.clone(), core).await)
329    }
330}
331
332/// Respond to an incoming request.
333#[derive(PartialEq)]
334pub struct WebsocketRespond {
335    id: u64,
336    core: WsCoreSync,
337}
338
339impl std::fmt::Debug for WebsocketRespond {
340    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341        f.debug_struct("WebsocketRespond")
342            .field("id", &self.id)
343            .finish()
344    }
345}
346
347impl WebsocketRespond {
348    /// Respond to an incoming request.
349    pub async fn respond<S>(self, s: S) -> WebsocketResult<()>
350    where
351        S: std::fmt::Debug,
352        SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
353    {
354        tracing::trace!(?s, %self.id, "OutResponse");
355        use futures::sink::SinkExt;
356        self.core
357            .exec(move |_, core| async move {
358                tokio::time::timeout(core.timeout, async {
359                    let s = WireMessage::response(self.id, s)?;
360                    core.send.lock().await.send(s).await.map_err(Box::new)?;
361                    Ok(())
362                })
363                .await?
364            })
365            .await
366    }
367}
368
369/// Types of messages that can be received by a WebsocketReceiver.
370#[derive(Debug, PartialEq)]
371pub enum ReceiveMessage<D>
372where
373    D: std::fmt::Debug,
374    SerializedBytes: TryInto<D, Error = SerializedBytesError>,
375{
376    /// Received a request to authenticate from the client.
377    Authenticate(Vec<u8>),
378
379    /// Received a signal from the remote.
380    Signal(Vec<u8>),
381
382    /// Received a request from the remote.
383    Request(D, WebsocketRespond),
384    /// Received a request that is malformed.
385    BadRequest(WebsocketRespond),
386}
387
388/// Receive signals and requests from a websocket connection.
389/// Note, This receiver must be polled (recv()) for responses to requests
390/// made on the Sender side to be received.
391/// If this receiver is dropped, the sender side will also be closed.
392pub struct WebsocketReceiver(
393    WsCoreSync,
394    std::net::SocketAddr,
395    tokio::task::JoinHandle<()>,
396);
397
398impl Drop for WebsocketReceiver {
399    fn drop(&mut self) {
400        self.0.close();
401        self.2.abort();
402    }
403}
404
405impl WebsocketReceiver {
406    fn new(core: WsCoreSync, addr: std::net::SocketAddr) -> Self {
407        let core2 = core.clone();
408        let ping_task = tokio::task::spawn(async move {
409            loop {
410                tokio::time::sleep(std::time::Duration::from_secs(5)).await;
411                let core = core2.0.lock().unwrap().as_ref().cloned();
412                if let Some(core) = core {
413                    use futures::sink::SinkExt;
414                    if core
415                        .send
416                        .lock()
417                        .await
418                        .send(Message::Ping(Bytes::new()))
419                        .await
420                        .is_err()
421                    {
422                        core2.close();
423                    }
424                } else {
425                    break;
426                }
427            }
428        });
429        Self(core, addr, ping_task)
430    }
431
432    /// Peer address.
433    pub fn peer_addr(&self) -> std::net::SocketAddr {
434        self.1
435    }
436
437    /// Receive the next message.
438    pub async fn recv<D>(&mut self) -> WebsocketResult<ReceiveMessage<D>>
439    where
440        D: std::fmt::Debug,
441        SerializedBytes: TryInto<D, Error = SerializedBytesError>,
442    {
443        match self.recv_inner().await {
444            Err(err) => {
445                tracing::warn!(?err, "WebsocketReceiver Error");
446                Err(err)
447            }
448            Ok(msg) => Ok(msg),
449        }
450    }
451
452    async fn recv_inner<D>(&mut self) -> WebsocketResult<ReceiveMessage<D>>
453    where
454        D: std::fmt::Debug,
455        SerializedBytes: TryInto<D, Error = SerializedBytesError>,
456    {
457        use futures::sink::SinkExt;
458        use futures::stream::StreamExt;
459        loop {
460            if let Some(result) = self
461                .0
462                .exec(move |core_sync, core| async move {
463                    let msg = core
464                        .recv
465                        .lock()
466                        .await
467                        .next()
468                        .await
469                        .ok_or::<WebsocketError>(WebsocketError::Other(
470                            "ReceiverClosed".to_string(),
471                        ))?
472                        .map_err(Box::new)?;
473                    let msg = match msg {
474                        Message::Text(s) => s.as_bytes().to_vec(),
475                        Message::Binary(b) => b.to_vec(),
476                        Message::Ping(b) => {
477                            core.send
478                                .lock()
479                                .await
480                                .send(Message::Pong(b))
481                                .await
482                                .map_err(Box::new)?;
483                            return Ok(None);
484                        }
485                        Message::Pong(_) => return Ok(None),
486                        Message::Close(frame) => {
487                            return Err(WebsocketError::Close(format!("{frame:?}")));
488                        }
489                        Message::Frame(_) => {
490                            return Err(WebsocketError::Other("UnexpectedRawFrame".to_string()))
491                        }
492                    };
493                    match WireMessage::try_from_bytes(msg)? {
494                        WireMessage::Authenticate { data } => {
495                            Ok(Some(ReceiveMessage::Authenticate(data)))
496                        }
497                        WireMessage::Request { id, data } => {
498                            let resp = WebsocketRespond {
499                                id,
500                                core: core_sync,
501                            };
502                            let data: D =
503                                match SerializedBytes::from(UnsafeBytes::from(data)).try_into() {
504                                    Ok(value) => value,
505                                    Err(_) => {
506                                        return Ok(Some(ReceiveMessage::BadRequest(resp)));
507                                    }
508                                };
509                            tracing::trace!(?data, %id, "InRequest");
510                            Ok(Some(ReceiveMessage::Request(data, resp)))
511                        }
512                        WireMessage::Response { id, data } => {
513                            if let Some(sender) = core.rmap.remove(id) {
514                                if let Some(data) = data {
515                                    let data = SerializedBytes::from(UnsafeBytes::from(data));
516                                    tracing::trace!(%id, ?data, "InResponse");
517                                    let _ = sender.send(Ok(data));
518                                }
519                            }
520                            Ok(None)
521                        }
522                        WireMessage::Signal { data } => Ok(Some(ReceiveMessage::Signal(data))),
523                    }
524                })
525                .await?
526            {
527                return Ok(result);
528            }
529        }
530    }
531}
532
533/// Send requests and signals to the remote end of this websocket connection.
534/// Note, this receiver side must be polled (recv()) for responses to requests
535/// made on this sender to be received.
536#[derive(Clone)]
537pub struct WebsocketSender(WsCoreSync, std::time::Duration);
538
539impl WebsocketSender {
540    /// Authenticate with the remote using the default configured timeout.
541    pub async fn authenticate<S>(&self, s: S) -> WebsocketResult<()>
542    where
543        S: std::fmt::Debug,
544        SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
545    {
546        self.authenticate_timeout(s, self.1).await
547    }
548
549    /// Authenticate with the remote.
550    pub async fn authenticate_timeout<S>(
551        &self,
552        s: S,
553        timeout: std::time::Duration,
554    ) -> WebsocketResult<()>
555    where
556        S: std::fmt::Debug,
557        SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
558    {
559        use futures::sink::SinkExt;
560        self.0
561            .exec(move |_, core| async move {
562                tokio::time::timeout(timeout, async {
563                    let s = WireMessage::authenticate(s)?;
564                    core.send.lock().await.send(s).await.map_err(Box::new)?;
565                    Ok(())
566                })
567                .await?
568            })
569            .await
570    }
571
572    /// Make a request of the remote using the default configured timeout.
573    /// Note, this receiver side must be polled (recv()) for responses to
574    /// requests made on this sender to be received.
575    pub async fn request<S, R>(&self, s: S) -> WebsocketResult<R>
576    where
577        S: std::fmt::Debug,
578        SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
579        R: serde::de::DeserializeOwned + std::fmt::Debug,
580    {
581        self.request_timeout(s, self.1).await
582    }
583
584    /// Make a request of the remote.
585    pub async fn request_timeout<S, R>(
586        &self,
587        s: S,
588        timeout: std::time::Duration,
589    ) -> WebsocketResult<R>
590    where
591        S: std::fmt::Debug,
592        SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
593        R: serde::de::DeserializeOwned + std::fmt::Debug,
594    {
595        let timeout_at = tokio::time::Instant::now() + timeout;
596
597        use futures::sink::SinkExt;
598
599        let (s, id) = WireMessage::request(s)?;
600
601        /// Drop helper to remove our response callback if we timeout.
602        struct D(RMap, u64);
603
604        impl Drop for D {
605            fn drop(&mut self) {
606                self.0.remove(self.1);
607            }
608        }
609
610        let (resp_s, resp_r) = tokio::sync::oneshot::channel();
611
612        let _drop = self
613            .0
614            .exec(move |_, core| async move {
615                // create the drop helper
616                let drop = D(core.rmap.clone(), id);
617
618                // register the response callback
619                core.rmap.insert(id, resp_s);
620
621                tokio::time::timeout_at(timeout_at, async move {
622                    // send the actual message
623                    core.send.lock().await.send(s).await.map_err(Box::new)?;
624
625                    Ok(drop)
626                })
627                .await?
628            })
629            .await?;
630
631        // do the remainder outside the 'exec' because we don't actually
632        // want to close the connection down if an individual response is
633        // not returned... that is separate from the connection no longer
634        // being viable. (but we still want it to timeout at the same point)
635        tokio::time::timeout_at(timeout_at, async {
636            // await the response
637            let resp = resp_r
638                .await
639                .map_err(|_| WebsocketError::Other("ResponderDropped".to_string()))??;
640
641            // decode the response
642            let res = decode(&Vec::from(UnsafeBytes::from(resp)))?;
643            tracing::trace!(?res, %id, "OutRequestResponse");
644            Ok(res)
645        })
646        .await?
647    }
648
649    /// Send a signal to the remote using the default configured timeout.
650    pub async fn signal<S>(&self, s: S) -> WebsocketResult<()>
651    where
652        S: std::fmt::Debug,
653        SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
654    {
655        self.signal_timeout(s, self.1).await
656    }
657
658    /// Send a signal to the remote.
659    pub async fn signal_timeout<S>(&self, s: S, timeout: std::time::Duration) -> WebsocketResult<()>
660    where
661        S: std::fmt::Debug,
662        SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
663    {
664        use futures::sink::SinkExt;
665        self.0
666            .exec(move |_, core| async move {
667                tokio::time::timeout(timeout, async {
668                    let s = WireMessage::signal(s)?;
669                    core.send.lock().await.send(s).await.map_err(Box::new)?;
670                    Ok(())
671                })
672                .await?
673            })
674            .await
675    }
676}
677
678fn split(
679    stream: WsStream,
680    timeout: std::time::Duration,
681    peer_addr: std::net::SocketAddr,
682) -> WebsocketResult<(WebsocketSender, WebsocketReceiver)> {
683    let (sink, stream) = futures::stream::StreamExt::split(stream);
684
685    // Q: Why do we split the parts only to seemingly put them back together?
686    // A: They are in separate tokio mutexes, so we can still receive
687    //    and send at the same time in separate tasks, but being in the same
688    //    WsCore(Sync) lets us close them both at the same time if either
689    //    one errors.
690    let core = WsCore {
691        send: Arc::new(tokio::sync::Mutex::new(sink)),
692        recv: Arc::new(tokio::sync::Mutex::new(stream)),
693        rmap: RMap::default(),
694        timeout,
695    };
696
697    let core_send = WsCoreSync(Arc::new(std::sync::Mutex::new(Some(core))));
698    let core_recv = core_send.clone();
699
700    Ok((
701        WebsocketSender(core_send, timeout),
702        WebsocketReceiver::new(core_recv, peer_addr),
703    ))
704}
705
706/// Establish a new outgoing websocket connection to remote.
707pub async fn connect(
708    config: Arc<WebsocketConfig>,
709    request: impl Into<ConnectRequest>,
710) -> WebsocketResult<(WebsocketSender, WebsocketReceiver)> {
711    let request = request.into();
712    let stream = tokio::net::TcpStream::connect(request.addr).await?;
713    let peer_addr = stream.peer_addr()?;
714    let (stream, _addr) = tokio_tungstenite::client_async_with_config(
715        request.into_client_request()?,
716        stream,
717        Some(config.as_tungstenite()),
718    )
719    .await
720    .map_err(Box::new)?;
721    split(stream, config.default_request_timeout, peer_addr)
722}
723
724/// A request to connect to a websocket server.
725pub struct ConnectRequest {
726    addr: std::net::SocketAddr,
727    headers: HeaderMap<HeaderValue>,
728}
729
730impl From<std::net::SocketAddr> for ConnectRequest {
731    fn from(addr: std::net::SocketAddr) -> Self {
732        Self::new(addr)
733    }
734}
735
736impl ConnectRequest {
737    /// Create a new [ConnectRequest].
738    pub fn new(addr: std::net::SocketAddr) -> Self {
739        let mut cr = ConnectRequest {
740            addr,
741            headers: HeaderMap::new(),
742        };
743
744        // Set a default Origin so that the connection request will be allowed by default when the listener is
745        // using `Any` as the allowed origin.
746        cr.headers.insert(
747            "Origin",
748            HeaderValue::from_str("holochain_websocket").expect("Invalid Origin value"),
749        );
750
751        cr
752    }
753
754    /// Try to set a header on this request.
755    ///
756    /// Errors if the value is invalid. See [HeaderValue::from_str].
757    pub fn try_set_header(mut self, name: &'static str, value: &str) -> Result<Self> {
758        self.headers
759            .insert(name, HeaderValue::from_str(value).map_err(Error::other)?);
760        Ok(self)
761    }
762
763    fn into_client_request(
764        self,
765    ) -> Result<impl tokio_tungstenite::tungstenite::client::IntoClientRequest + Unpin> {
766        use tokio_tungstenite::tungstenite::client::IntoClientRequest;
767        let mut req =
768            String::into_client_request(format!("ws://{}", self.addr)).map_err(Error::other)?;
769        for (name, value) in self.headers {
770            if let Some(name) = name {
771                req.headers_mut().insert(name, value);
772            } else {
773                tracing::warn!("Dropping invalid header");
774            }
775        }
776        Ok(req)
777    }
778
779    #[cfg(test)]
780    pub(crate) fn clear_headers(mut self) -> Self {
781        self.headers.clear();
782
783        self
784    }
785}
786
787// TODO async_trait still needed for dynamic dispatch https://blog.rust-lang.org/2023/12/21/async-fn-rpit-in-traits.html#dynamic-dispatch
788#[async_trait::async_trait]
789trait TcpListener: Send + Sync {
790    async fn accept(&self) -> Result<(tokio::net::TcpStream, SocketAddr)>;
791
792    fn local_addrs(&self) -> Result<Vec<SocketAddr>>;
793}
794
795#[async_trait::async_trait]
796impl TcpListener for tokio::net::TcpListener {
797    async fn accept(&self) -> Result<(tokio::net::TcpStream, SocketAddr)> {
798        self.accept().await
799    }
800
801    fn local_addrs(&self) -> Result<Vec<SocketAddr>> {
802        Ok(vec![self.local_addr()?])
803    }
804}
805
806struct DualStackListener {
807    v4: tokio::net::TcpListener,
808    v6: tokio::net::TcpListener,
809}
810
811#[async_trait::async_trait]
812impl TcpListener for DualStackListener {
813    async fn accept(&self) -> Result<(tokio::net::TcpStream, SocketAddr)> {
814        let (stream, addr) = select! {
815            res = self.v4.accept() => res?,
816            res = self.v6.accept() => res?,
817        };
818        Ok((stream, addr))
819    }
820
821    fn local_addrs(&self) -> Result<Vec<SocketAddr>> {
822        Ok(vec![self.v4.local_addr()?, self.v6.local_addr()?])
823    }
824}
825
826/// A Holochain websocket listener.
827pub struct WebsocketListener {
828    config: Arc<WebsocketConfig>,
829    access_control: Arc<AllowedOrigins>,
830    listener: Box<dyn TcpListener>,
831}
832
833impl Drop for WebsocketListener {
834    fn drop(&mut self) {
835        tracing::info!("WebsocketListenerDrop");
836    }
837}
838
839impl WebsocketListener {
840    /// Bind a new websocket listener.
841    pub async fn bind(config: Arc<WebsocketConfig>, addr: impl ToSocketAddrs) -> Result<Self> {
842        let access_control = Arc::new(config.allowed_origins.clone().ok_or_else(|| {
843            Error::other("WebsocketListener requires allowed_origins to be set in the config")
844        })?);
845
846        let listener = tokio::net::TcpListener::bind(addr).await?;
847
848        let addr = listener.local_addr()?;
849        tracing::info!(?addr, "WebsocketListener Listening");
850
851        Ok(Self {
852            config,
853            access_control,
854            listener: Box::new(listener),
855        })
856    }
857
858    /// Bind a new websocket listener on the same port using a v4 and a v6 socket.
859    ///
860    /// If the port is 0, then the OS will be allowed to pick a port for IPv6. This function will
861    /// then try to bind to the same port for IPv4. If the OS picks a port that is not available for
862    /// IPv4, then the function will retry binding IPv6 to get a new port and see if that is
863    /// available for IPv4. If this fails after 5 retries, then an error will be returned.
864    ///
865    /// If either IPv4 or IPv6 is disabled, then the function will fall back to binding to the
866    /// available stack. An info message will be logged to let the user know that one interface was
867    /// unavailable, but this is likely intentional or expected in the user's environment, so it will
868    /// not be treated as an error that should prevent the listener from starting.
869    ///
870    /// Note: The interface fallback behaviour can be tested manually on Linux by running:
871    /// `echo 1 | sudo tee /proc/sys/net/ipv6/conf/lo/disable_ipv6`
872    /// and then trying to start Holochain with info logging enabled. You can undo the change with:
873    /// `echo 0 | sudo tee /proc/sys/net/ipv6/conf/lo/disable_ipv6`.
874    pub async fn dual_bind(
875        config: Arc<WebsocketConfig>,
876        addr_v4: SocketAddrV4,
877        addr_v6: SocketAddrV6,
878    ) -> Result<Self> {
879        let access_control = Arc::new(config.allowed_origins.clone().ok_or_else(|| {
880            Error::other("WebsocketListener requires allowed_origins to be set in the config")
881        })?);
882
883        let addr_v6: SocketAddr = addr_v6.into();
884        let mut addr_v4: SocketAddr = addr_v4.into();
885
886        // The point of dual_bind is to bind to the same port on both v4 and v6
887        if addr_v6.port() != 0 && addr_v6.port() != addr_v4.port() {
888            return Err(Error::other(
889                "dual_bind requires the same port for IPv4 and IPv6",
890            ));
891        }
892
893        // Note that tokio binds to the stack matching the address type, so we can re-use the port
894        // without needing to create the sockets ourselves to configure this.
895
896        let mut listener: Option<DualStackListener> = None;
897        for _ in 0..5 {
898            let v6_listener = match tokio::net::TcpListener::bind(addr_v6).await {
899                Ok(l) => l,
900                // This is the error code that *should* be returned if IPv6 is disabled
901                Err(e) if e.kind() == ErrorKind::AddrNotAvailable => {
902                    tracing::info!(?e, "Failed to bind IPv6 listener because IPv6 appears to be disabled, falling back to IPv4 only");
903                    return Self::bind(config, addr_v4).await;
904                }
905                Err(e) => {
906                    tracing::error!("Failed to bind IPv6 listener: {:?}", e);
907                    return Err(e);
908                }
909            };
910
911            addr_v4.set_port(v6_listener.local_addr()?.port());
912
913            let v4_listener = match tokio::net::TcpListener::bind(addr_v4).await {
914                Ok(l) => l,
915                // This is the error code that *should* be returned if IPv4 is disabled
916                Err(e) if e.kind() == ErrorKind::AddrNotAvailable => {
917                    tracing::info!(?e, "Failed to bind IPv4 listener because IPv4 appears to be disabled, falling back to IPv6 only");
918                    // No need to re-bind the v6 listener, it's already bound. Just create a new Self
919                    // from the v6 listener and return it.
920                    return Ok(Self {
921                        config,
922                        access_control,
923                        listener: Box::new(v6_listener),
924                    });
925                }
926                // This is expected if `[::]` is bound for IPv6 and this OS automatically handles receiving IPv4
927                // connections on the IPv6 socket. In this case we just use the IPv6 socket and ignore IPv4.
928                Err(e) if addr_v6.ip().is_unspecified() && e.kind() == ErrorKind::AddrInUse => {
929                    tracing::info!(?e, "Failed to bind IPv4 listener because the address is already in use, falling back to IPv6 only");
930                    // No need to re-bind the v6 listener, it's already bound. Just create a new Self
931                    // from the v6 listener and return it.
932                    return Ok(Self {
933                        config,
934                        access_control,
935                        listener: Box::new(v6_listener),
936                    });
937                }
938                // If the port for IPv6 was selected by the OS but it isn't available for IPv4, retry and let the OS pick a new port for IPv6
939                // and hopefully it will be available for IPv4.
940                Err(e) if addr_v6.port() == 0 && e.kind() == ErrorKind::AddrInUse => {
941                    tracing::warn!(?e, "Failed to bind the same port for IPv4 that was selected for IPv6, retrying with a new port");
942                    continue;
943                }
944                Err(e) => {
945                    tracing::error!("Failed to bind IPv4 listener: {:?}", e);
946                    return Err(e);
947                }
948            };
949
950            listener = Some(DualStackListener {
951                v4: v4_listener,
952                v6: v6_listener,
953            });
954            break;
955        }
956
957        // Gave up after a few retries, there's no point in continuing forever because there might be
958        // something wrong that the logic above isn't accounting for.
959        let listener = listener.ok_or_else(|| {
960            Error::other("Failed to bind listener to IPv4 and IPv6 interfaces after 5 retries")
961        })?;
962
963        let addr = listener.v4.local_addr()?;
964        tracing::info!(?addr, "WebsocketListener listening");
965
966        let addr = listener.v6.local_addr()?;
967        tracing::info!(?addr, "WebsocketListener listening");
968
969        Ok(Self {
970            config,
971            access_control,
972            listener: Box::new(listener),
973        })
974    }
975
976    /// Get the bound local address of this listener.
977    pub fn local_addrs(&self) -> Result<Vec<std::net::SocketAddr>> {
978        self.listener.local_addrs()
979    }
980
981    /// Accept an incoming connection.
982    pub async fn accept(&self) -> WebsocketResult<(WebsocketSender, WebsocketReceiver)> {
983        let (stream, addr) = self.listener.accept().await?;
984        tracing::debug!(?addr, "Accept Incoming Websocket Connection");
985        let stream = tokio_tungstenite::accept_hdr_async_with_config(
986            stream,
987            ConnectCallback {
988                allowed_origin: self.access_control.clone(),
989            },
990            Some(self.config.as_tungstenite()),
991        )
992        .await
993        .map_err(Error::other)?;
994        split(stream, self.config.default_request_timeout, addr)
995    }
996}
997
998struct ConnectCallback {
999    allowed_origin: Arc<AllowedOrigins>,
1000}
1001
1002impl Callback for ConnectCallback {
1003    fn on_request(
1004        self,
1005        request: &Request,
1006        response: Response,
1007    ) -> std::result::Result<Response, ErrorResponse> {
1008        tracing::trace!(
1009            "Checking incoming websocket connection request with allowed origin {:?}: {:?}",
1010            self.allowed_origin,
1011            request.headers()
1012        );
1013        match request
1014            .headers()
1015            .get("Origin")
1016            .and_then(|v| v.to_str().ok())
1017        {
1018            Some(origin) => {
1019                if self.allowed_origin.is_allowed(origin) {
1020                    Ok(response)
1021                } else {
1022                    tracing::warn!("Rejecting websocket connection request with disallowed `Origin` header: {:?}", request);
1023                    let allowed_origin: String = self.allowed_origin.as_ref().clone().into();
1024                    match HeaderValue::from_str(&allowed_origin) {
1025                        Ok(allowed_origin) => {
1026                            let mut err_response = ErrorResponse::new(None);
1027                            *err_response.status_mut() = StatusCode::BAD_REQUEST;
1028                            err_response
1029                                .headers_mut()
1030                                .insert("Access-Control-Allow-Origin", allowed_origin);
1031                            Err(err_response)
1032                        }
1033                        Err(_) => {
1034                            // Shouldn't be possible to get here, the listener should be configured to require an origin
1035                            let mut err_response = ErrorResponse::new(Some(
1036                                "Invalid listener configuration for `Origin`".to_string(),
1037                            ));
1038                            *err_response.status_mut() = StatusCode::BAD_REQUEST;
1039                            Err(err_response)
1040                        }
1041                    }
1042                }
1043            }
1044            None => {
1045                tracing::warn!(
1046                    "Rejecting websocket connection request with missing `Origin` header: {:?}",
1047                    request
1048                );
1049                let mut err_response =
1050                    ErrorResponse::new(Some("Missing `Origin` header".to_string()));
1051                *err_response.status_mut() = StatusCode::BAD_REQUEST;
1052                Err(err_response)
1053            }
1054        }
1055    }
1056}
1057
1058#[cfg(test)]
1059mod test;