use crate::protocol::{self, order_gateway::*};
use crate::types::PlaceOrder;
use crate::OrderId;
use anyhow::{anyhow, bail, Result};
use futures::{SinkExt, StreamExt};
use log::{debug, error, info, trace, warn};
use std::collections::HashMap;
use tokio::net::TcpStream;
use tokio_tungstenite::{
connect_async,
tungstenite::{handshake::client::generate_key, http::Request, Message},
MaybeTlsStream, WebSocketStream,
};
use url::Url;
pub type SendCallback = Box<dyn Fn(&str) + Send + Sync>;
pub type ReceiveCallback = Box<dyn Fn(&str) + Send + Sync>;
pub struct OrderGatewayWsClient {
ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
next_request_id: i32,
in_flight_requests: HashMap<i32, OrderGatewayRequestType>,
on_send: Option<SendCallback>,
on_receive: Option<ReceiveCallback>,
}
impl OrderGatewayWsClient {
pub async fn connect(base_url: Url, token: impl AsRef<str>) -> Result<Self> {
Self::connect_inner(base_url, token, None).await
}
pub async fn connect_with_cancel_on_disconnect(
base_url: Url,
token: impl AsRef<str>,
) -> Result<Self> {
Self::connect_inner(base_url, token, Some("cancel_on_disconnect=true")).await
}
async fn connect_inner(
base_url: Url,
token: impl AsRef<str>,
query: Option<&str>,
) -> Result<Self> {
let mut ws_base_url = base_url.clone();
let res = match base_url.scheme() {
"http" => ws_base_url.set_scheme("ws"),
"https" => ws_base_url.set_scheme("wss"),
_ => bail!("invalid url scheme"),
};
res.map_err(|_| anyhow!("invalid url scheme"))?;
let mut order_gateway_url = ws_base_url.join("orders/ws")?;
if let Some(q) = query {
order_gateway_url.set_query(Some(q));
}
info!("connecting to {order_gateway_url}");
let authority = order_gateway_url.authority();
let host = authority
.find('@')
.map(|idx| authority.split_at(idx + 1).1)
.unwrap_or_else(|| authority);
let request = Request::builder()
.method("GET")
.uri(order_gateway_url.as_str())
.header("Host", host)
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Key", generate_key())
.header("Authorization", token.as_ref())
.body(())?;
let (ws, _) = connect_async(request).await?;
Ok(Self {
ws,
next_request_id: 1,
in_flight_requests: HashMap::new(),
on_send: None,
on_receive: None,
})
}
pub fn on_send<F>(&mut self, callback: F)
where
F: Fn(&str) + Send + Sync + 'static,
{
self.on_send = Some(Box::new(callback));
}
pub fn on_receive<F>(&mut self, callback: F)
where
F: Fn(&str) + Send + Sync + 'static,
{
self.on_receive = Some(Box::new(callback));
}
pub async fn next(&mut self) -> Result<OrderGatewayMessage> {
loop {
let msg = self
.ws
.next()
.await
.ok_or_else(|| anyhow!("ws stream ended"))??;
match msg {
Message::Text(text) => {
if let Some(ref callback) = self.on_receive {
callback(&text);
}
trace!("decoding order gateway message: {text}");
match serde_json::from_str::<OrderGatewayEvent>(&text) {
Ok(e) => {
self.handle_event(&e);
return Ok(OrderGatewayMessage::Event(e));
}
Err(e_as_event) => {
match serde_json::from_str::<
protocol::ws::Response<Box<serde_json::value::RawValue>>,
>(&text)
{
Ok(r) => match self.handle_response(r) {
Ok(Some(res)) => return Ok(OrderGatewayMessage::Response(res)),
Ok(None) => continue,
Err(e_res) => {
error!("handling response: {e_res:?}");
}
},
Err(e_as_response) => {
error!(
"decoding order gateway message as event: {e_as_event:?}"
);
error!("decoding order gateway message as response: {e_as_response:?}");
continue;
}
}
}
}
}
Message::Ping(..) => {
trace!("ws ping received");
}
Message::Binary(..)
| Message::Frame(..)
| Message::Pong(..)
| Message::Close(..) => {}
}
}
}
fn handle_response(
&mut self,
res: protocol::ws::Response<Box<serde_json::value::RawValue>>,
) -> Result<Option<protocol::ws::Response<OrderGatewayResponse>>> {
macro_rules! try_parse {
($res:expr, $type:ty, $v:path) => {
$res.response
.map(|r| serde_json::from_str::<$type>(r.get()))
.transpose()?
.map(|r| $v(r))
};
}
let Some(request_id) = res.request_id else {
if let Some(err) = res.error {
warn!("received error with unknown request_id: {}", err);
}
return Ok(None);
};
let parsed = if let Some(req_type) = self.in_flight_requests.remove(&request_id) {
match req_type {
OrderGatewayRequestType::PlaceOrder => {
try_parse!(
res,
PlaceOrderResponse,
OrderGatewayResponse::PlaceOrderResponse
)
}
OrderGatewayRequestType::CancelOrder => {
try_parse!(
res,
CancelOrderResponse,
OrderGatewayResponse::CancelOrderResponse
)
}
OrderGatewayRequestType::ReplaceOrder => {
try_parse!(
res,
ReplaceOrderResponse,
OrderGatewayResponse::ReplaceOrderResponse
)
}
OrderGatewayRequestType::CancelAllOrders => {
try_parse!(
res,
CancelAllOrdersResponse,
OrderGatewayResponse::CancelAllOrdersResponse
)
}
OrderGatewayRequestType::GetOpenOrders => {
try_parse!(
res,
GetOpenOrdersResponse,
OrderGatewayResponse::GetOpenOrdersResponse
)
}
}
} else {
warn!("response to unknown request: {}", request_id);
return Ok(None);
};
Ok(Some(protocol::ws::Response {
request_id: Some(request_id),
response: parsed,
error: res.error,
data: None,
}))
}
fn handle_event(&mut self, e: &protocol::order_gateway::OrderGatewayEvent) {
trace!("order gateway event: {e:?}");
if let OrderGatewayEvent::Heartbeat(t) = e {
debug!("heartbeat: {:?}", t.as_datetime());
}
}
pub async fn get_open_orders(&mut self) -> Result<()> {
let request_id = self.next_request_id;
self.next_request_id += 1;
let req = protocol::order_gateway::OrderGatewayRequest::GetOpenOrders(
protocol::order_gateway::GetOpenOrdersRequest {},
);
let wrapped_req = protocol::ws::Request {
request_id,
request: req,
};
let payload = serde_json::to_string(&wrapped_req)?;
if let Some(ref callback) = self.on_send {
callback(&payload);
}
trace!("sending get open orders request: {payload}");
self.ws.send(Message::Text(payload.into())).await?;
self.in_flight_requests
.insert(request_id, OrderGatewayRequestType::GetOpenOrders);
Ok(())
}
pub async fn place_order(&mut self, place_order: PlaceOrder) -> Result<i32> {
let request_id = self.next_request_id;
self.next_request_id += 1;
let req = protocol::order_gateway::OrderGatewayRequest::PlaceOrder(place_order.into());
let wrapped_req = protocol::ws::Request {
request_id,
request: req,
};
let payload = serde_json::to_string(&wrapped_req)?;
if let Some(ref callback) = self.on_send {
callback(&payload);
}
trace!("sending place order request: {payload}");
self.ws.send(Message::Text(payload.into())).await?;
self.in_flight_requests
.insert(request_id, OrderGatewayRequestType::PlaceOrder);
Ok(request_id)
}
pub async fn cancel_all_orders(&mut self, symbol: Option<&str>) -> Result<i32> {
let request_id = self.next_request_id;
self.next_request_id += 1;
let req = protocol::order_gateway::OrderGatewayRequest::CancelAllOrders(
protocol::order_gateway::CancelAllOrdersRequest {
symbol: symbol.map(|s| s.to_string()),
},
);
let wrapped_req = protocol::ws::Request {
request_id,
request: req,
};
let payload = serde_json::to_string(&wrapped_req)?;
if let Some(ref callback) = self.on_send {
callback(&payload);
}
trace!("sending cancel all orders request: {payload}");
self.ws.send(Message::Text(payload.into())).await?;
self.in_flight_requests
.insert(request_id, OrderGatewayRequestType::CancelAllOrders);
Ok(request_id)
}
pub async fn cancel_order(&mut self, order_id: &OrderId) -> Result<i32> {
let request_id = self.next_request_id;
self.next_request_id += 1;
let req = protocol::order_gateway::OrderGatewayRequest::CancelOrder(
protocol::order_gateway::CancelOrderRequest {
order_id: order_id.clone(),
},
);
let wrapped_req = protocol::ws::Request {
request_id,
request: req,
};
let payload = serde_json::to_string(&wrapped_req)?;
if let Some(ref callback) = self.on_send {
callback(&payload);
}
trace!("sending cancel order request: {payload}");
self.ws.send(Message::Text(payload.into())).await?;
self.in_flight_requests
.insert(request_id, OrderGatewayRequestType::CancelOrder);
Ok(request_id)
}
pub async fn replace_order(
&mut self,
req: protocol::order_gateway::ReplaceOrderRequest,
) -> Result<i32> {
let request_id = self.next_request_id;
self.next_request_id += 1;
let req = protocol::order_gateway::OrderGatewayRequest::ReplaceOrder(req);
let wrapped_req = protocol::ws::Request {
request_id,
request: req,
};
let payload = serde_json::to_string(&wrapped_req)?;
if let Some(ref callback) = self.on_send {
callback(&payload);
}
trace!("sending replace order request: {payload}");
self.ws.send(Message::Text(payload.into())).await?;
self.in_flight_requests
.insert(request_id, OrderGatewayRequestType::ReplaceOrder);
Ok(request_id)
}
}