use crate::order::{OrderKind, TimeInForce, id::ClientOrderId};
use fnv::FnvHashMap;
use ibapi::orders::{Action, Order, TimeInForce as IbTimeInForce, order_builder};
use parking_lot::{Mutex, RwLock};
use rust_decimal::Decimal;
use rustrade_instrument::{Side, instrument::name::InstrumentNameExchange};
use std::{sync::Arc, time::Instant};
#[derive(Debug, Clone)]
pub struct OrderContext {
pub instrument: InstrumentNameExchange,
pub side: Side,
pub price: Decimal,
pub quantity: Decimal,
pub kind: OrderKind,
pub time_in_force: TimeInForce,
}
#[derive(Debug, Clone)]
pub struct OrderIdMap {
inner: Arc<RwLock<OrderIdMapInner>>,
}
#[derive(Debug, Default)]
struct OrderIdMapInner {
cid_to_ib: FnvHashMap<ClientOrderId, i32>,
ib_to_entry: FnvHashMap<i32, (ClientOrderId, OrderContext, Instant)>,
}
impl OrderIdMap {
pub fn new() -> Self {
Self {
inner: Arc::new(RwLock::new(OrderIdMapInner::default())),
}
}
pub fn register(&self, client_id: ClientOrderId, ib_id: i32, context: OrderContext) {
let mut inner = self.inner.write();
inner.cid_to_ib.insert(client_id.clone(), ib_id);
inner
.ib_to_entry
.insert(ib_id, (client_id, context, Instant::now()));
}
pub fn get_ib_id(&self, client_id: &ClientOrderId) -> Option<i32> {
self.inner.read().cid_to_ib.get(client_id).copied()
}
pub fn get_client_id(&self, ib_id: i32) -> Option<ClientOrderId> {
self.inner
.read()
.ib_to_entry
.get(&ib_id)
.map(|(cid, _, _)| cid.clone())
}
pub fn get_client_id_and_context(&self, ib_id: i32) -> Option<(ClientOrderId, OrderContext)> {
self.inner
.read()
.ib_to_entry
.get(&ib_id)
.map(|(cid, ctx, _)| (cid.clone(), ctx.clone()))
}
pub fn remove_and_get_context(&self, ib_id: i32) -> Option<(ClientOrderId, OrderContext)> {
let mut inner = self.inner.write();
if let Some((client_id, ctx, _)) = inner.ib_to_entry.remove(&ib_id) {
inner.cid_to_ib.remove(&client_id);
Some((client_id, ctx))
} else {
None
}
}
pub fn remove_by_ib_id(&self, ib_id: i32) -> Option<ClientOrderId> {
let mut inner = self.inner.write();
if let Some((client_id, _, _)) = inner.ib_to_entry.remove(&ib_id) {
inner.cid_to_ib.remove(&client_id);
Some(client_id)
} else {
None
}
}
pub fn clear_stale(&self, max_age: std::time::Duration) -> usize {
let mut inner = self.inner.write();
let before = inner.ib_to_entry.len();
let stale_ids: Vec<i32> = inner
.ib_to_entry
.iter()
.filter(|(_, (_, _, registered_at))| registered_at.elapsed() >= max_age)
.map(|(ib_id, _)| *ib_id)
.collect();
for ib_id in stale_ids {
if let Some((client_id, _, _)) = inner.ib_to_entry.remove(&ib_id) {
inner.cid_to_ib.remove(&client_id);
}
}
before - inner.ib_to_entry.len()
}
pub fn len(&self) -> usize {
self.inner.read().cid_to_ib.len()
}
pub fn is_empty(&self) -> bool {
self.inner.read().cid_to_ib.is_empty()
}
}
impl Default for OrderIdMap {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct PendingCancels {
inner: Arc<Mutex<FnvHashMap<i32, Instant>>>,
}
impl PendingCancels {
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(FnvHashMap::with_capacity_and_hasher(
8,
Default::default(),
))),
}
}
pub fn insert(&self, ib_id: i32) {
self.inner.lock().insert(ib_id, Instant::now());
}
#[must_use]
pub fn remove(&self, ib_id: i32) -> bool {
self.inner.lock().remove(&ib_id).is_some()
}
#[must_use]
pub fn clear_stale(&self, max_age: std::time::Duration) -> usize {
let mut map = self.inner.lock();
let before = map.len();
map.retain(|_, registered_at| registered_at.elapsed() < max_age);
before - map.len()
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.lock().len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.lock().is_empty()
}
}
impl Default for PendingCancels {
fn default() -> Self {
Self::new()
}
}
pub fn side_to_action(side: rustrade_instrument::Side) -> Action {
match side {
rustrade_instrument::Side::Buy => Action::Buy,
rustrade_instrument::Side::Sell => Action::Sell,
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum OrderMappingError {
PostOnlyNotSupported,
InvalidPrice(String),
}
impl std::fmt::Display for OrderMappingError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::PostOnlyNotSupported => write!(f, "post_only not supported by IB"),
Self::InvalidPrice(p) => write!(f, "invalid price for f64 conversion: {p}"),
}
}
}
impl std::error::Error for OrderMappingError {}
pub fn time_in_force_to_ib(tif: &TimeInForce) -> Result<IbTimeInForce, OrderMappingError> {
match tif {
TimeInForce::GoodUntilCancelled { post_only } => {
if *post_only {
Err(OrderMappingError::PostOnlyNotSupported)
} else {
Ok(IbTimeInForce::GoodTilCanceled)
}
}
TimeInForce::GoodUntilEndOfDay => Ok(IbTimeInForce::Day),
TimeInForce::FillOrKill => Ok(IbTimeInForce::FillOrKill),
TimeInForce::ImmediateOrCancel => Ok(IbTimeInForce::ImmediateOrCancel),
}
}
pub fn build_ib_order(
side: rustrade_instrument::Side,
quantity: f64,
kind: &OrderKind,
price: rust_decimal::Decimal,
tif: &TimeInForce,
) -> Result<Order, OrderMappingError> {
let action = side_to_action(side);
let tif_ib = time_in_force_to_ib(tif)?;
let mut order = match kind {
OrderKind::Market => order_builder::market_order(action, quantity),
OrderKind::Limit => {
let price_f64: f64 = price.try_into().or_else(|_| {
price
.to_string()
.parse()
.map_err(|_| OrderMappingError::InvalidPrice(price.to_string()))
})?;
order_builder::limit_order(action, quantity, price_f64)
}
};
order.tif = tif_ib;
Ok(order)
}
#[cfg(test)]
#[allow(clippy::unwrap_used)] mod tests {
use super::*;
use rust_decimal::Decimal;
fn test_context() -> OrderContext {
OrderContext {
instrument: rustrade_instrument::instrument::name::InstrumentNameExchange::from("AAPL"),
side: Side::Buy,
price: Decimal::from(150),
quantity: Decimal::from(100),
kind: OrderKind::Limit,
time_in_force: TimeInForce::GoodUntilCancelled { post_only: false },
}
}
#[test]
fn test_order_id_map_basic() {
let map = OrderIdMap::new();
let cid = ClientOrderId::new("order-123");
let ctx = test_context();
map.register(cid.clone(), 42, ctx.clone());
assert_eq!(map.get_ib_id(&cid), Some(42));
assert_eq!(map.get_client_id(42), Some(cid.clone()));
assert_eq!(map.len(), 1);
let (retrieved_cid, retrieved_ctx) = map.get_client_id_and_context(42).unwrap();
assert_eq!(retrieved_cid, cid);
assert_eq!(retrieved_ctx.side, Side::Buy);
assert_eq!(retrieved_ctx.price, Decimal::from(150));
}
#[test]
fn test_order_id_map_remove() {
let map = OrderIdMap::new();
let cid = ClientOrderId::new("order-456");
map.register(cid.clone(), 100, test_context());
assert_eq!(map.len(), 1);
let removed = map.remove_by_ib_id(100);
assert_eq!(removed, Some(cid.clone()));
assert!(map.is_empty());
assert!(map.get_ib_id(&cid).is_none());
assert!(map.get_client_id(100).is_none());
assert!(map.get_client_id_and_context(100).is_none());
}
#[test]
fn test_side_conversion() {
assert!(matches!(
side_to_action(rustrade_instrument::Side::Buy),
Action::Buy
));
assert!(matches!(
side_to_action(rustrade_instrument::Side::Sell),
Action::Sell
));
}
#[test]
fn test_time_in_force_conversion() {
assert_eq!(
time_in_force_to_ib(&TimeInForce::GoodUntilCancelled { post_only: false }),
Ok(IbTimeInForce::GoodTilCanceled)
);
assert!(matches!(
time_in_force_to_ib(&TimeInForce::GoodUntilCancelled { post_only: true }),
Err(OrderMappingError::PostOnlyNotSupported)
));
assert_eq!(
time_in_force_to_ib(&TimeInForce::GoodUntilEndOfDay),
Ok(IbTimeInForce::Day)
);
assert_eq!(
time_in_force_to_ib(&TimeInForce::FillOrKill),
Ok(IbTimeInForce::FillOrKill)
);
assert_eq!(
time_in_force_to_ib(&TimeInForce::ImmediateOrCancel),
Ok(IbTimeInForce::ImmediateOrCancel)
);
}
#[test]
fn test_build_market_order() {
let order = build_ib_order(
rustrade_instrument::Side::Buy,
100.0,
&OrderKind::Market,
rust_decimal::Decimal::ZERO,
&TimeInForce::GoodUntilEndOfDay,
)
.unwrap();
assert_eq!(order.action, Action::Buy);
assert_eq!(order.total_quantity, 100.0);
assert_eq!(order.order_type, "MKT");
}
#[test]
fn test_build_limit_order() {
let order = build_ib_order(
rustrade_instrument::Side::Sell,
50.0,
&OrderKind::Limit,
Decimal::try_from(150.5).unwrap(),
&TimeInForce::GoodUntilCancelled { post_only: false },
)
.unwrap();
assert_eq!(order.action, Action::Sell);
assert_eq!(order.total_quantity, 50.0);
assert_eq!(order.order_type, "LMT");
}
#[test]
fn test_order_id_map_remove_and_get_context() {
let map = OrderIdMap::new();
let cid = ClientOrderId::new("order-789");
let ctx = test_context();
map.register(cid.clone(), 50, ctx);
assert_eq!(map.len(), 1);
let result = map.remove_and_get_context(50);
assert!(result.is_some());
let (retrieved_cid, retrieved_ctx) = result.unwrap();
assert_eq!(retrieved_cid, cid);
assert_eq!(retrieved_ctx.side, Side::Buy);
assert!(map.is_empty());
assert!(map.get_client_id(50).is_none());
assert!(map.get_ib_id(&cid).is_none());
assert!(map.remove_and_get_context(50).is_none());
}
#[test]
fn test_order_id_map_clear_stale() {
use std::time::Duration;
let map = OrderIdMap::new();
map.register(ClientOrderId::new("old-1"), 1, test_context());
map.register(ClientOrderId::new("old-2"), 2, test_context());
let cleared = map.clear_stale(Duration::ZERO);
assert_eq!(cleared, 2);
assert!(map.is_empty());
map.register(ClientOrderId::new("new-1"), 10, test_context());
map.register(ClientOrderId::new("new-2"), 20, test_context());
let cleared = map.clear_stale(Duration::from_secs(3600));
assert_eq!(cleared, 0);
assert_eq!(map.len(), 2);
}
#[test]
fn test_pending_cancels_insert_remove() {
let cancels = PendingCancels::new();
assert!(cancels.is_empty());
cancels.insert(42);
assert_eq!(cancels.len(), 1);
assert!(!cancels.is_empty());
assert!(cancels.remove(42));
assert!(cancels.is_empty());
assert!(!cancels.remove(42));
assert!(!cancels.remove(999));
}
#[test]
fn test_pending_cancels_multiple() {
let cancels = PendingCancels::new();
cancels.insert(1);
cancels.insert(2);
cancels.insert(3);
assert_eq!(cancels.len(), 3);
assert!(cancels.remove(2));
assert_eq!(cancels.len(), 2);
assert!(cancels.remove(1));
assert!(cancels.remove(3));
assert!(cancels.is_empty());
}
#[test]
fn test_pending_cancels_clear_stale() {
use std::time::Duration;
let cancels = PendingCancels::new();
cancels.insert(1);
cancels.insert(2);
let cleared = cancels.clear_stale(Duration::ZERO);
assert_eq!(cleared, 2);
assert!(cancels.is_empty());
cancels.insert(10);
cancels.insert(20);
let cleared = cancels.clear_stale(Duration::from_secs(3600));
assert_eq!(cleared, 0);
assert_eq!(cancels.len(), 2);
}
#[test]
fn test_pending_cancels_duplicate_insert() {
let cancels = PendingCancels::new();
cancels.insert(42);
cancels.insert(42);
assert_eq!(cancels.len(), 1);
assert!(cancels.remove(42));
assert!(cancels.is_empty());
}
}