Skip to main content

llm_agent_runtime/
orchestrator.rs

1//! # Module: Orchestrator
2//!
3//! ## Responsibility
4//! Provides a composable LLM pipeline with circuit breaking, retry, deduplication,
5//! and backpressure. Mirrors the public API of `tokio-prompt-orchestrator`.
6//!
7//! ## Guarantees
8//! - Thread-safe: all types wrap state in `Arc<Mutex<_>>` or atomics
9//! - Circuit breaker opens after `threshold` failures within `window` calls
10//! - RetryPolicy delays grow exponentially and are capped at `MAX_RETRY_DELAY`
11//! - Deduplicator is deterministic and non-blocking
12//! - BackpressureGuard never exceeds declared capacity
13//! - Non-panicking: all operations return `Result`
14//!
15//! ## NOT Responsible For
16//! - Cross-node circuit breakers (single-process only, unless a distributed backend is provided)
17//! - Persistent deduplication (in-memory, bounded TTL)
18//! - Distributed backpressure
19
20use crate::error::AgentRuntimeError;
21use std::collections::HashMap;
22use std::sync::{Arc, Mutex};
23use std::time::{Duration, Instant};
24
25/// Maximum delay between retries — caps exponential growth.
26pub const MAX_RETRY_DELAY: Duration = Duration::from_secs(60);
27
28// ── Lock recovery helpers ──────────────────────────────────────────────────────
29
30/// Recover from a poisoned mutex by extracting the inner value.
31///
32/// Instead of propagating a lock-poison error, this logs a warning and
33/// returns the inner guard so the caller can continue operating.
34#[allow(dead_code)]
35fn recover_lock<'a, T>(
36    result: std::sync::LockResult<std::sync::MutexGuard<'a, T>>,
37    ctx: &str,
38) -> std::sync::MutexGuard<'a, T>
39where
40    T: ?Sized,
41{
42    match result {
43        Ok(guard) => guard,
44        Err(poisoned) => {
45            tracing::warn!("mutex poisoned in {ctx}, recovering inner value");
46            poisoned.into_inner()
47        }
48    }
49}
50
51/// Acquire a lock, recovering from poisoning and logging slow acquisitions.
52///
53/// Times the lock acquisition; if it exceeds 5 ms a warning is emitted so
54/// contention hot-spots can be identified in production logs.
55fn timed_lock<'a, T>(mutex: &'a Mutex<T>, ctx: &str) -> std::sync::MutexGuard<'a, T>
56where
57    T: ?Sized,
58{
59    let start = std::time::Instant::now();
60    let result = mutex.lock();
61    let elapsed = start.elapsed();
62    if elapsed > std::time::Duration::from_millis(5) {
63        tracing::warn!(
64            duration_ms = elapsed.as_millis(),
65            ctx = ctx,
66            "slow mutex acquisition"
67        );
68    }
69    match result {
70        Ok(guard) => guard,
71        Err(poisoned) => {
72            tracing::warn!("mutex poisoned in {ctx}, recovering inner value");
73            poisoned.into_inner()
74        }
75    }
76}
77
78// ── RetryPolicy ───────────────────────────────────────────────────────────────
79
80/// Exponential backoff retry policy.
81#[derive(Debug, Clone)]
82pub struct RetryPolicy {
83    /// Maximum number of attempts (including the first).
84    pub max_attempts: u32,
85    /// Base delay for the first retry.
86    pub base_delay: Duration,
87}
88
89impl RetryPolicy {
90    /// Create an exponential retry policy.
91    ///
92    /// # Arguments
93    /// * `max_attempts` — total attempt budget (must be ≥ 1)
94    /// * `base_ms` — base delay in milliseconds for attempt 1
95    ///
96    /// # Returns
97    /// - `Ok(RetryPolicy)` — on success
98    /// - `Err(AgentRuntimeError::Orchestration)` — if `max_attempts == 0`
99    pub fn exponential(max_attempts: u32, base_ms: u64) -> Result<Self, AgentRuntimeError> {
100        if max_attempts == 0 {
101            return Err(AgentRuntimeError::Orchestration(
102                "max_attempts must be >= 1".into(),
103            ));
104        }
105        Ok(Self {
106            max_attempts,
107            base_delay: Duration::from_millis(base_ms),
108        })
109    }
110
111    /// Compute the delay before the given attempt number (1-based).
112    ///
113    /// Delay = `base_delay * 2^(attempt-1)`, capped at `MAX_RETRY_DELAY`.
114    pub fn delay_for(&self, attempt: u32) -> Duration {
115        let exp = attempt.saturating_sub(1);
116        let multiplier = 1u64.checked_shl(exp.min(63)).unwrap_or(u64::MAX);
117        let millis = self
118            .base_delay
119            .as_millis()
120            .saturating_mul(multiplier as u128);
121        let raw = Duration::from_millis(millis.min(u64::MAX as u128) as u64);
122        raw.min(MAX_RETRY_DELAY)
123    }
124}
125
126// ── CircuitBreaker ────────────────────────────────────────────────────────────
127
128/// Tracks failure rates and opens when the threshold is exceeded.
129///
130/// States: `Closed` (normal) → `Open` (fast-fail) → `HalfOpen` (probe).
131#[derive(Debug, Clone, PartialEq)]
132pub enum CircuitState {
133    /// Circuit is operating normally; requests pass through.
134    Closed,
135    /// Circuit has tripped; requests are fast-failed without calling the operation.
136    Open {
137        /// The instant at which the circuit was opened.
138        opened_at: Instant,
139    },
140    /// Recovery probe period; the next request will be attempted to test recovery.
141    HalfOpen,
142}
143
144/// Backend for circuit breaker state storage.
145///
146/// Implement this trait to share circuit breaker state across processes
147/// (e.g., via Redis). The in-process default is `InMemoryCircuitBreakerBackend`.
148///
149/// Note: Methods are synchronous to avoid pulling in `async-trait`. A
150/// distributed backend (e.g., Redis) can internally spawn a Tokio runtime.
151pub trait CircuitBreakerBackend: Send + Sync {
152    /// Increment the consecutive failure count for `service` and return the new count.
153    fn increment_failures(&self, service: &str) -> u32;
154    /// Reset the consecutive failure count for `service` to zero.
155    fn reset_failures(&self, service: &str);
156    /// Return the current consecutive failure count for `service`.
157    fn get_failures(&self, service: &str) -> u32;
158    /// Record the instant at which the circuit was opened for `service`.
159    fn set_open_at(&self, service: &str, at: std::time::Instant);
160    /// Clear the open-at timestamp, effectively moving the circuit to Closed or HalfOpen.
161    fn clear_open_at(&self, service: &str);
162    /// Return the instant at which the circuit was opened, or `None` if it is not open.
163    fn get_open_at(&self, service: &str) -> Option<std::time::Instant>;
164}
165
166// ── InMemoryCircuitBreakerBackend ─────────────────────────────────────────────
167
168/// In-process circuit breaker backend backed by a `Mutex`.
169pub struct InMemoryCircuitBreakerBackend {
170    inner: Arc<Mutex<InMemoryBackendState>>,
171}
172
173struct InMemoryBackendState {
174    consecutive_failures: u32,
175    open_at: Option<std::time::Instant>,
176}
177
178impl InMemoryCircuitBreakerBackend {
179    /// Create a new in-memory backend with all counters at zero.
180    pub fn new() -> Self {
181        Self {
182            inner: Arc::new(Mutex::new(InMemoryBackendState {
183                consecutive_failures: 0,
184                open_at: None,
185            })),
186        }
187    }
188}
189
190impl Default for InMemoryCircuitBreakerBackend {
191    fn default() -> Self {
192        Self::new()
193    }
194}
195
196impl CircuitBreakerBackend for InMemoryCircuitBreakerBackend {
197    fn increment_failures(&self, _service: &str) -> u32 {
198        let mut state = timed_lock(&self.inner, "InMemoryCircuitBreakerBackend::increment_failures");
199        state.consecutive_failures += 1;
200        state.consecutive_failures
201    }
202
203    fn reset_failures(&self, _service: &str) {
204        let mut state = timed_lock(&self.inner, "InMemoryCircuitBreakerBackend::reset_failures");
205        state.consecutive_failures = 0;
206    }
207
208    fn get_failures(&self, _service: &str) -> u32 {
209        let state = timed_lock(&self.inner, "InMemoryCircuitBreakerBackend::get_failures");
210        state.consecutive_failures
211    }
212
213    fn set_open_at(&self, _service: &str, at: std::time::Instant) {
214        let mut state = timed_lock(&self.inner, "InMemoryCircuitBreakerBackend::set_open_at");
215        state.open_at = Some(at);
216    }
217
218    fn clear_open_at(&self, _service: &str) {
219        let mut state = timed_lock(&self.inner, "InMemoryCircuitBreakerBackend::clear_open_at");
220        state.open_at = None;
221    }
222
223    fn get_open_at(&self, _service: &str) -> Option<std::time::Instant> {
224        let state = timed_lock(&self.inner, "InMemoryCircuitBreakerBackend::get_open_at");
225        state.open_at
226    }
227}
228
229// ── CircuitBreaker ────────────────────────────────────────────────────────────
230
231/// Circuit breaker guarding a fallible operation.
232///
233/// ## Guarantees
234/// - Opens after `threshold` consecutive failures
235/// - Transitions to `HalfOpen` after `recovery_window` has elapsed
236/// - Closes on the first successful probe in `HalfOpen`
237#[derive(Clone)]
238pub struct CircuitBreaker {
239    threshold: u32,
240    recovery_window: Duration,
241    service: String,
242    backend: Arc<dyn CircuitBreakerBackend>,
243}
244
245impl std::fmt::Debug for CircuitBreaker {
246    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247        f.debug_struct("CircuitBreaker")
248            .field("threshold", &self.threshold)
249            .field("recovery_window", &self.recovery_window)
250            .field("service", &self.service)
251            .finish()
252    }
253}
254
255impl CircuitBreaker {
256    /// Create a new circuit breaker backed by an in-memory backend.
257    ///
258    /// # Arguments
259    /// * `service` — name used in error messages and logs
260    /// * `threshold` — consecutive failures before opening
261    /// * `recovery_window` — how long to stay open before probing
262    pub fn new(
263        service: impl Into<String>,
264        threshold: u32,
265        recovery_window: Duration,
266    ) -> Result<Self, AgentRuntimeError> {
267        if threshold == 0 {
268            return Err(AgentRuntimeError::Orchestration(
269                "circuit breaker threshold must be >= 1".into(),
270            ));
271        }
272        let service = service.into();
273        Ok(Self {
274            threshold,
275            recovery_window,
276            service,
277            backend: Arc::new(InMemoryCircuitBreakerBackend::new()),
278        })
279    }
280
281    /// Replace the default in-memory backend with a custom one.
282    ///
283    /// Useful for sharing circuit breaker state across processes.
284    pub fn with_backend(mut self, backend: Arc<dyn CircuitBreakerBackend>) -> Self {
285        self.backend = backend;
286        self
287    }
288
289    /// Derive the current `CircuitState` from backend storage.
290    #[allow(dead_code)]
291    fn current_state(&self) -> CircuitState {
292        match self.backend.get_open_at(&self.service) {
293            Some(opened_at) => CircuitState::Open { opened_at },
294            None => {
295                // We encode HalfOpen as failures == threshold but no open_at.
296                // However, the transition to HalfOpen is done in `call`, so
297                // outside of `call` we can only observe Closed here.
298                // The `call` method manages the HalfOpen flag via a per-call
299                // transient field — see below.
300                CircuitState::Closed
301            }
302        }
303    }
304
305    /// Attempt to call `f`, respecting the circuit breaker state.
306    ///
307    /// # Returns
308    /// - `Ok(T)` — if `f` succeeds (resets failure count)
309    /// - `Err(AgentRuntimeError::CircuitOpen)` — if the breaker is open
310    /// - `Err(...)` — if `f` fails (may open the breaker)
311    #[tracing::instrument(skip(self, f))]
312    pub fn call<T, E, F>(&self, f: F) -> Result<T, AgentRuntimeError>
313    where
314        F: FnOnce() -> Result<T, E>,
315        E: std::fmt::Display,
316    {
317        // Determine effective state, potentially transitioning Open → HalfOpen.
318        let effective_state = match self.backend.get_open_at(&self.service) {
319            Some(opened_at) => {
320                if opened_at.elapsed() >= self.recovery_window {
321                    // Clear open_at to signal HalfOpen; failures remain.
322                    self.backend.clear_open_at(&self.service);
323                    tracing::info!("circuit moved to half-open for {}", self.service);
324                    CircuitState::HalfOpen
325                } else {
326                    CircuitState::Open { opened_at }
327                }
328            }
329            None => {
330                // Either Closed or HalfOpen (after a prior transition).
331                // We distinguish by checking whether failures >= threshold
332                // but no open_at is set — that means we are in HalfOpen.
333                let failures = self.backend.get_failures(&self.service);
334                if failures >= self.threshold {
335                    CircuitState::HalfOpen
336                } else {
337                    CircuitState::Closed
338                }
339            }
340        };
341
342        tracing::debug!("circuit state: {:?}", effective_state);
343
344        match effective_state {
345            CircuitState::Open { .. } => {
346                return Err(AgentRuntimeError::CircuitOpen {
347                    service: self.service.clone(),
348                });
349            }
350            CircuitState::Closed | CircuitState::HalfOpen => {}
351        }
352
353        // Execute the operation.
354        match f() {
355            Ok(val) => {
356                self.backend.reset_failures(&self.service);
357                self.backend.clear_open_at(&self.service);
358                tracing::info!("circuit closed for {}", self.service);
359                Ok(val)
360            }
361            Err(e) => {
362                let failures = self.backend.increment_failures(&self.service);
363                if failures >= self.threshold {
364                    let now = Instant::now();
365                    self.backend.set_open_at(&self.service, now);
366                    tracing::info!("circuit opened for {}", self.service);
367                }
368                Err(AgentRuntimeError::Orchestration(e.to_string()))
369            }
370        }
371    }
372
373    /// Return the current circuit state.
374    pub fn state(&self) -> Result<CircuitState, AgentRuntimeError> {
375        let state = match self.backend.get_open_at(&self.service) {
376            Some(opened_at) => {
377                if opened_at.elapsed() >= self.recovery_window {
378                    // Would transition to HalfOpen on next call; report HalfOpen.
379                    let failures = self.backend.get_failures(&self.service);
380                    if failures >= self.threshold {
381                        CircuitState::HalfOpen
382                    } else {
383                        CircuitState::Closed
384                    }
385                } else {
386                    CircuitState::Open { opened_at }
387                }
388            }
389            None => {
390                let failures = self.backend.get_failures(&self.service);
391                if failures >= self.threshold {
392                    CircuitState::HalfOpen
393                } else {
394                    CircuitState::Closed
395                }
396            }
397        };
398        Ok(state)
399    }
400
401    /// Return the consecutive failure count.
402    pub fn failure_count(&self) -> Result<u32, AgentRuntimeError> {
403        Ok(self.backend.get_failures(&self.service))
404    }
405}
406
407// ── DeduplicationResult ───────────────────────────────────────────────────────
408
409/// Result of a deduplication check.
410#[derive(Debug, Clone, PartialEq)]
411pub enum DeduplicationResult {
412    /// This is a new, unseen request.
413    New,
414    /// A cached result exists for this key.
415    Cached(String),
416    /// A matching request is currently in-flight.
417    InProgress,
418}
419
420/// Deduplicates requests by key within a TTL window.
421///
422/// ## Guarantees
423/// - Deterministic: same key always maps to the same result
424/// - Thread-safe via `Arc<Mutex<_>>`
425/// - Entries expire after `ttl`
426#[derive(Debug, Clone)]
427pub struct Deduplicator {
428    ttl: Duration,
429    inner: Arc<Mutex<DeduplicatorInner>>,
430}
431
432#[derive(Debug)]
433struct DeduplicatorInner {
434    cache: HashMap<String, (String, Instant)>, // key → (result, inserted_at)
435    in_flight: HashMap<String, Instant>,       // key → started_at
436}
437
438impl Deduplicator {
439    /// Create a new deduplicator with the given TTL.
440    pub fn new(ttl: Duration) -> Self {
441        Self {
442            ttl,
443            inner: Arc::new(Mutex::new(DeduplicatorInner {
444                cache: HashMap::new(),
445                in_flight: HashMap::new(),
446            })),
447        }
448    }
449
450    /// Check whether `key` is new, cached, or in-flight.
451    ///
452    /// Marks the key as in-flight if it is new.
453    pub fn check_and_register(&self, key: &str) -> Result<DeduplicationResult, AgentRuntimeError> {
454        let mut inner = timed_lock(&self.inner, "Deduplicator::check_and_register");
455
456        let now = Instant::now();
457
458        // Expire stale cache entries
459        inner
460            .cache
461            .retain(|_, (_, ts)| now.duration_since(*ts) < self.ttl);
462        inner
463            .in_flight
464            .retain(|_, ts| now.duration_since(*ts) < self.ttl);
465
466        if let Some((result, _)) = inner.cache.get(key) {
467            return Ok(DeduplicationResult::Cached(result.clone()));
468        }
469
470        if inner.in_flight.contains_key(key) {
471            return Ok(DeduplicationResult::InProgress);
472        }
473
474        inner.in_flight.insert(key.to_owned(), now);
475        Ok(DeduplicationResult::New)
476    }
477
478    /// Complete a request: move from in-flight to cached with the given result.
479    pub fn complete(&self, key: &str, result: impl Into<String>) -> Result<(), AgentRuntimeError> {
480        let mut inner = timed_lock(&self.inner, "Deduplicator::complete");
481        inner.in_flight.remove(key);
482        inner
483            .cache
484            .insert(key.to_owned(), (result.into(), Instant::now()));
485        Ok(())
486    }
487}
488
489// ── BackpressureGuard ─────────────────────────────────────────────────────────
490
491/// Tracks in-flight work count and enforces a capacity limit.
492///
493/// ## Guarantees
494/// - Thread-safe via `Arc<Mutex<_>>`
495/// - `try_acquire` is non-blocking
496/// - `release` decrements the counter; no-op if counter is already 0
497/// - Optional soft limit emits a warning when depth reaches the threshold
498#[derive(Debug, Clone)]
499pub struct BackpressureGuard {
500    capacity: usize,
501    soft_capacity: Option<usize>,
502    inner: Arc<Mutex<usize>>,
503}
504
505impl BackpressureGuard {
506    /// Create a new guard with the given capacity.
507    ///
508    /// # Returns
509    /// - `Ok(BackpressureGuard)` — on success
510    /// - `Err(AgentRuntimeError::Orchestration)` — if `capacity == 0`
511    pub fn new(capacity: usize) -> Result<Self, AgentRuntimeError> {
512        if capacity == 0 {
513            return Err(AgentRuntimeError::Orchestration(
514                "BackpressureGuard capacity must be > 0".into(),
515            ));
516        }
517        Ok(Self {
518            capacity,
519            soft_capacity: None,
520            inner: Arc::new(Mutex::new(0)),
521        })
522    }
523
524    /// Set a soft capacity threshold. When depth reaches this level, a warning
525    /// is logged but the request is still accepted (up to hard capacity).
526    pub fn with_soft_limit(mut self, soft: usize) -> Result<Self, AgentRuntimeError> {
527        if soft >= self.capacity {
528            return Err(AgentRuntimeError::Orchestration(
529                "soft_capacity must be less than hard capacity".into(),
530            ));
531        }
532        self.soft_capacity = Some(soft);
533        Ok(self)
534    }
535
536    /// Try to acquire a slot.
537    ///
538    /// Emits a warning when the soft limit is reached (if configured), but
539    /// still accepts the request until hard capacity is exceeded.
540    ///
541    /// # Returns
542    /// - `Ok(())` — slot acquired
543    /// - `Err(AgentRuntimeError::BackpressureShed)` — hard capacity exceeded
544    pub fn try_acquire(&self) -> Result<(), AgentRuntimeError> {
545        let mut depth = timed_lock(&self.inner, "BackpressureGuard::try_acquire");
546        if *depth >= self.capacity {
547            return Err(AgentRuntimeError::BackpressureShed {
548                depth: *depth,
549                capacity: self.capacity,
550            });
551        }
552        *depth += 1;
553        if let Some(soft) = self.soft_capacity {
554            if *depth >= soft {
555                tracing::warn!(
556                    depth = *depth,
557                    soft_capacity = soft,
558                    hard_capacity = self.capacity,
559                    "backpressure approaching hard limit"
560                );
561            }
562        }
563        Ok(())
564    }
565
566    /// Release a previously acquired slot.
567    pub fn release(&self) -> Result<(), AgentRuntimeError> {
568        let mut depth = timed_lock(&self.inner, "BackpressureGuard::release");
569        *depth = depth.saturating_sub(1);
570        Ok(())
571    }
572
573    /// Return the current depth.
574    pub fn depth(&self) -> Result<usize, AgentRuntimeError> {
575        let depth = timed_lock(&self.inner, "BackpressureGuard::depth");
576        Ok(*depth)
577    }
578
579    /// Return the ratio of current depth to soft capacity as a value in `[0.0, ∞)`.
580    ///
581    /// Returns `0.0` if no soft limit has been configured.
582    /// Values above `1.0` mean the soft limit has been exceeded.
583    pub fn soft_depth_ratio(&self) -> f32 {
584        match self.soft_capacity {
585            None => 0.0,
586            Some(soft) => {
587                let depth = timed_lock(&self.inner, "BackpressureGuard::soft_depth_ratio");
588                *depth as f32 / soft as f32
589            }
590        }
591    }
592}
593
594// ── Pipeline ──────────────────────────────────────────────────────────────────
595
596/// A single named stage in the pipeline.
597pub struct Stage {
598    /// Human-readable name used in log output and error messages.
599    pub name: String,
600    /// The transform function; receives the current string and returns the transformed string.
601    pub handler: Box<dyn Fn(String) -> Result<String, AgentRuntimeError> + Send + Sync>,
602}
603
604impl std::fmt::Debug for Stage {
605    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
606        f.debug_struct("Stage").field("name", &self.name).finish()
607    }
608}
609
610/// A composable pipeline that passes a string through a sequence of named stages.
611///
612/// ## Guarantees
613/// - Stages execute in insertion order
614/// - First stage failure short-circuits remaining stages
615/// - Non-panicking
616#[derive(Debug)]
617pub struct Pipeline {
618    stages: Vec<Stage>,
619}
620
621impl Pipeline {
622    /// Create a new empty pipeline.
623    pub fn new() -> Self {
624        Self { stages: Vec::new() }
625    }
626
627    /// Append a stage to the pipeline.
628    pub fn add_stage(
629        mut self,
630        name: impl Into<String>,
631        handler: impl Fn(String) -> Result<String, AgentRuntimeError> + Send + Sync + 'static,
632    ) -> Self {
633        self.stages.push(Stage {
634            name: name.into(),
635            handler: Box::new(handler),
636        });
637        self
638    }
639
640    /// Execute the pipeline, passing `input` through each stage in order.
641    #[tracing::instrument(skip(self))]
642    pub fn run(&self, input: String) -> Result<String, AgentRuntimeError> {
643        let mut current = input;
644        for stage in &self.stages {
645            tracing::debug!(stage = %stage.name, "running pipeline stage");
646            current = (stage.handler)(current).map_err(|e| {
647                tracing::error!(stage = %stage.name, error = %e, "pipeline stage failed");
648                e
649            })?;
650        }
651        Ok(current)
652    }
653
654    /// Return the number of stages in the pipeline.
655    pub fn stage_count(&self) -> usize {
656        self.stages.len()
657    }
658}
659
660impl Default for Pipeline {
661    fn default() -> Self {
662        Self::new()
663    }
664}
665
666// ── Tests ─────────────────────────────────────────────────────────────────────
667
668#[cfg(test)]
669mod tests {
670    use super::*;
671
672    // ── RetryPolicy ───────────────────────────────────────────────────────────
673
674    #[test]
675    fn test_retry_policy_rejects_zero_attempts() {
676        assert!(RetryPolicy::exponential(0, 100).is_err());
677    }
678
679    #[test]
680    fn test_retry_policy_delay_attempt_1_equals_base() {
681        let p = RetryPolicy::exponential(3, 100).unwrap();
682        assert_eq!(p.delay_for(1), Duration::from_millis(100));
683    }
684
685    #[test]
686    fn test_retry_policy_delay_doubles_each_attempt() {
687        let p = RetryPolicy::exponential(5, 100).unwrap();
688        assert_eq!(p.delay_for(2), Duration::from_millis(200));
689        assert_eq!(p.delay_for(3), Duration::from_millis(400));
690        assert_eq!(p.delay_for(4), Duration::from_millis(800));
691    }
692
693    #[test]
694    fn test_retry_policy_delay_capped_at_max() {
695        let p = RetryPolicy::exponential(10, 10_000).unwrap();
696        assert_eq!(p.delay_for(10), MAX_RETRY_DELAY);
697    }
698
699    #[test]
700    fn test_retry_policy_delay_never_exceeds_max_for_any_attempt() {
701        let p = RetryPolicy::exponential(10, 1000).unwrap();
702        for attempt in 1..=10 {
703            assert!(p.delay_for(attempt) <= MAX_RETRY_DELAY);
704        }
705    }
706
707    // ── CircuitBreaker ────────────────────────────────────────────────────────
708
709    #[test]
710    fn test_circuit_breaker_rejects_zero_threshold() {
711        assert!(CircuitBreaker::new("svc", 0, Duration::from_secs(1)).is_err());
712    }
713
714    #[test]
715    fn test_circuit_breaker_starts_closed() {
716        let cb = CircuitBreaker::new("svc", 3, Duration::from_secs(60)).unwrap();
717        assert_eq!(cb.state().unwrap(), CircuitState::Closed);
718    }
719
720    #[test]
721    fn test_circuit_breaker_success_keeps_closed() {
722        let cb = CircuitBreaker::new("svc", 3, Duration::from_secs(60)).unwrap();
723        let result: Result<i32, AgentRuntimeError> = cb.call(|| Ok::<i32, AgentRuntimeError>(42));
724        assert!(result.is_ok());
725        assert_eq!(cb.state().unwrap(), CircuitState::Closed);
726    }
727
728    #[test]
729    fn test_circuit_breaker_opens_after_threshold_failures() {
730        let cb = CircuitBreaker::new("svc", 3, Duration::from_secs(60)).unwrap();
731        for _ in 0..3 {
732            let _: Result<(), AgentRuntimeError> = cb.call(|| Err::<(), _>("oops".to_string()));
733        }
734        assert!(matches!(cb.state().unwrap(), CircuitState::Open { .. }));
735    }
736
737    #[test]
738    fn test_circuit_breaker_open_fast_fails() {
739        let cb = CircuitBreaker::new("svc", 1, Duration::from_secs(3600)).unwrap();
740        let _: Result<(), AgentRuntimeError> = cb.call(|| Err::<(), _>("fail".to_string()));
741        let result: Result<(), AgentRuntimeError> = cb.call(|| Ok::<(), AgentRuntimeError>(()));
742        assert!(matches!(result, Err(AgentRuntimeError::CircuitOpen { .. })));
743    }
744
745    #[test]
746    fn test_circuit_breaker_success_resets_failure_count() {
747        let cb = CircuitBreaker::new("svc", 5, Duration::from_secs(60)).unwrap();
748        let _: Result<(), AgentRuntimeError> = cb.call(|| Err::<(), _>("fail".to_string()));
749        let _: Result<(), AgentRuntimeError> = cb.call(|| Err::<(), _>("fail".to_string()));
750        let _: Result<i32, AgentRuntimeError> = cb.call(|| Ok::<i32, AgentRuntimeError>(1));
751        assert_eq!(cb.failure_count().unwrap(), 0);
752    }
753
754    #[test]
755    fn test_circuit_breaker_half_open_on_recovery() {
756        // Use a zero recovery window to immediately go half-open
757        let cb = CircuitBreaker::new("svc", 1, Duration::ZERO).unwrap();
758        let _: Result<(), AgentRuntimeError> = cb.call(|| Err::<(), _>("fail".to_string()));
759        // After recovery window, next call should probe (half-open → closed on success)
760        let result: Result<i32, AgentRuntimeError> = cb.call(|| Ok::<i32, AgentRuntimeError>(99));
761        assert_eq!(result.unwrap_or(0), 99);
762        assert_eq!(cb.state().unwrap(), CircuitState::Closed);
763    }
764
765    #[test]
766    fn test_circuit_breaker_with_custom_backend_uses_backend_state() {
767        // Build a custom backend and share it between two circuit breakers
768        // to verify that state is read from and written to the backend.
769        let shared_backend: Arc<dyn CircuitBreakerBackend> =
770            Arc::new(InMemoryCircuitBreakerBackend::new());
771
772        let cb1 = CircuitBreaker::new("svc", 2, Duration::from_secs(60))
773            .unwrap()
774            .with_backend(Arc::clone(&shared_backend));
775
776        let cb2 = CircuitBreaker::new("svc", 2, Duration::from_secs(60))
777            .unwrap()
778            .with_backend(Arc::clone(&shared_backend));
779
780        // Trigger one failure via cb1
781        let _: Result<(), AgentRuntimeError> = cb1.call(|| Err::<(), _>("fail".to_string()));
782
783        // cb2 should observe the failure recorded by cb1
784        assert_eq!(cb2.failure_count().unwrap(), 1);
785
786        // Trigger the second failure to open the circuit via cb1
787        let _: Result<(), AgentRuntimeError> = cb1.call(|| Err::<(), _>("fail again".to_string()));
788
789        // cb2 should now see the circuit as open
790        assert!(matches!(cb2.state().unwrap(), CircuitState::Open { .. }));
791    }
792
793    #[test]
794    fn test_in_memory_backend_increments_and_resets() {
795        let backend = InMemoryCircuitBreakerBackend::new();
796
797        assert_eq!(backend.get_failures("svc"), 0);
798
799        let count = backend.increment_failures("svc");
800        assert_eq!(count, 1);
801
802        let count = backend.increment_failures("svc");
803        assert_eq!(count, 2);
804
805        backend.reset_failures("svc");
806        assert_eq!(backend.get_failures("svc"), 0);
807
808        // open_at round-trip
809        assert!(backend.get_open_at("svc").is_none());
810        let now = Instant::now();
811        backend.set_open_at("svc", now);
812        assert!(backend.get_open_at("svc").is_some());
813        backend.clear_open_at("svc");
814        assert!(backend.get_open_at("svc").is_none());
815    }
816
817    // ── Deduplicator ──────────────────────────────────────────────────────────
818
819    #[test]
820    fn test_deduplicator_new_key_is_new() {
821        let d = Deduplicator::new(Duration::from_secs(60));
822        let r = d.check_and_register("key-1").unwrap();
823        assert_eq!(r, DeduplicationResult::New);
824    }
825
826    #[test]
827    fn test_deduplicator_second_check_is_in_progress() {
828        let d = Deduplicator::new(Duration::from_secs(60));
829        d.check_and_register("key-1").unwrap();
830        let r = d.check_and_register("key-1").unwrap();
831        assert_eq!(r, DeduplicationResult::InProgress);
832    }
833
834    #[test]
835    fn test_deduplicator_complete_makes_cached() {
836        let d = Deduplicator::new(Duration::from_secs(60));
837        d.check_and_register("key-1").unwrap();
838        d.complete("key-1", "result-value").unwrap();
839        let r = d.check_and_register("key-1").unwrap();
840        assert_eq!(r, DeduplicationResult::Cached("result-value".into()));
841    }
842
843    #[test]
844    fn test_deduplicator_different_keys_are_independent() {
845        let d = Deduplicator::new(Duration::from_secs(60));
846        d.check_and_register("key-a").unwrap();
847        let r = d.check_and_register("key-b").unwrap();
848        assert_eq!(r, DeduplicationResult::New);
849    }
850
851    #[test]
852    fn test_deduplicator_expired_entry_is_new() {
853        let d = Deduplicator::new(Duration::ZERO); // instant TTL
854        d.check_and_register("key-1").unwrap();
855        d.complete("key-1", "old").unwrap();
856        // Immediately expired — should be New again
857        let r = d.check_and_register("key-1").unwrap();
858        assert_eq!(r, DeduplicationResult::New);
859    }
860
861    // ── BackpressureGuard ─────────────────────────────────────────────────────
862
863    #[test]
864    fn test_backpressure_guard_rejects_zero_capacity() {
865        assert!(BackpressureGuard::new(0).is_err());
866    }
867
868    #[test]
869    fn test_backpressure_guard_acquire_within_capacity() {
870        let g = BackpressureGuard::new(5).unwrap();
871        assert!(g.try_acquire().is_ok());
872        assert_eq!(g.depth().unwrap(), 1);
873    }
874
875    #[test]
876    fn test_backpressure_guard_sheds_when_full() {
877        let g = BackpressureGuard::new(2).unwrap();
878        g.try_acquire().unwrap();
879        g.try_acquire().unwrap();
880        let result = g.try_acquire();
881        assert!(matches!(
882            result,
883            Err(AgentRuntimeError::BackpressureShed { .. })
884        ));
885    }
886
887    #[test]
888    fn test_backpressure_guard_release_decrements_depth() {
889        let g = BackpressureGuard::new(3).unwrap();
890        g.try_acquire().unwrap();
891        g.try_acquire().unwrap();
892        g.release().unwrap();
893        assert_eq!(g.depth().unwrap(), 1);
894    }
895
896    #[test]
897    fn test_backpressure_guard_release_on_empty_is_noop() {
898        let g = BackpressureGuard::new(3).unwrap();
899        g.release().unwrap(); // Should not fail
900        assert_eq!(g.depth().unwrap(), 0);
901    }
902
903    // ── Pipeline ──────────────────────────────────────────────────────────────
904
905    #[test]
906    fn test_pipeline_runs_stages_in_order() {
907        let p = Pipeline::new()
908            .add_stage("upper", |s| Ok(s.to_uppercase()))
909            .add_stage("append", |s| Ok(format!("{s}!")));
910        let result = p.run("hello".into()).unwrap();
911        assert_eq!(result, "HELLO!");
912    }
913
914    #[test]
915    fn test_pipeline_empty_pipeline_returns_input() {
916        let p = Pipeline::new();
917        assert_eq!(p.run("test".into()).unwrap(), "test");
918    }
919
920    #[test]
921    fn test_pipeline_stage_failure_short_circuits() {
922        let p = Pipeline::new()
923            .add_stage("fail", |_| {
924                Err(AgentRuntimeError::Orchestration("boom".into()))
925            })
926            .add_stage("never", |s| Ok(s));
927        assert!(p.run("input".into()).is_err());
928    }
929
930    #[test]
931    fn test_pipeline_stage_count() {
932        let p = Pipeline::new()
933            .add_stage("s1", |s| Ok(s))
934            .add_stage("s2", |s| Ok(s));
935        assert_eq!(p.stage_count(), 2);
936    }
937
938    // ── Item 13: BackpressureGuard soft limit ──────────────────────────────────
939
940    #[test]
941    fn test_backpressure_soft_limit_rejects_invalid_config() {
942        // soft >= capacity must be rejected
943        let g = BackpressureGuard::new(5).unwrap();
944        assert!(g.with_soft_limit(5).is_err());
945        let g = BackpressureGuard::new(5).unwrap();
946        assert!(g.with_soft_limit(6).is_err());
947    }
948
949    #[test]
950    fn test_backpressure_soft_limit_accepts_requests_below_soft() {
951        let g = BackpressureGuard::new(5)
952            .unwrap()
953            .with_soft_limit(2)
954            .unwrap();
955        // Both acquires below soft limit should succeed
956        assert!(g.try_acquire().is_ok());
957        assert!(g.try_acquire().is_ok());
958        assert_eq!(g.depth().unwrap(), 2);
959    }
960
961    #[test]
962    fn test_backpressure_with_soft_limit_still_sheds_at_hard_capacity() {
963        let g = BackpressureGuard::new(3)
964            .unwrap()
965            .with_soft_limit(2)
966            .unwrap();
967        g.try_acquire().unwrap();
968        g.try_acquire().unwrap();
969        g.try_acquire().unwrap(); // reaches hard limit
970        let result = g.try_acquire();
971        assert!(matches!(
972            result,
973            Err(AgentRuntimeError::BackpressureShed { .. })
974        ));
975    }
976
977    // ── Item 14: timed_lock concurrency correctness ───────────────────────────
978
979    #[test]
980    fn test_backpressure_concurrent_acquires_are_consistent() {
981        use std::sync::Arc;
982        use std::thread;
983
984        let g = Arc::new(BackpressureGuard::new(100).unwrap());
985        let mut handles = Vec::new();
986
987        for _ in 0..10 {
988            let g_clone = Arc::clone(&g);
989            handles.push(thread::spawn(move || {
990                g_clone.try_acquire().ok();
991            }));
992        }
993
994        for h in handles {
995            h.join().unwrap();
996        }
997
998        // All 10 threads acquired a slot; depth must be exactly 10
999        assert_eq!(g.depth().unwrap(), 10);
1000    }
1001}