mod io;
pub(crate) use io::{AsyncStream, AsyncTcpSocket};
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use async_trait::async_trait;
use futures::StreamExt;
use log::{debug, error, info, warn};
use tokio::sync::{broadcast, mpsc, Notify, RwLock};
use tokio::task;
use tokio::time::Duration;
use tokio_stream::wrappers::BroadcastStream;
use crate::connection::r#async::AsyncConnection;
use crate::messages::{shared_channel_configuration, IncomingMessages, Notice, OutgoingMessages, ResponseMessage};
use crate::Error;
use super::common::log_orphan;
use super::routing::{
determine_routing, is_warning_error, order_routing_strategy, DecodedError, OrderRoutingStrategy, RoutingDecision, UNSPECIFIED_REQUEST_ID,
};
use super::RoutedItem;
pub(crate) const BROADCAST_CHANNEL_CAPACITY: usize = 1024;
#[derive(Debug, Clone)]
pub enum CleanupSignal {
Request(i32),
Order(i32),
Shared(OutgoingMessages),
OrderUpdateStream,
}
#[async_trait]
pub trait AsyncMessageBus: Send + Sync {
async fn send_request(&self, request_id: i32, message: Vec<u8>) -> Result<AsyncInternalSubscription, Error>;
async fn send_order_request(&self, order_id: i32, message: Vec<u8>) -> Result<AsyncInternalSubscription, Error>;
async fn send_shared_request(&self, message_type: OutgoingMessages, message: Vec<u8>) -> Result<AsyncInternalSubscription, Error>;
async fn send_message(&self, message: Vec<u8>) -> Result<(), Error>;
#[allow(dead_code)]
async fn cancel_subscription(&self, request_id: i32, message: Vec<u8>) -> Result<(), Error>;
#[allow(dead_code)]
async fn cancel_order_subscription(&self, order_id: i32, message: Vec<u8>) -> Result<(), Error>;
async fn create_order_update_subscription(&self) -> Result<AsyncInternalSubscription, Error>;
fn notice_subscribe(&self) -> crate::subscriptions::notice_stream::async_impl::NoticeStream;
async fn ensure_shutdown(&self);
fn request_shutdown_sync(&self);
fn is_connected(&self) -> bool;
}
pub struct AsyncInternalSubscription {
template_receiver: broadcast::Receiver<RoutedItem>,
pub(crate) stream: BroadcastStream<RoutedItem>,
cleanup_sender: Option<mpsc::UnboundedSender<CleanupSignal>>,
cleanup_signal: Option<CleanupSignal>,
cleanup_sent: bool,
}
impl Clone for AsyncInternalSubscription {
fn clone(&self) -> Self {
let new_template = self.template_receiver.resubscribe();
let new_polling = self.template_receiver.resubscribe();
Self {
template_receiver: new_template,
stream: BroadcastStream::new(new_polling),
cleanup_sender: self.cleanup_sender.clone(),
cleanup_signal: self.cleanup_signal.clone(),
cleanup_sent: false, }
}
}
impl AsyncInternalSubscription {
#[cfg(test)]
pub(crate) fn new(receiver: broadcast::Receiver<RoutedItem>) -> Self {
let template = receiver.resubscribe();
Self {
template_receiver: template,
stream: BroadcastStream::new(receiver),
cleanup_sender: None,
cleanup_signal: None,
cleanup_sent: false,
}
}
pub(crate) fn with_cleanup(
receiver: broadcast::Receiver<RoutedItem>,
cleanup_sender: mpsc::UnboundedSender<CleanupSignal>,
cleanup_signal: CleanupSignal,
) -> Self {
let template = receiver.resubscribe();
Self {
template_receiver: template,
stream: BroadcastStream::new(receiver),
cleanup_sender: Some(cleanup_sender),
cleanup_signal: Some(cleanup_signal),
cleanup_sent: false,
}
}
pub async fn next(&mut self) -> Option<Result<ResponseMessage, Error>> {
loop {
match self.stream.next().await? {
Ok(item) => {
if let Some(legacy) = item.into_legacy() {
return Some(legacy);
}
}
Err(_lagged) => continue,
}
}
}
#[cfg(test)]
pub(crate) async fn next_routed(&mut self) -> Option<RoutedItem> {
loop {
match self.stream.next().await? {
Ok(item) => return Some(item),
Err(_lagged) => continue,
}
}
}
#[cfg(test)]
pub(crate) fn try_next_routed(&mut self) -> Option<RoutedItem> {
use futures::FutureExt;
loop {
match self.stream.next().now_or_never()? {
Some(Ok(item)) => return Some(item),
Some(Err(_lagged)) => continue,
None => return None,
}
}
}
fn send_cleanup_signal(&mut self) {
if !self.cleanup_sent {
if let (Some(sender), Some(signal)) = (&self.cleanup_sender, &self.cleanup_signal) {
let _ = sender.send(signal.clone());
self.cleanup_sent = true;
}
}
}
}
impl Drop for AsyncInternalSubscription {
fn drop(&mut self) {
self.send_cleanup_signal();
}
}
type BroadcastSender = broadcast::Sender<RoutedItem>;
pub struct AsyncTcpMessageBus<S: AsyncStream = AsyncTcpSocket> {
connection: Arc<AsyncConnection<S>>,
request_channels: Arc<RwLock<HashMap<i32, BroadcastSender>>>,
shared_channel_senders: Arc<RwLock<HashMap<IncomingMessages, Vec<BroadcastSender>>>>,
shared_channel_receivers: Arc<RwLock<HashMap<OutgoingMessages, broadcast::Receiver<RoutedItem>>>>,
order_channels: Arc<RwLock<HashMap<i32, BroadcastSender>>>,
execution_channels: Arc<RwLock<HashMap<String, BroadcastSender>>>,
order_update_stream: Arc<RwLock<Option<BroadcastSender>>>,
cleanup_sender: mpsc::UnboundedSender<CleanupSignal>,
process_task: Arc<RwLock<Option<task::JoinHandle<()>>>>,
shutdown_requested: Arc<AtomicBool>,
shutdown_notify: Arc<Notify>,
connected: Arc<AtomicBool>,
}
impl<S: AsyncStream> Drop for AsyncTcpMessageBus<S> {
fn drop(&mut self) {
debug!("dropping async tcp message bus");
self.shutdown_requested.store(true, Ordering::Relaxed);
self.shutdown_notify.notify_waiters();
}
}
impl<S: AsyncStream> AsyncTcpMessageBus<S> {
pub fn new(connection: AsyncConnection<S>) -> Result<Self, Error> {
let (cleanup_sender, cleanup_receiver) = mpsc::unbounded_channel();
let mut shared_channel_senders = HashMap::new();
let mut shared_channel_receivers = HashMap::new();
for mapping in shared_channel_configuration::CHANNEL_MAPPINGS {
let (sender, receiver) = broadcast::channel(BROADCAST_CHANNEL_CAPACITY);
shared_channel_receivers.insert(mapping.request, receiver);
for response_type in mapping.responses {
shared_channel_senders.entry(*response_type).or_insert_with(Vec::new).push(sender.clone());
}
}
let message_bus = Self {
connection: Arc::new(connection),
request_channels: Arc::new(RwLock::new(HashMap::new())),
shared_channel_senders: Arc::new(RwLock::new(shared_channel_senders)),
shared_channel_receivers: Arc::new(RwLock::new(shared_channel_receivers)),
order_channels: Arc::new(RwLock::new(HashMap::new())),
execution_channels: Arc::new(RwLock::new(HashMap::new())),
order_update_stream: Arc::new(RwLock::new(None)),
cleanup_sender,
process_task: Arc::new(RwLock::new(None)),
shutdown_requested: Arc::new(AtomicBool::new(false)),
shutdown_notify: Arc::new(Notify::new()),
connected: Arc::new(AtomicBool::new(true)),
};
let request_channels = message_bus.request_channels.clone();
let order_channels = message_bus.order_channels.clone();
let order_update_stream = message_bus.order_update_stream.clone();
task::spawn(async move {
let mut receiver = cleanup_receiver;
while let Some(signal) = receiver.recv().await {
match signal {
CleanupSignal::Request(request_id) => {
let mut channels = request_channels.write().await;
channels.remove(&request_id);
debug!("Cleaned up request channel for ID: {request_id}");
}
CleanupSignal::Order(order_id) => {
let mut channels = order_channels.write().await;
channels.remove(&order_id);
debug!("Cleaned up order channel for ID: {order_id}");
}
CleanupSignal::Shared(message_type) => {
debug!("Subscription for shared channel {:?} ended (channel remains active)", message_type);
}
CleanupSignal::OrderUpdateStream => {
let mut stream = order_update_stream.write().await;
*stream = None;
debug!("Cleaned up order update stream ownership");
}
}
}
});
Ok(message_bus)
}
pub fn process_messages(self: Arc<Self>, _server_version: i32, _reconnect_delay: Duration) -> Result<(), Error> {
let message_bus = self.clone();
let shutdown_notify = self.shutdown_notify.clone();
let handle = task::spawn(async move {
loop {
tokio::select! {
_ = shutdown_notify.notified() => {
debug!("Shutdown notification received, stopping message processing");
break;
}
result = message_bus.read_and_route_message() => {
use crate::client::error_handler::{is_connection_error, is_timeout_error};
match result {
Ok(_) => continue,
Err(ref err) if is_timeout_error(err) => {
if message_bus.shutdown_requested.load(Ordering::Relaxed) {
debug!("dispatcher task exiting");
break;
}
continue;
}
Err(ref err) if is_connection_error(err) => {
error!("Connection error detected, attempting to reconnect: {err:?}");
message_bus.connected.store(false, Ordering::Relaxed);
match message_bus.connection.reconnect().await {
Ok(_) => {
info!("Successfully reconnected to TWS/Gateway");
message_bus.connected.store(true, Ordering::Relaxed);
message_bus.reset_channels().await;
}
Err(e) => {
error!("Failed to reconnect to TWS/Gateway: {e:?}");
message_bus.request_shutdown().await;
break;
}
}
continue;
}
Err(Error::Shutdown) => {
error!("Received shutdown signal, stopping message processing.");
break;
}
Err(err) => {
error!("Error processing message (shutting down): {err:?}");
message_bus.request_shutdown().await;
break;
}
}
}
}
}
});
let process_task = self.process_task.clone();
tokio::spawn(async move {
let mut task_guard = process_task.write().await;
*task_guard = Some(handle);
});
Ok(())
}
pub(crate) async fn read_and_route_message(&self) -> Result<(), Error> {
let message = self.connection.read_message().await?;
match determine_routing(&message) {
RoutingDecision::ByRequestId(request_id) => self.route_to_request_channel(request_id, message).await,
RoutingDecision::ByOrderId(order_id) => self.route_to_order_channel(order_id, message).await,
RoutingDecision::ByMessageType(message_type) => self.route_to_shared_channel(message_type, message).await,
RoutingDecision::SharedMessage(message_type) => self.route_to_shared_channel(message_type, message).await,
RoutingDecision::Error(payload) => self.route_error_message(message, payload).await,
RoutingDecision::Shutdown => {
debug!("Received shutdown message, calling request_shutdown");
self.request_shutdown().await;
Err(Error::Shutdown)
}
}
}
async fn reset_channels(&self) {
debug!("resetting message bus channels");
{
let channels = self.request_channels.read().await;
for (_, sender) in channels.iter() {
let _ = sender.send(Error::ConnectionReset.into());
}
}
{
let channels = self.order_channels.read().await;
for (_, sender) in channels.iter() {
let _ = sender.send(Error::ConnectionReset.into());
}
}
{
let mut channels = self.request_channels.write().await;
channels.clear();
}
{
let mut channels = self.order_channels.write().await;
channels.clear();
}
{
let mut channels = self.execution_channels.write().await;
channels.clear();
}
}
async fn request_shutdown(&self) {
debug!("shutdown requested");
self.connected.store(false, Ordering::Relaxed);
self.shutdown_requested.store(true, Ordering::Relaxed);
self.shutdown_notify.notify_waiters();
{
let mut channels = self.request_channels.write().await;
channels.clear();
}
{
let mut channels = self.order_channels.write().await;
channels.clear();
}
{
let mut channels = self.shared_channel_senders.write().await;
channels.clear();
}
{
let mut channels = self.shared_channel_receivers.write().await;
channels.clear();
}
{
let mut order_update_stream = self.order_update_stream.write().await;
*order_update_stream = None;
}
}
async fn route_error_message(&self, message: ResponseMessage, payload: DecodedError) -> Result<(), Error> {
let sent_to_update_stream = self.send_order_update(&message).await;
let request_id = payload.request_id;
let is_warning = is_warning_error(payload.error_code);
if request_id == UNSPECIFIED_REQUEST_ID {
let notice = Notice::from(payload);
super::common::log_unrouted_notice(¬ice);
let _ = self.connection.notice_sender.send(notice);
} else {
let item = if is_warning {
RoutedItem::Notice(Notice::from(payload))
} else {
RoutedItem::Error(Error::from(payload))
};
self.deliver_to_request_id(request_id, item, sent_to_update_stream).await;
}
Ok(())
}
async fn deliver_to_request_id(&self, request_id: i32, item: RoutedItem, sent_to_update_stream: bool) {
{
let channels = self.request_channels.read().await;
if let Some(sender) = channels.get(&request_id) {
let _ = sender.send(item);
return;
}
}
{
let order_channels = self.order_channels.read().await;
if let Some(sender) = order_channels.get(&request_id) {
let _ = sender.send(item);
return;
}
}
if !sent_to_update_stream {
log_orphan(request_id, &item);
}
}
async fn route_to_request_channel(&self, request_id: i32, message: ResponseMessage) -> Result<(), Error> {
let channels = self.request_channels.read().await;
if let Some(sender) = channels.get(&request_id) {
let _ = sender.send(message.into());
}
Ok(())
}
async fn route_to_order_channel(&self, order_id: i32, message: ResponseMessage) -> Result<(), Error> {
let routed = self.send_order_update(&message).await;
let strategy = order_routing_strategy(message.message_type());
match strategy {
OrderRoutingStrategy::ExecutionData => {
if let Some(actual_order_id) = message.order_id() {
let channels = self.order_channels.read().await;
if let Some(sender) = channels.get(&actual_order_id) {
self.store_execution_mapping(&message, sender).await;
let _ = sender.send(message.into());
return Ok(());
}
}
if let Some(req_id) = message.request_id() {
let channels = self.request_channels.read().await;
if let Some(sender) = channels.get(&req_id) {
self.store_execution_mapping(&message, sender).await;
let _ = sender.send(message.into());
return Ok(());
}
}
if !routed {
warn!("could not route ExecutionData message {:?}", message);
}
}
OrderRoutingStrategy::ExecutionDataEnd => {
if let Some(actual_order_id) = message.order_id() {
let channels = self.order_channels.read().await;
if let Some(sender) = channels.get(&actual_order_id) {
let _ = sender.send(message.into());
return Ok(());
}
}
if let Some(req_id) = message.request_id() {
let channels = self.request_channels.read().await;
if let Some(sender) = channels.get(&req_id) {
let _ = sender.send(message.into());
return Ok(());
}
}
warn!("could not route ExecutionDataEnd message {:?}", message);
}
OrderRoutingStrategy::OrderOrShared => {
if let Some(actual_order_id) = message.order_id() {
let channels = self.order_channels.read().await;
if let Some(sender) = channels.get(&actual_order_id) {
let _ = sender.send(message.into());
return Ok(());
}
drop(channels);
let shared_channels = self.shared_channel_senders.read().await;
if let Some(senders) = shared_channels.get(&message.message_type()) {
for sender in senders {
let _ = sender.send(message.clone().into());
}
return Ok(());
}
}
if !routed {
warn!("could not route message {:?}", message);
}
}
OrderRoutingStrategy::ByExecutionId => {
if let Some(execution_id) = message.execution_id() {
let exec_channels = self.execution_channels.read().await;
if let Some(sender) = exec_channels.get(&execution_id) {
let _ = sender.send(message.into());
return Ok(());
}
}
}
OrderRoutingStrategy::SharedOnly => {
let shared_channels = self.shared_channel_senders.read().await;
if let Some(senders) = shared_channels.get(&message.message_type()) {
for sender in senders {
let _ = sender.send(message.clone().into());
}
return Ok(());
}
if !routed {
warn!("could not route message {:?}", message);
}
}
OrderRoutingStrategy::ByOrderId => {
if order_id >= 0 {
let channels = self.order_channels.read().await;
if let Some(sender) = channels.get(&order_id) {
let _ = sender.send(message.into());
return Ok(());
}
}
if !routed {
warn!("could not route message {:?}", message);
}
}
}
Ok(())
}
async fn store_execution_mapping(&self, message: &ResponseMessage, sender: &BroadcastSender) {
if let Some(execution_id) = message.execution_id() {
let mut exec_channels = self.execution_channels.write().await;
exec_channels.insert(execution_id, sender.clone());
}
}
async fn route_to_shared_channel(&self, message_type: IncomingMessages, message: ResponseMessage) -> Result<(), Error> {
match message_type {
IncomingMessages::OpenOrder
| IncomingMessages::OrderStatus
| IncomingMessages::ExecutionData
| IncomingMessages::CommissionsReport
| IncomingMessages::CompletedOrder => {
self.send_order_update(&message).await;
}
_ => {}
}
let channels = self.shared_channel_senders.read().await;
if let Some(senders) = channels.get(&message_type) {
for sender in senders {
if let Err(e) = sender.send(message.clone().into()) {
warn!("error sending to shared channel for {message_type:?}: {e}");
}
}
}
Ok(())
}
async fn send_order_update(&self, message: &ResponseMessage) -> bool {
let order_update_stream = self.order_update_stream.read().await;
if let Some(sender) = order_update_stream.as_ref() {
if let Err(e) = sender.send(message.clone().into()) {
warn!("error sending to order update stream: {e}");
return false;
}
return true;
}
false
}
}
#[async_trait]
impl<S: AsyncStream> AsyncMessageBus for AsyncTcpMessageBus<S> {
async fn send_request(&self, request_id: i32, message: Vec<u8>) -> Result<AsyncInternalSubscription, Error> {
let (sender, receiver) = broadcast::channel(BROADCAST_CHANNEL_CAPACITY);
{
let mut channels = self.request_channels.write().await;
channels.insert(request_id, sender);
}
self.connection.write_message(&message).await?;
Ok(AsyncInternalSubscription::with_cleanup(
receiver,
self.cleanup_sender.clone(),
CleanupSignal::Request(request_id),
))
}
async fn send_order_request(&self, order_id: i32, message: Vec<u8>) -> Result<AsyncInternalSubscription, Error> {
let (sender, receiver) = broadcast::channel(BROADCAST_CHANNEL_CAPACITY);
{
let mut channels = self.order_channels.write().await;
channels.insert(order_id, sender);
}
self.connection.write_message(&message).await?;
Ok(AsyncInternalSubscription::with_cleanup(
receiver,
self.cleanup_sender.clone(),
CleanupSignal::Order(order_id),
))
}
async fn send_shared_request(&self, message_type: OutgoingMessages, message: Vec<u8>) -> Result<AsyncInternalSubscription, Error> {
let receiver = {
let channels = self.shared_channel_receivers.read().await;
if let Some(receiver) = channels.get(&message_type) {
receiver.resubscribe()
} else {
return Err(Error::InvalidArgument(format!(
"No shared channel configured for message type: {:?}",
message_type
)));
}
};
self.connection.write_message(&message).await?;
Ok(AsyncInternalSubscription::with_cleanup(
receiver,
self.cleanup_sender.clone(),
CleanupSignal::Shared(message_type),
))
}
async fn send_message(&self, message: Vec<u8>) -> Result<(), Error> {
self.connection.write_message(&message).await
}
async fn cancel_subscription(&self, request_id: i32, message: Vec<u8>) -> Result<(), Error> {
self.connection.write_message(&message).await?;
let mut channels = self.request_channels.write().await;
if let Some(sender) = channels.get(&request_id) {
let _ = sender.send(Error::Cancelled.into());
}
channels.remove(&request_id);
Ok(())
}
async fn cancel_order_subscription(&self, order_id: i32, message: Vec<u8>) -> Result<(), Error> {
self.connection.write_message(&message).await?;
let mut channels = self.order_channels.write().await;
if let Some(sender) = channels.get(&order_id) {
let _ = sender.send(Error::Cancelled.into());
}
channels.remove(&order_id);
Ok(())
}
async fn create_order_update_subscription(&self) -> Result<AsyncInternalSubscription, Error> {
let mut order_update_stream = self.order_update_stream.write().await;
if order_update_stream.is_some() {
return Err(Error::AlreadySubscribed);
}
let (sender, receiver) = broadcast::channel(BROADCAST_CHANNEL_CAPACITY);
*order_update_stream = Some(sender);
Ok(AsyncInternalSubscription::with_cleanup(
receiver,
self.cleanup_sender.clone(),
CleanupSignal::OrderUpdateStream,
))
}
fn notice_subscribe(&self) -> crate::subscriptions::notice_stream::async_impl::NoticeStream {
crate::subscriptions::notice_stream::async_impl::NoticeStream::new(self.connection.notice_sender.subscribe())
}
async fn ensure_shutdown(&self) {
debug!("ensure_shutdown called");
self.request_shutdown().await;
let task_handle = {
let mut task_guard = self.process_task.write().await;
task_guard.take()
};
if let Some(handle) = task_handle {
debug!("Waiting for processing task to finish");
if let Err(e) = handle.await {
warn!("Error joining processing task: {e}");
}
debug!("Processing task finished");
}
}
fn request_shutdown_sync(&self) {
debug!("sync shutdown requested");
self.connected.store(false, Ordering::Relaxed);
self.shutdown_requested.store(true, Ordering::Relaxed);
self.shutdown_notify.notify_waiters();
}
fn is_connected(&self) -> bool {
self.connected.load(Ordering::Relaxed) && !self.shutdown_requested.load(Ordering::Relaxed)
}
}
#[cfg(test)]
mod memory;
#[cfg(test)]
pub(crate) use memory::MemoryStream;
#[cfg(test)]
pub(crate) mod test_listener;
#[cfg(test)]
#[path = "async_tests.rs"]
mod tests;