1use crate::codec::{FixDecoder, FixMessage, tags};
4use crate::config::FixConfig;
5use crate::error::{FixError, Result};
6use crate::messages::{
7 ExecType, ExecutionReport, MarketDataRequest, MsgType, NewOrderSingle, OrdStatus,
8 OrderCancelReplaceRequest, OrderCancelRequest, Side,
9};
10use crate::session::{FixSession, SessionState};
11use crate::transport::{self, FixTransport};
12use alpaca_base::Credentials;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::sync::{Mutex, mpsc};
16use tokio::time::{interval, timeout};
17
18const MESSAGE_CHANNEL_SIZE: usize = 1000;
20
21const DEFAULT_TIMEOUT_SECS: u64 = 30;
23
24pub struct FixClient {
26 #[allow(dead_code)]
28 credentials: Credentials,
29 session: Arc<Mutex<FixSession>>,
31 transport: Arc<Mutex<Option<FixTransport>>>,
33 #[allow(dead_code)]
35 decoder: FixDecoder,
36 config: FixConfig,
38 message_rx: Arc<Mutex<Option<mpsc::Receiver<FixMessage>>>>,
40 shutdown_tx: Arc<Mutex<Option<mpsc::Sender<()>>>>,
42}
43
44impl std::fmt::Debug for FixClient {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 f.debug_struct("FixClient")
47 .field("config", &self.config)
48 .finish()
49 }
50}
51
52impl FixClient {
53 #[must_use]
55 pub fn new(credentials: Credentials, config: FixConfig) -> Self {
56 let session = FixSession::new(config.clone());
57 Self {
58 credentials,
59 session: Arc::new(Mutex::new(session)),
60 transport: Arc::new(Mutex::new(None)),
61 decoder: FixDecoder::new(),
62 config,
63 message_rx: Arc::new(Mutex::new(None)),
64 shutdown_tx: Arc::new(Mutex::new(None)),
65 }
66 }
67
68 pub async fn state(&self) -> SessionState {
70 self.session.lock().await.state()
71 }
72
73 pub async fn connect(&self) -> Result<()> {
78 let mut session = self.session.lock().await;
79 session.set_state(SessionState::Connecting);
80
81 tracing::info!(
83 "Connecting to FIX server at {}:{}",
84 self.config.host,
85 self.config.port
86 );
87
88 let tcp_transport = transport::connect(&self.config.host, self.config.port).await?;
89
90 {
92 let mut transport_guard = self.transport.lock().await;
93 *transport_guard = Some(tcp_transport);
94 }
95
96 session.set_state(SessionState::LoggingOn);
97
98 let logon = session.create_logon();
100 self.send_raw(&logon).await?;
101
102 let logon_response = self
104 .receive_with_timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
105 .await?;
106
107 if let Some(msg_type) = logon_response.msg_type() {
109 match MsgType::from_fix_str(msg_type) {
110 Some(MsgType::Logon) => {
111 tracing::info!("Logon successful");
112 session.set_state(SessionState::Active);
113 }
114 Some(MsgType::Logout) => {
115 let text = logon_response.get(tags::TEXT).unwrap_or("unknown reason");
116 session.set_state(SessionState::Disconnected);
117 return Err(FixError::Authentication(format!(
118 "logon rejected: {}",
119 text
120 )));
121 }
122 _ => {
123 return Err(FixError::Session(format!(
124 "unexpected response to logon: {:?}",
125 msg_type
126 )));
127 }
128 }
129 } else {
130 return Err(FixError::InvalidMessage(
131 "missing MsgType in response".to_string(),
132 ));
133 }
134
135 self.start_background_tasks().await;
137
138 tracing::info!("FIX session established");
139 Ok(())
140 }
141
142 pub async fn disconnect(&self) -> Result<()> {
147 if let Some(tx) = self.shutdown_tx.lock().await.take() {
149 let _ = tx.send(()).await;
150 }
151
152 let mut session = self.session.lock().await;
153
154 if session.state() == SessionState::Active {
155 session.set_state(SessionState::LoggingOut);
156
157 let logout = session.create_logout(None);
159 if let Err(e) = self.send_raw(&logout).await {
160 tracing::warn!("Failed to send logout: {}", e);
161 }
162
163 if let Ok(response) = self.receive_with_timeout(Duration::from_secs(5)).await
165 && let Some(msg_type) = response.msg_type()
166 && MsgType::from_fix_str(msg_type) == Some(MsgType::Logout)
167 {
168 tracing::info!("Logout confirmed by server");
169 }
170 }
171
172 if let Some(transport) = self.transport.lock().await.take() {
174 let _ = transport.close().await;
175 }
176
177 session.set_state(SessionState::Disconnected);
178 tracing::info!("FIX session terminated");
179
180 Ok(())
181 }
182
183 pub async fn send_order(&self, order: &NewOrderSingle) -> Result<String> {
191 let session = self.session.lock().await;
192
193 if session.state() != SessionState::Active {
194 return Err(FixError::Session("session not active".to_string()));
195 }
196
197 let fields = self.build_new_order_fields(order);
198 let msg = session.encode_message(MsgType::NewOrderSingle.as_str(), &fields);
199 drop(session);
200
201 self.send_raw(&msg).await?;
202
203 tracing::debug!("Sent new order: cl_ord_id={}", order.cl_ord_id);
204 Ok(order.cl_ord_id.clone())
205 }
206
207 pub async fn cancel_order(&self, cancel: &OrderCancelRequest) -> Result<String> {
215 let session = self.session.lock().await;
216
217 if session.state() != SessionState::Active {
218 return Err(FixError::Session("session not active".to_string()));
219 }
220
221 let fields = vec![
222 (tags::ORIG_CL_ORD_ID, cancel.orig_cl_ord_id.clone()),
223 (tags::CL_ORD_ID, cancel.cl_ord_id.clone()),
224 (tags::SYMBOL, cancel.symbol.clone()),
225 (tags::SIDE, cancel.side.as_char().to_string()),
226 ];
227
228 let msg = session.encode_message(MsgType::OrderCancelRequest.as_str(), &fields);
229 drop(session);
230
231 self.send_raw(&msg).await?;
232
233 tracing::debug!("Sent cancel request: cl_ord_id={}", cancel.cl_ord_id);
234 Ok(cancel.cl_ord_id.clone())
235 }
236
237 pub async fn replace_order(&self, replace: &OrderCancelReplaceRequest) -> Result<String> {
245 let session = self.session.lock().await;
246
247 if session.state() != SessionState::Active {
248 return Err(FixError::Session("session not active".to_string()));
249 }
250
251 let mut fields = vec![
252 (tags::ORIG_CL_ORD_ID, replace.orig_cl_ord_id.clone()),
253 (tags::CL_ORD_ID, replace.cl_ord_id.clone()),
254 (tags::SYMBOL, replace.symbol.clone()),
255 (tags::SIDE, replace.side.as_char().to_string()),
256 (tags::ORD_TYPE, replace.ord_type.as_char().to_string()),
257 (tags::ORDER_QTY, replace.order_qty.to_string()),
258 ];
259
260 if let Some(price) = replace.price {
261 fields.push((tags::PRICE, price.to_string()));
262 }
263
264 let msg = session.encode_message(MsgType::OrderCancelReplaceRequest.as_str(), &fields);
265 drop(session);
266
267 self.send_raw(&msg).await?;
268
269 tracing::debug!("Sent replace request: cl_ord_id={}", replace.cl_ord_id);
270 Ok(replace.cl_ord_id.clone())
271 }
272
273 pub async fn request_market_data(&self, request: &MarketDataRequest) -> Result<String> {
281 let session = self.session.lock().await;
282
283 if session.state() != SessionState::Active {
284 return Err(FixError::Session("session not active".to_string()));
285 }
286
287 let fields = vec![
288 (tags::MD_REQ_ID, request.md_req_id.clone()),
289 (
290 tags::SUBSCRIPTION_REQUEST_TYPE,
291 request.subscription_request_type.to_string(),
292 ),
293 (tags::MARKET_DEPTH, request.market_depth.to_string()),
294 ];
295
296 let msg = session.encode_message(MsgType::MarketDataRequest.as_str(), &fields);
297 drop(session);
298
299 self.send_raw(&msg).await?;
300
301 tracing::debug!("Sent market data request: md_req_id={}", request.md_req_id);
302 Ok(request.md_req_id.clone())
303 }
304
305 pub async fn next_message(&self) -> Result<FixMessage> {
310 let mut rx_guard = self.message_rx.lock().await;
311 if let Some(ref mut rx) = *rx_guard {
312 rx.recv()
313 .await
314 .ok_or_else(|| FixError::Connection("message channel closed".to_string()))
315 } else {
316 Err(FixError::Session("not connected".to_string()))
317 }
318 }
319
320 pub async fn process_message(&self, msg: &FixMessage) -> Result<()> {
328 let mut session = self.session.lock().await;
329 session.validate_sequence(msg)?;
330
331 if let Some(msg_type) = msg.msg_type() {
333 match MsgType::from_fix_str(msg_type) {
334 Some(MsgType::Heartbeat) => {
335 tracing::debug!("Received heartbeat");
336 }
337 Some(MsgType::TestRequest) => {
338 if let Some(test_req_id) = msg.get(tags::TEST_REQ_ID) {
339 let heartbeat = session.create_heartbeat(Some(test_req_id));
340 drop(session);
341 self.send_raw(&heartbeat).await?;
342 tracing::debug!("Sent heartbeat response");
343 }
344 }
345 Some(MsgType::Logout) => {
346 session.set_state(SessionState::Disconnected);
347 tracing::info!("Received logout from server");
348 }
349 Some(MsgType::ResendRequest) => {
350 tracing::warn!("Resend request received - not fully implemented");
351 }
353 Some(MsgType::SequenceReset) => {
354 if let Some(new_seq) = msg.get(tags::MSG_SEQ_NUM)
355 && let Ok(seq) = new_seq.parse::<u64>()
356 {
357 session.seq_nums().set_incoming(seq);
358 tracing::info!("Sequence reset to {}", seq);
359 }
360 }
361 _ => {}
362 }
363 }
364
365 Ok(())
366 }
367
368 pub fn parse_execution_report(&self, msg: &FixMessage) -> Result<ExecutionReport> {
376 let order_id = msg
377 .get(tags::ORDER_ID)
378 .ok_or_else(|| FixError::InvalidMessage("missing OrderID".to_string()))?
379 .to_string();
380
381 let cl_ord_id = msg
382 .get(tags::CL_ORD_ID)
383 .ok_or_else(|| FixError::InvalidMessage("missing ClOrdID".to_string()))?
384 .to_string();
385
386 let exec_id = msg
387 .get(tags::EXEC_ID)
388 .ok_or_else(|| FixError::InvalidMessage("missing ExecID".to_string()))?
389 .to_string();
390
391 let exec_type_char = msg
392 .get(tags::EXEC_TYPE)
393 .and_then(|s| s.chars().next())
394 .ok_or_else(|| FixError::InvalidMessage("missing ExecType".to_string()))?;
395
396 let exec_type = ExecType::from_char(exec_type_char)
397 .ok_or_else(|| FixError::InvalidMessage("invalid ExecType".to_string()))?;
398
399 let ord_status_char = msg
400 .get(tags::ORD_STATUS)
401 .and_then(|s| s.chars().next())
402 .ok_or_else(|| FixError::InvalidMessage("missing OrdStatus".to_string()))?;
403
404 let ord_status = OrdStatus::from_char(ord_status_char)
405 .ok_or_else(|| FixError::InvalidMessage("invalid OrdStatus".to_string()))?;
406
407 let symbol = msg
408 .get(tags::SYMBOL)
409 .ok_or_else(|| FixError::InvalidMessage("missing Symbol".to_string()))?
410 .to_string();
411
412 let side_char = msg
413 .get(tags::SIDE)
414 .and_then(|s| s.chars().next())
415 .ok_or_else(|| FixError::InvalidMessage("missing Side".to_string()))?;
416
417 let side = Side::from_char(side_char)
418 .ok_or_else(|| FixError::InvalidMessage("invalid Side".to_string()))?;
419
420 let order_qty: f64 = msg
421 .get(tags::ORDER_QTY)
422 .ok_or_else(|| FixError::InvalidMessage("missing OrderQty".to_string()))?
423 .parse()
424 .map_err(|_| FixError::Decoding("invalid OrderQty".to_string()))?;
425
426 let cum_qty: f64 = msg.get(tags::CUM_QTY).unwrap_or("0").parse().unwrap_or(0.0);
427
428 let avg_px: f64 = msg.get(tags::AVG_PX).unwrap_or("0").parse().unwrap_or(0.0);
429
430 let leaves_qty: f64 = msg
431 .get(tags::LEAVES_QTY)
432 .unwrap_or("0")
433 .parse()
434 .unwrap_or(0.0);
435
436 let last_qty = msg.get(tags::LAST_QTY).and_then(|s| s.parse().ok());
437 let last_px = msg.get(tags::LAST_PX).and_then(|s| s.parse().ok());
438 let text = msg.get(tags::TEXT).map(String::from);
439
440 Ok(ExecutionReport {
441 order_id,
442 cl_ord_id,
443 exec_id,
444 exec_type,
445 ord_status,
446 symbol,
447 side,
448 order_qty,
449 last_qty,
450 last_px,
451 cum_qty,
452 avg_px,
453 leaves_qty,
454 text,
455 })
456 }
457
458 fn build_new_order_fields(&self, order: &NewOrderSingle) -> Vec<(u32, String)> {
460 let mut fields = vec![
461 (tags::CL_ORD_ID, order.cl_ord_id.clone()),
462 (tags::SYMBOL, order.symbol.clone()),
463 (tags::SIDE, order.side.as_char().to_string()),
464 (tags::ORD_TYPE, order.ord_type.as_char().to_string()),
465 (tags::ORDER_QTY, order.order_qty.to_string()),
466 (
467 tags::TIME_IN_FORCE,
468 order.time_in_force.as_char().to_string(),
469 ),
470 ];
471
472 if let Some(price) = order.price {
473 fields.push((tags::PRICE, price.to_string()));
474 }
475
476 if let Some(stop_px) = order.stop_px {
477 fields.push((tags::STOP_PX, stop_px.to_string()));
478 }
479
480 if let Some(ref account) = order.account {
481 fields.push((tags::ACCOUNT, account.clone()));
482 }
483
484 fields
485 }
486
487 async fn send_raw(&self, message: &str) -> Result<()> {
489 let transport_guard = self.transport.lock().await;
490 if let Some(ref transport) = *transport_guard {
491 transport.send(message).await
492 } else {
493 Err(FixError::Connection("not connected".to_string()))
494 }
495 }
496
497 async fn receive_with_timeout(&self, duration: Duration) -> Result<FixMessage> {
499 let transport_guard = self.transport.lock().await;
500 if transport_guard.is_some() {
501 drop(transport_guard);
502
503 let transport_clone = self.transport.clone();
504 timeout(duration, async move {
505 let guard = transport_clone.lock().await;
506 if let Some(ref t) = *guard {
507 t.receive().await
508 } else {
509 Err(FixError::Connection("not connected".to_string()))
510 }
511 })
512 .await
513 .map_err(|_| FixError::Timeout("receive timeout".to_string()))?
514 } else {
515 Err(FixError::Connection("not connected".to_string()))
516 }
517 }
518
519 async fn start_background_tasks(&self) {
521 let (msg_tx, msg_rx) = mpsc::channel(MESSAGE_CHANNEL_SIZE);
522 let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
523
524 *self.message_rx.lock().await = Some(msg_rx);
526 *self.shutdown_tx.lock().await = Some(shutdown_tx);
527
528 let transport = Arc::clone(&self.transport);
530 let session = Arc::clone(&self.session);
531 let heartbeat_interval = self.config.heartbeat_interval_secs;
532
533 let transport_recv = Arc::clone(&transport);
535 let session_recv = Arc::clone(&session);
536 let msg_tx_clone = msg_tx.clone();
537
538 tokio::spawn(async move {
539 loop {
540 tokio::select! {
541 _ = shutdown_rx.recv() => {
542 tracing::debug!("Message receiver shutting down");
543 break;
544 }
545 result = async {
546 let guard = transport_recv.lock().await;
547 if guard.is_some() {
548 drop(guard);
549 let guard2 = transport_recv.lock().await;
550 if let Some(ref t) = *guard2 {
551 t.receive().await
552 } else {
553 Err(FixError::Connection("disconnected".to_string()))
554 }
555 } else {
556 tokio::time::sleep(Duration::from_millis(100)).await;
558 Err(FixError::Connection("not connected".to_string()))
559 }
560 } => {
561 match result {
562 Ok(msg) => {
563 if let Some(msg_type) = msg.msg_type() {
565 match MsgType::from_fix_str(msg_type) {
566 Some(MsgType::TestRequest) => {
567 if let Some(test_req_id) = msg.get(tags::TEST_REQ_ID) {
569 let session_guard = session_recv.lock().await;
570 let heartbeat = session_guard.create_heartbeat(Some(test_req_id));
571 drop(session_guard);
572
573 let transport_guard = transport_recv.lock().await;
574 if let Some(ref t) = *transport_guard {
575 let _ = t.send(&heartbeat).await;
576 }
577 }
578 }
579 Some(MsgType::Logout) => {
580 let mut session_guard = session_recv.lock().await;
581 session_guard.set_state(SessionState::Disconnected);
582 tracing::info!("Server initiated logout");
583 }
584 _ => {}
585 }
586 }
587
588 if msg_tx_clone.send(msg).await.is_err() {
590 tracing::debug!("Message channel closed");
591 break;
592 }
593 }
594 Err(FixError::Connection(_)) => {
595 let mut session_guard = session_recv.lock().await;
597 if session_guard.state() == SessionState::Active {
598 session_guard.set_state(SessionState::Disconnected);
599 tracing::warn!("Connection lost");
600 }
601 break;
602 }
603 Err(e) => {
604 tracing::error!("Error receiving message: {}", e);
605 }
606 }
607 }
608 }
609 }
610 });
611
612 let transport_hb = Arc::clone(&transport);
614 let session_hb = Arc::clone(&session);
615
616 tokio::spawn(async move {
617 let mut heartbeat_timer = interval(Duration::from_secs(heartbeat_interval.into()));
618
619 loop {
620 heartbeat_timer.tick().await;
621
622 let session_guard = session_hb.lock().await;
623 if session_guard.state() != SessionState::Active {
624 break;
625 }
626
627 let heartbeat = session_guard.create_heartbeat(None);
628 drop(session_guard);
629
630 let transport_guard = transport_hb.lock().await;
631 if let Some(ref t) = *transport_guard {
632 if let Err(e) = t.send(&heartbeat).await {
633 tracing::warn!("Failed to send heartbeat: {}", e);
634 break;
635 }
636 tracing::debug!("Sent heartbeat");
637 } else {
638 break;
639 }
640 }
641 });
642 }
643}
644
645#[cfg(test)]
646mod tests {
647 use super::*;
648 use crate::config::FixVersion;
649
650 fn test_credentials() -> Credentials {
651 Credentials::new("test_key".to_string(), "test_secret".to_string())
652 }
653
654 #[tokio::test]
655 async fn test_client_creation() {
656 let config = FixConfig::builder()
657 .version(FixVersion::Fix44)
658 .sender_comp_id("SENDER")
659 .target_comp_id("TARGET")
660 .build();
661
662 let client = FixClient::new(test_credentials(), config);
663 assert_eq!(client.state().await, SessionState::Disconnected);
664 }
665
666 #[tokio::test]
667 async fn test_send_order_requires_active_session() {
668 let config = FixConfig::builder()
669 .sender_comp_id("SENDER")
670 .target_comp_id("TARGET")
671 .build();
672
673 let client = FixClient::new(test_credentials(), config);
674 let order = NewOrderSingle::market("AAPL", Side::Buy, 100.0);
675
676 let result = client.send_order(&order).await;
677 assert!(result.is_err());
678 }
679}