1use std::collections::HashMap;
44use std::sync::atomic::{AtomicU64, Ordering};
45use std::sync::Arc;
46use std::time::Duration;
47
48use futures_util::{SinkExt, StreamExt};
49use serde::{Deserialize, Serialize};
50use tokio::sync::{mpsc, oneshot, RwLock};
51use tokio::time::timeout;
52use tokio_tungstenite::tungstenite::Message;
53use tracing::{debug, error, info};
54
55use crate::auth::{current_timestamp_ms, sign_rest_request};
56use crate::config::ClientConfig;
57use crate::error::BybitError;
58use crate::types::{Category, OrderType, Side, TimeInForce};
59
60const DEFAULT_TIMEOUT_MS: u64 = 10000;
62const DEFAULT_RECV_WINDOW: u32 = 5000;
64
65
66#[derive(Debug, Clone, Serialize)]
68#[serde(rename_all = "camelCase")]
69pub struct CreateOrderRequest {
70 pub category: Category,
72 pub symbol: String,
74 pub side: Side,
76 pub order_type: OrderType,
78 pub qty: String,
80 #[serde(skip_serializing_if = "Option::is_none")]
82 pub price: Option<String>,
83 #[serde(skip_serializing_if = "Option::is_none")]
85 pub time_in_force: Option<TimeInForce>,
86 #[serde(skip_serializing_if = "Option::is_none")]
88 pub order_link_id: Option<String>,
89 #[serde(skip_serializing_if = "Option::is_none")]
91 pub is_leverage: Option<i32>,
92 #[serde(skip_serializing_if = "Option::is_none")]
94 pub position_idx: Option<i32>,
95 #[serde(skip_serializing_if = "Option::is_none")]
97 pub reduce_only: Option<bool>,
98 #[serde(skip_serializing_if = "Option::is_none")]
100 pub close_on_trigger: Option<bool>,
101 #[serde(skip_serializing_if = "Option::is_none")]
103 pub take_profit: Option<String>,
104 #[serde(skip_serializing_if = "Option::is_none")]
106 pub stop_loss: Option<String>,
107 #[serde(skip_serializing_if = "Option::is_none")]
109 pub tpsl_mode: Option<String>,
110 #[serde(skip_serializing_if = "Option::is_none")]
112 pub market_unit: Option<String>,
113}
114
115#[derive(Debug, Clone, Serialize)]
117#[serde(rename_all = "camelCase")]
118pub struct AmendOrderRequest {
119 pub category: Category,
121 pub symbol: String,
123 #[serde(skip_serializing_if = "Option::is_none")]
125 pub order_id: Option<String>,
126 #[serde(skip_serializing_if = "Option::is_none")]
128 pub order_link_id: Option<String>,
129 #[serde(skip_serializing_if = "Option::is_none")]
131 pub qty: Option<String>,
132 #[serde(skip_serializing_if = "Option::is_none")]
134 pub price: Option<String>,
135 #[serde(skip_serializing_if = "Option::is_none")]
137 pub take_profit: Option<String>,
138 #[serde(skip_serializing_if = "Option::is_none")]
140 pub stop_loss: Option<String>,
141 #[serde(skip_serializing_if = "Option::is_none")]
143 pub tp_limit_price: Option<String>,
144 #[serde(skip_serializing_if = "Option::is_none")]
146 pub sl_limit_price: Option<String>,
147}
148
149#[derive(Debug, Clone, Serialize)]
151#[serde(rename_all = "camelCase")]
152pub struct CancelOrderRequest {
153 pub category: Category,
155 pub symbol: String,
157 #[serde(skip_serializing_if = "Option::is_none")]
159 pub order_id: Option<String>,
160 #[serde(skip_serializing_if = "Option::is_none")]
162 pub order_link_id: Option<String>,
163}
164
165
166#[derive(Debug, Clone, Deserialize)]
168#[serde(rename_all = "camelCase")]
169pub struct OrderResult {
170 pub order_id: String,
172 #[serde(default)]
174 pub order_link_id: Option<String>,
175}
176
177#[derive(Debug, Clone, Deserialize)]
179#[serde(rename_all = "camelCase")]
180pub struct BatchOrderResult {
181 pub category: String,
183 pub symbol: String,
185 pub order_id: String,
187 #[serde(default)]
189 pub order_link_id: Option<String>,
190 #[serde(default)]
192 pub create_type: Option<String>,
193}
194
195#[derive(Debug, Clone, Deserialize)]
197#[serde(rename_all = "camelCase")]
198pub struct WsTradeResponse {
199 pub req_id: String,
201 pub ret_code: i32,
203 pub ret_msg: String,
205 pub op: String,
207 #[serde(default)]
209 pub data: serde_json::Value,
210 #[serde(default)]
212 pub conn_id: Option<String>,
213}
214
215impl WsTradeResponse {
216 pub fn is_success(&self) -> bool {
218 self.ret_code == 0
219 }
220
221 pub fn into_result<T: for<'de> Deserialize<'de>>(self) -> Result<T, BybitError> {
223 if self.is_success() {
224 serde_json::from_value(self.data)
225 .map_err(|e| BybitError::Serialization(e))
226 } else {
227 Err(BybitError::api_error(self.ret_code, self.ret_msg))
228 }
229 }
230}
231
232
233#[derive(Debug, Serialize)]
235#[serde(rename_all = "camelCase")]
236struct WsTradeRequest<T> {
237 req_id: String,
238 op: String,
239 header: WsTradeHeader,
240 args: Vec<T>,
241}
242
243#[derive(Debug, Serialize)]
245struct WsTradeHeader {
246 #[serde(rename = "X-BAPI-TIMESTAMP")]
247 timestamp: String,
248 #[serde(rename = "X-BAPI-RECV-WINDOW")]
249 recv_window: String,
250 #[serde(rename = "X-BAPI-API-KEY")]
251 api_key: String,
252 #[serde(rename = "X-BAPI-SIGN")]
253 sign: String,
254}
255
256struct PendingRequest {
258 sender: oneshot::Sender<Result<WsTradeResponse, BybitError>>,
259}
260
261
262pub struct WsTradeClient {
268 tx: mpsc::UnboundedSender<Message>,
270 pending: Arc<RwLock<HashMap<String, PendingRequest>>>,
272 req_counter: AtomicU64,
274 api_key: String,
276 api_secret: String,
278 recv_window: u32,
280 connected: Arc<RwLock<bool>>,
282}
283
284impl WsTradeClient {
285 pub async fn connect(config: ClientConfig) -> Result<Self, BybitError> {
287 let api_key = config
288 .api_key
289 .as_ref()
290 .ok_or_else(|| BybitError::Auth("API key required for trade API".to_string()))?
291 .clone();
292 let api_secret = config
293 .get_secret()
294 .ok_or_else(|| BybitError::Auth("API secret required for trade API".to_string()))?
295 .to_string();
296
297 let url = config.get_ws_trade_url();
298
299 info!("Connecting to WebSocket Trade API: {}", url);
300 let url = url.to_string();
301
302 let (ws_stream, _) = tokio_tungstenite::connect_async(&url)
303 .await
304 .map_err(|e| BybitError::WebSocket(format!("Connection failed: {}", e)))?;
305
306 let (mut write, mut read) = ws_stream.split();
307
308 let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
309 let pending: Arc<RwLock<HashMap<String, PendingRequest>>> =
310 Arc::new(RwLock::new(HashMap::new()));
311 let connected = Arc::new(RwLock::new(true));
312
313 let pending_clone = pending.clone();
314 let connected_clone = connected.clone();
315
316 tokio::spawn(async move {
317 while let Some(msg) = read.next().await {
318 match msg {
319 Ok(Message::Text(text)) => {
320 debug!("Trade API received: {}", text);
321
322 if let Ok(response) = serde_json::from_str::<WsTradeResponse>(&text) {
323 let mut pending = pending_clone.write().await;
324 if let Some(request) = pending.remove(&response.req_id) {
325 let _ = request.sender.send(Ok(response));
326 }
327 }
328 }
329 Ok(Message::Ping(data)) => {
330 debug!("Trade API ping received");
331 let _ = data;
332 }
333 Ok(Message::Close(_)) => {
334 info!("Trade API connection closed");
335 *connected_clone.write().await = false;
336 break;
337 }
338 Err(e) => {
339 error!("Trade API read error: {}", e);
340 *connected_clone.write().await = false;
341 break;
342 }
343 _ => {}
344 }
345 }
346
347 let mut pending = pending_clone.write().await;
348 for (_, request) in pending.drain() {
349 let _ = request
350 .sender
351 .send(Err(BybitError::WebSocket("Connection closed".to_string())));
352 }
353 });
354
355 tokio::spawn(async move {
356 while let Some(msg) = rx.recv().await {
357 if let Err(e) = write.send(msg).await {
358 error!("Trade API write error: {}", e);
359 break;
360 }
361 }
362 });
363
364 Ok(Self {
365 tx,
366 pending,
367 req_counter: AtomicU64::new(1),
368 api_key,
369 api_secret,
370 recv_window: DEFAULT_RECV_WINDOW,
371 connected,
372 })
373 }
374
375 pub async fn is_connected(&self) -> bool {
377 *self.connected.read().await
378 }
379
380 fn generate_req_id(&self) -> String {
382 let counter = self.req_counter.fetch_add(1, Ordering::SeqCst);
383 format!("req-{}", counter)
384 }
385
386 fn create_header(&self, args_json: &str) -> WsTradeHeader {
388 let timestamp = current_timestamp_ms();
389 let recv_window = self.recv_window;
390
391 let signature = sign_rest_request(
392 timestamp,
393 &self.api_key,
394 recv_window,
395 args_json,
396 &self.api_secret,
397 );
398
399 WsTradeHeader {
400 timestamp: timestamp.to_string(),
401 recv_window: recv_window.to_string(),
402 api_key: self.api_key.clone(),
403 sign: signature,
404 }
405 }
406
407 async fn send_request<T: Serialize>(
409 &self,
410 op: &str,
411 args: Vec<T>,
412 ) -> Result<WsTradeResponse, BybitError> {
413 if !self.is_connected().await {
414 return Err(BybitError::WebSocket("Not connected".to_string()));
415 }
416
417 let req_id = self.generate_req_id();
418
419 let args_json = serde_json::to_string(&args)
420 .map_err(|e| BybitError::Serialization(e))?;
421
422 let header = self.create_header(&args_json);
423
424 let request = WsTradeRequest {
425 req_id: req_id.clone(),
426 op: op.to_string(),
427 header,
428 args,
429 };
430
431 let json = serde_json::to_string(&request)
432 .map_err(|e| BybitError::Serialization(e))?;
433
434 debug!("Trade API sending: {}", json);
435
436 let (tx, rx) = oneshot::channel();
437 {
438 let mut pending = self.pending.write().await;
439 pending.insert(req_id.clone(), PendingRequest { sender: tx });
440 }
441
442 self.tx
443 .send(Message::Text(json.into()))
444 .map_err(|e| BybitError::WebSocket(format!("Send failed: {}", e)))?;
445
446 let result = timeout(Duration::from_millis(DEFAULT_TIMEOUT_MS), rx).await;
447
448 match result {
449 Ok(Ok(response)) => response,
450 Ok(Err(_)) => Err(BybitError::WebSocket("Response channel closed".to_string())),
451 Err(_) => {
452 let mut pending = self.pending.write().await;
453 pending.remove(&req_id);
454 Err(BybitError::Timeout)
455 }
456 }
457 }
458
459
460 pub async fn create_order(
462 &self,
463 request: CreateOrderRequest,
464 ) -> Result<OrderResult, BybitError> {
465 let response = self.send_request("order.create", vec![request]).await?;
466 response.into_result()
467 }
468
469 pub async fn amend_order(
471 &self,
472 request: AmendOrderRequest,
473 ) -> Result<OrderResult, BybitError> {
474 let response = self.send_request("order.amend", vec![request]).await?;
475 response.into_result()
476 }
477
478 pub async fn cancel_order(
480 &self,
481 request: CancelOrderRequest,
482 ) -> Result<OrderResult, BybitError> {
483 let response = self.send_request("order.cancel", vec![request]).await?;
484 response.into_result()
485 }
486
487
488 pub async fn batch_create_orders(
490 &self,
491 category: Category,
492 orders: Vec<CreateOrderRequest>,
493 ) -> Result<Vec<BatchOrderResult>, BybitError> {
494 if orders.is_empty() {
495 return Ok(Vec::new());
496 }
497 if orders.len() > 10 {
498 return Err(BybitError::InvalidParameter(
499 "Batch order limit is 10".to_string(),
500 ));
501 }
502
503 let orders: Vec<_> = orders
504 .into_iter()
505 .map(|mut o| {
506 o.category = category.clone();
507 o
508 })
509 .collect();
510
511 let response = self.send_request("order.create-batch", orders).await?;
512
513 if response.is_success() {
514 let list = response
515 .data
516 .get("result")
517 .and_then(|r| r.get("list"))
518 .cloned()
519 .unwrap_or(serde_json::Value::Array(vec![]));
520 serde_json::from_value(list).map_err(|e| BybitError::Serialization(e))
521 } else {
522 Err(BybitError::api_error(response.ret_code, response.ret_msg))
523 }
524 }
525
526 pub async fn batch_amend_orders(
528 &self,
529 category: Category,
530 orders: Vec<AmendOrderRequest>,
531 ) -> Result<Vec<BatchOrderResult>, BybitError> {
532 if orders.is_empty() {
533 return Ok(Vec::new());
534 }
535 if orders.len() > 10 {
536 return Err(BybitError::InvalidParameter(
537 "Batch order limit is 10".to_string(),
538 ));
539 }
540
541 let orders: Vec<_> = orders
542 .into_iter()
543 .map(|mut o| {
544 o.category = category.clone();
545 o
546 })
547 .collect();
548
549 let response = self.send_request("order.amend-batch", orders).await?;
550
551 if response.is_success() {
552 let list = response
553 .data
554 .get("result")
555 .and_then(|r| r.get("list"))
556 .cloned()
557 .unwrap_or(serde_json::Value::Array(vec![]));
558 serde_json::from_value(list).map_err(|e| BybitError::Serialization(e))
559 } else {
560 Err(BybitError::api_error(response.ret_code, response.ret_msg))
561 }
562 }
563
564 pub async fn batch_cancel_orders(
566 &self,
567 category: Category,
568 orders: Vec<CancelOrderRequest>,
569 ) -> Result<Vec<BatchOrderResult>, BybitError> {
570 if orders.is_empty() {
571 return Ok(Vec::new());
572 }
573 if orders.len() > 10 {
574 return Err(BybitError::InvalidParameter(
575 "Batch order limit is 10".to_string(),
576 ));
577 }
578
579 let orders: Vec<_> = orders
580 .into_iter()
581 .map(|mut o| {
582 o.category = category.clone();
583 o
584 })
585 .collect();
586
587 let response = self.send_request("order.cancel-batch", orders).await?;
588
589 if response.is_success() {
590 let list = response
591 .data
592 .get("result")
593 .and_then(|r| r.get("list"))
594 .cloned()
595 .unwrap_or(serde_json::Value::Array(vec![]));
596 serde_json::from_value(list).map_err(|e| BybitError::Serialization(e))
597 } else {
598 Err(BybitError::api_error(response.ret_code, response.ret_msg))
599 }
600 }
601
602 pub async fn disconnect(&self) {
604 *self.connected.write().await = false;
605 }
606}
607
608#[cfg(test)]
609mod tests {
610 use super::*;
611
612 #[test]
613 fn test_create_order_request_serialize() {
614 let request = CreateOrderRequest {
615 category: Category::Linear,
616 symbol: "BTCUSDT".to_string(),
617 side: Side::Buy,
618 order_type: OrderType::Limit,
619 qty: "0.001".to_string(),
620 price: Some("50000".to_string()),
621 time_in_force: Some(TimeInForce::GTC),
622 order_link_id: None,
623 is_leverage: None,
624 position_idx: None,
625 reduce_only: None,
626 close_on_trigger: None,
627 take_profit: None,
628 stop_loss: None,
629 tpsl_mode: None,
630 market_unit: None,
631 };
632
633 let json = match serde_json::to_string(&request) {
634 Ok(json) => json,
635 Err(err) => panic!("Failed to serialize create request: {}", err),
636 };
637 assert!(json.contains("\"category\":\"linear\""));
638 assert!(json.contains("\"symbol\":\"BTCUSDT\""));
639 assert!(json.contains("\"side\":\"Buy\""));
640 assert!(json.contains("\"orderType\":\"Limit\""));
641 assert!(json.contains("\"qty\":\"0.001\""));
642 assert!(json.contains("\"price\":\"50000\""));
643 }
644
645 #[test]
646 fn test_amend_order_request_serialize() {
647 let request = AmendOrderRequest {
648 category: Category::Linear,
649 symbol: "BTCUSDT".to_string(),
650 order_id: Some("order-123".to_string()),
651 order_link_id: None,
652 qty: Some("0.002".to_string()),
653 price: Some("51000".to_string()),
654 take_profit: None,
655 stop_loss: None,
656 tp_limit_price: None,
657 sl_limit_price: None,
658 };
659
660 let json = match serde_json::to_string(&request) {
661 Ok(json) => json,
662 Err(err) => panic!("Failed to serialize amend request: {}", err),
663 };
664 assert!(json.contains("\"orderId\":\"order-123\""));
665 assert!(json.contains("\"qty\":\"0.002\""));
666 assert!(json.contains("\"price\":\"51000\""));
667 }
668
669 #[test]
670 fn test_cancel_order_request_serialize() {
671 let request = CancelOrderRequest {
672 category: Category::Linear,
673 symbol: "BTCUSDT".to_string(),
674 order_id: Some("order-123".to_string()),
675 order_link_id: None,
676 };
677
678 let json = match serde_json::to_string(&request) {
679 Ok(json) => json,
680 Err(err) => panic!("Failed to serialize cancel request: {}", err),
681 };
682 assert!(json.contains("\"orderId\":\"order-123\""));
683 assert!(!json.contains("orderLinkId"));
684 }
685
686 #[test]
687 fn test_trade_response_deserialize() {
688 let json = r#"{
689 "reqId": "req-1",
690 "retCode": 0,
691 "retMsg": "OK",
692 "op": "order.create",
693 "data": {
694 "orderId": "order-456",
695 "orderLinkId": "my-order-1"
696 },
697 "connId": "conn-123"
698 }"#;
699
700 let response: WsTradeResponse = match serde_json::from_str(json) {
701 Ok(response) => response,
702 Err(err) => panic!("Failed to parse trade response: {}", err),
703 };
704 assert_eq!(response.req_id, "req-1");
705 assert!(response.is_success());
706 assert_eq!(response.op, "order.create");
707
708 let result: OrderResult = match response.into_result() {
709 Ok(result) => result,
710 Err(err) => panic!("Expected successful result: {}", err),
711 };
712 assert_eq!(result.order_id, "order-456");
713 assert_eq!(result.order_link_id, Some("my-order-1".to_string()));
714 }
715
716 #[test]
717 fn test_trade_response_error() {
718 let json = r#"{
719 "reqId": "req-1",
720 "retCode": 10001,
721 "retMsg": "Param error",
722 "op": "order.create",
723 "data": {}
724 }"#;
725
726 let response: WsTradeResponse = match serde_json::from_str(json) {
727 Ok(response) => response,
728 Err(err) => panic!("Failed to parse trade response: {}", err),
729 };
730 assert!(!response.is_success());
731
732 let result: Result<OrderResult, _> = response.into_result();
733 assert!(result.is_err());
734 }
735}