deribit_fix/connection/
tcp_connection.rs

1//! Connection management for Deribit FIX client
2
3use crate::model::message::FixMessage;
4use crate::model::stream::Stream;
5use crate::{
6    config::DeribitFixConfig,
7    error::{DeribitFixError, Result},
8};
9use std::collections::VecDeque;
10use std::str::FromStr;
11use tokio::{net::TcpStream, time::timeout};
12use tokio_native_tls::TlsConnector;
13use tracing::{debug, error, info, trace};
14
15/// TCP/TLS connection to Deribit FIX server
16pub struct Connection {
17    stream: Stream,
18    config: DeribitFixConfig,
19    buffer: Vec<u8>,
20    message_queue: VecDeque<FixMessage>,
21    connected: bool,
22}
23
24impl Connection {
25    /// Create a new connection to the Deribit FIX server
26    pub async fn new(config: &DeribitFixConfig) -> Result<Self> {
27        let stream = if config.use_ssl {
28            Self::connect_tls(config).await?
29        } else {
30            Self::connect_tcp(config).await?
31        };
32
33        Ok(Self {
34            stream,
35            config: config.clone(),
36            buffer: Vec::with_capacity(8192),
37            message_queue: VecDeque::new(),
38            connected: true,
39        })
40    }
41
42    /// Connect using raw TCP
43    async fn connect_tcp(config: &DeribitFixConfig) -> Result<Stream> {
44        info!("Connecting to {}:{} via TCP", config.host, config.port);
45
46        let addr = format!("{}:{}", config.host, config.port);
47        let stream = timeout(config.connection_timeout, TcpStream::connect(&addr))
48            .await
49            .map_err(|_| DeribitFixError::Timeout(format!("Connection timeout to {addr}")))?
50            .map_err(|e| {
51                DeribitFixError::Connection(format!("Failed to connect to {addr}: {e}"))
52            })?;
53
54        info!("Successfully connected via TCP");
55        Ok(Stream::Tcp(stream))
56    }
57
58    /// Connect using TLS
59    async fn connect_tls(config: &DeribitFixConfig) -> Result<Stream> {
60        info!("Connecting to {}:{} via TLS", config.host, config.port);
61
62        let addr = format!("{}:{}", config.host, config.port);
63        let tcp_stream = timeout(config.connection_timeout, TcpStream::connect(&addr))
64            .await
65            .map_err(|_| DeribitFixError::Timeout(format!("Connection timeout to {addr}")))?
66            .map_err(|e| {
67                DeribitFixError::Connection(format!("Failed to connect to {addr}: {e}"))
68            })?;
69
70        let connector = TlsConnector::from(
71            native_tls::TlsConnector::builder()
72                .build()
73                .map_err(|e| DeribitFixError::Connection(format!("TLS setup failed: {e}")))?,
74        );
75
76        let tls_stream = connector
77            .connect(&config.host, tcp_stream)
78            .await
79            .map_err(|e| DeribitFixError::Connection(format!("TLS handshake failed: {e}")))?;
80
81        info!("Successfully connected via TLS");
82        Ok(Stream::Tls(tls_stream))
83    }
84
85    /// Send a FIX message
86    pub async fn send_message(&mut self, message: &FixMessage) -> Result<()> {
87        if !self.connected {
88            return Err(DeribitFixError::Connection(
89                "Connection is not active".to_string(),
90            ));
91        }
92
93        let message_str = message.to_string();
94        debug!("Sending FIX message: {}", message_str);
95
96        match self.stream.write_all(message_str.as_bytes()).await {
97            Ok(_) => {}
98            Err(e) => {
99                error!("Failed to send message: {}", e);
100                self.connected = false;
101                return Err(DeribitFixError::Io(e));
102            }
103        }
104
105        match self.stream.flush().await {
106            Ok(_) => Ok(()),
107            Err(e) => {
108                error!("Failed to flush stream: {}", e);
109                self.connected = false;
110                Err(DeribitFixError::Io(e))
111            }
112        }
113    }
114
115    /// Receive a FIX message from the server
116    pub async fn receive_message(&mut self) -> Result<Option<FixMessage>> {
117        if !self.connected {
118            return Err(DeribitFixError::Connection(
119                "Not connected to server".to_string(),
120            ));
121        }
122
123        // Check if we have queued messages first
124        if let Some(message) = self.message_queue.pop_front() {
125            return Ok(Some(message));
126        }
127
128        // Try to parse any existing buffered data first
129        self.parse_all_messages_from_buffer()?;
130        if let Some(message) = self.message_queue.pop_front() {
131            return Ok(Some(message));
132        }
133
134        // Read data from the stream with timeout
135        let mut temp_buffer = vec![0u8; 4096];
136
137        // Use a timeout to avoid blocking indefinitely
138        match tokio::time::timeout(
139            std::time::Duration::from_millis(1000), // Increased to 1 second
140            self.stream.read(&mut temp_buffer),
141        )
142        .await
143        {
144            Ok(Ok(0)) => {
145                // Connection closed
146                debug!("Connection closed by server");
147                self.connected = false;
148                Ok(None)
149            }
150            Ok(Ok(n)) => {
151                trace!("Received {} bytes from server", n);
152                trace!("Raw bytes: {:?}", &temp_buffer[..n]);
153                self.buffer.extend_from_slice(&temp_buffer[..n]);
154
155                // Parse all complete messages from buffer and queue them
156                self.parse_all_messages_from_buffer()?;
157
158                // Return the first message from queue
159                Ok(self.message_queue.pop_front())
160            }
161            Ok(Err(e)) => {
162                if e.kind() == std::io::ErrorKind::WouldBlock {
163                    // No data available right now
164                    return Ok(None);
165                }
166                error!("IO error reading from server: {}", e);
167                // Mark connection as inactive on IO errors
168                self.connected = false;
169                Err(DeribitFixError::Io(e))
170            }
171            Err(_) => {
172                // Timeout - no data available
173                Ok(None)
174            }
175        }
176    }
177
178    /// Parse all complete messages from buffer and add to queue
179    fn parse_all_messages_from_buffer(&mut self) -> Result<()> {
180        while let Some(message) = self.try_parse_message()? {
181            self.message_queue.push_back(message);
182        }
183        Ok(())
184    }
185
186    /// Try to parse a complete FIX message from the buffer
187    fn try_parse_message(&mut self) -> Result<Option<FixMessage>> {
188        if !self.buffer.is_empty() {
189            trace!(
190                "Buffer contains {} bytes: {:?}",
191                self.buffer.len(),
192                String::from_utf8_lossy(&self.buffer)
193            );
194        }
195
196        // Look for SOH (Start of Header) character which delimits FIX fields
197        const SOH: u8 = 0x01;
198
199        // Find the beginning of a FIX message (looking for BeginString field)
200        let buffer_str = String::from_utf8_lossy(&self.buffer);
201
202        // Look for the start of a FIX message with BeginString (8=FIX.4.4)
203        if let Some(msg_start) = buffer_str.find("8=FIX.4.4") {
204            // For FIX messages, we need to check the BodyLength (tag 9) to know the complete message size
205            let message_from_start = &buffer_str[msg_start..];
206
207            // Parse the BodyLength to determine the complete message size
208            if let Some(body_length_start) = message_from_start.find("9=")
209                && let Some(body_length_end) =
210                    message_from_start[body_length_start + 2..].find(char::from(SOH))
211            {
212                let body_length_str = &message_from_start
213                    [body_length_start + 2..body_length_start + 2 + body_length_end];
214                if let Ok(body_length) = body_length_str.parse::<usize>() {
215                    // Calculate the total message length:
216                    // "8=FIX.4.4\x01" + body_length + checksum field
217                    let header_length = body_length_start + 2 + body_length_end + 1; // Up to and including SOH after BodyLength
218                    let expected_total_length = msg_start + header_length + body_length;
219
220                    // Check if we have the complete message
221                    if self.buffer.len() >= expected_total_length {
222                        let message_bytes = self
223                            .buffer
224                            .drain(msg_start..expected_total_length)
225                            .collect::<Vec<u8>>();
226                        let message_str = String::from_utf8_lossy(&message_bytes);
227
228                        debug!(
229                            "Received complete FIX message ({} bytes): {}",
230                            message_bytes.len(),
231                            message_str
232                        );
233
234                        // Parse the message
235                        match FixMessage::from_str(&message_str) {
236                            Ok(message) => return Ok(Some(message)),
237                            Err(e) => {
238                                return Err(DeribitFixError::MessageParsing(format!(
239                                    "Failed to parse FIX message: {e}"
240                                )));
241                            }
242                        }
243                    } else {
244                        debug!(
245                            "Incomplete message: have {} bytes, need {}",
246                            self.buffer.len(),
247                            expected_total_length
248                        );
249                        return Ok(None);
250                    }
251                }
252            }
253
254            // Fallback to old checksum-based parsing if BodyLength parsing fails
255            if let Some(checksum_pos) = message_from_start.find("10=") {
256                debug!(
257                    "Found checksum field at position {}",
258                    msg_start + checksum_pos
259                );
260
261                // Find the SOH after the checksum (should be 3 digits + SOH)
262                let checksum_section = &message_from_start[checksum_pos..];
263                if let Some(end_pos) = checksum_section.find(char::from(SOH)) {
264                    // Make sure we have the full 3-digit checksum
265                    if end_pos >= 4 {
266                        // "10=" + 3 digits = 7 chars minimum, but we'll be more lenient
267                        let message_end = msg_start + checksum_pos + end_pos + 1;
268                        let message_bytes = self
269                            .buffer
270                            .drain(msg_start..message_end)
271                            .collect::<Vec<u8>>();
272                        let message_str = String::from_utf8_lossy(&message_bytes);
273
274                        debug!("Received FIX message (fallback): {}", message_str);
275
276                        // Parse the message
277                        match FixMessage::from_str(&message_str) {
278                            Ok(message) => Ok(Some(message)),
279                            Err(e) => Err(DeribitFixError::MessageParsing(format!(
280                                "Failed to parse FIX message: {e}"
281                            ))),
282                        }
283                    } else {
284                        // Incomplete checksum
285                        Ok(None)
286                    }
287                } else {
288                    // Incomplete message - checksum field found but no terminating SOH
289                    Ok(None)
290                }
291            } else {
292                // No complete message yet - found start but no checksum
293                Ok(None)
294            }
295        } else {
296            // No message start found yet - might be just leftover data or waiting for more
297            // Clear any non-message data from the beginning of buffer
298            if !buffer_str.is_empty() && !buffer_str.starts_with("8=FIX") {
299                // Find if there's a message start somewhere in the buffer
300                if let Some(start_pos) = buffer_str.find("8=FIX") {
301                    // Remove garbage data before the message start
302                    debug!(
303                        "Removing {} bytes of garbage data before message start",
304                        start_pos
305                    );
306                    self.buffer.drain(..start_pos);
307                } else {
308                    // No message start found, could be fragment - keep if small, discard if too large
309                    if self.buffer.len() > 1000 {
310                        debug!(
311                            "Clearing large buffer ({} bytes) with no message start",
312                            self.buffer.len()
313                        );
314                        self.buffer.clear();
315                    } else if self.buffer.len() > 10 && !buffer_str.trim().is_empty() {
316                        // Check if this looks like invalid data (not starting with FIX fields)
317                        let trimmed = buffer_str.trim();
318                        // Valid FIX fragments should contain field numbers like "10=", "35=", etc.
319                        // or be very short (under certain threshold)
320                        if !trimmed.contains('=')
321                            || (!trimmed.starts_with(char::is_numeric) && self.buffer.len() > 20)
322                        {
323                            // This looks like invalid data, not a FIX message fragment
324                            return Err(DeribitFixError::MessageParsing(format!(
325                                "Failed to parse invalid message data: {}",
326                                trimmed
327                            )));
328                        }
329                    }
330                    // Keep smaller fragments or those that look like valid FIX field fragments
331                }
332            }
333            Ok(None)
334        }
335    }
336
337    /// Check if the connection is active
338    pub fn is_connected(&self) -> bool {
339        self.connected
340    }
341
342    /// Close the connection
343    pub async fn close(&mut self) -> Result<()> {
344        self.connected = false;
345        self.stream.shutdown().await.map_err(DeribitFixError::Io)?;
346        info!("Connection closed");
347        Ok(())
348    }
349
350    /// Reconnect to the server
351    pub async fn reconnect(&mut self) -> Result<()> {
352        info!("Reconnecting to Deribit FIX server");
353
354        // Close existing connection
355        let _ = self.close().await;
356
357        // Create new connection
358        let stream = if self.config.use_ssl {
359            Self::connect_tls(&self.config).await?
360        } else {
361            Self::connect_tcp(&self.config).await?
362        };
363
364        self.stream = stream;
365        self.buffer.clear();
366        self.message_queue.clear();
367        self.connected = true;
368
369        info!("Successfully reconnected");
370        Ok(())
371    }
372}