cfix/
trade_client.rs

1use async_std::{
2    channel::{bounded, Receiver, Sender},
3    sync::RwLock,
4    task,
5};
6use chrono::NaiveDateTime;
7use uuid::Uuid;
8
9use crate::{
10    fixapi::FixApi,
11    messages::{
12        NewOrderSingleReq, OrderCancelReplaceReq, OrderCancelReq, OrderMassStatusReq, PositionsReq,
13        ResponseMessage, SecurityListReq,
14    },
15    parse_func::{self, parse_execution_report},
16    types::{
17        ConnectionHandler, Error, ExecutionReport, Field, OrderType, PositionReport, Side,
18        SymbolInformation, TradeDataHandler,
19    },
20};
21
22use std::{
23    collections::VecDeque,
24    sync::{
25        atomic::{AtomicBool, Ordering},
26        Arc,
27    },
28    time::{Duration, Instant},
29};
30
31#[derive(Debug)]
32struct TimeoutItem<T> {
33    item: T,
34    expiry: Instant,
35    consumed: AtomicBool,
36}
37
38impl<T> TimeoutItem<T> {
39    fn new(item: T, lifetime: Duration) -> Self {
40        TimeoutItem {
41            item,
42            expiry: Instant::now() + lifetime,
43            consumed: AtomicBool::new(false),
44        }
45    }
46}
47
48pub struct TradeClient {
49    internal: FixApi,
50
51    trade_data_handler: Option<Arc<dyn TradeDataHandler + Send + Sync>>,
52
53    queue: Arc<RwLock<VecDeque<TimeoutItem<ResponseMessage>>>>,
54
55    signal: Sender<()>,
56    receiver: Receiver<()>,
57
58    // for waiting response in fetch methods.
59    timeout: u64,
60}
61
62impl TradeClient {
63    pub fn new(
64        host: String,
65        login: String,
66        password: String,
67        sender_comp_id: String,
68        heartbeat_interval: Option<u32>,
69    ) -> Self {
70        let (tx, rx) = bounded(1);
71        Self {
72            internal: FixApi::new(
73                crate::types::SubID::TRADE,
74                host,
75                login,
76                password,
77                sender_comp_id,
78                heartbeat_interval,
79            ),
80            trade_data_handler: None,
81            queue: Arc::new(RwLock::new(VecDeque::new())),
82
83            signal: tx,
84            receiver: rx,
85
86            timeout: 5000, //
87        }
88    }
89
90    pub fn get_timeout(&self) -> u64 {
91        self.timeout
92    }
93
94    pub fn set_timeout(&mut self, timeout: u64) {
95        self.timeout = timeout;
96    }
97
98    pub fn register_trade_handler_arc<T: TradeDataHandler + Send + Sync + 'static>(
99        &mut self,
100        handler: Arc<T>,
101    ) {
102        self.trade_data_handler = Some(handler);
103    }
104
105    pub fn register_trade_handler<T: TradeDataHandler + Send + Sync + 'static>(
106        &mut self,
107        handler: T,
108    ) {
109        self.trade_data_handler = Some(Arc::new(handler));
110    }
111
112    pub fn register_connection_handler<T: ConnectionHandler + Send + Sync + 'static>(
113        &mut self,
114        handler: T,
115    ) {
116        self.internal.register_connection_handler(handler);
117    }
118
119    pub fn register_connection_handler_arc<T: ConnectionHandler + Send + Sync + 'static>(
120        &mut self,
121        handler: Arc<T>,
122    ) {
123        self.internal.register_connection_handler_arc(handler);
124    }
125
126    pub async fn connect(&mut self) -> Result<(), Error> {
127        self.register_internal_handler();
128        self.internal.connect().await?;
129        self.internal.logon(false).await
130    }
131
132    pub async fn disconnect(&mut self) -> Result<(), Error> {
133        self.internal.disconnect().await
134    }
135
136    pub fn is_connected(&self) -> bool {
137        self.internal.is_connected()
138    }
139
140    fn register_internal_handler(&mut self) {
141        let queue = self.queue.clone();
142        let handler = self.trade_data_handler.clone();
143        let signal = self.signal.clone();
144        let trade_callback = move |res: ResponseMessage| {
145            let signal = signal.clone();
146            let handler = handler.clone();
147            let queue = queue.clone();
148            let lifetime = Duration::from_millis(5000);
149            task::spawn(async move {
150                match res.get_message_type() {
151                    "8" => {
152                        if res
153                            .get_field_value(Field::ExecType)
154                            .map(|v| v.as_str() != "I")
155                            .unwrap_or(true)
156                        {
157                            match parse_execution_report(res.clone()) {
158                                Ok(report) => {
159                                    if let Some(handler) = handler {
160                                        handler.on_execution_report(report).await;
161                                    }
162                                }
163                                Err(_err) => {
164                                    // IGNORE
165                                }
166                            }
167                        }
168                    }
169                    _ => {}
170                }
171
172                queue
173                    .write()
174                    .await
175                    .push_back(TimeoutItem::new(res, lifetime));
176
177                // check timeout
178                let now = Instant::now();
179                loop {
180                    let expiry = queue.read().await.front().map(|v| v.expiry).unwrap_or(now);
181                    if expiry < now {
182                        // pop old item
183                        queue.write().await.pop_front();
184                    } else {
185                        break;
186                    }
187                }
188
189                signal.try_send(()).ok();
190                // signal.send(()).await.ok();
191            });
192        };
193
194        self.internal.register_trade_callback(trade_callback);
195    }
196
197    fn create_unique_id(&self) -> String {
198        Uuid::new_v4().to_string()
199    }
200
201    async fn wait_notifier(&self, receiver: Receiver<()>, dur: u64) -> Result<(), Error> {
202        if !self.is_connected() {
203            return Err(Error::NotConnected);
204        }
205        async_std::future::timeout(Duration::from_millis(dur), receiver.recv())
206            .await
207            .map_err(|_| Error::TimeoutError)?
208            .map_err(|e| e.into())
209    }
210
211    async fn fetch_response(
212        &self,
213        arg: Vec<(&str, Field, String)>,
214    ) -> Result<ResponseMessage, Error> {
215        // setup for timeout
216        let now = Instant::now();
217        let mut remain = self.timeout;
218
219        loop {
220            let _ = self.wait_notifier(self.receiver.clone(), remain).await?;
221            // match self.wait_notifier(receiver, remain).await {
222            let mut res = None;
223            let q = self.queue.read().await;
224            for v in q.iter().rev() {
225                let mut b = false;
226                let consumed = v.consumed.load(Ordering::Relaxed);
227                if consumed {
228                    continue;
229                }
230
231                for (msg_type, field, value) in arg.iter() {
232                    if v.item.matching_field_value(msg_type, *field, value) {
233                        b = true;
234                        res = Some(v.item.clone());
235                        v.consumed.store(true, Ordering::Relaxed);
236                        break;
237                    }
238                }
239                if b {
240                    break;
241                }
242            }
243
244            match res {
245                Some(res) => {
246                    return Ok(res);
247                }
248                None => {
249                    // check remaining time.
250                    let past = (Instant::now() - now).as_millis() as u64;
251                    if past < self.timeout {
252                        // continue.
253                        remain = self.timeout - past;
254
255                        // check if there is more waiting receiver.
256                        if self.receiver.receiver_count() > 1 {
257                            self.signal.try_send(()).ok();
258                        }
259                        continue;
260                    } else {
261                        return Err(Error::TimeoutError);
262                    }
263                }
264            }
265        }
266    }
267
268    fn check_connection(&self) -> Result<(), Error> {
269        if self.is_connected() {
270            Ok(())
271        } else {
272            Err(Error::NotConnected)
273        }
274    }
275
276    /// Fetch the security list from the server.
277    ///
278    ///
279    /// This is asn asynchronous method that sends a request to the server and waits for the
280    /// response. It returns a result containing the data if the request succesful, or an error if
281    /// it fails.
282    pub async fn fetch_security_list(&self) -> Result<Vec<SymbolInformation>, Error> {
283        self.check_connection()?;
284        let security_req_id = self.create_unique_id();
285        let req = SecurityListReq::new(security_req_id.clone(), 0, None);
286        self.internal.send_message(req).await?;
287        match self
288            .fetch_response(vec![("y", Field::SecurityReqID, security_req_id)])
289            .await
290        {
291            Ok(res) => parse_func::parse_security_list(&res),
292            Err(err) => Err(err),
293        }
294    }
295
296    pub async fn fetch_positions(&self) -> Result<Vec<PositionReport>, Error> {
297        self.check_connection()?;
298        let pos_req_id = self.create_unique_id();
299        let req = PositionsReq::new(pos_req_id.clone(), None);
300        self.internal.send_message(req).await?;
301
302        let mut result = Vec::new();
303
304        loop {
305            match self
306                .fetch_response(vec![("AP", Field::PosReqID, pos_req_id.clone())])
307                .await
308            {
309                Ok(res) => {
310                    if res.get_message_type() == "AP"
311                        && res
312                            .get_field_value(Field::PosReqResult)
313                            .map_or(false, |v| v.as_str() == "0")
314                    {
315                        let no_pos = res
316                            .get_field_value(Field::TotalNumPosReports)
317                            .unwrap_or("0".into())
318                            .parse::<usize>()
319                            .unwrap();
320                        result.push(res);
321                        if no_pos <= result.len() {
322                            return parse_func::parse_positions(result);
323                        } else {
324                            continue;
325                        }
326                    } else {
327                        return parse_func::parse_positions(vec![res]);
328                    }
329                }
330                Err(err) => {
331                    return Err(err);
332                }
333            }
334        }
335    }
336
337    pub async fn fetch_all_order_status(
338        &self,
339        issue_data: Option<NaiveDateTime>,
340    ) -> Result<Vec<ExecutionReport>, Error> {
341        self.check_connection()?;
342        let mass_status_req_id = self.create_unique_id();
343        // FIXME if mass_status_req_id is not 7, then return 'j' but response does not include the mass_status_req_id
344        let req = OrderMassStatusReq::new(mass_status_req_id.clone(), 7, issue_data);
345        self.internal.send_message(req).await?;
346
347        let mut result = Vec::new();
348
349        loop {
350            match self
351                .fetch_response(vec![
352                    ("8", Field::MassStatusReqID, mass_status_req_id.clone()),
353                    ("j", Field::BusinessRejectRefID, mass_status_req_id.clone()),
354                ])
355                .await
356            {
357                Ok(res) => {
358                    return match res.get_message_type() {
359                        "j" => Ok(Vec::new()),
360                        "8" => {
361                            let no_report = res
362                                .get_field_value(Field::TotNumReports)
363                                .unwrap_or("0".into())
364                                .parse::<usize>()
365                                .unwrap();
366
367                            result.push(res);
368
369                            if no_report <= result.len() {
370                                parse_func::parse_order_mass_status(result)
371                            } else {
372                                continue;
373                            }
374                        }
375                        _ => Err(Error::UnknownError),
376                    };
377                }
378                Err(err) => {
379                    return Err(err);
380                }
381            }
382        }
383    }
384
385    async fn new_order(&self, req: NewOrderSingleReq) -> Result<ExecutionReport, Error> {
386        self.check_connection()?;
387        let cl_ord_id = req.cl_ord_id.clone();
388
389        self.internal.send_message(req).await?;
390        match self
391            .fetch_response(vec![
392                ("8", Field::ClOrdId, cl_ord_id.clone()),
393                ("j", Field::BusinessRejectRefID, cl_ord_id.clone()),
394            ])
395            .await
396        {
397            Ok(res) => match res.get_message_type() {
398                "j" => Err(Error::OrderFailed(
399                    res.get_field_value(Field::Text).unwrap_or("Unknown".into()),
400                )),
401                "8" => parse_func::parse_execution_report(res),
402                _ => Err(Error::UnknownError),
403            },
404            Err(err) => Err(err),
405        }
406    }
407
408    pub async fn new_market_order(
409        &self,
410        symbol: u32,
411        side: Side,
412        order_qty: f64,
413        cl_ord_id: Option<String>,
414        custom_ord_label: Option<String>,
415    ) -> Result<ExecutionReport, Error> {
416        let req = NewOrderSingleReq::new(
417            cl_ord_id.unwrap_or(self.create_unique_id()),
418            symbol,
419            side,
420            None,
421            order_qty,
422            OrderType::Market,
423            None,
424            None,
425            None,
426            None,
427            custom_ord_label,
428        );
429        self.new_order(req).await
430    }
431
432    pub async fn new_limit_order(
433        &self,
434        symbol: u32,
435        side: Side,
436        price: f64,
437        order_qty: f64,
438        cl_ord_id: Option<String>,
439        expire_time: Option<NaiveDateTime>,
440        custom_ord_label: Option<String>,
441    ) -> Result<ExecutionReport, Error> {
442        let req = NewOrderSingleReq::new(
443            cl_ord_id.unwrap_or(self.create_unique_id()),
444            symbol,
445            side,
446            None,
447            order_qty,
448            OrderType::Limit,
449            Some(price),
450            None,
451            expire_time,
452            None,
453            custom_ord_label,
454        );
455
456        self.new_order(req).await
457    }
458
459    pub async fn new_stop_order(
460        &self,
461        symbol: u32,
462        side: Side,
463        stop_px: f64,
464        order_qty: f64,
465        cl_ord_id: Option<String>,
466        expire_time: Option<NaiveDateTime>,
467        custom_ord_label: Option<String>,
468    ) -> Result<ExecutionReport, Error> {
469        let req = NewOrderSingleReq::new(
470            cl_ord_id.unwrap_or(self.create_unique_id()),
471            symbol,
472            side,
473            None,
474            order_qty,
475            OrderType::Stop,
476            None,
477            Some(stop_px),
478            expire_time,
479            None,
480            custom_ord_label,
481        );
482
483        self.new_order(req).await
484    }
485    pub async fn close_position(
486        &self,
487        pos_report: PositionReport,
488        custom_ord_label: Option<String>,
489    ) -> Result<ExecutionReport, Error> {
490        self.adjust_position_size(
491            pos_report.position_id,
492            pos_report.symbol_id,
493            if pos_report.long_qty == 0.0 {
494                pos_report.short_qty
495            } else {
496                pos_report.long_qty
497            },
498            if pos_report.long_qty == 0.0 {
499                Side::BUY
500            } else {
501                Side::SELL
502            },
503            custom_ord_label,
504        )
505        .await
506    }
507
508    /// Adjusts the size of a position.
509    ///
510    /// This method takes a position id, symbol_id, a side (buy or sell), and a lot size.
511    /// If the position exists, it adjusts the size of the position by adding or subtracting the given lot size.
512    /// If the side is 'buy', the lot size is added to the position.
513    /// If the side is 'sell', the lot size is subtracted from the position.
514    pub async fn adjust_position_size(
515        &self,
516        pos_id: String,
517        symbol_id: u32,
518        lot: f64,
519        side: Side,
520        custom_ord_label: Option<String>,
521    ) -> Result<ExecutionReport, Error> {
522        let req = NewOrderSingleReq::new(
523            self.create_unique_id(),
524            symbol_id,
525            side,
526            None,
527            lot,
528            OrderType::Market,
529            None,
530            None,
531            None,
532            Some(pos_id),
533            custom_ord_label,
534        );
535
536        self.new_order(req).await
537    }
538
539    /// Replace order request
540    ///
541    /// # Arguments
542    ///
543    /// * `orig_cl_ord_id` - A unique identifier for the order, which is going to be canceled, allocated by the client.
544    /// * `order_id` - Unique ID of an order, returned by the server.
545    /// ...
546    ///
547    ///  Either `orig_cl_ord_id` or `order_id` must be passed to this function. If both are `None`, the function will return an error.
548    pub async fn replace_order(
549        &self,
550        org_cl_ord_id: Option<String>,
551        order_id: Option<String>,
552        order_qty: f64,
553        price: Option<f64>,
554        stop_px: Option<f64>,
555        expire_time: Option<NaiveDateTime>,
556    ) -> Result<ExecutionReport, Error> {
557        if org_cl_ord_id.is_none() && order_id.is_none() {
558            return Err(Error::MissingArgumentError);
559        }
560        self.check_connection()?;
561        let orgid = match org_cl_ord_id.clone() {
562            Some(v) => v,
563            None => order_id.clone().unwrap(),
564        };
565        let oid = match order_id.clone() {
566            Some(v) => v,
567            None => org_cl_ord_id.clone().unwrap(),
568        };
569        let cl_ord_id = self.create_unique_id();
570        let req = OrderCancelReplaceReq::new(
571            orgid,
572            Some(oid),
573            cl_ord_id.clone(),
574            order_qty,
575            price,
576            stop_px,
577            expire_time,
578        );
579        self.internal.send_message(req).await?;
580        match self
581            .fetch_response(vec![
582                if org_cl_ord_id.is_some() {
583                    ("8", Field::ClOrdId, org_cl_ord_id.unwrap())
584                } else {
585                    ("8", Field::OrderID, order_id.unwrap())
586                },
587                ("j", Field::BusinessRejectRefID, cl_ord_id.clone()),
588            ])
589            .await
590        {
591            Ok(res) => {
592                match res.get_message_type() {
593                    "j" => {
594                        // failed
595                        Err(Error::OrderFailed(
596                            res.get_field_value(Field::Text)
597                                .unwrap_or("Unknown error".into()),
598                        )
599                        .into())
600                    }
601                    _ => {
602                        // "8" Success
603                        parse_func::parse_execution_report(res)
604                    }
605                }
606            }
607            Err(err) => Err(err),
608        }
609    }
610
611    /// Order cancel reqeuest
612    ///
613    /// # Arguments
614    ///
615    /// * `orig_cl_ord_id` - A unique identifier for the order, which is going to be canceled, allocated by the client.
616    /// * `order_id` - Unique ID of an order, returned by the server.
617    ///
618    ///  Either `orig_cl_ord_id` or `order_id` must be passed to this function. If both are `None`, the function will return an error.
619    pub async fn cancel_order(
620        &self,
621        org_cl_ord_id: Option<String>,
622        order_id: Option<String>,
623    ) -> Result<ExecutionReport, Error> {
624        if org_cl_ord_id.is_none() && order_id.is_none() {
625            return Err(Error::MissingArgumentError);
626        }
627        self.check_connection()?;
628
629        let orgid = match org_cl_ord_id.clone() {
630            Some(v) => v,
631            None => order_id.clone().unwrap(),
632        };
633        let oid = match order_id {
634            Some(v) => v,
635            None => org_cl_ord_id.unwrap(),
636        };
637
638        let cl_ord_id = self.create_unique_id();
639        let req = OrderCancelReq::new(orgid, Some(oid), cl_ord_id.clone());
640        self.internal.send_message(req).await?;
641        match self
642            .fetch_response(vec![
643                ("8", Field::ClOrdId, cl_ord_id.clone()),
644                ("j", Field::BusinessRejectRefID, cl_ord_id.clone()),
645                ("9", Field::ClOrdId, cl_ord_id.clone()),
646            ])
647            .await
648        {
649            Ok(res) => {
650                match res.get_message_type() {
651                    "j" => {
652                        // failed
653                        Err(Error::OrderFailed(
654                            res.get_field_value(Field::Text)
655                                .unwrap_or("Unknown error".into()),
656                        )
657                        .into())
658                    }
659                    "9" => {
660                        // cancel rejected
661                        Err(Error::OrderCancelRejected(
662                            res.get_field_value(Field::Text)
663                                .unwrap_or("Unknown error".into()),
664                        )
665                        .into())
666                    }
667                    _ => {
668                        // "8" Success
669                        parse_func::parse_execution_report(res)
670                    }
671                }
672            }
673            Err(err) => Err(err),
674        }
675    }
676}