1use crate::CowStr;
4use crate::deps::fluent_uri::Uri;
5use crate::stream::StreamError;
6use alloc::boxed::Box;
7use alloc::string::String;
8use alloc::string::ToString;
9use alloc::vec::Vec;
10use bytes::Bytes;
11use core::borrow::Borrow;
12use core::fmt::{self, Display};
13use core::future::Future;
14use core::ops::Deref;
15use core::pin::Pin;
16use n0_future::Stream;
17
18#[repr(transparent)]
20#[derive(Debug, Clone, Eq, PartialEq, Hash, PartialOrd, Ord)]
21pub struct WsText(Bytes);
22
23impl WsText {
24 pub const fn from_static(s: &'static str) -> Self {
26 Self(Bytes::from_static(s.as_bytes()))
27 }
28
29 pub fn as_str(&self) -> &str {
31 unsafe { core::str::from_utf8_unchecked(&self.0) }
32 }
33
34 pub unsafe fn from_bytes_unchecked(bytes: Bytes) -> Self {
39 Self(bytes)
40 }
41
42 pub fn into_bytes(self) -> Bytes {
44 self.0
45 }
46}
47
48impl Deref for WsText {
49 type Target = str;
50 fn deref(&self) -> &str {
51 self.as_str()
52 }
53}
54
55impl AsRef<str> for WsText {
56 fn as_ref(&self) -> &str {
57 self.as_str()
58 }
59}
60
61impl AsRef<[u8]> for WsText {
62 fn as_ref(&self) -> &[u8] {
63 &self.0
64 }
65}
66
67impl AsRef<Bytes> for WsText {
68 fn as_ref(&self) -> &Bytes {
69 &self.0
70 }
71}
72
73impl Borrow<str> for WsText {
74 fn borrow(&self) -> &str {
75 self.as_str()
76 }
77}
78
79impl Display for WsText {
80 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
81 Display::fmt(self.as_str(), f)
82 }
83}
84
85impl From<String> for WsText {
86 fn from(s: String) -> Self {
87 Self(Bytes::from(s))
88 }
89}
90
91impl From<&str> for WsText {
92 fn from(s: &str) -> Self {
93 Self(Bytes::copy_from_slice(s.as_bytes()))
94 }
95}
96
97impl From<&String> for WsText {
98 fn from(s: &String) -> Self {
99 Self::from(s.as_str())
100 }
101}
102
103impl TryFrom<Bytes> for WsText {
104 type Error = core::str::Utf8Error;
105 fn try_from(bytes: Bytes) -> Result<Self, Self::Error> {
106 core::str::from_utf8(&bytes)?;
107 Ok(Self(bytes))
108 }
109}
110
111impl TryFrom<Vec<u8>> for WsText {
112 type Error = core::str::Utf8Error;
113 fn try_from(vec: Vec<u8>) -> Result<Self, Self::Error> {
114 Self::try_from(Bytes::from(vec))
115 }
116}
117
118impl From<WsText> for Bytes {
119 fn from(t: WsText) -> Bytes {
120 t.0
121 }
122}
123
124impl Default for WsText {
125 fn default() -> Self {
126 Self(Bytes::new())
127 }
128}
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
132#[repr(u16)]
133pub enum CloseCode {
134 Normal = 1000,
136 Away = 1001,
138 Protocol = 1002,
140 Unsupported = 1003,
142 Invalid = 1007,
144 Policy = 1008,
146 Size = 1009,
148 Extension = 1010,
150 Error = 1011,
152 Tls = 1015,
154 Other(u16),
156}
157
158impl From<u16> for CloseCode {
159 fn from(code: u16) -> Self {
160 match code {
161 1000 => CloseCode::Normal,
162 1001 => CloseCode::Away,
163 1002 => CloseCode::Protocol,
164 1003 => CloseCode::Unsupported,
165 1007 => CloseCode::Invalid,
166 1008 => CloseCode::Policy,
167 1009 => CloseCode::Size,
168 1010 => CloseCode::Extension,
169 1011 => CloseCode::Error,
170 1015 => CloseCode::Tls,
171 other => CloseCode::Other(other),
172 }
173 }
174}
175
176impl From<CloseCode> for u16 {
177 fn from(code: CloseCode) -> u16 {
178 match code {
179 CloseCode::Normal => 1000,
180 CloseCode::Away => 1001,
181 CloseCode::Protocol => 1002,
182 CloseCode::Unsupported => 1003,
183 CloseCode::Invalid => 1007,
184 CloseCode::Policy => 1008,
185 CloseCode::Size => 1009,
186 CloseCode::Extension => 1010,
187 CloseCode::Error => 1011,
188 CloseCode::Tls => 1015,
189 CloseCode::Other(code) => code,
190 }
191 }
192}
193
194#[derive(Debug, Clone, PartialEq, Eq)]
196pub struct CloseFrame<'a> {
197 pub code: CloseCode,
199 pub reason: CowStr<'a>,
201}
202
203impl<'a> CloseFrame<'a> {
204 pub fn new(code: CloseCode, reason: impl Into<CowStr<'a>>) -> Self {
206 Self {
207 code,
208 reason: reason.into(),
209 }
210 }
211}
212
213#[derive(Debug, Clone, PartialEq, Eq)]
215pub enum WsMessage {
216 Text(WsText),
218 Binary(Bytes),
220 Close(Option<CloseFrame<'static>>),
222}
223
224impl WsMessage {
225 pub fn is_text(&self) -> bool {
227 matches!(self, WsMessage::Text(_))
228 }
229
230 pub fn is_binary(&self) -> bool {
232 matches!(self, WsMessage::Binary(_))
233 }
234
235 pub fn is_close(&self) -> bool {
237 matches!(self, WsMessage::Close(_))
238 }
239
240 pub fn as_text(&self) -> Option<&str> {
242 match self {
243 WsMessage::Text(t) => Some(t.as_str()),
244 _ => None,
245 }
246 }
247
248 pub fn as_bytes(&self) -> Option<&[u8]> {
250 match self {
251 WsMessage::Text(t) => Some(t.as_ref()),
252 WsMessage::Binary(b) => Some(b),
253 WsMessage::Close(_) => None,
254 }
255 }
256}
257
258impl From<WsText> for WsMessage {
259 fn from(text: WsText) -> Self {
260 WsMessage::Text(text)
261 }
262}
263
264impl From<String> for WsMessage {
265 fn from(s: String) -> Self {
266 WsMessage::Text(WsText::from(s))
267 }
268}
269
270impl From<&str> for WsMessage {
271 fn from(s: &str) -> Self {
272 WsMessage::Text(WsText::from(s))
273 }
274}
275
276impl From<Bytes> for WsMessage {
277 fn from(bytes: Bytes) -> Self {
278 WsMessage::Binary(bytes)
279 }
280}
281
282impl From<Vec<u8>> for WsMessage {
283 fn from(vec: Vec<u8>) -> Self {
284 WsMessage::Binary(Bytes::from(vec))
285 }
286}
287
288#[cfg(not(target_arch = "wasm32"))]
290pub struct WsStream(Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>> + Send>>);
291
292#[cfg(target_arch = "wasm32")]
294pub struct WsStream(Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>>>>);
295
296impl WsStream {
297 #[cfg(not(target_arch = "wasm32"))]
299 pub fn new<S>(stream: S) -> Self
300 where
301 S: Stream<Item = Result<WsMessage, StreamError>> + Send + 'static,
302 {
303 Self(Box::pin(stream))
304 }
305
306 #[cfg(target_arch = "wasm32")]
308 pub fn new<S>(stream: S) -> Self
309 where
310 S: Stream<Item = Result<WsMessage, StreamError>> + 'static,
311 {
312 Self(Box::pin(stream))
313 }
314
315 #[cfg(not(target_arch = "wasm32"))]
317 pub fn into_inner(self) -> Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>> + Send>> {
318 self.0
319 }
320
321 #[cfg(target_arch = "wasm32")]
323 pub fn into_inner(self) -> Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>>>> {
324 self.0
325 }
326
327 pub fn tee(self) -> (WsStream, WsStream) {
334 use futures::channel::mpsc;
335 use n0_future::StreamExt as _;
336
337 let (tx1, rx1) = mpsc::unbounded();
338 let (tx2, rx2) = mpsc::unbounded();
339
340 n0_future::task::spawn(async move {
341 let mut stream = self.0;
342 while let Some(result) = stream.next().await {
343 match result {
344 Ok(msg) => {
345 let msg2 = msg.clone();
347
348 let send1 = tx1.unbounded_send(Ok(msg));
350 let send2 = tx2.unbounded_send(Ok(msg2));
351
352 if send1.is_err() && send2.is_err() {
354 break;
355 }
356 }
357 Err(_e) => {
358 break;
361 }
362 }
363 }
364 });
365
366 (WsStream::new(rx1), WsStream::new(rx2))
367 }
368}
369
370impl fmt::Debug for WsStream {
371 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
372 f.debug_struct("WsStream").finish_non_exhaustive()
373 }
374}
375
376#[cfg(not(target_arch = "wasm32"))]
378pub struct WsSink(Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>>);
379
380#[cfg(target_arch = "wasm32")]
382pub struct WsSink(Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError>>>);
383
384impl WsSink {
385 #[cfg(not(target_arch = "wasm32"))]
387 pub fn new<S>(sink: S) -> Self
388 where
389 S: n0_future::Sink<WsMessage, Error = StreamError> + Send + 'static,
390 {
391 Self(Box::pin(sink))
392 }
393
394 #[cfg(target_arch = "wasm32")]
396 pub fn new<S>(sink: S) -> Self
397 where
398 S: n0_future::Sink<WsMessage, Error = StreamError> + 'static,
399 {
400 Self(Box::pin(sink))
401 }
402
403 #[cfg(not(target_arch = "wasm32"))]
405 pub fn into_inner(
406 self,
407 ) -> Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>> {
408 self.0
409 }
410
411 #[cfg(target_arch = "wasm32")]
413 pub fn into_inner(self) -> Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError>>> {
414 self.0
415 }
416
417 #[cfg(not(target_arch = "wasm32"))]
419 pub fn get_mut(
420 &mut self,
421 ) -> &mut Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>> {
422 use core::borrow::BorrowMut;
423
424 self.0.borrow_mut()
425 }
426
427 #[cfg(target_arch = "wasm32")]
429 pub fn get_mut(
430 &mut self,
431 ) -> &mut Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + 'static>> {
432 use core::borrow::BorrowMut;
433
434 self.0.borrow_mut()
435 }
436}
437
438impl fmt::Debug for WsSink {
439 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
440 f.debug_struct("WsSink").finish_non_exhaustive()
441 }
442}
443
444#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
446pub trait WebSocketClient: Sync {
447 type Error: core::error::Error + Send + Sync + 'static;
449
450 fn connect(
452 &self,
453 uri: Uri<&str>,
454 ) -> impl Future<Output = Result<WebSocketConnection, Self::Error>>;
455
456 fn connect_with_headers(
461 &self,
462 uri: Uri<&str>,
463 _headers: Vec<(CowStr<'_>, CowStr<'_>)>,
464 ) -> impl Future<Output = Result<WebSocketConnection, Self::Error>> {
465 async move { self.connect(uri).await }
466 }
467}
468
469pub struct WebSocketConnection {
471 tx: WsSink,
472 rx: WsStream,
473}
474
475impl WebSocketConnection {
476 pub fn new(tx: WsSink, rx: WsStream) -> Self {
478 Self { tx, rx }
479 }
480
481 pub fn sender_mut(&mut self) -> &mut WsSink {
483 &mut self.tx
484 }
485
486 pub fn receiver_mut(&mut self) -> &mut WsStream {
488 &mut self.rx
489 }
490
491 pub fn receiver(&self) -> &WsStream {
493 &self.rx
494 }
495
496 pub fn sender(&self) -> &WsSink {
498 &self.tx
499 }
500
501 pub fn split(self) -> (WsSink, WsStream) {
503 (self.tx, self.rx)
504 }
505
506 pub fn is_open(&self) -> bool {
508 true
509 }
510}
511
512impl fmt::Debug for WebSocketConnection {
513 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
514 f.debug_struct("WebSocketConnection")
515 .finish_non_exhaustive()
516 }
517}
518
519pub mod tungstenite_client {
521 use super::*;
522 use crate::IntoStatic;
523 use futures::{SinkExt, StreamExt};
524
525 #[derive(Debug, Clone, Default)]
527 pub struct TungsteniteClient;
528
529 impl TungsteniteClient {
530 pub fn new() -> Self {
532 Self
533 }
534 }
535
536 impl WebSocketClient for TungsteniteClient {
537 type Error = tokio_tungstenite_wasm::Error;
538
539 async fn connect(&self, uri: Uri<&str>) -> Result<WebSocketConnection, Self::Error> {
540 let ws_stream = tokio_tungstenite_wasm::connect(uri.as_str()).await?;
541
542 let (sink, stream) = ws_stream.split();
543
544 let rx_stream = stream.filter_map(|result| async move {
546 match result {
547 Ok(msg) => match convert_message(msg) {
548 Some(ws_msg) => Some(Ok(ws_msg)),
549 None => None, },
551 Err(e) => Some(Err(StreamError::transport(e))),
552 }
553 });
554
555 let rx = WsStream::new(rx_stream);
556
557 let tx_sink = sink.with(|msg: WsMessage| async move {
559 Ok::<_, tokio_tungstenite_wasm::Error>(msg.into())
560 });
561
562 let tx_sink_mapped = tx_sink.sink_map_err(|e| StreamError::transport(e));
563 let tx = WsSink::new(tx_sink_mapped);
564
565 Ok(WebSocketConnection::new(tx, rx))
566 }
567 }
568
569 fn convert_message(msg: tokio_tungstenite_wasm::Message) -> Option<WsMessage> {
572 use tokio_tungstenite_wasm::Message;
573
574 match msg {
575 Message::Text(vec) => {
576 let bytes = Bytes::from(vec);
578 Some(WsMessage::Text(unsafe {
579 WsText::from_bytes_unchecked(bytes)
580 }))
581 }
582 Message::Binary(vec) => Some(WsMessage::Binary(Bytes::from(vec))),
583 Message::Close(frame) => {
584 let close_frame = frame.map(|f| {
585 let code = convert_close_code(f.code);
586 CloseFrame::new(code, CowStr::from(f.reason.into_owned()))
587 });
588 Some(WsMessage::Close(close_frame))
589 }
590 }
591 }
592
593 fn convert_close_code(code: tokio_tungstenite_wasm::CloseCode) -> CloseCode {
595 use tokio_tungstenite_wasm::CloseCode as TungsteniteCode;
596
597 match code {
598 TungsteniteCode::Normal => CloseCode::Normal,
599 TungsteniteCode::Away => CloseCode::Away,
600 TungsteniteCode::Protocol => CloseCode::Protocol,
601 TungsteniteCode::Unsupported => CloseCode::Unsupported,
602 TungsteniteCode::Invalid => CloseCode::Invalid,
603 TungsteniteCode::Policy => CloseCode::Policy,
604 TungsteniteCode::Size => CloseCode::Size,
605 TungsteniteCode::Extension => CloseCode::Extension,
606 TungsteniteCode::Error => CloseCode::Error,
607 TungsteniteCode::Tls => CloseCode::Tls,
608 other => {
610 let raw: u16 = other.into();
611 CloseCode::from(raw)
612 }
613 }
614 }
615
616 impl From<WsMessage> for tokio_tungstenite_wasm::Message {
617 fn from(msg: WsMessage) -> Self {
618 use tokio_tungstenite_wasm::Message;
619
620 match msg {
621 WsMessage::Text(text) => {
622 let bytes = text.into_bytes();
624 let string = unsafe { String::from_utf8_unchecked(bytes.to_vec()) };
626 Message::Text(string)
627 }
628 WsMessage::Binary(bytes) => Message::Binary(bytes.to_vec()),
629 WsMessage::Close(frame) => {
630 let close_frame = frame.map(|f| {
631 let code = u16::from(f.code).into();
632 tokio_tungstenite_wasm::CloseFrame {
633 code,
634 reason: f.reason.into_static().to_string().into(),
635 }
636 });
637 Message::Close(close_frame)
638 }
639 }
640 }
641 }
642}
643
644#[cfg(test)]
645mod tests {
646 use super::*;
647
648 #[test]
649 fn ws_text_from_string() {
650 let text = WsText::from("hello");
651 assert_eq!(text.as_str(), "hello");
652 }
653
654 #[test]
655 fn ws_text_deref() {
656 let text = WsText::from(String::from("world"));
657 assert_eq!(&*text, "world");
658 }
659
660 #[test]
661 fn ws_text_try_from_bytes() {
662 let bytes = Bytes::from("test");
663 let text = WsText::try_from(bytes).unwrap();
664 assert_eq!(text.as_str(), "test");
665 }
666
667 #[test]
668 fn ws_text_invalid_utf8() {
669 let bytes = Bytes::from(vec![0xFF, 0xFE]);
670 assert!(WsText::try_from(bytes).is_err());
671 }
672
673 #[test]
674 fn ws_message_text() {
675 let msg = WsMessage::from("hello");
676 assert!(msg.is_text());
677 assert_eq!(msg.as_text(), Some("hello"));
678 }
679
680 #[test]
681 fn ws_message_binary() {
682 let msg = WsMessage::from(vec![1, 2, 3]);
683 assert!(msg.is_binary());
684 assert_eq!(msg.as_bytes(), Some(&[1u8, 2, 3][..]));
685 }
686
687 #[test]
688 fn close_code_conversion() {
689 assert_eq!(u16::from(CloseCode::Normal), 1000);
690 assert_eq!(CloseCode::from(1000), CloseCode::Normal);
691 assert_eq!(CloseCode::from(9999), CloseCode::Other(9999));
692 }
693
694 #[test]
695 fn websocket_connection_has_tx_and_rx() {
696 use futures::sink::SinkExt;
697 use futures::stream;
698
699 let rx_stream = stream::iter(vec![Ok(WsMessage::from("test"))]);
700 let rx = WsStream::new(rx_stream);
701
702 let drain_sink = futures::sink::drain()
703 .sink_map_err(|_: std::convert::Infallible| StreamError::closed());
704 let tx = WsSink::new(drain_sink);
705
706 let conn = WebSocketConnection::new(tx, rx);
707 assert!(conn.is_open());
708 }
709}