Skip to main content

devboy_mcp/
speculation.rs

1//! Paper 3 — speculative tool-call dispatcher.
2//!
3//! Given an `EnrichmentPlan` from the planner, spawns each prefetch
4//! as a Tokio task, caps concurrency per `rate_limit_host` (so two
5//! providers hitting the same domain share the budget), waits up to
6//! `prefetch_timeout_ms` for results to land, and aborts everything
7//! still pending on session shutdown.
8//!
9//! ## Design
10//!
11//! - **Dispatcher trait** — [`PrefetchDispatcher`] abstracts the
12//!   actual `tools/call` path so tests can plug in a mock without
13//!   pulling in MCP transport. The real impl wraps the server's
14//!   handler and is wired in `SessionPipeline`.
15//! - **Per-host concurrency cap** — a `Mutex<HashMap<host,
16//!   in_flight>>` tracks in-flight prefetches per rate-limit host;
17//!   the dispatcher refuses to schedule a call when the cap is hit.
18//!   `None` host = unlimited (local tool).
19//! - **Bounded synchronous wait** — [`SpeculationEngine::wait_within`]
20//!   blocks at most `prefetch_timeout_ms` collecting results that
21//!   landed in time; anything still pending keeps running in the
22//!   background and lands later via the dedup cache.
23//! - **Cascade cancellation** — [`SpeculationEngine::shutdown`] (also
24//!   called from `Drop`) aborts every pending task. No orphan IO.
25//!
26//! Telemetry counters (`prefetch_dispatched`, `prefetch_won_race`,
27//! `prefetch_wasted`) are updated by the caller; this module only
28//! reports the outcomes.
29
30use std::collections::HashMap;
31use std::sync::Arc;
32use std::time::Duration;
33
34use async_trait::async_trait;
35use serde_json::Value;
36use tokio::sync::Mutex;
37use tokio::task::JoinSet;
38use tokio::time::timeout;
39
40use devboy_format_pipeline::adaptive_config::EnrichmentConfig;
41use devboy_format_pipeline::enrichment::PlannedCall;
42
43/// Abstracts how a prefetch is actually executed. The real impl wraps
44/// the MCP server's `tools/call` handler; the test impl returns a
45/// canned body or an error after an optional sleep.
46#[async_trait]
47pub trait PrefetchDispatcher: Send + Sync {
48    /// Execute `tool_name` with `args` out-of-band and return the
49    /// response body (the same string the LLM would receive). Errors
50    /// are logged at WARN by the engine and counted as wasted, never
51    /// surfaced to the LLM stream.
52    async fn dispatch(&self, tool_name: &str, args: Value) -> Result<String, PrefetchError>;
53}
54
55#[derive(Debug, thiserror::Error)]
56pub enum PrefetchError {
57    #[error("dispatcher rejected: {0}")]
58    Rejected(String),
59    #[error("dispatcher I/O: {0}")]
60    Io(String),
61    #[error("dispatcher timed out (host-level)")]
62    HostTimeout,
63}
64
65/// Outcome of a single prefetch task as observed by
66/// [`SpeculationEngine::wait_within`].
67#[derive(Debug)]
68pub enum PrefetchOutcome {
69    /// Prefetch landed within the timeout. The body is ready to write
70    /// into the dedup cache. `predicted_cost_tokens` carries the
71    /// planner's admit-time estimate so callers can pass it through
72    /// to telemetry (`PipelineEvent.enricher_predicted_cost_tokens`).
73    Settled {
74        tool: String,
75        args: Value,
76        body: String,
77        predicted_cost_tokens: u32,
78    },
79    /// Prefetch returned an error. Counted as wasted; logged at WARN.
80    Failed {
81        tool: String,
82        /// Underlying prefetch error.
83        error: PrefetchError,
84    },
85    /// Prefetch was rate-limited at scheduling time — the engine
86    /// never even spawned it. Used by callers to attribute the
87    /// `prefetch_dispatched` gap.
88    Skipped {
89        /// Tool name.
90        tool: String,
91        /// Reason the dispatch was skipped.
92        reason: SkipReason,
93    },
94}
95
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
97pub enum SkipReason {
98    /// `rate_limit_host` saturated (in-flight count == cap).
99    HostSaturated,
100    /// `enrichment.max_parallel_prefetches` reached for this turn.
101    MaxParallelReached,
102    /// Tool's `side_effect_class` (or `speculate=Some(false)`) blocks
103    /// speculation — surfaced for telemetry but the planner already
104    /// filters these out.
105    NotSpeculatable,
106}
107
108/// One unit of work the engine decides about. Public so the host can
109/// produce the list (combining `EnrichmentPlan.calls` with extracted
110/// args from `projection`) before handing it to the engine.
111#[derive(Debug, Clone)]
112pub struct PrefetchRequest {
113    pub call: PlannedCall,
114    pub args: Value,
115    /// Pre-computed rate-limit host for this call. `None` = uncapped.
116    /// Built by the host from either the static
117    /// `ToolValueModel.rate_limit_host` or the runtime
118    /// `ToolEnricher::rate_limit_host(args)` — whichever resolved.
119    pub rate_limit_host: Option<String>,
120}
121
122/// Per-host in-flight counter. Cheap clone (Arc).
123#[derive(Default, Clone)]
124pub struct HostBudget {
125    counts: Arc<Mutex<HashMap<String, u32>>>,
126}
127
128impl HostBudget {
129    pub fn new() -> Self {
130        Self::default()
131    }
132
133    /// Returns `true` and increments the count if `host` has room
134    /// under `cap`; returns `false` (no change) otherwise. `cap = 0`
135    /// blocks all calls to this host (defensive default).
136    pub async fn try_acquire(&self, host: &str, cap: u32) -> bool {
137        if cap == 0 {
138            return false;
139        }
140        let mut g = self.counts.lock().await;
141        let entry = g.entry(host.to_string()).or_insert(0);
142        if *entry >= cap {
143            return false;
144        }
145        *entry = entry.saturating_add(1);
146        true
147    }
148
149    /// Decrement the count for `host`. No-op if the count is already 0
150    /// (mismatched release-without-acquire — defensive, can't go neg).
151    pub async fn release(&self, host: &str) {
152        let mut g = self.counts.lock().await;
153        if let Some(entry) = g.get_mut(host) {
154            *entry = entry.saturating_sub(1);
155            if *entry == 0 {
156                g.remove(host);
157            }
158        }
159    }
160
161    /// Snapshot of in-flight counts — for telemetry / tests.
162    pub async fn snapshot(&self) -> HashMap<String, u32> {
163        self.counts.lock().await.clone()
164    }
165}
166
167/// Per-turn speculation engine. One instance per `SessionPipeline`;
168/// holds the JoinSet and the host budget. Drop = shutdown.
169pub struct SpeculationEngine {
170    config: EnrichmentConfig,
171    dispatcher: Arc<dyn PrefetchDispatcher>,
172    budget: HostBudget,
173    /// Pending tasks for the current turn. Cleared at the end of each
174    /// `wait_within`; orphans get aborted by `shutdown` / `Drop`.
175    join_set: JoinSet<TaskResult>,
176    /// Default per-host concurrency cap when the call has a host.
177    /// Picked at 4 — enough for top-3 prefetch fan-out plus one
178    /// in-flight from the main flow without saturating typical APIs.
179    per_host_cap: u32,
180}
181
182struct TaskResult {
183    tool: String,
184    args: Value,
185    body: Result<String, PrefetchError>,
186    predicted_cost_tokens: u32,
187    /// Carried through for future per-host telemetry on the result
188    /// path. Today, the budget is released inside the spawned task
189    /// before this value is read by `wait_within` — so this field
190    /// is only logged in WARN traces. Kept for symmetry with
191    /// `PrefetchRequest.rate_limit_host`; downstream may surface it.
192    #[allow(dead_code)]
193    rate_limit_host: Option<String>,
194}
195
196impl SpeculationEngine {
197    /// Build a fresh engine bound to a `dispatcher`. Settings come
198    /// from `config.enrichment`.
199    pub fn new(config: EnrichmentConfig, dispatcher: Arc<dyn PrefetchDispatcher>) -> Self {
200        Self {
201            config,
202            dispatcher,
203            budget: HostBudget::new(),
204            join_set: JoinSet::new(),
205            per_host_cap: 4,
206        }
207    }
208
209    /// Build with an explicit per-host cap. Useful for tests that want
210    /// `cap = 1` to force serialisation.
211    pub fn with_per_host_cap(mut self, cap: u32) -> Self {
212        self.per_host_cap = cap;
213        self
214    }
215
216    /// `true` when `enrichment.enabled = false` — host should skip
217    /// all dispatch and not even build a plan.
218    pub fn is_enabled(&self) -> bool {
219        self.config.enabled
220    }
221
222    /// Returns the configured wall-clock budget for one wait_within
223    /// call.
224    pub fn timeout(&self) -> Duration {
225        Duration::from_millis(self.config.prefetch_timeout_ms.into())
226    }
227
228    /// Number of currently-spawned tasks (incl. pending ones from
229    /// previous turns that have not been collected yet).
230    pub fn pending(&self) -> usize {
231        self.join_set.len()
232    }
233
234    /// Schedule `requests` honouring `max_parallel_prefetches` and
235    /// per-host caps. Requests rejected by either gate are reported
236    /// as `Skipped` outcomes and **not** spawned.
237    ///
238    /// Returns the immediately-known outcomes (skips). Settled /
239    /// failed outcomes come back from [`Self::wait_within`].
240    pub async fn dispatch(&mut self, requests: Vec<PrefetchRequest>) -> Vec<PrefetchOutcome> {
241        let mut skips = Vec::new();
242        let mut spawned = 0u32;
243        let max = self.config.max_parallel_prefetches;
244        for req in requests {
245            if spawned >= max {
246                skips.push(PrefetchOutcome::Skipped {
247                    tool: req.call.tool.clone(),
248                    reason: SkipReason::MaxParallelReached,
249                });
250                continue;
251            }
252            // Honour rate-limit host. Uncapped tools (None) fall
253            // through; capped ones must acquire a slot first.
254            if self.config.respect_rate_limits
255                && let Some(host) = &req.rate_limit_host
256                && !self.budget.try_acquire(host, self.per_host_cap).await
257            {
258                skips.push(PrefetchOutcome::Skipped {
259                    tool: req.call.tool.clone(),
260                    reason: SkipReason::HostSaturated,
261                });
262                continue;
263            }
264
265            let dispatcher = Arc::clone(&self.dispatcher);
266            let tool = req.call.tool.clone();
267            let args = req.args.clone();
268            let host = req.rate_limit_host.clone();
269            let predicted_cost_tokens = req.call.estimated_cost_tokens;
270            let budget = self.budget.clone();
271            let respects = self.config.respect_rate_limits;
272            self.join_set.spawn(async move {
273                let body = dispatcher.dispatch(&tool, args.clone()).await;
274                // Release the per-host slot whether we succeeded or
275                // failed — the call has stopped occupying the API.
276                if respects && let Some(h) = &host {
277                    budget.release(h).await;
278                }
279                TaskResult {
280                    tool,
281                    args,
282                    body,
283                    predicted_cost_tokens,
284                    rate_limit_host: host,
285                }
286            });
287            spawned += 1;
288        }
289        skips
290    }
291
292    /// Wait up to `prefetch_timeout_ms` collecting outcomes for tasks
293    /// that complete in time. Tasks still pending stay in the
294    /// JoinSet — their results arrive on the next `wait_within` (or
295    /// get cancelled by `shutdown`).
296    ///
297    /// The timeout is a **global deadline** for the whole call, not
298    /// a per-task budget — N slow prefetches finishing one-by-one
299    /// just under the threshold can no longer stall the response
300    /// for `N × prefetch_timeout_ms`.
301    ///
302    /// Returns `Settled` for every task that returned `Ok(body)`,
303    /// `Failed` for every `Err(...)`. Skipped outcomes from the most
304    /// recent `dispatch` call are not echoed here — the host already
305    /// has them.
306    pub async fn wait_within(&mut self) -> Vec<PrefetchOutcome> {
307        let mut out = Vec::new();
308        let deadline = tokio::time::Instant::now() + self.timeout();
309        loop {
310            if self.join_set.is_empty() {
311                break;
312            }
313            let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
314            if remaining.is_zero() {
315                tracing::debug!(
316                    target: "devboy_mcp::speculation",
317                    "prefetch_timeout_ms reached with {} tasks still pending",
318                    self.join_set.len()
319                );
320                break;
321            }
322            match tokio::time::timeout_at(deadline, self.join_set.join_next()).await {
323                Ok(Some(Ok(task_result))) => {
324                    let predicted = task_result.predicted_cost_tokens;
325                    out.push(match task_result.body {
326                        Ok(body) => PrefetchOutcome::Settled {
327                            tool: task_result.tool,
328                            args: task_result.args,
329                            body,
330                            predicted_cost_tokens: predicted,
331                        },
332                        Err(error) => PrefetchOutcome::Failed {
333                            tool: task_result.tool,
334                            error,
335                        },
336                    });
337                }
338                Ok(Some(Err(join_err))) => {
339                    tracing::warn!(
340                        target: "devboy_mcp::speculation",
341                        "prefetch task panicked or was cancelled: {join_err}"
342                    );
343                    out.push(PrefetchOutcome::Failed {
344                        tool: "<unknown>".into(),
345                        error: PrefetchError::Io(join_err.to_string()),
346                    });
347                }
348                Ok(None) => break, // empty join_set
349                Err(_elapsed) => {
350                    // Global deadline expired. Remaining tasks stay in
351                    // the JoinSet so a future `drain_pending()` (or
352                    // another `wait_within` cycle) can still collect
353                    // their results into the dedup cache.
354                    tracing::debug!(
355                        target: "devboy_mcp::speculation",
356                        "prefetch_timeout_ms reached with {} tasks still pending",
357                        self.join_set.len()
358                    );
359                    break;
360                }
361            }
362        }
363        out
364    }
365
366    /// Best-effort drain of completed tasks **without blocking**.
367    ///
368    /// Returns outcomes for every task that has already finished,
369    /// leaves still-pending tasks alone. Lets the host call this on
370    /// the next turn (or before each `dispatch`) to recover bodies
371    /// that landed after the previous `wait_within` timed out, so
372    /// they can still be written into the dedup cache instead of
373    /// being silently lost on the next `shutdown`.
374    pub async fn drain_pending(&mut self) -> Vec<PrefetchOutcome> {
375        let mut out = Vec::new();
376        loop {
377            if self.join_set.is_empty() {
378                break;
379            }
380            // 0-duration timeout = "non-blocking poll".
381            match timeout(Duration::from_millis(0), self.join_set.join_next()).await {
382                Ok(Some(Ok(task_result))) => {
383                    let predicted = task_result.predicted_cost_tokens;
384                    out.push(match task_result.body {
385                        Ok(body) => PrefetchOutcome::Settled {
386                            tool: task_result.tool,
387                            args: task_result.args,
388                            body,
389                            predicted_cost_tokens: predicted,
390                        },
391                        Err(error) => PrefetchOutcome::Failed {
392                            tool: task_result.tool,
393                            error,
394                        },
395                    });
396                }
397                Ok(Some(Err(join_err))) => {
398                    out.push(PrefetchOutcome::Failed {
399                        tool: "<unknown>".into(),
400                        error: PrefetchError::Io(join_err.to_string()),
401                    });
402                }
403                Ok(None) | Err(_) => break,
404            }
405        }
406        out
407    }
408
409    /// Abort every still-pending task. Idempotent.
410    pub async fn shutdown(&mut self) {
411        self.join_set.abort_all();
412        // Drain so the abort_all signal is observed before the engine
413        // is dropped — without this, ASAN-style runtimes could complain
414        // about tasks holding outstanding waker references.
415        while self.join_set.join_next().await.is_some() {}
416    }
417}
418
419impl Drop for SpeculationEngine {
420    fn drop(&mut self) {
421        // `JoinSet::abort_all` is a sync call — safe in Drop. We do
422        // not poll for completion here (sync context), but `Tokio`'s
423        // task abort is delivered to each worker on its next yield.
424        self.join_set.abort_all();
425    }
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431    use devboy_format_pipeline::enrichment::PlannedCall;
432    use std::sync::atomic::{AtomicU32, Ordering};
433
434    struct MockDispatcher {
435        delay_ms: u64,
436        call_count: Arc<AtomicU32>,
437        fail_for: Option<String>,
438    }
439
440    #[async_trait]
441    impl PrefetchDispatcher for MockDispatcher {
442        async fn dispatch(&self, tool: &str, args: Value) -> Result<String, PrefetchError> {
443            self.call_count.fetch_add(1, Ordering::SeqCst);
444            tokio::time::sleep(Duration::from_millis(self.delay_ms)).await;
445            if Some(tool.to_string()) == self.fail_for {
446                return Err(PrefetchError::Io("simulated failure".into()));
447            }
448            Ok(format!("mock-body for {tool} args={args}"))
449        }
450    }
451
452    fn req(tool: &str, host: Option<&str>) -> PrefetchRequest {
453        PrefetchRequest {
454            call: PlannedCall {
455                tool: tool.into(),
456                projection: None,
457                probability: 1.0,
458                estimated_cost_bytes: 1024,
459                estimated_cost_tokens: 256,
460                value_class: devboy_core::ValueClass::Critical,
461            },
462            args: serde_json::json!({"x": 1}),
463            rate_limit_host: host.map(String::from),
464        }
465    }
466
467    fn cfg(timeout_ms: u32, max_parallel: u32) -> EnrichmentConfig {
468        EnrichmentConfig {
469            enabled: true,
470            max_parallel_prefetches: max_parallel,
471            prefetch_budget_tokens: 8000,
472            prefetch_timeout_ms: timeout_ms,
473            respect_rate_limits: true,
474        }
475    }
476
477    #[tokio::test]
478    async fn settled_outcome_returned_when_within_budget() {
479        let count = Arc::new(AtomicU32::new(0));
480        let mut engine = SpeculationEngine::new(
481            cfg(500, 5),
482            Arc::new(MockDispatcher {
483                delay_ms: 10,
484                call_count: count.clone(),
485                fail_for: None,
486            }),
487        );
488        let skips = engine
489            .dispatch(vec![req("Read", None), req("Read", None)])
490            .await;
491        assert!(skips.is_empty(), "no skips expected: {skips:?}");
492        let outcomes = engine.wait_within().await;
493        assert_eq!(outcomes.len(), 2);
494        for o in outcomes {
495            match o {
496                PrefetchOutcome::Settled { body, .. } => assert!(body.contains("mock-body")),
497                other => panic!("expected Settled, got {other:?}"),
498            }
499        }
500        assert_eq!(count.load(Ordering::SeqCst), 2);
501    }
502
503    #[tokio::test]
504    async fn timeout_leaves_slow_prefetches_pending() {
505        let count = Arc::new(AtomicU32::new(0));
506        let mut engine = SpeculationEngine::new(
507            cfg(50, 5),
508            Arc::new(MockDispatcher {
509                delay_ms: 500,
510                call_count: count.clone(),
511                fail_for: None,
512            }),
513        );
514        engine.dispatch(vec![req("SlowTool", None)]).await;
515        let outcomes = engine.wait_within().await;
516        // Wall-clock budget hit before the dispatcher returned.
517        assert!(
518            outcomes.is_empty(),
519            "expected no settled within 50ms timeout"
520        );
521        assert_eq!(engine.pending(), 1, "task must still be in JoinSet");
522        engine.shutdown().await;
523    }
524
525    #[tokio::test]
526    async fn max_parallel_skips_excess_requests() {
527        let count = Arc::new(AtomicU32::new(0));
528        let mut engine = SpeculationEngine::new(
529            cfg(500, 2),
530            Arc::new(MockDispatcher {
531                delay_ms: 5,
532                call_count: count.clone(),
533                fail_for: None,
534            }),
535        );
536        let skips = engine
537            .dispatch(vec![
538                req("A", None),
539                req("B", None),
540                req("C", None),
541                req("D", None),
542            ])
543            .await;
544        assert_eq!(skips.len(), 2, "C+D must skip — max_parallel=2");
545        for s in &skips {
546            assert!(matches!(
547                s,
548                PrefetchOutcome::Skipped {
549                    reason: SkipReason::MaxParallelReached,
550                    ..
551                }
552            ));
553        }
554        let settled = engine.wait_within().await;
555        assert_eq!(settled.len(), 2);
556    }
557
558    #[tokio::test]
559    async fn host_saturation_is_observed_across_dispatches() {
560        let count = Arc::new(AtomicU32::new(0));
561        let dispatcher = Arc::new(MockDispatcher {
562            delay_ms: 100,
563            call_count: count.clone(),
564            fail_for: None,
565        });
566        // Cap = 1 in-flight per host, max_parallel large enough.
567        let mut engine = SpeculationEngine::new(cfg(500, 10), dispatcher).with_per_host_cap(1);
568        // First call grabs the slot.
569        let skips1 = engine
570            .dispatch(vec![req("ToolA", Some("api.github.com"))])
571            .await;
572        assert!(skips1.is_empty());
573        // Second call to the same host while first is in-flight: SKIP.
574        let skips2 = engine
575            .dispatch(vec![req("ToolB", Some("api.github.com"))])
576            .await;
577        assert_eq!(skips2.len(), 1);
578        assert!(matches!(
579            skips2[0],
580            PrefetchOutcome::Skipped {
581                reason: SkipReason::HostSaturated,
582                ..
583            }
584        ));
585        // Drain the first one so the slot frees up.
586        engine.wait_within().await;
587        // Now the same host has room again.
588        let skips3 = engine
589            .dispatch(vec![req("ToolC", Some("api.github.com"))])
590            .await;
591        assert!(skips3.is_empty(), "after drain the slot must be free");
592        engine.wait_within().await;
593    }
594
595    #[tokio::test]
596    async fn different_hosts_share_no_budget() {
597        let count = Arc::new(AtomicU32::new(0));
598        let mut engine = SpeculationEngine::new(
599            cfg(500, 10),
600            Arc::new(MockDispatcher {
601                delay_ms: 5,
602                call_count: count.clone(),
603                fail_for: None,
604            }),
605        )
606        .with_per_host_cap(1);
607        let skips = engine
608            .dispatch(vec![
609                req("A", Some("api.github.com")),
610                req("B", Some("gitlab.example.com")),
611                req("C", Some("api.openai.com")),
612            ])
613            .await;
614        assert!(skips.is_empty(), "different hosts must each get a slot");
615        let settled = engine.wait_within().await;
616        assert_eq!(settled.len(), 3);
617    }
618
619    #[tokio::test]
620    async fn dispatcher_failure_surfaces_as_failed_outcome() {
621        let count = Arc::new(AtomicU32::new(0));
622        let mut engine = SpeculationEngine::new(
623            cfg(500, 5),
624            Arc::new(MockDispatcher {
625                delay_ms: 5,
626                call_count: count.clone(),
627                fail_for: Some("Bad".into()),
628            }),
629        );
630        engine
631            .dispatch(vec![req("Bad", None), req("Good", None)])
632            .await;
633        let outcomes = engine.wait_within().await;
634        assert_eq!(outcomes.len(), 2);
635        let failed = outcomes
636            .iter()
637            .find(|o| matches!(o, PrefetchOutcome::Failed { tool, .. } if tool == "Bad"));
638        assert!(failed.is_some(), "expected Failed for Bad");
639    }
640
641    #[tokio::test]
642    async fn shutdown_aborts_pending_tasks() {
643        let count = Arc::new(AtomicU32::new(0));
644        let mut engine = SpeculationEngine::new(
645            cfg(50, 5),
646            Arc::new(MockDispatcher {
647                delay_ms: 10_000,
648                call_count: count.clone(),
649                fail_for: None,
650            }),
651        );
652        engine.dispatch(vec![req("LongRunning", None)]).await;
653        // Don't wait for them — go straight to shutdown.
654        engine.shutdown().await;
655        assert_eq!(engine.pending(), 0, "shutdown must drain JoinSet");
656    }
657
658    #[tokio::test]
659    async fn host_budget_release_after_failure() {
660        let count = Arc::new(AtomicU32::new(0));
661        let mut engine = SpeculationEngine::new(
662            cfg(500, 5),
663            Arc::new(MockDispatcher {
664                delay_ms: 5,
665                call_count: count.clone(),
666                fail_for: Some("Failing".into()),
667            }),
668        )
669        .with_per_host_cap(1);
670        engine
671            .dispatch(vec![req("Failing", Some("host.example.org"))])
672            .await;
673        engine.wait_within().await;
674        // After failure, the slot must have been released.
675        let snap = engine.budget.snapshot().await;
676        assert!(
677            !snap.contains_key("host.example.org")
678                || snap.get("host.example.org").copied() == Some(0),
679            "host budget must release on failure: {snap:?}"
680        );
681    }
682
683    #[tokio::test]
684    async fn stress_50_requests_3_hosts_cap_2_per_host() {
685        // QA: realistic load — 50 prefetches fan out across 3 hosts
686        // (api.github.com, api.openai.com, gitlab.com), per-host cap 2
687        // and max_parallel 6. Expect:
688        //   - settled count ≤ max_parallel × ceil(host_groups / cap)
689        //   - skip count = 50 − settled
690        //   - no panics, no orphan tasks at the end.
691        let count = Arc::new(AtomicU32::new(0));
692        let mut engine = SpeculationEngine::new(
693            cfg(2_000, 6),
694            Arc::new(MockDispatcher {
695                delay_ms: 5,
696                call_count: count.clone(),
697                fail_for: None,
698            }),
699        )
700        .with_per_host_cap(2);
701        let hosts = ["api.github.com", "api.openai.com", "gitlab.com"];
702        let mut requests = Vec::new();
703        for i in 0..50 {
704            requests.push(req("ToolX", Some(hosts[i % hosts.len()])));
705        }
706        let skips = engine.dispatch(requests).await;
707        // First batch: only 6 fit (max_parallel), rest skip on max_parallel
708        // — host saturation kicks in only after some have completed.
709        assert!(
710            skips.len() >= 44,
711            "expected ≥ 44 skipped (cap 6 + per-host limits), got {}",
712            skips.len()
713        );
714        // Settled count = up to 6 (max_parallel cap)
715        let settled = engine.wait_within().await;
716        let settled_ok = settled
717            .iter()
718            .filter(|o| matches!(o, PrefetchOutcome::Settled { .. }))
719            .count();
720        assert!(
721            settled_ok <= 6,
722            "settled must respect max_parallel=6, got {settled_ok}"
723        );
724        engine.shutdown().await;
725        // No orphans.
726        assert_eq!(engine.pending(), 0);
727    }
728
729    #[tokio::test]
730    async fn rate_limit_disabled_in_config_lets_everything_through() {
731        let count = Arc::new(AtomicU32::new(0));
732        let mut cfg_no_rl = cfg(500, 10);
733        cfg_no_rl.respect_rate_limits = false;
734        let mut engine = SpeculationEngine::new(
735            cfg_no_rl,
736            Arc::new(MockDispatcher {
737                delay_ms: 5,
738                call_count: count.clone(),
739                fail_for: None,
740            }),
741        )
742        .with_per_host_cap(1);
743        // Three calls to the same host with cap=1 — but respect=false
744        // means the cap is not enforced.
745        let skips = engine
746            .dispatch(vec![
747                req("A", Some("api.github.com")),
748                req("B", Some("api.github.com")),
749                req("C", Some("api.github.com")),
750            ])
751            .await;
752        assert!(skips.is_empty());
753        let settled = engine.wait_within().await;
754        assert_eq!(settled.len(), 3);
755    }
756}