Skip to main content

bybit_client/ws/
trade.rs

1//! WebSocket Trade API client for order management.
2//!
3//! This module provides a client for submitting, amending, and canceling orders
4//! via WebSocket connections. This offers lower latency compared to REST API.
5//!
6//! # Example
7//!
8//! ```no_run
9//! use bybit_client::ws::{WsTradeClient, CreateOrderRequest};
10//! use bybit_client::ClientConfig;
11//! use bybit_client::types::{Category, Side, OrderType, TimeInForce};
12//!
13//! #[tokio::main]
14//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
15//!     let config = ClientConfig::new("api_key", "api_secret");
16//!     let client = WsTradeClient::connect(config).await?;
17//!
18//!     // Submit a new order.
19//!     let result = client.create_order(CreateOrderRequest {
20//!         category: Category::Linear,
21//!         symbol: "BTCUSDT".to_string(),
22//!         side: Side::Buy,
23//!         order_type: OrderType::Limit,
24//!         qty: "0.001".to_string(),
25//!         price: Some("50000".to_string()),
26//!         time_in_force: Some(TimeInForce::GTC),
27//!         order_link_id: None,
28//!         is_leverage: None,
29//!         position_idx: None,
30//!         reduce_only: None,
31//!         close_on_trigger: None,
32//!         take_profit: None,
33//!         stop_loss: None,
34//!         tpsl_mode: None,
35//!         market_unit: None,
36//!     }).await?;
37//!
38//!     println!("Order created: {}", result.order_id);
39//!     Ok(())
40//! }
41//! ```
42
43use std::collections::HashMap;
44use std::sync::atomic::{AtomicU64, Ordering};
45use std::sync::Arc;
46use std::time::Duration;
47
48use futures_util::{SinkExt, StreamExt};
49use serde::{Deserialize, Serialize};
50use tokio::sync::{mpsc, oneshot, RwLock};
51use tokio::time::timeout;
52use tokio_tungstenite::tungstenite::Message;
53use tracing::{debug, error, info};
54
55use crate::auth::{current_timestamp_ms, sign_rest_request};
56use crate::config::ClientConfig;
57use crate::error::BybitError;
58use crate::types::{Category, OrderType, Side, TimeInForce};
59
60/// Default timeout for trade operations.
61const DEFAULT_TIMEOUT_MS: u64 = 10000;
62/// Default receive window.
63const DEFAULT_RECV_WINDOW: u32 = 5000;
64
65
66/// Request for creating a new order.
67#[derive(Debug, Clone, Serialize)]
68#[serde(rename_all = "camelCase")]
69pub struct CreateOrderRequest {
70    /// Product category.
71    pub category: Category,
72    /// Trading symbol.
73    pub symbol: String,
74    /// Order side (Buy/Sell).
75    pub side: Side,
76    /// Order type (Limit/Market).
77    pub order_type: OrderType,
78    /// Order quantity.
79    pub qty: String,
80    /// Order price (required for Limit orders).
81    #[serde(skip_serializing_if = "Option::is_none")]
82    pub price: Option<String>,
83    /// Time in force.
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub time_in_force: Option<TimeInForce>,
86    /// User custom order ID.
87    #[serde(skip_serializing_if = "Option::is_none")]
88    pub order_link_id: Option<String>,
89    /// Is leverage order (for spot margin).
90    #[serde(skip_serializing_if = "Option::is_none")]
91    pub is_leverage: Option<i32>,
92    /// Position index (for hedged mode).
93    #[serde(skip_serializing_if = "Option::is_none")]
94    pub position_idx: Option<i32>,
95    /// Reduce only order.
96    #[serde(skip_serializing_if = "Option::is_none")]
97    pub reduce_only: Option<bool>,
98    /// Close on trigger.
99    #[serde(skip_serializing_if = "Option::is_none")]
100    pub close_on_trigger: Option<bool>,
101    /// Take profit price.
102    #[serde(skip_serializing_if = "Option::is_none")]
103    pub take_profit: Option<String>,
104    /// Stop loss price.
105    #[serde(skip_serializing_if = "Option::is_none")]
106    pub stop_loss: Option<String>,
107    /// TP/SL mode (Full/Partial).
108    #[serde(skip_serializing_if = "Option::is_none")]
109    pub tpsl_mode: Option<String>,
110    /// Market unit (baseCoin/quoteCoin for spot).
111    #[serde(skip_serializing_if = "Option::is_none")]
112    pub market_unit: Option<String>,
113}
114
115/// Request for amending an existing order.
116#[derive(Debug, Clone, Serialize)]
117#[serde(rename_all = "camelCase")]
118pub struct AmendOrderRequest {
119    /// Product category.
120    pub category: Category,
121    /// Trading symbol.
122    pub symbol: String,
123    /// Order ID (required if orderLinkId not provided).
124    #[serde(skip_serializing_if = "Option::is_none")]
125    pub order_id: Option<String>,
126    /// User custom order ID (required if orderId not provided).
127    #[serde(skip_serializing_if = "Option::is_none")]
128    pub order_link_id: Option<String>,
129    /// New quantity.
130    #[serde(skip_serializing_if = "Option::is_none")]
131    pub qty: Option<String>,
132    /// New price.
133    #[serde(skip_serializing_if = "Option::is_none")]
134    pub price: Option<String>,
135    /// New take profit.
136    #[serde(skip_serializing_if = "Option::is_none")]
137    pub take_profit: Option<String>,
138    /// New stop loss.
139    #[serde(skip_serializing_if = "Option::is_none")]
140    pub stop_loss: Option<String>,
141    /// New TP trigger price.
142    #[serde(skip_serializing_if = "Option::is_none")]
143    pub tp_limit_price: Option<String>,
144    /// New SL trigger price.
145    #[serde(skip_serializing_if = "Option::is_none")]
146    pub sl_limit_price: Option<String>,
147}
148
149/// Request for canceling an order.
150#[derive(Debug, Clone, Serialize)]
151#[serde(rename_all = "camelCase")]
152pub struct CancelOrderRequest {
153    /// Product category.
154    pub category: Category,
155    /// Trading symbol.
156    pub symbol: String,
157    /// Order ID (required if orderLinkId not provided).
158    #[serde(skip_serializing_if = "Option::is_none")]
159    pub order_id: Option<String>,
160    /// User custom order ID (required if orderId not provided).
161    #[serde(skip_serializing_if = "Option::is_none")]
162    pub order_link_id: Option<String>,
163}
164
165
166/// Result of a single order operation.
167#[derive(Debug, Clone, Deserialize)]
168#[serde(rename_all = "camelCase")]
169pub struct OrderResult {
170    /// Order ID.
171    pub order_id: String,
172    /// User custom order ID.
173    #[serde(default)]
174    pub order_link_id: Option<String>,
175}
176
177/// Result of a batch order operation.
178#[derive(Debug, Clone, Deserialize)]
179#[serde(rename_all = "camelCase")]
180pub struct BatchOrderResult {
181    /// Category.
182    pub category: String,
183    /// Symbol.
184    pub symbol: String,
185    /// Order ID.
186    pub order_id: String,
187    /// User custom order ID.
188    #[serde(default)]
189    pub order_link_id: Option<String>,
190    /// Create type.
191    #[serde(default)]
192    pub create_type: Option<String>,
193}
194
195/// WebSocket Trade API response.
196#[derive(Debug, Clone, Deserialize)]
197#[serde(rename_all = "camelCase")]
198pub struct WsTradeResponse {
199    /// Request ID (echoed from request).
200    pub req_id: String,
201    /// Return code (0 = success).
202    pub ret_code: i32,
203    /// Return message.
204    pub ret_msg: String,
205    /// Operation type.
206    pub op: String,
207    /// Response data.
208    #[serde(default)]
209    pub data: serde_json::Value,
210    /// Connection ID.
211    #[serde(default)]
212    pub conn_id: Option<String>,
213}
214
215impl WsTradeResponse {
216    /// Check if the response indicates success.
217    pub fn is_success(&self) -> bool {
218        self.ret_code == 0
219    }
220
221    /// Convert to a typed result.
222    pub fn into_result<T: for<'de> Deserialize<'de>>(self) -> Result<T, BybitError> {
223        if self.is_success() {
224            serde_json::from_value(self.data)
225                .map_err(|e| BybitError::Serialization(e))
226        } else {
227            Err(BybitError::api_error(self.ret_code, self.ret_msg))
228        }
229    }
230}
231
232
233/// WebSocket Trade API request wrapper.
234#[derive(Debug, Serialize)]
235#[serde(rename_all = "camelCase")]
236struct WsTradeRequest<T> {
237    req_id: String,
238    op: String,
239    header: WsTradeHeader,
240    args: Vec<T>,
241}
242
243/// Request header with authentication.
244#[derive(Debug, Serialize)]
245struct WsTradeHeader {
246    #[serde(rename = "X-BAPI-TIMESTAMP")]
247    timestamp: String,
248    #[serde(rename = "X-BAPI-RECV-WINDOW")]
249    recv_window: String,
250    #[serde(rename = "X-BAPI-API-KEY")]
251    api_key: String,
252    #[serde(rename = "X-BAPI-SIGN")]
253    sign: String,
254}
255
256/// Pending request waiting for response.
257struct PendingRequest {
258    sender: oneshot::Sender<Result<WsTradeResponse, BybitError>>,
259}
260
261
262/// WebSocket Trade API client for low-latency order management.
263///
264/// This client connects to the WebSocket Trade endpoint and provides
265/// methods for creating, amending, and canceling orders with lower
266/// latency compared to REST API.
267pub struct WsTradeClient {
268    /// Sender for outgoing WebSocket messages.
269    tx: mpsc::UnboundedSender<Message>,
270    /// Pending requests waiting for responses.
271    pending: Arc<RwLock<HashMap<String, PendingRequest>>>,
272    /// Request ID counter.
273    req_counter: AtomicU64,
274    /// API key.
275    api_key: String,
276    /// API secret.
277    api_secret: String,
278    /// Receive window for requests.
279    recv_window: u32,
280    /// Connected flag.
281    connected: Arc<RwLock<bool>>,
282}
283
284impl WsTradeClient {
285    /// Connect to the WebSocket Trade API.
286    pub async fn connect(config: ClientConfig) -> Result<Self, BybitError> {
287        let api_key = config
288            .api_key
289            .as_ref()
290            .ok_or_else(|| BybitError::Auth("API key required for trade API".to_string()))?
291            .clone();
292        let api_secret = config
293            .get_secret()
294            .ok_or_else(|| BybitError::Auth("API secret required for trade API".to_string()))?
295            .to_string();
296
297        let url = config.get_ws_trade_url();
298
299        info!("Connecting to WebSocket Trade API: {}", url);
300        let url = url.to_string();
301
302        let (ws_stream, _) = tokio_tungstenite::connect_async(&url)
303            .await
304            .map_err(|e| BybitError::WebSocket(format!("Connection failed: {}", e)))?;
305
306        let (mut write, mut read) = ws_stream.split();
307
308        let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
309        let pending: Arc<RwLock<HashMap<String, PendingRequest>>> =
310            Arc::new(RwLock::new(HashMap::new()));
311        let connected = Arc::new(RwLock::new(true));
312
313        let pending_clone = pending.clone();
314        let connected_clone = connected.clone();
315
316        tokio::spawn(async move {
317            while let Some(msg) = read.next().await {
318                match msg {
319                    Ok(Message::Text(text)) => {
320                        debug!("Trade API received: {}", text);
321
322                        if let Ok(response) = serde_json::from_str::<WsTradeResponse>(&text) {
323                            let mut pending = pending_clone.write().await;
324                            if let Some(request) = pending.remove(&response.req_id) {
325                                let _ = request.sender.send(Ok(response));
326                            }
327                        }
328                    }
329                    Ok(Message::Ping(data)) => {
330                        debug!("Trade API ping received");
331                        let _ = data;
332                    }
333                    Ok(Message::Close(_)) => {
334                        info!("Trade API connection closed");
335                        *connected_clone.write().await = false;
336                        break;
337                    }
338                    Err(e) => {
339                        error!("Trade API read error: {}", e);
340                        *connected_clone.write().await = false;
341                        break;
342                    }
343                    _ => {}
344                }
345            }
346
347            let mut pending = pending_clone.write().await;
348            for (_, request) in pending.drain() {
349                let _ = request
350                    .sender
351                    .send(Err(BybitError::WebSocket("Connection closed".to_string())));
352            }
353        });
354
355        tokio::spawn(async move {
356            while let Some(msg) = rx.recv().await {
357                if let Err(e) = write.send(msg).await {
358                    error!("Trade API write error: {}", e);
359                    break;
360                }
361            }
362        });
363
364        Ok(Self {
365            tx,
366            pending,
367            req_counter: AtomicU64::new(1),
368            api_key,
369            api_secret,
370            recv_window: DEFAULT_RECV_WINDOW,
371            connected,
372        })
373    }
374
375    /// Check if the client is connected.
376    pub async fn is_connected(&self) -> bool {
377        *self.connected.read().await
378    }
379
380    /// Generate a unique request ID.
381    fn generate_req_id(&self) -> String {
382        let counter = self.req_counter.fetch_add(1, Ordering::SeqCst);
383        format!("req-{}", counter)
384    }
385
386    /// Create authentication header for a request.
387    fn create_header(&self, args_json: &str) -> WsTradeHeader {
388        let timestamp = current_timestamp_ms();
389        let recv_window = self.recv_window;
390
391        let signature = sign_rest_request(
392            timestamp,
393            &self.api_key,
394            recv_window,
395            args_json,
396            &self.api_secret,
397        );
398
399        WsTradeHeader {
400            timestamp: timestamp.to_string(),
401            recv_window: recv_window.to_string(),
402            api_key: self.api_key.clone(),
403            sign: signature,
404        }
405    }
406
407    /// Send a trade request and wait for response.
408    async fn send_request<T: Serialize>(
409        &self,
410        op: &str,
411        args: Vec<T>,
412    ) -> Result<WsTradeResponse, BybitError> {
413        if !self.is_connected().await {
414            return Err(BybitError::WebSocket("Not connected".to_string()));
415        }
416
417        let req_id = self.generate_req_id();
418
419        let args_json = serde_json::to_string(&args)
420            .map_err(|e| BybitError::Serialization(e))?;
421
422        let header = self.create_header(&args_json);
423
424        let request = WsTradeRequest {
425            req_id: req_id.clone(),
426            op: op.to_string(),
427            header,
428            args,
429        };
430
431        let json = serde_json::to_string(&request)
432            .map_err(|e| BybitError::Serialization(e))?;
433
434        debug!("Trade API sending: {}", json);
435
436        let (tx, rx) = oneshot::channel();
437        {
438            let mut pending = self.pending.write().await;
439            pending.insert(req_id.clone(), PendingRequest { sender: tx });
440        }
441
442        self.tx
443            .send(Message::Text(json.into()))
444            .map_err(|e| BybitError::WebSocket(format!("Send failed: {}", e)))?;
445
446        let result = timeout(Duration::from_millis(DEFAULT_TIMEOUT_MS), rx).await;
447
448        match result {
449            Ok(Ok(response)) => response,
450            Ok(Err(_)) => Err(BybitError::WebSocket("Response channel closed".to_string())),
451            Err(_) => {
452                let mut pending = self.pending.write().await;
453                pending.remove(&req_id);
454                Err(BybitError::Timeout)
455            }
456        }
457    }
458
459
460    /// Create a new order.
461    pub async fn create_order(
462        &self,
463        request: CreateOrderRequest,
464    ) -> Result<OrderResult, BybitError> {
465        let response = self.send_request("order.create", vec![request]).await?;
466        response.into_result()
467    }
468
469    /// Amend an existing order.
470    pub async fn amend_order(
471        &self,
472        request: AmendOrderRequest,
473    ) -> Result<OrderResult, BybitError> {
474        let response = self.send_request("order.amend", vec![request]).await?;
475        response.into_result()
476    }
477
478    /// Cancel an order.
479    pub async fn cancel_order(
480        &self,
481        request: CancelOrderRequest,
482    ) -> Result<OrderResult, BybitError> {
483        let response = self.send_request("order.cancel", vec![request]).await?;
484        response.into_result()
485    }
486
487
488    /// Create multiple orders in a single request (max 10).
489    pub async fn batch_create_orders(
490        &self,
491        category: Category,
492        orders: Vec<CreateOrderRequest>,
493    ) -> Result<Vec<BatchOrderResult>, BybitError> {
494        if orders.is_empty() {
495            return Ok(Vec::new());
496        }
497        if orders.len() > 10 {
498            return Err(BybitError::InvalidParameter(
499                "Batch order limit is 10".to_string(),
500            ));
501        }
502
503        let orders: Vec<_> = orders
504            .into_iter()
505            .map(|mut o| {
506                o.category = category.clone();
507                o
508            })
509            .collect();
510
511        let response = self.send_request("order.create-batch", orders).await?;
512
513        if response.is_success() {
514            let list = response
515                .data
516                .get("result")
517                .and_then(|r| r.get("list"))
518                .cloned()
519                .unwrap_or(serde_json::Value::Array(vec![]));
520            serde_json::from_value(list).map_err(|e| BybitError::Serialization(e))
521        } else {
522            Err(BybitError::api_error(response.ret_code, response.ret_msg))
523        }
524    }
525
526    /// Amend multiple orders in a single request (max 10).
527    pub async fn batch_amend_orders(
528        &self,
529        category: Category,
530        orders: Vec<AmendOrderRequest>,
531    ) -> Result<Vec<BatchOrderResult>, BybitError> {
532        if orders.is_empty() {
533            return Ok(Vec::new());
534        }
535        if orders.len() > 10 {
536            return Err(BybitError::InvalidParameter(
537                "Batch order limit is 10".to_string(),
538            ));
539        }
540
541        let orders: Vec<_> = orders
542            .into_iter()
543            .map(|mut o| {
544                o.category = category.clone();
545                o
546            })
547            .collect();
548
549        let response = self.send_request("order.amend-batch", orders).await?;
550
551        if response.is_success() {
552            let list = response
553                .data
554                .get("result")
555                .and_then(|r| r.get("list"))
556                .cloned()
557                .unwrap_or(serde_json::Value::Array(vec![]));
558            serde_json::from_value(list).map_err(|e| BybitError::Serialization(e))
559        } else {
560            Err(BybitError::api_error(response.ret_code, response.ret_msg))
561        }
562    }
563
564    /// Cancel multiple orders in a single request (max 10).
565    pub async fn batch_cancel_orders(
566        &self,
567        category: Category,
568        orders: Vec<CancelOrderRequest>,
569    ) -> Result<Vec<BatchOrderResult>, BybitError> {
570        if orders.is_empty() {
571            return Ok(Vec::new());
572        }
573        if orders.len() > 10 {
574            return Err(BybitError::InvalidParameter(
575                "Batch order limit is 10".to_string(),
576            ));
577        }
578
579        let orders: Vec<_> = orders
580            .into_iter()
581            .map(|mut o| {
582                o.category = category.clone();
583                o
584            })
585            .collect();
586
587        let response = self.send_request("order.cancel-batch", orders).await?;
588
589        if response.is_success() {
590            let list = response
591                .data
592                .get("result")
593                .and_then(|r| r.get("list"))
594                .cloned()
595                .unwrap_or(serde_json::Value::Array(vec![]));
596            serde_json::from_value(list).map_err(|e| BybitError::Serialization(e))
597        } else {
598            Err(BybitError::api_error(response.ret_code, response.ret_msg))
599        }
600    }
601
602    /// Disconnect from the WebSocket Trade API.
603    pub async fn disconnect(&self) {
604        *self.connected.write().await = false;
605    }
606}
607
608#[cfg(test)]
609mod tests {
610    use super::*;
611
612    #[test]
613    fn test_create_order_request_serialize() {
614        let request = CreateOrderRequest {
615            category: Category::Linear,
616            symbol: "BTCUSDT".to_string(),
617            side: Side::Buy,
618            order_type: OrderType::Limit,
619            qty: "0.001".to_string(),
620            price: Some("50000".to_string()),
621            time_in_force: Some(TimeInForce::GTC),
622            order_link_id: None,
623            is_leverage: None,
624            position_idx: None,
625            reduce_only: None,
626            close_on_trigger: None,
627            take_profit: None,
628            stop_loss: None,
629            tpsl_mode: None,
630            market_unit: None,
631        };
632
633        let json = match serde_json::to_string(&request) {
634            Ok(json) => json,
635            Err(err) => panic!("Failed to serialize create request: {}", err),
636        };
637        assert!(json.contains("\"category\":\"linear\""));
638        assert!(json.contains("\"symbol\":\"BTCUSDT\""));
639        assert!(json.contains("\"side\":\"Buy\""));
640        assert!(json.contains("\"orderType\":\"Limit\""));
641        assert!(json.contains("\"qty\":\"0.001\""));
642        assert!(json.contains("\"price\":\"50000\""));
643    }
644
645    #[test]
646    fn test_amend_order_request_serialize() {
647        let request = AmendOrderRequest {
648            category: Category::Linear,
649            symbol: "BTCUSDT".to_string(),
650            order_id: Some("order-123".to_string()),
651            order_link_id: None,
652            qty: Some("0.002".to_string()),
653            price: Some("51000".to_string()),
654            take_profit: None,
655            stop_loss: None,
656            tp_limit_price: None,
657            sl_limit_price: None,
658        };
659
660        let json = match serde_json::to_string(&request) {
661            Ok(json) => json,
662            Err(err) => panic!("Failed to serialize amend request: {}", err),
663        };
664        assert!(json.contains("\"orderId\":\"order-123\""));
665        assert!(json.contains("\"qty\":\"0.002\""));
666        assert!(json.contains("\"price\":\"51000\""));
667    }
668
669    #[test]
670    fn test_cancel_order_request_serialize() {
671        let request = CancelOrderRequest {
672            category: Category::Linear,
673            symbol: "BTCUSDT".to_string(),
674            order_id: Some("order-123".to_string()),
675            order_link_id: None,
676        };
677
678        let json = match serde_json::to_string(&request) {
679            Ok(json) => json,
680            Err(err) => panic!("Failed to serialize cancel request: {}", err),
681        };
682        assert!(json.contains("\"orderId\":\"order-123\""));
683        assert!(!json.contains("orderLinkId"));
684    }
685
686    #[test]
687    fn test_trade_response_deserialize() {
688        let json = r#"{
689            "reqId": "req-1",
690            "retCode": 0,
691            "retMsg": "OK",
692            "op": "order.create",
693            "data": {
694                "orderId": "order-456",
695                "orderLinkId": "my-order-1"
696            },
697            "connId": "conn-123"
698        }"#;
699
700        let response: WsTradeResponse = match serde_json::from_str(json) {
701            Ok(response) => response,
702            Err(err) => panic!("Failed to parse trade response: {}", err),
703        };
704        assert_eq!(response.req_id, "req-1");
705        assert!(response.is_success());
706        assert_eq!(response.op, "order.create");
707
708        let result: OrderResult = match response.into_result() {
709            Ok(result) => result,
710            Err(err) => panic!("Expected successful result: {}", err),
711        };
712        assert_eq!(result.order_id, "order-456");
713        assert_eq!(result.order_link_id, Some("my-order-1".to_string()));
714    }
715
716    #[test]
717    fn test_trade_response_error() {
718        let json = r#"{
719            "reqId": "req-1",
720            "retCode": 10001,
721            "retMsg": "Param error",
722            "op": "order.create",
723            "data": {}
724        }"#;
725
726        let response: WsTradeResponse = match serde_json::from_str(json) {
727            Ok(response) => response,
728            Err(err) => panic!("Failed to parse trade response: {}", err),
729        };
730        assert!(!response.is_success());
731
732        let result: Result<OrderResult, _> = response.into_result();
733        assert!(result.is_err());
734    }
735}