use async_std::{
channel::{bounded, Receiver, Sender},
sync::RwLock,
task,
};
use chrono::NaiveDateTime;
use uuid::Uuid;
use crate::{
fixapi::FixApi,
messages::{
NewOrderSingleReq, OrderCancelReplaceReq, OrderCancelReq, OrderMassStatusReq, PositionsReq,
ResponseMessage, SecurityListReq,
},
parse_func::{self, parse_execution_report},
types::{
ConnectionHandler, Error, ExecutionReport, Field, OrderType, PositionReport, Side,
SymbolInformation, TradeDataHandler,
},
};
use std::{
collections::VecDeque,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::{Duration, Instant},
};
#[derive(Debug)]
struct TimeoutItem<T> {
item: T,
expiry: Instant,
consumed: AtomicBool,
}
impl<T> TimeoutItem<T> {
fn new(item: T, lifetime: Duration) -> Self {
TimeoutItem {
item,
expiry: Instant::now() + lifetime,
consumed: AtomicBool::new(false),
}
}
}
pub struct TradeClient {
internal: FixApi,
trade_data_handler: Option<Arc<dyn TradeDataHandler + Send + Sync>>,
queue: Arc<RwLock<VecDeque<TimeoutItem<ResponseMessage>>>>,
signal: Sender<()>,
receiver: Receiver<()>,
timeout: u64,
}
impl TradeClient {
pub fn new(
host: String,
login: String,
password: String,
sender_comp_id: String,
heartbeat_interval: Option<u32>,
) -> Self {
let (tx, rx) = bounded(1);
Self {
internal: FixApi::new(
crate::types::SubID::TRADE,
host,
login,
password,
sender_comp_id,
heartbeat_interval,
),
trade_data_handler: None,
queue: Arc::new(RwLock::new(VecDeque::new())),
signal: tx,
receiver: rx,
timeout: 5000, }
}
pub fn get_timeout(&self) -> u64 {
self.timeout
}
pub fn set_timeout(&mut self, timeout: u64) {
self.timeout = timeout;
}
pub fn register_trade_handler_arc<T: TradeDataHandler + Send + Sync + 'static>(
&mut self,
handler: Arc<T>,
) {
self.trade_data_handler = Some(handler);
}
pub fn register_trade_handler<T: TradeDataHandler + Send + Sync + 'static>(
&mut self,
handler: T,
) {
self.trade_data_handler = Some(Arc::new(handler));
}
pub fn register_connection_handler<T: ConnectionHandler + Send + Sync + 'static>(
&mut self,
handler: T,
) {
self.internal.register_connection_handler(handler);
}
pub fn register_connection_handler_arc<T: ConnectionHandler + Send + Sync + 'static>(
&mut self,
handler: Arc<T>,
) {
self.internal.register_connection_handler_arc(handler);
}
pub async fn connect(&mut self) -> Result<(), Error> {
self.register_internal_handler();
self.internal.connect().await?;
self.internal.logon(false).await
}
pub async fn disconnect(&mut self) -> Result<(), Error> {
self.internal.disconnect().await
}
pub fn is_connected(&self) -> bool {
self.internal.is_connected()
}
fn register_internal_handler(&mut self) {
let queue = self.queue.clone();
let handler = self.trade_data_handler.clone();
let signal = self.signal.clone();
let trade_callback = move |res: ResponseMessage| {
let signal = signal.clone();
let handler = handler.clone();
let queue = queue.clone();
let lifetime = Duration::from_millis(5000);
task::spawn(async move {
match res.get_message_type() {
"8" => {
if res
.get_field_value(Field::ExecType)
.map(|v| v.as_str() != "I")
.unwrap_or(true)
{
match parse_execution_report(res.clone()) {
Ok(report) => {
if let Some(handler) = handler {
handler.on_execution_report(report).await;
}
}
Err(_err) => {
}
}
}
}
_ => {}
}
queue
.write()
.await
.push_back(TimeoutItem::new(res, lifetime));
let now = Instant::now();
loop {
let expiry = queue.read().await.front().map(|v| v.expiry).unwrap_or(now);
if expiry < now {
queue.write().await.pop_front();
} else {
break;
}
}
signal.try_send(()).ok();
});
};
self.internal.register_trade_callback(trade_callback);
}
fn create_unique_id(&self) -> String {
Uuid::new_v4().to_string()
}
async fn wait_notifier(&self, receiver: Receiver<()>, dur: u64) -> Result<(), Error> {
if !self.is_connected() {
return Err(Error::NotConnected);
}
async_std::future::timeout(Duration::from_millis(dur), receiver.recv())
.await
.map_err(|_| Error::TimeoutError)?
.map_err(|e| e.into())
}
async fn fetch_response(
&self,
arg: Vec<(&str, Field, String)>,
) -> Result<ResponseMessage, Error> {
let now = Instant::now();
let mut remain = self.timeout;
loop {
let _ = self.wait_notifier(self.receiver.clone(), remain).await?;
let mut res = None;
let q = self.queue.read().await;
for v in q.iter().rev() {
let mut b = false;
let consumed = v.consumed.load(Ordering::Relaxed);
if consumed {
continue;
}
for (msg_type, field, value) in arg.iter() {
if v.item.matching_field_value(msg_type, *field, value) {
b = true;
res = Some(v.item.clone());
v.consumed.store(true, Ordering::Relaxed);
break;
}
}
if b {
break;
}
}
match res {
Some(res) => {
return Ok(res);
}
None => {
let past = (Instant::now() - now).as_millis() as u64;
if past < self.timeout {
remain = self.timeout - past;
if self.receiver.receiver_count() > 1 {
self.signal.try_send(()).ok();
}
continue;
} else {
return Err(Error::TimeoutError);
}
}
}
}
}
fn check_connection(&self) -> Result<(), Error> {
if self.is_connected() {
Ok(())
} else {
Err(Error::NotConnected)
}
}
pub async fn fetch_security_list(&self) -> Result<Vec<SymbolInformation>, Error> {
self.check_connection()?;
let security_req_id = self.create_unique_id();
let req = SecurityListReq::new(security_req_id.clone(), 0, None);
self.internal.send_message(req).await?;
match self
.fetch_response(vec![("y", Field::SecurityReqID, security_req_id)])
.await
{
Ok(res) => parse_func::parse_security_list(&res),
Err(err) => Err(err),
}
}
pub async fn fetch_positions(&self) -> Result<Vec<PositionReport>, Error> {
self.check_connection()?;
let pos_req_id = self.create_unique_id();
let req = PositionsReq::new(pos_req_id.clone(), None);
self.internal.send_message(req).await?;
let mut result = Vec::new();
loop {
match self
.fetch_response(vec![("AP", Field::PosReqID, pos_req_id.clone())])
.await
{
Ok(res) => {
if res.get_message_type() == "AP"
&& res
.get_field_value(Field::PosReqResult)
.map_or(false, |v| v.as_str() == "0")
{
let no_pos = res
.get_field_value(Field::TotalNumPosReports)
.unwrap_or("0".into())
.parse::<usize>()
.unwrap();
result.push(res);
if no_pos <= result.len() {
return parse_func::parse_positions(result);
} else {
continue;
}
} else {
return parse_func::parse_positions(vec![res]);
}
}
Err(err) => {
return Err(err);
}
}
}
}
pub async fn fetch_all_order_status(
&self,
issue_data: Option<NaiveDateTime>,
) -> Result<Vec<ExecutionReport>, Error> {
self.check_connection()?;
let mass_status_req_id = self.create_unique_id();
let req = OrderMassStatusReq::new(mass_status_req_id.clone(), 7, issue_data);
self.internal.send_message(req).await?;
let mut result = Vec::new();
loop {
match self
.fetch_response(vec![
("8", Field::MassStatusReqID, mass_status_req_id.clone()),
("j", Field::BusinessRejectRefID, mass_status_req_id.clone()),
])
.await
{
Ok(res) => {
return match res.get_message_type() {
"j" => Ok(Vec::new()),
"8" => {
let no_report = res
.get_field_value(Field::TotNumReports)
.unwrap_or("0".into())
.parse::<usize>()
.unwrap();
result.push(res);
if no_report <= result.len() {
parse_func::parse_order_mass_status(result)
} else {
continue;
}
}
_ => Err(Error::UnknownError),
};
}
Err(err) => {
return Err(err);
}
}
}
}
async fn new_order(&self, req: NewOrderSingleReq) -> Result<ExecutionReport, Error> {
self.check_connection()?;
let cl_ord_id = req.cl_ord_id.clone();
self.internal.send_message(req).await?;
match self
.fetch_response(vec![
("8", Field::ClOrdId, cl_ord_id.clone()),
("j", Field::BusinessRejectRefID, cl_ord_id.clone()),
])
.await
{
Ok(res) => match res.get_message_type() {
"j" => Err(Error::OrderFailed(
res.get_field_value(Field::Text).unwrap_or("Unknown".into()),
)),
"8" => parse_func::parse_execution_report(res),
_ => Err(Error::UnknownError),
},
Err(err) => Err(err),
}
}
pub async fn new_market_order(
&self,
symbol: u32,
side: Side,
order_qty: f64,
cl_ord_id: Option<String>,
custom_ord_label: Option<String>,
) -> Result<ExecutionReport, Error> {
let req = NewOrderSingleReq::new(
cl_ord_id.unwrap_or(self.create_unique_id()),
symbol,
side,
None,
order_qty,
OrderType::Market,
None,
None,
None,
None,
custom_ord_label,
);
self.new_order(req).await
}
pub async fn new_limit_order(
&self,
symbol: u32,
side: Side,
price: f64,
order_qty: f64,
cl_ord_id: Option<String>,
expire_time: Option<NaiveDateTime>,
custom_ord_label: Option<String>,
) -> Result<ExecutionReport, Error> {
let req = NewOrderSingleReq::new(
cl_ord_id.unwrap_or(self.create_unique_id()),
symbol,
side,
None,
order_qty,
OrderType::Limit,
Some(price),
None,
expire_time,
None,
custom_ord_label,
);
self.new_order(req).await
}
pub async fn new_stop_order(
&self,
symbol: u32,
side: Side,
stop_px: f64,
order_qty: f64,
cl_ord_id: Option<String>,
expire_time: Option<NaiveDateTime>,
custom_ord_label: Option<String>,
) -> Result<ExecutionReport, Error> {
let req = NewOrderSingleReq::new(
cl_ord_id.unwrap_or(self.create_unique_id()),
symbol,
side,
None,
order_qty,
OrderType::Stop,
None,
Some(stop_px),
expire_time,
None,
custom_ord_label,
);
self.new_order(req).await
}
pub async fn close_position(
&self,
pos_report: PositionReport,
) -> Result<ExecutionReport, Error> {
self.adjust_position_size(
pos_report.position_id,
pos_report.symbol_id,
if pos_report.long_qty == 0.0 {
pos_report.short_qty
} else {
pos_report.long_qty
},
if pos_report.long_qty == 0.0 {
Side::BUY
} else {
Side::SELL
},
)
.await
}
pub async fn adjust_position_size(
&self,
pos_id: String,
symbol_id: u32,
lot: f64,
side: Side,
) -> Result<ExecutionReport, Error> {
let req = NewOrderSingleReq::new(
self.create_unique_id(),
symbol_id,
side,
None,
lot,
OrderType::Market,
None,
None,
None,
Some(pos_id),
None,
);
self.new_order(req).await
}
pub async fn replace_order(
&self,
org_cl_ord_id: Option<String>,
order_id: Option<String>,
order_qty: f64,
price: Option<f64>,
stop_px: Option<f64>,
expire_time: Option<NaiveDateTime>,
) -> Result<ExecutionReport, Error> {
if org_cl_ord_id.is_none() && order_id.is_none() {
return Err(Error::MissingArgumentError);
}
self.check_connection()?;
let orgid = match org_cl_ord_id.clone() {
Some(v) => v,
None => order_id.clone().unwrap(),
};
let oid = match order_id.clone() {
Some(v) => v,
None => org_cl_ord_id.clone().unwrap(),
};
let cl_ord_id = self.create_unique_id();
let req = OrderCancelReplaceReq::new(
orgid,
Some(oid),
cl_ord_id.clone(),
order_qty,
price,
stop_px,
expire_time,
);
self.internal.send_message(req).await?;
match self
.fetch_response(vec![
if org_cl_ord_id.is_some() {
("8", Field::ClOrdId, org_cl_ord_id.unwrap())
} else {
("8", Field::OrderID, order_id.unwrap())
},
("j", Field::BusinessRejectRefID, cl_ord_id.clone()),
])
.await
{
Ok(res) => {
match res.get_message_type() {
"j" => {
Err(Error::OrderFailed(
res.get_field_value(Field::Text)
.unwrap_or("Unknown error".into()),
)
.into())
}
_ => {
parse_func::parse_execution_report(res)
}
}
}
Err(err) => Err(err),
}
}
pub async fn cancel_order(
&self,
org_cl_ord_id: Option<String>,
order_id: Option<String>,
) -> Result<ExecutionReport, Error> {
if org_cl_ord_id.is_none() && order_id.is_none() {
return Err(Error::MissingArgumentError);
}
self.check_connection()?;
let orgid = match org_cl_ord_id.clone() {
Some(v) => v,
None => order_id.clone().unwrap(),
};
let oid = match order_id {
Some(v) => v,
None => org_cl_ord_id.unwrap(),
};
let cl_ord_id = self.create_unique_id();
let req = OrderCancelReq::new(orgid, Some(oid), cl_ord_id.clone());
self.internal.send_message(req).await?;
match self
.fetch_response(vec![
("8", Field::ClOrdId, cl_ord_id.clone()),
("j", Field::BusinessRejectRefID, cl_ord_id.clone()),
("9", Field::ClOrdId, cl_ord_id.clone()),
])
.await
{
Ok(res) => {
match res.get_message_type() {
"j" => {
Err(Error::OrderFailed(
res.get_field_value(Field::Text)
.unwrap_or("Unknown error".into()),
)
.into())
}
"9" => {
Err(Error::OrderCancelRejected(
res.get_field_value(Field::Text)
.unwrap_or("Unknown error".into()),
)
.into())
}
_ => {
parse_func::parse_execution_report(res)
}
}
}
Err(err) => Err(err),
}
}
}