1#![allow(unreachable_code)]
2#![allow(unused_imports)]
3
4use crate::{ClientRequest, Encoding, FromResponse, IntoRequest, JsonEncoding, ServerFnError};
25use axum::{
26 extract::{FromRequest, Request},
27 http::StatusCode,
28};
29use axum_core::response::{IntoResponse, Response};
30use bytes::Bytes;
31use dioxus_core::{use_hook, CapturedError, Result};
32use dioxus_fullstack_core::{HttpError, RequestError};
33use dioxus_hooks::{use_resource, Resource, UseWaker};
34use dioxus_hooks::{use_signal, use_waker};
35use dioxus_signals::{ReadSignal, ReadableExt, ReadableOptionExt, Signal, WritableExt};
36use futures::{
37 stream::{SplitSink, SplitStream},
38 Sink, SinkExt, Stream, StreamExt, TryFutureExt,
39};
40use serde::{de::DeserializeOwned, Serialize};
41use std::{
42 marker::PhantomData,
43 pin::Pin,
44 prelude::rust_2024::Future,
45 rc::Rc,
46 task::{ready, Context, Poll},
47};
48
49#[cfg(feature = "web")]
50use {
51 futures_util::lock::Mutex,
52 gloo_net::websocket::{futures::WebSocket as WsWebsocket, Message as WsMessage},
53};
54
55pub fn use_websocket<
67 In: 'static,
68 Out: 'static,
69 E: Into<CapturedError> + 'static,
70 F: Future<Output = Result<Websocket<In, Out, Enc>, E>> + 'static,
71 Enc: Encoding,
72>(
73 mut connect_to_websocket: impl FnMut() -> F + 'static,
74) -> UseWebsocket<In, Out, Enc> {
75 let mut waker = use_waker();
76 let mut status = use_signal(|| WebsocketState::Connecting);
77 let status_read = use_hook(|| ReadSignal::new(status));
78
79 let connection = use_resource(move || {
80 let connection = connect_to_websocket().map_err(|e| e.into());
81 async move {
82 let connection = connection.await;
83
84 match connection.as_ref() {
86 Ok(_) => status.set(WebsocketState::Open),
87 Err(_) => status.set(WebsocketState::FailedToConnect),
88 }
89
90 waker.wake(());
92
93 connection.map(Rc::new)
96 }
97 });
98
99 UseWebsocket {
100 connection,
101 waker,
102 status,
103 status_read,
104 }
105}
106
107pub struct UseWebsocket<In, Out, Enc = JsonEncoding>
114where
115 In: 'static,
116 Out: 'static,
117 Enc: 'static,
118{
119 #[allow(clippy::type_complexity)]
120 connection: Resource<Result<Rc<Websocket<In, Out, Enc>>, CapturedError>>,
121 waker: UseWaker<()>,
122 status: Signal<WebsocketState>,
123 status_read: ReadSignal<WebsocketState>,
124}
125
126impl<In, Out, E> UseWebsocket<In, Out, E> {
127 pub async fn connect(&self) -> WebsocketState {
130 while !self.connection.finished() {
132 _ = self.waker.wait().await;
133 }
134
135 self.status.cloned()
136 }
137
138 pub fn connecting(&self) -> bool {
142 matches!(self.status.cloned(), WebsocketState::Connecting)
143 }
144
145 pub fn is_err(&self) -> bool {
147 matches!(self.status.cloned(), WebsocketState::FailedToConnect)
148 }
149
150 pub fn is_closed(&self) -> bool {
152 matches!(
153 self.status.cloned(),
154 WebsocketState::Closed | WebsocketState::FailedToConnect
155 )
156 }
157
158 pub fn status(&self) -> ReadSignal<WebsocketState> {
160 self.status_read
161 }
162
163 pub async fn send_raw(&self, msg: Message) -> Result<(), WebsocketError> {
167 self.connect().await;
168 self.get_connection()?.send_raw(msg).await
169 }
170
171 pub async fn recv_raw(&mut self) -> Result<Message, WebsocketError> {
175 self.connect().await;
176 let ws = self.get_connection()?;
177
178 let recv_fut = ws.recv_raw();
182 let waker_fut = self.waker.wait();
183 futures::pin_mut!(recv_fut, waker_fut);
184
185 match futures::future::select(recv_fut, waker_fut).await {
186 futures::future::Either::Left((recv_result, _)) => {
187 if let Err(WebsocketError::ConnectionClosed { .. }) = recv_result.as_ref() {
188 self.received_shutdown();
189 }
190 recv_result
191 }
192 futures::future::Either::Right(_) => Err(WebsocketError::ConnectionClosed {
193 code: CloseCode::Away,
194 description: "Connection replaced by a new one".to_string(),
195 }),
196 }
197 }
198
199 pub async fn send(&self, msg: In) -> Result<(), WebsocketError>
200 where
201 In: Serialize,
202 E: Encoding,
203 {
204 self.send_raw(Message::Binary(
205 E::to_bytes(&msg).ok_or_else(WebsocketError::serialization)?,
206 ))
207 .await
208 }
209
210 pub async fn recv(&mut self) -> Result<Out, WebsocketError>
218 where
219 Out: DeserializeOwned,
220 E: Encoding,
221 {
222 self.connect().await;
223 let ws = self.get_connection()?;
224
225 let recv_fut = ws.recv();
226 let waker_fut = self.waker.wait();
227 futures::pin_mut!(recv_fut, waker_fut);
228
229 match futures::future::select(recv_fut, waker_fut).await {
230 futures::future::Either::Left((recv_result, _)) => {
231 if let Err(WebsocketError::ConnectionClosed { .. }) = recv_result.as_ref() {
232 self.received_shutdown();
233 }
234 recv_result
235 }
236 futures::future::Either::Right(_) => Err(WebsocketError::ConnectionClosed {
237 code: CloseCode::Away,
238 description: "Connection replaced by a new one".to_string(),
239 }),
240 }
241 }
242
243 pub fn set<Err: Into<CapturedError>>(&mut self, socket: Result<Websocket<In, Out, E>, Err>) {
248 match socket {
249 Ok(_) => self.status.set(WebsocketState::Open),
250 Err(_) => self.status.set(WebsocketState::FailedToConnect),
251 }
252
253 self.connection
254 .set(Some(socket.map(Rc::new).map_err(|e| e.into())));
255 self.waker.wake(());
256 }
257
258 fn received_shutdown(&self) {
260 let mut _self = *self;
261 _self.status.set(WebsocketState::Closed);
262 _self.waker.wake(());
263 }
264
265 #[allow(clippy::result_large_err)]
269 fn get_connection(&self) -> Result<Rc<Websocket<In, Out, E>>, WebsocketError> {
270 self.connection.with_peek(|opt| {
271 opt.as_ref()
272 .ok_or_else(WebsocketError::closed_away)?
273 .as_ref()
274 .map(Rc::clone)
275 .map_err(|_| WebsocketError::AlreadyClosed)
276 })
277 }
278}
279
280impl<In, Out, E> Copy for UseWebsocket<In, Out, E> {}
281impl<In, Out, E> Clone for UseWebsocket<In, Out, E> {
282 fn clone(&self) -> Self {
283 *self
284 }
285}
286
287#[derive(Debug, Clone, PartialEq, Copy)]
288pub enum WebsocketState {
289 Connecting,
291
292 Open,
294
295 Closing,
297
298 Closed,
300
301 FailedToConnect,
303}
304
305pub struct Websocket<In = String, Out = String, E = JsonEncoding> {
307 protocol: Option<String>,
308
309 #[allow(clippy::type_complexity)]
310 _in: std::marker::PhantomData<fn() -> (In, Out, E)>,
311
312 #[cfg(not(target_arch = "wasm32"))]
313 native: Option<native::SplitSocket>,
314
315 #[cfg(feature = "web")]
316 web: Option<WebsysSocket>,
317
318 response: Option<axum::response::Response>,
319}
320
321impl<I, O, E> Websocket<I, O, E> {
322 pub async fn recv(&self) -> Result<O, WebsocketError>
323 where
324 O: DeserializeOwned,
325 E: Encoding,
326 {
327 loop {
328 let msg = self.recv_raw().await?;
329 match msg {
330 Message::Text(text) => {
331 let e: O =
332 E::decode(text.into()).ok_or_else(WebsocketError::deserialization)?;
333 return Ok(e);
334 }
335 Message::Binary(bytes) => {
336 let e: O = E::decode(bytes).ok_or_else(WebsocketError::deserialization)?;
337 return Ok(e);
338 }
339 Message::Close { code, reason } => {
340 return Err(WebsocketError::ConnectionClosed {
341 code,
342 description: reason,
343 });
344 }
345
346 Message::Ping(_bytes) => continue,
348 Message::Pong(_bytes) => continue,
349 }
350 }
351 }
352
353 pub async fn send(&self, msg: I) -> Result<(), WebsocketError>
359 where
360 I: Serialize,
361 E: Encoding,
362 {
363 let bytes = E::to_bytes(&msg).ok_or_else(WebsocketError::serialization)?;
364 self.send_raw(Message::Binary(bytes)).await
365 }
366
367 pub async fn send_raw(&self, message: Message) -> Result<(), WebsocketError> {
371 #[cfg(feature = "web")]
372 if cfg!(target_arch = "wasm32") {
373 let mut sender = self
374 .web
375 .as_ref()
376 .ok_or_else(|| WebsocketError::Uninitialized)?
377 .sender
378 .lock()
379 .await;
380
381 match message {
382 Message::Text(s) => {
383 sender.send(gloo_net::websocket::Message::Text(s)).await?;
384 }
385 Message::Binary(bytes) => {
386 sender
387 .send(gloo_net::websocket::Message::Bytes(bytes.into()))
388 .await?;
389 }
390 Message::Close { .. } => {
391 sender.close().await?;
392 }
393 Message::Ping(_bytes) => return Ok(()),
394 Message::Pong(_bytes) => return Ok(()),
395 }
396
397 return Ok(());
398 }
399
400 #[cfg(not(target_arch = "wasm32"))]
401 {
402 let mut sender = self
403 .native
404 .as_ref()
405 .ok_or_else(|| WebsocketError::Uninitialized)?
406 .sender
407 .lock()
408 .await;
409
410 sender
411 .send(message.into())
412 .await
413 .map_err(WebsocketError::from)?;
414 }
415
416 Ok(())
417 }
418
419 pub async fn recv_raw(&self) -> Result<Message, WebsocketError> {
421 #[cfg(feature = "web")]
422 if cfg!(target_arch = "wasm32") {
423 let mut conn = self.web.as_ref().unwrap().receiver.lock().await;
424 return match conn.next().await {
425 Some(Ok(WsMessage::Text(text))) => Ok(Message::Text(text)),
426 Some(Ok(WsMessage::Bytes(items))) => Ok(Message::Binary(items.into())),
427 Some(Err(e)) => Err(WebsocketError::from(e)),
428 None => Err(WebsocketError::closed_away()),
429 };
430 }
431
432 #[cfg(not(target_arch = "wasm32"))]
433 {
434 use tungstenite::Message as TMessage;
435 let mut conn = self.native.as_ref().unwrap().receiver.lock().await;
436 return match conn.next().await {
437 Some(Ok(res)) => match res {
438 TMessage::Text(utf8_bytes) => Ok(Message::Text(utf8_bytes.to_string())),
439 TMessage::Binary(bytes) => Ok(Message::Binary(bytes)),
440 TMessage::Close(Some(cf)) => Ok(Message::Close {
441 code: cf.code.into(),
442 reason: cf.reason.to_string(),
443 }),
444 TMessage::Close(None) => Ok(Message::Close {
445 code: CloseCode::Away,
446 reason: "Away".to_string(),
447 }),
448 TMessage::Ping(bytes) => Ok(Message::Ping(bytes)),
449 TMessage::Pong(bytes) => Ok(Message::Pong(bytes)),
450 TMessage::Frame(_frame) => Err(WebsocketError::Unexpected),
451 },
452 Some(Err(e)) => Err(WebsocketError::from(e)),
453 None => Err(WebsocketError::closed_away()),
454 };
455 }
456
457 unimplemented!("Non web wasm32 clients are not supported yet")
458 }
459
460 pub fn protocol(&self) -> Option<&str> {
461 self.protocol.as_deref()
462 }
463}
464
465impl<I, O, E> PartialEq for Websocket<I, O, E> {
467 fn eq(&self, _other: &Self) -> bool {
468 false
469 }
470}
471
472impl<In, Out, E> IntoResponse for Websocket<In, Out, E> {
474 fn into_response(self) -> Response {
475 let Some(response) = self.response else {
476 return HttpError::new(
477 StatusCode::INTERNAL_SERVER_ERROR,
478 "WebSocket response not initialized",
479 )
480 .into_response();
481 };
482
483 response.into_response()
484 }
485}
486
487impl<I, O, E> FromResponse<UpgradingWebsocket> for Websocket<I, O, E> {
488 fn from_response(res: UpgradingWebsocket) -> impl Future<Output = Result<Self, ServerFnError>> {
489 async move {
490 #[cfg(not(target_arch = "wasm32"))]
491 let native = res.native;
492
493 #[cfg(feature = "web")]
494 let web = res.web.map(|f| {
495 let (sender, receiver) = f.split();
496 WebsysSocket {
497 sender: Mutex::new(sender),
498 receiver: Mutex::new(receiver),
499 }
500 });
501
502 Ok(Websocket {
503 protocol: res.protocol,
504 #[cfg(not(target_arch = "wasm32"))]
505 native,
506 #[cfg(feature = "web")]
507 web,
508 response: None,
509 _in: PhantomData,
510 })
511 }
512 }
513}
514
515pub struct WebSocketOptions {
516 protocols: Vec<String>,
517 automatic_reconnect: bool,
518 #[cfg(feature = "server")]
519 upgrade: Option<axum::extract::ws::WebSocketUpgrade>,
520 #[cfg(feature = "server")]
521 on_failed_upgrade: Option<Box<dyn FnOnce(axum::Error) + Send + 'static>>,
522}
523
524impl WebSocketOptions {
525 pub fn new() -> Self {
526 Self {
527 protocols: Vec::new(),
528 automatic_reconnect: false,
529
530 #[cfg(feature = "server")]
531 upgrade: None,
532
533 #[cfg(feature = "server")]
534 on_failed_upgrade: None,
535 }
536 }
537
538 pub fn with_automatic_reconnect(mut self) -> Self {
540 self.automatic_reconnect = true;
541 self
542 }
543
544 #[cfg(feature = "server")]
545 pub fn on_failed_upgrade(
546 mut self,
547 callback: impl FnOnce(axum::Error) + Send + 'static,
548 ) -> Self {
549 self.on_failed_upgrade = Some(Box::new(callback));
550
551 self
552 }
553
554 #[cfg(feature = "server")]
555 pub fn on_upgrade<F, Fut, In, Out, Enc>(mut self, callback: F) -> Websocket<In, Out, Enc>
556 where
557 F: FnOnce(TypedWebsocket<In, Out, Enc>) -> Fut + Send + 'static,
558 Fut: Future<Output = ()> + 'static,
559 {
560 let on_failed_upgrade = self.on_failed_upgrade.take();
561 let response = self
562 .upgrade
563 .unwrap()
564 .on_failed_upgrade(|e| {
565 if let Some(callback) = on_failed_upgrade {
566 callback(e);
567 }
568 })
569 .on_upgrade(|socket| {
570 let res = crate::spawn_platform(move || {
571 callback(TypedWebsocket {
572 _in: PhantomData,
573 _out: PhantomData,
574 _enc: PhantomData,
575 inner: socket,
576 })
577 });
578 async move {
579 let _ = res.await;
580 }
581 });
582
583 Websocket {
584 protocol: None,
586 response: Some(response),
587 _in: PhantomData,
588
589 #[cfg(not(target_arch = "wasm32"))]
590 native: None,
591
592 #[cfg(feature = "web")]
593 web: None,
594 }
595 }
596}
597
598impl Default for WebSocketOptions {
599 fn default() -> Self {
600 Self::new()
601 }
602}
603
604impl IntoRequest<UpgradingWebsocket> for WebSocketOptions {
605 fn into_request(
606 self,
607 request: ClientRequest,
608 ) -> impl Future<Output = std::result::Result<UpgradingWebsocket, RequestError>> + 'static {
609 async move {
610 #[cfg(feature = "web")]
611 if cfg!(target_arch = "wasm32") {
612 let url_path = request.url().path();
613 let url_query = request.url().query();
614 let url_fragment = request.url().fragment();
615 let path_and_query = format!(
616 "{}{}{}",
617 url_path,
618 url_query.map_or("".to_string(), |q| format!("?{q}")),
619 url_fragment.map_or("".to_string(), |f| format!("#{f}"))
620 );
621
622 let socket = gloo_net::websocket::futures::WebSocket::open_with_protocols(
623 &path_and_query,
626 &self
627 .protocols
628 .iter()
629 .map(String::as_str)
630 .collect::<Vec<_>>(),
631 )
632 .map_err(|error| RequestError::Connect(error.to_string()))?;
633
634 return Ok(UpgradingWebsocket {
635 protocol: Some(socket.protocol()),
636 web: Some(socket),
637 #[cfg(not(target_arch = "wasm32"))]
638 native: None,
639 });
640 }
641
642 #[cfg(not(target_arch = "wasm32"))]
643 {
644 let response = native::send_request(request, &self.protocols).await?;
645
646 let (inner, protocol) = response
647 .into_stream_and_protocol(self.protocols, None)
648 .await?;
649
650 return Ok(UpgradingWebsocket {
651 protocol,
652 native: Some(inner),
653 #[cfg(feature = "web")]
654 web: None,
655 });
656 }
657
658 unimplemented!("Non web wasm32 clients are not supported yet")
659 }
660 }
661}
662
663impl<S: Send> FromRequest<S> for WebSocketOptions {
664 type Rejection = axum::response::Response;
665
666 fn from_request(
667 _req: Request,
668 _: &S,
669 ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
670 #[cfg(not(feature = "server"))]
671 return async move { Err(StatusCode::NOT_IMPLEMENTED.into_response()) };
672
673 #[cfg(feature = "server")]
674 async move {
675 let ws = match axum::extract::ws::WebSocketUpgrade::from_request(_req, &()).await {
676 Ok(ws) => ws,
677 Err(rejection) => return Err(rejection.into_response()),
678 };
679
680 Ok(WebSocketOptions {
681 protocols: vec![],
682 automatic_reconnect: false,
683 upgrade: Some(ws),
684 on_failed_upgrade: None,
685 })
686 }
687 }
688}
689
690#[doc(hidden)]
691pub struct UpgradingWebsocket {
692 protocol: Option<String>,
693
694 #[cfg(feature = "web")]
695 web: Option<gloo_net::websocket::futures::WebSocket>,
696
697 #[cfg(not(target_arch = "wasm32"))]
698 native: Option<native::SplitSocket>,
699}
700
701unsafe impl Send for UpgradingWebsocket {}
702unsafe impl Sync for UpgradingWebsocket {}
703
704#[cfg(feature = "server")]
705pub struct TypedWebsocket<In, Out, E = JsonEncoding> {
706 _in: std::marker::PhantomData<fn() -> In>,
707 _out: std::marker::PhantomData<fn() -> Out>,
708 _enc: std::marker::PhantomData<fn() -> E>,
709
710 inner: axum::extract::ws::WebSocket,
711}
712
713#[cfg(feature = "server")]
714impl<In: DeserializeOwned, Out: Serialize, E: Encoding> TypedWebsocket<In, Out, E> {
715 pub async fn recv(&mut self) -> Result<In, WebsocketError> {
717 self.next()
718 .await
719 .unwrap_or(Err(WebsocketError::closed_away()))
720 }
721
722 pub async fn send(&mut self, msg: Out) -> Result<(), WebsocketError> {
724 SinkExt::send(self, msg).await
725 }
726
727 pub async fn recv_raw(&mut self) -> Result<Message, WebsocketError> {
729 use axum::extract::ws::Message as AxumMessage;
730
731 let message = self
732 .inner
733 .next()
734 .await
735 .ok_or_else(WebsocketError::closed_away)?
736 .map_err(|_| WebsocketError::AlreadyClosed)?;
737
738 Ok(match message {
739 AxumMessage::Text(utf8_bytes) => Message::Text(utf8_bytes.to_string()),
740 AxumMessage::Binary(bytes) => Message::Binary(bytes),
741 AxumMessage::Ping(bytes) => Message::Ping(bytes),
742 AxumMessage::Pong(bytes) => Message::Pong(bytes),
743 AxumMessage::Close(close_frame) => Message::Close {
744 code: close_frame
745 .clone()
746 .map_or(CloseCode::Away, |cf| cf.code.into()),
747 reason: close_frame.map_or("Away".to_string(), |cf| cf.reason.to_string()),
748 },
749 })
750 }
751
752 pub async fn send_raw(&mut self, msg: Message) -> Result<(), WebsocketError> {
754 let real = match msg {
755 Message::Text(text) => axum::extract::ws::Message::Text(text.into()),
756 Message::Binary(bytes) => axum::extract::ws::Message::Binary(bytes),
757 Message::Ping(bytes) => axum::extract::ws::Message::Ping(bytes),
758 Message::Pong(bytes) => axum::extract::ws::Message::Pong(bytes),
759 Message::Close { code, reason } => {
760 axum::extract::ws::Message::Close(Some(axum::extract::ws::CloseFrame {
761 code: code.into(),
762 reason: reason.into(),
763 }))
764 }
765 };
766
767 self.inner
768 .send(real)
769 .await
770 .map_err(|_err| WebsocketError::AlreadyClosed)
771 }
772
773 pub fn protocol(&self) -> Option<&http::HeaderValue> {
775 self.inner.protocol()
776 }
777
778 pub fn socket(&mut self) -> &mut axum::extract::ws::WebSocket {
780 &mut self.inner
781 }
782}
783
784#[cfg(feature = "server")]
785impl<In: DeserializeOwned, Out: Serialize, E: Encoding> Stream for TypedWebsocket<In, Out, E> {
786 type Item = Result<In, WebsocketError>;
787
788 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
789 use axum::extract::ws::Message as AxumMessage;
790
791 loop {
792 match ready!(self.inner.poll_next_unpin(cx)) {
793 Some(Ok(msg)) => match msg {
794 AxumMessage::Text(utf8_bytes) => {
795 let e: In = E::decode(utf8_bytes.into())
796 .ok_or_else(WebsocketError::deserialization)?;
797 return Poll::Ready(Some(Ok(e)));
798 }
799 AxumMessage::Binary(bytes) => {
800 let e: In = E::decode(bytes).ok_or_else(WebsocketError::deserialization)?;
801 return Poll::Ready(Some(Ok(e)));
802 }
803
804 AxumMessage::Close(Some(close_frame)) => {
805 return Poll::Ready(Some(Err(WebsocketError::ConnectionClosed {
806 code: close_frame.code.into(),
807 description: close_frame.reason.to_string(),
808 })));
809 }
810 AxumMessage::Close(None) => {
811 return Poll::Ready(Some(Err(WebsocketError::AlreadyClosed)));
812 }
813
814 AxumMessage::Ping(_bytes) => continue,
815 AxumMessage::Pong(_bytes) => continue,
816 },
817 Some(Err(_)) => {
818 return Poll::Ready(Some(Err(WebsocketError::closed_away())));
819 }
820 None => return Poll::Ready(None),
821 }
822 }
823 }
824}
825
826#[cfg(feature = "server")]
827impl<In: DeserializeOwned, Out: Serialize, E: Encoding> Sink<Out> for TypedWebsocket<In, Out, E> {
828 type Error = WebsocketError;
829
830 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
831 Pin::new(&mut self.inner)
832 .poll_ready(cx)
833 .map_err(|_| WebsocketError::AlreadyClosed)
834 }
835
836 fn start_send(mut self: Pin<&mut Self>, item: Out) -> Result<(), Self::Error> {
837 use axum::extract::ws::Message;
838
839 let to_bytes = E::to_bytes(&item).ok_or_else(|| {
840 WebsocketError::Serialization(anyhow::anyhow!("Failed to serialize message").into())
841 })?;
842
843 Pin::new(&mut self.inner)
844 .start_send(Message::Binary(to_bytes))
845 .map_err(|_| WebsocketError::AlreadyClosed)
846 }
847
848 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
849 Pin::new(&mut self.inner)
850 .poll_flush(cx)
851 .map_err(|_| WebsocketError::AlreadyClosed)
852 }
853
854 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
855 Pin::new(&mut self.inner)
856 .poll_close(cx)
857 .map_err(|_| WebsocketError::AlreadyClosed)
858 }
859}
860
861#[derive(thiserror::Error, Debug)]
862pub enum WebsocketError {
863 #[error("Connection closed")]
864 ConnectionClosed {
865 code: CloseCode,
866 description: String,
867 },
868
869 #[error("WebSocket already closed")]
870 AlreadyClosed,
871
872 #[error("WebSocket capacity reached")]
873 Capacity,
874
875 #[error("An unexpected internal error occurred")]
876 Unexpected,
877
878 #[error("WebSocket is not initialized on this platform")]
879 Uninitialized,
880
881 #[cfg(not(target_arch = "wasm32"))]
882 #[error("websocket upgrade failed")]
883 Handshake(#[from] native::HandshakeError),
884
885 #[error("reqwest error")]
886 Reqwest(#[from] reqwest::Error),
887
888 #[cfg(not(target_arch = "wasm32"))]
889 #[error("tungstenite error")]
890 Tungstenite(#[from] tungstenite::Error),
891
892 #[error("error during serialization/deserialization")]
894 Deserialization(Box<dyn std::error::Error + Send + Sync>),
895
896 #[error("error during serialization/deserialization")]
898 Serialization(Box<dyn std::error::Error + Send + Sync>),
899
900 #[error("serde_json error")]
902 Json(#[from] serde_json::Error),
903
904 #[error("ciborium error")]
906 Cbor(#[from] ciborium::de::Error<std::io::Error>),
907}
908
909#[cfg(feature = "web")]
910impl From<gloo_net::websocket::WebSocketError> for WebsocketError {
911 fn from(value: gloo_net::websocket::WebSocketError) -> Self {
912 use gloo_net::websocket::WebSocketError;
913 match value {
914 WebSocketError::ConnectionError => WebsocketError::AlreadyClosed,
915 WebSocketError::ConnectionClose(close_event) => WebsocketError::ConnectionClosed {
916 code: close_event.code.into(),
917 description: close_event.reason,
918 },
919 WebSocketError::MessageSendError(_js_error) => WebsocketError::Unexpected,
920 _ => WebsocketError::Unexpected,
921 }
922 }
923}
924
925impl WebsocketError {
926 pub fn closed_away() -> Self {
927 Self::ConnectionClosed {
928 code: CloseCode::Normal,
929 description: "Connection closed normally".into(),
930 }
931 }
932
933 pub fn deserialization() -> Self {
934 Self::Deserialization(anyhow::anyhow!("Failed to deserialize message").into())
935 }
936
937 pub fn serialization() -> Self {
938 Self::Serialization(anyhow::anyhow!("Failed to serialize message").into())
939 }
940}
941
942#[cfg(feature = "web")]
943struct WebsysSocket {
944 sender: Mutex<SplitSink<WsWebsocket, WsMessage>>,
945 receiver: Mutex<SplitStream<WsWebsocket>>,
946}
947
948#[derive(Clone, Debug)]
950pub enum Message {
951 Text(String),
954
955 Binary(Bytes),
957
958 Ping(Bytes),
966
967 Pong(Bytes),
975
976 Close { code: CloseCode, reason: String },
980}
981
982impl From<String> for Message {
983 #[inline]
984 fn from(value: String) -> Self {
985 Self::Text(value)
986 }
987}
988
989impl From<&str> for Message {
990 #[inline]
991 fn from(value: &str) -> Self {
992 Self::from(value.to_owned())
993 }
994}
995
996impl From<Bytes> for Message {
997 #[inline]
998 fn from(value: Bytes) -> Self {
999 Self::Binary(value)
1000 }
1001}
1002
1003impl From<Vec<u8>> for Message {
1004 #[inline]
1005 fn from(value: Vec<u8>) -> Self {
1006 Self::from(Bytes::from(value))
1007 }
1008}
1009
1010impl From<&[u8]> for Message {
1011 #[inline]
1012 fn from(value: &[u8]) -> Self {
1013 Self::from(Bytes::copy_from_slice(value))
1014 }
1015}
1016
1017#[derive(Debug, Default, Eq, PartialEq, Clone, Copy)]
1022#[non_exhaustive]
1023pub enum CloseCode {
1024 #[default]
1027 Normal,
1028
1029 Away,
1032
1033 Protocol,
1036
1037 Unsupported,
1042
1043 Status,
1047
1048 Abnormal,
1054
1055 Invalid,
1060
1061 Policy,
1067
1068 Size,
1072
1073 Extension,
1081
1082 Error,
1086
1087 Restart,
1091
1092 Again,
1096
1097 Tls,
1102
1103 Reserved(u16),
1105
1106 Iana(u16),
1110
1111 Library(u16),
1115
1116 Bad(u16),
1118}
1119
1120impl CloseCode {
1121 #[must_use]
1123 pub const fn is_allowed(self) -> bool {
1124 !matches!(
1125 self,
1126 Self::Bad(_) | Self::Reserved(_) | Self::Status | Self::Abnormal | Self::Tls
1127 )
1128 }
1129}
1130
1131impl std::fmt::Display for CloseCode {
1132 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1133 let code: u16 = (*self).into();
1134 write!(f, "{code}")
1135 }
1136}
1137
1138impl From<CloseCode> for u16 {
1139 fn from(code: CloseCode) -> Self {
1140 match code {
1141 CloseCode::Normal => 1000,
1142 CloseCode::Away => 1001,
1143 CloseCode::Protocol => 1002,
1144 CloseCode::Unsupported => 1003,
1145 CloseCode::Status => 1005,
1146 CloseCode::Abnormal => 1006,
1147 CloseCode::Invalid => 1007,
1148 CloseCode::Policy => 1008,
1149 CloseCode::Size => 1009,
1150 CloseCode::Extension => 1010,
1151 CloseCode::Error => 1011,
1152 CloseCode::Restart => 1012,
1153 CloseCode::Again => 1013,
1154 CloseCode::Tls => 1015,
1155 CloseCode::Reserved(code)
1156 | CloseCode::Iana(code)
1157 | CloseCode::Library(code)
1158 | CloseCode::Bad(code) => code,
1159 }
1160 }
1161}
1162
1163impl From<u16> for CloseCode {
1164 fn from(code: u16) -> Self {
1165 match code {
1166 1000 => Self::Normal,
1167 1001 => Self::Away,
1168 1002 => Self::Protocol,
1169 1003 => Self::Unsupported,
1170 1005 => Self::Status,
1171 1006 => Self::Abnormal,
1172 1007 => Self::Invalid,
1173 1008 => Self::Policy,
1174 1009 => Self::Size,
1175 1010 => Self::Extension,
1176 1011 => Self::Error,
1177 1012 => Self::Restart,
1178 1013 => Self::Again,
1179 1015 => Self::Tls,
1180 1016..=2999 => Self::Reserved(code),
1181 3000..=3999 => Self::Iana(code),
1182 4000..=4999 => Self::Library(code),
1183 _ => Self::Bad(code),
1184 }
1185 }
1186}
1187
1188#[cfg(not(target_arch = "wasm32"))]
1189mod native {
1190 use std::borrow::Cow;
1191
1192 use crate::ClientRequest;
1193
1194 use super::{CloseCode, Message, WebsocketError};
1195 use dioxus_fullstack_core::RequestError;
1196 use reqwest::{
1197 header::{HeaderName, HeaderValue},
1198 Response, StatusCode, Version,
1199 };
1200 use tungstenite::protocol::WebSocketConfig;
1201
1202 pub(crate) struct SplitSocket {
1203 pub sender: futures_util::lock::Mutex<
1204 async_tungstenite::WebSocketSender<tokio_util::compat::Compat<reqwest::Upgraded>>,
1205 >,
1206
1207 pub receiver: futures_util::lock::Mutex<
1208 async_tungstenite::WebSocketReceiver<tokio_util::compat::Compat<reqwest::Upgraded>>,
1209 >,
1210 }
1211
1212 pub async fn send_request(
1213 request: ClientRequest,
1214 protocols: &[String],
1215 ) -> Result<WebSocketResponse, WebsocketError> {
1216 let request_builder = request.new_reqwest_request();
1217 let (client, request_result) = request_builder.build_split();
1218 let mut request = request_result?;
1219
1220 let url = request.url_mut();
1222 match url.scheme() {
1223 "ws" => {
1224 url.set_scheme("http")
1225 .expect("url should accept http scheme");
1226 }
1227 "wss" => {
1228 url.set_scheme("https")
1229 .expect("url should accept https scheme");
1230 }
1231 _ => {}
1232 }
1233
1234 let version = request.version();
1236 let nonce = match version {
1237 Version::HTTP_10 | Version::HTTP_11 => {
1238 let nonce_value = tungstenite::handshake::client::generate_key();
1240 let headers = request.headers_mut();
1241 headers.insert(
1242 reqwest::header::CONNECTION,
1243 HeaderValue::from_static("upgrade"),
1244 );
1245 headers.insert(
1246 reqwest::header::UPGRADE,
1247 HeaderValue::from_static("websocket"),
1248 );
1249 headers.insert(
1250 reqwest::header::SEC_WEBSOCKET_KEY,
1251 HeaderValue::from_str(&nonce_value).expect("nonce is a invalid header value"),
1252 );
1253 headers.insert(
1254 reqwest::header::SEC_WEBSOCKET_VERSION,
1255 HeaderValue::from_static("13"),
1256 );
1257 if !protocols.is_empty() {
1258 headers.insert(
1259 reqwest::header::SEC_WEBSOCKET_PROTOCOL,
1260 HeaderValue::from_str(&protocols.join(", "))
1261 .expect("protocols is an invalid header value"),
1262 );
1263 }
1264
1265 Some(nonce_value)
1266 }
1267 Version::HTTP_2 => {
1268 return Err(HandshakeError::UnsupportedHttpVersion(version).into());
1270 }
1271 _ => {
1272 return Err(HandshakeError::UnsupportedHttpVersion(version).into());
1273 }
1274 };
1275
1276 let response = client.execute(request).await?;
1278
1279 Ok(WebSocketResponse {
1280 response,
1281 version,
1282 nonce,
1283 })
1284 }
1285
1286 pub type WebSocketStream =
1287 async_tungstenite::WebSocketStream<tokio_util::compat::Compat<reqwest::Upgraded>>;
1288
1289 #[derive(Debug, thiserror::Error, Clone)]
1291 pub enum HandshakeError {
1292 #[error("unsupported http version: {0:?}")]
1293 UnsupportedHttpVersion(Version),
1294
1295 #[error(
1296 "the server responded with a different http version. this could be the case because reqwest silently upgraded the connection to http2. see: https://github.com/jgraef/reqwest-websocket/issues/2"
1297 )]
1298 ServerRespondedWithDifferentVersion,
1299
1300 #[error("missing header {header}")]
1301 MissingHeader { header: HeaderName },
1302
1303 #[error("unexpected value for header {header}: expected {expected}, but got {got:?}.")]
1304 UnexpectedHeaderValue {
1305 header: HeaderName,
1306 got: HeaderValue,
1307 expected: Cow<'static, str>,
1308 },
1309
1310 #[error("expected the server to select a protocol.")]
1311 ExpectedAProtocol,
1312
1313 #[error("unexpected protocol: {got}")]
1314 UnexpectedProtocol { got: String },
1315
1316 #[error("unexpected status code: {0}")]
1317 UnexpectedStatusCode(StatusCode),
1318 }
1319
1320 pub struct WebSocketResponse {
1321 pub response: Response,
1322 pub version: Version,
1323 pub nonce: Option<String>,
1324 }
1325
1326 impl WebSocketResponse {
1327 pub async fn into_stream_and_protocol(
1328 self,
1329 protocols: Vec<String>,
1330 web_socket_config: Option<WebSocketConfig>,
1331 ) -> Result<(SplitSocket, Option<String>), WebsocketError> {
1332 let headers = self.response.headers();
1333
1334 if self.response.version() != self.version {
1335 return Err(HandshakeError::ServerRespondedWithDifferentVersion.into());
1336 }
1337
1338 if self.response.status() != reqwest::StatusCode::SWITCHING_PROTOCOLS {
1339 tracing::debug!(status_code = %self.response.status(), "server responded with unexpected status code");
1340 return Err(HandshakeError::UnexpectedStatusCode(self.response.status()).into());
1341 }
1342
1343 if let Some(header) = headers.get(reqwest::header::CONNECTION) {
1344 if !header
1345 .to_str()
1346 .is_ok_and(|s| s.eq_ignore_ascii_case("upgrade"))
1347 {
1348 tracing::debug!("server responded with invalid Connection header: {header:?}");
1349 return Err(HandshakeError::UnexpectedHeaderValue {
1350 header: reqwest::header::CONNECTION,
1351 got: header.clone(),
1352 expected: "upgrade".into(),
1353 }
1354 .into());
1355 }
1356 } else {
1357 tracing::debug!("missing Connection header");
1358 return Err(HandshakeError::MissingHeader {
1359 header: reqwest::header::CONNECTION,
1360 }
1361 .into());
1362 }
1363
1364 if let Some(header) = headers.get(reqwest::header::UPGRADE) {
1365 if !header
1366 .to_str()
1367 .is_ok_and(|s| s.eq_ignore_ascii_case("websocket"))
1368 {
1369 tracing::debug!("server responded with invalid Upgrade header: {header:?}");
1370 return Err(HandshakeError::UnexpectedHeaderValue {
1371 header: reqwest::header::UPGRADE,
1372 got: header.clone(),
1373 expected: "websocket".into(),
1374 }
1375 .into());
1376 }
1377 } else {
1378 tracing::debug!("missing Upgrade header");
1379 return Err(HandshakeError::MissingHeader {
1380 header: reqwest::header::UPGRADE,
1381 }
1382 .into());
1383 }
1384
1385 if let Some(nonce) = &self.nonce {
1386 let expected_nonce = tungstenite::handshake::derive_accept_key(nonce.as_bytes());
1387
1388 if let Some(header) = headers.get(reqwest::header::SEC_WEBSOCKET_ACCEPT) {
1389 if !header.to_str().is_ok_and(|s| s == expected_nonce) {
1390 tracing::debug!(
1391 "server responded with invalid Sec-Websocket-Accept header: {header:?}"
1392 );
1393 return Err(HandshakeError::UnexpectedHeaderValue {
1394 header: reqwest::header::SEC_WEBSOCKET_ACCEPT,
1395 got: header.clone(),
1396 expected: expected_nonce.into(),
1397 }
1398 .into());
1399 }
1400 } else {
1401 tracing::debug!("missing Sec-Websocket-Accept header");
1402 return Err(HandshakeError::MissingHeader {
1403 header: reqwest::header::SEC_WEBSOCKET_ACCEPT,
1404 }
1405 .into());
1406 }
1407 }
1408
1409 let protocol = headers
1410 .get(reqwest::header::SEC_WEBSOCKET_PROTOCOL)
1411 .and_then(|v| v.to_str().ok())
1412 .map(ToOwned::to_owned);
1413
1414 match (protocols.is_empty(), &protocol) {
1415 (true, None) => {
1416 }
1419 (false, None) => {
1420 return Err(HandshakeError::ExpectedAProtocol.into());
1422 }
1423 (false, Some(protocol)) => {
1424 if !protocols.contains(protocol) {
1425 return Err(HandshakeError::UnexpectedProtocol {
1427 got: protocol.clone(),
1428 }
1429 .into());
1430 }
1431 }
1432 (true, Some(protocol)) => {
1433 return Err(HandshakeError::UnexpectedProtocol {
1435 got: protocol.clone(),
1436 }
1437 .into());
1438 }
1439 }
1440
1441 use tokio_util::compat::TokioAsyncReadCompatExt;
1442
1443 let inner = WebSocketStream::from_raw_socket(
1444 self.response.upgrade().await?.compat(),
1445 tungstenite::protocol::Role::Client,
1446 web_socket_config,
1447 )
1448 .await;
1449
1450 let split: (
1451 async_tungstenite::WebSocketSender<tokio_util::compat::Compat<reqwest::Upgraded>>,
1452 async_tungstenite::WebSocketReceiver<tokio_util::compat::Compat<reqwest::Upgraded>>,
1453 ) = inner.split();
1454
1455 let split_socket = SplitSocket {
1456 sender: futures_util::lock::Mutex::new(split.0),
1457 receiver: futures_util::lock::Mutex::new(split.1),
1458 };
1459
1460 Ok((split_socket, protocol))
1461 }
1462 }
1463
1464 #[derive(Debug, thiserror::Error)]
1465 #[error("could not convert message")]
1466 pub struct FromTungsteniteMessageError {
1467 pub original: tungstenite::Message,
1468 }
1469
1470 impl TryFrom<tungstenite::Message> for Message {
1471 type Error = FromTungsteniteMessageError;
1472
1473 fn try_from(value: tungstenite::Message) -> Result<Self, Self::Error> {
1474 match value {
1475 tungstenite::Message::Text(text) => Ok(Self::Text(text.as_str().to_owned())),
1476 tungstenite::Message::Binary(data) => Ok(Self::Binary(data)),
1477 tungstenite::Message::Ping(data) => Ok(Self::Ping(data)),
1478 tungstenite::Message::Pong(data) => Ok(Self::Pong(data)),
1479 tungstenite::Message::Close(Some(tungstenite::protocol::CloseFrame {
1480 code,
1481 reason,
1482 })) => Ok(Self::Close {
1483 code: code.into(),
1484 reason: reason.as_str().to_owned(),
1485 }),
1486 tungstenite::Message::Close(None) => Ok(Self::Close {
1487 code: CloseCode::default(),
1488 reason: "".to_owned(),
1489 }),
1490 tungstenite::Message::Frame(_) => {
1491 Err(FromTungsteniteMessageError { original: value })
1492 }
1493 }
1494 }
1495 }
1496
1497 impl From<Message> for tungstenite::Message {
1498 fn from(value: Message) -> Self {
1499 match value {
1500 Message::Text(text) => Self::Text(tungstenite::Utf8Bytes::from(text)),
1501 Message::Binary(data) => Self::Binary(data),
1502 Message::Ping(data) => Self::Ping(data),
1503 Message::Pong(data) => Self::Pong(data),
1504 Message::Close { code, reason } => {
1505 Self::Close(Some(tungstenite::protocol::CloseFrame {
1506 code: code.into(),
1507 reason: reason.into(),
1508 }))
1509 }
1510 }
1511 }
1512 }
1513
1514 impl From<tungstenite::protocol::frame::coding::CloseCode> for CloseCode {
1515 fn from(value: tungstenite::protocol::frame::coding::CloseCode) -> Self {
1516 u16::from(value).into()
1517 }
1518 }
1519
1520 impl From<CloseCode> for tungstenite::protocol::frame::coding::CloseCode {
1521 fn from(value: CloseCode) -> Self {
1522 u16::from(value).into()
1523 }
1524 }
1525
1526 impl From<HandshakeError> for RequestError {
1527 fn from(value: HandshakeError) -> Self {
1528 let string = value.to_string();
1529 match value {
1530 HandshakeError::UnexpectedStatusCode(status) => {
1531 Self::Status(string, status.as_u16())
1532 }
1533 HandshakeError::UnsupportedHttpVersion(_)
1534 | HandshakeError::MissingHeader { .. }
1535 | HandshakeError::UnexpectedHeaderValue { .. }
1536 | HandshakeError::ExpectedAProtocol
1537 | HandshakeError::UnexpectedProtocol { .. }
1538 | HandshakeError::ServerRespondedWithDifferentVersion => Self::Connect(string),
1539 }
1540 }
1541 }
1542
1543 trait IntoRequestError {
1544 fn into_request_error(self) -> RequestError;
1545 }
1546
1547 impl IntoRequestError for reqwest::Error {
1548 fn into_request_error(self) -> RequestError {
1549 const DEFAULT_STATUS_CODE: u16 = 0;
1550 let string = self.to_string();
1551 if self.is_builder() {
1552 RequestError::Builder(string)
1553 } else if self.is_redirect() {
1554 RequestError::Redirect(string)
1555 } else if self.is_status() {
1556 RequestError::Status(
1557 string,
1558 self.status()
1559 .as_ref()
1560 .map(StatusCode::as_u16)
1561 .unwrap_or(DEFAULT_STATUS_CODE),
1562 )
1563 } else if self.is_body() {
1564 RequestError::Body(string)
1565 } else if self.is_decode() {
1566 RequestError::Decode(string)
1567 } else if self.is_upgrade() {
1568 RequestError::Connect(string)
1569 } else {
1570 RequestError::Request(string)
1571 }
1572 }
1573 }
1574
1575 impl IntoRequestError for tungstenite::Error {
1576 fn into_request_error(self) -> RequestError {
1577 match self {
1578 tungstenite::Error::ConnectionClosed => {
1579 RequestError::Connect("websocket connection closed".to_owned())
1580 }
1581 tungstenite::Error::AlreadyClosed => {
1582 RequestError::Connect("websocket already closed".to_owned())
1583 }
1584 tungstenite::Error::Io(error) => RequestError::Connect(error.to_string()),
1585 tungstenite::Error::Tls(error) => RequestError::Connect(error.to_string()),
1586 tungstenite::Error::Capacity(error) => RequestError::Body(error.to_string()),
1587 tungstenite::Error::Protocol(error) => RequestError::Request(error.to_string()),
1588 tungstenite::Error::WriteBufferFull(message) => {
1589 RequestError::Body(message.to_string())
1590 }
1591 tungstenite::Error::Utf8(error) => RequestError::Decode(error),
1592 tungstenite::Error::AttackAttempt => {
1593 RequestError::Request("Tungstenite attack attempt detected".to_owned())
1594 }
1595 tungstenite::Error::Url(error) => RequestError::Builder(error.to_string()),
1596 tungstenite::Error::Http(response) => {
1597 let status_code = response.status();
1598 RequestError::Status(format!("HTTP error: {status_code}"), status_code.as_u16())
1599 }
1600 tungstenite::Error::HttpFormat(error) => RequestError::Builder(error.to_string()),
1601 }
1602 }
1603 }
1604
1605 impl From<WebsocketError> for RequestError {
1606 fn from(value: WebsocketError) -> Self {
1607 match value {
1608 WebsocketError::ConnectionClosed { code, description } => {
1609 Self::Connect(format!("connection closed ({code}): {description}"))
1610 }
1611 WebsocketError::AlreadyClosed => Self::Connect(value.to_string()),
1612 WebsocketError::Capacity => Self::Body(value.to_string()),
1613 WebsocketError::Unexpected => Self::Request(value.to_string()),
1614 WebsocketError::Uninitialized => Self::Builder(value.to_string()),
1615 WebsocketError::Handshake(error) => error.into(),
1616 WebsocketError::Reqwest(error) => error.into_request_error(),
1617 WebsocketError::Tungstenite(error) => error.into_request_error(),
1618 WebsocketError::Serialization(error) => Self::Serialization(error.to_string()),
1619 WebsocketError::Deserialization(error) => Self::Decode(error.to_string()),
1620 WebsocketError::Json(error) => Self::Decode(error.to_string()),
1621 WebsocketError::Cbor(error) => Self::Decode(error.to_string()),
1622 }
1623 }
1624 }
1625}