#![cfg(all(feature = "test-utils", feature = "tokio-comp"))]
#![allow(missing_docs, reason = "internal test utilities — public for cross-crate test use")]
use crate::models::streaming::StreamMessage;
use crate::websocket::aio::WebSocketClient;
use crate::websocket::{ConnectionConfig, ReconnectionConfig};
use crate::AuthRequest;
use futures_util::{SinkExt, StreamExt};
use std::collections::VecDeque;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::{mpsc, oneshot, Mutex};
use tokio_tungstenite::tungstenite::Message;
struct ClientHandle {
pending_sub_ids: Arc<Mutex<VecDeque<String>>>,
inject_tx: mpsc::UnboundedSender<MockInjection>,
transport_drop: Mutex<Option<oneshot::Sender<()>>>,
}
pub struct MockWsServer {
addr: SocketAddr,
clients: Vec<ClientHandle>,
}
enum MockInjection {
Frame(StreamMessage),
Close { code: u16, reason: String },
}
impl MockWsServer {
pub async fn start() -> Self {
Self::start_with_capacity(1).await
}
pub async fn start_with_capacity(capacity: usize) -> Self {
assert!(
capacity > 0,
"MockWsServer::start_with_capacity: capacity must be > 0"
);
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind ephemeral port");
let addr = listener.local_addr().expect("local_addr");
let mut clients = Vec::with_capacity(capacity);
let mut accept_seeds = Vec::with_capacity(capacity);
for _ in 0..capacity {
let pending_sub_ids: Arc<Mutex<VecDeque<String>>> =
Arc::new(Mutex::new(VecDeque::new()));
let (inject_tx, inject_rx) = mpsc::unbounded_channel::<MockInjection>();
let (drop_tx, drop_rx) = oneshot::channel::<()>();
clients.push(ClientHandle {
pending_sub_ids: Arc::clone(&pending_sub_ids),
inject_tx,
transport_drop: Mutex::new(Some(drop_tx)),
});
accept_seeds.push(AcceptSeed {
pending_sub_ids,
inject_rx,
drop_rx,
});
}
tokio::spawn(async move {
run_accept_loop(listener, accept_seeds).await;
});
Self { addr, clients }
}
pub fn url(&self) -> String {
format!("ws://{}/marketdata/v1.0/stock/streaming", self.addr)
}
pub fn address(&self) -> SocketAddr {
self.addr
}
pub fn capacity(&self) -> usize {
self.clients.len()
}
pub async fn next_subscribe_id(&self, id: impl Into<String>) {
self.assert_single_client("next_subscribe_id");
self.next_subscribe_id_for(0, id).await;
}
pub async fn next_subscribe_id_for(&self, client_idx: usize, id: impl Into<String>) {
let client = self.client_or_panic(client_idx, "next_subscribe_id_for");
client.pending_sub_ids.lock().await.push_back(id.into());
}
pub async fn inject_frame(&self, frame: StreamMessage) {
self.assert_single_client("inject_frame");
self.inject_frame_for(0, frame).await;
}
pub async fn inject_frame_for(&self, client_idx: usize, frame: StreamMessage) {
let client = self.client_or_panic(client_idx, "inject_frame_for");
let _ = client.inject_tx.send(MockInjection::Frame(frame));
}
pub async fn close(&self, code: u16, reason: impl Into<String>) {
self.assert_single_client("close");
self.close_for(0, code, reason).await;
}
pub async fn close_for(&self, client_idx: usize, code: u16, reason: impl Into<String>) {
let client = self.client_or_panic(client_idx, "close_for");
let _ = client.inject_tx.send(MockInjection::Close {
code,
reason: reason.into(),
});
}
pub async fn drop_transport(&self) {
self.assert_single_client("drop_transport");
self.drop_transport_for(0).await;
}
pub async fn drop_transport_for(&self, client_idx: usize) {
let client = self.client_or_panic(client_idx, "drop_transport_for");
let mut slot = client.transport_drop.lock().await;
if let Some(tx) = slot.take() {
let _ = tx.send(());
}
}
fn assert_single_client(&self, method: &str) {
assert!(
self.clients.len() == 1,
"MockWsServer::{method} requires capacity == 1; use {method}_for(client_idx, ...) for multi-client mocks"
);
}
fn client_or_panic(&self, client_idx: usize, method: &str) -> &ClientHandle {
let capacity = self.clients.len();
self.clients.get(client_idx).unwrap_or_else(|| {
panic!(
"MockWsServer::{method}: client_idx {client_idx} out of range (capacity = {capacity})"
)
})
}
}
struct AcceptSeed {
pending_sub_ids: Arc<Mutex<VecDeque<String>>>,
inject_rx: mpsc::UnboundedReceiver<MockInjection>,
drop_rx: oneshot::Receiver<()>,
}
pub async fn aio_pair() -> (MockWsServer, WebSocketClient) {
let (server, mut clients) = aio_pair_n(1).await;
let client = clients.remove(0);
(server, client)
}
pub async fn aio_pair_n(n: usize) -> (MockWsServer, Vec<WebSocketClient>) {
let server = MockWsServer::start_with_capacity(n).await;
let mut clients = Vec::with_capacity(n);
for _ in 0..n {
let auth = AuthRequest::with_api_key("mock-test-key");
let config = ConnectionConfig::builder(server.url(), auth).build();
let client = WebSocketClient::with_reconnection_config(config, ReconnectionConfig::disabled());
clients.push(client);
}
(server, clients)
}
async fn run_accept_loop(listener: TcpListener, seeds: Vec<AcceptSeed>) {
for seed in seeds {
let Ok((stream, _peer)) = listener.accept().await else {
return;
};
let ws = match tokio_tungstenite::accept_async(stream).await {
Ok(ws) => ws,
Err(_) => continue,
};
tokio::spawn(run_client_loop(ws, seed));
}
}
async fn run_client_loop(
mut ws: tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
seed: AcceptSeed,
) {
let AcceptSeed {
pending_sub_ids,
mut inject_rx,
mut drop_rx,
} = seed;
loop {
tokio::select! {
_ = &mut drop_rx => {
drop(ws);
return;
}
client_frame = ws.next() => {
match client_frame {
Some(Ok(Message::Text(text))) => {
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&text) {
let event = json.get("event").and_then(|v| v.as_str()).unwrap_or("");
match event {
"auth" => {
let ack = serde_json::json!({ "event": "authenticated" });
let _ = ws.send(Message::Text(ack.to_string().into())).await;
}
"subscribe" => {
let id = pending_sub_ids
.lock()
.await
.pop_front()
.unwrap_or_else(|| "mock-id".to_string());
let channel = json
.get("channel")
.and_then(|v| v.as_str())
.unwrap_or("trades");
let symbol = json
.get("symbol")
.and_then(|v| v.as_str())
.unwrap_or("");
let ack = serde_json::json!({
"event": "subscribed",
"id": id,
"channel": channel,
"symbol": symbol,
});
let _ = ws.send(Message::Text(ack.to_string().into())).await;
}
_ => {}
}
}
}
Some(Ok(Message::Close(_))) | None => break,
Some(Err(_)) => break,
_ => {}
}
}
injection = inject_rx.recv() => {
match injection {
Some(MockInjection::Frame(frame)) => {
if let Ok(text) = serde_json::to_string(&frame) {
let _ = ws.send(Message::Text(text.into())).await;
}
}
Some(MockInjection::Close { code, reason }) => {
let _ = ws
.send(Message::Close(Some(
tokio_tungstenite::tungstenite::protocol::CloseFrame {
code: code.into(),
reason: reason.into(),
},
)))
.await;
break;
}
None => break,
}
}
}
}
}