1use std::{convert::Infallible, error::Error, marker::PhantomData, time::Duration};
2
3use reqwest::{
4 header::{HeaderMap, HeaderName, HeaderValue},
5 Method, StatusCode,
6};
7use serde::{de::DeserializeOwned, Deserialize, Serialize};
8
9use crate::{
10 signature::{signature, SignatureParams},
11 timestamp::Timestamp,
12 HttpClient, HttpClientError, HttpClientResult,
13};
14
15const USER_AGENT: &str = "openapi-sdk";
16const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
17const RETRY_COUNT: usize = 5;
18const RETRY_INITIAL_DELAY: Duration = Duration::from_millis(100);
19const RETRY_FACTOR: f32 = 2.0;
20
21pub struct Json<T>(pub T);
23
24pub trait FromPayload: Sized + Send + Sync + 'static {
26 type Err: Error;
28
29 fn parse_from_bytes(data: &[u8]) -> Result<Self, Self::Err>;
31}
32
33pub trait ToPayload: Sized + Send + Sync + 'static {
35 type Err: Error;
37
38 fn to_bytes(&self) -> Result<Vec<u8>, Self::Err>;
40}
41
42impl<T> FromPayload for Json<T>
43where
44 T: DeserializeOwned + Send + Sync + 'static,
45{
46 type Err = serde_json::Error;
47
48 #[inline]
49 fn parse_from_bytes(data: &[u8]) -> Result<Self, Self::Err> {
50 Ok(Json(serde_json::from_slice(data)?))
51 }
52}
53
54impl<T> ToPayload for Json<T>
55where
56 T: Serialize + Send + Sync + 'static,
57{
58 type Err = serde_json::Error;
59
60 #[inline]
61 fn to_bytes(&self) -> Result<Vec<u8>, Self::Err> {
62 serde_json::to_vec(&self.0)
63 }
64}
65
66impl FromPayload for String {
67 type Err = std::string::FromUtf8Error;
68
69 #[inline]
70 fn parse_from_bytes(data: &[u8]) -> Result<Self, Self::Err> {
71 String::from_utf8(data.to_vec())
72 }
73}
74
75impl ToPayload for String {
76 type Err = std::string::FromUtf8Error;
77
78 #[inline]
79 fn to_bytes(&self) -> Result<Vec<u8>, Self::Err> {
80 Ok(self.clone().into_bytes())
81 }
82}
83
84impl FromPayload for () {
85 type Err = Infallible;
86
87 #[inline]
88 fn parse_from_bytes(_data: &[u8]) -> Result<Self, Self::Err> {
89 Ok(())
90 }
91}
92
93impl ToPayload for () {
94 type Err = Infallible;
95
96 #[inline]
97 fn to_bytes(&self) -> Result<Vec<u8>, Self::Err> {
98 Ok(vec![])
99 }
100}
101
102#[derive(Deserialize)]
103struct OpenApiResponse {
104 code: i32,
105 message: String,
106 data: Option<Box<serde_json::value::RawValue>>,
107}
108
109pub struct RequestBuilder<T, Q, R> {
111 client: HttpClient,
112 method: Method,
113 path: String,
114 headers: HeaderMap,
115 body: Option<T>,
116 query_params: Option<Q>,
117 mark_resp: PhantomData<R>,
118}
119
120impl RequestBuilder<(), (), ()> {
121 pub(crate) fn new(client: HttpClient, method: Method, path: impl Into<String>) -> Self {
122 Self {
123 client,
124 method,
125 path: path.into(),
126 headers: Default::default(),
127 body: None,
128 query_params: None,
129 mark_resp: PhantomData,
130 }
131 }
132}
133
134impl<T, Q, R> RequestBuilder<T, Q, R> {
135 #[must_use]
137 pub fn body<T2>(self, body: T2) -> RequestBuilder<T2, Q, R>
138 where
139 T2: ToPayload,
140 {
141 RequestBuilder {
142 client: self.client,
143 method: self.method,
144 path: self.path,
145 headers: self.headers,
146 body: Some(body),
147 query_params: self.query_params,
148 mark_resp: self.mark_resp,
149 }
150 }
151
152 #[must_use]
154 pub fn header<K, V>(mut self, key: K, value: V) -> Self
155 where
156 K: TryInto<HeaderName>,
157 V: TryInto<HeaderValue>,
158 {
159 let key = key.try_into();
160 let value = value.try_into();
161 if let (Ok(key), Ok(value)) = (key, value) {
162 self.headers.append(key, value);
163 }
164 self
165 }
166
167 #[must_use]
169 pub fn query_params<Q2>(self, params: Q2) -> RequestBuilder<T, Q2, R>
170 where
171 Q2: Serialize + Send + Sync,
172 {
173 RequestBuilder {
174 client: self.client,
175 method: self.method,
176 path: self.path,
177 headers: self.headers,
178 body: self.body,
179 query_params: Some(params),
180 mark_resp: self.mark_resp,
181 }
182 }
183
184 #[must_use]
186 pub fn response<R2>(self) -> RequestBuilder<T, Q, R2>
187 where
188 R2: FromPayload,
189 {
190 RequestBuilder {
191 client: self.client,
192 method: self.method,
193 path: self.path,
194 headers: self.headers,
195 body: self.body,
196 query_params: self.query_params,
197 mark_resp: PhantomData,
198 }
199 }
200}
201
202impl<T, Q, R> RequestBuilder<T, Q, R>
203where
204 T: ToPayload,
205 Q: Serialize + Send,
206 R: FromPayload,
207{
208 async fn do_send(&self) -> HttpClientResult<R> {
209 let HttpClient {
210 http_cli,
211 config,
212 default_headers,
213 } = &self.client;
214 let timestamp = self
215 .headers
216 .get("X-Timestamp")
217 .and_then(|value| value.to_str().ok())
218 .and_then(|value| value.parse().ok())
219 .unwrap_or_else(Timestamp::now);
220 let app_key_value =
221 HeaderValue::from_str(&config.app_key).map_err(|_| HttpClientError::InvalidApiKey)?;
222 let access_token_value = HeaderValue::from_str(&config.access_token)
223 .map_err(|_| HttpClientError::InvalidAccessToken)?;
224
225 let mut request_builder = http_cli
226 .request(
227 self.method.clone(),
228 format!("{}{}", config.http_url, self.path),
229 )
230 .headers(default_headers.clone())
231 .headers(self.headers.clone())
232 .header("User-Agent", USER_AGENT)
233 .header("X-Api-Key", app_key_value)
234 .header("Authorization", access_token_value)
235 .header("X-Timestamp", timestamp.to_string())
236 .header("Content-Type", "application/json; charset=utf-8");
237
238 if let Some(body) = &self.body {
240 let body = body
241 .to_bytes()
242 .map_err(|err| HttpClientError::SerializeRequestBody(err.to_string()))?;
243 request_builder = request_builder.body(body);
244 }
245
246 let mut request = request_builder.build().expect("invalid request");
247
248 if let Some(query_params) = &self.query_params {
250 let query_string = crate::qs::to_string(&query_params)?;
251 request.url_mut().set_query(Some(&query_string));
252 }
253
254 let sign = signature(SignatureParams {
256 request: &request,
257 app_key: &config.app_key,
258 access_token: Some(&config.access_token),
259 app_secret: &config.app_secret,
260 timestamp,
261 });
262 request.headers_mut().insert(
263 "X-Api-Signature",
264 HeaderValue::from_maybe_shared(sign).expect("valid signature"),
265 );
266
267 tracing::debug!(method = %request.method(), url = %request.url(), "http request");
268
269 let (status, trace_id, text) = tokio::time::timeout(REQUEST_TIMEOUT, async move {
271 let resp = http_cli.execute(request).await?;
272 let status = resp.status();
273 let trace_id = resp
274 .headers()
275 .get("x-trace-id")
276 .and_then(|value| value.to_str().ok())
277 .unwrap_or_default()
278 .to_string();
279 let text = resp.text().await.map_err(HttpClientError::from)?;
280 Ok::<_, HttpClientError>((status, trace_id, text))
281 })
282 .await
283 .map_err(|_| HttpClientError::RequestTimeout)??;
284
285 tracing::debug!(body = text.as_str(), "http response");
286
287 let resp = match serde_json::from_str::<OpenApiResponse>(&text) {
288 Ok(resp) if resp.code == 0 => resp.data.ok_or(HttpClientError::UnexpectedResponse),
289 Ok(resp) => Err(HttpClientError::OpenApi {
290 code: resp.code,
291 message: resp.message,
292 trace_id,
293 }),
294 Err(err) if status == StatusCode::OK => {
295 Err(HttpClientError::DeserializeResponseBody(err.to_string()))
296 }
297 Err(_) => Err(HttpClientError::BadStatus(status)),
298 }?;
299
300 R::parse_from_bytes(resp.get().as_bytes())
301 .map_err(|err| HttpClientError::DeserializeResponseBody(err.to_string()))
302 }
303
304 #[tracing::instrument(level = "debug", skip(self))]
306 pub async fn send(self) -> HttpClientResult<R> {
307 match self.do_send().await {
308 Ok(resp) => Ok(resp),
309 Err(HttpClientError::BadStatus(StatusCode::TOO_MANY_REQUESTS)) => {
310 let mut retry_delay = RETRY_INITIAL_DELAY;
311
312 for _ in 0..RETRY_COUNT {
313 tokio::time::sleep(retry_delay).await;
314
315 match self.do_send().await {
316 Ok(resp) => return Ok(resp),
317 Err(HttpClientError::BadStatus(StatusCode::TOO_MANY_REQUESTS)) => {
318 retry_delay =
319 Duration::from_secs_f32(retry_delay.as_secs_f32() * RETRY_FACTOR);
320 continue;
321 }
322 Err(err) => return Err(err),
323 }
324 }
325
326 Err(HttpClientError::BadStatus(StatusCode::TOO_MANY_REQUESTS))
327 }
328 Err(err) => Err(err),
329 }
330 }
331}