use crate::protocol::{order_gateway::*, ErrorResponse, HealthResponse};
use crate::types::trading::{Order, PlaceOrder};
use anyhow::{anyhow, bail, Result};
use chrono::{DateTime, Utc};
use log::{debug, trace};
use reqwest;
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::time::Duration;
use url::Url;
pub struct OrderGatewayRestClient {
client: reqwest::Client,
base_url: Url,
token: Option<String>,
token_expires_at: Option<DateTime<Utc>>,
}
impl OrderGatewayRestClient {
pub fn new(base_url: Url) -> Result<Self> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
Ok(Self {
client,
base_url,
token: None,
token_expires_at: None,
})
}
pub fn set_token(&mut self, token: String, expires_at: DateTime<Utc>) {
self.token = Some(token);
self.token_expires_at = Some(expires_at);
}
fn token(&self) -> Result<&str> {
if let Some(token) = &self.token {
if self.token_expires_at.is_some_and(|exp| Utc::now() > exp) {
bail!("token expired")
}
return Ok(token);
} else {
bail!("token not available")
}
}
async fn request<T: Serialize, R: DeserializeOwned>(
&self,
method: reqwest::Method,
path: &str,
params: Option<T>,
auth: bool,
) -> Result<R> {
let url = self.base_url.join(path)?;
debug!("=> {} {}", method, url);
let mut req = self
.client
.request(method.clone(), url.clone())
.header("Content-Type", "application/json");
if auth {
let token = self.token()?;
req = req.header("Authorization", format!("{}", token));
}
if let Some(params) = params {
if method == reqwest::Method::POST
|| method == reqwest::Method::PUT
|| method == reqwest::Method::PATCH
{
req = req.json(¶ms);
} else {
req = req.query(¶ms);
}
}
let res = req.send().await?;
let res_status = res.status();
let res_text = res.text().await?;
trace!("<= {method} {url}: {res_status}");
trace!("<= {res_text}");
if res_status.is_success() {
Ok(serde_json::from_str(&res_text)?)
} else {
match serde_json::from_str::<ErrorResponse>(&res_text) {
Ok(error_response) => Err(anyhow!(error_response.error)),
Err(e) => Err(anyhow!("while parsing error response: {e:?}")),
}
}
}
pub async fn health(&self) -> Result<HealthResponse> {
self.request(reqwest::Method::GET, "health", None::<&str>, false)
.await
}
pub async fn open_orders(&self) -> Result<Vec<Order>> {
let payload = GetOpenOrdersRequest {};
let res: GetOpenOrdersResponse = self
.request(reqwest::Method::GET, "open_orders", Some(payload), true)
.await?;
let orders = res
.orders
.into_iter()
.map(|o| o.try_into())
.collect::<Result<Vec<Order>>>()?;
Ok(orders)
}
pub async fn place_order(&self, order: PlaceOrder) -> Result<String> {
let payload: PlaceOrderRequest = order.into();
let res: PlaceOrderResponse = self
.request(reqwest::Method::POST, "place_order", Some(payload), true)
.await?;
Ok(res.order_id)
}
pub async fn cancel_order(&self, order_id: impl AsRef<str>) -> Result<bool> {
let payload = CancelOrderRequest {
order_id: order_id.as_ref().to_string(),
};
let res: CancelOrderResponse = self
.request(reqwest::Method::POST, "cancel_order", Some(payload), true)
.await?;
Ok(res.cancel_request_accepted)
}
}