use std::sync::Arc;
use reqwest::Client as HttpClient;
use serde::{de::DeserializeOwned, Serialize};
use crate::config::Config;
use crate::error::{Error, ErrorCode, Result};
use crate::models::ApiResponse;
use crate::token::TokenManager;
#[derive(Debug, Clone)]
pub struct RequestOptions {
pub trade_type: Option<i32>,
pub lang: Option<String>,
}
impl Default for RequestOptions {
fn default() -> Self {
Self::new()
}
}
impl RequestOptions {
pub fn new() -> Self {
RequestOptions {
trade_type: None,
lang: None,
}
}
pub fn with_trade_type(mut self, trade_type: i32) -> Self {
self.trade_type = Some(trade_type);
self
}
pub fn with_lang(mut self, lang: impl Into<String>) -> Self {
self.lang = Some(lang.into());
self
}
}
#[derive(Debug, Clone)]
pub struct BaseClient {
config: Arc<Config>,
http_client: HttpClient,
token_manager: Arc<TokenManager>,
}
impl BaseClient {
pub fn new(config: Config, http_client: HttpClient, token_manager: Arc<TokenManager>) -> Self {
BaseClient {
config: Arc::new(config),
http_client,
token_manager,
}
}
pub async fn do_request<T, R>(
&self,
method: reqwest::Method,
path: &str,
body: Option<&T>,
opts: Option<RequestOptions>,
) -> Result<R>
where
T: Serialize,
R: DeserializeOwned,
{
let opts = opts.unwrap_or_default();
let result = self.execute_request(&method, path, body, &opts).await;
if let Err(Error::Api { code, .. }) = &result {
if *code == ErrorCode::TokenExpired as i32 {
self.token_manager.refresh().await?;
return self.execute_request(&method, path, body, &opts).await;
}
}
result
}
async fn execute_request<T, R>(
&self,
method: &reqwest::Method,
path: &str,
body: Option<&T>,
opts: &RequestOptions,
) -> Result<R>
where
T: Serialize,
R: DeserializeOwned,
{
let token = self.token_manager.token().await?;
let url = format!("{}{}", self.config.base_url, path);
let mut request = self.http_client.request(method.clone(), &url);
request = request
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", token))
.header("apikey", &self.config.api_key)
.header(
"tradeType",
opts.trade_type.unwrap_or(self.config.trade_type).to_string(),
);
if let Some(lang) = opts.lang.as_ref().or(Some(&self.config.lang)) {
request = request.header("lang", lang);
}
if let Some(body) = body {
request = request.json(body);
}
let response = request.send().await?;
let resp_text = response.text().await?;
self.parse_response(&resp_text)
}
fn parse_response<R>(&self, resp_text: &str) -> Result<R>
where
R: DeserializeOwned,
{
let api_resp: ApiResponse = serde_json::from_str(resp_text).map_err(|e| {
Error::parse(resp_text, format!("failed to parse response: {}", e))
})?;
match ErrorCode::from_code(api_resp.code) {
Some(ErrorCode::Success) => {
serde_json::from_value(api_resp.data).map_err(|e| {
Error::parse(
resp_text,
format!("failed to deserialize response data: {}", e),
)
})
}
Some(ErrorCode::ParamError) => {
Err(Error::api(ErrorCode::ParamError as i32, api_resp.msg))
}
Some(ErrorCode::NoPermission) => {
Err(Error::api(ErrorCode::NoPermission as i32, api_resp.msg))
}
Some(ErrorCode::TokenExpired) => {
Err(Error::api(ErrorCode::TokenExpired as i32, api_resp.msg))
}
Some(ErrorCode::ServerError) => {
Err(Error::api(ErrorCode::ServerError as i32, api_resp.msg))
}
Some(ErrorCode::RateLimit) => {
Err(Error::api(ErrorCode::RateLimit as i32, api_resp.msg))
}
None => {
Err(Error::api(api_resp.code, api_resp.msg))
}
}
}
pub async fn do_get<R>(&self, path: &str, opts: Option<RequestOptions>) -> Result<R>
where
R: DeserializeOwned,
{
self.do_request::<(), R>(reqwest::Method::GET, path, None, opts)
.await
}
pub async fn do_post<T, R>(
&self,
path: &str,
body: &T,
opts: Option<RequestOptions>,
) -> Result<R>
where
T: Serialize,
R: DeserializeOwned,
{
self.do_request(reqwest::Method::POST, path, Some(body), opts)
.await
}
pub fn config(&self) -> &Config {
&self.config
}
pub fn token_manager(&self) -> &TokenManager {
&self.token_manager
}
}