#![allow(dead_code)]
pub mod cache;
pub mod partial_json;
pub mod paste;
pub mod prediction;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::Instant;
use tokio_util::sync::CancellationToken;
use tracing::debug;
use zeph_common::SkillTrustLevel;
use zeph_tools::{ErasedToolExecutor, ToolCall, ToolError, ToolOutput};
use cache::{HandleKey, SpeculativeCache, SpeculativeHandle, hash_args};
use prediction::Prediction;
pub use zeph_config::tools::{SpeculationMode, SpeculativeConfig};
#[derive(Debug, Default, Clone)]
pub struct SpeculativeMetrics {
pub committed: u32,
pub cancelled: u32,
pub evicted_oldest: u32,
pub skipped_confirmation: u32,
pub wasted_ms: u64,
}
pub struct SpeculationEngine {
executor: Arc<dyn ErasedToolExecutor>,
config: SpeculativeConfig,
cache: SpeculativeCache,
metrics: parking_lot::Mutex<SpeculativeMetrics>,
sweeper: parking_lot::Mutex<Option<zeph_common::task_supervisor::TaskHandle>>,
task_supervisor: Option<Arc<zeph_common::TaskSupervisor>>,
}
impl SpeculationEngine {
#[must_use]
pub fn new(executor: Arc<dyn ErasedToolExecutor>, config: SpeculativeConfig) -> Self {
Self::new_with_supervisor(executor, config, None)
}
#[must_use]
pub fn new_with_supervisor(
executor: Arc<dyn ErasedToolExecutor>,
config: SpeculativeConfig,
supervisor: Option<Arc<zeph_common::TaskSupervisor>>,
) -> Self {
let cache = SpeculativeCache::new(config.max_in_flight);
let shared = cache.shared_inner();
let sweeper_handle = if let Some(sup) = &supervisor {
Some(sup.spawn(zeph_common::task_supervisor::TaskDescriptor {
name: "agent.speculative.sweeper",
restart: zeph_common::task_supervisor::RestartPolicy::RunOnce,
factory: move || {
let shared = Arc::clone(&shared);
async move {
let mut interval = tokio::time::interval(Duration::from_secs(5));
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
interval.tick().await;
SpeculativeCache::sweep_expired_inner(&shared);
}
}
},
}))
} else {
let jh = tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(5));
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
interval.tick().await;
SpeculativeCache::sweep_expired_inner(&shared);
}
});
let cancel = tokio_util::sync::CancellationToken::new();
let tmp_sup = zeph_common::TaskSupervisor::new(cancel);
let h = tmp_sup.spawn(zeph_common::task_supervisor::TaskDescriptor {
name: "agent.speculative.sweeper",
restart: zeph_common::task_supervisor::RestartPolicy::RunOnce,
factory: || async {},
});
let abort = jh.abort_handle();
std::mem::forget(jh); drop(abort);
Some(h)
};
Self {
executor,
config,
cache,
metrics: parking_lot::Mutex::new(SpeculativeMetrics::default()),
sweeper: parking_lot::Mutex::new(sweeper_handle),
task_supervisor: supervisor,
}
}
#[must_use]
pub fn mode(&self) -> SpeculationMode {
self.config.mode
}
#[must_use]
pub fn is_active(&self) -> bool {
self.config.mode != SpeculationMode::Off
}
pub fn try_dispatch(&self, prediction: &Prediction, trust_level: SkillTrustLevel) -> bool {
if trust_level != SkillTrustLevel::Trusted {
return false;
}
let tool_id = &prediction.tool_id;
if !self.executor.is_tool_speculatable_erased(tool_id.as_ref()) {
return false;
}
let call = prediction.to_tool_call(format!("spec-{}", uuid::Uuid::new_v4()));
let args_hash = hash_args(&call.params);
if self.executor.requires_confirmation_erased(&call) {
let mut m = self.metrics.lock();
m.skipped_confirmation += 1;
debug!(tool_id = %tool_id, "speculative skip: requires_confirmation");
return false;
}
let exec = Arc::clone(&self.executor);
let call_clone = call.clone();
let cancel = CancellationToken::new();
let cancel_child = cancel.child_token();
let task_name: Arc<str> = Arc::from(format!(
"agent.speculative.dispatch.{}",
uuid::Uuid::new_v4()
));
let join = if let Some(sup) = &self.task_supervisor {
sup.spawn_oneshot(Arc::clone(&task_name), move || async move {
tokio::select! {
result = exec.execute_tool_call_erased(&call_clone) => result,
() = cancel_child.cancelled() => {
Err(ToolError::Execution(std::io::Error::other("speculative cancelled")))
}
}
})
} else {
let tmp_cancel = tokio_util::sync::CancellationToken::new();
let tmp_sup = Arc::new(zeph_common::TaskSupervisor::new(tmp_cancel));
tmp_sup.spawn_oneshot(task_name, move || async move {
tokio::select! {
result = exec.execute_tool_call_erased(&call_clone) => result,
() = cancel_child.cancelled() => {
Err(ToolError::Execution(std::io::Error::other("speculative cancelled")))
}
}
})
};
let handle = SpeculativeHandle {
key: HandleKey {
tool_id: tool_id.clone(),
args_hash,
},
join,
cancel,
ttl_deadline: Instant::now() + Duration::from_secs(self.config.ttl_seconds),
started_at: std::time::Instant::now(),
};
debug!(tool_id = %tool_id, confidence = prediction.confidence, "speculative dispatch");
self.cache.insert(handle);
true
}
pub async fn try_commit(
&self,
call: &ToolCall,
) -> Option<Result<Option<ToolOutput>, ToolError>> {
let args_hash = hash_args(&call.params);
if let Some(handle) = self.cache.take_match(&call.tool_id, &args_hash) {
{
let mut m = self.metrics.lock();
m.committed += 1;
}
debug!(tool_id = %call.tool_id, "speculative commit");
Some(handle.commit().await)
} else {
None
}
}
pub fn cancel_for(&self, tool_id: &zeph_common::ToolName) {
debug!(tool_id = %tool_id, "speculative cancel for tool");
self.cache.cancel_by_tool_id(tool_id);
let mut m = self.metrics.lock();
m.cancelled += 1;
}
pub fn end_turn(&self) -> SpeculativeMetrics {
self.cache.cancel_all();
let m = self.metrics.lock().clone();
*self.metrics.lock() = SpeculativeMetrics::default();
m
}
#[must_use]
pub fn metrics_snapshot(&self) -> SpeculativeMetrics {
self.metrics.lock().clone()
}
}
impl Drop for SpeculationEngine {
fn drop(&mut self) {
self.cache.cancel_all();
if let Some(handle) = self.sweeper.lock().take() {
handle.abort();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use zeph_tools::{ToolCall, ToolError, ToolExecutor, ToolOutput};
struct AlwaysOkExecutor;
impl ToolExecutor for AlwaysOkExecutor {
async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
Ok(None)
}
async fn execute_tool_call(
&self,
_call: &ToolCall,
) -> Result<Option<ToolOutput>, ToolError> {
Ok(Some(ToolOutput {
tool_name: zeph_common::ToolName::new("test"),
summary: "ok".into(),
blocks_executed: 1,
filter_stats: None,
diff: None,
streamed: false,
terminal_id: None,
locations: None,
raw_response: None,
claim_source: None,
}))
}
fn is_tool_speculatable(&self, _: &str) -> bool {
true
}
}
#[tokio::test]
async fn dispatch_and_commit_succeeds() {
let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
let config = SpeculativeConfig {
mode: SpeculationMode::Decoding,
..Default::default()
};
let engine = SpeculationEngine::new(exec, config);
let pred = Prediction {
tool_id: zeph_common::ToolName::new("test"),
args: serde_json::Map::new(),
confidence: 0.9,
source: prediction::PredictionSource::StreamPartial,
};
let dispatched = engine.try_dispatch(&pred, SkillTrustLevel::Trusted);
let _ = dispatched;
}
#[tokio::test]
async fn untrusted_skill_skips_dispatch() {
let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
let config = SpeculativeConfig {
mode: SpeculationMode::Decoding,
..Default::default()
};
let engine = SpeculationEngine::new(exec, config);
let pred = Prediction {
tool_id: zeph_common::ToolName::new("test"),
args: serde_json::Map::new(),
confidence: 0.9,
source: prediction::PredictionSource::StreamPartial,
};
let dispatched = engine.try_dispatch(&pred, SkillTrustLevel::Quarantined);
assert!(
!dispatched,
"untrusted skill must not dispatch speculatively"
);
}
#[tokio::test]
async fn cancel_for_removes_handle() {
let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
let config = SpeculativeConfig {
mode: SpeculationMode::Decoding,
..Default::default()
};
let engine = SpeculationEngine::new(exec, config);
let pred = Prediction {
tool_id: zeph_common::ToolName::new("test"),
args: serde_json::Map::new(),
confidence: 0.9,
source: prediction::PredictionSource::StreamPartial,
};
engine.try_dispatch(&pred, SkillTrustLevel::Trusted);
engine.cancel_for(&zeph_common::ToolName::new("test"));
assert!(
engine.cache.is_empty(),
"cancel_for must remove handle from cache"
);
}
}