alpaca_fix/
session.rs

1//! FIX session management.
2
3use crate::codec::{FixEncoder, FixMessage, tags};
4use crate::config::FixConfig;
5use crate::error::{FixError, Result};
6use crate::messages::MsgType;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicU64, Ordering};
9
10/// FIX session state.
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum SessionState {
13    /// Disconnected.
14    Disconnected,
15    /// Connecting.
16    Connecting,
17    /// Logging on.
18    LoggingOn,
19    /// Active session.
20    Active,
21    /// Logging out.
22    LoggingOut,
23}
24
25impl std::fmt::Display for SessionState {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        match self {
28            Self::Disconnected => write!(f, "Disconnected"),
29            Self::Connecting => write!(f, "Connecting"),
30            Self::LoggingOn => write!(f, "LoggingOn"),
31            Self::Active => write!(f, "Active"),
32            Self::LoggingOut => write!(f, "LoggingOut"),
33        }
34    }
35}
36
37/// Sequence number manager.
38#[derive(Debug)]
39pub struct SequenceNumbers {
40    /// Outgoing sequence number.
41    outgoing: AtomicU64,
42    /// Expected incoming sequence number.
43    incoming: AtomicU64,
44}
45
46impl SequenceNumbers {
47    /// Create new sequence numbers.
48    #[must_use]
49    pub fn new() -> Self {
50        Self {
51            outgoing: AtomicU64::new(1),
52            incoming: AtomicU64::new(1),
53        }
54    }
55
56    /// Get and increment outgoing sequence number.
57    pub fn next_outgoing(&self) -> u64 {
58        self.outgoing.fetch_add(1, Ordering::SeqCst)
59    }
60
61    /// Get current outgoing sequence number.
62    #[must_use]
63    pub fn current_outgoing(&self) -> u64 {
64        self.outgoing.load(Ordering::SeqCst)
65    }
66
67    /// Get expected incoming sequence number.
68    #[must_use]
69    pub fn expected_incoming(&self) -> u64 {
70        self.incoming.load(Ordering::SeqCst)
71    }
72
73    /// Increment incoming sequence number.
74    pub fn increment_incoming(&self) {
75        self.incoming.fetch_add(1, Ordering::SeqCst);
76    }
77
78    /// Set incoming sequence number.
79    pub fn set_incoming(&self, seq: u64) {
80        self.incoming.store(seq, Ordering::SeqCst);
81    }
82
83    /// Reset sequence numbers.
84    pub fn reset(&self) {
85        self.outgoing.store(1, Ordering::SeqCst);
86        self.incoming.store(1, Ordering::SeqCst);
87    }
88}
89
90impl Default for SequenceNumbers {
91    fn default() -> Self {
92        Self::new()
93    }
94}
95
96/// FIX session.
97#[derive(Debug)]
98pub struct FixSession {
99    /// Session configuration.
100    config: FixConfig,
101    /// Session state.
102    state: SessionState,
103    /// Sequence numbers.
104    seq_nums: Arc<SequenceNumbers>,
105    /// Message encoder.
106    encoder: FixEncoder,
107}
108
109impl FixSession {
110    /// Create a new session.
111    #[must_use]
112    pub fn new(config: FixConfig) -> Self {
113        let encoder = FixEncoder::new(
114            config.version,
115            &config.sender_comp_id,
116            &config.target_comp_id,
117        );
118        Self {
119            config,
120            state: SessionState::Disconnected,
121            seq_nums: Arc::new(SequenceNumbers::new()),
122            encoder,
123        }
124    }
125
126    /// Get session state.
127    #[must_use]
128    pub fn state(&self) -> SessionState {
129        self.state
130    }
131
132    /// Set session state.
133    pub fn set_state(&mut self, state: SessionState) {
134        tracing::info!("Session state: {} -> {}", self.state, state);
135        self.state = state;
136    }
137
138    /// Get sequence numbers.
139    #[must_use]
140    pub fn seq_nums(&self) -> Arc<SequenceNumbers> {
141        Arc::clone(&self.seq_nums)
142    }
143
144    /// Create logon message.
145    #[must_use]
146    pub fn create_logon(&self) -> String {
147        let mut fields = vec![
148            (tags::ENCRYPT_METHOD, "0".to_string()),
149            (
150                tags::HEART_BT_INT,
151                self.config.heartbeat_interval_secs.to_string(),
152            ),
153        ];
154
155        if self.config.reset_on_logon {
156            fields.push((tags::RESET_SEQ_NUM_FLAG, "Y".to_string()));
157        }
158
159        self.encoder.encode(
160            MsgType::Logon.as_str(),
161            self.seq_nums.next_outgoing(),
162            &fields,
163        )
164    }
165
166    /// Create logout message.
167    #[must_use]
168    pub fn create_logout(&self, text: Option<&str>) -> String {
169        let fields = if let Some(t) = text {
170            vec![(tags::TEXT, t.to_string())]
171        } else {
172            vec![]
173        };
174
175        self.encoder.encode(
176            MsgType::Logout.as_str(),
177            self.seq_nums.next_outgoing(),
178            &fields,
179        )
180    }
181
182    /// Create heartbeat message.
183    #[must_use]
184    pub fn create_heartbeat(&self, test_req_id: Option<&str>) -> String {
185        let fields = if let Some(id) = test_req_id {
186            vec![(tags::TEST_REQ_ID, id.to_string())]
187        } else {
188            vec![]
189        };
190
191        self.encoder.encode(
192            MsgType::Heartbeat.as_str(),
193            self.seq_nums.next_outgoing(),
194            &fields,
195        )
196    }
197
198    /// Create test request message.
199    #[must_use]
200    pub fn create_test_request(&self, test_req_id: &str) -> String {
201        let fields = vec![(tags::TEST_REQ_ID, test_req_id.to_string())];
202
203        self.encoder.encode(
204            MsgType::TestRequest.as_str(),
205            self.seq_nums.next_outgoing(),
206            &fields,
207        )
208    }
209
210    /// Create resend request message.
211    #[must_use]
212    pub fn create_resend_request(&self, begin_seq: u64, end_seq: u64) -> String {
213        let fields = vec![
214            (tags::BEGIN_SEQ_NO, begin_seq.to_string()),
215            (tags::END_SEQ_NO, end_seq.to_string()),
216        ];
217
218        self.encoder.encode(
219            MsgType::ResendRequest.as_str(),
220            self.seq_nums.next_outgoing(),
221            &fields,
222        )
223    }
224
225    /// Validate incoming message sequence number.
226    pub fn validate_sequence(&self, msg: &FixMessage) -> Result<()> {
227        if let Some(seq_str) = msg.get(tags::MSG_SEQ_NUM) {
228            let seq: u64 = seq_str
229                .parse()
230                .map_err(|_| FixError::Decoding("invalid sequence number".to_string()))?;
231
232            let expected = self.seq_nums.expected_incoming();
233
234            if seq < expected {
235                // Duplicate or old message
236                tracing::warn!("Received old message: seq={}, expected={}", seq, expected);
237            } else if seq > expected {
238                // Gap detected
239                return Err(FixError::SequenceError {
240                    expected,
241                    actual: seq,
242                });
243            }
244
245            self.seq_nums.increment_incoming();
246        }
247
248        Ok(())
249    }
250
251    /// Encode a message with session headers.
252    pub fn encode_message(&self, msg_type: &str, fields: &[(u32, String)]) -> String {
253        self.encoder
254            .encode(msg_type, self.seq_nums.next_outgoing(), fields)
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn test_sequence_numbers() {
264        let seq = SequenceNumbers::new();
265        assert_eq!(seq.next_outgoing(), 1);
266        assert_eq!(seq.next_outgoing(), 2);
267        assert_eq!(seq.current_outgoing(), 3);
268
269        assert_eq!(seq.expected_incoming(), 1);
270        seq.increment_incoming();
271        assert_eq!(seq.expected_incoming(), 2);
272
273        seq.reset();
274        assert_eq!(seq.current_outgoing(), 1);
275        assert_eq!(seq.expected_incoming(), 1);
276    }
277
278    #[test]
279    fn test_session_state() {
280        let config = FixConfig::builder()
281            .sender_comp_id("SENDER")
282            .target_comp_id("TARGET")
283            .build();
284        let mut session = FixSession::new(config);
285
286        assert_eq!(session.state(), SessionState::Disconnected);
287        session.set_state(SessionState::Active);
288        assert_eq!(session.state(), SessionState::Active);
289    }
290
291    #[test]
292    fn test_create_logon() {
293        let config = FixConfig::builder()
294            .sender_comp_id("SENDER")
295            .target_comp_id("TARGET")
296            .heartbeat_interval_secs(30)
297            .build();
298        let session = FixSession::new(config);
299        let logon = session.create_logon();
300
301        assert!(logon.contains("35=A"));
302        assert!(logon.contains("108=30"));
303    }
304
305    #[test]
306    fn test_create_heartbeat() {
307        let config = FixConfig::builder()
308            .sender_comp_id("SENDER")
309            .target_comp_id("TARGET")
310            .build();
311        let session = FixSession::new(config);
312        let heartbeat = session.create_heartbeat(Some("TEST123"));
313
314        assert!(heartbeat.contains("35=0"));
315        assert!(heartbeat.contains("112=TEST123"));
316    }
317}