use std::{cell::Cell, io, net::SocketAddr, ops::RangeInclusive, rc::Rc, time::Duration};
use chrono::NaiveTime;
use easyfix_macros::fix_str;
use easyfix_messages::{
fields::{
DefaultApplVerId, EncryptMethod, FixStr, SeqNum, SessionStatus, Utc, UtcTimestamp,
},
messages::{FixtMessage, Header, Logon, Message, TestRequest, Trailer},
};
use easyfix_session::{
acceptor::{Acceptor, Connection},
application::{AsEvent, FixEvent},
messages_storage::{InMemoryStorage, MessagesStorage},
session_id::SessionId,
settings::{SessionSettings, Settings},
};
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, DuplexStream},
runtime::Builder,
sync::mpsc,
task::LocalSet,
time::{sleep, timeout},
};
use tokio_stream::StreamExt;
const BEGIN_STRING: &FixStr = fix_str!("FIXT.1.1");
const SERVER_COMP_ID: &FixStr = fix_str!("server");
const CLIENT_COMP_ID: &FixStr = fix_str!("client");
fn session_id() -> SessionId {
SessionId::new(
BEGIN_STRING.to_owned(),
SERVER_COMP_ID.to_owned(),
CLIENT_COMP_ID.to_owned(),
)
}
fn settings() -> Settings {
Settings {
sender_comp_id: SERVER_COMP_ID.to_owned(),
sender_sub_id: None,
heartbeat_interval: Some(10),
auto_disconnect_after_no_logon_received: Duration::from_secs(3),
auto_disconnect_after_no_heartbeat: 3,
auto_disconnect_after_no_logout: Duration::from_secs(5),
}
}
fn session_settings() -> SessionSettings {
SessionSettings {
session_id: session_id(),
session_time: NaiveTime::from_hms_opt(0, 0, 0).unwrap()
..=NaiveTime::from_hms_opt(23, 59, 59).unwrap(),
logon_time: NaiveTime::from_hms_opt(0, 0, 0).unwrap()
..=NaiveTime::from_hms_opt(23, 59, 59).unwrap(),
send_redundant_resend_requests: false,
check_comp_id: true,
max_latency: Some(Duration::from_secs(60)),
reset_on_logon: false,
reset_on_logout: false,
reset_on_disconnect: true,
sender_default_appl_ver_id: fix_str!("9").to_owned(),
target_default_appl_ver_id: fix_str!("9").to_owned(),
persist: false,
refresh_on_logon: false,
enable_next_expected_msg_seq_num: false,
verify_logout: true,
verify_test_request_id: true,
}
}
struct TestConnection {
incoming: mpsc::UnboundedReceiver<DuplexStream>,
}
impl Connection for TestConnection {
async fn accept(
&mut self,
) -> Result<
(
impl AsyncRead + Unpin + 'static,
impl AsyncWrite + Unpin + 'static,
SocketAddr,
),
io::Error,
> {
match self.incoming.recv().await {
Some(stream) => {
let (reader, writer) = tokio::io::split(stream);
Ok((reader, writer, "127.0.0.1:1".parse().unwrap()))
}
None => std::future::pending().await,
}
}
}
fn serialize_msg(msg: Message, msg_seq_num: SeqNum, sending_time: UtcTimestamp) -> Vec<u8> {
FixtMessage {
header: Box::new(Header {
begin_string: BEGIN_STRING.to_owned(),
msg_type: msg.msg_type(),
sender_comp_id: CLIENT_COMP_ID.to_owned(),
target_comp_id: SERVER_COMP_ID.to_owned(),
msg_seq_num,
sending_time,
..Default::default()
}),
body: Box::new(msg),
trailer: Box::new(Trailer::default()),
}
.serialize()
}
fn logon() -> Message {
Message::Logon(Logon {
encrypt_method: EncryptMethod::NoneOther,
heart_bt_int: 10,
default_appl_ver_id: DefaultApplVerId::Fix50Sp2,
..Default::default()
})
}
fn test_request(id: &FixStr) -> Message {
Message::TestRequest(TestRequest {
test_req_id: id.to_owned(),
})
}
fn count_occurrences(haystack: &[u8], needle: &[u8]) -> usize {
haystack.windows(needle.len()).filter(|w| w == &needle).count()
}
async fn pump_until_logon<S: MessagesStorage + 'static>(acceptor: &mut Acceptor<S>) {
loop {
let mut entry = acceptor.next().await.expect("event stream closed");
if matches!(entry.as_event(), FixEvent::Logon(..)) {
break;
}
}
}
async fn read_until_pumping<S: MessagesStorage + 'static, R: AsyncRead + Unpin>(
acceptor: &mut Acceptor<S>,
reader: &mut R,
buf: &mut Vec<u8>,
needle: &[u8],
count: usize,
) {
let mut chunk = [0u8; 4096];
while count_occurrences(buf, needle) < count {
tokio::select! {
entry = acceptor.next() => {
let mut entry = entry.expect("event stream closed");
let _ = entry.as_event();
}
read_result = reader.read(&mut chunk) => {
let n = read_result.expect("read failed");
assert_ne!(n, 0, "connection closed while waiting for {needle:?}");
buf.extend_from_slice(&chunk[..n]);
}
}
}
}
fn run_local_test(test: impl Future<Output = ()>) {
let _ = tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.try_init();
let runtime = Builder::new_current_thread()
.enable_time()
.build()
.expect("runtime");
let local_set = LocalSet::new();
local_set
.block_on(&runtime, async {
timeout(Duration::from_secs(10), test).await
})
.expect("test timed out");
}
#[test]
fn reject_flood_does_not_kill_connection_task() {
const FLOOD_LEN: usize = 20;
run_local_test(async {
let mut acceptor = Acceptor::new(settings(), Box::new(|_| InMemoryStorage::new()));
acceptor.register_session(session_id(), session_settings());
let (conn_tx, conn_rx) = mpsc::unbounded_channel();
acceptor.start(TestConnection { incoming: conn_rx });
let (client, server) = tokio::io::duplex(1024 * 1024);
conn_tx.send(server).expect("connection refused");
let (mut client_rx, mut client_tx) = tokio::io::split(client);
client_tx
.write_all(&serialize_msg(logon(), 1, UtcTimestamp::now()))
.await
.expect("logon write failed");
pump_until_logon(&mut acceptor).await;
let mut buf = Vec::new();
read_until_pumping(&mut acceptor, &mut client_rx, &mut buf, b"\x0135=A\x01", 1).await;
let stale = UtcTimestamp::with_millis(Utc::now() - chrono::Duration::seconds(600));
let mut batch = Vec::new();
for seq_num in 2..2 + FLOOD_LEN as SeqNum {
batch.extend_from_slice(&serialize_msg(test_request(fix_str!("flood")), seq_num, stale));
}
client_tx.write_all(&batch).await.expect("flood write failed");
sleep(Duration::from_millis(200)).await;
let mut deserialize_errors = 0;
while deserialize_errors < FLOOD_LEN {
let mut entry = acceptor.next().await.expect("event stream closed");
if matches!(entry.as_event(), FixEvent::DeserializeError(..)) {
deserialize_errors += 1;
}
}
let mut buf = Vec::new();
read_until_pumping(&mut acceptor, &mut client_rx, &mut buf, b"\x0135=3\x01", FLOOD_LEN)
.await;
});
}
struct PanickingStorage {
inner: InMemoryStorage,
panic_armed: Rc<Cell<bool>>,
}
impl MessagesStorage for PanickingStorage {
fn fetch_range(&mut self, range: RangeInclusive<SeqNum>) -> impl Iterator<Item = &[u8]> {
self.inner.fetch_range(range)
}
fn store(&mut self, seq_num: SeqNum, data: &[u8]) {
if self.panic_armed.replace(false) {
panic!("injected storage failure");
}
self.inner.store(seq_num, data);
}
fn next_sender_msg_seq_num(&self) -> SeqNum {
self.inner.next_sender_msg_seq_num()
}
fn next_target_msg_seq_num(&self) -> SeqNum {
self.inner.next_target_msg_seq_num()
}
fn set_next_sender_msg_seq_num(&mut self, seq_num: SeqNum) {
self.inner.set_next_sender_msg_seq_num(seq_num);
}
fn set_next_target_msg_seq_num(&mut self, seq_num: SeqNum) {
self.inner.set_next_target_msg_seq_num(seq_num);
}
fn incr_next_sender_msg_seq_num(&mut self) {
self.inner.incr_next_sender_msg_seq_num();
}
fn incr_next_target_msg_seq_num(&mut self) {
self.inner.incr_next_target_msg_seq_num();
}
fn reset(&mut self) {
self.inner.reset();
}
}
#[test]
fn panicked_connection_task_releases_session() {
run_local_test(async {
let panic_armed = Rc::new(Cell::new(false));
let storage_panic_armed = panic_armed.clone();
let mut acceptor = Acceptor::new(
settings(),
Box::new(move |_| PanickingStorage {
inner: InMemoryStorage::new(),
panic_armed: storage_panic_armed.clone(),
}),
);
acceptor.register_session(session_id(), session_settings());
let (conn_tx, conn_rx) = mpsc::unbounded_channel();
acceptor.start(TestConnection { incoming: conn_rx });
let (client, server) = tokio::io::duplex(1024 * 1024);
conn_tx.send(server).expect("connection refused");
let (_client_rx, mut client_tx) = tokio::io::split(client);
client_tx
.write_all(&serialize_msg(logon(), 1, UtcTimestamp::now()))
.await
.expect("logon write failed");
pump_until_logon(&mut acceptor).await;
panic_armed.set(true);
client_tx
.write_all(&serialize_msg(test_request(fix_str!("boom")), 2, UtcTimestamp::now()))
.await
.expect("test request write failed");
while acceptor.is_session_active(&session_id()).expect("unknown session") {
if let Ok(Some(mut entry)) = timeout(Duration::from_millis(20), acceptor.next()).await
{
let _ = entry.as_event();
}
}
let (client2, server2) = tokio::io::duplex(1024 * 1024);
conn_tx.send(server2).expect("connection refused");
let (mut client2_rx, mut client2_tx) = tokio::io::split(client2);
client2_tx
.write_all(&serialize_msg(logon(), 1, UtcTimestamp::now()))
.await
.expect("logon write failed");
pump_until_logon(&mut acceptor).await;
acceptor.disable_with_logout(Some(SessionStatus::SessionLogoutComplete), None);
let mut buf = Vec::new();
read_until_pumping(&mut acceptor, &mut client2_rx, &mut buf, b"\x0135=5\x01", 1).await;
});
}