cfix/
fixapi.rs

1use std::{
2    collections::VecDeque,
3    sync::{
4        atomic::{AtomicBool, AtomicU32, Ordering},
5        Arc,
6    },
7    time::Duration,
8};
9
10use async_std::{
11    channel::{bounded, Receiver},
12    io::{BufWriter, WriteExt},
13    net::TcpStream,
14    stream,
15    stream::StreamExt,
16    sync::RwLock,
17    task,
18};
19
20use crate::types::{Config, Error, Field, InternalMDResult, SubID, DELIMITER};
21use crate::{
22    messages::{HeartbeatReq, LogonReq, LogoutReq, RequestMessage, ResponseMessage},
23    types::ConnectionHandler,
24};
25use crate::{
26    socket::Socket,
27    types::{MarketCallback, TradeCallback},
28};
29
30pub struct FixApi {
31    config: Config,
32    stream: Option<Arc<TcpStream>>,
33    seq: Arc<AtomicU32>,
34    sub_id: SubID,
35
36    is_connected: Arc<AtomicBool>,
37
38    res_receiver: Option<Receiver<ResponseMessage>>,
39    // pub container: Arc<RwLock<HashMap<String, Vec<ResponseMessage>>>>,
40
41    // ReqMessage Container
42    message_buffer: Arc<RwLock<VecDeque<(u32, String)>>>,
43
44    //callback
45    connection_handler: Option<Arc<dyn ConnectionHandler + Send + Sync>>,
46    market_callback: Option<MarketCallback>,
47    trade_callback: Option<TradeCallback>,
48}
49
50impl FixApi {
51    pub fn new(
52        sub_id: SubID,
53        host: String,
54        login: String,
55        password: String,
56        sender_comp_id: String,
57        heartbeat_interval: Option<u32>,
58    ) -> Self {
59        Self {
60            config: Config::new(
61                host,
62                login,
63                password,
64                sender_comp_id,
65                heartbeat_interval.unwrap_or(30),
66            ),
67            stream: None,
68            res_receiver: None,
69            is_connected: Arc::new(AtomicBool::new(false)),
70            seq: Arc::new(AtomicU32::new(1)),
71            // container: Arc::new(RwLock::new(HashMap::new())),
72            sub_id,
73
74            message_buffer: Arc::new(RwLock::new(VecDeque::new())),
75            connection_handler: None,
76            market_callback: None,
77            trade_callback: None,
78        }
79    }
80
81    pub fn register_market_callback<F>(&mut self, callback: F)
82    where
83        F: Fn(InternalMDResult) -> () + Send + Sync + 'static,
84    {
85        self.market_callback = Some(Arc::new(move |mdresult: InternalMDResult| -> () {
86            callback(mdresult)
87        }));
88    }
89
90    pub fn register_trade_callback<F>(&mut self, callback: F)
91    where
92        F: Fn(ResponseMessage) -> () + Send + Sync + 'static,
93    {
94        self.trade_callback = Some(Arc::new(move |res: ResponseMessage| -> () {
95            callback(res)
96        }));
97    }
98
99    pub fn register_connection_handler_arc<T: ConnectionHandler + Send + Sync + 'static>(
100        &mut self,
101        handler: Arc<T>,
102    ) {
103        self.connection_handler = Some(handler);
104    }
105
106    pub fn register_connection_handler<T: ConnectionHandler + Send + Sync + 'static>(
107        &mut self,
108        handler: T,
109    ) {
110        self.connection_handler = Some(Arc::new(handler));
111    }
112
113    pub async fn disconnect(&mut self) -> Result<(), Error> {
114        if let Some(stream) = self.stream.clone() {
115            stream.shutdown(std::net::Shutdown::Both)?;
116        }
117        self.stream = None;
118        self.res_receiver = None;
119        self.is_connected.store(false, Ordering::Relaxed);
120        self.message_buffer.write().await.clear();
121        Ok(())
122    }
123
124    pub async fn connect(&mut self) -> Result<(), Error> {
125        self.message_buffer.write().await.clear();
126        let (sender, receiver) = bounded(1);
127        let mut socket = Socket::connect(
128            self.config.host.as_str(),
129            if self.sub_id == SubID::QUOTE {
130                5201
131            } else {
132                5202
133            },
134            sender,
135        )
136        .await?;
137        self.is_connected.store(true, Ordering::Relaxed);
138        log::debug!("stream connected");
139
140        // notify connection
141        if let Some(handler) = self.connection_handler.clone() {
142            task::spawn(async move {
143                handler.on_connect().await;
144            });
145        }
146
147        self.res_receiver = Some(receiver);
148        self.stream = Some(socket.stream.clone());
149
150        let is_connected = self.is_connected.clone();
151
152        let handler = self.connection_handler.clone();
153        let _ = task::spawn(async move {
154            socket.recv_loop(is_connected, handler).await.ok();
155        });
156
157        Ok(())
158    }
159
160    pub async fn send_message<R: RequestMessage>(&self, req: R) -> Result<(), Error> {
161        let no_seq = self.seq.fetch_add(1, Ordering::Relaxed);
162        let req = req.build(self.sub_id, no_seq, DELIMITER, &self.config);
163        if let Some(stream) = self.stream.clone() {
164            // FIXME
165            self.message_buffer
166                .write()
167                .await
168                .push_back((no_seq, req.clone()));
169            self.message_buffer
170                .write()
171                .await
172                .push_back((no_seq, req.clone()));
173            if self.message_buffer.read().await.len() > 10 {
174                self.message_buffer.write().await.pop_front();
175            }
176
177            log::debug!("Send request : {}", req);
178            let mut writer = BufWriter::new(stream.as_ref());
179            writer.write_all(req.as_bytes()).await?;
180            writer.flush().await?;
181        }
182        Ok(())
183    }
184
185    pub fn is_connected(&self) -> bool {
186        self.is_connected.load(Ordering::Relaxed)
187    }
188
189    pub async fn logon(&self, heartbeat: bool) -> Result<(), Error> {
190        // res3et the seq
191        self.seq.store(1, Ordering::Relaxed);
192        self.send_message(LogonReq::new(Some(true))).await?;
193
194        // wait to receive the response
195        if let Some(recv) = &self.res_receiver {
196            while let Ok(response) = recv.recv().await {
197                // logon response
198                let msg_type = response.get_message_type();
199                match msg_type {
200                    "A" => {
201                        //
202                        if let Some(handler) = self.connection_handler.clone() {
203                            task::spawn(async move {
204                                handler.on_logon().await;
205                            });
206                        }
207
208                        let stream = self.stream.clone().unwrap();
209                        let stream_clone = self.stream.clone().unwrap();
210                        let sub_id = self.sub_id;
211                        let config = self.config.clone();
212                        let seq = self.seq.clone();
213                        let msg_buffer = self.message_buffer.clone();
214                        let is_connected = self.is_connected.clone();
215                        let handler = self.connection_handler.clone();
216
217                        let send_request = move |req: Box<dyn RequestMessage>| {
218                            let stream = stream.clone();
219                            let sub_id = sub_id;
220                            let config = config.clone();
221                            let seq = seq.clone();
222                            let msg_buffer = msg_buffer.clone();
223                            let is_connected = is_connected.clone();
224                            let handler = handler.clone();
225                            async move {
226                                let msg_type = req.get_message_type();
227                                let no_seq = seq.fetch_add(1, Ordering::Relaxed);
228                                let req = req.build(sub_id, no_seq, DELIMITER, &config);
229                                let handler = handler.clone();
230
231                                // FIXME later
232                                msg_buffer.write().await.push_back((no_seq, req.clone()));
233                                if msg_buffer.read().await.len() > 10 {
234                                    msg_buffer.write().await.pop_front();
235                                }
236
237                                let mut writer = BufWriter::new(stream.as_ref());
238                                log::debug!(
239                                    "[Session:MsgType({msg_type})] Sending request: {}",
240                                    req
241                                );
242                                let _ = writer.write_all(req.as_bytes()).await;
243
244                                match writer.flush().await {
245                                    Ok(_) => {}
246                                    Err(err) => {
247                                        log::error!("Failed to send the request - {:?}", err);
248                                        is_connected.store(false, Ordering::Relaxed);
249                                        if let Err(err) = stream.shutdown(std::net::Shutdown::Both)
250                                        {
251                                            log::error!(
252                                                "Failed to shutdown the stream - {:?}",
253                                                err
254                                            );
255                                        }
256                                        if let Some(handler) = handler {
257                                            task::spawn(async move {
258                                                handler.on_disconnect().await;
259                                            });
260                                        }
261                                    }
262                                }
263                            }
264                        };
265                        let send_request_clone = send_request.clone();
266
267                        if heartbeat {
268                            let hb_interval = self.config.heart_beat as u64;
269
270                            let is_connected = self.is_connected.clone();
271
272                            //send heartbeat per hb_interval
273                            task::spawn(async move {
274                                let mut heartbeat_stream =
275                                    stream::interval(Duration::from_secs(hb_interval));
276
277                                while let Some(_) = heartbeat_stream.next().await {
278                                    if !is_connected.load(Ordering::Relaxed) {
279                                        break;
280                                    }
281                                    let req = HeartbeatReq::new(None);
282                                    send_request(Box::new(req)).await;
283                                }
284                            });
285                        }
286
287                        //
288                        // handle the responses
289
290                        let recv = self.res_receiver.clone().unwrap();
291                        let market_callback = self.market_callback.clone();
292                        let trade_callback = self.trade_callback.clone();
293
294                        let is_connected = self.is_connected.clone();
295                        // let seq = self.seq.clone();
296                        let msg_buffer = self.message_buffer.clone();
297                        task::spawn(async move {
298                            while let Ok(res) = recv.recv().await {
299                                if !is_connected.load(Ordering::Relaxed) {
300                                    break;
301                                }
302
303                                let msg_type = res.get_message_type();
304
305                                // notify? or send? via channel?
306                                match msg_type {
307                                    "0" => {
308                                        log::debug!(
309                                            "[Session:MsyType({msg_type})] Received Heartbeat"
310                                        );
311                                    }
312                                    "2" => {
313                                        log::debug!(
314                                            "[Session:MsyType({msg_type})] Received ResendRequest"
315                                        );
316                                        let begin = res
317                                            .get_field_value(Field::BeginSeqNo)
318                                            .map(|v| v.parse::<u32>().unwrap_or(0))
319                                            .unwrap();
320
321                                        let end = res
322                                            .get_field_value(Field::EndSeqNo)
323                                            .map(|v| v.parse::<u32>().unwrap_or(0))
324                                            .unwrap();
325
326                                        {
327                                            for msg in msg_buffer
328                                                .read()
329                                                .await
330                                                .iter()
331                                                .filter(|(no, _)| {
332                                                    if end == 0 {
333                                                        *no >= begin
334                                                    } else {
335                                                        *no >= begin && *no <= end
336                                                    }
337                                                })
338                                                .map(|(_, msg)| msg.clone())
339                                            {
340                                                let mut writer =
341                                                    BufWriter::new(stream_clone.as_ref());
342                                                log::debug!(
343                                                 "[Session:MsgType({msg_type})] Send ResendRequest: {}",
344                                                msg
345                                                );
346                                                let _ = writer.write_all(msg.as_bytes()).await;
347                                                break;
348                                            }
349                                        }
350                                    }
351                                    "5" => {
352                                        log::debug!(
353                                            "[Session:MsyType({msg_type})] Received Logged out"
354                                        );
355                                        // 5 : logout
356                                        //disconnect
357                                        stream_clone.shutdown(std::net::Shutdown::Both).ok();
358                                    }
359                                    "1" => {
360                                        log::debug!(
361                                            "[Session:MsyType({msg_type})] Received TestRequest"
362                                        );
363                                        // send back with test request id
364                                        if let Some(test_req_id) =
365                                            res.get_field_value(Field::TestReqID)
366                                        {
367                                            send_request_clone(Box::new(HeartbeatReq::new(Some(
368                                                test_req_id,
369                                            ))))
370                                            .await;
371                                            log::debug!("Sent the heartbeat from test_req_id");
372                                        }
373                                    }
374                                    "W" | "X" | "Y" => {
375                                        // For market data
376                                        let symbol_id = res
377                                            .get_field_value(Field::Symbol)
378                                            .unwrap_or("0".into())
379                                            .parse::<u32>()
380                                            .unwrap();
381                                        // notify to callback
382                                        if let Some(market_callback) = market_callback.clone() {
383                                            let mdresult = if msg_type == "Y" {
384                                                let md_req_id = res
385                                                    .get_field_value(Field::MDReqID)
386                                                    .map(|v| v.clone())
387                                                    .unwrap_or("".into());
388                                                let err_msg = res
389                                                    .get_field_value(Field::Text)
390                                                    .map(|v| v.clone())
391                                                    .unwrap_or("".into());
392                                                InternalMDResult::MDReject {
393                                                    symbol_id,
394                                                    md_req_id,
395                                                    err_msg,
396                                                }
397                                            } else {
398                                                let data = res.get_repeating_groups(
399                                                    Field::NoMDEntries,
400                                                    if msg_type == "W" {
401                                                        Field::MDEntryType
402                                                    } else {
403                                                        Field::MDUpdateAction
404                                                    },
405                                                    None,
406                                                );
407                                                InternalMDResult::MD {
408                                                    msg_type: msg_type.chars().next().unwrap(),
409                                                    symbol_id,
410                                                    data,
411                                                }
412                                            };
413
414                                            market_callback(mdresult);
415                                        }
416                                    }
417                                    _ => {
418                                        log::debug!("{}", res.get_message());
419                                        if let Some(trade_callback) = trade_callback.clone() {
420                                            trade_callback(res);
421                                        }
422                                    }
423                                }
424                            }
425                        });
426
427                        break;
428                    }
429                    "5" => {
430                        return Err(Error::LoggedOut);
431                    }
432                    _ => {}
433                }
434            }
435        }
436        Ok(())
437    }
438
439    pub async fn logout(&self) -> Result<(), Error> {
440        self.send_message(LogoutReq::default()).await?;
441        Ok(())
442    }
443}