use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use rustrade_core::{ExchangeClient, MetricsSink, Order, OrderKind, Symbol};
use rustrade_supervisor::{RestartPolicy, TradingService};
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
#[derive(Debug, Clone, PartialEq)]
pub struct TrackedOrder {
pub order_id: String,
pub symbol: Symbol,
pub placed_at: DateTime<Utc>,
}
#[derive(Clone, Default)]
pub struct OrderTracker {
inner: Arc<RwLock<HashMap<String, TrackedOrder>>>,
}
impl OrderTracker {
pub fn new() -> Self {
Self::default()
}
pub(crate) async fn record(&self, order_id: String, order: &Order) {
if matches!(order.kind, OrderKind::Market) {
return;
}
self.inner.write().await.insert(
order_id.clone(),
TrackedOrder {
order_id,
symbol: order.symbol.clone(),
placed_at: Utc::now(),
},
);
}
pub(crate) async fn forget(&self, order_id: &str) {
self.inner.write().await.remove(order_id);
}
pub async fn len(&self) -> usize {
self.inner.read().await.len()
}
pub async fn is_empty(&self) -> bool {
self.inner.read().await.is_empty()
}
pub async fn snapshot(&self) -> Vec<TrackedOrder> {
self.inner.read().await.values().cloned().collect()
}
}
pub struct OrderReaperService {
exchange: Arc<dyn ExchangeClient>,
tracker: OrderTracker,
symbols: Vec<Symbol>,
ttl: Duration,
poll_cadence: Duration,
metrics: Arc<dyn MetricsSink>,
cancelled: AtomicU64,
reconciled: AtomicU64,
sweeps: AtomicU64,
}
impl OrderReaperService {
pub(crate) fn new(
exchange: Arc<dyn ExchangeClient>,
tracker: OrderTracker,
symbols: Vec<Symbol>,
ttl: Duration,
poll_cadence: Duration,
metrics: Arc<dyn MetricsSink>,
) -> Self {
Self {
exchange,
tracker,
symbols,
ttl,
poll_cadence,
metrics,
cancelled: AtomicU64::new(0),
reconciled: AtomicU64::new(0),
sweeps: AtomicU64::new(0),
}
}
pub fn cancelled(&self) -> u64 {
self.cancelled.load(Ordering::Relaxed)
}
pub fn reconciled(&self) -> u64 {
self.reconciled.load(Ordering::Relaxed)
}
pub fn sweeps(&self) -> u64 {
self.sweeps.load(Ordering::Relaxed)
}
pub(crate) async fn sweep_once(&self) {
self.sweeps.fetch_add(1, Ordering::Relaxed);
let now = Utc::now();
for symbol in &self.symbols {
let open = match self.exchange.get_open_orders(symbol).await {
Ok(o) => o,
Err(e) => {
tracing::warn!(symbol = %symbol, error = %e, "get_open_orders failed; skipping sweep for symbol");
continue;
}
};
let live: HashMap<&str, &rustrade_core::OpenOrder> =
open.iter().map(|o| (o.order_id.as_str(), o)).collect();
let tracked: Vec<TrackedOrder> = self
.tracker
.snapshot()
.await
.into_iter()
.filter(|t| &t.symbol == symbol)
.collect();
for t in tracked {
match live.get(t.order_id.as_str()) {
None => {
self.tracker.forget(&t.order_id).await;
self.reconciled.fetch_add(1, Ordering::Relaxed);
tracing::debug!(symbol = %symbol, order_id = %t.order_id, "reconciled away (no longer open)");
}
Some(oo) => {
let age_from = oo.created_at.unwrap_or(t.placed_at);
let age = now.signed_duration_since(age_from);
if age.num_milliseconds().max(0) as u128 >= self.ttl.as_millis() {
match self.exchange.cancel_order(symbol, &t.order_id).await {
Ok(_) => {
self.tracker.forget(&t.order_id).await;
self.cancelled.fetch_add(1, Ordering::Relaxed);
self.metrics.counter(
"rustrade_orders_cancelled_ttl_total",
&[("symbol", symbol.as_str())],
1,
);
tracing::info!(symbol = %symbol, order_id = %t.order_id, ttl_secs = self.ttl.as_secs(), "cancelled stale resting order (TTL)");
}
Err(e) => {
tracing::warn!(symbol = %symbol, order_id = %t.order_id, error = %e, "TTL cancel failed; will retry next sweep")
}
}
}
}
}
}
}
}
}
#[async_trait]
impl TradingService for OrderReaperService {
fn name(&self) -> &str {
"order-reaper"
}
fn restart_policy(&self) -> RestartPolicy {
RestartPolicy::OnFailure
}
async fn run(&self, cancel: CancellationToken) -> anyhow::Result<()> {
tracing::info!(
ttl_secs = self.ttl.as_secs(),
cadence_secs = self.poll_cadence.as_secs(),
symbols = self.symbols.len(),
"order-reaper starting"
);
loop {
tokio::select! {
_ = cancel.cancelled() => {
tracing::info!(
sweeps = self.sweeps(),
cancelled = self.cancelled(),
reconciled = self.reconciled(),
"order-reaper shutting down"
);
return Ok(());
}
_ = tokio::time::sleep(self.poll_cadence) => {
self.sweep_once().await;
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rustrade_core::{Capability, NoopSink, Position, Price, Result, Side, Volume};
fn limit(symbol: &str) -> Order {
Order::limit(symbol, Side::Buy, Volume(1.0), Price(100.0))
}
#[tokio::test]
async fn tracker_ignores_market_orders() {
let t = OrderTracker::new();
t.record(
"m1".into(),
&Order::market("BTCUSDT", Side::Buy, Volume(1.0)),
)
.await;
assert!(t.is_empty().await, "market orders must not be tracked");
t.record("l1".into(), &limit("BTCUSDT")).await;
assert_eq!(t.len().await, 1);
}
#[tokio::test]
async fn tracker_forget_removes() {
let t = OrderTracker::new();
t.record("l1".into(), &limit("BTCUSDT")).await;
t.forget("l1").await;
assert!(t.is_empty().await);
}
struct MockEx {
open: std::sync::Mutex<Vec<rustrade_core::OpenOrder>>,
cancels: std::sync::Mutex<Vec<String>>,
}
impl MockEx {
fn new(open: Vec<rustrade_core::OpenOrder>) -> Arc<Self> {
Arc::new(Self {
open: std::sync::Mutex::new(open),
cancels: std::sync::Mutex::new(Vec::new()),
})
}
}
#[async_trait]
impl ExchangeClient for MockEx {
fn name(&self) -> &str {
"mock"
}
async fn place_order(&self, _o: &Order) -> Result<String> {
Ok("x".into())
}
async fn cancel_all(&self, _s: &Symbol) -> Result<usize> {
Ok(0)
}
async fn close_position(&self, _s: &Symbol, _p: &Position) -> Result<String> {
Ok("c".into())
}
async fn get_position(&self, _s: &Symbol) -> Result<Position> {
Ok(Position::FLAT)
}
async fn get_balance(&self, _c: &str) -> Result<f64> {
Ok(0.0)
}
fn supports(&self, c: Capability) -> bool {
matches!(c, Capability::OrderTracking)
}
async fn get_open_orders(&self, _s: &Symbol) -> Result<Vec<rustrade_core::OpenOrder>> {
Ok(self.open.lock().unwrap().clone())
}
async fn cancel_order(&self, _s: &Symbol, order_id: &str) -> Result<bool> {
self.cancels.lock().unwrap().push(order_id.to_string());
Ok(true)
}
}
fn open_order(id: &str, created_at: Option<DateTime<Utc>>) -> rustrade_core::OpenOrder {
rustrade_core::OpenOrder {
order_id: id.into(),
client_id: None,
symbol: Symbol::from("BTCUSDT"),
side: Side::Buy,
kind: OrderKind::Limit,
limit_price: Some(Price(100.0)),
size: Volume(1.0),
filled: Volume(0.0),
status: rustrade_core::OrderStatus::Open,
created_at,
}
}
fn reaper(ex: Arc<MockEx>, tracker: OrderTracker, ttl: Duration) -> OrderReaperService {
OrderReaperService::new(
ex,
tracker,
vec![Symbol::from("BTCUSDT")],
ttl,
Duration::from_secs(60),
Arc::new(NoopSink),
)
}
#[tokio::test]
async fn sweep_reconciles_away_vanished_order() {
let tracker = OrderTracker::new();
tracker.record("gone".into(), &limit("BTCUSDT")).await;
let ex = MockEx::new(vec![]);
let svc = reaper(ex.clone(), tracker.clone(), Duration::from_secs(3600));
svc.sweep_once().await;
assert!(
tracker.is_empty().await,
"vanished order should be reconciled away"
);
assert_eq!(svc.reconciled(), 1);
assert_eq!(svc.cancelled(), 0);
assert!(ex.cancels.lock().unwrap().is_empty());
}
#[tokio::test]
async fn sweep_keeps_fresh_resting_order() {
let tracker = OrderTracker::new();
tracker.record("fresh".into(), &limit("BTCUSDT")).await;
let ex = MockEx::new(vec![open_order("fresh", Some(Utc::now()))]);
let svc = reaper(ex.clone(), tracker.clone(), Duration::from_secs(3600));
svc.sweep_once().await;
assert_eq!(tracker.len().await, 1, "fresh order should remain tracked");
assert_eq!(svc.cancelled(), 0);
assert!(ex.cancels.lock().unwrap().is_empty());
}
#[tokio::test]
async fn sweep_cancels_order_past_ttl() {
let tracker = OrderTracker::new();
tracker.record("stale".into(), &limit("BTCUSDT")).await;
let created = Utc::now() - chrono::Duration::hours(1);
let ex = MockEx::new(vec![open_order("stale", Some(created))]);
let svc = reaper(ex.clone(), tracker.clone(), Duration::from_secs(1));
svc.sweep_once().await;
assert_eq!(svc.cancelled(), 1, "stale order should be cancelled");
assert!(
tracker.is_empty().await,
"cancelled order should be forgotten"
);
assert_eq!(
ex.cancels.lock().unwrap().as_slice(),
&["stale".to_string()]
);
}
}