1use crate::CowStr;
4use crate::stream::StreamError;
5use bytes::Bytes;
6use n0_future::Stream;
7use std::borrow::Borrow;
8use std::fmt::{self, Display};
9use std::future::Future;
10use std::ops::Deref;
11use std::pin::Pin;
12use url::Url;
13
14#[repr(transparent)]
16#[derive(Debug, Clone, Eq, PartialEq, Hash, PartialOrd, Ord)]
17pub struct WsText(Bytes);
18
19impl WsText {
20 pub const fn from_static(s: &'static str) -> Self {
22 Self(Bytes::from_static(s.as_bytes()))
23 }
24
25 pub fn as_str(&self) -> &str {
27 unsafe { std::str::from_utf8_unchecked(&self.0) }
28 }
29
30 pub unsafe fn from_bytes_unchecked(bytes: Bytes) -> Self {
35 Self(bytes)
36 }
37
38 pub fn into_bytes(self) -> Bytes {
40 self.0
41 }
42}
43
44impl Deref for WsText {
45 type Target = str;
46 fn deref(&self) -> &str {
47 self.as_str()
48 }
49}
50
51impl AsRef<str> for WsText {
52 fn as_ref(&self) -> &str {
53 self.as_str()
54 }
55}
56
57impl AsRef<[u8]> for WsText {
58 fn as_ref(&self) -> &[u8] {
59 &self.0
60 }
61}
62
63impl AsRef<Bytes> for WsText {
64 fn as_ref(&self) -> &Bytes {
65 &self.0
66 }
67}
68
69impl Borrow<str> for WsText {
70 fn borrow(&self) -> &str {
71 self.as_str()
72 }
73}
74
75impl Display for WsText {
76 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
77 Display::fmt(self.as_str(), f)
78 }
79}
80
81impl From<String> for WsText {
82 fn from(s: String) -> Self {
83 Self(Bytes::from(s))
84 }
85}
86
87impl From<&str> for WsText {
88 fn from(s: &str) -> Self {
89 Self(Bytes::copy_from_slice(s.as_bytes()))
90 }
91}
92
93impl From<&String> for WsText {
94 fn from(s: &String) -> Self {
95 Self::from(s.as_str())
96 }
97}
98
99impl TryFrom<Bytes> for WsText {
100 type Error = std::str::Utf8Error;
101 fn try_from(bytes: Bytes) -> Result<Self, Self::Error> {
102 std::str::from_utf8(&bytes)?;
103 Ok(Self(bytes))
104 }
105}
106
107impl TryFrom<Vec<u8>> for WsText {
108 type Error = std::str::Utf8Error;
109 fn try_from(vec: Vec<u8>) -> Result<Self, Self::Error> {
110 Self::try_from(Bytes::from(vec))
111 }
112}
113
114impl From<WsText> for Bytes {
115 fn from(t: WsText) -> Bytes {
116 t.0
117 }
118}
119
120impl Default for WsText {
121 fn default() -> Self {
122 Self(Bytes::new())
123 }
124}
125
126#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
128#[repr(u16)]
129pub enum CloseCode {
130 Normal = 1000,
132 Away = 1001,
134 Protocol = 1002,
136 Unsupported = 1003,
138 Invalid = 1007,
140 Policy = 1008,
142 Size = 1009,
144 Extension = 1010,
146 Error = 1011,
148 Tls = 1015,
150 Other(u16),
152}
153
154impl From<u16> for CloseCode {
155 fn from(code: u16) -> Self {
156 match code {
157 1000 => CloseCode::Normal,
158 1001 => CloseCode::Away,
159 1002 => CloseCode::Protocol,
160 1003 => CloseCode::Unsupported,
161 1007 => CloseCode::Invalid,
162 1008 => CloseCode::Policy,
163 1009 => CloseCode::Size,
164 1010 => CloseCode::Extension,
165 1011 => CloseCode::Error,
166 1015 => CloseCode::Tls,
167 other => CloseCode::Other(other),
168 }
169 }
170}
171
172impl From<CloseCode> for u16 {
173 fn from(code: CloseCode) -> u16 {
174 match code {
175 CloseCode::Normal => 1000,
176 CloseCode::Away => 1001,
177 CloseCode::Protocol => 1002,
178 CloseCode::Unsupported => 1003,
179 CloseCode::Invalid => 1007,
180 CloseCode::Policy => 1008,
181 CloseCode::Size => 1009,
182 CloseCode::Extension => 1010,
183 CloseCode::Error => 1011,
184 CloseCode::Tls => 1015,
185 CloseCode::Other(code) => code,
186 }
187 }
188}
189
190#[derive(Debug, Clone, PartialEq, Eq)]
192pub struct CloseFrame<'a> {
193 pub code: CloseCode,
195 pub reason: CowStr<'a>,
197}
198
199impl<'a> CloseFrame<'a> {
200 pub fn new(code: CloseCode, reason: impl Into<CowStr<'a>>) -> Self {
202 Self {
203 code,
204 reason: reason.into(),
205 }
206 }
207}
208
209#[derive(Debug, Clone, PartialEq, Eq)]
211pub enum WsMessage {
212 Text(WsText),
214 Binary(Bytes),
216 Close(Option<CloseFrame<'static>>),
218}
219
220impl WsMessage {
221 pub fn is_text(&self) -> bool {
223 matches!(self, WsMessage::Text(_))
224 }
225
226 pub fn is_binary(&self) -> bool {
228 matches!(self, WsMessage::Binary(_))
229 }
230
231 pub fn is_close(&self) -> bool {
233 matches!(self, WsMessage::Close(_))
234 }
235
236 pub fn as_text(&self) -> Option<&str> {
238 match self {
239 WsMessage::Text(t) => Some(t.as_str()),
240 _ => None,
241 }
242 }
243
244 pub fn as_bytes(&self) -> Option<&[u8]> {
246 match self {
247 WsMessage::Text(t) => Some(t.as_ref()),
248 WsMessage::Binary(b) => Some(b),
249 WsMessage::Close(_) => None,
250 }
251 }
252}
253
254impl From<WsText> for WsMessage {
255 fn from(text: WsText) -> Self {
256 WsMessage::Text(text)
257 }
258}
259
260impl From<String> for WsMessage {
261 fn from(s: String) -> Self {
262 WsMessage::Text(WsText::from(s))
263 }
264}
265
266impl From<&str> for WsMessage {
267 fn from(s: &str) -> Self {
268 WsMessage::Text(WsText::from(s))
269 }
270}
271
272impl From<Bytes> for WsMessage {
273 fn from(bytes: Bytes) -> Self {
274 WsMessage::Binary(bytes)
275 }
276}
277
278impl From<Vec<u8>> for WsMessage {
279 fn from(vec: Vec<u8>) -> Self {
280 WsMessage::Binary(Bytes::from(vec))
281 }
282}
283
284#[cfg(not(target_arch = "wasm32"))]
286pub struct WsStream(Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>> + Send>>);
287
288#[cfg(target_arch = "wasm32")]
290pub struct WsStream(Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>>>>);
291
292impl WsStream {
293 #[cfg(not(target_arch = "wasm32"))]
295 pub fn new<S>(stream: S) -> Self
296 where
297 S: Stream<Item = Result<WsMessage, StreamError>> + Send + 'static,
298 {
299 Self(Box::pin(stream))
300 }
301
302 #[cfg(target_arch = "wasm32")]
304 pub fn new<S>(stream: S) -> Self
305 where
306 S: Stream<Item = Result<WsMessage, StreamError>> + 'static,
307 {
308 Self(Box::pin(stream))
309 }
310
311 #[cfg(not(target_arch = "wasm32"))]
313 pub fn into_inner(self) -> Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>> + Send>> {
314 self.0
315 }
316
317 #[cfg(target_arch = "wasm32")]
319 pub fn into_inner(self) -> Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>>>> {
320 self.0
321 }
322
323 pub fn tee(self) -> (WsStream, WsStream) {
330 use futures::channel::mpsc;
331 use n0_future::StreamExt as _;
332
333 let (tx1, rx1) = mpsc::unbounded();
334 let (tx2, rx2) = mpsc::unbounded();
335
336 n0_future::task::spawn(async move {
337 let mut stream = self.0;
338 while let Some(result) = stream.next().await {
339 match result {
340 Ok(msg) => {
341 let msg2 = msg.clone();
343
344 let send1 = tx1.unbounded_send(Ok(msg));
346 let send2 = tx2.unbounded_send(Ok(msg2));
347
348 if send1.is_err() && send2.is_err() {
350 break;
351 }
352 }
353 Err(_e) => {
354 break;
357 }
358 }
359 }
360 });
361
362 (WsStream::new(rx1), WsStream::new(rx2))
363 }
364}
365
366impl fmt::Debug for WsStream {
367 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
368 f.debug_struct("WsStream").finish_non_exhaustive()
369 }
370}
371
372#[cfg(not(target_arch = "wasm32"))]
374pub struct WsSink(Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>>);
375
376#[cfg(target_arch = "wasm32")]
378pub struct WsSink(Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError>>>);
379
380impl WsSink {
381 #[cfg(not(target_arch = "wasm32"))]
383 pub fn new<S>(sink: S) -> Self
384 where
385 S: n0_future::Sink<WsMessage, Error = StreamError> + Send + 'static,
386 {
387 Self(Box::pin(sink))
388 }
389
390 #[cfg(target_arch = "wasm32")]
392 pub fn new<S>(sink: S) -> Self
393 where
394 S: n0_future::Sink<WsMessage, Error = StreamError> + 'static,
395 {
396 Self(Box::pin(sink))
397 }
398
399 #[cfg(not(target_arch = "wasm32"))]
401 pub fn into_inner(
402 self,
403 ) -> Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>> {
404 self.0
405 }
406
407 #[cfg(target_arch = "wasm32")]
409 pub fn into_inner(self) -> Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError>>> {
410 self.0
411 }
412
413 #[cfg(not(target_arch = "wasm32"))]
415 pub fn get_mut(
416 &mut self,
417 ) -> &mut Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>> {
418 use std::borrow::BorrowMut;
419
420 self.0.borrow_mut()
421 }
422
423 #[cfg(target_arch = "wasm32")]
425 pub fn get_mut(
426 &mut self,
427 ) -> &mut Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + 'static>> {
428 use std::borrow::BorrowMut;
429
430 self.0.borrow_mut()
431 }
432}
433
434impl fmt::Debug for WsSink {
435 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
436 f.debug_struct("WsSink").finish_non_exhaustive()
437 }
438}
439
440#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
442pub trait WebSocketClient: Sync {
443 type Error: std::error::Error + Send + Sync + 'static;
445
446 fn connect(&self, url: Url) -> impl Future<Output = Result<WebSocketConnection, Self::Error>>;
448
449 fn connect_with_headers(
454 &self,
455 url: Url,
456 _headers: Vec<(CowStr<'_>, CowStr<'_>)>,
457 ) -> impl Future<Output = Result<WebSocketConnection, Self::Error>> {
458 async move { self.connect(url).await }
459 }
460}
461
462pub struct WebSocketConnection {
464 tx: WsSink,
465 rx: WsStream,
466}
467
468impl WebSocketConnection {
469 pub fn new(tx: WsSink, rx: WsStream) -> Self {
471 Self { tx, rx }
472 }
473
474 pub fn sender_mut(&mut self) -> &mut WsSink {
476 &mut self.tx
477 }
478
479 pub fn receiver_mut(&mut self) -> &mut WsStream {
481 &mut self.rx
482 }
483
484 pub fn receiver(&self) -> &WsStream {
486 &self.rx
487 }
488
489 pub fn sender(&self) -> &WsSink {
491 &self.tx
492 }
493
494 pub fn split(self) -> (WsSink, WsStream) {
496 (self.tx, self.rx)
497 }
498
499 pub fn is_open(&self) -> bool {
501 true
502 }
503}
504
505impl fmt::Debug for WebSocketConnection {
506 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
507 f.debug_struct("WebSocketConnection")
508 .finish_non_exhaustive()
509 }
510}
511
512pub mod tungstenite_client {
514 use super::*;
515 use crate::IntoStatic;
516 use futures::{SinkExt, StreamExt};
517
518 #[derive(Debug, Clone, Default)]
520 pub struct TungsteniteClient;
521
522 impl TungsteniteClient {
523 pub fn new() -> Self {
525 Self
526 }
527 }
528
529 impl WebSocketClient for TungsteniteClient {
530 type Error = tokio_tungstenite_wasm::Error;
531
532 async fn connect(&self, url: Url) -> Result<WebSocketConnection, Self::Error> {
533 let ws_stream = tokio_tungstenite_wasm::connect(url.as_str()).await?;
534
535 let (sink, stream) = ws_stream.split();
536
537 let rx_stream = stream.filter_map(|result| async move {
539 match result {
540 Ok(msg) => match convert_message(msg) {
541 Some(ws_msg) => Some(Ok(ws_msg)),
542 None => None, },
544 Err(e) => Some(Err(StreamError::transport(e))),
545 }
546 });
547
548 let rx = WsStream::new(rx_stream);
549
550 let tx_sink = sink.with(|msg: WsMessage| async move {
552 Ok::<_, tokio_tungstenite_wasm::Error>(msg.into())
553 });
554
555 let tx_sink_mapped = tx_sink.sink_map_err(|e| StreamError::transport(e));
556 let tx = WsSink::new(tx_sink_mapped);
557
558 Ok(WebSocketConnection::new(tx, rx))
559 }
560 }
561
562 fn convert_message(msg: tokio_tungstenite_wasm::Message) -> Option<WsMessage> {
565 use tokio_tungstenite_wasm::Message;
566
567 match msg {
568 Message::Text(vec) => {
569 let bytes = Bytes::from(vec);
571 Some(WsMessage::Text(unsafe {
572 WsText::from_bytes_unchecked(bytes)
573 }))
574 }
575 Message::Binary(vec) => Some(WsMessage::Binary(Bytes::from(vec))),
576 Message::Close(frame) => {
577 let close_frame = frame.map(|f| {
578 let code = convert_close_code(f.code);
579 CloseFrame::new(code, CowStr::from(f.reason.into_owned()))
580 });
581 Some(WsMessage::Close(close_frame))
582 }
583 }
584 }
585
586 fn convert_close_code(code: tokio_tungstenite_wasm::CloseCode) -> CloseCode {
588 use tokio_tungstenite_wasm::CloseCode as TungsteniteCode;
589
590 match code {
591 TungsteniteCode::Normal => CloseCode::Normal,
592 TungsteniteCode::Away => CloseCode::Away,
593 TungsteniteCode::Protocol => CloseCode::Protocol,
594 TungsteniteCode::Unsupported => CloseCode::Unsupported,
595 TungsteniteCode::Invalid => CloseCode::Invalid,
596 TungsteniteCode::Policy => CloseCode::Policy,
597 TungsteniteCode::Size => CloseCode::Size,
598 TungsteniteCode::Extension => CloseCode::Extension,
599 TungsteniteCode::Error => CloseCode::Error,
600 TungsteniteCode::Tls => CloseCode::Tls,
601 other => {
603 let raw: u16 = other.into();
604 CloseCode::from(raw)
605 }
606 }
607 }
608
609 impl From<WsMessage> for tokio_tungstenite_wasm::Message {
610 fn from(msg: WsMessage) -> Self {
611 use tokio_tungstenite_wasm::Message;
612
613 match msg {
614 WsMessage::Text(text) => {
615 let bytes = text.into_bytes();
617 let string = unsafe { String::from_utf8_unchecked(bytes.to_vec()) };
619 Message::Text(string)
620 }
621 WsMessage::Binary(bytes) => Message::Binary(bytes.to_vec()),
622 WsMessage::Close(frame) => {
623 let close_frame = frame.map(|f| {
624 let code = u16::from(f.code).into();
625 tokio_tungstenite_wasm::CloseFrame {
626 code,
627 reason: f.reason.into_static().to_string().into(),
628 }
629 });
630 Message::Close(close_frame)
631 }
632 }
633 }
634 }
635}
636
637#[cfg(test)]
638mod tests {
639 use super::*;
640
641 #[test]
642 fn ws_text_from_string() {
643 let text = WsText::from("hello");
644 assert_eq!(text.as_str(), "hello");
645 }
646
647 #[test]
648 fn ws_text_deref() {
649 let text = WsText::from(String::from("world"));
650 assert_eq!(&*text, "world");
651 }
652
653 #[test]
654 fn ws_text_try_from_bytes() {
655 let bytes = Bytes::from("test");
656 let text = WsText::try_from(bytes).unwrap();
657 assert_eq!(text.as_str(), "test");
658 }
659
660 #[test]
661 fn ws_text_invalid_utf8() {
662 let bytes = Bytes::from(vec![0xFF, 0xFE]);
663 assert!(WsText::try_from(bytes).is_err());
664 }
665
666 #[test]
667 fn ws_message_text() {
668 let msg = WsMessage::from("hello");
669 assert!(msg.is_text());
670 assert_eq!(msg.as_text(), Some("hello"));
671 }
672
673 #[test]
674 fn ws_message_binary() {
675 let msg = WsMessage::from(vec![1, 2, 3]);
676 assert!(msg.is_binary());
677 assert_eq!(msg.as_bytes(), Some(&[1u8, 2, 3][..]));
678 }
679
680 #[test]
681 fn close_code_conversion() {
682 assert_eq!(u16::from(CloseCode::Normal), 1000);
683 assert_eq!(CloseCode::from(1000), CloseCode::Normal);
684 assert_eq!(CloseCode::from(9999), CloseCode::Other(9999));
685 }
686
687 #[test]
688 fn websocket_connection_has_tx_and_rx() {
689 use futures::sink::SinkExt;
690 use futures::stream;
691
692 let rx_stream = stream::iter(vec![Ok(WsMessage::from("test"))]);
693 let rx = WsStream::new(rx_stream);
694
695 let drain_sink = futures::sink::drain()
696 .sink_map_err(|_: std::convert::Infallible| StreamError::closed());
697 let tx = WsSink::new(drain_sink);
698
699 let conn = WebSocketConnection::new(tx, rx);
700 assert!(conn.is_open());
701 }
702}