1#![forbid(unsafe_code)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4#[cfg(feature = "json")]
49mod json;
50#[cfg(not(target_arch = "wasm32"))]
51mod native;
52mod protocol;
53#[cfg(target_arch = "wasm32")]
54mod wasm;
55
56use std::{
57 pin::Pin,
58 task::{ready, Context, Poll},
59};
60
61#[cfg(not(target_arch = "wasm32"))]
62#[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
63pub use crate::native::HandshakeError;
64pub use crate::protocol::{CloseCode, Message};
65use futures_util::{Sink, SinkExt, Stream, StreamExt};
66use reqwest::{Client, ClientBuilder, IntoUrl, RequestBuilder};
67
68#[derive(Debug, thiserror::Error)]
70#[non_exhaustive]
71pub enum Error {
72 #[cfg(not(target_arch = "wasm32"))]
73 #[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
74 #[error("websocket upgrade failed")]
75 Handshake(#[from] HandshakeError),
76
77 #[error("reqwest error")]
78 Reqwest(#[from] reqwest::Error),
79
80 #[cfg(not(target_arch = "wasm32"))]
81 #[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
82 #[error("tungstenite error")]
83 Tungstenite(#[from] tungstenite::Error),
84
85 #[cfg(target_arch = "wasm32")]
86 #[cfg_attr(docsrs, doc(cfg(target_arch = "wasm32")))]
87 #[error("web_sys error")]
88 WebSys(#[from] wasm::WebSysError),
89
90 #[error("serde_json error")]
92 #[cfg(feature = "json")]
93 #[cfg_attr(docsrs, doc(cfg(feature = "json")))]
94 Json(#[from] serde_json::Error),
95}
96
97pub async fn websocket(url: impl IntoUrl) -> Result<WebSocket, Error> {
105 builder_http1_only(Client::builder())
106 .build()?
107 .get(url)
108 .upgrade()
109 .send()
110 .await?
111 .into_websocket()
112 .await
113}
114
115#[inline]
116#[cfg(not(target_arch = "wasm32"))]
117fn builder_http1_only(builder: ClientBuilder) -> ClientBuilder {
118 builder.http1_only()
119}
120
121#[inline]
122#[cfg(target_arch = "wasm32")]
123fn builder_http1_only(builder: ClientBuilder) -> ClientBuilder {
124 builder
125}
126
127pub trait RequestBuilderExt {
129 fn upgrade(self) -> UpgradedRequestBuilder;
134}
135
136impl RequestBuilderExt for RequestBuilder {
137 fn upgrade(self) -> UpgradedRequestBuilder {
138 UpgradedRequestBuilder::new(self)
139 }
140}
141
142pub struct UpgradedRequestBuilder {
145 inner: RequestBuilder,
146 protocols: Vec<String>,
147 #[cfg(not(target_arch = "wasm32"))]
148 web_socket_config: Option<tungstenite::protocol::WebSocketConfig>,
149}
150
151impl UpgradedRequestBuilder {
152 pub(crate) fn new(inner: RequestBuilder) -> Self {
153 Self {
154 inner,
155 protocols: vec![],
156 #[cfg(not(target_arch = "wasm32"))]
157 web_socket_config: None,
158 }
159 }
160
161 pub fn protocols<S: Into<String>>(mut self, protocols: impl IntoIterator<Item = S>) -> Self {
163 self.protocols = protocols.into_iter().map(Into::into).collect();
164
165 self
166 }
167
168 #[cfg(not(target_arch = "wasm32"))]
170 pub fn web_socket_config(mut self, config: tungstenite::protocol::WebSocketConfig) -> Self {
171 self.web_socket_config = Some(config);
172 self
173 }
174
175 pub async fn send(self) -> Result<UpgradeResponse, Error> {
177 #[cfg(not(target_arch = "wasm32"))]
178 let inner = native::send_request(self.inner, &self.protocols).await?;
179
180 #[cfg(target_arch = "wasm32")]
181 let inner = wasm::WebSysWebSocketStream::new(self.inner.build()?, &self.protocols).await?;
182
183 Ok(UpgradeResponse {
184 inner,
185 protocols: self.protocols,
186 #[cfg(not(target_arch = "wasm32"))]
187 web_socket_config: self.web_socket_config,
188 })
189 }
190}
191
192pub struct UpgradeResponse {
197 #[cfg(not(target_arch = "wasm32"))]
198 inner: native::WebSocketResponse,
199
200 #[cfg(target_arch = "wasm32")]
201 inner: wasm::WebSysWebSocketStream,
202
203 #[allow(dead_code)]
204 protocols: Vec<String>,
205
206 #[cfg(not(target_arch = "wasm32"))]
207 #[allow(dead_code)]
208 web_socket_config: Option<tungstenite::protocol::WebSocketConfig>,
209}
210
211#[cfg(not(target_arch = "wasm32"))]
212impl std::ops::Deref for UpgradeResponse {
213 type Target = reqwest::Response;
214
215 fn deref(&self) -> &Self::Target {
216 &self.inner.response
217 }
218}
219
220impl UpgradeResponse {
221 pub async fn into_websocket(self) -> Result<WebSocket, Error> {
224 #[cfg(not(target_arch = "wasm32"))]
225 let (inner, protocol) = self
226 .inner
227 .into_stream_and_protocol(self.protocols, self.web_socket_config)
228 .await?;
229
230 #[cfg(target_arch = "wasm32")]
231 let (inner, protocol) = {
232 let protocol = self.inner.protocol();
233 (self.inner, Some(protocol))
234 };
235
236 Ok(WebSocket { inner, protocol })
237 }
238
239 #[must_use]
241 #[cfg(not(target_arch = "wasm32"))]
242 pub fn into_inner(self) -> reqwest::Response {
243 self.inner.response
244 }
245}
246
247#[derive(Debug)]
250pub struct WebSocket {
251 #[cfg(not(target_arch = "wasm32"))]
252 inner: native::WebSocketStream,
253
254 #[cfg(target_arch = "wasm32")]
255 inner: wasm::WebSysWebSocketStream,
256
257 protocol: Option<String>,
258}
259
260impl WebSocket {
261 pub fn protocol(&self) -> Option<&str> {
263 self.protocol.as_deref()
264 }
265
266 pub async fn close(self, code: CloseCode, reason: Option<&str>) -> Result<(), Error> {
274 #[cfg(not(target_arch = "wasm32"))]
275 {
276 let mut inner = self.inner;
277 inner
278 .close(Some(tungstenite::protocol::CloseFrame {
279 code: code.into(),
280 reason: reason.unwrap_or_default().into(),
281 }))
282 .await?;
283 }
284
285 #[cfg(target_arch = "wasm32")]
286 self.inner.close(code.into(), reason.unwrap_or_default())?;
287
288 Ok(())
289 }
290}
291
292impl Stream for WebSocket {
293 type Item = Result<Message, Error>;
294
295 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
296 match ready!(self.inner.poll_next_unpin(cx)) {
297 None => Poll::Ready(None),
298 Some(Err(error)) => Poll::Ready(Some(Err(error.into()))),
299 Some(Ok(message)) => {
300 match message.try_into() {
301 Ok(message) => Poll::Ready(Some(Ok(message))),
302
303 #[cfg(target_arch = "wasm32")]
304 Err(e) => match e {},
305
306 #[cfg(not(target_arch = "wasm32"))]
307 Err(e) => {
308 panic!("Received an invalid frame: {e}");
310 }
311 }
312 }
313 }
314 }
315}
316
317impl Sink<Message> for WebSocket {
318 type Error = Error;
319
320 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
321 self.inner.poll_ready_unpin(cx).map_err(Into::into)
322 }
323
324 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
325 self.inner.start_send_unpin(item.into()).map_err(Into::into)
326 }
327
328 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
329 self.inner.poll_flush_unpin(cx).map_err(Into::into)
330 }
331
332 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
333 self.inner.poll_close_unpin(cx).map_err(Into::into)
334 }
335}
336
337#[cfg(test)]
338pub mod tests {
339 use futures_util::{SinkExt, TryStreamExt};
340 use reqwest::Client;
341 #[cfg(target_arch = "wasm32")]
342 use wasm_bindgen_test::wasm_bindgen_test;
343
344 #[cfg(target_arch = "wasm32")]
345 wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
346
347 use super::{websocket, CloseCode, Message, RequestBuilderExt, WebSocket};
348
349 async fn test_websocket(mut websocket: WebSocket) {
350 let text = "Hello, World!";
351 websocket
352 .send(Message::Text(text.to_owned()))
353 .await
354 .unwrap();
355
356 while let Some(message) = websocket.try_next().await.unwrap() {
357 match message {
358 Message::Text(s) => {
359 if s == text {
360 return;
361 }
362 }
363 _ => {}
364 }
365 }
366
367 panic!("didn't receive text back");
368 }
369
370 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
371 #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
372 async fn test_with_request_builder() {
373 let websocket = Client::default()
374 .get("https://echo.websocket.org/")
375 .upgrade()
376 .send()
377 .await
378 .unwrap()
379 .into_websocket()
380 .await
381 .unwrap();
382
383 test_websocket(websocket).await;
384 }
385
386 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
387 #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
388 async fn test_shorthand() {
389 let websocket = websocket("https://echo.websocket.org/").await.unwrap();
390 test_websocket(websocket).await;
391 }
392
393 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
394 #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
395 async fn test_with_ws_scheme() {
396 let websocket = websocket("wss://echo.websocket.org/").await.unwrap();
397
398 test_websocket(websocket).await;
399 }
400
401 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
402 #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
403 async fn test_close() {
404 let websocket = websocket("https://echo.websocket.org/").await.unwrap();
405 websocket
406 .close(CloseCode::Normal, Some("test"))
407 .await
408 .expect("close returned an error");
409 }
410
411 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
412 #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
413 async fn test_send_close_frame() {
414 let mut websocket = websocket("https://echo.websocket.org/").await.unwrap();
415 websocket
416 .send(Message::Close {
417 code: CloseCode::Normal,
418 reason: "Can you please reply with a close frame?".into(),
419 })
420 .await
421 .unwrap();
422
423 let mut close_received = false;
424 while let Some(message) = websocket.try_next().await.unwrap() {
425 match message {
426 Message::Close { code, .. } => {
427 assert_eq!(code, CloseCode::Normal);
428 close_received = true;
429 }
430 _ => {}
431 }
432 }
433
434 assert!(close_received, "No close frame was received");
435 }
436
437 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
438 #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
439 #[ignore = "https://echo.websocket.org/ ignores subprotocols"]
440 async fn test_with_subprotocol() {
441 let websocket = Client::default()
442 .get("https://echo.websocket.org/")
443 .upgrade()
444 .protocols(["chat"])
445 .send()
446 .await
447 .unwrap()
448 .into_websocket()
449 .await
450 .unwrap();
451
452 assert_eq!(websocket.protocol(), Some("chat"));
453
454 test_websocket(websocket).await;
455 }
456
457 #[test]
458 fn close_code_from_u16() {
459 let byte = 1008u16;
460 assert_eq!(CloseCode::from(byte), CloseCode::Policy);
461 }
462
463 #[test]
464 fn close_code_into_u16() {
465 let text = CloseCode::Away;
466 let byte: u16 = text.into();
467 assert_eq!(byte, 1001u16);
468 assert_eq!(u16::from(text), 1001u16);
469 }
470}