use std::{net::SocketAddr, time::Duration};
use futures::{SinkExt, StreamExt};
use simulator_api::{
BacktestError, BacktestRequest, BacktestResponse, BacktestStatus, CreateSessionParams,
SequencedResponse,
};
use simulator_client::{
BacktestClient, BacktestClientError, Continue, CreateSession, ReadyOutcome,
};
use tokio::net::TcpListener;
use tokio_tungstenite::{
accept_hdr_async,
tungstenite::{
Message,
handshake::server::{ErrorResponse, Request, Response},
},
};
fn assert_expected_api_key(req: &Request, expected_api_key: &str) {
let api_key = req.headers().get("X-API-Key").and_then(|v| v.to_str().ok());
assert_eq!(api_key, Some(expected_api_key));
}
#[allow(clippy::result_large_err)]
async fn accept_with_expected_api_key(
stream: tokio::net::TcpStream,
expected_api_key: &'static str,
) -> tokio_tungstenite::tungstenite::Result<tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>>
{
accept_hdr_async(stream, move |req: &Request, resp: Response| {
assert_expected_api_key(req, expected_api_key);
Ok::<_, ErrorResponse>(resp)
})
.await
}
async fn spawn_server(
expected_api_key: &'static str,
handler: impl FnOnce(
tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
) -> tokio::task::JoinHandle<()>
+ Send
+ 'static,
) -> (String, tokio::task::JoinHandle<()>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr: SocketAddr = listener.local_addr().unwrap();
let url = format!("ws://{addr}/backtest");
let join = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let ws = accept_with_expected_api_key(stream, expected_api_key)
.await
.unwrap();
handler(ws).await.unwrap();
});
(url, join)
}
#[tokio::test]
async fn creates_session_waits_ready_advances_and_closes() {
let (url, server) = spawn_server("k", |mut ws| {
tokio::spawn(async move {
eprintln!("server: waiting for resume request");
let msg = ws.next().await.unwrap().unwrap();
let Message::Text(text) = msg else {
panic!("expected text");
};
let req: BacktestRequest = serde_json::from_str(&text).unwrap();
let BacktestRequest::CreateBacktestSession(request) = req else {
panic!("expected create");
};
let (
CreateSessionParams {
start_slot,
end_slot,
..
},
parallel,
) = request.into_request_and_parallel();
assert!(!parallel);
assert_eq!(start_slot, 100);
assert_eq!(end_slot, 105);
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SessionCreated {
session_id: "s1".to_string(),
rpc_endpoint: "http://rpc".to_string(),
task_id: None,
})
.unwrap(),
))
.await
.unwrap();
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::ReadyForContinue).unwrap(),
))
.await
.unwrap();
let msg = ws.next().await.unwrap().unwrap();
let Message::Text(text) = msg else {
panic!("expected text");
};
let req: BacktestRequest = serde_json::from_str(&text).unwrap();
assert!(matches!(req, BacktestRequest::Continue(_)));
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::Status {
status: BacktestStatus::DecodedTransactions,
})
.unwrap(),
))
.await
.unwrap();
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SlotNotification(101)).unwrap(),
))
.await
.unwrap();
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SlotNotification(102)).unwrap(),
))
.await
.unwrap();
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::ReadyForContinue).unwrap(),
))
.await
.unwrap();
let msg = ws.next().await.unwrap().unwrap();
let Message::Text(text) = msg else {
panic!("expected text");
};
let req: BacktestRequest = serde_json::from_str(&text).unwrap();
assert!(matches!(req, BacktestRequest::CloseBacktestSession));
eprintln!("server: sending success");
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::Success).unwrap(),
))
.await
.unwrap();
})
})
.await;
let client = BacktestClient::builder().url(url).api_key("k").build();
let mut session = client
.create_session(
CreateSession::builder()
.start_slot(100)
.slot_count(5)
.build(),
)
.await
.unwrap();
assert_eq!(session.session_id(), Some("s1"));
assert_eq!(session.rpc_endpoint(), Some("http://rpc"));
let ready = session
.ensure_ready(Some(Duration::from_secs(2)))
.await
.unwrap();
assert_eq!(ready, ReadyOutcome::Ready);
let result = session
.advance(
Continue::builder().advance_count(2).build(),
Some(Duration::from_secs(2)),
|_| {},
)
.await
.unwrap();
assert_eq!(result.slot_notifications, 2);
assert_eq!(result.last_slot, Some(102));
assert!(result.ready_for_continue);
session.close(Some(Duration::from_secs(2))).await.unwrap();
server.await.unwrap();
}
#[tokio::test]
async fn surfaces_remote_error() {
let (url, server) = spawn_server("k", |mut ws| {
tokio::spawn(async move {
let msg = ws.next().await.unwrap().unwrap();
let Message::Text(text) = msg else {
panic!("expected text");
};
let req: BacktestRequest = serde_json::from_str(&text).unwrap();
assert!(matches!(req, BacktestRequest::CreateBacktestSession(_)));
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SessionCreated {
session_id: "s1".to_string(),
rpc_endpoint: "http://rpc".to_string(),
task_id: None,
})
.unwrap(),
))
.await
.unwrap();
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::ReadyForContinue).unwrap(),
))
.await
.unwrap();
let msg = ws.next().await.unwrap().unwrap();
let Message::Text(text) = msg else {
panic!("expected text");
};
let req: BacktestRequest = serde_json::from_str(&text).unwrap();
assert!(matches!(req, BacktestRequest::Continue(_)));
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::Error(BacktestError::NoMoreBlocks))
.unwrap(),
))
.await
.unwrap();
})
})
.await;
let client = BacktestClient::builder().url(url).api_key("k").build();
let mut session = client
.create_session(
CreateSession::builder()
.start_slot(100)
.end_slot(100)
.build(),
)
.await
.unwrap();
session
.ensure_ready(Some(Duration::from_secs(2)))
.await
.unwrap();
let err = session
.continue_until_ready(
Continue::builder().advance_count(1).build(),
Some(Duration::from_secs(2)),
|_| {},
)
.await
.unwrap_err();
assert!(matches!(
err,
BacktestClientError::Remote(BacktestError::NoMoreBlocks)
));
server.await.unwrap();
}
#[tokio::test]
async fn creates_parallel_sessions_and_returns_ids() {
let (url, server) = spawn_server("k", |mut ws| {
tokio::spawn(async move {
let msg = ws.next().await.unwrap().unwrap();
let Message::Text(text) = msg else {
panic!("expected text");
};
let payload: serde_json::Value = serde_json::from_str(&text).unwrap();
assert_eq!(payload["method"], "createBacktestSession");
assert_eq!(payload["params"]["parallel"], true);
assert_eq!(payload["params"]["startSlot"], 100);
assert!(payload["params"].get("request").is_none());
let req: BacktestRequest = serde_json::from_str(&text).unwrap();
let BacktestRequest::CreateBacktestSession(request) = req else {
panic!("expected create");
};
let (_, parallel) = request.into_request_and_parallel();
assert!(parallel);
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SessionsCreated {
session_ids: vec!["s1".to_string(), "s2".to_string()],
})
.unwrap(),
))
.await
.unwrap();
})
})
.await;
let client = BacktestClient::builder().url(url).api_key("k").build();
let session_ids = client
.create_sessions(
CreateSession::builder()
.start_slot(100)
.end_slot(105)
.parallel(true)
.build(),
)
.await
.unwrap();
assert_eq!(session_ids, vec!["s1".to_string(), "s2".to_string()]);
server.await.unwrap();
}
#[tokio::test]
async fn creates_parallel_sessions_from_streamed_session_created_events() {
let (url, server) = spawn_server("k", |mut ws| {
tokio::spawn(async move {
let msg = ws.next().await.unwrap().unwrap();
let Message::Text(text) = msg else {
panic!("expected text");
};
let req: BacktestRequest = serde_json::from_str(&text).unwrap();
let BacktestRequest::CreateBacktestSession(request) = req else {
panic!("expected create");
};
let (_, parallel) = request.into_request_and_parallel();
assert!(parallel);
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SessionCreated {
session_id: "s1".to_string(),
rpc_endpoint: "/backtest/s1".to_string(),
task_id: None,
})
.unwrap(),
))
.await
.unwrap();
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SessionCreated {
session_id: "s2".to_string(),
rpc_endpoint: "/backtest/s2".to_string(),
task_id: None,
})
.unwrap(),
))
.await
.unwrap();
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SessionsCreated {
session_ids: Vec::new(),
})
.unwrap(),
))
.await
.unwrap();
})
})
.await;
let client = BacktestClient::builder().url(url).api_key("k").build();
let mut streamed = Vec::new();
let session_ids = client
.create_sessions_with_progress(
CreateSession::builder()
.start_slot(100)
.end_slot(105)
.parallel(true)
.build(),
|session_id| streamed.push(session_id),
)
.await
.unwrap();
assert_eq!(streamed, vec!["s1".to_string(), "s2".to_string()]);
assert_eq!(session_ids, vec!["s1".to_string(), "s2".to_string()]);
server.await.unwrap();
}
#[tokio::test]
async fn attaches_to_existing_session() {
let (url, server) = spawn_server("k", |mut ws| {
tokio::spawn(async move {
let msg = ws.next().await.unwrap().unwrap();
let Message::Text(text) = msg else {
panic!("expected text");
};
let req: BacktestRequest = serde_json::from_str(&text).unwrap();
let BacktestRequest::AttachBacktestSession {
session_id,
last_sequence,
} = req
else {
panic!("expected attach");
};
assert_eq!(session_id, "s1");
assert_eq!(last_sequence, Some(7));
eprintln!("server: sending session attached");
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SessionAttached {
session_id: "s1".to_string(),
rpc_endpoint: "http://rpc/s1".to_string(),
task_id: None,
})
.unwrap(),
))
.await
.unwrap();
})
})
.await;
let client = BacktestClient::builder().url(url).api_key("k").build();
let session = client.attach_session("s1", Some(7)).await.unwrap();
assert_eq!(session.session_id(), Some("s1"));
assert_eq!(session.rpc_endpoint(), Some("http://rpc/s1"));
server.await.unwrap();
}
#[tokio::test]
async fn tracks_last_sequence_from_sequenced_control_responses() {
let (url, server) = spawn_server("k", |mut ws| {
tokio::spawn(async move {
let msg = ws.next().await.unwrap().unwrap();
let Message::Text(text) = msg else {
panic!("expected text");
};
let req: BacktestRequest = serde_json::from_str(&text).unwrap();
let BacktestRequest::AttachBacktestSession {
session_id,
last_sequence,
} = req
else {
panic!("expected attach");
};
assert_eq!(session_id, "s1");
assert_eq!(last_sequence, None);
ws.send(Message::Text(
serde_json::to_string(&BacktestResponse::SessionAttached {
session_id: "s1".to_string(),
rpc_endpoint: "http://rpc/s1".to_string(),
task_id: None,
})
.unwrap(),
))
.await
.unwrap();
ws.send(Message::Text(
serde_json::to_string(&SequencedResponse {
seq_id: 41,
response: BacktestResponse::ReadyForContinue,
})
.unwrap(),
))
.await
.unwrap();
ws.send(Message::Text(
serde_json::to_string(&SequencedResponse {
seq_id: 42,
response: BacktestResponse::Status {
status: BacktestStatus::DecodedTransactions,
},
})
.unwrap(),
))
.await
.unwrap();
})
})
.await;
let client = BacktestClient::builder().url(url).api_key("k").build();
let mut session = client.attach_session("s1", None).await.unwrap();
assert_eq!(session.last_sequence(), None);
let ready = session
.ensure_ready(Some(Duration::from_secs(2)))
.await
.unwrap();
assert_eq!(ready, ReadyOutcome::Ready);
assert_eq!(session.last_sequence(), Some(41));
session
.wait_for_status(
BacktestStatus::DecodedTransactions,
Some(Duration::from_secs(2)),
)
.await
.unwrap();
assert_eq!(session.last_sequence(), Some(42));
server.await.unwrap();
}