ant_libp2p_websocket_websys/
lib.rs1#![allow(unexpected_cfgs)]
24
25mod web_context;
26
27use ant_libp2p_core as libp2p_core;
28
29use std::{
30 cmp::min,
31 pin::Pin,
32 rc::Rc,
33 sync::{
34 atomic::{AtomicBool, Ordering},
35 Mutex,
36 },
37 task::{Context, Poll},
38};
39
40use bytes::BytesMut;
41use futures::{future::Ready, io, prelude::*, task::AtomicWaker};
42use js_sys::Array;
43use libp2p_core::{
44 multiaddr::{Multiaddr, Protocol},
45 transport::{DialOpts, ListenerId, TransportError, TransportEvent},
46};
47use send_wrapper::SendWrapper;
48use wasm_bindgen::prelude::*;
49use web_sys::{CloseEvent, Event, MessageEvent, WebSocket};
50
51use crate::web_context::WebContext;
52
53#[derive(Default)]
72pub struct Transport {
73 _private: (),
74}
75
76const MAX_BUFFER: usize = 1024 * 1024;
78
79impl libp2p_core::Transport for Transport {
80 type Output = Connection;
81 type Error = Error;
82 type ListenerUpgrade = Ready<Result<Self::Output, Self::Error>>;
83 type Dial = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
84
85 fn listen_on(
86 &mut self,
87 _: ListenerId,
88 addr: Multiaddr,
89 ) -> Result<(), TransportError<Self::Error>> {
90 Err(TransportError::MultiaddrNotSupported(addr))
91 }
92
93 fn remove_listener(&mut self, _id: ListenerId) -> bool {
94 false
95 }
96
97 fn dial(
98 &mut self,
99 addr: Multiaddr,
100 dial_opts: DialOpts,
101 ) -> Result<Self::Dial, TransportError<Self::Error>> {
102 if dial_opts.role.is_listener() {
103 return Err(TransportError::MultiaddrNotSupported(addr));
104 }
105
106 let url =
107 extract_websocket_url(&addr).ok_or(TransportError::MultiaddrNotSupported(addr))?;
108
109 Ok(async move {
110 let socket = match WebSocket::new(&url) {
111 Ok(ws) => ws,
112 Err(_) => return Err(Error::invalid_websocket_url(&url)),
113 };
114
115 Ok(Connection::new(socket))
116 }
117 .boxed())
118 }
119
120 fn poll(
121 self: Pin<&mut Self>,
122 _cx: &mut Context<'_>,
123 ) -> std::task::Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
124 Poll::Pending
125 }
126}
127
128fn extract_websocket_url(addr: &Multiaddr) -> Option<String> {
130 let mut protocols = addr.iter();
131 let host_port = match (protocols.next(), protocols.next()) {
132 (Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) => {
133 format!("{ip}:{port}")
134 }
135 (Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) => {
136 format!("[{ip}]:{port}")
137 }
138 (Some(Protocol::Dns(h)), Some(Protocol::Tcp(port)))
139 | (Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port)))
140 | (Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port))) => {
141 format!("{}:{}", &h, port)
142 }
143 _ => return None,
144 };
145
146 let (scheme, wspath) = match (protocols.next(), protocols.next()) {
147 (Some(Protocol::Tls), Some(Protocol::Ws(path))) => ("wss", path.into_owned()),
148 (Some(Protocol::Ws(path)), _) => ("ws", path.into_owned()),
149 (Some(Protocol::Wss(path)), _) => ("wss", path.into_owned()),
150 _ => return None,
151 };
152
153 Some(format!("{scheme}://{host_port}{wspath}"))
154}
155
156#[derive(thiserror::Error, Debug)]
157#[error("{msg}")]
158pub struct Error {
159 msg: String,
160}
161
162impl Error {
163 fn invalid_websocket_url(url: &str) -> Self {
164 Self {
165 msg: format!("Invalid websocket url: {url}"),
166 }
167 }
168}
169
170pub struct Connection {
172 inner: SendWrapper<Inner>,
173}
174
175struct Inner {
176 socket: WebSocket,
177
178 new_data_waker: Rc<AtomicWaker>,
179 read_buffer: Rc<Mutex<BytesMut>>,
180
181 open_waker: Rc<AtomicWaker>,
183
184 write_waker: Rc<AtomicWaker>,
187
188 close_waker: Rc<AtomicWaker>,
190
191 errored: Rc<AtomicBool>,
193
194 _on_open_closure: Rc<Closure<dyn FnMut(Event)>>,
197 _on_buffered_amount_low_closure: Rc<Closure<dyn FnMut(Event)>>,
198 _on_close_closure: Rc<Closure<dyn FnMut(CloseEvent)>>,
199 _on_error_closure: Rc<Closure<dyn FnMut(CloseEvent)>>,
200 _on_message_closure: Rc<Closure<dyn FnMut(MessageEvent)>>,
201 buffered_amount_low_interval: i32,
202}
203
204impl Inner {
205 fn ready_state(&self) -> ReadyState {
206 match self.socket.ready_state() {
207 0 => ReadyState::Connecting,
208 1 => ReadyState::Open,
209 2 => ReadyState::Closing,
210 3 => ReadyState::Closed,
211 unknown => unreachable!("invalid `ReadyState` value: {unknown}"),
212 }
213 }
214
215 fn poll_open(&mut self, cx: &Context<'_>) -> Poll<io::Result<()>> {
216 match self.ready_state() {
217 ReadyState::Connecting => {
218 self.open_waker.register(cx.waker());
219 Poll::Pending
220 }
221 ReadyState::Open => Poll::Ready(Ok(())),
222 ReadyState::Closed | ReadyState::Closing => {
223 Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()))
224 }
225 }
226 }
227
228 fn error_barrier(&self) -> io::Result<()> {
229 if self.errored.load(Ordering::SeqCst) {
230 return Err(io::ErrorKind::BrokenPipe.into());
231 }
232
233 Ok(())
234 }
235}
236
237#[derive(PartialEq)]
241enum ReadyState {
242 Connecting,
243 Open,
244 Closing,
245 Closed,
246}
247
248impl Connection {
249 fn new(socket: WebSocket) -> Self {
250 socket.set_binary_type(web_sys::BinaryType::Arraybuffer);
251
252 let open_waker = Rc::new(AtomicWaker::new());
253 let onopen_closure = Closure::<dyn FnMut(_)>::new({
254 let open_waker = open_waker.clone();
255 move |_| {
256 open_waker.wake();
257 }
258 });
259 socket.set_onopen(Some(onopen_closure.as_ref().unchecked_ref()));
260
261 let close_waker = Rc::new(AtomicWaker::new());
262 let onclose_closure = Closure::<dyn FnMut(_)>::new({
263 let close_waker = close_waker.clone();
264 move |_| {
265 close_waker.wake();
266 }
267 });
268 socket.set_onclose(Some(onclose_closure.as_ref().unchecked_ref()));
269
270 let errored = Rc::new(AtomicBool::new(false));
271 let onerror_closure = Closure::<dyn FnMut(_)>::new({
272 let errored = errored.clone();
273 move |_| {
274 errored.store(true, Ordering::SeqCst);
275 }
276 });
277 socket.set_onerror(Some(onerror_closure.as_ref().unchecked_ref()));
278
279 let read_buffer = Rc::new(Mutex::new(BytesMut::new()));
280 let new_data_waker = Rc::new(AtomicWaker::new());
281 let onmessage_closure = Closure::<dyn FnMut(_)>::new({
282 let read_buffer = read_buffer.clone();
283 let new_data_waker = new_data_waker.clone();
284 let errored = errored.clone();
285 move |e: MessageEvent| {
286 let data = js_sys::Uint8Array::new(&e.data());
287
288 let mut read_buffer = read_buffer.lock().unwrap();
289
290 if read_buffer.len() + data.length() as usize > MAX_BUFFER {
291 tracing::warn!("Remote is overloading us with messages, closing connection");
292 errored.store(true, Ordering::SeqCst);
293
294 return;
295 }
296
297 read_buffer.extend_from_slice(&data.to_vec());
298 new_data_waker.wake();
299 }
300 });
301 socket.set_onmessage(Some(onmessage_closure.as_ref().unchecked_ref()));
302
303 let write_waker = Rc::new(AtomicWaker::new());
304 let on_buffered_amount_low_closure = Closure::<dyn FnMut(_)>::new({
305 let write_waker = write_waker.clone();
306 let socket = socket.clone();
307 move |_| {
308 if socket.buffered_amount() == 0 {
309 write_waker.wake();
310 }
311 }
312 });
313 let buffered_amount_low_interval = WebContext::new()
314 .expect("to have a window or worker context")
315 .set_interval_with_callback_and_timeout_and_arguments(
316 on_buffered_amount_low_closure.as_ref().unchecked_ref(),
317 100,
320 &Array::new(),
321 )
322 .expect("to be able to set an interval");
323
324 Self {
325 inner: SendWrapper::new(Inner {
326 socket,
327 new_data_waker,
328 read_buffer,
329 open_waker,
330 write_waker,
331 close_waker,
332 errored,
333 _on_open_closure: Rc::new(onopen_closure),
334 _on_buffered_amount_low_closure: Rc::new(on_buffered_amount_low_closure),
335 _on_close_closure: Rc::new(onclose_closure),
336 _on_error_closure: Rc::new(onerror_closure),
337 _on_message_closure: Rc::new(onmessage_closure),
338 buffered_amount_low_interval,
339 }),
340 }
341 }
342
343 fn buffered_amount(&self) -> usize {
344 self.inner.socket.buffered_amount() as usize
345 }
346}
347
348impl AsyncRead for Connection {
349 fn poll_read(
350 self: Pin<&mut Self>,
351 cx: &mut Context<'_>,
352 buf: &mut [u8],
353 ) -> Poll<Result<usize, io::Error>> {
354 let this = self.get_mut();
355 this.inner.error_barrier()?;
356 futures::ready!(this.inner.poll_open(cx))?;
357
358 let mut read_buffer = this.inner.read_buffer.lock().unwrap();
359
360 if read_buffer.is_empty() {
361 this.inner.new_data_waker.register(cx.waker());
362 return Poll::Pending;
363 }
364
365 let split_index = min(buf.len(), read_buffer.len());
369
370 let bytes_to_return = read_buffer.split_to(split_index);
371 let len = bytes_to_return.len();
372 buf[..len].copy_from_slice(&bytes_to_return);
373
374 Poll::Ready(Ok(len))
375 }
376}
377
378impl AsyncWrite for Connection {
379 fn poll_write(
380 self: Pin<&mut Self>,
381 cx: &mut Context<'_>,
382 buf: &[u8],
383 ) -> Poll<io::Result<usize>> {
384 let this = self.get_mut();
385
386 this.inner.error_barrier()?;
387 futures::ready!(this.inner.poll_open(cx))?;
388
389 debug_assert!(this.buffered_amount() <= MAX_BUFFER);
390 let remaining_space = MAX_BUFFER - this.buffered_amount();
391
392 if remaining_space == 0 {
393 this.inner.write_waker.register(cx.waker());
394 return Poll::Pending;
395 }
396
397 let bytes_to_send = min(buf.len(), remaining_space);
398
399 if this
400 .inner
401 .socket
402 .send_with_u8_array(&buf[..bytes_to_send])
403 .is_err()
404 {
405 return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
406 }
407
408 Poll::Ready(Ok(bytes_to_send))
409 }
410
411 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
412 if self.buffered_amount() == 0 {
413 return Poll::Ready(Ok(()));
414 }
415
416 self.inner.error_barrier()?;
417
418 self.inner.write_waker.register(cx.waker());
419 Poll::Pending
420 }
421
422 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
423 const REGULAR_CLOSE: u16 = 1000; if self.inner.ready_state() == ReadyState::Closed {
426 return Poll::Ready(Ok(()));
427 }
428
429 self.inner.error_barrier()?;
430
431 if self.inner.ready_state() != ReadyState::Closing {
432 let _ = self
433 .inner
434 .socket
435 .close_with_code_and_reason(REGULAR_CLOSE, "user initiated");
436 }
437
438 self.inner.close_waker.register(cx.waker());
439 Poll::Pending
440 }
441}
442
443impl Drop for Connection {
444 fn drop(&mut self) {
445 self.inner.socket.set_onclose(None);
448 self.inner.socket.set_onerror(None);
449 self.inner.socket.set_onopen(None);
450 self.inner.socket.set_onmessage(None);
451
452 const REGULAR_CLOSE: u16 = 1000; if let ReadyState::Connecting | ReadyState::Open = self.inner.ready_state() {
456 let _ = self
457 .inner
458 .socket
459 .close_with_code_and_reason(REGULAR_CLOSE, "connection dropped");
460 }
461
462 WebContext::new()
463 .expect("to have a window or worker context")
464 .clear_interval_with_handle(self.inner.buffered_amount_low_interval);
465 }
466}
467
468#[cfg(test)]
469mod tests {
470 use libp2p_identity::PeerId;
471
472 use super::*;
473
474 #[test]
475 fn extract_url() {
476 let peer_id = PeerId::random();
477
478 let addr = "/dns4/example.com/tcp/2222/tls/ws"
480 .parse::<Multiaddr>()
481 .unwrap();
482 let url = extract_websocket_url(&addr).unwrap();
483 assert_eq!(url, "wss://example.com:2222/");
484
485 let addr = format!("/dns4/example.com/tcp/2222/tls/ws/p2p/{peer_id}")
487 .parse()
488 .unwrap();
489 let url = extract_websocket_url(&addr).unwrap();
490 assert_eq!(url, "wss://example.com:2222/");
491
492 let addr = "/ip4/127.0.0.1/tcp/2222/tls/ws"
494 .parse::<Multiaddr>()
495 .unwrap();
496 let url = extract_websocket_url(&addr).unwrap();
497 assert_eq!(url, "wss://127.0.0.1:2222/");
498
499 let addr = "/ip6/::1/tcp/2222/tls/ws".parse::<Multiaddr>().unwrap();
501 let url = extract_websocket_url(&addr).unwrap();
502 assert_eq!(url, "wss://[::1]:2222/");
503
504 let addr = "/dns4/example.com/tcp/2222/wss"
506 .parse::<Multiaddr>()
507 .unwrap();
508 let url = extract_websocket_url(&addr).unwrap();
509 assert_eq!(url, "wss://example.com:2222/");
510
511 let addr = format!("/dns4/example.com/tcp/2222/wss/p2p/{peer_id}")
513 .parse()
514 .unwrap();
515 let url = extract_websocket_url(&addr).unwrap();
516 assert_eq!(url, "wss://example.com:2222/");
517
518 let addr = "/ip4/127.0.0.1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
520 let url = extract_websocket_url(&addr).unwrap();
521 assert_eq!(url, "wss://127.0.0.1:2222/");
522
523 let addr = "/ip6/::1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
525 let url = extract_websocket_url(&addr).unwrap();
526 assert_eq!(url, "wss://[::1]:2222/");
527
528 let addr = "/dns4/example.com/tcp/2222/ws"
530 .parse::<Multiaddr>()
531 .unwrap();
532 let url = extract_websocket_url(&addr).unwrap();
533 assert_eq!(url, "ws://example.com:2222/");
534
535 let addr = format!("/dns4/example.com/tcp/2222/ws/p2p/{peer_id}")
537 .parse()
538 .unwrap();
539 let url = extract_websocket_url(&addr).unwrap();
540 assert_eq!(url, "ws://example.com:2222/");
541
542 let addr = "/ip4/127.0.0.1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
544 let url = extract_websocket_url(&addr).unwrap();
545 assert_eq!(url, "ws://127.0.0.1:2222/");
546
547 let addr = "/ip6/::1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
549 let url = extract_websocket_url(&addr).unwrap();
550 assert_eq!(url, "ws://[::1]:2222/");
551
552 let addr = "/ip4/127.0.0.1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
554 let url = extract_websocket_url(&addr).unwrap();
555 assert_eq!(url, "ws://127.0.0.1:2222/");
556
557 let addr = "/ip4/127.0.0.1/tcp/2222/tls/wss"
559 .parse::<Multiaddr>()
560 .unwrap();
561 assert!(extract_websocket_url(&addr).is_none());
562
563 let addr = "/dnsaddr/example.com/tcp/2222/ws"
565 .parse::<Multiaddr>()
566 .unwrap();
567 assert!(extract_websocket_url(&addr).is_none());
568
569 let addr = "/ip4/127.0.0.1/tcp/2222".parse::<Multiaddr>().unwrap();
571 assert!(extract_websocket_url(&addr).is_none());
572 }
573}