use super::tokio_event_loop::{TokioLlEvent, TokioLlEventLoop};
use super::{Backend, Config, Configure};
#[derive(Debug, Clone)]
enum Response<'a> {
None,
ResetHeartbeat,
TerminateTransport,
Application(Message<'a, &'a [u8]>),
OutboundBytes(&'a [u8]),
}
use crate::tagvalue::Message;
use crate::tagvalue::{DecoderStreaming, Encoder};
use crate::{FieldMap, SetField, StreamingDecoder};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::mpsc;
#[derive(Debug)]
pub struct TokioFixConnection<B, C = Config> {
config: C,
backend: B,
encoder: Encoder,
buffer: Vec<u8>,
msg_seq_num_inbound: MsgSeqNumCounter,
msg_seq_num_outbound: MsgSeqNumCounter,
}
#[derive(Debug)]
pub struct TokioAppHandler {
inbound_tx: mpsc::UnboundedSender<Message<'static, Vec<u8>>>,
outbound_rx: mpsc::UnboundedReceiver<Vec<u8>>,
}
impl TokioAppHandler {
pub fn new() -> (Self, TokioAppChannels) {
let (inbound_tx, inbound_rx) = mpsc::unbounded_channel();
let (outbound_tx, outbound_rx) = mpsc::unbounded_channel();
let handler = Self { inbound_tx, outbound_rx };
let channels = TokioAppChannels { inbound_rx, outbound_tx };
(handler, channels)
}
pub fn send_inbound(
&self,
message: Message<'static, Vec<u8>>,
) -> Result<(), mpsc::error::SendError<Message<'static, Vec<u8>>>> {
self.inbound_tx.send(message)
}
pub fn try_recv_outbound(
&mut self,
) -> Result<Vec<u8>, mpsc::error::TryRecvError> {
self.outbound_rx.try_recv()
}
}
#[derive(Debug)]
pub struct TokioAppChannels {
pub inbound_rx: mpsc::UnboundedReceiver<Message<'static, Vec<u8>>>,
pub outbound_tx: mpsc::UnboundedSender<Vec<u8>>,
}
impl<C, B> TokioFixConnection<B, C>
where
C: Configure,
B: Backend,
{
pub fn new(config: C, backend: B) -> Self {
Self {
config,
backend,
encoder: Encoder::new(),
buffer: Vec::new(),
msg_seq_num_inbound: MsgSeqNumCounter::new(),
msg_seq_num_outbound: MsgSeqNumCounter::new(),
}
}
pub async fn start_with_stream(
&mut self,
stream: TcpStream,
mut decoder: DecoderStreaming<Vec<u8>>,
app_handler: Option<TokioAppHandler>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let (mut reader, mut writer) = stream.into_split();
self.establish_connection(&mut reader, &mut writer, &mut decoder)
.await?;
self.tokio_event_loop(reader, writer, decoder, app_handler).await
}
async fn establish_connection(
&mut self,
reader: &mut tokio::net::tcp::OwnedReadHalf,
writer: &mut tokio::net::tcp::OwnedWriteHalf,
decoder: &mut DecoderStreaming<Vec<u8>>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let logon = self.create_logon_message();
writer.write_all(&logon).await?;
self.backend.on_outbound_message(&logon).ok();
loop {
let buffer = decoder.fillable();
reader.read_exact(buffer).await?;
if let Ok(Some(())) = decoder.try_parse() {
let logon_response = decoder.message();
self.on_logon(logon_response);
self.backend.on_inbound_message(logon_response, true).ok();
decoder.clear();
self.msg_seq_num_inbound.next();
self.backend.on_successful_handshake().ok();
break;
}
}
Ok(())
}
async fn tokio_event_loop(
&mut self,
reader: tokio::net::tcp::OwnedReadHalf,
mut writer: tokio::net::tcp::OwnedWriteHalf,
decoder: DecoderStreaming<Vec<u8>>,
mut app_handler: Option<TokioAppHandler>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let mut event_loop =
TokioLlEventLoop::new(decoder, reader, self.heartbeat());
loop {
tokio::select! {
event = event_loop.next_event() => {
match event {
Some(TokioLlEvent::Message(msg)) => {
let response = self.on_inbound_message(msg);
match response {
Response::OutboundBytes(bytes) => {
writer.write_all(bytes).await?;
}
Response::ResetHeartbeat => {
event_loop.ping_heartbeat();
}
Response::Application(app_msg) => {
if let Some(ref _handler) = app_handler {
let _owned_data = app_msg.as_bytes().to_vec();
eprintln!("Message handling temporarily disabled due to unsafe code removal");
}
}
_ => {}
}
}
Some(TokioLlEvent::BadMessage(_err)) => {
}
Some(TokioLlEvent::IoError(_)) => {
return Err("I/O error in FIX connection".into());
}
Some(TokioLlEvent::Heartbeat) => {
let _ = self.backend.on_heartbeat_is_due();
let heartbeat = self.create_heartbeat();
writer.write_all(&heartbeat).await?;
}
Some(TokioLlEvent::TestRequest) => {
}
Some(TokioLlEvent::Logout) => {
return Ok(());
}
None => {
return Ok(());
}
}
}
outbound = async {
if let Some(ref mut handler) = app_handler {
handler.try_recv_outbound().ok()
} else {
None
}
} => {
if let Some(msg_bytes) = outbound {
writer.write_all(&msg_bytes).await?;
self.on_outbound_message(&msg_bytes).ok();
}
}
}
}
}
fn create_logon_message(&mut self) -> Vec<u8> {
let begin_string = self.config.begin_string();
let sender_comp_id = self.config.sender_comp_id();
let target_comp_id = self.config.target_comp_id();
let heartbeat = self.config.heartbeat().as_secs();
let msg_seq_num = self.msg_seq_num_outbound.next();
self.buffer.clear();
let mut msg =
self.encoder.start_message(begin_string, &mut self.buffer, b"A");
msg.set(49, sender_comp_id);
msg.set(56, target_comp_id);
let now = chrono::Utc::now().format("%Y%m%d-%H:%M:%S%.3f").to_string();
msg.set(52, now.as_str());
msg.set(34, msg_seq_num);
msg.set(98, 0);
msg.set(108, heartbeat as u32);
let (_, _) = msg.done();
self.buffer.clone()
}
fn on_inbound_message<'a>(
&'a mut self,
msg: Message<'a, &'a [u8]>,
) -> Response<'a> {
if let Ok(msg_type) = msg.get::<&[u8]>(35) {
match msg_type {
b"A" => {
self.on_logon(msg);
Response::None
}
b"0" => {
self.on_heartbeat(msg);
Response::ResetHeartbeat
}
b"1" => {
let response = self.on_test_request(msg);
Response::OutboundBytes(response)
}
b"5" => Response::TerminateTransport,
_ => Response::Application(msg),
}
} else {
Response::Application(msg)
}
}
fn on_logon(&mut self, _msg: Message<&[u8]>) {
}
fn on_heartbeat(&mut self, _msg: Message<&[u8]>) {
}
fn on_test_request(&mut self, msg: Message<&[u8]>) -> &[u8] {
let test_req_id = msg.get::<&[u8]>(112).unwrap_or(b"");
let begin_string = self.config.begin_string();
let msg_seq_num = self.msg_seq_num_outbound.next();
let sender_comp_id = self.config.sender_comp_id();
let target_comp_id = self.config.target_comp_id();
self.buffer.clear();
let mut response =
self.encoder.start_message(begin_string, &mut self.buffer, b"0");
response.set(49, sender_comp_id);
response.set(56, target_comp_id);
response.set(34, msg_seq_num);
let now = chrono::Utc::now().format("%Y%m%d-%H:%M:%S%.3f").to_string();
response.set(52, now.as_str());
response.set(112, test_req_id);
let (_, _) = response.done();
&self.buffer
}
fn on_outbound_message(
&mut self,
_message: &[u8],
) -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
fn heartbeat(&self) -> Duration {
self.config.heartbeat()
}
fn create_heartbeat(&mut self) -> Vec<u8> {
let begin_string = self.config.begin_string();
let msg_seq_num = self.msg_seq_num_outbound.next();
let sender_comp_id = self.config.sender_comp_id();
let target_comp_id = self.config.target_comp_id();
self.buffer.clear();
let mut msg =
self.encoder.start_message(begin_string, &mut self.buffer, b"0");
msg.set(49, sender_comp_id);
msg.set(56, target_comp_id);
msg.set(34, msg_seq_num);
let now = chrono::Utc::now().format("%Y%m%d-%H:%M:%S%.3f").to_string();
msg.set(52, now.as_str());
let (_, _) = msg.done();
self.buffer.clone()
}
}
#[derive(Debug)]
pub struct MsgSeqNumCounter(u64);
impl MsgSeqNumCounter {
pub const fn new() -> Self {
Self(1)
}
pub fn next(&mut self) -> u64 {
let current = self.0;
self.0 += 1;
current
}
pub fn expected(&self) -> u64 {
self.0
}
}
impl Default for MsgSeqNumCounter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod test {
use super::*;
#[tokio::test]
async fn tokio_app_handler_channels() {
let (_handler, channels) = TokioAppHandler::new();
assert!(channels.outbound_tx.send(vec![1, 2, 3]).is_ok());
}
}