1use crate::codec::{FixEncoder, FixMessage, tags};
4use crate::config::FixConfig;
5use crate::error::{FixError, Result};
6use crate::messages::MsgType;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicU64, Ordering};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum SessionState {
13 Disconnected,
15 Connecting,
17 LoggingOn,
19 Active,
21 LoggingOut,
23}
24
25impl std::fmt::Display for SessionState {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 match self {
28 Self::Disconnected => write!(f, "Disconnected"),
29 Self::Connecting => write!(f, "Connecting"),
30 Self::LoggingOn => write!(f, "LoggingOn"),
31 Self::Active => write!(f, "Active"),
32 Self::LoggingOut => write!(f, "LoggingOut"),
33 }
34 }
35}
36
37#[derive(Debug)]
39pub struct SequenceNumbers {
40 outgoing: AtomicU64,
42 incoming: AtomicU64,
44}
45
46impl SequenceNumbers {
47 #[must_use]
49 pub fn new() -> Self {
50 Self {
51 outgoing: AtomicU64::new(1),
52 incoming: AtomicU64::new(1),
53 }
54 }
55
56 pub fn next_outgoing(&self) -> u64 {
58 self.outgoing.fetch_add(1, Ordering::SeqCst)
59 }
60
61 #[must_use]
63 pub fn current_outgoing(&self) -> u64 {
64 self.outgoing.load(Ordering::SeqCst)
65 }
66
67 #[must_use]
69 pub fn expected_incoming(&self) -> u64 {
70 self.incoming.load(Ordering::SeqCst)
71 }
72
73 pub fn increment_incoming(&self) {
75 self.incoming.fetch_add(1, Ordering::SeqCst);
76 }
77
78 pub fn set_incoming(&self, seq: u64) {
80 self.incoming.store(seq, Ordering::SeqCst);
81 }
82
83 pub fn reset(&self) {
85 self.outgoing.store(1, Ordering::SeqCst);
86 self.incoming.store(1, Ordering::SeqCst);
87 }
88}
89
90impl Default for SequenceNumbers {
91 fn default() -> Self {
92 Self::new()
93 }
94}
95
96#[derive(Debug)]
98pub struct FixSession {
99 config: FixConfig,
101 state: SessionState,
103 seq_nums: Arc<SequenceNumbers>,
105 encoder: FixEncoder,
107}
108
109impl FixSession {
110 #[must_use]
112 pub fn new(config: FixConfig) -> Self {
113 let encoder = FixEncoder::new(
114 config.version,
115 &config.sender_comp_id,
116 &config.target_comp_id,
117 );
118 Self {
119 config,
120 state: SessionState::Disconnected,
121 seq_nums: Arc::new(SequenceNumbers::new()),
122 encoder,
123 }
124 }
125
126 #[must_use]
128 pub fn state(&self) -> SessionState {
129 self.state
130 }
131
132 pub fn set_state(&mut self, state: SessionState) {
134 tracing::info!("Session state: {} -> {}", self.state, state);
135 self.state = state;
136 }
137
138 #[must_use]
140 pub fn seq_nums(&self) -> Arc<SequenceNumbers> {
141 Arc::clone(&self.seq_nums)
142 }
143
144 #[must_use]
146 pub fn create_logon(&self) -> String {
147 let mut fields = vec![
148 (tags::ENCRYPT_METHOD, "0".to_string()),
149 (
150 tags::HEART_BT_INT,
151 self.config.heartbeat_interval_secs.to_string(),
152 ),
153 ];
154
155 if self.config.reset_on_logon {
156 fields.push((tags::RESET_SEQ_NUM_FLAG, "Y".to_string()));
157 }
158
159 self.encoder.encode(
160 MsgType::Logon.as_str(),
161 self.seq_nums.next_outgoing(),
162 &fields,
163 )
164 }
165
166 #[must_use]
168 pub fn create_logout(&self, text: Option<&str>) -> String {
169 let fields = if let Some(t) = text {
170 vec![(tags::TEXT, t.to_string())]
171 } else {
172 vec![]
173 };
174
175 self.encoder.encode(
176 MsgType::Logout.as_str(),
177 self.seq_nums.next_outgoing(),
178 &fields,
179 )
180 }
181
182 #[must_use]
184 pub fn create_heartbeat(&self, test_req_id: Option<&str>) -> String {
185 let fields = if let Some(id) = test_req_id {
186 vec![(tags::TEST_REQ_ID, id.to_string())]
187 } else {
188 vec![]
189 };
190
191 self.encoder.encode(
192 MsgType::Heartbeat.as_str(),
193 self.seq_nums.next_outgoing(),
194 &fields,
195 )
196 }
197
198 #[must_use]
200 pub fn create_test_request(&self, test_req_id: &str) -> String {
201 let fields = vec![(tags::TEST_REQ_ID, test_req_id.to_string())];
202
203 self.encoder.encode(
204 MsgType::TestRequest.as_str(),
205 self.seq_nums.next_outgoing(),
206 &fields,
207 )
208 }
209
210 #[must_use]
212 pub fn create_resend_request(&self, begin_seq: u64, end_seq: u64) -> String {
213 let fields = vec![
214 (tags::BEGIN_SEQ_NO, begin_seq.to_string()),
215 (tags::END_SEQ_NO, end_seq.to_string()),
216 ];
217
218 self.encoder.encode(
219 MsgType::ResendRequest.as_str(),
220 self.seq_nums.next_outgoing(),
221 &fields,
222 )
223 }
224
225 pub fn validate_sequence(&self, msg: &FixMessage) -> Result<()> {
227 if let Some(seq_str) = msg.get(tags::MSG_SEQ_NUM) {
228 let seq: u64 = seq_str
229 .parse()
230 .map_err(|_| FixError::Decoding("invalid sequence number".to_string()))?;
231
232 let expected = self.seq_nums.expected_incoming();
233
234 if seq < expected {
235 tracing::warn!("Received old message: seq={}, expected={}", seq, expected);
237 } else if seq > expected {
238 return Err(FixError::SequenceError {
240 expected,
241 actual: seq,
242 });
243 }
244
245 self.seq_nums.increment_incoming();
246 }
247
248 Ok(())
249 }
250
251 pub fn encode_message(&self, msg_type: &str, fields: &[(u32, String)]) -> String {
253 self.encoder
254 .encode(msg_type, self.seq_nums.next_outgoing(), fields)
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 #[test]
263 fn test_sequence_numbers() {
264 let seq = SequenceNumbers::new();
265 assert_eq!(seq.next_outgoing(), 1);
266 assert_eq!(seq.next_outgoing(), 2);
267 assert_eq!(seq.current_outgoing(), 3);
268
269 assert_eq!(seq.expected_incoming(), 1);
270 seq.increment_incoming();
271 assert_eq!(seq.expected_incoming(), 2);
272
273 seq.reset();
274 assert_eq!(seq.current_outgoing(), 1);
275 assert_eq!(seq.expected_incoming(), 1);
276 }
277
278 #[test]
279 fn test_session_state() {
280 let config = FixConfig::builder()
281 .sender_comp_id("SENDER")
282 .target_comp_id("TARGET")
283 .build();
284 let mut session = FixSession::new(config);
285
286 assert_eq!(session.state(), SessionState::Disconnected);
287 session.set_state(SessionState::Active);
288 assert_eq!(session.state(), SessionState::Active);
289 }
290
291 #[test]
292 fn test_create_logon() {
293 let config = FixConfig::builder()
294 .sender_comp_id("SENDER")
295 .target_comp_id("TARGET")
296 .heartbeat_interval_secs(30)
297 .build();
298 let session = FixSession::new(config);
299 let logon = session.create_logon();
300
301 assert!(logon.contains("35=A"));
302 assert!(logon.contains("108=30"));
303 }
304
305 #[test]
306 fn test_create_heartbeat() {
307 let config = FixConfig::builder()
308 .sender_comp_id("SENDER")
309 .target_comp_id("TARGET")
310 .build();
311 let session = FixSession::new(config);
312 let heartbeat = session.create_heartbeat(Some("TEST123"));
313
314 assert!(heartbeat.contains("35=0"));
315 assert!(heartbeat.contains("112=TEST123"));
316 }
317}