1#![forbid(unsafe_code)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![allow(clippy::result_large_err)]
5
6#[cfg(feature = "json")]
51mod json;
52#[cfg(not(target_arch = "wasm32"))]
53mod native;
54mod protocol;
55#[cfg(target_arch = "wasm32")]
56mod wasm;
57
58use std::{
59 pin::Pin,
60 task::{ready, Context, Poll},
61};
62
63#[cfg(not(target_arch = "wasm32"))]
64#[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
65pub use crate::native::HandshakeError;
66pub use crate::protocol::{CloseCode, Message};
67pub use bytes::Bytes;
68use futures_util::{Sink, SinkExt, Stream, StreamExt};
69use reqwest::{Client, ClientBuilder, IntoUrl, RequestBuilder};
70
71#[derive(Debug, thiserror::Error)]
73#[non_exhaustive]
74pub enum Error {
75 #[cfg(not(target_arch = "wasm32"))]
76 #[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
77 #[error("websocket upgrade failed")]
78 Handshake(#[from] HandshakeError),
79
80 #[error("reqwest error")]
81 Reqwest(#[from] reqwest::Error),
82
83 #[cfg(not(target_arch = "wasm32"))]
84 #[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
85 #[error("tungstenite error")]
86 Tungstenite(#[from] tungstenite::Error),
87
88 #[cfg(target_arch = "wasm32")]
89 #[cfg_attr(docsrs, doc(cfg(target_arch = "wasm32")))]
90 #[error("web_sys error")]
91 WebSys(#[from] wasm::WebSysError),
92
93 #[error("serde_json error")]
95 #[cfg(feature = "json")]
96 #[cfg_attr(docsrs, doc(cfg(feature = "json")))]
97 Json(#[from] serde_json::Error),
98}
99
100pub async fn websocket(url: impl IntoUrl) -> Result<WebSocket, Error> {
108 builder_http1_only(Client::builder())
109 .build()?
110 .get(url)
111 .upgrade()
112 .send()
113 .await?
114 .into_websocket()
115 .await
116}
117
118#[inline]
119#[cfg(not(target_arch = "wasm32"))]
120fn builder_http1_only(builder: ClientBuilder) -> ClientBuilder {
121 builder.http1_only()
122}
123
124#[inline]
125#[cfg(target_arch = "wasm32")]
126fn builder_http1_only(builder: ClientBuilder) -> ClientBuilder {
127 builder
128}
129
130pub trait RequestBuilderExt {
132 fn upgrade(self) -> UpgradedRequestBuilder;
137}
138
139impl RequestBuilderExt for RequestBuilder {
140 fn upgrade(self) -> UpgradedRequestBuilder {
141 UpgradedRequestBuilder::new(self)
142 }
143}
144
145pub struct UpgradedRequestBuilder {
148 inner: RequestBuilder,
149 protocols: Vec<String>,
150 #[cfg(not(target_arch = "wasm32"))]
151 web_socket_config: Option<tungstenite::protocol::WebSocketConfig>,
152}
153
154impl UpgradedRequestBuilder {
155 pub(crate) fn new(inner: RequestBuilder) -> Self {
156 Self {
157 inner,
158 protocols: vec![],
159 #[cfg(not(target_arch = "wasm32"))]
160 web_socket_config: None,
161 }
162 }
163
164 pub fn protocols<S: Into<String>>(mut self, protocols: impl IntoIterator<Item = S>) -> Self {
166 self.protocols = protocols.into_iter().map(Into::into).collect();
167
168 self
169 }
170
171 #[cfg(not(target_arch = "wasm32"))]
173 pub fn web_socket_config(mut self, config: tungstenite::protocol::WebSocketConfig) -> Self {
174 self.web_socket_config = Some(config);
175 self
176 }
177
178 pub async fn send(self) -> Result<UpgradeResponse, Error> {
180 #[cfg(not(target_arch = "wasm32"))]
181 let inner = native::send_request(self.inner, &self.protocols).await?;
182
183 #[cfg(target_arch = "wasm32")]
184 let inner = wasm::WebSysWebSocketStream::new(self.inner.build()?, &self.protocols).await?;
185
186 Ok(UpgradeResponse {
187 inner,
188 protocols: self.protocols,
189 #[cfg(not(target_arch = "wasm32"))]
190 web_socket_config: self.web_socket_config,
191 })
192 }
193}
194
195pub struct UpgradeResponse {
200 #[cfg(not(target_arch = "wasm32"))]
201 inner: native::WebSocketResponse,
202
203 #[cfg(target_arch = "wasm32")]
204 inner: wasm::WebSysWebSocketStream,
205
206 #[allow(dead_code)]
207 protocols: Vec<String>,
208
209 #[cfg(not(target_arch = "wasm32"))]
210 #[allow(dead_code)]
211 web_socket_config: Option<tungstenite::protocol::WebSocketConfig>,
212}
213
214#[cfg(not(target_arch = "wasm32"))]
215impl std::ops::Deref for UpgradeResponse {
216 type Target = reqwest::Response;
217
218 fn deref(&self) -> &Self::Target {
219 &self.inner.response
220 }
221}
222
223impl UpgradeResponse {
224 pub async fn into_websocket(self) -> Result<WebSocket, Error> {
227 #[cfg(not(target_arch = "wasm32"))]
228 let (inner, protocol) = self
229 .inner
230 .into_stream_and_protocol(self.protocols, self.web_socket_config)
231 .await?;
232
233 #[cfg(target_arch = "wasm32")]
234 let (inner, protocol) = {
235 let protocol = self.inner.protocol();
236 (self.inner, Some(protocol))
237 };
238
239 Ok(WebSocket { inner, protocol })
240 }
241
242 #[must_use]
244 #[cfg(not(target_arch = "wasm32"))]
245 pub fn into_inner(self) -> reqwest::Response {
246 self.inner.response
247 }
248}
249
250#[derive(Debug)]
253pub struct WebSocket {
254 #[cfg(not(target_arch = "wasm32"))]
255 inner: native::WebSocketStream,
256
257 #[cfg(target_arch = "wasm32")]
258 inner: wasm::WebSysWebSocketStream,
259
260 protocol: Option<String>,
261}
262
263impl WebSocket {
264 pub fn protocol(&self) -> Option<&str> {
266 self.protocol.as_deref()
267 }
268
269 pub async fn close(self, code: CloseCode, reason: Option<&str>) -> Result<(), Error> {
277 #[cfg(not(target_arch = "wasm32"))]
278 {
279 let mut inner = self.inner;
280 inner
281 .close(Some(tungstenite::protocol::CloseFrame {
282 code: code.into(),
283 reason: reason.unwrap_or_default().into(),
284 }))
285 .await?;
286 }
287
288 #[cfg(target_arch = "wasm32")]
289 self.inner.close(code.into(), reason.unwrap_or_default())?;
290
291 Ok(())
292 }
293}
294
295impl Stream for WebSocket {
296 type Item = Result<Message, Error>;
297
298 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
299 match ready!(self.inner.poll_next_unpin(cx)) {
300 None => Poll::Ready(None),
301 Some(Err(error)) => Poll::Ready(Some(Err(error.into()))),
302 Some(Ok(message)) => {
303 match message.try_into() {
304 Ok(message) => Poll::Ready(Some(Ok(message))),
305
306 #[cfg(target_arch = "wasm32")]
307 Err(e) => match e {},
308
309 #[cfg(not(target_arch = "wasm32"))]
310 Err(e) => {
311 panic!("Received an invalid frame: {e}");
313 }
314 }
315 }
316 }
317 }
318}
319
320impl Sink<Message> for WebSocket {
321 type Error = Error;
322
323 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
324 self.inner.poll_ready_unpin(cx).map_err(Into::into)
325 }
326
327 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
328 self.inner.start_send_unpin(item.into()).map_err(Into::into)
329 }
330
331 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
332 self.inner.poll_flush_unpin(cx).map_err(Into::into)
333 }
334
335 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
336 self.inner.poll_close_unpin(cx).map_err(Into::into)
337 }
338}
339
340#[cfg(test)]
341pub mod tests {
342 use futures_util::{SinkExt, TryStreamExt};
343 use reqwest::Client;
344 #[cfg(target_arch = "wasm32")]
345 use wasm_bindgen_test::wasm_bindgen_test;
346
347 #[cfg(target_arch = "wasm32")]
348 wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
349
350 use super::{websocket, CloseCode, Message, RequestBuilderExt, WebSocket};
351
352 async fn test_websocket(mut websocket: WebSocket) {
353 let text = "Hello, World!";
354 websocket.send(Message::Text(text.into())).await.unwrap();
355
356 while let Some(message) = websocket.try_next().await.unwrap() {
357 match message {
358 Message::Text(s) => {
359 if s == text {
360 return;
361 }
362 }
363 _ => {}
364 }
365 }
366
367 panic!("didn't receive text back");
368 }
369
370 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
371 #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
372 async fn test_with_request_builder() {
373 let websocket = Client::default()
374 .get("https://echo.websocket.org/")
375 .upgrade()
376 .send()
377 .await
378 .unwrap()
379 .into_websocket()
380 .await
381 .unwrap();
382
383 test_websocket(websocket).await;
384 }
385
386 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
387 #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
388 async fn test_shorthand() {
389 let websocket = websocket("https://echo.websocket.org/").await.unwrap();
390 test_websocket(websocket).await;
391 }
392
393 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
394 #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
395 async fn test_with_ws_scheme() {
396 let websocket = websocket("wss://echo.websocket.org/").await.unwrap();
397
398 test_websocket(websocket).await;
399 }
400
401 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
402 #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
403 async fn test_close() {
404 let websocket = websocket("https://echo.websocket.org/").await.unwrap();
405 websocket
406 .close(CloseCode::Normal, Some("test"))
407 .await
408 .expect("close returned an error");
409 }
410
411 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
412 #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
413 async fn test_send_close_frame() {
414 let mut websocket = websocket("https://echo.websocket.org/").await.unwrap();
415 websocket
416 .send(Message::Close {
417 code: CloseCode::Normal,
418 reason: "Can you please reply with a close frame?".into(),
419 })
420 .await
421 .unwrap();
422
423 let mut close_received = false;
424 while let Some(message) = websocket.try_next().await.unwrap() {
425 match message {
426 Message::Close { code, .. } => {
427 assert_eq!(code, CloseCode::Normal);
428 close_received = true;
429 }
430 _ => {}
431 }
432 }
433
434 assert!(close_received, "No close frame was received");
435 }
436
437 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
438 #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
439 #[ignore = "https://echo.websocket.org/ ignores subprotocols"]
440 async fn test_with_subprotocol() {
441 let websocket = Client::default()
442 .get("https://echo.websocket.org/")
443 .upgrade()
444 .protocols(["chat"])
445 .send()
446 .await
447 .unwrap()
448 .into_websocket()
449 .await
450 .unwrap();
451
452 assert_eq!(websocket.protocol(), Some("chat"));
453
454 test_websocket(websocket).await;
455 }
456
457 #[test]
458 fn close_code_from_u16() {
459 let byte = 1008u16;
460 assert_eq!(CloseCode::from(byte), CloseCode::Policy);
461 }
462
463 #[test]
464 fn close_code_into_u16() {
465 let text = CloseCode::Away;
466 let byte: u16 = text.into();
467 assert_eq!(byte, 1001u16);
468 assert_eq!(u16::from(text), 1001u16);
469 }
470}