ant_libp2p_websocket_websys/
lib.rs

1// Copyright (C) 2023 Vince Vasta
2//
3// Permission is hereby granted, free of charge, to any person obtaining a copy
4// of this software and associated documentation files (the "Software"), to deal
5// in the Software without restriction, including without limitation the rights
6// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7// copies of the Software, and to permit persons to whom the Software is
8// furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in all
11// copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19// SOFTWARE.
20
21//! Libp2p websocket transports built on [web-sys](https://rustwasm.github.io/wasm-bindgen/web-sys/index.html).
22
23#![allow(unexpected_cfgs)]
24
25mod web_context;
26
27use ant_libp2p_core as libp2p_core;
28
29use std::{
30    cmp::min,
31    pin::Pin,
32    rc::Rc,
33    sync::{
34        atomic::{AtomicBool, Ordering},
35        Mutex,
36    },
37    task::{Context, Poll},
38};
39
40use bytes::BytesMut;
41use futures::{future::Ready, io, prelude::*, task::AtomicWaker};
42use js_sys::Array;
43use libp2p_core::{
44    multiaddr::{Multiaddr, Protocol},
45    transport::{DialOpts, ListenerId, TransportError, TransportEvent},
46};
47use send_wrapper::SendWrapper;
48use wasm_bindgen::prelude::*;
49use web_sys::{CloseEvent, Event, MessageEvent, WebSocket};
50
51use crate::web_context::WebContext;
52
53/// A Websocket transport that can be used in a wasm environment.
54///
55/// ## Example
56///
57/// To create an authenticated transport instance with Noise protocol and Yamux:
58///
59/// ```
60/// # use libp2p_core::{upgrade::Version, Transport};
61/// # use libp2p_identity::Keypair;
62/// # use libp2p_yamux as yamux;
63/// # use libp2p_noise as noise;
64/// let local_key = Keypair::generate_ed25519();
65/// let transport = libp2p_websocket_websys::Transport::default()
66///     .upgrade(Version::V1)
67///     .authenticate(noise::Config::new(&local_key).unwrap())
68///     .multiplex(yamux::Config::default())
69///     .boxed();
70/// ```
71#[derive(Default)]
72pub struct Transport {
73    _private: (),
74}
75
76/// Arbitrary, maximum amount we are willing to buffer before we throttle our user.
77const MAX_BUFFER: usize = 1024 * 1024;
78
79impl libp2p_core::Transport for Transport {
80    type Output = Connection;
81    type Error = Error;
82    type ListenerUpgrade = Ready<Result<Self::Output, Self::Error>>;
83    type Dial = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
84
85    fn listen_on(
86        &mut self,
87        _: ListenerId,
88        addr: Multiaddr,
89    ) -> Result<(), TransportError<Self::Error>> {
90        Err(TransportError::MultiaddrNotSupported(addr))
91    }
92
93    fn remove_listener(&mut self, _id: ListenerId) -> bool {
94        false
95    }
96
97    fn dial(
98        &mut self,
99        addr: Multiaddr,
100        dial_opts: DialOpts,
101    ) -> Result<Self::Dial, TransportError<Self::Error>> {
102        if dial_opts.role.is_listener() {
103            return Err(TransportError::MultiaddrNotSupported(addr));
104        }
105
106        let url =
107            extract_websocket_url(&addr).ok_or(TransportError::MultiaddrNotSupported(addr))?;
108
109        Ok(async move {
110            let socket = match WebSocket::new(&url) {
111                Ok(ws) => ws,
112                Err(_) => return Err(Error::invalid_websocket_url(&url)),
113            };
114
115            Ok(Connection::new(socket))
116        }
117        .boxed())
118    }
119
120    fn poll(
121        self: Pin<&mut Self>,
122        _cx: &mut Context<'_>,
123    ) -> std::task::Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
124        Poll::Pending
125    }
126}
127
128// Try to convert Multiaddr to a Websocket url.
129fn extract_websocket_url(addr: &Multiaddr) -> Option<String> {
130    let mut protocols = addr.iter();
131    let host_port = match (protocols.next(), protocols.next()) {
132        (Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) => {
133            format!("{ip}:{port}")
134        }
135        (Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) => {
136            format!("[{ip}]:{port}")
137        }
138        (Some(Protocol::Dns(h)), Some(Protocol::Tcp(port)))
139        | (Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port)))
140        | (Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port))) => {
141            format!("{}:{}", &h, port)
142        }
143        _ => return None,
144    };
145
146    let (scheme, wspath) = match (protocols.next(), protocols.next()) {
147        (Some(Protocol::Tls), Some(Protocol::Ws(path))) => ("wss", path.into_owned()),
148        (Some(Protocol::Ws(path)), _) => ("ws", path.into_owned()),
149        (Some(Protocol::Wss(path)), _) => ("wss", path.into_owned()),
150        _ => return None,
151    };
152
153    Some(format!("{scheme}://{host_port}{wspath}"))
154}
155
156#[derive(thiserror::Error, Debug)]
157#[error("{msg}")]
158pub struct Error {
159    msg: String,
160}
161
162impl Error {
163    fn invalid_websocket_url(url: &str) -> Self {
164        Self {
165            msg: format!("Invalid websocket url: {url}"),
166        }
167    }
168}
169
170/// A Websocket connection created by the [`Transport`].
171pub struct Connection {
172    inner: SendWrapper<Inner>,
173}
174
175struct Inner {
176    socket: WebSocket,
177
178    new_data_waker: Rc<AtomicWaker>,
179    read_buffer: Rc<Mutex<BytesMut>>,
180
181    /// Waker for when we are waiting for the WebSocket to be opened.
182    open_waker: Rc<AtomicWaker>,
183
184    /// Waker for when we are waiting to write (again) to the WebSocket because we previously
185    /// exceeded the [`MAX_BUFFER`] threshold.
186    write_waker: Rc<AtomicWaker>,
187
188    /// Waker for when we are waiting for the WebSocket to be closed.
189    close_waker: Rc<AtomicWaker>,
190
191    /// Whether the connection errored.
192    errored: Rc<AtomicBool>,
193
194    // Store the closures for proper garbage collection.
195    // These are wrapped in an [`Rc`] so we can implement [`Clone`].
196    _on_open_closure: Rc<Closure<dyn FnMut(Event)>>,
197    _on_buffered_amount_low_closure: Rc<Closure<dyn FnMut(Event)>>,
198    _on_close_closure: Rc<Closure<dyn FnMut(CloseEvent)>>,
199    _on_error_closure: Rc<Closure<dyn FnMut(CloseEvent)>>,
200    _on_message_closure: Rc<Closure<dyn FnMut(MessageEvent)>>,
201    buffered_amount_low_interval: i32,
202}
203
204impl Inner {
205    fn ready_state(&self) -> ReadyState {
206        match self.socket.ready_state() {
207            0 => ReadyState::Connecting,
208            1 => ReadyState::Open,
209            2 => ReadyState::Closing,
210            3 => ReadyState::Closed,
211            unknown => unreachable!("invalid `ReadyState` value: {unknown}"),
212        }
213    }
214
215    fn poll_open(&mut self, cx: &Context<'_>) -> Poll<io::Result<()>> {
216        match self.ready_state() {
217            ReadyState::Connecting => {
218                self.open_waker.register(cx.waker());
219                Poll::Pending
220            }
221            ReadyState::Open => Poll::Ready(Ok(())),
222            ReadyState::Closed | ReadyState::Closing => {
223                Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()))
224            }
225        }
226    }
227
228    fn error_barrier(&self) -> io::Result<()> {
229        if self.errored.load(Ordering::SeqCst) {
230            return Err(io::ErrorKind::BrokenPipe.into());
231        }
232
233        Ok(())
234    }
235}
236
237/// The state of the WebSocket.
238///
239/// See <https://developer.mozilla.org/en-US/docs/Web/API/WebSocket/readyState>.
240#[derive(PartialEq)]
241enum ReadyState {
242    Connecting,
243    Open,
244    Closing,
245    Closed,
246}
247
248impl Connection {
249    fn new(socket: WebSocket) -> Self {
250        socket.set_binary_type(web_sys::BinaryType::Arraybuffer);
251
252        let open_waker = Rc::new(AtomicWaker::new());
253        let onopen_closure = Closure::<dyn FnMut(_)>::new({
254            let open_waker = open_waker.clone();
255            move |_| {
256                open_waker.wake();
257            }
258        });
259        socket.set_onopen(Some(onopen_closure.as_ref().unchecked_ref()));
260
261        let close_waker = Rc::new(AtomicWaker::new());
262        let onclose_closure = Closure::<dyn FnMut(_)>::new({
263            let close_waker = close_waker.clone();
264            move |_| {
265                close_waker.wake();
266            }
267        });
268        socket.set_onclose(Some(onclose_closure.as_ref().unchecked_ref()));
269
270        let errored = Rc::new(AtomicBool::new(false));
271        let onerror_closure = Closure::<dyn FnMut(_)>::new({
272            let errored = errored.clone();
273            move |_| {
274                errored.store(true, Ordering::SeqCst);
275            }
276        });
277        socket.set_onerror(Some(onerror_closure.as_ref().unchecked_ref()));
278
279        let read_buffer = Rc::new(Mutex::new(BytesMut::new()));
280        let new_data_waker = Rc::new(AtomicWaker::new());
281        let onmessage_closure = Closure::<dyn FnMut(_)>::new({
282            let read_buffer = read_buffer.clone();
283            let new_data_waker = new_data_waker.clone();
284            let errored = errored.clone();
285            move |e: MessageEvent| {
286                let data = js_sys::Uint8Array::new(&e.data());
287
288                let mut read_buffer = read_buffer.lock().unwrap();
289
290                if read_buffer.len() + data.length() as usize > MAX_BUFFER {
291                    tracing::warn!("Remote is overloading us with messages, closing connection");
292                    errored.store(true, Ordering::SeqCst);
293
294                    return;
295                }
296
297                read_buffer.extend_from_slice(&data.to_vec());
298                new_data_waker.wake();
299            }
300        });
301        socket.set_onmessage(Some(onmessage_closure.as_ref().unchecked_ref()));
302
303        let write_waker = Rc::new(AtomicWaker::new());
304        let on_buffered_amount_low_closure = Closure::<dyn FnMut(_)>::new({
305            let write_waker = write_waker.clone();
306            let socket = socket.clone();
307            move |_| {
308                if socket.buffered_amount() == 0 {
309                    write_waker.wake();
310                }
311            }
312        });
313        let buffered_amount_low_interval = WebContext::new()
314            .expect("to have a window or worker context")
315            .set_interval_with_callback_and_timeout_and_arguments(
316                on_buffered_amount_low_closure.as_ref().unchecked_ref(),
317                // Chosen arbitrarily and likely worth tuning. Due to low impact of the /ws
318                // transport, no further effort was invested at the time.
319                100,
320                &Array::new(),
321            )
322            .expect("to be able to set an interval");
323
324        Self {
325            inner: SendWrapper::new(Inner {
326                socket,
327                new_data_waker,
328                read_buffer,
329                open_waker,
330                write_waker,
331                close_waker,
332                errored,
333                _on_open_closure: Rc::new(onopen_closure),
334                _on_buffered_amount_low_closure: Rc::new(on_buffered_amount_low_closure),
335                _on_close_closure: Rc::new(onclose_closure),
336                _on_error_closure: Rc::new(onerror_closure),
337                _on_message_closure: Rc::new(onmessage_closure),
338                buffered_amount_low_interval,
339            }),
340        }
341    }
342
343    fn buffered_amount(&self) -> usize {
344        self.inner.socket.buffered_amount() as usize
345    }
346}
347
348impl AsyncRead for Connection {
349    fn poll_read(
350        self: Pin<&mut Self>,
351        cx: &mut Context<'_>,
352        buf: &mut [u8],
353    ) -> Poll<Result<usize, io::Error>> {
354        let this = self.get_mut();
355        this.inner.error_barrier()?;
356        futures::ready!(this.inner.poll_open(cx))?;
357
358        let mut read_buffer = this.inner.read_buffer.lock().unwrap();
359
360        if read_buffer.is_empty() {
361            this.inner.new_data_waker.register(cx.waker());
362            return Poll::Pending;
363        }
364
365        // Ensure that we:
366        // - at most return what the caller can read (`buf.len()`)
367        // - at most what we have (`read_buffer.len()`)
368        let split_index = min(buf.len(), read_buffer.len());
369
370        let bytes_to_return = read_buffer.split_to(split_index);
371        let len = bytes_to_return.len();
372        buf[..len].copy_from_slice(&bytes_to_return);
373
374        Poll::Ready(Ok(len))
375    }
376}
377
378impl AsyncWrite for Connection {
379    fn poll_write(
380        self: Pin<&mut Self>,
381        cx: &mut Context<'_>,
382        buf: &[u8],
383    ) -> Poll<io::Result<usize>> {
384        let this = self.get_mut();
385
386        this.inner.error_barrier()?;
387        futures::ready!(this.inner.poll_open(cx))?;
388
389        debug_assert!(this.buffered_amount() <= MAX_BUFFER);
390        let remaining_space = MAX_BUFFER - this.buffered_amount();
391
392        if remaining_space == 0 {
393            this.inner.write_waker.register(cx.waker());
394            return Poll::Pending;
395        }
396
397        let bytes_to_send = min(buf.len(), remaining_space);
398
399        if this
400            .inner
401            .socket
402            .send_with_u8_array(&buf[..bytes_to_send])
403            .is_err()
404        {
405            return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
406        }
407
408        Poll::Ready(Ok(bytes_to_send))
409    }
410
411    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
412        if self.buffered_amount() == 0 {
413            return Poll::Ready(Ok(()));
414        }
415
416        self.inner.error_barrier()?;
417
418        self.inner.write_waker.register(cx.waker());
419        Poll::Pending
420    }
421
422    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
423        const REGULAR_CLOSE: u16 = 1000; // See https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1.
424
425        if self.inner.ready_state() == ReadyState::Closed {
426            return Poll::Ready(Ok(()));
427        }
428
429        self.inner.error_barrier()?;
430
431        if self.inner.ready_state() != ReadyState::Closing {
432            let _ = self
433                .inner
434                .socket
435                .close_with_code_and_reason(REGULAR_CLOSE, "user initiated");
436        }
437
438        self.inner.close_waker.register(cx.waker());
439        Poll::Pending
440    }
441}
442
443impl Drop for Connection {
444    fn drop(&mut self) {
445        // Unset event listeners, as otherwise they will be called by JS after the handlers have
446        // already been dropped.
447        self.inner.socket.set_onclose(None);
448        self.inner.socket.set_onerror(None);
449        self.inner.socket.set_onopen(None);
450        self.inner.socket.set_onmessage(None);
451
452        // In browsers, userland code is not allowed to use any other status code than 1000: https://websockets.spec.whatwg.org/#dom-websocket-close
453        const REGULAR_CLOSE: u16 = 1000; // See https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1.
454
455        if let ReadyState::Connecting | ReadyState::Open = self.inner.ready_state() {
456            let _ = self
457                .inner
458                .socket
459                .close_with_code_and_reason(REGULAR_CLOSE, "connection dropped");
460        }
461
462        WebContext::new()
463            .expect("to have a window or worker context")
464            .clear_interval_with_handle(self.inner.buffered_amount_low_interval);
465    }
466}
467
468#[cfg(test)]
469mod tests {
470    use libp2p_identity::PeerId;
471
472    use super::*;
473
474    #[test]
475    fn extract_url() {
476        let peer_id = PeerId::random();
477
478        // Check `/tls/ws`
479        let addr = "/dns4/example.com/tcp/2222/tls/ws"
480            .parse::<Multiaddr>()
481            .unwrap();
482        let url = extract_websocket_url(&addr).unwrap();
483        assert_eq!(url, "wss://example.com:2222/");
484
485        // Check `/tls/ws` with `/p2p`
486        let addr = format!("/dns4/example.com/tcp/2222/tls/ws/p2p/{peer_id}")
487            .parse()
488            .unwrap();
489        let url = extract_websocket_url(&addr).unwrap();
490        assert_eq!(url, "wss://example.com:2222/");
491
492        // Check `/tls/ws` with `/ip4`
493        let addr = "/ip4/127.0.0.1/tcp/2222/tls/ws"
494            .parse::<Multiaddr>()
495            .unwrap();
496        let url = extract_websocket_url(&addr).unwrap();
497        assert_eq!(url, "wss://127.0.0.1:2222/");
498
499        // Check `/tls/ws` with `/ip6`
500        let addr = "/ip6/::1/tcp/2222/tls/ws".parse::<Multiaddr>().unwrap();
501        let url = extract_websocket_url(&addr).unwrap();
502        assert_eq!(url, "wss://[::1]:2222/");
503
504        // Check `/wss`
505        let addr = "/dns4/example.com/tcp/2222/wss"
506            .parse::<Multiaddr>()
507            .unwrap();
508        let url = extract_websocket_url(&addr).unwrap();
509        assert_eq!(url, "wss://example.com:2222/");
510
511        // Check `/wss` with `/p2p`
512        let addr = format!("/dns4/example.com/tcp/2222/wss/p2p/{peer_id}")
513            .parse()
514            .unwrap();
515        let url = extract_websocket_url(&addr).unwrap();
516        assert_eq!(url, "wss://example.com:2222/");
517
518        // Check `/wss` with `/ip4`
519        let addr = "/ip4/127.0.0.1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
520        let url = extract_websocket_url(&addr).unwrap();
521        assert_eq!(url, "wss://127.0.0.1:2222/");
522
523        // Check `/wss` with `/ip6`
524        let addr = "/ip6/::1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
525        let url = extract_websocket_url(&addr).unwrap();
526        assert_eq!(url, "wss://[::1]:2222/");
527
528        // Check `/ws`
529        let addr = "/dns4/example.com/tcp/2222/ws"
530            .parse::<Multiaddr>()
531            .unwrap();
532        let url = extract_websocket_url(&addr).unwrap();
533        assert_eq!(url, "ws://example.com:2222/");
534
535        // Check `/ws` with `/p2p`
536        let addr = format!("/dns4/example.com/tcp/2222/ws/p2p/{peer_id}")
537            .parse()
538            .unwrap();
539        let url = extract_websocket_url(&addr).unwrap();
540        assert_eq!(url, "ws://example.com:2222/");
541
542        // Check `/ws` with `/ip4`
543        let addr = "/ip4/127.0.0.1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
544        let url = extract_websocket_url(&addr).unwrap();
545        assert_eq!(url, "ws://127.0.0.1:2222/");
546
547        // Check `/ws` with `/ip6`
548        let addr = "/ip6/::1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
549        let url = extract_websocket_url(&addr).unwrap();
550        assert_eq!(url, "ws://[::1]:2222/");
551
552        // Check `/ws` with `/ip4`
553        let addr = "/ip4/127.0.0.1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
554        let url = extract_websocket_url(&addr).unwrap();
555        assert_eq!(url, "ws://127.0.0.1:2222/");
556
557        // Check that `/tls/wss` is invalid
558        let addr = "/ip4/127.0.0.1/tcp/2222/tls/wss"
559            .parse::<Multiaddr>()
560            .unwrap();
561        assert!(extract_websocket_url(&addr).is_none());
562
563        // Check `/dnsaddr`
564        let addr = "/dnsaddr/example.com/tcp/2222/ws"
565            .parse::<Multiaddr>()
566            .unwrap();
567        assert!(extract_websocket_url(&addr).is_none());
568
569        // Check non-ws address
570        let addr = "/ip4/127.0.0.1/tcp/2222".parse::<Multiaddr>().unwrap();
571        assert!(extract_websocket_url(&addr).is_none());
572    }
573}