Skip to main content

entelix_core/tools/
retry_layer.rs

1//! `RetryToolLayer` — `tower::Layer<Service<ToolInvocation>>` that
2//! turns the metadata-level [`RetryHint`](crate::tools::RetryHint)
3//! contract into runtime retry behaviour.
4//!
5//! Tool authors annotate intent on
6//! [`ToolMetadata::retry_hint`](crate::tools::ToolMetadata) when a
7//! tool is *idempotent* and *transport-bound* (HTTP fetch, RPC call,
8//! search adapter). Without this layer the metadata is documentation;
9//! with it the runtime honours the hint — re-invoking on transient
10//! failures up to the configured budget.
11//!
12//! ## Contract
13//!
14//! - **No hint, no retry.** Tools that have not opted in
15//!   (`retry_hint == None`) pass through the layer unchanged. The
16//!   default for non-idempotent tools is fail-fast, regardless of
17//!   error category.
18//! - **Hint present + retryable error → retry.** The layer reads
19//!   `metadata.retry_hint.max_attempts` for the cap, applies the
20//!   layer's `RetryClassifier` (default: matches transient errors
21//!   per [`ToolErrorKind::is_retryable`](crate::ToolErrorKind)),
22//!   and waits `hint.initial_backoff * 2^attempt` (jittered, capped
23//!   at the layer's max-backoff) between attempts. Vendor
24//!   `Retry-After` hints (`RetryDecision::after`) override the
25//!   computed delay when present.
26//! - **Cancellation-aware sleep.** Backoff sleeps respect
27//!   [`ExecutionContext::cancellation`](crate::ExecutionContext) —
28//!   a cancellation during backoff returns
29//!   [`Error::Cancelled`](crate::Error::Cancelled) immediately, no
30//!   final attempt.
31//!
32//! ## Composition order
33//!
34//! Wire `RetryToolLayer` *innermost* (closest to the leaf service)
35//! so observability layers (`OtelLayer`, `ToolEventLayer`) emit one
36//! event per retry attempt rather than one event for the entire
37//! retry envelope. Mirrors the pattern transport-side
38//! `RetryService` / `OtelLayer` use for model invocations.
39//!
40//! ```ignore
41//! use entelix_core::ToolRegistry;
42//! use entelix_core::tools::RetryToolLayer;
43//!
44//! let registry = ToolRegistry::new()
45//!     .layer(RetryToolLayer::new())          // innermost
46//!     .layer(my_observability_layer)         // outermost
47//!     .register(my_tool)?;
48//! # Ok::<(), entelix_core::Error>(())
49//! ```
50
51use std::sync::Arc;
52use std::sync::atomic::{AtomicU64, Ordering};
53use std::task::{Context, Poll};
54use std::time::{Duration, SystemTime, UNIX_EPOCH};
55
56use futures::future::BoxFuture;
57use rand::SeedableRng;
58use rand::rngs::SmallRng;
59use serde_json::Value;
60use tower::{Layer, Service, ServiceExt};
61
62use crate::backoff::ExponentialBackoff;
63use crate::error::{Error, Result};
64use crate::service::ToolInvocation;
65use crate::transports::{DefaultRetryClassifier, RetryClassifier};
66
67/// Default upper bound on the per-attempt backoff. Caps the geometric
68/// growth of `hint.initial_backoff * 2^attempt` so a misconfigured
69/// hint cannot pin the loop indefinitely.
70pub const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(30);
71
72/// `tower::Layer` that retries tool dispatches per the wrapped tool's
73/// [`RetryHint`](crate::tools::RetryHint) metadata.
74///
75/// Cloning is cheap — internal state is `Arc`-backed.
76#[derive(Clone)]
77pub struct RetryToolLayer {
78    classifier: Arc<dyn RetryClassifier>,
79    max_backoff: Duration,
80}
81
82impl RetryToolLayer {
83    /// Patch-version-stable identifier surfaced through
84    /// `ToolRegistry::layer_names`. Distinguished from the
85    /// transport-level [`crate::transports::RetryLayer`] (`"retry"`)
86    /// — this layer drives per-tool retries from the wrapped tool's
87    /// [`RetryHint`](crate::tools::RetryHint) metadata, not from a
88    /// global [`crate::transports::RetryPolicy`]. Renaming this
89    /// constant is a breaking change for dashboards keyed off the
90    /// value.
91    pub const NAME: &'static str = "tool_retry";
92
93    /// Build with the default classifier ([`DefaultRetryClassifier`])
94    /// and [`DEFAULT_MAX_BACKOFF`] cap.
95    #[must_use]
96    pub fn new() -> Self {
97        Self {
98            classifier: Arc::new(DefaultRetryClassifier),
99            max_backoff: DEFAULT_MAX_BACKOFF,
100        }
101    }
102
103    /// Replace the [`RetryClassifier`] consulted on each failure.
104    /// Operators with custom retry policy (e.g. retry only on
105    /// `Transient`, ignore `RateLimit`) install their own
106    /// classifier here.
107    #[must_use]
108    pub fn with_classifier(mut self, classifier: Arc<dyn RetryClassifier>) -> Self {
109        self.classifier = classifier;
110        self
111    }
112
113    /// Override the per-attempt backoff cap. The geometric growth of
114    /// `hint.initial_backoff * 2^attempt` is clamped to this value
115    /// before jitter is applied.
116    #[must_use]
117    pub const fn with_max_backoff(mut self, max: Duration) -> Self {
118        self.max_backoff = max;
119        self
120    }
121}
122
123impl Default for RetryToolLayer {
124    fn default() -> Self {
125        Self::new()
126    }
127}
128
129impl std::fmt::Debug for RetryToolLayer {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        f.debug_struct("RetryToolLayer")
132            .field("max_backoff", &self.max_backoff)
133            .finish_non_exhaustive()
134    }
135}
136
137impl<S> Layer<S> for RetryToolLayer
138where
139    S: Service<ToolInvocation, Response = Value, Error = Error> + Clone + Send + 'static,
140    S::Future: Send + 'static,
141{
142    type Service = RetryToolService<S>;
143
144    fn layer(&self, inner: S) -> Self::Service {
145        RetryToolService {
146            inner,
147            classifier: Arc::clone(&self.classifier),
148            max_backoff: self.max_backoff,
149        }
150    }
151}
152
153impl crate::NamedLayer for RetryToolLayer {
154    fn layer_name(&self) -> &'static str {
155        Self::NAME
156    }
157}
158
159/// `Service<ToolInvocation>` produced by [`RetryToolLayer`].
160#[derive(Clone)]
161pub struct RetryToolService<Inner> {
162    inner: Inner,
163    classifier: Arc<dyn RetryClassifier>,
164    max_backoff: Duration,
165}
166
167impl<Inner> std::fmt::Debug for RetryToolService<Inner> {
168    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169        f.debug_struct("RetryToolService")
170            .field("max_backoff", &self.max_backoff)
171            .finish_non_exhaustive()
172    }
173}
174
175impl<Inner> Service<ToolInvocation> for RetryToolService<Inner>
176where
177    Inner: Service<ToolInvocation, Response = Value, Error = Error> + Clone + Send + 'static,
178    Inner::Future: Send + 'static,
179{
180    type Response = Value;
181    type Error = Error;
182    type Future = BoxFuture<'static, Result<Value>>;
183
184    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
185        self.inner.poll_ready(cx)
186    }
187
188    fn call(&mut self, invocation: ToolInvocation) -> Self::Future {
189        let mut inner = self.inner.clone();
190        let classifier = Arc::clone(&self.classifier);
191        let max_backoff = self.max_backoff;
192
193        Box::pin(async move {
194            let hint = invocation.metadata.retry_hint;
195            // Tools without an explicit hint pass through unchanged
196            // — the metadata-level fail-fast contract for
197            // non-idempotent tools.
198            let Some(hint) = hint else {
199                return inner.ready().await?.call(invocation).await;
200            };
201
202            let max_attempts = hint.max_attempts.max(1);
203            // Per-tool baseline + layer-level cap → fresh backoff
204            // strategy. Each invocation gets its own tuned schedule
205            // tied to that tool's hint.
206            let backoff = ExponentialBackoff::new(hint.initial_backoff, max_backoff);
207            let mut rng = SmallRng::seed_from_u64(seed_from_time());
208            let mut attempt: u32 = 0;
209
210            loop {
211                let ctx_token = invocation.ctx.cancellation();
212                if ctx_token.is_cancelled() {
213                    return Err(Error::Cancelled);
214                }
215
216                let cloned = invocation.clone();
217                let result = inner.ready().await?.call(cloned).await;
218
219                match result {
220                    Ok(value) => return Ok(value),
221                    Err(err) => {
222                        attempt = attempt.saturating_add(1);
223                        let exhausted = attempt >= max_attempts;
224                        let decision = classifier.should_retry(&err, attempt - 1);
225                        if exhausted || !decision.retry {
226                            return Err(err);
227                        }
228                        let computed = backoff.delay_for_attempt(attempt - 1, &mut rng);
229                        // Vendor hints win over self-jitter (mirrors
230                        // model-side `RetryService`, invariant 17).
231                        let delay = decision
232                            .after
233                            .map_or(computed, |hint| hint.min(max_backoff));
234
235                        tokio::select! {
236                            () = tokio::time::sleep(delay) => {}
237                            () = ctx_token.cancelled() => return Err(Error::Cancelled),
238                        }
239                    }
240                }
241            }
242        })
243    }
244}
245
246/// Seed a per-call RNG from system clock nanoseconds XOR a process-
247/// local counter — uncorrelated jitter even when two calls collide
248/// in the same tick.
249fn seed_from_time() -> u64 {
250    static COUNTER: AtomicU64 = AtomicU64::new(0);
251    // silent-fallback-ok: jitter seed only — `now() < UNIX_EPOCH`
252    // cannot happen on a sane clock, and the per-process atomic
253    // counter XORed below already breaks ties so a 0 nanos
254    // contribution still yields uncorrelated low-order bits.
255    let nanos = SystemTime::now().duration_since(UNIX_EPOCH).map_or(0, |d| {
256        let n = d.as_nanos();
257        #[allow(clippy::cast_possible_truncation)]
258        {
259            n as u64
260        }
261    });
262    let bump = COUNTER.fetch_add(1, Ordering::Relaxed);
263    nanos ^ bump
264}