use std::sync::{Arc, Mutex};
use std::time::Duration;
use time_tz::timezones;
use super::*;
use crate::client::r#async::Client;
use crate::common::test_utils::helpers::{error_frame, managed_accounts_frame, next_valid_id_frame};
use crate::messages::IncomingMessages;
use crate::server_versions;
use crate::transport::common::MAX_RECONNECT_ATTEMPTS;
use crate::transport::r#async::{AsyncTcpMessageBus, MemoryStream};
const CLIENT_ID: i32 = 100;
const SERVER_VERSION: i32 = server_versions::PROTOBUF_REST_MESSAGES_3;
fn push_handshake(stream: &MemoryStream) {
let handshake = format!("{}\020240120 12:00:00 EST\0", SERVER_VERSION);
stream.push_inbound(handshake.into_bytes());
stream.push_inbound(next_valid_id_frame(90));
stream.push_inbound(managed_accounts_frame("DU1234567"));
}
fn binary_text(msg_id: i32, payload: &str) -> Vec<u8> {
let mut data = Vec::with_capacity(4 + payload.len());
data.extend_from_slice(&msg_id.to_be_bytes());
data.extend_from_slice(payload.as_bytes());
data
}
#[tokio::test]
async fn establish_connection_rejects_pre_protobuf_server() {
let stream = MemoryStream::default();
let connection = AsyncConnection::stubbed(stream.clone(), CLIENT_ID);
let too_old = server_versions::PROTOBUF_REST_MESSAGES_3 - 1;
let handshake = format!("{}\020240120 12:00:00 EST\0", too_old);
stream.push_inbound(handshake.into_bytes());
let err = connection.establish_connection().await.expect_err("must reject old server");
match err {
crate::errors::Error::ServerVersion(required, got, ref msg) => {
assert_eq!(required, server_versions::PROTOBUF_REST_MESSAGES_3);
assert_eq!(got, too_old);
assert!(msg.contains("protobuf"), "message should mention protobuf: {msg}");
}
other => panic!("expected Error::ServerVersion, got {other:?}"),
}
let captured = stream.captured();
let expected = connection.connection_handler.format_handshake();
assert_eq!(captured, expected, "no bytes should follow the handshake when version check fails");
}
#[tokio::test]
async fn establish_connection_populates_metadata() {
let stream = MemoryStream::default();
let connection = AsyncConnection::stubbed(stream.clone(), CLIENT_ID);
push_handshake(&stream);
connection.establish_connection().await.expect("establish_connection failed");
assert_eq!(connection.client_id, CLIENT_ID);
assert_eq!(connection.server_version(), SERVER_VERSION);
let metadata = connection.connection_metadata().await;
assert_eq!(metadata.next_order_id, 90);
assert_eq!(metadata.managed_accounts, "DU1234567");
assert_eq!(metadata.time_zone, Some(timezones::db::EST));
}
#[tokio::test]
async fn disconnect_completes() {
let client = make_client().await;
tokio::time::timeout(Duration::from_secs(2), client.disconnect())
.await
.expect("disconnect did not complete in time");
assert!(!client.is_connected());
}
#[tokio::test]
async fn disconnect_is_idempotent() {
let client = make_client().await;
tokio::time::timeout(Duration::from_secs(2), async {
client.disconnect().await;
client.disconnect().await;
})
.await
.expect("repeated disconnect did not complete in time");
assert!(!client.is_connected());
}
async fn make_client() -> Client {
let stream = MemoryStream::default();
let connection = AsyncConnection::stubbed(stream.clone(), CLIENT_ID);
push_handshake(&stream);
connection.establish_connection().await.expect("establish_connection failed");
let server_version = connection.server_version();
let bus = Arc::new(AsyncTcpMessageBus::new(connection).expect("AsyncTcpMessageBus::new"));
bus.clone()
.process_messages(server_version, Duration::from_secs(0))
.expect("process_messages");
Client::stubbed(bus, server_version)
}
#[tokio::test]
async fn handshake_callbacks_and_notice_stream_survive_reconnect() {
let stream = MemoryStream::default();
let mut connection = AsyncConnection::stubbed(stream.clone(), CLIENT_ID);
let startup_count = Arc::new(Mutex::new(0_usize));
let startup_count_clone = startup_count.clone();
connection.startup_callback = Some(Arc::new(move |_msg: crate::connection::common::StartupMessage| {
*startup_count_clone.lock().unwrap() += 1;
}));
let mut notice_rx = connection.notice_sender.subscribe();
let handshake_bytes = format!("{}\020240120 12:00:00 EST\0", SERVER_VERSION).into_bytes();
stream.push_inbound(handshake_bytes.clone());
stream.push_inbound(binary_text(IncomingMessages::OpenOrderEnd as i32, "1\0"));
stream.push_inbound(error_frame(-1, 2104, "farm OK"));
stream.push_inbound(next_valid_id_frame(90));
stream.push_inbound(managed_accounts_frame("DU1234567"));
connection.establish_connection().await.expect("first establish_connection failed");
assert_eq!(*startup_count.lock().unwrap(), 1, "startup callback should fire on first handshake");
let n1 = notice_rx.try_recv().expect("first farm-status notice should be on the stream");
assert_eq!(n1.code, 2104);
stream.push_inbound(handshake_bytes);
stream.push_inbound(binary_text(IncomingMessages::OpenOrderEnd as i32, "1\0"));
stream.push_inbound(error_frame(-1, 2106, "HMDS farm OK"));
stream.push_inbound(next_valid_id_frame(91));
stream.push_inbound(managed_accounts_frame("DU1234567"));
connection.establish_connection().await.expect("second establish_connection failed");
assert_eq!(*startup_count.lock().unwrap(), 2, "startup callback should fire on reconnect handshake");
let n2 = notice_rx.try_recv().expect("second farm-status notice should be on the same stream");
assert_eq!(n2.code, 2106);
}
#[test]
fn debug_impl_formats_connection() {
let stream = MemoryStream::default();
let connection = AsyncConnection::stubbed(stream, CLIENT_ID);
let rendered = format!("{connection:?}");
assert!(rendered.contains("AsyncConnection"), "{rendered}");
assert!(rendered.contains(&CLIENT_ID.to_string()), "{rendered}");
}
#[tokio::test]
async fn handshake_unexpected_eof_returns_connection_rejected() {
let stream = MemoryStream::default();
let connection = AsyncConnection::stubbed(stream.clone(), CLIENT_ID);
stream.close();
let err = connection.handshake().await.expect_err("must surface rejection error");
match err {
crate::errors::Error::ConnectionRejected(ref msg) => {
assert!(msg.contains("server may be rejecting"), "unexpected message: {msg}");
}
other => panic!("expected Error::ConnectionRejected, got {other:?}"),
}
}
#[tokio::test]
async fn reconnect_succeeds_after_transient_failures() {
let stream = MemoryStream::default();
let connection = AsyncConnection::stubbed(stream.clone(), CLIENT_ID);
push_handshake(&stream);
connection.establish_connection().await.expect("initial establish_connection failed");
assert_eq!(connection.server_version(), SERVER_VERSION);
stream.set_reconnect_failures(3);
push_handshake(&stream);
connection.reconnect().await.expect("reconnect must succeed after transient failures");
assert_eq!(connection.server_version(), SERVER_VERSION);
}
#[tokio::test]
async fn reconnect_returns_connection_failed_after_exhausting_attempts() {
let stream = MemoryStream::default();
let connection = AsyncConnection::stubbed(stream.clone(), CLIENT_ID);
push_handshake(&stream);
connection.establish_connection().await.expect("initial establish_connection failed");
stream.set_reconnect_failures(MAX_RECONNECT_ATTEMPTS as usize);
let err = connection.reconnect().await.expect_err("must give up after MAX_RECONNECT_ATTEMPTS");
assert!(matches!(err, crate::errors::Error::ConnectionFailed), "got {err:?}");
}
#[tokio::test]
async fn reconnect_clears_metadata_while_waiting_for_handshake() {
let stream = MemoryStream::default();
let connection = AsyncConnection::stubbed(stream.clone(), CLIENT_ID);
push_handshake(&stream);
connection.establish_connection().await.expect("initial establish_connection failed");
let metadata = connection.connection_metadata().await;
assert_eq!(metadata.server_version, SERVER_VERSION);
assert_eq!(metadata.next_order_id, 90);
assert_eq!(metadata.managed_accounts, "DU1234567");
let initial_capture_len = stream.captured().len();
let connection = Arc::new(connection);
let conn_for_task = Arc::clone(&connection);
let reconnect_task = tokio::spawn(async move { conn_for_task.reconnect().await });
tokio::time::timeout(Duration::from_secs(2), async {
loop {
if stream.captured().len() > initial_capture_len {
break;
}
tokio::task::yield_now().await;
}
})
.await
.expect("reconnect must reach handshake-write phase");
let metadata = connection.connection_metadata().await;
assert_eq!(metadata.client_id, CLIENT_ID);
assert_eq!(metadata.server_version, 0);
assert_eq!(metadata.next_order_id, 0);
assert_eq!(metadata.managed_accounts, "");
assert!(metadata.connection_time.is_none());
assert!(metadata.time_zone.is_none());
push_handshake(&stream);
reconnect_task.await.expect("reconnect task panicked").expect("reconnect failed");
let metadata = connection.connection_metadata().await;
assert_eq!(metadata.server_version, SERVER_VERSION);
assert_eq!(metadata.next_order_id, 90);
assert_eq!(metadata.managed_accounts, "DU1234567");
assert_eq!(metadata.time_zone, Some(timezones::db::EST));
}