#![cfg(all(feature = "test-utils", feature = "tokio-comp"))]
#![allow(missing_docs, reason = "internal test utilities — public for cross-crate test use")]
use crate::models::streaming::StreamMessage;
use crate::websocket::aio::WebSocketClient;
use crate::websocket::{ConnectionConfig, ReconnectionConfig};
use crate::AuthRequest;
use futures_util::{SinkExt, StreamExt};
use std::collections::VecDeque;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::{mpsc, Mutex};
use tokio_tungstenite::tungstenite::Message;
pub struct MockWsServer {
addr: SocketAddr,
pending_sub_ids: Arc<Mutex<VecDeque<String>>>,
inject_tx: mpsc::UnboundedSender<MockInjection>,
}
enum MockInjection {
Frame(StreamMessage),
Close { code: u16, reason: String },
}
impl MockWsServer {
pub async fn start() -> Self {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind ephemeral port");
let addr = listener.local_addr().expect("local_addr");
let pending_sub_ids: Arc<Mutex<VecDeque<String>>> = Arc::new(Mutex::new(VecDeque::new()));
let (inject_tx, inject_rx) = mpsc::unbounded_channel::<MockInjection>();
let pending_clone = Arc::clone(&pending_sub_ids);
tokio::spawn(async move {
run_accept_loop(listener, pending_clone, inject_rx).await;
});
Self {
addr,
pending_sub_ids,
inject_tx,
}
}
pub fn url(&self) -> String {
format!("ws://{}/marketdata/v1.0/stock/streaming", self.addr)
}
pub fn address(&self) -> SocketAddr {
self.addr
}
pub async fn next_subscribe_id(&self, id: impl Into<String>) {
self.pending_sub_ids.lock().await.push_back(id.into());
}
pub async fn inject_frame(&self, frame: StreamMessage) {
let _ = self.inject_tx.send(MockInjection::Frame(frame));
}
pub async fn close(&self, code: u16, reason: impl Into<String>) {
let _ = self.inject_tx.send(MockInjection::Close {
code,
reason: reason.into(),
});
}
}
pub async fn aio_pair() -> (MockWsServer, WebSocketClient) {
let server = MockWsServer::start().await;
let auth = AuthRequest::with_api_key("mock-test-key");
let config = ConnectionConfig::builder(server.url(), auth).build();
let client = WebSocketClient::with_reconnection_config(config, ReconnectionConfig::disabled());
(server, client)
}
async fn run_accept_loop(
listener: TcpListener,
pending_sub_ids: Arc<Mutex<VecDeque<String>>>,
mut inject_rx: mpsc::UnboundedReceiver<MockInjection>,
) {
let Ok((stream, _peer)) = listener.accept().await else {
return;
};
let mut ws = match tokio_tungstenite::accept_async(stream).await {
Ok(ws) => ws,
Err(_) => return,
};
loop {
tokio::select! {
client_frame = ws.next() => {
match client_frame {
Some(Ok(Message::Text(text))) => {
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&text) {
let event = json.get("event").and_then(|v| v.as_str()).unwrap_or("");
match event {
"auth" => {
let ack = serde_json::json!({ "event": "authenticated" });
let _ = ws.send(Message::Text(ack.to_string().into())).await;
}
"subscribe" => {
let id = pending_sub_ids
.lock()
.await
.pop_front()
.unwrap_or_else(|| "mock-id".to_string());
let channel = json
.get("channel")
.and_then(|v| v.as_str())
.unwrap_or("trades");
let symbol = json
.get("symbol")
.and_then(|v| v.as_str())
.unwrap_or("");
let ack = serde_json::json!({
"event": "subscribed",
"id": id,
"channel": channel,
"symbol": symbol,
});
let _ = ws.send(Message::Text(ack.to_string().into())).await;
}
_ => {}
}
}
}
Some(Ok(Message::Close(_))) | None => break,
Some(Err(_)) => break,
_ => {}
}
}
injection = inject_rx.recv() => {
match injection {
Some(MockInjection::Frame(frame)) => {
if let Ok(text) = serde_json::to_string(&frame) {
let _ = ws.send(Message::Text(text.into())).await;
}
}
Some(MockInjection::Close { code, reason }) => {
let _ = ws
.send(Message::Close(Some(
tokio_tungstenite::tungstenite::protocol::CloseFrame {
code: code.into(),
reason: reason.into(),
},
)))
.await;
break;
}
None => break,
}
}
}
}
}