use std::collections::HashMap;
use std::io::{prelude::*, Cursor, ErrorKind};
use std::net::TcpStream;
use std::ops::RangeInclusive;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use std::thread::{self, JoinHandle};
use std::time::Duration;
use byteorder::{BigEndian, ReadBytesExt};
use crossbeam::channel::{self, Receiver, Sender};
use log::{debug, error, info, warn};
use time::macros::format_description;
use time::OffsetDateTime;
use time_tz::{timezones, OffsetResult, PrimitiveDateTimeExt, Tz};
use crate::messages::{encode_length, shared_channel_configuration, IncomingMessages, OutgoingMessages, RequestMessage, ResponseMessage};
use crate::{server_versions, Error};
use recorder::MessageRecorder;
mod connection;
mod recorder;
const MIN_SERVER_VERSION: i32 = 100;
const MAX_SERVER_VERSION: i32 = server_versions::WSH_EVENT_DATA_FILTERS_DATE;
const MAX_RETRIES: i32 = 20;
const TWS_READ_TIMEOUT: Duration = Duration::from_secs(1);
const WARNING_CODES: RangeInclusive<i32> = 2100..=2169;
pub(crate) trait MessageBus: Send + Sync {
fn send_request(&self, request_id: i32, packet: &RequestMessage) -> Result<InternalSubscription, Error>;
fn cancel_subscription(&self, request_id: i32, packet: &RequestMessage) -> Result<(), Error>;
fn send_shared_request(&self, message_id: OutgoingMessages, packet: &RequestMessage) -> Result<InternalSubscription, Error>;
fn cancel_shared_subscription(&self, message_id: OutgoingMessages, packet: &RequestMessage) -> Result<(), Error>;
fn send_order_request(&self, request_id: i32, packet: &RequestMessage) -> Result<InternalSubscription, Error>;
fn send_message(&self, packet: &RequestMessage) -> Result<(), Error>;
fn create_order_update_subscription(&self) -> Result<InternalSubscription, Error>;
fn cancel_order_subscription(&self, request_id: i32, packet: &RequestMessage) -> Result<(), Error>;
fn ensure_shutdown(&self);
#[cfg(test)]
fn request_messages(&self) -> Vec<RequestMessage> {
vec![]
}
}
pub(crate) type Response = Result<ResponseMessage, Error>;
#[derive(Debug)]
struct SharedChannels {
senders: HashMap<IncomingMessages, Vec<Arc<Sender<Response>>>>,
receivers: HashMap<OutgoingMessages, Arc<Receiver<Response>>>,
}
impl SharedChannels {
pub fn new() -> Self {
let mut instance = Self {
senders: HashMap::new(),
receivers: HashMap::new(),
};
for mapping in shared_channel_configuration::CHANNEL_MAPPINGS {
instance.register(mapping.request, mapping.responses);
}
instance
}
fn register(&mut self, outbound: OutgoingMessages, inbounds: &[IncomingMessages]) {
let (sender, receiver) = channel::unbounded::<Response>();
self.receivers.insert(outbound, Arc::new(receiver));
let sender = &Arc::new(sender);
for inbound in inbounds {
if !self.senders.contains_key(inbound) {
self.senders.insert(*inbound, Vec::new());
}
self.senders.get_mut(inbound).unwrap().push(Arc::clone(sender));
}
}
fn get_receiver(&self, message_type: OutgoingMessages) -> Arc<Receiver<Response>> {
let receiver = self
.receivers
.get(&message_type)
.unwrap_or_else(|| panic!("unsupported request message {message_type:?}. check mapping in messages::shared_channel_configuration"));
Arc::clone(receiver)
}
fn contains_sender(&self, message_type: IncomingMessages) -> bool {
self.senders.contains_key(&message_type)
}
fn send_message(&self, message_type: IncomingMessages, message: &ResponseMessage) {
if let Some(senders) = self.senders.get(&message_type) {
for sender in senders {
if let Err(e) = sender.send(Ok(message.clone())) {
warn!("error sending message: {e}");
}
}
}
}
fn notify_all(&self, message: &Response) {
for senders in self.senders.values() {
for sender in senders {
if let Err(e) = sender.send(message.clone()) {
warn!("error sending notification: {e}");
}
}
}
}
}
pub enum Signal {
Request(i32),
Order(i32),
OrderUpdateStream,
}
#[derive(Debug)]
pub struct TcpMessageBus<S: Stream> {
connection: Connection<S>,
handles: Mutex<Vec<JoinHandle<()>>>,
requests: SenderHash<i32, Response>,
orders: SenderHash<i32, Response>,
executions: SenderHash<String, Response>,
shared_channels: SharedChannels,
signals_send: Sender<Signal>,
signals_recv: Receiver<Signal>,
shutdown_requested: AtomicBool,
order_update_stream: Mutex<Option<Sender<Response>>>, }
impl<S: Stream> TcpMessageBus<S> {
pub fn new(connection: Connection<S>) -> Result<TcpMessageBus<S>, Error> {
let (signals_send, signals_recv) = channel::unbounded();
Ok(TcpMessageBus {
connection,
handles: Mutex::new(Vec::default()),
requests: SenderHash::new(),
orders: SenderHash::new(),
executions: SenderHash::new(),
shared_channels: SharedChannels::new(),
signals_send,
signals_recv,
shutdown_requested: AtomicBool::new(false),
order_update_stream: Mutex::new(None),
})
}
fn is_shutting_down(&self) -> bool {
self.shutdown_requested.load(Ordering::SeqCst)
}
fn request_shutdown(&self) {
debug!("shutdown requested");
self.requests.notify_all(&Err(Error::Shutdown));
self.orders.notify_all(&Err(Error::Shutdown));
self.shared_channels.notify_all(&Err(Error::Shutdown));
self.requests.clear();
self.orders.clear();
self.executions.clear();
self.shutdown_requested.store(true, Ordering::Relaxed);
}
fn reset(&self) {
debug!("reset message bus");
self.requests.notify_all(&Err(Error::ConnectionReset));
self.orders.notify_all(&Err(Error::ConnectionReset));
self.shared_channels.notify_all(&Err(Error::ConnectionReset));
self.requests.clear();
self.orders.clear();
self.executions.clear();
}
fn clean_request(&self, request_id: i32) {
self.requests.remove(&request_id);
debug!("released request_id {}, requests.len()={}", request_id, self.requests.len());
}
fn clean_order(&self, order_id: i32) {
self.orders.remove(&order_id);
debug!("released order_id {}, orders.len()={}", order_id, self.orders.len());
}
fn clear_order_update_stream(&self) {
let mut stream = if let Ok(stream) = self.order_update_stream.lock() {
stream
} else {
warn!("failed to lock order_update_stream");
return;
};
*stream = None;
debug!("released order_update_stream");
}
fn read_message(&self) -> Response {
self.connection.read_message()
}
fn dispatch(&self, server_version: i32) -> Result<(), Error> {
const RECONNECT_CODES: &[ErrorKind] = &[ErrorKind::ConnectionReset, ErrorKind::ConnectionAborted, ErrorKind::UnexpectedEof];
const TIMEOUT_CODES: &[ErrorKind] = &[ErrorKind::WouldBlock, ErrorKind::TimedOut];
match self.read_message() {
Ok(message) => {
self.dispatch_message(server_version, message);
Ok(())
}
Err(Error::Io(e)) if TIMEOUT_CODES.contains(&e.kind()) => {
if self.is_shutting_down() {
debug!("dispatcher thread exiting");
return Err(Error::Shutdown);
}
Ok(())
}
Err(Error::Io(e)) if RECONNECT_CODES.contains(&e.kind()) => {
error!("error reading next message (will attempt reconnect): {:?}", e);
if let Err(reconnect_err) = self.connection.reconnect() {
error!("failed to reconnect to TWS/Gateway: {:?}", reconnect_err);
self.request_shutdown();
return Err(Error::ConnectionFailed);
}
info!("successfully reconnected to TWS/Gateway");
self.reset();
Ok(())
}
Err(err) => {
error!("error reading next message (shutting down): {:?}", err);
self.request_shutdown();
Err(err)
}
}
}
fn start_dispatcher_thread(self: &Arc<Self>, server_version: i32) -> JoinHandle<()> {
let message_bus = Arc::clone(self);
thread::spawn(move || {
loop {
match message_bus.dispatch(server_version) {
Ok(_) => continue,
Err(Error::Shutdown) | Err(Error::ConnectionFailed) => break,
Err(e) => {
error!("Dispatcher encountered an error: {:?}", e);
break;
}
}
}
debug!("Dispatcher thread finished.");
})
}
fn dispatch_message(&self, server_version: i32, message: ResponseMessage) {
match message.message_type() {
IncomingMessages::Error => {
let request_id = message.peek_int(2).unwrap_or(-1);
let error_code = message.peek_int(3).unwrap_or(0);
let is_warning = WARNING_CODES.contains(&error_code);
if request_id == UNSPECIFIED_REQUEST_ID || is_warning {
error_event(server_version, message).unwrap();
} else {
self.process_response(message);
}
}
IncomingMessages::OrderStatus
| IncomingMessages::OpenOrder
| IncomingMessages::OpenOrderEnd
| IncomingMessages::CompletedOrder
| IncomingMessages::CompletedOrdersEnd
| IncomingMessages::ExecutionData
| IncomingMessages::ExecutionDataEnd
| IncomingMessages::CommissionsReport => self.process_orders(message),
_ => self.process_response(message),
};
}
fn process_response(&self, message: ResponseMessage) {
let request_id = message.request_id().unwrap_or(-1); if self.requests.contains(&request_id) {
self.requests.send(&request_id, Ok(message)).unwrap();
} else if self.orders.contains(&request_id) {
self.orders.send(&request_id, Ok(message)).unwrap();
} else if self.shared_channels.contains_sender(message.message_type()) {
self.shared_channels.send_message(message.message_type(), &message);
} else {
info!("no recipient found for: {:?}", message)
}
}
fn process_orders(&self, message: ResponseMessage) {
match message.message_type() {
IncomingMessages::ExecutionData => {
let sent_to_update_stream = self.send_order_update(&message);
match (message.order_id(), message.request_id()) {
(Some(order_id), _) if self.orders.contains(&order_id) => {
if let Some(sender) = self.orders.copy_sender(order_id) {
if let Some(execution_id) = message.execution_id() {
self.executions.insert(execution_id, sender);
}
}
if let Err(e) = self.orders.send(&order_id, Ok(message)) {
warn!("error routing message for order_id({order_id}): {e}");
}
}
(_, Some(request_id)) if self.requests.contains(&request_id) => {
if let Some(sender) = self.requests.copy_sender(request_id) {
if let Some(execution_id) = message.execution_id() {
self.executions.insert(execution_id, sender);
}
}
if let Err(e) = self.requests.send(&request_id, Ok(message)) {
warn!("error routing message for request_id({request_id}): {e}");
}
}
_ => {
if !sent_to_update_stream {
warn!("could not route message {message:?}");
}
}
}
}
IncomingMessages::ExecutionDataEnd => {
match (message.order_id(), message.request_id()) {
(Some(order_id), _) if self.orders.contains(&order_id) => {
if let Err(e) = self.orders.send(&order_id, Ok(message)) {
warn!("error routing message for order_id({order_id}): {e}");
}
}
(_, Some(request_id)) if self.requests.contains(&request_id) => {
if let Err(e) = self.requests.send(&request_id, Ok(message)) {
warn!("error routing message for request_id({request_id}): {e}");
}
}
_ => {
warn!("could not route message {message:?}");
}
}
}
IncomingMessages::OpenOrder | IncomingMessages::OrderStatus => {
let sent_to_update_stream = self.send_order_update(&message);
if let Some(order_id) = message.order_id() {
if self.orders.contains(&order_id) {
if let Err(e) = self.orders.send(&order_id, Ok(message)) {
warn!("error routing message for order_id({order_id}): {e}");
}
} else if self.shared_channels.contains_sender(IncomingMessages::OpenOrder) {
self.shared_channels.send_message(message.message_type(), &message);
} else if !sent_to_update_stream {
warn!("could not route message {message:?}");
}
} else if !sent_to_update_stream {
warn!("could not route message {message:?}");
}
}
IncomingMessages::CompletedOrder | IncomingMessages::OpenOrderEnd | IncomingMessages::CompletedOrdersEnd => {
self.shared_channels.send_message(message.message_type(), &message);
}
IncomingMessages::CommissionsReport => {
let sent_to_update_stream = self.send_order_update(&message);
if let Some(execution_id) = message.execution_id() {
if let Err(e) = self.executions.send(&execution_id, Ok(message)) {
warn!("error sending commission report for execution {}: {}", execution_id, e);
}
} else if !sent_to_update_stream {
warn!("could not route commission report {message:?}");
}
}
_ => (),
}
}
fn send_order_update(&self, message: &ResponseMessage) -> bool {
if let Ok(order_update_stream) = self.order_update_stream.lock() {
if let Some(sender) = order_update_stream.as_ref() {
if let Err(e) = sender.send(Ok(message.clone())) {
warn!("error sending to order update stream: {e}");
return false;
}
return true;
}
}
false
}
fn start_cleanup_thread(self: &Arc<Self>, timeout: std::time::Duration) -> JoinHandle<()> {
let message_bus = Arc::clone(self);
thread::spawn(move || {
let signal_recv = message_bus.signals_recv.clone();
loop {
if let Ok(signal) = signal_recv.recv_timeout(timeout) {
match signal {
Signal::Request(request_id) => {
message_bus.clean_request(request_id);
}
Signal::Order(order_id) => {
message_bus.clean_order(order_id);
}
Signal::OrderUpdateStream => {
message_bus.clear_order_update_stream();
}
}
}
if message_bus.is_shutting_down() {
debug!("cleanup thread exiting");
return;
}
}
})
}
pub(crate) fn process_messages(self: &Arc<Self>, server_version: i32, timeout: std::time::Duration) -> Result<(), Error> {
let handle = self.start_dispatcher_thread(server_version);
self.add_join_handle(handle);
let handle = self.start_cleanup_thread(timeout);
self.add_join_handle(handle);
Ok(())
}
fn add_join_handle(&self, handle: JoinHandle<()>) {
let mut handles = self.handles.lock().unwrap();
handles.push(handle);
}
pub fn join(&self) {
let mut handles = self.handles.lock().unwrap();
for handle in handles.drain(..) {
if let Err(e) = handle.join() {
warn!("could not join thread: {e:?}");
}
}
}
}
const UNSPECIFIED_REQUEST_ID: i32 = -1;
impl<S: Stream> MessageBus for TcpMessageBus<S> {
fn send_request(&self, request_id: i32, message: &RequestMessage) -> Result<InternalSubscription, Error> {
let (sender, receiver) = channel::unbounded();
let sender_copy = sender.clone();
self.requests.insert(request_id, sender);
self.connection.write_message(message)?;
let subscription = SubscriptionBuilder::new()
.receiver(receiver)
.sender(sender_copy)
.signaler(self.signals_send.clone())
.request_id(request_id)
.build();
Ok(subscription)
}
fn cancel_subscription(&self, request_id: i32, message: &RequestMessage) -> Result<(), Error> {
self.connection.write_message(message)?;
if let Err(e) = self.requests.send(&request_id, Err(Error::Cancelled)) {
info!("error sending cancel notification: {e}");
}
self.requests.remove(&request_id);
Ok(())
}
fn send_order_request(&self, order_id: i32, message: &RequestMessage) -> Result<InternalSubscription, Error> {
let (sender, receiver) = channel::unbounded();
let sender_copy = sender.clone();
self.orders.insert(order_id, sender);
self.connection.write_message(message)?;
let subscription = SubscriptionBuilder::new()
.receiver(receiver)
.sender(sender_copy)
.signaler(self.signals_send.clone())
.order_id(order_id)
.build();
Ok(subscription)
}
fn send_message(&self, message: &RequestMessage) -> Result<(), Error> {
self.connection.write_message(message)?;
Ok(())
}
fn create_order_update_subscription(&self) -> Result<InternalSubscription, Error> {
let mut order_update_stream = self.order_update_stream.lock().unwrap();
if order_update_stream.is_some() {
return Err(Error::AlreadySubscribed);
}
let (sender, receiver) = channel::unbounded();
*order_update_stream = Some(sender);
let subscription = SubscriptionBuilder::new().receiver(receiver).signaler(self.signals_send.clone()).build();
Ok(subscription)
}
fn cancel_order_subscription(&self, request_id: i32, message: &RequestMessage) -> Result<(), Error> {
self.connection.write_message(message)?;
if let Err(e) = self.orders.send(&request_id, Err(Error::Cancelled)) {
info!("error sending cancel notification: {e}");
}
self.orders.remove(&request_id);
Ok(())
}
fn send_shared_request(&self, message_type: OutgoingMessages, message: &RequestMessage) -> Result<InternalSubscription, Error> {
self.connection.write_message(message)?;
let shared_receiver = self.shared_channels.get_receiver(message_type);
let subscription = SubscriptionBuilder::new()
.shared_receiver(shared_receiver)
.message_type(message_type)
.build();
Ok(subscription)
}
fn cancel_shared_subscription(&self, _message_type: OutgoingMessages, message: &RequestMessage) -> Result<(), Error> {
self.connection.write_message(message)?;
Ok(())
}
fn ensure_shutdown(&self) {
self.request_shutdown();
self.join();
}
}
fn error_event(server_version: i32, mut packet: ResponseMessage) -> Result<(), Error> {
packet.skip();
let version = packet.next_int()?;
if version < 2 {
let message = packet.next_string()?;
error!("version 2 error: {}", message);
Ok(())
} else {
let request_id = packet.next_int()?;
let error_code = packet.next_int()?;
let error_message = packet.next_string()?;
let mut advanced_order_reject_json: String = "".to_string();
if server_version >= server_versions::ADVANCED_ORDER_REJECT {
advanced_order_reject_json = packet.next_string()?;
}
let is_warning = WARNING_CODES.contains(&error_code);
if is_warning {
warn!(
"request_id: {}, warning_code: {}, warning_message: {}, advanced_order_reject_json: {}",
request_id, error_code, error_message, advanced_order_reject_json
);
} else {
error!(
"request_id: {}, error_code: {}, error_message: {}, advanced_order_reject_json: {}",
request_id, error_code, error_message, advanced_order_reject_json
);
}
Ok(())
}
}
#[derive(Debug)]
struct SenderHash<K, V> {
senders: RwLock<HashMap<K, Sender<V>>>,
}
impl<K: std::hash::Hash + Eq + std::fmt::Debug, V: std::fmt::Debug + Clone> SenderHash<K, V> {
pub fn new() -> Self {
Self {
senders: RwLock::new(HashMap::new()),
}
}
pub fn send(&self, id: &K, message: V) -> Result<(), Error> {
let senders = self.senders.read().unwrap();
debug!("senders: {senders:?}");
if let Some(sender) = senders.get(id) {
if let Err(err) = sender.send(message) {
warn!("error sending: {id:?}, {err}")
}
} else {
warn!("no recipient found for: {id:?}, {message:?}")
}
Ok(())
}
pub fn copy_sender(&self, id: K) -> Option<Sender<V>> {
let senders = self.senders.read().unwrap();
senders.get(&id).cloned()
}
pub fn insert(&self, id: K, message: Sender<V>) -> Option<Sender<V>> {
let mut senders = self.senders.write().unwrap();
senders.insert(id, message)
}
pub fn remove(&self, id: &K) -> Option<Sender<V>> {
let mut senders = self.senders.write().unwrap();
senders.remove(id)
}
pub fn contains(&self, id: &K) -> bool {
let senders = self.senders.read().unwrap();
senders.contains_key(id)
}
pub fn len(&self) -> usize {
let senders = self.senders.read().unwrap();
senders.len()
}
pub fn clear(&self) {
let mut senders = self.senders.write().unwrap();
senders.clear();
}
pub fn notify_all(&self, message: &V) {
let senders = self.senders.read().unwrap();
for sender in senders.values() {
if let Err(e) = sender.send(message.clone()) {
warn!("error sending notification: {e}");
}
}
}
}
#[derive(Debug, Default)]
pub(crate) struct InternalSubscription {
receiver: Option<Receiver<Response>>, sender: Option<Sender<Response>>, shared_receiver: Option<Arc<Receiver<Response>>>, signaler: Option<Sender<Signal>>, pub(crate) request_id: Option<i32>, pub(crate) order_id: Option<i32>, pub(crate) message_type: Option<OutgoingMessages>, }
impl InternalSubscription {
pub(crate) fn next(&self) -> Option<Response> {
if let Some(receiver) = &self.receiver {
Self::receive(receiver)
} else if let Some(receiver) = &self.shared_receiver {
Self::receive(receiver)
} else {
None
}
}
pub(crate) fn try_next(&self) -> Option<Response> {
if let Some(receiver) = &self.receiver {
Self::try_receive(receiver)
} else if let Some(receiver) = &self.shared_receiver {
Self::try_receive(receiver)
} else {
None
}
}
pub(crate) fn next_timeout(&self, timeout: Duration) -> Option<Response> {
if let Some(receiver) = &self.receiver {
Self::timeout_receive(receiver, timeout)
} else if let Some(receiver) = &self.shared_receiver {
Self::timeout_receive(receiver, timeout)
} else {
None
}
}
pub(crate) fn cancel(&self) {
if let Some(sender) = &self.sender {
if let Err(e) = sender.send(Err(Error::Cancelled)) {
warn!("error sending cancel notification: {e}")
}
}
}
fn receive(receiver: &Receiver<Response>) -> Option<Response> {
receiver.recv().ok()
}
fn try_receive(receiver: &Receiver<Response>) -> Option<Response> {
receiver.try_recv().ok()
}
fn timeout_receive(receiver: &Receiver<Response>, timeout: Duration) -> Option<Response> {
receiver.recv_timeout(timeout).ok()
}
}
impl Drop for InternalSubscription {
fn drop(&mut self) {
if let (Some(request_id), Some(signaler)) = (self.request_id, &self.signaler) {
if let Err(e) = signaler.send(Signal::Request(request_id)) {
warn!("error sending drop signal: {e}");
}
} else if let (Some(order_id), Some(signaler)) = (self.order_id, &self.signaler) {
if let Err(e) = signaler.send(Signal::Order(order_id)) {
warn!("error sending drop signal: {e}");
}
} else if let Some(signaler) = &self.signaler {
if let Err(e) = signaler.send(Signal::OrderUpdateStream) {
warn!("error sending drop signal: {e}");
}
}
}
}
pub(crate) struct SubscriptionBuilder {
receiver: Option<Receiver<Response>>,
sender: Option<Sender<Response>>,
shared_receiver: Option<Arc<Receiver<Response>>>,
signaler: Option<Sender<Signal>>,
order_id: Option<i32>,
request_id: Option<i32>,
message_type: Option<OutgoingMessages>,
}
impl SubscriptionBuilder {
pub(crate) fn new() -> Self {
Self {
receiver: None,
sender: None,
shared_receiver: None,
signaler: None,
order_id: None,
request_id: None,
message_type: None,
}
}
pub(crate) fn receiver(mut self, receiver: Receiver<Response>) -> Self {
self.receiver = Some(receiver);
self
}
pub(crate) fn sender(mut self, sender: Sender<Response>) -> Self {
self.sender = Some(sender);
self
}
pub(crate) fn shared_receiver(mut self, shared_receiver: Arc<Receiver<Response>>) -> Self {
self.shared_receiver = Some(shared_receiver);
self
}
pub(crate) fn signaler(mut self, signaler: Sender<Signal>) -> Self {
self.signaler = Some(signaler);
self
}
pub(crate) fn order_id(mut self, order_id: i32) -> Self {
self.order_id = Some(order_id);
self
}
pub(crate) fn request_id(mut self, request_id: i32) -> Self {
self.request_id = Some(request_id);
self
}
pub(crate) fn message_type(mut self, message_type: OutgoingMessages) -> Self {
self.message_type = Some(message_type);
self
}
pub(crate) fn build(self) -> InternalSubscription {
if let (Some(receiver), Some(signaler)) = (self.receiver, self.signaler) {
InternalSubscription {
receiver: Some(receiver),
sender: self.sender,
shared_receiver: None,
signaler: Some(signaler),
request_id: self.request_id,
order_id: self.order_id,
message_type: self.message_type,
}
} else if let Some(receiver) = self.shared_receiver {
InternalSubscription {
receiver: None,
sender: None,
shared_receiver: Some(receiver),
signaler: None,
request_id: self.request_id,
order_id: self.order_id,
message_type: self.message_type,
}
} else {
panic!("bad configuration");
}
}
}
#[derive(Default, Clone, Debug)]
pub(crate) struct ConnectionMetadata {
pub(crate) next_order_id: i32,
pub(crate) client_id: i32,
pub(crate) server_version: i32,
pub(crate) managed_accounts: String,
pub(crate) connection_time: Option<OffsetDateTime>,
pub(crate) time_zone: Option<&'static Tz>,
}
#[derive(Debug)]
pub(crate) struct TcpSocket {
reader: Mutex<TcpStream>,
writer: Mutex<TcpStream>,
connection_url: String,
}
impl TcpSocket {
pub fn new(stream: TcpStream, connection_url: &str) -> Result<Self, Error> {
let writer = stream.try_clone()?;
stream.set_read_timeout(Some(TWS_READ_TIMEOUT))?;
Ok(Self {
reader: Mutex::new(stream),
writer: Mutex::new(writer),
connection_url: connection_url.to_string(),
})
}
}
impl Reconnect for TcpSocket {
fn reconnect(&self) -> Result<(), Error> {
match TcpStream::connect(&self.connection_url) {
Ok(stream) => {
stream.set_read_timeout(Some(TWS_READ_TIMEOUT))?;
let mut reader = self.reader.lock()?;
*reader = stream.try_clone()?;
let mut writer = self.writer.lock()?;
*writer = stream;
Ok(())
}
Err(e) => Err(e.into()),
}
}
fn sleep(&self, duration: std::time::Duration) {
thread::sleep(duration)
}
}
pub(crate) trait Reconnect {
fn reconnect(&self) -> Result<(), Error>;
fn sleep(&self, duration: std::time::Duration);
}
pub(crate) trait Stream: Io + Reconnect + Sync + Send + 'static + std::fmt::Debug {}
impl Stream for TcpSocket {}
fn read_header(reader: &mut impl Read) -> Result<usize, Error> {
let buffer = &mut [0_u8; 4];
reader.read_exact(buffer)?;
let mut reader = Cursor::new(buffer);
let count = reader.read_u32::<BigEndian>()?;
Ok(count as usize)
}
fn read_message(reader: &mut impl Read) -> Result<Vec<u8>, Error> {
let message_size = read_header(reader)?;
let mut data = vec![0_u8; message_size];
reader.read_exact(&mut data)?;
Ok(data)
}
impl Io for TcpSocket {
fn read_message(&self) -> Result<Vec<u8>, Error> {
let mut reader = self.reader.lock()?;
read_message(&mut *reader)
}
fn write_all(&self, buf: &[u8]) -> Result<(), Error> {
let mut writer = self.writer.lock()?;
writer.write_all(buf)?;
Ok(())
}
}
pub(crate) trait Io {
fn read_message(&self) -> Result<Vec<u8>, Error>;
fn write_all(&self, buf: &[u8]) -> Result<(), Error>;
}
#[derive(Debug)]
pub(crate) struct Connection<S: Stream> {
client_id: i32,
socket: S,
connection_metadata: Mutex<ConnectionMetadata>,
max_retries: i32,
recorder: MessageRecorder,
}
impl<S: Stream> Connection<S> {
pub fn connect(socket: S, client_id: i32) -> Result<Self, Error> {
let connection = Self {
client_id,
socket,
connection_metadata: Mutex::new(ConnectionMetadata {
client_id,
..Default::default()
}),
max_retries: MAX_RETRIES,
recorder: MessageRecorder::from_env(),
};
connection.establish_connection()?;
Ok(connection)
}
pub fn connection_metadata(&self) -> ConnectionMetadata {
let metadata = self.connection_metadata.lock().unwrap();
metadata.clone()
}
pub fn reconnect(&self) -> Result<(), Error> {
let mut backoff = FibonacciBackoff::new(30);
for i in 0..self.max_retries {
let next_delay = backoff.next_delay();
info!("next reconnection attempt in {next_delay:#?}");
self.socket.sleep(next_delay);
match self.socket.reconnect() {
Ok(_) => {
info!("reconnected !!!");
self.establish_connection()?;
return Ok(());
}
Err(e) => {
error!("reconnection attempt {i} of {} failed: {e}", self.max_retries);
}
}
}
Err(Error::ConnectionFailed)
}
fn establish_connection(&self) -> Result<(), Error> {
self.handshake()?;
self.start_api()?;
self.receive_account_info()?;
Ok(())
}
fn write_message(&self, message: &RequestMessage) -> Result<(), Error> {
self.recorder.record_request(message);
let encoded = message.encode();
debug!("-> {encoded:?}");
let length_encoded = encode_length(&encoded);
self.socket.write_all(&length_encoded)?;
Ok(())
}
fn read_message(&self) -> Response {
let data = self.socket.read_message()?;
let raw_string = String::from_utf8(data)?;
debug!("<- {:?}", raw_string);
let message = ResponseMessage::from(&raw_string);
self.recorder.record_response(&message);
Ok(message)
}
fn handshake(&self) -> Result<(), Error> {
let version = &format!("v{MIN_SERVER_VERSION}..{MAX_SERVER_VERSION}");
debug!("-> {version:?}");
let mut handshake = Vec::from(b"API\0");
handshake.extend_from_slice(&encode_length(version));
self.socket.write_all(&handshake)?;
let ack = self.read_message();
let mut connection_metadata = self.connection_metadata.lock()?;
match ack {
Ok(mut response) => {
connection_metadata.server_version = response.next_int()?;
let time = response.next_string()?;
(connection_metadata.connection_time, connection_metadata.time_zone) = parse_connection_time(time.as_str());
}
Err(Error::Io(err)) if err.kind() == std::io::ErrorKind::UnexpectedEof => {
return Err(Error::Simple(format!("The server may be rejecting connections from this host: {err}")));
}
Err(err) => {
return Err(err);
}
}
Ok(())
}
fn start_api(&self) -> Result<(), Error> {
const VERSION: i32 = 2;
let prelude = &mut RequestMessage::default();
prelude.push_field(&OutgoingMessages::StartApi);
prelude.push_field(&VERSION);
prelude.push_field(&self.client_id);
if self.server_version() > server_versions::OPTIONAL_CAPABILITIES {
prelude.push_field(&"");
}
self.write_message(prelude)?;
Ok(())
}
fn server_version(&self) -> i32 {
let connection_metadata = self.connection_metadata.lock().unwrap();
connection_metadata.server_version
}
fn receive_account_info(&self) -> Result<(), Error> {
let mut saw_next_order_id: bool = false;
let mut saw_managed_accounts: bool = false;
let mut attempts = 0;
const MAX_ATTEMPTS: i32 = 100;
loop {
let mut message = self.read_message()?;
match message.message_type() {
IncomingMessages::NextValidId => {
saw_next_order_id = true;
message.skip(); message.skip();
let mut connection_metadata = self.connection_metadata.lock()?;
connection_metadata.next_order_id = message.next_int()?;
}
IncomingMessages::ManagedAccounts => {
saw_managed_accounts = true;
message.skip(); message.skip();
let mut connection_metadata = self.connection_metadata.lock()?;
connection_metadata.managed_accounts = message.next_string()?;
}
IncomingMessages::Error => {
error!("message: {message:?}")
}
_ => info!("message: {message:?}"),
}
attempts += 1;
if (saw_next_order_id && saw_managed_accounts) || attempts > MAX_ATTEMPTS {
break;
}
}
Ok(())
}
#[cfg(test)]
pub(crate) fn stubbed(socket: S, client_id: i32) -> Connection<S> {
Connection {
client_id,
socket,
connection_metadata: Mutex::new(ConnectionMetadata {
client_id,
..Default::default()
}),
max_retries: MAX_RETRIES,
recorder: MessageRecorder::new(false, String::from("")),
}
}
}
struct FibonacciBackoff {
previous: u64,
current: u64,
max: u64,
}
impl FibonacciBackoff {
fn new(max: u64) -> Self {
FibonacciBackoff {
previous: 0,
current: 1,
max,
}
}
fn next_delay(&mut self) -> Duration {
let next = self.previous + self.current;
self.previous = self.current;
self.current = next;
if next > self.max {
Duration::from_secs(self.max)
} else {
Duration::from_secs(next)
}
}
}
fn parse_connection_time(connection_time: &str) -> (Option<OffsetDateTime>, Option<&'static Tz>) {
let parts: Vec<&str> = connection_time.split(' ').collect();
let zones = timezones::find_by_name(parts[2]);
if zones.is_empty() {
error!("time zone not found for {}", parts[2]);
return (None, None);
}
let timezone = zones[0];
let format = format_description!("[year][month][day] [hour]:[minute]:[second]");
let date_str = format!("{} {}", parts[0], parts[1]);
let date = time::PrimitiveDateTime::parse(date_str.as_str(), format);
match date {
Ok(connected_at) => match connected_at.assume_timezone(timezone) {
OffsetResult::Some(date) => (Some(date), Some(timezone)),
_ => {
warn!("error setting timezone");
(None, Some(timezone))
}
},
Err(err) => {
warn!("could not parse connection time from {date_str}: {err}");
(None, Some(timezone))
}
}
}
#[cfg(test)]
mod tests;