drm_core/strategy/
order_tracker.rs

1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4use chrono::{DateTime, Utc};
5
6use crate::models::{Order, OrderStatus};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum OrderEvent {
10    Created,
11    PartialFill,
12    Filled,
13    Cancelled,
14    Rejected,
15    Expired,
16}
17
18#[derive(Debug, Clone)]
19pub struct TrackedOrder {
20    pub order: Order,
21    pub total_filled: f64,
22    pub created_time: DateTime<Utc>,
23}
24
25impl TrackedOrder {
26    pub fn new(order: Order) -> Self {
27        let filled = order.filled;
28        Self {
29            order,
30            total_filled: filled,
31            created_time: Utc::now(),
32        }
33    }
34}
35
36pub type OrderCallback = Arc<dyn Fn(OrderEvent, &Order, f64) + Send + Sync>;
37
38pub struct OrderTracker {
39    tracked_orders: RwLock<HashMap<String, TrackedOrder>>,
40    callbacks: RwLock<Vec<OrderCallback>>,
41    verbose: bool,
42}
43
44impl OrderTracker {
45    pub fn new(verbose: bool) -> Self {
46        Self {
47            tracked_orders: RwLock::new(HashMap::new()),
48            callbacks: RwLock::new(Vec::new()),
49            verbose,
50        }
51    }
52
53    pub fn on_fill<F>(&self, callback: F) -> &Self
54    where
55        F: Fn(OrderEvent, &Order, f64) + Send + Sync + 'static,
56    {
57        let mut callbacks = self.callbacks.write().unwrap();
58        callbacks.push(Arc::new(callback));
59        self
60    }
61
62    pub fn track_order(&self, order: Order) {
63        let order_id = order.id.clone();
64        let mut tracked = self.tracked_orders.write().unwrap();
65
66        if tracked.contains_key(&order_id) {
67            return;
68        }
69
70        if self.verbose {
71            let id_preview = if order_id.len() > 16 {
72                &order_id[..16]
73            } else {
74                &order_id
75            };
76            println!("Tracking order {id_preview}...");
77        }
78
79        tracked.insert(order_id, TrackedOrder::new(order));
80    }
81
82    pub fn untrack_order(&self, order_id: &str) {
83        let mut tracked = self.tracked_orders.write().unwrap();
84        tracked.remove(order_id);
85    }
86
87    pub fn handle_trade(
88        &self,
89        order_id: &str,
90        fill_size: f64,
91        fill_price: f64,
92        market_id: Option<&str>,
93        outcome: Option<&str>,
94    ) {
95        let (event, updated_order) = {
96            let mut tracked = self.tracked_orders.write().unwrap();
97
98            let tracked_order = match tracked.get_mut(order_id) {
99                Some(t) => t,
100                None => return,
101            };
102
103            tracked_order.total_filled += fill_size;
104
105            let is_complete = tracked_order.total_filled >= tracked_order.order.size;
106            let new_status = if is_complete {
107                OrderStatus::Filled
108            } else {
109                OrderStatus::PartiallyFilled
110            };
111
112            let updated_order = Order {
113                id: tracked_order.order.id.clone(),
114                market_id: market_id
115                    .map(|s| s.to_string())
116                    .unwrap_or_else(|| tracked_order.order.market_id.clone()),
117                outcome: outcome
118                    .map(|s| s.to_string())
119                    .unwrap_or_else(|| tracked_order.order.outcome.clone()),
120                side: tracked_order.order.side,
121                price: fill_price,
122                size: tracked_order.order.size,
123                filled: tracked_order.total_filled,
124                status: new_status,
125                created_at: tracked_order.order.created_at,
126                updated_at: Some(Utc::now()),
127            };
128
129            tracked_order.order = updated_order.clone();
130
131            let event = if is_complete {
132                OrderEvent::Filled
133            } else {
134                OrderEvent::PartialFill
135            };
136
137            (event, updated_order)
138        };
139
140        self.emit(event, &updated_order, fill_size);
141
142        if event == OrderEvent::Filled {
143            self.untrack_order(order_id);
144        }
145    }
146
147    pub fn handle_cancel(&self, order_id: &str) {
148        let order = {
149            let tracked = self.tracked_orders.read().unwrap();
150            tracked.get(order_id).map(|t| t.order.clone())
151        };
152
153        if let Some(order) = order {
154            self.emit(OrderEvent::Cancelled, &order, 0.0);
155            self.untrack_order(order_id);
156        }
157    }
158
159    fn emit(&self, event: OrderEvent, order: &Order, fill_size: f64) {
160        let callbacks = self.callbacks.read().unwrap();
161        for callback in callbacks.iter() {
162            callback(event, order, fill_size);
163        }
164    }
165
166    pub fn tracked_count(&self) -> usize {
167        self.tracked_orders.read().unwrap().len()
168    }
169
170    pub fn get_tracked_orders(&self) -> Vec<Order> {
171        self.tracked_orders
172            .read()
173            .unwrap()
174            .values()
175            .map(|t| t.order.clone())
176            .collect()
177    }
178
179    pub fn clear(&self) {
180        self.tracked_orders.write().unwrap().clear();
181    }
182}
183
184impl Default for OrderTracker {
185    fn default() -> Self {
186        Self::new(false)
187    }
188}
189
190pub fn create_fill_logger() -> impl Fn(OrderEvent, &Order, f64) + Send + Sync {
191    move |event: OrderEvent, order: &Order, fill_size: f64| {
192        let side_str = format!("{:?}", order.side).to_uppercase();
193
194        match event {
195            OrderEvent::Filled => {
196                println!(
197                    "FILLED {} {} {:.2} @ {:.4}",
198                    order.outcome, side_str, fill_size, order.price
199                );
200            }
201            OrderEvent::PartialFill => {
202                println!(
203                    "PARTIAL {} {} +{:.2} ({:.2}/{:.2}) @ {:.4}",
204                    order.outcome, side_str, fill_size, order.filled, order.size, order.price
205                );
206            }
207            OrderEvent::Cancelled => {
208                println!(
209                    "CANCELLED {} {} {:.2} @ {:.4} (filled: {:.2})",
210                    order.outcome, side_str, order.size, order.price, order.filled
211                );
212            }
213            _ => {}
214        }
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221    use crate::models::OrderSide;
222    use std::sync::atomic::{AtomicUsize, Ordering};
223
224    fn make_test_order(id: &str, size: f64) -> Order {
225        Order {
226            id: id.to_string(),
227            market_id: "test-market".to_string(),
228            outcome: "Yes".to_string(),
229            side: OrderSide::Buy,
230            price: 0.50,
231            size,
232            filled: 0.0,
233            status: OrderStatus::Open,
234            created_at: Utc::now(),
235            updated_at: None,
236        }
237    }
238
239    #[test]
240    fn test_track_order() {
241        // given
242        let tracker = OrderTracker::new(false);
243        let order = make_test_order("order-1", 10.0);
244
245        // when
246        tracker.track_order(order);
247
248        // then
249        assert_eq!(tracker.tracked_count(), 1);
250    }
251
252    #[test]
253    fn test_partial_fill() {
254        // given
255        let tracker = OrderTracker::new(false);
256        let order = make_test_order("order-1", 10.0);
257        tracker.track_order(order);
258
259        let fill_count = Arc::new(AtomicUsize::new(0));
260        let fill_count_clone = fill_count.clone();
261
262        tracker.on_fill(move |event, _, _| {
263            if event == OrderEvent::PartialFill {
264                fill_count_clone.fetch_add(1, Ordering::SeqCst);
265            }
266        });
267
268        // when
269        tracker.handle_trade("order-1", 3.0, 0.50, None, None);
270
271        // then
272        assert_eq!(fill_count.load(Ordering::SeqCst), 1);
273        assert_eq!(tracker.tracked_count(), 1);
274    }
275
276    #[test]
277    fn test_complete_fill() {
278        // given
279        let tracker = OrderTracker::new(false);
280        let order = make_test_order("order-1", 10.0);
281        tracker.track_order(order);
282
283        let filled = Arc::new(AtomicUsize::new(0));
284        let filled_clone = filled.clone();
285
286        tracker.on_fill(move |event, _, _| {
287            if event == OrderEvent::Filled {
288                filled_clone.fetch_add(1, Ordering::SeqCst);
289            }
290        });
291
292        // when
293        tracker.handle_trade("order-1", 10.0, 0.50, None, None);
294
295        // then
296        assert_eq!(filled.load(Ordering::SeqCst), 1);
297        assert_eq!(tracker.tracked_count(), 0);
298    }
299
300    #[test]
301    fn test_cancel() {
302        // given
303        let tracker = OrderTracker::new(false);
304        let order = make_test_order("order-1", 10.0);
305        tracker.track_order(order);
306
307        let cancelled = Arc::new(AtomicUsize::new(0));
308        let cancelled_clone = cancelled.clone();
309
310        tracker.on_fill(move |event, _, _| {
311            if event == OrderEvent::Cancelled {
312                cancelled_clone.fetch_add(1, Ordering::SeqCst);
313            }
314        });
315
316        // when
317        tracker.handle_cancel("order-1");
318
319        // then
320        assert_eq!(cancelled.load(Ordering::SeqCst), 1);
321        assert_eq!(tracker.tracked_count(), 0);
322    }
323}