Skip to main content

reqwest_websocket/
middleware.rs

1use crate::{Client, Error, RequestBuilder};
2
3impl Client for reqwest_middleware::ClientWithMiddleware {
4    async fn execute(&self, request: reqwest::Request) -> Result<reqwest::Response, Error> {
5        self.execute(request).await.map_err(Into::into)
6    }
7}
8
9impl RequestBuilder for reqwest_middleware::RequestBuilder {
10    type Client = reqwest_middleware::ClientWithMiddleware;
11
12    fn build_split(self) -> (Self::Client, Result<reqwest::Request, Error>) {
13        let (client, request) = reqwest_middleware::RequestBuilder::build_split(self);
14        (client, request.map_err(Into::into))
15    }
16}
17
18#[cfg(test)]
19#[cfg(not(target_arch = "wasm32"))]
20mod tests {
21    use crate::{
22        tests::{test_websocket, TestServer},
23        Upgrade,
24    };
25    use std::sync::{Arc, Mutex};
26
27    #[derive(Debug)]
28    struct TestMiddleware {
29        did_run: Arc<Mutex<bool>>,
30    }
31
32    #[async_trait::async_trait]
33    impl reqwest_middleware::Middleware for TestMiddleware {
34        async fn handle(
35            &self,
36            req: reqwest::Request,
37            extensions: &mut http::Extensions,
38            next: reqwest_middleware::Next<'_>,
39        ) -> Result<reqwest::Response, reqwest_middleware::Error> {
40            {
41                let mut did_run = self.did_run.lock().unwrap();
42                *did_run = true;
43            }
44            next.run(req, extensions).await
45        }
46    }
47
48    //#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
49    //#[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
50    #[tokio::test]
51    async fn websocket_with_middleware() {
52        let echo = TestServer::new().await;
53
54        let did_run = Arc::new(Mutex::new(false));
55        let middleware = TestMiddleware {
56            did_run: did_run.clone(),
57        };
58
59        let client = reqwest::Client::builder().http1_only().build().unwrap();
60        let client = reqwest_middleware::ClientBuilder::new(client)
61            .with(middleware)
62            .build();
63
64        let websocket = client
65            .get(echo.http_url())
66            .upgrade()
67            .send()
68            .await
69            .unwrap()
70            .into_websocket()
71            .await
72            .unwrap();
73
74        test_websocket(websocket).await;
75
76        let did_run = {
77            let did_run = did_run.lock().unwrap();
78            *did_run
79        };
80        assert!(did_run);
81    }
82}