1#![forbid(unsafe_code)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4#[cfg(feature = "json")]
49mod json;
50#[cfg(not(target_arch = "wasm32"))]
51mod native;
52mod protocol;
53#[cfg(target_arch = "wasm32")]
54mod wasm;
55
56use std::{
57 pin::Pin,
58 task::{ready, Context, Poll},
59};
60
61#[cfg(not(target_arch = "wasm32"))]
62#[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
63pub use crate::native::HandshakeError;
64pub use crate::protocol::{CloseCode, Message};
65pub use bytes::Bytes;
66use futures_util::{Sink, SinkExt, Stream, StreamExt};
67use reqwest::{Client, ClientBuilder, IntoUrl, RequestBuilder};
68
69#[derive(Debug, thiserror::Error)]
71#[non_exhaustive]
72pub enum Error {
73 #[cfg(not(target_arch = "wasm32"))]
74 #[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
75 #[error("websocket upgrade failed")]
76 Handshake(#[from] HandshakeError),
77
78 #[error("reqwest error")]
79 Reqwest(#[from] reqwest::Error),
80
81 #[cfg(not(target_arch = "wasm32"))]
82 #[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
83 #[error("tungstenite error")]
84 Tungstenite(#[from] tungstenite::Error),
85
86 #[cfg(target_arch = "wasm32")]
87 #[cfg_attr(docsrs, doc(cfg(target_arch = "wasm32")))]
88 #[error("web_sys error")]
89 WebSys(#[from] wasm::WebSysError),
90
91 #[error("serde_json error")]
93 #[cfg(feature = "json")]
94 #[cfg_attr(docsrs, doc(cfg(feature = "json")))]
95 Json(#[from] serde_json::Error),
96}
97
98pub async fn websocket(url: impl IntoUrl) -> Result<WebSocket, Error> {
106 builder_http1_only(Client::builder())
107 .build()?
108 .get(url)
109 .upgrade()
110 .send()
111 .await?
112 .into_websocket()
113 .await
114}
115
116#[inline]
117#[cfg(not(target_arch = "wasm32"))]
118fn builder_http1_only(builder: ClientBuilder) -> ClientBuilder {
119 builder.http1_only()
120}
121
122#[inline]
123#[cfg(target_arch = "wasm32")]
124fn builder_http1_only(builder: ClientBuilder) -> ClientBuilder {
125 builder
126}
127
128pub trait RequestBuilderExt {
130 fn upgrade(self) -> UpgradedRequestBuilder;
135}
136
137impl RequestBuilderExt for RequestBuilder {
138 fn upgrade(self) -> UpgradedRequestBuilder {
139 UpgradedRequestBuilder::new(self)
140 }
141}
142
143pub struct UpgradedRequestBuilder {
146 inner: RequestBuilder,
147 protocols: Vec<String>,
148 #[cfg(not(target_arch = "wasm32"))]
149 web_socket_config: Option<tungstenite::protocol::WebSocketConfig>,
150}
151
152impl UpgradedRequestBuilder {
153 pub(crate) fn new(inner: RequestBuilder) -> Self {
154 Self {
155 inner,
156 protocols: vec![],
157 #[cfg(not(target_arch = "wasm32"))]
158 web_socket_config: None,
159 }
160 }
161
162 pub fn protocols<S: Into<String>>(mut self, protocols: impl IntoIterator<Item = S>) -> Self {
164 self.protocols = protocols.into_iter().map(Into::into).collect();
165
166 self
167 }
168
169 #[cfg(not(target_arch = "wasm32"))]
171 pub fn web_socket_config(mut self, config: tungstenite::protocol::WebSocketConfig) -> Self {
172 self.web_socket_config = Some(config);
173 self
174 }
175
176 pub async fn send(self) -> Result<UpgradeResponse, Error> {
178 #[cfg(not(target_arch = "wasm32"))]
179 let inner = native::send_request(self.inner, &self.protocols).await?;
180
181 #[cfg(target_arch = "wasm32")]
182 let inner = wasm::WebSysWebSocketStream::new(self.inner.build()?, &self.protocols).await?;
183
184 Ok(UpgradeResponse {
185 inner,
186 protocols: self.protocols,
187 #[cfg(not(target_arch = "wasm32"))]
188 web_socket_config: self.web_socket_config,
189 })
190 }
191}
192
193pub struct UpgradeResponse {
198 #[cfg(not(target_arch = "wasm32"))]
199 inner: native::WebSocketResponse,
200
201 #[cfg(target_arch = "wasm32")]
202 inner: wasm::WebSysWebSocketStream,
203
204 #[allow(dead_code)]
205 protocols: Vec<String>,
206
207 #[cfg(not(target_arch = "wasm32"))]
208 #[allow(dead_code)]
209 web_socket_config: Option<tungstenite::protocol::WebSocketConfig>,
210}
211
212#[cfg(not(target_arch = "wasm32"))]
213impl std::ops::Deref for UpgradeResponse {
214 type Target = reqwest::Response;
215
216 fn deref(&self) -> &Self::Target {
217 &self.inner.response
218 }
219}
220
221impl UpgradeResponse {
222 pub async fn into_websocket(self) -> Result<WebSocket, Error> {
225 #[cfg(not(target_arch = "wasm32"))]
226 let (inner, protocol) = self
227 .inner
228 .into_stream_and_protocol(self.protocols, self.web_socket_config)
229 .await?;
230
231 #[cfg(target_arch = "wasm32")]
232 let (inner, protocol) = {
233 let protocol = self.inner.protocol();
234 (self.inner, Some(protocol))
235 };
236
237 Ok(WebSocket { inner, protocol })
238 }
239
240 #[must_use]
242 #[cfg(not(target_arch = "wasm32"))]
243 pub fn into_inner(self) -> reqwest::Response {
244 self.inner.response
245 }
246}
247
248#[derive(Debug)]
251pub struct WebSocket {
252 #[cfg(not(target_arch = "wasm32"))]
253 inner: native::WebSocketStream,
254
255 #[cfg(target_arch = "wasm32")]
256 inner: wasm::WebSysWebSocketStream,
257
258 protocol: Option<String>,
259}
260
261impl WebSocket {
262 pub fn protocol(&self) -> Option<&str> {
264 self.protocol.as_deref()
265 }
266
267 pub async fn close(self, code: CloseCode, reason: Option<&str>) -> Result<(), Error> {
275 #[cfg(not(target_arch = "wasm32"))]
276 {
277 let mut inner = self.inner;
278 inner
279 .close(Some(tungstenite::protocol::CloseFrame {
280 code: code.into(),
281 reason: reason.unwrap_or_default().into(),
282 }))
283 .await?;
284 }
285
286 #[cfg(target_arch = "wasm32")]
287 self.inner.close(code.into(), reason.unwrap_or_default())?;
288
289 Ok(())
290 }
291}
292
293impl Stream for WebSocket {
294 type Item = Result<Message, Error>;
295
296 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
297 match ready!(self.inner.poll_next_unpin(cx)) {
298 None => Poll::Ready(None),
299 Some(Err(error)) => Poll::Ready(Some(Err(error.into()))),
300 Some(Ok(message)) => {
301 match message.try_into() {
302 Ok(message) => Poll::Ready(Some(Ok(message))),
303
304 #[cfg(target_arch = "wasm32")]
305 Err(e) => match e {},
306
307 #[cfg(not(target_arch = "wasm32"))]
308 Err(e) => {
309 panic!("Received an invalid frame: {e}");
311 }
312 }
313 }
314 }
315 }
316}
317
318impl Sink<Message> for WebSocket {
319 type Error = Error;
320
321 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
322 self.inner.poll_ready_unpin(cx).map_err(Into::into)
323 }
324
325 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
326 self.inner.start_send_unpin(item.into()).map_err(Into::into)
327 }
328
329 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
330 self.inner.poll_flush_unpin(cx).map_err(Into::into)
331 }
332
333 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
334 self.inner.poll_close_unpin(cx).map_err(Into::into)
335 }
336}
337
338#[cfg(test)]
339pub mod tests {
340 use futures_util::{SinkExt, TryStreamExt};
341 use reqwest::Client;
342 #[cfg(target_arch = "wasm32")]
343 use wasm_bindgen_test::wasm_bindgen_test;
344
345 #[cfg(target_arch = "wasm32")]
346 wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
347
348 use super::{websocket, CloseCode, Message, RequestBuilderExt, WebSocket};
349
350 async fn test_websocket(mut websocket: WebSocket) {
351 let text = "Hello, World!";
352 websocket.send(Message::Text(text.into())).await.unwrap();
353
354 while let Some(message) = websocket.try_next().await.unwrap() {
355 match message {
356 Message::Text(s) => {
357 if s == text {
358 return;
359 }
360 }
361 _ => {}
362 }
363 }
364
365 panic!("didn't receive text back");
366 }
367
368 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
369 #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
370 async fn test_with_request_builder() {
371 let websocket = Client::default()
372 .get("https://echo.websocket.org/")
373 .upgrade()
374 .send()
375 .await
376 .unwrap()
377 .into_websocket()
378 .await
379 .unwrap();
380
381 test_websocket(websocket).await;
382 }
383
384 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
385 #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
386 async fn test_shorthand() {
387 let websocket = websocket("https://echo.websocket.org/").await.unwrap();
388 test_websocket(websocket).await;
389 }
390
391 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
392 #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
393 async fn test_with_ws_scheme() {
394 let websocket = websocket("wss://echo.websocket.org/").await.unwrap();
395
396 test_websocket(websocket).await;
397 }
398
399 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
400 #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
401 async fn test_close() {
402 let websocket = websocket("https://echo.websocket.org/").await.unwrap();
403 websocket
404 .close(CloseCode::Normal, Some("test"))
405 .await
406 .expect("close returned an error");
407 }
408
409 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
410 #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
411 async fn test_send_close_frame() {
412 let mut websocket = websocket("https://echo.websocket.org/").await.unwrap();
413 websocket
414 .send(Message::Close {
415 code: CloseCode::Normal,
416 reason: "Can you please reply with a close frame?".into(),
417 })
418 .await
419 .unwrap();
420
421 let mut close_received = false;
422 while let Some(message) = websocket.try_next().await.unwrap() {
423 match message {
424 Message::Close { code, .. } => {
425 assert_eq!(code, CloseCode::Normal);
426 close_received = true;
427 }
428 _ => {}
429 }
430 }
431
432 assert!(close_received, "No close frame was received");
433 }
434
435 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
436 #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
437 #[ignore = "https://echo.websocket.org/ ignores subprotocols"]
438 async fn test_with_subprotocol() {
439 let websocket = Client::default()
440 .get("https://echo.websocket.org/")
441 .upgrade()
442 .protocols(["chat"])
443 .send()
444 .await
445 .unwrap()
446 .into_websocket()
447 .await
448 .unwrap();
449
450 assert_eq!(websocket.protocol(), Some("chat"));
451
452 test_websocket(websocket).await;
453 }
454
455 #[test]
456 fn close_code_from_u16() {
457 let byte = 1008u16;
458 assert_eq!(CloseCode::from(byte), CloseCode::Policy);
459 }
460
461 #[test]
462 fn close_code_into_u16() {
463 let text = CloseCode::Away;
464 let byte: u16 = text.into();
465 assert_eq!(byte, 1001u16);
466 assert_eq!(u16::from(text), 1001u16);
467 }
468}