http_request_derive_client_reqwest/
reqwest_client.rs

1// SPDX-FileCopyrightText: OpenTalk GmbH <mail@opentalk.eu>
2//
3// SPDX-License-Identifier: MIT OR Apache-2.0
4
5use http_request_derive::HttpRequest;
6use http_request_derive_client::Client;
7use snafu::ResultExt as _;
8use url::Url;
9
10use crate::{
11    reqwest_client_error::{
12        BuildHttpResponseBodySnafu, ConvertToHttpRequestSnafu, ConvertToReqwestRequestSnafu,
13        ReadResponseSnafu, RequestExecutionSnafu, RetrieveResponseBodySnafu,
14    },
15    ReqwestClientError,
16};
17
18/// A client for executing requests as defined by [`http_request_derive::HttpRequest`] implementations.
19#[derive(Debug, Clone)]
20pub struct ReqwestClient {
21    client: reqwest::Client,
22    base_url: Url,
23}
24
25impl ReqwestClient {
26    /// Create a new [`ReqwestClient`] from a base [`Url`].
27    pub fn new(base_url: Url) -> Self {
28        Self {
29            client: reqwest::Client::new(),
30            base_url,
31        }
32    }
33
34    /// Returns the base URL which is used for subsequent requests.
35    pub fn base_url(&self) -> &Url {
36        &self.base_url
37    }
38
39    /// Sets the base URL to the given URL.
40    pub fn set_base_url(&mut self, base_url: Url) {
41        self.base_url = base_url
42    }
43}
44
45#[async_trait::async_trait]
46impl Client for ReqwestClient {
47    type ClientError = ReqwestClientError;
48
49    async fn execute<R: HttpRequest + Send>(
50        &self,
51        request: R,
52    ) -> Result<R::Response, Self::ClientError> {
53        let request = request
54            .to_http_request(&self.base_url)
55            .context(ConvertToHttpRequestSnafu)?;
56        let request = reqwest::Request::try_from(request).context(ConvertToReqwestRequestSnafu)?;
57        let response = self
58            .client
59            .execute(request)
60            .await
61            .context(RequestExecutionSnafu)?;
62        let mut http_response = http::Response::builder()
63            .status(response.status())
64            .version(response.version());
65        if let Some(headers) = http_response.headers_mut() {
66            *headers = response.headers().clone();
67        }
68        let body = response.bytes().await.context(RetrieveResponseBodySnafu)?;
69        let http_response = http_response
70            .body(body)
71            .context(BuildHttpResponseBodySnafu)?;
72        R::read_response(http_response).context(ReadResponseSnafu)
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use http::StatusCode;
79    use http_request_derive::HttpRequest;
80    use http_request_derive_client::Client as _;
81    use httptest::{
82        all_of,
83        matchers::{json_decoded, request},
84        responders::{json_encoded, status_code},
85        Expectation,
86    };
87    use pretty_assertions::{assert_eq, assert_matches};
88    use serde::{Deserialize, Serialize};
89    use serde_json::json;
90    use url::Url;
91
92    use crate::ReqwestClient;
93
94    #[tokio::test]
95    async fn simple_successful_get_request() {
96        #[derive(HttpRequest)]
97        #[http_request(method = "GET",response = ResponseBody, path = "/query/a/response")]
98        struct Request;
99
100        #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
101        struct ResponseBody {
102            name: String,
103        }
104
105        _ = pretty_env_logger::try_init();
106        let server = httptest::Server::run();
107
108        server.expect(
109            Expectation::matching(request::method_path("GET", "/api/query/a/response"))
110                .respond_with(json_encoded(json!(ResponseBody {
111                    name: "hello".to_string()
112                }))),
113        );
114
115        let url = server
116            .url("/api")
117            .to_string()
118            .parse()
119            .expect("must be a valid url");
120        let client = ReqwestClient::new(url);
121
122        let response = client
123            .execute(Request)
124            .await
125            .expect("valid response expected");
126
127        assert_eq!(
128            response,
129            ResponseBody {
130                name: "hello".to_string()
131            }
132        );
133    }
134
135    #[tokio::test]
136    async fn simple_get_request_failing_with_invalid_json_body() {
137        #[derive(HttpRequest)]
138        #[http_request(method = "GET",response = ResponseBody, path = "/query/a/response")]
139        struct Request;
140
141        #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
142        struct ResponseBody {
143            name: String,
144        }
145
146        _ = pretty_env_logger::try_init();
147        let server = httptest::Server::run();
148
149        server.expect(
150            Expectation::matching(request::method_path("GET", "/api/query/a/response"))
151                .respond_with(http::Response::new(r#"{ "invalid": JSON }"#)),
152        );
153
154        let url = server
155            .url("/api")
156            .to_string()
157            .parse()
158            .expect("must be a valid url");
159        let client = ReqwestClient::new(url);
160
161        let response = client.execute(Request).await;
162        assert_matches!(
163            response,
164            Err(crate::ReqwestClientError::ReadResponse {
165                source: http_request_derive::Error::Json { source: _ }
166            })
167        );
168    }
169
170    #[tokio::test]
171    async fn simple_get_request_failing_with_internal_server_error() {
172        #[derive(HttpRequest)]
173        #[http_request(method = "GET",response = ResponseBody, path = "/query/a/response")]
174        struct Request;
175
176        #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
177        struct ResponseBody {
178            name: String,
179        }
180
181        _ = pretty_env_logger::try_init();
182        let server = httptest::Server::run();
183
184        server.expect(
185            Expectation::matching(request::method_path("GET", "/api/query/a/response"))
186                .respond_with(status_code(StatusCode::INTERNAL_SERVER_ERROR.as_u16())),
187        );
188
189        let url = server
190            .url("/api")
191            .to_string()
192            .parse()
193            .expect("must be a valid url");
194        let client = ReqwestClient::new(url);
195
196        let response = client.execute(Request).await;
197        assert_matches!(
198            response,
199            Err(crate::ReqwestClientError::ReadResponse {
200                source: http_request_derive::Error::NonSuccessStatus {
201                    status: StatusCode::INTERNAL_SERVER_ERROR,
202                    data: _
203                }
204            })
205        );
206    }
207
208    #[tokio::test]
209    async fn simple_successful_post_request() {
210        #[derive(HttpRequest)]
211        #[http_request(method = "POST",response = ResponseBody, path = "/post/a/resource")]
212        struct Request {
213            #[http_request(body)]
214            body: RequestBody,
215        }
216
217        #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
218        struct RequestBody {
219            resource: String,
220        }
221
222        #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
223        struct ResponseBody {
224            name: String,
225        }
226
227        _ = pretty_env_logger::try_init();
228        let server = httptest::Server::run();
229
230        server.expect(
231            Expectation::matching(all_of![
232                request::method_path("POST", "/api/post/a/resource"),
233                request::body(json_decoded(|b: &RequestBody| {
234                    b == &RequestBody {
235                        resource: "user".to_string(),
236                    }
237                })),
238            ])
239            .respond_with(json_encoded(json!(ResponseBody {
240                name: "hello".to_string()
241            }))),
242        );
243
244        let url = server
245            .url("/api")
246            .to_string()
247            .parse()
248            .expect("must be a valid url");
249        let client = ReqwestClient::new(url);
250
251        let response = client
252            .execute(Request {
253                body: RequestBody {
254                    resource: "user".to_string(),
255                },
256            })
257            .await
258            .expect("valid response expected");
259
260        assert_eq!(
261            response,
262            ResponseBody {
263                name: "hello".to_string()
264            }
265        );
266    }
267
268    #[tokio::test]
269    async fn get_base_url() {
270        let url = Url::parse("http://localhost:9090/v1/api").expect("must be a valid url");
271        let mut client = ReqwestClient::new(url.clone());
272
273        assert_eq!(client.base_url(), &url);
274
275        let new_url = Url::parse("http://localhost:9090/v2/api").expect("must be a valid url");
276        client.set_base_url(new_url.clone());
277
278        assert_eq!(client.base_url(), &new_url);
279    }
280}