use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use serde_json::Value;
use tokio::sync::Mutex;
use tokio::task::JoinSet;
use tokio::time::timeout;
use devboy_format_pipeline::adaptive_config::EnrichmentConfig;
use devboy_format_pipeline::enrichment::PlannedCall;
#[async_trait]
pub trait PrefetchDispatcher: Send + Sync {
async fn dispatch(&self, tool_name: &str, args: Value) -> Result<String, PrefetchError>;
}
#[derive(Debug, thiserror::Error)]
pub enum PrefetchError {
#[error("dispatcher rejected: {0}")]
Rejected(String),
#[error("dispatcher I/O: {0}")]
Io(String),
#[error("dispatcher timed out (host-level)")]
HostTimeout,
}
#[derive(Debug)]
pub enum PrefetchOutcome {
Settled {
tool: String,
args: Value,
body: String,
predicted_cost_tokens: u32,
},
Failed {
tool: String,
error: PrefetchError,
},
Skipped {
tool: String,
reason: SkipReason,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SkipReason {
HostSaturated,
MaxParallelReached,
NotSpeculatable,
}
#[derive(Debug, Clone)]
pub struct PrefetchRequest {
pub call: PlannedCall,
pub args: Value,
pub rate_limit_host: Option<String>,
}
#[derive(Default, Clone)]
pub struct HostBudget {
counts: Arc<Mutex<HashMap<String, u32>>>,
}
impl HostBudget {
pub fn new() -> Self {
Self::default()
}
pub async fn try_acquire(&self, host: &str, cap: u32) -> bool {
if cap == 0 {
return false;
}
let mut g = self.counts.lock().await;
let entry = g.entry(host.to_string()).or_insert(0);
if *entry >= cap {
return false;
}
*entry = entry.saturating_add(1);
true
}
pub async fn release(&self, host: &str) {
let mut g = self.counts.lock().await;
if let Some(entry) = g.get_mut(host) {
*entry = entry.saturating_sub(1);
if *entry == 0 {
g.remove(host);
}
}
}
pub async fn snapshot(&self) -> HashMap<String, u32> {
self.counts.lock().await.clone()
}
}
pub struct SpeculationEngine {
config: EnrichmentConfig,
dispatcher: Arc<dyn PrefetchDispatcher>,
budget: HostBudget,
join_set: JoinSet<TaskResult>,
per_host_cap: u32,
}
struct TaskResult {
tool: String,
args: Value,
body: Result<String, PrefetchError>,
predicted_cost_tokens: u32,
#[allow(dead_code)]
rate_limit_host: Option<String>,
}
impl SpeculationEngine {
pub fn new(config: EnrichmentConfig, dispatcher: Arc<dyn PrefetchDispatcher>) -> Self {
Self {
config,
dispatcher,
budget: HostBudget::new(),
join_set: JoinSet::new(),
per_host_cap: 4,
}
}
pub fn with_per_host_cap(mut self, cap: u32) -> Self {
self.per_host_cap = cap;
self
}
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
pub fn timeout(&self) -> Duration {
Duration::from_millis(self.config.prefetch_timeout_ms.into())
}
pub fn pending(&self) -> usize {
self.join_set.len()
}
pub async fn dispatch(&mut self, requests: Vec<PrefetchRequest>) -> Vec<PrefetchOutcome> {
let mut skips = Vec::new();
let mut spawned = 0u32;
let max = self.config.max_parallel_prefetches;
for req in requests {
if spawned >= max {
skips.push(PrefetchOutcome::Skipped {
tool: req.call.tool.clone(),
reason: SkipReason::MaxParallelReached,
});
continue;
}
if self.config.respect_rate_limits
&& let Some(host) = &req.rate_limit_host
&& !self.budget.try_acquire(host, self.per_host_cap).await
{
skips.push(PrefetchOutcome::Skipped {
tool: req.call.tool.clone(),
reason: SkipReason::HostSaturated,
});
continue;
}
let dispatcher = Arc::clone(&self.dispatcher);
let tool = req.call.tool.clone();
let args = req.args.clone();
let host = req.rate_limit_host.clone();
let predicted_cost_tokens = req.call.estimated_cost_tokens;
let budget = self.budget.clone();
let respects = self.config.respect_rate_limits;
self.join_set.spawn(async move {
let body = dispatcher.dispatch(&tool, args.clone()).await;
if respects && let Some(h) = &host {
budget.release(h).await;
}
TaskResult {
tool,
args,
body,
predicted_cost_tokens,
rate_limit_host: host,
}
});
spawned += 1;
}
skips
}
pub async fn wait_within(&mut self) -> Vec<PrefetchOutcome> {
let mut out = Vec::new();
let deadline = tokio::time::Instant::now() + self.timeout();
loop {
if self.join_set.is_empty() {
break;
}
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
if remaining.is_zero() {
tracing::debug!(
target: "devboy_mcp::speculation",
"prefetch_timeout_ms reached with {} tasks still pending",
self.join_set.len()
);
break;
}
match tokio::time::timeout_at(deadline, self.join_set.join_next()).await {
Ok(Some(Ok(task_result))) => {
let predicted = task_result.predicted_cost_tokens;
out.push(match task_result.body {
Ok(body) => PrefetchOutcome::Settled {
tool: task_result.tool,
args: task_result.args,
body,
predicted_cost_tokens: predicted,
},
Err(error) => PrefetchOutcome::Failed {
tool: task_result.tool,
error,
},
});
}
Ok(Some(Err(join_err))) => {
tracing::warn!(
target: "devboy_mcp::speculation",
"prefetch task panicked or was cancelled: {join_err}"
);
out.push(PrefetchOutcome::Failed {
tool: "<unknown>".into(),
error: PrefetchError::Io(join_err.to_string()),
});
}
Ok(None) => break, Err(_elapsed) => {
tracing::debug!(
target: "devboy_mcp::speculation",
"prefetch_timeout_ms reached with {} tasks still pending",
self.join_set.len()
);
break;
}
}
}
out
}
pub async fn drain_pending(&mut self) -> Vec<PrefetchOutcome> {
let mut out = Vec::new();
loop {
if self.join_set.is_empty() {
break;
}
match timeout(Duration::from_millis(0), self.join_set.join_next()).await {
Ok(Some(Ok(task_result))) => {
let predicted = task_result.predicted_cost_tokens;
out.push(match task_result.body {
Ok(body) => PrefetchOutcome::Settled {
tool: task_result.tool,
args: task_result.args,
body,
predicted_cost_tokens: predicted,
},
Err(error) => PrefetchOutcome::Failed {
tool: task_result.tool,
error,
},
});
}
Ok(Some(Err(join_err))) => {
out.push(PrefetchOutcome::Failed {
tool: "<unknown>".into(),
error: PrefetchError::Io(join_err.to_string()),
});
}
Ok(None) | Err(_) => break,
}
}
out
}
pub async fn shutdown(&mut self) {
self.join_set.abort_all();
while self.join_set.join_next().await.is_some() {}
}
}
impl Drop for SpeculationEngine {
fn drop(&mut self) {
self.join_set.abort_all();
}
}
#[cfg(test)]
mod tests {
use super::*;
use devboy_format_pipeline::enrichment::PlannedCall;
use std::sync::atomic::{AtomicU32, Ordering};
struct MockDispatcher {
delay_ms: u64,
call_count: Arc<AtomicU32>,
fail_for: Option<String>,
}
#[async_trait]
impl PrefetchDispatcher for MockDispatcher {
async fn dispatch(&self, tool: &str, args: Value) -> Result<String, PrefetchError> {
self.call_count.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(self.delay_ms)).await;
if Some(tool.to_string()) == self.fail_for {
return Err(PrefetchError::Io("simulated failure".into()));
}
Ok(format!("mock-body for {tool} args={args}"))
}
}
fn req(tool: &str, host: Option<&str>) -> PrefetchRequest {
PrefetchRequest {
call: PlannedCall {
tool: tool.into(),
projection: None,
probability: 1.0,
estimated_cost_bytes: 1024,
estimated_cost_tokens: 256,
value_class: devboy_core::ValueClass::Critical,
},
args: serde_json::json!({"x": 1}),
rate_limit_host: host.map(String::from),
}
}
fn cfg(timeout_ms: u32, max_parallel: u32) -> EnrichmentConfig {
EnrichmentConfig {
enabled: true,
max_parallel_prefetches: max_parallel,
prefetch_budget_tokens: 8000,
prefetch_timeout_ms: timeout_ms,
respect_rate_limits: true,
}
}
#[tokio::test]
async fn settled_outcome_returned_when_within_budget() {
let count = Arc::new(AtomicU32::new(0));
let mut engine = SpeculationEngine::new(
cfg(500, 5),
Arc::new(MockDispatcher {
delay_ms: 10,
call_count: count.clone(),
fail_for: None,
}),
);
let skips = engine
.dispatch(vec![req("Read", None), req("Read", None)])
.await;
assert!(skips.is_empty(), "no skips expected: {skips:?}");
let outcomes = engine.wait_within().await;
assert_eq!(outcomes.len(), 2);
for o in outcomes {
match o {
PrefetchOutcome::Settled { body, .. } => assert!(body.contains("mock-body")),
other => panic!("expected Settled, got {other:?}"),
}
}
assert_eq!(count.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn timeout_leaves_slow_prefetches_pending() {
let count = Arc::new(AtomicU32::new(0));
let mut engine = SpeculationEngine::new(
cfg(50, 5),
Arc::new(MockDispatcher {
delay_ms: 500,
call_count: count.clone(),
fail_for: None,
}),
);
engine.dispatch(vec![req("SlowTool", None)]).await;
let outcomes = engine.wait_within().await;
assert!(
outcomes.is_empty(),
"expected no settled within 50ms timeout"
);
assert_eq!(engine.pending(), 1, "task must still be in JoinSet");
engine.shutdown().await;
}
#[tokio::test]
async fn max_parallel_skips_excess_requests() {
let count = Arc::new(AtomicU32::new(0));
let mut engine = SpeculationEngine::new(
cfg(500, 2),
Arc::new(MockDispatcher {
delay_ms: 5,
call_count: count.clone(),
fail_for: None,
}),
);
let skips = engine
.dispatch(vec![
req("A", None),
req("B", None),
req("C", None),
req("D", None),
])
.await;
assert_eq!(skips.len(), 2, "C+D must skip — max_parallel=2");
for s in &skips {
assert!(matches!(
s,
PrefetchOutcome::Skipped {
reason: SkipReason::MaxParallelReached,
..
}
));
}
let settled = engine.wait_within().await;
assert_eq!(settled.len(), 2);
}
#[tokio::test]
async fn host_saturation_is_observed_across_dispatches() {
let count = Arc::new(AtomicU32::new(0));
let dispatcher = Arc::new(MockDispatcher {
delay_ms: 100,
call_count: count.clone(),
fail_for: None,
});
let mut engine = SpeculationEngine::new(cfg(500, 10), dispatcher).with_per_host_cap(1);
let skips1 = engine
.dispatch(vec![req("ToolA", Some("api.github.com"))])
.await;
assert!(skips1.is_empty());
let skips2 = engine
.dispatch(vec![req("ToolB", Some("api.github.com"))])
.await;
assert_eq!(skips2.len(), 1);
assert!(matches!(
skips2[0],
PrefetchOutcome::Skipped {
reason: SkipReason::HostSaturated,
..
}
));
engine.wait_within().await;
let skips3 = engine
.dispatch(vec![req("ToolC", Some("api.github.com"))])
.await;
assert!(skips3.is_empty(), "after drain the slot must be free");
engine.wait_within().await;
}
#[tokio::test]
async fn different_hosts_share_no_budget() {
let count = Arc::new(AtomicU32::new(0));
let mut engine = SpeculationEngine::new(
cfg(500, 10),
Arc::new(MockDispatcher {
delay_ms: 5,
call_count: count.clone(),
fail_for: None,
}),
)
.with_per_host_cap(1);
let skips = engine
.dispatch(vec![
req("A", Some("api.github.com")),
req("B", Some("gitlab.example.com")),
req("C", Some("api.openai.com")),
])
.await;
assert!(skips.is_empty(), "different hosts must each get a slot");
let settled = engine.wait_within().await;
assert_eq!(settled.len(), 3);
}
#[tokio::test]
async fn dispatcher_failure_surfaces_as_failed_outcome() {
let count = Arc::new(AtomicU32::new(0));
let mut engine = SpeculationEngine::new(
cfg(500, 5),
Arc::new(MockDispatcher {
delay_ms: 5,
call_count: count.clone(),
fail_for: Some("Bad".into()),
}),
);
engine
.dispatch(vec![req("Bad", None), req("Good", None)])
.await;
let outcomes = engine.wait_within().await;
assert_eq!(outcomes.len(), 2);
let failed = outcomes
.iter()
.find(|o| matches!(o, PrefetchOutcome::Failed { tool, .. } if tool == "Bad"));
assert!(failed.is_some(), "expected Failed for Bad");
}
#[tokio::test]
async fn shutdown_aborts_pending_tasks() {
let count = Arc::new(AtomicU32::new(0));
let mut engine = SpeculationEngine::new(
cfg(50, 5),
Arc::new(MockDispatcher {
delay_ms: 10_000,
call_count: count.clone(),
fail_for: None,
}),
);
engine.dispatch(vec![req("LongRunning", None)]).await;
engine.shutdown().await;
assert_eq!(engine.pending(), 0, "shutdown must drain JoinSet");
}
#[tokio::test]
async fn host_budget_release_after_failure() {
let count = Arc::new(AtomicU32::new(0));
let mut engine = SpeculationEngine::new(
cfg(500, 5),
Arc::new(MockDispatcher {
delay_ms: 5,
call_count: count.clone(),
fail_for: Some("Failing".into()),
}),
)
.with_per_host_cap(1);
engine
.dispatch(vec![req("Failing", Some("host.example.org"))])
.await;
engine.wait_within().await;
let snap = engine.budget.snapshot().await;
assert!(
!snap.contains_key("host.example.org")
|| snap.get("host.example.org").copied() == Some(0),
"host budget must release on failure: {snap:?}"
);
}
#[tokio::test]
async fn stress_50_requests_3_hosts_cap_2_per_host() {
let count = Arc::new(AtomicU32::new(0));
let mut engine = SpeculationEngine::new(
cfg(2_000, 6),
Arc::new(MockDispatcher {
delay_ms: 5,
call_count: count.clone(),
fail_for: None,
}),
)
.with_per_host_cap(2);
let hosts = ["api.github.com", "api.openai.com", "gitlab.com"];
let mut requests = Vec::new();
for i in 0..50 {
requests.push(req("ToolX", Some(hosts[i % hosts.len()])));
}
let skips = engine.dispatch(requests).await;
assert!(
skips.len() >= 44,
"expected ≥ 44 skipped (cap 6 + per-host limits), got {}",
skips.len()
);
let settled = engine.wait_within().await;
let settled_ok = settled
.iter()
.filter(|o| matches!(o, PrefetchOutcome::Settled { .. }))
.count();
assert!(
settled_ok <= 6,
"settled must respect max_parallel=6, got {settled_ok}"
);
engine.shutdown().await;
assert_eq!(engine.pending(), 0);
}
#[tokio::test]
async fn rate_limit_disabled_in_config_lets_everything_through() {
let count = Arc::new(AtomicU32::new(0));
let mut cfg_no_rl = cfg(500, 10);
cfg_no_rl.respect_rate_limits = false;
let mut engine = SpeculationEngine::new(
cfg_no_rl,
Arc::new(MockDispatcher {
delay_ms: 5,
call_count: count.clone(),
fail_for: None,
}),
)
.with_per_host_cap(1);
let skips = engine
.dispatch(vec![
req("A", Some("api.github.com")),
req("B", Some("api.github.com")),
req("C", Some("api.github.com")),
])
.await;
assert!(skips.is_empty());
let settled = engine.wait_within().await;
assert_eq!(settled.len(), 3);
}
}