http_request_derive_client_reqwest/
reqwest_client.rs

1// SPDX-FileCopyrightText: OpenTalk GmbH <mail@opentalk.eu>
2//
3// SPDX-License-Identifier: MIT OR Apache-2.0
4
5use http_request_derive::HttpRequest;
6use http_request_derive_client::Client;
7use snafu::ResultExt as _;
8use url::Url;
9
10use crate::{
11    ReqwestClientError,
12    reqwest_client_error::{
13        BuildHttpResponseBodySnafu, ConvertToHttpRequestSnafu, ConvertToReqwestRequestSnafu,
14        ReadResponseSnafu, RequestExecutionSnafu, RetrieveResponseBodySnafu,
15    },
16};
17
18/// A client for executing requests as defined by [`http_request_derive::HttpRequest`] implementations.
19#[derive(Debug, Clone)]
20pub struct ReqwestClient {
21    client: reqwest::Client,
22    base_url: Url,
23}
24
25impl ReqwestClient {
26    /// Create a new [`ReqwestClient`] from a base [`Url`].
27    pub fn new(base_url: Url) -> Self {
28        Self {
29            client: reqwest::Client::new(),
30            base_url,
31        }
32    }
33
34    /// Returns the base URL which is used for subsequent requests.
35    pub fn base_url(&self) -> &Url {
36        &self.base_url
37    }
38
39    /// Sets the base URL to the given URL.
40    pub fn set_base_url(&mut self, base_url: Url) {
41        self.base_url = base_url
42    }
43}
44
45#[async_trait::async_trait]
46impl Client for ReqwestClient {
47    type ClientError = ReqwestClientError;
48
49    async fn execute<R: HttpRequest + Send>(
50        &self,
51        request: R,
52    ) -> Result<R::Response, Self::ClientError> {
53        let request = request
54            .to_http_request(&self.base_url)
55            .context(ConvertToHttpRequestSnafu)?;
56        let request = reqwest::Request::try_from(request).context(ConvertToReqwestRequestSnafu)?;
57        let response = self
58            .client
59            .execute(request)
60            .await
61            .context(RequestExecutionSnafu)?;
62        let mut http_response = http::Response::builder()
63            .status(response.status())
64            .version(response.version());
65        if let Some(headers) = http_response.headers_mut() {
66            *headers = response.headers().clone();
67        }
68        let body = response.bytes().await.context(RetrieveResponseBodySnafu)?;
69        let http_response = http_response
70            .body(body)
71            .context(BuildHttpResponseBodySnafu)?;
72        R::read_response(http_response).context(ReadResponseSnafu)
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use http::StatusCode;
79    use http_request_derive::HttpRequest;
80    use http_request_derive_client::Client as _;
81    use httptest::{
82        Expectation, all_of,
83        matchers::{json_decoded, request},
84        responders::{json_encoded, status_code},
85    };
86    use pretty_assertions::{assert_eq, assert_matches};
87    use serde::{Deserialize, Serialize};
88    use serde_json::json;
89    use url::Url;
90
91    use crate::ReqwestClient;
92
93    #[tokio::test]
94    async fn simple_successful_get_request() {
95        #[derive(HttpRequest)]
96        #[http_request(method = "GET",response = ResponseBody, path = "/query/a/response")]
97        struct Request;
98
99        #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
100        struct ResponseBody {
101            name: String,
102        }
103
104        _ = pretty_env_logger::try_init();
105        let server = httptest::Server::run();
106
107        server.expect(
108            Expectation::matching(request::method_path("GET", "/api/query/a/response"))
109                .respond_with(json_encoded(json!(ResponseBody {
110                    name: "hello".to_string()
111                }))),
112        );
113
114        let url = server
115            .url("/api")
116            .to_string()
117            .parse()
118            .expect("must be a valid url");
119        let client = ReqwestClient::new(url);
120
121        let response = client
122            .execute(Request)
123            .await
124            .expect("valid response expected");
125
126        assert_eq!(
127            response,
128            ResponseBody {
129                name: "hello".to_string()
130            }
131        );
132    }
133
134    #[tokio::test]
135    async fn simple_get_request_failing_with_invalid_json_body() {
136        #[derive(HttpRequest)]
137        #[http_request(method = "GET",response = ResponseBody, path = "/query/a/response")]
138        struct Request;
139
140        #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
141        struct ResponseBody {
142            name: String,
143        }
144
145        _ = pretty_env_logger::try_init();
146        let server = httptest::Server::run();
147
148        server.expect(
149            Expectation::matching(request::method_path("GET", "/api/query/a/response"))
150                .respond_with(http::Response::new(r#"{ "invalid": JSON }"#)),
151        );
152
153        let url = server
154            .url("/api")
155            .to_string()
156            .parse()
157            .expect("must be a valid url");
158        let client = ReqwestClient::new(url);
159
160        let response = client.execute(Request).await;
161        assert_matches!(
162            response,
163            Err(crate::ReqwestClientError::ReadResponse {
164                source: http_request_derive::Error::Json { source: _ }
165            })
166        );
167    }
168
169    #[tokio::test]
170    async fn simple_get_request_failing_with_internal_server_error() {
171        #[derive(HttpRequest)]
172        #[http_request(method = "GET",response = ResponseBody, path = "/query/a/response")]
173        struct Request;
174
175        #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
176        struct ResponseBody {
177            name: String,
178        }
179
180        _ = pretty_env_logger::try_init();
181        let server = httptest::Server::run();
182
183        server.expect(
184            Expectation::matching(request::method_path("GET", "/api/query/a/response"))
185                .respond_with(status_code(StatusCode::INTERNAL_SERVER_ERROR.as_u16())),
186        );
187
188        let url = server
189            .url("/api")
190            .to_string()
191            .parse()
192            .expect("must be a valid url");
193        let client = ReqwestClient::new(url);
194
195        let response = client.execute(Request).await;
196        assert_matches!(
197            response,
198            Err(crate::ReqwestClientError::ReadResponse {
199                source: http_request_derive::Error::NonSuccessStatus {
200                    status: StatusCode::INTERNAL_SERVER_ERROR,
201                    body: _
202                }
203            })
204        );
205    }
206
207    #[tokio::test]
208    async fn simple_successful_post_request() {
209        #[derive(HttpRequest)]
210        #[http_request(method = "POST",response = ResponseBody, path = "/post/a/resource")]
211        struct Request {
212            #[http_request(body)]
213            body: RequestBody,
214        }
215
216        #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
217        struct RequestBody {
218            resource: String,
219        }
220
221        #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
222        struct ResponseBody {
223            name: String,
224        }
225
226        _ = pretty_env_logger::try_init();
227        let server = httptest::Server::run();
228
229        server.expect(
230            Expectation::matching(all_of![
231                request::method_path("POST", "/api/post/a/resource"),
232                request::body(json_decoded(|b: &RequestBody| {
233                    b == &RequestBody {
234                        resource: "user".to_string(),
235                    }
236                })),
237            ])
238            .respond_with(json_encoded(json!(ResponseBody {
239                name: "hello".to_string()
240            }))),
241        );
242
243        let url = server
244            .url("/api")
245            .to_string()
246            .parse()
247            .expect("must be a valid url");
248        let client = ReqwestClient::new(url);
249
250        let response = client
251            .execute(Request {
252                body: RequestBody {
253                    resource: "user".to_string(),
254                },
255            })
256            .await
257            .expect("valid response expected");
258
259        assert_eq!(
260            response,
261            ResponseBody {
262                name: "hello".to_string()
263            }
264        );
265    }
266
267    #[tokio::test]
268    async fn get_base_url() {
269        let url = Url::parse("http://localhost:9090/v1/api").expect("must be a valid url");
270        let mut client = ReqwestClient::new(url.clone());
271
272        assert_eq!(client.base_url(), &url);
273
274        let new_url = Url::parse("http://localhost:9090/v2/api").expect("must be a valid url");
275        client.set_base_url(new_url.clone());
276
277        assert_eq!(client.base_url(), &new_url);
278    }
279}