stripe/hyper/
client.rs

1use std::fmt::Write as _;
2
3use http_body_util::{BodyExt, Full};
4use hyper::body::Bytes;
5use hyper::header::{AUTHORIZATION, CONTENT_TYPE, USER_AGENT};
6use hyper::http::request::Builder;
7use hyper::http::{HeaderName, HeaderValue};
8use hyper::{Request, StatusCode};
9use hyper_util::client::legacy::Client as HyperClient;
10use hyper_util::rt::TokioExecutor;
11use miniserde::json::from_str;
12use stripe_client_core::{CustomizedStripeRequest, RequestBuilder, StripeMethod};
13use stripe_client_core::{Outcome, RequestStrategy};
14use stripe_shared::AccountId;
15
16use crate::StripeError;
17use crate::hyper::client_builder::{ClientBuilder, ClientConfig};
18
19/// A client for making Stripe API requests.
20#[derive(Clone, Debug)]
21pub struct Client {
22    client: HyperClient<crate::hyper::connector::Connector, Full<Bytes>>,
23    config: ClientConfig,
24}
25
26impl Client {
27    pub(crate) fn from_config(config: ClientConfig) -> Self {
28        Self {
29            client: HyperClient::builder(TokioExecutor::new())
30                .pool_max_idle_per_host(0)
31                .build(crate::hyper::connector::create()),
32            config,
33        }
34    }
35
36    /// Construct a `client` with the given secret key and a default configuration.
37    ///
38    /// # Panics
39    /// This method panics if secret key is not usable as a header value.
40    pub fn new(secret_key: impl Into<String>) -> Self {
41        ClientBuilder::new(secret_key).build().expect("invalid secret provided")
42    }
43
44    fn get_account_id_header(
45        &self,
46        account_id_override: Option<AccountId>,
47    ) -> Result<Option<HeaderValue>, StripeError> {
48        if let Some(overridden) = account_id_override {
49            return Ok(Some(HeaderValue::from_str(overridden.as_str()).map_err(|_| {
50                StripeError::ConfigError("invalid account id set in customizations".into())
51            })?));
52        }
53        Ok(self.config.account_id.clone())
54    }
55
56    fn construct_request(
57        &self,
58        req: RequestBuilder,
59        account_id: Option<AccountId>,
60    ) -> Result<(Builder, Option<Bytes>), StripeError> {
61        let mut uri = format!("{}v1{}", self.config.api_base, req.path);
62        if let Some(query) = req.query {
63            let _ = write!(uri, "?{query}");
64        }
65
66        let mut builder = Request::builder()
67            .method(conv_stripe_method(req.method))
68            .uri(uri)
69            .header(AUTHORIZATION, self.config.secret.clone())
70            .header(USER_AGENT, self.config.user_agent.clone())
71            .header(HeaderName::from_static("stripe-version"), self.config.stripe_version.clone());
72
73        if let Some(client_id) = &self.config.client_id {
74            builder = builder.header(HeaderName::from_static("client-id"), client_id.clone());
75        }
76        if let Some(account_id) = self.get_account_id_header(account_id)? {
77            builder = builder.header(HeaderName::from_static("stripe-account"), account_id);
78        }
79
80        let body = if let Some(body) = req.body {
81            builder = builder.header(
82                CONTENT_TYPE,
83                HeaderValue::from_static("application/x-www-form-urlencoded"),
84            );
85            Some(Bytes::from(body))
86        } else {
87            None
88        };
89        Ok((builder, body))
90    }
91
92    async fn send_inner(
93        &self,
94        body: Option<Bytes>,
95        mut req_builder: Builder,
96        strategy: RequestStrategy,
97    ) -> Result<Bytes, StripeError> {
98        let mut tries = 0;
99        let mut last_status: Option<StatusCode> = None;
100        let mut last_retry_header: Option<bool> = None;
101        let mut last_error = StripeError::ClientError("invalid strategy".into());
102
103        if let Some(key) = strategy.get_key() {
104            const HEADER_NAME: HeaderName = HeaderName::from_static("idempotency-key");
105            req_builder = req_builder.header(HEADER_NAME, key.as_str());
106        }
107
108        let req = req_builder.body(Full::new(body.unwrap_or_default()))?;
109
110        loop {
111            return match strategy.test(last_status.map(|s| s.as_u16()), last_retry_header, tries) {
112                Outcome::Stop => Err(last_error),
113                Outcome::Continue(duration) => {
114                    if let Some(duration) = duration {
115                        tokio::time::sleep(duration).await;
116                    }
117
118                    let response = match self.client.request(req.clone()).await {
119                        Ok(resp) => resp,
120                        Err(err) => {
121                            last_error = StripeError::from(err);
122                            tries += 1;
123                            continue;
124                        }
125                    };
126                    let status = response.status();
127                    let retry = response
128                        .headers()
129                        .get("Stripe-Should-Retry")
130                        .and_then(|s| s.to_str().ok())
131                        .and_then(|s| s.parse().ok());
132
133                    let bytes = response.into_body().collect().await?.to_bytes();
134                    if !status.is_success() {
135                        tries += 1;
136
137                        let str = std::str::from_utf8(bytes.as_ref()).map_err(|_| {
138                            StripeError::JSONDeserialize("Response was not valid UTF-8".into())
139                        })?;
140                        last_error = from_str(str)
141                            .map(|e: stripe_shared::Error| {
142                                StripeError::Stripe(e.error, status.as_u16())
143                            })
144                            .unwrap_or_else(|_| {
145                                StripeError::JSONDeserialize(
146                                    "error deserializing Stripe error".into(),
147                                )
148                            });
149                        last_status = Some(status);
150                        last_retry_header = retry;
151                        continue;
152                    }
153                    Ok(bytes)
154                }
155            };
156        }
157    }
158}
159
160fn conv_stripe_method(method: StripeMethod) -> hyper::Method {
161    match method {
162        StripeMethod::Get => hyper::Method::GET,
163        StripeMethod::Post => hyper::Method::POST,
164        StripeMethod::Delete => hyper::Method::DELETE,
165    }
166}
167
168impl stripe_client_core::StripeClient for Client {
169    type Err = StripeError;
170
171    async fn execute(&self, req_full: CustomizedStripeRequest) -> Result<Bytes, Self::Err> {
172        let (req, config) = req_full.into_pieces();
173        let (builder, body) = self.construct_request(req, config.account_id)?;
174
175        let request_strategy =
176            config.request_strategy.unwrap_or_else(|| self.config.request_strategy.clone());
177        self.send_inner(body, builder, request_strategy).await
178    }
179}