deribit_fix/session/
fix_session.rs

1//! FIX session management
2
3use crate::config::gen_id;
4use crate::model::message::FixMessage;
5use crate::model::types::MsgType;
6use crate::{
7    config::DeribitFixConfig,
8    connection::Connection,
9    error::{DeribitFixError, Result},
10    message::{MessageBuilder, PositionReport, RequestForPositions},
11};
12use base64::prelude::*;
13use chrono::Utc;
14use deribit_base::prelude::*;
15use rand;
16use sha2::{Digest, Sha256};
17use std::str::FromStr;
18use std::sync::Arc;
19use tokio::sync::Mutex;
20use tracing::{debug, error, info, trace};
21
22/// FIX session state
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum SessionState {
25    /// Session is disconnected
26    Disconnected,
27    /// Logon message sent, waiting for response
28    LogonSent,
29    /// Session is logged on and active
30    LoggedOn,
31    /// Logout message sent, waiting for confirmation
32    LogoutSent,
33}
34
35/// FIX session manager
36pub struct Session {
37    config: DeribitFixConfig,
38    connection: Option<Arc<Mutex<Connection>>>,
39    state: SessionState,
40    outgoing_seq_num: u32,
41    incoming_seq_num: u32,
42}
43
44impl Session {
45    /// Create a new FIX session
46    pub fn new(config: &DeribitFixConfig, connection: Arc<Mutex<Connection>>) -> Result<Self> {
47        info!("Creating new FIX session");
48        Ok(Self {
49            config: config.clone(),
50            state: SessionState::Disconnected,
51            outgoing_seq_num: 1,
52            incoming_seq_num: 1,
53            connection: Some(connection),
54        })
55    }
56
57    /// Set the connection for this session
58    pub fn set_connection(&mut self, connection: Arc<Mutex<Connection>>) {
59        self.connection = Some(connection);
60    }
61
62    /// Get the current session state
63    pub fn get_state(&self) -> SessionState {
64        self.state
65    }
66
67    /// Send a FIX message through the connection
68    async fn send_message(&mut self, message: FixMessage) -> Result<()> {
69        if let Some(connection) = &self.connection {
70            let mut conn_guard = connection.lock().await;
71            conn_guard.send_message(&message).await?;
72            debug!("Sent FIX message: {}", message.to_string());
73        } else {
74            return Err(DeribitFixError::Connection(
75                "No connection available".to_string(),
76            ));
77        }
78        Ok(())
79    }
80
81    /// Perform FIX logon
82    pub async fn logon(&mut self) -> Result<()> {
83        info!("Performing FIX logon");
84
85        // Generate RawData and password hash according to Deribit FIX spec
86        let (raw_data, password_hash) = self.generate_auth_data(&self.config.password)?;
87
88        let mut message_builder = MessageBuilder::new()
89            .msg_type(MsgType::Logon)
90            .sender_comp_id(self.config.sender_comp_id.clone())
91            .target_comp_id(self.config.target_comp_id.clone())
92            .msg_seq_num(self.outgoing_seq_num)
93            .field(108, self.config.heartbeat_interval.to_string()) // HeartBtInt - Required
94            .field(96, raw_data.clone()) // RawData - Required (timestamp.nonce)
95            .field(553, self.config.username.clone()) // Username - Required
96            .field(554, password_hash); // Password - Required
97
98        // Add RawDataLength if needed (optional but recommended)
99        message_builder = message_builder.field(95, raw_data.len().to_string()); // RawDataLength
100
101        // Add optional Deribit-specific tags based on configuration
102        if let Some(use_wordsafe_tags) = &self.config.use_wordsafe_tags {
103            message_builder =
104                message_builder.field(9002, if *use_wordsafe_tags { "Y" } else { "N" }.to_string()); // UseWordsafeTags
105        }
106
107        // CancelOnDisconnect - always include based on config
108        message_builder = message_builder.field(
109            9001,
110            if self.config.cancel_on_disconnect {
111                "Y"
112            } else {
113                "N"
114            }
115            .to_string(),
116        ); // CancelOnDisconnect
117
118        if let Some(app_id) = &self.config.app_id {
119            message_builder = message_builder.field(9004, app_id.clone()); // DeribitAppId
120        }
121
122        if let Some(app_secret) = &self.config.app_secret
123            && let Some(raw_data_str) = raw_data
124                .split_once('.')
125                .map(|(timestamp, nonce)| format!("{}.{}", timestamp, nonce))
126            && let Ok(app_sig) = self.calculate_app_signature(&raw_data_str, app_secret)
127        {
128            message_builder = message_builder.field(9005, app_sig); // DeribitAppSig
129        }
130
131        if let Some(deribit_sequential) = &self.config.deribit_sequential {
132            message_builder = message_builder.field(
133                9007,
134                if *deribit_sequential { "Y" } else { "N" }.to_string(),
135            ); // DeribitSequential
136        }
137
138        if let Some(unsubscribe_exec_reports) = &self.config.unsubscribe_execution_reports {
139            message_builder = message_builder.field(
140                9009,
141                if *unsubscribe_exec_reports { "Y" } else { "N" }.to_string(),
142            ); // UnsubscribeExecutionReports
143        }
144
145        if let Some(connection_only_exec_reports) = &self.config.connection_only_execution_reports {
146            message_builder = message_builder.field(
147                9010,
148                if *connection_only_exec_reports {
149                    "Y"
150                } else {
151                    "N"
152                }
153                .to_string(),
154            ); // ConnectionOnlyExecutionReports
155        }
156
157        if let Some(report_fills_as_exec_reports) = &self.config.report_fills_as_exec_reports {
158            message_builder = message_builder.field(
159                9015,
160                if *report_fills_as_exec_reports {
161                    "Y"
162                } else {
163                    "N"
164                }
165                .to_string(),
166            ); // ReportFillsAsExecReports
167        }
168
169        if let Some(display_increment_steps) = &self.config.display_increment_steps {
170            message_builder = message_builder.field(
171                9018,
172                if *display_increment_steps { "Y" } else { "N" }.to_string(),
173            ); // DisplayIncrementSteps
174        }
175
176        // Add AppID if available
177        if let Some(app_id) = &self.config.app_id {
178            message_builder = message_builder.field(1128, app_id.clone()); // AppID
179        }
180
181        let logon_message = message_builder.build()?;
182
183        // Send the logon message
184        self.send_message(logon_message).await?;
185        self.state = SessionState::LogonSent;
186        self.outgoing_seq_num += 1;
187
188        info!("Logon message sent");
189        Ok(())
190    }
191
192    /// Perform FIX logout
193    pub async fn logout(&mut self) -> Result<()> {
194        self.logout_with_options(None, None).await
195    }
196
197    /// Perform FIX logout with optional parameters
198    pub async fn logout_with_options(
199        &mut self,
200        text: Option<String>,
201        dont_cancel_on_disconnect: Option<bool>,
202    ) -> Result<()> {
203        info!("Performing FIX logout");
204
205        let mut message_builder = MessageBuilder::new()
206            .msg_type(MsgType::Logout)
207            .sender_comp_id(self.config.sender_comp_id.clone())
208            .target_comp_id(self.config.target_comp_id.clone())
209            .msg_seq_num(self.outgoing_seq_num);
210
211        // Add Text field (tag 58) - optional
212        let logout_text = text.unwrap_or_else(|| "Normal logout".to_string());
213        message_builder = message_builder.field(58, logout_text); // Text
214
215        // Add DontCancelOnDisconnect field (tag 9003) - optional
216        if let Some(dont_cancel) = dont_cancel_on_disconnect {
217            message_builder =
218                message_builder.field(9003, if dont_cancel { "Y" } else { "N" }.to_string()); // DontCancelOnDisconnect
219        }
220
221        let logout_message = message_builder.build()?;
222
223        // Send the logout message
224        self.send_message(logout_message).await?;
225        self.state = SessionState::LogoutSent;
226        self.outgoing_seq_num += 1;
227
228        info!("Logout message sent");
229        Ok(())
230    }
231
232    /// Send a heartbeat message
233    pub async fn send_heartbeat(&mut self, test_req_id: Option<String>) -> Result<()> {
234        debug!("Sending heartbeat message");
235
236        let mut builder = MessageBuilder::new()
237            .msg_type(MsgType::Heartbeat)
238            .sender_comp_id(self.config.sender_comp_id.clone())
239            .target_comp_id(self.config.target_comp_id.clone())
240            .msg_seq_num(self.outgoing_seq_num);
241
242        if let Some(test_req_id) = test_req_id {
243            builder = builder.field(112, test_req_id); // TestReqID
244        }
245
246        let heartbeat_message = builder.build()?;
247
248        // Send the heartbeat message
249        self.send_message(heartbeat_message).await?;
250        self.outgoing_seq_num += 1;
251
252        debug!("Heartbeat message sent");
253        Ok(())
254    }
255
256    /// Send a new order
257    pub async fn send_new_order(&mut self, order: NewOrderRequest) -> Result<String> {
258        info!("Sending new order: {:?}", order);
259
260        // Use the client order ID if provided, otherwise generate one
261        let order_id = order
262            .client_order_id
263            .clone()
264            .unwrap_or_else(|| format!("ORDER_{}", gen_id()));
265
266        // Determine order type
267        let ord_type = match order.order_type {
268            OrderType::Market => "1",     // Market
269            OrderType::Limit => "2",      // Limit
270            OrderType::StopLimit => "4",  // Stop Limit
271            OrderType::StopMarket => "3", // Stop Market
272            _ => "2",                     // Default to Limit for other types
273        };
274
275        let mut builder = MessageBuilder::new()
276            .msg_type(MsgType::NewOrderSingle)
277            .sender_comp_id(self.config.sender_comp_id.clone())
278            .target_comp_id(self.config.target_comp_id.clone())
279            .msg_seq_num(self.outgoing_seq_num)
280            .field(11, order_id.clone()) // ClOrdID
281            .field(55, order.instrument_name.clone()) // Symbol
282            .field(
283                54,
284                match order.side {
285                    deribit_base::model::order::OrderSide::Buy => "1".to_string(),
286                    deribit_base::model::order::OrderSide::Sell => "2".to_string(),
287                },
288            ) // Side
289            .field(60, Utc::now().format("%Y%m%d-%H:%M:%S%.3f").to_string()) // TransactTime
290            .field(38, order.amount.to_string()) // OrderQty
291            .field(40, ord_type.to_string()); // OrdType
292
293        // Add price for limit orders
294        if order.order_type == OrderType::Limit || order.price.is_some() {
295            builder = builder.field(44, order.price.unwrap_or(0.0).to_string()); // Price
296        }
297
298        // Add time in force
299        let tif = match order.time_in_force {
300            TimeInForce::GoodTilCancelled => "1",
301            TimeInForce::ImmediateOrCancel => "3",
302            TimeInForce::FillOrKill => "4",
303            TimeInForce::GoodTilDay => "6",
304        };
305        builder = builder.field(59, tif.to_string()); // TimeInForce
306
307        // Add execution instructions
308        let mut exec_inst = String::new();
309        if order.post_only == Some(true) {
310            exec_inst.push('6'); // Participate don't initiate
311        }
312        if order.reduce_only == Some(true) {
313            exec_inst.push('E'); // Reduce only
314        }
315        if !exec_inst.is_empty() {
316            builder = builder.field(18, exec_inst); // ExecInst
317        }
318
319        // Add label if provided
320        if let Some(label) = &order.label {
321            builder = builder.field(100010, label.clone()); // Deribit label
322        }
323
324        let order_message = builder.build()?;
325
326        // Actually send the message
327        self.send_message(order_message).await?;
328        self.outgoing_seq_num += 1;
329
330        info!("New order message sent with ID: {}", order_id);
331        Ok(order_id)
332    }
333
334    /// Cancel an order
335    ///
336    /// # Arguments
337    /// * `order_id` - The order identifier (OrigClOrdID) to cancel
338    /// * `symbol` - Optional instrument symbol. Required when canceling by ClOrdID or DeribitLabel,
339    ///   but not required when using OrigClOrdID (fastest approach)
340    /// * `currency` - Optional currency to speed up search when using ClOrdID or DeribitLabel
341    pub async fn cancel_order(&mut self, order_id: String) -> Result<()> {
342        self.cancel_order_with_symbol(order_id, None).await
343    }
344
345    /// Cancel an order with optional symbol specification
346    ///
347    /// According to Deribit FIX documentation:
348    /// - Canceling by OrigClOrdId is fastest and recommended when possible
349    /// - Symbol is required only when OrigClOrdId is absent (canceling by ClOrdID or DeribitLabel)
350    /// - Currency can optionally speed up searches by DeribitLabel or ClOrdID
351    ///
352    /// # Arguments
353    /// * `order_id` - The order identifier (OrigClOrdID) to cancel
354    /// * `symbol` - Optional instrument symbol (e.g., "BTC-PERPETUAL")
355    /// * `currency` - Optional currency to speed up search
356    pub async fn cancel_order_with_symbol(
357        &mut self,
358        order_id: String,
359        symbol: Option<String>,
360    ) -> Result<()> {
361        info!("Cancelling order: {} with symbol: {:?}", order_id, symbol);
362
363        // Generate a proper unique cancel ID using random number instead of timestamp
364        let cancel_id = format!("CANCEL_{}", gen_id());
365
366        let mut builder = MessageBuilder::new()
367            .msg_type(MsgType::OrderCancelRequest)
368            .sender_comp_id(self.config.sender_comp_id.clone())
369            .target_comp_id(self.config.target_comp_id.clone())
370            .msg_seq_num(self.outgoing_seq_num)
371            .field(11, cancel_id) // ClOrdID - Original order identifier assigned by the user
372            .field(41, order_id) // OrigClOrdID - Order identifier assigned by Deribit
373            .field(60, Utc::now().format("%Y%m%d-%H:%M:%S%.3f").to_string()); // TransactTime
374
375        // Add symbol if provided - required when OrigClOrdId is absent
376        if let Some(symbol_value) = symbol {
377            builder = builder.field(55, symbol_value); // Symbol
378        }
379
380        let cancel_message = builder.build()?;
381
382        // Actually send the cancel message
383        self.send_message(cancel_message).await?;
384        self.outgoing_seq_num += 1;
385
386        info!("Order cancel message sent");
387        Ok(())
388    }
389
390    /// Subscribe to market data
391    pub async fn subscribe_market_data(&mut self, symbol: String) -> Result<()> {
392        info!("Subscribing to market data for: {}", symbol);
393
394        let request_id = format!("MDR_{}", gen_id());
395
396        let market_data_request = MessageBuilder::new()
397            .msg_type(MsgType::MarketDataRequest)
398            .sender_comp_id(self.config.sender_comp_id.clone())
399            .target_comp_id(self.config.target_comp_id.clone())
400            .msg_seq_num(self.outgoing_seq_num)
401            .field(262, request_id.clone()) // MDReqID
402            .field(263, "1".to_string()) // SubscriptionRequestType (1 = Snapshot + Updates)
403            .field(264, "0".to_string()) // MarketDepth (0 = Full Book)
404            .field(267, "2".to_string()) // NoMDEntryTypes
405            .field(269, "0".to_string()) // MDEntryType (0 = Bid)
406            .field(269, "1".to_string()) // MDEntryType (1 = Offer)
407            .field(146, "1".to_string()) // NoRelatedSym
408            .field(55, symbol.clone()) // Symbol
409            .build()?;
410
411        // Send the market data request
412        self.send_message(market_data_request).await?;
413        self.outgoing_seq_num += 1;
414
415        info!(
416            "Market data subscription request sent for symbol: {} with ID: {}",
417            symbol, request_id
418        );
419        Ok(())
420    }
421
422    /// Request positions asynchronously
423    pub async fn request_positions(&mut self) -> Result<Vec<Position>> {
424        use std::time::{Duration, Instant};
425        use tracing::{debug, info, warn};
426
427        info!("Requesting positions");
428
429        let request_id = format!("POS_{}", gen_id());
430
431        // Create typed position request
432        let position_request = RequestForPositions::all_positions(request_id.clone())
433            .with_clearing_date(Utc::now().format("%Y%m%d").to_string());
434
435        // Build the FIX message
436        let fix_message = position_request.to_fix_message(
437            self.config.sender_comp_id.clone(),
438            self.config.target_comp_id.clone(),
439            self.outgoing_seq_num,
440        )?;
441
442        // Send the position request
443        self.send_message(fix_message).await?;
444        self.outgoing_seq_num += 1;
445
446        info!(
447            "Position request sent, awaiting responses for request ID: {}",
448            request_id
449        );
450
451        // Collect position reports with correlation by PosReqID
452        let mut positions = Vec::new();
453        let timeout = Duration::from_secs(30); // 30 second timeout
454        let start_time = Instant::now();
455
456        loop {
457            // Check for timeout
458            if start_time.elapsed() > timeout {
459                warn!("Position request timed out after {:?}", timeout);
460                break;
461            }
462
463            // Receive and process messages
464            match self.receive_and_process_message().await {
465                Ok(Some(message)) => {
466                    // Check if this is a PositionReport message
467                    if let Some(msg_type_str) = message.get_field(35)
468                        && msg_type_str == "AP"
469                    {
470                        // PositionReport
471                        // Check if this position report matches our request ID
472                        if let Some(pos_req_id) = message.get_field(710) {
473                            if pos_req_id == &request_id {
474                                debug!("Received PositionReport for request: {}", request_id);
475
476                                // Check if this is an empty position report (no instrument name)
477                                if message.get_field(55).is_none() {
478                                    info!(
479                                        "Received empty position report - indicates no active positions for this request"
480                                    );
481                                    debug!(
482                                        "Empty PositionReport details: PosReqID={}, PosMaintRptID={:?}",
483                                        request_id,
484                                        message.get_field(721)
485                                    );
486                                    // This is an end-of-positions marker indicating no positions, continue waiting
487                                    continue;
488                                }
489
490                                match PositionReport::try_from_fix_message(&message) {
491                                    Ok(position) => {
492                                        debug!("Successfully parsed position: {:?}", position);
493                                        positions.push(position);
494                                    }
495                                    Err(e) => {
496                                        warn!("Failed to parse PositionReport: {}", e);
497                                        debug!("Message fields: {:?}", message);
498                                    }
499                                }
500                            } else {
501                                debug!(
502                                    "Received PositionReport for different request: {}",
503                                    pos_req_id
504                                );
505                            }
506                        }
507                    }
508                }
509                Ok(None) => {
510                    // No message received, continue loop
511                    tokio::time::sleep(Duration::from_millis(10)).await;
512                }
513                Err(e) => {
514                    warn!("Error receiving message: {}", e);
515                    // Continue trying to receive more messages
516                    tokio::time::sleep(Duration::from_millis(100)).await;
517                }
518            }
519
520            // For now, we'll break after receiving some positions or after a reasonable time
521            // In a real implementation, we might wait for an end-of-transmission signal
522            if !positions.is_empty() && start_time.elapsed() > Duration::from_secs(5) {
523                debug!(
524                    "Received {} positions, stopping collection",
525                    positions.len()
526                );
527                break;
528            }
529        }
530
531        info!(
532            "Position request completed, received {} positions",
533            positions.len()
534        );
535        Ok(positions)
536    }
537
538    /// Generate authentication data according to Deribit FIX specification
539    /// Returns (raw_data, base64_password_hash)
540    pub fn generate_auth_data(&self, access_secret: &str) -> Result<(String, String)> {
541        // Generate timestamp (strictly increasing integer in milliseconds)
542        let timestamp = Utc::now().timestamp_millis();
543
544        // Generate random nonce (at least 32 bytes as recommended by Deribit)
545        let mut nonce_bytes = vec![0u8; 32];
546        for byte in nonce_bytes.iter_mut() {
547            *byte = rand::random::<u8>();
548        }
549        let nonce_b64 = BASE64_STANDARD.encode(&nonce_bytes);
550
551        // Create RawData: timestamp.nonce (separated by ASCII period)
552        let raw_data = format!("{timestamp}.{nonce_b64}");
553
554        // Calculate password hash: base64(sha256(RawData ++ access_secret))
555        let mut auth_data = raw_data.as_bytes().to_vec();
556        auth_data.extend_from_slice(access_secret.as_bytes());
557
558        debug!("Auth Data at Timestamp: {}", timestamp);
559        trace!("Nonce length: {} bytes", nonce_bytes.len());
560        trace!("Nonce (base64): {}", nonce_b64);
561        trace!("RawData: {}", raw_data);
562        trace!("Access secret: {}", access_secret);
563        trace!("Auth data length: {} bytes", auth_data.len());
564
565        let mut hasher = Sha256::new();
566        hasher.update(&auth_data);
567        let hash_result = hasher.finalize();
568        let password_hash = BASE64_STANDARD.encode(hash_result);
569
570        debug!("Password hash: {}", password_hash);
571
572        Ok((raw_data, password_hash))
573    }
574
575    /// Calculate application signature for registered apps
576    #[allow(dead_code)]
577    fn calculate_app_signature(&self, raw_data: &str, app_secret: &str) -> Result<String> {
578        let mut hasher = Sha256::new();
579        hasher.update(format!("{raw_data}{app_secret}").as_bytes());
580        let result = hasher.finalize();
581        Ok(BASE64_STANDARD.encode(result))
582    }
583
584    /// Get current session state
585    pub fn state(&self) -> SessionState {
586        self.state
587    }
588
589    /// Set session state (for testing)
590    pub fn set_state(&mut self, state: SessionState) {
591        self.state = state;
592    }
593
594    /// Process incoming FIX message
595    async fn process_message(&mut self, message: &FixMessage) -> Result<()> {
596        debug!("Processing FIX message: {:?}", message);
597
598        // Get message type
599        let msg_type_str = message.get_field(35).unwrap_or(&String::new()).clone();
600
601        // Skip messages with empty or missing message type
602        if msg_type_str.is_empty() {
603            debug!("Skipping message with empty message type");
604            return Ok(());
605        }
606
607        let msg_type = MsgType::from_str(&msg_type_str).map_err(|_| {
608            DeribitFixError::MessageParsing(format!("Unknown message type: {msg_type_str}"))
609        })?;
610
611        match msg_type {
612            MsgType::Logon => {
613                info!("Received logon response");
614                self.state = SessionState::LoggedOn;
615            }
616            MsgType::Logout => {
617                info!("Received logout message");
618                self.state = SessionState::Disconnected;
619            }
620            MsgType::Heartbeat => {
621                debug!("Received heartbeat");
622            }
623            MsgType::TestRequest => {
624                debug!("Received test request, sending heartbeat response");
625                let test_req_id = message.get_field(112);
626                self.send_heartbeat(test_req_id.cloned()).await?;
627            }
628            MsgType::ExecutionReport => {
629                debug!("Received ExecutionReport: {:?}", message);
630                // ExecutionReport processing - let the client handle the details
631            }
632            MsgType::PositionReport => {
633                debug!("Received PositionReport: {:?}", message);
634                // PositionReport processing - let the client handle the details
635            }
636            MsgType::Reject => {
637                error!("Received Reject message: {:?}", message);
638                // Reject message processing
639            }
640            _ => {
641                debug!("Received message type: {:?}", msg_type);
642            }
643        }
644
645        self.incoming_seq_num += 1;
646        Ok(())
647    }
648
649    /// Receive and process a FIX message from the connection
650    pub async fn receive_and_process_message(&mut self) -> Result<Option<FixMessage>> {
651        let message = if let Some(connection) = &self.connection {
652            let mut conn_guard = connection.lock().await;
653            conn_guard.receive_message().await?
654        } else {
655            None
656        };
657
658        if let Some(message) = message {
659            self.process_message(&message).await?;
660            Ok(Some(message))
661        } else {
662            Ok(None)
663        }
664    }
665}