1use std::collections::HashMap;
31use std::sync::Arc;
32use std::time::Duration;
33
34use async_trait::async_trait;
35use serde_json::Value;
36use tokio::sync::Mutex;
37use tokio::task::JoinSet;
38use tokio::time::timeout;
39
40use devboy_format_pipeline::adaptive_config::EnrichmentConfig;
41use devboy_format_pipeline::enrichment::PlannedCall;
42
43#[async_trait]
47pub trait PrefetchDispatcher: Send + Sync {
48 async fn dispatch(&self, tool_name: &str, args: Value) -> Result<String, PrefetchError>;
53}
54
55#[derive(Debug, thiserror::Error)]
56pub enum PrefetchError {
57 #[error("dispatcher rejected: {0}")]
58 Rejected(String),
59 #[error("dispatcher I/O: {0}")]
60 Io(String),
61 #[error("dispatcher timed out (host-level)")]
62 HostTimeout,
63}
64
65#[derive(Debug)]
68pub enum PrefetchOutcome {
69 Settled {
74 tool: String,
75 args: Value,
76 body: String,
77 predicted_cost_tokens: u32,
78 },
79 Failed {
81 tool: String,
82 error: PrefetchError,
84 },
85 Skipped {
89 tool: String,
91 reason: SkipReason,
93 },
94}
95
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
97pub enum SkipReason {
98 HostSaturated,
100 MaxParallelReached,
102 NotSpeculatable,
106}
107
108#[derive(Debug, Clone)]
112pub struct PrefetchRequest {
113 pub call: PlannedCall,
114 pub args: Value,
115 pub rate_limit_host: Option<String>,
120}
121
122#[derive(Default, Clone)]
124pub struct HostBudget {
125 counts: Arc<Mutex<HashMap<String, u32>>>,
126}
127
128impl HostBudget {
129 pub fn new() -> Self {
130 Self::default()
131 }
132
133 pub async fn try_acquire(&self, host: &str, cap: u32) -> bool {
137 if cap == 0 {
138 return false;
139 }
140 let mut g = self.counts.lock().await;
141 let entry = g.entry(host.to_string()).or_insert(0);
142 if *entry >= cap {
143 return false;
144 }
145 *entry = entry.saturating_add(1);
146 true
147 }
148
149 pub async fn release(&self, host: &str) {
152 let mut g = self.counts.lock().await;
153 if let Some(entry) = g.get_mut(host) {
154 *entry = entry.saturating_sub(1);
155 if *entry == 0 {
156 g.remove(host);
157 }
158 }
159 }
160
161 pub async fn snapshot(&self) -> HashMap<String, u32> {
163 self.counts.lock().await.clone()
164 }
165}
166
167pub struct SpeculationEngine {
170 config: EnrichmentConfig,
171 dispatcher: Arc<dyn PrefetchDispatcher>,
172 budget: HostBudget,
173 join_set: JoinSet<TaskResult>,
176 per_host_cap: u32,
180}
181
182struct TaskResult {
183 tool: String,
184 args: Value,
185 body: Result<String, PrefetchError>,
186 predicted_cost_tokens: u32,
187 #[allow(dead_code)]
193 rate_limit_host: Option<String>,
194}
195
196impl SpeculationEngine {
197 pub fn new(config: EnrichmentConfig, dispatcher: Arc<dyn PrefetchDispatcher>) -> Self {
200 Self {
201 config,
202 dispatcher,
203 budget: HostBudget::new(),
204 join_set: JoinSet::new(),
205 per_host_cap: 4,
206 }
207 }
208
209 pub fn with_per_host_cap(mut self, cap: u32) -> Self {
212 self.per_host_cap = cap;
213 self
214 }
215
216 pub fn is_enabled(&self) -> bool {
219 self.config.enabled
220 }
221
222 pub fn timeout(&self) -> Duration {
225 Duration::from_millis(self.config.prefetch_timeout_ms.into())
226 }
227
228 pub fn pending(&self) -> usize {
231 self.join_set.len()
232 }
233
234 pub async fn dispatch(&mut self, requests: Vec<PrefetchRequest>) -> Vec<PrefetchOutcome> {
241 let mut skips = Vec::new();
242 let mut spawned = 0u32;
243 let max = self.config.max_parallel_prefetches;
244 for req in requests {
245 if spawned >= max {
246 skips.push(PrefetchOutcome::Skipped {
247 tool: req.call.tool.clone(),
248 reason: SkipReason::MaxParallelReached,
249 });
250 continue;
251 }
252 if self.config.respect_rate_limits
255 && let Some(host) = &req.rate_limit_host
256 && !self.budget.try_acquire(host, self.per_host_cap).await
257 {
258 skips.push(PrefetchOutcome::Skipped {
259 tool: req.call.tool.clone(),
260 reason: SkipReason::HostSaturated,
261 });
262 continue;
263 }
264
265 let dispatcher = Arc::clone(&self.dispatcher);
266 let tool = req.call.tool.clone();
267 let args = req.args.clone();
268 let host = req.rate_limit_host.clone();
269 let predicted_cost_tokens = req.call.estimated_cost_tokens;
270 let budget = self.budget.clone();
271 let respects = self.config.respect_rate_limits;
272 self.join_set.spawn(async move {
273 let body = dispatcher.dispatch(&tool, args.clone()).await;
274 if respects && let Some(h) = &host {
277 budget.release(h).await;
278 }
279 TaskResult {
280 tool,
281 args,
282 body,
283 predicted_cost_tokens,
284 rate_limit_host: host,
285 }
286 });
287 spawned += 1;
288 }
289 skips
290 }
291
292 pub async fn wait_within(&mut self) -> Vec<PrefetchOutcome> {
307 let mut out = Vec::new();
308 let deadline = tokio::time::Instant::now() + self.timeout();
309 loop {
310 if self.join_set.is_empty() {
311 break;
312 }
313 let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
314 if remaining.is_zero() {
315 tracing::debug!(
316 target: "devboy_mcp::speculation",
317 "prefetch_timeout_ms reached with {} tasks still pending",
318 self.join_set.len()
319 );
320 break;
321 }
322 match tokio::time::timeout_at(deadline, self.join_set.join_next()).await {
323 Ok(Some(Ok(task_result))) => {
324 let predicted = task_result.predicted_cost_tokens;
325 out.push(match task_result.body {
326 Ok(body) => PrefetchOutcome::Settled {
327 tool: task_result.tool,
328 args: task_result.args,
329 body,
330 predicted_cost_tokens: predicted,
331 },
332 Err(error) => PrefetchOutcome::Failed {
333 tool: task_result.tool,
334 error,
335 },
336 });
337 }
338 Ok(Some(Err(join_err))) => {
339 tracing::warn!(
340 target: "devboy_mcp::speculation",
341 "prefetch task panicked or was cancelled: {join_err}"
342 );
343 out.push(PrefetchOutcome::Failed {
344 tool: "<unknown>".into(),
345 error: PrefetchError::Io(join_err.to_string()),
346 });
347 }
348 Ok(None) => break, Err(_elapsed) => {
350 tracing::debug!(
355 target: "devboy_mcp::speculation",
356 "prefetch_timeout_ms reached with {} tasks still pending",
357 self.join_set.len()
358 );
359 break;
360 }
361 }
362 }
363 out
364 }
365
366 pub async fn drain_pending(&mut self) -> Vec<PrefetchOutcome> {
375 let mut out = Vec::new();
376 loop {
377 if self.join_set.is_empty() {
378 break;
379 }
380 match timeout(Duration::from_millis(0), self.join_set.join_next()).await {
382 Ok(Some(Ok(task_result))) => {
383 let predicted = task_result.predicted_cost_tokens;
384 out.push(match task_result.body {
385 Ok(body) => PrefetchOutcome::Settled {
386 tool: task_result.tool,
387 args: task_result.args,
388 body,
389 predicted_cost_tokens: predicted,
390 },
391 Err(error) => PrefetchOutcome::Failed {
392 tool: task_result.tool,
393 error,
394 },
395 });
396 }
397 Ok(Some(Err(join_err))) => {
398 out.push(PrefetchOutcome::Failed {
399 tool: "<unknown>".into(),
400 error: PrefetchError::Io(join_err.to_string()),
401 });
402 }
403 Ok(None) | Err(_) => break,
404 }
405 }
406 out
407 }
408
409 pub async fn shutdown(&mut self) {
411 self.join_set.abort_all();
412 while self.join_set.join_next().await.is_some() {}
416 }
417}
418
419impl Drop for SpeculationEngine {
420 fn drop(&mut self) {
421 self.join_set.abort_all();
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431 use devboy_format_pipeline::enrichment::PlannedCall;
432 use std::sync::atomic::{AtomicU32, Ordering};
433
434 struct MockDispatcher {
435 delay_ms: u64,
436 call_count: Arc<AtomicU32>,
437 fail_for: Option<String>,
438 }
439
440 #[async_trait]
441 impl PrefetchDispatcher for MockDispatcher {
442 async fn dispatch(&self, tool: &str, args: Value) -> Result<String, PrefetchError> {
443 self.call_count.fetch_add(1, Ordering::SeqCst);
444 tokio::time::sleep(Duration::from_millis(self.delay_ms)).await;
445 if Some(tool.to_string()) == self.fail_for {
446 return Err(PrefetchError::Io("simulated failure".into()));
447 }
448 Ok(format!("mock-body for {tool} args={args}"))
449 }
450 }
451
452 fn req(tool: &str, host: Option<&str>) -> PrefetchRequest {
453 PrefetchRequest {
454 call: PlannedCall {
455 tool: tool.into(),
456 projection: None,
457 probability: 1.0,
458 estimated_cost_bytes: 1024,
459 estimated_cost_tokens: 256,
460 value_class: devboy_core::ValueClass::Critical,
461 },
462 args: serde_json::json!({"x": 1}),
463 rate_limit_host: host.map(String::from),
464 }
465 }
466
467 fn cfg(timeout_ms: u32, max_parallel: u32) -> EnrichmentConfig {
468 EnrichmentConfig {
469 enabled: true,
470 max_parallel_prefetches: max_parallel,
471 prefetch_budget_tokens: 8000,
472 prefetch_timeout_ms: timeout_ms,
473 respect_rate_limits: true,
474 }
475 }
476
477 #[tokio::test]
478 async fn settled_outcome_returned_when_within_budget() {
479 let count = Arc::new(AtomicU32::new(0));
480 let mut engine = SpeculationEngine::new(
481 cfg(500, 5),
482 Arc::new(MockDispatcher {
483 delay_ms: 10,
484 call_count: count.clone(),
485 fail_for: None,
486 }),
487 );
488 let skips = engine
489 .dispatch(vec![req("Read", None), req("Read", None)])
490 .await;
491 assert!(skips.is_empty(), "no skips expected: {skips:?}");
492 let outcomes = engine.wait_within().await;
493 assert_eq!(outcomes.len(), 2);
494 for o in outcomes {
495 match o {
496 PrefetchOutcome::Settled { body, .. } => assert!(body.contains("mock-body")),
497 other => panic!("expected Settled, got {other:?}"),
498 }
499 }
500 assert_eq!(count.load(Ordering::SeqCst), 2);
501 }
502
503 #[tokio::test]
504 async fn timeout_leaves_slow_prefetches_pending() {
505 let count = Arc::new(AtomicU32::new(0));
506 let mut engine = SpeculationEngine::new(
507 cfg(50, 5),
508 Arc::new(MockDispatcher {
509 delay_ms: 500,
510 call_count: count.clone(),
511 fail_for: None,
512 }),
513 );
514 engine.dispatch(vec![req("SlowTool", None)]).await;
515 let outcomes = engine.wait_within().await;
516 assert!(
518 outcomes.is_empty(),
519 "expected no settled within 50ms timeout"
520 );
521 assert_eq!(engine.pending(), 1, "task must still be in JoinSet");
522 engine.shutdown().await;
523 }
524
525 #[tokio::test]
526 async fn max_parallel_skips_excess_requests() {
527 let count = Arc::new(AtomicU32::new(0));
528 let mut engine = SpeculationEngine::new(
529 cfg(500, 2),
530 Arc::new(MockDispatcher {
531 delay_ms: 5,
532 call_count: count.clone(),
533 fail_for: None,
534 }),
535 );
536 let skips = engine
537 .dispatch(vec![
538 req("A", None),
539 req("B", None),
540 req("C", None),
541 req("D", None),
542 ])
543 .await;
544 assert_eq!(skips.len(), 2, "C+D must skip — max_parallel=2");
545 for s in &skips {
546 assert!(matches!(
547 s,
548 PrefetchOutcome::Skipped {
549 reason: SkipReason::MaxParallelReached,
550 ..
551 }
552 ));
553 }
554 let settled = engine.wait_within().await;
555 assert_eq!(settled.len(), 2);
556 }
557
558 #[tokio::test]
559 async fn host_saturation_is_observed_across_dispatches() {
560 let count = Arc::new(AtomicU32::new(0));
561 let dispatcher = Arc::new(MockDispatcher {
562 delay_ms: 100,
563 call_count: count.clone(),
564 fail_for: None,
565 });
566 let mut engine = SpeculationEngine::new(cfg(500, 10), dispatcher).with_per_host_cap(1);
568 let skips1 = engine
570 .dispatch(vec![req("ToolA", Some("api.github.com"))])
571 .await;
572 assert!(skips1.is_empty());
573 let skips2 = engine
575 .dispatch(vec![req("ToolB", Some("api.github.com"))])
576 .await;
577 assert_eq!(skips2.len(), 1);
578 assert!(matches!(
579 skips2[0],
580 PrefetchOutcome::Skipped {
581 reason: SkipReason::HostSaturated,
582 ..
583 }
584 ));
585 engine.wait_within().await;
587 let skips3 = engine
589 .dispatch(vec![req("ToolC", Some("api.github.com"))])
590 .await;
591 assert!(skips3.is_empty(), "after drain the slot must be free");
592 engine.wait_within().await;
593 }
594
595 #[tokio::test]
596 async fn different_hosts_share_no_budget() {
597 let count = Arc::new(AtomicU32::new(0));
598 let mut engine = SpeculationEngine::new(
599 cfg(500, 10),
600 Arc::new(MockDispatcher {
601 delay_ms: 5,
602 call_count: count.clone(),
603 fail_for: None,
604 }),
605 )
606 .with_per_host_cap(1);
607 let skips = engine
608 .dispatch(vec![
609 req("A", Some("api.github.com")),
610 req("B", Some("gitlab.example.com")),
611 req("C", Some("api.openai.com")),
612 ])
613 .await;
614 assert!(skips.is_empty(), "different hosts must each get a slot");
615 let settled = engine.wait_within().await;
616 assert_eq!(settled.len(), 3);
617 }
618
619 #[tokio::test]
620 async fn dispatcher_failure_surfaces_as_failed_outcome() {
621 let count = Arc::new(AtomicU32::new(0));
622 let mut engine = SpeculationEngine::new(
623 cfg(500, 5),
624 Arc::new(MockDispatcher {
625 delay_ms: 5,
626 call_count: count.clone(),
627 fail_for: Some("Bad".into()),
628 }),
629 );
630 engine
631 .dispatch(vec![req("Bad", None), req("Good", None)])
632 .await;
633 let outcomes = engine.wait_within().await;
634 assert_eq!(outcomes.len(), 2);
635 let failed = outcomes
636 .iter()
637 .find(|o| matches!(o, PrefetchOutcome::Failed { tool, .. } if tool == "Bad"));
638 assert!(failed.is_some(), "expected Failed for Bad");
639 }
640
641 #[tokio::test]
642 async fn shutdown_aborts_pending_tasks() {
643 let count = Arc::new(AtomicU32::new(0));
644 let mut engine = SpeculationEngine::new(
645 cfg(50, 5),
646 Arc::new(MockDispatcher {
647 delay_ms: 10_000,
648 call_count: count.clone(),
649 fail_for: None,
650 }),
651 );
652 engine.dispatch(vec![req("LongRunning", None)]).await;
653 engine.shutdown().await;
655 assert_eq!(engine.pending(), 0, "shutdown must drain JoinSet");
656 }
657
658 #[tokio::test]
659 async fn host_budget_release_after_failure() {
660 let count = Arc::new(AtomicU32::new(0));
661 let mut engine = SpeculationEngine::new(
662 cfg(500, 5),
663 Arc::new(MockDispatcher {
664 delay_ms: 5,
665 call_count: count.clone(),
666 fail_for: Some("Failing".into()),
667 }),
668 )
669 .with_per_host_cap(1);
670 engine
671 .dispatch(vec![req("Failing", Some("host.example.org"))])
672 .await;
673 engine.wait_within().await;
674 let snap = engine.budget.snapshot().await;
676 assert!(
677 !snap.contains_key("host.example.org")
678 || snap.get("host.example.org").copied() == Some(0),
679 "host budget must release on failure: {snap:?}"
680 );
681 }
682
683 #[tokio::test]
684 async fn stress_50_requests_3_hosts_cap_2_per_host() {
685 let count = Arc::new(AtomicU32::new(0));
692 let mut engine = SpeculationEngine::new(
693 cfg(2_000, 6),
694 Arc::new(MockDispatcher {
695 delay_ms: 5,
696 call_count: count.clone(),
697 fail_for: None,
698 }),
699 )
700 .with_per_host_cap(2);
701 let hosts = ["api.github.com", "api.openai.com", "gitlab.com"];
702 let mut requests = Vec::new();
703 for i in 0..50 {
704 requests.push(req("ToolX", Some(hosts[i % hosts.len()])));
705 }
706 let skips = engine.dispatch(requests).await;
707 assert!(
710 skips.len() >= 44,
711 "expected ≥ 44 skipped (cap 6 + per-host limits), got {}",
712 skips.len()
713 );
714 let settled = engine.wait_within().await;
716 let settled_ok = settled
717 .iter()
718 .filter(|o| matches!(o, PrefetchOutcome::Settled { .. }))
719 .count();
720 assert!(
721 settled_ok <= 6,
722 "settled must respect max_parallel=6, got {settled_ok}"
723 );
724 engine.shutdown().await;
725 assert_eq!(engine.pending(), 0);
727 }
728
729 #[tokio::test]
730 async fn rate_limit_disabled_in_config_lets_everything_through() {
731 let count = Arc::new(AtomicU32::new(0));
732 let mut cfg_no_rl = cfg(500, 10);
733 cfg_no_rl.respect_rate_limits = false;
734 let mut engine = SpeculationEngine::new(
735 cfg_no_rl,
736 Arc::new(MockDispatcher {
737 delay_ms: 5,
738 call_count: count.clone(),
739 fail_for: None,
740 }),
741 )
742 .with_per_host_cap(1);
743 let skips = engine
746 .dispatch(vec![
747 req("A", Some("api.github.com")),
748 req("B", Some("api.github.com")),
749 req("C", Some("api.github.com")),
750 ])
751 .await;
752 assert!(skips.is_empty());
753 let settled = engine.wait_within().await;
754 assert_eq!(settled.len(), 3);
755 }
756}