deribit_fix/session/
fix_session.rs

1//! FIX session management
2
3use crate::model::message::FixMessage;
4use crate::model::types::MsgType;
5use crate::{
6    config::DeribitFixConfig,
7    connection::Connection,
8    error::{DeribitFixError, Result},
9    message::{MessageBuilder, PositionReport, RequestForPositions},
10};
11use base64::prelude::*;
12use chrono::Utc;
13use deribit_base::prelude::*;
14use rand;
15use sha2::{Digest, Sha256};
16use std::str::FromStr;
17use std::sync::Arc;
18use tokio::sync::Mutex;
19use tracing::{debug, info};
20
21/// FIX session state
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum SessionState {
24    /// Session is disconnected
25    Disconnected,
26    /// Logon message sent, waiting for response
27    LogonSent,
28    /// Session is logged on and active
29    LoggedOn,
30    /// Logout message sent, waiting for confirmation
31    LogoutSent,
32}
33
34/// FIX session manager
35pub struct Session {
36    config: DeribitFixConfig,
37    connection: Option<Arc<Mutex<Connection>>>,
38    state: SessionState,
39    outgoing_seq_num: u32,
40    incoming_seq_num: u32,
41}
42
43impl Session {
44    /// Create a new FIX session
45    pub fn new(config: &DeribitFixConfig, connection: Arc<Mutex<Connection>>) -> Result<Self> {
46        info!("Creating new FIX session");
47        Ok(Self {
48            config: config.clone(),
49            state: SessionState::Disconnected,
50            outgoing_seq_num: 1,
51            incoming_seq_num: 1,
52            connection: Some(connection),
53        })
54    }
55
56    /// Set the connection for this session
57    pub fn set_connection(&mut self, connection: Arc<Mutex<Connection>>) {
58        self.connection = Some(connection);
59    }
60
61    /// Get the current session state
62    pub fn get_state(&self) -> SessionState {
63        self.state
64    }
65
66    /// Send a FIX message through the connection
67    async fn send_message(&mut self, message: FixMessage) -> Result<()> {
68        if let Some(connection) = &self.connection {
69            let mut conn_guard = connection.lock().await;
70            conn_guard.send_message(&message).await?;
71            debug!("Sent FIX message: {}", message.to_string());
72        } else {
73            return Err(DeribitFixError::Connection(
74                "No connection available".to_string(),
75            ));
76        }
77        Ok(())
78    }
79
80    /// Perform FIX logon
81    pub async fn logon(&mut self) -> Result<()> {
82        info!("Performing FIX logon");
83
84        // Generate RawData and password hash according to Deribit FIX spec
85        let (raw_data, password_hash) = self.generate_auth_data(&self.config.password)?;
86
87        let mut message_builder = MessageBuilder::new()
88            .msg_type(MsgType::Logon)
89            .sender_comp_id(self.config.sender_comp_id.clone())
90            .target_comp_id(self.config.target_comp_id.clone())
91            .msg_seq_num(self.outgoing_seq_num)
92            .field(108, self.config.heartbeat_interval.to_string()) // HeartBtInt - Required
93            .field(96, raw_data.clone()) // RawData - Required (timestamp.nonce)
94            .field(553, self.config.username.clone()) // Username - Required
95            .field(554, password_hash); // Password - Required
96
97        // Add RawDataLength if needed (optional but recommended)
98        message_builder = message_builder.field(95, raw_data.len().to_string()); // RawDataLength
99
100        // Add optional Deribit-specific tags based on configuration
101        if let Some(use_wordsafe_tags) = &self.config.use_wordsafe_tags {
102            message_builder =
103                message_builder.field(9002, if *use_wordsafe_tags { "Y" } else { "N" }.to_string()); // UseWordsafeTags
104        }
105
106        // CancelOnDisconnect - always include based on config
107        message_builder = message_builder.field(
108            9001,
109            if self.config.cancel_on_disconnect {
110                "Y"
111            } else {
112                "N"
113            }
114            .to_string(),
115        ); // CancelOnDisconnect
116
117        if let Some(app_id) = &self.config.app_id {
118            message_builder = message_builder.field(9004, app_id.clone()); // DeribitAppId
119        }
120
121        if let Some(app_secret) = &self.config.app_secret
122            && let Some(raw_data_str) = raw_data
123                .split_once('.')
124                .map(|(timestamp, nonce)| format!("{}.{}", timestamp, nonce))
125            && let Ok(app_sig) = self.calculate_app_signature(&raw_data_str, app_secret)
126        {
127            message_builder = message_builder.field(9005, app_sig); // DeribitAppSig
128        }
129
130        if let Some(deribit_sequential) = &self.config.deribit_sequential {
131            message_builder = message_builder.field(
132                9007,
133                if *deribit_sequential { "Y" } else { "N" }.to_string(),
134            ); // DeribitSequential
135        }
136
137        if let Some(unsubscribe_exec_reports) = &self.config.unsubscribe_execution_reports {
138            message_builder = message_builder.field(
139                9009,
140                if *unsubscribe_exec_reports { "Y" } else { "N" }.to_string(),
141            ); // UnsubscribeExecutionReports
142        }
143
144        if let Some(connection_only_exec_reports) = &self.config.connection_only_execution_reports {
145            message_builder = message_builder.field(
146                9010,
147                if *connection_only_exec_reports {
148                    "Y"
149                } else {
150                    "N"
151                }
152                .to_string(),
153            ); // ConnectionOnlyExecutionReports
154        }
155
156        if let Some(report_fills_as_exec_reports) = &self.config.report_fills_as_exec_reports {
157            message_builder = message_builder.field(
158                9015,
159                if *report_fills_as_exec_reports {
160                    "Y"
161                } else {
162                    "N"
163                }
164                .to_string(),
165            ); // ReportFillsAsExecReports
166        }
167
168        if let Some(display_increment_steps) = &self.config.display_increment_steps {
169            message_builder = message_builder.field(
170                9018,
171                if *display_increment_steps { "Y" } else { "N" }.to_string(),
172            ); // DisplayIncrementSteps
173        }
174
175        // Add AppID if available - temporarily disabled for testing
176        // if let Some(app_id) = &self.config.app_id {
177        //     message_builder = message_builder.field(1128, app_id.clone()); // AppID
178        // }
179
180        let logon_message = message_builder.build()?;
181
182        // Send the logon message
183        self.send_message(logon_message).await?;
184        self.state = SessionState::LogonSent;
185        self.outgoing_seq_num += 1;
186
187        info!("Logon message sent");
188        Ok(())
189    }
190
191    /// Perform FIX logout
192    pub async fn logout(&mut self) -> Result<()> {
193        self.logout_with_options(None, None).await
194    }
195
196    /// Perform FIX logout with optional parameters
197    pub async fn logout_with_options(
198        &mut self,
199        text: Option<String>,
200        dont_cancel_on_disconnect: Option<bool>,
201    ) -> Result<()> {
202        info!("Performing FIX logout");
203
204        let mut message_builder = MessageBuilder::new()
205            .msg_type(MsgType::Logout)
206            .sender_comp_id(self.config.sender_comp_id.clone())
207            .target_comp_id(self.config.target_comp_id.clone())
208            .msg_seq_num(self.outgoing_seq_num);
209
210        // Add Text field (tag 58) - optional
211        let logout_text = text.unwrap_or_else(|| "Normal logout".to_string());
212        message_builder = message_builder.field(58, logout_text); // Text
213
214        // Add DontCancelOnDisconnect field (tag 9003) - optional
215        if let Some(dont_cancel) = dont_cancel_on_disconnect {
216            message_builder =
217                message_builder.field(9003, if dont_cancel { "Y" } else { "N" }.to_string()); // DontCancelOnDisconnect
218        }
219
220        let logout_message = message_builder.build()?;
221
222        // Send the logout message
223        self.send_message(logout_message).await?;
224        self.state = SessionState::LogoutSent;
225        self.outgoing_seq_num += 1;
226
227        info!("Logout message sent");
228        Ok(())
229    }
230
231    /// Send a heartbeat message
232    pub async fn send_heartbeat(&mut self, test_req_id: Option<String>) -> Result<()> {
233        debug!("Sending heartbeat message");
234
235        let mut builder = MessageBuilder::new()
236            .msg_type(MsgType::Heartbeat)
237            .sender_comp_id(self.config.sender_comp_id.clone())
238            .target_comp_id(self.config.target_comp_id.clone())
239            .msg_seq_num(self.outgoing_seq_num);
240
241        if let Some(test_req_id) = test_req_id {
242            builder = builder.field(112, test_req_id); // TestReqID
243        }
244
245        let heartbeat_message = builder.build()?;
246
247        // Send the heartbeat message
248        self.send_message(heartbeat_message).await?;
249        self.outgoing_seq_num += 1;
250
251        debug!("Heartbeat message sent");
252        Ok(())
253    }
254
255    /// Send a new order
256    pub fn send_new_order(&mut self, order: NewOrderRequest) -> Result<String> {
257        info!("Sending new order: {:?}", order);
258
259        let order_id = format!("ORDER_{}", chrono::Utc::now().timestamp_millis());
260
261        let _order_message = MessageBuilder::new()
262            .msg_type(MsgType::NewOrderSingle)
263            .sender_comp_id(self.config.sender_comp_id.clone())
264            .target_comp_id(self.config.target_comp_id.clone())
265            .msg_seq_num(self.outgoing_seq_num)
266            .field(11, order_id.clone()) // ClOrdID
267            .field(55, order.instrument_name.clone()) // Symbol
268            .field(
269                54,
270                match order.side {
271                    deribit_base::model::order::OrderSide::Buy => "1".to_string(),
272                    deribit_base::model::order::OrderSide::Sell => "2".to_string(),
273                },
274            ) // Side
275            .field(60, Utc::now().format("%Y%m%d-%H:%M:%S%.3f").to_string()) // TransactTime
276            .field(38, order.amount.to_string()) // OrderQty
277            .field(40, "2".to_string()) // OrdType (2 = Limit)
278            .field(44, order.price.unwrap_or(0.0).to_string()) // Price
279            .build()?;
280
281        // In a real implementation, you would send this message
282        self.outgoing_seq_num += 1;
283
284        info!("New order message prepared with ID: {}", order_id);
285        Ok(order_id)
286    }
287
288    /// Cancel an order
289    pub fn cancel_order(&mut self, order_id: String) -> Result<()> {
290        info!("Cancelling order: {}", order_id);
291
292        let cancel_id = format!("CANCEL_{}", chrono::Utc::now().timestamp_millis());
293
294        let _cancel_message = MessageBuilder::new()
295            .msg_type(MsgType::OrderCancelRequest)
296            .sender_comp_id(self.config.sender_comp_id.clone())
297            .target_comp_id(self.config.target_comp_id.clone())
298            .msg_seq_num(self.outgoing_seq_num)
299            .field(11, cancel_id) // ClOrdID
300            .field(41, order_id) // OrigClOrdID
301            .field(60, Utc::now().format("%Y%m%d-%H:%M:%S%.3f").to_string()) // TransactTime
302            .build()?;
303
304        // In a real implementation, you would send this message
305        self.outgoing_seq_num += 1;
306
307        info!("Order cancel message prepared");
308        Ok(())
309    }
310
311    /// Subscribe to market data
312    pub async fn subscribe_market_data(&mut self, symbol: String) -> Result<()> {
313        info!("Subscribing to market data for: {}", symbol);
314
315        let request_id = format!("MDR_{}", chrono::Utc::now().timestamp_millis());
316
317        let market_data_request = MessageBuilder::new()
318            .msg_type(MsgType::MarketDataRequest)
319            .sender_comp_id(self.config.sender_comp_id.clone())
320            .target_comp_id(self.config.target_comp_id.clone())
321            .msg_seq_num(self.outgoing_seq_num)
322            .field(262, request_id.clone()) // MDReqID
323            .field(263, "1".to_string()) // SubscriptionRequestType (1 = Snapshot + Updates)
324            .field(264, "0".to_string()) // MarketDepth (0 = Full Book)
325            .field(267, "2".to_string()) // NoMDEntryTypes
326            .field(269, "0".to_string()) // MDEntryType (0 = Bid)
327            .field(269, "1".to_string()) // MDEntryType (1 = Offer)
328            .field(146, "1".to_string()) // NoRelatedSym
329            .field(55, symbol.clone()) // Symbol
330            .build()?;
331
332        // Send the market data request
333        self.send_message(market_data_request).await?;
334        self.outgoing_seq_num += 1;
335
336        info!(
337            "Market data subscription request sent for symbol: {} with ID: {}",
338            symbol, request_id
339        );
340        Ok(())
341    }
342
343    /// Request positions asynchronously
344    pub async fn request_positions(&mut self) -> Result<Vec<Position>> {
345        use std::time::{Duration, Instant};
346        use tracing::{debug, info, warn};
347
348        info!("Requesting positions");
349
350        let request_id = format!("POS_{}", chrono::Utc::now().timestamp_millis());
351
352        // Create typed position request
353        let position_request = RequestForPositions::all_positions(request_id.clone())
354            .with_clearing_date(Utc::now().format("%Y%m%d").to_string());
355
356        // Build the FIX message
357        let fix_message = position_request.to_fix_message(
358            self.config.sender_comp_id.clone(),
359            self.config.target_comp_id.clone(),
360            self.outgoing_seq_num,
361        )?;
362
363        // Send the position request
364        self.send_message(fix_message).await?;
365        self.outgoing_seq_num += 1;
366
367        info!(
368            "Position request sent, awaiting responses for request ID: {}",
369            request_id
370        );
371
372        // Collect position reports with correlation by PosReqID
373        let mut positions = Vec::new();
374        let timeout = Duration::from_secs(30); // 30 second timeout
375        let start_time = Instant::now();
376
377        loop {
378            // Check for timeout
379            if start_time.elapsed() > timeout {
380                warn!("Position request timed out after {:?}", timeout);
381                break;
382            }
383
384            // Receive and process messages
385            match self.receive_and_process_message().await {
386                Ok(Some(message)) => {
387                    // Check if this is a PositionReport message
388                    if let Some(msg_type_str) = message.get_field(35)
389                        && msg_type_str == "AP"
390                    {
391                        // PositionReport
392                        // Check if this position report matches our request ID
393                        if let Some(pos_req_id) = message.get_field(710) {
394                            if pos_req_id == &request_id {
395                                debug!("Received PositionReport for request: {}", request_id);
396
397                                match PositionReport::from_fix_message(&message) {
398                                    Ok(position_report) => {
399                                        let position = position_report.to_position();
400                                        debug!(
401                                            "Parsed position: {} - Qty: {}, Avg Price: {}",
402                                            position.symbol,
403                                            position.quantity,
404                                            position.average_price
405                                        );
406                                        positions.push(position);
407                                    }
408                                    Err(e) => {
409                                        warn!("Failed to parse PositionReport: {}", e);
410                                    }
411                                }
412                            } else {
413                                debug!(
414                                    "Received PositionReport for different request: {}",
415                                    pos_req_id
416                                );
417                            }
418                        }
419                    }
420                }
421                Ok(None) => {
422                    // No message received, continue loop
423                    tokio::time::sleep(Duration::from_millis(10)).await;
424                }
425                Err(e) => {
426                    warn!("Error receiving message: {}", e);
427                    // Continue trying to receive more messages
428                    tokio::time::sleep(Duration::from_millis(100)).await;
429                }
430            }
431
432            // For now, we'll break after receiving some positions or after a reasonable time
433            // In a real implementation, we might wait for an end-of-transmission signal
434            if !positions.is_empty() && start_time.elapsed() > Duration::from_secs(5) {
435                debug!(
436                    "Received {} positions, stopping collection",
437                    positions.len()
438                );
439                break;
440            }
441        }
442
443        info!(
444            "Position request completed, received {} positions",
445            positions.len()
446        );
447        Ok(positions)
448    }
449
450    /// Generate authentication data according to Deribit FIX specification
451    /// Returns (raw_data, base64_password_hash)
452    pub fn generate_auth_data(&self, access_secret: &str) -> Result<(String, String)> {
453        // Generate timestamp (strictly increasing integer in milliseconds)
454        let timestamp = chrono::Utc::now().timestamp_millis();
455
456        // Generate random nonce (at least 32 bytes as recommended by Deribit)
457        let mut nonce_bytes = vec![0u8; 32];
458        for byte in nonce_bytes.iter_mut() {
459            *byte = rand::random::<u8>();
460        }
461        let nonce_b64 = BASE64_STANDARD.encode(&nonce_bytes);
462
463        // Create RawData: timestamp.nonce (separated by ASCII period)
464        let raw_data = format!("{timestamp}.{nonce_b64}");
465
466        // Calculate password hash: base64(sha256(RawData ++ access_secret))
467        let mut auth_data = raw_data.as_bytes().to_vec();
468        auth_data.extend_from_slice(access_secret.as_bytes());
469
470        debug!("Timestamp: {}", timestamp);
471        debug!("Nonce length: {} bytes", nonce_bytes.len());
472        debug!("Nonce (base64): {}", nonce_b64);
473        debug!("RawData: {}", raw_data);
474        debug!("Access secret: {}", access_secret);
475        debug!("Auth data length: {} bytes", auth_data.len());
476
477        let mut hasher = Sha256::new();
478        hasher.update(&auth_data);
479        let hash_result = hasher.finalize();
480        let password_hash = BASE64_STANDARD.encode(hash_result);
481
482        debug!("Password hash: {}", password_hash);
483
484        Ok((raw_data, password_hash))
485    }
486
487    /// Calculate application signature for registered apps
488    #[allow(dead_code)]
489    fn calculate_app_signature(&self, raw_data: &str, app_secret: &str) -> Result<String> {
490        let mut hasher = Sha256::new();
491        hasher.update(format!("{raw_data}{app_secret}").as_bytes());
492        let result = hasher.finalize();
493        Ok(BASE64_STANDARD.encode(result))
494    }
495
496    /// Get current session state
497    pub fn state(&self) -> SessionState {
498        self.state
499    }
500
501    /// Set session state (for testing)
502    pub fn set_state(&mut self, state: SessionState) {
503        self.state = state;
504    }
505
506    /// Process incoming FIX message
507    async fn process_message(&mut self, message: &FixMessage) -> Result<()> {
508        debug!("Processing FIX message: {:?}", message);
509
510        // Get message type
511        let msg_type_str = message.get_field(35).unwrap_or(&String::new()).clone();
512        let msg_type = MsgType::from_str(&msg_type_str).map_err(|_| {
513            DeribitFixError::MessageParsing(format!("Unknown message type: {msg_type_str}"))
514        })?;
515
516        match msg_type {
517            MsgType::Logon => {
518                info!("Received logon response");
519                self.state = SessionState::LoggedOn;
520            }
521            MsgType::Logout => {
522                info!("Received logout message");
523                self.state = SessionState::Disconnected;
524            }
525            MsgType::Heartbeat => {
526                debug!("Received heartbeat");
527            }
528            MsgType::TestRequest => {
529                debug!("Received test request, sending heartbeat response");
530                let test_req_id = message.get_field(112);
531                self.send_heartbeat(test_req_id.cloned()).await?;
532            }
533            _ => {
534                debug!("Received message type: {:?}", msg_type);
535            }
536        }
537
538        self.incoming_seq_num += 1;
539        Ok(())
540    }
541
542    /// Receive and process a FIX message from the connection
543    pub async fn receive_and_process_message(&mut self) -> Result<Option<FixMessage>> {
544        let message = if let Some(connection) = &self.connection {
545            let mut conn_guard = connection.lock().await;
546            conn_guard.receive_message().await?
547        } else {
548            None
549        };
550
551        if let Some(message) = message {
552            self.process_message(&message).await?;
553            Ok(Some(message))
554        } else {
555            Ok(None)
556        }
557    }
558}