use std::{
net::SocketAddr,
sync::{
Arc,
atomic::{AtomicBool, AtomicUsize, Ordering},
},
time::Duration,
};
use axum::{
Router,
extract::{
State,
ws::{Message, WebSocket, WebSocketUpgrade},
},
response::Response,
routing::get,
};
use nautilus_bybit::{
common::enums::{
BybitEnvironment, BybitOrderSide, BybitOrderType, BybitProductType, BybitTimeInForce,
},
websocket::{
client::BybitWebSocketClient,
messages::{BybitWsAmendOrderParams, BybitWsCancelOrderParams, BybitWsPlaceOrderParams},
},
};
use nautilus_common::testing::wait_until_async;
use nautilus_model::{
data::BarType,
identifiers::{InstrumentId, StrategyId, TraderId},
instruments::{CurrencyPair, InstrumentAny},
types::{Currency, Price, Quantity},
};
use rstest::rstest;
use serde_json::json;
use ustr::Ustr;
#[derive(Clone)]
struct TestServerState {
connection_count: Arc<tokio::sync::Mutex<usize>>,
subscriptions: Arc<tokio::sync::Mutex<Vec<String>>>,
subscription_events: Arc<tokio::sync::Mutex<Vec<(String, bool)>>>, fail_next_subscriptions: Arc<tokio::sync::Mutex<Vec<String>>>,
auth_response_delay_ms: Arc<tokio::sync::Mutex<Option<u64>>>,
authenticated: Arc<AtomicBool>,
disconnect_trigger: Arc<AtomicBool>,
ping_count: Arc<AtomicUsize>,
pong_count: Arc<AtomicUsize>,
}
impl Default for TestServerState {
fn default() -> Self {
Self {
connection_count: Arc::new(tokio::sync::Mutex::new(0)),
subscriptions: Arc::new(tokio::sync::Mutex::new(Vec::new())),
subscription_events: Arc::new(tokio::sync::Mutex::new(Vec::new())),
fail_next_subscriptions: Arc::new(tokio::sync::Mutex::new(Vec::new())),
auth_response_delay_ms: Arc::new(tokio::sync::Mutex::new(None)),
authenticated: Arc::new(AtomicBool::new(false)),
disconnect_trigger: Arc::new(AtomicBool::new(false)),
ping_count: Arc::new(AtomicUsize::new(0)),
pong_count: Arc::new(AtomicUsize::new(0)),
}
}
}
impl TestServerState {
#[allow(dead_code)]
async fn reset(&self) {
*self.connection_count.lock().await = 0;
self.subscriptions.lock().await.clear();
self.subscription_events.lock().await.clear();
self.fail_next_subscriptions.lock().await.clear();
*self.auth_response_delay_ms.lock().await = None;
self.authenticated.store(false, Ordering::Relaxed);
self.disconnect_trigger.store(false, Ordering::Relaxed);
self.ping_count.store(0, Ordering::Relaxed);
self.pong_count.store(0, Ordering::Relaxed);
}
#[allow(dead_code)]
async fn set_subscription_failures(&self, topics: Vec<String>) {
*self.fail_next_subscriptions.lock().await = topics;
}
#[allow(dead_code)]
async fn set_auth_delay(&self, delay_ms: u64) {
*self.auth_response_delay_ms.lock().await = Some(delay_ms);
}
#[allow(dead_code)]
async fn subscription_events(&self) -> Vec<(String, bool)> {
self.subscription_events.lock().await.clone()
}
#[allow(dead_code)]
async fn clear_subscription_events(&self) {
self.subscription_events.lock().await.clear();
}
}
async fn handle_websocket(ws: WebSocketUpgrade, State(state): State<TestServerState>) -> Response {
ws.on_upgrade(|socket| handle_socket(socket, state))
}
async fn handle_socket(mut socket: WebSocket, state: TestServerState) {
{
let mut count = state.connection_count.lock().await;
*count += 1;
}
let state_clone = state.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_millis(100));
loop {
interval.tick().await;
if state_clone.disconnect_trigger.load(Ordering::Relaxed) {
break;
}
}
});
loop {
if state.disconnect_trigger.load(Ordering::Relaxed) {
break;
}
let msg_opt = match tokio::time::timeout(Duration::from_millis(50), socket.recv()).await {
Ok(opt) => opt,
Err(_) => continue,
};
let Some(msg) = msg_opt else {
break;
};
let msg = match msg {
Ok(m) => m,
Err(_) => break,
};
if state.disconnect_trigger.load(Ordering::Relaxed) {
break;
}
match msg {
Message::Text(text) => {
let Ok(value) = serde_json::from_str::<serde_json::Value>(&text) else {
continue;
};
let op = value.get("op").and_then(|v| v.as_str());
match op {
Some("ping") => {
state.ping_count.fetch_add(1, Ordering::Relaxed);
let pong_response = json!({
"success": true,
"ret_msg": "pong",
"conn_id": "test-conn-id",
"req_id": value.get("req_id").and_then(|v| v.as_str()).unwrap_or(""),
"op": "pong"
});
if socket
.send(Message::Text(pong_response.to_string().into()))
.await
.is_err()
{
break;
}
}
Some("auth") => {
if let Some(delay_ms) = *state.auth_response_delay_ms.lock().await {
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
let api_key = value
.get("args")
.and_then(|a| a.as_array())
.and_then(|arr| arr.first())
.and_then(|v| v.as_str());
if api_key == Some("test_api_key") {
state.authenticated.store(true, Ordering::Relaxed);
let auth_response = json!({
"success": true,
"ret_msg": "",
"op": "auth",
"conn_id": "test-conn-id"
});
if socket
.send(Message::Text(auth_response.to_string().into()))
.await
.is_err()
{
break;
}
} else {
let auth_response = json!({
"success": false,
"ret_msg": "Invalid API key",
"op": "auth",
"conn_id": "test-conn-id"
});
if socket
.send(Message::Text(auth_response.to_string().into()))
.await
.is_err()
{
break;
}
}
}
Some("subscribe") => {
let args = value.get("args").and_then(|a| a.as_array());
let mut failed_topics = Vec::new();
if let Some(topics) = args {
let fail_list = state.fail_next_subscriptions.lock().await.clone();
for topic in topics {
if let Some(topic_str) = topic.as_str() {
let should_fail = fail_list.contains(&topic_str.to_string());
state
.subscription_events
.lock()
.await
.push((topic_str.to_string(), !should_fail));
if should_fail {
failed_topics.push(topic_str);
} else {
let mut subs = state.subscriptions.lock().await;
if !subs.contains(&topic_str.to_string()) {
subs.push(topic_str.to_string());
}
}
}
}
}
if failed_topics.is_empty() {
let sub_response = json!({
"success": true,
"ret_msg": "",
"conn_id": "test-conn-id",
"req_id": value.get("req_id").and_then(|v| v.as_str()).unwrap_or(""),
"op": "subscribe"
});
if socket
.send(Message::Text(sub_response.to_string().into()))
.await
.is_err()
{
break;
}
} else {
let error_response = json!({
"success": false,
"ret_msg": format!("Subscription failed for topics: {:?}", failed_topics),
"ret_code": 10001,
"conn_id": "test-conn-id",
"req_id": value.get("req_id").and_then(|v| v.as_str()).unwrap_or(""),
"op": "subscribe"
});
if socket
.send(Message::Text(error_response.to_string().into()))
.await
.is_err()
{
break;
}
}
if let Some(topics) = args
&& let Some(first_topic) = topics.first().and_then(|t| t.as_str())
{
if first_topic.contains("publicTrade") {
let trade_msg = load_test_data("ws_public_trade.json");
if socket
.send(Message::Text(trade_msg.to_string().into()))
.await
.is_err()
{
break;
}
} else if first_topic.contains("orderbook") {
let orderbook_msg = load_test_data("ws_orderbook_snapshot.json");
if socket
.send(Message::Text(orderbook_msg.to_string().into()))
.await
.is_err()
{
break;
}
}
}
}
Some("unsubscribe") => {
let args = value.get("args").and_then(|a| a.as_array());
if let Some(topics) = args {
for topic in topics {
if let Some(topic_str) = topic.as_str() {
let mut events = state.subscription_events.lock().await;
events.retain(|(t, _)| t != topic_str);
drop(events);
let mut subs = state.subscriptions.lock().await;
subs.retain(|s| s != topic_str);
}
}
}
let unsub_response = json!({
"success": true,
"ret_msg": "",
"conn_id": "test-conn-id",
"req_id": value.get("req_id").and_then(|v| v.as_str()).unwrap_or(""),
"op": "unsubscribe"
});
if socket
.send(Message::Text(unsub_response.to_string().into()))
.await
.is_err()
{
break;
}
}
Some("order.place") => {
let req_id = value.get("req_id").and_then(|v| v.as_str());
let response = json!({
"success": true,
"ret_msg": "",
"conn_id": "test-conn-id",
"req_id": req_id.unwrap_or(""),
"op": "order.place"
});
if socket
.send(Message::Text(response.to_string().into()))
.await
.is_err()
{
break;
}
}
Some("order.amend") => {
let req_id = value.get("req_id").and_then(|v| v.as_str());
let response = json!({
"success": true,
"ret_msg": "",
"conn_id": "test-conn-id",
"req_id": req_id.unwrap_or(""),
"op": "order.amend"
});
if socket
.send(Message::Text(response.to_string().into()))
.await
.is_err()
{
break;
}
}
Some("order.cancel") => {
let req_id = value.get("req_id").and_then(|v| v.as_str());
let response = json!({
"success": true,
"ret_msg": "",
"conn_id": "test-conn-id",
"req_id": req_id.unwrap_or(""),
"op": "order.cancel"
});
if socket
.send(Message::Text(response.to_string().into()))
.await
.is_err()
{
break;
}
}
_ => {}
}
}
Message::Ping(_) => {
state.ping_count.fetch_add(1, Ordering::Relaxed);
if socket.send(Message::Pong(vec![].into())).await.is_err() {
break;
}
}
Message::Pong(_) => {
state.pong_count.fetch_add(1, Ordering::Relaxed);
}
Message::Close(_) => {
break;
}
_ => {}
}
if state.disconnect_trigger.load(Ordering::Relaxed) {
break;
}
}
let mut count = state.connection_count.lock().await;
*count = count.saturating_sub(1);
}
fn load_test_data(filename: &str) -> serde_json::Value {
let path = format!("test_data/{filename}");
let content = std::fs::read_to_string(path).expect("Failed to read test data");
serde_json::from_str(&content).expect("Failed to parse test data")
}
fn make_linear_pair(raw_symbol: &str, base: &str, quote: &str) -> CurrencyPair {
CurrencyPair::new(
format!("{raw_symbol}-LINEAR.BYBIT").into(),
raw_symbol.into(),
Currency::from(base),
Currency::from(quote),
2,
5,
Price::from("0.01"),
Quantity::from("0.00001"),
None, None, None, None, None, None, None, None, None, None, None, None, None, 0.into(), 0.into(), )
}
fn create_test_router(state: TestServerState) -> Router {
Router::new()
.route("/v5/public/linear", get(handle_websocket))
.route("/v5/private", get(handle_websocket))
.with_state(state)
}
async fn start_test_server()
-> Result<(SocketAddr, TestServerState), Box<dyn std::error::Error + Send + Sync>> {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let state = TestServerState::default();
let router = create_test_router(state.clone());
tokio::spawn(async move {
axum::serve(listener, router).await.unwrap();
});
tokio::time::sleep(Duration::from_millis(100)).await;
Ok((addr, state))
}
#[allow(dead_code)]
async fn wait_for_subscription_events<F>(
state: &TestServerState,
timeout: Duration,
mut predicate: F,
) -> Vec<(String, bool)>
where
F: FnMut(&[(String, bool)]) -> bool,
{
let state_clone = state.clone();
let poll = async {
loop {
let events = state_clone.subscription_events().await;
if predicate(&events) {
return events;
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
};
match tokio::time::timeout(timeout, poll).await {
Ok(events) => events,
Err(_) => state.subscription_events().await,
}
}
async fn wait_for_connection_count(state: &TestServerState, expected: usize, timeout: Duration) {
wait_until_async(
|| {
let state = state.clone();
async move { *state.connection_count.lock().await == expected }
},
timeout,
)
.await;
}
#[rstest]
#[tokio::test]
async fn test_public_client_connection() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/public/linear");
let mut client = BybitWebSocketClient::new_public_with(
BybitProductType::Linear,
BybitEnvironment::Mainnet,
Some(ws_url),
None,
);
client.connect().await.unwrap();
wait_until_async(
|| async { *state.connection_count.lock().await > 0 },
Duration::from_secs(5),
)
.await;
assert!(client.is_active());
assert_eq!(*state.connection_count.lock().await, 1);
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_private_client_authentication() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/private");
let mut client = BybitWebSocketClient::new_private(
BybitEnvironment::Mainnet,
Some("test_api_key".to_string()),
Some("test_api_secret".to_string()),
Some(ws_url),
None,
);
let _result = client.connect().await;
wait_until_async(
|| async { *state.connection_count.lock().await > 0 },
Duration::from_secs(5),
)
.await;
assert!(*state.connection_count.lock().await > 0);
let _ = client.close().await;
}
#[rstest]
#[tokio::test]
async fn test_authentication_failure() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/private");
let mut client = BybitWebSocketClient::new_private(
BybitEnvironment::Mainnet,
Some("invalid_key".to_string()),
Some("invalid_secret".to_string()),
Some(ws_url),
None,
);
let _result = client.connect().await;
wait_until_async(
|| async { *state.connection_count.lock().await > 0 },
Duration::from_secs(5),
)
.await;
assert!(!state.authenticated.load(Ordering::Relaxed));
let _ = client.close().await;
}
#[rstest]
#[tokio::test]
async fn test_ping_pong() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/public/linear");
let mut client = BybitWebSocketClient::new_public_with(
BybitProductType::Linear,
BybitEnvironment::Mainnet,
Some(ws_url),
Some(1), );
client.connect().await.unwrap();
wait_until_async(
|| async { *state.connection_count.lock().await > 0 },
Duration::from_secs(5),
)
.await;
wait_until_async(
|| async { state.ping_count.load(Ordering::Relaxed) > 0 },
Duration::from_secs(3),
)
.await;
assert!(state.ping_count.load(Ordering::Relaxed) > 0);
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_subscription_lifecycle() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/public/linear");
let mut client = BybitWebSocketClient::new_public_with(
BybitProductType::Linear,
BybitEnvironment::Mainnet,
Some(ws_url),
None,
);
client.connect().await.unwrap();
let topics = vec!["publicTrade.BTCUSDT".to_string()];
client.subscribe(topics.clone()).await.unwrap();
wait_until_async(
|| async { !state.subscription_events.lock().await.is_empty() },
Duration::from_secs(5),
)
.await;
let subs = state.subscription_events.lock().await.clone();
assert!(
subs.iter()
.any(|(topic, ok)| topic == "publicTrade.BTCUSDT" && *ok)
);
client.unsubscribe(topics).await.unwrap();
wait_until_async(
|| async { state.subscription_events.lock().await.is_empty() },
Duration::from_secs(5),
)
.await;
assert!(state.subscription_events.lock().await.is_empty());
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_message_routing() {
let (addr, _state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/public/linear");
let mut client = BybitWebSocketClient::new_public_with(
BybitProductType::Linear,
BybitEnvironment::Mainnet,
Some(ws_url),
None,
);
client.connect().await.unwrap();
let topics = vec!["publicTrade.BTCUSDT".to_string()];
client.subscribe(topics).await.unwrap();
wait_until_async(
|| async { client.subscription_count() > 0 },
Duration::from_secs(5),
)
.await;
assert!(client.subscription_count() > 0);
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_reconnection_flow() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/public/linear");
let mut client = BybitWebSocketClient::new_public_with(
BybitProductType::Linear,
BybitEnvironment::Mainnet,
Some(ws_url),
None,
);
client.connect().await.unwrap();
let topics = vec!["publicTrade.BTCUSDT".to_string()];
client.subscribe(topics.clone()).await.unwrap();
wait_until_async(
|| async { *state.connection_count.lock().await > 0 },
Duration::from_secs(5),
)
.await;
let initial_count = *state.connection_count.lock().await;
assert_eq!(initial_count, 1);
state.disconnect_trigger.store(true, Ordering::Relaxed);
tokio::time::sleep(Duration::from_millis(100)).await;
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_multiple_subscriptions() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/public/linear");
let mut client = BybitWebSocketClient::new_public_with(
BybitProductType::Linear,
BybitEnvironment::Mainnet,
Some(ws_url),
None,
);
client.connect().await.unwrap();
let topics = vec![
"publicTrade.BTCUSDT".to_string(),
"publicTrade.ETHUSDT".to_string(),
"orderbook.50.BTCUSDT".to_string(),
];
client.subscribe(topics).await.unwrap();
wait_until_async(
|| async { state.subscription_events.lock().await.len() >= 3 },
Duration::from_secs(5),
)
.await;
let subs = state.subscription_events.lock().await.clone();
assert_eq!(subs.len(), 3);
assert!(
subs.iter()
.any(|(topic, ok)| topic == "publicTrade.BTCUSDT" && *ok)
);
assert!(
subs.iter()
.any(|(topic, ok)| topic == "publicTrade.ETHUSDT" && *ok)
);
assert!(
subs.iter()
.any(|(topic, ok)| topic == "orderbook.50.BTCUSDT" && *ok)
);
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_wait_until_active_timeout() {
let mut client = BybitWebSocketClient::new_public_with(
BybitProductType::Linear,
BybitEnvironment::Mainnet,
Some("ws://127.0.0.1:9999/invalid".to_string()),
None,
);
let _ = client.connect().await;
let result = client.wait_until_active(0.5).await;
assert!(result.is_err());
}
#[rstest]
#[tokio::test]
async fn test_heartbeat_timeout_reconnection() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/public/linear");
let mut client = BybitWebSocketClient::new_public_with(
BybitProductType::Linear,
BybitEnvironment::Mainnet,
Some(ws_url),
Some(1), );
client.connect().await.unwrap();
wait_until_async(
|| async { *state.connection_count.lock().await > 0 },
Duration::from_secs(5),
)
.await;
state.disconnect_trigger.store(true, Ordering::Relaxed);
tokio::time::sleep(Duration::from_millis(100)).await;
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_sends_pong_for_text_ping() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/public/linear");
let mut client = BybitWebSocketClient::new_public_with(
BybitProductType::Linear,
BybitEnvironment::Mainnet,
Some(ws_url),
Some(1),
);
client.connect().await.unwrap();
wait_until_async(
|| async { state.ping_count.load(Ordering::Relaxed) > 0 },
Duration::from_secs(3),
)
.await;
assert!(state.ping_count.load(Ordering::Relaxed) > 0);
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_sends_pong_for_control_ping() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/public/linear");
let mut client = BybitWebSocketClient::new_public_with(
BybitProductType::Linear,
BybitEnvironment::Mainnet,
Some(ws_url),
None,
);
client.connect().await.unwrap();
wait_until_async(
|| async { *state.connection_count.lock().await > 0 },
Duration::from_secs(5),
)
.await;
assert!(client.is_active());
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_reauth_after_disconnect() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/private");
let mut client = BybitWebSocketClient::new_private(
BybitEnvironment::Mainnet,
Some("test_api_key".to_string()),
Some("test_api_secret".to_string()),
Some(ws_url),
None,
);
let _ = client.connect().await;
wait_until_async(
|| async { *state.connection_count.lock().await > 0 },
Duration::from_secs(5),
)
.await;
state.disconnect_trigger.store(true, Ordering::Relaxed);
tokio::time::sleep(Duration::from_millis(100)).await;
let _ = client.close().await;
}
#[rstest]
#[tokio::test]
async fn test_login_failure_emits_error() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/private");
let mut client = BybitWebSocketClient::new_private(
BybitEnvironment::Mainnet,
Some("invalid_key".to_string()),
Some("invalid_secret".to_string()),
Some(ws_url),
None,
);
let _ = client.connect().await;
wait_until_async(
|| async { *state.connection_count.lock().await > 0 },
Duration::from_secs(5),
)
.await;
assert!(!state.authenticated.load(Ordering::Relaxed));
let _ = client.close().await;
}
#[rstest]
#[tokio::test]
async fn test_unauthenticated_private_subscription_fails() {
let (addr, _state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/public/linear");
let mut client = BybitWebSocketClient::new_public_with(
BybitProductType::Linear,
BybitEnvironment::Mainnet,
Some(ws_url),
None,
);
client.connect().await.unwrap();
let result = client.subscribe_orders().await;
assert!(result.is_err());
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_subscription_after_reconnection() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/public/linear");
let mut client = BybitWebSocketClient::new_public_with(
BybitProductType::Linear,
BybitEnvironment::Mainnet,
Some(ws_url),
None,
);
client.connect().await.unwrap();
let topics = vec!["publicTrade.BTCUSDT".to_string()];
client.subscribe(topics.clone()).await.unwrap();
wait_until_async(
|| async { !state.subscription_events.lock().await.is_empty() },
Duration::from_secs(5),
)
.await;
state.disconnect_trigger.store(true, Ordering::Relaxed);
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(client.subscription_count() > 0);
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_subscription_restoration_tracking() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/public/linear");
let mut client = BybitWebSocketClient::new_public_with(
BybitProductType::Linear,
BybitEnvironment::Mainnet,
Some(ws_url),
None,
);
client.connect().await.unwrap();
let topics = vec![
"publicTrade.BTCUSDT".to_string(),
"orderbook.50.ETHUSDT".to_string(),
];
client.subscribe(topics).await.unwrap();
wait_until_async(
|| async { state.subscription_events.lock().await.len() >= 2 },
Duration::from_secs(5),
)
.await;
let initial_count = client.subscription_count();
assert_eq!(initial_count, 2);
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_reconnection_retries_failed_subscriptions() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/public/linear");
let mut client = BybitWebSocketClient::new_public_with(
BybitProductType::Linear,
BybitEnvironment::Mainnet,
Some(ws_url),
None,
);
client.connect().await.unwrap();
let topics = vec!["publicTrade.BTCUSDT".to_string()];
client.subscribe(topics).await.unwrap();
wait_until_async(
|| async { !state.subscription_events.lock().await.is_empty() },
Duration::from_secs(5),
)
.await;
state.subscription_events.lock().await.clear();
state.disconnect_trigger.store(true, Ordering::Relaxed);
tokio::time::sleep(Duration::from_millis(100)).await;
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_trade_subscription_flow() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/public/linear");
let mut client = BybitWebSocketClient::new_public_with(
BybitProductType::Linear,
BybitEnvironment::Mainnet,
Some(ws_url),
None,
);
client.connect().await.unwrap();
let instrument_id = InstrumentId::from("BTCUSDT-LINEAR.BYBIT");
client.subscribe_trades(instrument_id).await.unwrap();
wait_until_async(
|| async { !state.subscription_events.lock().await.is_empty() },
Duration::from_secs(5),
)
.await;
let subs = state.subscription_events.lock().await.clone();
assert!(
subs.iter()
.any(|(topic, ok)| topic.contains("publicTrade") && *ok)
);
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_orderbook_subscription_flow() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/public/linear");
let mut client = BybitWebSocketClient::new_public_with(
BybitProductType::Linear,
BybitEnvironment::Mainnet,
Some(ws_url),
None,
);
client.connect().await.unwrap();
let instrument_id = InstrumentId::from("BTCUSDT-LINEAR.BYBIT");
client.subscribe_orderbook(instrument_id, 50).await.unwrap();
wait_until_async(
|| async { !state.subscription_events.lock().await.is_empty() },
Duration::from_secs(5),
)
.await;
let subs = state.subscription_events.lock().await.clone();
assert!(
subs.iter()
.any(|(topic, ok)| topic.contains("orderbook") && *ok)
);
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_ticker_subscription_flow() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/public/linear");
let mut client = BybitWebSocketClient::new_public_with(
BybitProductType::Linear,
BybitEnvironment::Mainnet,
Some(ws_url),
None,
);
client.connect().await.unwrap();
let instrument_id = InstrumentId::from("BTCUSDT-LINEAR.BYBIT");
client.subscribe_ticker(instrument_id).await.unwrap();
wait_until_async(
|| async { !state.subscription_events.lock().await.is_empty() },
Duration::from_secs(5),
)
.await;
let subs = state.subscription_events.lock().await.clone();
assert!(
subs.iter()
.any(|(topic, ok)| topic.contains("ticker") && *ok)
);
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_klines_subscription_flow() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/public/linear");
let mut client = BybitWebSocketClient::new_public_with(
BybitProductType::Linear,
BybitEnvironment::Mainnet,
Some(ws_url),
None,
);
client.connect().await.unwrap();
let bar_type = BarType::from("BTCUSDT-LINEAR.BYBIT-1-MINUTE-LAST-EXTERNAL");
client.subscribe_bars(bar_type).await.unwrap();
wait_until_async(
|| async { !state.subscription_events.lock().await.is_empty() },
Duration::from_secs(5),
)
.await;
let subs = state.subscription_events.lock().await.clone();
assert!(
subs.iter()
.any(|(topic, ok)| topic.contains("kline") && *ok)
);
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_private_orders_subscription() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/private");
let mut client = BybitWebSocketClient::new_private(
BybitEnvironment::Mainnet,
Some("test_api_key".to_string()),
Some("test_api_secret".to_string()),
Some(ws_url),
None,
);
let _ = client.connect().await;
wait_until_async(
|| async { *state.connection_count.lock().await > 0 },
Duration::from_secs(5),
)
.await;
let _ = client.subscribe_orders().await;
let _ = client.close().await;
}
#[rstest]
#[tokio::test]
async fn test_private_executions_subscription() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/private");
let mut client = BybitWebSocketClient::new_private(
BybitEnvironment::Mainnet,
Some("test_api_key".to_string()),
Some("test_api_secret".to_string()),
Some(ws_url),
None,
);
let _ = client.connect().await;
wait_until_async(
|| async { *state.connection_count.lock().await > 0 },
Duration::from_secs(5),
)
.await;
let _ = client.subscribe_executions().await;
let _ = client.close().await;
}
#[rstest]
#[tokio::test]
async fn test_private_wallet_subscription() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/private");
let mut client = BybitWebSocketClient::new_private(
BybitEnvironment::Mainnet,
Some("test_api_key".to_string()),
Some("test_api_secret".to_string()),
Some(ws_url),
None,
);
let _ = client.connect().await;
wait_until_async(
|| async { *state.connection_count.lock().await > 0 },
Duration::from_secs(5),
)
.await;
let _ = client.subscribe_wallet().await;
let _ = client.close().await;
}
#[rstest]
#[tokio::test]
async fn test_rapid_consecutive_reconnections() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/public/linear");
let mut client = BybitWebSocketClient::new_public_with(
BybitProductType::Linear,
BybitEnvironment::Mainnet,
Some(ws_url),
None,
);
client.connect().await.unwrap();
let topics = vec!["publicTrade.BTCUSDT".to_string()];
client.subscribe(topics.clone()).await.unwrap();
wait_until_async(
|| async { !state.subscription_events.lock().await.is_empty() },
Duration::from_secs(5),
)
.await;
let initial_connection_count = *state.connection_count.lock().await;
for i in 0..3 {
state.clear_subscription_events().await;
wait_until_async(
|| {
let state = state.clone();
async move { state.subscription_events().await.is_empty() }
},
Duration::from_secs(5),
)
.await;
state.disconnect_trigger.store(true, Ordering::Relaxed);
let _ = client.subscribe(vec![format!("publicTrade.ETH{i}")]).await;
tokio::time::sleep(Duration::from_millis(200)).await;
state.disconnect_trigger.store(false, Ordering::Relaxed);
}
tokio::time::sleep(Duration::from_secs(2)).await;
let final_count = *state.connection_count.lock().await;
assert!(
final_count >= initial_connection_count,
"Expected connection to be maintained or reconnected"
);
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_reconnection_race_condition() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/public/linear");
let mut client = BybitWebSocketClient::new_public_with(
BybitProductType::Linear,
BybitEnvironment::Mainnet,
Some(ws_url),
None,
);
client.connect().await.unwrap();
let topics = vec!["publicTrade.BTCUSDT".to_string()];
client.subscribe(topics).await.unwrap();
wait_until_async(
|| async { !state.subscription_events.lock().await.is_empty() },
Duration::from_secs(5),
)
.await;
state.clear_subscription_events().await;
wait_until_async(
|| {
let state = state.clone();
async move { state.subscription_events().await.is_empty() }
},
Duration::from_secs(5),
)
.await;
state.disconnect_trigger.store(true, Ordering::Relaxed);
let _ = client
.subscribe(vec!["orderbook.50.ETHUSDT".to_string()])
.await;
tokio::time::sleep(Duration::from_millis(100)).await;
state.disconnect_trigger.store(false, Ordering::Relaxed);
tokio::time::sleep(Duration::from_millis(50)).await;
state.disconnect_trigger.store(true, Ordering::Relaxed);
tokio::time::sleep(Duration::from_millis(100)).await;
state.disconnect_trigger.store(false, Ordering::Relaxed);
tokio::time::sleep(Duration::from_secs(3)).await;
assert!(*state.connection_count.lock().await >= 1);
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_reconnection_waits_for_delayed_auth_ack() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/private");
state.set_auth_delay(500).await;
let mut client = BybitWebSocketClient::new_private(
BybitEnvironment::Mainnet,
Some("test_api_key".to_string()),
Some("test_api_secret".to_string()),
Some(ws_url),
None,
);
let _ = client.connect().await;
wait_until_async(
|| async { *state.connection_count.lock().await > 0 },
Duration::from_secs(3),
)
.await;
let _ = client.subscribe_orders().await;
tokio::time::sleep(Duration::from_millis(1000)).await;
assert!(
*state.connection_count.lock().await > 0,
"Connection should be maintained during delayed auth"
);
let _ = client.close().await;
}
#[rstest]
#[tokio::test]
async fn test_multiple_partial_subscription_failures() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/public/linear");
let mut client = BybitWebSocketClient::new_public_with(
BybitProductType::Linear,
BybitEnvironment::Mainnet,
Some(ws_url),
None,
);
client.connect().await.unwrap();
let topics = vec![
"publicTrade.BTCUSDT".to_string(),
"publicTrade.ETHUSDT".to_string(),
"orderbook.50.BTCUSDT".to_string(),
];
client.subscribe(topics.clone()).await.unwrap();
wait_until_async(
|| async { state.subscription_events.lock().await.len() >= 3 },
Duration::from_secs(5),
)
.await;
state
.set_subscription_failures(vec!["publicTrade.SOLUSDT".to_string()])
.await;
state.clear_subscription_events().await;
wait_until_async(
|| {
let state = state.clone();
async move { state.subscription_events().await.is_empty() }
},
Duration::from_secs(5),
)
.await;
let mixed_topics = vec![
"publicTrade.SOLUSDT".to_string(),
"orderbook.50.ETHUSDT".to_string(),
];
let _ = client.subscribe(mixed_topics).await;
wait_until_async(
|| async { !state.subscription_events.lock().await.is_empty() },
Duration::from_secs(5),
)
.await;
let events = state.subscription_events().await;
assert!(
!events.is_empty(),
"Should have subscription events even with partial failures"
);
let has_failure = events.iter().any(|(_, success)| !success);
assert!(has_failure, "Should have at least one failed subscription");
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_is_active_false_during_reconnection() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/public/linear");
let mut client = BybitWebSocketClient::new_public_with(
BybitProductType::Linear,
BybitEnvironment::Mainnet,
Some(ws_url),
None,
);
client.connect().await.unwrap();
wait_until_async(|| async { client.is_active() }, Duration::from_secs(2)).await;
assert!(client.is_active(), "Client should be active after connect");
state.disconnect_trigger.store(true, Ordering::Relaxed);
let _ = client
.subscribe(vec!["publicTrade.BTCUSDT".to_string()])
.await;
tokio::time::sleep(Duration::from_millis(100)).await;
let active_during_reconnect = client.is_active();
state.disconnect_trigger.store(false, Ordering::Relaxed);
if !active_during_reconnect {
assert!(!active_during_reconnect);
}
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_sends_pong_for_text_ping_message() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/public/linear");
let mut client = BybitWebSocketClient::new_public_with(
BybitProductType::Linear,
BybitEnvironment::Mainnet,
Some(ws_url),
Some(1), );
client.connect().await.unwrap();
wait_until_async(
|| async { *state.connection_count.lock().await > 0 },
Duration::from_secs(5),
)
.await;
wait_until_async(
|| async { state.ping_count.load(Ordering::Relaxed) > 0 },
Duration::from_secs(3),
)
.await;
assert!(
state.ping_count.load(Ordering::Relaxed) > 0,
"Server should have received ping messages"
);
client.close().await.unwrap();
}
#[cfg(test)]
mod conditional_order_tests {
use nautilus_bybit::{
common::enums::{
BybitOrderSide, BybitOrderType, BybitProductType, BybitTimeInForce, BybitTriggerType,
},
websocket::{client::BybitWebSocketClient, messages::BybitWsPlaceOrderParams},
};
use nautilus_model::{
enums::{OrderSide, OrderType, TimeInForce},
identifiers::{ClientOrderId, InstrumentId},
types::{Price, Quantity},
};
use rstest::rstest;
#[rstest]
fn test_buy_stop_market_order_trigger_direction() {
let params = create_conditional_order_params_with_side(
OrderType::StopMarket,
BybitOrderSide::Buy,
Some(Price::from("4500.00")),
None,
);
assert_eq!(params.trigger_direction, Some(1)); assert_eq!(params.trigger_price.as_ref().unwrap(), "4500.00");
assert_eq!(params.order_type, BybitOrderType::Market);
}
#[rstest]
fn test_sell_stop_market_order_trigger_direction() {
let params = create_conditional_order_params_with_side(
OrderType::StopMarket,
BybitOrderSide::Sell,
Some(Price::from("4500.00")),
None,
);
assert_eq!(params.trigger_direction, Some(2)); assert_eq!(params.trigger_price.as_ref().unwrap(), "4500.00");
assert_eq!(params.order_type, BybitOrderType::Market);
}
#[rstest]
fn test_buy_stop_limit_order_trigger_direction() {
let params = create_conditional_order_params_with_side(
OrderType::StopLimit,
BybitOrderSide::Buy,
Some(Price::from("4500.00")),
Some(Price::from("4505.00")),
);
assert_eq!(params.trigger_direction, Some(1)); assert_eq!(params.trigger_price.as_ref().unwrap(), "4500.00");
assert_eq!(params.price.as_ref().unwrap(), "4505.00");
assert_eq!(params.order_type, BybitOrderType::Limit);
}
#[rstest]
fn test_sell_stop_limit_order_trigger_direction() {
let params = create_conditional_order_params_with_side(
OrderType::StopLimit,
BybitOrderSide::Sell,
Some(Price::from("4500.00")),
Some(Price::from("4495.00")),
);
assert_eq!(params.trigger_direction, Some(2)); assert_eq!(params.trigger_price.as_ref().unwrap(), "4500.00");
assert_eq!(params.price.as_ref().unwrap(), "4495.00");
assert_eq!(params.order_type, BybitOrderType::Limit);
}
#[rstest]
fn test_buy_market_if_touched_order_trigger_direction() {
let params = create_conditional_order_params_with_side(
OrderType::MarketIfTouched,
BybitOrderSide::Buy,
Some(Price::from("4500.00")),
None,
);
assert_eq!(params.trigger_direction, Some(2)); assert_eq!(params.trigger_price.as_ref().unwrap(), "4500.00");
assert_eq!(params.order_type, BybitOrderType::Market);
}
#[rstest]
fn test_sell_market_if_touched_order_trigger_direction() {
let params = create_conditional_order_params_with_side(
OrderType::MarketIfTouched,
BybitOrderSide::Sell,
Some(Price::from("5500.00")),
None,
);
assert_eq!(params.trigger_direction, Some(1)); assert_eq!(params.trigger_price.as_ref().unwrap(), "5500.00");
assert_eq!(params.order_type, BybitOrderType::Market);
}
#[rstest]
fn test_buy_limit_if_touched_order_trigger_direction() {
let params = create_conditional_order_params_with_side(
OrderType::LimitIfTouched,
BybitOrderSide::Buy,
Some(Price::from("4500.00")),
Some(Price::from("4505.00")),
);
assert_eq!(params.trigger_direction, Some(2)); assert_eq!(params.trigger_price.as_ref().unwrap(), "4500.00");
assert_eq!(params.price.as_ref().unwrap(), "4505.00");
assert_eq!(params.order_type, BybitOrderType::Limit);
}
#[rstest]
fn test_sell_limit_if_touched_order_trigger_direction() {
let params = create_conditional_order_params_with_side(
OrderType::LimitIfTouched,
BybitOrderSide::Sell,
Some(Price::from("5500.00")),
Some(Price::from("5495.00")),
);
assert_eq!(params.trigger_direction, Some(1)); assert_eq!(params.trigger_price.as_ref().unwrap(), "5500.00");
assert_eq!(params.price.as_ref().unwrap(), "5495.00");
assert_eq!(params.order_type, BybitOrderType::Limit);
}
#[rstest]
fn test_reduce_only_false_omitted() {
let params = create_conditional_order_params_with_reduce_only(
OrderType::StopMarket,
Some(Price::from("4500.00")),
None,
Some(false),
);
assert!(params.reduce_only.is_none());
}
#[rstest]
fn test_reduce_only_explicit_true() {
let params = create_conditional_order_params_with_reduce_only(
OrderType::StopMarket,
Some(Price::from("4500.00")),
None,
Some(true),
);
assert!(params.reduce_only.is_some());
assert!(params.reduce_only.unwrap());
}
fn create_conditional_order_params_with_side(
order_type: OrderType,
side: BybitOrderSide,
trigger_price: Option<Price>,
price: Option<Price>,
) -> BybitWsPlaceOrderParams {
let client = BybitWebSocketClient::new_public(None, None);
let nautilus_side = match side {
BybitOrderSide::Buy => OrderSide::Buy,
BybitOrderSide::Sell => OrderSide::Sell,
BybitOrderSide::Unknown => panic!("Unknown side not supported in tests"),
};
client
.build_place_order_params(
BybitProductType::Linear,
InstrumentId::from("ETHUSDT-LINEAR.BYBIT"),
ClientOrderId::from("test-order-1"),
nautilus_side,
order_type,
Quantity::from("0.01"),
false, Some(TimeInForce::Gtc),
price,
trigger_price,
None, None, false, None, None, )
.unwrap()
}
fn create_conditional_order_params_with_reduce_only(
order_type: OrderType,
trigger_price: Option<Price>,
price: Option<Price>,
reduce_only: Option<bool>,
) -> BybitWsPlaceOrderParams {
let is_stop_order = matches!(
order_type,
OrderType::StopMarket
| OrderType::StopLimit
| OrderType::MarketIfTouched
| OrderType::LimitIfTouched
);
let bybit_order_type = match order_type {
OrderType::Market | OrderType::StopMarket | OrderType::MarketIfTouched => {
BybitOrderType::Market
}
OrderType::Limit | OrderType::StopLimit | OrderType::LimitIfTouched => {
BybitOrderType::Limit
}
_ => panic!("Unsupported order type"),
};
if is_stop_order {
BybitWsPlaceOrderParams {
category: BybitProductType::Linear,
symbol: "ETHUSDT".into(),
side: BybitOrderSide::Buy,
order_type: bybit_order_type,
qty: "0.01".to_string(),
is_leverage: None,
market_unit: None,
price: price.map(|p| p.to_string()),
time_in_force: Some(BybitTimeInForce::Gtc),
order_link_id: Some("test-order-1".to_string()),
reduce_only: reduce_only.filter(|&r| r),
close_on_trigger: None,
trigger_price: trigger_price.map(|p| p.to_string()),
trigger_by: Some(BybitTriggerType::LastPrice),
trigger_direction: None,
tpsl_mode: None,
take_profit: None,
stop_loss: None,
tp_trigger_by: None,
sl_trigger_by: None,
sl_trigger_price: None,
tp_trigger_price: None,
sl_order_type: None,
tp_order_type: None,
sl_limit_price: None,
tp_limit_price: None,
}
} else {
BybitWsPlaceOrderParams {
category: BybitProductType::Linear,
symbol: "ETHUSDT".into(),
side: BybitOrderSide::Buy,
order_type: bybit_order_type,
qty: "0.01".to_string(),
is_leverage: None,
market_unit: None,
price: price.map(|p| p.to_string()),
time_in_force: Some(BybitTimeInForce::Gtc),
order_link_id: Some("test-order-1".to_string()),
reduce_only: reduce_only.filter(|&r| r),
close_on_trigger: None,
trigger_price: None,
trigger_by: None,
trigger_direction: None,
tpsl_mode: None,
take_profit: None,
stop_loss: None,
tp_trigger_by: None,
sl_trigger_by: None,
sl_trigger_price: None,
tp_trigger_price: None,
sl_order_type: None,
tp_order_type: None,
sl_limit_price: None,
tp_limit_price: None,
}
}
}
}
#[rstest]
#[tokio::test]
async fn test_is_active_lifecycle() {
let (addr, _state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/private");
let mut client = BybitWebSocketClient::new_private(
BybitEnvironment::Mainnet,
Some("test_key".to_string()),
Some("test_secret".to_string()),
Some(ws_url),
None,
);
assert!(
!client.is_active(),
"Client should not be active before connect"
);
client.connect().await.unwrap();
client.wait_until_active(5.0).await.unwrap();
assert!(
client.is_active(),
"Client should be active after connect completes"
);
client.close().await.unwrap();
wait_until_async(|| async { !client.is_active() }, Duration::from_secs(2)).await;
assert!(
!client.is_active(),
"Client should not be active after close"
);
}
#[tokio::test]
async fn test_is_active_false_after_close() {
let (addr, _state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/private");
let mut client = BybitWebSocketClient::new_private(
BybitEnvironment::Mainnet,
Some("test_key".to_string()),
Some("test_secret".to_string()),
Some(ws_url),
None,
);
client.connect().await.unwrap();
client.wait_until_active(5.0).await.unwrap();
assert!(
client.is_active(),
"Expected is_active() to be true after connect"
);
client.close().await.unwrap();
wait_until_async(|| async { !client.is_active() }, Duration::from_secs(2)).await;
assert!(
!client.is_active(),
"Expected is_active() to be false after close"
);
assert!(
client.is_closed(),
"Expected is_closed() to be true after close"
);
}
#[rstest]
#[tokio::test]
async fn test_subscribe_after_stream_call() {
let (addr, _state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/public/linear");
let mut client = BybitWebSocketClient::new_public_with(
BybitProductType::Linear,
BybitEnvironment::Mainnet,
Some(ws_url),
None,
);
client.connect().await.unwrap();
client.wait_until_active(5.0).await.unwrap();
let _stream = client.stream();
tokio::spawn(async move {
tokio::pin!(_stream);
});
tokio::time::sleep(Duration::from_millis(100)).await;
let result = client
.subscribe(vec!["publicTrade.BTCUSDT".to_string()])
.await;
assert!(
result.is_ok(),
"Subscribe should work after stream() is called, but got error: {:?}",
result.err()
);
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_unsubscribed_private_channel_not_resubscribed_after_disconnect() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/private");
let mut client = BybitWebSocketClient::new_private(
BybitEnvironment::Mainnet,
Some("test_api_key".to_string()),
Some("test_api_secret".to_string()),
Some(ws_url.clone()),
None,
);
client.connect().await.unwrap();
let instrument_id = InstrumentId::from("BTCUSDT.BYBIT");
client.subscribe_trades(instrument_id).await.unwrap();
client.subscribe_positions().await.unwrap();
wait_for_subscription_events(&state, Duration::from_secs(5), |events| {
events
.iter()
.any(|(topic, ok)| topic == "publicTrade.BTCUSDT" && *ok)
&& events.iter().any(|(topic, ok)| topic == "position" && *ok)
})
.await;
{
let subs = state.subscriptions.lock().await;
assert!(subs.contains(&"publicTrade.BTCUSDT".to_string()));
assert!(subs.contains(&"position".to_string()));
}
client.unsubscribe_positions().await.unwrap();
wait_until_async(
|| {
let state = state.clone();
async move {
let subs = state.subscriptions.lock().await;
!subs.contains(&"position".to_string())
}
},
Duration::from_secs(5),
)
.await;
{
let subs = state.subscriptions.lock().await;
assert!(!subs.contains(&"position".to_string()));
assert!(subs.contains(&"publicTrade.BTCUSDT".to_string()));
}
state.clear_subscription_events().await;
wait_until_async(
|| {
let state = state.clone();
async move { state.subscription_events().await.is_empty() }
},
Duration::from_secs(5),
)
.await;
state.disconnect_trigger.store(true, Ordering::Relaxed);
wait_for_connection_count(&state, 0, Duration::from_secs(2)).await;
state.disconnect_trigger.store(false, Ordering::Relaxed);
client.wait_until_active(10.0).await.unwrap();
wait_for_connection_count(&state, 1, Duration::from_secs(5)).await;
wait_for_subscription_events(&state, Duration::from_secs(10), |events| {
events
.iter()
.any(|(topic, ok)| topic == "publicTrade.BTCUSDT" && *ok)
})
.await;
let subs = state.subscriptions.lock().await;
let events = state.subscription_events().await;
assert!(
subs.contains(&"publicTrade.BTCUSDT".to_string()),
"Trade subscription should be restored after reconnection"
);
assert!(
!subs.contains(&"position".to_string()),
"Position subscription should NOT be restored after unsubscribe and reconnect"
);
assert!(
!events.iter().any(|(topic, _ok)| topic == "position"),
"Position should not appear in subscription events after reconnect; events={events:?}"
);
assert!(
events
.iter()
.any(|(topic, ok)| topic == "publicTrade.BTCUSDT" && *ok),
"Trade subscription should be restored; events={events:?}"
);
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_batch_place_orders_with_cache_keys() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/private");
let mut client = BybitWebSocketClient::new_private(
BybitEnvironment::Mainnet,
Some("test_api_key".to_string()),
Some("test_api_secret".to_string()),
Some(ws_url),
None,
);
client.connect().await.unwrap();
wait_until_async(
|| async { state.authenticated.load(Ordering::Relaxed) },
Duration::from_secs(5),
)
.await;
let btcusdt_linear = make_linear_pair("BTCUSDT", "BTC", "USDT");
client.cache_instrument(InstrumentAny::CurrencyPair(btcusdt_linear));
let orders = vec![BybitWsPlaceOrderParams {
category: BybitProductType::Linear,
symbol: Ustr::from("BTCUSDT"),
side: BybitOrderSide::Buy,
order_type: BybitOrderType::Limit,
qty: "0.001".to_string(),
is_leverage: None,
market_unit: None,
price: Some("50000.0".to_string()),
time_in_force: Some(BybitTimeInForce::Gtc),
order_link_id: Some("test-order-1".to_string()),
reduce_only: None,
close_on_trigger: None,
trigger_price: None,
trigger_by: None,
trigger_direction: None,
tpsl_mode: None,
take_profit: None,
stop_loss: None,
tp_trigger_by: None,
sl_trigger_by: None,
sl_trigger_price: None,
tp_trigger_price: None,
sl_order_type: None,
tp_order_type: None,
sl_limit_price: None,
tp_limit_price: None,
}];
let trader_id = TraderId::from("TRADER-001");
let strategy_id = StrategyId::from("STRATEGY-001");
let result = client
.batch_place_orders(trader_id, strategy_id, orders)
.await;
assert!(
result.is_ok(),
"Batch place orders should succeed with proper cache keys"
);
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_batch_amend_orders() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/private");
let mut client = BybitWebSocketClient::new_private(
BybitEnvironment::Mainnet,
Some("test_api_key".to_string()),
Some("test_api_secret".to_string()),
Some(ws_url),
None,
);
client.connect().await.unwrap();
wait_until_async(
|| async { state.authenticated.load(Ordering::Relaxed) },
Duration::from_secs(5),
)
.await;
let orders = vec![BybitWsAmendOrderParams {
category: BybitProductType::Linear,
symbol: Ustr::from("BTCUSDT"),
order_id: None,
order_link_id: Some("test-order-1".to_string()),
qty: Some("0.002".to_string()),
price: Some("51000.0".to_string()),
trigger_price: None,
take_profit: None,
stop_loss: None,
tp_trigger_by: None,
sl_trigger_by: None,
}];
let trader_id = TraderId::from("TRADER-001");
let strategy_id = StrategyId::from("STRATEGY-001");
let result = client
.batch_amend_orders(trader_id, strategy_id, orders)
.await;
assert!(result.is_ok(), "Batch amend orders should succeed");
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_batch_cancel_orders() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/private");
let mut client = BybitWebSocketClient::new_private(
BybitEnvironment::Mainnet,
Some("test_api_key".to_string()),
Some("test_api_secret".to_string()),
Some(ws_url),
None,
);
client.connect().await.unwrap();
wait_until_async(
|| async { state.authenticated.load(Ordering::Relaxed) },
Duration::from_secs(5),
)
.await;
let trader_id = TraderId::from("TESTER-001");
let strategy_id = StrategyId::from("S-001");
let btcusdt_linear = make_linear_pair("BTCUSDT", "BTC", "USDT");
let ethusdt_linear = make_linear_pair("ETHUSDT", "ETH", "USDT");
client.cache_instrument(InstrumentAny::CurrencyPair(btcusdt_linear));
client.cache_instrument(InstrumentAny::CurrencyPair(ethusdt_linear));
let orders = vec![
BybitWsCancelOrderParams {
category: BybitProductType::Linear,
symbol: Ustr::from("BTCUSDT"),
order_id: None,
order_link_id: Some("test-order-1".to_string()),
},
BybitWsCancelOrderParams {
category: BybitProductType::Linear,
symbol: Ustr::from("ETHUSDT"),
order_id: None,
order_link_id: Some("test-order-2".to_string()),
},
];
let result = client
.batch_cancel_orders(trader_id, strategy_id, orders)
.await;
assert!(result.is_ok(), "Batch cancel orders should succeed");
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_batch_cancel_orders_chunking_over_20() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/private");
let mut client = BybitWebSocketClient::new_private(
BybitEnvironment::Mainnet,
Some("test_api_key".to_string()),
Some("test_api_secret".to_string()),
Some(ws_url),
None,
);
client.connect().await.unwrap();
wait_until_async(
|| async { state.authenticated.load(Ordering::Relaxed) },
Duration::from_secs(5),
)
.await;
let trader_id = TraderId::from("TESTER-001");
let strategy_id = StrategyId::from("S-001");
let btcusdt_linear = make_linear_pair("BTCUSDT", "BTC", "USDT");
client.cache_instrument(InstrumentAny::CurrencyPair(btcusdt_linear));
let orders: Vec<BybitWsCancelOrderParams> = (0..25)
.map(|i| BybitWsCancelOrderParams {
category: BybitProductType::Linear,
symbol: Ustr::from("BTCUSDT"),
order_id: Some(format!("order-{i}")),
order_link_id: Some(format!("client-order-{i}")),
})
.collect();
let result = client
.batch_cancel_orders(trader_id, strategy_id, orders)
.await;
assert!(result.is_ok(), "Batch cancel with chunking should succeed");
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_batch_cancel_orders_empty_list() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/private");
let mut client = BybitWebSocketClient::new_private(
BybitEnvironment::Mainnet,
Some("test_api_key".to_string()),
Some("test_api_secret".to_string()),
Some(ws_url),
None,
);
client.connect().await.unwrap();
wait_until_async(
|| async { state.authenticated.load(Ordering::Relaxed) },
Duration::from_secs(5),
)
.await;
let trader_id = TraderId::from("TESTER-001");
let strategy_id = StrategyId::from("S-001");
let orders: Vec<BybitWsCancelOrderParams> = vec![];
let result = client
.batch_cancel_orders(trader_id, strategy_id, orders)
.await;
assert!(
result.is_ok(),
"Batch cancel with empty list should succeed"
);
client.close().await.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_build_cancel_order_params_requires_order_id() {
let (addr, state) = start_test_server().await.unwrap();
let ws_url = format!("ws://{addr}/v5/private");
let mut client = BybitWebSocketClient::new_private(
BybitEnvironment::Mainnet,
Some("test_api_key".to_string()),
Some("test_api_secret".to_string()),
Some(ws_url),
None,
);
client.connect().await.unwrap();
wait_until_async(
|| async { state.authenticated.load(Ordering::Relaxed) },
Duration::from_secs(5),
)
.await;
let btcusdt_linear = make_linear_pair("BTCUSDT", "BTC", "USDT");
client.cache_instrument(InstrumentAny::CurrencyPair(btcusdt_linear.clone()));
let result =
client.build_cancel_order_params(BybitProductType::Linear, btcusdt_linear.id, None, None);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Either venue_order_id or client_order_id must be provided")
);
client.close().await.unwrap();
}