1#![forbid(unsafe_code)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![allow(clippy::result_large_err)]
5
6#[cfg(feature = "json")]
51mod json;
52#[cfg(feature = "middleware")]
53mod middleware;
54#[cfg(not(target_arch = "wasm32"))]
55mod native;
56mod protocol;
57#[cfg(target_arch = "wasm32")]
58mod wasm;
59
60use std::{
61 future::Future,
62 pin::Pin,
63 task::{ready, Context, Poll},
64};
65
66#[cfg(not(target_arch = "wasm32"))]
67#[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
68pub use crate::native::HandshakeError;
69pub use crate::protocol::{CloseCode, Message};
70pub use bytes::Bytes;
71use futures_util::{Sink, SinkExt, Stream, StreamExt};
72use reqwest::IntoUrl;
73
74#[derive(Debug, thiserror::Error)]
76#[non_exhaustive]
77pub enum Error {
78 #[cfg(not(target_arch = "wasm32"))]
79 #[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
80 #[error("websocket upgrade failed")]
81 Handshake(#[from] HandshakeError),
82
83 #[error("reqwest error")]
84 Reqwest(#[from] reqwest::Error),
85
86 #[cfg(not(target_arch = "wasm32"))]
87 #[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
88 #[error("tungstenite error")]
89 Tungstenite(#[from] tungstenite::Error),
90
91 #[cfg(target_arch = "wasm32")]
92 #[cfg_attr(docsrs, doc(cfg(target_arch = "wasm32")))]
93 #[error("web_sys error")]
94 WebSys(#[from] wasm::WebSysError),
95
96 #[cfg(feature = "json")]
98 #[cfg_attr(docsrs, doc(cfg(feature = "json")))]
99 #[error("serde_json error")]
100 Json(#[from] serde_json::Error),
101
102 #[cfg(feature = "middleware")]
103 #[error("reqwest_middleware error")]
104 ReqwestMiddleware(#[from] reqwest_middleware::Error),
105}
106
107pub async fn websocket(url: impl IntoUrl) -> Result<WebSocket, Error> {
115 builder_http1_only(reqwest::Client::builder())
116 .build()?
117 .get(url)
118 .upgrade()
119 .send()
120 .await?
121 .into_websocket()
122 .await
123}
124
125#[inline]
126#[cfg(not(target_arch = "wasm32"))]
127fn builder_http1_only(builder: reqwest::ClientBuilder) -> reqwest::ClientBuilder {
128 builder.http1_only()
129}
130
131#[inline]
132#[cfg(target_arch = "wasm32")]
133fn builder_http1_only(builder: reqwest::ClientBuilder) -> reqwest::ClientBuilder {
134 builder
135}
136
137pub trait Client {
145 fn execute(
146 &self,
147 request: reqwest::Request,
148 ) -> impl Future<Output = Result<reqwest::Response, Error>> + '_;
149}
150
151impl Client for reqwest::Client {
152 async fn execute(&self, request: reqwest::Request) -> Result<reqwest::Response, Error> {
153 self.execute(request).await.map_err(Into::into)
154 }
155}
156
157pub trait RequestBuilder {
162 type Client: Client;
163
164 fn build_split(self) -> (Self::Client, Result<reqwest::Request, Error>);
165}
166
167impl RequestBuilder for reqwest::RequestBuilder {
168 type Client = reqwest::Client;
169
170 fn build_split(self) -> (Self::Client, Result<reqwest::Request, Error>) {
171 let (client, request) = reqwest::RequestBuilder::build_split(self);
172 (client, request.map_err(Into::into))
173 }
174}
175
176pub trait Upgrade: Sized {
180 fn upgrade(self) -> Upgraded<Self>;
185}
186
187impl<R> Upgrade for R
188where
189 R: RequestBuilder,
190{
191 fn upgrade(self) -> Upgraded<Self> {
192 Upgraded::new(self)
193 }
194}
195
196pub struct Upgraded<R> {
199 inner: R,
200 protocols: Vec<String>,
201 #[cfg(not(target_arch = "wasm32"))]
202 web_socket_config: Option<tungstenite::protocol::WebSocketConfig>,
203}
204
205impl<R> Upgraded<R>
206where
207 R: RequestBuilder,
208{
209 pub(crate) fn new(inner: R) -> Self {
210 Self {
211 inner,
212 protocols: vec![],
213 #[cfg(not(target_arch = "wasm32"))]
214 web_socket_config: None,
215 }
216 }
217
218 pub fn protocols<S: Into<String>>(mut self, protocols: impl IntoIterator<Item = S>) -> Self {
220 self.protocols = protocols.into_iter().map(Into::into).collect();
221
222 self
223 }
224
225 #[cfg(not(target_arch = "wasm32"))]
227 pub fn web_socket_config(mut self, config: tungstenite::protocol::WebSocketConfig) -> Self {
228 self.web_socket_config = Some(config);
229 self
230 }
231
232 pub async fn send(self) -> Result<UpgradeResponse, Error> {
234 #[cfg(not(target_arch = "wasm32"))]
235 let inner = native::send_request(self.inner, &self.protocols).await?;
236
237 #[cfg(target_arch = "wasm32")]
238 let inner = {
239 let request = self.inner.build_split().1?;
240 wasm::WebSysWebSocketStream::new(request, &self.protocols).await?
241 };
242
243 Ok(UpgradeResponse {
244 inner,
245 protocols: self.protocols,
246 #[cfg(not(target_arch = "wasm32"))]
247 web_socket_config: self.web_socket_config,
248 })
249 }
250}
251
252pub struct UpgradeResponse {
257 #[cfg(not(target_arch = "wasm32"))]
258 inner: native::WebSocketResponse,
259
260 #[cfg(target_arch = "wasm32")]
261 inner: wasm::WebSysWebSocketStream,
262
263 #[allow(dead_code)]
264 protocols: Vec<String>,
265
266 #[cfg(not(target_arch = "wasm32"))]
267 #[allow(dead_code)]
268 web_socket_config: Option<tungstenite::protocol::WebSocketConfig>,
269}
270
271#[cfg(not(target_arch = "wasm32"))]
272impl std::ops::Deref for UpgradeResponse {
273 type Target = reqwest::Response;
274
275 fn deref(&self) -> &Self::Target {
276 &self.inner.response
277 }
278}
279
280impl UpgradeResponse {
281 pub async fn into_websocket(self) -> Result<WebSocket, Error> {
284 #[cfg(not(target_arch = "wasm32"))]
285 let (inner, protocol) = self
286 .inner
287 .into_stream_and_protocol(self.protocols, self.web_socket_config)
288 .await?;
289
290 #[cfg(target_arch = "wasm32")]
291 let (inner, protocol) = {
292 let protocol = self.inner.protocol();
293 (self.inner, Some(protocol))
294 };
295
296 Ok(WebSocket { inner, protocol })
297 }
298
299 #[must_use]
301 #[cfg(not(target_arch = "wasm32"))]
302 pub fn into_inner(self) -> reqwest::Response {
303 self.inner.response
304 }
305}
306
307#[derive(Debug)]
310pub struct WebSocket {
311 #[cfg(not(target_arch = "wasm32"))]
312 inner: native::WebSocketStream,
313
314 #[cfg(target_arch = "wasm32")]
315 inner: wasm::WebSysWebSocketStream,
316
317 protocol: Option<String>,
318}
319
320impl WebSocket {
321 pub fn protocol(&self) -> Option<&str> {
323 self.protocol.as_deref()
324 }
325
326 pub async fn close(self, code: CloseCode, reason: Option<&str>) -> Result<(), Error> {
334 #[cfg(not(target_arch = "wasm32"))]
335 {
336 let mut inner = self.inner;
337 inner
338 .close(Some(tungstenite::protocol::CloseFrame {
339 code: code.into(),
340 reason: reason.unwrap_or_default().into(),
341 }))
342 .await?;
343 }
344
345 #[cfg(target_arch = "wasm32")]
346 self.inner.close(code.into(), reason.unwrap_or_default())?;
347
348 Ok(())
349 }
350}
351
352impl Stream for WebSocket {
353 type Item = Result<Message, Error>;
354
355 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
356 match ready!(self.inner.poll_next_unpin(cx)) {
357 None => Poll::Ready(None),
358 Some(Err(error)) => Poll::Ready(Some(Err(error.into()))),
359 Some(Ok(message)) => {
360 match message.try_into() {
361 Ok(message) => Poll::Ready(Some(Ok(message))),
362
363 #[cfg(target_arch = "wasm32")]
364 Err(e) => match e {},
365
366 #[cfg(not(target_arch = "wasm32"))]
367 Err(e) => {
368 panic!("Received an invalid frame: {e}");
370 }
371 }
372 }
373 }
374 }
375}
376
377impl Sink<Message> for WebSocket {
378 type Error = Error;
379
380 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
381 self.inner.poll_ready_unpin(cx).map_err(Into::into)
382 }
383
384 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
385 self.inner.start_send_unpin(item.into()).map_err(Into::into)
386 }
387
388 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
389 self.inner.poll_flush_unpin(cx).map_err(Into::into)
390 }
391
392 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
393 self.inner.poll_close_unpin(cx).map_err(Into::into)
394 }
395}
396
397#[cfg(test)]
398mod tests {
399 use futures_util::{SinkExt, TryStreamExt};
400 use reqwest::Client;
401 #[cfg(target_arch = "wasm32")]
402 use wasm_bindgen_test::wasm_bindgen_test;
403
404 use crate::{websocket, CloseCode, Message, Upgrade, WebSocket};
405
406 #[cfg(target_arch = "wasm32")]
407 wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
408
409 #[cfg(not(target_arch = "wasm32"))]
410 #[derive(Debug)]
411 pub struct TestServer {
412 shutdown_sender: Option<tokio::sync::oneshot::Sender<()>>,
413 http_url: String,
414 ws_url: String,
415 }
416
417 #[cfg(not(target_arch = "wasm32"))]
418 impl TestServer {
419 pub async fn new() -> Self {
420 async fn handle_connection(mut socket: axum::extract::ws::WebSocket) {
421 if let Some(protocol) = socket.protocol() {
422 if let Ok(protocol) = protocol.to_str() {
423 println!("server/protocol: {protocol:?}");
424 if let Err(error) = socket
425 .send(axum::extract::ws::Message::Text(
426 format!("protocol: {protocol}").into(),
427 ))
428 .await
429 {
430 eprintln!("server/send: {error}");
431 return;
432 }
433 } else {
434 println!("server/protocol: could not convert to utf-8");
435 }
436 }
437
438 while let Some(message) = socket.recv().await {
439 match message {
440 Ok(message) => match &message {
441 axum::extract::ws::Message::Text(_)
442 | axum::extract::ws::Message::Binary(_) => {
443 if let Err(error) = socket.send(message).await {
444 eprintln!("server/send: {error}");
445 break;
446 }
447 }
448 _ => {}
449 },
450 Err(error) => {
451 eprintln!("server/recv: {error}");
452 break;
453 }
454 }
455 }
456 }
457
458 let (shutdown_sender, shutdown_receiver) = tokio::sync::oneshot::channel();
459 let listener = tokio::net::TcpListener::bind(("localhost", 0))
460 .await
461 .unwrap();
462 let port = listener.local_addr().unwrap().port();
463 let app = axum::Router::new().route(
464 "/",
465 axum::routing::any(|ws: axum::extract::ws::WebSocketUpgrade| async move {
466 ws.protocols(["chat"]).on_upgrade(handle_connection)
467 }),
468 );
469
470 let _join_handle = tokio::spawn(async move {
472 axum::serve(listener, app)
473 .with_graceful_shutdown(async move {
474 let _ = shutdown_receiver.await;
475 })
476 .await
477 .unwrap();
478 });
479 Self {
480 shutdown_sender: Some(shutdown_sender),
481 http_url: format!("http://localhost:{port}/"),
482 ws_url: format!("ws://localhost:{port}/"),
483 }
484 }
485
486 pub fn http_url(&self) -> &str {
487 &self.http_url
488 }
489
490 pub fn ws_url(&self) -> &str {
491 &self.ws_url
492 }
493 }
494
495 #[cfg(not(target_arch = "wasm32"))]
496 impl Drop for TestServer {
497 fn drop(&mut self) {
498 if let Some(shutdown_sender) = self.shutdown_sender.take() {
499 println!("Shutting down server");
500 let _ = shutdown_sender.send(());
501 }
502 }
503 }
504
505 #[cfg(target_arch = "wasm32")]
506 pub struct TestServer;
507
508 #[cfg(target_arch = "wasm32")]
509 impl TestServer {
510 pub async fn new() -> Self {
511 Self
512 }
513
514 pub fn http_url(&self) -> &str {
515 "https://echo.websocket.org/"
516 }
517
518 pub fn ws_url(&self) -> &str {
519 "wss://echo.websocket.org/"
520 }
521 }
522
523 pub async fn test_websocket(mut websocket: WebSocket) {
524 let text = "Hello, World!";
525 websocket.send(Message::Text(text.into())).await.unwrap();
526
527 while let Some(message) = websocket.try_next().await.unwrap() {
528 match message {
529 Message::Text(s) => {
530 if s == text {
531 return;
532 }
533 }
534 _ => {}
535 }
536 }
537
538 panic!("didn't receive text back");
539 }
540
541 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
542 #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
543 async fn test_with_request_builder() {
544 let echo = TestServer::new().await;
545
546 let websocket = Client::default()
547 .get(echo.http_url())
548 .upgrade()
549 .send()
550 .await
551 .unwrap()
552 .into_websocket()
553 .await
554 .unwrap();
555
556 test_websocket(websocket).await;
557 }
558
559 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
560 #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
561 async fn test_shorthand() {
562 let echo = TestServer::new().await;
563
564 let websocket = websocket(echo.http_url()).await.unwrap();
565 test_websocket(websocket).await;
566 }
567
568 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
569 #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
570 async fn test_with_ws_scheme() {
571 let echo = TestServer::new().await;
572 let websocket = websocket(echo.ws_url()).await.unwrap();
573
574 test_websocket(websocket).await;
575 }
576
577 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
578 #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
579 async fn test_close() {
580 let echo = TestServer::new().await;
581
582 let websocket = websocket(echo.http_url()).await.unwrap();
583 websocket
584 .close(CloseCode::Normal, Some("test"))
585 .await
586 .expect("close returned an error");
587 }
588
589 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
590 #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
591 async fn test_send_close_frame() {
592 let echo = TestServer::new().await;
593
594 let mut websocket = websocket(echo.http_url()).await.unwrap();
595 websocket
596 .send(Message::Close {
597 code: CloseCode::Normal,
598 reason: "Can you please reply with a close frame?".into(),
599 })
600 .await
601 .unwrap();
602
603 let mut close_received = false;
604 while let Some(message) = websocket.try_next().await.unwrap() {
605 match message {
606 Message::Close { code, .. } => {
607 assert_eq!(code, CloseCode::Normal);
608 close_received = true;
609 }
610 _ => {}
611 }
612 }
613
614 assert!(close_received, "No close frame was received");
615 }
616
617 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
618 #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
619 #[cfg_attr(
620 target_arch = "wasm32",
621 ignore = "echo.websocket.org ignores subprotocols"
622 )]
623 async fn test_with_subprotocol() {
624 let echo = TestServer::new().await;
625
626 let mut websocket = Client::default()
627 .get(echo.http_url())
628 .upgrade()
629 .protocols(["chat"])
630 .send()
631 .await
632 .unwrap()
633 .into_websocket()
634 .await
635 .unwrap();
636
637 assert_eq!(websocket.protocol(), Some("chat"));
638
639 let message = websocket.try_next().await.unwrap().unwrap();
640 match message {
641 Message::Text(s) => {
642 assert_eq!(s, "protocol: chat");
643 }
644 _ => {
645 panic!("Expected text message with selected protocol");
646 }
647 }
648 }
649
650 #[test]
651 fn close_code_from_u16() {
652 let byte = 1008u16;
653 assert_eq!(CloseCode::from(byte), CloseCode::Policy);
654 }
655
656 #[test]
657 fn close_code_into_u16() {
658 let text = CloseCode::Away;
659 let byte: u16 = text.into();
660 assert_eq!(byte, 1001u16);
661 assert_eq!(u16::from(text), 1001u16);
662 }
663}