Skip to main content

http_request_derive_client_reqwest/
reqwest_client.rs

1// SPDX-FileCopyrightText: OpenTalk GmbH <mail@opentalk.eu>
2//
3// SPDX-License-Identifier: MIT OR Apache-2.0
4
5use std::time::SystemTime;
6
7use bytes::Bytes;
8use http_request_derive::HttpRequest;
9use http_request_derive_client::Client;
10use http_request_derive_logging::HttpLogger;
11use snafu::ResultExt as _;
12use url::Url;
13
14use crate::{
15    ReqwestClientError,
16    reqwest_client_error::{
17        BuildHttpResponseBodySnafu, ConvertToHttpRequestSnafu, ConvertToReqwestRequestSnafu,
18        ReadResponseSnafu, RequestExecutionSnafu, RetrieveResponseBodySnafu,
19    },
20};
21
22/// A client for executing requests as defined by [`http_request_derive::HttpRequest`] implementations.
23#[derive(Debug, Clone)]
24pub struct ReqwestClient {
25    client: reqwest::Client,
26    base_url: Url,
27    logger: Option<HttpLogger>,
28}
29
30impl ReqwestClient {
31    /// Create a new [`ReqwestClient`] from a base [`Url`].
32    pub fn new(base_url: Url) -> Self {
33        Self {
34            client: reqwest::Client::new(),
35            base_url,
36            logger: None,
37        }
38    }
39
40    /// Get the logger for dumping information about the HTTP communication if it is set.
41    pub fn logger(&self) -> Option<HttpLogger> {
42        self.logger.clone()
43    }
44
45    /// Set a logger for dumping information about the HTTP communication.
46    pub fn set_logger(&mut self, logger: HttpLogger) {
47        self.logger = Some(logger);
48    }
49
50    /// Return the client with a new logger for dumping the HTTP communication.
51    pub fn with_logger(mut self, logger: HttpLogger) -> Self {
52        self.logger = Some(logger);
53        self
54    }
55
56    /// Returns the base URL which is used for subsequent requests.
57    pub fn base_url(&self) -> &Url {
58        &self.base_url
59    }
60
61    /// Sets the base URL to the given URL.
62    pub fn set_base_url(&mut self, base_url: Url) {
63        self.base_url = base_url
64    }
65
66    /// Return the client with a new base URL.
67    pub fn with_base_url(mut self, base_url: Url) -> Self {
68        self.base_url = base_url;
69        self
70    }
71
72    async fn execute_http_request(
73        &self,
74        request: http::Request<Vec<u8>>,
75    ) -> Result<http::Response<Bytes>, ReqwestClientError> {
76        let request = reqwest::Request::try_from(request).context(ConvertToReqwestRequestSnafu)?;
77        let response = self
78            .client
79            .execute(request)
80            .await
81            .context(RequestExecutionSnafu)?;
82        let mut http_response = http::Response::builder().status(response.status());
83        #[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
84        {
85            http_response = http_response.version(response.version());
86        }
87        if let Some(headers) = http_response.headers_mut() {
88            *headers = response.headers().clone();
89        }
90        let body = response.bytes().await.context(RetrieveResponseBodySnafu)?;
91        let http_response = http_response
92            .body(body)
93            .context(BuildHttpResponseBodySnafu)?;
94        Ok(http_response)
95    }
96
97    async fn execute_http_request_with_optional_logging(
98        &self,
99        request: http::Request<Vec<u8>>,
100    ) -> Result<http::Response<Bytes>, ReqwestClientError> {
101        if let Some(logger) = self.logger.as_ref() {
102            let start_time = SystemTime::now();
103            let response = self.execute_http_request(request.clone()).await;
104            logger
105                .log_request(start_time, &request, response.as_ref().ok())
106                .await;
107            return response;
108        }
109
110        self.execute_http_request(request).await
111    }
112}
113
114#[async_trait::async_trait(?Send)]
115impl Client for ReqwestClient {
116    type ClientError = ReqwestClientError;
117
118    async fn execute<R: HttpRequest + Send>(
119        &self,
120        request: R,
121    ) -> Result<R::Response, Self::ClientError> {
122        let request = request
123            .to_http_request(&self.base_url)
124            .context(ConvertToHttpRequestSnafu)?;
125        let http_response = self
126            .execute_http_request_with_optional_logging(request)
127            .await?;
128        R::read_response(http_response).context(ReadResponseSnafu)
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use http::StatusCode;
135    use http_request_derive::HttpRequest;
136    use http_request_derive_client::Client as _;
137    use httptest::{
138        Expectation, all_of,
139        matchers::{json_decoded, request},
140        responders::{json_encoded, status_code},
141    };
142    use pretty_assertions::{assert_eq, assert_matches};
143    use serde::{Deserialize, Serialize};
144    use serde_json::json;
145    use url::Url;
146
147    use crate::ReqwestClient;
148
149    #[tokio::test]
150    async fn simple_successful_get_request() {
151        #[derive(HttpRequest)]
152        #[http_request(method = "GET",response = ResponseBody, path = "/query/a/response")]
153        struct Request;
154
155        #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
156        struct ResponseBody {
157            name: String,
158        }
159
160        _ = pretty_env_logger::try_init();
161        let server = httptest::Server::run();
162
163        server.expect(
164            Expectation::matching(request::method_path("GET", "/api/query/a/response"))
165                .respond_with(json_encoded(json!(ResponseBody {
166                    name: "hello".to_string()
167                }))),
168        );
169
170        let url = server
171            .url("/api")
172            .to_string()
173            .parse()
174            .expect("must be a valid url");
175        let client = ReqwestClient::new(url);
176
177        let response = client
178            .execute(Request)
179            .await
180            .expect("valid response expected");
181
182        assert_eq!(
183            response,
184            ResponseBody {
185                name: "hello".to_string()
186            }
187        );
188    }
189
190    #[tokio::test]
191    async fn simple_get_request_failing_with_invalid_json_body() {
192        #[derive(HttpRequest)]
193        #[http_request(method = "GET",response = ResponseBody, path = "/query/a/response")]
194        struct Request;
195
196        #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
197        struct ResponseBody {
198            name: String,
199        }
200
201        _ = pretty_env_logger::try_init();
202        let server = httptest::Server::run();
203
204        server.expect(
205            Expectation::matching(request::method_path("GET", "/api/query/a/response"))
206                .respond_with(http::Response::new(r#"{ "invalid": JSON }"#)),
207        );
208
209        let url = server
210            .url("/api")
211            .to_string()
212            .parse()
213            .expect("must be a valid url");
214        let client = ReqwestClient::new(url);
215
216        let response = client.execute(Request).await;
217        assert_matches!(
218            response,
219            Err(crate::ReqwestClientError::ReadResponse {
220                source: http_request_derive::Error::Json { source: _ }
221            })
222        );
223    }
224
225    #[tokio::test]
226    async fn simple_get_request_failing_with_internal_server_error() {
227        #[derive(HttpRequest)]
228        #[http_request(method = "GET",response = ResponseBody, path = "/query/a/response")]
229        struct Request;
230
231        #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
232        struct ResponseBody {
233            name: String,
234        }
235
236        _ = pretty_env_logger::try_init();
237        let server = httptest::Server::run();
238
239        server.expect(
240            Expectation::matching(request::method_path("GET", "/api/query/a/response"))
241                .respond_with(status_code(StatusCode::INTERNAL_SERVER_ERROR.as_u16())),
242        );
243
244        let url = server
245            .url("/api")
246            .to_string()
247            .parse()
248            .expect("must be a valid url");
249        let client = ReqwestClient::new(url);
250
251        let response = client.execute(Request).await;
252        assert_matches!(
253            response,
254            Err(crate::ReqwestClientError::ReadResponse {
255                source: http_request_derive::Error::NonSuccessStatus {
256                    status: StatusCode::INTERNAL_SERVER_ERROR,
257                    body: _
258                }
259            })
260        );
261    }
262
263    #[tokio::test]
264    async fn simple_successful_post_request() {
265        #[derive(HttpRequest)]
266        #[http_request(method = "POST",response = ResponseBody, path = "/post/a/resource")]
267        struct Request {
268            #[http_request(body)]
269            body: RequestBody,
270        }
271
272        #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
273        struct RequestBody {
274            resource: String,
275        }
276
277        #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
278        struct ResponseBody {
279            name: String,
280        }
281
282        _ = pretty_env_logger::try_init();
283        let server = httptest::Server::run();
284
285        server.expect(
286            Expectation::matching(all_of![
287                request::method_path("POST", "/api/post/a/resource"),
288                request::body(json_decoded(|b: &RequestBody| {
289                    b == &RequestBody {
290                        resource: "user".to_string(),
291                    }
292                })),
293            ])
294            .respond_with(json_encoded(json!(ResponseBody {
295                name: "hello".to_string()
296            }))),
297        );
298
299        let url = server
300            .url("/api")
301            .to_string()
302            .parse()
303            .expect("must be a valid url");
304        let client = ReqwestClient::new(url);
305
306        let response = client
307            .execute(Request {
308                body: RequestBody {
309                    resource: "user".to_string(),
310                },
311            })
312            .await
313            .expect("valid response expected");
314
315        assert_eq!(
316            response,
317            ResponseBody {
318                name: "hello".to_string()
319            }
320        );
321    }
322
323    #[tokio::test]
324    async fn get_base_url() {
325        let url = Url::parse("http://localhost:9090/v1/api").expect("must be a valid url");
326        let mut client = ReqwestClient::new(url.clone());
327
328        assert_eq!(client.base_url(), &url);
329
330        let new_url = Url::parse("http://localhost:9090/v2/api").expect("must be a valid url");
331        client.set_base_url(new_url.clone());
332
333        assert_eq!(client.base_url(), &new_url);
334    }
335}