use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio_stream::StreamExt;
use tracing::debug;
use zeph_llm::provider::{ChatResponse, ThinkingBlock, ToolUseRequest};
use zeph_llm::sse::{ToolSseEvent, ToolSseStream};
use super::SpeculationEngine;
use super::partial_json::{PartialJsonParser, PrefixState};
use super::prediction::{Prediction, PredictionSource};
pub struct SpeculativeStreamDrainer {
stream: ToolSseStream,
engine: Arc<SpeculationEngine>,
confidence_threshold: f32,
parsers: HashMap<usize, PartialJsonParser>,
}
impl SpeculativeStreamDrainer {
#[must_use]
pub fn new(
stream: ToolSseStream,
engine: Arc<SpeculationEngine>,
confidence_threshold: f32,
) -> Self {
Self {
stream,
engine,
confidence_threshold,
parsers: HashMap::new(),
}
}
pub async fn drive(mut self) -> Result<ChatResponse, zeph_llm::LlmError> {
let mut tool_calls: Vec<ToolUseRequest> = Vec::new();
let mut thinking_blocks: Vec<ThinkingBlock> = Vec::new();
let mut text_buf = String::new();
let mut tool_meta: HashMap<usize, (String, String)> = HashMap::new();
let mut dispatched: std::collections::HashSet<usize> = std::collections::HashSet::new();
while let Some(event) = self.stream.next().await {
match event {
ToolSseEvent::ToolBlockStart { index, id, name } => {
tool_meta.insert(index, (id, name));
}
ToolSseEvent::InputJsonDelta { index, delta } => {
let parser = self.parsers.entry(index).or_default();
if let PrefixState::ValidPrefix {
known_leaves,
missing_required,
} = parser.push(&delta)
&& missing_required.is_empty()
&& !dispatched.contains(&index)
&& let Some((_llm_id, name)) = tool_meta.get(&index)
{
let pred = Prediction {
tool_id: name.as_str().into(),
args: known_leaves,
confidence: self.confidence_threshold,
source: PredictionSource::StreamPartial,
};
if self
.engine
.try_dispatch(&pred, zeph_common::SkillTrustLevel::Trusted)
{
dispatched.insert(index);
debug!(tool = %name, index, "speculative dispatch fired from SSE delta");
}
}
}
ToolSseEvent::ToolCallComplete {
index,
id,
name,
full_json,
} => {
tool_meta.insert(index, (id.clone(), name.clone()));
let input = serde_json::from_str(&full_json)
.unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
tool_calls.push(ToolUseRequest {
id,
name: name.into(),
input,
});
}
ToolSseEvent::ThinkingBlockDone(block) => {
thinking_blocks.push(block);
}
ToolSseEvent::ThinkingChunk(_) => {
}
ToolSseEvent::ContentChunk(text) => {
text_buf.push_str(&text);
}
ToolSseEvent::Compaction(_summary) => {
tracing::debug!(
"compaction summary received during tool stream (not yet surfaced)"
);
}
ToolSseEvent::Error(e) => {
return Err(e);
}
}
}
let text = if text_buf.is_empty() {
None
} else {
Some(text_buf)
};
if tool_calls.is_empty() {
Ok(ChatResponse::Text(text.unwrap_or_default()))
} else {
Ok(ChatResponse::ToolUse {
text,
tool_calls,
thinking_blocks,
})
}
}
}
pub async fn try_commit_with_timeout(
engine: &SpeculationEngine,
call: &zeph_tools::ToolCall,
) -> Option<Result<Option<zeph_tools::ToolOutput>, zeph_tools::ToolError>> {
const COMMIT_TIMEOUT: Duration = Duration::from_secs(2);
match tokio::time::timeout(COMMIT_TIMEOUT, engine.try_commit(call)).await {
Ok(result) => result,
Err(_elapsed) => {
debug!(tool_id = %call.tool_id, "speculative try_commit timed out after 2s");
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn drainer_new_has_empty_parsers() {
use std::sync::Arc;
use zeph_config::tools::SpeculativeConfig;
use zeph_tools::{ToolCall, ToolError, ToolExecutor, ToolOutput};
struct NullExec;
impl ToolExecutor for NullExec {
async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
Ok(None)
}
async fn execute_tool_call(
&self,
_: &ToolCall,
) -> Result<Option<ToolOutput>, ToolError> {
Ok(None)
}
fn is_tool_speculatable(&self, _: &str) -> bool {
false
}
}
let engine = Arc::new(super::super::SpeculationEngine::new(
Arc::new(NullExec),
SpeculativeConfig::default(),
));
let drainer = SpeculativeStreamDrainer::new(Box::pin(tokio_stream::empty()), engine, 0.8);
assert!(drainer.parsers.is_empty());
}
#[tokio::test]
async fn tool_block_start_enables_incremental_dispatch() {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use zeph_config::tools::{SpeculationMode, SpeculativeConfig};
use zeph_tools::{ToolCall, ToolError, ToolExecutor, ToolOutput};
struct SpyExec {
count: Arc<AtomicUsize>,
}
impl ToolExecutor for SpyExec {
async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
Ok(None)
}
async fn execute_tool_call(
&self,
_: &ToolCall,
) -> Result<Option<ToolOutput>, ToolError> {
Ok(None)
}
fn is_tool_speculatable(&self, _: &str) -> bool {
self.count.fetch_add(1, Ordering::Relaxed);
true
}
}
let dispatch_count = Arc::new(AtomicUsize::new(0));
let dispatch_count_clone = Arc::clone(&dispatch_count);
let config = SpeculativeConfig {
mode: SpeculationMode::Decoding,
..Default::default()
};
let engine = Arc::new(super::super::SpeculationEngine::new(
Arc::new(SpyExec {
count: dispatch_count_clone,
}),
config,
));
let events = vec![
ToolSseEvent::ToolBlockStart {
index: 0,
id: "toolu_01".into(),
name: "bash".into(),
},
ToolSseEvent::InputJsonDelta {
index: 0,
delta: r#"{"command":"ls"}"#.into(),
},
ToolSseEvent::ToolCallComplete {
index: 0,
id: "toolu_01".into(),
name: "bash".into(),
full_json: r#"{"command":"ls"}"#.into(),
},
];
let drainer =
SpeculativeStreamDrainer::new(Box::pin(tokio_stream::iter(events)), engine, 0.0);
let result = drainer.drive().await.unwrap();
assert!(matches!(result, ChatResponse::ToolUse { .. }));
assert!(
dispatch_count.load(Ordering::Relaxed) > 0,
"dispatch should have been attempted"
);
}
#[tokio::test]
async fn drive_empty_stream_returns_text_empty() {
use std::sync::Arc;
use zeph_config::tools::SpeculativeConfig;
use zeph_tools::{ToolCall, ToolError, ToolExecutor, ToolOutput};
struct NullExec;
impl ToolExecutor for NullExec {
async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
Ok(None)
}
async fn execute_tool_call(
&self,
_: &ToolCall,
) -> Result<Option<ToolOutput>, ToolError> {
Ok(None)
}
fn is_tool_speculatable(&self, _: &str) -> bool {
false
}
}
let engine = Arc::new(super::super::SpeculationEngine::new(
Arc::new(NullExec),
SpeculativeConfig::default(),
));
let drainer = SpeculativeStreamDrainer::new(Box::pin(tokio_stream::empty()), engine, 0.8);
let result = drainer.drive().await.unwrap();
assert!(matches!(result, ChatResponse::Text(s) if s.is_empty()));
}
#[tokio::test]
async fn drive_content_chunk_returns_text() {
use std::sync::Arc;
use zeph_config::tools::SpeculativeConfig;
use zeph_tools::{ToolCall, ToolError, ToolExecutor, ToolOutput};
struct NullExec;
impl ToolExecutor for NullExec {
async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
Ok(None)
}
async fn execute_tool_call(
&self,
_: &ToolCall,
) -> Result<Option<ToolOutput>, ToolError> {
Ok(None)
}
fn is_tool_speculatable(&self, _: &str) -> bool {
false
}
}
let engine = Arc::new(super::super::SpeculationEngine::new(
Arc::new(NullExec),
SpeculativeConfig::default(),
));
let events = vec![ToolSseEvent::ContentChunk("Hello world".into())];
let drainer =
SpeculativeStreamDrainer::new(Box::pin(tokio_stream::iter(events)), engine, 0.8);
let result = drainer.drive().await.unwrap();
assert!(matches!(result, ChatResponse::Text(s) if s == "Hello world"));
}
#[tokio::test]
async fn drive_tool_call_complete_returns_tool_use() {
use std::sync::Arc;
use zeph_config::tools::SpeculativeConfig;
use zeph_tools::{ToolCall, ToolError, ToolExecutor, ToolOutput};
struct NullExec;
impl ToolExecutor for NullExec {
async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
Ok(None)
}
async fn execute_tool_call(
&self,
_: &ToolCall,
) -> Result<Option<ToolOutput>, ToolError> {
Ok(None)
}
fn is_tool_speculatable(&self, _: &str) -> bool {
false
}
}
let engine = Arc::new(super::super::SpeculationEngine::new(
Arc::new(NullExec),
SpeculativeConfig::default(),
));
let events = vec![ToolSseEvent::ToolCallComplete {
index: 0,
id: "toolu_01".into(),
name: "bash".into(),
full_json: r#"{"command":"ls"}"#.into(),
}];
let drainer =
SpeculativeStreamDrainer::new(Box::pin(tokio_stream::iter(events)), engine, 0.8);
let result = drainer.drive().await.unwrap();
match result {
ChatResponse::ToolUse { tool_calls, .. } => {
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].id, "toolu_01");
assert_eq!(tool_calls[0].name, "bash");
}
other @ ChatResponse::Text(_) => panic!("expected ToolUse, got {other:?}"),
}
}
#[tokio::test]
async fn drive_error_event_propagates() {
use std::sync::Arc;
use zeph_config::tools::SpeculativeConfig;
use zeph_tools::{ToolCall, ToolError, ToolExecutor, ToolOutput};
struct NullExec;
impl ToolExecutor for NullExec {
async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
Ok(None)
}
async fn execute_tool_call(
&self,
_: &ToolCall,
) -> Result<Option<ToolOutput>, ToolError> {
Ok(None)
}
fn is_tool_speculatable(&self, _: &str) -> bool {
false
}
}
let engine = Arc::new(super::super::SpeculationEngine::new(
Arc::new(NullExec),
SpeculativeConfig::default(),
));
let events = vec![ToolSseEvent::Error(zeph_llm::LlmError::SseParse(
"boom".into(),
))];
let drainer =
SpeculativeStreamDrainer::new(Box::pin(tokio_stream::iter(events)), engine, 0.8);
let result = drainer.drive().await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("boom"));
}
}