1#![allow(unreachable_code)]
2#![allow(unused_imports)]
3
4use crate::{ClientRequest, Encoding, FromResponse, IntoRequest, JsonEncoding, ServerFnError};
25use axum::{
26 extract::{FromRequest, Request},
27 http::StatusCode,
28};
29use axum_core::response::{IntoResponse, Response};
30use bytes::Bytes;
31use dioxus_core::{use_hook, CapturedError, Result};
32use dioxus_fullstack_core::{HttpError, RequestError};
33use dioxus_hooks::{use_resource, Resource, UseWaker};
34use dioxus_hooks::{use_signal, use_waker};
35use dioxus_signals::{ReadSignal, ReadableExt, ReadableOptionExt, Signal, WritableExt};
36use futures::StreamExt;
37use futures::{
38 stream::{SplitSink, SplitStream},
39 SinkExt, TryFutureExt,
40};
41use serde::{de::DeserializeOwned, Serialize};
42use std::{marker::PhantomData, prelude::rust_2024::Future};
43
44#[cfg(feature = "web")]
45use {
46 futures_util::lock::Mutex,
47 gloo_net::websocket::{futures::WebSocket as WsWebsocket, Message as WsMessage},
48};
49
50pub fn use_websocket<
62 In: 'static,
63 Out: 'static,
64 E: Into<CapturedError> + 'static,
65 F: Future<Output = Result<Websocket<In, Out, Enc>, E>> + 'static,
66 Enc: Encoding,
67>(
68 mut connect_to_websocket: impl FnMut() -> F + 'static,
69) -> UseWebsocket<In, Out, Enc> {
70 let mut waker = use_waker();
71 let mut status = use_signal(|| WebsocketState::Connecting);
72 let status_read = use_hook(|| ReadSignal::new(status));
73
74 let connection = use_resource(move || {
75 let connection = connect_to_websocket().map_err(|e| e.into());
76 async move {
77 let connection = connection.await;
78
79 match connection.as_ref() {
81 Ok(_) => status.set(WebsocketState::Open),
82 Err(_) => status.set(WebsocketState::FailedToConnect),
83 }
84
85 waker.wake(());
87
88 connection
89 }
90 });
91
92 UseWebsocket {
93 connection,
94 waker,
95 status,
96 status_read,
97 }
98}
99
100pub struct UseWebsocket<In, Out, Enc = JsonEncoding>
107where
108 In: 'static,
109 Out: 'static,
110 Enc: 'static,
111{
112 connection: Resource<Result<Websocket<In, Out, Enc>, CapturedError>>,
113 waker: UseWaker<()>,
114 status: Signal<WebsocketState>,
115 status_read: ReadSignal<WebsocketState>,
116}
117
118impl<In, Out, E> UseWebsocket<In, Out, E> {
119 pub async fn connect(&self) -> WebsocketState {
122 while !self.connection.finished() {
124 _ = self.waker.wait().await;
125 }
126
127 self.status.cloned()
128 }
129
130 pub fn connecting(&self) -> bool {
134 matches!(self.status.cloned(), WebsocketState::Connecting)
135 }
136
137 pub fn is_err(&self) -> bool {
139 matches!(self.status.cloned(), WebsocketState::FailedToConnect)
140 }
141
142 pub fn is_closed(&self) -> bool {
144 matches!(
145 self.status.cloned(),
146 WebsocketState::Closed | WebsocketState::FailedToConnect
147 )
148 }
149
150 pub fn status(&self) -> ReadSignal<WebsocketState> {
152 self.status_read
153 }
154
155 pub async fn send_raw(&self, msg: Message) -> Result<(), WebsocketError> {
159 self.connect().await;
160
161 self.connection
162 .as_ref()
163 .as_deref()
164 .ok_or_else(WebsocketError::closed_away)?
165 .as_ref()
166 .map_err(|_| WebsocketError::AlreadyClosed)?
167 .send_raw(msg)
168 .await
169 }
170
171 pub async fn recv_raw(&mut self) -> Result<Message, WebsocketError> {
175 self.connect().await;
176
177 let result = self
178 .connection
179 .as_ref()
180 .as_deref()
181 .ok_or_else(WebsocketError::closed_away)?
182 .as_ref()
183 .map_err(|_| WebsocketError::AlreadyClosed)?
184 .recv_raw()
185 .await;
186
187 if let Err(WebsocketError::ConnectionClosed { .. }) = result.as_ref() {
188 self.received_shutdown();
189 }
190
191 result
192 }
193
194 pub async fn send(&self, msg: In) -> Result<(), WebsocketError>
195 where
196 In: Serialize,
197 E: Encoding,
198 {
199 self.send_raw(Message::Binary(
200 E::to_bytes(&msg).ok_or_else(WebsocketError::serialization)?,
201 ))
202 .await
203 }
204
205 pub async fn recv(&mut self) -> Result<Out, WebsocketError>
213 where
214 Out: DeserializeOwned,
215 E: Encoding,
216 {
217 self.connect().await;
218
219 let result = self
220 .connection
221 .as_ref()
222 .as_deref()
223 .ok_or_else(WebsocketError::closed_away)?
224 .as_ref()
225 .map_err(|_| WebsocketError::AlreadyClosed)?
226 .recv()
227 .await;
228
229 if let Err(WebsocketError::ConnectionClosed { .. }) = result.as_ref() {
230 self.received_shutdown();
231 }
232
233 result
234 }
235
236 pub fn set<Err: Into<CapturedError>>(&mut self, socket: Result<Websocket<In, Out, E>, Err>) {
241 match socket {
242 Ok(_) => self.status.set(WebsocketState::Open),
243 Err(_) => self.status.set(WebsocketState::FailedToConnect),
244 }
245
246 self.connection.set(Some(socket.map_err(|e| e.into())));
247 self.waker.wake(());
248 }
249
250 fn received_shutdown(&self) {
252 let mut _self = *self;
253 _self.status.set(WebsocketState::Closed);
254 _self.waker.wake(());
255 }
256}
257
258impl<In, Out, E> Copy for UseWebsocket<In, Out, E> {}
259impl<In, Out, E> Clone for UseWebsocket<In, Out, E> {
260 fn clone(&self) -> Self {
261 *self
262 }
263}
264
265#[derive(Debug, Clone, PartialEq, Copy)]
266pub enum WebsocketState {
267 Connecting,
269
270 Open,
272
273 Closing,
275
276 Closed,
278
279 FailedToConnect,
281}
282
283pub struct Websocket<In = String, Out = String, E = JsonEncoding> {
285 protocol: Option<String>,
286
287 #[allow(clippy::type_complexity)]
288 _in: std::marker::PhantomData<fn() -> (In, Out, E)>,
289
290 #[cfg(not(target_arch = "wasm32"))]
291 native: Option<native::SplitSocket>,
292
293 #[cfg(feature = "web")]
294 web: Option<WebsysSocket>,
295
296 response: Option<axum::response::Response>,
297}
298
299impl<I, O, E> Websocket<I, O, E> {
300 pub async fn recv(&self) -> Result<O, WebsocketError>
301 where
302 O: DeserializeOwned,
303 E: Encoding,
304 {
305 loop {
306 let msg = self.recv_raw().await?;
307 match msg {
308 Message::Text(text) => {
309 let e: O =
310 E::decode(text.into()).ok_or_else(WebsocketError::deserialization)?;
311 return Ok(e);
312 }
313 Message::Binary(bytes) => {
314 let e: O = E::decode(bytes).ok_or_else(WebsocketError::deserialization)?;
315 return Ok(e);
316 }
317 Message::Close { code, reason } => {
318 return Err(WebsocketError::ConnectionClosed {
319 code,
320 description: reason,
321 });
322 }
323
324 Message::Ping(_bytes) => continue,
326 Message::Pong(_bytes) => continue,
327 }
328 }
329 }
330
331 pub async fn send(&self, msg: I) -> Result<(), WebsocketError>
337 where
338 I: Serialize,
339 E: Encoding,
340 {
341 let bytes = E::to_bytes(&msg).ok_or_else(WebsocketError::serialization)?;
342 self.send_raw(Message::Binary(bytes)).await
343 }
344
345 pub async fn send_raw(&self, message: Message) -> Result<(), WebsocketError> {
349 #[cfg(feature = "web")]
350 if cfg!(target_arch = "wasm32") {
351 let mut sender = self
352 .web
353 .as_ref()
354 .ok_or_else(|| WebsocketError::Uninitialized)?
355 .sender
356 .lock()
357 .await;
358
359 match message {
360 Message::Text(s) => {
361 sender.send(gloo_net::websocket::Message::Text(s)).await?;
362 }
363 Message::Binary(bytes) => {
364 sender
365 .send(gloo_net::websocket::Message::Bytes(bytes.into()))
366 .await?;
367 }
368 Message::Close { .. } => {
369 sender.close().await?;
370 }
371 Message::Ping(_bytes) => return Ok(()),
372 Message::Pong(_bytes) => return Ok(()),
373 }
374
375 return Ok(());
376 }
377
378 #[cfg(not(target_arch = "wasm32"))]
379 {
380 let mut sender = self
381 .native
382 .as_ref()
383 .ok_or_else(|| WebsocketError::Uninitialized)?
384 .sender
385 .lock()
386 .await;
387
388 sender
389 .send(message.into())
390 .await
391 .map_err(WebsocketError::from)?;
392 }
393
394 Ok(())
395 }
396
397 pub async fn recv_raw(&self) -> Result<Message, WebsocketError> {
399 #[cfg(feature = "web")]
400 if cfg!(target_arch = "wasm32") {
401 let mut conn = self.web.as_ref().unwrap().receiver.lock().await;
402 return match conn.next().await {
403 Some(Ok(WsMessage::Text(text))) => Ok(Message::Text(text)),
404 Some(Ok(WsMessage::Bytes(items))) => Ok(Message::Binary(items.into())),
405 Some(Err(e)) => Err(WebsocketError::from(e)),
406 None => Err(WebsocketError::closed_away()),
407 };
408 }
409
410 #[cfg(not(target_arch = "wasm32"))]
411 {
412 use tungstenite::Message as TMessage;
413 let mut conn = self.native.as_ref().unwrap().receiver.lock().await;
414 return match conn.next().await {
415 Some(Ok(res)) => match res {
416 TMessage::Text(utf8_bytes) => Ok(Message::Text(utf8_bytes.to_string())),
417 TMessage::Binary(bytes) => Ok(Message::Binary(bytes)),
418 TMessage::Close(Some(cf)) => Ok(Message::Close {
419 code: cf.code.into(),
420 reason: cf.reason.to_string(),
421 }),
422 TMessage::Close(None) => Ok(Message::Close {
423 code: CloseCode::Away,
424 reason: "Away".to_string(),
425 }),
426 TMessage::Ping(bytes) => Ok(Message::Ping(bytes)),
427 TMessage::Pong(bytes) => Ok(Message::Pong(bytes)),
428 TMessage::Frame(_frame) => Err(WebsocketError::Unexpected),
429 },
430 Some(Err(e)) => Err(WebsocketError::from(e)),
431 None => Err(WebsocketError::closed_away()),
432 };
433 }
434
435 unimplemented!("Non web wasm32 clients are not supported yet")
436 }
437
438 pub fn protocol(&self) -> Option<&str> {
439 self.protocol.as_deref()
440 }
441}
442
443impl<I, O, E> PartialEq for Websocket<I, O, E> {
445 fn eq(&self, _other: &Self) -> bool {
446 false
447 }
448}
449
450impl<In, Out, E> IntoResponse for Websocket<In, Out, E> {
452 fn into_response(self) -> Response {
453 let Some(response) = self.response else {
454 return HttpError::new(
455 StatusCode::INTERNAL_SERVER_ERROR,
456 "WebSocket response not initialized",
457 )
458 .into_response();
459 };
460
461 response.into_response()
462 }
463}
464
465impl<I, O, E> FromResponse<UpgradingWebsocket> for Websocket<I, O, E> {
466 fn from_response(res: UpgradingWebsocket) -> impl Future<Output = Result<Self, ServerFnError>> {
467 async move {
468 #[cfg(not(target_arch = "wasm32"))]
469 let native = res.native;
470
471 #[cfg(feature = "web")]
472 let web = res.web.map(|f| {
473 let (sender, receiver) = f.split();
474 WebsysSocket {
475 sender: Mutex::new(sender),
476 receiver: Mutex::new(receiver),
477 }
478 });
479
480 Ok(Websocket {
481 protocol: res.protocol,
482 #[cfg(not(target_arch = "wasm32"))]
483 native,
484 #[cfg(feature = "web")]
485 web,
486 response: None,
487 _in: PhantomData,
488 })
489 }
490 }
491}
492
493pub struct WebSocketOptions {
494 protocols: Vec<String>,
495 automatic_reconnect: bool,
496 #[cfg(feature = "server")]
497 upgrade: Option<axum::extract::ws::WebSocketUpgrade>,
498 #[cfg(feature = "server")]
499 on_failed_upgrade: Option<Box<dyn FnOnce(axum::Error) + Send + 'static>>,
500}
501
502impl WebSocketOptions {
503 pub fn new() -> Self {
504 Self {
505 protocols: Vec::new(),
506 automatic_reconnect: false,
507
508 #[cfg(feature = "server")]
509 upgrade: None,
510
511 #[cfg(feature = "server")]
512 on_failed_upgrade: None,
513 }
514 }
515
516 pub fn with_automatic_reconnect(mut self) -> Self {
518 self.automatic_reconnect = true;
519 self
520 }
521
522 #[cfg(feature = "server")]
523 pub fn on_failed_upgrade(
524 mut self,
525 callback: impl FnOnce(axum::Error) + Send + 'static,
526 ) -> Self {
527 self.on_failed_upgrade = Some(Box::new(callback));
528
529 self
530 }
531
532 #[cfg(feature = "server")]
533 pub fn on_upgrade<F, Fut, In, Out, Enc>(mut self, callback: F) -> Websocket<In, Out, Enc>
534 where
535 F: FnOnce(TypedWebsocket<In, Out, Enc>) -> Fut + Send + 'static,
536 Fut: Future<Output = ()> + 'static,
537 {
538 let on_failed_upgrade = self.on_failed_upgrade.take();
539 let response = self
540 .upgrade
541 .unwrap()
542 .on_failed_upgrade(|e| {
543 if let Some(callback) = on_failed_upgrade {
544 callback(e);
545 }
546 })
547 .on_upgrade(|socket| {
548 let res = crate::spawn_platform(move || {
549 callback(TypedWebsocket {
550 _in: PhantomData,
551 _out: PhantomData,
552 _enc: PhantomData,
553 inner: socket,
554 })
555 });
556 async move {
557 let _ = res.await;
558 }
559 });
560
561 Websocket {
562 protocol: None,
564 response: Some(response),
565 _in: PhantomData,
566
567 #[cfg(not(target_arch = "wasm32"))]
568 native: None,
569
570 #[cfg(feature = "web")]
571 web: None,
572 }
573 }
574}
575
576impl Default for WebSocketOptions {
577 fn default() -> Self {
578 Self::new()
579 }
580}
581
582impl IntoRequest<UpgradingWebsocket> for WebSocketOptions {
583 fn into_request(
584 self,
585 request: ClientRequest,
586 ) -> impl Future<Output = std::result::Result<UpgradingWebsocket, RequestError>> + 'static {
587 async move {
588 #[cfg(feature = "web")]
589 if cfg!(target_arch = "wasm32") {
590 let url_path = request.url().path();
591 let url_query = request.url().query();
592 let url_fragment = request.url().fragment();
593 let path_and_query = format!(
594 "{}{}{}",
595 url_path,
596 url_query.map_or("".to_string(), |q| format!("?{q}")),
597 url_fragment.map_or("".to_string(), |f| format!("#{f}"))
598 );
599
600 let socket = gloo_net::websocket::futures::WebSocket::open_with_protocols(
601 &path_and_query,
604 &self
605 .protocols
606 .iter()
607 .map(String::as_str)
608 .collect::<Vec<_>>(),
609 )
610 .unwrap();
611
612 return Ok(UpgradingWebsocket {
613 protocol: Some(socket.protocol()),
614 web: Some(socket),
615 #[cfg(not(target_arch = "wasm32"))]
616 native: None,
617 });
618 }
619
620 #[cfg(not(target_arch = "wasm32"))]
621 {
622 let response = native::send_request(request, &self.protocols)
623 .await
624 .unwrap();
625
626 let (inner, protocol) = response
627 .into_stream_and_protocol(self.protocols, None)
628 .await
629 .unwrap();
630
631 return Ok(UpgradingWebsocket {
632 protocol,
633 native: Some(inner),
634 #[cfg(feature = "web")]
635 web: None,
636 });
637 }
638
639 unimplemented!("Non web wasm32 clients are not supported yet")
640 }
641 }
642}
643
644impl<S: Send> FromRequest<S> for WebSocketOptions {
645 type Rejection = axum::response::Response;
646
647 fn from_request(
648 _req: Request,
649 _: &S,
650 ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
651 #[cfg(not(feature = "server"))]
652 return async move { Err(StatusCode::NOT_IMPLEMENTED.into_response()) };
653
654 #[cfg(feature = "server")]
655 async move {
656 let ws = match axum::extract::ws::WebSocketUpgrade::from_request(_req, &()).await {
657 Ok(ws) => ws,
658 Err(rejection) => return Err(rejection.into_response()),
659 };
660
661 Ok(WebSocketOptions {
662 protocols: vec![],
663 automatic_reconnect: false,
664 upgrade: Some(ws),
665 on_failed_upgrade: None,
666 })
667 }
668 }
669}
670
671#[doc(hidden)]
672pub struct UpgradingWebsocket {
673 protocol: Option<String>,
674
675 #[cfg(feature = "web")]
676 web: Option<gloo_net::websocket::futures::WebSocket>,
677
678 #[cfg(not(target_arch = "wasm32"))]
679 native: Option<native::SplitSocket>,
680}
681
682unsafe impl Send for UpgradingWebsocket {}
683unsafe impl Sync for UpgradingWebsocket {}
684
685#[cfg(feature = "server")]
686pub struct TypedWebsocket<In, Out, E = JsonEncoding> {
687 _in: std::marker::PhantomData<fn() -> In>,
688 _out: std::marker::PhantomData<fn() -> Out>,
689 _enc: std::marker::PhantomData<fn() -> E>,
690
691 inner: axum::extract::ws::WebSocket,
692}
693
694#[cfg(feature = "server")]
695impl<In: DeserializeOwned, Out: Serialize, E: Encoding> TypedWebsocket<In, Out, E> {
696 pub async fn recv(&mut self) -> Result<In, WebsocketError> {
700 use axum::extract::ws::Message as AxumMessage;
701
702 loop {
703 let Some(res) = self.inner.next().await else {
704 return Err(WebsocketError::closed_away());
705 };
706
707 match res {
708 Ok(res) => match res {
709 AxumMessage::Text(utf8_bytes) => {
710 let e: In = E::decode(utf8_bytes.into())
711 .ok_or_else(WebsocketError::deserialization)?;
712 return Ok(e);
713 }
714 AxumMessage::Binary(bytes) => {
715 let e: In = E::decode(bytes).ok_or_else(WebsocketError::deserialization)?;
716 return Ok(e);
717 }
718
719 AxumMessage::Close(Some(close_frame)) => {
720 return Err(WebsocketError::ConnectionClosed {
721 code: close_frame.code.into(),
722 description: close_frame.reason.to_string(),
723 });
724 }
725 AxumMessage::Close(None) => return Err(WebsocketError::AlreadyClosed),
726
727 AxumMessage::Ping(_bytes) => continue,
728 AxumMessage::Pong(_bytes) => continue,
729 },
730 Err(_res) => return Err(WebsocketError::closed_away()),
731 }
732 }
733 }
734
735 pub async fn send(&mut self, msg: Out) -> Result<(), WebsocketError> {
737 use axum::extract::ws::Message;
738
739 let to_bytes = E::to_bytes(&msg).ok_or_else(|| {
740 WebsocketError::Serialization(anyhow::anyhow!("Failed to serialize message").into())
741 })?;
742
743 self.inner
744 .send(Message::Binary(to_bytes))
745 .await
746 .map_err(|_err| WebsocketError::AlreadyClosed)
747 }
748
749 pub async fn recv_raw(&mut self) -> Result<Message, WebsocketError> {
753 use axum::extract::ws::Message as AxumMessage;
754
755 let message = self
756 .inner
757 .next()
758 .await
759 .ok_or_else(WebsocketError::closed_away)?
760 .map_err(|_| WebsocketError::AlreadyClosed)?;
761
762 Ok(match message {
763 AxumMessage::Text(utf8_bytes) => Message::Text(utf8_bytes.to_string()),
764 AxumMessage::Binary(bytes) => Message::Binary(bytes),
765 AxumMessage::Ping(bytes) => Message::Ping(bytes),
766 AxumMessage::Pong(bytes) => Message::Pong(bytes),
767 AxumMessage::Close(close_frame) => Message::Close {
768 code: close_frame
769 .clone()
770 .map_or(CloseCode::Away, |cf| cf.code.into()),
771 reason: close_frame.map_or("Away".to_string(), |cf| cf.reason.to_string()),
772 },
773 })
774 }
775
776 pub async fn send_raw(&mut self, msg: Message) -> Result<(), WebsocketError> {
778 let real = match msg {
779 Message::Text(text) => axum::extract::ws::Message::Text(text.into()),
780 Message::Binary(bytes) => axum::extract::ws::Message::Binary(bytes),
781 Message::Ping(bytes) => axum::extract::ws::Message::Ping(bytes),
782 Message::Pong(bytes) => axum::extract::ws::Message::Pong(bytes),
783 Message::Close { code, reason } => {
784 axum::extract::ws::Message::Close(Some(axum::extract::ws::CloseFrame {
785 code: code.into(),
786 reason: reason.into(),
787 }))
788 }
789 };
790
791 self.inner
792 .send(real)
793 .await
794 .map_err(|_err| WebsocketError::AlreadyClosed)
795 }
796
797 pub fn protocol(&self) -> Option<&http::HeaderValue> {
799 self.inner.protocol()
800 }
801
802 pub fn socket(&mut self) -> &mut axum::extract::ws::WebSocket {
804 &mut self.inner
805 }
806}
807
808#[derive(thiserror::Error, Debug)]
809pub enum WebsocketError {
810 #[error("Connection closed")]
811 ConnectionClosed {
812 code: CloseCode,
813 description: String,
814 },
815
816 #[error("WebSocket already closed")]
817 AlreadyClosed,
818
819 #[error("WebSocket capacity reached")]
820 Capacity,
821
822 #[error("An unexpected internal error occurred")]
823 Unexpected,
824
825 #[error("WebSocket is not initialized on this platform")]
826 Uninitialized,
827
828 #[cfg(not(target_arch = "wasm32"))]
829 #[error("websocket upgrade failed")]
830 Handshake(#[from] native::HandshakeError),
831
832 #[error("reqwest error")]
833 Reqwest(#[from] reqwest::Error),
834
835 #[cfg(not(target_arch = "wasm32"))]
836 #[error("tungstenite error")]
837 Tungstenite(#[from] tungstenite::Error),
838
839 #[error("error during serialization/deserialization")]
841 Deserialization(Box<dyn std::error::Error + Send + Sync>),
842
843 #[error("error during serialization/deserialization")]
845 Serialization(Box<dyn std::error::Error + Send + Sync>),
846
847 #[error("serde_json error")]
849 Json(#[from] serde_json::Error),
850
851 #[error("ciborium error")]
853 Cbor(#[from] ciborium::de::Error<std::io::Error>),
854}
855
856#[cfg(feature = "web")]
857impl From<gloo_net::websocket::WebSocketError> for WebsocketError {
858 fn from(value: gloo_net::websocket::WebSocketError) -> Self {
859 use gloo_net::websocket::WebSocketError;
860 match value {
861 WebSocketError::ConnectionError => WebsocketError::AlreadyClosed,
862 WebSocketError::ConnectionClose(close_event) => WebsocketError::ConnectionClosed {
863 code: close_event.code.into(),
864 description: close_event.reason,
865 },
866 WebSocketError::MessageSendError(_js_error) => WebsocketError::Unexpected,
867 _ => WebsocketError::Unexpected,
868 }
869 }
870}
871
872impl WebsocketError {
873 pub fn closed_away() -> Self {
874 Self::ConnectionClosed {
875 code: CloseCode::Normal,
876 description: "Connection closed normally".into(),
877 }
878 }
879
880 pub fn deserialization() -> Self {
881 Self::Deserialization(anyhow::anyhow!("Failed to deserialize message").into())
882 }
883
884 pub fn serialization() -> Self {
885 Self::Serialization(anyhow::anyhow!("Failed to serialize message").into())
886 }
887}
888
889#[cfg(feature = "web")]
890struct WebsysSocket {
891 sender: Mutex<SplitSink<WsWebsocket, WsMessage>>,
892 receiver: Mutex<SplitStream<WsWebsocket>>,
893}
894
895#[derive(Clone, Debug)]
897pub enum Message {
898 Text(String),
901
902 Binary(Bytes),
904
905 Ping(Bytes),
913
914 Pong(Bytes),
922
923 Close { code: CloseCode, reason: String },
927}
928
929impl From<String> for Message {
930 #[inline]
931 fn from(value: String) -> Self {
932 Self::Text(value)
933 }
934}
935
936impl From<&str> for Message {
937 #[inline]
938 fn from(value: &str) -> Self {
939 Self::from(value.to_owned())
940 }
941}
942
943impl From<Bytes> for Message {
944 #[inline]
945 fn from(value: Bytes) -> Self {
946 Self::Binary(value)
947 }
948}
949
950impl From<Vec<u8>> for Message {
951 #[inline]
952 fn from(value: Vec<u8>) -> Self {
953 Self::from(Bytes::from(value))
954 }
955}
956
957impl From<&[u8]> for Message {
958 #[inline]
959 fn from(value: &[u8]) -> Self {
960 Self::from(Bytes::copy_from_slice(value))
961 }
962}
963
964#[derive(Debug, Default, Eq, PartialEq, Clone, Copy)]
969#[non_exhaustive]
970pub enum CloseCode {
971 #[default]
974 Normal,
975
976 Away,
979
980 Protocol,
983
984 Unsupported,
989
990 Status,
994
995 Abnormal,
1001
1002 Invalid,
1007
1008 Policy,
1014
1015 Size,
1019
1020 Extension,
1028
1029 Error,
1033
1034 Restart,
1038
1039 Again,
1043
1044 Tls,
1049
1050 Reserved(u16),
1052
1053 Iana(u16),
1057
1058 Library(u16),
1062
1063 Bad(u16),
1065}
1066
1067impl CloseCode {
1068 #[must_use]
1070 pub const fn is_allowed(self) -> bool {
1071 !matches!(
1072 self,
1073 Self::Bad(_) | Self::Reserved(_) | Self::Status | Self::Abnormal | Self::Tls
1074 )
1075 }
1076}
1077
1078impl std::fmt::Display for CloseCode {
1079 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1080 let code: u16 = (*self).into();
1081 write!(f, "{code}")
1082 }
1083}
1084
1085impl From<CloseCode> for u16 {
1086 fn from(code: CloseCode) -> Self {
1087 match code {
1088 CloseCode::Normal => 1000,
1089 CloseCode::Away => 1001,
1090 CloseCode::Protocol => 1002,
1091 CloseCode::Unsupported => 1003,
1092 CloseCode::Status => 1005,
1093 CloseCode::Abnormal => 1006,
1094 CloseCode::Invalid => 1007,
1095 CloseCode::Policy => 1008,
1096 CloseCode::Size => 1009,
1097 CloseCode::Extension => 1010,
1098 CloseCode::Error => 1011,
1099 CloseCode::Restart => 1012,
1100 CloseCode::Again => 1013,
1101 CloseCode::Tls => 1015,
1102 CloseCode::Reserved(code)
1103 | CloseCode::Iana(code)
1104 | CloseCode::Library(code)
1105 | CloseCode::Bad(code) => code,
1106 }
1107 }
1108}
1109
1110impl From<u16> for CloseCode {
1111 fn from(code: u16) -> Self {
1112 match code {
1113 1000 => Self::Normal,
1114 1001 => Self::Away,
1115 1002 => Self::Protocol,
1116 1003 => Self::Unsupported,
1117 1005 => Self::Status,
1118 1006 => Self::Abnormal,
1119 1007 => Self::Invalid,
1120 1008 => Self::Policy,
1121 1009 => Self::Size,
1122 1010 => Self::Extension,
1123 1011 => Self::Error,
1124 1012 => Self::Restart,
1125 1013 => Self::Again,
1126 1015 => Self::Tls,
1127 1016..=2999 => Self::Reserved(code),
1128 3000..=3999 => Self::Iana(code),
1129 4000..=4999 => Self::Library(code),
1130 _ => Self::Bad(code),
1131 }
1132 }
1133}
1134
1135#[cfg(not(target_arch = "wasm32"))]
1136mod native {
1137 use std::borrow::Cow;
1138
1139 use crate::ClientRequest;
1140
1141 use super::{CloseCode, Message, WebsocketError};
1142 use reqwest::{
1143 header::{HeaderName, HeaderValue},
1144 Response, StatusCode, Version,
1145 };
1146 use tungstenite::protocol::WebSocketConfig;
1147
1148 pub(crate) struct SplitSocket {
1149 pub sender: futures_util::lock::Mutex<
1150 async_tungstenite::WebSocketSender<tokio_util::compat::Compat<reqwest::Upgraded>>,
1151 >,
1152
1153 pub receiver: futures_util::lock::Mutex<
1154 async_tungstenite::WebSocketReceiver<tokio_util::compat::Compat<reqwest::Upgraded>>,
1155 >,
1156 }
1157
1158 pub async fn send_request(
1159 request: ClientRequest,
1160 protocols: &[String],
1161 ) -> Result<WebSocketResponse, WebsocketError> {
1162 let request_builder = request.new_reqwest_request();
1163 let (client, request_result) = request_builder.build_split();
1164 let mut request = request_result?;
1165
1166 let url = request.url_mut();
1168 match url.scheme() {
1169 "ws" => {
1170 url.set_scheme("http")
1171 .expect("url should accept http scheme");
1172 }
1173 "wss" => {
1174 url.set_scheme("https")
1175 .expect("url should accept https scheme");
1176 }
1177 _ => {}
1178 }
1179
1180 let version = request.version();
1182 let nonce = match version {
1183 Version::HTTP_10 | Version::HTTP_11 => {
1184 let nonce_value = tungstenite::handshake::client::generate_key();
1186 let headers = request.headers_mut();
1187 headers.insert(
1188 reqwest::header::CONNECTION,
1189 HeaderValue::from_static("upgrade"),
1190 );
1191 headers.insert(
1192 reqwest::header::UPGRADE,
1193 HeaderValue::from_static("websocket"),
1194 );
1195 headers.insert(
1196 reqwest::header::SEC_WEBSOCKET_KEY,
1197 HeaderValue::from_str(&nonce_value).expect("nonce is a invalid header value"),
1198 );
1199 headers.insert(
1200 reqwest::header::SEC_WEBSOCKET_VERSION,
1201 HeaderValue::from_static("13"),
1202 );
1203 if !protocols.is_empty() {
1204 headers.insert(
1205 reqwest::header::SEC_WEBSOCKET_PROTOCOL,
1206 HeaderValue::from_str(&protocols.join(", "))
1207 .expect("protocols is an invalid header value"),
1208 );
1209 }
1210
1211 Some(nonce_value)
1212 }
1213 Version::HTTP_2 => {
1214 return Err(HandshakeError::UnsupportedHttpVersion(version).into());
1216 }
1217 _ => {
1218 return Err(HandshakeError::UnsupportedHttpVersion(version).into());
1219 }
1220 };
1221
1222 let response = client.execute(request).await?;
1224
1225 Ok(WebSocketResponse {
1226 response,
1227 version,
1228 nonce,
1229 })
1230 }
1231
1232 pub type WebSocketStream =
1233 async_tungstenite::WebSocketStream<tokio_util::compat::Compat<reqwest::Upgraded>>;
1234
1235 #[derive(Debug, thiserror::Error, Clone)]
1237 pub enum HandshakeError {
1238 #[error("unsupported http version: {0:?}")]
1239 UnsupportedHttpVersion(Version),
1240
1241 #[error("the server responded with a different http version. this could be the case because reqwest silently upgraded the connection to http2. see: https://github.com/jgraef/reqwest-websocket/issues/2")]
1242 ServerRespondedWithDifferentVersion,
1243
1244 #[error("missing header {header}")]
1245 MissingHeader { header: HeaderName },
1246
1247 #[error("unexpected value for header {header}: expected {expected}, but got {got:?}.")]
1248 UnexpectedHeaderValue {
1249 header: HeaderName,
1250 got: HeaderValue,
1251 expected: Cow<'static, str>,
1252 },
1253
1254 #[error("expected the server to select a protocol.")]
1255 ExpectedAProtocol,
1256
1257 #[error("unexpected protocol: {got}")]
1258 UnexpectedProtocol { got: String },
1259
1260 #[error("unexpected status code: {0}")]
1261 UnexpectedStatusCode(StatusCode),
1262 }
1263
1264 pub struct WebSocketResponse {
1265 pub response: Response,
1266 pub version: Version,
1267 pub nonce: Option<String>,
1268 }
1269
1270 impl WebSocketResponse {
1271 pub async fn into_stream_and_protocol(
1272 self,
1273 protocols: Vec<String>,
1274 web_socket_config: Option<WebSocketConfig>,
1275 ) -> Result<(SplitSocket, Option<String>), WebsocketError> {
1276 let headers = self.response.headers();
1277
1278 if self.response.version() != self.version {
1279 return Err(HandshakeError::ServerRespondedWithDifferentVersion.into());
1280 }
1281
1282 if self.response.status() != reqwest::StatusCode::SWITCHING_PROTOCOLS {
1283 tracing::debug!(status_code = %self.response.status(), "server responded with unexpected status code");
1284 return Err(HandshakeError::UnexpectedStatusCode(self.response.status()).into());
1285 }
1286
1287 if let Some(header) = headers.get(reqwest::header::CONNECTION) {
1288 if !header
1289 .to_str()
1290 .is_ok_and(|s| s.eq_ignore_ascii_case("upgrade"))
1291 {
1292 tracing::debug!("server responded with invalid Connection header: {header:?}");
1293 return Err(HandshakeError::UnexpectedHeaderValue {
1294 header: reqwest::header::CONNECTION,
1295 got: header.clone(),
1296 expected: "upgrade".into(),
1297 }
1298 .into());
1299 }
1300 } else {
1301 tracing::debug!("missing Connection header");
1302 return Err(HandshakeError::MissingHeader {
1303 header: reqwest::header::CONNECTION,
1304 }
1305 .into());
1306 }
1307
1308 if let Some(header) = headers.get(reqwest::header::UPGRADE) {
1309 if !header
1310 .to_str()
1311 .is_ok_and(|s| s.eq_ignore_ascii_case("websocket"))
1312 {
1313 tracing::debug!("server responded with invalid Upgrade header: {header:?}");
1314 return Err(HandshakeError::UnexpectedHeaderValue {
1315 header: reqwest::header::UPGRADE,
1316 got: header.clone(),
1317 expected: "websocket".into(),
1318 }
1319 .into());
1320 }
1321 } else {
1322 tracing::debug!("missing Upgrade header");
1323 return Err(HandshakeError::MissingHeader {
1324 header: reqwest::header::UPGRADE,
1325 }
1326 .into());
1327 }
1328
1329 if let Some(nonce) = &self.nonce {
1330 let expected_nonce = tungstenite::handshake::derive_accept_key(nonce.as_bytes());
1331
1332 if let Some(header) = headers.get(reqwest::header::SEC_WEBSOCKET_ACCEPT) {
1333 if !header.to_str().is_ok_and(|s| s == expected_nonce) {
1334 tracing::debug!(
1335 "server responded with invalid Sec-Websocket-Accept header: {header:?}"
1336 );
1337 return Err(HandshakeError::UnexpectedHeaderValue {
1338 header: reqwest::header::SEC_WEBSOCKET_ACCEPT,
1339 got: header.clone(),
1340 expected: expected_nonce.into(),
1341 }
1342 .into());
1343 }
1344 } else {
1345 tracing::debug!("missing Sec-Websocket-Accept header");
1346 return Err(HandshakeError::MissingHeader {
1347 header: reqwest::header::SEC_WEBSOCKET_ACCEPT,
1348 }
1349 .into());
1350 }
1351 }
1352
1353 let protocol = headers
1354 .get(reqwest::header::SEC_WEBSOCKET_PROTOCOL)
1355 .and_then(|v| v.to_str().ok())
1356 .map(ToOwned::to_owned);
1357
1358 match (protocols.is_empty(), &protocol) {
1359 (true, None) => {
1360 }
1363 (false, None) => {
1364 return Err(HandshakeError::ExpectedAProtocol.into());
1366 }
1367 (false, Some(protocol)) => {
1368 if !protocols.contains(protocol) {
1369 return Err(HandshakeError::UnexpectedProtocol {
1371 got: protocol.clone(),
1372 }
1373 .into());
1374 }
1375 }
1376 (true, Some(protocol)) => {
1377 return Err(HandshakeError::UnexpectedProtocol {
1379 got: protocol.clone(),
1380 }
1381 .into());
1382 }
1383 }
1384
1385 use tokio_util::compat::TokioAsyncReadCompatExt;
1386
1387 let inner = WebSocketStream::from_raw_socket(
1388 self.response.upgrade().await?.compat(),
1389 tungstenite::protocol::Role::Client,
1390 web_socket_config,
1391 )
1392 .await;
1393
1394 let split: (
1395 async_tungstenite::WebSocketSender<tokio_util::compat::Compat<reqwest::Upgraded>>,
1396 async_tungstenite::WebSocketReceiver<tokio_util::compat::Compat<reqwest::Upgraded>>,
1397 ) = inner.split();
1398
1399 let split_socket = SplitSocket {
1400 sender: futures_util::lock::Mutex::new(split.0),
1401 receiver: futures_util::lock::Mutex::new(split.1),
1402 };
1403
1404 Ok((split_socket, protocol))
1405 }
1406 }
1407
1408 #[derive(Debug, thiserror::Error)]
1409 #[error("could not convert message")]
1410 pub struct FromTungsteniteMessageError {
1411 pub original: tungstenite::Message,
1412 }
1413
1414 impl TryFrom<tungstenite::Message> for Message {
1415 type Error = FromTungsteniteMessageError;
1416
1417 fn try_from(value: tungstenite::Message) -> Result<Self, Self::Error> {
1418 match value {
1419 tungstenite::Message::Text(text) => Ok(Self::Text(text.as_str().to_owned())),
1420 tungstenite::Message::Binary(data) => Ok(Self::Binary(data)),
1421 tungstenite::Message::Ping(data) => Ok(Self::Ping(data)),
1422 tungstenite::Message::Pong(data) => Ok(Self::Pong(data)),
1423 tungstenite::Message::Close(Some(tungstenite::protocol::CloseFrame {
1424 code,
1425 reason,
1426 })) => Ok(Self::Close {
1427 code: code.into(),
1428 reason: reason.as_str().to_owned(),
1429 }),
1430 tungstenite::Message::Close(None) => Ok(Self::Close {
1431 code: CloseCode::default(),
1432 reason: "".to_owned(),
1433 }),
1434 tungstenite::Message::Frame(_) => {
1435 Err(FromTungsteniteMessageError { original: value })
1436 }
1437 }
1438 }
1439 }
1440
1441 impl From<Message> for tungstenite::Message {
1442 fn from(value: Message) -> Self {
1443 match value {
1444 Message::Text(text) => Self::Text(tungstenite::Utf8Bytes::from(text)),
1445 Message::Binary(data) => Self::Binary(data),
1446 Message::Ping(data) => Self::Ping(data),
1447 Message::Pong(data) => Self::Pong(data),
1448 Message::Close { code, reason } => {
1449 Self::Close(Some(tungstenite::protocol::CloseFrame {
1450 code: code.into(),
1451 reason: reason.into(),
1452 }))
1453 }
1454 }
1455 }
1456 }
1457
1458 impl From<tungstenite::protocol::frame::coding::CloseCode> for CloseCode {
1459 fn from(value: tungstenite::protocol::frame::coding::CloseCode) -> Self {
1460 u16::from(value).into()
1461 }
1462 }
1463
1464 impl From<CloseCode> for tungstenite::protocol::frame::coding::CloseCode {
1465 fn from(value: CloseCode) -> Self {
1466 u16::from(value).into()
1467 }
1468 }
1469}