Skip to main content

deribit_fix/session/
fix_session.rs

1//! FIX session management
2
3use crate::config::gen_id;
4use crate::model::message::FixMessage;
5use crate::model::position::Position;
6use crate::model::request::{NewOrderRequest, OrderSide, OrderType, TimeInForce};
7use crate::model::types::MsgType;
8use crate::{
9    config::DeribitFixConfig,
10    connection::Connection,
11    error::{DeribitFixError, Result},
12    message::{MessageBuilder, PositionReport, RequestForPositions},
13};
14use base64::prelude::*;
15use chrono::Utc;
16use rand;
17use sha2::{Digest, Sha256};
18use std::str::FromStr;
19use std::sync::Arc;
20use tokio::sync::Mutex;
21use tracing::{debug, error, info, trace};
22
23/// FIX session state
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum SessionState {
26    /// Session is disconnected
27    Disconnected,
28    /// Logon message sent, waiting for response
29    LogonSent,
30    /// Session is logged on and active
31    LoggedOn,
32    /// Logout message sent, waiting for confirmation
33    LogoutSent,
34}
35
36/// FIX session manager
37pub struct Session {
38    config: DeribitFixConfig,
39    connection: Option<Arc<Mutex<Connection>>>,
40    state: SessionState,
41    outgoing_seq_num: u32,
42    incoming_seq_num: u32,
43}
44
45impl Session {
46    /// Create a new FIX session
47    pub fn new(config: &DeribitFixConfig, connection: Arc<Mutex<Connection>>) -> Result<Self> {
48        info!("Creating new FIX session");
49        Ok(Self {
50            config: config.clone(),
51            state: SessionState::Disconnected,
52            outgoing_seq_num: 1,
53            incoming_seq_num: 1,
54            connection: Some(connection),
55        })
56    }
57
58    /// Set the connection for this session
59    pub fn set_connection(&mut self, connection: Arc<Mutex<Connection>>) {
60        self.connection = Some(connection);
61    }
62
63    /// Get the current session state
64    pub fn get_state(&self) -> SessionState {
65        self.state
66    }
67
68    /// Send a FIX message through the connection
69    async fn send_message(&mut self, message: FixMessage) -> Result<()> {
70        if let Some(connection) = &self.connection {
71            let mut conn_guard = connection.lock().await;
72            conn_guard.send_message(&message).await?;
73            debug!("Sent FIX message: {}", message.to_string());
74        } else {
75            return Err(DeribitFixError::Connection(
76                "No connection available".to_string(),
77            ));
78        }
79        Ok(())
80    }
81
82    /// Perform FIX logon
83    pub async fn logon(&mut self) -> Result<()> {
84        info!("Performing FIX logon");
85
86        // Generate RawData and password hash according to Deribit FIX spec
87        let (raw_data, password_hash) = self.generate_auth_data(&self.config.password)?;
88
89        let mut message_builder = MessageBuilder::new()
90            .msg_type(MsgType::Logon)
91            .sender_comp_id(self.config.sender_comp_id.clone())
92            .target_comp_id(self.config.target_comp_id.clone())
93            .msg_seq_num(self.outgoing_seq_num)
94            .field(108, self.config.heartbeat_interval.to_string()) // HeartBtInt - Required
95            .field(96, raw_data.clone()) // RawData - Required (timestamp.nonce)
96            .field(553, self.config.username.clone()) // Username - Required
97            .field(554, password_hash); // Password - Required
98
99        // Add RawDataLength if needed (optional but recommended)
100        message_builder = message_builder.field(95, raw_data.len().to_string()); // RawDataLength
101
102        // Add optional Deribit-specific tags based on configuration
103        if let Some(use_wordsafe_tags) = &self.config.use_wordsafe_tags {
104            message_builder =
105                message_builder.field(9002, if *use_wordsafe_tags { "Y" } else { "N" }.to_string()); // UseWordsafeTags
106        }
107
108        // CancelOnDisconnect - always include based on config
109        message_builder = message_builder.field(
110            9001,
111            if self.config.cancel_on_disconnect {
112                "Y"
113            } else {
114                "N"
115            }
116            .to_string(),
117        ); // CancelOnDisconnect
118
119        if let Some(app_id) = &self.config.app_id {
120            message_builder = message_builder.field(9004, app_id.clone()); // DeribitAppId
121        }
122
123        if let Some(app_secret) = &self.config.app_secret
124            && let Some(raw_data_str) = raw_data
125                .split_once('.')
126                .map(|(timestamp, nonce)| format!("{}.{}", timestamp, nonce))
127            && let Ok(app_sig) = self.calculate_app_signature(&raw_data_str, app_secret)
128        {
129            message_builder = message_builder.field(9005, app_sig); // DeribitAppSig
130        }
131
132        if let Some(deribit_sequential) = &self.config.deribit_sequential {
133            message_builder = message_builder.field(
134                9007,
135                if *deribit_sequential { "Y" } else { "N" }.to_string(),
136            ); // DeribitSequential
137        }
138
139        if let Some(unsubscribe_exec_reports) = &self.config.unsubscribe_execution_reports {
140            message_builder = message_builder.field(
141                9009,
142                if *unsubscribe_exec_reports { "Y" } else { "N" }.to_string(),
143            ); // UnsubscribeExecutionReports
144        }
145
146        if let Some(connection_only_exec_reports) = &self.config.connection_only_execution_reports {
147            message_builder = message_builder.field(
148                9010,
149                if *connection_only_exec_reports {
150                    "Y"
151                } else {
152                    "N"
153                }
154                .to_string(),
155            ); // ConnectionOnlyExecutionReports
156        }
157
158        if let Some(report_fills_as_exec_reports) = &self.config.report_fills_as_exec_reports {
159            message_builder = message_builder.field(
160                9015,
161                if *report_fills_as_exec_reports {
162                    "Y"
163                } else {
164                    "N"
165                }
166                .to_string(),
167            ); // ReportFillsAsExecReports
168        }
169
170        if let Some(display_increment_steps) = &self.config.display_increment_steps {
171            message_builder = message_builder.field(
172                9018,
173                if *display_increment_steps { "Y" } else { "N" }.to_string(),
174            ); // DisplayIncrementSteps
175        }
176
177        // Add AppID if available
178        if let Some(app_id) = &self.config.app_id {
179            message_builder = message_builder.field(1128, app_id.clone()); // AppID
180        }
181
182        let logon_message = message_builder.build()?;
183
184        // Send the logon message
185        self.send_message(logon_message).await?;
186        self.state = SessionState::LogonSent;
187        self.outgoing_seq_num += 1;
188
189        info!("Logon message sent");
190        Ok(())
191    }
192
193    /// Perform FIX logout
194    pub async fn logout(&mut self) -> Result<()> {
195        self.logout_with_options(None, None).await
196    }
197
198    /// Perform FIX logout with optional parameters
199    pub async fn logout_with_options(
200        &mut self,
201        text: Option<String>,
202        dont_cancel_on_disconnect: Option<bool>,
203    ) -> Result<()> {
204        info!("Performing FIX logout");
205
206        let mut message_builder = MessageBuilder::new()
207            .msg_type(MsgType::Logout)
208            .sender_comp_id(self.config.sender_comp_id.clone())
209            .target_comp_id(self.config.target_comp_id.clone())
210            .msg_seq_num(self.outgoing_seq_num);
211
212        // Add Text field (tag 58) - optional
213        let logout_text = text.unwrap_or_else(|| "Normal logout".to_string());
214        message_builder = message_builder.field(58, logout_text); // Text
215
216        // Add DontCancelOnDisconnect field (tag 9003) - optional
217        if let Some(dont_cancel) = dont_cancel_on_disconnect {
218            message_builder =
219                message_builder.field(9003, if dont_cancel { "Y" } else { "N" }.to_string()); // DontCancelOnDisconnect
220        }
221
222        let logout_message = message_builder.build()?;
223
224        // Send the logout message
225        self.send_message(logout_message).await?;
226        self.state = SessionState::LogoutSent;
227        self.outgoing_seq_num += 1;
228
229        info!("Logout message sent");
230        Ok(())
231    }
232
233    /// Send a heartbeat message
234    pub async fn send_heartbeat(&mut self, test_req_id: Option<String>) -> Result<()> {
235        debug!("Sending heartbeat message");
236
237        let mut builder = MessageBuilder::new()
238            .msg_type(MsgType::Heartbeat)
239            .sender_comp_id(self.config.sender_comp_id.clone())
240            .target_comp_id(self.config.target_comp_id.clone())
241            .msg_seq_num(self.outgoing_seq_num);
242
243        if let Some(test_req_id) = test_req_id {
244            builder = builder.field(112, test_req_id); // TestReqID
245        }
246
247        let heartbeat_message = builder.build()?;
248
249        // Send the heartbeat message
250        self.send_message(heartbeat_message).await?;
251        self.outgoing_seq_num += 1;
252
253        debug!("Heartbeat message sent");
254        Ok(())
255    }
256
257    /// Send a new order
258    pub async fn send_new_order(&mut self, order: NewOrderRequest) -> Result<String> {
259        info!("Sending new order: {:?}", order);
260
261        // Use the client order ID if provided, otherwise generate one
262        let order_id = order
263            .client_order_id
264            .clone()
265            .unwrap_or_else(|| format!("ORDER_{}", gen_id()));
266
267        // Determine order type
268        let ord_type = match order.order_type {
269            OrderType::Market => "1",     // Market
270            OrderType::Limit => "2",      // Limit
271            OrderType::StopLimit => "4",  // Stop Limit
272            OrderType::StopMarket => "3", // Stop Market
273            _ => "2",                     // Default to Limit for other types
274        };
275
276        let mut builder = MessageBuilder::new()
277            .msg_type(MsgType::NewOrderSingle)
278            .sender_comp_id(self.config.sender_comp_id.clone())
279            .target_comp_id(self.config.target_comp_id.clone())
280            .msg_seq_num(self.outgoing_seq_num)
281            .field(11, order_id.clone()) // ClOrdID
282            .field(55, order.instrument_name.clone()) // Symbol
283            .field(
284                54,
285                match order.side {
286                    OrderSide::Buy => "1".to_string(),
287                    OrderSide::Sell => "2".to_string(),
288                },
289            ) // Side
290            .field(60, Utc::now().format("%Y%m%d-%H:%M:%S%.3f").to_string()) // TransactTime
291            .field(38, order.amount.to_string()) // OrderQty
292            .field(40, ord_type.to_string()); // OrdType
293
294        // Add price for limit orders
295        if order.order_type == OrderType::Limit || order.price.is_some() {
296            builder = builder.field(44, order.price.unwrap_or(0.0).to_string()); // Price
297        }
298
299        // Add time in force
300        let tif = match order.time_in_force {
301            TimeInForce::GoodTilCancelled => "1",
302            TimeInForce::ImmediateOrCancel => "3",
303            TimeInForce::FillOrKill => "4",
304            TimeInForce::GoodTilDay => "6",
305        };
306        builder = builder.field(59, tif.to_string()); // TimeInForce
307
308        // Add execution instructions
309        let mut exec_inst = String::new();
310        if order.post_only == Some(true) {
311            exec_inst.push('6'); // Participate don't initiate
312        }
313        if order.reduce_only == Some(true) {
314            exec_inst.push('E'); // Reduce only
315        }
316        if !exec_inst.is_empty() {
317            builder = builder.field(18, exec_inst); // ExecInst
318        }
319
320        // Add label if provided
321        if let Some(label) = &order.label {
322            builder = builder.field(100010, label.clone()); // Deribit label
323        }
324
325        let order_message = builder.build()?;
326
327        // Actually send the message
328        self.send_message(order_message).await?;
329        self.outgoing_seq_num += 1;
330
331        info!("New order message sent with ID: {}", order_id);
332        Ok(order_id)
333    }
334
335    /// Cancel an order
336    ///
337    /// # Arguments
338    /// * `order_id` - The order identifier (OrigClOrdID) to cancel
339    /// * `symbol` - Optional instrument symbol. Required when canceling by ClOrdID or DeribitLabel,
340    ///   but not required when using OrigClOrdID (fastest approach)
341    /// * `currency` - Optional currency to speed up search when using ClOrdID or DeribitLabel
342    pub async fn cancel_order(&mut self, order_id: String) -> Result<()> {
343        self.cancel_order_with_symbol(order_id, None).await
344    }
345
346    /// Cancel an order with optional symbol specification
347    ///
348    /// According to Deribit FIX documentation:
349    /// - Canceling by OrigClOrdId is fastest and recommended when possible
350    /// - Symbol is required only when OrigClOrdId is absent (canceling by ClOrdID or DeribitLabel)
351    /// - Currency can optionally speed up searches by DeribitLabel or ClOrdID
352    ///
353    /// # Arguments
354    /// * `order_id` - The order identifier (OrigClOrdID) to cancel
355    /// * `symbol` - Optional instrument symbol (e.g., "BTC-PERPETUAL")
356    /// * `currency` - Optional currency to speed up search
357    pub async fn cancel_order_with_symbol(
358        &mut self,
359        order_id: String,
360        symbol: Option<String>,
361    ) -> Result<()> {
362        info!("Cancelling order: {} with symbol: {:?}", order_id, symbol);
363
364        // Generate a proper unique cancel ID using random number instead of timestamp
365        let cancel_id = format!("CANCEL_{}", gen_id());
366
367        let mut builder = MessageBuilder::new()
368            .msg_type(MsgType::OrderCancelRequest)
369            .sender_comp_id(self.config.sender_comp_id.clone())
370            .target_comp_id(self.config.target_comp_id.clone())
371            .msg_seq_num(self.outgoing_seq_num)
372            .field(11, cancel_id) // ClOrdID - Original order identifier assigned by the user
373            .field(41, order_id) // OrigClOrdID - Order identifier assigned by Deribit
374            .field(60, Utc::now().format("%Y%m%d-%H:%M:%S%.3f").to_string()); // TransactTime
375
376        // Add symbol if provided - required when OrigClOrdId is absent
377        if let Some(symbol_value) = symbol {
378            builder = builder.field(55, symbol_value); // Symbol
379        }
380
381        let cancel_message = builder.build()?;
382
383        // Actually send the cancel message
384        self.send_message(cancel_message).await?;
385        self.outgoing_seq_num += 1;
386
387        info!("Order cancel message sent");
388        Ok(())
389    }
390
391    /// Subscribe to market data
392    pub async fn subscribe_market_data(&mut self, symbol: String) -> Result<()> {
393        info!("Subscribing to market data for: {}", symbol);
394
395        let request_id = format!("MDR_{}", gen_id());
396
397        let market_data_request = MessageBuilder::new()
398            .msg_type(MsgType::MarketDataRequest)
399            .sender_comp_id(self.config.sender_comp_id.clone())
400            .target_comp_id(self.config.target_comp_id.clone())
401            .msg_seq_num(self.outgoing_seq_num)
402            .field(262, request_id.clone()) // MDReqID
403            .field(263, "1".to_string()) // SubscriptionRequestType (1 = Snapshot + Updates)
404            .field(264, "0".to_string()) // MarketDepth (0 = Full Book)
405            .field(267, "2".to_string()) // NoMDEntryTypes
406            .field(269, "0".to_string()) // MDEntryType (0 = Bid)
407            .field(269, "1".to_string()) // MDEntryType (1 = Offer)
408            .field(146, "1".to_string()) // NoRelatedSym
409            .field(55, symbol.clone()) // Symbol
410            .build()?;
411
412        // Send the market data request
413        self.send_message(market_data_request).await?;
414        self.outgoing_seq_num += 1;
415
416        info!(
417            "Market data subscription request sent for symbol: {} with ID: {}",
418            symbol, request_id
419        );
420        Ok(())
421    }
422
423    /// Request positions asynchronously
424    pub async fn request_positions(&mut self) -> Result<Vec<Position>> {
425        use std::time::{Duration, Instant};
426        use tracing::{debug, info, warn};
427
428        info!("Requesting positions");
429
430        let request_id = format!("POS_{}", gen_id());
431
432        // Create typed position request
433        let position_request = RequestForPositions::all_positions(request_id.clone())
434            .with_clearing_date(Utc::now().format("%Y%m%d").to_string());
435
436        // Build the FIX message
437        let fix_message = position_request.to_fix_message(
438            self.config.sender_comp_id.clone(),
439            self.config.target_comp_id.clone(),
440            self.outgoing_seq_num,
441        )?;
442
443        // Send the position request
444        self.send_message(fix_message).await?;
445        self.outgoing_seq_num += 1;
446
447        info!(
448            "Position request sent, awaiting responses for request ID: {}",
449            request_id
450        );
451
452        // Collect position reports with correlation by PosReqID
453        let mut positions = Vec::new();
454        let timeout = Duration::from_secs(30); // 30 second timeout
455        let start_time = Instant::now();
456
457        loop {
458            // Check for timeout
459            if start_time.elapsed() > timeout {
460                warn!("Position request timed out after {:?}", timeout);
461                break;
462            }
463
464            // Receive and process messages
465            match self.receive_and_process_message().await {
466                Ok(Some(message)) => {
467                    // Check if this is a PositionReport message
468                    if let Some(msg_type_str) = message.get_field(35)
469                        && msg_type_str == "AP"
470                    {
471                        // PositionReport
472                        // Check if this position report matches our request ID
473                        if let Some(pos_req_id) = message.get_field(710) {
474                            if pos_req_id == &request_id {
475                                debug!("Received PositionReport for request: {}", request_id);
476
477                                // Check if this is an empty position report (no instrument name)
478                                if message.get_field(55).is_none() {
479                                    info!(
480                                        "Received empty position report - indicates no active positions for this request"
481                                    );
482                                    debug!(
483                                        "Empty PositionReport details: PosReqID={}, PosMaintRptID={:?}",
484                                        request_id,
485                                        message.get_field(721)
486                                    );
487                                    // This is an end-of-positions marker indicating no positions, continue waiting
488                                    continue;
489                                }
490
491                                match PositionReport::try_from_fix_message(&message) {
492                                    Ok(position) => {
493                                        debug!("Successfully parsed position: {:?}", position);
494                                        positions.push(position);
495                                    }
496                                    Err(e) => {
497                                        warn!("Failed to parse PositionReport: {}", e);
498                                        debug!("Message fields: {:?}", message);
499                                    }
500                                }
501                            } else {
502                                debug!(
503                                    "Received PositionReport for different request: {}",
504                                    pos_req_id
505                                );
506                            }
507                        }
508                    }
509                }
510                Ok(None) => {
511                    // No message received, continue loop
512                    tokio::time::sleep(Duration::from_millis(10)).await;
513                }
514                Err(e) => {
515                    warn!("Error receiving message: {}", e);
516                    // Continue trying to receive more messages
517                    tokio::time::sleep(Duration::from_millis(100)).await;
518                }
519            }
520
521            // For now, we'll break after receiving some positions or after a reasonable time
522            // In a real implementation, we might wait for an end-of-transmission signal
523            if !positions.is_empty() && start_time.elapsed() > Duration::from_secs(5) {
524                debug!(
525                    "Received {} positions, stopping collection",
526                    positions.len()
527                );
528                break;
529            }
530        }
531
532        info!(
533            "Position request completed, received {} positions",
534            positions.len()
535        );
536        Ok(positions)
537    }
538
539    /// Generate authentication data according to Deribit FIX specification
540    /// Returns (raw_data, base64_password_hash)
541    pub fn generate_auth_data(&self, access_secret: &str) -> Result<(String, String)> {
542        // Generate timestamp (strictly increasing integer in milliseconds)
543        let timestamp = Utc::now().timestamp_millis();
544
545        // Generate random nonce (at least 32 bytes as recommended by Deribit)
546        let mut nonce_bytes = vec![0u8; 32];
547        for byte in nonce_bytes.iter_mut() {
548            *byte = rand::random::<u8>();
549        }
550        let nonce_b64 = BASE64_STANDARD.encode(&nonce_bytes);
551
552        // Create RawData: timestamp.nonce (separated by ASCII period)
553        let raw_data = format!("{timestamp}.{nonce_b64}");
554
555        // Calculate password hash: base64(sha256(RawData ++ access_secret))
556        let mut auth_data = raw_data.as_bytes().to_vec();
557        auth_data.extend_from_slice(access_secret.as_bytes());
558
559        debug!("Auth Data at Timestamp: {}", timestamp);
560        trace!("Nonce length: {} bytes", nonce_bytes.len());
561        trace!("Nonce (base64): {}", nonce_b64);
562        trace!("RawData: {}", raw_data);
563        trace!("Access secret: {}", access_secret);
564        trace!("Auth data length: {} bytes", auth_data.len());
565
566        let mut hasher = Sha256::new();
567        hasher.update(&auth_data);
568        let hash_result = hasher.finalize();
569        let password_hash = BASE64_STANDARD.encode(hash_result);
570
571        debug!("Password hash: {}", password_hash);
572
573        Ok((raw_data, password_hash))
574    }
575
576    /// Calculate application signature for registered apps
577    #[allow(dead_code)]
578    fn calculate_app_signature(&self, raw_data: &str, app_secret: &str) -> Result<String> {
579        let mut hasher = Sha256::new();
580        hasher.update(format!("{raw_data}{app_secret}").as_bytes());
581        let result = hasher.finalize();
582        Ok(BASE64_STANDARD.encode(result))
583    }
584
585    /// Get current session state
586    pub fn state(&self) -> SessionState {
587        self.state
588    }
589
590    /// Set session state (for testing)
591    pub fn set_state(&mut self, state: SessionState) {
592        self.state = state;
593    }
594
595    /// Process incoming FIX message
596    async fn process_message(&mut self, message: &FixMessage) -> Result<()> {
597        debug!("Processing FIX message: {:?}", message);
598
599        // Get message type
600        let msg_type_str = message.get_field(35).unwrap_or(&String::new()).clone();
601
602        // Skip messages with empty or missing message type
603        if msg_type_str.is_empty() {
604            debug!("Skipping message with empty message type");
605            return Ok(());
606        }
607
608        let msg_type = MsgType::from_str(&msg_type_str).map_err(|_| {
609            DeribitFixError::MessageParsing(format!("Unknown message type: {msg_type_str}"))
610        })?;
611
612        match msg_type {
613            MsgType::Logon => {
614                info!("Received logon response");
615                self.state = SessionState::LoggedOn;
616            }
617            MsgType::Logout => {
618                info!("Received logout message");
619                self.state = SessionState::Disconnected;
620            }
621            MsgType::Heartbeat => {
622                debug!("Received heartbeat");
623            }
624            MsgType::TestRequest => {
625                debug!("Received test request, sending heartbeat response");
626                let test_req_id = message.get_field(112);
627                self.send_heartbeat(test_req_id.cloned()).await?;
628            }
629            MsgType::ExecutionReport => {
630                debug!("Received ExecutionReport: {:?}", message);
631                // ExecutionReport processing - let the client handle the details
632            }
633            MsgType::PositionReport => {
634                debug!("Received PositionReport: {:?}", message);
635                // PositionReport processing - let the client handle the details
636            }
637            MsgType::Reject => {
638                error!("Received Reject message: {:?}", message);
639                // Reject message processing
640            }
641            _ => {
642                debug!("Received message type: {:?}", msg_type);
643            }
644        }
645
646        self.incoming_seq_num += 1;
647        Ok(())
648    }
649
650    /// Receive and process a FIX message from the connection
651    pub async fn receive_and_process_message(&mut self) -> Result<Option<FixMessage>> {
652        let message = if let Some(connection) = &self.connection {
653            let mut conn_guard = connection.lock().await;
654            conn_guard.receive_message().await?
655        } else {
656            None
657        };
658
659        if let Some(message) = message {
660            self.process_message(&message).await?;
661            Ok(Some(message))
662        } else {
663            Ok(None)
664        }
665    }
666}