1use crate::{BrokerClient, ExecutionError, OrderSide, OrderType, Result, Symbol, TimeInForce};
11use chrono::{DateTime, Utc};
12use dashmap::DashMap;
13use rust_decimal::Decimal;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::time::Duration;
18use tokio::sync::{mpsc, oneshot};
19use tokio::time::timeout;
20use tracing::{debug, error, info, warn};
21use uuid::Uuid;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
25#[serde(rename_all = "lowercase")]
26pub enum OrderStatus {
27 Pending,
29 Accepted,
31 PartiallyFilled,
33 Filled,
35 Cancelled,
37 Rejected,
39 Expired,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct OrderRequest {
46 pub symbol: Symbol,
47 pub side: OrderSide,
48 pub order_type: OrderType,
49 pub quantity: u32,
50 pub limit_price: Option<Decimal>,
51 pub stop_price: Option<Decimal>,
52 pub time_in_force: TimeInForce,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct OrderResponse {
58 pub order_id: String,
59 pub client_order_id: String,
60 pub status: OrderStatus,
61 pub filled_qty: u32,
62 pub filled_avg_price: Option<Decimal>,
63 pub submitted_at: DateTime<Utc>,
64 pub filled_at: Option<DateTime<Utc>>,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct OrderUpdate {
70 pub order_id: String,
71 pub status: OrderStatus,
72 pub filled_qty: u32,
73 pub filled_avg_price: Option<Decimal>,
74 pub timestamp: DateTime<Utc>,
75}
76
77#[derive(Debug, Clone)]
79struct TrackedOrder {
80 request: OrderRequest,
81 response: Option<OrderResponse>,
82 status: OrderStatus,
83 created_at: DateTime<Utc>,
84 updated_at: DateTime<Utc>,
85}
86
87enum OrderMessage {
89 PlaceOrder {
90 request: OrderRequest,
91 response_tx: oneshot::Sender<Result<OrderResponse>>,
92 },
93 CancelOrder {
94 order_id: String,
95 response_tx: oneshot::Sender<Result<()>>,
96 },
97 GetOrderStatus {
98 order_id: String,
99 response_tx: oneshot::Sender<Result<OrderStatus>>,
100 },
101 UpdateOrder {
102 update: OrderUpdate,
103 },
104 Shutdown,
105}
106
107pub struct OrderManager {
109 message_tx: mpsc::Sender<OrderMessage>,
110 orders: Arc<DashMap<String, TrackedOrder>>,
111}
112
113impl OrderManager {
114 pub fn new<B: BrokerClient + 'static>(broker: Arc<B>) -> Self {
116 let (message_tx, message_rx) = mpsc::channel(1000);
117 let orders = Arc::new(DashMap::new());
118
119 let orders_clone = Arc::clone(&orders);
121 tokio::spawn(async move {
122 Self::actor_loop(broker, message_rx, orders_clone).await;
123 });
124
125 Self { message_tx, orders }
126 }
127
128 pub async fn place_order(&self, request: OrderRequest) -> Result<OrderResponse> {
132 let (response_tx, response_rx) = oneshot::channel();
133
134 self.message_tx
135 .send(OrderMessage::PlaceOrder {
136 request,
137 response_tx,
138 })
139 .await
140 .map_err(|e| ExecutionError::Order(format!("Failed to send message: {}", e)))?;
141
142 timeout(Duration::from_secs(10), response_rx)
144 .await
145 .map_err(|_| ExecutionError::Timeout)?
146 .map_err(|e| ExecutionError::Order(format!("Failed to receive response: {}", e)))?
147 }
148
149 pub async fn cancel_order(&self, order_id: String) -> Result<()> {
151 let (response_tx, response_rx) = oneshot::channel();
152
153 self.message_tx
154 .send(OrderMessage::CancelOrder {
155 order_id,
156 response_tx,
157 })
158 .await
159 .map_err(|e| ExecutionError::Order(format!("Failed to send message: {}", e)))?;
160
161 timeout(Duration::from_secs(5), response_rx)
162 .await
163 .map_err(|_| ExecutionError::Timeout)?
164 .map_err(|e| ExecutionError::Order(format!("Failed to receive response: {}", e)))?
165 }
166
167 pub async fn get_order_status(&self, order_id: &str) -> Result<OrderStatus> {
169 if let Some(order) = self.orders.get(order_id) {
171 return Ok(order.status);
172 }
173
174 let (response_tx, response_rx) = oneshot::channel();
175
176 self.message_tx
177 .send(OrderMessage::GetOrderStatus {
178 order_id: order_id.to_string(),
179 response_tx,
180 })
181 .await
182 .map_err(|e| ExecutionError::Order(format!("Failed to send message: {}", e)))?;
183
184 timeout(Duration::from_secs(5), response_rx)
185 .await
186 .map_err(|_| ExecutionError::Timeout)?
187 .map_err(|e| ExecutionError::Order(format!("Failed to receive response: {}", e)))?
188 }
189
190 pub async fn handle_order_update(&self, update: OrderUpdate) -> Result<()> {
192 self.message_tx
193 .send(OrderMessage::UpdateOrder { update })
194 .await
195 .map_err(|e| ExecutionError::Order(format!("Failed to send update: {}", e)))?;
196
197 Ok(())
198 }
199
200 pub fn get_all_orders(&self) -> Vec<(String, OrderStatus)> {
202 self.orders
203 .iter()
204 .map(|entry| (entry.key().clone(), entry.value().status))
205 .collect()
206 }
207
208 pub async fn shutdown(&self) -> Result<()> {
210 self.message_tx
211 .send(OrderMessage::Shutdown)
212 .await
213 .map_err(|e| ExecutionError::Order(format!("Failed to send shutdown: {}", e)))?;
214
215 Ok(())
216 }
217
218 async fn actor_loop<B: BrokerClient + 'static>(
220 broker: Arc<B>,
221 mut message_rx: mpsc::Receiver<OrderMessage>,
222 orders: Arc<DashMap<String, TrackedOrder>>,
223 ) {
224 info!("Order manager actor started");
225
226 while let Some(message) = message_rx.recv().await {
227 match message {
228 OrderMessage::PlaceOrder {
229 request,
230 response_tx,
231 } => {
232 let result =
233 Self::handle_place_order(Arc::clone(&broker), &orders, request).await;
234 let _ = response_tx.send(result);
235 }
236
237 OrderMessage::CancelOrder {
238 order_id,
239 response_tx,
240 } => {
241 let result =
242 Self::handle_cancel_order(Arc::clone(&broker), &orders, &order_id).await;
243 let _ = response_tx.send(result);
244 }
245
246 OrderMessage::GetOrderStatus {
247 order_id,
248 response_tx,
249 } => {
250 let result =
251 Self::handle_get_status(Arc::clone(&broker), &orders, &order_id).await;
252 let _ = response_tx.send(result);
253 }
254
255 OrderMessage::UpdateOrder { update } => {
256 Self::handle_order_update_internal(&orders, update);
257 }
258
259 OrderMessage::Shutdown => {
260 info!("Order manager actor shutting down");
261 break;
262 }
263 }
264 }
265
266 info!("Order manager actor stopped");
267 }
268
269 async fn handle_place_order<B: BrokerClient + 'static>(
270 broker: Arc<B>,
271 orders: &Arc<DashMap<String, TrackedOrder>>,
272 request: OrderRequest,
273 ) -> Result<OrderResponse> {
274 debug!("Placing order: {:?}", request);
275
276 let response = retry_with_backoff(
278 || {
279 let broker = Arc::clone(&broker);
280 let req = request.clone();
281 Box::pin(async move { broker.place_order(req).await })
282 },
283 3,
284 Duration::from_millis(100),
285 )
286 .await?;
287
288 info!(
289 "Order placed: {} status={:?}",
290 response.order_id, response.status
291 );
292
293 orders.insert(
295 response.order_id.clone(),
296 TrackedOrder {
297 request: request.clone(),
298 response: Some(response.clone()),
299 status: response.status,
300 created_at: Utc::now(),
301 updated_at: Utc::now(),
302 },
303 );
304
305 Ok(response)
306 }
307
308 async fn handle_cancel_order<B: BrokerClient>(
309 broker: Arc<B>,
310 orders: &Arc<DashMap<String, TrackedOrder>>,
311 order_id: &str,
312 ) -> Result<()> {
313 debug!("Cancelling order: {}", order_id);
314
315 broker.cancel_order(order_id).await?;
316
317 if let Some(mut order) = orders.get_mut(order_id) {
319 order.status = OrderStatus::Cancelled;
320 order.updated_at = Utc::now();
321 }
322
323 info!("Order cancelled: {}", order_id);
324 Ok(())
325 }
326
327 async fn handle_get_status<B: BrokerClient>(
328 broker: Arc<B>,
329 orders: &Arc<DashMap<String, TrackedOrder>>,
330 order_id: &str,
331 ) -> Result<OrderStatus> {
332 if let Some(order) = orders.get(order_id) {
334 return Ok(order.status);
335 }
336
337 let order = broker.get_order(order_id).await?;
339
340 if let Some(mut tracked) = orders.get_mut(order_id) {
342 tracked.status = order.status;
343 tracked.updated_at = Utc::now();
344 }
345
346 Ok(order.status)
347 }
348
349 fn handle_order_update_internal(orders: &Arc<DashMap<String, TrackedOrder>>, update: OrderUpdate) {
350 if let Some(mut order) = orders.get_mut(&update.order_id) {
351 order.status = update.status;
352 order.updated_at = update.timestamp;
353
354 if let Some(ref mut response) = order.response {
355 response.status = update.status;
356 response.filled_qty = update.filled_qty;
357 response.filled_avg_price = update.filled_avg_price;
358 }
359
360 debug!(
361 "Order updated: {} status={:?} filled={}",
362 update.order_id, update.status, update.filled_qty
363 );
364 } else {
365 warn!("Received update for unknown order: {}", update.order_id);
366 }
367 }
368}
369
370async fn retry_with_backoff<F, T, E>(
372 mut f: F,
373 max_attempts: u32,
374 initial_delay: Duration,
375) -> Result<T>
376where
377 F: FnMut() -> std::pin::Pin<Box<dyn std::future::Future<Output = std::result::Result<T, E>> + Send>>,
378 E: Into<ExecutionError>,
379{
380 let mut delay = initial_delay;
381
382 for attempt in 1..=max_attempts {
383 match f().await {
384 Ok(result) => return Ok(result),
385 Err(e) if attempt == max_attempts => {
386 error!("All {} retry attempts failed", max_attempts);
387 return Err(e.into());
388 }
389 Err(e) => {
390 warn!(
391 "Attempt {} failed, retrying in {:?}...",
392 attempt, delay
393 );
394 tokio::time::sleep(delay).await;
395 delay *= 2; }
397 }
398 }
399
400 unreachable!()
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406
407 #[tokio::test]
408 async fn test_order_request_serialization() {
409 let request = OrderRequest {
410 symbol: Symbol::new("AAPL").expect("Valid symbol"),
411 side: OrderSide::Buy,
412 order_type: OrderType::Market,
413 quantity: 100,
414 limit_price: None,
415 stop_price: None,
416 time_in_force: TimeInForce::Day,
417 };
418
419 let json = serde_json::to_string(&request).unwrap();
420 let deserialized: OrderRequest = serde_json::from_str(&json).unwrap();
421
422 assert_eq!(request.symbol, deserialized.symbol);
423 assert_eq!(request.quantity, deserialized.quantity);
424 }
425}