Skip to main content

llm_agent_runtime/
orchestrator.rs

1//! # Module: Orchestrator
2//!
3//! ## Responsibility
4//! Provides a composable LLM pipeline with circuit breaking, retry, deduplication,
5//! and backpressure. Mirrors the public API of `tokio-prompt-orchestrator`.
6//!
7//! ## Guarantees
8//! - Thread-safe: all types wrap state in `Arc<Mutex<_>>` or atomics
9//! - Circuit breaker opens after `threshold` consecutive failures
10//! - RetryPolicy delays grow exponentially and are capped at [`MAX_RETRY_DELAY`]
11//! - Deduplicator is deterministic and non-blocking
12//! - BackpressureGuard never exceeds declared hard capacity
13//! - Non-panicking: all operations return `Result`
14//!
15//! ## NOT Responsible For
16//! - Cross-node circuit breakers (single-process only, unless a distributed backend is provided)
17//! - Persistent deduplication (in-memory, bounded TTL)
18//! - Distributed backpressure
19//!
20//! ## Composing the Primitives
21//!
22//! The four primitives are designed to be layered. A typical production setup:
23//!
24//! ```text
25//! request
26//!   │
27//!   ▼
28//! BackpressureGuard  ← shed if too many in-flight requests
29//!   │
30//!   ▼
31//! Deduplicator       ← return cached result for duplicate keys
32//!   │
33//!   ▼
34//! CircuitBreaker     ← fast-fail if the downstream is unhealthy
35//!   │
36//!   ▼
37//! RetryPolicy        ← retry transient failures with exponential backoff
38//!   │
39//!   ▼
40//! Pipeline           ← transform request/response through named stages
41//! ```
42
43use crate::error::AgentRuntimeError;
44use crate::util::timed_lock;
45use std::collections::HashMap;
46use std::sync::{Arc, Mutex};
47use std::time::{Duration, Instant};
48
49/// Maximum delay between retries — caps exponential growth.
50pub const MAX_RETRY_DELAY: Duration = Duration::from_secs(60);
51
52// ── RetryPolicy ───────────────────────────────────────────────────────────────
53
54/// Exponential backoff retry policy.
55#[derive(Debug, Clone)]
56pub struct RetryPolicy {
57    /// Maximum number of attempts (including the first).
58    pub max_attempts: u32,
59    /// Base delay for the first retry.
60    pub base_delay: Duration,
61}
62
63impl RetryPolicy {
64    /// Create an exponential retry policy.
65    ///
66    /// # Arguments
67    /// * `max_attempts` — total attempt budget (must be ≥ 1)
68    /// * `base_ms` — base delay in milliseconds for attempt 1
69    ///
70    /// # Returns
71    /// - `Ok(RetryPolicy)` — on success
72    /// - `Err(AgentRuntimeError::Orchestration)` — if `max_attempts == 0`
73    pub fn exponential(max_attempts: u32, base_ms: u64) -> Result<Self, AgentRuntimeError> {
74        if max_attempts == 0 {
75            return Err(AgentRuntimeError::Orchestration(
76                "max_attempts must be >= 1".into(),
77            ));
78        }
79        if base_ms == 0 {
80            return Err(AgentRuntimeError::Orchestration(
81                "base_ms must be >= 1 to avoid zero-delay busy-loop retries".into(),
82            ));
83        }
84        Ok(Self {
85            max_attempts,
86            base_delay: Duration::from_millis(base_ms),
87        })
88    }
89
90    /// Compute the delay before the given attempt number (1-based).
91    ///
92    /// Delay = `base_delay * 2^(attempt-1)`, capped at `MAX_RETRY_DELAY`.
93    pub fn delay_for(&self, attempt: u32) -> Duration {
94        let exp = attempt.saturating_sub(1);
95        let multiplier = 1u64.checked_shl(exp.min(63)).unwrap_or(u64::MAX);
96        let millis = self
97            .base_delay
98            .as_millis()
99            .saturating_mul(multiplier as u128);
100        let raw = Duration::from_millis(millis.min(u64::MAX as u128) as u64);
101        raw.min(MAX_RETRY_DELAY)
102    }
103}
104
105// ── CircuitBreaker ────────────────────────────────────────────────────────────
106
107/// Tracks failure rates and opens when the threshold is exceeded.
108///
109/// States: `Closed` (normal) → `Open` (fast-fail) → `HalfOpen` (probe).
110///
111/// Note: `PartialEq` is implemented manually because the `Open` variant
112/// contains `std::time::Instant` which does not implement `Eq`. The manual
113/// implementation compares only the variant discriminant, not the timestamp.
114#[derive(Debug, Clone)]
115pub enum CircuitState {
116    /// Circuit is operating normally; requests pass through.
117    Closed,
118    /// Circuit has tripped; requests are fast-failed without calling the operation.
119    Open {
120        /// The instant at which the circuit was opened.
121        opened_at: Instant,
122    },
123    /// Recovery probe period; the next request will be attempted to test recovery.
124    HalfOpen,
125}
126
127impl PartialEq for CircuitState {
128    fn eq(&self, other: &Self) -> bool {
129        match (self, other) {
130            (CircuitState::Closed, CircuitState::Closed) => true,
131            (CircuitState::Open { .. }, CircuitState::Open { .. }) => true,
132            (CircuitState::HalfOpen, CircuitState::HalfOpen) => true,
133            _ => false,
134        }
135    }
136}
137
138impl Eq for CircuitState {}
139
140/// Backend for circuit breaker state storage.
141///
142/// Implement this trait to share circuit breaker state across processes
143/// (e.g., via Redis). The in-process default is `InMemoryCircuitBreakerBackend`.
144///
145/// Note: Methods are synchronous to avoid pulling in `async-trait`. A
146/// distributed backend (e.g., Redis) can internally spawn a Tokio runtime.
147pub trait CircuitBreakerBackend: Send + Sync {
148    /// Increment the consecutive failure count for `service` and return the new count.
149    fn increment_failures(&self, service: &str) -> u32;
150    /// Reset the consecutive failure count for `service` to zero.
151    fn reset_failures(&self, service: &str);
152    /// Return the current consecutive failure count for `service`.
153    fn get_failures(&self, service: &str) -> u32;
154    /// Record the instant at which the circuit was opened for `service`.
155    fn set_open_at(&self, service: &str, at: std::time::Instant);
156    /// Clear the open-at timestamp, effectively moving the circuit to Closed or HalfOpen.
157    fn clear_open_at(&self, service: &str);
158    /// Return the instant at which the circuit was opened, or `None` if it is not open.
159    fn get_open_at(&self, service: &str) -> Option<std::time::Instant>;
160}
161
162// ── InMemoryCircuitBreakerBackend ─────────────────────────────────────────────
163
164/// In-process circuit breaker backend backed by a `Mutex<HashMap>`.
165///
166/// Each service name gets its own independent failure counter and open-at
167/// timestamp.  Multiple `CircuitBreaker` instances that share the same
168/// backend (via [`CircuitBreaker::with_backend`]) will correctly track
169/// failures per service rather than sharing a single counter.
170pub struct InMemoryCircuitBreakerBackend {
171    inner: Arc<Mutex<HashMap<String, InMemoryServiceState>>>,
172}
173
174#[derive(Default)]
175struct InMemoryServiceState {
176    consecutive_failures: u32,
177    open_at: Option<std::time::Instant>,
178}
179
180impl InMemoryCircuitBreakerBackend {
181    /// Create a new in-memory backend with all counters at zero.
182    pub fn new() -> Self {
183        Self {
184            inner: Arc::new(Mutex::new(HashMap::new())),
185        }
186    }
187}
188
189impl Default for InMemoryCircuitBreakerBackend {
190    fn default() -> Self {
191        Self::new()
192    }
193}
194
195impl CircuitBreakerBackend for InMemoryCircuitBreakerBackend {
196    fn increment_failures(&self, service: &str) -> u32 {
197        let mut map = timed_lock(
198            &self.inner,
199            "InMemoryCircuitBreakerBackend::increment_failures",
200        );
201        let state = map.entry(service.to_owned()).or_default();
202        state.consecutive_failures += 1;
203        state.consecutive_failures
204    }
205
206    fn reset_failures(&self, service: &str) {
207        let mut map = timed_lock(&self.inner, "InMemoryCircuitBreakerBackend::reset_failures");
208        if let Some(state) = map.get_mut(service) {
209            state.consecutive_failures = 0;
210        }
211    }
212
213    fn get_failures(&self, service: &str) -> u32 {
214        let map = timed_lock(&self.inner, "InMemoryCircuitBreakerBackend::get_failures");
215        map.get(service).map_or(0, |s| s.consecutive_failures)
216    }
217
218    fn set_open_at(&self, service: &str, at: std::time::Instant) {
219        let mut map = timed_lock(&self.inner, "InMemoryCircuitBreakerBackend::set_open_at");
220        map.entry(service.to_owned()).or_default().open_at = Some(at);
221    }
222
223    fn clear_open_at(&self, service: &str) {
224        let mut map = timed_lock(&self.inner, "InMemoryCircuitBreakerBackend::clear_open_at");
225        if let Some(state) = map.get_mut(service) {
226            state.open_at = None;
227        }
228    }
229
230    fn get_open_at(&self, service: &str) -> Option<std::time::Instant> {
231        let map = timed_lock(&self.inner, "InMemoryCircuitBreakerBackend::get_open_at");
232        map.get(service).and_then(|s| s.open_at)
233    }
234}
235
236// ── CircuitBreaker ────────────────────────────────────────────────────────────
237
238/// Circuit breaker guarding a fallible operation.
239///
240/// ## Guarantees
241/// - Opens after `threshold` consecutive failures
242/// - Transitions to `HalfOpen` after `recovery_window` has elapsed
243/// - Closes on the first successful probe in `HalfOpen`
244#[derive(Clone)]
245pub struct CircuitBreaker {
246    threshold: u32,
247    recovery_window: Duration,
248    service: String,
249    backend: Arc<dyn CircuitBreakerBackend>,
250}
251
252impl std::fmt::Debug for CircuitBreaker {
253    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
254        f.debug_struct("CircuitBreaker")
255            .field("threshold", &self.threshold)
256            .field("recovery_window", &self.recovery_window)
257            .field("service", &self.service)
258            .finish()
259    }
260}
261
262impl CircuitBreaker {
263    /// Create a new circuit breaker backed by an in-memory backend.
264    ///
265    /// # Arguments
266    /// * `service` — name used in error messages and logs
267    /// * `threshold` — consecutive failures before opening
268    /// * `recovery_window` — how long to stay open before probing
269    pub fn new(
270        service: impl Into<String>,
271        threshold: u32,
272        recovery_window: Duration,
273    ) -> Result<Self, AgentRuntimeError> {
274        if threshold == 0 {
275            return Err(AgentRuntimeError::Orchestration(
276                "circuit breaker threshold must be >= 1".into(),
277            ));
278        }
279        let service = service.into();
280        Ok(Self {
281            threshold,
282            recovery_window,
283            service,
284            backend: Arc::new(InMemoryCircuitBreakerBackend::new()),
285        })
286    }
287
288    /// Replace the default in-memory backend with a custom one.
289    ///
290    /// Useful for sharing circuit breaker state across processes.
291    pub fn with_backend(mut self, backend: Arc<dyn CircuitBreakerBackend>) -> Self {
292        self.backend = backend;
293        self
294    }
295
296    /// Attempt to call `f`, respecting the circuit breaker state.
297    ///
298    /// # Errors
299    /// - `AgentRuntimeError::CircuitOpen` — the breaker is in the `Open` state
300    ///   and the recovery window has not yet elapsed
301    /// - `AgentRuntimeError::Orchestration` — `f` returned an error; the error
302    ///   message is the `Display` of the inner error. This call may open the
303    ///   breaker if it pushes the consecutive failure count above `threshold`.
304    #[tracing::instrument(skip(self, f))]
305    pub fn call<T, E, F>(&self, f: F) -> Result<T, AgentRuntimeError>
306    where
307        F: FnOnce() -> Result<T, E>,
308        E: std::fmt::Display,
309    {
310        // Determine effective state, potentially transitioning Open → HalfOpen.
311        let effective_state = match self.backend.get_open_at(&self.service) {
312            Some(opened_at) => {
313                if opened_at.elapsed() >= self.recovery_window {
314                    // Clear open_at to signal HalfOpen; failures remain.
315                    self.backend.clear_open_at(&self.service);
316                    tracing::info!("circuit moved to half-open for {}", self.service);
317                    CircuitState::HalfOpen
318                } else {
319                    CircuitState::Open { opened_at }
320                }
321            }
322            None => {
323                // Either Closed or HalfOpen (after a prior transition).
324                // We distinguish by checking whether failures >= threshold
325                // but no open_at is set — that means we are in HalfOpen.
326                let failures = self.backend.get_failures(&self.service);
327                if failures >= self.threshold {
328                    CircuitState::HalfOpen
329                } else {
330                    CircuitState::Closed
331                }
332            }
333        };
334
335        tracing::debug!("circuit state: {:?}", effective_state);
336
337        match effective_state {
338            CircuitState::Open { .. } => {
339                return Err(AgentRuntimeError::CircuitOpen {
340                    service: self.service.clone(),
341                });
342            }
343            CircuitState::Closed | CircuitState::HalfOpen => {}
344        }
345
346        // Execute the operation.
347        match f() {
348            Ok(val) => {
349                self.backend.reset_failures(&self.service);
350                self.backend.clear_open_at(&self.service);
351                tracing::info!("circuit closed for {}", self.service);
352                Ok(val)
353            }
354            Err(e) => {
355                let failures = self.backend.increment_failures(&self.service);
356                if failures >= self.threshold {
357                    let now = Instant::now();
358                    self.backend.set_open_at(&self.service, now);
359                    tracing::info!("circuit opened for {}", self.service);
360                }
361                Err(AgentRuntimeError::Orchestration(e.to_string()))
362            }
363        }
364    }
365
366    /// Return the current circuit state.
367    pub fn state(&self) -> Result<CircuitState, AgentRuntimeError> {
368        let state = match self.backend.get_open_at(&self.service) {
369            Some(opened_at) => {
370                if opened_at.elapsed() >= self.recovery_window {
371                    // Would transition to HalfOpen on next call; report HalfOpen.
372                    let failures = self.backend.get_failures(&self.service);
373                    if failures >= self.threshold {
374                        CircuitState::HalfOpen
375                    } else {
376                        CircuitState::Closed
377                    }
378                } else {
379                    CircuitState::Open { opened_at }
380                }
381            }
382            None => {
383                let failures = self.backend.get_failures(&self.service);
384                if failures >= self.threshold {
385                    CircuitState::HalfOpen
386                } else {
387                    CircuitState::Closed
388                }
389            }
390        };
391        Ok(state)
392    }
393
394    /// Return the consecutive failure count.
395    pub fn failure_count(&self) -> Result<u32, AgentRuntimeError> {
396        Ok(self.backend.get_failures(&self.service))
397    }
398}
399
400// ── DeduplicationResult ───────────────────────────────────────────────────────
401
402/// Result of a deduplication check.
403#[derive(Debug, Clone, PartialEq)]
404pub enum DeduplicationResult {
405    /// This is a new, unseen request.
406    New,
407    /// A cached result exists for this key.
408    Cached(String),
409    /// A matching request is currently in-flight.
410    InProgress,
411}
412
413/// Deduplicates requests by key within a TTL window.
414///
415/// ## Guarantees
416/// - Deterministic: same key always maps to the same result
417/// - Thread-safe via `Arc<Mutex<_>>`
418/// - Entries expire after `ttl`
419#[derive(Debug, Clone)]
420pub struct Deduplicator {
421    ttl: Duration,
422    inner: Arc<Mutex<DeduplicatorInner>>,
423}
424
425#[derive(Debug)]
426struct DeduplicatorInner {
427    cache: HashMap<String, (String, Instant)>, // key → (result, inserted_at)
428    in_flight: HashMap<String, Instant>,       // key → started_at
429}
430
431impl Deduplicator {
432    /// Create a new deduplicator with the given TTL.
433    pub fn new(ttl: Duration) -> Self {
434        Self {
435            ttl,
436            inner: Arc::new(Mutex::new(DeduplicatorInner {
437                cache: HashMap::new(),
438                in_flight: HashMap::new(),
439            })),
440        }
441    }
442
443    /// Check whether `key` is new, cached, or in-flight.
444    ///
445    /// Marks the key as in-flight if it is new.
446    pub fn check_and_register(&self, key: &str) -> Result<DeduplicationResult, AgentRuntimeError> {
447        let mut inner = timed_lock(&self.inner, "Deduplicator::check_and_register");
448
449        let now = Instant::now();
450
451        // Expire stale cache entries
452        inner
453            .cache
454            .retain(|_, (_, ts)| now.duration_since(*ts) < self.ttl);
455        inner
456            .in_flight
457            .retain(|_, ts| now.duration_since(*ts) < self.ttl);
458
459        if let Some((result, _)) = inner.cache.get(key) {
460            return Ok(DeduplicationResult::Cached(result.clone()));
461        }
462
463        if inner.in_flight.contains_key(key) {
464            return Ok(DeduplicationResult::InProgress);
465        }
466
467        inner.in_flight.insert(key.to_owned(), now);
468        Ok(DeduplicationResult::New)
469    }
470
471    /// Check deduplication state for a key with a per-call TTL override.
472    ///
473    /// Marks the key as in-flight if it is new. Ignores the stored TTL and uses
474    /// `ttl` instead for expiry checks.
475    pub fn check(&self, key: &str, ttl: std::time::Duration) -> Result<DeduplicationResult, AgentRuntimeError> {
476        let mut inner = timed_lock(&self.inner, "Deduplicator::check");
477        let now = Instant::now();
478
479        inner.cache.retain(|_, (_, ts)| now.duration_since(*ts) < ttl);
480        inner.in_flight.retain(|_, ts| now.duration_since(*ts) < ttl);
481
482        if let Some((result, _)) = inner.cache.get(key) {
483            return Ok(DeduplicationResult::Cached(result.clone()));
484        }
485
486        if inner.in_flight.contains_key(key) {
487            return Ok(DeduplicationResult::InProgress);
488        }
489
490        inner.in_flight.insert(key.to_owned(), now);
491        Ok(DeduplicationResult::New)
492    }
493
494    /// Check deduplication state for multiple keys at once.
495    ///
496    /// Returns results in the same order as `requests`.
497    /// Each entry is `(key, ttl)` — same signature as `check`.
498    pub fn dedup_many(
499        &self,
500        requests: &[(&str, std::time::Duration)],
501    ) -> Result<Vec<DeduplicationResult>, AgentRuntimeError> {
502        requests
503            .iter()
504            .map(|(key, ttl)| self.check(key, *ttl))
505            .collect()
506    }
507
508    /// Complete a request: move from in-flight to cached with the given result.
509    pub fn complete(&self, key: &str, result: impl Into<String>) -> Result<(), AgentRuntimeError> {
510        let mut inner = timed_lock(&self.inner, "Deduplicator::complete");
511        inner.in_flight.remove(key);
512        inner
513            .cache
514            .insert(key.to_owned(), (result.into(), Instant::now()));
515        Ok(())
516    }
517
518    /// Remove a key from in-flight tracking without caching a result.
519    ///
520    /// Call this when an in-flight operation fails so that subsequent callers
521    /// are not permanently blocked by a stuck `InProgress` entry for the full TTL.
522    pub fn fail(&self, key: &str) -> Result<(), AgentRuntimeError> {
523        let mut inner = timed_lock(&self.inner, "Deduplicator::fail");
524        inner.in_flight.remove(key);
525        Ok(())
526    }
527}
528
529// ── BackpressureGuard ─────────────────────────────────────────────────────────
530
531/// Tracks in-flight work count and enforces a capacity limit.
532///
533/// ## Guarantees
534/// - Thread-safe via `Arc<Mutex<_>>`
535/// - `try_acquire` is non-blocking
536/// - `release` decrements the counter; no-op if counter is already 0
537/// - Optional soft limit emits a warning when depth reaches the threshold
538#[derive(Debug, Clone)]
539pub struct BackpressureGuard {
540    capacity: usize,
541    soft_capacity: Option<usize>,
542    inner: Arc<Mutex<usize>>,
543}
544
545impl BackpressureGuard {
546    /// Create a new guard with the given capacity.
547    ///
548    /// # Returns
549    /// - `Ok(BackpressureGuard)` — on success
550    /// - `Err(AgentRuntimeError::Orchestration)` — if `capacity == 0`
551    pub fn new(capacity: usize) -> Result<Self, AgentRuntimeError> {
552        if capacity == 0 {
553            return Err(AgentRuntimeError::Orchestration(
554                "BackpressureGuard capacity must be > 0".into(),
555            ));
556        }
557        Ok(Self {
558            capacity,
559            soft_capacity: None,
560            inner: Arc::new(Mutex::new(0)),
561        })
562    }
563
564    /// Set a soft capacity threshold. When depth reaches this level, a warning
565    /// is logged but the request is still accepted (up to hard capacity).
566    pub fn with_soft_limit(mut self, soft: usize) -> Result<Self, AgentRuntimeError> {
567        if soft >= self.capacity {
568            return Err(AgentRuntimeError::Orchestration(
569                "soft_capacity must be less than hard capacity".into(),
570            ));
571        }
572        self.soft_capacity = Some(soft);
573        Ok(self)
574    }
575
576    /// Try to acquire a slot.
577    ///
578    /// Emits a warning when the soft limit is reached (if configured), but
579    /// still accepts the request until hard capacity is exceeded.
580    ///
581    /// # Returns
582    /// - `Ok(())` — slot acquired
583    /// - `Err(AgentRuntimeError::BackpressureShed)` — hard capacity exceeded
584    pub fn try_acquire(&self) -> Result<(), AgentRuntimeError> {
585        let mut depth = timed_lock(&self.inner, "BackpressureGuard::try_acquire");
586        if *depth >= self.capacity {
587            return Err(AgentRuntimeError::BackpressureShed {
588                depth: *depth,
589                capacity: self.capacity,
590            });
591        }
592        *depth += 1;
593        if let Some(soft) = self.soft_capacity {
594            if *depth >= soft {
595                tracing::warn!(
596                    depth = *depth,
597                    soft_capacity = soft,
598                    hard_capacity = self.capacity,
599                    "backpressure approaching hard limit"
600                );
601            }
602        }
603        Ok(())
604    }
605
606    /// Release a previously acquired slot.
607    pub fn release(&self) -> Result<(), AgentRuntimeError> {
608        let mut depth = timed_lock(&self.inner, "BackpressureGuard::release");
609        *depth = depth.saturating_sub(1);
610        Ok(())
611    }
612
613    /// Return the hard capacity (maximum concurrent slots) configured for this guard.
614    pub fn hard_capacity(&self) -> usize {
615        self.capacity
616    }
617
618    /// Return the current depth.
619    pub fn depth(&self) -> Result<usize, AgentRuntimeError> {
620        let depth = timed_lock(&self.inner, "BackpressureGuard::depth");
621        Ok(*depth)
622    }
623
624    /// Return the ratio of current depth to soft capacity as a value in `[0.0, ∞)`.
625    ///
626    /// Returns `0.0` if no soft limit has been configured.
627    /// Values above `1.0` mean the soft limit has been exceeded.
628    pub fn soft_depth_ratio(&self) -> f32 {
629        match self.soft_capacity {
630            None => 0.0,
631            Some(soft) => {
632                let depth = timed_lock(&self.inner, "BackpressureGuard::soft_depth_ratio");
633                *depth as f32 / soft as f32
634            }
635        }
636    }
637}
638
639// ── Pipeline ──────────────────────────────────────────────────────────────────
640
641/// Result of executing a pipeline, including per-stage timing.
642#[derive(Debug)]
643pub struct PipelineResult {
644    /// Final output value after all stages.
645    pub output: String,
646    /// Per-stage timing: (stage_index, duration_ms).
647    pub stage_timings: Vec<(usize, u64)>,
648}
649
650/// A single named stage in the pipeline.
651pub struct Stage {
652    /// Human-readable name used in log output and error messages.
653    pub name: String,
654    /// The transform function; receives the current string and returns the transformed string.
655    pub handler: Box<dyn Fn(String) -> Result<String, AgentRuntimeError> + Send + Sync>,
656}
657
658impl std::fmt::Debug for Stage {
659    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
660        f.debug_struct("Stage").field("name", &self.name).finish()
661    }
662}
663
664/// Error handler callback type for pipeline stage failures.
665type StageErrorHandler = Box<dyn Fn(&str, &str) -> String + Send + Sync>;
666
667/// A composable pipeline that passes a string through a sequence of named stages.
668///
669/// ## Guarantees
670/// - Stages execute in insertion order
671/// - First stage failure short-circuits remaining stages (unless an error handler is set)
672/// - Non-panicking
673pub struct Pipeline {
674    stages: Vec<Stage>,
675    error_handler: Option<StageErrorHandler>,
676}
677
678impl std::fmt::Debug for Pipeline {
679    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
680        f.debug_struct("Pipeline")
681            .field("stages", &self.stages)
682            .field("has_error_handler", &self.error_handler.is_some())
683            .finish()
684    }
685}
686
687impl Pipeline {
688    /// Create a new empty pipeline.
689    pub fn new() -> Self {
690        Self { stages: Vec::new(), error_handler: None }
691    }
692
693    /// Attach a recovery callback for stage failures.
694    ///
695    /// When a stage fails, `handler(stage_name, error_message)` is called.
696    /// The returned string becomes the input to the next stage.
697    /// If no handler is set, stage failures propagate as errors.
698    pub fn with_error_handler(
699        mut self,
700        handler: impl Fn(&str, &str) -> String + Send + Sync + 'static,
701    ) -> Self {
702        self.error_handler = Some(Box::new(handler));
703        self
704    }
705
706    /// Append a stage to the pipeline.
707    pub fn add_stage(
708        mut self,
709        name: impl Into<String>,
710        handler: impl Fn(String) -> Result<String, AgentRuntimeError> + Send + Sync + 'static,
711    ) -> Self {
712        self.stages.push(Stage {
713            name: name.into(),
714            handler: Box::new(handler),
715        });
716        self
717    }
718
719    /// Execute the pipeline, passing `input` through each stage in order.
720    #[tracing::instrument(skip(self))]
721    pub fn run(&self, input: String) -> Result<String, AgentRuntimeError> {
722        let mut current = input;
723        for stage in &self.stages {
724            tracing::debug!(stage = %stage.name, "running pipeline stage");
725            match (stage.handler)(current) {
726                Ok(out) => current = out,
727                Err(e) => {
728                    tracing::error!(stage = %stage.name, error = %e, "pipeline stage failed");
729                    if let Some(ref handler) = self.error_handler {
730                        current = handler(&stage.name, &e.to_string());
731                    } else {
732                        return Err(e);
733                    }
734                }
735            }
736        }
737        Ok(current)
738    }
739
740    /// Execute the pipeline with per-stage timing.
741    pub fn execute_timed(&self, input: String) -> Result<PipelineResult, AgentRuntimeError> {
742        let mut current = input;
743        let mut stage_timings = Vec::new();
744        for (idx, stage) in self.stages.iter().enumerate() {
745            let start = std::time::Instant::now();
746            tracing::debug!(stage = %stage.name, "running timed pipeline stage");
747            match (stage.handler)(current) {
748                Ok(out) => current = out,
749                Err(e) => {
750                    tracing::error!(stage = %stage.name, error = %e, "timed pipeline stage failed");
751                    if let Some(ref handler) = self.error_handler {
752                        current = handler(&stage.name, &e.to_string());
753                    } else {
754                        return Err(e);
755                    }
756                }
757            }
758            let duration_ms = start.elapsed().as_millis() as u64;
759            stage_timings.push((idx, duration_ms));
760        }
761        Ok(PipelineResult {
762            output: current,
763            stage_timings,
764        })
765    }
766
767    /// Return the number of stages in the pipeline.
768    pub fn stage_count(&self) -> usize {
769        self.stages.len()
770    }
771}
772
773impl Default for Pipeline {
774    fn default() -> Self {
775        Self::new()
776    }
777}
778
779// ── Tests ─────────────────────────────────────────────────────────────────────
780
781#[cfg(test)]
782mod tests {
783    use super::*;
784
785    // ── RetryPolicy ───────────────────────────────────────────────────────────
786
787    #[test]
788    fn test_retry_policy_rejects_zero_attempts() {
789        assert!(RetryPolicy::exponential(0, 100).is_err());
790    }
791
792    #[test]
793    fn test_retry_policy_delay_attempt_1_equals_base() {
794        let p = RetryPolicy::exponential(3, 100).unwrap();
795        assert_eq!(p.delay_for(1), Duration::from_millis(100));
796    }
797
798    #[test]
799    fn test_retry_policy_delay_doubles_each_attempt() {
800        let p = RetryPolicy::exponential(5, 100).unwrap();
801        assert_eq!(p.delay_for(2), Duration::from_millis(200));
802        assert_eq!(p.delay_for(3), Duration::from_millis(400));
803        assert_eq!(p.delay_for(4), Duration::from_millis(800));
804    }
805
806    #[test]
807    fn test_retry_policy_delay_capped_at_max() {
808        let p = RetryPolicy::exponential(10, 10_000).unwrap();
809        assert_eq!(p.delay_for(10), MAX_RETRY_DELAY);
810    }
811
812    #[test]
813    fn test_retry_policy_delay_never_exceeds_max_for_any_attempt() {
814        let p = RetryPolicy::exponential(10, 1000).unwrap();
815        for attempt in 1..=10 {
816            assert!(p.delay_for(attempt) <= MAX_RETRY_DELAY);
817        }
818    }
819
820    // ── CircuitBreaker ────────────────────────────────────────────────────────
821
822    #[test]
823    fn test_circuit_breaker_rejects_zero_threshold() {
824        assert!(CircuitBreaker::new("svc", 0, Duration::from_secs(1)).is_err());
825    }
826
827    #[test]
828    fn test_circuit_breaker_starts_closed() {
829        let cb = CircuitBreaker::new("svc", 3, Duration::from_secs(60)).unwrap();
830        assert_eq!(cb.state().unwrap(), CircuitState::Closed);
831    }
832
833    #[test]
834    fn test_circuit_breaker_success_keeps_closed() {
835        let cb = CircuitBreaker::new("svc", 3, Duration::from_secs(60)).unwrap();
836        let result: Result<i32, AgentRuntimeError> = cb.call(|| Ok::<i32, AgentRuntimeError>(42));
837        assert!(result.is_ok());
838        assert_eq!(cb.state().unwrap(), CircuitState::Closed);
839    }
840
841    #[test]
842    fn test_circuit_breaker_opens_after_threshold_failures() {
843        let cb = CircuitBreaker::new("svc", 3, Duration::from_secs(60)).unwrap();
844        for _ in 0..3 {
845            let _: Result<(), AgentRuntimeError> = cb.call(|| Err::<(), _>("oops".to_string()));
846        }
847        assert!(matches!(cb.state().unwrap(), CircuitState::Open { .. }));
848    }
849
850    #[test]
851    fn test_circuit_breaker_open_fast_fails() {
852        let cb = CircuitBreaker::new("svc", 1, Duration::from_secs(3600)).unwrap();
853        let _: Result<(), AgentRuntimeError> = cb.call(|| Err::<(), _>("fail".to_string()));
854        let result: Result<(), AgentRuntimeError> = cb.call(|| Ok::<(), AgentRuntimeError>(()));
855        assert!(matches!(result, Err(AgentRuntimeError::CircuitOpen { .. })));
856    }
857
858    #[test]
859    fn test_circuit_breaker_success_resets_failure_count() {
860        let cb = CircuitBreaker::new("svc", 5, Duration::from_secs(60)).unwrap();
861        let _: Result<(), AgentRuntimeError> = cb.call(|| Err::<(), _>("fail".to_string()));
862        let _: Result<(), AgentRuntimeError> = cb.call(|| Err::<(), _>("fail".to_string()));
863        let _: Result<i32, AgentRuntimeError> = cb.call(|| Ok::<i32, AgentRuntimeError>(1));
864        assert_eq!(cb.failure_count().unwrap(), 0);
865    }
866
867    #[test]
868    fn test_circuit_breaker_half_open_on_recovery() {
869        // Use a zero recovery window to immediately go half-open
870        let cb = CircuitBreaker::new("svc", 1, Duration::ZERO).unwrap();
871        let _: Result<(), AgentRuntimeError> = cb.call(|| Err::<(), _>("fail".to_string()));
872        // After recovery window, next call should probe (half-open → closed on success)
873        let result: Result<i32, AgentRuntimeError> = cb.call(|| Ok::<i32, AgentRuntimeError>(99));
874        assert_eq!(result.unwrap_or(0), 99);
875        assert_eq!(cb.state().unwrap(), CircuitState::Closed);
876    }
877
878    #[test]
879    fn test_circuit_breaker_with_custom_backend_uses_backend_state() {
880        // Build a custom backend and share it between two circuit breakers
881        // to verify that state is read from and written to the backend.
882        let shared_backend: Arc<dyn CircuitBreakerBackend> =
883            Arc::new(InMemoryCircuitBreakerBackend::new());
884
885        let cb1 = CircuitBreaker::new("svc", 2, Duration::from_secs(60))
886            .unwrap()
887            .with_backend(Arc::clone(&shared_backend));
888
889        let cb2 = CircuitBreaker::new("svc", 2, Duration::from_secs(60))
890            .unwrap()
891            .with_backend(Arc::clone(&shared_backend));
892
893        // Trigger one failure via cb1
894        let _: Result<(), AgentRuntimeError> = cb1.call(|| Err::<(), _>("fail".to_string()));
895
896        // cb2 should observe the failure recorded by cb1
897        assert_eq!(cb2.failure_count().unwrap(), 1);
898
899        // Trigger the second failure to open the circuit via cb1
900        let _: Result<(), AgentRuntimeError> = cb1.call(|| Err::<(), _>("fail again".to_string()));
901
902        // cb2 should now see the circuit as open
903        assert!(matches!(cb2.state().unwrap(), CircuitState::Open { .. }));
904    }
905
906    #[test]
907    fn test_in_memory_backend_increments_and_resets() {
908        let backend = InMemoryCircuitBreakerBackend::new();
909
910        assert_eq!(backend.get_failures("svc"), 0);
911
912        let count = backend.increment_failures("svc");
913        assert_eq!(count, 1);
914
915        let count = backend.increment_failures("svc");
916        assert_eq!(count, 2);
917
918        backend.reset_failures("svc");
919        assert_eq!(backend.get_failures("svc"), 0);
920
921        // open_at round-trip
922        assert!(backend.get_open_at("svc").is_none());
923        let now = Instant::now();
924        backend.set_open_at("svc", now);
925        assert!(backend.get_open_at("svc").is_some());
926        backend.clear_open_at("svc");
927        assert!(backend.get_open_at("svc").is_none());
928    }
929
930    // ── Deduplicator ──────────────────────────────────────────────────────────
931
932    #[test]
933    fn test_deduplicator_new_key_is_new() {
934        let d = Deduplicator::new(Duration::from_secs(60));
935        let r = d.check_and_register("key-1").unwrap();
936        assert_eq!(r, DeduplicationResult::New);
937    }
938
939    #[test]
940    fn test_deduplicator_second_check_is_in_progress() {
941        let d = Deduplicator::new(Duration::from_secs(60));
942        d.check_and_register("key-1").unwrap();
943        let r = d.check_and_register("key-1").unwrap();
944        assert_eq!(r, DeduplicationResult::InProgress);
945    }
946
947    #[test]
948    fn test_deduplicator_complete_makes_cached() {
949        let d = Deduplicator::new(Duration::from_secs(60));
950        d.check_and_register("key-1").unwrap();
951        d.complete("key-1", "result-value").unwrap();
952        let r = d.check_and_register("key-1").unwrap();
953        assert_eq!(r, DeduplicationResult::Cached("result-value".into()));
954    }
955
956    #[test]
957    fn test_deduplicator_different_keys_are_independent() {
958        let d = Deduplicator::new(Duration::from_secs(60));
959        d.check_and_register("key-a").unwrap();
960        let r = d.check_and_register("key-b").unwrap();
961        assert_eq!(r, DeduplicationResult::New);
962    }
963
964    #[test]
965    fn test_deduplicator_expired_entry_is_new() {
966        let d = Deduplicator::new(Duration::ZERO); // instant TTL
967        d.check_and_register("key-1").unwrap();
968        d.complete("key-1", "old").unwrap();
969        // Immediately expired — should be New again
970        let r = d.check_and_register("key-1").unwrap();
971        assert_eq!(r, DeduplicationResult::New);
972    }
973
974    // ── BackpressureGuard ─────────────────────────────────────────────────────
975
976    #[test]
977    fn test_backpressure_guard_rejects_zero_capacity() {
978        assert!(BackpressureGuard::new(0).is_err());
979    }
980
981    #[test]
982    fn test_backpressure_guard_acquire_within_capacity() {
983        let g = BackpressureGuard::new(5).unwrap();
984        assert!(g.try_acquire().is_ok());
985        assert_eq!(g.depth().unwrap(), 1);
986    }
987
988    #[test]
989    fn test_backpressure_guard_sheds_when_full() {
990        let g = BackpressureGuard::new(2).unwrap();
991        g.try_acquire().unwrap();
992        g.try_acquire().unwrap();
993        let result = g.try_acquire();
994        assert!(matches!(
995            result,
996            Err(AgentRuntimeError::BackpressureShed { .. })
997        ));
998    }
999
1000    #[test]
1001    fn test_backpressure_guard_release_decrements_depth() {
1002        let g = BackpressureGuard::new(3).unwrap();
1003        g.try_acquire().unwrap();
1004        g.try_acquire().unwrap();
1005        g.release().unwrap();
1006        assert_eq!(g.depth().unwrap(), 1);
1007    }
1008
1009    #[test]
1010    fn test_backpressure_guard_release_on_empty_is_noop() {
1011        let g = BackpressureGuard::new(3).unwrap();
1012        g.release().unwrap(); // Should not fail
1013        assert_eq!(g.depth().unwrap(), 0);
1014    }
1015
1016    // ── Pipeline ──────────────────────────────────────────────────────────────
1017
1018    #[test]
1019    fn test_pipeline_runs_stages_in_order() {
1020        let p = Pipeline::new()
1021            .add_stage("upper", |s| Ok(s.to_uppercase()))
1022            .add_stage("append", |s| Ok(format!("{s}!")));
1023        let result = p.run("hello".into()).unwrap();
1024        assert_eq!(result, "HELLO!");
1025    }
1026
1027    #[test]
1028    fn test_pipeline_empty_pipeline_returns_input() {
1029        let p = Pipeline::new();
1030        assert_eq!(p.run("test".into()).unwrap(), "test");
1031    }
1032
1033    #[test]
1034    fn test_pipeline_stage_failure_short_circuits() {
1035        let p = Pipeline::new()
1036            .add_stage("fail", |_| {
1037                Err(AgentRuntimeError::Orchestration("boom".into()))
1038            })
1039            .add_stage("never", |s| Ok(s));
1040        assert!(p.run("input".into()).is_err());
1041    }
1042
1043    #[test]
1044    fn test_pipeline_stage_count() {
1045        let p = Pipeline::new()
1046            .add_stage("s1", |s| Ok(s))
1047            .add_stage("s2", |s| Ok(s));
1048        assert_eq!(p.stage_count(), 2);
1049    }
1050
1051    #[test]
1052    fn test_pipeline_execute_timed_captures_stage_durations() {
1053        let p = Pipeline::new()
1054            .add_stage("s1", |s| Ok(format!("{s}1")))
1055            .add_stage("s2", |s| Ok(format!("{s}2")));
1056        let result = p.execute_timed("x".to_string()).unwrap();
1057        assert_eq!(result.output, "x12");
1058        assert_eq!(result.stage_timings.len(), 2);
1059        assert_eq!(result.stage_timings[0].0, 0);
1060        assert_eq!(result.stage_timings[1].0, 1);
1061    }
1062
1063    // ── Item 13: BackpressureGuard soft limit ──────────────────────────────────
1064
1065    #[test]
1066    fn test_backpressure_soft_limit_rejects_invalid_config() {
1067        // soft >= capacity must be rejected
1068        let g = BackpressureGuard::new(5).unwrap();
1069        assert!(g.with_soft_limit(5).is_err());
1070        let g = BackpressureGuard::new(5).unwrap();
1071        assert!(g.with_soft_limit(6).is_err());
1072    }
1073
1074    #[test]
1075    fn test_backpressure_soft_limit_accepts_requests_below_soft() {
1076        let g = BackpressureGuard::new(5)
1077            .unwrap()
1078            .with_soft_limit(2)
1079            .unwrap();
1080        // Both acquires below soft limit should succeed
1081        assert!(g.try_acquire().is_ok());
1082        assert!(g.try_acquire().is_ok());
1083        assert_eq!(g.depth().unwrap(), 2);
1084    }
1085
1086    #[test]
1087    fn test_backpressure_with_soft_limit_still_sheds_at_hard_capacity() {
1088        let g = BackpressureGuard::new(3)
1089            .unwrap()
1090            .with_soft_limit(2)
1091            .unwrap();
1092        g.try_acquire().unwrap();
1093        g.try_acquire().unwrap();
1094        g.try_acquire().unwrap(); // reaches hard limit
1095        let result = g.try_acquire();
1096        assert!(matches!(
1097            result,
1098            Err(AgentRuntimeError::BackpressureShed { .. })
1099        ));
1100    }
1101
1102    // ── #4/#31 BackpressureGuard::hard_capacity ───────────────────────────────
1103
1104    #[test]
1105    fn test_backpressure_hard_capacity_matches_new() {
1106        let g = BackpressureGuard::new(7).unwrap();
1107        assert_eq!(g.hard_capacity(), 7);
1108    }
1109
1110    // ── #10 Pipeline::with_error_handler ──────────────────────────────────────
1111
1112    #[test]
1113    fn test_pipeline_error_handler_recovers_from_stage_failure() {
1114        let p = Pipeline::new()
1115            .add_stage("fail_stage", |_| {
1116                Err(AgentRuntimeError::Orchestration("oops".into()))
1117            })
1118            .add_stage("append", |s| Ok(format!("{s}-recovered")))
1119            .with_error_handler(|stage_name, _err| format!("recovered_from_{stage_name}"));
1120        let result = p.run("input".to_string()).unwrap();
1121        assert_eq!(result, "recovered_from_fail_stage-recovered");
1122    }
1123
1124    // ── #11/#32 CircuitState PartialEq/Eq ────────────────────────────────────
1125
1126    #[test]
1127    fn test_circuit_state_eq() {
1128        assert_eq!(CircuitState::Closed, CircuitState::Closed);
1129        assert_eq!(CircuitState::HalfOpen, CircuitState::HalfOpen);
1130        assert_eq!(
1131            CircuitState::Open { opened_at: std::time::Instant::now() },
1132            CircuitState::Open { opened_at: std::time::Instant::now() }
1133        );
1134        assert_ne!(CircuitState::Closed, CircuitState::HalfOpen);
1135        assert_ne!(CircuitState::Closed, CircuitState::Open { opened_at: std::time::Instant::now() });
1136    }
1137
1138    // ── #18 Deduplicator::dedup_many ──────────────────────────────────────────
1139
1140    #[test]
1141    fn test_dedup_many_independent_keys() {
1142        let d = Deduplicator::new(Duration::from_secs(60));
1143        let ttl = Duration::from_secs(60);
1144        let results = d.dedup_many(&[("key-a", ttl), ("key-b", ttl), ("key-c", ttl)]).unwrap();
1145        assert_eq!(results.len(), 3);
1146        assert!(results.iter().all(|r| matches!(r, DeduplicationResult::New)));
1147    }
1148
1149    // ── Task 11: Concurrent CircuitBreaker state transition tests ─────────────
1150
1151    #[test]
1152    fn test_concurrent_circuit_breaker_opens_under_concurrent_failures() {
1153        use std::sync::Arc;
1154        use std::thread;
1155
1156        let cb = Arc::new(
1157            CircuitBreaker::new("svc", 5, Duration::from_secs(60)).unwrap(),
1158        );
1159        let n_threads = 8;
1160        let failures_per_thread = 2;
1161
1162        let mut handles = Vec::new();
1163        for _ in 0..n_threads {
1164            let cb = Arc::clone(&cb);
1165            handles.push(thread::spawn(move || {
1166                for _ in 0..failures_per_thread {
1167                    let _ = cb.call(|| Err::<(), &str>("fail"));
1168                }
1169            }));
1170        }
1171        for h in handles {
1172            h.join().unwrap();
1173        }
1174
1175        // After n_threads * failures_per_thread = 16 failures with threshold=5,
1176        // the circuit must be Open.
1177        let state = cb.state().unwrap();
1178        assert!(
1179            matches!(state, CircuitState::Open { .. }),
1180            "circuit should be open after many concurrent failures; got: {state:?}"
1181        );
1182    }
1183
1184    #[test]
1185    fn test_per_service_tracking_is_independent() {
1186        let backend = Arc::new(InMemoryCircuitBreakerBackend::new());
1187
1188        let cb_a = CircuitBreaker::new("service-a", 3, Duration::from_secs(60))
1189            .unwrap()
1190            .with_backend(Arc::clone(&backend) as Arc<dyn CircuitBreakerBackend>);
1191        let cb_b = CircuitBreaker::new("service-b", 3, Duration::from_secs(60))
1192            .unwrap()
1193            .with_backend(Arc::clone(&backend) as Arc<dyn CircuitBreakerBackend>);
1194
1195        // Fail service-a 3 times → opens
1196        for _ in 0..3 {
1197            let _ = cb_a.call(|| Err::<(), &str>("fail"));
1198        }
1199
1200        // service-b should still be Closed
1201        let state_b = cb_b.state().unwrap();
1202        assert_eq!(
1203            state_b,
1204            CircuitState::Closed,
1205            "service-b should be unaffected by service-a failures"
1206        );
1207
1208        // service-a should be Open
1209        let state_a = cb_a.state().unwrap();
1210        assert!(
1211            matches!(state_a, CircuitState::Open { .. }),
1212            "service-a should be open"
1213        );
1214    }
1215
1216    // ── Item 14: timed_lock concurrency correctness ───────────────────────────
1217
1218    #[test]
1219    fn test_backpressure_concurrent_acquires_are_consistent() {
1220        use std::sync::Arc;
1221        use std::thread;
1222
1223        let g = Arc::new(BackpressureGuard::new(100).unwrap());
1224        let mut handles = Vec::new();
1225
1226        for _ in 0..10 {
1227            let g_clone = Arc::clone(&g);
1228            handles.push(thread::spawn(move || {
1229                g_clone.try_acquire().ok();
1230            }));
1231        }
1232
1233        for h in handles {
1234            h.join().unwrap();
1235        }
1236
1237        // All 10 threads acquired a slot; depth must be exactly 10
1238        assert_eq!(g.depth().unwrap(), 10);
1239    }
1240}