longbridge_httpcli/
request.rs

1use std::{convert::Infallible, error::Error, marker::PhantomData, time::Duration};
2
3use reqwest::{
4    header::{HeaderMap, HeaderName, HeaderValue},
5    Method, StatusCode,
6};
7use serde::{de::DeserializeOwned, Deserialize, Serialize};
8
9use crate::{
10    signature::{signature, SignatureParams},
11    timestamp::Timestamp,
12    HttpClient, HttpClientError, HttpClientResult,
13};
14
15const USER_AGENT: &str = "openapi-sdk";
16const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
17const RETRY_COUNT: usize = 5;
18const RETRY_INITIAL_DELAY: Duration = Duration::from_millis(100);
19const RETRY_FACTOR: f32 = 2.0;
20
21/// A JSON payload
22pub struct Json<T>(pub T);
23
24/// Represents a type that can parse from payload
25pub trait FromPayload: Sized + Send + Sync + 'static {
26    /// A error type
27    type Err: Error;
28
29    /// Parse the payload to this object
30    fn parse_from_bytes(data: &[u8]) -> Result<Self, Self::Err>;
31}
32
33/// Represents a type that can convert to payload
34pub trait ToPayload: Sized + Send + Sync + 'static {
35    /// A error type
36    type Err: Error;
37
38    /// Convert this object to the payload
39    fn to_bytes(&self) -> Result<Vec<u8>, Self::Err>;
40}
41
42impl<T> FromPayload for Json<T>
43where
44    T: DeserializeOwned + Send + Sync + 'static,
45{
46    type Err = serde_json::Error;
47
48    #[inline]
49    fn parse_from_bytes(data: &[u8]) -> Result<Self, Self::Err> {
50        Ok(Json(serde_json::from_slice(data)?))
51    }
52}
53
54impl<T> ToPayload for Json<T>
55where
56    T: Serialize + Send + Sync + 'static,
57{
58    type Err = serde_json::Error;
59
60    #[inline]
61    fn to_bytes(&self) -> Result<Vec<u8>, Self::Err> {
62        serde_json::to_vec(&self.0)
63    }
64}
65
66impl FromPayload for String {
67    type Err = std::string::FromUtf8Error;
68
69    #[inline]
70    fn parse_from_bytes(data: &[u8]) -> Result<Self, Self::Err> {
71        String::from_utf8(data.to_vec())
72    }
73}
74
75impl ToPayload for String {
76    type Err = std::string::FromUtf8Error;
77
78    #[inline]
79    fn to_bytes(&self) -> Result<Vec<u8>, Self::Err> {
80        Ok(self.clone().into_bytes())
81    }
82}
83
84impl FromPayload for () {
85    type Err = Infallible;
86
87    #[inline]
88    fn parse_from_bytes(_data: &[u8]) -> Result<Self, Self::Err> {
89        Ok(())
90    }
91}
92
93impl ToPayload for () {
94    type Err = Infallible;
95
96    #[inline]
97    fn to_bytes(&self) -> Result<Vec<u8>, Self::Err> {
98        Ok(vec![])
99    }
100}
101
102#[derive(Deserialize)]
103struct OpenApiResponse {
104    code: i32,
105    message: String,
106    data: Option<Box<serde_json::value::RawValue>>,
107}
108
109/// A request builder
110pub struct RequestBuilder<T, Q, R> {
111    client: HttpClient,
112    method: Method,
113    path: String,
114    headers: HeaderMap,
115    body: Option<T>,
116    query_params: Option<Q>,
117    mark_resp: PhantomData<R>,
118}
119
120impl RequestBuilder<(), (), ()> {
121    pub(crate) fn new(client: HttpClient, method: Method, path: impl Into<String>) -> Self {
122        Self {
123            client,
124            method,
125            path: path.into(),
126            headers: Default::default(),
127            body: None,
128            query_params: None,
129            mark_resp: PhantomData,
130        }
131    }
132}
133
134impl<T, Q, R> RequestBuilder<T, Q, R> {
135    /// Set the request body
136    #[must_use]
137    pub fn body<T2>(self, body: T2) -> RequestBuilder<T2, Q, R>
138    where
139        T2: ToPayload,
140    {
141        RequestBuilder {
142            client: self.client,
143            method: self.method,
144            path: self.path,
145            headers: self.headers,
146            body: Some(body),
147            query_params: self.query_params,
148            mark_resp: self.mark_resp,
149        }
150    }
151
152    /// Set the header
153    #[must_use]
154    pub fn header<K, V>(mut self, key: K, value: V) -> Self
155    where
156        K: TryInto<HeaderName>,
157        V: TryInto<HeaderValue>,
158    {
159        let key = key.try_into();
160        let value = value.try_into();
161        if let (Ok(key), Ok(value)) = (key, value) {
162            self.headers.append(key, value);
163        }
164        self
165    }
166
167    /// Set the query string
168    #[must_use]
169    pub fn query_params<Q2>(self, params: Q2) -> RequestBuilder<T, Q2, R>
170    where
171        Q2: Serialize + Send + Sync,
172    {
173        RequestBuilder {
174            client: self.client,
175            method: self.method,
176            path: self.path,
177            headers: self.headers,
178            body: self.body,
179            query_params: Some(params),
180            mark_resp: self.mark_resp,
181        }
182    }
183
184    /// Set the response body type
185    #[must_use]
186    pub fn response<R2>(self) -> RequestBuilder<T, Q, R2>
187    where
188        R2: FromPayload,
189    {
190        RequestBuilder {
191            client: self.client,
192            method: self.method,
193            path: self.path,
194            headers: self.headers,
195            body: self.body,
196            query_params: self.query_params,
197            mark_resp: PhantomData,
198        }
199    }
200}
201
202impl<T, Q, R> RequestBuilder<T, Q, R>
203where
204    T: ToPayload,
205    Q: Serialize + Send,
206    R: FromPayload,
207{
208    async fn do_send(&self) -> HttpClientResult<R> {
209        let HttpClient {
210            http_cli,
211            config,
212            default_headers,
213        } = &self.client;
214        let timestamp = self
215            .headers
216            .get("X-Timestamp")
217            .and_then(|value| value.to_str().ok())
218            .and_then(|value| value.parse().ok())
219            .unwrap_or_else(Timestamp::now);
220        let app_key_value =
221            HeaderValue::from_str(&config.app_key).map_err(|_| HttpClientError::InvalidApiKey)?;
222        let access_token_value = HeaderValue::from_str(&config.access_token)
223            .map_err(|_| HttpClientError::InvalidAccessToken)?;
224
225        let mut request_builder = http_cli
226            .request(
227                self.method.clone(),
228                format!("{}{}", config.http_url, self.path),
229            )
230            .headers(default_headers.clone())
231            .headers(self.headers.clone())
232            .header("User-Agent", USER_AGENT)
233            .header("X-Api-Key", app_key_value)
234            .header("Authorization", access_token_value)
235            .header("X-Timestamp", timestamp.to_string())
236            .header("Content-Type", "application/json; charset=utf-8");
237
238        // set the request body
239        if let Some(body) = &self.body {
240            let body = body
241                .to_bytes()
242                .map_err(|err| HttpClientError::SerializeRequestBody(err.to_string()))?;
243            request_builder = request_builder.body(body);
244        }
245
246        let mut request = request_builder.build().expect("invalid request");
247
248        // set the query string
249        if let Some(query_params) = &self.query_params {
250            let query_string = crate::qs::to_string(&query_params)?;
251            request.url_mut().set_query(Some(&query_string));
252        }
253
254        // signature the request
255        let sign = signature(SignatureParams {
256            request: &request,
257            app_key: &config.app_key,
258            access_token: Some(&config.access_token),
259            app_secret: &config.app_secret,
260            timestamp,
261        });
262        request.headers_mut().insert(
263            "X-Api-Signature",
264            HeaderValue::from_maybe_shared(sign).expect("valid signature"),
265        );
266
267        tracing::debug!(method = %request.method(), url = %request.url(), "http request");
268
269        // send request
270        let (status, trace_id, text) = tokio::time::timeout(REQUEST_TIMEOUT, async move {
271            let resp = http_cli.execute(request).await?;
272            let status = resp.status();
273            let trace_id = resp
274                .headers()
275                .get("x-trace-id")
276                .and_then(|value| value.to_str().ok())
277                .unwrap_or_default()
278                .to_string();
279            let text = resp.text().await.map_err(HttpClientError::from)?;
280            Ok::<_, HttpClientError>((status, trace_id, text))
281        })
282        .await
283        .map_err(|_| HttpClientError::RequestTimeout)??;
284
285        tracing::debug!(body = text.as_str(), "http response");
286
287        let resp = match serde_json::from_str::<OpenApiResponse>(&text) {
288            Ok(resp) if resp.code == 0 => resp.data.ok_or(HttpClientError::UnexpectedResponse),
289            Ok(resp) => Err(HttpClientError::OpenApi {
290                code: resp.code,
291                message: resp.message,
292                trace_id,
293            }),
294            Err(err) if status == StatusCode::OK => {
295                Err(HttpClientError::DeserializeResponseBody(err.to_string()))
296            }
297            Err(_) => Err(HttpClientError::BadStatus(status)),
298        }?;
299
300        R::parse_from_bytes(resp.get().as_bytes())
301            .map_err(|err| HttpClientError::DeserializeResponseBody(err.to_string()))
302    }
303
304    /// Send request and get the response
305    #[tracing::instrument(level = "debug", skip(self))]
306    pub async fn send(self) -> HttpClientResult<R> {
307        match self.do_send().await {
308            Ok(resp) => Ok(resp),
309            Err(HttpClientError::BadStatus(StatusCode::TOO_MANY_REQUESTS)) => {
310                let mut retry_delay = RETRY_INITIAL_DELAY;
311
312                for _ in 0..RETRY_COUNT {
313                    tokio::time::sleep(retry_delay).await;
314
315                    match self.do_send().await {
316                        Ok(resp) => return Ok(resp),
317                        Err(HttpClientError::BadStatus(StatusCode::TOO_MANY_REQUESTS)) => {
318                            retry_delay =
319                                Duration::from_secs_f32(retry_delay.as_secs_f32() * RETRY_FACTOR);
320                            continue;
321                        }
322                        Err(err) => return Err(err),
323                    }
324                }
325
326                Err(HttpClientError::BadStatus(StatusCode::TOO_MANY_REQUESTS))
327            }
328            Err(err) => Err(err),
329        }
330    }
331}