use log::{debug, error, info, warn};
use time::macros::format_description;
use time::OffsetDateTime;
use time_tz::{OffsetResult, PrimitiveDateTimeExt, Tz};
use crate::accounts::AccountUpdate;
use crate::common::timezone::find_timezone;
use crate::errors::Error;
use crate::messages::{
encode_length, encode_protobuf_message, IncomingMessages, Notice, OutgoingMessages, ResponseMessage, HANDSHAKE_DECODE_FAILURE_CODE,
HANDSHAKE_UNKNOWN_FRAME_CODE, PROTOBUF_MSG_ID,
};
use crate::orders::{CommissionReport, ExecutionData, OrderData, OrderStatus};
use crate::server_versions;
#[derive(Debug)]
#[non_exhaustive]
#[allow(clippy::large_enum_variant)]
pub enum StartupMessage {
OpenOrder(OrderData),
OrderStatus(OrderStatus),
OpenOrderEnd,
AccountUpdate(AccountUpdate),
Execution(ExecutionData),
CommissionReport(CommissionReport),
CompletedOrder(OrderData),
ExecutionDataEnd,
CompletedOrdersEnd,
}
impl StartupMessage {
pub fn message_type(&self) -> IncomingMessages {
match self {
StartupMessage::OpenOrder(_) => IncomingMessages::OpenOrder,
StartupMessage::OrderStatus(_) => IncomingMessages::OrderStatus,
StartupMessage::OpenOrderEnd => IncomingMessages::OpenOrderEnd,
StartupMessage::AccountUpdate(au) => match au {
AccountUpdate::AccountValue(_) => IncomingMessages::AccountValue,
AccountUpdate::PortfolioValue(_) => IncomingMessages::PortfolioValue,
AccountUpdate::UpdateTime(_) => IncomingMessages::AccountUpdateTime,
AccountUpdate::End => IncomingMessages::AccountDownloadEnd,
},
StartupMessage::Execution(_) => IncomingMessages::ExecutionData,
StartupMessage::CommissionReport(_) => IncomingMessages::CommissionsReport,
StartupMessage::CompletedOrder(_) => IncomingMessages::CompletedOrder,
StartupMessage::ExecutionDataEnd => IncomingMessages::ExecutionDataEnd,
StartupMessage::CompletedOrdersEnd => IncomingMessages::CompletedOrdersEnd,
}
}
}
pub(crate) trait NoticeSink: Send + Sync {
fn deliver(&self, notice: Notice);
}
#[cfg(feature = "sync")]
impl NoticeSink for crate::transport::sync::NoticeBroadcaster {
fn deliver(&self, notice: Notice) {
self.broadcast(notice);
}
}
#[cfg(feature = "async")]
impl NoticeSink for tokio::sync::broadcast::Sender<Notice> {
fn deliver(&self, notice: Notice) {
let _ = self.send(notice);
}
}
pub(crate) struct StartupHandshakeContext<'a> {
pub startup: Option<&'a (dyn Fn(StartupMessage) + Send + Sync)>,
pub notice_sink: &'a (dyn NoticeSink + Sync),
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct HandshakeData {
pub min_version: i32,
pub max_version: i32,
pub server_version: i32,
pub server_time: String,
}
pub trait ConnectionProtocol {
type Error;
fn format_handshake(&self) -> Vec<u8>;
fn parse_handshake_response(&self, response: &mut ResponseMessage) -> Result<HandshakeData, Self::Error>;
fn format_start_api(&self, client_id: i32, server_version: i32) -> Vec<u8>;
fn parse_account_info(
&self,
server_version: i32,
message: &mut ResponseMessage,
ctx: &StartupHandshakeContext<'_>,
) -> Result<AccountInfo, Self::Error>;
}
#[derive(Debug, Clone, Default)]
pub struct AccountInfo {
pub next_order_id: Option<i32>,
pub managed_accounts: Option<String>,
}
#[derive(Debug)]
pub struct ConnectionHandler {
pub min_version: i32,
pub max_version: i32,
}
impl Default for ConnectionHandler {
fn default() -> Self {
Self {
min_version: server_versions::PROTOBUF_REST_MESSAGES_3,
max_version: server_versions::UPDATE_CONFIG,
}
}
}
impl ConnectionProtocol for ConnectionHandler {
type Error = Error;
fn format_handshake(&self) -> Vec<u8> {
let version_string = format!("v{}..{}", self.min_version, self.max_version);
debug!("Handshake version: {version_string}");
let mut handshake = Vec::from(b"API\0");
handshake.extend_from_slice(&encode_length(&version_string));
handshake
}
fn parse_handshake_response(&self, response: &mut ResponseMessage) -> Result<HandshakeData, Self::Error> {
let server_version = response.next_int()?;
let server_time = response.next_string()?;
Ok(HandshakeData {
min_version: self.min_version,
max_version: self.max_version,
server_version,
server_time,
})
}
fn format_start_api(&self, client_id: i32, _server_version: i32) -> Vec<u8> {
use prost::Message;
let request = crate::proto::StartApiRequest {
client_id: Some(client_id),
optional_capabilities: None,
};
encode_protobuf_message(OutgoingMessages::StartApi as i32, &request.encode_to_vec())
}
fn parse_account_info(
&self,
server_version: i32,
message: &mut ResponseMessage,
ctx: &StartupHandshakeContext<'_>,
) -> Result<AccountInfo, Self::Error> {
use prost::Message;
let mut info = AccountInfo::default();
match message.message_type() {
IncomingMessages::NextValidId => {
let proto = crate::proto::NextValidId::decode(message.require_proto()?)?;
info.next_order_id = proto.order_id;
}
IncomingMessages::ManagedAccounts => {
let proto = crate::proto::ManagedAccounts::decode(message.require_proto()?)?;
info.managed_accounts = proto.accounts_list;
}
_ => dispatch_unsolicited_message(server_version, message, ctx),
}
Ok(info)
}
}
pub(crate) fn dispatch_unsolicited_message(_server_version: i32, message: &mut ResponseMessage, ctx: &StartupHandshakeContext<'_>) {
use crate::accounts::common::decode_account_update_message;
use crate::orders::common::{decode_commission_report, decode_completed_order, decode_execution_data, decode_open_order, decode_order_status};
fn dispatch_typed<T>(
ctx: &StartupHandshakeContext<'_>,
kind: IncomingMessages,
decode: impl FnOnce() -> Result<T, Error>,
wrap: impl FnOnce(T) -> StartupMessage,
) {
let Some(cb) = ctx.startup else { return };
match decode() {
Ok(t) => cb(wrap(t)),
Err(e) => ctx.notice_sink.deliver(Notice::synthesized(
HANDSHAKE_DECODE_FAILURE_CODE,
format!("handshake decoder failed for {kind:?}: {e}"),
)),
}
}
fn dispatch_unit(ctx: &StartupHandshakeContext<'_>, msg: StartupMessage) {
if let Some(cb) = ctx.startup {
cb(msg);
}
}
let kind = message.message_type();
match kind {
IncomingMessages::Error => {
let notice = Notice::from(&*message);
if notice.is_warning() || notice.is_system_message() {
info!("{notice}");
} else {
error!("Error during account info: {notice}");
}
ctx.notice_sink.deliver(notice);
}
IncomingMessages::OpenOrder => dispatch_typed(ctx, kind, || decode_open_order(message), StartupMessage::OpenOrder),
IncomingMessages::OrderStatus => dispatch_typed(ctx, kind, || decode_order_status(message), StartupMessage::OrderStatus),
IncomingMessages::OpenOrderEnd => dispatch_unit(ctx, StartupMessage::OpenOrderEnd),
IncomingMessages::AccountValue
| IncomingMessages::PortfolioValue
| IncomingMessages::AccountUpdateTime
| IncomingMessages::AccountDownloadEnd => dispatch_typed(ctx, kind, || decode_account_update_message(message), StartupMessage::AccountUpdate),
IncomingMessages::ExecutionData => dispatch_typed(ctx, kind, || decode_execution_data(message), StartupMessage::Execution),
IncomingMessages::CommissionsReport => dispatch_typed(ctx, kind, || decode_commission_report(message), StartupMessage::CommissionReport),
IncomingMessages::CompletedOrder => dispatch_typed(ctx, kind, || decode_completed_order(message), StartupMessage::CompletedOrder),
IncomingMessages::ExecutionDataEnd => dispatch_unit(ctx, StartupMessage::ExecutionDataEnd),
IncomingMessages::CompletedOrdersEnd => dispatch_unit(ctx, StartupMessage::CompletedOrdersEnd),
_ => {
warn!("unrouted handshake frame: {kind:?}");
ctx.notice_sink.deliver(Notice::synthesized(
HANDSHAKE_UNKNOWN_FRAME_CODE,
format!("unsolicited handshake frame with no typed variant: {kind:?}"),
));
}
}
}
pub(crate) fn require_protobuf_support(server_version: i32) -> Result<(), Error> {
if server_version < server_versions::PROTOBUF_REST_MESSAGES_3 {
return Err(Error::ServerVersion(
server_versions::PROTOBUF_REST_MESSAGES_3,
server_version,
format!(
"protobuf transport — rust-ibapi 3.x requires TWS or IB Gateway with server version {} or later; please upgrade",
server_versions::PROTOBUF_REST_MESSAGES_3
),
));
}
Ok(())
}
pub fn parse_connection_time(connection_time: &str) -> Result<(Option<OffsetDateTime>, Option<&'static Tz>), Error> {
let parts: Vec<&str> = connection_time.split(' ').collect();
if parts.len() < 3 {
error!("Invalid connection time format: {connection_time}");
return Ok((None, None));
}
let tz_name = if parts.len() > 3 { parts[2..].join(" ") } else { parts[2].to_string() };
let zones = find_timezone(&tz_name);
if zones.is_empty() {
return Err(Error::UnsupportedTimeZone(tz_name));
}
let timezone = zones[0];
let format = format_description!("[year][month][day] [hour]:[minute]:[second]");
let date_str = format!("{} {}", parts[0], parts[1]);
let date = time::PrimitiveDateTime::parse(date_str.as_str(), format);
match date {
Ok(connected_at) => match connected_at.assume_timezone(timezone) {
OffsetResult::Some(date) => Ok((Some(date), Some(timezone))),
_ => {
log::warn!("Error setting timezone");
Ok((None, Some(timezone)))
}
},
Err(err) => {
log::warn!("Could not parse connection time from {date_str}: {err}");
Ok((None, Some(timezone)))
}
}
}
pub fn parse_raw_message(data: &[u8]) -> (ResponseMessage, Option<String>) {
let msg_id = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
if msg_id > PROTOBUF_MSG_ID {
let real_type = msg_id - PROTOBUF_MSG_ID;
debug!("<- protobuf msg_id={real_type}");
let message = ResponseMessage::from_protobuf(real_type, data[4..].to_vec());
(message, None)
} else {
let raw_string = String::from_utf8_lossy(&data[4..]).into_owned();
debug!("<- {raw_string:?}");
let mut fields = vec![msg_id.to_string()];
fields.extend(raw_string.split_terminator('\0').map(|s| s.to_string()));
let message = ResponseMessage {
i: 0,
fields,
raw_bytes: None,
};
(message, Some(raw_string))
}
}
#[cfg(test)]
#[path = "common_tests.rs"]
mod tests;