use std::collections::HashMap;
use std::mem;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use async_trait::async_trait;
use log::{debug, error, info, warn};
use tokio::sync::{broadcast, mpsc, Notify, RwLock};
use tokio::task;
use tokio::time::Duration;
const BROADCAST_CHANNEL_CAPACITY: usize = 1024;
#[derive(Debug, Clone)]
pub enum CleanupSignal {
Request(i32),
Order(i32),
Shared(OutgoingMessages),
OrderUpdateStream,
}
use crate::connection::r#async::AsyncConnection;
use crate::messages::{shared_channel_configuration, IncomingMessages, OutgoingMessages, RequestMessage, ResponseMessage};
use crate::Error;
use super::routing::{determine_routing, is_warning_error, RoutingDecision, UNSPECIFIED_REQUEST_ID};
#[async_trait]
pub trait AsyncMessageBus: Send + Sync {
async fn send_request(&self, request_id: i32, message: RequestMessage) -> Result<AsyncInternalSubscription, Error>;
async fn send_order_request(&self, order_id: i32, message: RequestMessage) -> Result<AsyncInternalSubscription, Error>;
async fn send_shared_request(&self, message_type: OutgoingMessages, message: RequestMessage) -> Result<AsyncInternalSubscription, Error>;
async fn send_message(&self, message: RequestMessage) -> Result<(), Error>;
#[allow(dead_code)]
async fn cancel_subscription(&self, request_id: i32, message: RequestMessage) -> Result<(), Error>;
#[allow(dead_code)]
async fn cancel_order_subscription(&self, order_id: i32, message: RequestMessage) -> Result<(), Error>;
async fn create_order_update_subscription(&self) -> Result<AsyncInternalSubscription, Error>;
async fn ensure_shutdown(&self);
fn request_shutdown_sync(&self);
fn is_connected(&self) -> bool;
#[cfg(test)]
fn request_messages(&self) -> Vec<RequestMessage> {
vec![]
}
}
pub struct AsyncInternalSubscription {
pub(crate) receiver: broadcast::Receiver<ResponseMessage>,
cleanup_sender: Option<mpsc::UnboundedSender<CleanupSignal>>,
cleanup_signal: Option<CleanupSignal>,
cleanup_sent: bool,
}
impl Clone for AsyncInternalSubscription {
fn clone(&self) -> Self {
Self {
receiver: self.receiver.resubscribe(),
cleanup_sender: self.cleanup_sender.clone(),
cleanup_signal: self.cleanup_signal.clone(),
cleanup_sent: false, }
}
}
impl AsyncInternalSubscription {
pub fn new(receiver: broadcast::Receiver<ResponseMessage>) -> Self {
Self {
receiver,
cleanup_sender: None,
cleanup_signal: None,
cleanup_sent: false,
}
}
pub fn with_cleanup(
receiver: broadcast::Receiver<ResponseMessage>,
cleanup_sender: mpsc::UnboundedSender<CleanupSignal>,
cleanup_signal: CleanupSignal,
) -> Self {
Self {
receiver,
cleanup_sender: Some(cleanup_sender),
cleanup_signal: Some(cleanup_signal),
cleanup_sent: false,
}
}
pub async fn next(&mut self) -> Option<Result<ResponseMessage, Error>> {
loop {
match self.receiver.recv().await {
Ok(msg) => return Some(Ok(msg)),
Err(broadcast::error::RecvError::Closed) => return None,
Err(broadcast::error::RecvError::Lagged(_)) => {
continue;
}
}
}
}
pub fn take_receiver(mut self) -> broadcast::Receiver<ResponseMessage> {
self.cleanup_sender = None;
self.cleanup_signal = None;
self.cleanup_sent = true;
let (dummy_sender, dummy_receiver) = broadcast::channel(1);
drop(dummy_sender); mem::replace(&mut self.receiver, dummy_receiver)
}
fn send_cleanup_signal(&mut self) {
if !self.cleanup_sent {
if let (Some(sender), Some(signal)) = (&self.cleanup_sender, &self.cleanup_signal) {
let _ = sender.send(signal.clone());
self.cleanup_sent = true;
}
}
}
}
impl Drop for AsyncInternalSubscription {
fn drop(&mut self) {
self.send_cleanup_signal();
}
}
type BroadcastSender = broadcast::Sender<ResponseMessage>;
pub struct AsyncTcpMessageBus {
connection: Arc<AsyncConnection>,
request_channels: Arc<RwLock<HashMap<i32, BroadcastSender>>>,
shared_channel_senders: Arc<RwLock<HashMap<IncomingMessages, Vec<BroadcastSender>>>>,
shared_channel_receivers: Arc<RwLock<HashMap<OutgoingMessages, broadcast::Receiver<ResponseMessage>>>>,
order_channels: Arc<RwLock<HashMap<i32, BroadcastSender>>>,
execution_channels: Arc<RwLock<HashMap<String, BroadcastSender>>>,
order_update_stream: Arc<RwLock<Option<BroadcastSender>>>,
cleanup_sender: mpsc::UnboundedSender<CleanupSignal>,
process_task: Arc<RwLock<Option<task::JoinHandle<()>>>>,
shutdown_requested: Arc<AtomicBool>,
shutdown_notify: Arc<Notify>,
connected: Arc<AtomicBool>,
}
impl Drop for AsyncTcpMessageBus {
fn drop(&mut self) {
debug!("dropping async tcp message bus");
self.shutdown_requested.store(true, Ordering::Relaxed);
self.shutdown_notify.notify_waiters();
}
}
impl AsyncTcpMessageBus {
pub fn new(connection: AsyncConnection) -> Result<Self, Error> {
let (cleanup_sender, cleanup_receiver) = mpsc::unbounded_channel();
let mut shared_channel_senders = HashMap::new();
let mut shared_channel_receivers = HashMap::new();
for mapping in shared_channel_configuration::CHANNEL_MAPPINGS {
let (sender, receiver) = broadcast::channel(BROADCAST_CHANNEL_CAPACITY);
shared_channel_receivers.insert(mapping.request, receiver);
for response_type in mapping.responses {
shared_channel_senders.entry(*response_type).or_insert_with(Vec::new).push(sender.clone());
}
}
let message_bus = Self {
connection: Arc::new(connection),
request_channels: Arc::new(RwLock::new(HashMap::new())),
shared_channel_senders: Arc::new(RwLock::new(shared_channel_senders)),
shared_channel_receivers: Arc::new(RwLock::new(shared_channel_receivers)),
order_channels: Arc::new(RwLock::new(HashMap::new())),
execution_channels: Arc::new(RwLock::new(HashMap::new())),
order_update_stream: Arc::new(RwLock::new(None)),
cleanup_sender,
process_task: Arc::new(RwLock::new(None)),
shutdown_requested: Arc::new(AtomicBool::new(false)),
shutdown_notify: Arc::new(Notify::new()),
connected: Arc::new(AtomicBool::new(true)),
};
let request_channels = message_bus.request_channels.clone();
let order_channels = message_bus.order_channels.clone();
let order_update_stream = message_bus.order_update_stream.clone();
task::spawn(async move {
let mut receiver = cleanup_receiver;
while let Some(signal) = receiver.recv().await {
match signal {
CleanupSignal::Request(request_id) => {
let mut channels = request_channels.write().await;
channels.remove(&request_id);
debug!("Cleaned up request channel for ID: {request_id}");
}
CleanupSignal::Order(order_id) => {
let mut channels = order_channels.write().await;
channels.remove(&order_id);
debug!("Cleaned up order channel for ID: {order_id}");
}
CleanupSignal::Shared(message_type) => {
debug!("Subscription for shared channel {:?} ended (channel remains active)", message_type);
}
CleanupSignal::OrderUpdateStream => {
let mut stream = order_update_stream.write().await;
*stream = None;
debug!("Cleaned up order update stream ownership");
}
}
}
});
Ok(message_bus)
}
pub fn process_messages(self: Arc<Self>, _server_version: i32, _reconnect_delay: Duration) -> Result<(), Error> {
let message_bus = self.clone();
let shutdown_notify = self.shutdown_notify.clone();
let handle = task::spawn(async move {
loop {
tokio::select! {
_ = shutdown_notify.notified() => {
debug!("Shutdown notification received, stopping message processing");
break;
}
result = message_bus.read_and_route_message() => {
use crate::client::error_handler::{is_connection_error, is_timeout_error};
match result {
Ok(_) => continue,
Err(ref err) if is_timeout_error(err) => {
if message_bus.shutdown_requested.load(Ordering::Relaxed) {
debug!("dispatcher task exiting");
break;
}
continue;
}
Err(ref err) if is_connection_error(err) => {
error!("Connection error detected, attempting to reconnect: {err:?}");
message_bus.connected.store(false, Ordering::Relaxed);
match message_bus.connection.reconnect().await {
Ok(_) => {
info!("Successfully reconnected to TWS/Gateway");
message_bus.connected.store(true, Ordering::Relaxed);
message_bus.reset_channels().await;
}
Err(e) => {
error!("Failed to reconnect to TWS/Gateway: {e:?}");
message_bus.request_shutdown().await;
break;
}
}
continue;
}
Err(Error::Shutdown) => {
error!("Received shutdown signal, stopping message processing.");
break;
}
Err(err) => {
error!("Error processing message (shutting down): {err:?}");
message_bus.request_shutdown().await;
break;
}
}
}
}
}
});
let process_task = self.process_task.clone();
tokio::spawn(async move {
let mut task_guard = process_task.write().await;
*task_guard = Some(handle);
});
Ok(())
}
async fn read_and_route_message(&self) -> Result<(), Error> {
let message = self.connection.read_message().await?;
match determine_routing(&message) {
RoutingDecision::ByRequestId(request_id) => self.route_to_request_channel(request_id, message).await,
RoutingDecision::ByOrderId(order_id) => self.route_to_order_channel(order_id, message).await,
RoutingDecision::ByMessageType(message_type) => self.route_to_shared_channel(message_type, message).await,
RoutingDecision::SharedMessage(message_type) => self.route_to_shared_channel(message_type, message).await,
RoutingDecision::Error { request_id, error_code } => self.route_error_message(message, request_id, error_code).await,
RoutingDecision::Shutdown => {
debug!("Received shutdown message, calling request_shutdown");
self.request_shutdown().await;
Err(Error::Shutdown)
}
}
}
async fn reset_channels(&self) {
debug!("resetting message bus channels");
{
let channels = self.request_channels.read().await;
for (_, sender) in channels.iter() {
let error_msg = ResponseMessage::from("ConnectionReset");
let _ = sender.send(error_msg);
}
}
{
let channels = self.order_channels.read().await;
for (_, sender) in channels.iter() {
let error_msg = ResponseMessage::from("ConnectionReset");
let _ = sender.send(error_msg);
}
}
{
let mut channels = self.request_channels.write().await;
channels.clear();
}
{
let mut channels = self.order_channels.write().await;
channels.clear();
}
{
let mut channels = self.execution_channels.write().await;
channels.clear();
}
}
async fn request_shutdown(&self) {
debug!("shutdown requested");
self.connected.store(false, Ordering::Relaxed);
self.shutdown_requested.store(true, Ordering::Relaxed);
self.shutdown_notify.notify_waiters();
{
let mut channels = self.request_channels.write().await;
channels.clear();
}
{
let mut channels = self.order_channels.write().await;
channels.clear();
}
{
let mut channels = self.shared_channel_senders.write().await;
channels.clear();
}
{
let mut channels = self.shared_channel_receivers.write().await;
channels.clear();
}
{
let mut order_update_stream = self.order_update_stream.write().await;
*order_update_stream = None;
}
}
async fn route_error_message(&self, message: ResponseMessage, request_id: i32, error_code: i32) -> Result<(), Error> {
let _ = self.send_order_update(&message).await;
let error_msg = message.error_message();
if request_id == UNSPECIFIED_REQUEST_ID || is_warning_error(error_code) {
if is_warning_error(error_code) {
warn!("Warning - Request ID: {request_id}, Code: {error_code}, Message: {error_msg}");
} else {
error!("Error - Request ID: {request_id}, Code: {error_code}, Message: {error_msg}");
}
} else {
info!("Error message - Request ID: {request_id}, Code: {error_code}, Message: {error_msg}");
let sent_to_update_stream = if message.order_id().is_some() {
self.send_order_update(&message).await
} else {
false
};
let channels = self.request_channels.read().await;
if let Some(sender) = channels.get(&request_id) {
let _ = sender.send(message);
} else {
let order_channels = self.order_channels.read().await;
if let Some(sender) = order_channels.get(&request_id) {
let _ = sender.send(message.clone());
} else if !sent_to_update_stream && message.order_id().is_some() {
info!("order error message has no recipient: {:?}", message);
}
}
}
Ok(())
}
async fn route_to_request_channel(&self, request_id: i32, message: ResponseMessage) -> Result<(), Error> {
let channels = self.request_channels.read().await;
if let Some(sender) = channels.get(&request_id) {
let _ = sender.send(message);
}
Ok(())
}
async fn route_to_order_channel(&self, order_id: i32, message: ResponseMessage) -> Result<(), Error> {
let routed = self.send_order_update(&message).await;
let message_type = message.message_type();
match message_type {
IncomingMessages::ExecutionData => {
let order_id = message.order_id();
let request_id = message.request_id();
if let Some(actual_order_id) = order_id {
let channels = self.order_channels.read().await;
if let Some(sender) = channels.get(&actual_order_id) {
if let Some(execution_id) = message.execution_id() {
let mut exec_channels = self.execution_channels.write().await;
exec_channels.insert(execution_id, sender.clone());
}
let _ = sender.send(message);
return Ok(());
}
}
if let Some(req_id) = request_id {
let channels = self.request_channels.read().await;
if let Some(sender) = channels.get(&req_id) {
if let Some(execution_id) = message.execution_id() {
let mut exec_channels = self.execution_channels.write().await;
exec_channels.insert(execution_id, sender.clone());
}
let _ = sender.send(message);
return Ok(());
}
}
if !routed {
warn!("could not route ExecutionData message {:?}", message);
}
}
IncomingMessages::CommissionsReport => {
if let Some(execution_id) = message.execution_id() {
let exec_channels = self.execution_channels.read().await;
if let Some(sender) = exec_channels.get(&execution_id) {
let _ = sender.send(message);
return Ok(());
}
}
}
IncomingMessages::OpenOrder | IncomingMessages::OrderStatus => {
if let Some(actual_order_id) = message.order_id() {
let channels = self.order_channels.read().await;
if let Some(sender) = channels.get(&actual_order_id) {
let _ = sender.send(message);
return Ok(());
}
let shared_channels = self.shared_channel_senders.read().await;
if let Some(senders) = shared_channels.get(&message_type) {
for sender in senders {
if let Err(e) = sender.send(message.clone()) {
warn!("error sending to shared channel for {message_type:?}: {e}");
}
}
return Ok(());
}
}
if !routed {
warn!("could not route message {:?}", message);
}
}
IncomingMessages::ExecutionDataEnd => {
let order_id = message.order_id();
let request_id = message.request_id();
if let Some(actual_order_id) = order_id {
let channels = self.order_channels.read().await;
if let Some(sender) = channels.get(&actual_order_id) {
let _ = sender.send(message);
return Ok(());
}
}
if let Some(req_id) = request_id {
let channels = self.request_channels.read().await;
if let Some(sender) = channels.get(&req_id) {
let _ = sender.send(message);
return Ok(());
}
}
warn!("could not route ExecutionDataEnd message {:?}", message);
}
IncomingMessages::CompletedOrder | IncomingMessages::OpenOrderEnd | IncomingMessages::CompletedOrdersEnd => {
let shared_channels = self.shared_channel_senders.read().await;
if let Some(senders) = shared_channels.get(&message_type) {
for sender in senders {
if let Err(e) = sender.send(message.clone()) {
warn!("error sending to shared channel for {message_type:?}: {e}");
}
}
return Ok(());
}
if !routed {
warn!("could not route message {:?}", message);
}
}
_ => {
if order_id >= 0 {
let channels = self.order_channels.read().await;
if let Some(sender) = channels.get(&order_id) {
let _ = sender.send(message);
return Ok(());
}
}
if !routed {
warn!("could not route message {:?}", message);
}
}
}
Ok(())
}
async fn route_to_shared_channel(&self, message_type: IncomingMessages, message: ResponseMessage) -> Result<(), Error> {
match message_type {
IncomingMessages::OpenOrder
| IncomingMessages::OrderStatus
| IncomingMessages::ExecutionData
| IncomingMessages::CommissionsReport
| IncomingMessages::CompletedOrder => {
self.send_order_update(&message).await;
}
_ => {}
}
let channels = self.shared_channel_senders.read().await;
if let Some(senders) = channels.get(&message_type) {
for sender in senders {
if let Err(e) = sender.send(message.clone()) {
warn!("error sending to shared channel for {message_type:?}: {e}");
}
}
}
Ok(())
}
async fn send_order_update(&self, message: &ResponseMessage) -> bool {
let order_update_stream = self.order_update_stream.read().await;
if let Some(sender) = order_update_stream.as_ref() {
if let Err(e) = sender.send(message.clone()) {
warn!("error sending to order update stream: {e}");
return false;
}
return true;
}
false
}
}
#[async_trait]
impl AsyncMessageBus for AsyncTcpMessageBus {
async fn send_request(&self, request_id: i32, message: RequestMessage) -> Result<AsyncInternalSubscription, Error> {
let (sender, receiver) = broadcast::channel(BROADCAST_CHANNEL_CAPACITY);
{
let mut channels = self.request_channels.write().await;
channels.insert(request_id, sender);
}
self.connection.write_message(&message).await?;
Ok(AsyncInternalSubscription::with_cleanup(
receiver,
self.cleanup_sender.clone(),
CleanupSignal::Request(request_id),
))
}
async fn send_order_request(&self, order_id: i32, message: RequestMessage) -> Result<AsyncInternalSubscription, Error> {
let (sender, receiver) = broadcast::channel(BROADCAST_CHANNEL_CAPACITY);
{
let mut channels = self.order_channels.write().await;
channels.insert(order_id, sender);
}
self.connection.write_message(&message).await?;
Ok(AsyncInternalSubscription::with_cleanup(
receiver,
self.cleanup_sender.clone(),
CleanupSignal::Order(order_id),
))
}
async fn send_shared_request(&self, message_type: OutgoingMessages, message: RequestMessage) -> Result<AsyncInternalSubscription, Error> {
let receiver = {
let channels = self.shared_channel_receivers.read().await;
if let Some(receiver) = channels.get(&message_type) {
receiver.resubscribe()
} else {
return Err(Error::Simple(format!(
"No shared channel configured for message type: {:?}",
message_type
)));
}
};
self.connection.write_message(&message).await?;
Ok(AsyncInternalSubscription::with_cleanup(
receiver,
self.cleanup_sender.clone(),
CleanupSignal::Shared(message_type),
))
}
async fn send_message(&self, message: RequestMessage) -> Result<(), Error> {
self.connection.write_message(&message).await
}
async fn cancel_subscription(&self, request_id: i32, message: RequestMessage) -> Result<(), Error> {
self.connection.write_message(&message).await?;
let channels = self.request_channels.read().await;
if let Some(sender) = channels.get(&request_id) {
let _ = sender.send(ResponseMessage::from("Cancelled"));
}
let mut channels = self.request_channels.write().await;
channels.remove(&request_id);
Ok(())
}
async fn cancel_order_subscription(&self, order_id: i32, message: RequestMessage) -> Result<(), Error> {
self.connection.write_message(&message).await?;
let channels = self.order_channels.read().await;
if let Some(sender) = channels.get(&order_id) {
let _ = sender.send(ResponseMessage::from("Cancelled"));
}
let mut channels = self.order_channels.write().await;
channels.remove(&order_id);
Ok(())
}
async fn create_order_update_subscription(&self) -> Result<AsyncInternalSubscription, Error> {
let mut order_update_stream = self.order_update_stream.write().await;
if order_update_stream.is_some() {
return Err(Error::AlreadySubscribed);
}
let (sender, receiver) = broadcast::channel(BROADCAST_CHANNEL_CAPACITY);
*order_update_stream = Some(sender);
Ok(AsyncInternalSubscription::with_cleanup(
receiver,
self.cleanup_sender.clone(),
CleanupSignal::OrderUpdateStream,
))
}
async fn ensure_shutdown(&self) {
debug!("ensure_shutdown called");
self.request_shutdown().await;
let task_handle = {
let mut task_guard = self.process_task.write().await;
task_guard.take()
};
if let Some(handle) = task_handle {
debug!("Waiting for processing task to finish");
if let Err(e) = handle.await {
warn!("Error joining processing task: {e}");
}
debug!("Processing task finished");
}
}
fn request_shutdown_sync(&self) {
debug!("sync shutdown requested");
self.connected.store(false, Ordering::Relaxed);
self.shutdown_requested.store(true, Ordering::Relaxed);
self.shutdown_notify.notify_waiters();
}
fn is_connected(&self) -> bool {
self.connected.load(Ordering::Relaxed) && !self.shutdown_requested.load(Ordering::Relaxed)
}
}