fire_http/ws/
mod.rs

1#[doc(hidden)]
2pub mod util;
3
4use crate::extractor::Extractor;
5
6use std::convert::Infallible;
7use std::fmt;
8use std::str::Utf8Error;
9
10use hyper_util::rt::TokioIo;
11use tracing::warn;
12
13#[doc(hidden)]
14pub use hyper::upgrade;
15
16use tokio_tungstenite::tungstenite;
17use tokio_tungstenite::WebSocketStream;
18use tungstenite::protocol::Role;
19
20// rexport
21use tungstenite::protocol::Message as ProtMessage;
22pub use tungstenite::{
23	error::Error, protocol::frame::coding::CloseCode, protocol::CloseFrame,
24};
25
26use futures_util::sink::SinkExt;
27use futures_util::stream::StreamExt;
28
29pub trait LogWebSocketReturn: fmt::Debug {
30	fn should_log_error(&self) -> bool;
31}
32
33impl<T, E> LogWebSocketReturn for Result<T, E>
34where
35	T: fmt::Debug,
36	E: fmt::Debug,
37{
38	fn should_log_error(&self) -> bool {
39		self.is_err()
40	}
41}
42
43impl LogWebSocketReturn for () {
44	fn should_log_error(&self) -> bool {
45		false
46	}
47}
48
49#[cfg(feature = "json")]
50macro_rules! try2 {
51	($e:expr) => {
52		match $e {
53			Some(v) => v,
54			None => return Ok(None),
55		}
56	};
57}
58
59#[derive(Debug, Clone, PartialEq, Eq, Hash)]
60pub enum Message {
61	Text(String),
62	Binary(Vec<u8>),
63}
64
65impl Message {
66	pub fn into_data(self) -> Vec<u8> {
67		match self {
68			Self::Text(t) => t.into(),
69			Self::Binary(b) => b,
70		}
71	}
72
73	pub fn to_text(&self) -> Result<&str, Utf8Error> {
74		match self {
75			Self::Text(t) => Ok(&t),
76			Self::Binary(b) => std::str::from_utf8(b),
77		}
78	}
79}
80
81impl From<String> for Message {
82	fn from(s: String) -> Self {
83		Self::Text(s)
84	}
85}
86
87impl From<&str> for Message {
88	fn from(s: &str) -> Self {
89		Self::Text(s.into())
90	}
91}
92
93impl From<Vec<u8>> for Message {
94	fn from(v: Vec<u8>) -> Self {
95		Self::Binary(v)
96	}
97}
98
99impl From<&[u8]> for Message {
100	fn from(v: &[u8]) -> Self {
101		Self::Binary(v.into())
102	}
103}
104
105impl From<Message> for ProtMessage {
106	fn from(m: Message) -> Self {
107		match m {
108			Message::Text(t) => Self::Text(t),
109			Message::Binary(b) => Self::Binary(b),
110		}
111	}
112}
113
114#[derive(Debug)]
115pub struct WebSocket {
116	inner: WebSocketStream<TokioIo<upgrade::Upgraded>>,
117}
118
119impl WebSocket {
120	pub async fn new(upgraded: upgrade::Upgraded) -> Self {
121		Self {
122			inner: WebSocketStream::from_raw_socket(
123				TokioIo::new(upgraded),
124				Role::Server,
125				None,
126			)
127			.await,
128		}
129	}
130
131	// used for tests
132	#[doc(hidden)]
133	pub fn from_raw(
134		inner: WebSocketStream<TokioIo<upgrade::Upgraded>>,
135	) -> Self {
136		Self { inner }
137	}
138
139	/// Handles Ping and Pong messages
140	///
141	/// never returns Error::ConnectionClose | Error::AlreadyClosed
142	pub async fn receive(&mut self) -> Result<Option<Message>, Error> {
143		// loop used to handle Message::Pong | Message::Ping
144		loop {
145			let res = self.inner.next().await.transpose();
146			return match res {
147				Ok(None) => Ok(None),
148				Ok(Some(ProtMessage::Text(t))) => Ok(Some(Message::Text(t))),
149				Ok(Some(ProtMessage::Binary(b))) => {
150					Ok(Some(Message::Binary(b)))
151				}
152				Ok(Some(ProtMessage::Ping(d))) => {
153					// respond with a pong
154					self.inner.send(ProtMessage::Pong(d)).await?;
155					// then listen for a new message
156					continue;
157				}
158				Ok(Some(ProtMessage::Pong(_))) => continue,
159				Ok(Some(ProtMessage::Close(_))) => Ok(None),
160				Ok(Some(ProtMessage::Frame(f))) => {
161					warn!("we received a websocket frame {:?}", f);
162					// Todod should we do something about this frame??
163					continue;
164				}
165				Err(Error::ConnectionClosed) | Err(Error::AlreadyClosed) => {
166					Ok(None)
167				}
168				Err(e) => Err(e),
169			};
170		}
171	}
172
173	pub async fn send<M>(&mut self, msg: M) -> Result<(), Error>
174	where
175		M: Into<Message>,
176	{
177		self.inner.send(msg.into().into()).await
178	}
179
180	pub async fn close(&mut self, code: CloseCode, reason: String) {
181		let _ = self
182			.inner
183			.send(ProtMessage::Close(Some(CloseFrame {
184				code,
185				reason: reason.into(),
186			})))
187			.await;
188		let _ = self.inner.close(None).await;
189		// close is close
190		// don't mind if you could send close or not
191	}
192
193	pub async fn ping(&mut self) -> Result<(), Error> {
194		self.inner.send(ProtMessage::Ping(vec![])).await
195	}
196
197	/// calls receive and then deserialize
198	#[cfg(feature = "json")]
199	#[cfg_attr(docsrs, doc(cfg(feature = "json")))]
200	pub async fn deserialize<D>(&mut self) -> Result<Option<D>, JsonError>
201	where
202		D: serde::de::DeserializeOwned,
203	{
204		let msg = try2!(self.receive().await?).into_data();
205		serde_json::from_slice(&msg)
206			.map(|d| Some(d))
207			.map_err(|e| e.into())
208	}
209
210	/// calls serialize then send
211	#[cfg(feature = "json")]
212	#[cfg_attr(docsrs, doc(cfg(feature = "json")))]
213	pub async fn serialize<S: ?Sized>(
214		&mut self,
215		value: &S,
216	) -> Result<(), JsonError>
217	where
218		S: serde::Serialize,
219	{
220		let v = serde_json::to_string(value)?;
221		self.send(v).await.map_err(|e| e.into())
222	}
223}
224
225#[cfg(feature = "json")]
226mod json_error {
227
228	use super::Error;
229	use std::fmt;
230
231	#[cfg_attr(docsrs, doc(cfg(feature = "json")))]
232	#[derive(Debug)]
233	pub enum JsonError {
234		ConnectionError(Error),
235		SerdeError(serde_json::Error),
236	}
237
238	impl From<Error> for JsonError {
239		fn from(e: Error) -> Self {
240			Self::ConnectionError(e)
241		}
242	}
243
244	impl From<serde_json::Error> for JsonError {
245		fn from(e: serde_json::Error) -> Self {
246			Self::SerdeError(e)
247		}
248	}
249
250	impl fmt::Display for JsonError {
251		fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
252			fmt::Debug::fmt(self, f)
253		}
254	}
255
256	impl std::error::Error for JsonError {
257		fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
258			match self {
259				Self::ConnectionError(e) => Some(e),
260				Self::SerdeError(e) => Some(e),
261			}
262		}
263	}
264}
265
266#[cfg(feature = "json")]
267pub use json_error::JsonError;
268
269impl<'a> Extractor<'a, WebSocket> for WebSocket {
270	type Error = Infallible;
271	type Prepared = ();
272
273	extractor_validate!();
274
275	extractor_prepare!();
276
277	extractor_extract!(<WebSocket> |extract| {
278		Ok(extract.request.take().unwrap())
279	});
280}