1use std::fmt::Write as _;
2
3use http_body_util::{BodyExt, Full};
4use hyper::body::Bytes;
5use hyper::header::{AUTHORIZATION, CONTENT_TYPE, USER_AGENT};
6use hyper::http::request::Builder;
7use hyper::http::{HeaderName, HeaderValue};
8use hyper::{Request, StatusCode};
9use hyper_util::client::legacy::Client as HyperClient;
10use hyper_util::rt::TokioExecutor;
11use miniserde::json::from_str;
12use stripe_client_core::{CustomizedStripeRequest, RequestBuilder, StripeMethod};
13use stripe_client_core::{Outcome, RequestStrategy};
14use stripe_shared::AccountId;
15
16use crate::StripeError;
17use crate::hyper::client_builder::{ClientBuilder, ClientConfig};
18
19#[derive(Clone, Debug)]
21pub struct Client {
22 client: HyperClient<crate::hyper::connector::Connector, Full<Bytes>>,
23 config: ClientConfig,
24}
25
26impl Client {
27 pub(crate) fn from_config(config: ClientConfig) -> Self {
28 Self {
29 client: HyperClient::builder(TokioExecutor::new())
30 .pool_max_idle_per_host(0)
31 .build(crate::hyper::connector::create()),
32 config,
33 }
34 }
35
36 pub fn new(secret_key: impl Into<String>) -> Self {
41 ClientBuilder::new(secret_key).build().expect("invalid secret provided")
42 }
43
44 fn get_account_id_header(
45 &self,
46 account_id_override: Option<AccountId>,
47 ) -> Result<Option<HeaderValue>, StripeError> {
48 if let Some(overridden) = account_id_override {
49 return Ok(Some(HeaderValue::from_str(overridden.as_str()).map_err(|_| {
50 StripeError::ConfigError("invalid account id set in customizations".into())
51 })?));
52 }
53 Ok(self.config.account_id.clone())
54 }
55
56 fn construct_request(
57 &self,
58 req: RequestBuilder,
59 account_id: Option<AccountId>,
60 ) -> Result<(Builder, Option<Bytes>), StripeError> {
61 let mut uri = format!("{}v1{}", self.config.api_base, req.path);
62 if let Some(query) = req.query {
63 let _ = write!(uri, "?{query}");
64 }
65
66 let mut builder = Request::builder()
67 .method(conv_stripe_method(req.method))
68 .uri(uri)
69 .header(AUTHORIZATION, self.config.secret.clone())
70 .header(USER_AGENT, self.config.user_agent.clone())
71 .header(HeaderName::from_static("stripe-version"), self.config.stripe_version.clone());
72
73 if let Some(client_id) = &self.config.client_id {
74 builder = builder.header(HeaderName::from_static("client-id"), client_id.clone());
75 }
76 if let Some(account_id) = self.get_account_id_header(account_id)? {
77 builder = builder.header(HeaderName::from_static("stripe-account"), account_id);
78 }
79
80 let body = if let Some(body) = req.body {
81 builder = builder.header(
82 CONTENT_TYPE,
83 HeaderValue::from_static("application/x-www-form-urlencoded"),
84 );
85 Some(Bytes::from(body))
86 } else {
87 None
88 };
89 Ok((builder, body))
90 }
91
92 async fn send_inner(
93 &self,
94 body: Option<Bytes>,
95 mut req_builder: Builder,
96 strategy: RequestStrategy,
97 ) -> Result<Bytes, StripeError> {
98 let mut tries = 0;
99 let mut last_status: Option<StatusCode> = None;
100 let mut last_retry_header: Option<bool> = None;
101 let mut last_error = StripeError::ClientError("invalid strategy".into());
102
103 if let Some(key) = strategy.get_key() {
104 const HEADER_NAME: HeaderName = HeaderName::from_static("idempotency-key");
105 req_builder = req_builder.header(HEADER_NAME, key.as_str());
106 }
107
108 let req = req_builder.body(Full::new(body.unwrap_or_default()))?;
109
110 loop {
111 return match strategy.test(last_status.map(|s| s.as_u16()), last_retry_header, tries) {
112 Outcome::Stop => Err(last_error),
113 Outcome::Continue(duration) => {
114 if let Some(duration) = duration {
115 tokio::time::sleep(duration).await;
116 }
117
118 let response = match self.client.request(req.clone()).await {
119 Ok(resp) => resp,
120 Err(err) => {
121 last_error = StripeError::from(err);
122 tries += 1;
123 continue;
124 }
125 };
126 let status = response.status();
127 let retry = response
128 .headers()
129 .get("Stripe-Should-Retry")
130 .and_then(|s| s.to_str().ok())
131 .and_then(|s| s.parse().ok());
132
133 let bytes = response.into_body().collect().await?.to_bytes();
134 if !status.is_success() {
135 tries += 1;
136
137 let str = std::str::from_utf8(bytes.as_ref()).map_err(|_| {
138 StripeError::JSONDeserialize("Response was not valid UTF-8".into())
139 })?;
140 last_error = from_str(str)
141 .map(|e: stripe_shared::Error| {
142 StripeError::Stripe(e.error, status.as_u16())
143 })
144 .unwrap_or_else(|_| {
145 StripeError::JSONDeserialize(
146 "error deserializing Stripe error".into(),
147 )
148 });
149 last_status = Some(status);
150 last_retry_header = retry;
151 continue;
152 }
153 Ok(bytes)
154 }
155 };
156 }
157 }
158}
159
160fn conv_stripe_method(method: StripeMethod) -> hyper::Method {
161 match method {
162 StripeMethod::Get => hyper::Method::GET,
163 StripeMethod::Post => hyper::Method::POST,
164 StripeMethod::Delete => hyper::Method::DELETE,
165 }
166}
167
168impl stripe_client_core::StripeClient for Client {
169 type Err = StripeError;
170
171 async fn execute(&self, req_full: CustomizedStripeRequest) -> Result<Bytes, Self::Err> {
172 let (req, config) = req_full.into_pieces();
173 let (builder, body) = self.construct_request(req, config.account_id)?;
174
175 let request_strategy =
176 config.request_strategy.unwrap_or_else(|| self.config.request_strategy.clone());
177 self.send_inner(body, builder, request_strategy).await
178 }
179}