alpaca_fix/
client.rs

1//! FIX protocol client implementation.
2
3use crate::codec::{FixDecoder, FixMessage, tags};
4use crate::config::FixConfig;
5use crate::error::{FixError, Result};
6use crate::messages::{
7    ExecType, ExecutionReport, MarketDataRequest, MsgType, NewOrderSingle, OrdStatus,
8    OrderCancelReplaceRequest, OrderCancelRequest, Side,
9};
10use crate::session::{FixSession, SessionState};
11use crate::transport::{self, FixTransport};
12use alpaca_base::Credentials;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::sync::{Mutex, mpsc};
16use tokio::time::{interval, timeout};
17
18/// Channel buffer size for incoming messages.
19const MESSAGE_CHANNEL_SIZE: usize = 1000;
20
21/// Default timeout for operations in seconds.
22const DEFAULT_TIMEOUT_SECS: u64 = 30;
23
24/// FIX protocol client for Alpaca.
25pub struct FixClient {
26    /// Alpaca credentials.
27    #[allow(dead_code)]
28    credentials: Credentials,
29    /// FIX session.
30    session: Arc<Mutex<FixSession>>,
31    /// TCP transport.
32    transport: Arc<Mutex<Option<FixTransport>>>,
33    /// Message decoder.
34    #[allow(dead_code)]
35    decoder: FixDecoder,
36    /// Configuration.
37    config: FixConfig,
38    /// Incoming message receiver.
39    message_rx: Arc<Mutex<Option<mpsc::Receiver<FixMessage>>>>,
40    /// Shutdown signal sender.
41    shutdown_tx: Arc<Mutex<Option<mpsc::Sender<()>>>>,
42}
43
44impl std::fmt::Debug for FixClient {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        f.debug_struct("FixClient")
47            .field("config", &self.config)
48            .finish()
49    }
50}
51
52impl FixClient {
53    /// Create a new FIX client.
54    #[must_use]
55    pub fn new(credentials: Credentials, config: FixConfig) -> Self {
56        let session = FixSession::new(config.clone());
57        Self {
58            credentials,
59            session: Arc::new(Mutex::new(session)),
60            transport: Arc::new(Mutex::new(None)),
61            decoder: FixDecoder::new(),
62            config,
63            message_rx: Arc::new(Mutex::new(None)),
64            shutdown_tx: Arc::new(Mutex::new(None)),
65        }
66    }
67
68    /// Get the current session state.
69    pub async fn state(&self) -> SessionState {
70        self.session.lock().await.state()
71    }
72
73    /// Connect to the FIX server and establish a session.
74    ///
75    /// # Errors
76    /// Returns error if connection or logon fails.
77    pub async fn connect(&self) -> Result<()> {
78        let mut session = self.session.lock().await;
79        session.set_state(SessionState::Connecting);
80
81        // Establish TCP connection
82        tracing::info!(
83            "Connecting to FIX server at {}:{}",
84            self.config.host,
85            self.config.port
86        );
87
88        let tcp_transport = transport::connect(&self.config.host, self.config.port).await?;
89
90        // Store transport
91        {
92            let mut transport_guard = self.transport.lock().await;
93            *transport_guard = Some(tcp_transport);
94        }
95
96        session.set_state(SessionState::LoggingOn);
97
98        // Send logon message
99        let logon = session.create_logon();
100        self.send_raw(&logon).await?;
101
102        // Wait for logon response
103        let logon_response = self
104            .receive_with_timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
105            .await?;
106
107        // Validate logon response
108        if let Some(msg_type) = logon_response.msg_type() {
109            match MsgType::from_fix_str(msg_type) {
110                Some(MsgType::Logon) => {
111                    tracing::info!("Logon successful");
112                    session.set_state(SessionState::Active);
113                }
114                Some(MsgType::Logout) => {
115                    let text = logon_response.get(tags::TEXT).unwrap_or("unknown reason");
116                    session.set_state(SessionState::Disconnected);
117                    return Err(FixError::Authentication(format!(
118                        "logon rejected: {}",
119                        text
120                    )));
121                }
122                _ => {
123                    return Err(FixError::Session(format!(
124                        "unexpected response to logon: {:?}",
125                        msg_type
126                    )));
127                }
128            }
129        } else {
130            return Err(FixError::InvalidMessage(
131                "missing MsgType in response".to_string(),
132            ));
133        }
134
135        // Start background tasks
136        self.start_background_tasks().await;
137
138        tracing::info!("FIX session established");
139        Ok(())
140    }
141
142    /// Disconnect from the FIX server.
143    ///
144    /// # Errors
145    /// Returns error if disconnect fails.
146    pub async fn disconnect(&self) -> Result<()> {
147        // Send shutdown signal to background tasks
148        if let Some(tx) = self.shutdown_tx.lock().await.take() {
149            let _ = tx.send(()).await;
150        }
151
152        let mut session = self.session.lock().await;
153
154        if session.state() == SessionState::Active {
155            session.set_state(SessionState::LoggingOut);
156
157            // Send logout message
158            let logout = session.create_logout(None);
159            if let Err(e) = self.send_raw(&logout).await {
160                tracing::warn!("Failed to send logout: {}", e);
161            }
162
163            // Wait briefly for logout response
164            if let Ok(response) = self.receive_with_timeout(Duration::from_secs(5)).await
165                && let Some(msg_type) = response.msg_type()
166                && MsgType::from_fix_str(msg_type) == Some(MsgType::Logout)
167            {
168                tracing::info!("Logout confirmed by server");
169            }
170        }
171
172        // Close transport
173        if let Some(transport) = self.transport.lock().await.take() {
174            let _ = transport.close().await;
175        }
176
177        session.set_state(SessionState::Disconnected);
178        tracing::info!("FIX session terminated");
179
180        Ok(())
181    }
182
183    /// Send a new order.
184    ///
185    /// # Arguments
186    /// * `order` - New order single message
187    ///
188    /// # Errors
189    /// Returns error if order submission fails.
190    pub async fn send_order(&self, order: &NewOrderSingle) -> Result<String> {
191        let session = self.session.lock().await;
192
193        if session.state() != SessionState::Active {
194            return Err(FixError::Session("session not active".to_string()));
195        }
196
197        let fields = self.build_new_order_fields(order);
198        let msg = session.encode_message(MsgType::NewOrderSingle.as_str(), &fields);
199        drop(session);
200
201        self.send_raw(&msg).await?;
202
203        tracing::debug!("Sent new order: cl_ord_id={}", order.cl_ord_id);
204        Ok(order.cl_ord_id.clone())
205    }
206
207    /// Cancel an order.
208    ///
209    /// # Arguments
210    /// * `cancel` - Order cancel request
211    ///
212    /// # Errors
213    /// Returns error if cancel request fails.
214    pub async fn cancel_order(&self, cancel: &OrderCancelRequest) -> Result<String> {
215        let session = self.session.lock().await;
216
217        if session.state() != SessionState::Active {
218            return Err(FixError::Session("session not active".to_string()));
219        }
220
221        let fields = vec![
222            (tags::ORIG_CL_ORD_ID, cancel.orig_cl_ord_id.clone()),
223            (tags::CL_ORD_ID, cancel.cl_ord_id.clone()),
224            (tags::SYMBOL, cancel.symbol.clone()),
225            (tags::SIDE, cancel.side.as_char().to_string()),
226        ];
227
228        let msg = session.encode_message(MsgType::OrderCancelRequest.as_str(), &fields);
229        drop(session);
230
231        self.send_raw(&msg).await?;
232
233        tracing::debug!("Sent cancel request: cl_ord_id={}", cancel.cl_ord_id);
234        Ok(cancel.cl_ord_id.clone())
235    }
236
237    /// Replace an order.
238    ///
239    /// # Arguments
240    /// * `replace` - Order cancel/replace request
241    ///
242    /// # Errors
243    /// Returns error if replace request fails.
244    pub async fn replace_order(&self, replace: &OrderCancelReplaceRequest) -> Result<String> {
245        let session = self.session.lock().await;
246
247        if session.state() != SessionState::Active {
248            return Err(FixError::Session("session not active".to_string()));
249        }
250
251        let mut fields = vec![
252            (tags::ORIG_CL_ORD_ID, replace.orig_cl_ord_id.clone()),
253            (tags::CL_ORD_ID, replace.cl_ord_id.clone()),
254            (tags::SYMBOL, replace.symbol.clone()),
255            (tags::SIDE, replace.side.as_char().to_string()),
256            (tags::ORD_TYPE, replace.ord_type.as_char().to_string()),
257            (tags::ORDER_QTY, replace.order_qty.to_string()),
258        ];
259
260        if let Some(price) = replace.price {
261            fields.push((tags::PRICE, price.to_string()));
262        }
263
264        let msg = session.encode_message(MsgType::OrderCancelReplaceRequest.as_str(), &fields);
265        drop(session);
266
267        self.send_raw(&msg).await?;
268
269        tracing::debug!("Sent replace request: cl_ord_id={}", replace.cl_ord_id);
270        Ok(replace.cl_ord_id.clone())
271    }
272
273    /// Request market data.
274    ///
275    /// # Arguments
276    /// * `request` - Market data request
277    ///
278    /// # Errors
279    /// Returns error if request fails.
280    pub async fn request_market_data(&self, request: &MarketDataRequest) -> Result<String> {
281        let session = self.session.lock().await;
282
283        if session.state() != SessionState::Active {
284            return Err(FixError::Session("session not active".to_string()));
285        }
286
287        let fields = vec![
288            (tags::MD_REQ_ID, request.md_req_id.clone()),
289            (
290                tags::SUBSCRIPTION_REQUEST_TYPE,
291                request.subscription_request_type.to_string(),
292            ),
293            (tags::MARKET_DEPTH, request.market_depth.to_string()),
294        ];
295
296        let msg = session.encode_message(MsgType::MarketDataRequest.as_str(), &fields);
297        drop(session);
298
299        self.send_raw(&msg).await?;
300
301        tracing::debug!("Sent market data request: md_req_id={}", request.md_req_id);
302        Ok(request.md_req_id.clone())
303    }
304
305    /// Receive the next message from the server.
306    ///
307    /// # Errors
308    /// Returns error if no message is available or channel is closed.
309    pub async fn next_message(&self) -> Result<FixMessage> {
310        let mut rx_guard = self.message_rx.lock().await;
311        if let Some(ref mut rx) = *rx_guard {
312            rx.recv()
313                .await
314                .ok_or_else(|| FixError::Connection("message channel closed".to_string()))
315        } else {
316            Err(FixError::Session("not connected".to_string()))
317        }
318    }
319
320    /// Process an incoming message.
321    ///
322    /// # Arguments
323    /// * `msg` - FIX message
324    ///
325    /// # Errors
326    /// Returns error if message processing fails.
327    pub async fn process_message(&self, msg: &FixMessage) -> Result<()> {
328        let mut session = self.session.lock().await;
329        session.validate_sequence(msg)?;
330
331        // Handle session-level messages
332        if let Some(msg_type) = msg.msg_type() {
333            match MsgType::from_fix_str(msg_type) {
334                Some(MsgType::Heartbeat) => {
335                    tracing::debug!("Received heartbeat");
336                }
337                Some(MsgType::TestRequest) => {
338                    if let Some(test_req_id) = msg.get(tags::TEST_REQ_ID) {
339                        let heartbeat = session.create_heartbeat(Some(test_req_id));
340                        drop(session);
341                        self.send_raw(&heartbeat).await?;
342                        tracing::debug!("Sent heartbeat response");
343                    }
344                }
345                Some(MsgType::Logout) => {
346                    session.set_state(SessionState::Disconnected);
347                    tracing::info!("Received logout from server");
348                }
349                Some(MsgType::ResendRequest) => {
350                    tracing::warn!("Resend request received - not fully implemented");
351                    // TODO: Implement message resend
352                }
353                Some(MsgType::SequenceReset) => {
354                    if let Some(new_seq) = msg.get(tags::MSG_SEQ_NUM)
355                        && let Ok(seq) = new_seq.parse::<u64>()
356                    {
357                        session.seq_nums().set_incoming(seq);
358                        tracing::info!("Sequence reset to {}", seq);
359                    }
360                }
361                _ => {}
362            }
363        }
364
365        Ok(())
366    }
367
368    /// Parse an execution report from a FIX message.
369    ///
370    /// # Arguments
371    /// * `msg` - FIX message
372    ///
373    /// # Errors
374    /// Returns error if parsing fails.
375    pub fn parse_execution_report(&self, msg: &FixMessage) -> Result<ExecutionReport> {
376        let order_id = msg
377            .get(tags::ORDER_ID)
378            .ok_or_else(|| FixError::InvalidMessage("missing OrderID".to_string()))?
379            .to_string();
380
381        let cl_ord_id = msg
382            .get(tags::CL_ORD_ID)
383            .ok_or_else(|| FixError::InvalidMessage("missing ClOrdID".to_string()))?
384            .to_string();
385
386        let exec_id = msg
387            .get(tags::EXEC_ID)
388            .ok_or_else(|| FixError::InvalidMessage("missing ExecID".to_string()))?
389            .to_string();
390
391        let exec_type_char = msg
392            .get(tags::EXEC_TYPE)
393            .and_then(|s| s.chars().next())
394            .ok_or_else(|| FixError::InvalidMessage("missing ExecType".to_string()))?;
395
396        let exec_type = ExecType::from_char(exec_type_char)
397            .ok_or_else(|| FixError::InvalidMessage("invalid ExecType".to_string()))?;
398
399        let ord_status_char = msg
400            .get(tags::ORD_STATUS)
401            .and_then(|s| s.chars().next())
402            .ok_or_else(|| FixError::InvalidMessage("missing OrdStatus".to_string()))?;
403
404        let ord_status = OrdStatus::from_char(ord_status_char)
405            .ok_or_else(|| FixError::InvalidMessage("invalid OrdStatus".to_string()))?;
406
407        let symbol = msg
408            .get(tags::SYMBOL)
409            .ok_or_else(|| FixError::InvalidMessage("missing Symbol".to_string()))?
410            .to_string();
411
412        let side_char = msg
413            .get(tags::SIDE)
414            .and_then(|s| s.chars().next())
415            .ok_or_else(|| FixError::InvalidMessage("missing Side".to_string()))?;
416
417        let side = Side::from_char(side_char)
418            .ok_or_else(|| FixError::InvalidMessage("invalid Side".to_string()))?;
419
420        let order_qty: f64 = msg
421            .get(tags::ORDER_QTY)
422            .ok_or_else(|| FixError::InvalidMessage("missing OrderQty".to_string()))?
423            .parse()
424            .map_err(|_| FixError::Decoding("invalid OrderQty".to_string()))?;
425
426        let cum_qty: f64 = msg.get(tags::CUM_QTY).unwrap_or("0").parse().unwrap_or(0.0);
427
428        let avg_px: f64 = msg.get(tags::AVG_PX).unwrap_or("0").parse().unwrap_or(0.0);
429
430        let leaves_qty: f64 = msg
431            .get(tags::LEAVES_QTY)
432            .unwrap_or("0")
433            .parse()
434            .unwrap_or(0.0);
435
436        let last_qty = msg.get(tags::LAST_QTY).and_then(|s| s.parse().ok());
437        let last_px = msg.get(tags::LAST_PX).and_then(|s| s.parse().ok());
438        let text = msg.get(tags::TEXT).map(String::from);
439
440        Ok(ExecutionReport {
441            order_id,
442            cl_ord_id,
443            exec_id,
444            exec_type,
445            ord_status,
446            symbol,
447            side,
448            order_qty,
449            last_qty,
450            last_px,
451            cum_qty,
452            avg_px,
453            leaves_qty,
454            text,
455        })
456    }
457
458    /// Build FIX fields for a new order.
459    fn build_new_order_fields(&self, order: &NewOrderSingle) -> Vec<(u32, String)> {
460        let mut fields = vec![
461            (tags::CL_ORD_ID, order.cl_ord_id.clone()),
462            (tags::SYMBOL, order.symbol.clone()),
463            (tags::SIDE, order.side.as_char().to_string()),
464            (tags::ORD_TYPE, order.ord_type.as_char().to_string()),
465            (tags::ORDER_QTY, order.order_qty.to_string()),
466            (
467                tags::TIME_IN_FORCE,
468                order.time_in_force.as_char().to_string(),
469            ),
470        ];
471
472        if let Some(price) = order.price {
473            fields.push((tags::PRICE, price.to_string()));
474        }
475
476        if let Some(stop_px) = order.stop_px {
477            fields.push((tags::STOP_PX, stop_px.to_string()));
478        }
479
480        if let Some(ref account) = order.account {
481            fields.push((tags::ACCOUNT, account.clone()));
482        }
483
484        fields
485    }
486
487    /// Send a raw FIX message over the transport.
488    async fn send_raw(&self, message: &str) -> Result<()> {
489        let transport_guard = self.transport.lock().await;
490        if let Some(ref transport) = *transport_guard {
491            transport.send(message).await
492        } else {
493            Err(FixError::Connection("not connected".to_string()))
494        }
495    }
496
497    /// Receive a message with timeout.
498    async fn receive_with_timeout(&self, duration: Duration) -> Result<FixMessage> {
499        let transport_guard = self.transport.lock().await;
500        if transport_guard.is_some() {
501            drop(transport_guard);
502
503            let transport_clone = self.transport.clone();
504            timeout(duration, async move {
505                let guard = transport_clone.lock().await;
506                if let Some(ref t) = *guard {
507                    t.receive().await
508                } else {
509                    Err(FixError::Connection("not connected".to_string()))
510                }
511            })
512            .await
513            .map_err(|_| FixError::Timeout("receive timeout".to_string()))?
514        } else {
515            Err(FixError::Connection("not connected".to_string()))
516        }
517    }
518
519    /// Start background tasks for heartbeat and message receiving.
520    async fn start_background_tasks(&self) {
521        let (msg_tx, msg_rx) = mpsc::channel(MESSAGE_CHANNEL_SIZE);
522        let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
523
524        // Store receivers
525        *self.message_rx.lock().await = Some(msg_rx);
526        *self.shutdown_tx.lock().await = Some(shutdown_tx);
527
528        // Clone references for background tasks
529        let transport = Arc::clone(&self.transport);
530        let session = Arc::clone(&self.session);
531        let heartbeat_interval = self.config.heartbeat_interval_secs;
532
533        // Spawn message receiver task
534        let transport_recv = Arc::clone(&transport);
535        let session_recv = Arc::clone(&session);
536        let msg_tx_clone = msg_tx.clone();
537
538        tokio::spawn(async move {
539            loop {
540                tokio::select! {
541                    _ = shutdown_rx.recv() => {
542                        tracing::debug!("Message receiver shutting down");
543                        break;
544                    }
545                    result = async {
546                        let guard = transport_recv.lock().await;
547                        if guard.is_some() {
548                            drop(guard);
549                            let guard2 = transport_recv.lock().await;
550                            if let Some(ref t) = *guard2 {
551                                t.receive().await
552                            } else {
553                                Err(FixError::Connection("disconnected".to_string()))
554                            }
555                        } else {
556                            // Not connected, wait a bit
557                            tokio::time::sleep(Duration::from_millis(100)).await;
558                            Err(FixError::Connection("not connected".to_string()))
559                        }
560                    } => {
561                        match result {
562                            Ok(msg) => {
563                                // Process session-level messages
564                                if let Some(msg_type) = msg.msg_type() {
565                                    match MsgType::from_fix_str(msg_type) {
566                                        Some(MsgType::TestRequest) => {
567                                            // Respond to test request with heartbeat
568                                            if let Some(test_req_id) = msg.get(tags::TEST_REQ_ID) {
569                                                let session_guard = session_recv.lock().await;
570                                                let heartbeat = session_guard.create_heartbeat(Some(test_req_id));
571                                                drop(session_guard);
572
573                                                let transport_guard = transport_recv.lock().await;
574                                                if let Some(ref t) = *transport_guard {
575                                                    let _ = t.send(&heartbeat).await;
576                                                }
577                                            }
578                                        }
579                                        Some(MsgType::Logout) => {
580                                            let mut session_guard = session_recv.lock().await;
581                                            session_guard.set_state(SessionState::Disconnected);
582                                            tracing::info!("Server initiated logout");
583                                        }
584                                        _ => {}
585                                    }
586                                }
587
588                                // Forward message to channel
589                                if msg_tx_clone.send(msg).await.is_err() {
590                                    tracing::debug!("Message channel closed");
591                                    break;
592                                }
593                            }
594                            Err(FixError::Connection(_)) => {
595                                // Connection lost
596                                let mut session_guard = session_recv.lock().await;
597                                if session_guard.state() == SessionState::Active {
598                                    session_guard.set_state(SessionState::Disconnected);
599                                    tracing::warn!("Connection lost");
600                                }
601                                break;
602                            }
603                            Err(e) => {
604                                tracing::error!("Error receiving message: {}", e);
605                            }
606                        }
607                    }
608                }
609            }
610        });
611
612        // Spawn heartbeat task
613        let transport_hb = Arc::clone(&transport);
614        let session_hb = Arc::clone(&session);
615
616        tokio::spawn(async move {
617            let mut heartbeat_timer = interval(Duration::from_secs(heartbeat_interval.into()));
618
619            loop {
620                heartbeat_timer.tick().await;
621
622                let session_guard = session_hb.lock().await;
623                if session_guard.state() != SessionState::Active {
624                    break;
625                }
626
627                let heartbeat = session_guard.create_heartbeat(None);
628                drop(session_guard);
629
630                let transport_guard = transport_hb.lock().await;
631                if let Some(ref t) = *transport_guard {
632                    if let Err(e) = t.send(&heartbeat).await {
633                        tracing::warn!("Failed to send heartbeat: {}", e);
634                        break;
635                    }
636                    tracing::debug!("Sent heartbeat");
637                } else {
638                    break;
639                }
640            }
641        });
642    }
643}
644
645#[cfg(test)]
646mod tests {
647    use super::*;
648    use crate::config::FixVersion;
649
650    fn test_credentials() -> Credentials {
651        Credentials::new("test_key".to_string(), "test_secret".to_string())
652    }
653
654    #[tokio::test]
655    async fn test_client_creation() {
656        let config = FixConfig::builder()
657            .version(FixVersion::Fix44)
658            .sender_comp_id("SENDER")
659            .target_comp_id("TARGET")
660            .build();
661
662        let client = FixClient::new(test_credentials(), config);
663        assert_eq!(client.state().await, SessionState::Disconnected);
664    }
665
666    #[tokio::test]
667    async fn test_send_order_requires_active_session() {
668        let config = FixConfig::builder()
669            .sender_comp_id("SENDER")
670            .target_comp_id("TARGET")
671            .build();
672
673        let client = FixClient::new(test_credentials(), config);
674        let order = NewOrderSingle::market("AAPL", Side::Buy, 100.0);
675
676        let result = client.send_order(&order).await;
677        assert!(result.is_err());
678    }
679}