1#[doc(hidden)]
2pub mod util;
3
4use crate::extractor::Extractor;
5
6use std::convert::Infallible;
7use std::fmt;
8use std::str::Utf8Error;
9
10use hyper_util::rt::TokioIo;
11use tracing::warn;
12
13#[doc(hidden)]
14pub use hyper::upgrade;
15
16use tokio_tungstenite::tungstenite;
17use tokio_tungstenite::WebSocketStream;
18use tungstenite::protocol::Role;
19
20use tungstenite::protocol::Message as ProtMessage;
22pub use tungstenite::{
23 error::Error, protocol::frame::coding::CloseCode, protocol::CloseFrame,
24};
25
26use futures_util::sink::SinkExt;
27use futures_util::stream::StreamExt;
28
29pub trait LogWebSocketReturn: fmt::Debug {
30 fn should_log_error(&self) -> bool;
31}
32
33impl<T, E> LogWebSocketReturn for Result<T, E>
34where
35 T: fmt::Debug,
36 E: fmt::Debug,
37{
38 fn should_log_error(&self) -> bool {
39 self.is_err()
40 }
41}
42
43impl LogWebSocketReturn for () {
44 fn should_log_error(&self) -> bool {
45 false
46 }
47}
48
49#[cfg(feature = "json")]
50macro_rules! try2 {
51 ($e:expr) => {
52 match $e {
53 Some(v) => v,
54 None => return Ok(None),
55 }
56 };
57}
58
59#[derive(Debug, Clone, PartialEq, Eq, Hash)]
60pub enum Message {
61 Text(String),
62 Binary(Vec<u8>),
63}
64
65impl Message {
66 pub fn into_data(self) -> Vec<u8> {
67 match self {
68 Self::Text(t) => t.into(),
69 Self::Binary(b) => b,
70 }
71 }
72
73 pub fn to_text(&self) -> Result<&str, Utf8Error> {
74 match self {
75 Self::Text(t) => Ok(&t),
76 Self::Binary(b) => std::str::from_utf8(b),
77 }
78 }
79}
80
81impl From<String> for Message {
82 fn from(s: String) -> Self {
83 Self::Text(s)
84 }
85}
86
87impl From<&str> for Message {
88 fn from(s: &str) -> Self {
89 Self::Text(s.into())
90 }
91}
92
93impl From<Vec<u8>> for Message {
94 fn from(v: Vec<u8>) -> Self {
95 Self::Binary(v)
96 }
97}
98
99impl From<&[u8]> for Message {
100 fn from(v: &[u8]) -> Self {
101 Self::Binary(v.into())
102 }
103}
104
105impl From<Message> for ProtMessage {
106 fn from(m: Message) -> Self {
107 match m {
108 Message::Text(t) => Self::Text(t),
109 Message::Binary(b) => Self::Binary(b),
110 }
111 }
112}
113
114#[derive(Debug)]
115pub struct WebSocket {
116 inner: WebSocketStream<TokioIo<upgrade::Upgraded>>,
117}
118
119impl WebSocket {
120 pub async fn new(upgraded: upgrade::Upgraded) -> Self {
121 Self {
122 inner: WebSocketStream::from_raw_socket(
123 TokioIo::new(upgraded),
124 Role::Server,
125 None,
126 )
127 .await,
128 }
129 }
130
131 #[doc(hidden)]
133 pub fn from_raw(
134 inner: WebSocketStream<TokioIo<upgrade::Upgraded>>,
135 ) -> Self {
136 Self { inner }
137 }
138
139 pub async fn receive(&mut self) -> Result<Option<Message>, Error> {
143 loop {
145 let res = self.inner.next().await.transpose();
146 return match res {
147 Ok(None) => Ok(None),
148 Ok(Some(ProtMessage::Text(t))) => Ok(Some(Message::Text(t))),
149 Ok(Some(ProtMessage::Binary(b))) => {
150 Ok(Some(Message::Binary(b)))
151 }
152 Ok(Some(ProtMessage::Ping(d))) => {
153 self.inner.send(ProtMessage::Pong(d)).await?;
155 continue;
157 }
158 Ok(Some(ProtMessage::Pong(_))) => continue,
159 Ok(Some(ProtMessage::Close(_))) => Ok(None),
160 Ok(Some(ProtMessage::Frame(f))) => {
161 warn!("we received a websocket frame {:?}", f);
162 continue;
164 }
165 Err(Error::ConnectionClosed) | Err(Error::AlreadyClosed) => {
166 Ok(None)
167 }
168 Err(e) => Err(e),
169 };
170 }
171 }
172
173 pub async fn send<M>(&mut self, msg: M) -> Result<(), Error>
174 where
175 M: Into<Message>,
176 {
177 self.inner.send(msg.into().into()).await
178 }
179
180 pub async fn close(&mut self, code: CloseCode, reason: String) {
181 let _ = self
182 .inner
183 .send(ProtMessage::Close(Some(CloseFrame {
184 code,
185 reason: reason.into(),
186 })))
187 .await;
188 let _ = self.inner.close(None).await;
189 }
192
193 pub async fn ping(&mut self) -> Result<(), Error> {
194 self.inner.send(ProtMessage::Ping(vec![])).await
195 }
196
197 #[cfg(feature = "json")]
199 #[cfg_attr(docsrs, doc(cfg(feature = "json")))]
200 pub async fn deserialize<D>(&mut self) -> Result<Option<D>, JsonError>
201 where
202 D: serde::de::DeserializeOwned,
203 {
204 let msg = try2!(self.receive().await?).into_data();
205 serde_json::from_slice(&msg)
206 .map(|d| Some(d))
207 .map_err(|e| e.into())
208 }
209
210 #[cfg(feature = "json")]
212 #[cfg_attr(docsrs, doc(cfg(feature = "json")))]
213 pub async fn serialize<S: ?Sized>(
214 &mut self,
215 value: &S,
216 ) -> Result<(), JsonError>
217 where
218 S: serde::Serialize,
219 {
220 let v = serde_json::to_string(value)?;
221 self.send(v).await.map_err(|e| e.into())
222 }
223}
224
225#[cfg(feature = "json")]
226mod json_error {
227
228 use super::Error;
229 use std::fmt;
230
231 #[cfg_attr(docsrs, doc(cfg(feature = "json")))]
232 #[derive(Debug)]
233 pub enum JsonError {
234 ConnectionError(Error),
235 SerdeError(serde_json::Error),
236 }
237
238 impl From<Error> for JsonError {
239 fn from(e: Error) -> Self {
240 Self::ConnectionError(e)
241 }
242 }
243
244 impl From<serde_json::Error> for JsonError {
245 fn from(e: serde_json::Error) -> Self {
246 Self::SerdeError(e)
247 }
248 }
249
250 impl fmt::Display for JsonError {
251 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
252 fmt::Debug::fmt(self, f)
253 }
254 }
255
256 impl std::error::Error for JsonError {
257 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
258 match self {
259 Self::ConnectionError(e) => Some(e),
260 Self::SerdeError(e) => Some(e),
261 }
262 }
263 }
264}
265
266#[cfg(feature = "json")]
267pub use json_error::JsonError;
268
269impl<'a> Extractor<'a, WebSocket> for WebSocket {
270 type Error = Infallible;
271 type Prepared = ();
272
273 extractor_validate!();
274
275 extractor_prepare!();
276
277 extractor_extract!(<WebSocket> |extract| {
278 Ok(extract.request.take().unwrap())
279 });
280}