1use std::{
2 collections::VecDeque,
3 sync::{
4 atomic::{AtomicBool, AtomicU32, Ordering},
5 Arc,
6 },
7 time::Duration,
8};
9
10use async_std::{
11 channel::{bounded, Receiver},
12 io::{BufWriter, WriteExt},
13 net::TcpStream,
14 stream,
15 stream::StreamExt,
16 sync::RwLock,
17 task,
18};
19
20use crate::types::{Config, Error, Field, InternalMDResult, SubID, DELIMITER};
21use crate::{
22 messages::{HeartbeatReq, LogonReq, LogoutReq, RequestMessage, ResponseMessage},
23 types::ConnectionHandler,
24};
25use crate::{
26 socket::Socket,
27 types::{MarketCallback, TradeCallback},
28};
29
30pub struct FixApi {
31 config: Config,
32 stream: Option<Arc<TcpStream>>,
33 seq: Arc<AtomicU32>,
34 sub_id: SubID,
35
36 is_connected: Arc<AtomicBool>,
37
38 res_receiver: Option<Receiver<ResponseMessage>>,
39 message_buffer: Arc<RwLock<VecDeque<(u32, String)>>>,
43
44 connection_handler: Option<Arc<dyn ConnectionHandler + Send + Sync>>,
46 market_callback: Option<MarketCallback>,
47 trade_callback: Option<TradeCallback>,
48}
49
50impl FixApi {
51 pub fn new(
52 sub_id: SubID,
53 host: String,
54 login: String,
55 password: String,
56 sender_comp_id: String,
57 heartbeat_interval: Option<u32>,
58 ) -> Self {
59 Self {
60 config: Config::new(
61 host,
62 login,
63 password,
64 sender_comp_id,
65 heartbeat_interval.unwrap_or(30),
66 ),
67 stream: None,
68 res_receiver: None,
69 is_connected: Arc::new(AtomicBool::new(false)),
70 seq: Arc::new(AtomicU32::new(1)),
71 sub_id,
73
74 message_buffer: Arc::new(RwLock::new(VecDeque::new())),
75 connection_handler: None,
76 market_callback: None,
77 trade_callback: None,
78 }
79 }
80
81 pub fn register_market_callback<F>(&mut self, callback: F)
82 where
83 F: Fn(InternalMDResult) -> () + Send + Sync + 'static,
84 {
85 self.market_callback = Some(Arc::new(move |mdresult: InternalMDResult| -> () {
86 callback(mdresult)
87 }));
88 }
89
90 pub fn register_trade_callback<F>(&mut self, callback: F)
91 where
92 F: Fn(ResponseMessage) -> () + Send + Sync + 'static,
93 {
94 self.trade_callback = Some(Arc::new(move |res: ResponseMessage| -> () {
95 callback(res)
96 }));
97 }
98
99 pub fn register_connection_handler_arc<T: ConnectionHandler + Send + Sync + 'static>(
100 &mut self,
101 handler: Arc<T>,
102 ) {
103 self.connection_handler = Some(handler);
104 }
105
106 pub fn register_connection_handler<T: ConnectionHandler + Send + Sync + 'static>(
107 &mut self,
108 handler: T,
109 ) {
110 self.connection_handler = Some(Arc::new(handler));
111 }
112
113 pub async fn disconnect(&mut self) -> Result<(), Error> {
114 if let Some(stream) = self.stream.clone() {
115 stream.shutdown(std::net::Shutdown::Both)?;
116 }
117 self.stream = None;
118 self.res_receiver = None;
119 self.is_connected.store(false, Ordering::Relaxed);
120 self.message_buffer.write().await.clear();
121 Ok(())
122 }
123
124 pub async fn connect(&mut self) -> Result<(), Error> {
125 self.message_buffer.write().await.clear();
126 let (sender, receiver) = bounded(1);
127 let mut socket = Socket::connect(
128 self.config.host.as_str(),
129 if self.sub_id == SubID::QUOTE {
130 5201
131 } else {
132 5202
133 },
134 sender,
135 )
136 .await?;
137 self.is_connected.store(true, Ordering::Relaxed);
138 log::debug!("stream connected");
139
140 if let Some(handler) = self.connection_handler.clone() {
142 task::spawn(async move {
143 handler.on_connect().await;
144 });
145 }
146
147 self.res_receiver = Some(receiver);
148 self.stream = Some(socket.stream.clone());
149
150 let is_connected = self.is_connected.clone();
151
152 let handler = self.connection_handler.clone();
153 let _ = task::spawn(async move {
154 socket.recv_loop(is_connected, handler).await.ok();
155 });
156
157 Ok(())
158 }
159
160 pub async fn send_message<R: RequestMessage>(&self, req: R) -> Result<(), Error> {
161 let no_seq = self.seq.fetch_add(1, Ordering::Relaxed);
162 let req = req.build(self.sub_id, no_seq, DELIMITER, &self.config);
163 if let Some(stream) = self.stream.clone() {
164 self.message_buffer
166 .write()
167 .await
168 .push_back((no_seq, req.clone()));
169 self.message_buffer
170 .write()
171 .await
172 .push_back((no_seq, req.clone()));
173 if self.message_buffer.read().await.len() > 10 {
174 self.message_buffer.write().await.pop_front();
175 }
176
177 log::debug!("Send request : {}", req);
178 let mut writer = BufWriter::new(stream.as_ref());
179 writer.write_all(req.as_bytes()).await?;
180 writer.flush().await?;
181 }
182 Ok(())
183 }
184
185 pub fn is_connected(&self) -> bool {
186 self.is_connected.load(Ordering::Relaxed)
187 }
188
189 pub async fn logon(&self, heartbeat: bool) -> Result<(), Error> {
190 self.seq.store(1, Ordering::Relaxed);
192 self.send_message(LogonReq::new(Some(true))).await?;
193
194 if let Some(recv) = &self.res_receiver {
196 while let Ok(response) = recv.recv().await {
197 let msg_type = response.get_message_type();
199 match msg_type {
200 "A" => {
201 if let Some(handler) = self.connection_handler.clone() {
203 task::spawn(async move {
204 handler.on_logon().await;
205 });
206 }
207
208 let stream = self.stream.clone().unwrap();
209 let stream_clone = self.stream.clone().unwrap();
210 let sub_id = self.sub_id;
211 let config = self.config.clone();
212 let seq = self.seq.clone();
213 let msg_buffer = self.message_buffer.clone();
214 let is_connected = self.is_connected.clone();
215 let handler = self.connection_handler.clone();
216
217 let send_request = move |req: Box<dyn RequestMessage>| {
218 let stream = stream.clone();
219 let sub_id = sub_id;
220 let config = config.clone();
221 let seq = seq.clone();
222 let msg_buffer = msg_buffer.clone();
223 let is_connected = is_connected.clone();
224 let handler = handler.clone();
225 async move {
226 let msg_type = req.get_message_type();
227 let no_seq = seq.fetch_add(1, Ordering::Relaxed);
228 let req = req.build(sub_id, no_seq, DELIMITER, &config);
229 let handler = handler.clone();
230
231 msg_buffer.write().await.push_back((no_seq, req.clone()));
233 if msg_buffer.read().await.len() > 10 {
234 msg_buffer.write().await.pop_front();
235 }
236
237 let mut writer = BufWriter::new(stream.as_ref());
238 log::debug!(
239 "[Session:MsgType({msg_type})] Sending request: {}",
240 req
241 );
242 let _ = writer.write_all(req.as_bytes()).await;
243
244 match writer.flush().await {
245 Ok(_) => {}
246 Err(err) => {
247 log::error!("Failed to send the request - {:?}", err);
248 is_connected.store(false, Ordering::Relaxed);
249 if let Err(err) = stream.shutdown(std::net::Shutdown::Both)
250 {
251 log::error!(
252 "Failed to shutdown the stream - {:?}",
253 err
254 );
255 }
256 if let Some(handler) = handler {
257 task::spawn(async move {
258 handler.on_disconnect().await;
259 });
260 }
261 }
262 }
263 }
264 };
265 let send_request_clone = send_request.clone();
266
267 if heartbeat {
268 let hb_interval = self.config.heart_beat as u64;
269
270 let is_connected = self.is_connected.clone();
271
272 task::spawn(async move {
274 let mut heartbeat_stream =
275 stream::interval(Duration::from_secs(hb_interval));
276
277 while let Some(_) = heartbeat_stream.next().await {
278 if !is_connected.load(Ordering::Relaxed) {
279 break;
280 }
281 let req = HeartbeatReq::new(None);
282 send_request(Box::new(req)).await;
283 }
284 });
285 }
286
287 let recv = self.res_receiver.clone().unwrap();
291 let market_callback = self.market_callback.clone();
292 let trade_callback = self.trade_callback.clone();
293
294 let is_connected = self.is_connected.clone();
295 let msg_buffer = self.message_buffer.clone();
297 task::spawn(async move {
298 while let Ok(res) = recv.recv().await {
299 if !is_connected.load(Ordering::Relaxed) {
300 break;
301 }
302
303 let msg_type = res.get_message_type();
304
305 match msg_type {
307 "0" => {
308 log::debug!(
309 "[Session:MsyType({msg_type})] Received Heartbeat"
310 );
311 }
312 "2" => {
313 log::debug!(
314 "[Session:MsyType({msg_type})] Received ResendRequest"
315 );
316 let begin = res
317 .get_field_value(Field::BeginSeqNo)
318 .map(|v| v.parse::<u32>().unwrap_or(0))
319 .unwrap();
320
321 let end = res
322 .get_field_value(Field::EndSeqNo)
323 .map(|v| v.parse::<u32>().unwrap_or(0))
324 .unwrap();
325
326 {
327 for msg in msg_buffer
328 .read()
329 .await
330 .iter()
331 .filter(|(no, _)| {
332 if end == 0 {
333 *no >= begin
334 } else {
335 *no >= begin && *no <= end
336 }
337 })
338 .map(|(_, msg)| msg.clone())
339 {
340 let mut writer =
341 BufWriter::new(stream_clone.as_ref());
342 log::debug!(
343 "[Session:MsgType({msg_type})] Send ResendRequest: {}",
344 msg
345 );
346 let _ = writer.write_all(msg.as_bytes()).await;
347 break;
348 }
349 }
350 }
351 "5" => {
352 log::debug!(
353 "[Session:MsyType({msg_type})] Received Logged out"
354 );
355 stream_clone.shutdown(std::net::Shutdown::Both).ok();
358 }
359 "1" => {
360 log::debug!(
361 "[Session:MsyType({msg_type})] Received TestRequest"
362 );
363 if let Some(test_req_id) =
365 res.get_field_value(Field::TestReqID)
366 {
367 send_request_clone(Box::new(HeartbeatReq::new(Some(
368 test_req_id,
369 ))))
370 .await;
371 log::debug!("Sent the heartbeat from test_req_id");
372 }
373 }
374 "W" | "X" | "Y" => {
375 let symbol_id = res
377 .get_field_value(Field::Symbol)
378 .unwrap_or("0".into())
379 .parse::<u32>()
380 .unwrap();
381 if let Some(market_callback) = market_callback.clone() {
383 let mdresult = if msg_type == "Y" {
384 let md_req_id = res
385 .get_field_value(Field::MDReqID)
386 .map(|v| v.clone())
387 .unwrap_or("".into());
388 let err_msg = res
389 .get_field_value(Field::Text)
390 .map(|v| v.clone())
391 .unwrap_or("".into());
392 InternalMDResult::MDReject {
393 symbol_id,
394 md_req_id,
395 err_msg,
396 }
397 } else {
398 let data = res.get_repeating_groups(
399 Field::NoMDEntries,
400 if msg_type == "W" {
401 Field::MDEntryType
402 } else {
403 Field::MDUpdateAction
404 },
405 None,
406 );
407 InternalMDResult::MD {
408 msg_type: msg_type.chars().next().unwrap(),
409 symbol_id,
410 data,
411 }
412 };
413
414 market_callback(mdresult);
415 }
416 }
417 _ => {
418 log::debug!("{}", res.get_message());
419 if let Some(trade_callback) = trade_callback.clone() {
420 trade_callback(res);
421 }
422 }
423 }
424 }
425 });
426
427 break;
428 }
429 "5" => {
430 return Err(Error::LoggedOut);
431 }
432 _ => {}
433 }
434 }
435 }
436 Ok(())
437 }
438
439 pub async fn logout(&self) -> Result<(), Error> {
440 self.send_message(LogoutReq::default()).await?;
441 Ok(())
442 }
443}