use super::{Credentials, Exchange, RequestHeaders};
use crate::data::DataApi;
use crate::trading::TradingApi;
use std::env;
use std::error::Error;
use std::io::Error as IoError;
use chrono::Utc;
use reqwest::{
Client,
header::{CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue},
};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use crate::crypto::{encode_base64, encrypt_hmac_sha_256};
use crate::model::{
AssetClass, CanceledOrderResponse, ListOrdersRequest, OptionChainSnapshot, Order,
OrderQueryStatus, OrderRequest, OrderSide, OrderStatus, OrderType, ReplaceOrderRequest,
SortDirection, TimeInForce, TradingAccount,
};
const OKX_BASE_URL: &str = "https://www.okx.com";
pub struct OKXCredentials {
api_key: String,
secret_key: String,
passphrase: String,
}
impl OKXCredentials {
pub fn new(api_key: String, secret_key: String, passphrase: String) -> Self {
OKXCredentials {
api_key,
secret_key,
passphrase,
}
}
pub fn env() -> Result<Self, Box<dyn Error>> {
let _ = dotenvy::dotenv();
Ok(OKXCredentials {
api_key: env::var("OKX_API_KEY")?,
secret_key: env::var("OKX_SECRET_KEY")?,
passphrase: env::var("OKX_PASSPHRASE")?,
})
}
}
impl Credentials for OKXCredentials {
fn sign(&self, payload: &str) -> Result<String, Box<dyn Error>> {
let _ = (&self.api_key, &self.passphrase);
encode_base64(&encrypt_hmac_sha_256(&self.secret_key, payload)?)
}
}
impl RequestHeaders for OKXCredentials {
fn headers(
&self,
method: &str,
path: &str,
payload: &str,
) -> Result<HeaderMap, Box<dyn Error>> {
let timestamp = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
let sign_payload = format!("{}{}{}{}", timestamp, method, path, payload);
let signature = self.sign(&sign_payload)?;
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_static("ok-access-key"),
HeaderValue::from_str(&self.api_key)?,
);
headers.insert(
HeaderName::from_static("ok-access-sign"),
HeaderValue::from_str(&signature)?,
);
headers.insert(
HeaderName::from_static("ok-access-timestamp"),
HeaderValue::from_str(×tamp)?,
);
headers.insert(
HeaderName::from_static("ok-access-passphrase"),
HeaderValue::from_str(&self.passphrase)?,
);
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
Ok(headers)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum OKXTradeMode {
Cash,
Isolated,
Cross,
}
#[derive(Debug, Clone)]
pub struct OKXConfig {
pub trade_mode: OKXTradeMode,
pub demo: bool,
}
impl Default for OKXConfig {
fn default() -> Self {
OKXConfig {
trade_mode: OKXTradeMode::Cash,
demo: false,
}
}
}
pub struct OKX {
client: Client,
credentials: OKXCredentials,
config: OKXConfig,
}
impl OKX {
fn endpoint(path: &str) -> String {
format!("{}{}", OKX_BASE_URL, path)
}
fn boxed_error(message: String) -> Box<dyn Error> {
IoError::other(message).into()
}
fn trade_mode_str(&self) -> &'static str {
match self.config.trade_mode {
OKXTradeMode::Cash => "cash",
OKXTradeMode::Isolated => "isolated",
OKXTradeMode::Cross => "cross",
}
}
fn side_str(side: &OrderSide) -> &'static str {
match side {
OrderSide::Buy => "buy",
OrderSide::Sell => "sell",
}
}
fn order_type_str(order_type: &OrderType) -> Result<&'static str, Box<dyn Error>> {
match order_type {
OrderType::Market => Ok("market"),
OrderType::Limit => Ok("limit"),
_ => Err(Self::boxed_error(format!(
"unsupported order type: {:?}",
order_type,
))),
}
}
fn parse_order_side(side: &str) -> Option<OrderSide> {
match side {
"buy" => Some(OrderSide::Buy),
"sell" => Some(OrderSide::Sell),
_ => None,
}
}
fn parse_order_type(order_type: &str) -> OrderType {
match order_type {
"market" => OrderType::Market,
"limit" => OrderType::Limit,
"stop" => OrderType::Stop,
"post_only" | "fok" | "ioc" | "optimal_limit_ioc" | "mmp" | "mmp_and_post_only" => {
OrderType::Limit
}
_ => OrderType::Market,
}
}
fn parse_order_status(status: &str) -> OrderStatus {
match status {
"live" => OrderStatus::New,
"partially_filled" => OrderStatus::PartiallyFilled,
"filled" => OrderStatus::Filled,
"canceled" => OrderStatus::Canceled,
"mmp_canceled" => OrderStatus::Canceled,
"order_failed" => OrderStatus::Rejected,
"effective" => OrderStatus::Accepted,
_ => OrderStatus::Unknown,
}
}
fn total_available_balance(details: &Option<Vec<OKXBalanceDetail>>) -> Option<String> {
let total = details
.as_ref()?
.iter()
.filter_map(|detail| {
detail
.avail_bal
.as_deref()
.and_then(|balance| balance.parse::<f64>().ok())
.or_else(|| {
let equity = detail.eq.as_deref()?.parse::<f64>().ok()?;
let frozen = detail
.frozen_bal
.as_deref()
.and_then(|balance| balance.parse::<f64>().ok())
.unwrap_or(0.0);
Some(equity - frozen)
})
})
.sum::<f64>();
Some(total.to_string())
}
fn build_list_orders_query(request: &ListOrdersRequest) -> Vec<(&'static str, String)> {
let mut query = Vec::new();
if let Some(limit) = request.limit {
query.push(("limit", limit.to_string()));
}
if let Some(after) = request.after.as_ref() {
query.push(("after", after.clone()));
}
if let Some(until) = request.until.as_ref() {
query.push(("before", until.clone()));
}
query
}
fn filter_orders(mut orders: Vec<Order>, request: &ListOrdersRequest) -> Vec<Order> {
if let Some(side) = request.side.as_ref() {
orders.retain(|order| order.side.as_ref() == Some(side));
}
if matches!(request.direction, Some(SortDirection::Asc)) {
orders.reverse();
}
if let Some(limit) = request.limit {
orders.truncate(limit as usize);
}
orders
}
fn ensure_action_success(response: &OKXOrderActionResponse) -> Result<(), Box<dyn Error>> {
if response.s_code.is_empty() || response.s_code == "0" {
return Ok(());
}
Err(Self::boxed_error(format!(
"okx order error {}: {}",
response.s_code, response.s_msg,
)))
}
async fn parse_response<T: DeserializeOwned>(body: &str) -> Result<Vec<T>, Box<dyn Error>> {
let wrapper: OKXResponseWrapper<serde_json::Value> = serde_json::from_str(body)?;
if wrapper.code != "0" {
let detail = wrapper
.data
.first()
.and_then(|item| item.get("sMsg"))
.and_then(|item| item.as_str())
.map(String::from)
.unwrap_or_default();
let message = if detail.is_empty() {
wrapper.msg
} else {
detail
};
return Err(Self::boxed_error(format!(
"okx api error {}: {}",
wrapper.code, message,
)));
}
Ok(wrapper
.data
.into_iter()
.map(serde_json::from_value)
.collect::<Result<Vec<T>, _>>()?)
}
async fn send_get<T: DeserializeOwned>(
&self,
path: &str,
query: Option<&Vec<(&'static str, String)>>,
) -> Result<Vec<T>, Box<dyn Error>> {
let url = Self::endpoint(path);
let query_string = query
.map(|query| {
query
.iter()
.map(|(key, value)| format!("{}={}", key, value))
.collect::<Vec<_>>()
.join("&")
})
.unwrap_or_default();
let full_path = if query_string.is_empty() {
path.to_string()
} else {
format!("{}?{}", path, query_string)
};
let headers = self.credentials.headers("GET", &full_path, "")?;
let mut request = self.client.get(&url).headers(headers);
if let Some(query) = query {
request = request.query(query);
}
if self.config.demo {
request = request.header("x-simulated-trading", "1");
}
let response = request.send().await?;
let body = response.text().await?;
Self::parse_response(&body).await
}
async fn send_post<B: Serialize, T: DeserializeOwned>(
&self,
path: &str,
body: &B,
) -> Result<Vec<T>, Box<dyn Error>> {
let body_str = serde_json::to_string(body)?;
let headers = self.credentials.headers("POST", path, &body_str)?;
let mut request = self
.client
.post(Self::endpoint(path))
.headers(headers)
.body(body_str);
if self.config.demo {
request = request.header("x-simulated-trading", "1");
}
let response = request.send().await?;
let body = response.text().await?;
Self::parse_response(&body).await
}
async fn get_order_from_endpoints(
&self,
pending_query: &Vec<(&'static str, String)>,
history_query: &Vec<(&'static str, String)>,
missing_message: String,
) -> Result<Order, Box<dyn Error>> {
let pending_results: Vec<OKXOrderDetailResponse> = self
.send_get("/api/v5/trade/orders-pending", Some(pending_query))
.await?;
if let Some(order) = pending_results.into_iter().next() {
return Ok(order.into());
}
let history_results: Vec<OKXOrderDetailResponse> = self
.send_get("/api/v5/trade/orders-history-archive", Some(history_query))
.await?;
history_results
.into_iter()
.next()
.map(Order::from)
.ok_or_else(|| Self::boxed_error(missing_message))
}
pub fn new(credentials: OKXCredentials, config: OKXConfig) -> Self {
OKX {
client: Client::new(),
credentials,
config,
}
}
}
impl Exchange for OKX {
type Credentials = OKXCredentials;
fn new(credentials: OKXCredentials) -> Self {
OKX::new(credentials, OKXConfig::default())
}
}
impl DataApi for OKX {
async fn get_account(&self) -> Result<TradingAccount, Box<dyn Error>> {
let results: Vec<OKXAccountResponse> =
self.send_get("/api/v5/account/balance", None).await?;
let account = results
.into_iter()
.next()
.ok_or_else(|| Self::boxed_error("empty response".to_string()))?;
Ok(account.into())
}
async fn get_order(&self, order_id: &str) -> Result<Order, Box<dyn Error>> {
let pending_query = vec![("ordId", order_id.to_string())];
let history_query = vec![("ordId", order_id.to_string())];
self.get_order_from_endpoints(
&pending_query,
&history_query,
format!("order not found: {}", order_id),
)
.await
}
async fn get_order_by_client_id(&self, client_order_id: &str) -> Result<Order, Box<dyn Error>> {
let pending_query = vec![("clOrdId", client_order_id.to_string())];
let history_query = vec![("clOrdId", client_order_id.to_string())];
self.get_order_from_endpoints(
&pending_query,
&history_query,
format!("order not found for client id: {}", client_order_id),
)
.await
}
async fn list_orders(&self, request: &ListOrdersRequest) -> Result<Vec<Order>, Box<dyn Error>> {
let query = Self::build_list_orders_query(request);
let orders = match request.status.as_ref().unwrap_or(&OrderQueryStatus::Open) {
OrderQueryStatus::Open => {
let results: Vec<OKXOrderDetailResponse> = self
.send_get("/api/v5/trade/orders-pending", Some(&query))
.await?;
results.into_iter().map(Order::from).collect()
}
OrderQueryStatus::Closed => {
let results: Vec<OKXOrderDetailResponse> = self
.send_get("/api/v5/trade/orders-history-archive", Some(&query))
.await?;
results.into_iter().map(Order::from).collect()
}
OrderQueryStatus::All => {
let pending: Vec<OKXOrderDetailResponse> = self
.send_get("/api/v5/trade/orders-pending", Some(&query))
.await?;
let history: Vec<OKXOrderDetailResponse> = self
.send_get("/api/v5/trade/orders-history-archive", Some(&query))
.await?;
pending
.into_iter()
.chain(history)
.map(Order::from)
.collect()
}
};
Ok(Self::filter_orders(orders, request))
}
async fn get_option_chain(
&self,
_underlying_symbol: &str,
) -> Result<Vec<OptionChainSnapshot>, Box<dyn Error>> {
Ok(Vec::new())
}
}
impl TradingApi for OKX {
async fn submit(&self, order: &OrderRequest) -> Result<Order, Box<dyn Error>> {
let symbol = order
.symbol
.as_deref()
.ok_or_else(|| Self::boxed_error("symbol required".to_string()))?;
let side = order
.side
.as_ref()
.map(Self::side_str)
.ok_or_else(|| Self::boxed_error("side required".to_string()))?;
let order_type = Self::order_type_str(&order.order_type)?;
let qty = order
.qty
.as_deref()
.ok_or_else(|| Self::boxed_error("qty required".to_string()))?;
let request = OKXOrderRequest {
inst_id: symbol,
td_mode: self.trade_mode_str(),
side,
ord_type: order_type,
sz: qty,
px: order.limit_price.as_deref(),
cl_ord_id: order.client_order_id.as_deref(),
};
let results: Vec<OKXOrderActionResponse> =
self.send_post("/api/v5/trade/order", &request).await?;
let result = results
.into_iter()
.next()
.ok_or_else(|| Self::boxed_error("empty response".to_string()))?;
Self::ensure_action_success(&result)?;
self.get_order(&result.ord_id).await
}
async fn replace(
&self,
order_id: &str,
request: &ReplaceOrderRequest,
) -> Result<Order, Box<dyn Error>> {
let current_order = self.get_order(order_id).await?;
let symbol = current_order
.symbol
.as_deref()
.ok_or_else(|| Self::boxed_error(format!("missing symbol for order: {}", order_id)))?;
let amend = OKXAmendOrderRequest {
inst_id: symbol,
ord_id: Some(order_id),
new_sz: request.qty.as_deref(),
new_px: request.limit_price.as_deref(),
};
let results: Vec<OKXOrderActionResponse> =
self.send_post("/api/v5/trade/amend-order", &amend).await?;
let result = results
.into_iter()
.next()
.ok_or_else(|| Self::boxed_error("empty response".to_string()))?;
Self::ensure_action_success(&result)?;
self.get_order(&result.ord_id).await
}
async fn cancel(&self, order_id: &str) -> Result<(), Box<dyn Error>> {
let order = self.get_order(order_id).await?;
let symbol = order
.symbol
.as_deref()
.ok_or_else(|| Self::boxed_error(format!("missing symbol for order: {}", order_id)))?;
let request = OKXCancelOrderRequest {
inst_id: symbol,
ord_id: order_id,
};
let results: Vec<OKXOrderActionResponse> = self
.send_post("/api/v5/trade/cancel-order", &request)
.await?;
let result = results
.into_iter()
.next()
.ok_or_else(|| Self::boxed_error("empty response".to_string()))?;
Self::ensure_action_success(&result)
}
async fn cancel_all(&self) -> Result<Vec<CanceledOrderResponse>, Box<dyn Error>> {
let pending_orders = self
.list_orders(&ListOrdersRequest {
status: Some(OrderQueryStatus::Open),
..Default::default()
})
.await?;
if pending_orders.is_empty() {
return Ok(Vec::new());
}
let requests = pending_orders
.iter()
.filter_map(|order| {
order.symbol.as_deref().map(|symbol| OKXCancelOrderRequest {
inst_id: symbol,
ord_id: order.id.as_str(),
})
})
.collect::<Vec<_>>();
let results: Vec<OKXOrderActionResponse> = self
.send_post("/api/v5/trade/cancel-batch-orders", &requests)
.await?;
Ok(results
.into_iter()
.map(CanceledOrderResponse::from)
.collect())
}
}
#[derive(Debug, Deserialize)]
struct OKXResponseWrapper<T> {
code: String,
msg: String,
data: Vec<T>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct OKXOrderRequest<'a> {
inst_id: &'a str,
td_mode: &'a str,
side: &'a str,
ord_type: &'a str,
sz: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
px: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
cl_ord_id: Option<&'a str>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct OKXAmendOrderRequest<'a> {
inst_id: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
ord_id: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
new_sz: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
new_px: Option<&'a str>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct OKXCancelOrderRequest<'a> {
inst_id: &'a str,
ord_id: &'a str,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct OKXOrderActionResponse {
ord_id: String,
#[serde(default)]
cl_ord_id: String,
#[serde(default)]
s_code: String,
#[serde(default)]
s_msg: String,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct OKXOrderDetailResponse {
inst_id: String,
ord_id: String,
cl_ord_id: Option<String>,
side: String,
ord_type: String,
px: Option<String>,
sz: String,
acc_fill_sz: Option<String>,
state: String,
#[serde(default)]
c_time: String,
}
impl From<OKXOrderDetailResponse> for Order {
fn from(order: OKXOrderDetailResponse) -> Self {
Order {
id: order.ord_id,
client_order_id: order.cl_ord_id,
symbol: Some(order.inst_id),
asset_class: Some(AssetClass::Crypto),
qty: Some(order.sz),
notional: None,
filled_qty: order.acc_fill_sz,
side: OKX::parse_order_side(&order.side),
order_type: OKX::parse_order_type(&order.ord_type),
time_in_force: TimeInForce::Day,
status: OKX::parse_order_status(&order.state),
order_class: None,
limit_price: order.px,
stop_price: None,
trail_price: None,
trail_percent: None,
created_at: if order.c_time.is_empty() {
None
} else {
Some(order.c_time)
},
extended_hours: None,
position_intent: None,
ratio_qty: None,
legs: None,
}
}
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct OKXAccountResponse {
total_eq: Option<String>,
details: Option<Vec<OKXBalanceDetail>>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
struct OKXBalanceDetail {
ccy: String,
avail_bal: Option<String>,
frozen_bal: Option<String>,
eq: Option<String>,
}
impl From<OKXAccountResponse> for TradingAccount {
fn from(account: OKXAccountResponse) -> Self {
let currency = account
.details
.as_ref()
.and_then(|details| details.first())
.map(|detail| detail.ccy.clone())
.unwrap_or_else(|| "USD".to_string());
TradingAccount {
id: "okx".to_string(),
account_number: None,
status: "ACTIVE".to_string(),
currency,
buying_power: OKX::total_available_balance(&account.details),
equity: account.total_eq,
}
}
}
impl From<OKXOrderActionResponse> for CanceledOrderResponse {
fn from(response: OKXOrderActionResponse) -> Self {
let body = Some(serde_json::json!({
"clOrdId": response.cl_ord_id,
"sCode": response.s_code,
"sMsg": response.s_msg,
}));
CanceledOrderResponse {
id: response.ord_id,
status: if response.s_code.is_empty() || response.s_code == "0" {
200
} else {
400
},
body,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_okx_credentials_headers_include_required_keys() {
let credentials = OKXCredentials::new(
"api-key".to_string(),
"secret-key".to_string(),
"passphrase".to_string(),
);
let headers = credentials
.headers("GET", "/api/v5/account/balance", "")
.unwrap();
assert_eq!(
headers.get("ok-access-key").unwrap(),
&HeaderValue::from_static("api-key"),
);
assert!(headers.get("ok-access-sign").is_some());
assert!(headers.get("ok-access-timestamp").is_some());
assert_eq!(
headers.get("ok-access-passphrase").unwrap(),
&HeaderValue::from_static("passphrase"),
);
}
#[test]
fn test_okx_order_request_serializes_camel_case() {
let request = OKXOrderRequest {
inst_id: "BTC-USDT",
td_mode: "cash",
side: "buy",
ord_type: "limit",
sz: "1",
px: Some("100000"),
cl_ord_id: Some("client-1"),
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("\"instId\":\"BTC-USDT\""));
assert!(json.contains("\"tdMode\":\"cash\""));
assert!(json.contains("\"ordType\":\"limit\""));
assert!(json.contains("\"clOrdId\":\"client-1\""));
}
#[tokio::test]
async fn test_okx_response_wrapper_parses_success() {
let body =
r#"{"code":"0","msg":"","data":[{"ordId":"1","clOrdId":"c1","sCode":"0","sMsg":""}]}"#;
let results = OKX::parse_response::<OKXOrderActionResponse>(body)
.await
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].ord_id, "1");
}
#[tokio::test]
async fn test_okx_response_wrapper_parses_error() {
let body = r#"{"code":"51000","msg":"Parameter error","data":[{"sCode":"51000","sMsg":"size too small"}]}"#;
let error = OKX::parse_response::<OKXOrderActionResponse>(body)
.await
.unwrap_err();
assert_eq!(error.to_string(), "okx api error 51000: size too small");
}
#[test]
fn test_okx_order_detail_response_maps_to_order() {
let response = OKXOrderDetailResponse {
inst_id: "BTC-USDT".to_string(),
ord_id: "order-1".to_string(),
cl_ord_id: Some("client-1".to_string()),
side: "buy".to_string(),
ord_type: "limit".to_string(),
px: Some("50000".to_string()),
sz: "2".to_string(),
acc_fill_sz: Some("1".to_string()),
state: "partially_filled".to_string(),
c_time: "1700000000000".to_string(),
};
let order = Order::from(response);
assert_eq!(order.id, "order-1");
assert_eq!(order.symbol.as_deref(), Some("BTC-USDT"));
assert_eq!(order.asset_class, Some(AssetClass::Crypto));
assert_eq!(order.side, Some(OrderSide::Buy));
assert_eq!(order.order_type, OrderType::Limit);
assert_eq!(order.status, OrderStatus::PartiallyFilled);
assert_eq!(order.qty.as_deref(), Some("2"));
assert_eq!(order.filled_qty.as_deref(), Some("1"));
assert_eq!(order.limit_price.as_deref(), Some("50000"));
assert_eq!(order.created_at.as_deref(), Some("1700000000000"));
}
#[test]
fn test_okx_default_constructor_uses_cash_trade_mode() {
let exchange = <OKX as Exchange>::new(OKXCredentials::new(
"key".into(),
"secret".into(),
"pass".into(),
));
assert_eq!(exchange.config.trade_mode, OKXTradeMode::Cash);
assert!(!exchange.config.demo);
}
}