stripe/client/base/
tokio.rs1use std::future::{self, Future};
2use std::pin::Pin;
3
4use http_types::{Request, StatusCode};
5use hyper::http;
6use hyper::{client::HttpConnector, Body};
7use serde::de::DeserializeOwned;
8use tokio::time::sleep;
9
10use crate::client::request_strategy::{Outcome, RequestStrategy};
11use crate::error::{ErrorResponse, StripeError};
12
13#[cfg(feature = "hyper-rustls-native")]
14mod connector {
15 use hyper::client::{connect::dns::GaiResolver, HttpConnector};
16 pub use hyper_rustls::HttpsConnector;
17 use hyper_rustls::HttpsConnectorBuilder;
18
19 pub fn create() -> HttpsConnector<HttpConnector<GaiResolver>> {
20 HttpsConnectorBuilder::new()
21 .with_native_roots()
22 .https_or_http()
23 .enable_http1()
24 .enable_http2()
25 .build()
26 }
27}
28
29#[cfg(feature = "hyper-rustls-webpki")]
30mod connector {
31 use hyper::client::{connect::dns::GaiResolver, HttpConnector};
32 pub use hyper_rustls::HttpsConnector;
33 use hyper_rustls::HttpsConnectorBuilder;
34
35 pub fn create() -> HttpsConnector<HttpConnector<GaiResolver>> {
36 HttpsConnectorBuilder::new()
37 .with_webpki_roots()
38 .https_or_http()
39 .enable_http1()
40 .enable_http2()
41 .build()
42 }
43}
44
45#[cfg(feature = "hyper-tls")]
46mod connector {
47 use hyper::client::{connect::dns::GaiResolver, HttpConnector};
48 pub use hyper_tls::HttpsConnector;
49
50 pub fn create() -> HttpsConnector<HttpConnector<GaiResolver>> {
51 HttpsConnector::new()
52 }
53}
54
55#[cfg(all(feature = "hyper-tls", feature = "hyper-rustls"))]
56compile_error!("You must enable only one TLS implementation");
57
58type HttpClient = hyper::Client<connector::HttpsConnector<HttpConnector>, Body>;
59
60pub type Response<T> = Pin<Box<dyn Future<Output = Result<T, StripeError>> + Send>>;
61
62#[allow(dead_code)]
63#[inline(always)]
64pub(crate) fn ok<T: Send + 'static>(ok: T) -> Response<T> {
65 Box::pin(future::ready(Ok(ok)))
66}
67
68#[allow(dead_code)]
69#[inline(always)]
70pub(crate) fn err<T: Send + 'static>(err: StripeError) -> Response<T> {
71 Box::pin(future::ready(Err(err)))
72}
73
74#[derive(Clone)]
75pub struct TokioClient {
76 client: HttpClient,
77}
78
79impl Default for TokioClient {
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85impl TokioClient {
86 pub fn new() -> Self {
87 Self {
88 client: hyper::Client::builder().pool_max_idle_per_host(0).build(connector::create()),
89 }
90 }
91
92 pub fn execute<T: DeserializeOwned + Send + 'static>(
93 &self,
94 request: Request,
95 strategy: &RequestStrategy,
96 ) -> Response<T> {
97 let client = self.client.clone();
100 let strategy = strategy.clone();
101
102 Box::pin(async move {
103 let bytes = send_inner(&client, request, &strategy).await?;
104 let json_deserializer = &mut serde_json::Deserializer::from_slice(&bytes);
105 serde_path_to_error::deserialize(json_deserializer).map_err(StripeError::from)
106 })
107 }
108}
109
110async fn send_inner(
111 client: &HttpClient,
112 mut request: Request,
113 strategy: &RequestStrategy,
114) -> Result<hyper::body::Bytes, StripeError> {
115 let mut tries = 0;
116 let mut last_status: Option<StatusCode> = None;
117 let mut last_retry_header: Option<bool> = None;
118
119 let mut last_error = StripeError::ClientError("Invalid strategy".to_string());
121
122 if let Some(key) = strategy.get_key() {
123 request.insert_header("Idempotency-Key", key);
124 }
125
126 let body = request.body_bytes().await?;
127
128 loop {
129 return match strategy.test(last_status, last_retry_header, tries) {
130 Outcome::Stop => Err(last_error),
131 Outcome::Continue(duration) => {
132 if let Some(duration) = duration {
133 sleep(duration).await;
134 }
135
136 let mut request = request.clone();
140 request.set_body(body.clone());
141
142 let response = match client.request(convert_request(request).await).await {
143 Ok(response) => response,
144 Err(err) => {
145 last_error = StripeError::from(err);
146 tries += 1;
147 continue;
148 }
149 };
150
151 let status = response.status();
152 let retry = response
153 .headers()
154 .get("Stripe-Should-Retry")
155 .and_then(|s| s.to_str().ok())
156 .and_then(|s| s.parse().ok());
157
158 let bytes = hyper::body::to_bytes(response.into_body()).await?;
159
160 if !status.is_success() {
161 tries += 1;
162 let json_deserializer = &mut serde_json::Deserializer::from_slice(&bytes);
163 last_error = serde_path_to_error::deserialize(json_deserializer)
164 .map(|mut e: ErrorResponse| {
165 e.error.http_status = status.into();
166 StripeError::from(e.error)
167 })
168 .unwrap_or_else(StripeError::from);
169 last_status = Some(
170 StatusCode::try_from(u16::from(status))
173 .unwrap_or(StatusCode::InternalServerError),
174 );
175 last_retry_header = retry;
176 continue;
177 }
178
179 Ok(bytes)
180 }
181 };
182 }
183}
184
185async fn convert_request(mut request: http_types::Request) -> http::Request<hyper::Body> {
190 let body = request.body_bytes().await.expect("We know the data is a valid bytes object.");
191 let request: http::Request<_> = request.into();
192 http::Request::from_parts(request.into_parts().0, hyper::Body::from(body))
193}
194
195#[cfg(test)]
196mod tests {
197 use http_types::{Method, Request, Url};
198 use httpmock::prelude::*;
199 use hyper::{body::to_bytes, Body, Request as HyperRequest};
200
201 use super::convert_request;
202 use super::TokioClient;
203 use crate::client::request_strategy::RequestStrategy;
204 use crate::StripeError;
205
206 const TEST_URL: &str = "https://api.stripe.com/v1/";
207
208 #[tokio::test]
209 async fn basic_conversion() {
210 req_equal(
211 convert_request(Request::new(Method::Get, TEST_URL)).await,
212 HyperRequest::builder()
213 .method("GET")
214 .uri("http://test.com")
215 .body(Body::empty())
216 .unwrap(),
217 )
218 .await;
219 }
220
221 #[tokio::test]
222 async fn bytes_body_conversion() {
223 let body = "test".as_bytes();
224
225 let mut req = Request::new(Method::Post, TEST_URL);
226 req.set_body(body);
227
228 req_equal(
229 convert_request(req).await,
230 HyperRequest::builder().method("POST").uri(TEST_URL).body(Body::from(body)).unwrap(),
231 )
232 .await;
233 }
234
235 #[tokio::test]
236 async fn string_body_conversion() {
237 let body = "test";
238
239 let mut req = Request::new(Method::Post, TEST_URL);
240 req.set_body(body);
241
242 req_equal(
243 convert_request(req).await,
244 HyperRequest::builder().method("POST").uri(TEST_URL).body(Body::from(body)).unwrap(),
245 )
246 .await;
247 }
248
249 async fn req_equal(a: HyperRequest<Body>, b: HyperRequest<Body>) {
250 let (a_parts, a_body) = a.into_parts();
251 let (b_parts, b_body) = b.into_parts();
252
253 assert_eq!(a_parts.method, b_parts.method);
254 assert_eq!(to_bytes(a_body).await.unwrap().len(), to_bytes(b_body).await.unwrap().len());
255 }
256
257 #[tokio::test]
258 async fn retry() {
259 let client = TokioClient::new();
260
261 let server = MockServer::start_async().await;
263
264 let hello_mock = server.mock(|when, then| {
266 when.method(GET).path("/server-errors");
267 then.status(500);
268 });
269
270 let req = Request::get(Url::parse(&server.url("/server-errors")).unwrap());
271 let res = client.execute::<()>(req, &RequestStrategy::Retry(5)).await;
272
273 hello_mock.assert_hits_async(5).await;
274 assert!(res.is_err());
275 }
276
277 #[tokio::test]
278 async fn user_error() {
279 let client = TokioClient::new();
280
281 let server = MockServer::start_async().await;
283
284 let mock = server.mock(|when, then| {
285 when.method(GET).path("/v1/missing");
286 then.status(404).body("{
287 \"error\": {
288 \"message\": \"Unrecognized request URL (GET: /v1/missing). Please see https://stripe.com/docs or we can help at https://support.stripe.com/.\",
289 \"type\": \"invalid_request_error\"
290 }
291 }
292 ");
293 });
294
295 let req = Request::get(Url::parse(&server.url("/v1/missing")).unwrap());
296 let res = client.execute::<()>(req, &RequestStrategy::Retry(3)).await;
297
298 mock.assert_hits_async(1).await;
299
300 match res {
301 Err(StripeError::Stripe(x)) => println!("{:?}", x),
302 _ => panic!("Expected stripe error {:?}", res),
303 }
304 }
305
306 #[tokio::test]
307 async fn nice_serde_error() {
308 use serde::Deserialize;
309
310 #[derive(Debug, Deserialize)]
311 struct DataType {
312 #[allow(dead_code)]
314 id: String,
315 #[allow(dead_code)]
316 name: String,
317 }
318
319 let client = TokioClient::new();
320
321 let server = MockServer::start_async().await;
323
324 let mock = server.mock(|when, then| {
325 when.method(GET).path("/v1/odd_data");
326 then.status(200).body(
327 "{
328 \"id\": \"test\",
329 \"name\": 10
330 }
331 ",
332 );
333 });
334
335 let req = Request::get(Url::parse(&server.url("/v1/odd_data")).unwrap());
336 let res = client.execute::<DataType>(req, &RequestStrategy::Retry(3)).await;
337
338 mock.assert_hits_async(1).await;
339
340 match res {
341 Err(StripeError::JSONSerialize(err)) => {
342 println!("Error: {:?} Path: {:?}", err.inner(), err.path().to_string())
343 }
344 _ => panic!("Expected stripe error {:?}", res),
345 }
346 }
347
348 #[tokio::test]
349 async fn retry_header() {
350 let client = TokioClient::new();
351
352 let server = MockServer::start_async().await;
354
355 let hello_mock = server.mock(|when, then| {
357 when.method(GET).path("/server-errors");
358 then.status(500).header("Stripe-Should-Retry", "false");
359 });
360
361 let req = Request::get(Url::parse(&server.url("/server-errors")).unwrap());
362 let res = client.execute::<()>(req, &RequestStrategy::Retry(5)).await;
363
364 hello_mock.assert_hits_async(1).await;
365 assert!(res.is_err());
366 }
367
368 #[tokio::test]
369 async fn retry_body() {
370 let client = TokioClient::new();
371
372 let server = MockServer::start_async().await;
374
375 let hello_mock = server.mock(|when, then| {
377 when.method(POST).path("/server-errors").body("body");
378 then.status(500);
379 });
380
381 let mut req = Request::post(Url::parse(&server.url("/server-errors")).unwrap());
382 req.set_body("body");
383 let res = client.execute::<()>(req, &RequestStrategy::Retry(5)).await;
384
385 hello_mock.assert_hits_async(5).await;
386 assert!(res.is_err());
387 }
388}