1use crate::{BrokerClient, BrokerError, ExecutionError, OrderRequest, OrderResponse, Result};
10use parking_lot::RwLock;
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13use tracing::{debug, error, info, warn};
14
15#[derive(Debug, Clone, Copy)]
17pub enum RoutingStrategy {
18 RoundRobin,
20 LowestFee,
22 FastestExecution,
24 PrimaryWithFallback,
26}
27
28#[derive(Debug, Clone)]
30enum CircuitState {
31 Closed { failure_count: u32 },
32 Open { opened_at: Instant },
33 HalfOpen,
34}
35
36struct CircuitBreaker {
38 state: RwLock<CircuitState>,
39 failure_threshold: u32,
40 reset_timeout: Duration,
41}
42
43impl CircuitBreaker {
44 fn new(failure_threshold: u32, reset_timeout: Duration) -> Self {
45 Self {
46 state: RwLock::new(CircuitState::Closed { failure_count: 0 }),
47 failure_threshold,
48 reset_timeout,
49 }
50 }
51
52 fn is_open(&self) -> bool {
53 let state = self.state.read();
54 matches!(*state, CircuitState::Open { .. })
55 }
56
57 fn record_success(&self) {
58 let mut state = self.state.write();
59 *state = CircuitState::Closed { failure_count: 0 };
60 }
61
62 fn record_failure(&self) {
63 let mut state = self.state.write();
64 match *state {
65 CircuitState::Closed { failure_count } => {
66 let new_count = failure_count + 1;
67 if new_count >= self.failure_threshold {
68 *state = CircuitState::Open {
69 opened_at: Instant::now(),
70 };
71 warn!("Circuit breaker opened after {} failures", new_count);
72 } else {
73 *state = CircuitState::Closed {
74 failure_count: new_count,
75 };
76 }
77 }
78 CircuitState::HalfOpen => {
79 *state = CircuitState::Open {
80 opened_at: Instant::now(),
81 };
82 warn!("Circuit breaker reopened after failure in half-open state");
83 }
84 CircuitState::Open { .. } => {}
85 }
86 }
87
88 fn try_reset(&self) -> bool {
89 let mut state = self.state.write();
90 if let CircuitState::Open { opened_at } = *state {
91 if opened_at.elapsed() >= self.reset_timeout {
92 *state = CircuitState::HalfOpen;
93 info!("Circuit breaker entering half-open state");
94 return true;
95 }
96 }
97 false
98 }
99}
100
101struct ProtectedBroker {
103 broker: Arc<dyn BrokerClient>,
104 circuit_breaker: CircuitBreaker,
105 name: String,
106}
107
108impl ProtectedBroker {
109 fn new(broker: Arc<dyn BrokerClient>, name: String) -> Self {
110 Self {
111 broker,
112 circuit_breaker: CircuitBreaker::new(3, Duration::from_secs(30)),
113 name,
114 }
115 }
116
117 async fn place_order(&self, order: OrderRequest) -> Result<OrderResponse> {
118 if self.circuit_breaker.is_open() {
120 self.circuit_breaker.try_reset();
121 if self.circuit_breaker.is_open() {
122 return Err(ExecutionError::CircuitBreakerOpen);
123 }
124 }
125
126 match self.broker.place_order(order).await {
128 Ok(response) => {
129 self.circuit_breaker.record_success();
130 Ok(response)
131 }
132 Err(e) => {
133 self.circuit_breaker.record_failure();
134 Err(e.into())
135 }
136 }
137 }
138
139 fn is_available(&self) -> bool {
140 !self.circuit_breaker.is_open()
141 }
142}
143
144pub struct OrderRouter {
146 brokers: Vec<ProtectedBroker>,
147 strategy: RoutingStrategy,
148 current_index: RwLock<usize>,
149}
150
151impl OrderRouter {
152 pub fn new(strategy: RoutingStrategy) -> Self {
154 Self {
155 brokers: Vec::new(),
156 strategy,
157 current_index: RwLock::new(0),
158 }
159 }
160
161 pub fn add_broker(mut self, broker: Arc<dyn BrokerClient>, name: String) -> Self {
163 self.brokers.push(ProtectedBroker::new(broker, name));
164 self
165 }
166
167 pub async fn route_order(&self, order: OrderRequest) -> Result<OrderResponse> {
169 if self.brokers.is_empty() {
170 return Err(ExecutionError::Order(
171 "No brokers available".to_string(),
172 ));
173 }
174
175 match self.strategy {
176 RoutingStrategy::RoundRobin => self.route_round_robin(order).await,
177 RoutingStrategy::PrimaryWithFallback => self.route_primary_with_fallback(order).await,
178 RoutingStrategy::LowestFee | RoutingStrategy::FastestExecution => {
179 self.route_round_robin(order).await
182 }
183 }
184 }
185
186 async fn route_round_robin(&self, order: OrderRequest) -> Result<OrderResponse> {
188 let start_index = {
189 let mut index = self.current_index.write();
190 let current = *index;
191 *index = (current + 1) % self.brokers.len();
192 current
193 };
194
195 for i in 0..self.brokers.len() {
197 let broker_index = (start_index + i) % self.brokers.len();
198 let broker = &self.brokers[broker_index];
199
200 if !broker.is_available() {
201 debug!(
202 "Broker {} unavailable (circuit breaker open), trying next",
203 broker.name
204 );
205 continue;
206 }
207
208 match broker.place_order(order.clone()).await {
209 Ok(response) => {
210 info!("Order routed to broker: {}", broker.name);
211 return Ok(response);
212 }
213 Err(e) => {
214 warn!(
215 "Failed to place order on broker {}: {}",
216 broker.name, e
217 );
218 continue;
219 }
220 }
221 }
222
223 error!("All brokers failed to execute order");
224 Err(ExecutionError::Order(
225 "All brokers failed to execute order".to_string(),
226 ))
227 }
228
229 async fn route_primary_with_fallback(&self, order: OrderRequest) -> Result<OrderResponse> {
231 if let Some(primary) = self.brokers.first() {
233 if primary.is_available() {
234 match primary.place_order(order.clone()).await {
235 Ok(response) => {
236 info!("Order routed to primary broker: {}", primary.name);
237 return Ok(response);
238 }
239 Err(e) => {
240 warn!(
241 "Primary broker {} failed: {}, trying fallbacks",
242 primary.name, e
243 );
244 }
245 }
246 } else {
247 warn!(
248 "Primary broker {} unavailable, trying fallbacks",
249 primary.name
250 );
251 }
252 }
253
254 for (i, broker) in self.brokers.iter().enumerate().skip(1) {
256 if !broker.is_available() {
257 continue;
258 }
259
260 match broker.place_order(order.clone()).await {
261 Ok(response) => {
262 info!(
263 "Order routed to fallback broker #{}: {}",
264 i, broker.name
265 );
266 return Ok(response);
267 }
268 Err(e) => {
269 warn!("Fallback broker {} failed: {}", broker.name, e);
270 continue;
271 }
272 }
273 }
274
275 error!("All brokers (primary and fallbacks) failed");
276 Err(ExecutionError::Order(
277 "All brokers failed to execute order".to_string(),
278 ))
279 }
280
281 pub fn get_broker_status(&self) -> Vec<(String, bool)> {
283 self.brokers
284 .iter()
285 .map(|b| (b.name.clone(), b.is_available()))
286 .collect()
287 }
288
289 pub fn available_brokers(&self) -> usize {
291 self.brokers.iter().filter(|b| b.is_available()).count()
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
300 fn test_circuit_breaker_opens_after_failures() {
301 let cb = CircuitBreaker::new(3, Duration::from_secs(30));
302
303 assert!(!cb.is_open());
304
305 cb.record_failure();
306 assert!(!cb.is_open());
307
308 cb.record_failure();
309 assert!(!cb.is_open());
310
311 cb.record_failure();
312 assert!(cb.is_open());
313 }
314
315 #[test]
316 fn test_circuit_breaker_resets_on_success() {
317 let cb = CircuitBreaker::new(3, Duration::from_secs(30));
318
319 cb.record_failure();
320 cb.record_failure();
321 assert!(!cb.is_open());
322
323 cb.record_success();
324 assert!(!cb.is_open());
325
326 cb.record_failure();
328 cb.record_failure();
329 assert!(!cb.is_open());
330
331 cb.record_failure();
332 assert!(cb.is_open());
333 }
334}