pub mod cache;
pub mod partial_json;
pub mod paste;
pub mod prediction;
pub mod stream_drainer;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::Instant;
use tokio_util::sync::CancellationToken;
use tracing::debug;
use zeph_common::SkillTrustLevel;
use zeph_tools::{ErasedToolExecutor, ToolCall, ToolError, ToolOutput};
use cache::{HandleKey, SpeculativeCache, SpeculativeHandle, hash_args, hash_context};
use prediction::Prediction;
pub use zeph_config::tools::{SpeculationMode, SpeculativeConfig};
enum SweepHandle {
Supervised(zeph_common::task_supervisor::TaskHandle),
Raw(tokio::task::JoinHandle<()>),
}
impl SweepHandle {
fn abort(self) {
match self {
SweepHandle::Supervised(h) => h.abort(),
SweepHandle::Raw(h) => h.abort(),
}
}
}
#[derive(Debug, Default, Clone)]
pub struct SpeculativeMetrics {
pub committed: u32,
pub cancelled: u32,
pub evicted_oldest: u32,
pub skipped_confirmation: u32,
pub wasted_ms: u64,
}
pub struct SpeculationEngine {
executor: Arc<dyn ErasedToolExecutor>,
config: SpeculativeConfig,
cache: SpeculativeCache,
metrics: parking_lot::Mutex<SpeculativeMetrics>,
sweeper: Option<SweepHandle>,
task_supervisor: Option<Arc<zeph_common::TaskSupervisor>>,
}
impl SpeculationEngine {
#[must_use]
pub fn new(executor: Arc<dyn ErasedToolExecutor>, config: SpeculativeConfig) -> Self {
Self::new_with_supervisor(executor, config, None)
}
#[must_use]
pub fn new_with_supervisor(
executor: Arc<dyn ErasedToolExecutor>,
config: SpeculativeConfig,
supervisor: Option<Arc<zeph_common::TaskSupervisor>>,
) -> Self {
let cache = SpeculativeCache::new(config.max_in_flight);
let shared = cache.shared_inner();
let sweeper_handle = if let Some(sup) = &supervisor {
let task_handle = sup.spawn(zeph_common::task_supervisor::TaskDescriptor {
name: "agent.speculative.sweeper",
restart: zeph_common::task_supervisor::RestartPolicy::RunOnce,
factory: move || {
let shared = Arc::clone(&shared);
async move {
let mut interval = tokio::time::interval(Duration::from_secs(5));
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
interval.tick().await;
SpeculativeCache::sweep_expired_inner(&shared);
}
}
},
});
Some(SweepHandle::Supervised(task_handle))
} else {
let jh = tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(5));
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
interval.tick().await;
SpeculativeCache::sweep_expired_inner(&shared);
}
});
Some(SweepHandle::Raw(jh))
};
Self {
executor,
config,
cache,
metrics: parking_lot::Mutex::new(SpeculativeMetrics::default()),
sweeper: sweeper_handle,
task_supervisor: supervisor,
}
}
#[must_use]
pub fn mode(&self) -> SpeculationMode {
self.config.mode
}
#[must_use]
pub fn is_active(&self) -> bool {
self.config.mode != SpeculationMode::Off
}
#[must_use]
pub fn confidence_threshold(&self) -> f32 {
self.config.confidence_threshold
}
pub fn try_dispatch(&self, prediction: &Prediction, trust_level: SkillTrustLevel) -> bool {
if trust_level != SkillTrustLevel::Trusted {
return false;
}
let tool_id = &prediction.tool_id;
if !self.executor.is_tool_speculatable_erased(tool_id.as_ref()) {
return false;
}
let call = prediction.to_tool_call(format!("spec-{}", uuid::Uuid::new_v4()));
let args_hash = hash_args(&call.params);
let context_hash = hash_context(call.context.as_ref());
if self.executor.requires_confirmation_erased(&call) {
let mut m = self.metrics.lock();
m.skipped_confirmation += 1;
debug!(tool_id = %tool_id, "speculative skip: requires_confirmation");
return false;
}
let exec = Arc::clone(&self.executor);
let call_clone = call.clone();
let cancel = CancellationToken::new();
let cancel_child = cancel.child_token();
let task_name: Arc<str> = Arc::from(format!(
"agent.speculative.dispatch.{}",
uuid::Uuid::new_v4()
));
let join = if let Some(sup) = &self.task_supervisor {
sup.spawn_oneshot(Arc::clone(&task_name), move || async move {
tokio::select! {
result = exec.execute_tool_call_erased(&call_clone) => result,
() = cancel_child.cancelled() => {
Err(ToolError::Execution(std::io::Error::other("speculative cancelled")))
}
}
})
} else {
let tmp_cancel = tokio_util::sync::CancellationToken::new();
let tmp_sup = Arc::new(zeph_common::TaskSupervisor::new(tmp_cancel));
tmp_sup.spawn_oneshot(task_name, move || async move {
tokio::select! {
result = exec.execute_tool_call_erased(&call_clone) => result,
() = cancel_child.cancelled() => {
Err(ToolError::Execution(std::io::Error::other("speculative cancelled")))
}
}
})
};
let handle = SpeculativeHandle {
key: HandleKey {
tool_id: tool_id.clone(),
args_hash,
context_hash,
},
join,
cancel,
ttl_deadline: Instant::now() + Duration::from_secs(self.config.ttl_seconds),
started_at: std::time::Instant::now(),
};
debug!(tool_id = %tool_id, confidence = prediction.confidence, "speculative dispatch");
self.cache.insert(handle);
true
}
pub async fn try_commit(
&self,
call: &ToolCall,
) -> Option<Result<Option<ToolOutput>, ToolError>> {
let args_hash = hash_args(&call.params);
let context_hash = hash_context(call.context.as_ref());
if let Some(handle) = self
.cache
.take_match(&call.tool_id, &args_hash, &context_hash)
{
{
let mut m = self.metrics.lock();
m.committed += 1;
}
debug!(tool_id = %call.tool_id, "speculative commit");
Some(handle.commit().await)
} else {
None
}
}
pub fn cancel_for(&self, tool_id: &zeph_common::ToolName) {
debug!(tool_id = %tool_id, "speculative cancel for tool");
self.cache.cancel_by_tool_id(tool_id);
let mut m = self.metrics.lock();
m.cancelled += 1;
}
pub fn end_turn(&self) -> SpeculativeMetrics {
self.cache.cancel_all();
std::mem::take(&mut *self.metrics.lock())
}
#[must_use]
pub fn metrics_snapshot(&self) -> SpeculativeMetrics {
self.metrics.lock().clone()
}
}
impl Drop for SpeculationEngine {
fn drop(&mut self) {
self.cache.cancel_all();
if let Some(handle) = self.sweeper.take() {
handle.abort();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use zeph_tools::{ToolCall, ToolError, ToolExecutor, ToolOutput};
struct AlwaysOkExecutor;
impl ToolExecutor for AlwaysOkExecutor {
async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
Ok(None)
}
async fn execute_tool_call(
&self,
_call: &ToolCall,
) -> Result<Option<ToolOutput>, ToolError> {
Ok(Some(ToolOutput {
tool_name: zeph_common::ToolName::new("test"),
summary: "ok".into(),
blocks_executed: 1,
filter_stats: None,
diff: None,
streamed: false,
terminal_id: None,
locations: None,
raw_response: None,
claim_source: None,
}))
}
fn is_tool_speculatable(&self, _: &str) -> bool {
true
}
}
#[tokio::test]
async fn dispatch_and_commit_succeeds() {
let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
let config = SpeculativeConfig {
mode: SpeculationMode::Decoding,
..Default::default()
};
let engine = SpeculationEngine::new(exec, config);
let pred = Prediction {
tool_id: zeph_common::ToolName::new("test"),
args: serde_json::Map::new(),
confidence: 0.9,
source: prediction::PredictionSource::StreamPartial,
};
let dispatched = engine.try_dispatch(&pred, SkillTrustLevel::Trusted);
let _ = dispatched;
}
#[tokio::test]
async fn untrusted_skill_skips_dispatch() {
let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
let config = SpeculativeConfig {
mode: SpeculationMode::Decoding,
..Default::default()
};
let engine = SpeculationEngine::new(exec, config);
let pred = Prediction {
tool_id: zeph_common::ToolName::new("test"),
args: serde_json::Map::new(),
confidence: 0.9,
source: prediction::PredictionSource::StreamPartial,
};
let dispatched = engine.try_dispatch(&pred, SkillTrustLevel::Quarantined);
assert!(
!dispatched,
"untrusted skill must not dispatch speculatively"
);
}
#[tokio::test]
async fn cancel_for_removes_handle() {
let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
let config = SpeculativeConfig {
mode: SpeculationMode::Decoding,
..Default::default()
};
let engine = SpeculationEngine::new(exec, config);
let pred = Prediction {
tool_id: zeph_common::ToolName::new("test"),
args: serde_json::Map::new(),
confidence: 0.9,
source: prediction::PredictionSource::StreamPartial,
};
engine.try_dispatch(&pred, SkillTrustLevel::Trusted);
engine.cancel_for(&zeph_common::ToolName::new("test"));
assert!(
engine.cache.is_empty(),
"cancel_for must remove handle from cache"
);
}
#[tokio::test]
async fn end_turn_cancels_handles_and_resets_metrics() {
let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
let config = SpeculativeConfig {
mode: SpeculationMode::Decoding,
..Default::default()
};
let engine = SpeculationEngine::new(exec, config);
let pred = Prediction {
tool_id: zeph_common::ToolName::new("test"),
args: serde_json::Map::new(),
confidence: 0.9,
source: prediction::PredictionSource::StreamPartial,
};
engine.try_dispatch(&pred, SkillTrustLevel::Trusted);
assert!(
!engine.cache.is_empty(),
"precondition: handle must be in cache before end_turn"
);
let _metrics = engine.end_turn();
assert!(
engine.cache.is_empty(),
"end_turn must cancel all in-flight handles"
);
let snapshot = engine.metrics_snapshot();
assert_eq!(snapshot.committed, 0, "metrics must reset after end_turn");
assert_eq!(snapshot.cancelled, 0, "metrics must reset after end_turn");
}
#[tokio::test]
async fn is_active_reflects_mode() {
let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
let engine_off = SpeculationEngine::new(
Arc::clone(&exec),
SpeculativeConfig {
mode: SpeculationMode::Off,
..Default::default()
},
);
assert!(!engine_off.is_active(), "mode=Off means is_active()=false");
let engine_on = SpeculationEngine::new(
exec,
SpeculativeConfig {
mode: SpeculationMode::Decoding,
..Default::default()
},
);
assert!(
engine_on.is_active(),
"mode=Decoding means is_active()=true"
);
}
#[tokio::test]
async fn sweeper_aborted_on_drop() {
let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
let config = SpeculativeConfig {
mode: SpeculationMode::Decoding,
..Default::default()
};
let engine = SpeculationEngine::new(Arc::clone(&exec), config);
let (tx, rx) = tokio::sync::oneshot::channel::<()>();
let witness = tokio::spawn(async move {
let _ = tx.send(());
tokio::time::sleep(Duration::from_hours(1)).await;
});
let _ = rx.await;
drop(engine);
tokio::task::yield_now().await;
assert!(!witness.is_finished(), "unrelated task must not be aborted");
witness.abort();
let engine2 = SpeculationEngine::new(exec, SpeculativeConfig::default());
assert!(
engine2.sweeper.is_some(),
"sweeper handle must be Some after construction"
);
drop(engine2); }
#[tokio::test]
async fn sweeper_supervised_aborted_on_drop() {
let exec: Arc<dyn ErasedToolExecutor> = Arc::new(AlwaysOkExecutor);
let config = SpeculativeConfig {
mode: SpeculationMode::Decoding,
..Default::default()
};
let cancel = tokio_util::sync::CancellationToken::new();
let supervisor = Arc::new(zeph_common::TaskSupervisor::new(cancel));
let engine =
SpeculationEngine::new_with_supervisor(Arc::clone(&exec), config, Some(supervisor));
assert!(
engine.sweeper.is_some(),
"sweeper handle must be Some with supervisor"
);
drop(engine); }
}