1use async_std::{
2 channel::{bounded, Receiver, Sender},
3 sync::RwLock,
4 task,
5};
6use chrono::NaiveDateTime;
7use uuid::Uuid;
8
9use crate::{
10 fixapi::FixApi,
11 messages::{
12 NewOrderSingleReq, OrderCancelReplaceReq, OrderCancelReq, OrderMassStatusReq, PositionsReq,
13 ResponseMessage, SecurityListReq,
14 },
15 parse_func::{self, parse_execution_report},
16 types::{
17 ConnectionHandler, Error, ExecutionReport, Field, OrderType, PositionReport, Side,
18 SymbolInformation, TradeDataHandler,
19 },
20};
21
22use std::{
23 collections::VecDeque,
24 sync::{
25 atomic::{AtomicBool, Ordering},
26 Arc,
27 },
28 time::{Duration, Instant},
29};
30
31#[derive(Debug)]
32struct TimeoutItem<T> {
33 item: T,
34 expiry: Instant,
35 consumed: AtomicBool,
36}
37
38impl<T> TimeoutItem<T> {
39 fn new(item: T, lifetime: Duration) -> Self {
40 TimeoutItem {
41 item,
42 expiry: Instant::now() + lifetime,
43 consumed: AtomicBool::new(false),
44 }
45 }
46}
47
48pub struct TradeClient {
49 internal: FixApi,
50
51 trade_data_handler: Option<Arc<dyn TradeDataHandler + Send + Sync>>,
52
53 queue: Arc<RwLock<VecDeque<TimeoutItem<ResponseMessage>>>>,
54
55 signal: Sender<()>,
56 receiver: Receiver<()>,
57
58 timeout: u64,
60}
61
62impl TradeClient {
63 pub fn new(
64 host: String,
65 login: String,
66 password: String,
67 sender_comp_id: String,
68 heartbeat_interval: Option<u32>,
69 ) -> Self {
70 let (tx, rx) = bounded(1);
71 Self {
72 internal: FixApi::new(
73 crate::types::SubID::TRADE,
74 host,
75 login,
76 password,
77 sender_comp_id,
78 heartbeat_interval,
79 ),
80 trade_data_handler: None,
81 queue: Arc::new(RwLock::new(VecDeque::new())),
82
83 signal: tx,
84 receiver: rx,
85
86 timeout: 5000, }
88 }
89
90 pub fn get_timeout(&self) -> u64 {
91 self.timeout
92 }
93
94 pub fn set_timeout(&mut self, timeout: u64) {
95 self.timeout = timeout;
96 }
97
98 pub fn register_trade_handler_arc<T: TradeDataHandler + Send + Sync + 'static>(
99 &mut self,
100 handler: Arc<T>,
101 ) {
102 self.trade_data_handler = Some(handler);
103 }
104
105 pub fn register_trade_handler<T: TradeDataHandler + Send + Sync + 'static>(
106 &mut self,
107 handler: T,
108 ) {
109 self.trade_data_handler = Some(Arc::new(handler));
110 }
111
112 pub fn register_connection_handler<T: ConnectionHandler + Send + Sync + 'static>(
113 &mut self,
114 handler: T,
115 ) {
116 self.internal.register_connection_handler(handler);
117 }
118
119 pub fn register_connection_handler_arc<T: ConnectionHandler + Send + Sync + 'static>(
120 &mut self,
121 handler: Arc<T>,
122 ) {
123 self.internal.register_connection_handler_arc(handler);
124 }
125
126 pub async fn connect(&mut self) -> Result<(), Error> {
127 self.register_internal_handler();
128 self.internal.connect().await?;
129 self.internal.logon(false).await
130 }
131
132 pub async fn disconnect(&mut self) -> Result<(), Error> {
133 self.internal.disconnect().await
134 }
135
136 pub fn is_connected(&self) -> bool {
137 self.internal.is_connected()
138 }
139
140 fn register_internal_handler(&mut self) {
141 let queue = self.queue.clone();
142 let handler = self.trade_data_handler.clone();
143 let signal = self.signal.clone();
144 let trade_callback = move |res: ResponseMessage| {
145 let signal = signal.clone();
146 let handler = handler.clone();
147 let queue = queue.clone();
148 let lifetime = Duration::from_millis(5000);
149 task::spawn(async move {
150 match res.get_message_type() {
151 "8" => {
152 if res
153 .get_field_value(Field::ExecType)
154 .map(|v| v.as_str() != "I")
155 .unwrap_or(true)
156 {
157 match parse_execution_report(res.clone()) {
158 Ok(report) => {
159 if let Some(handler) = handler {
160 handler.on_execution_report(report).await;
161 }
162 }
163 Err(_err) => {
164 }
166 }
167 }
168 }
169 _ => {}
170 }
171
172 queue
173 .write()
174 .await
175 .push_back(TimeoutItem::new(res, lifetime));
176
177 let now = Instant::now();
179 loop {
180 let expiry = queue.read().await.front().map(|v| v.expiry).unwrap_or(now);
181 if expiry < now {
182 queue.write().await.pop_front();
184 } else {
185 break;
186 }
187 }
188
189 signal.try_send(()).ok();
190 });
192 };
193
194 self.internal.register_trade_callback(trade_callback);
195 }
196
197 fn create_unique_id(&self) -> String {
198 Uuid::new_v4().to_string()
199 }
200
201 async fn wait_notifier(&self, receiver: Receiver<()>, dur: u64) -> Result<(), Error> {
202 if !self.is_connected() {
203 return Err(Error::NotConnected);
204 }
205 async_std::future::timeout(Duration::from_millis(dur), receiver.recv())
206 .await
207 .map_err(|_| Error::TimeoutError)?
208 .map_err(|e| e.into())
209 }
210
211 async fn fetch_response(
212 &self,
213 arg: Vec<(&str, Field, String)>,
214 ) -> Result<ResponseMessage, Error> {
215 let now = Instant::now();
217 let mut remain = self.timeout;
218
219 loop {
220 let _ = self.wait_notifier(self.receiver.clone(), remain).await?;
221 let mut res = None;
223 let q = self.queue.read().await;
224 for v in q.iter().rev() {
225 let mut b = false;
226 let consumed = v.consumed.load(Ordering::Relaxed);
227 if consumed {
228 continue;
229 }
230
231 for (msg_type, field, value) in arg.iter() {
232 if v.item.matching_field_value(msg_type, *field, value) {
233 b = true;
234 res = Some(v.item.clone());
235 v.consumed.store(true, Ordering::Relaxed);
236 break;
237 }
238 }
239 if b {
240 break;
241 }
242 }
243
244 match res {
245 Some(res) => {
246 return Ok(res);
247 }
248 None => {
249 let past = (Instant::now() - now).as_millis() as u64;
251 if past < self.timeout {
252 remain = self.timeout - past;
254
255 if self.receiver.receiver_count() > 1 {
257 self.signal.try_send(()).ok();
258 }
259 continue;
260 } else {
261 return Err(Error::TimeoutError);
262 }
263 }
264 }
265 }
266 }
267
268 fn check_connection(&self) -> Result<(), Error> {
269 if self.is_connected() {
270 Ok(())
271 } else {
272 Err(Error::NotConnected)
273 }
274 }
275
276 pub async fn fetch_security_list(&self) -> Result<Vec<SymbolInformation>, Error> {
283 self.check_connection()?;
284 let security_req_id = self.create_unique_id();
285 let req = SecurityListReq::new(security_req_id.clone(), 0, None);
286 self.internal.send_message(req).await?;
287 match self
288 .fetch_response(vec![("y", Field::SecurityReqID, security_req_id)])
289 .await
290 {
291 Ok(res) => parse_func::parse_security_list(&res),
292 Err(err) => Err(err),
293 }
294 }
295
296 pub async fn fetch_positions(&self) -> Result<Vec<PositionReport>, Error> {
297 self.check_connection()?;
298 let pos_req_id = self.create_unique_id();
299 let req = PositionsReq::new(pos_req_id.clone(), None);
300 self.internal.send_message(req).await?;
301
302 let mut result = Vec::new();
303
304 loop {
305 match self
306 .fetch_response(vec![("AP", Field::PosReqID, pos_req_id.clone())])
307 .await
308 {
309 Ok(res) => {
310 if res.get_message_type() == "AP"
311 && res
312 .get_field_value(Field::PosReqResult)
313 .map_or(false, |v| v.as_str() == "0")
314 {
315 let no_pos = res
316 .get_field_value(Field::TotalNumPosReports)
317 .unwrap_or("0".into())
318 .parse::<usize>()
319 .unwrap();
320 result.push(res);
321 if no_pos <= result.len() {
322 return parse_func::parse_positions(result);
323 } else {
324 continue;
325 }
326 } else {
327 return parse_func::parse_positions(vec![res]);
328 }
329 }
330 Err(err) => {
331 return Err(err);
332 }
333 }
334 }
335 }
336
337 pub async fn fetch_all_order_status(
338 &self,
339 issue_data: Option<NaiveDateTime>,
340 ) -> Result<Vec<ExecutionReport>, Error> {
341 self.check_connection()?;
342 let mass_status_req_id = self.create_unique_id();
343 let req = OrderMassStatusReq::new(mass_status_req_id.clone(), 7, issue_data);
345 self.internal.send_message(req).await?;
346
347 let mut result = Vec::new();
348
349 loop {
350 match self
351 .fetch_response(vec![
352 ("8", Field::MassStatusReqID, mass_status_req_id.clone()),
353 ("j", Field::BusinessRejectRefID, mass_status_req_id.clone()),
354 ])
355 .await
356 {
357 Ok(res) => {
358 return match res.get_message_type() {
359 "j" => Ok(Vec::new()),
360 "8" => {
361 let no_report = res
362 .get_field_value(Field::TotNumReports)
363 .unwrap_or("0".into())
364 .parse::<usize>()
365 .unwrap();
366
367 result.push(res);
368
369 if no_report <= result.len() {
370 parse_func::parse_order_mass_status(result)
371 } else {
372 continue;
373 }
374 }
375 _ => Err(Error::UnknownError),
376 };
377 }
378 Err(err) => {
379 return Err(err);
380 }
381 }
382 }
383 }
384
385 async fn new_order(&self, req: NewOrderSingleReq) -> Result<ExecutionReport, Error> {
386 self.check_connection()?;
387 let cl_ord_id = req.cl_ord_id.clone();
388
389 self.internal.send_message(req).await?;
390 match self
391 .fetch_response(vec![
392 ("8", Field::ClOrdId, cl_ord_id.clone()),
393 ("j", Field::BusinessRejectRefID, cl_ord_id.clone()),
394 ])
395 .await
396 {
397 Ok(res) => match res.get_message_type() {
398 "j" => Err(Error::OrderFailed(
399 res.get_field_value(Field::Text).unwrap_or("Unknown".into()),
400 )),
401 "8" => parse_func::parse_execution_report(res),
402 _ => Err(Error::UnknownError),
403 },
404 Err(err) => Err(err),
405 }
406 }
407
408 pub async fn new_market_order(
409 &self,
410 symbol: u32,
411 side: Side,
412 order_qty: f64,
413 cl_ord_id: Option<String>,
414 custom_ord_label: Option<String>,
415 ) -> Result<ExecutionReport, Error> {
416 let req = NewOrderSingleReq::new(
417 cl_ord_id.unwrap_or(self.create_unique_id()),
418 symbol,
419 side,
420 None,
421 order_qty,
422 OrderType::Market,
423 None,
424 None,
425 None,
426 None,
427 custom_ord_label,
428 );
429 self.new_order(req).await
430 }
431
432 pub async fn new_limit_order(
433 &self,
434 symbol: u32,
435 side: Side,
436 price: f64,
437 order_qty: f64,
438 cl_ord_id: Option<String>,
439 expire_time: Option<NaiveDateTime>,
440 custom_ord_label: Option<String>,
441 ) -> Result<ExecutionReport, Error> {
442 let req = NewOrderSingleReq::new(
443 cl_ord_id.unwrap_or(self.create_unique_id()),
444 symbol,
445 side,
446 None,
447 order_qty,
448 OrderType::Limit,
449 Some(price),
450 None,
451 expire_time,
452 None,
453 custom_ord_label,
454 );
455
456 self.new_order(req).await
457 }
458
459 pub async fn new_stop_order(
460 &self,
461 symbol: u32,
462 side: Side,
463 stop_px: f64,
464 order_qty: f64,
465 cl_ord_id: Option<String>,
466 expire_time: Option<NaiveDateTime>,
467 custom_ord_label: Option<String>,
468 ) -> Result<ExecutionReport, Error> {
469 let req = NewOrderSingleReq::new(
470 cl_ord_id.unwrap_or(self.create_unique_id()),
471 symbol,
472 side,
473 None,
474 order_qty,
475 OrderType::Stop,
476 None,
477 Some(stop_px),
478 expire_time,
479 None,
480 custom_ord_label,
481 );
482
483 self.new_order(req).await
484 }
485 pub async fn close_position(
486 &self,
487 pos_report: PositionReport,
488 custom_ord_label: Option<String>,
489 ) -> Result<ExecutionReport, Error> {
490 self.adjust_position_size(
491 pos_report.position_id,
492 pos_report.symbol_id,
493 if pos_report.long_qty == 0.0 {
494 pos_report.short_qty
495 } else {
496 pos_report.long_qty
497 },
498 if pos_report.long_qty == 0.0 {
499 Side::BUY
500 } else {
501 Side::SELL
502 },
503 custom_ord_label,
504 )
505 .await
506 }
507
508 pub async fn adjust_position_size(
515 &self,
516 pos_id: String,
517 symbol_id: u32,
518 lot: f64,
519 side: Side,
520 custom_ord_label: Option<String>,
521 ) -> Result<ExecutionReport, Error> {
522 let req = NewOrderSingleReq::new(
523 self.create_unique_id(),
524 symbol_id,
525 side,
526 None,
527 lot,
528 OrderType::Market,
529 None,
530 None,
531 None,
532 Some(pos_id),
533 custom_ord_label,
534 );
535
536 self.new_order(req).await
537 }
538
539 pub async fn replace_order(
549 &self,
550 org_cl_ord_id: Option<String>,
551 order_id: Option<String>,
552 order_qty: f64,
553 price: Option<f64>,
554 stop_px: Option<f64>,
555 expire_time: Option<NaiveDateTime>,
556 ) -> Result<ExecutionReport, Error> {
557 if org_cl_ord_id.is_none() && order_id.is_none() {
558 return Err(Error::MissingArgumentError);
559 }
560 self.check_connection()?;
561 let orgid = match org_cl_ord_id.clone() {
562 Some(v) => v,
563 None => order_id.clone().unwrap(),
564 };
565 let oid = match order_id.clone() {
566 Some(v) => v,
567 None => org_cl_ord_id.clone().unwrap(),
568 };
569 let cl_ord_id = self.create_unique_id();
570 let req = OrderCancelReplaceReq::new(
571 orgid,
572 Some(oid),
573 cl_ord_id.clone(),
574 order_qty,
575 price,
576 stop_px,
577 expire_time,
578 );
579 self.internal.send_message(req).await?;
580 match self
581 .fetch_response(vec![
582 if org_cl_ord_id.is_some() {
583 ("8", Field::ClOrdId, org_cl_ord_id.unwrap())
584 } else {
585 ("8", Field::OrderID, order_id.unwrap())
586 },
587 ("j", Field::BusinessRejectRefID, cl_ord_id.clone()),
588 ])
589 .await
590 {
591 Ok(res) => {
592 match res.get_message_type() {
593 "j" => {
594 Err(Error::OrderFailed(
596 res.get_field_value(Field::Text)
597 .unwrap_or("Unknown error".into()),
598 )
599 .into())
600 }
601 _ => {
602 parse_func::parse_execution_report(res)
604 }
605 }
606 }
607 Err(err) => Err(err),
608 }
609 }
610
611 pub async fn cancel_order(
620 &self,
621 org_cl_ord_id: Option<String>,
622 order_id: Option<String>,
623 ) -> Result<ExecutionReport, Error> {
624 if org_cl_ord_id.is_none() && order_id.is_none() {
625 return Err(Error::MissingArgumentError);
626 }
627 self.check_connection()?;
628
629 let orgid = match org_cl_ord_id.clone() {
630 Some(v) => v,
631 None => order_id.clone().unwrap(),
632 };
633 let oid = match order_id {
634 Some(v) => v,
635 None => org_cl_ord_id.unwrap(),
636 };
637
638 let cl_ord_id = self.create_unique_id();
639 let req = OrderCancelReq::new(orgid, Some(oid), cl_ord_id.clone());
640 self.internal.send_message(req).await?;
641 match self
642 .fetch_response(vec![
643 ("8", Field::ClOrdId, cl_ord_id.clone()),
644 ("j", Field::BusinessRejectRefID, cl_ord_id.clone()),
645 ("9", Field::ClOrdId, cl_ord_id.clone()),
646 ])
647 .await
648 {
649 Ok(res) => {
650 match res.get_message_type() {
651 "j" => {
652 Err(Error::OrderFailed(
654 res.get_field_value(Field::Text)
655 .unwrap_or("Unknown error".into()),
656 )
657 .into())
658 }
659 "9" => {
660 Err(Error::OrderCancelRejected(
662 res.get_field_value(Field::Text)
663 .unwrap_or("Unknown error".into()),
664 )
665 .into())
666 }
667 _ => {
668 parse_func::parse_execution_report(res)
670 }
671 }
672 }
673 Err(err) => Err(err),
674 }
675 }
676}