holochain_websocket/
lib.rs

1#![deny(missing_docs)]
2//! Holochain websocket support library.
3//! This is currently a thin wrapper around tokio-tungstenite that
4//! provides rpc-style request/responses via u64 message ids.
5
6use holochain_serialized_bytes::prelude::*;
7use holochain_types::websocket::AllowedOrigins;
8use std::io::ErrorKind;
9pub use std::io::{Error, Result};
10use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
11use std::sync::Arc;
12use tokio::net::ToSocketAddrs;
13use tokio::select;
14use tokio_tungstenite::tungstenite::handshake::client::Request;
15use tokio_tungstenite::tungstenite::handshake::server::{Callback, ErrorResponse, Response};
16use tokio_tungstenite::tungstenite::http::{HeaderMap, HeaderValue, StatusCode};
17use tokio_tungstenite::tungstenite::protocol::Message;
18
19#[derive(Debug, serde::Serialize, serde::Deserialize, SerializedBytes)]
20#[serde(rename_all = "snake_case", tag = "type")]
21/// The messages actually sent over the wire by this library.
22/// If you want to implement your own server or client you
23/// will need this type or be able to serialize / deserialize it.
24pub enum WireMessage {
25    /// A message without a response.
26    Signal {
27        #[serde(with = "serde_bytes")]
28        /// Actual bytes of the message serialized as [message pack](https://msgpack.org/).
29        data: Vec<u8>,
30    },
31
32    /// An authentication message, sent by the client if the server requires it.
33    Authenticate {
34        #[serde(with = "serde_bytes")]
35        /// Actual bytes of the message serialized as [message pack](https://msgpack.org/).
36        data: Vec<u8>,
37    },
38
39    /// A request that requires a response.
40    Request {
41        /// The id of this request.
42        id: u64,
43        #[serde(with = "serde_bytes")]
44        /// Actual bytes of the message serialized as [message pack](https://msgpack.org/).
45        data: Vec<u8>,
46    },
47
48    /// The response to a request.
49    Response {
50        /// The id of the request that this response is for.
51        id: u64,
52        #[serde(with = "serde_bytes")]
53        /// Actual bytes of the message serialized as [message pack](https://msgpack.org/).
54        data: Option<Vec<u8>>,
55    },
56}
57
58impl WireMessage {
59    /// Deserialize a WireMessage.
60    fn try_from_bytes(b: Vec<u8>) -> WebsocketResult<Self> {
61        let b = UnsafeBytes::from(b);
62        let b = SerializedBytes::from(b);
63        let b: WireMessage = b.try_into()?;
64        Ok(b)
65    }
66
67    /// Create a new authenticate message.
68    fn authenticate<S>(s: S) -> WebsocketResult<Message>
69    where
70        S: std::fmt::Debug,
71        SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
72    {
73        let s1 = SerializedBytes::try_from(s)?;
74        let s2 = Self::Authenticate {
75            data: UnsafeBytes::from(s1).into(),
76        };
77        let s3: SerializedBytes = s2.try_into()?;
78        Ok(Message::Binary(UnsafeBytes::from(s3).into()))
79    }
80
81    /// Create a new request message (with new unique msg id).
82    fn request<S>(s: S) -> WebsocketResult<(Message, u64)>
83    where
84        S: std::fmt::Debug,
85        SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
86    {
87        static ID: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1);
88        let id = ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
89        tracing::trace!(?s, %id, "OutRequest");
90        let s1 = SerializedBytes::try_from(s)?;
91        let s2 = Self::Request {
92            id,
93            data: UnsafeBytes::from(s1).into(),
94        };
95        let s3: SerializedBytes = s2.try_into()?;
96        Ok((Message::Binary(UnsafeBytes::from(s3).into()), id))
97    }
98
99    /// Create a new response message.
100    fn response<S>(id: u64, s: S) -> WebsocketResult<Message>
101    where
102        S: std::fmt::Debug,
103        SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
104    {
105        let s1 = SerializedBytes::try_from(s)?;
106        let s2 = Self::Response {
107            id,
108            data: Some(UnsafeBytes::from(s1).into()),
109        };
110        let s3: SerializedBytes = s2.try_into()?;
111        Ok(Message::Binary(UnsafeBytes::from(s3).into()))
112    }
113
114    /// Create a new signal message.
115    fn signal<S>(s: S) -> WebsocketResult<Message>
116    where
117        S: std::fmt::Debug,
118        SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
119    {
120        tracing::trace!(?s, "SendSignal");
121        let s1 = SerializedBytes::try_from(s)?;
122        let s2 = Self::Signal {
123            data: UnsafeBytes::from(s1).into(),
124        };
125        let s3: SerializedBytes = s2.try_into()?;
126        Ok(Message::Binary(UnsafeBytes::from(s3).into()))
127    }
128}
129
130/// Websocket configuration struct.
131#[derive(Clone, Debug)]
132pub struct WebsocketConfig {
133    /// Seconds after which the lib will stop tracking individual request ids.
134    /// [default = 60 seconds]
135    pub default_request_timeout: std::time::Duration,
136
137    /// Maximum total message size of a websocket message. [default = 64M]
138    pub max_message_size: usize,
139
140    /// Maximum websocket frame size. [default = 16M]
141    pub max_frame_size: usize,
142
143    /// Allowed origins access control for a [WebsocketListener].
144    /// Not used by the [WebsocketSender].
145    pub allowed_origins: Option<AllowedOrigins>,
146}
147
148impl WebsocketConfig {
149    /// The default client WebsocketConfig.
150    pub const CLIENT_DEFAULT: WebsocketConfig = WebsocketConfig {
151        default_request_timeout: std::time::Duration::from_secs(60),
152        max_message_size: 64 << 20,
153        max_frame_size: 16 << 20,
154        allowed_origins: None,
155    };
156
157    /// The default listener WebsocketConfig.
158    pub const LISTENER_DEFAULT: WebsocketConfig = WebsocketConfig {
159        default_request_timeout: std::time::Duration::from_secs(60),
160        max_message_size: 64 << 20,
161        max_frame_size: 16 << 20,
162        allowed_origins: Some(AllowedOrigins::Any),
163    };
164
165    /// Internal convert to tungstenite config.
166    pub(crate) fn as_tungstenite(
167        &self,
168    ) -> tokio_tungstenite::tungstenite::protocol::WebSocketConfig {
169        tokio_tungstenite::tungstenite::protocol::WebSocketConfig {
170            max_message_size: Some(self.max_message_size),
171            max_frame_size: Some(self.max_frame_size),
172            ..Default::default()
173        }
174    }
175}
176
177struct RMapInner(
178    pub  std::collections::HashMap<
179        u64,
180        tokio::sync::oneshot::Sender<WebsocketResult<SerializedBytes>>,
181    >,
182);
183
184impl Drop for RMapInner {
185    fn drop(&mut self) {
186        self.close();
187    }
188}
189
190impl RMapInner {
191    fn close(&mut self) {
192        for (_, s) in self.0.drain() {
193            let _ = s.send(Err(WebsocketError::Close("ConnectionClosed".to_string())));
194        }
195    }
196}
197
198#[derive(Clone)]
199struct RMap(Arc<std::sync::Mutex<RMapInner>>);
200
201impl Default for RMap {
202    fn default() -> Self {
203        Self(Arc::new(std::sync::Mutex::new(RMapInner(
204            std::collections::HashMap::default(),
205        ))))
206    }
207}
208
209impl RMap {
210    pub fn close(&self) {
211        if let Ok(mut lock) = self.0.lock() {
212            lock.close();
213        }
214    }
215
216    pub fn insert(
217        &self,
218        id: u64,
219        sender: tokio::sync::oneshot::Sender<WebsocketResult<SerializedBytes>>,
220    ) {
221        self.0.lock().unwrap().0.insert(id, sender);
222    }
223
224    pub fn remove(
225        &self,
226        id: u64,
227    ) -> Option<tokio::sync::oneshot::Sender<WebsocketResult<SerializedBytes>>> {
228        self.0.lock().unwrap().0.remove(&id)
229    }
230}
231
232/// An error produced when working with websockets.
233///
234/// It is intended to capture all the errors that a caller might want to handle. Other errors that
235/// are unlikely to be recoverable are mapped to [WebsocketError::Other].
236#[derive(thiserror::Error, Debug)]
237pub enum WebsocketError {
238    /// The websocket has been closed by the other side.
239    #[error("Websocket closed: {0}")]
240    Close(String),
241    /// A received messaged did not deserialize to the expected type.
242    #[error("Received a message that did not deserialize: {0}")]
243    Deserialize(#[from] SerializedBytesError),
244    /// A websocket error from the underlying tungstenite library.
245    #[error("Websocket error: {0}")]
246    Websocket(#[from] tokio_tungstenite::tungstenite::Error),
247    /// A timeout occurred.
248    #[error("Timeout")]
249    Timeout(#[from] tokio::time::error::Elapsed),
250    /// An IO error occurred.
251    #[error("IO error: {0}")]
252    Io(#[from] Error),
253    /// Some other error occurred.
254    #[error("Other error: {0}")]
255    Other(String),
256}
257
258/// A result type, with the error type [WebsocketError].
259pub type WebsocketResult<T> = std::result::Result<T, WebsocketError>;
260
261type WsStream = tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>;
262type WsSend =
263    futures::stream::SplitSink<WsStream, tokio_tungstenite::tungstenite::protocol::Message>;
264type WsSendSync = Arc<tokio::sync::Mutex<WsSend>>;
265type WsRecv = futures::stream::SplitStream<WsStream>;
266type WsRecvSync = Arc<tokio::sync::Mutex<WsRecv>>;
267
268#[derive(Clone)]
269struct WsCore {
270    pub send: WsSendSync,
271    pub recv: WsRecvSync,
272    pub rmap: RMap,
273    pub timeout: std::time::Duration,
274}
275
276#[derive(Clone)]
277struct WsCoreSync(Arc<std::sync::Mutex<Option<WsCore>>>);
278
279impl PartialEq for WsCoreSync {
280    fn eq(&self, other: &Self) -> bool {
281        Arc::ptr_eq(&self.0, &other.0)
282    }
283}
284
285impl WsCoreSync {
286    fn close(&self) {
287        if let Some(core) = self.0.lock().unwrap().take() {
288            core.rmap.close();
289            tokio::task::spawn(async move {
290                use futures::sink::SinkExt;
291                let _ = core.send.lock().await.close().await;
292            });
293        }
294    }
295
296    fn close_if_err<R>(&self, r: WebsocketResult<R>) -> WebsocketResult<R> {
297        match r {
298            Err(e @ WebsocketError::Deserialize { .. }) => {
299                // Don't close the connection on a deserialization error.
300                // That's a client issue and not a connection issue.
301                Err(e)
302            }
303            Err(err) => {
304                self.close();
305                Err(err)
306            }
307            Ok(res) => Ok(res),
308        }
309    }
310
311    pub async fn exec<F, C, R>(&self, c: C) -> WebsocketResult<R>
312    where
313        F: std::future::Future<Output = WebsocketResult<R>>,
314        C: FnOnce(WsCoreSync, WsCore) -> F,
315    {
316        let core = match self.0.lock().unwrap().as_ref() {
317            Some(core) => core.clone(),
318            None => return Err(WebsocketError::Close("No connection".to_string())),
319        };
320        self.close_if_err(c(self.clone(), core).await)
321    }
322}
323
324/// Respond to an incoming request.
325#[derive(PartialEq)]
326pub struct WebsocketRespond {
327    id: u64,
328    core: WsCoreSync,
329}
330
331impl std::fmt::Debug for WebsocketRespond {
332    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
333        f.debug_struct("WebsocketRespond")
334            .field("id", &self.id)
335            .finish()
336    }
337}
338
339impl WebsocketRespond {
340    /// Respond to an incoming request.
341    pub async fn respond<S>(self, s: S) -> WebsocketResult<()>
342    where
343        S: std::fmt::Debug,
344        SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
345    {
346        tracing::trace!(?s, %self.id, "OutResponse");
347        use futures::sink::SinkExt;
348        self.core
349            .exec(move |_, core| async move {
350                tokio::time::timeout(core.timeout, async {
351                    let s = WireMessage::response(self.id, s)?;
352                    core.send.lock().await.send(s).await?;
353                    Ok(())
354                })
355                .await?
356            })
357            .await
358    }
359}
360
361/// Types of messages that can be received by a WebsocketReceiver.
362#[derive(Debug, PartialEq)]
363pub enum ReceiveMessage<D>
364where
365    D: std::fmt::Debug,
366    SerializedBytes: TryInto<D, Error = SerializedBytesError>,
367{
368    /// Received a request to authenticate from the client.
369    Authenticate(Vec<u8>),
370
371    /// Received a signal from the remote.
372    Signal(Vec<u8>),
373
374    /// Received a request from the remote.
375    Request(D, WebsocketRespond),
376}
377
378/// Receive signals and requests from a websocket connection.
379/// Note, This receiver must be polled (recv()) for responses to requests
380/// made on the Sender side to be received.
381/// If this receiver is dropped, the sender side will also be closed.
382pub struct WebsocketReceiver(
383    WsCoreSync,
384    std::net::SocketAddr,
385    tokio::task::JoinHandle<()>,
386);
387
388impl Drop for WebsocketReceiver {
389    fn drop(&mut self) {
390        self.0.close();
391        self.2.abort();
392    }
393}
394
395impl WebsocketReceiver {
396    fn new(core: WsCoreSync, addr: std::net::SocketAddr) -> Self {
397        let core2 = core.clone();
398        let ping_task = tokio::task::spawn(async move {
399            loop {
400                tokio::time::sleep(std::time::Duration::from_secs(5)).await;
401                let core = core2.0.lock().unwrap().as_ref().cloned();
402                if let Some(core) = core {
403                    use futures::sink::SinkExt;
404                    if core
405                        .send
406                        .lock()
407                        .await
408                        .send(Message::Ping(Vec::new()))
409                        .await
410                        .is_err()
411                    {
412                        core2.close();
413                    }
414                } else {
415                    break;
416                }
417            }
418        });
419        Self(core, addr, ping_task)
420    }
421
422    /// Peer address.
423    pub fn peer_addr(&self) -> std::net::SocketAddr {
424        self.1
425    }
426
427    /// Receive the next message.
428    pub async fn recv<D>(&mut self) -> WebsocketResult<ReceiveMessage<D>>
429    where
430        D: std::fmt::Debug,
431        SerializedBytes: TryInto<D, Error = SerializedBytesError>,
432    {
433        match self.recv_inner().await {
434            Err(err) => {
435                tracing::warn!(?err, "WebsocketReceiver Error");
436                Err(err)
437            }
438            Ok(msg) => Ok(msg),
439        }
440    }
441
442    async fn recv_inner<D>(&mut self) -> WebsocketResult<ReceiveMessage<D>>
443    where
444        D: std::fmt::Debug,
445        SerializedBytes: TryInto<D, Error = SerializedBytesError>,
446    {
447        use futures::sink::SinkExt;
448        use futures::stream::StreamExt;
449        loop {
450            if let Some(result) = self
451                .0
452                .exec(move |core_sync, core| async move {
453                    let msg = core
454                        .recv
455                        .lock()
456                        .await
457                        .next()
458                        .await
459                        .ok_or::<WebsocketError>(WebsocketError::Other(
460                            "ReceiverClosed".to_string(),
461                        ))??;
462                    let msg = match msg {
463                        Message::Text(s) => s.into_bytes(),
464                        Message::Binary(b) => b,
465                        Message::Ping(b) => {
466                            core.send.lock().await.send(Message::Pong(b)).await?;
467                            return Ok(None);
468                        }
469                        Message::Pong(_) => return Ok(None),
470                        Message::Close(frame) => {
471                            return Err(WebsocketError::Close(format!("{frame:?}")));
472                        }
473                        Message::Frame(_) => {
474                            return Err(WebsocketError::Other("UnexpectedRawFrame".to_string()))
475                        }
476                    };
477                    match WireMessage::try_from_bytes(msg)? {
478                        WireMessage::Authenticate { data } => {
479                            Ok(Some(ReceiveMessage::Authenticate(data)))
480                        }
481                        WireMessage::Request { id, data } => {
482                            let resp = WebsocketRespond {
483                                id,
484                                core: core_sync,
485                            };
486                            let data: D =
487                                SerializedBytes::from(UnsafeBytes::from(data)).try_into()?;
488                            tracing::trace!(?data, %id, "InRequest");
489                            Ok(Some(ReceiveMessage::Request(data, resp)))
490                        }
491                        WireMessage::Response { id, data } => {
492                            if let Some(sender) = core.rmap.remove(id) {
493                                if let Some(data) = data {
494                                    let data = SerializedBytes::from(UnsafeBytes::from(data));
495                                    tracing::trace!(%id, ?data, "InResponse");
496                                    let _ = sender.send(Ok(data));
497                                }
498                            }
499                            Ok(None)
500                        }
501                        WireMessage::Signal { data } => Ok(Some(ReceiveMessage::Signal(data))),
502                    }
503                })
504                .await?
505            {
506                return Ok(result);
507            }
508        }
509    }
510}
511
512/// Send requests and signals to the remote end of this websocket connection.
513/// Note, this receiver side must be polled (recv()) for responses to requests
514/// made on this sender to be received.
515#[derive(Clone)]
516pub struct WebsocketSender(WsCoreSync, std::time::Duration);
517
518impl WebsocketSender {
519    /// Authenticate with the remote using the default configured timeout.
520    pub async fn authenticate<S>(&self, s: S) -> WebsocketResult<()>
521    where
522        S: std::fmt::Debug,
523        SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
524    {
525        self.authenticate_timeout(s, self.1).await
526    }
527
528    /// Authenticate with the remote.
529    pub async fn authenticate_timeout<S>(
530        &self,
531        s: S,
532        timeout: std::time::Duration,
533    ) -> WebsocketResult<()>
534    where
535        S: std::fmt::Debug,
536        SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
537    {
538        use futures::sink::SinkExt;
539        self.0
540            .exec(move |_, core| async move {
541                tokio::time::timeout(timeout, async {
542                    let s = WireMessage::authenticate(s)?;
543                    core.send.lock().await.send(s).await?;
544                    Ok(())
545                })
546                .await?
547            })
548            .await
549    }
550
551    /// Make a request of the remote using the default configured timeout.
552    /// Note, this receiver side must be polled (recv()) for responses to
553    /// requests made on this sender to be received.
554    pub async fn request<S, R>(&self, s: S) -> WebsocketResult<R>
555    where
556        S: std::fmt::Debug,
557        SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
558        R: serde::de::DeserializeOwned + std::fmt::Debug,
559    {
560        self.request_timeout(s, self.1).await
561    }
562
563    /// Make a request of the remote.
564    pub async fn request_timeout<S, R>(
565        &self,
566        s: S,
567        timeout: std::time::Duration,
568    ) -> WebsocketResult<R>
569    where
570        S: std::fmt::Debug,
571        SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
572        R: serde::de::DeserializeOwned + std::fmt::Debug,
573    {
574        let timeout_at = tokio::time::Instant::now() + timeout;
575
576        use futures::sink::SinkExt;
577
578        let (s, id) = WireMessage::request(s)?;
579
580        /// Drop helper to remove our response callback if we timeout.
581        struct D(RMap, u64);
582
583        impl Drop for D {
584            fn drop(&mut self) {
585                self.0.remove(self.1);
586            }
587        }
588
589        let (resp_s, resp_r) = tokio::sync::oneshot::channel();
590
591        let _drop = self
592            .0
593            .exec(move |_, core| async move {
594                // create the drop helper
595                let drop = D(core.rmap.clone(), id);
596
597                // register the response callback
598                core.rmap.insert(id, resp_s);
599
600                tokio::time::timeout_at(timeout_at, async move {
601                    // send the actual message
602                    core.send.lock().await.send(s).await?;
603
604                    Ok(drop)
605                })
606                .await?
607            })
608            .await?;
609
610        // do the remainder outside the 'exec' because we don't actually
611        // want to close the connection down if an individual response is
612        // not returned... that is separate from the connection no longer
613        // being viable. (but we still want it to timeout at the same point)
614        tokio::time::timeout_at(timeout_at, async {
615            // await the response
616            let resp = resp_r
617                .await
618                .map_err(|_| WebsocketError::Other("ResponderDropped".to_string()))??;
619
620            // decode the response
621            let res = decode(&Vec::from(UnsafeBytes::from(resp)))?;
622            tracing::trace!(?res, %id, "OutRequestResponse");
623            Ok(res)
624        })
625        .await?
626    }
627
628    /// Send a signal to the remote using the default configured timeout.
629    pub async fn signal<S>(&self, s: S) -> WebsocketResult<()>
630    where
631        S: std::fmt::Debug,
632        SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
633    {
634        self.signal_timeout(s, self.1).await
635    }
636
637    /// Send a signal to the remote.
638    pub async fn signal_timeout<S>(&self, s: S, timeout: std::time::Duration) -> WebsocketResult<()>
639    where
640        S: std::fmt::Debug,
641        SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
642    {
643        use futures::sink::SinkExt;
644        self.0
645            .exec(move |_, core| async move {
646                tokio::time::timeout(timeout, async {
647                    let s = WireMessage::signal(s)?;
648                    core.send.lock().await.send(s).await?;
649                    Ok(())
650                })
651                .await?
652            })
653            .await
654    }
655}
656
657fn split(
658    stream: WsStream,
659    timeout: std::time::Duration,
660    peer_addr: std::net::SocketAddr,
661) -> WebsocketResult<(WebsocketSender, WebsocketReceiver)> {
662    let (sink, stream) = futures::stream::StreamExt::split(stream);
663
664    // Q: Why do we split the parts only to seemingly put them back together?
665    // A: They are in separate tokio mutexes, so we can still receive
666    //    and send at the same time in separate tasks, but being in the same
667    //    WsCore(Sync) lets us close them both at the same time if either
668    //    one errors.
669    let core = WsCore {
670        send: Arc::new(tokio::sync::Mutex::new(sink)),
671        recv: Arc::new(tokio::sync::Mutex::new(stream)),
672        rmap: RMap::default(),
673        timeout,
674    };
675
676    let core_send = WsCoreSync(Arc::new(std::sync::Mutex::new(Some(core))));
677    let core_recv = core_send.clone();
678
679    Ok((
680        WebsocketSender(core_send, timeout),
681        WebsocketReceiver::new(core_recv, peer_addr),
682    ))
683}
684
685/// Establish a new outgoing websocket connection to remote.
686pub async fn connect(
687    config: Arc<WebsocketConfig>,
688    request: impl Into<ConnectRequest>,
689) -> WebsocketResult<(WebsocketSender, WebsocketReceiver)> {
690    let request = request.into();
691    let stream = tokio::net::TcpStream::connect(request.addr).await?;
692    let peer_addr = stream.peer_addr()?;
693    let (stream, _addr) = tokio_tungstenite::client_async_with_config(
694        request.into_client_request()?,
695        stream,
696        Some(config.as_tungstenite()),
697    )
698    .await?;
699    split(stream, config.default_request_timeout, peer_addr)
700}
701
702/// A request to connect to a websocket server.
703pub struct ConnectRequest {
704    addr: std::net::SocketAddr,
705    headers: HeaderMap<HeaderValue>,
706}
707
708impl From<std::net::SocketAddr> for ConnectRequest {
709    fn from(addr: std::net::SocketAddr) -> Self {
710        Self::new(addr)
711    }
712}
713
714impl ConnectRequest {
715    /// Create a new [ConnectRequest].
716    pub fn new(addr: std::net::SocketAddr) -> Self {
717        let mut cr = ConnectRequest {
718            addr,
719            headers: HeaderMap::new(),
720        };
721
722        // Set a default Origin so that the connection request will be allowed by default when the listener is
723        // using `Any` as the allowed origin.
724        cr.headers.insert(
725            "Origin",
726            HeaderValue::from_str("holochain_websocket").expect("Invalid Origin value"),
727        );
728
729        cr
730    }
731
732    /// Try to set a header on this request.
733    ///
734    /// Errors if the value is invalid. See [HeaderValue::from_str].
735    pub fn try_set_header(mut self, name: &'static str, value: &str) -> Result<Self> {
736        self.headers
737            .insert(name, HeaderValue::from_str(value).map_err(Error::other)?);
738        Ok(self)
739    }
740
741    fn into_client_request(
742        self,
743    ) -> Result<impl tokio_tungstenite::tungstenite::client::IntoClientRequest + Unpin> {
744        use tokio_tungstenite::tungstenite::client::IntoClientRequest;
745        let mut req =
746            String::into_client_request(format!("ws://{}", self.addr)).map_err(Error::other)?;
747        for (name, value) in self.headers {
748            if let Some(name) = name {
749                req.headers_mut().insert(name, value);
750            } else {
751                tracing::warn!("Dropping invalid header");
752            }
753        }
754        Ok(req)
755    }
756
757    #[cfg(test)]
758    pub(crate) fn clear_headers(mut self) -> Self {
759        self.headers.clear();
760
761        self
762    }
763}
764
765// TODO async_trait still needed for dynamic dispatch https://blog.rust-lang.org/2023/12/21/async-fn-rpit-in-traits.html#dynamic-dispatch
766#[async_trait::async_trait]
767trait TcpListener: Send + Sync {
768    async fn accept(&self) -> Result<(tokio::net::TcpStream, SocketAddr)>;
769
770    fn local_addrs(&self) -> Result<Vec<SocketAddr>>;
771}
772
773#[async_trait::async_trait]
774impl TcpListener for tokio::net::TcpListener {
775    async fn accept(&self) -> Result<(tokio::net::TcpStream, SocketAddr)> {
776        self.accept().await
777    }
778
779    fn local_addrs(&self) -> Result<Vec<SocketAddr>> {
780        Ok(vec![self.local_addr()?])
781    }
782}
783
784struct DualStackListener {
785    v4: tokio::net::TcpListener,
786    v6: tokio::net::TcpListener,
787}
788
789#[async_trait::async_trait]
790impl TcpListener for DualStackListener {
791    async fn accept(&self) -> Result<(tokio::net::TcpStream, SocketAddr)> {
792        let (stream, addr) = select! {
793            res = self.v4.accept() => res?,
794            res = self.v6.accept() => res?,
795        };
796        Ok((stream, addr))
797    }
798
799    fn local_addrs(&self) -> Result<Vec<SocketAddr>> {
800        Ok(vec![self.v4.local_addr()?, self.v6.local_addr()?])
801    }
802}
803
804/// A Holochain websocket listener.
805pub struct WebsocketListener {
806    config: Arc<WebsocketConfig>,
807    access_control: Arc<AllowedOrigins>,
808    listener: Box<dyn TcpListener>,
809}
810
811impl Drop for WebsocketListener {
812    fn drop(&mut self) {
813        tracing::info!("WebsocketListenerDrop");
814    }
815}
816
817impl WebsocketListener {
818    /// Bind a new websocket listener.
819    pub async fn bind(config: Arc<WebsocketConfig>, addr: impl ToSocketAddrs) -> Result<Self> {
820        let access_control = Arc::new(config.allowed_origins.clone().ok_or_else(|| {
821            Error::other("WebsocketListener requires allowed_origins to be set in the config")
822        })?);
823
824        let listener = tokio::net::TcpListener::bind(addr).await?;
825
826        let addr = listener.local_addr()?;
827        tracing::info!(?addr, "WebsocketListener Listening");
828
829        Ok(Self {
830            config,
831            access_control,
832            listener: Box::new(listener),
833        })
834    }
835
836    /// Bind a new websocket listener on the same port using a v4 and a v6 socket.
837    ///
838    /// If the port is 0, then the OS will be allowed to pick a port for IPv6. This function will
839    /// then try to bind to the same port for IPv4. If the OS picks a port that is not available for
840    /// IPv4, then the function will retry binding IPv6 to get a new port and see if that is
841    /// available for IPv4. If this fails after 5 retries, then an error will be returned.
842    ///
843    /// If either IPv4 or IPv6 is disabled, then the function will fall back to binding to the
844    /// available stack. An info message will be logged to let the user know that one interface was
845    /// unavailable, but this is likely intentional or expected in the user's environment, so it will
846    /// not be treated as an error that should prevent the listener from starting.
847    ///
848    /// Note: The interface fallback behaviour can be tested manually on Linux by running:
849    /// `echo 1 | sudo tee /proc/sys/net/ipv6/conf/lo/disable_ipv6`
850    /// and then trying to start Holochain with info logging enabled. You can undo the change with:
851    /// `echo 0 | sudo tee /proc/sys/net/ipv6/conf/lo/disable_ipv6`.
852    pub async fn dual_bind(
853        config: Arc<WebsocketConfig>,
854        addr_v4: SocketAddrV4,
855        addr_v6: SocketAddrV6,
856    ) -> Result<Self> {
857        let access_control = Arc::new(config.allowed_origins.clone().ok_or_else(|| {
858            Error::other("WebsocketListener requires allowed_origins to be set in the config")
859        })?);
860
861        let addr_v6: SocketAddr = addr_v6.into();
862        let mut addr_v4: SocketAddr = addr_v4.into();
863
864        // The point of dual_bind is to bind to the same port on both v4 and v6
865        if addr_v6.port() != 0 && addr_v6.port() != addr_v4.port() {
866            return Err(Error::other(
867                "dual_bind requires the same port for IPv4 and IPv6",
868            ));
869        }
870
871        // Note that tokio binds to the stack matching the address type, so we can re-use the port
872        // without needing to create the sockets ourselves to configure this.
873
874        let mut listener: Option<DualStackListener> = None;
875        for _ in 0..5 {
876            let v6_listener = match tokio::net::TcpListener::bind(addr_v6).await {
877                Ok(l) => l,
878                // This is the error code that *should* be returned if IPv6 is disabled
879                Err(e) if e.kind() == ErrorKind::AddrNotAvailable => {
880                    tracing::info!(?e, "Failed to bind IPv6 listener because IPv6 appears to be disabled, falling back to IPv4 only");
881                    return Self::bind(config, addr_v4).await;
882                }
883                Err(e) => {
884                    return Err(e);
885                }
886            };
887
888            addr_v4.set_port(v6_listener.local_addr()?.port());
889
890            let v4_listener = match tokio::net::TcpListener::bind(addr_v4).await {
891                Ok(l) => l,
892                // This is the error code that *should* be returned if IPv4 is disabled
893                Err(e) if e.kind() == ErrorKind::AddrNotAvailable => {
894                    tracing::info!(?e, "Failed to bind IPv4 listener because IPv4 appears to be disabled, falling back to IPv6 only");
895                    // No need to re-bind the v6 listener, it's already bound. Just create a new Self
896                    // from the v6 listener and return it.
897                    return Ok(Self {
898                        config,
899                        access_control,
900                        listener: Box::new(v6_listener),
901                    });
902                }
903                // If the port for IPv6 was selected by the OS but it isn't available for IPv4, retry and let the OS pick a new port for IPv6
904                // and hopefully it will be available for IPv4.
905                Err(e) if addr_v6.port() == 0 && e.kind() == ErrorKind::AddrInUse => {
906                    tracing::warn!(?e, "Failed to bind the same port for IPv4 that was selected for IPv6, retrying with a new port");
907                    continue;
908                }
909                Err(e) => {
910                    return Err(e);
911                }
912            };
913
914            listener = Some(DualStackListener {
915                v4: v4_listener,
916                v6: v6_listener,
917            });
918            break;
919        }
920
921        // Gave up after a few retries, there's no point in continuing forever because there might be
922        // something wrong that the logic above isn't accounting for.
923        let listener = listener.ok_or_else(|| {
924            Error::other("Failed to bind listener to IPv4 and IPv6 interfaces after 5 retries")
925        })?;
926
927        let addr = listener.v4.local_addr()?;
928        tracing::info!(?addr, "WebsocketListener listening");
929
930        let addr = listener.v6.local_addr()?;
931        tracing::info!(?addr, "WebsocketListener listening");
932
933        Ok(Self {
934            config,
935            access_control,
936            listener: Box::new(listener),
937        })
938    }
939
940    /// Get the bound local address of this listener.
941    pub fn local_addrs(&self) -> Result<Vec<std::net::SocketAddr>> {
942        self.listener.local_addrs()
943    }
944
945    /// Accept an incoming connection.
946    pub async fn accept(&self) -> WebsocketResult<(WebsocketSender, WebsocketReceiver)> {
947        let (stream, addr) = self.listener.accept().await?;
948        tracing::debug!(?addr, "Accept Incoming Websocket Connection");
949        let stream = tokio_tungstenite::accept_hdr_async_with_config(
950            stream,
951            ConnectCallback {
952                allowed_origin: self.access_control.clone(),
953            },
954            Some(self.config.as_tungstenite()),
955        )
956        .await
957        .map_err(Error::other)?;
958        split(stream, self.config.default_request_timeout, addr)
959    }
960}
961
962struct ConnectCallback {
963    allowed_origin: Arc<AllowedOrigins>,
964}
965
966impl Callback for ConnectCallback {
967    fn on_request(
968        self,
969        request: &Request,
970        response: Response,
971    ) -> std::result::Result<Response, ErrorResponse> {
972        tracing::trace!(
973            "Checking incoming websocket connection request with allowed origin {:?}: {:?}",
974            self.allowed_origin,
975            request.headers()
976        );
977        match request
978            .headers()
979            .get("Origin")
980            .and_then(|v| v.to_str().ok())
981        {
982            Some(origin) => {
983                if self.allowed_origin.is_allowed(origin) {
984                    Ok(response)
985                } else {
986                    tracing::warn!("Rejecting websocket connection request with disallowed `Origin` header: {:?}", request);
987                    let allowed_origin: String = self.allowed_origin.as_ref().clone().into();
988                    match HeaderValue::from_str(&allowed_origin) {
989                        Ok(allowed_origin) => {
990                            let mut err_response = ErrorResponse::new(None);
991                            *err_response.status_mut() = StatusCode::BAD_REQUEST;
992                            err_response
993                                .headers_mut()
994                                .insert("Access-Control-Allow-Origin", allowed_origin);
995                            Err(err_response)
996                        }
997                        Err(_) => {
998                            // Shouldn't be possible to get here, the listener should be configured to require an origin
999                            let mut err_response = ErrorResponse::new(Some(
1000                                "Invalid listener configuration for `Origin`".to_string(),
1001                            ));
1002                            *err_response.status_mut() = StatusCode::BAD_REQUEST;
1003                            Err(err_response)
1004                        }
1005                    }
1006                }
1007            }
1008            None => {
1009                tracing::warn!(
1010                    "Rejecting websocket connection request with missing `Origin` header: {:?}",
1011                    request
1012                );
1013                let mut err_response =
1014                    ErrorResponse::new(Some("Missing `Origin` header".to_string()));
1015                *err_response.status_mut() = StatusCode::BAD_REQUEST;
1016                Err(err_response)
1017            }
1018        }
1019    }
1020}
1021
1022#[cfg(test)]
1023mod test;