1use std::collections::HashMap;
2use std::sync::Arc;
3
4use parking_lot::Mutex;
5
6use chrono::{DateTime, Utc};
7use serde::Serialize;
8
9#[derive(Debug, Clone)]
11enum CircuitState {
12 Closed {
13 failures: u32,
14 },
15 Open {
16 opened_at: DateTime<Utc>,
17 failures: u32,
18 },
19 HalfOpen {
20 failures: u32,
21 },
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum CanDeliver {
27 Yes,
28 No,
29 Probe,
30}
31
32#[derive(Debug, Clone, Serialize)]
34pub struct CircuitStatus {
35 pub agent_id: String,
36 pub state: String,
37 pub failures: u32,
38 pub opened_at: Option<String>,
39}
40
41#[derive(Debug, Clone)]
43pub struct CircuitConfig {
44 pub failure_threshold: u32,
45 pub cooldown_seconds: u64,
46}
47
48impl Default for CircuitConfig {
49 fn default() -> Self {
50 Self {
51 failure_threshold: 5,
52 cooldown_seconds: 60,
53 }
54 }
55}
56
57pub struct CircuitBreaker {
62 states: Arc<Mutex<HashMap<String, CircuitState>>>,
63 config: CircuitConfig,
64}
65
66impl CircuitBreaker {
67 pub fn new(config: CircuitConfig) -> Self {
68 Self {
69 states: Arc::new(Mutex::new(HashMap::new())),
70 config,
71 }
72 }
73
74 pub fn with_defaults() -> Self {
75 Self::new(CircuitConfig::default())
76 }
77
78 pub fn check(&self, agent_id: &str) -> CanDeliver {
80 let mut states = self.states.lock();
81 let state = states
82 .entry(agent_id.to_string())
83 .or_insert_with(|| CircuitState::Closed { failures: 0 });
84
85 match state {
86 CircuitState::Closed { .. } => CanDeliver::Yes,
87 CircuitState::Open { opened_at, .. } => {
88 let elapsed = (Utc::now() - *opened_at).num_seconds();
89 if elapsed >= self.config.cooldown_seconds as i64 {
90 let failures = match state {
91 CircuitState::Open { failures, .. } => *failures,
92 _ => 0,
93 };
94 *state = CircuitState::HalfOpen { failures };
95 CanDeliver::Probe
96 } else {
97 CanDeliver::No
98 }
99 }
100 CircuitState::HalfOpen { .. } => CanDeliver::Probe,
101 }
102 }
103
104 pub fn record_failure(&self, agent_id: &str) {
106 let mut states = self.states.lock();
107 let state = states
108 .entry(agent_id.to_string())
109 .or_insert_with(|| CircuitState::Closed { failures: 0 });
110
111 match state {
112 CircuitState::Closed { failures } => {
113 let new_failures = *failures + 1;
114 if new_failures >= self.config.failure_threshold {
115 *state = CircuitState::Open {
116 opened_at: Utc::now(),
117 failures: new_failures,
118 };
119 } else {
120 *state = CircuitState::Closed {
121 failures: new_failures,
122 };
123 }
124 }
125 CircuitState::Open { failures, .. } => {
126 *state = CircuitState::Open {
127 opened_at: Utc::now(),
128 failures: *failures + 1,
129 };
130 }
131 CircuitState::HalfOpen { failures } => {
132 *state = CircuitState::Open {
133 opened_at: Utc::now(),
134 failures: *failures + 1,
135 };
136 }
137 }
138 }
139
140 pub fn record_success(&self, agent_id: &str) {
142 let mut states = self.states.lock();
143 if let Some(state) = states.get_mut(agent_id) {
144 if matches!(state, CircuitState::HalfOpen { .. }) {
145 *state = CircuitState::Closed { failures: 0 };
146 }
147 }
148 }
149
150 pub fn reset(&self, agent_id: &str) {
152 let mut states = self.states.lock();
153 states.insert(agent_id.to_string(), CircuitState::Closed { failures: 0 });
154 }
155
156 pub fn remove(&self, agent_id: &str) {
158 let mut states = self.states.lock();
159 states.remove(agent_id);
160 }
161
162 pub fn evict_stale(&self) -> usize {
165 let mut states = self.states.lock();
166 let cutoff = Utc::now() - chrono::Duration::hours(1);
167 let before = states.len();
168 states.retain(|_, state| match state {
169 CircuitState::Open { opened_at, .. } => *opened_at > cutoff,
170 _ => true,
171 });
172 before - states.len()
173 }
174
175 pub fn get_state(&self, agent_id: &str) -> CircuitStatus {
177 let states = self.states.lock();
178 match states.get(agent_id) {
179 None => CircuitStatus {
180 agent_id: agent_id.to_string(),
181 state: "closed".to_string(),
182 failures: 0,
183 opened_at: None,
184 },
185 Some(CircuitState::Closed { failures }) => CircuitStatus {
186 agent_id: agent_id.to_string(),
187 state: "closed".to_string(),
188 failures: *failures,
189 opened_at: None,
190 },
191 Some(CircuitState::Open {
192 opened_at,
193 failures,
194 }) => CircuitStatus {
195 agent_id: agent_id.to_string(),
196 state: "open".to_string(),
197 failures: *failures,
198 opened_at: Some(opened_at.to_rfc3339()),
199 },
200 Some(CircuitState::HalfOpen { failures }) => CircuitStatus {
201 agent_id: agent_id.to_string(),
202 state: "half_open".to_string(),
203 failures: *failures,
204 opened_at: None,
205 },
206 }
207 }
208
209 pub fn list_active(&self) -> Vec<CircuitStatus> {
211 let states = self.states.lock();
212 let mut result = Vec::new();
213 for (agent_id, state) in states.iter() {
214 match state {
215 CircuitState::Closed { failures: 0 } => continue,
216 CircuitState::Closed { failures } => result.push(CircuitStatus {
217 agent_id: agent_id.clone(),
218 state: "closed".to_string(),
219 failures: *failures,
220 opened_at: None,
221 }),
222 CircuitState::Open {
223 opened_at,
224 failures,
225 } => result.push(CircuitStatus {
226 agent_id: agent_id.clone(),
227 state: "open".to_string(),
228 failures: *failures,
229 opened_at: Some(opened_at.to_rfc3339()),
230 }),
231 CircuitState::HalfOpen { failures } => result.push(CircuitStatus {
232 agent_id: agent_id.clone(),
233 state: "half_open".to_string(),
234 failures: *failures,
235 opened_at: None,
236 }),
237 }
238 }
239 result
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 fn test_breaker() -> CircuitBreaker {
248 CircuitBreaker::new(CircuitConfig {
249 failure_threshold: 3,
250 cooldown_seconds: 60,
251 })
252 }
253
254 fn breaker_with_cooldown() -> CircuitBreaker {
255 CircuitBreaker::new(CircuitConfig {
256 failure_threshold: 3,
257 cooldown_seconds: 60,
258 })
259 }
260
261 #[test]
262 fn starts_closed() {
263 let cb = test_breaker();
264 assert_eq!(cb.check("agent1"), CanDeliver::Yes);
265 }
266
267 #[test]
268 fn stays_closed_below_threshold() {
269 let cb = test_breaker();
270 cb.record_failure("agent1");
271 cb.record_failure("agent1");
272 assert_eq!(cb.check("agent1"), CanDeliver::Yes);
273 }
274
275 #[test]
276 fn opens_at_threshold() {
277 let cb = test_breaker();
278 cb.record_failure("agent1");
279 cb.record_failure("agent1");
280 cb.record_failure("agent1");
281 assert_eq!(cb.check("agent1"), CanDeliver::No);
282 let status = cb.get_state("agent1");
283 assert_eq!(status.state, "open");
284 }
285
286 #[test]
287 fn half_open_after_cooldown() {
288 let cb = breaker_with_cooldown();
289 cb.record_failure("agent1");
290 cb.record_failure("agent1");
291 cb.record_failure("agent1");
292 assert_eq!(cb.check("agent1"), CanDeliver::No);
293
294 {
296 let mut states = cb.states.lock();
297 states.insert(
298 "agent1".to_string(),
299 CircuitState::Open {
300 opened_at: Utc::now() - chrono::Duration::seconds(120),
301 failures: 3,
302 },
303 );
304 }
305 assert_eq!(cb.check("agent1"), CanDeliver::Probe);
306 }
307
308 #[test]
309 fn half_open_success_closes() {
310 let cb = test_breaker();
311 {
313 let mut states = cb.states.lock();
314 states.insert("agent1".to_string(), CircuitState::HalfOpen { failures: 3 });
315 }
316 cb.record_success("agent1");
317 assert_eq!(cb.check("agent1"), CanDeliver::Yes);
318 let status = cb.get_state("agent1");
319 assert_eq!(status.failures, 0);
320 }
321
322 #[test]
323 fn half_open_failure_reopens() {
324 let cb = breaker_with_cooldown();
325 {
326 let mut states = cb.states.lock();
327 states.insert("agent1".to_string(), CircuitState::HalfOpen { failures: 3 });
328 }
329 cb.record_failure("agent1");
330 assert_eq!(cb.check("agent1"), CanDeliver::No);
331 }
332
333 #[test]
334 fn heartbeat_resets_closed() {
335 let cb = test_breaker();
336 cb.record_failure("agent1");
337 cb.record_failure("agent1");
338 cb.record_failure("agent1");
339 assert_eq!(cb.check("agent1"), CanDeliver::No);
340
341 cb.reset("agent1");
342 assert_eq!(cb.check("agent1"), CanDeliver::Yes);
343 assert_eq!(cb.get_state("agent1").failures, 0);
344 }
345
346 #[test]
347 fn heartbeat_resets_open() {
348 let cb = breaker_with_cooldown();
349 cb.record_failure("agent1");
350 cb.record_failure("agent1");
351 cb.record_failure("agent1");
352 assert_eq!(cb.check("agent1"), CanDeliver::No);
353
354 cb.reset("agent1");
355 assert_eq!(cb.check("agent1"), CanDeliver::Yes);
356 }
357
358 #[test]
359 fn heartbeat_resets_half_open() {
360 let cb = breaker_with_cooldown();
361 {
362 let mut states = cb.states.lock();
363 states.insert("agent1".to_string(), CircuitState::HalfOpen { failures: 3 });
364 }
365 cb.reset("agent1");
366 assert_eq!(cb.check("agent1"), CanDeliver::Yes);
367 }
368
369 #[test]
370 fn different_agents_independent() {
371 let cb = test_breaker();
372 cb.record_failure("agent1");
373 cb.record_failure("agent1");
374 cb.record_failure("agent1");
375 assert_eq!(cb.check("agent1"), CanDeliver::No);
376 assert_eq!(cb.check("agent2"), CanDeliver::Yes);
377 }
378
379 #[test]
380 fn remove_clears_state() {
381 let cb = test_breaker();
382 cb.record_failure("agent1");
383 cb.record_failure("agent1");
384 cb.record_failure("agent1");
385 cb.remove("agent1");
386 assert_eq!(cb.check("agent1"), CanDeliver::Yes);
387 }
388
389 #[test]
390 fn list_active_skips_healthy() {
391 let cb = test_breaker();
392 cb.record_failure("agent1");
393 cb.record_failure("agent1");
394 cb.record_failure("agent1");
395 let active = cb.list_active();
396 assert_eq!(active.len(), 1);
397 assert_eq!(active[0].agent_id, "agent1");
398 }
399
400 #[test]
401 fn full_lifecycle() {
402 let cb = breaker_with_cooldown();
403
404 assert_eq!(cb.check("agent1"), CanDeliver::Yes);
406
407 for _ in 0..3 {
409 cb.record_failure("agent1");
410 }
411 assert_eq!(cb.check("agent1"), CanDeliver::No);
412
413 {
415 let mut states = cb.states.lock();
416 states.insert(
417 "agent1".to_string(),
418 CircuitState::Open {
419 opened_at: Utc::now() - chrono::Duration::seconds(120),
420 failures: 3,
421 },
422 );
423 }
424 assert_eq!(cb.check("agent1"), CanDeliver::Probe);
425
426 cb.record_success("agent1");
428 assert_eq!(cb.check("agent1"), CanDeliver::Yes);
429 assert_eq!(cb.get_state("agent1").failures, 0);
430 }
431}