1#![warn(missing_docs)]
59#[cfg(not(target_family = "wasm"))]
60compile_error!("websocket-web requires a WebAssembly target");
61
62mod closed;
63mod standard;
64mod stream;
65mod util;
66
67use futures_core::Stream;
68use futures_sink::Sink;
69use futures_util::{SinkExt, StreamExt};
70use js_sys::{Reflect, Uint8Array};
71use std::{
72 fmt, io,
73 io::ErrorKind,
74 mem,
75 pin::Pin,
76 rc::Rc,
77 task::{ready, Context, Poll},
78};
79use tokio::io::{AsyncRead, AsyncWrite};
80use wasm_bindgen::prelude::*;
81
82pub use closed::{CloseCode, Closed, ClosedReason};
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86pub enum Interface {
87 Stream,
91 Standard,
95}
96
97impl Interface {
98 pub fn is_supported(&self) -> bool {
100 let global = js_sys::global();
101 match self {
102 Self::Stream => Reflect::has(&global, &JsValue::from_str("WebSocketStream")).unwrap_or_default(),
103 Self::Standard => Reflect::has(&global, &JsValue::from_str("WebSocket")).unwrap_or_default(),
104 }
105 }
106}
107
108#[derive(Debug, Clone, PartialEq, Eq)]
110pub enum Msg {
111 Text(String),
113 Binary(Vec<u8>),
115}
116
117impl Msg {
118 pub const fn is_text(&self) -> bool {
120 matches!(self, Self::Text(_))
121 }
122
123 pub const fn is_binary(&self) -> bool {
125 matches!(self, Self::Binary(_))
126 }
127
128 pub fn to_vec(self) -> Vec<u8> {
130 match self {
131 Self::Text(text) => text.as_bytes().to_vec(),
132 Self::Binary(vec) => vec,
133 }
134 }
135
136 pub fn len(&self) -> usize {
138 match self {
139 Self::Text(text) => text.len(),
140 Self::Binary(vec) => vec.len(),
141 }
142 }
143
144 pub fn is_empty(&self) -> bool {
146 match self {
147 Self::Text(text) => text.is_empty(),
148 Self::Binary(vec) => vec.is_empty(),
149 }
150 }
151}
152
153impl fmt::Display for Msg {
154 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
155 match self {
156 Self::Text(text) => write!(f, "{text}"),
157 Self::Binary(binary) => write!(f, "{}", String::from_utf8_lossy(binary)),
158 }
159 }
160}
161
162impl From<Msg> for Vec<u8> {
163 fn from(msg: Msg) -> Self {
164 msg.to_vec()
165 }
166}
167
168impl AsRef<[u8]> for Msg {
169 fn as_ref(&self) -> &[u8] {
170 match self {
171 Self::Text(text) => text.as_bytes(),
172 Self::Binary(vec) => vec,
173 }
174 }
175}
176
177#[derive(Debug, Clone)]
179pub struct WebSocketBuilder {
180 url: String,
181 protocols: Vec<String>,
182 interface: Option<Interface>,
183 send_buffer_size: Option<usize>,
184 receive_buffer_size: Option<usize>,
185}
186
187impl WebSocketBuilder {
188 pub fn new(url: impl AsRef<str>) -> Self {
190 Self {
191 url: url.as_ref().to_string(),
192 protocols: Vec::new(),
193 interface: None,
194 send_buffer_size: None,
195 receive_buffer_size: None,
196 }
197 }
198
199 pub fn set_interface(&mut self, interface: Interface) {
203 self.interface = Some(interface);
204 }
205
206 pub fn set_protocols<P>(&mut self, protocols: impl IntoIterator<Item = P>)
218 where
219 P: AsRef<str>,
220 {
221 self.protocols = protocols.into_iter().map(|s| s.as_ref().to_string()).collect();
222 }
223
224 pub fn set_send_buffer_size(&mut self, send_buffer_size: usize) {
236 self.send_buffer_size = Some(send_buffer_size);
237 }
238
239 pub fn set_receive_buffer_size(&mut self, receive_buffer_size: usize) {
249 self.receive_buffer_size = Some(receive_buffer_size);
250 }
251
252 pub async fn connect(self) -> io::Result<WebSocket> {
254 let interface = match self.interface {
255 Some(interface) => interface,
256 None if Interface::Stream.is_supported() => Interface::Stream,
257 None => Interface::Standard,
258 };
259
260 if !interface.is_supported() {
261 match interface {
262 Interface::Stream => {
263 return Err(io::Error::new(ErrorKind::Unsupported, "WebSocketStream not supported"))
264 }
265 Interface::Standard => {
266 return Err(io::Error::new(ErrorKind::Unsupported, "WebSocket not supported"))
267 }
268 }
269 }
270
271 match interface {
272 Interface::Stream => {
273 let (stream, info) = stream::Inner::new(self).await?;
274 Ok(WebSocket { inner: Inner::Stream(stream), info: Rc::new(info), read_buf: Vec::new() })
275 }
276 Interface::Standard => {
277 let (standard, info) = standard::Inner::new(self).await?;
278 Ok(WebSocket { inner: Inner::Standard(standard), info: Rc::new(info), read_buf: Vec::new() })
279 }
280 }
281 }
282}
283
284struct Info {
285 url: String,
286 protocol: String,
287 interface: Interface,
288}
289
290pub struct WebSocket {
294 inner: Inner,
295 info: Rc<Info>,
296 read_buf: Vec<u8>,
297}
298
299enum Inner {
300 Stream(stream::Inner),
301 Standard(standard::Inner),
302}
303
304impl fmt::Debug for WebSocket {
305 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
306 f.debug_struct("WebSocket")
307 .field("url", &self.info.url)
308 .field("protocol", &self.protocol())
309 .field("interface", &self.interface())
310 .finish()
311 }
312}
313
314impl WebSocket {
315 pub async fn connect(url: impl AsRef<str>) -> io::Result<Self> {
317 WebSocketBuilder::new(url).connect().await
318 }
319
320 pub fn url(&self) -> &str {
322 &self.info.url
323 }
324
325 pub fn protocol(&self) -> &str {
331 &self.info.protocol
332 }
333
334 pub fn interface(&self) -> Interface {
336 self.info.interface
337 }
338
339 pub fn into_split(self) -> (WebSocketSender, WebSocketReceiver) {
341 let Self { inner, info, read_buf } = self;
342 match inner {
343 Inner::Stream(inner) => {
344 let (sender, receiver) = inner.into_split();
345 let sender = WebSocketSender { inner: SenderInner::Stream(sender), info: info.clone() };
346 let receiver = WebSocketReceiver { inner: ReceiverInner::Stream(receiver), info, read_buf };
347 (sender, receiver)
348 }
349 Inner::Standard(inner) => {
350 let (sender, receiver) = inner.into_split();
351 let sender = WebSocketSender { inner: SenderInner::Standard(sender), info: info.clone() };
352 let receiver =
353 WebSocketReceiver { inner: ReceiverInner::Standard(receiver), info, read_buf: Vec::new() };
354 (sender, receiver)
355 }
356 }
357 }
358
359 pub fn close(self) {
361 self.into_split().0.close();
362 }
363
364 #[track_caller]
370 pub fn close_with_reason(self, code: CloseCode, reason: &str) {
371 self.into_split().0.close_with_reason(code, reason);
372 }
373
374 pub fn closed(&self) -> Closed {
376 match &self.inner {
377 Inner::Stream(inner) => inner.closed(),
378 Inner::Standard(inner) => inner.closed(),
379 }
380 }
381
382 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
383 match &mut self.inner {
384 Inner::Stream(inner) => inner.sender.poll_ready_unpin(cx),
385 Inner::Standard(inner) => inner.sender.poll_ready_unpin(cx),
386 }
387 }
388
389 fn start_send(mut self: Pin<&mut Self>, item: &JsValue, len: usize) -> Result<(), io::Error> {
390 match &mut self.inner {
391 Inner::Stream(inner) => inner.sender.start_send_unpin((item, len)),
392 Inner::Standard(inner) => inner.sender.start_send_unpin(item),
393 }
394 }
395
396 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
397 match &mut self.inner {
398 Inner::Stream(inner) => inner.sender.poll_flush_unpin(cx),
399 Inner::Standard(inner) => inner.sender.poll_flush_unpin(cx),
400 }
401 }
402
403 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
404 match &mut self.inner {
405 Inner::Stream(inner) => inner.sender.poll_close_unpin(cx),
406 Inner::Standard(inner) => inner.sender.poll_close_unpin(cx),
407 }
408 }
409}
410
411impl Sink<&str> for WebSocket {
412 type Error = io::Error;
413
414 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
415 self.poll_ready(cx)
416 }
417
418 fn start_send(self: Pin<&mut Self>, item: &str) -> Result<(), Self::Error> {
419 self.start_send(&JsValue::from_str(item), item.len())
420 }
421
422 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
423 self.poll_flush(cx)
424 }
425
426 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
427 self.poll_close(cx)
428 }
429}
430
431impl Sink<String> for WebSocket {
432 type Error = io::Error;
433
434 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
435 self.poll_ready(cx)
436 }
437
438 fn start_send(self: Pin<&mut Self>, item: String) -> Result<(), Self::Error> {
439 self.start_send(&JsValue::from_str(&item), item.len())
440 }
441
442 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
443 self.poll_flush(cx)
444 }
445
446 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
447 self.poll_close(cx)
448 }
449}
450
451impl Sink<&[u8]> for WebSocket {
452 type Error = io::Error;
453
454 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
455 self.poll_ready(cx)
456 }
457
458 fn start_send(self: Pin<&mut Self>, item: &[u8]) -> Result<(), Self::Error> {
459 self.start_send(&Uint8Array::from(item), item.len())
460 }
461
462 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
463 self.poll_flush(cx)
464 }
465
466 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
467 self.poll_close(cx)
468 }
469}
470
471impl Sink<Vec<u8>> for WebSocket {
472 type Error = io::Error;
473
474 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
475 self.poll_ready(cx)
476 }
477
478 fn start_send(self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
479 self.start_send(&Uint8Array::from(&item[..]), item.len())
480 }
481
482 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
483 self.poll_flush(cx)
484 }
485
486 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
487 self.poll_close(cx)
488 }
489}
490
491impl Sink<Msg> for WebSocket {
492 type Error = io::Error;
493
494 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
495 self.poll_ready(cx)
496 }
497
498 fn start_send(self: Pin<&mut Self>, item: Msg) -> Result<(), Self::Error> {
499 match item {
500 Msg::Text(text) => self.start_send(&JsValue::from_str(&text), text.len()),
501 Msg::Binary(vec) => self.start_send(&Uint8Array::from(&vec[..]), vec.len()),
502 }
503 }
504
505 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
506 self.poll_flush(cx)
507 }
508
509 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
510 self.poll_close(cx)
511 }
512}
513
514impl AsyncWrite for WebSocket {
515 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
516 ready!(self.as_mut().poll_ready(cx))?;
517 self.start_send(&Uint8Array::from(buf), buf.len())?;
518 Poll::Ready(Ok(buf.len()))
519 }
520
521 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
522 self.poll_flush(cx)
523 }
524
525 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
526 self.poll_close(cx)
527 }
528}
529
530impl Stream for WebSocket {
531 type Item = io::Result<Msg>;
532
533 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
534 match &mut self.inner {
535 Inner::Stream(inner) => inner.receiver.poll_next_unpin(cx),
536 Inner::Standard(inner) => inner.receiver.poll_next_unpin(cx),
537 }
538 }
539}
540
541impl AsyncRead for WebSocket {
542 fn poll_read(
543 mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut tokio::io::ReadBuf,
544 ) -> Poll<io::Result<()>> {
545 while self.read_buf.is_empty() {
546 let Some(msg) = ready!(self.as_mut().poll_next(cx)?) else { return Poll::Ready(Ok(())) };
547 self.read_buf = msg.to_vec();
548 }
549
550 let part = if buf.remaining() < self.read_buf.len() {
551 let rem = self.read_buf.split_off(buf.remaining());
552 mem::replace(&mut self.read_buf, rem)
553 } else {
554 mem::take(&mut self.read_buf)
555 };
556
557 buf.put_slice(&part);
558 Poll::Ready(Ok(()))
559 }
560}
561
562pub struct WebSocketSender {
567 inner: SenderInner,
568 info: Rc<Info>,
569}
570
571enum SenderInner {
572 Stream(stream::Sender),
573 Standard(standard::Sender),
574}
575
576impl fmt::Debug for WebSocketSender {
577 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
578 f.debug_struct("WebSocketSender")
579 .field("url", &self.info.url)
580 .field("protocol", &self.protocol())
581 .field("interface", &self.interface())
582 .finish()
583 }
584}
585
586impl WebSocketSender {
587 pub fn url(&self) -> &str {
589 &self.info.url
590 }
591
592 pub fn protocol(&self) -> &str {
594 &self.info.protocol
595 }
596
597 pub fn interface(&self) -> Interface {
599 self.info.interface
600 }
601
602 pub fn close(self) {
606 self.close_with_reason(CloseCode::NormalClosure, "");
607 }
608
609 #[track_caller]
617 pub fn close_with_reason(self, code: CloseCode, reason: &str) {
618 if !code.is_valid() {
619 panic!("WebSocket close code {code} is invalid");
620 }
621
622 match self.inner {
623 SenderInner::Stream(sender) => sender.close(code.into(), reason),
624 SenderInner::Standard(sender) => sender.close(code.into(), reason),
625 }
626 }
627
628 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
629 match &mut self.inner {
630 SenderInner::Stream(inner) => inner.poll_ready_unpin(cx),
631 SenderInner::Standard(inner) => inner.poll_ready_unpin(cx),
632 }
633 }
634
635 fn start_send(mut self: Pin<&mut Self>, item: &JsValue, len: usize) -> Result<(), io::Error> {
636 match &mut self.inner {
637 SenderInner::Stream(inner) => inner.start_send_unpin((item, len)),
638 SenderInner::Standard(inner) => inner.start_send_unpin(item),
639 }
640 }
641
642 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
643 match &mut self.inner {
644 SenderInner::Stream(inner) => inner.poll_flush_unpin(cx),
645 SenderInner::Standard(inner) => inner.poll_flush_unpin(cx),
646 }
647 }
648
649 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
650 match &mut self.inner {
651 SenderInner::Stream(inner) => inner.poll_close_unpin(cx),
652 SenderInner::Standard(inner) => inner.poll_close_unpin(cx),
653 }
654 }
655}
656
657impl Sink<&str> for WebSocketSender {
658 type Error = io::Error;
659
660 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
661 self.poll_ready(cx)
662 }
663
664 fn start_send(self: Pin<&mut Self>, item: &str) -> Result<(), Self::Error> {
665 self.start_send(&JsValue::from_str(item), item.len())
666 }
667
668 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
669 self.poll_flush(cx)
670 }
671
672 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
673 self.poll_close(cx)
674 }
675}
676
677impl Sink<String> for WebSocketSender {
678 type Error = io::Error;
679
680 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
681 self.poll_ready(cx)
682 }
683
684 fn start_send(self: Pin<&mut Self>, item: String) -> Result<(), Self::Error> {
685 self.start_send(&JsValue::from_str(&item), item.len())
686 }
687
688 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
689 self.poll_flush(cx)
690 }
691
692 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
693 self.poll_close(cx)
694 }
695}
696
697impl Sink<&[u8]> for WebSocketSender {
698 type Error = io::Error;
699
700 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
701 self.poll_ready(cx)
702 }
703
704 fn start_send(self: Pin<&mut Self>, item: &[u8]) -> Result<(), Self::Error> {
705 self.start_send(&Uint8Array::from(item), item.len())
706 }
707
708 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
709 self.poll_flush(cx)
710 }
711
712 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
713 self.poll_close(cx)
714 }
715}
716
717impl Sink<Vec<u8>> for WebSocketSender {
718 type Error = io::Error;
719
720 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
721 self.poll_ready(cx)
722 }
723
724 fn start_send(self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
725 self.start_send(&Uint8Array::from(&item[..]), item.len())
726 }
727
728 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
729 self.poll_flush(cx)
730 }
731
732 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
733 self.poll_close(cx)
734 }
735}
736
737impl Sink<Msg> for WebSocketSender {
738 type Error = io::Error;
739
740 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
741 self.poll_ready(cx)
742 }
743
744 fn start_send(self: Pin<&mut Self>, item: Msg) -> Result<(), Self::Error> {
745 match item {
746 Msg::Text(text) => self.start_send(&JsValue::from_str(&text), text.len()),
747 Msg::Binary(vec) => self.start_send(&Uint8Array::from(&vec[..]), vec.len()),
748 }
749 }
750
751 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
752 self.poll_flush(cx)
753 }
754
755 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
756 self.poll_close(cx)
757 }
758}
759
760impl AsyncWrite for WebSocketSender {
761 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
762 ready!(self.as_mut().poll_ready(cx))?;
763 self.start_send(&Uint8Array::from(buf), buf.len())?;
764 Poll::Ready(Ok(buf.len()))
765 }
766
767 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
768 self.poll_flush(cx)
769 }
770
771 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
772 self.poll_close(cx)
773 }
774}
775
776pub struct WebSocketReceiver {
781 inner: ReceiverInner,
782 info: Rc<Info>,
783 read_buf: Vec<u8>,
784}
785
786enum ReceiverInner {
787 Stream(stream::Receiver),
788 Standard(standard::Receiver),
789}
790
791impl fmt::Debug for WebSocketReceiver {
792 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
793 f.debug_struct("WebSocketReceiver")
794 .field("url", &self.info.url)
795 .field("protocol", &self.protocol())
796 .field("interface", &self.interface())
797 .finish()
798 }
799}
800
801impl WebSocketReceiver {
802 pub fn url(&self) -> &str {
804 &self.info.url
805 }
806
807 pub fn protocol(&self) -> &str {
809 &self.info.protocol
810 }
811
812 pub fn interface(&self) -> Interface {
814 self.info.interface
815 }
816}
817
818impl Stream for WebSocketReceiver {
819 type Item = io::Result<Msg>;
820
821 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
822 match &mut self.inner {
823 ReceiverInner::Stream(inner) => inner.poll_next_unpin(cx),
824 ReceiverInner::Standard(inner) => inner.poll_next_unpin(cx),
825 }
826 }
827}
828
829impl AsyncRead for WebSocketReceiver {
830 fn poll_read(
831 mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut tokio::io::ReadBuf,
832 ) -> Poll<io::Result<()>> {
833 while self.read_buf.is_empty() {
834 let Some(msg) = ready!(self.as_mut().poll_next(cx)?) else { return Poll::Ready(Ok(())) };
835 self.read_buf = msg.to_vec();
836 }
837
838 let part = if buf.remaining() < self.read_buf.len() {
839 let rem = self.read_buf.split_off(buf.remaining());
840 mem::replace(&mut self.read_buf, rem)
841 } else {
842 mem::take(&mut self.read_buf)
843 };
844
845 buf.put_slice(&part);
846 Poll::Ready(Ok(()))
847 }
848}