use super::content::ContentBlock;
use super::inference::{InferenceOverride, StreamResult};
use super::message::Message;
use super::tool::ToolDescriptor;
use async_trait::async_trait;
use thiserror::Error;
#[derive(Debug, Clone)]
pub struct InferenceRequest {
pub upstream_model: String,
pub messages: Vec<Message>,
pub tools: Vec<ToolDescriptor>,
pub system: Vec<ContentBlock>,
pub overrides: Option<InferenceOverride>,
pub enable_prompt_cache: bool,
}
#[derive(Debug, Error)]
pub enum InferenceExecutionError {
#[error("provider error: {0}")]
Provider(String),
#[error("rate limited: {0}")]
RateLimited(String),
#[error("timeout: {0}")]
Timeout(String),
#[error("cancelled")]
Cancelled,
}
#[derive(Debug, Clone)]
pub enum LlmStreamEvent {
TextDelta(String),
ReasoningDelta(String),
ToolCallStart { id: String, name: String },
ToolCallDelta { id: String, args_delta: String },
ContentBlockStop,
Usage(super::inference::TokenUsage),
Stop(super::inference::StopReason),
}
pub type InferenceStream = std::pin::Pin<
Box<dyn futures::Stream<Item = Result<LlmStreamEvent, InferenceExecutionError>> + Send>,
>;
#[async_trait]
pub trait LlmExecutor: Send + Sync {
async fn execute(
&self,
request: InferenceRequest,
) -> Result<StreamResult, InferenceExecutionError>;
fn execute_stream(
&self,
request: InferenceRequest,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<InferenceStream, InferenceExecutionError>>
+ Send
+ '_,
>,
> {
Box::pin(async move {
let result = self.execute(request).await?;
let events = collected_to_stream_events(result);
Ok(Box::pin(futures::stream::iter(events)) as InferenceStream)
})
}
fn name(&self) -> &str;
}
pub fn collected_to_stream_events(
result: StreamResult,
) -> Vec<Result<LlmStreamEvent, InferenceExecutionError>> {
use super::content::ContentBlock;
let mut events = Vec::new();
for block in &result.content {
match block {
ContentBlock::Text { text } if !text.is_empty() => {
events.push(Ok(LlmStreamEvent::TextDelta(text.clone())));
}
ContentBlock::Thinking { thinking } if !thinking.is_empty() => {
events.push(Ok(LlmStreamEvent::ReasoningDelta(thinking.clone())));
}
_ => {}
}
}
for call in &result.tool_calls {
events.push(Ok(LlmStreamEvent::ToolCallStart {
id: call.id.clone(),
name: call.name.clone(),
}));
let args = serde_json::to_string(&call.arguments).unwrap_or_default();
if !args.is_empty() {
events.push(Ok(LlmStreamEvent::ToolCallDelta {
id: call.id.clone(),
args_delta: args,
}));
}
}
if let Some(usage) = result.usage {
events.push(Ok(LlmStreamEvent::Usage(usage)));
}
if let Some(stop) = result.stop_reason {
events.push(Ok(LlmStreamEvent::Stop(stop)));
}
events
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum ToolExecutionMode {
#[default]
Sequential,
ParallelBatchApproval,
ParallelStreaming,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::contract::inference::{StopReason, TokenUsage};
use crate::contract::message::ToolCall;
use crate::contract::tool::ToolDescriptor;
use serde_json::json;
struct MockLlm {
response_text: String,
tool_calls: Vec<ToolCall>,
}
#[async_trait]
impl LlmExecutor for MockLlm {
async fn execute(
&self,
_request: InferenceRequest,
) -> Result<StreamResult, InferenceExecutionError> {
Ok(StreamResult {
content: if self.response_text.is_empty() {
vec![]
} else {
vec![ContentBlock::text(self.response_text.clone())]
},
tool_calls: self.tool_calls.clone(),
usage: Some(TokenUsage {
prompt_tokens: Some(100),
completion_tokens: Some(50),
total_tokens: Some(150),
..Default::default()
}),
stop_reason: if self.tool_calls.is_empty() {
Some(StopReason::EndTurn)
} else {
Some(StopReason::ToolUse)
},
has_incomplete_tool_calls: false,
})
}
fn name(&self) -> &str {
"mock"
}
}
#[tokio::test]
async fn mock_llm_returns_text() {
let llm = MockLlm {
response_text: "Hello!".into(),
tool_calls: vec![],
};
let request = InferenceRequest {
upstream_model: "test-model".into(),
messages: vec![Message::user("hi")],
tools: vec![],
system: vec![],
overrides: None,
enable_prompt_cache: false,
};
let result = llm.execute(request).await.unwrap();
assert_eq!(result.text(), "Hello!");
assert!(!result.needs_tools());
assert_eq!(result.stop_reason, Some(StopReason::EndTurn));
}
#[tokio::test]
async fn mock_llm_returns_tool_calls() {
let llm = MockLlm {
response_text: String::new(),
tool_calls: vec![ToolCall::new("c1", "search", json!({"q": "rust"}))],
};
let request = InferenceRequest {
upstream_model: "test-model".into(),
messages: vec![Message::user("search for rust")],
tools: vec![ToolDescriptor::new("search", "search", "Web search")],
system: vec![ContentBlock::text("You are helpful.")],
overrides: None,
enable_prompt_cache: false,
};
let result = llm.execute(request).await.unwrap();
assert!(result.needs_tools());
assert_eq!(result.tool_calls.len(), 1);
assert_eq!(result.tool_calls[0].name, "search");
assert_eq!(result.stop_reason, Some(StopReason::ToolUse));
}
#[tokio::test]
async fn mock_llm_with_overrides() {
let llm = MockLlm {
response_text: "ok".into(),
tool_calls: vec![],
};
let request = InferenceRequest {
upstream_model: "base-model".into(),
messages: vec![],
tools: vec![],
system: vec![],
overrides: Some(InferenceOverride {
temperature: Some(0.7),
..Default::default()
}),
enable_prompt_cache: false,
};
let result = llm.execute(request).await.unwrap();
assert_eq!(result.text(), "ok");
}
#[test]
fn llm_executor_name_is_exposed() {
let llm = MockLlm {
response_text: String::new(),
tool_calls: vec![],
};
assert_eq!(llm.name(), "mock");
}
#[test]
fn tool_execution_mode_default_is_sequential() {
assert_eq!(ToolExecutionMode::default(), ToolExecutionMode::Sequential);
}
#[test]
fn inference_execution_error_display_strings_are_stable() {
assert_eq!(
InferenceExecutionError::Provider("provider failed".into()).to_string(),
"provider error: provider failed"
);
assert_eq!(
InferenceExecutionError::RateLimited("too many requests".into()).to_string(),
"rate limited: too many requests"
);
assert_eq!(
InferenceExecutionError::Timeout("slow backend".into()).to_string(),
"timeout: slow backend"
);
assert_eq!(InferenceExecutionError::Cancelled.to_string(), "cancelled");
}
}