1use std::collections::{BTreeMap, VecDeque};
27use std::path::{Path, PathBuf};
28use std::sync::{Arc, Mutex};
29
30use devboy_format_pipeline::adaptive_config::AdaptiveConfig;
31use devboy_format_pipeline::enrichment::{PlannerOptions, TurnContext, build_plan};
32use devboy_format_pipeline::layered_pipeline::{LayeredPipeline, ToolResponseInput};
33use devboy_format_pipeline::projection::{extract_args, extract_host};
34use devboy_format_pipeline::telemetry::{EnrichmentEffectiveness, JsonlSink, Layer, TelemetrySink};
35
36use crate::protocol::{ToolCallParams, ToolCallResult, ToolResultContent};
37use crate::speculation::{
38 PrefetchDispatcher, PrefetchOutcome, PrefetchRequest, SkipReason, SpeculationEngine,
39};
40
41const RECENT_TOOLS_WINDOW: usize = 16;
45
46const FAIL_FAST_EMPTY_THRESHOLD_BYTES: usize = 8;
51
52#[derive(Clone)]
56pub struct SessionPipeline {
57 inner: Arc<Mutex<LayeredPipeline>>,
58 config: Arc<AdaptiveConfig>,
59 recent_tools: Arc<Mutex<VecDeque<String>>>,
63 enrichment: Arc<Mutex<EnrichmentEffectiveness>>,
65 fail_fast_streak: Arc<Mutex<BTreeMap<String, u32>>>,
69 speculation: Arc<tokio::sync::Mutex<Option<SpeculationEngine>>>,
75}
76
77impl SessionPipeline {
78 pub fn new(mut config: AdaptiveConfig) -> Self {
88 let defaults = devboy_format_pipeline::tool_defaults::default_tool_value_models();
93 for (name, model) in defaults {
94 config.tools.entry(name).or_insert(model);
95 }
96
97 let session_id = format!("mcp_{}", std::process::id());
98 let mut pipeline = LayeredPipeline::new(session_id.clone(), config.clone());
99
100 if config.telemetry.enabled
101 && let Some(path) = resolve_telemetry_path(&config, &session_id)
102 {
103 match JsonlSink::open(&path) {
104 Ok(sink) => {
105 let arc: Arc<dyn TelemetrySink> = Arc::new(sink);
106 pipeline = pipeline.with_telemetry(arc);
107 tracing::info!(target: "devboy_mcp::telemetry", "telemetry sink opened at {}", path.display());
108 }
109 Err(e) => {
110 tracing::warn!(
111 target: "devboy_mcp::telemetry",
112 "telemetry sink at {} failed to open: {e} — running without telemetry",
113 path.display()
114 );
115 }
116 }
117 }
118
119 Self {
120 inner: Arc::new(Mutex::new(pipeline)),
121 config: Arc::new(config),
122 recent_tools: Arc::new(Mutex::new(VecDeque::with_capacity(RECENT_TOOLS_WINDOW))),
123 enrichment: Arc::new(Mutex::new(EnrichmentEffectiveness::default())),
124 fail_fast_streak: Arc::new(Mutex::new(BTreeMap::new())),
125 speculation: Arc::new(tokio::sync::Mutex::new(None)),
126 }
127 }
128
129 pub async fn with_speculation(self, dispatcher: Arc<dyn PrefetchDispatcher>) -> Self {
135 let engine = SpeculationEngine::new(self.config.enrichment.clone(), dispatcher);
136 *self.speculation.lock().await = Some(engine);
137 self
138 }
139
140 pub async fn shutdown(&self) {
146 if let Some(engine) = self.speculation.lock().await.as_mut() {
147 engine.shutdown().await;
148 }
149 }
150
151 pub fn enrichment_snapshot(&self) -> EnrichmentEffectiveness {
156 self.enrichment
157 .lock()
158 .map(|g| g.clone())
159 .unwrap_or_default()
160 }
161
162 pub fn recent_tools_snapshot(&self) -> Vec<String> {
165 self.recent_tools
166 .lock()
167 .map(|g| g.iter().cloned().collect())
168 .unwrap_or_default()
169 }
170
171 pub fn should_skip(&self, tool_name: &str) -> bool {
183 let Some(model) = self.config.effective_tool_value_model(tool_name) else {
184 return false;
185 };
186 let Some(threshold) = model.fail_fast_after_n else {
187 return false;
188 };
189 let streak = self
190 .fail_fast_streak
191 .lock()
192 .ok()
193 .and_then(|g| g.get(tool_name).copied())
194 .unwrap_or(0);
195 streak >= threshold
196 }
197
198 pub fn record_fail_fast_skip(&self, predicted_cost_tokens: u32) {
204 if let Ok(mut e) = self.enrichment.lock() {
205 e.record_fail_fast_skip(predicted_cost_tokens);
206 }
207 }
208
209 pub async fn speculate_after(
233 &self,
234 tool_name: &str,
235 prev_response_json: &serde_json::Value,
236 ) -> String {
237 if !self.config.enrichment.enabled {
239 return String::new();
240 }
241 let mut engine_guard = self.speculation.lock().await;
242 let Some(engine) = engine_guard.as_mut() else {
243 return String::new();
244 };
245 if !engine.is_enabled() {
246 return String::new();
247 }
248
249 for outcome in engine.drain_pending().await {
254 if let PrefetchOutcome::Settled {
255 tool,
256 args,
257 body,
258 predicted_cost_tokens,
259 } = outcome
260 {
261 self.write_prefetch_to_cache(&tool, &args, &body, predicted_cost_tokens);
262 }
263 }
264
265 let recent = self.recent_tools_snapshot();
269 let ctx = TurnContext::new(&recent, self.config.enrichment.prefetch_budget_tokens);
270 let opts = PlannerOptions {
277 min_followup_probability: 0.3,
278 ..PlannerOptions::default()
279 };
280 let plan = build_plan(&self.config, &ctx, opts);
281
282 let mut requests: Vec<PrefetchRequest> = Vec::new();
288 for call in &plan.calls {
289 let Some(model) = self.config.effective_tool_value_model(&call.tool) else {
290 continue;
291 };
292 if !model.is_speculatable() {
293 continue;
294 }
295 let Some(link) = self
299 .config
300 .effective_tool_value_model(tool_name)
301 .and_then(|m| m.follow_up.iter().find(|l| l.tool == call.tool))
302 else {
303 continue;
304 };
305 let arg_objects = extract_args(tool_name, prev_response_json, link);
306 if arg_objects.is_empty() {
307 continue;
308 }
309 for args in arg_objects {
310 let host = static_or_url_host(&args, model.rate_limit_host.as_deref());
311 requests.push(PrefetchRequest {
312 call: call.clone(),
313 args,
314 rate_limit_host: host,
315 });
316 }
317 }
318
319 if requests.is_empty() {
320 return String::new();
321 }
322
323 let total_to_dispatch = requests.len() as u32;
327 let skips = engine.dispatch(requests).await;
328 let dispatched = total_to_dispatch.saturating_sub(skips.len() as u32);
329 if let Ok(mut e) = self.enrichment.lock() {
330 for _ in 0..dispatched {
331 e.total_prefetches = e.total_prefetches.saturating_add(1);
332 e.record_prefetch_dispatched();
333 }
334 for s in &skips {
337 if let PrefetchOutcome::Skipped { reason, .. } = s {
338 let label = match reason {
339 SkipReason::HostSaturated => "host_saturated",
340 SkipReason::MaxParallelReached => "max_parallel_reached",
341 SkipReason::NotSpeculatable => "not_speculatable",
342 };
343 tracing::debug!(
344 target: "devboy_mcp::speculation",
345 "prefetch skipped: {label}"
346 );
347 }
348 }
349 }
350
351 let outcomes = engine.wait_within().await;
355 let mut hint_parts: Vec<String> = Vec::new();
356 for o in outcomes {
357 match o {
358 PrefetchOutcome::Settled {
359 tool,
360 args,
361 body,
362 predicted_cost_tokens,
363 } => {
364 self.write_prefetch_to_cache(&tool, &args, &body, predicted_cost_tokens);
365 hint_parts.push(format!("{tool}({})", short_args(&args)));
366 }
367 PrefetchOutcome::Failed { tool, error } => {
368 tracing::warn!(
369 target: "devboy_mcp::speculation",
370 "prefetch failed for {tool}: {error}"
371 );
372 if let Ok(mut e) = self.enrichment.lock() {
373 e.record_prefetch_wasted();
374 }
375 }
376 PrefetchOutcome::Skipped { .. } => {}
377 }
378 }
379
380 if hint_parts.is_empty() {
381 String::new()
382 } else {
383 format!(
384 "\n\n> [enrichment: pre-fetched {} in background — call as usual, results served from cache]",
385 hint_parts.join(", ")
386 )
387 }
388 }
389
390 fn write_prefetch_to_cache(
397 &self,
398 tool: &str,
399 args: &serde_json::Value,
400 body: &str,
401 predicted_cost_tokens: u32,
402 ) {
403 let Ok(mut p) = self.inner.lock() else {
404 return;
405 };
406 let request_id = format!(
407 "prefetch_{}_{}",
408 tool,
409 short_args_hash(args)
412 );
413 let path = args.get("file_path").and_then(|v| v.as_str());
414 let input = ToolResponseInput {
415 tool_call_id: &request_id,
416 tool_name: tool,
417 file_path: path,
418 content: body,
419 is_sidechain: false,
420 ts_ms: std::time::SystemTime::now()
421 .duration_since(std::time::UNIX_EPOCH)
422 .map(|d| d.as_millis() as i64)
423 .unwrap_or(0),
424 enricher_prefetched: true,
428 enricher_predicted_cost_tokens: predicted_cost_tokens,
429 };
430 let _out = p.process(input);
433 }
434}
435
436fn static_or_url_host(args: &serde_json::Value, static_host: Option<&str>) -> Option<String> {
440 if let Some(url) = args.get("url").and_then(|v| v.as_str())
441 && let Some(h) = extract_host(url)
442 {
443 return Some(h);
444 }
445 static_host.map(String::from)
446}
447
448fn short_args(args: &serde_json::Value) -> String {
452 let Some(obj) = args.as_object() else {
453 return String::new();
454 };
455 for (_, v) in obj {
456 if let Some(s) = v.as_str() {
457 let mut t = s.to_string();
458 if t.len() > 40 {
459 t.truncate(40);
460 t.push('…');
461 }
462 return t;
463 }
464 }
465 String::new()
466}
467
468fn short_args_hash(args: &serde_json::Value) -> String {
473 let s = args.to_string();
474 let mut h: u64 = 5381;
475 for b in s.bytes() {
476 h = h.wrapping_mul(33).wrapping_add(b as u64);
477 }
478 format!("{h:08x}")
479}
480
481impl SessionPipeline {
482 pub fn on_compaction_boundary(&self) {
485 if let Ok(mut p) = self.inner.lock() {
486 p.on_compaction_boundary();
487 }
488 }
489
490 pub fn invalidate_file(&self, file_path: &str) {
495 if let Ok(mut p) = self.inner.lock() {
496 p.invalidate_file(file_path);
497 }
498 }
499
500 pub fn process(
507 &self,
508 request_id: &str,
509 params: &ToolCallParams,
510 result: ToolCallResult,
511 ts_ms: i64,
512 ) -> ToolCallResult {
513 if result.is_error == Some(true) {
516 return result;
517 }
518
519 let file_path = extract_file_path(params.arguments.as_ref());
520
521 let mut new_content: Vec<ToolResultContent> = Vec::with_capacity(result.content.len());
522 let mut p = match self.inner.lock() {
523 Ok(g) => g,
524 Err(_) => return result,
527 };
528
529 let mut total_dedup_hits: u32 = 0;
532 let mut total_dedup_tokens_saved: u64 = 0;
533 let mut max_original_chars: usize = 0;
534
535 for c in result.content {
536 match c {
537 ToolResultContent::Text { text } => {
538 max_original_chars = max_original_chars.max(text.len());
539 let input = ToolResponseInput {
540 tool_call_id: request_id,
541 tool_name: ¶ms.name,
542 file_path: file_path.as_deref(),
543 content: &text,
544 is_sidechain: false,
545 ts_ms,
546 enricher_prefetched: false,
548 enricher_predicted_cost_tokens: 0,
549 };
550 let out = p.process(input);
551 if matches!(out.layer, Layer::L0) {
552 total_dedup_hits = total_dedup_hits.saturating_add(1);
553 if out.tokens_saved > 0 {
556 total_dedup_tokens_saved =
557 total_dedup_tokens_saved.saturating_add(out.tokens_saved as u64);
558 }
559 }
560 let body = if matches!(out.layer, Layer::L0) {
565 out.output
566 } else {
567 text
568 };
569 new_content.push(ToolResultContent::Text { text: body });
570 }
571 }
572 }
573
574 drop(p);
578
579 if total_dedup_hits > 0
583 && let Ok(mut e) = self.enrichment.lock()
584 {
585 e.inference_calls_saved_dedup = e
586 .inference_calls_saved_dedup
587 .saturating_add(total_dedup_hits);
588 e.inference_tokens_saved = e
589 .inference_tokens_saved
590 .saturating_add(total_dedup_tokens_saved);
591 }
592
593 if let Ok(mut streak) = self.fail_fast_streak.lock() {
594 let entry = streak.entry(params.name.clone()).or_insert(0);
595 if max_original_chars <= FAIL_FAST_EMPTY_THRESHOLD_BYTES {
596 *entry = entry.saturating_add(1);
597 } else {
598 *entry = 0;
599 }
600 }
601
602 if let Ok(mut recent) = self.recent_tools.lock() {
603 if recent.len() >= RECENT_TOOLS_WINDOW {
604 recent.pop_front();
605 }
606 recent.push_back(params.name.clone());
607 }
608
609 ToolCallResult {
610 content: new_content,
611 is_error: result.is_error,
612 }
613 }
614}
615
616pub fn extract_file_path(args: Option<&serde_json::Value>) -> Option<String> {
619 let obj = args?.as_object()?;
620 for k in ["file_path", "path", "notebook_path"] {
621 if let Some(v) = obj.get(k).and_then(|v| v.as_str()) {
622 return Some(v.to_string());
623 }
624 }
625 None
626}
627
628pub fn is_mutating_tool(name: &str) -> bool {
631 matches!(name, "Edit" | "Write" | "MultiEdit" | "NotebookEdit")
632}
633
634fn resolve_telemetry_path(config: &AdaptiveConfig, session_id: &str) -> Option<PathBuf> {
638 let dir: PathBuf = if let Some(p) = config.telemetry.path.as_deref() {
639 Path::new(p).to_path_buf()
640 } else if let Ok(env_dir) = std::env::var("DEVBOY_TELEMETRY_DIR") {
641 PathBuf::from(env_dir)
642 } else if let Some(home) = std::env::var_os("HOME").map(PathBuf::from) {
643 home.join(".devboy").join("telemetry")
644 } else {
645 std::env::temp_dir().join(".devboy-telemetry")
646 };
647 Some(dir.join(format!("{session_id}.jsonl")))
648}
649
650#[cfg(test)]
651mod tests {
652 use super::*;
653 use crate::protocol::{ToolCallParams, ToolCallResult, ToolResultContent};
654 use serde_json::json;
655
656 fn read_params(path: &str) -> ToolCallParams {
657 ToolCallParams {
658 name: "Read".to_string(),
659 arguments: Some(json!({"file_path": path})),
660 }
661 }
662
663 fn long_text(seed: &str) -> String {
664 format!("{}{}", seed, "x".repeat(400))
667 }
668
669 #[test]
670 fn second_identical_read_emits_reference_hint() {
671 let pipeline = SessionPipeline::new(AdaptiveConfig::default());
672 let body = long_text("file-A:");
673 let r1 = pipeline.process(
674 "req_1",
675 &read_params("/tmp/a.rs"),
676 ToolCallResult::text(body.clone()),
677 0,
678 );
679 let r2 = pipeline.process(
680 "req_2",
681 &read_params("/tmp/a.rs"),
682 ToolCallResult::text(body.clone()),
683 10,
684 );
685 let ToolResultContent::Text { text: t1 } = &r1.content[0];
687 assert_eq!(t1, &body);
688 let ToolResultContent::Text { text: t2 } = &r2.content[0];
690 assert!(t2.len() < body.len() / 2, "expected hint, got `{t2}`");
691 assert!(
692 t2.contains("[ref:") || t2.contains("[ref "),
693 "expected reference hint, got `{t2}`"
694 );
695 }
696
697 #[test]
698 fn edit_invalidation_busts_cache() {
699 let pipeline = SessionPipeline::new(AdaptiveConfig::default());
700 let body = long_text("file-B:");
701 let _ = pipeline.process(
702 "req_1",
703 &read_params("/tmp/b.rs"),
704 ToolCallResult::text(body.clone()),
705 0,
706 );
707 pipeline.invalidate_file("/tmp/b.rs");
709 let r3 = pipeline.process(
711 "req_3",
712 &read_params("/tmp/b.rs"),
713 ToolCallResult::text(body.clone()),
714 10,
715 );
716 let ToolResultContent::Text { text: t3 } = &r3.content[0];
717 assert_eq!(t3, &body, "expected fresh body after invalidation");
718 }
719
720 #[test]
721 fn errors_are_never_deduped() {
722 let pipeline = SessionPipeline::new(AdaptiveConfig::default());
723 let body = long_text("err:");
724 let _ = pipeline.process(
725 "req_1",
726 &read_params("/tmp/c.rs"),
727 ToolCallResult::text(body.clone()),
728 0,
729 );
730 let mut err = ToolCallResult::text(body.clone());
731 err.is_error = Some(true);
732 let r2 = pipeline.process("req_2", &read_params("/tmp/c.rs"), err, 10);
733 let ToolResultContent::Text { text: t2 } = &r2.content[0];
734 assert_eq!(t2, &body, "errors must pass through untouched");
735 }
736
737 #[test]
738 fn telemetry_disabled_by_default_writes_no_files() {
739 let tmp = tempfile::tempdir().unwrap();
740 let mut cfg = AdaptiveConfig::default();
741 cfg.telemetry.path = Some(tmp.path().to_string_lossy().into_owned());
742 let pipeline = SessionPipeline::new(cfg);
744 let body = long_text("file-T:");
745 let _ = pipeline.process(
746 "req_1",
747 &read_params("/tmp/t.rs"),
748 ToolCallResult::text(body),
749 0,
750 );
751 let entries: Vec<_> = std::fs::read_dir(tmp.path())
753 .unwrap()
754 .filter_map(|e| e.ok())
755 .collect();
756 assert!(
757 entries.is_empty(),
758 "telemetry must be silent until explicitly enabled, found {entries:?}"
759 );
760 }
761
762 #[test]
763 fn telemetry_enabled_creates_jsonl_file() {
764 let tmp = tempfile::tempdir().unwrap();
765 let mut cfg = AdaptiveConfig::default();
766 cfg.telemetry.enabled = true;
767 cfg.telemetry.path = Some(tmp.path().to_string_lossy().into_owned());
768 cfg.telemetry.flush_every_n = 1;
770 let pipeline = SessionPipeline::new(cfg);
771 let body = long_text("file-U:");
772 let _ = pipeline.process(
773 "req_1",
774 &read_params("/tmp/u.rs"),
775 ToolCallResult::text(body),
776 0,
777 );
778 let mut found = false;
779 for entry in std::fs::read_dir(tmp.path()).unwrap() {
780 let entry = entry.unwrap();
781 if entry.path().extension().and_then(|s| s.to_str()) == Some("jsonl") {
782 let contents = std::fs::read_to_string(entry.path()).unwrap();
783 assert!(
784 contents.contains("\"endpoint_class\":\"Read\""),
785 "expected Read event in JSONL, got {contents}"
786 );
787 found = true;
788 break;
789 }
790 }
791 assert!(
792 found,
793 "expected at least one .jsonl file in {:?}",
794 tmp.path()
795 );
796 }
797
798 fn pipeline_with_fail_fast_on(tool: &str, threshold: u32) -> SessionPipeline {
801 let mut cfg = AdaptiveConfig::default();
802 let model = devboy_core::ToolValueModel {
803 fail_fast_after_n: Some(threshold),
804 ..devboy_core::ToolValueModel::default()
805 };
806 cfg.tools.insert(tool.to_string(), model);
807 SessionPipeline::new(cfg)
808 }
809
810 fn empty_params(name: &str) -> ToolCallParams {
811 ToolCallParams {
812 name: name.to_string(),
813 arguments: None,
814 }
815 }
816
817 #[test]
818 fn dedup_hit_increments_inference_calls_saved_dedup() {
819 let pipeline = SessionPipeline::new(AdaptiveConfig::default());
820 let body = long_text("file-D:");
821 let _ = pipeline.process(
822 "req_1",
823 &read_params("/tmp/d.rs"),
824 ToolCallResult::text(body.clone()),
825 0,
826 );
827 let pre = pipeline.enrichment_snapshot();
828 assert_eq!(pre.inference_calls_saved_dedup, 0);
829
830 let _ = pipeline.process(
832 "req_2",
833 &read_params("/tmp/d.rs"),
834 ToolCallResult::text(body),
835 10,
836 );
837 let post = pipeline.enrichment_snapshot();
838 assert_eq!(post.inference_calls_saved_dedup, 1);
839 assert!(
840 post.inference_tokens_saved > 0,
841 "tokens_saved must be > 0 after a real L0 dedup, got {}",
842 post.inference_tokens_saved
843 );
844 assert_eq!(post.total_calls_saved(), 1);
845 }
846
847 #[test]
848 fn recent_tools_window_records_calls_in_order() {
849 let pipeline = SessionPipeline::new(AdaptiveConfig::default());
850 for (i, name) in ["Glob", "Grep", "Read"].iter().enumerate() {
851 let _ = pipeline.process(
852 &format!("req_{i}"),
853 &ToolCallParams {
854 name: (*name).to_string(),
855 arguments: None,
856 },
857 ToolCallResult::text(format!("body-{i}")),
858 i as i64,
859 );
860 }
861 assert_eq!(
862 pipeline.recent_tools_snapshot(),
863 vec!["Glob".to_string(), "Grep".into(), "Read".into()]
864 );
865 }
866
867 #[test]
868 fn fail_fast_arms_after_n_consecutive_empty_responses() {
869 let pipeline = pipeline_with_fail_fast_on("ToolSearch", 2);
871 assert!(!pipeline.should_skip("ToolSearch"), "fresh streak");
872
873 let _ = pipeline.process(
875 "req_1",
876 &empty_params("ToolSearch"),
877 ToolCallResult::text(String::new()),
878 0,
879 );
880 assert!(!pipeline.should_skip("ToolSearch"));
881
882 let _ = pipeline.process(
884 "req_2",
885 &empty_params("ToolSearch"),
886 ToolCallResult::text(String::new()),
887 10,
888 );
889 assert!(pipeline.should_skip("ToolSearch"));
890
891 for i in 0..5 {
894 let _ = pipeline.process(
895 &format!("rd_{i}"),
896 &empty_params("Read"),
897 ToolCallResult::text(String::new()),
898 100 + i,
899 );
900 }
901 assert!(!pipeline.should_skip("Read"));
902 }
903
904 #[test]
905 fn fail_fast_streak_resets_on_non_empty_response() {
906 let pipeline = pipeline_with_fail_fast_on("ToolSearch", 2);
907 let _ = pipeline.process(
908 "req_1",
909 &empty_params("ToolSearch"),
910 ToolCallResult::text(String::new()),
911 0,
912 );
913 let _ = pipeline.process(
915 "req_2",
916 &empty_params("ToolSearch"),
917 ToolCallResult::text("a real result".to_string()),
918 10,
919 );
920 let _ = pipeline.process(
921 "req_3",
922 &empty_params("ToolSearch"),
923 ToolCallResult::text(String::new()),
924 20,
925 );
926 assert!(!pipeline.should_skip("ToolSearch"));
928 }
929
930 #[test]
931 fn record_fail_fast_skip_updates_aggregator() {
932 let pipeline = pipeline_with_fail_fast_on("ToolSearch", 2);
933 pipeline.record_fail_fast_skip(40);
934 pipeline.record_fail_fast_skip(40);
935 let s = pipeline.enrichment_snapshot();
936 assert_eq!(s.inference_calls_saved_fail_fast, 2);
937 assert_eq!(s.inference_tokens_saved, 80);
938 }
939
940 use crate::speculation::{PrefetchDispatcher, PrefetchError};
943 use async_trait::async_trait;
944 use serde_json::Value;
945
946 struct MapDispatcher {
949 bodies: std::collections::HashMap<String, String>,
950 delay_ms: u64,
951 }
952
953 #[async_trait]
954 impl PrefetchDispatcher for MapDispatcher {
955 async fn dispatch(
956 &self,
957 tool: &str,
958 _args: serde_json::Value,
959 ) -> Result<String, PrefetchError> {
960 tokio::time::sleep(std::time::Duration::from_millis(self.delay_ms)).await;
961 self.bodies
962 .get(tool)
963 .cloned()
964 .ok_or_else(|| PrefetchError::Rejected(format!("no body for {tool}")))
965 }
966 }
967
968 fn enrichment_on_config() -> AdaptiveConfig {
969 let mut cfg = AdaptiveConfig {
970 tools: devboy_format_pipeline::tool_defaults::default_tool_value_models(),
971 ..AdaptiveConfig::default()
972 };
973 cfg.enrichment.enabled = true;
974 cfg.enrichment.prefetch_timeout_ms = 500;
975 cfg.enrichment.max_parallel_prefetches = 3;
976 cfg.enrichment.prefetch_budget_tokens = 4_000;
979 cfg
980 }
981
982 #[tokio::test]
983 async fn speculate_after_dispatches_glob_to_read_chain() {
984 let cfg = enrichment_on_config();
985 let mut bodies = std::collections::HashMap::new();
986 bodies.insert("Read".into(), "long body of file/main.rs ".repeat(40));
987 let dispatcher = Arc::new(MapDispatcher {
988 bodies,
989 delay_ms: 5,
990 });
991 let pipeline = SessionPipeline::new(cfg).with_speculation(dispatcher).await;
992
993 let glob_body = "src/main.rs\nsrc/lib.rs\nsrc/api.rs\n";
995 let _ = pipeline.process(
996 "req_1",
997 &ToolCallParams {
998 name: "Glob".to_string(),
999 arguments: Some(json!({"pattern": "src/**/*.rs"})),
1000 },
1001 ToolCallResult::text(glob_body.to_string()),
1002 0,
1003 );
1004
1005 let prev_response = Value::String(glob_body.to_string());
1009 let hint = pipeline.speculate_after("Glob", &prev_response).await;
1010
1011 let snap = pipeline.enrichment_snapshot();
1012 assert!(
1014 snap.total_prefetches > 0,
1015 "expected total_prefetches > 0, got {snap:?}"
1016 );
1017 assert!(
1018 snap.prefetch_dispatched > 0,
1019 "expected prefetch_dispatched > 0, got {snap:?}"
1020 );
1021 assert!(
1024 hint.contains("Read"),
1025 "expected Read in hint, got: {hint:?}"
1026 );
1027 pipeline.shutdown().await;
1028 }
1029
1030 #[tokio::test]
1031 async fn speculate_after_is_noop_when_disabled() {
1032 let pipeline = SessionPipeline::new(AdaptiveConfig {
1034 tools: devboy_format_pipeline::tool_defaults::default_tool_value_models(),
1035 ..AdaptiveConfig::default()
1036 });
1037 let _ = pipeline.process(
1038 "req_1",
1039 &ToolCallParams {
1040 name: "Glob".to_string(),
1041 arguments: Some(json!({"pattern": "src/**/*.rs"})),
1042 },
1043 ToolCallResult::text("src/main.rs\n".into()),
1044 0,
1045 );
1046 let hint = pipeline
1047 .speculate_after("Glob", &Value::String("src/main.rs\n".into()))
1048 .await;
1049 assert!(hint.is_empty(), "speculation must be silent when disabled");
1050 let snap = pipeline.enrichment_snapshot();
1051 assert_eq!(snap.total_prefetches, 0);
1052 assert_eq!(snap.prefetch_dispatched, 0);
1053 }
1054
1055 #[tokio::test]
1056 async fn prefetched_call_emits_telemetry_event_tagged_correctly() {
1057 let tmp = tempfile::tempdir().unwrap();
1061 let mut cfg = enrichment_on_config();
1062 cfg.telemetry.enabled = true;
1063 cfg.telemetry.path = Some(tmp.path().to_string_lossy().into_owned());
1064 cfg.telemetry.flush_every_n = 1;
1065
1066 let mut bodies = std::collections::HashMap::new();
1067 bodies.insert("Read".into(), "fn main() {}\n".repeat(40));
1070 let dispatcher = Arc::new(MapDispatcher {
1071 bodies,
1072 delay_ms: 5,
1073 });
1074 let pipeline = SessionPipeline::new(cfg).with_speculation(dispatcher).await;
1075
1076 let glob_body = "src/main.rs\n";
1078 let _ = pipeline.process(
1079 "req_1",
1080 &ToolCallParams {
1081 name: "Glob".to_string(),
1082 arguments: Some(json!({"pattern": "src/**/*.rs"})),
1083 },
1084 ToolCallResult::text(glob_body.into()),
1085 0,
1086 );
1087 let _hint = pipeline
1088 .speculate_after("Glob", &Value::String(glob_body.into()))
1089 .await;
1090 pipeline.shutdown().await;
1091
1092 drop(pipeline);
1094
1095 let mut prefetched_event_lines: Vec<String> = Vec::new();
1098 for entry in std::fs::read_dir(tmp.path()).unwrap() {
1099 let entry = entry.unwrap();
1100 if entry.path().extension().and_then(|s| s.to_str()) != Some("jsonl") {
1101 continue;
1102 }
1103 for line in std::fs::read_to_string(entry.path()).unwrap().lines() {
1104 if line.contains("\"enricher_prefetched\":true") {
1105 prefetched_event_lines.push(line.into());
1106 }
1107 }
1108 }
1109
1110 assert!(
1111 !prefetched_event_lines.is_empty(),
1112 "expected at least one event tagged enricher_prefetched=true"
1113 );
1114 assert!(
1116 prefetched_event_lines
1117 .iter()
1118 .any(|l| l.contains("\"enricher_predicted_cost_tokens\":")),
1119 "expected enricher_predicted_cost_tokens to be set in the event JSON"
1120 );
1121 }
1122
1123 #[tokio::test]
1124 async fn shutdown_drains_pending_speculation() {
1125 let mut cfg = enrichment_on_config();
1126 cfg.enrichment.prefetch_timeout_ms = 1;
1129 let mut bodies = std::collections::HashMap::new();
1130 bodies.insert("Read".into(), "any body".into());
1131 let dispatcher = Arc::new(MapDispatcher {
1132 bodies,
1133 delay_ms: 200, });
1135 let pipeline = SessionPipeline::new(cfg).with_speculation(dispatcher).await;
1136 let _ = pipeline.process(
1137 "req_1",
1138 &ToolCallParams {
1139 name: "Glob".to_string(),
1140 arguments: Some(json!({"pattern": "x"})),
1141 },
1142 ToolCallResult::text("src/main.rs\n".into()),
1143 0,
1144 );
1145 let _hint = pipeline
1146 .speculate_after("Glob", &Value::String("src/main.rs\n".into()))
1147 .await;
1148 pipeline.shutdown().await;
1151 pipeline.shutdown().await;
1153 }
1154
1155 #[test]
1156 fn extract_file_path_handles_three_argument_names() {
1157 assert_eq!(
1158 extract_file_path(Some(&json!({"file_path": "/x"}))),
1159 Some("/x".into())
1160 );
1161 assert_eq!(
1162 extract_file_path(Some(&json!({"path": "/y"}))),
1163 Some("/y".into())
1164 );
1165 assert_eq!(
1166 extract_file_path(Some(&json!({"notebook_path": "/z"}))),
1167 Some("/z".into())
1168 );
1169 assert_eq!(extract_file_path(Some(&json!({"unrelated": "x"}))), None);
1170 assert_eq!(extract_file_path(None), None);
1171 }
1172}