use std::collections::{BTreeMap, VecDeque};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use devboy_format_pipeline::adaptive_config::AdaptiveConfig;
use devboy_format_pipeline::enrichment::{PlannerOptions, TurnContext, build_plan};
use devboy_format_pipeline::layered_pipeline::{LayeredPipeline, ToolResponseInput};
use devboy_format_pipeline::projection::{extract_args, extract_host};
use devboy_format_pipeline::telemetry::{EnrichmentEffectiveness, JsonlSink, Layer, TelemetrySink};
use crate::protocol::{ToolCallParams, ToolCallResult, ToolResultContent};
use crate::speculation::{
PrefetchDispatcher, PrefetchOutcome, PrefetchRequest, SkipReason, SpeculationEngine,
};
const RECENT_TOOLS_WINDOW: usize = 16;
const FAIL_FAST_EMPTY_THRESHOLD_BYTES: usize = 8;
#[derive(Clone)]
pub struct SessionPipeline {
inner: Arc<Mutex<LayeredPipeline>>,
config: Arc<AdaptiveConfig>,
recent_tools: Arc<Mutex<VecDeque<String>>>,
enrichment: Arc<Mutex<EnrichmentEffectiveness>>,
fail_fast_streak: Arc<Mutex<BTreeMap<String, u32>>>,
speculation: Arc<tokio::sync::Mutex<Option<SpeculationEngine>>>,
}
impl SessionPipeline {
pub fn new(mut config: AdaptiveConfig) -> Self {
let defaults = devboy_format_pipeline::tool_defaults::default_tool_value_models();
for (name, model) in defaults {
config.tools.entry(name).or_insert(model);
}
let session_id = format!("mcp_{}", std::process::id());
let mut pipeline = LayeredPipeline::new(session_id.clone(), config.clone());
if config.telemetry.enabled
&& let Some(path) = resolve_telemetry_path(&config, &session_id)
{
match JsonlSink::open(&path) {
Ok(sink) => {
let arc: Arc<dyn TelemetrySink> = Arc::new(sink);
pipeline = pipeline.with_telemetry(arc);
tracing::info!(target: "devboy_mcp::telemetry", "telemetry sink opened at {}", path.display());
}
Err(e) => {
tracing::warn!(
target: "devboy_mcp::telemetry",
"telemetry sink at {} failed to open: {e} — running without telemetry",
path.display()
);
}
}
}
Self {
inner: Arc::new(Mutex::new(pipeline)),
config: Arc::new(config),
recent_tools: Arc::new(Mutex::new(VecDeque::with_capacity(RECENT_TOOLS_WINDOW))),
enrichment: Arc::new(Mutex::new(EnrichmentEffectiveness::default())),
fail_fast_streak: Arc::new(Mutex::new(BTreeMap::new())),
speculation: Arc::new(tokio::sync::Mutex::new(None)),
}
}
pub async fn with_speculation(self, dispatcher: Arc<dyn PrefetchDispatcher>) -> Self {
let engine = SpeculationEngine::new(self.config.enrichment.clone(), dispatcher);
*self.speculation.lock().await = Some(engine);
self
}
pub async fn shutdown(&self) {
if let Some(engine) = self.speculation.lock().await.as_mut() {
engine.shutdown().await;
}
}
pub fn enrichment_snapshot(&self) -> EnrichmentEffectiveness {
self.enrichment
.lock()
.map(|g| g.clone())
.unwrap_or_default()
}
pub fn recent_tools_snapshot(&self) -> Vec<String> {
self.recent_tools
.lock()
.map(|g| g.iter().cloned().collect())
.unwrap_or_default()
}
pub fn should_skip(&self, tool_name: &str) -> bool {
let Some(model) = self.config.effective_tool_value_model(tool_name) else {
return false;
};
let Some(threshold) = model.fail_fast_after_n else {
return false;
};
let streak = self
.fail_fast_streak
.lock()
.ok()
.and_then(|g| g.get(tool_name).copied())
.unwrap_or(0);
streak >= threshold
}
pub fn record_fail_fast_skip(&self, predicted_cost_tokens: u32) {
if let Ok(mut e) = self.enrichment.lock() {
e.record_fail_fast_skip(predicted_cost_tokens);
}
}
pub async fn speculate_after(
&self,
tool_name: &str,
prev_response_json: &serde_json::Value,
) -> String {
if !self.config.enrichment.enabled {
return String::new();
}
let mut engine_guard = self.speculation.lock().await;
let Some(engine) = engine_guard.as_mut() else {
return String::new();
};
if !engine.is_enabled() {
return String::new();
}
for outcome in engine.drain_pending().await {
if let PrefetchOutcome::Settled {
tool,
args,
body,
predicted_cost_tokens,
} = outcome
{
self.write_prefetch_to_cache(&tool, &args, &body, predicted_cost_tokens);
}
}
let recent = self.recent_tools_snapshot();
let ctx = TurnContext::new(&recent, self.config.enrichment.prefetch_budget_tokens);
let opts = PlannerOptions {
min_followup_probability: 0.3,
..PlannerOptions::default()
};
let plan = build_plan(&self.config, &ctx, opts);
let mut requests: Vec<PrefetchRequest> = Vec::new();
for call in &plan.calls {
let Some(model) = self.config.effective_tool_value_model(&call.tool) else {
continue;
};
if !model.is_speculatable() {
continue;
}
let Some(link) = self
.config
.effective_tool_value_model(tool_name)
.and_then(|m| m.follow_up.iter().find(|l| l.tool == call.tool))
else {
continue;
};
let arg_objects = extract_args(tool_name, prev_response_json, link);
if arg_objects.is_empty() {
continue;
}
for args in arg_objects {
let host = static_or_url_host(&args, model.rate_limit_host.as_deref());
requests.push(PrefetchRequest {
call: call.clone(),
args,
rate_limit_host: host,
});
}
}
if requests.is_empty() {
return String::new();
}
let total_to_dispatch = requests.len() as u32;
let skips = engine.dispatch(requests).await;
let dispatched = total_to_dispatch.saturating_sub(skips.len() as u32);
if let Ok(mut e) = self.enrichment.lock() {
for _ in 0..dispatched {
e.total_prefetches = e.total_prefetches.saturating_add(1);
e.record_prefetch_dispatched();
}
for s in &skips {
if let PrefetchOutcome::Skipped { reason, .. } = s {
let label = match reason {
SkipReason::HostSaturated => "host_saturated",
SkipReason::MaxParallelReached => "max_parallel_reached",
SkipReason::NotSpeculatable => "not_speculatable",
};
tracing::debug!(
target: "devboy_mcp::speculation",
"prefetch skipped: {label}"
);
}
}
}
let outcomes = engine.wait_within().await;
let mut hint_parts: Vec<String> = Vec::new();
for o in outcomes {
match o {
PrefetchOutcome::Settled {
tool,
args,
body,
predicted_cost_tokens,
} => {
self.write_prefetch_to_cache(&tool, &args, &body, predicted_cost_tokens);
hint_parts.push(format!("{tool}({})", short_args(&args)));
}
PrefetchOutcome::Failed { tool, error } => {
tracing::warn!(
target: "devboy_mcp::speculation",
"prefetch failed for {tool}: {error}"
);
if let Ok(mut e) = self.enrichment.lock() {
e.record_prefetch_wasted();
}
}
PrefetchOutcome::Skipped { .. } => {}
}
}
if hint_parts.is_empty() {
String::new()
} else {
format!(
"\n\n> [enrichment: pre-fetched {} in background — call as usual, results served from cache]",
hint_parts.join(", ")
)
}
}
fn write_prefetch_to_cache(
&self,
tool: &str,
args: &serde_json::Value,
body: &str,
predicted_cost_tokens: u32,
) {
let Ok(mut p) = self.inner.lock() else {
return;
};
let request_id = format!(
"prefetch_{}_{}",
tool,
short_args_hash(args)
);
let path = args.get("file_path").and_then(|v| v.as_str());
let input = ToolResponseInput {
tool_call_id: &request_id,
tool_name: tool,
file_path: path,
content: body,
is_sidechain: false,
ts_ms: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as i64)
.unwrap_or(0),
enricher_prefetched: true,
enricher_predicted_cost_tokens: predicted_cost_tokens,
};
let _out = p.process(input);
}
}
fn static_or_url_host(args: &serde_json::Value, static_host: Option<&str>) -> Option<String> {
if let Some(url) = args.get("url").and_then(|v| v.as_str())
&& let Some(h) = extract_host(url)
{
return Some(h);
}
static_host.map(String::from)
}
fn short_args(args: &serde_json::Value) -> String {
let Some(obj) = args.as_object() else {
return String::new();
};
for (_, v) in obj {
if let Some(s) = v.as_str() {
let mut t = s.to_string();
if t.len() > 40 {
t.truncate(40);
t.push('…');
}
return t;
}
}
String::new()
}
fn short_args_hash(args: &serde_json::Value) -> String {
let s = args.to_string();
let mut h: u64 = 5381;
for b in s.bytes() {
h = h.wrapping_mul(33).wrapping_add(b as u64);
}
format!("{h:08x}")
}
impl SessionPipeline {
pub fn on_compaction_boundary(&self) {
if let Ok(mut p) = self.inner.lock() {
p.on_compaction_boundary();
}
}
pub fn invalidate_file(&self, file_path: &str) {
if let Ok(mut p) = self.inner.lock() {
p.invalidate_file(file_path);
}
}
pub fn process(
&self,
request_id: &str,
params: &ToolCallParams,
result: ToolCallResult,
ts_ms: i64,
) -> ToolCallResult {
if result.is_error == Some(true) {
return result;
}
let file_path = extract_file_path(params.arguments.as_ref());
let mut new_content: Vec<ToolResultContent> = Vec::with_capacity(result.content.len());
let mut p = match self.inner.lock() {
Ok(g) => g,
Err(_) => return result,
};
let mut total_dedup_hits: u32 = 0;
let mut total_dedup_tokens_saved: u64 = 0;
let mut max_original_chars: usize = 0;
for c in result.content {
match c {
ToolResultContent::Text { text } => {
max_original_chars = max_original_chars.max(text.len());
let input = ToolResponseInput {
tool_call_id: request_id,
tool_name: ¶ms.name,
file_path: file_path.as_deref(),
content: &text,
is_sidechain: false,
ts_ms,
enricher_prefetched: false,
enricher_predicted_cost_tokens: 0,
};
let out = p.process(input);
if matches!(out.layer, Layer::L0) {
total_dedup_hits = total_dedup_hits.saturating_add(1);
if out.tokens_saved > 0 {
total_dedup_tokens_saved =
total_dedup_tokens_saved.saturating_add(out.tokens_saved as u64);
}
}
let body = if matches!(out.layer, Layer::L0) {
out.output
} else {
text
};
new_content.push(ToolResultContent::Text { text: body });
}
}
}
drop(p);
if total_dedup_hits > 0
&& let Ok(mut e) = self.enrichment.lock()
{
e.inference_calls_saved_dedup = e
.inference_calls_saved_dedup
.saturating_add(total_dedup_hits);
e.inference_tokens_saved = e
.inference_tokens_saved
.saturating_add(total_dedup_tokens_saved);
}
if let Ok(mut streak) = self.fail_fast_streak.lock() {
let entry = streak.entry(params.name.clone()).or_insert(0);
if max_original_chars <= FAIL_FAST_EMPTY_THRESHOLD_BYTES {
*entry = entry.saturating_add(1);
} else {
*entry = 0;
}
}
if let Ok(mut recent) = self.recent_tools.lock() {
if recent.len() >= RECENT_TOOLS_WINDOW {
recent.pop_front();
}
recent.push_back(params.name.clone());
}
ToolCallResult {
content: new_content,
is_error: result.is_error,
}
}
}
pub fn extract_file_path(args: Option<&serde_json::Value>) -> Option<String> {
let obj = args?.as_object()?;
for k in ["file_path", "path", "notebook_path"] {
if let Some(v) = obj.get(k).and_then(|v| v.as_str()) {
return Some(v.to_string());
}
}
None
}
pub fn is_mutating_tool(name: &str) -> bool {
matches!(name, "Edit" | "Write" | "MultiEdit" | "NotebookEdit")
}
fn resolve_telemetry_path(config: &AdaptiveConfig, session_id: &str) -> Option<PathBuf> {
let dir: PathBuf = if let Some(p) = config.telemetry.path.as_deref() {
Path::new(p).to_path_buf()
} else if let Ok(env_dir) = std::env::var("DEVBOY_TELEMETRY_DIR") {
PathBuf::from(env_dir)
} else if let Some(home) = std::env::var_os("HOME").map(PathBuf::from) {
home.join(".devboy").join("telemetry")
} else {
std::env::temp_dir().join(".devboy-telemetry")
};
Some(dir.join(format!("{session_id}.jsonl")))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::{ToolCallParams, ToolCallResult, ToolResultContent};
use serde_json::json;
fn read_params(path: &str) -> ToolCallParams {
ToolCallParams {
name: "Read".to_string(),
arguments: Some(json!({"file_path": path})),
}
}
fn long_text(seed: &str) -> String {
format!("{}{}", seed, "x".repeat(400))
}
#[test]
fn second_identical_read_emits_reference_hint() {
let pipeline = SessionPipeline::new(AdaptiveConfig::default());
let body = long_text("file-A:");
let r1 = pipeline.process(
"req_1",
&read_params("/tmp/a.rs"),
ToolCallResult::text(body.clone()),
0,
);
let r2 = pipeline.process(
"req_2",
&read_params("/tmp/a.rs"),
ToolCallResult::text(body.clone()),
10,
);
let ToolResultContent::Text { text: t1 } = &r1.content[0];
assert_eq!(t1, &body);
let ToolResultContent::Text { text: t2 } = &r2.content[0];
assert!(t2.len() < body.len() / 2, "expected hint, got `{t2}`");
assert!(
t2.contains("[ref:") || t2.contains("[ref "),
"expected reference hint, got `{t2}`"
);
}
#[test]
fn edit_invalidation_busts_cache() {
let pipeline = SessionPipeline::new(AdaptiveConfig::default());
let body = long_text("file-B:");
let _ = pipeline.process(
"req_1",
&read_params("/tmp/b.rs"),
ToolCallResult::text(body.clone()),
0,
);
pipeline.invalidate_file("/tmp/b.rs");
let r3 = pipeline.process(
"req_3",
&read_params("/tmp/b.rs"),
ToolCallResult::text(body.clone()),
10,
);
let ToolResultContent::Text { text: t3 } = &r3.content[0];
assert_eq!(t3, &body, "expected fresh body after invalidation");
}
#[test]
fn errors_are_never_deduped() {
let pipeline = SessionPipeline::new(AdaptiveConfig::default());
let body = long_text("err:");
let _ = pipeline.process(
"req_1",
&read_params("/tmp/c.rs"),
ToolCallResult::text(body.clone()),
0,
);
let mut err = ToolCallResult::text(body.clone());
err.is_error = Some(true);
let r2 = pipeline.process("req_2", &read_params("/tmp/c.rs"), err, 10);
let ToolResultContent::Text { text: t2 } = &r2.content[0];
assert_eq!(t2, &body, "errors must pass through untouched");
}
#[test]
fn telemetry_disabled_by_default_writes_no_files() {
let tmp = tempfile::tempdir().unwrap();
let mut cfg = AdaptiveConfig::default();
cfg.telemetry.path = Some(tmp.path().to_string_lossy().into_owned());
let pipeline = SessionPipeline::new(cfg);
let body = long_text("file-T:");
let _ = pipeline.process(
"req_1",
&read_params("/tmp/t.rs"),
ToolCallResult::text(body),
0,
);
let entries: Vec<_> = std::fs::read_dir(tmp.path())
.unwrap()
.filter_map(|e| e.ok())
.collect();
assert!(
entries.is_empty(),
"telemetry must be silent until explicitly enabled, found {entries:?}"
);
}
#[test]
fn telemetry_enabled_creates_jsonl_file() {
let tmp = tempfile::tempdir().unwrap();
let mut cfg = AdaptiveConfig::default();
cfg.telemetry.enabled = true;
cfg.telemetry.path = Some(tmp.path().to_string_lossy().into_owned());
cfg.telemetry.flush_every_n = 1;
let pipeline = SessionPipeline::new(cfg);
let body = long_text("file-U:");
let _ = pipeline.process(
"req_1",
&read_params("/tmp/u.rs"),
ToolCallResult::text(body),
0,
);
let mut found = false;
for entry in std::fs::read_dir(tmp.path()).unwrap() {
let entry = entry.unwrap();
if entry.path().extension().and_then(|s| s.to_str()) == Some("jsonl") {
let contents = std::fs::read_to_string(entry.path()).unwrap();
assert!(
contents.contains("\"endpoint_class\":\"Read\""),
"expected Read event in JSONL, got {contents}"
);
found = true;
break;
}
}
assert!(
found,
"expected at least one .jsonl file in {:?}",
tmp.path()
);
}
fn pipeline_with_fail_fast_on(tool: &str, threshold: u32) -> SessionPipeline {
let mut cfg = AdaptiveConfig::default();
let model = devboy_core::ToolValueModel {
fail_fast_after_n: Some(threshold),
..devboy_core::ToolValueModel::default()
};
cfg.tools.insert(tool.to_string(), model);
SessionPipeline::new(cfg)
}
fn empty_params(name: &str) -> ToolCallParams {
ToolCallParams {
name: name.to_string(),
arguments: None,
}
}
#[test]
fn dedup_hit_increments_inference_calls_saved_dedup() {
let pipeline = SessionPipeline::new(AdaptiveConfig::default());
let body = long_text("file-D:");
let _ = pipeline.process(
"req_1",
&read_params("/tmp/d.rs"),
ToolCallResult::text(body.clone()),
0,
);
let pre = pipeline.enrichment_snapshot();
assert_eq!(pre.inference_calls_saved_dedup, 0);
let _ = pipeline.process(
"req_2",
&read_params("/tmp/d.rs"),
ToolCallResult::text(body),
10,
);
let post = pipeline.enrichment_snapshot();
assert_eq!(post.inference_calls_saved_dedup, 1);
assert!(
post.inference_tokens_saved > 0,
"tokens_saved must be > 0 after a real L0 dedup, got {}",
post.inference_tokens_saved
);
assert_eq!(post.total_calls_saved(), 1);
}
#[test]
fn recent_tools_window_records_calls_in_order() {
let pipeline = SessionPipeline::new(AdaptiveConfig::default());
for (i, name) in ["Glob", "Grep", "Read"].iter().enumerate() {
let _ = pipeline.process(
&format!("req_{i}"),
&ToolCallParams {
name: (*name).to_string(),
arguments: None,
},
ToolCallResult::text(format!("body-{i}")),
i as i64,
);
}
assert_eq!(
pipeline.recent_tools_snapshot(),
vec!["Glob".to_string(), "Grep".into(), "Read".into()]
);
}
#[test]
fn fail_fast_arms_after_n_consecutive_empty_responses() {
let pipeline = pipeline_with_fail_fast_on("ToolSearch", 2);
assert!(!pipeline.should_skip("ToolSearch"), "fresh streak");
let _ = pipeline.process(
"req_1",
&empty_params("ToolSearch"),
ToolCallResult::text(String::new()),
0,
);
assert!(!pipeline.should_skip("ToolSearch"));
let _ = pipeline.process(
"req_2",
&empty_params("ToolSearch"),
ToolCallResult::text(String::new()),
10,
);
assert!(pipeline.should_skip("ToolSearch"));
for i in 0..5 {
let _ = pipeline.process(
&format!("rd_{i}"),
&empty_params("Read"),
ToolCallResult::text(String::new()),
100 + i,
);
}
assert!(!pipeline.should_skip("Read"));
}
#[test]
fn fail_fast_streak_resets_on_non_empty_response() {
let pipeline = pipeline_with_fail_fast_on("ToolSearch", 2);
let _ = pipeline.process(
"req_1",
&empty_params("ToolSearch"),
ToolCallResult::text(String::new()),
0,
);
let _ = pipeline.process(
"req_2",
&empty_params("ToolSearch"),
ToolCallResult::text("a real result".to_string()),
10,
);
let _ = pipeline.process(
"req_3",
&empty_params("ToolSearch"),
ToolCallResult::text(String::new()),
20,
);
assert!(!pipeline.should_skip("ToolSearch"));
}
#[test]
fn record_fail_fast_skip_updates_aggregator() {
let pipeline = pipeline_with_fail_fast_on("ToolSearch", 2);
pipeline.record_fail_fast_skip(40);
pipeline.record_fail_fast_skip(40);
let s = pipeline.enrichment_snapshot();
assert_eq!(s.inference_calls_saved_fail_fast, 2);
assert_eq!(s.inference_tokens_saved, 80);
}
use crate::speculation::{PrefetchDispatcher, PrefetchError};
use async_trait::async_trait;
use serde_json::Value;
struct MapDispatcher {
bodies: std::collections::HashMap<String, String>,
delay_ms: u64,
}
#[async_trait]
impl PrefetchDispatcher for MapDispatcher {
async fn dispatch(
&self,
tool: &str,
_args: serde_json::Value,
) -> Result<String, PrefetchError> {
tokio::time::sleep(std::time::Duration::from_millis(self.delay_ms)).await;
self.bodies
.get(tool)
.cloned()
.ok_or_else(|| PrefetchError::Rejected(format!("no body for {tool}")))
}
}
fn enrichment_on_config() -> AdaptiveConfig {
let mut cfg = AdaptiveConfig {
tools: devboy_format_pipeline::tool_defaults::default_tool_value_models(),
..AdaptiveConfig::default()
};
cfg.enrichment.enabled = true;
cfg.enrichment.prefetch_timeout_ms = 500;
cfg.enrichment.max_parallel_prefetches = 3;
cfg.enrichment.prefetch_budget_tokens = 4_000;
cfg
}
#[tokio::test]
async fn speculate_after_dispatches_glob_to_read_chain() {
let cfg = enrichment_on_config();
let mut bodies = std::collections::HashMap::new();
bodies.insert("Read".into(), "long body of file/main.rs ".repeat(40));
let dispatcher = Arc::new(MapDispatcher {
bodies,
delay_ms: 5,
});
let pipeline = SessionPipeline::new(cfg).with_speculation(dispatcher).await;
let glob_body = "src/main.rs\nsrc/lib.rs\nsrc/api.rs\n";
let _ = pipeline.process(
"req_1",
&ToolCallParams {
name: "Glob".to_string(),
arguments: Some(json!({"pattern": "src/**/*.rs"})),
},
ToolCallResult::text(glob_body.to_string()),
0,
);
let prev_response = Value::String(glob_body.to_string());
let hint = pipeline.speculate_after("Glob", &prev_response).await;
let snap = pipeline.enrichment_snapshot();
assert!(
snap.total_prefetches > 0,
"expected total_prefetches > 0, got {snap:?}"
);
assert!(
snap.prefetch_dispatched > 0,
"expected prefetch_dispatched > 0, got {snap:?}"
);
assert!(
hint.contains("Read"),
"expected Read in hint, got: {hint:?}"
);
pipeline.shutdown().await;
}
#[tokio::test]
async fn speculate_after_is_noop_when_disabled() {
let pipeline = SessionPipeline::new(AdaptiveConfig {
tools: devboy_format_pipeline::tool_defaults::default_tool_value_models(),
..AdaptiveConfig::default()
});
let _ = pipeline.process(
"req_1",
&ToolCallParams {
name: "Glob".to_string(),
arguments: Some(json!({"pattern": "src/**/*.rs"})),
},
ToolCallResult::text("src/main.rs\n".into()),
0,
);
let hint = pipeline
.speculate_after("Glob", &Value::String("src/main.rs\n".into()))
.await;
assert!(hint.is_empty(), "speculation must be silent when disabled");
let snap = pipeline.enrichment_snapshot();
assert_eq!(snap.total_prefetches, 0);
assert_eq!(snap.prefetch_dispatched, 0);
}
#[tokio::test]
async fn prefetched_call_emits_telemetry_event_tagged_correctly() {
let tmp = tempfile::tempdir().unwrap();
let mut cfg = enrichment_on_config();
cfg.telemetry.enabled = true;
cfg.telemetry.path = Some(tmp.path().to_string_lossy().into_owned());
cfg.telemetry.flush_every_n = 1;
let mut bodies = std::collections::HashMap::new();
bodies.insert("Read".into(), "fn main() {}\n".repeat(40));
let dispatcher = Arc::new(MapDispatcher {
bodies,
delay_ms: 5,
});
let pipeline = SessionPipeline::new(cfg).with_speculation(dispatcher).await;
let glob_body = "src/main.rs\n";
let _ = pipeline.process(
"req_1",
&ToolCallParams {
name: "Glob".to_string(),
arguments: Some(json!({"pattern": "src/**/*.rs"})),
},
ToolCallResult::text(glob_body.into()),
0,
);
let _hint = pipeline
.speculate_after("Glob", &Value::String(glob_body.into()))
.await;
pipeline.shutdown().await;
drop(pipeline);
let mut prefetched_event_lines: Vec<String> = Vec::new();
for entry in std::fs::read_dir(tmp.path()).unwrap() {
let entry = entry.unwrap();
if entry.path().extension().and_then(|s| s.to_str()) != Some("jsonl") {
continue;
}
for line in std::fs::read_to_string(entry.path()).unwrap().lines() {
if line.contains("\"enricher_prefetched\":true") {
prefetched_event_lines.push(line.into());
}
}
}
assert!(
!prefetched_event_lines.is_empty(),
"expected at least one event tagged enricher_prefetched=true"
);
assert!(
prefetched_event_lines
.iter()
.any(|l| l.contains("\"enricher_predicted_cost_tokens\":")),
"expected enricher_predicted_cost_tokens to be set in the event JSON"
);
}
#[tokio::test]
async fn shutdown_drains_pending_speculation() {
let mut cfg = enrichment_on_config();
cfg.enrichment.prefetch_timeout_ms = 1;
let mut bodies = std::collections::HashMap::new();
bodies.insert("Read".into(), "any body".into());
let dispatcher = Arc::new(MapDispatcher {
bodies,
delay_ms: 200, });
let pipeline = SessionPipeline::new(cfg).with_speculation(dispatcher).await;
let _ = pipeline.process(
"req_1",
&ToolCallParams {
name: "Glob".to_string(),
arguments: Some(json!({"pattern": "x"})),
},
ToolCallResult::text("src/main.rs\n".into()),
0,
);
let _hint = pipeline
.speculate_after("Glob", &Value::String("src/main.rs\n".into()))
.await;
pipeline.shutdown().await;
pipeline.shutdown().await;
}
#[test]
fn extract_file_path_handles_three_argument_names() {
assert_eq!(
extract_file_path(Some(&json!({"file_path": "/x"}))),
Some("/x".into())
);
assert_eq!(
extract_file_path(Some(&json!({"path": "/y"}))),
Some("/y".into())
);
assert_eq!(
extract_file_path(Some(&json!({"notebook_path": "/z"}))),
Some("/z".into())
);
assert_eq!(extract_file_path(Some(&json!({"unrelated": "x"}))), None);
assert_eq!(extract_file_path(None), None);
}
}