1use crate::error::AgentRuntimeError;
44use crate::util::timed_lock;
45use std::collections::HashMap;
46use std::sync::{Arc, Mutex};
47use std::time::{Duration, Instant};
48
49pub const MAX_RETRY_DELAY: Duration = Duration::from_secs(60);
51
52#[derive(Debug, Clone)]
56pub struct RetryPolicy {
57 pub max_attempts: u32,
59 pub base_delay: Duration,
61}
62
63impl RetryPolicy {
64 pub fn exponential(max_attempts: u32, base_ms: u64) -> Result<Self, AgentRuntimeError> {
74 if max_attempts == 0 {
75 return Err(AgentRuntimeError::Orchestration(
76 "max_attempts must be >= 1".into(),
77 ));
78 }
79 if base_ms == 0 {
80 return Err(AgentRuntimeError::Orchestration(
81 "base_ms must be >= 1 to avoid zero-delay busy-loop retries".into(),
82 ));
83 }
84 Ok(Self {
85 max_attempts,
86 base_delay: Duration::from_millis(base_ms),
87 })
88 }
89
90 pub fn delay_for(&self, attempt: u32) -> Duration {
94 let exp = attempt.saturating_sub(1);
95 let multiplier = 1u64.checked_shl(exp.min(63)).unwrap_or(u64::MAX);
96 let millis = self
97 .base_delay
98 .as_millis()
99 .saturating_mul(multiplier as u128);
100 let raw = Duration::from_millis(millis.min(u64::MAX as u128) as u64);
101 raw.min(MAX_RETRY_DELAY)
102 }
103}
104
105#[derive(Debug, Clone)]
115pub enum CircuitState {
116 Closed,
118 Open {
120 opened_at: Instant,
122 },
123 HalfOpen,
125}
126
127impl PartialEq for CircuitState {
128 fn eq(&self, other: &Self) -> bool {
129 match (self, other) {
130 (CircuitState::Closed, CircuitState::Closed) => true,
131 (CircuitState::Open { .. }, CircuitState::Open { .. }) => true,
132 (CircuitState::HalfOpen, CircuitState::HalfOpen) => true,
133 _ => false,
134 }
135 }
136}
137
138impl Eq for CircuitState {}
139
140pub trait CircuitBreakerBackend: Send + Sync {
148 fn increment_failures(&self, service: &str) -> u32;
150 fn reset_failures(&self, service: &str);
152 fn get_failures(&self, service: &str) -> u32;
154 fn set_open_at(&self, service: &str, at: std::time::Instant);
156 fn clear_open_at(&self, service: &str);
158 fn get_open_at(&self, service: &str) -> Option<std::time::Instant>;
160}
161
162pub struct InMemoryCircuitBreakerBackend {
171 inner: Arc<Mutex<HashMap<String, InMemoryServiceState>>>,
172}
173
174#[derive(Default)]
175struct InMemoryServiceState {
176 consecutive_failures: u32,
177 open_at: Option<std::time::Instant>,
178}
179
180impl InMemoryCircuitBreakerBackend {
181 pub fn new() -> Self {
183 Self {
184 inner: Arc::new(Mutex::new(HashMap::new())),
185 }
186 }
187}
188
189impl Default for InMemoryCircuitBreakerBackend {
190 fn default() -> Self {
191 Self::new()
192 }
193}
194
195impl CircuitBreakerBackend for InMemoryCircuitBreakerBackend {
196 fn increment_failures(&self, service: &str) -> u32 {
197 let mut map = timed_lock(
198 &self.inner,
199 "InMemoryCircuitBreakerBackend::increment_failures",
200 );
201 let state = map.entry(service.to_owned()).or_default();
202 state.consecutive_failures += 1;
203 state.consecutive_failures
204 }
205
206 fn reset_failures(&self, service: &str) {
207 let mut map = timed_lock(&self.inner, "InMemoryCircuitBreakerBackend::reset_failures");
208 if let Some(state) = map.get_mut(service) {
209 state.consecutive_failures = 0;
210 }
211 }
212
213 fn get_failures(&self, service: &str) -> u32 {
214 let map = timed_lock(&self.inner, "InMemoryCircuitBreakerBackend::get_failures");
215 map.get(service).map_or(0, |s| s.consecutive_failures)
216 }
217
218 fn set_open_at(&self, service: &str, at: std::time::Instant) {
219 let mut map = timed_lock(&self.inner, "InMemoryCircuitBreakerBackend::set_open_at");
220 map.entry(service.to_owned()).or_default().open_at = Some(at);
221 }
222
223 fn clear_open_at(&self, service: &str) {
224 let mut map = timed_lock(&self.inner, "InMemoryCircuitBreakerBackend::clear_open_at");
225 if let Some(state) = map.get_mut(service) {
226 state.open_at = None;
227 }
228 }
229
230 fn get_open_at(&self, service: &str) -> Option<std::time::Instant> {
231 let map = timed_lock(&self.inner, "InMemoryCircuitBreakerBackend::get_open_at");
232 map.get(service).and_then(|s| s.open_at)
233 }
234}
235
236#[derive(Clone)]
245pub struct CircuitBreaker {
246 threshold: u32,
247 recovery_window: Duration,
248 service: String,
249 backend: Arc<dyn CircuitBreakerBackend>,
250}
251
252impl std::fmt::Debug for CircuitBreaker {
253 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
254 f.debug_struct("CircuitBreaker")
255 .field("threshold", &self.threshold)
256 .field("recovery_window", &self.recovery_window)
257 .field("service", &self.service)
258 .finish()
259 }
260}
261
262impl CircuitBreaker {
263 pub fn new(
270 service: impl Into<String>,
271 threshold: u32,
272 recovery_window: Duration,
273 ) -> Result<Self, AgentRuntimeError> {
274 if threshold == 0 {
275 return Err(AgentRuntimeError::Orchestration(
276 "circuit breaker threshold must be >= 1".into(),
277 ));
278 }
279 let service = service.into();
280 Ok(Self {
281 threshold,
282 recovery_window,
283 service,
284 backend: Arc::new(InMemoryCircuitBreakerBackend::new()),
285 })
286 }
287
288 pub fn with_backend(mut self, backend: Arc<dyn CircuitBreakerBackend>) -> Self {
292 self.backend = backend;
293 self
294 }
295
296 #[tracing::instrument(skip(self, f))]
305 pub fn call<T, E, F>(&self, f: F) -> Result<T, AgentRuntimeError>
306 where
307 F: FnOnce() -> Result<T, E>,
308 E: std::fmt::Display,
309 {
310 let effective_state = match self.backend.get_open_at(&self.service) {
312 Some(opened_at) => {
313 if opened_at.elapsed() >= self.recovery_window {
314 self.backend.clear_open_at(&self.service);
316 tracing::info!("circuit moved to half-open for {}", self.service);
317 CircuitState::HalfOpen
318 } else {
319 CircuitState::Open { opened_at }
320 }
321 }
322 None => {
323 let failures = self.backend.get_failures(&self.service);
327 if failures >= self.threshold {
328 CircuitState::HalfOpen
329 } else {
330 CircuitState::Closed
331 }
332 }
333 };
334
335 tracing::debug!("circuit state: {:?}", effective_state);
336
337 match effective_state {
338 CircuitState::Open { .. } => {
339 return Err(AgentRuntimeError::CircuitOpen {
340 service: self.service.clone(),
341 });
342 }
343 CircuitState::Closed | CircuitState::HalfOpen => {}
344 }
345
346 match f() {
348 Ok(val) => {
349 self.backend.reset_failures(&self.service);
350 self.backend.clear_open_at(&self.service);
351 tracing::info!("circuit closed for {}", self.service);
352 Ok(val)
353 }
354 Err(e) => {
355 let failures = self.backend.increment_failures(&self.service);
356 if failures >= self.threshold {
357 let now = Instant::now();
358 self.backend.set_open_at(&self.service, now);
359 tracing::info!("circuit opened for {}", self.service);
360 }
361 Err(AgentRuntimeError::Orchestration(e.to_string()))
362 }
363 }
364 }
365
366 pub fn state(&self) -> Result<CircuitState, AgentRuntimeError> {
368 let state = match self.backend.get_open_at(&self.service) {
369 Some(opened_at) => {
370 if opened_at.elapsed() >= self.recovery_window {
371 let failures = self.backend.get_failures(&self.service);
373 if failures >= self.threshold {
374 CircuitState::HalfOpen
375 } else {
376 CircuitState::Closed
377 }
378 } else {
379 CircuitState::Open { opened_at }
380 }
381 }
382 None => {
383 let failures = self.backend.get_failures(&self.service);
384 if failures >= self.threshold {
385 CircuitState::HalfOpen
386 } else {
387 CircuitState::Closed
388 }
389 }
390 };
391 Ok(state)
392 }
393
394 pub fn failure_count(&self) -> Result<u32, AgentRuntimeError> {
396 Ok(self.backend.get_failures(&self.service))
397 }
398}
399
400#[derive(Debug, Clone, PartialEq)]
404pub enum DeduplicationResult {
405 New,
407 Cached(String),
409 InProgress,
411}
412
413#[derive(Debug, Clone)]
420pub struct Deduplicator {
421 ttl: Duration,
422 inner: Arc<Mutex<DeduplicatorInner>>,
423}
424
425#[derive(Debug)]
426struct DeduplicatorInner {
427 cache: HashMap<String, (String, Instant)>, in_flight: HashMap<String, Instant>, }
430
431impl Deduplicator {
432 pub fn new(ttl: Duration) -> Self {
434 Self {
435 ttl,
436 inner: Arc::new(Mutex::new(DeduplicatorInner {
437 cache: HashMap::new(),
438 in_flight: HashMap::new(),
439 })),
440 }
441 }
442
443 pub fn check_and_register(&self, key: &str) -> Result<DeduplicationResult, AgentRuntimeError> {
447 let mut inner = timed_lock(&self.inner, "Deduplicator::check_and_register");
448
449 let now = Instant::now();
450
451 inner
453 .cache
454 .retain(|_, (_, ts)| now.duration_since(*ts) < self.ttl);
455 inner
456 .in_flight
457 .retain(|_, ts| now.duration_since(*ts) < self.ttl);
458
459 if let Some((result, _)) = inner.cache.get(key) {
460 return Ok(DeduplicationResult::Cached(result.clone()));
461 }
462
463 if inner.in_flight.contains_key(key) {
464 return Ok(DeduplicationResult::InProgress);
465 }
466
467 inner.in_flight.insert(key.to_owned(), now);
468 Ok(DeduplicationResult::New)
469 }
470
471 pub fn check(&self, key: &str, ttl: std::time::Duration) -> Result<DeduplicationResult, AgentRuntimeError> {
476 let mut inner = timed_lock(&self.inner, "Deduplicator::check");
477 let now = Instant::now();
478
479 inner.cache.retain(|_, (_, ts)| now.duration_since(*ts) < ttl);
480 inner.in_flight.retain(|_, ts| now.duration_since(*ts) < ttl);
481
482 if let Some((result, _)) = inner.cache.get(key) {
483 return Ok(DeduplicationResult::Cached(result.clone()));
484 }
485
486 if inner.in_flight.contains_key(key) {
487 return Ok(DeduplicationResult::InProgress);
488 }
489
490 inner.in_flight.insert(key.to_owned(), now);
491 Ok(DeduplicationResult::New)
492 }
493
494 pub fn dedup_many(
499 &self,
500 requests: &[(&str, std::time::Duration)],
501 ) -> Result<Vec<DeduplicationResult>, AgentRuntimeError> {
502 requests
503 .iter()
504 .map(|(key, ttl)| self.check(key, *ttl))
505 .collect()
506 }
507
508 pub fn complete(&self, key: &str, result: impl Into<String>) -> Result<(), AgentRuntimeError> {
510 let mut inner = timed_lock(&self.inner, "Deduplicator::complete");
511 inner.in_flight.remove(key);
512 inner
513 .cache
514 .insert(key.to_owned(), (result.into(), Instant::now()));
515 Ok(())
516 }
517
518 pub fn fail(&self, key: &str) -> Result<(), AgentRuntimeError> {
523 let mut inner = timed_lock(&self.inner, "Deduplicator::fail");
524 inner.in_flight.remove(key);
525 Ok(())
526 }
527}
528
529#[derive(Debug, Clone)]
539pub struct BackpressureGuard {
540 capacity: usize,
541 soft_capacity: Option<usize>,
542 inner: Arc<Mutex<usize>>,
543}
544
545impl BackpressureGuard {
546 pub fn new(capacity: usize) -> Result<Self, AgentRuntimeError> {
552 if capacity == 0 {
553 return Err(AgentRuntimeError::Orchestration(
554 "BackpressureGuard capacity must be > 0".into(),
555 ));
556 }
557 Ok(Self {
558 capacity,
559 soft_capacity: None,
560 inner: Arc::new(Mutex::new(0)),
561 })
562 }
563
564 pub fn with_soft_limit(mut self, soft: usize) -> Result<Self, AgentRuntimeError> {
567 if soft >= self.capacity {
568 return Err(AgentRuntimeError::Orchestration(
569 "soft_capacity must be less than hard capacity".into(),
570 ));
571 }
572 self.soft_capacity = Some(soft);
573 Ok(self)
574 }
575
576 pub fn try_acquire(&self) -> Result<(), AgentRuntimeError> {
585 let mut depth = timed_lock(&self.inner, "BackpressureGuard::try_acquire");
586 if *depth >= self.capacity {
587 return Err(AgentRuntimeError::BackpressureShed {
588 depth: *depth,
589 capacity: self.capacity,
590 });
591 }
592 *depth += 1;
593 if let Some(soft) = self.soft_capacity {
594 if *depth >= soft {
595 tracing::warn!(
596 depth = *depth,
597 soft_capacity = soft,
598 hard_capacity = self.capacity,
599 "backpressure approaching hard limit"
600 );
601 }
602 }
603 Ok(())
604 }
605
606 pub fn release(&self) -> Result<(), AgentRuntimeError> {
608 let mut depth = timed_lock(&self.inner, "BackpressureGuard::release");
609 *depth = depth.saturating_sub(1);
610 Ok(())
611 }
612
613 pub fn hard_capacity(&self) -> usize {
615 self.capacity
616 }
617
618 pub fn depth(&self) -> Result<usize, AgentRuntimeError> {
620 let depth = timed_lock(&self.inner, "BackpressureGuard::depth");
621 Ok(*depth)
622 }
623
624 pub fn soft_depth_ratio(&self) -> f32 {
629 match self.soft_capacity {
630 None => 0.0,
631 Some(soft) => {
632 let depth = timed_lock(&self.inner, "BackpressureGuard::soft_depth_ratio");
633 *depth as f32 / soft as f32
634 }
635 }
636 }
637}
638
639#[derive(Debug)]
643pub struct PipelineResult {
644 pub output: String,
646 pub stage_timings: Vec<(usize, u64)>,
648}
649
650pub struct Stage {
652 pub name: String,
654 pub handler: Box<dyn Fn(String) -> Result<String, AgentRuntimeError> + Send + Sync>,
656}
657
658impl std::fmt::Debug for Stage {
659 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
660 f.debug_struct("Stage").field("name", &self.name).finish()
661 }
662}
663
664type StageErrorHandler = Box<dyn Fn(&str, &str) -> String + Send + Sync>;
666
667pub struct Pipeline {
674 stages: Vec<Stage>,
675 error_handler: Option<StageErrorHandler>,
676}
677
678impl std::fmt::Debug for Pipeline {
679 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
680 f.debug_struct("Pipeline")
681 .field("stages", &self.stages)
682 .field("has_error_handler", &self.error_handler.is_some())
683 .finish()
684 }
685}
686
687impl Pipeline {
688 pub fn new() -> Self {
690 Self { stages: Vec::new(), error_handler: None }
691 }
692
693 pub fn with_error_handler(
699 mut self,
700 handler: impl Fn(&str, &str) -> String + Send + Sync + 'static,
701 ) -> Self {
702 self.error_handler = Some(Box::new(handler));
703 self
704 }
705
706 pub fn add_stage(
708 mut self,
709 name: impl Into<String>,
710 handler: impl Fn(String) -> Result<String, AgentRuntimeError> + Send + Sync + 'static,
711 ) -> Self {
712 self.stages.push(Stage {
713 name: name.into(),
714 handler: Box::new(handler),
715 });
716 self
717 }
718
719 #[tracing::instrument(skip(self))]
721 pub fn run(&self, input: String) -> Result<String, AgentRuntimeError> {
722 let mut current = input;
723 for stage in &self.stages {
724 tracing::debug!(stage = %stage.name, "running pipeline stage");
725 match (stage.handler)(current) {
726 Ok(out) => current = out,
727 Err(e) => {
728 tracing::error!(stage = %stage.name, error = %e, "pipeline stage failed");
729 if let Some(ref handler) = self.error_handler {
730 current = handler(&stage.name, &e.to_string());
731 } else {
732 return Err(e);
733 }
734 }
735 }
736 }
737 Ok(current)
738 }
739
740 pub fn execute_timed(&self, input: String) -> Result<PipelineResult, AgentRuntimeError> {
742 let mut current = input;
743 let mut stage_timings = Vec::new();
744 for (idx, stage) in self.stages.iter().enumerate() {
745 let start = std::time::Instant::now();
746 tracing::debug!(stage = %stage.name, "running timed pipeline stage");
747 match (stage.handler)(current) {
748 Ok(out) => current = out,
749 Err(e) => {
750 tracing::error!(stage = %stage.name, error = %e, "timed pipeline stage failed");
751 if let Some(ref handler) = self.error_handler {
752 current = handler(&stage.name, &e.to_string());
753 } else {
754 return Err(e);
755 }
756 }
757 }
758 let duration_ms = start.elapsed().as_millis() as u64;
759 stage_timings.push((idx, duration_ms));
760 }
761 Ok(PipelineResult {
762 output: current,
763 stage_timings,
764 })
765 }
766
767 pub fn stage_count(&self) -> usize {
769 self.stages.len()
770 }
771}
772
773impl Default for Pipeline {
774 fn default() -> Self {
775 Self::new()
776 }
777}
778
779#[cfg(test)]
782mod tests {
783 use super::*;
784
785 #[test]
788 fn test_retry_policy_rejects_zero_attempts() {
789 assert!(RetryPolicy::exponential(0, 100).is_err());
790 }
791
792 #[test]
793 fn test_retry_policy_delay_attempt_1_equals_base() {
794 let p = RetryPolicy::exponential(3, 100).unwrap();
795 assert_eq!(p.delay_for(1), Duration::from_millis(100));
796 }
797
798 #[test]
799 fn test_retry_policy_delay_doubles_each_attempt() {
800 let p = RetryPolicy::exponential(5, 100).unwrap();
801 assert_eq!(p.delay_for(2), Duration::from_millis(200));
802 assert_eq!(p.delay_for(3), Duration::from_millis(400));
803 assert_eq!(p.delay_for(4), Duration::from_millis(800));
804 }
805
806 #[test]
807 fn test_retry_policy_delay_capped_at_max() {
808 let p = RetryPolicy::exponential(10, 10_000).unwrap();
809 assert_eq!(p.delay_for(10), MAX_RETRY_DELAY);
810 }
811
812 #[test]
813 fn test_retry_policy_delay_never_exceeds_max_for_any_attempt() {
814 let p = RetryPolicy::exponential(10, 1000).unwrap();
815 for attempt in 1..=10 {
816 assert!(p.delay_for(attempt) <= MAX_RETRY_DELAY);
817 }
818 }
819
820 #[test]
823 fn test_circuit_breaker_rejects_zero_threshold() {
824 assert!(CircuitBreaker::new("svc", 0, Duration::from_secs(1)).is_err());
825 }
826
827 #[test]
828 fn test_circuit_breaker_starts_closed() {
829 let cb = CircuitBreaker::new("svc", 3, Duration::from_secs(60)).unwrap();
830 assert_eq!(cb.state().unwrap(), CircuitState::Closed);
831 }
832
833 #[test]
834 fn test_circuit_breaker_success_keeps_closed() {
835 let cb = CircuitBreaker::new("svc", 3, Duration::from_secs(60)).unwrap();
836 let result: Result<i32, AgentRuntimeError> = cb.call(|| Ok::<i32, AgentRuntimeError>(42));
837 assert!(result.is_ok());
838 assert_eq!(cb.state().unwrap(), CircuitState::Closed);
839 }
840
841 #[test]
842 fn test_circuit_breaker_opens_after_threshold_failures() {
843 let cb = CircuitBreaker::new("svc", 3, Duration::from_secs(60)).unwrap();
844 for _ in 0..3 {
845 let _: Result<(), AgentRuntimeError> = cb.call(|| Err::<(), _>("oops".to_string()));
846 }
847 assert!(matches!(cb.state().unwrap(), CircuitState::Open { .. }));
848 }
849
850 #[test]
851 fn test_circuit_breaker_open_fast_fails() {
852 let cb = CircuitBreaker::new("svc", 1, Duration::from_secs(3600)).unwrap();
853 let _: Result<(), AgentRuntimeError> = cb.call(|| Err::<(), _>("fail".to_string()));
854 let result: Result<(), AgentRuntimeError> = cb.call(|| Ok::<(), AgentRuntimeError>(()));
855 assert!(matches!(result, Err(AgentRuntimeError::CircuitOpen { .. })));
856 }
857
858 #[test]
859 fn test_circuit_breaker_success_resets_failure_count() {
860 let cb = CircuitBreaker::new("svc", 5, Duration::from_secs(60)).unwrap();
861 let _: Result<(), AgentRuntimeError> = cb.call(|| Err::<(), _>("fail".to_string()));
862 let _: Result<(), AgentRuntimeError> = cb.call(|| Err::<(), _>("fail".to_string()));
863 let _: Result<i32, AgentRuntimeError> = cb.call(|| Ok::<i32, AgentRuntimeError>(1));
864 assert_eq!(cb.failure_count().unwrap(), 0);
865 }
866
867 #[test]
868 fn test_circuit_breaker_half_open_on_recovery() {
869 let cb = CircuitBreaker::new("svc", 1, Duration::ZERO).unwrap();
871 let _: Result<(), AgentRuntimeError> = cb.call(|| Err::<(), _>("fail".to_string()));
872 let result: Result<i32, AgentRuntimeError> = cb.call(|| Ok::<i32, AgentRuntimeError>(99));
874 assert_eq!(result.unwrap_or(0), 99);
875 assert_eq!(cb.state().unwrap(), CircuitState::Closed);
876 }
877
878 #[test]
879 fn test_circuit_breaker_with_custom_backend_uses_backend_state() {
880 let shared_backend: Arc<dyn CircuitBreakerBackend> =
883 Arc::new(InMemoryCircuitBreakerBackend::new());
884
885 let cb1 = CircuitBreaker::new("svc", 2, Duration::from_secs(60))
886 .unwrap()
887 .with_backend(Arc::clone(&shared_backend));
888
889 let cb2 = CircuitBreaker::new("svc", 2, Duration::from_secs(60))
890 .unwrap()
891 .with_backend(Arc::clone(&shared_backend));
892
893 let _: Result<(), AgentRuntimeError> = cb1.call(|| Err::<(), _>("fail".to_string()));
895
896 assert_eq!(cb2.failure_count().unwrap(), 1);
898
899 let _: Result<(), AgentRuntimeError> = cb1.call(|| Err::<(), _>("fail again".to_string()));
901
902 assert!(matches!(cb2.state().unwrap(), CircuitState::Open { .. }));
904 }
905
906 #[test]
907 fn test_in_memory_backend_increments_and_resets() {
908 let backend = InMemoryCircuitBreakerBackend::new();
909
910 assert_eq!(backend.get_failures("svc"), 0);
911
912 let count = backend.increment_failures("svc");
913 assert_eq!(count, 1);
914
915 let count = backend.increment_failures("svc");
916 assert_eq!(count, 2);
917
918 backend.reset_failures("svc");
919 assert_eq!(backend.get_failures("svc"), 0);
920
921 assert!(backend.get_open_at("svc").is_none());
923 let now = Instant::now();
924 backend.set_open_at("svc", now);
925 assert!(backend.get_open_at("svc").is_some());
926 backend.clear_open_at("svc");
927 assert!(backend.get_open_at("svc").is_none());
928 }
929
930 #[test]
933 fn test_deduplicator_new_key_is_new() {
934 let d = Deduplicator::new(Duration::from_secs(60));
935 let r = d.check_and_register("key-1").unwrap();
936 assert_eq!(r, DeduplicationResult::New);
937 }
938
939 #[test]
940 fn test_deduplicator_second_check_is_in_progress() {
941 let d = Deduplicator::new(Duration::from_secs(60));
942 d.check_and_register("key-1").unwrap();
943 let r = d.check_and_register("key-1").unwrap();
944 assert_eq!(r, DeduplicationResult::InProgress);
945 }
946
947 #[test]
948 fn test_deduplicator_complete_makes_cached() {
949 let d = Deduplicator::new(Duration::from_secs(60));
950 d.check_and_register("key-1").unwrap();
951 d.complete("key-1", "result-value").unwrap();
952 let r = d.check_and_register("key-1").unwrap();
953 assert_eq!(r, DeduplicationResult::Cached("result-value".into()));
954 }
955
956 #[test]
957 fn test_deduplicator_different_keys_are_independent() {
958 let d = Deduplicator::new(Duration::from_secs(60));
959 d.check_and_register("key-a").unwrap();
960 let r = d.check_and_register("key-b").unwrap();
961 assert_eq!(r, DeduplicationResult::New);
962 }
963
964 #[test]
965 fn test_deduplicator_expired_entry_is_new() {
966 let d = Deduplicator::new(Duration::ZERO); d.check_and_register("key-1").unwrap();
968 d.complete("key-1", "old").unwrap();
969 let r = d.check_and_register("key-1").unwrap();
971 assert_eq!(r, DeduplicationResult::New);
972 }
973
974 #[test]
977 fn test_backpressure_guard_rejects_zero_capacity() {
978 assert!(BackpressureGuard::new(0).is_err());
979 }
980
981 #[test]
982 fn test_backpressure_guard_acquire_within_capacity() {
983 let g = BackpressureGuard::new(5).unwrap();
984 assert!(g.try_acquire().is_ok());
985 assert_eq!(g.depth().unwrap(), 1);
986 }
987
988 #[test]
989 fn test_backpressure_guard_sheds_when_full() {
990 let g = BackpressureGuard::new(2).unwrap();
991 g.try_acquire().unwrap();
992 g.try_acquire().unwrap();
993 let result = g.try_acquire();
994 assert!(matches!(
995 result,
996 Err(AgentRuntimeError::BackpressureShed { .. })
997 ));
998 }
999
1000 #[test]
1001 fn test_backpressure_guard_release_decrements_depth() {
1002 let g = BackpressureGuard::new(3).unwrap();
1003 g.try_acquire().unwrap();
1004 g.try_acquire().unwrap();
1005 g.release().unwrap();
1006 assert_eq!(g.depth().unwrap(), 1);
1007 }
1008
1009 #[test]
1010 fn test_backpressure_guard_release_on_empty_is_noop() {
1011 let g = BackpressureGuard::new(3).unwrap();
1012 g.release().unwrap(); assert_eq!(g.depth().unwrap(), 0);
1014 }
1015
1016 #[test]
1019 fn test_pipeline_runs_stages_in_order() {
1020 let p = Pipeline::new()
1021 .add_stage("upper", |s| Ok(s.to_uppercase()))
1022 .add_stage("append", |s| Ok(format!("{s}!")));
1023 let result = p.run("hello".into()).unwrap();
1024 assert_eq!(result, "HELLO!");
1025 }
1026
1027 #[test]
1028 fn test_pipeline_empty_pipeline_returns_input() {
1029 let p = Pipeline::new();
1030 assert_eq!(p.run("test".into()).unwrap(), "test");
1031 }
1032
1033 #[test]
1034 fn test_pipeline_stage_failure_short_circuits() {
1035 let p = Pipeline::new()
1036 .add_stage("fail", |_| {
1037 Err(AgentRuntimeError::Orchestration("boom".into()))
1038 })
1039 .add_stage("never", |s| Ok(s));
1040 assert!(p.run("input".into()).is_err());
1041 }
1042
1043 #[test]
1044 fn test_pipeline_stage_count() {
1045 let p = Pipeline::new()
1046 .add_stage("s1", |s| Ok(s))
1047 .add_stage("s2", |s| Ok(s));
1048 assert_eq!(p.stage_count(), 2);
1049 }
1050
1051 #[test]
1052 fn test_pipeline_execute_timed_captures_stage_durations() {
1053 let p = Pipeline::new()
1054 .add_stage("s1", |s| Ok(format!("{s}1")))
1055 .add_stage("s2", |s| Ok(format!("{s}2")));
1056 let result = p.execute_timed("x".to_string()).unwrap();
1057 assert_eq!(result.output, "x12");
1058 assert_eq!(result.stage_timings.len(), 2);
1059 assert_eq!(result.stage_timings[0].0, 0);
1060 assert_eq!(result.stage_timings[1].0, 1);
1061 }
1062
1063 #[test]
1066 fn test_backpressure_soft_limit_rejects_invalid_config() {
1067 let g = BackpressureGuard::new(5).unwrap();
1069 assert!(g.with_soft_limit(5).is_err());
1070 let g = BackpressureGuard::new(5).unwrap();
1071 assert!(g.with_soft_limit(6).is_err());
1072 }
1073
1074 #[test]
1075 fn test_backpressure_soft_limit_accepts_requests_below_soft() {
1076 let g = BackpressureGuard::new(5)
1077 .unwrap()
1078 .with_soft_limit(2)
1079 .unwrap();
1080 assert!(g.try_acquire().is_ok());
1082 assert!(g.try_acquire().is_ok());
1083 assert_eq!(g.depth().unwrap(), 2);
1084 }
1085
1086 #[test]
1087 fn test_backpressure_with_soft_limit_still_sheds_at_hard_capacity() {
1088 let g = BackpressureGuard::new(3)
1089 .unwrap()
1090 .with_soft_limit(2)
1091 .unwrap();
1092 g.try_acquire().unwrap();
1093 g.try_acquire().unwrap();
1094 g.try_acquire().unwrap(); let result = g.try_acquire();
1096 assert!(matches!(
1097 result,
1098 Err(AgentRuntimeError::BackpressureShed { .. })
1099 ));
1100 }
1101
1102 #[test]
1105 fn test_backpressure_hard_capacity_matches_new() {
1106 let g = BackpressureGuard::new(7).unwrap();
1107 assert_eq!(g.hard_capacity(), 7);
1108 }
1109
1110 #[test]
1113 fn test_pipeline_error_handler_recovers_from_stage_failure() {
1114 let p = Pipeline::new()
1115 .add_stage("fail_stage", |_| {
1116 Err(AgentRuntimeError::Orchestration("oops".into()))
1117 })
1118 .add_stage("append", |s| Ok(format!("{s}-recovered")))
1119 .with_error_handler(|stage_name, _err| format!("recovered_from_{stage_name}"));
1120 let result = p.run("input".to_string()).unwrap();
1121 assert_eq!(result, "recovered_from_fail_stage-recovered");
1122 }
1123
1124 #[test]
1127 fn test_circuit_state_eq() {
1128 assert_eq!(CircuitState::Closed, CircuitState::Closed);
1129 assert_eq!(CircuitState::HalfOpen, CircuitState::HalfOpen);
1130 assert_eq!(
1131 CircuitState::Open { opened_at: std::time::Instant::now() },
1132 CircuitState::Open { opened_at: std::time::Instant::now() }
1133 );
1134 assert_ne!(CircuitState::Closed, CircuitState::HalfOpen);
1135 assert_ne!(CircuitState::Closed, CircuitState::Open { opened_at: std::time::Instant::now() });
1136 }
1137
1138 #[test]
1141 fn test_dedup_many_independent_keys() {
1142 let d = Deduplicator::new(Duration::from_secs(60));
1143 let ttl = Duration::from_secs(60);
1144 let results = d.dedup_many(&[("key-a", ttl), ("key-b", ttl), ("key-c", ttl)]).unwrap();
1145 assert_eq!(results.len(), 3);
1146 assert!(results.iter().all(|r| matches!(r, DeduplicationResult::New)));
1147 }
1148
1149 #[test]
1152 fn test_concurrent_circuit_breaker_opens_under_concurrent_failures() {
1153 use std::sync::Arc;
1154 use std::thread;
1155
1156 let cb = Arc::new(
1157 CircuitBreaker::new("svc", 5, Duration::from_secs(60)).unwrap(),
1158 );
1159 let n_threads = 8;
1160 let failures_per_thread = 2;
1161
1162 let mut handles = Vec::new();
1163 for _ in 0..n_threads {
1164 let cb = Arc::clone(&cb);
1165 handles.push(thread::spawn(move || {
1166 for _ in 0..failures_per_thread {
1167 let _ = cb.call(|| Err::<(), &str>("fail"));
1168 }
1169 }));
1170 }
1171 for h in handles {
1172 h.join().unwrap();
1173 }
1174
1175 let state = cb.state().unwrap();
1178 assert!(
1179 matches!(state, CircuitState::Open { .. }),
1180 "circuit should be open after many concurrent failures; got: {state:?}"
1181 );
1182 }
1183
1184 #[test]
1185 fn test_per_service_tracking_is_independent() {
1186 let backend = Arc::new(InMemoryCircuitBreakerBackend::new());
1187
1188 let cb_a = CircuitBreaker::new("service-a", 3, Duration::from_secs(60))
1189 .unwrap()
1190 .with_backend(Arc::clone(&backend) as Arc<dyn CircuitBreakerBackend>);
1191 let cb_b = CircuitBreaker::new("service-b", 3, Duration::from_secs(60))
1192 .unwrap()
1193 .with_backend(Arc::clone(&backend) as Arc<dyn CircuitBreakerBackend>);
1194
1195 for _ in 0..3 {
1197 let _ = cb_a.call(|| Err::<(), &str>("fail"));
1198 }
1199
1200 let state_b = cb_b.state().unwrap();
1202 assert_eq!(
1203 state_b,
1204 CircuitState::Closed,
1205 "service-b should be unaffected by service-a failures"
1206 );
1207
1208 let state_a = cb_a.state().unwrap();
1210 assert!(
1211 matches!(state_a, CircuitState::Open { .. }),
1212 "service-a should be open"
1213 );
1214 }
1215
1216 #[test]
1219 fn test_backpressure_concurrent_acquires_are_consistent() {
1220 use std::sync::Arc;
1221 use std::thread;
1222
1223 let g = Arc::new(BackpressureGuard::new(100).unwrap());
1224 let mut handles = Vec::new();
1225
1226 for _ in 0..10 {
1227 let g_clone = Arc::clone(&g);
1228 handles.push(thread::spawn(move || {
1229 g_clone.try_acquire().ok();
1230 }));
1231 }
1232
1233 for h in handles {
1234 h.join().unwrap();
1235 }
1236
1237 assert_eq!(g.depth().unwrap(), 10);
1239 }
1240}