use std::{net::SocketAddr, time::Duration};
use futures::{SinkExt, StreamExt};
use simulator_api::{
AgentStatsReport, BacktestError, BacktestRequest, BacktestResponse, BacktestStatus,
ContinueParams, CreateSessionParams, SequencedResponse, SessionEventKind, SessionSummary,
};
use simulator_client::{
BacktestClient, BacktestClientError, Continue, CreateSession, ManagedBacktestSession,
ReadyOutcome,
managed::{
ManagedEvent, ManagedParallelSession, ManagedSessionError, ParallelSubSession,
SubscriptionNotification, spawn_account_diff_subscription_manager,
},
};
use tokio::net::{TcpListener, TcpStream};
use tokio_tungstenite::{
WebSocketStream, accept_async, accept_hdr_async,
tungstenite::{
Message,
handshake::server::{ErrorResponse, Request, Response},
},
};
use tokio_util::sync::CancellationToken;
fn assert_expected_api_key(req: &Request, expected_api_key: &str) {
let api_key = req.headers().get("X-API-Key").and_then(|v| v.to_str().ok());
assert_eq!(api_key, Some(expected_api_key));
}
#[allow(clippy::result_large_err)]
async fn accept_with_expected_api_key(
stream: tokio::net::TcpStream,
expected_api_key: &'static str,
) -> tokio_tungstenite::tungstenite::Result<tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>>
{
accept_hdr_async(stream, move |req: &Request, resp: Response| {
assert_expected_api_key(req, expected_api_key);
Ok::<_, ErrorResponse>(resp)
})
.await
}
async fn spawn_server(
expected_api_key: &'static str,
handler: impl FnOnce(
tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
) -> tokio::task::JoinHandle<()>
+ Send
+ 'static,
) -> (String, tokio::task::JoinHandle<()>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr: SocketAddr = listener.local_addr().unwrap();
let url = format!("ws://{addr}/backtest");
let join = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let ws = accept_with_expected_api_key(stream, expected_api_key)
.await
.unwrap();
handler(ws).await.unwrap();
});
(url, join)
}
async fn spawn_rpc_server(
handler: impl FnOnce(WebSocketStream<TcpStream>) -> tokio::task::JoinHandle<()> + Send + 'static,
) -> (String, tokio::task::JoinHandle<()>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr: SocketAddr = listener.local_addr().unwrap();
let url = format!("http://{addr}");
let join = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let ws = accept_async(stream).await.unwrap();
handler(ws).await.unwrap();
});
(url, join)
}
#[tokio::test]
async fn subscription_drains_all_notifications_then_terminal_closes_channel() {
let (url, server) = spawn_rpc_server(|mut ws| {
tokio::spawn(async move {
let msg = ws.next().await.unwrap().unwrap();
let Message::Text(text) = msg else {
panic!("expected text");
};
let req: serde_json::Value = serde_json::from_str(&text).unwrap();
assert_eq!(req["method"], "accountDiffSubscribe");
let id = req["id"].clone();
ws.send(Message::Text(
serde_json::json!({ "jsonrpc": "2.0", "id": id, "result": 1 }).to_string(),
))
.await
.unwrap();
for slot in 0..5u64 {
ws.send(Message::Text(
serde_json::json!({
"jsonrpc": "2.0",
"method": "accountDiffNotification",
"params": { "subscription": 1, "result": { "context": { "slot": slot } } },
})
.to_string(),
))
.await
.unwrap();
}
ws.send(Message::Text(
serde_json::json!({
"jsonrpc": "2.0",
"method": "subscriptionComplete",
"params": { "subscription": 1 },
})
.to_string(),
))
.await
.unwrap();
tokio::time::sleep(Duration::from_secs(2)).await;
})
})
.await;
let cancel = CancellationToken::new();
let mut handle =
spawn_account_diff_subscription_manager(url, vec!["prog".to_string()], cancel, None);
let mut received = 0;
while let Some(notification) = handle.notifications.recv().await {
assert!(matches!(
notification,
SubscriptionNotification::AccountDiff(_)
));
received += 1;
}
assert_eq!(received, 5);
server.await.unwrap();
}
#[tokio::test]
async fn next_event_drains_subscriptions_before_completed() {
let (rpc_url, rpc_server) = spawn_rpc_server(|mut ws| {
tokio::spawn(async move {
let msg = ws.next().await.unwrap().unwrap();
let Message::Text(text) = msg else {
panic!("expected text");
};
let req: serde_json::Value = serde_json::from_str(&text).unwrap();
assert_eq!(req["method"], "accountDiffSubscribe");
let id = req["id"].clone();
ws.send(Message::Text(
serde_json::json!({ "jsonrpc": "2.0", "id": id, "result": 1 }).to_string(),
))
.await
.unwrap();
for slot in 0..5u64 {
ws.send(Message::Text(
serde_json::json!({
"jsonrpc": "2.0",
"method": "accountDiffNotification",
"params": { "subscription": 1, "result": { "context": { "slot": slot } } },
})
.to_string(),
))
.await
.unwrap();
}
ws.send(Message::Text(
serde_json::json!({
"jsonrpc": "2.0",
"method": "subscriptionComplete",
"params": { "subscription": 1 },
})
.to_string(),
))
.await
.unwrap();
tokio::time::sleep(Duration::from_secs(2)).await;
})
})
.await;
let (ctrl_url, ctrl_server) = spawn_server("k", move |mut ws| {
tokio::spawn(async move {
let msg = ws.next().await.unwrap().unwrap();
let Message::Text(text) = msg else {
panic!("expected text");
};
let req: BacktestRequest = serde_json::from_str(&text).unwrap();
assert!(matches!(req, BacktestRequest::CreateBacktestSession(_)));
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SessionCreated {
session_id: "s1".to_string(),
rpc_endpoint: rpc_url,
task_id: None,
})
.unwrap(),
))
.await
.unwrap();
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::ReadyForContinue).unwrap(),
))
.await
.unwrap();
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::Completed {
summary: Some(SessionSummary {
correct_simulation: 7,
incorrect_simulation: 2,
..Default::default()
}),
agent_stats: Some(vec![AgentStatsReport {
name: "agent-1".to_string(),
slots_processed: 6,
..Default::default()
}]),
})
.unwrap(),
))
.await
.unwrap();
tokio::time::sleep(Duration::from_secs(2)).await;
})
})
.await;
let create = CreateSession::builder()
.start_slot(100)
.end_slot(105)
.build()
.into_request()
.unwrap();
let mut session = ManagedBacktestSession::start(ctrl_url, "k".to_string(), create)
.await
.unwrap();
session.subscribe_account_diffs(vec!["prog".to_string()]);
let mut account_diffs = 0;
let (summary, agent_stats) = loop {
match session.next_event().await.unwrap() {
ManagedEvent::AccountDiff(_) => account_diffs += 1,
ManagedEvent::Completed {
summary,
agent_stats,
} => break (summary, agent_stats),
ManagedEvent::Error(e) => panic!("unexpected error: {e}"),
_ => {}
}
};
assert_eq!(account_diffs, 5);
let summary = summary.expect("summary must survive the completion drain");
assert_eq!(summary.correct_simulation, 7);
assert_eq!(summary.incorrect_simulation, 2);
let agent_stats = agent_stats.expect("agent stats must survive the completion drain");
assert_eq!(agent_stats.len(), 1);
assert_eq!(agent_stats[0].name, "agent-1");
assert_eq!(agent_stats[0].slots_processed, 6);
session.shutdown().await;
let _ = rpc_server.await;
let _ = ctrl_server.await;
}
#[tokio::test]
async fn slow_notification_stream_is_not_truncated_by_drain_timeout() {
const N: u64 = 20;
let (rpc_url, rpc_server) = spawn_rpc_server(|mut ws| {
tokio::spawn(async move {
let msg = ws.next().await.unwrap().unwrap();
let Message::Text(text) = msg else {
panic!("expected text");
};
let req: serde_json::Value = serde_json::from_str(&text).unwrap();
assert_eq!(req["method"], "accountDiffSubscribe");
let id = req["id"].clone();
ws.send(Message::Text(
serde_json::json!({ "jsonrpc": "2.0", "id": id, "result": 1 }).to_string(),
))
.await
.unwrap();
for slot in 0..N {
ws.send(Message::Text(
serde_json::json!({
"jsonrpc": "2.0",
"method": "accountDiffNotification",
"params": { "subscription": 1, "result": { "context": { "slot": slot } } },
})
.to_string(),
))
.await
.unwrap();
if slot + 1 < N {
tokio::time::sleep(Duration::from_millis(50)).await;
}
}
ws.send(Message::Text(
serde_json::json!({
"jsonrpc": "2.0",
"method": "subscriptionComplete",
"params": { "subscription": 1 },
})
.to_string(),
))
.await
.unwrap();
tokio::time::sleep(Duration::from_secs(2)).await;
})
})
.await;
let (ctrl_url, ctrl_server) = spawn_server("k", move |mut ws| {
tokio::spawn(async move {
let msg = ws.next().await.unwrap().unwrap();
let Message::Text(text) = msg else {
panic!("expected text");
};
let req: BacktestRequest = serde_json::from_str(&text).unwrap();
assert!(matches!(req, BacktestRequest::CreateBacktestSession(_)));
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SessionCreated {
session_id: "s1".to_string(),
rpc_endpoint: rpc_url,
task_id: None,
})
.unwrap(),
))
.await
.unwrap();
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::ReadyForContinue).unwrap(),
))
.await
.unwrap();
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::Completed {
summary: None,
agent_stats: None,
})
.unwrap(),
))
.await
.unwrap();
tokio::time::sleep(Duration::from_secs(2)).await;
})
})
.await;
let create = CreateSession::builder()
.start_slot(100)
.end_slot(105)
.build()
.into_request()
.unwrap();
let mut session = ManagedBacktestSession::start(ctrl_url, "k".to_string(), create)
.await
.unwrap();
session.set_completion_drain_timeout(Duration::from_millis(300));
session.subscribe_account_diffs(vec!["prog".to_string()]);
let mut account_diffs = 0u64;
loop {
match session.next_event().await.unwrap() {
ManagedEvent::AccountDiff(_) => account_diffs += 1,
ManagedEvent::Completed { .. } => break,
ManagedEvent::Error(e) => panic!("unexpected error: {e}"),
_ => {}
}
}
assert_eq!(
account_diffs, N,
"drain truncated a slow stream: {account_diffs}/{N} delivered (idle timeout regressed to a wall-clock cap?)"
);
session.shutdown().await;
let _ = rpc_server.await;
let _ = ctrl_server.await;
}
#[tokio::test]
async fn stalled_completion_drain_surfaces_failure_not_completed() {
let (rpc_url, rpc_server) = spawn_rpc_server(|mut ws| {
tokio::spawn(async move {
let msg = ws.next().await.unwrap().unwrap();
let Message::Text(text) = msg else {
panic!("expected text");
};
let req: serde_json::Value = serde_json::from_str(&text).unwrap();
let id = req["id"].clone();
ws.send(Message::Text(
serde_json::json!({ "jsonrpc": "2.0", "id": id, "result": 1 }).to_string(),
))
.await
.unwrap();
for slot in 0..3u64 {
ws.send(Message::Text(
serde_json::json!({
"jsonrpc": "2.0",
"method": "accountDiffNotification",
"params": { "subscription": 1, "result": { "context": { "slot": slot } } },
})
.to_string(),
))
.await
.unwrap();
}
tokio::time::sleep(Duration::from_secs(5)).await;
})
})
.await;
let (ctrl_url, ctrl_server) = spawn_server("k", move |mut ws| {
tokio::spawn(async move {
let msg = ws.next().await.unwrap().unwrap();
let Message::Text(text) = msg else {
panic!("expected text");
};
let req: BacktestRequest = serde_json::from_str(&text).unwrap();
assert!(matches!(req, BacktestRequest::CreateBacktestSession(_)));
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SessionCreated {
session_id: "s1".to_string(),
rpc_endpoint: rpc_url,
task_id: None,
})
.unwrap(),
))
.await
.unwrap();
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::ReadyForContinue).unwrap(),
))
.await
.unwrap();
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::Completed {
summary: None,
agent_stats: None,
})
.unwrap(),
))
.await
.unwrap();
tokio::time::sleep(Duration::from_secs(5)).await;
})
})
.await;
let create = CreateSession::builder()
.start_slot(100)
.end_slot(105)
.build()
.into_request()
.unwrap();
let mut session = ManagedBacktestSession::start(ctrl_url, "k".to_string(), create)
.await
.unwrap();
session.set_completion_drain_timeout(Duration::from_millis(200));
session.subscribe_account_diffs(vec!["prog".to_string()]);
let mut account_diffs = 0u64;
loop {
match session.next_event().await {
Ok(ManagedEvent::AccountDiff(_)) => account_diffs += 1,
Ok(ManagedEvent::Completed { .. }) => {
panic!("a stalled drain must not report Completed");
}
Ok(ManagedEvent::Error(e)) => panic!("unexpected error event: {e}"),
Ok(_) => {}
Err(ManagedSessionError::SubscriptionFailed(_)) => break,
Err(e) => panic!("unexpected error: {e}"),
}
}
assert_eq!(account_diffs, 3);
session.shutdown().await;
let _ = rpc_server.await;
let _ = ctrl_server.await;
}
#[tokio::test]
async fn reconnect_resumes_from_last_slot_via_replay_cursor() {
async fn next_subscribe(ws: &mut WebSocketStream<TcpStream>) -> serde_json::Value {
loop {
match ws.next().await.unwrap().unwrap() {
Message::Text(t) => return serde_json::from_str(&t).unwrap(),
_ => continue,
}
}
}
async fn send_diff(ws: &mut WebSocketStream<TcpStream>, slot: u64) {
ws.send(Message::Text(
serde_json::json!({
"jsonrpc": "2.0",
"method": "accountDiffNotification",
"params": { "subscription": 1, "result": { "context": { "slot": slot } } },
})
.to_string(),
))
.await
.unwrap();
}
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr: SocketAddr = listener.local_addr().unwrap();
let url = format!("http://{addr}");
let server = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut ws = accept_async(stream).await.unwrap();
let req = next_subscribe(&mut ws).await;
assert_eq!(req["method"], "accountDiffSubscribe");
assert!(
req["params"][1].get("replayFromSlot").is_none(),
"first subscribe must not carry a resume cursor"
);
ws.send(Message::Text(
serde_json::json!({ "jsonrpc": "2.0", "id": req["id"], "result": 1 }).to_string(),
))
.await
.unwrap();
for slot in 0..=2u64 {
send_diff(&mut ws, slot).await;
}
drop(ws);
let (stream, _) = listener.accept().await.unwrap();
let mut ws = accept_async(stream).await.unwrap();
let req = next_subscribe(&mut ws).await;
assert_eq!(
req["params"][1]["replayFromSlot"],
serde_json::json!(2),
"reconnect must resume from the last slot delivered (inclusive)"
);
ws.send(Message::Text(
serde_json::json!({ "jsonrpc": "2.0", "id": req["id"], "result": 1 }).to_string(),
))
.await
.unwrap();
for slot in 2..=4u64 {
send_diff(&mut ws, slot).await;
}
ws.send(Message::Text(
serde_json::json!({
"jsonrpc": "2.0",
"method": "subscriptionComplete",
"params": { "subscription": 1 },
})
.to_string(),
))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(200)).await;
});
let cancel = CancellationToken::new();
let mut handle =
spawn_account_diff_subscription_manager(url, vec!["prog".to_string()], cancel, None);
let mut slots = Vec::new();
while let Some(SubscriptionNotification::AccountDiff(diff)) = handle.notifications.recv().await
{
slots.push(diff.context.slot);
}
server.await.unwrap();
assert_eq!(slots, vec![0, 1, 2, 2, 3, 4]);
}
#[tokio::test]
async fn compressed_binary_notifications_are_decoded() {
use simulator_api::ws_compression::{WS_COMPRESSION_LEVEL, WsStreamCompressor};
let (url, server) = spawn_rpc_server(|mut ws| {
tokio::spawn(async move {
let req: serde_json::Value = loop {
match ws.next().await.unwrap().unwrap() {
Message::Text(t) => break serde_json::from_str(&t).unwrap(),
_ => continue,
}
};
assert_eq!(req["method"], "accountDiffSubscribe");
assert_eq!(
req["params"][1]["compression"],
serde_json::json!("zstd"),
"client must opt into compression"
);
ws.send(Message::Text(
serde_json::json!({ "jsonrpc": "2.0", "id": req["id"], "result": 1 }).to_string(),
))
.await
.unwrap();
let mut comp = WsStreamCompressor::new(WS_COMPRESSION_LEVEL).unwrap();
let mut frames: Vec<String> = (0..5u64)
.map(|slot| {
serde_json::json!({
"jsonrpc": "2.0",
"method": "accountDiffNotification",
"params": { "subscription": 1, "result": { "context": { "slot": slot } } },
})
.to_string()
})
.collect();
frames.push(
serde_json::json!({
"jsonrpc": "2.0",
"method": "subscriptionComplete",
"params": { "subscription": 1 },
})
.to_string(),
);
for json in frames {
let frame = comp.compress(json.as_bytes()).unwrap();
ws.send(Message::Binary(frame)).await.unwrap();
}
tokio::time::sleep(Duration::from_secs(2)).await;
})
})
.await;
let cancel = CancellationToken::new();
let mut handle =
spawn_account_diff_subscription_manager(url, vec!["prog".to_string()], cancel, None);
let mut slots = Vec::new();
while let Some(SubscriptionNotification::AccountDiff(diff)) = handle.notifications.recv().await
{
slots.push(diff.context.slot);
}
server.await.unwrap();
assert_eq!(slots, vec![0, 1, 2, 3, 4]);
}
#[tokio::test]
async fn multi_program_compressed_handshake_buffers_interleaved_notifications() {
use simulator_api::ws_compression::{WS_COMPRESSION_LEVEL, WsStreamCompressor};
let (url, server) = spawn_rpc_server(|mut ws| {
tokio::spawn(async move {
async fn next_req(ws: &mut WebSocketStream<TcpStream>) -> serde_json::Value {
loop {
if let Message::Text(t) = ws.next().await.unwrap().unwrap() {
return serde_json::from_str(&t).unwrap();
}
}
}
let notif = |sub: u64, slot: u64| {
serde_json::json!({
"jsonrpc": "2.0",
"method": "accountDiffNotification",
"params": { "subscription": sub, "result": { "context": { "slot": slot } } },
})
.to_string()
};
let complete = |sub: u64| {
serde_json::json!({
"jsonrpc": "2.0",
"method": "subscriptionComplete",
"params": { "subscription": sub },
})
.to_string()
};
let mut comp = WsStreamCompressor::new(WS_COMPRESSION_LEVEL).unwrap();
let req1 = next_req(&mut ws).await;
ws.send(Message::Text(
serde_json::json!({ "jsonrpc": "2.0", "id": req1["id"], "result": 1 }).to_string(),
))
.await
.unwrap();
ws.send(Message::Binary(
comp.compress(notif(1, 0).as_bytes()).unwrap(),
))
.await
.unwrap();
let req2 = next_req(&mut ws).await;
ws.send(Message::Text(
serde_json::json!({ "jsonrpc": "2.0", "id": req2["id"], "result": 2 }).to_string(),
))
.await
.unwrap();
for json in [notif(1, 1), notif(2, 0), complete(1), complete(2)] {
ws.send(Message::Binary(comp.compress(json.as_bytes()).unwrap()))
.await
.unwrap();
}
tokio::time::sleep(Duration::from_secs(2)).await;
})
})
.await;
let cancel = CancellationToken::new();
let mut handle = spawn_account_diff_subscription_manager(
url,
vec!["progA".to_string(), "progB".to_string()],
cancel,
None,
);
let mut slots = Vec::new();
while let Some(SubscriptionNotification::AccountDiff(diff)) = handle.notifications.recv().await
{
slots.push(diff.context.slot);
}
server.await.unwrap();
slots.sort_unstable();
assert_eq!(slots, vec![0, 0, 1]);
}
#[tokio::test]
async fn creates_session_waits_ready_advances_and_closes() {
let (url, server) = spawn_server("k", |mut ws| {
tokio::spawn(async move {
eprintln!("server: waiting for resume request");
let msg = ws.next().await.unwrap().unwrap();
let Message::Text(text) = msg else {
panic!("expected text");
};
let req: BacktestRequest = serde_json::from_str(&text).unwrap();
let BacktestRequest::CreateBacktestSession(request) = req else {
panic!("expected create");
};
let (
CreateSessionParams {
start_slot,
end_slot,
..
},
parallel,
) = request.into_request_and_parallel();
assert!(!parallel);
assert_eq!(start_slot, 100);
assert_eq!(end_slot, 105);
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SessionCreated {
session_id: "s1".to_string(),
rpc_endpoint: "http://rpc".to_string(),
task_id: None,
})
.unwrap(),
))
.await
.unwrap();
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::ReadyForContinue).unwrap(),
))
.await
.unwrap();
let msg = ws.next().await.unwrap().unwrap();
let Message::Text(text) = msg else {
panic!("expected text");
};
let req: BacktestRequest = serde_json::from_str(&text).unwrap();
assert!(matches!(req, BacktestRequest::Continue(_)));
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::Status {
status: BacktestStatus::DecodedTransactions,
})
.unwrap(),
))
.await
.unwrap();
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SlotNotification(101)).unwrap(),
))
.await
.unwrap();
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SlotNotification(102)).unwrap(),
))
.await
.unwrap();
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::ReadyForContinue).unwrap(),
))
.await
.unwrap();
let msg = ws.next().await.unwrap().unwrap();
let Message::Text(text) = msg else {
panic!("expected text");
};
let req: BacktestRequest = serde_json::from_str(&text).unwrap();
assert!(matches!(req, BacktestRequest::CloseBacktestSession));
eprintln!("server: sending success");
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::Success).unwrap(),
))
.await
.unwrap();
})
})
.await;
let client = BacktestClient::builder().url(url).api_key("k").build();
let mut session = client
.create_session(
CreateSession::builder()
.start_slot(100)
.slot_count(5)
.build(),
)
.await
.unwrap();
assert_eq!(session.session_id(), Some("s1"));
assert_eq!(session.rpc_endpoint(), Some("http://rpc"));
let ready = session
.ensure_ready(Some(Duration::from_secs(2)))
.await
.unwrap();
assert_eq!(ready, ReadyOutcome::Ready);
let result = session
.advance(
Continue::builder().advance_count(2).build(),
Some(Duration::from_secs(2)),
|_| {},
)
.await
.unwrap();
assert_eq!(result.slot_notifications, 2);
assert_eq!(result.last_slot, Some(102));
assert!(result.ready_for_continue);
session.close(Some(Duration::from_secs(2))).await.unwrap();
server.await.unwrap();
}
#[tokio::test]
async fn surfaces_remote_error() {
let (url, server) = spawn_server("k", |mut ws| {
tokio::spawn(async move {
let msg = ws.next().await.unwrap().unwrap();
let Message::Text(text) = msg else {
panic!("expected text");
};
let req: BacktestRequest = serde_json::from_str(&text).unwrap();
assert!(matches!(req, BacktestRequest::CreateBacktestSession(_)));
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SessionCreated {
session_id: "s1".to_string(),
rpc_endpoint: "http://rpc".to_string(),
task_id: None,
})
.unwrap(),
))
.await
.unwrap();
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::ReadyForContinue).unwrap(),
))
.await
.unwrap();
let msg = ws.next().await.unwrap().unwrap();
let Message::Text(text) = msg else {
panic!("expected text");
};
let req: BacktestRequest = serde_json::from_str(&text).unwrap();
assert!(matches!(req, BacktestRequest::Continue(_)));
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::Error(BacktestError::NoMoreBlocks))
.unwrap(),
))
.await
.unwrap();
})
})
.await;
let client = BacktestClient::builder().url(url).api_key("k").build();
let mut session = client
.create_session(
CreateSession::builder()
.start_slot(100)
.end_slot(100)
.build(),
)
.await
.unwrap();
session
.ensure_ready(Some(Duration::from_secs(2)))
.await
.unwrap();
let err = session
.continue_until_ready(
Continue::builder().advance_count(1).build(),
Some(Duration::from_secs(2)),
|_| {},
)
.await
.unwrap_err();
assert!(matches!(
err,
BacktestClientError::Remote(BacktestError::NoMoreBlocks)
));
server.await.unwrap();
}
#[tokio::test]
async fn creates_parallel_sessions_and_returns_ids() {
let (url, server) = spawn_server("k", |mut ws| {
tokio::spawn(async move {
let msg = ws.next().await.unwrap().unwrap();
let Message::Text(text) = msg else {
panic!("expected text");
};
let payload: serde_json::Value = serde_json::from_str(&text).unwrap();
assert_eq!(payload["method"], "createBacktestSession");
assert_eq!(payload["params"]["parallel"], true);
assert_eq!(payload["params"]["startSlot"], 100);
assert!(payload["params"].get("request").is_none());
let req: BacktestRequest = serde_json::from_str(&text).unwrap();
let BacktestRequest::CreateBacktestSession(request) = req else {
panic!("expected create");
};
let (_, parallel) = request.into_request_and_parallel();
assert!(parallel);
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SessionsCreated {
session_ids: vec!["s1".to_string(), "s2".to_string()],
})
.unwrap(),
))
.await
.unwrap();
})
})
.await;
let client = BacktestClient::builder().url(url).api_key("k").build();
let session_ids = client
.create_sessions(
CreateSession::builder()
.start_slot(100)
.end_slot(105)
.parallel(true)
.build(),
)
.await
.unwrap();
assert_eq!(session_ids, vec!["s1".to_string(), "s2".to_string()]);
server.await.unwrap();
}
#[tokio::test]
async fn creates_parallel_sessions_from_streamed_session_created_events() {
let (url, server) = spawn_server("k", |mut ws| {
tokio::spawn(async move {
let msg = ws.next().await.unwrap().unwrap();
let Message::Text(text) = msg else {
panic!("expected text");
};
let req: BacktestRequest = serde_json::from_str(&text).unwrap();
let BacktestRequest::CreateBacktestSession(request) = req else {
panic!("expected create");
};
let (_, parallel) = request.into_request_and_parallel();
assert!(parallel);
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SessionCreated {
session_id: "s1".to_string(),
rpc_endpoint: "/backtest/s1".to_string(),
task_id: None,
})
.unwrap(),
))
.await
.unwrap();
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SessionCreated {
session_id: "s2".to_string(),
rpc_endpoint: "/backtest/s2".to_string(),
task_id: None,
})
.unwrap(),
))
.await
.unwrap();
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SessionsCreated {
session_ids: Vec::new(),
})
.unwrap(),
))
.await
.unwrap();
})
})
.await;
let client = BacktestClient::builder().url(url).api_key("k").build();
let mut streamed = Vec::new();
let session_ids = client
.create_sessions_with_progress(
CreateSession::builder()
.start_slot(100)
.end_slot(105)
.parallel(true)
.build(),
|session_id| streamed.push(session_id),
)
.await
.unwrap();
assert_eq!(streamed, vec!["s1".to_string(), "s2".to_string()]);
assert_eq!(session_ids, vec!["s1".to_string(), "s2".to_string()]);
server.await.unwrap();
}
#[tokio::test]
async fn attaches_to_existing_session() {
let (url, server) = spawn_server("k", |mut ws| {
tokio::spawn(async move {
let msg = ws.next().await.unwrap().unwrap();
let Message::Text(text) = msg else {
panic!("expected text");
};
let req: BacktestRequest = serde_json::from_str(&text).unwrap();
let BacktestRequest::AttachBacktestSession {
session_id,
last_sequence,
} = req
else {
panic!("expected attach");
};
assert_eq!(session_id, "s1");
assert_eq!(last_sequence, Some(7));
eprintln!("server: sending session attached");
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SessionAttached {
session_id: "s1".to_string(),
rpc_endpoint: "http://rpc/s1".to_string(),
task_id: None,
})
.unwrap(),
))
.await
.unwrap();
})
})
.await;
let client = BacktestClient::builder().url(url).api_key("k").build();
let session = client.attach_session("s1", Some(7)).await.unwrap();
assert_eq!(session.session_id(), Some("s1"));
assert_eq!(session.rpc_endpoint(), Some("http://rpc/s1"));
server.await.unwrap();
}
#[tokio::test]
async fn tracks_last_sequence_from_sequenced_control_responses() {
let (url, server) = spawn_server("k", |mut ws| {
tokio::spawn(async move {
let msg = ws.next().await.unwrap().unwrap();
let Message::Text(text) = msg else {
panic!("expected text");
};
let req: BacktestRequest = serde_json::from_str(&text).unwrap();
let BacktestRequest::AttachBacktestSession {
session_id,
last_sequence,
} = req
else {
panic!("expected attach");
};
assert_eq!(session_id, "s1");
assert_eq!(last_sequence, None);
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SessionAttached {
session_id: "s1".to_string(),
rpc_endpoint: "http://rpc/s1".to_string(),
task_id: None,
})
.unwrap(),
))
.await
.unwrap();
ws.send(Message::Text(
serde_json::to_string(&SequencedResponse {
seq_id: 41,
response: BacktestResponse::ReadyForContinue,
})
.unwrap(),
))
.await
.unwrap();
ws.send(Message::Text(
serde_json::to_string(&SequencedResponse {
seq_id: 42,
response: BacktestResponse::Status {
status: BacktestStatus::DecodedTransactions,
},
})
.unwrap(),
))
.await
.unwrap();
})
})
.await;
let client = BacktestClient::builder().url(url).api_key("k").build();
let mut session = client.attach_session("s1", None).await.unwrap();
assert_eq!(session.last_sequence(), None);
let ready = session
.ensure_ready(Some(Duration::from_secs(2)))
.await
.unwrap();
assert_eq!(ready, ReadyOutcome::Ready);
assert_eq!(session.last_sequence(), Some(41));
session
.wait_for_status(
BacktestStatus::DecodedTransactions,
Some(Duration::from_secs(2)),
)
.await
.unwrap();
assert_eq!(session.last_sequence(), Some(42));
server.await.unwrap();
}
async fn read_text(ws: &mut WebSocketStream<TcpStream>) -> String {
loop {
match ws.next().await.unwrap().unwrap() {
Message::Text(text) => return text,
Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => {}
other => panic!("expected text, got {other:?}"),
}
}
}
async fn send_session_event(
ws: &mut WebSocketStream<TcpStream>,
session_id: &str,
seq_id: u64,
event: SessionEventKind,
) {
let response = BacktestResponse::SessionEventV2 {
session_id: session_id.to_string(),
seq_id,
event,
};
ws.send(Message::Text(serde_json::to_string(&response).unwrap()))
.await
.unwrap();
}
async fn drive_sub_to_completion(
mut session: ParallelSubSession,
) -> (String, Option<SessionSummary>) {
loop {
match session.next_event().await.unwrap() {
ManagedEvent::ReadyForContinue => {
session
.send_continue(ContinueParams {
advance_count: 100,
transactions: Vec::new(),
modify_account_states: Default::default(),
})
.await
.unwrap();
}
ManagedEvent::Completed { summary, .. } => {
return (session.session_info().session_id.clone(), summary);
}
_ => {}
}
}
}
#[tokio::test]
async fn parallel_multiplex_drives_all_sub_sessions_to_completion() {
let (url, server) = spawn_server("k", |mut ws| {
tokio::spawn(async move {
let msg = ws.next().await.unwrap().unwrap();
let Message::Text(text) = msg else {
panic!("expected text");
};
let req: BacktestRequest = serde_json::from_str(&text).unwrap();
let BacktestRequest::CreateBacktestSession(request) = req else {
panic!("expected create");
};
let (_, parallel) = request.into_request_and_parallel();
assert!(parallel, "create should request the parallel path");
for (session_id, rpc, start) in [
("s1", "http://rpc/s1", 100u64),
("s2", "http://rpc/s2", 200u64),
] {
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SessionCreated {
session_id: session_id.to_string(),
rpc_endpoint: rpc.to_string(),
task_id: None,
})
.unwrap(),
))
.await
.unwrap();
send_session_event(
&mut ws,
session_id,
1,
SessionEventKind::SlotNotification(start),
)
.await;
}
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SessionsCreatedV2 {
control_session_id: "parallel_test".to_string(),
session_ids: vec!["s1".to_string(), "s2".to_string()],
task_ids: Vec::new(),
start_slots: vec![100, 200],
end_slots: vec![199, 299],
})
.unwrap(),
))
.await
.unwrap();
send_session_event(&mut ws, "s1", 2, SessionEventKind::ReadyForContinue).await;
send_session_event(&mut ws, "s2", 2, SessionEventKind::ReadyForContinue).await;
for _ in 0..2 {
let text = read_text(&mut ws).await;
let req: BacktestRequest = serde_json::from_str(&text).unwrap();
let BacktestRequest::ContinueSessionV1(c) = req else {
panic!("expected continueSessionV1, got {req:?}");
};
let correct_simulation = if c.session_id == "s1" { 1 } else { 2 };
send_session_event(
&mut ws,
&c.session_id,
3,
SessionEventKind::Completed {
summary: Some(SessionSummary {
correct_simulation,
..Default::default()
}),
},
)
.await;
}
})
})
.await;
let create = CreateSession::builder()
.start_slot(100)
.end_slot(299)
.parallel(true)
.build()
.into_request()
.unwrap();
let mut parallel = ManagedParallelSession::start_with_cancel(
url,
"k".to_string(),
create,
CancellationToken::new(),
)
.await
.unwrap();
assert_eq!(parallel.control_session_id(), "parallel_test");
let subs = parallel.take_sub_sessions();
assert_eq!(subs.len(), 2);
let mut handles = Vec::new();
for sub in subs {
handles.push(tokio::spawn(drive_sub_to_completion(sub)));
}
let mut completed = Vec::new();
for handle in handles {
completed.push(handle.await.unwrap());
}
completed.sort_by(|(a, _), (b, _)| a.cmp(b));
let ids: Vec<&str> = completed.iter().map(|(id, _)| id.as_str()).collect();
assert_eq!(ids, ["s1", "s2"]);
for ((id, summary), expected) in completed.iter().zip([1, 2]) {
let summary = summary
.as_ref()
.unwrap_or_else(|| panic!("sub-session {id} should carry a summary"));
assert_eq!(summary.correct_simulation, expected);
}
parallel.shutdown().await;
server.await.unwrap();
}
#[tokio::test]
async fn parallel_multiplex_reconnects_and_dedups_replay() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr: SocketAddr = listener.local_addr().unwrap();
let url = format!("ws://{addr}/backtest");
let server = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut ws = accept_with_expected_api_key(stream, "k").await.unwrap();
let _ = ws.next().await.unwrap().unwrap();
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SessionCreated {
session_id: "s1".to_string(),
rpc_endpoint: "http://rpc/s1".to_string(),
task_id: None,
})
.unwrap(),
))
.await
.unwrap();
send_session_event(&mut ws, "s1", 1, SessionEventKind::SlotNotification(100)).await;
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SessionsCreatedV2 {
control_session_id: "p".to_string(),
session_ids: vec!["s1".to_string()],
task_ids: Vec::new(),
start_slots: vec![100],
end_slots: vec![199],
})
.unwrap(),
))
.await
.unwrap();
drop(ws);
let (stream, _) = listener.accept().await.unwrap();
let mut ws = accept_with_expected_api_key(stream, "k").await.unwrap();
let text = read_text(&mut ws).await;
let req: BacktestRequest = serde_json::from_str(&text).unwrap();
let BacktestRequest::AttachParallelControlSessionV2 {
control_session_id,
last_sequences,
} = req
else {
panic!("expected attachParallelControlSessionV2, got {req:?}");
};
assert_eq!(control_session_id, "p");
assert_eq!(
last_sequences.get("s1"),
Some(&1),
"client should resume from the last seq it saw"
);
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::ParallelSessionAttachedV2 {
control_session_id: "p".to_string(),
session_ids: vec!["s1".to_string()],
task_ids: Vec::new(),
})
.unwrap(),
))
.await
.unwrap();
send_session_event(&mut ws, "s1", 1, SessionEventKind::SlotNotification(100)).await;
send_session_event(&mut ws, "s1", 2, SessionEventKind::ReadyForContinue).await;
let text = read_text(&mut ws).await;
let req: BacktestRequest = serde_json::from_str(&text).unwrap();
let BacktestRequest::ContinueSessionV1(c) = req else {
panic!("expected continueSessionV1, got {req:?}");
};
assert_eq!(c.session_id, "s1");
send_session_event(
&mut ws,
"s1",
3,
SessionEventKind::Completed { summary: None },
)
.await;
});
let create = CreateSession::builder()
.start_slot(100)
.end_slot(199)
.parallel(true)
.build()
.into_request()
.unwrap();
let mut parallel = ManagedParallelSession::start_with_cancel(
url,
"k".to_string(),
create,
CancellationToken::new(),
)
.await
.unwrap();
let mut subs = parallel.take_sub_sessions();
assert_eq!(subs.len(), 1);
let (session_id, _) = drive_sub_to_completion(subs.pop().unwrap()).await;
assert_eq!(session_id, "s1");
parallel.shutdown().await;
server.await.unwrap();
}
#[tokio::test]
async fn parallel_multiplex_rejects_missing_sub_session_ranges() {
let (url, server) = spawn_server("k", |mut ws| {
tokio::spawn(async move {
let _ = ws.next().await.unwrap().unwrap();
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SessionCreated {
session_id: "s1".to_string(),
rpc_endpoint: "http://rpc/s1".to_string(),
task_id: None,
})
.unwrap(),
))
.await
.unwrap();
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SessionsCreatedV2 {
control_session_id: "p".to_string(),
session_ids: vec!["s1".to_string()],
task_ids: Vec::new(),
start_slots: Vec::new(),
end_slots: Vec::new(),
})
.unwrap(),
))
.await
.unwrap();
})
})
.await;
let create = CreateSession::builder()
.start_slot(100)
.end_slot(199)
.parallel(true)
.build()
.into_request()
.unwrap();
let result = ManagedParallelSession::start_with_cancel(
url,
"k".to_string(),
create,
CancellationToken::new(),
)
.await;
let err = result
.err()
.expect("create should fail without sub-session ranges");
assert!(
matches!(err, ManagedSessionError::Create(_)),
"expected a create error, got {err:?}"
);
server.await.unwrap();
}