stripe/client/base/
tokio.rs

1use std::future::{self, Future};
2use std::pin::Pin;
3
4use http_types::{Request, StatusCode};
5use hyper::http;
6use hyper::{client::HttpConnector, Body};
7use serde::de::DeserializeOwned;
8use tokio::time::sleep;
9
10use crate::client::request_strategy::{Outcome, RequestStrategy};
11use crate::error::{ErrorResponse, StripeError};
12
13#[cfg(feature = "hyper-rustls-native")]
14mod connector {
15    use hyper::client::{connect::dns::GaiResolver, HttpConnector};
16    pub use hyper_rustls::HttpsConnector;
17    use hyper_rustls::HttpsConnectorBuilder;
18
19    pub fn create() -> HttpsConnector<HttpConnector<GaiResolver>> {
20        HttpsConnectorBuilder::new()
21            .with_native_roots()
22            .https_or_http()
23            .enable_http1()
24            .enable_http2()
25            .build()
26    }
27}
28
29#[cfg(feature = "hyper-rustls-webpki")]
30mod connector {
31    use hyper::client::{connect::dns::GaiResolver, HttpConnector};
32    pub use hyper_rustls::HttpsConnector;
33    use hyper_rustls::HttpsConnectorBuilder;
34
35    pub fn create() -> HttpsConnector<HttpConnector<GaiResolver>> {
36        HttpsConnectorBuilder::new()
37            .with_webpki_roots()
38            .https_or_http()
39            .enable_http1()
40            .enable_http2()
41            .build()
42    }
43}
44
45#[cfg(feature = "hyper-tls")]
46mod connector {
47    use hyper::client::{connect::dns::GaiResolver, HttpConnector};
48    pub use hyper_tls::HttpsConnector;
49
50    pub fn create() -> HttpsConnector<HttpConnector<GaiResolver>> {
51        HttpsConnector::new()
52    }
53}
54
55#[cfg(all(feature = "hyper-tls", feature = "hyper-rustls"))]
56compile_error!("You must enable only one TLS implementation");
57
58type HttpClient = hyper::Client<connector::HttpsConnector<HttpConnector>, Body>;
59
60pub type Response<T> = Pin<Box<dyn Future<Output = Result<T, StripeError>> + Send>>;
61
62#[allow(dead_code)]
63#[inline(always)]
64pub(crate) fn ok<T: Send + 'static>(ok: T) -> Response<T> {
65    Box::pin(future::ready(Ok(ok)))
66}
67
68#[allow(dead_code)]
69#[inline(always)]
70pub(crate) fn err<T: Send + 'static>(err: StripeError) -> Response<T> {
71    Box::pin(future::ready(Err(err)))
72}
73
74#[derive(Clone)]
75pub struct TokioClient {
76    client: HttpClient,
77}
78
79impl Default for TokioClient {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl TokioClient {
86    pub fn new() -> Self {
87        Self {
88            client: hyper::Client::builder().pool_max_idle_per_host(0).build(connector::create()),
89        }
90    }
91
92    pub fn execute<T: DeserializeOwned + Send + 'static>(
93        &self,
94        request: Request,
95        strategy: &RequestStrategy,
96    ) -> Response<T> {
97        // need to clone here since client could be used across threads.
98        // N.B. Client is send sync; cloned clients share the same pool.
99        let client = self.client.clone();
100        let strategy = strategy.clone();
101
102        Box::pin(async move {
103            let bytes = send_inner(&client, request, &strategy).await?;
104            let json_deserializer = &mut serde_json::Deserializer::from_slice(&bytes);
105            serde_path_to_error::deserialize(json_deserializer).map_err(StripeError::from)
106        })
107    }
108}
109
110async fn send_inner(
111    client: &HttpClient,
112    mut request: Request,
113    strategy: &RequestStrategy,
114) -> Result<hyper::body::Bytes, StripeError> {
115    let mut tries = 0;
116    let mut last_status: Option<StatusCode> = None;
117    let mut last_retry_header: Option<bool> = None;
118
119    // if we have no last error, then the strategy is invalid
120    let mut last_error = StripeError::ClientError("Invalid strategy".to_string());
121
122    if let Some(key) = strategy.get_key() {
123        request.insert_header("Idempotency-Key", key);
124    }
125
126    let body = request.body_bytes().await?;
127
128    loop {
129        return match strategy.test(last_status, last_retry_header, tries) {
130            Outcome::Stop => Err(last_error),
131            Outcome::Continue(duration) => {
132                if let Some(duration) = duration {
133                    sleep(duration).await;
134                }
135
136                // note: http::Request provides no easy way to clone, so we perform
137                //       the conversion from the clonable http_types::Request each time
138                //       obviously cloning before the first request is not ideal
139                let mut request = request.clone();
140                request.set_body(body.clone());
141
142                let response = match client.request(convert_request(request).await).await {
143                    Ok(response) => response,
144                    Err(err) => {
145                        last_error = StripeError::from(err);
146                        tries += 1;
147                        continue;
148                    }
149                };
150
151                let status = response.status();
152                let retry = response
153                    .headers()
154                    .get("Stripe-Should-Retry")
155                    .and_then(|s| s.to_str().ok())
156                    .and_then(|s| s.parse().ok());
157
158                let bytes = hyper::body::to_bytes(response.into_body()).await?;
159
160                if !status.is_success() {
161                    tries += 1;
162                    let json_deserializer = &mut serde_json::Deserializer::from_slice(&bytes);
163                    last_error = serde_path_to_error::deserialize(json_deserializer)
164                        .map(|mut e: ErrorResponse| {
165                            e.error.http_status = status.into();
166                            StripeError::from(e.error)
167                        })
168                        .unwrap_or_else(StripeError::from);
169                    last_status = Some(
170                        // NOTE: StatusCode::from can panic here, so fall back to InternalServerError
171                        //       see https://github.com/http-rs/http-types/blob/ac5d645ce5294554b86ebd49233d3ec01665d1d7/src/hyperium_http.rs#L20-L24
172                        StatusCode::try_from(u16::from(status))
173                            .unwrap_or(StatusCode::InternalServerError),
174                    );
175                    last_retry_header = retry;
176                    continue;
177                }
178
179                Ok(bytes)
180            }
181        };
182    }
183}
184
185/// convert an http_types::Request with a http_types::Body into a http::Request<hyper::Body>
186///
187/// note: this is necesarry because `http` deliberately does not support a `Body` type
188///       so hyper has a `Body` for which http_types cannot provide automatic conversion.
189async fn convert_request(mut request: http_types::Request) -> http::Request<hyper::Body> {
190    let body = request.body_bytes().await.expect("We know the data is a valid bytes object.");
191    let request: http::Request<_> = request.into();
192    http::Request::from_parts(request.into_parts().0, hyper::Body::from(body))
193}
194
195#[cfg(test)]
196mod tests {
197    use http_types::{Method, Request, Url};
198    use httpmock::prelude::*;
199    use hyper::{body::to_bytes, Body, Request as HyperRequest};
200
201    use super::convert_request;
202    use super::TokioClient;
203    use crate::client::request_strategy::RequestStrategy;
204    use crate::StripeError;
205
206    const TEST_URL: &str = "https://api.stripe.com/v1/";
207
208    #[tokio::test]
209    async fn basic_conversion() {
210        req_equal(
211            convert_request(Request::new(Method::Get, TEST_URL)).await,
212            HyperRequest::builder()
213                .method("GET")
214                .uri("http://test.com")
215                .body(Body::empty())
216                .unwrap(),
217        )
218        .await;
219    }
220
221    #[tokio::test]
222    async fn bytes_body_conversion() {
223        let body = "test".as_bytes();
224
225        let mut req = Request::new(Method::Post, TEST_URL);
226        req.set_body(body);
227
228        req_equal(
229            convert_request(req).await,
230            HyperRequest::builder().method("POST").uri(TEST_URL).body(Body::from(body)).unwrap(),
231        )
232        .await;
233    }
234
235    #[tokio::test]
236    async fn string_body_conversion() {
237        let body = "test";
238
239        let mut req = Request::new(Method::Post, TEST_URL);
240        req.set_body(body);
241
242        req_equal(
243            convert_request(req).await,
244            HyperRequest::builder().method("POST").uri(TEST_URL).body(Body::from(body)).unwrap(),
245        )
246        .await;
247    }
248
249    async fn req_equal(a: HyperRequest<Body>, b: HyperRequest<Body>) {
250        let (a_parts, a_body) = a.into_parts();
251        let (b_parts, b_body) = b.into_parts();
252
253        assert_eq!(a_parts.method, b_parts.method);
254        assert_eq!(to_bytes(a_body).await.unwrap().len(), to_bytes(b_body).await.unwrap().len());
255    }
256
257    #[tokio::test]
258    async fn retry() {
259        let client = TokioClient::new();
260
261        // Start a lightweight mock server.
262        let server = MockServer::start_async().await;
263
264        // Create a mock on the server.
265        let hello_mock = server.mock(|when, then| {
266            when.method(GET).path("/server-errors");
267            then.status(500);
268        });
269
270        let req = Request::get(Url::parse(&server.url("/server-errors")).unwrap());
271        let res = client.execute::<()>(req, &RequestStrategy::Retry(5)).await;
272
273        hello_mock.assert_hits_async(5).await;
274        assert!(res.is_err());
275    }
276
277    #[tokio::test]
278    async fn user_error() {
279        let client = TokioClient::new();
280
281        // Start a lightweight mock server.
282        let server = MockServer::start_async().await;
283
284        let mock = server.mock(|when, then| {
285            when.method(GET).path("/v1/missing");
286            then.status(404).body("{
287                \"error\": {
288                  \"message\": \"Unrecognized request URL (GET: /v1/missing). Please see https://stripe.com/docs or we can help at https://support.stripe.com/.\",
289                  \"type\": \"invalid_request_error\"
290                }
291              }
292              ");
293        });
294
295        let req = Request::get(Url::parse(&server.url("/v1/missing")).unwrap());
296        let res = client.execute::<()>(req, &RequestStrategy::Retry(3)).await;
297
298        mock.assert_hits_async(1).await;
299
300        match res {
301            Err(StripeError::Stripe(x)) => println!("{:?}", x),
302            _ => panic!("Expected stripe error {:?}", res),
303        }
304    }
305
306    #[tokio::test]
307    async fn nice_serde_error() {
308        use serde::Deserialize;
309
310        #[derive(Debug, Deserialize)]
311        struct DataType {
312            // Allowing dead code since used for deserialization
313            #[allow(dead_code)]
314            id: String,
315            #[allow(dead_code)]
316            name: String,
317        }
318
319        let client = TokioClient::new();
320
321        // Start a lightweight mock server.
322        let server = MockServer::start_async().await;
323
324        let mock = server.mock(|when, then| {
325            when.method(GET).path("/v1/odd_data");
326            then.status(200).body(
327                "{
328                \"id\": \"test\",
329                \"name\": 10
330              }
331              ",
332            );
333        });
334
335        let req = Request::get(Url::parse(&server.url("/v1/odd_data")).unwrap());
336        let res = client.execute::<DataType>(req, &RequestStrategy::Retry(3)).await;
337
338        mock.assert_hits_async(1).await;
339
340        match res {
341            Err(StripeError::JSONSerialize(err)) => {
342                println!("Error: {:?} Path: {:?}", err.inner(), err.path().to_string())
343            }
344            _ => panic!("Expected stripe error {:?}", res),
345        }
346    }
347
348    #[tokio::test]
349    async fn retry_header() {
350        let client = TokioClient::new();
351
352        // Start a lightweight mock server.
353        let server = MockServer::start_async().await;
354
355        // Create a mock on the server.
356        let hello_mock = server.mock(|when, then| {
357            when.method(GET).path("/server-errors");
358            then.status(500).header("Stripe-Should-Retry", "false");
359        });
360
361        let req = Request::get(Url::parse(&server.url("/server-errors")).unwrap());
362        let res = client.execute::<()>(req, &RequestStrategy::Retry(5)).await;
363
364        hello_mock.assert_hits_async(1).await;
365        assert!(res.is_err());
366    }
367
368    #[tokio::test]
369    async fn retry_body() {
370        let client = TokioClient::new();
371
372        // Start a lightweight mock server.
373        let server = MockServer::start_async().await;
374
375        // Create a mock on the server.
376        let hello_mock = server.mock(|when, then| {
377            when.method(POST).path("/server-errors").body("body");
378            then.status(500);
379        });
380
381        let mut req = Request::post(Url::parse(&server.url("/server-errors")).unwrap());
382        req.set_body("body");
383        let res = client.execute::<()>(req, &RequestStrategy::Retry(5)).await;
384
385        hello_mock.assert_hits_async(5).await;
386        assert!(res.is_err());
387    }
388}