Skip to main content

llm_agent_runtime/
orchestrator.rs

1//! # Module: Orchestrator
2//!
3//! ## Responsibility
4//! Provides a composable LLM pipeline with circuit breaking, retry, deduplication,
5//! and backpressure. Mirrors the public API of `tokio-prompt-orchestrator`.
6//!
7//! ## Guarantees
8//! - Thread-safe: all types wrap state in `Arc<Mutex<_>>` or atomics
9//! - Circuit breaker opens after `threshold` failures within `window` calls
10//! - RetryPolicy delays grow exponentially and are capped at `MAX_RETRY_DELAY`
11//! - Deduplicator is deterministic and non-blocking
12//! - BackpressureGuard never exceeds declared capacity
13//! - Non-panicking: all operations return `Result`
14//!
15//! ## NOT Responsible For
16//! - Cross-node circuit breakers (single-process only, unless a distributed backend is provided)
17//! - Persistent deduplication (in-memory, bounded TTL)
18//! - Distributed backpressure
19
20use crate::error::AgentRuntimeError;
21use std::collections::HashMap;
22use std::sync::{Arc, Mutex};
23use std::time::{Duration, Instant};
24
25/// Maximum delay between retries — caps exponential growth.
26pub const MAX_RETRY_DELAY: Duration = Duration::from_secs(60);
27
28// ── Lock recovery helpers ──────────────────────────────────────────────────────
29
30/// Recover from a poisoned mutex by extracting the inner value.
31///
32/// Instead of propagating a lock-poison error, this logs a warning and
33/// returns the inner guard so the caller can continue operating.
34#[allow(dead_code)]
35fn recover_lock<'a, T>(
36    result: std::sync::LockResult<std::sync::MutexGuard<'a, T>>,
37    ctx: &str,
38) -> std::sync::MutexGuard<'a, T>
39where
40    T: ?Sized,
41{
42    match result {
43        Ok(guard) => guard,
44        Err(poisoned) => {
45            tracing::warn!("mutex poisoned in {ctx}, recovering inner value");
46            poisoned.into_inner()
47        }
48    }
49}
50
51/// Acquire a lock, recovering from poisoning and logging slow acquisitions.
52///
53/// Times the lock acquisition; if it exceeds 5 ms a warning is emitted so
54/// contention hot-spots can be identified in production logs.
55fn timed_lock<'a, T>(mutex: &'a Mutex<T>, ctx: &str) -> std::sync::MutexGuard<'a, T>
56where
57    T: ?Sized,
58{
59    let start = std::time::Instant::now();
60    let result = mutex.lock();
61    let elapsed = start.elapsed();
62    if elapsed > std::time::Duration::from_millis(5) {
63        tracing::warn!(
64            duration_ms = elapsed.as_millis(),
65            ctx = ctx,
66            "slow mutex acquisition"
67        );
68    }
69    match result {
70        Ok(guard) => guard,
71        Err(poisoned) => {
72            tracing::warn!("mutex poisoned in {ctx}, recovering inner value");
73            poisoned.into_inner()
74        }
75    }
76}
77
78// ── RetryPolicy ───────────────────────────────────────────────────────────────
79
80/// Exponential backoff retry policy.
81#[derive(Debug, Clone)]
82pub struct RetryPolicy {
83    /// Maximum number of attempts (including the first).
84    pub max_attempts: u32,
85    /// Base delay for the first retry.
86    pub base_delay: Duration,
87}
88
89impl RetryPolicy {
90    /// Create an exponential retry policy.
91    ///
92    /// # Arguments
93    /// * `max_attempts` — total attempt budget (must be ≥ 1)
94    /// * `base_ms` — base delay in milliseconds for attempt 1
95    ///
96    /// # Returns
97    /// - `Ok(RetryPolicy)` — on success
98    /// - `Err(AgentRuntimeError::Orchestration)` — if `max_attempts == 0`
99    pub fn exponential(max_attempts: u32, base_ms: u64) -> Result<Self, AgentRuntimeError> {
100        if max_attempts == 0 {
101            return Err(AgentRuntimeError::Orchestration(
102                "max_attempts must be >= 1".into(),
103            ));
104        }
105        Ok(Self {
106            max_attempts,
107            base_delay: Duration::from_millis(base_ms),
108        })
109    }
110
111    /// Compute the delay before the given attempt number (1-based).
112    ///
113    /// Delay = `base_delay * 2^(attempt-1)`, capped at `MAX_RETRY_DELAY`.
114    pub fn delay_for(&self, attempt: u32) -> Duration {
115        let exp = attempt.saturating_sub(1);
116        let multiplier = 1u64.checked_shl(exp.min(63)).unwrap_or(u64::MAX);
117        let millis = self
118            .base_delay
119            .as_millis()
120            .saturating_mul(multiplier as u128);
121        let raw = Duration::from_millis(millis.min(u64::MAX as u128) as u64);
122        raw.min(MAX_RETRY_DELAY)
123    }
124}
125
126// ── CircuitBreaker ────────────────────────────────────────────────────────────
127
128/// Tracks failure rates and opens when the threshold is exceeded.
129///
130/// States: `Closed` (normal) → `Open` (fast-fail) → `HalfOpen` (probe).
131#[derive(Debug, Clone, PartialEq)]
132pub enum CircuitState {
133    /// Circuit is operating normally; requests pass through.
134    Closed,
135    /// Circuit has tripped; requests are fast-failed without calling the operation.
136    Open {
137        /// The instant at which the circuit was opened.
138        opened_at: Instant,
139    },
140    /// Recovery probe period; the next request will be attempted to test recovery.
141    HalfOpen,
142}
143
144/// Backend for circuit breaker state storage.
145///
146/// Implement this trait to share circuit breaker state across processes
147/// (e.g., via Redis). The in-process default is `InMemoryCircuitBreakerBackend`.
148///
149/// Note: Methods are synchronous to avoid pulling in `async-trait`. A
150/// distributed backend (e.g., Redis) can internally spawn a Tokio runtime.
151pub trait CircuitBreakerBackend: Send + Sync {
152    /// Increment the consecutive failure count for `service` and return the new count.
153    fn increment_failures(&self, service: &str) -> u32;
154    /// Reset the consecutive failure count for `service` to zero.
155    fn reset_failures(&self, service: &str);
156    /// Return the current consecutive failure count for `service`.
157    fn get_failures(&self, service: &str) -> u32;
158    /// Record the instant at which the circuit was opened for `service`.
159    fn set_open_at(&self, service: &str, at: std::time::Instant);
160    /// Clear the open-at timestamp, effectively moving the circuit to Closed or HalfOpen.
161    fn clear_open_at(&self, service: &str);
162    /// Return the instant at which the circuit was opened, or `None` if it is not open.
163    fn get_open_at(&self, service: &str) -> Option<std::time::Instant>;
164}
165
166// ── InMemoryCircuitBreakerBackend ─────────────────────────────────────────────
167
168/// In-process circuit breaker backend backed by a `Mutex`.
169pub struct InMemoryCircuitBreakerBackend {
170    inner: Arc<Mutex<InMemoryBackendState>>,
171}
172
173struct InMemoryBackendState {
174    consecutive_failures: u32,
175    open_at: Option<std::time::Instant>,
176}
177
178impl InMemoryCircuitBreakerBackend {
179    /// Create a new in-memory backend with all counters at zero.
180    pub fn new() -> Self {
181        Self {
182            inner: Arc::new(Mutex::new(InMemoryBackendState {
183                consecutive_failures: 0,
184                open_at: None,
185            })),
186        }
187    }
188}
189
190impl Default for InMemoryCircuitBreakerBackend {
191    fn default() -> Self {
192        Self::new()
193    }
194}
195
196impl CircuitBreakerBackend for InMemoryCircuitBreakerBackend {
197    fn increment_failures(&self, _service: &str) -> u32 {
198        let mut state = timed_lock(
199            &self.inner,
200            "InMemoryCircuitBreakerBackend::increment_failures",
201        );
202        state.consecutive_failures += 1;
203        state.consecutive_failures
204    }
205
206    fn reset_failures(&self, _service: &str) {
207        let mut state = timed_lock(&self.inner, "InMemoryCircuitBreakerBackend::reset_failures");
208        state.consecutive_failures = 0;
209    }
210
211    fn get_failures(&self, _service: &str) -> u32 {
212        let state = timed_lock(&self.inner, "InMemoryCircuitBreakerBackend::get_failures");
213        state.consecutive_failures
214    }
215
216    fn set_open_at(&self, _service: &str, at: std::time::Instant) {
217        let mut state = timed_lock(&self.inner, "InMemoryCircuitBreakerBackend::set_open_at");
218        state.open_at = Some(at);
219    }
220
221    fn clear_open_at(&self, _service: &str) {
222        let mut state = timed_lock(&self.inner, "InMemoryCircuitBreakerBackend::clear_open_at");
223        state.open_at = None;
224    }
225
226    fn get_open_at(&self, _service: &str) -> Option<std::time::Instant> {
227        let state = timed_lock(&self.inner, "InMemoryCircuitBreakerBackend::get_open_at");
228        state.open_at
229    }
230}
231
232// ── CircuitBreaker ────────────────────────────────────────────────────────────
233
234/// Circuit breaker guarding a fallible operation.
235///
236/// ## Guarantees
237/// - Opens after `threshold` consecutive failures
238/// - Transitions to `HalfOpen` after `recovery_window` has elapsed
239/// - Closes on the first successful probe in `HalfOpen`
240#[derive(Clone)]
241pub struct CircuitBreaker {
242    threshold: u32,
243    recovery_window: Duration,
244    service: String,
245    backend: Arc<dyn CircuitBreakerBackend>,
246}
247
248impl std::fmt::Debug for CircuitBreaker {
249    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
250        f.debug_struct("CircuitBreaker")
251            .field("threshold", &self.threshold)
252            .field("recovery_window", &self.recovery_window)
253            .field("service", &self.service)
254            .finish()
255    }
256}
257
258impl CircuitBreaker {
259    /// Create a new circuit breaker backed by an in-memory backend.
260    ///
261    /// # Arguments
262    /// * `service` — name used in error messages and logs
263    /// * `threshold` — consecutive failures before opening
264    /// * `recovery_window` — how long to stay open before probing
265    pub fn new(
266        service: impl Into<String>,
267        threshold: u32,
268        recovery_window: Duration,
269    ) -> Result<Self, AgentRuntimeError> {
270        if threshold == 0 {
271            return Err(AgentRuntimeError::Orchestration(
272                "circuit breaker threshold must be >= 1".into(),
273            ));
274        }
275        let service = service.into();
276        Ok(Self {
277            threshold,
278            recovery_window,
279            service,
280            backend: Arc::new(InMemoryCircuitBreakerBackend::new()),
281        })
282    }
283
284    /// Replace the default in-memory backend with a custom one.
285    ///
286    /// Useful for sharing circuit breaker state across processes.
287    pub fn with_backend(mut self, backend: Arc<dyn CircuitBreakerBackend>) -> Self {
288        self.backend = backend;
289        self
290    }
291
292    /// Derive the current `CircuitState` from backend storage.
293    #[allow(dead_code)]
294    fn current_state(&self) -> CircuitState {
295        match self.backend.get_open_at(&self.service) {
296            Some(opened_at) => CircuitState::Open { opened_at },
297            None => {
298                // We encode HalfOpen as failures == threshold but no open_at.
299                // However, the transition to HalfOpen is done in `call`, so
300                // outside of `call` we can only observe Closed here.
301                // The `call` method manages the HalfOpen flag via a per-call
302                // transient field — see below.
303                CircuitState::Closed
304            }
305        }
306    }
307
308    /// Attempt to call `f`, respecting the circuit breaker state.
309    ///
310    /// # Returns
311    /// - `Ok(T)` — if `f` succeeds (resets failure count)
312    /// - `Err(AgentRuntimeError::CircuitOpen)` — if the breaker is open
313    /// - `Err(...)` — if `f` fails (may open the breaker)
314    #[tracing::instrument(skip(self, f))]
315    pub fn call<T, E, F>(&self, f: F) -> Result<T, AgentRuntimeError>
316    where
317        F: FnOnce() -> Result<T, E>,
318        E: std::fmt::Display,
319    {
320        // Determine effective state, potentially transitioning Open → HalfOpen.
321        let effective_state = match self.backend.get_open_at(&self.service) {
322            Some(opened_at) => {
323                if opened_at.elapsed() >= self.recovery_window {
324                    // Clear open_at to signal HalfOpen; failures remain.
325                    self.backend.clear_open_at(&self.service);
326                    tracing::info!("circuit moved to half-open for {}", self.service);
327                    CircuitState::HalfOpen
328                } else {
329                    CircuitState::Open { opened_at }
330                }
331            }
332            None => {
333                // Either Closed or HalfOpen (after a prior transition).
334                // We distinguish by checking whether failures >= threshold
335                // but no open_at is set — that means we are in HalfOpen.
336                let failures = self.backend.get_failures(&self.service);
337                if failures >= self.threshold {
338                    CircuitState::HalfOpen
339                } else {
340                    CircuitState::Closed
341                }
342            }
343        };
344
345        tracing::debug!("circuit state: {:?}", effective_state);
346
347        match effective_state {
348            CircuitState::Open { .. } => {
349                return Err(AgentRuntimeError::CircuitOpen {
350                    service: self.service.clone(),
351                });
352            }
353            CircuitState::Closed | CircuitState::HalfOpen => {}
354        }
355
356        // Execute the operation.
357        match f() {
358            Ok(val) => {
359                self.backend.reset_failures(&self.service);
360                self.backend.clear_open_at(&self.service);
361                tracing::info!("circuit closed for {}", self.service);
362                Ok(val)
363            }
364            Err(e) => {
365                let failures = self.backend.increment_failures(&self.service);
366                if failures >= self.threshold {
367                    let now = Instant::now();
368                    self.backend.set_open_at(&self.service, now);
369                    tracing::info!("circuit opened for {}", self.service);
370                }
371                Err(AgentRuntimeError::Orchestration(e.to_string()))
372            }
373        }
374    }
375
376    /// Return the current circuit state.
377    pub fn state(&self) -> Result<CircuitState, AgentRuntimeError> {
378        let state = match self.backend.get_open_at(&self.service) {
379            Some(opened_at) => {
380                if opened_at.elapsed() >= self.recovery_window {
381                    // Would transition to HalfOpen on next call; report HalfOpen.
382                    let failures = self.backend.get_failures(&self.service);
383                    if failures >= self.threshold {
384                        CircuitState::HalfOpen
385                    } else {
386                        CircuitState::Closed
387                    }
388                } else {
389                    CircuitState::Open { opened_at }
390                }
391            }
392            None => {
393                let failures = self.backend.get_failures(&self.service);
394                if failures >= self.threshold {
395                    CircuitState::HalfOpen
396                } else {
397                    CircuitState::Closed
398                }
399            }
400        };
401        Ok(state)
402    }
403
404    /// Return the consecutive failure count.
405    pub fn failure_count(&self) -> Result<u32, AgentRuntimeError> {
406        Ok(self.backend.get_failures(&self.service))
407    }
408}
409
410// ── DeduplicationResult ───────────────────────────────────────────────────────
411
412/// Result of a deduplication check.
413#[derive(Debug, Clone, PartialEq)]
414pub enum DeduplicationResult {
415    /// This is a new, unseen request.
416    New,
417    /// A cached result exists for this key.
418    Cached(String),
419    /// A matching request is currently in-flight.
420    InProgress,
421}
422
423/// Deduplicates requests by key within a TTL window.
424///
425/// ## Guarantees
426/// - Deterministic: same key always maps to the same result
427/// - Thread-safe via `Arc<Mutex<_>>`
428/// - Entries expire after `ttl`
429#[derive(Debug, Clone)]
430pub struct Deduplicator {
431    ttl: Duration,
432    inner: Arc<Mutex<DeduplicatorInner>>,
433}
434
435#[derive(Debug)]
436struct DeduplicatorInner {
437    cache: HashMap<String, (String, Instant)>, // key → (result, inserted_at)
438    in_flight: HashMap<String, Instant>,       // key → started_at
439}
440
441impl Deduplicator {
442    /// Create a new deduplicator with the given TTL.
443    pub fn new(ttl: Duration) -> Self {
444        Self {
445            ttl,
446            inner: Arc::new(Mutex::new(DeduplicatorInner {
447                cache: HashMap::new(),
448                in_flight: HashMap::new(),
449            })),
450        }
451    }
452
453    /// Check whether `key` is new, cached, or in-flight.
454    ///
455    /// Marks the key as in-flight if it is new.
456    pub fn check_and_register(&self, key: &str) -> Result<DeduplicationResult, AgentRuntimeError> {
457        let mut inner = timed_lock(&self.inner, "Deduplicator::check_and_register");
458
459        let now = Instant::now();
460
461        // Expire stale cache entries
462        inner
463            .cache
464            .retain(|_, (_, ts)| now.duration_since(*ts) < self.ttl);
465        inner
466            .in_flight
467            .retain(|_, ts| now.duration_since(*ts) < self.ttl);
468
469        if let Some((result, _)) = inner.cache.get(key) {
470            return Ok(DeduplicationResult::Cached(result.clone()));
471        }
472
473        if inner.in_flight.contains_key(key) {
474            return Ok(DeduplicationResult::InProgress);
475        }
476
477        inner.in_flight.insert(key.to_owned(), now);
478        Ok(DeduplicationResult::New)
479    }
480
481    /// Complete a request: move from in-flight to cached with the given result.
482    pub fn complete(&self, key: &str, result: impl Into<String>) -> Result<(), AgentRuntimeError> {
483        let mut inner = timed_lock(&self.inner, "Deduplicator::complete");
484        inner.in_flight.remove(key);
485        inner
486            .cache
487            .insert(key.to_owned(), (result.into(), Instant::now()));
488        Ok(())
489    }
490}
491
492// ── BackpressureGuard ─────────────────────────────────────────────────────────
493
494/// Tracks in-flight work count and enforces a capacity limit.
495///
496/// ## Guarantees
497/// - Thread-safe via `Arc<Mutex<_>>`
498/// - `try_acquire` is non-blocking
499/// - `release` decrements the counter; no-op if counter is already 0
500/// - Optional soft limit emits a warning when depth reaches the threshold
501#[derive(Debug, Clone)]
502pub struct BackpressureGuard {
503    capacity: usize,
504    soft_capacity: Option<usize>,
505    inner: Arc<Mutex<usize>>,
506}
507
508impl BackpressureGuard {
509    /// Create a new guard with the given capacity.
510    ///
511    /// # Returns
512    /// - `Ok(BackpressureGuard)` — on success
513    /// - `Err(AgentRuntimeError::Orchestration)` — if `capacity == 0`
514    pub fn new(capacity: usize) -> Result<Self, AgentRuntimeError> {
515        if capacity == 0 {
516            return Err(AgentRuntimeError::Orchestration(
517                "BackpressureGuard capacity must be > 0".into(),
518            ));
519        }
520        Ok(Self {
521            capacity,
522            soft_capacity: None,
523            inner: Arc::new(Mutex::new(0)),
524        })
525    }
526
527    /// Set a soft capacity threshold. When depth reaches this level, a warning
528    /// is logged but the request is still accepted (up to hard capacity).
529    pub fn with_soft_limit(mut self, soft: usize) -> Result<Self, AgentRuntimeError> {
530        if soft >= self.capacity {
531            return Err(AgentRuntimeError::Orchestration(
532                "soft_capacity must be less than hard capacity".into(),
533            ));
534        }
535        self.soft_capacity = Some(soft);
536        Ok(self)
537    }
538
539    /// Try to acquire a slot.
540    ///
541    /// Emits a warning when the soft limit is reached (if configured), but
542    /// still accepts the request until hard capacity is exceeded.
543    ///
544    /// # Returns
545    /// - `Ok(())` — slot acquired
546    /// - `Err(AgentRuntimeError::BackpressureShed)` — hard capacity exceeded
547    pub fn try_acquire(&self) -> Result<(), AgentRuntimeError> {
548        let mut depth = timed_lock(&self.inner, "BackpressureGuard::try_acquire");
549        if *depth >= self.capacity {
550            return Err(AgentRuntimeError::BackpressureShed {
551                depth: *depth,
552                capacity: self.capacity,
553            });
554        }
555        *depth += 1;
556        if let Some(soft) = self.soft_capacity {
557            if *depth >= soft {
558                tracing::warn!(
559                    depth = *depth,
560                    soft_capacity = soft,
561                    hard_capacity = self.capacity,
562                    "backpressure approaching hard limit"
563                );
564            }
565        }
566        Ok(())
567    }
568
569    /// Release a previously acquired slot.
570    pub fn release(&self) -> Result<(), AgentRuntimeError> {
571        let mut depth = timed_lock(&self.inner, "BackpressureGuard::release");
572        *depth = depth.saturating_sub(1);
573        Ok(())
574    }
575
576    /// Return the current depth.
577    pub fn depth(&self) -> Result<usize, AgentRuntimeError> {
578        let depth = timed_lock(&self.inner, "BackpressureGuard::depth");
579        Ok(*depth)
580    }
581
582    /// Return the ratio of current depth to soft capacity as a value in `[0.0, ∞)`.
583    ///
584    /// Returns `0.0` if no soft limit has been configured.
585    /// Values above `1.0` mean the soft limit has been exceeded.
586    pub fn soft_depth_ratio(&self) -> f32 {
587        match self.soft_capacity {
588            None => 0.0,
589            Some(soft) => {
590                let depth = timed_lock(&self.inner, "BackpressureGuard::soft_depth_ratio");
591                *depth as f32 / soft as f32
592            }
593        }
594    }
595}
596
597// ── Pipeline ──────────────────────────────────────────────────────────────────
598
599/// A single named stage in the pipeline.
600pub struct Stage {
601    /// Human-readable name used in log output and error messages.
602    pub name: String,
603    /// The transform function; receives the current string and returns the transformed string.
604    pub handler: Box<dyn Fn(String) -> Result<String, AgentRuntimeError> + Send + Sync>,
605}
606
607impl std::fmt::Debug for Stage {
608    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
609        f.debug_struct("Stage").field("name", &self.name).finish()
610    }
611}
612
613/// A composable pipeline that passes a string through a sequence of named stages.
614///
615/// ## Guarantees
616/// - Stages execute in insertion order
617/// - First stage failure short-circuits remaining stages
618/// - Non-panicking
619#[derive(Debug)]
620pub struct Pipeline {
621    stages: Vec<Stage>,
622}
623
624impl Pipeline {
625    /// Create a new empty pipeline.
626    pub fn new() -> Self {
627        Self { stages: Vec::new() }
628    }
629
630    /// Append a stage to the pipeline.
631    pub fn add_stage(
632        mut self,
633        name: impl Into<String>,
634        handler: impl Fn(String) -> Result<String, AgentRuntimeError> + Send + Sync + 'static,
635    ) -> Self {
636        self.stages.push(Stage {
637            name: name.into(),
638            handler: Box::new(handler),
639        });
640        self
641    }
642
643    /// Execute the pipeline, passing `input` through each stage in order.
644    #[tracing::instrument(skip(self))]
645    pub fn run(&self, input: String) -> Result<String, AgentRuntimeError> {
646        let mut current = input;
647        for stage in &self.stages {
648            tracing::debug!(stage = %stage.name, "running pipeline stage");
649            current = (stage.handler)(current).map_err(|e| {
650                tracing::error!(stage = %stage.name, error = %e, "pipeline stage failed");
651                e
652            })?;
653        }
654        Ok(current)
655    }
656
657    /// Return the number of stages in the pipeline.
658    pub fn stage_count(&self) -> usize {
659        self.stages.len()
660    }
661}
662
663impl Default for Pipeline {
664    fn default() -> Self {
665        Self::new()
666    }
667}
668
669// ── Tests ─────────────────────────────────────────────────────────────────────
670
671#[cfg(test)]
672mod tests {
673    use super::*;
674
675    // ── RetryPolicy ───────────────────────────────────────────────────────────
676
677    #[test]
678    fn test_retry_policy_rejects_zero_attempts() {
679        assert!(RetryPolicy::exponential(0, 100).is_err());
680    }
681
682    #[test]
683    fn test_retry_policy_delay_attempt_1_equals_base() {
684        let p = RetryPolicy::exponential(3, 100).unwrap();
685        assert_eq!(p.delay_for(1), Duration::from_millis(100));
686    }
687
688    #[test]
689    fn test_retry_policy_delay_doubles_each_attempt() {
690        let p = RetryPolicy::exponential(5, 100).unwrap();
691        assert_eq!(p.delay_for(2), Duration::from_millis(200));
692        assert_eq!(p.delay_for(3), Duration::from_millis(400));
693        assert_eq!(p.delay_for(4), Duration::from_millis(800));
694    }
695
696    #[test]
697    fn test_retry_policy_delay_capped_at_max() {
698        let p = RetryPolicy::exponential(10, 10_000).unwrap();
699        assert_eq!(p.delay_for(10), MAX_RETRY_DELAY);
700    }
701
702    #[test]
703    fn test_retry_policy_delay_never_exceeds_max_for_any_attempt() {
704        let p = RetryPolicy::exponential(10, 1000).unwrap();
705        for attempt in 1..=10 {
706            assert!(p.delay_for(attempt) <= MAX_RETRY_DELAY);
707        }
708    }
709
710    // ── CircuitBreaker ────────────────────────────────────────────────────────
711
712    #[test]
713    fn test_circuit_breaker_rejects_zero_threshold() {
714        assert!(CircuitBreaker::new("svc", 0, Duration::from_secs(1)).is_err());
715    }
716
717    #[test]
718    fn test_circuit_breaker_starts_closed() {
719        let cb = CircuitBreaker::new("svc", 3, Duration::from_secs(60)).unwrap();
720        assert_eq!(cb.state().unwrap(), CircuitState::Closed);
721    }
722
723    #[test]
724    fn test_circuit_breaker_success_keeps_closed() {
725        let cb = CircuitBreaker::new("svc", 3, Duration::from_secs(60)).unwrap();
726        let result: Result<i32, AgentRuntimeError> = cb.call(|| Ok::<i32, AgentRuntimeError>(42));
727        assert!(result.is_ok());
728        assert_eq!(cb.state().unwrap(), CircuitState::Closed);
729    }
730
731    #[test]
732    fn test_circuit_breaker_opens_after_threshold_failures() {
733        let cb = CircuitBreaker::new("svc", 3, Duration::from_secs(60)).unwrap();
734        for _ in 0..3 {
735            let _: Result<(), AgentRuntimeError> = cb.call(|| Err::<(), _>("oops".to_string()));
736        }
737        assert!(matches!(cb.state().unwrap(), CircuitState::Open { .. }));
738    }
739
740    #[test]
741    fn test_circuit_breaker_open_fast_fails() {
742        let cb = CircuitBreaker::new("svc", 1, Duration::from_secs(3600)).unwrap();
743        let _: Result<(), AgentRuntimeError> = cb.call(|| Err::<(), _>("fail".to_string()));
744        let result: Result<(), AgentRuntimeError> = cb.call(|| Ok::<(), AgentRuntimeError>(()));
745        assert!(matches!(result, Err(AgentRuntimeError::CircuitOpen { .. })));
746    }
747
748    #[test]
749    fn test_circuit_breaker_success_resets_failure_count() {
750        let cb = CircuitBreaker::new("svc", 5, Duration::from_secs(60)).unwrap();
751        let _: Result<(), AgentRuntimeError> = cb.call(|| Err::<(), _>("fail".to_string()));
752        let _: Result<(), AgentRuntimeError> = cb.call(|| Err::<(), _>("fail".to_string()));
753        let _: Result<i32, AgentRuntimeError> = cb.call(|| Ok::<i32, AgentRuntimeError>(1));
754        assert_eq!(cb.failure_count().unwrap(), 0);
755    }
756
757    #[test]
758    fn test_circuit_breaker_half_open_on_recovery() {
759        // Use a zero recovery window to immediately go half-open
760        let cb = CircuitBreaker::new("svc", 1, Duration::ZERO).unwrap();
761        let _: Result<(), AgentRuntimeError> = cb.call(|| Err::<(), _>("fail".to_string()));
762        // After recovery window, next call should probe (half-open → closed on success)
763        let result: Result<i32, AgentRuntimeError> = cb.call(|| Ok::<i32, AgentRuntimeError>(99));
764        assert_eq!(result.unwrap_or(0), 99);
765        assert_eq!(cb.state().unwrap(), CircuitState::Closed);
766    }
767
768    #[test]
769    fn test_circuit_breaker_with_custom_backend_uses_backend_state() {
770        // Build a custom backend and share it between two circuit breakers
771        // to verify that state is read from and written to the backend.
772        let shared_backend: Arc<dyn CircuitBreakerBackend> =
773            Arc::new(InMemoryCircuitBreakerBackend::new());
774
775        let cb1 = CircuitBreaker::new("svc", 2, Duration::from_secs(60))
776            .unwrap()
777            .with_backend(Arc::clone(&shared_backend));
778
779        let cb2 = CircuitBreaker::new("svc", 2, Duration::from_secs(60))
780            .unwrap()
781            .with_backend(Arc::clone(&shared_backend));
782
783        // Trigger one failure via cb1
784        let _: Result<(), AgentRuntimeError> = cb1.call(|| Err::<(), _>("fail".to_string()));
785
786        // cb2 should observe the failure recorded by cb1
787        assert_eq!(cb2.failure_count().unwrap(), 1);
788
789        // Trigger the second failure to open the circuit via cb1
790        let _: Result<(), AgentRuntimeError> = cb1.call(|| Err::<(), _>("fail again".to_string()));
791
792        // cb2 should now see the circuit as open
793        assert!(matches!(cb2.state().unwrap(), CircuitState::Open { .. }));
794    }
795
796    #[test]
797    fn test_in_memory_backend_increments_and_resets() {
798        let backend = InMemoryCircuitBreakerBackend::new();
799
800        assert_eq!(backend.get_failures("svc"), 0);
801
802        let count = backend.increment_failures("svc");
803        assert_eq!(count, 1);
804
805        let count = backend.increment_failures("svc");
806        assert_eq!(count, 2);
807
808        backend.reset_failures("svc");
809        assert_eq!(backend.get_failures("svc"), 0);
810
811        // open_at round-trip
812        assert!(backend.get_open_at("svc").is_none());
813        let now = Instant::now();
814        backend.set_open_at("svc", now);
815        assert!(backend.get_open_at("svc").is_some());
816        backend.clear_open_at("svc");
817        assert!(backend.get_open_at("svc").is_none());
818    }
819
820    // ── Deduplicator ──────────────────────────────────────────────────────────
821
822    #[test]
823    fn test_deduplicator_new_key_is_new() {
824        let d = Deduplicator::new(Duration::from_secs(60));
825        let r = d.check_and_register("key-1").unwrap();
826        assert_eq!(r, DeduplicationResult::New);
827    }
828
829    #[test]
830    fn test_deduplicator_second_check_is_in_progress() {
831        let d = Deduplicator::new(Duration::from_secs(60));
832        d.check_and_register("key-1").unwrap();
833        let r = d.check_and_register("key-1").unwrap();
834        assert_eq!(r, DeduplicationResult::InProgress);
835    }
836
837    #[test]
838    fn test_deduplicator_complete_makes_cached() {
839        let d = Deduplicator::new(Duration::from_secs(60));
840        d.check_and_register("key-1").unwrap();
841        d.complete("key-1", "result-value").unwrap();
842        let r = d.check_and_register("key-1").unwrap();
843        assert_eq!(r, DeduplicationResult::Cached("result-value".into()));
844    }
845
846    #[test]
847    fn test_deduplicator_different_keys_are_independent() {
848        let d = Deduplicator::new(Duration::from_secs(60));
849        d.check_and_register("key-a").unwrap();
850        let r = d.check_and_register("key-b").unwrap();
851        assert_eq!(r, DeduplicationResult::New);
852    }
853
854    #[test]
855    fn test_deduplicator_expired_entry_is_new() {
856        let d = Deduplicator::new(Duration::ZERO); // instant TTL
857        d.check_and_register("key-1").unwrap();
858        d.complete("key-1", "old").unwrap();
859        // Immediately expired — should be New again
860        let r = d.check_and_register("key-1").unwrap();
861        assert_eq!(r, DeduplicationResult::New);
862    }
863
864    // ── BackpressureGuard ─────────────────────────────────────────────────────
865
866    #[test]
867    fn test_backpressure_guard_rejects_zero_capacity() {
868        assert!(BackpressureGuard::new(0).is_err());
869    }
870
871    #[test]
872    fn test_backpressure_guard_acquire_within_capacity() {
873        let g = BackpressureGuard::new(5).unwrap();
874        assert!(g.try_acquire().is_ok());
875        assert_eq!(g.depth().unwrap(), 1);
876    }
877
878    #[test]
879    fn test_backpressure_guard_sheds_when_full() {
880        let g = BackpressureGuard::new(2).unwrap();
881        g.try_acquire().unwrap();
882        g.try_acquire().unwrap();
883        let result = g.try_acquire();
884        assert!(matches!(
885            result,
886            Err(AgentRuntimeError::BackpressureShed { .. })
887        ));
888    }
889
890    #[test]
891    fn test_backpressure_guard_release_decrements_depth() {
892        let g = BackpressureGuard::new(3).unwrap();
893        g.try_acquire().unwrap();
894        g.try_acquire().unwrap();
895        g.release().unwrap();
896        assert_eq!(g.depth().unwrap(), 1);
897    }
898
899    #[test]
900    fn test_backpressure_guard_release_on_empty_is_noop() {
901        let g = BackpressureGuard::new(3).unwrap();
902        g.release().unwrap(); // Should not fail
903        assert_eq!(g.depth().unwrap(), 0);
904    }
905
906    // ── Pipeline ──────────────────────────────────────────────────────────────
907
908    #[test]
909    fn test_pipeline_runs_stages_in_order() {
910        let p = Pipeline::new()
911            .add_stage("upper", |s| Ok(s.to_uppercase()))
912            .add_stage("append", |s| Ok(format!("{s}!")));
913        let result = p.run("hello".into()).unwrap();
914        assert_eq!(result, "HELLO!");
915    }
916
917    #[test]
918    fn test_pipeline_empty_pipeline_returns_input() {
919        let p = Pipeline::new();
920        assert_eq!(p.run("test".into()).unwrap(), "test");
921    }
922
923    #[test]
924    fn test_pipeline_stage_failure_short_circuits() {
925        let p = Pipeline::new()
926            .add_stage("fail", |_| {
927                Err(AgentRuntimeError::Orchestration("boom".into()))
928            })
929            .add_stage("never", |s| Ok(s));
930        assert!(p.run("input".into()).is_err());
931    }
932
933    #[test]
934    fn test_pipeline_stage_count() {
935        let p = Pipeline::new()
936            .add_stage("s1", |s| Ok(s))
937            .add_stage("s2", |s| Ok(s));
938        assert_eq!(p.stage_count(), 2);
939    }
940
941    // ── Item 13: BackpressureGuard soft limit ──────────────────────────────────
942
943    #[test]
944    fn test_backpressure_soft_limit_rejects_invalid_config() {
945        // soft >= capacity must be rejected
946        let g = BackpressureGuard::new(5).unwrap();
947        assert!(g.with_soft_limit(5).is_err());
948        let g = BackpressureGuard::new(5).unwrap();
949        assert!(g.with_soft_limit(6).is_err());
950    }
951
952    #[test]
953    fn test_backpressure_soft_limit_accepts_requests_below_soft() {
954        let g = BackpressureGuard::new(5)
955            .unwrap()
956            .with_soft_limit(2)
957            .unwrap();
958        // Both acquires below soft limit should succeed
959        assert!(g.try_acquire().is_ok());
960        assert!(g.try_acquire().is_ok());
961        assert_eq!(g.depth().unwrap(), 2);
962    }
963
964    #[test]
965    fn test_backpressure_with_soft_limit_still_sheds_at_hard_capacity() {
966        let g = BackpressureGuard::new(3)
967            .unwrap()
968            .with_soft_limit(2)
969            .unwrap();
970        g.try_acquire().unwrap();
971        g.try_acquire().unwrap();
972        g.try_acquire().unwrap(); // reaches hard limit
973        let result = g.try_acquire();
974        assert!(matches!(
975            result,
976            Err(AgentRuntimeError::BackpressureShed { .. })
977        ));
978    }
979
980    // ── Item 14: timed_lock concurrency correctness ───────────────────────────
981
982    #[test]
983    fn test_backpressure_concurrent_acquires_are_consistent() {
984        use std::sync::Arc;
985        use std::thread;
986
987        let g = Arc::new(BackpressureGuard::new(100).unwrap());
988        let mut handles = Vec::new();
989
990        for _ in 0..10 {
991            let g_clone = Arc::clone(&g);
992            handles.push(thread::spawn(move || {
993                g_clone.try_acquire().ok();
994            }));
995        }
996
997        for h in handles {
998            h.join().unwrap();
999        }
1000
1001        // All 10 threads acquired a slot; depth must be exactly 10
1002        assert_eq!(g.depth().unwrap(), 10);
1003    }
1004}