use crate::context::{ContextConfig, compact_messages};
use crate::error::CoreError;
use crate::lifecycle::LifecycleHook;
use crate::protocol::{
AgentEvent, ChatMessage, ModelDirective, ModelStopReason, ModelTurn, RunStopReason, TokenUsage,
ToolCall, ToolDefinition, ToolResult, ToolResultSummary,
};
use crate::state::AppState;
use std::collections::BTreeMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
#[derive(Debug, Clone)]
pub enum StreamEvent<'a> {
Text(&'a str),
Reasoning(&'a str),
}
pub type SwappableProviderHandle = Arc<std::sync::RwLock<Arc<dyn Provider>>>;
pub trait ProviderFactory: Send + Sync {
fn build(&self, spec: &str) -> Result<Arc<dyn Provider>, CoreError>;
fn available_providers(&self) -> Vec<String>;
}
pub trait ApprovalGateHook: Send + Sync {
fn set_event_handler(&self, handler: Arc<dyn Fn(AgentEvent) + Send + Sync>);
fn clear_event_handler(&self);
}
pub trait ApprovalResolver: Send + Sync {
fn resolve_approval(&self, approval_id: &str, decision: &str, reason: Option<String>) -> bool;
fn pending_approval_ids(&self) -> Vec<String>;
}
#[derive(Debug, Clone)]
pub struct ProviderRequest {
pub run_id: String,
pub session_id: String,
pub iteration: u32,
pub messages: Vec<ChatMessage>,
pub tools: Vec<ToolDefinition>,
pub state: AppState,
}
pub trait Provider: Send + Sync {
fn name(&self) -> &str;
fn complete(&self, request: &ProviderRequest) -> Result<ModelTurn, CoreError>;
fn supports_streaming(&self) -> bool {
false
}
fn complete_streaming(
&self,
request: &ProviderRequest,
_on_delta: &dyn Fn(StreamEvent<'_>),
) -> Result<ModelTurn, CoreError> {
self.complete(request)
}
fn context_window(&self) -> Option<u32> {
None
}
}
#[derive(Debug, Clone)]
pub struct ToolContext {
pub run_id: String,
pub session_id: String,
pub iteration: u32,
}
pub trait Tool: Send + Sync {
fn definition(&self) -> ToolDefinition;
fn execute(&self, call: &ToolCall, ctx: &ToolContext) -> Result<ToolResult, CoreError>;
}
pub trait Middleware: Send + Sync {
fn before_model_call(&self, _request: &ProviderRequest) -> Result<(), CoreError> {
Ok(())
}
fn after_model_call(
&self,
_request: &ProviderRequest,
_response: &ModelTurn,
) -> Result<(), CoreError> {
Ok(())
}
fn pre_tool_call(&self, _context: &ToolContext, _call: &ToolCall) -> Result<(), CoreError> {
Ok(())
}
fn post_tool_call(
&self,
_context: &ToolContext,
_result: &ToolResult,
) -> Result<(), CoreError> {
Ok(())
}
fn on_run_finished(&self, _output: &RunOutput) -> Result<(), CoreError> {
Ok(())
}
}
pub trait TurnMiddleware: Send + Sync {
fn before_model_call(&self, _request: &mut ProviderRequest) -> Result<(), CoreError> {
Ok(())
}
fn after_model_call(
&self,
_request: &ProviderRequest,
_response: &mut ModelTurn,
) -> Result<(), CoreError> {
Ok(())
}
fn pre_tool_call(&self, _context: &ToolContext, _call: &mut ToolCall) -> Result<(), CoreError> {
Ok(())
}
fn post_tool_call(
&self,
_context: &ToolContext,
_result: &mut ToolResult,
) -> Result<(), CoreError> {
Ok(())
}
fn on_run_finished(&self, _output: &mut RunOutput) -> Result<(), CoreError> {
Ok(())
}
}
struct LegacyMiddlewareAdapter {
inner: Arc<dyn Middleware>,
}
impl LegacyMiddlewareAdapter {
fn new(inner: Arc<dyn Middleware>) -> Self {
Self { inner }
}
}
impl TurnMiddleware for LegacyMiddlewareAdapter {
fn before_model_call(&self, request: &mut ProviderRequest) -> Result<(), CoreError> {
self.inner.before_model_call(request)
}
fn after_model_call(
&self,
request: &ProviderRequest,
response: &mut ModelTurn,
) -> Result<(), CoreError> {
self.inner.after_model_call(request, response)
}
fn pre_tool_call(&self, context: &ToolContext, call: &mut ToolCall) -> Result<(), CoreError> {
self.inner.pre_tool_call(context, call)
}
fn post_tool_call(
&self,
context: &ToolContext,
result: &mut ToolResult,
) -> Result<(), CoreError> {
self.inner.post_tool_call(context, result)
}
fn on_run_finished(&self, output: &mut RunOutput) -> Result<(), CoreError> {
self.inner.on_run_finished(output)
}
}
#[derive(Clone, Default)]
pub struct ToolRegistry {
tools: BTreeMap<String, Arc<dyn Tool>>,
}
impl ToolRegistry {
pub fn register<T: Tool + 'static>(&mut self, tool: T) {
self.tools
.insert(tool.definition().name.clone(), Arc::new(tool));
}
pub fn get(&self, tool_name: &str) -> Option<Arc<dyn Tool>> {
self.tools.get(tool_name).cloned()
}
pub fn definitions(&self) -> Vec<ToolDefinition> {
self.tools.values().map(|tool| tool.definition()).collect()
}
}
#[derive(Debug, Clone)]
pub struct OrchestratorConfig {
pub max_iterations: u32,
pub context: Option<ContextConfig>,
pub context_compiler: Option<crate::context_compiler::ContextCompilerConfig>,
}
impl Default for OrchestratorConfig {
fn default() -> Self {
Self {
max_iterations: 24,
context: Some(ContextConfig::default()),
context_compiler: None,
}
}
}
#[derive(Debug, Clone)]
pub struct RunInput {
pub run_id: String,
pub session_id: String,
pub branch_id: String,
pub messages: Vec<ChatMessage>,
pub state: AppState,
}
#[derive(Debug, Clone)]
pub struct RunOutput {
pub run_id: String,
pub session_id: String,
pub branch_id: String,
pub events: Vec<AgentEvent>,
pub messages: Vec<ChatMessage>,
pub state: AppState,
pub reason: RunStopReason,
pub final_answer: Option<String>,
pub total_usage: TokenUsage,
}
pub struct Orchestrator {
provider: Arc<std::sync::RwLock<Arc<dyn Provider>>>,
tools: ToolRegistry,
turn_middlewares: Vec<Arc<dyn TurnMiddleware>>,
lifecycle_hooks: Vec<Arc<dyn LifecycleHook>>,
config: OrchestratorConfig,
}
impl Orchestrator {
pub fn new(
provider: Arc<dyn Provider>,
tools: ToolRegistry,
middlewares: Vec<Arc<dyn Middleware>>,
config: OrchestratorConfig,
) -> Self {
Self {
provider: Arc::new(std::sync::RwLock::new(provider)),
tools,
turn_middlewares: middlewares
.into_iter()
.map(|middleware| {
Arc::new(LegacyMiddlewareAdapter::new(middleware)) as Arc<dyn TurnMiddleware>
})
.collect(),
lifecycle_hooks: Vec::new(),
config,
}
}
pub fn with_turn_middlewares(
provider: Arc<dyn Provider>,
tools: ToolRegistry,
turn_middlewares: Vec<Arc<dyn TurnMiddleware>>,
config: OrchestratorConfig,
) -> Self {
Self {
provider: Arc::new(std::sync::RwLock::new(provider)),
tools,
turn_middlewares,
lifecycle_hooks: Vec::new(),
config,
}
}
pub fn with_lifecycle_hooks(mut self, hooks: Vec<Arc<dyn LifecycleHook>>) -> Self {
self.lifecycle_hooks = hooks;
self
}
pub fn add_lifecycle_hook(&mut self, hook: Arc<dyn LifecycleHook>) {
self.lifecycle_hooks.push(hook);
}
pub fn swap_provider(&self, new_provider: Arc<dyn Provider>) -> Result<String, CoreError> {
let name = new_provider.name().to_string();
let mut guard = self
.provider
.write()
.map_err(|e| CoreError::LockPoisoned(format!("provider write lock: {e}")))?;
*guard = new_provider;
Ok(name)
}
pub fn provider_name(&self) -> Result<String, CoreError> {
let guard = self
.provider
.read()
.map_err(|e| CoreError::LockPoisoned(format!("provider read lock: {e}")))?;
Ok(guard.name().to_string())
}
pub fn run(&self, input: RunInput, event_handler: impl FnMut(AgentEvent)) -> RunOutput {
self.run_cancellable(input, None, event_handler)
}
pub fn run_cancellable(
&self,
input: RunInput,
cancel: Option<&Arc<AtomicBool>>,
mut event_handler: impl FnMut(AgentEvent),
) -> RunOutput {
let mut events = Vec::new();
let mut messages = input.messages;
let mut state = input.state;
let mut final_answer: Option<String> = None;
let mut stop_reason = RunStopReason::BudgetExceeded;
let mut total_iterations = 0;
let mut total_usage = TokenUsage::default();
let provider = match self.provider.read() {
Ok(guard) => guard.clone(),
Err(e) => {
let err_event = AgentEvent::RunErrored {
run_id: input.run_id.clone(),
session_id: input.session_id.clone(),
error: format!("provider lock poisoned: {e}"),
};
event_handler(err_event.clone());
return RunOutput {
run_id: input.run_id,
session_id: input.session_id,
branch_id: input.branch_id,
events: vec![err_event],
final_answer: None,
messages,
state,
reason: RunStopReason::Error,
total_usage: TokenUsage::default(),
};
}
};
let start_event = AgentEvent::RunStarted {
run_id: input.run_id.clone(),
session_id: input.session_id.clone(),
provider: provider.name().to_string(),
max_iterations: self.config.max_iterations,
};
event_handler(start_event.clone());
events.push(start_event);
for hook in &self.lifecycle_hooks {
hook.on_session_start(&input.session_id);
}
for iteration in 1..=self.config.max_iterations {
if let Some(flag) = cancel
&& flag.load(Ordering::Relaxed)
{
stop_reason = RunStopReason::Cancelled;
let err_event = AgentEvent::RunErrored {
run_id: input.run_id.clone(),
session_id: input.session_id.clone(),
error: "run cancelled".to_string(),
};
event_handler(err_event.clone());
events.push(err_event);
break;
}
total_iterations = iteration;
let iter_event = AgentEvent::IterationStarted {
run_id: input.run_id.clone(),
session_id: input.session_id.clone(),
iteration,
};
event_handler(iter_event.clone());
events.push(iter_event);
if let Some(ref ctx_config) = self.config.context
&& let Some(result) = compact_messages(&messages, ctx_config)
{
let compact_event = AgentEvent::ContextCompacted {
run_id: input.run_id.clone(),
session_id: input.session_id.clone(),
iteration,
dropped_count: result.dropped_count,
tokens_before: result.tokens_before,
tokens_after: result.tokens_after,
};
event_handler(compact_event.clone());
events.push(compact_event);
messages = result.messages;
}
let mut provider_request = ProviderRequest {
run_id: input.run_id.clone(),
session_id: input.session_id.clone(),
iteration,
messages: messages.clone(),
tools: self.tools.definitions(),
state: state.clone(),
};
if let Err(err) = self.run_before_model(&mut provider_request) {
stop_reason = RunStopReason::BlockedByPolicy;
let err_event = AgentEvent::RunErrored {
run_id: input.run_id.clone(),
session_id: input.session_id.clone(),
error: err.to_string(),
};
event_handler(err_event.clone());
events.push(err_event);
break;
}
for hook in &self.lifecycle_hooks {
hook.pre_llm_call(&provider_request);
}
let mut model_turn = match provider.complete(&provider_request) {
Ok(turn) => turn,
Err(err) => {
stop_reason = RunStopReason::Error;
let err_event = AgentEvent::RunErrored {
run_id: input.run_id.clone(),
session_id: input.session_id.clone(),
error: err.to_string(),
};
event_handler(err_event.clone());
events.push(err_event);
break;
}
};
for hook in &self.lifecycle_hooks {
hook.post_llm_call(&provider_request);
}
if let Err(err) = self.run_after_model(&provider_request, &mut model_turn) {
stop_reason = RunStopReason::BlockedByPolicy;
let err_event = AgentEvent::RunErrored {
run_id: input.run_id.clone(),
session_id: input.session_id.clone(),
error: err.to_string(),
};
event_handler(err_event.clone());
events.push(err_event);
break;
}
if let Some(ref usage) = model_turn.usage {
total_usage.accumulate(usage);
}
let output_event = AgentEvent::ModelOutput {
run_id: input.run_id.clone(),
session_id: input.session_id.clone(),
iteration,
stop_reason: model_turn.stop_reason,
directive_count: model_turn.directives.len(),
usage: model_turn.usage,
};
event_handler(output_event.clone());
events.push(output_event);
let mut requested_tool = false;
for directive in model_turn.directives {
match directive {
ModelDirective::Text { delta } => {
let delta_event = AgentEvent::TextDelta {
run_id: input.run_id.clone(),
session_id: input.session_id.clone(),
iteration,
delta: delta.clone(),
};
event_handler(delta_event.clone());
events.push(delta_event);
messages.push(ChatMessage::assistant(delta));
}
ModelDirective::ToolCall { mut call } => {
requested_tool = true;
let tc_event = AgentEvent::ToolCallRequested {
run_id: input.run_id.clone(),
session_id: input.session_id.clone(),
iteration,
call: call.clone(),
};
event_handler(tc_event.clone());
events.push(tc_event);
let context = ToolContext {
run_id: input.run_id.clone(),
session_id: input.session_id.clone(),
iteration,
};
if let Err(err) = self.run_pre_tool(&context, &mut call) {
stop_reason = RunStopReason::BlockedByPolicy;
let err_event = AgentEvent::ToolCallFailed {
run_id: input.run_id.clone(),
session_id: input.session_id.clone(),
iteration,
call_id: call.call_id.clone(),
tool_name: call.tool_name.clone(),
error: err.to_string(),
};
event_handler(err_event.clone());
events.push(err_event);
break;
}
let Some(tool) = self.tools.get(&call.tool_name) else {
stop_reason = RunStopReason::Error;
let err_event = AgentEvent::ToolCallFailed {
run_id: input.run_id.clone(),
session_id: input.session_id.clone(),
iteration,
call_id: call.call_id.clone(),
tool_name: call.tool_name.clone(),
error: format!(
"{}",
CoreError::ToolNotFound {
tool_name: call.tool_name.clone(),
}
),
};
event_handler(err_event.clone());
events.push(err_event);
break;
};
for hook in &self.lifecycle_hooks {
hook.pre_tool_call(&call.tool_name, &call.input);
}
match tool.execute(&call, &context) {
Ok(mut result) => {
if let Err(err) = self.run_post_tool(&context, &mut result) {
stop_reason = RunStopReason::BlockedByPolicy;
let err_event = AgentEvent::ToolCallFailed {
run_id: input.run_id.clone(),
session_id: input.session_id.clone(),
iteration,
call_id: call.call_id.clone(),
tool_name: call.tool_name.clone(),
error: err.to_string(),
};
event_handler(err_event.clone());
events.push(err_event);
break;
}
if let Some(patch) = &result.state_patch {
match state.apply_patch(patch) {
Ok(()) => {
let patch_event = AgentEvent::StatePatched {
run_id: input.run_id.clone(),
session_id: input.session_id.clone(),
iteration,
patch: patch.clone(),
revision: state.revision,
};
event_handler(patch_event.clone());
events.push(patch_event);
}
Err(err) => {
stop_reason = RunStopReason::Error;
let err_event = AgentEvent::ToolCallFailed {
run_id: input.run_id.clone(),
session_id: input.session_id.clone(),
iteration,
call_id: call.call_id.clone(),
tool_name: call.tool_name.clone(),
error: err.to_string(),
};
event_handler(err_event.clone());
events.push(err_event);
break;
}
}
}
let result_str = serde_json::to_string(&result.output)
.unwrap_or_else(|_| "{}".to_string());
for hook in &self.lifecycle_hooks {
hook.post_tool_call(&call.tool_name, &result_str);
}
let completed_event = AgentEvent::ToolCallCompleted {
run_id: input.run_id.clone(),
session_id: input.session_id.clone(),
iteration,
result: ToolResultSummary::from(&result),
};
event_handler(completed_event.clone());
events.push(completed_event);
messages.push(ChatMessage::tool_result(
&result.call_id,
serde_json::to_string(&result.output)
.unwrap_or_else(|_| "{}".to_string()),
));
}
Err(err) => {
stop_reason = RunStopReason::Error;
let err_event = AgentEvent::ToolCallFailed {
run_id: input.run_id.clone(),
session_id: input.session_id.clone(),
iteration,
call_id: call.call_id.clone(),
tool_name: call.tool_name.clone(),
error: err.to_string(),
};
event_handler(err_event.clone());
events.push(err_event);
break;
}
}
}
ModelDirective::StatePatch { patch } => match state.apply_patch(&patch) {
Ok(()) => {
let patch_event = AgentEvent::StatePatched {
run_id: input.run_id.clone(),
session_id: input.session_id.clone(),
iteration,
patch: patch.clone(),
revision: state.revision,
};
event_handler(patch_event.clone());
events.push(patch_event);
}
Err(err) => {
stop_reason = RunStopReason::Error;
let err_event = AgentEvent::RunErrored {
run_id: input.run_id.clone(),
session_id: input.session_id.clone(),
error: err.to_string(),
};
event_handler(err_event.clone());
events.push(err_event);
break;
}
},
ModelDirective::FinalAnswer { text } => {
final_answer = Some(text.clone());
let delta_event = AgentEvent::TextDelta {
run_id: input.run_id.clone(),
session_id: input.session_id.clone(),
iteration,
delta: text.clone(),
};
event_handler(delta_event.clone());
events.push(delta_event);
messages.push(ChatMessage::assistant(text));
}
}
}
if matches!(
stop_reason,
RunStopReason::Error | RunStopReason::BlockedByPolicy | RunStopReason::Cancelled
) {
break;
}
match model_turn.stop_reason {
ModelStopReason::EndTurn => {
stop_reason = RunStopReason::Completed;
break;
}
ModelStopReason::NeedsUser => {
stop_reason = RunStopReason::NeedsUser;
break;
}
ModelStopReason::Safety => {
stop_reason = RunStopReason::BlockedByPolicy;
break;
}
ModelStopReason::ToolUse => {
if !requested_tool {
stop_reason = RunStopReason::Error;
let err_event = AgentEvent::RunErrored {
run_id: input.run_id.clone(),
session_id: input.session_id.clone(),
error: "model requested tool_use stop reason without tool call"
.to_string(),
};
event_handler(err_event.clone());
events.push(err_event);
break;
}
}
ModelStopReason::MaxTokens | ModelStopReason::Unknown => {
if !requested_tool {
stop_reason = RunStopReason::Error;
let err_event = AgentEvent::RunErrored {
run_id: input.run_id.clone(),
session_id: input.session_id.clone(),
error: "model returned non-terminal stop reason without tool call"
.to_string(),
};
event_handler(err_event.clone());
events.push(err_event);
break;
}
}
}
}
if total_iterations == self.config.max_iterations
&& stop_reason == RunStopReason::BudgetExceeded
{
let err_event = AgentEvent::RunErrored {
run_id: input.run_id.clone(),
session_id: input.session_id.clone(),
error: "max iteration budget exceeded".to_string(),
};
event_handler(err_event.clone());
events.push(err_event);
}
let finished_event = AgentEvent::RunFinished {
run_id: input.run_id.clone(),
session_id: input.session_id.clone(),
reason: stop_reason,
total_iterations,
final_answer: final_answer.clone(),
usage: if total_usage.total() > 0 {
Some(total_usage)
} else {
None
},
};
event_handler(finished_event.clone());
events.push(finished_event);
let mut output = RunOutput {
run_id: input.run_id,
session_id: input.session_id,
branch_id: input.branch_id,
events,
messages,
state,
reason: stop_reason,
final_answer,
total_usage,
};
if let Err(e) = self
.turn_middlewares
.iter()
.try_for_each(|m| m.on_run_finished(&mut output))
{
tracing::warn!(error = %e, "middleware on_run_finished failed (non-fatal)");
}
for hook in &self.lifecycle_hooks {
hook.on_session_end(&output.session_id, &output);
}
output
}
fn run_before_model(&self, request: &mut ProviderRequest) -> Result<(), CoreError> {
self.turn_middlewares
.iter()
.try_for_each(|middleware| middleware.before_model_call(request))
}
fn run_after_model(
&self,
request: &ProviderRequest,
response: &mut ModelTurn,
) -> Result<(), CoreError> {
self.turn_middlewares
.iter()
.try_for_each(|middleware| middleware.after_model_call(request, response))
}
fn run_pre_tool(&self, context: &ToolContext, call: &mut ToolCall) -> Result<(), CoreError> {
self.turn_middlewares
.iter()
.try_for_each(|middleware| middleware.pre_tool_call(context, call))
}
fn run_post_tool(
&self,
context: &ToolContext,
result: &mut ToolResult,
) -> Result<(), CoreError> {
self.turn_middlewares
.iter()
.try_for_each(|middleware| middleware.post_tool_call(context, result))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::{
ModelDirective, ModelStopReason, ModelTurn, StatePatch, StatePatchFormat, StatePatchSource,
};
use serde_json::json;
use std::sync::Mutex;
struct ScriptedProvider {
turns: Vec<ModelTurn>,
cursor: Mutex<usize>,
}
impl Provider for ScriptedProvider {
fn name(&self) -> &str {
"scripted"
}
fn complete(&self, _request: &ProviderRequest) -> Result<ModelTurn, CoreError> {
let mut cursor = self
.cursor
.lock()
.map_err(|_| CoreError::Provider("scripted provider lock poisoned".to_string()))?;
let idx = *cursor;
let Some(turn) = self.turns.get(idx) else {
return Err(CoreError::Provider("no scripted turn left".to_string()));
};
*cursor += 1;
Ok(turn.clone())
}
}
struct EchoTool;
impl Tool for EchoTool {
fn definition(&self) -> ToolDefinition {
ToolDefinition {
name: "echo".to_string(),
description: "Echoes the provided value".to_string(),
input_schema: json!({
"type": "object",
"properties": { "value": { "type": "string" } },
"required": ["value"]
}),
title: None,
output_schema: None,
annotations: None,
category: None,
tags: Vec::new(),
timeout_secs: None,
}
}
fn execute(&self, call: &ToolCall, _ctx: &ToolContext) -> Result<ToolResult, CoreError> {
let value = call.input.get("value").cloned().unwrap_or(json!(null));
Ok(ToolResult {
call_id: call.call_id.clone(),
tool_name: call.tool_name.clone(),
output: json!({ "echo": value.clone() }),
content: None,
is_error: false,
state_patch: Some(StatePatch {
format: StatePatchFormat::MergePatch,
patch: json!({ "last_echo": value }),
source: StatePatchSource::Tool,
}),
})
}
}
#[test]
fn orchestrator_runs_tool_then_finishes() {
let provider = ScriptedProvider {
turns: vec![
ModelTurn {
directives: vec![ModelDirective::ToolCall {
call: ToolCall {
call_id: "call-1".to_string(),
tool_name: "echo".to_string(),
input: json!({ "value": "hello" }),
},
}],
stop_reason: ModelStopReason::ToolUse,
usage: None,
},
ModelTurn {
directives: vec![ModelDirective::FinalAnswer {
text: "done".to_string(),
}],
stop_reason: ModelStopReason::EndTurn,
usage: None,
},
],
cursor: Mutex::new(0),
};
let mut tools = ToolRegistry::default();
tools.register(EchoTool);
let orchestrator = Orchestrator::new(
Arc::new(provider),
tools,
Vec::new(),
OrchestratorConfig {
max_iterations: 4,
context: None,
context_compiler: None,
},
);
let output = orchestrator.run(
RunInput {
run_id: "run-1".to_string(),
session_id: "session-1".to_string(),
branch_id: "main".to_string(),
messages: vec![ChatMessage::user("test")],
state: AppState::default(),
},
|_| {},
);
assert_eq!(output.reason, RunStopReason::Completed);
assert_eq!(output.final_answer.as_deref(), Some("done"));
assert_eq!(output.state.revision, 1);
assert_eq!(output.state.data["last_echo"], "hello");
assert!(
output
.events
.iter()
.any(|event| matches!(event, AgentEvent::ToolCallCompleted { .. }))
);
assert!(output.events.iter().any(|event| matches!(
event,
AgentEvent::RunFinished {
reason: RunStopReason::Completed,
..
}
)));
}
#[test]
fn provider_error_stops_run() {
struct FailProvider;
impl Provider for FailProvider {
fn name(&self) -> &str {
"fail"
}
fn complete(&self, _request: &ProviderRequest) -> Result<ModelTurn, CoreError> {
Err(CoreError::Provider("connection refused".to_string()))
}
}
let orchestrator = Orchestrator::new(
Arc::new(FailProvider),
ToolRegistry::default(),
Vec::new(),
OrchestratorConfig {
max_iterations: 4,
context: None,
context_compiler: None,
},
);
let output = orchestrator.run(
RunInput {
run_id: "run-1".to_string(),
session_id: "s1".to_string(),
branch_id: "main".to_string(),
messages: vec![ChatMessage::user("test")],
state: AppState::default(),
},
|_| {},
);
assert_eq!(output.reason, RunStopReason::Error);
assert!(
output
.events
.iter()
.any(|e| matches!(e, AgentEvent::RunErrored { .. }))
);
}
#[test]
fn tool_not_found_stops_run() {
let provider = ScriptedProvider {
turns: vec![ModelTurn {
directives: vec![ModelDirective::ToolCall {
call: ToolCall {
call_id: "c1".to_string(),
tool_name: "nonexistent".to_string(),
input: json!({}),
},
}],
stop_reason: ModelStopReason::ToolUse,
usage: None,
}],
cursor: Mutex::new(0),
};
let orchestrator = Orchestrator::new(
Arc::new(provider),
ToolRegistry::default(),
Vec::new(),
OrchestratorConfig {
max_iterations: 4,
context: None,
context_compiler: None,
},
);
let output = orchestrator.run(
RunInput {
run_id: "run-1".to_string(),
session_id: "s1".to_string(),
branch_id: "main".to_string(),
messages: vec![ChatMessage::user("test")],
state: AppState::default(),
},
|_| {},
);
assert_eq!(output.reason, RunStopReason::Error);
assert!(
output
.events
.iter()
.any(|e| matches!(e, AgentEvent::ToolCallFailed { .. }))
);
}
#[test]
fn middleware_blocks_model_call() {
struct BlockMiddleware;
impl Middleware for BlockMiddleware {
fn before_model_call(&self, _request: &ProviderRequest) -> Result<(), CoreError> {
Err(CoreError::Middleware("blocked by policy".to_string()))
}
}
let provider = ScriptedProvider {
turns: vec![ModelTurn {
directives: vec![ModelDirective::Text {
delta: "hi".to_string(),
}],
stop_reason: ModelStopReason::EndTurn,
usage: None,
}],
cursor: Mutex::new(0),
};
let orchestrator = Orchestrator::new(
Arc::new(provider),
ToolRegistry::default(),
vec![Arc::new(BlockMiddleware)],
OrchestratorConfig {
max_iterations: 4,
context: None,
context_compiler: None,
},
);
let output = orchestrator.run(
RunInput {
run_id: "run-1".to_string(),
session_id: "s1".to_string(),
branch_id: "main".to_string(),
messages: vec![ChatMessage::user("test")],
state: AppState::default(),
},
|_| {},
);
assert_eq!(output.reason, RunStopReason::BlockedByPolicy);
}
#[test]
fn turn_middleware_can_rewrite_calls_and_responses() {
struct RewriteMiddleware;
impl TurnMiddleware for RewriteMiddleware {
fn after_model_call(
&self,
_request: &ProviderRequest,
response: &mut ModelTurn,
) -> Result<(), CoreError> {
for directive in &mut response.directives {
if let ModelDirective::FinalAnswer { text } = directive {
*text = "rewritten answer".to_string();
}
}
Ok(())
}
fn pre_tool_call(
&self,
_context: &ToolContext,
call: &mut ToolCall,
) -> Result<(), CoreError> {
call.input = json!({ "value": "rewritten input" });
Ok(())
}
}
let provider = ScriptedProvider {
turns: vec![
ModelTurn {
directives: vec![ModelDirective::ToolCall {
call: ToolCall {
call_id: "call-1".to_string(),
tool_name: "echo".to_string(),
input: json!({ "value": "original input" }),
},
}],
stop_reason: ModelStopReason::ToolUse,
usage: None,
},
ModelTurn {
directives: vec![ModelDirective::FinalAnswer {
text: "original answer".to_string(),
}],
stop_reason: ModelStopReason::EndTurn,
usage: None,
},
],
cursor: Mutex::new(0),
};
let mut tools = ToolRegistry::default();
tools.register(EchoTool);
let orchestrator = Orchestrator::with_turn_middlewares(
Arc::new(provider),
tools,
vec![Arc::new(RewriteMiddleware)],
OrchestratorConfig {
max_iterations: 4,
context: None,
context_compiler: None,
},
);
let output = orchestrator.run(
RunInput {
run_id: "run-1".to_string(),
session_id: "session-1".to_string(),
branch_id: "main".to_string(),
messages: vec![ChatMessage::user("test")],
state: AppState::default(),
},
|_| {},
);
assert_eq!(output.reason, RunStopReason::Completed);
assert_eq!(output.final_answer.as_deref(), Some("rewritten answer"));
assert_eq!(output.state.data["last_echo"], "rewritten input");
}
#[test]
fn budget_exceeded_when_iterations_exhausted() {
let provider = ScriptedProvider {
turns: vec![
ModelTurn {
directives: vec![ModelDirective::ToolCall {
call: ToolCall {
call_id: "c1".to_string(),
tool_name: "echo".to_string(),
input: json!({"value": "1"}),
},
}],
stop_reason: ModelStopReason::ToolUse,
usage: None,
},
ModelTurn {
directives: vec![ModelDirective::ToolCall {
call: ToolCall {
call_id: "c2".to_string(),
tool_name: "echo".to_string(),
input: json!({"value": "2"}),
},
}],
stop_reason: ModelStopReason::ToolUse,
usage: None,
},
],
cursor: Mutex::new(0),
};
let mut tools = ToolRegistry::default();
tools.register(EchoTool);
let orchestrator = Orchestrator::new(
Arc::new(provider),
tools,
Vec::new(),
OrchestratorConfig {
max_iterations: 2,
context: None,
context_compiler: None,
},
);
let output = orchestrator.run(
RunInput {
run_id: "run-1".to_string(),
session_id: "s1".to_string(),
branch_id: "main".to_string(),
messages: vec![ChatMessage::user("test")],
state: AppState::default(),
},
|_| {},
);
assert_eq!(output.reason, RunStopReason::BudgetExceeded);
}
#[test]
fn text_only_response_completes() {
let provider = ScriptedProvider {
turns: vec![ModelTurn {
directives: vec![ModelDirective::Text {
delta: "Hello, world!".to_string(),
}],
stop_reason: ModelStopReason::EndTurn,
usage: None,
}],
cursor: Mutex::new(0),
};
let orchestrator = Orchestrator::new(
Arc::new(provider),
ToolRegistry::default(),
Vec::new(),
OrchestratorConfig {
max_iterations: 4,
context: None,
context_compiler: None,
},
);
let output = orchestrator.run(
RunInput {
run_id: "run-1".to_string(),
session_id: "s1".to_string(),
branch_id: "main".to_string(),
messages: vec![ChatMessage::user("hi")],
state: AppState::default(),
},
|_| {},
);
assert_eq!(output.reason, RunStopReason::Completed);
assert!(output.messages.iter().any(|m| m.content == "Hello, world!"));
}
#[test]
fn event_handler_receives_all_events() {
let provider = ScriptedProvider {
turns: vec![ModelTurn {
directives: vec![ModelDirective::FinalAnswer {
text: "done".to_string(),
}],
stop_reason: ModelStopReason::EndTurn,
usage: None,
}],
cursor: Mutex::new(0),
};
let orchestrator = Orchestrator::new(
Arc::new(provider),
ToolRegistry::default(),
Vec::new(),
OrchestratorConfig {
max_iterations: 4,
context: None,
context_compiler: None,
},
);
let received = Arc::new(Mutex::new(Vec::new()));
let received_clone = received.clone();
orchestrator.run(
RunInput {
run_id: "run-1".to_string(),
session_id: "s1".to_string(),
branch_id: "main".to_string(),
messages: vec![ChatMessage::user("test")],
state: AppState::default(),
},
move |event| {
received_clone.lock().unwrap().push(event);
},
);
let events = received.lock().unwrap();
assert!(events.len() >= 4); assert!(matches!(events[0], AgentEvent::RunStarted { .. }));
assert!(matches!(
events.last().unwrap(),
AgentEvent::RunFinished { .. }
));
}
#[test]
fn tool_result_includes_call_id() {
let provider = ScriptedProvider {
turns: vec![
ModelTurn {
directives: vec![ModelDirective::ToolCall {
call: ToolCall {
call_id: "my-call-id".to_string(),
tool_name: "echo".to_string(),
input: json!({"value": "test"}),
},
}],
stop_reason: ModelStopReason::ToolUse,
usage: None,
},
ModelTurn {
directives: vec![ModelDirective::FinalAnswer {
text: "ok".to_string(),
}],
stop_reason: ModelStopReason::EndTurn,
usage: None,
},
],
cursor: Mutex::new(0),
};
let mut tools = ToolRegistry::default();
tools.register(EchoTool);
let orchestrator = Orchestrator::new(
Arc::new(provider),
tools,
Vec::new(),
OrchestratorConfig {
max_iterations: 4,
context: None,
context_compiler: None,
},
);
let output = orchestrator.run(
RunInput {
run_id: "run-1".to_string(),
session_id: "s1".to_string(),
branch_id: "main".to_string(),
messages: vec![ChatMessage::user("test")],
state: AppState::default(),
},
|_| {},
);
let tool_msg = output
.messages
.iter()
.find(|m| m.role == crate::protocol::Role::Tool)
.expect("should have tool message");
assert_eq!(tool_msg.tool_call_id.as_deref(), Some("my-call-id"));
}
#[test]
fn cancellation_stops_run() {
let provider = ScriptedProvider {
turns: vec![
ModelTurn {
directives: vec![ModelDirective::ToolCall {
call: ToolCall {
call_id: "c1".to_string(),
tool_name: "echo".to_string(),
input: json!({"value": "1"}),
},
}],
stop_reason: ModelStopReason::ToolUse,
usage: None,
},
ModelTurn {
directives: vec![ModelDirective::FinalAnswer {
text: "should not reach".to_string(),
}],
stop_reason: ModelStopReason::EndTurn,
usage: None,
},
],
cursor: Mutex::new(0),
};
let mut tools = ToolRegistry::default();
tools.register(EchoTool);
let orchestrator = Orchestrator::new(
Arc::new(provider),
tools,
Vec::new(),
OrchestratorConfig {
max_iterations: 10,
context: None,
context_compiler: None,
},
);
let cancel = Arc::new(AtomicBool::new(false));
let cancel_clone = cancel.clone();
let call_count = Arc::new(Mutex::new(0u32));
let call_count_clone = call_count.clone();
let output = orchestrator.run_cancellable(
RunInput {
run_id: "run-1".to_string(),
session_id: "s1".to_string(),
branch_id: "main".to_string(),
messages: vec![ChatMessage::user("test")],
state: AppState::default(),
},
Some(&cancel_clone),
move |event| {
if matches!(event, AgentEvent::ToolCallCompleted { .. }) {
let mut count = call_count_clone.lock().unwrap();
*count += 1;
if *count >= 1 {
cancel.store(true, Ordering::Relaxed);
}
}
},
);
assert_eq!(output.reason, RunStopReason::Cancelled);
assert!(output.final_answer.is_none());
}
#[test]
fn swappable_provider_handle_swap() {
struct ProviderA;
impl Provider for ProviderA {
fn name(&self) -> &str {
"provider-a"
}
fn complete(&self, _: &ProviderRequest) -> Result<ModelTurn, CoreError> {
Err(CoreError::Provider("stub provider".into()))
}
}
struct ProviderB;
impl Provider for ProviderB {
fn name(&self) -> &str {
"provider-b"
}
fn complete(&self, _: &ProviderRequest) -> Result<ModelTurn, CoreError> {
Err(CoreError::Provider("stub provider".into()))
}
}
let handle: SwappableProviderHandle = Arc::new(std::sync::RwLock::new(Arc::new(ProviderA)));
assert_eq!(handle.read().unwrap().name(), "provider-a");
{
let mut guard = handle.write().unwrap();
*guard = Arc::new(ProviderB);
}
assert_eq!(handle.read().unwrap().name(), "provider-b");
}
#[test]
fn token_usage_accumulated() {
let provider = ScriptedProvider {
turns: vec![
ModelTurn {
directives: vec![ModelDirective::ToolCall {
call: ToolCall {
call_id: "c1".to_string(),
tool_name: "echo".to_string(),
input: json!({"value": "hi"}),
},
}],
stop_reason: ModelStopReason::ToolUse,
usage: Some(TokenUsage {
input_tokens: 100,
output_tokens: 50,
cache_read_tokens: 0,
cache_creation_tokens: 0,
}),
},
ModelTurn {
directives: vec![ModelDirective::FinalAnswer {
text: "done".to_string(),
}],
stop_reason: ModelStopReason::EndTurn,
usage: Some(TokenUsage {
input_tokens: 200,
output_tokens: 30,
cache_read_tokens: 0,
cache_creation_tokens: 0,
}),
},
],
cursor: Mutex::new(0),
};
let mut tools = ToolRegistry::default();
tools.register(EchoTool);
let orchestrator = Orchestrator::new(
Arc::new(provider),
tools,
Vec::new(),
OrchestratorConfig {
max_iterations: 4,
context: None,
context_compiler: None,
},
);
let output = orchestrator.run(
RunInput {
run_id: "run-1".to_string(),
session_id: "s1".to_string(),
branch_id: "main".to_string(),
messages: vec![ChatMessage::user("test")],
state: AppState::default(),
},
|_| {},
);
assert_eq!(output.reason, RunStopReason::Completed);
assert_eq!(output.total_usage.input_tokens, 300);
assert_eq!(output.total_usage.output_tokens, 80);
assert_eq!(output.total_usage.total(), 380);
}
#[test]
fn lifecycle_hooks_fire_during_run() {
use crate::lifecycle::LifecycleHook;
use std::sync::atomic::{AtomicU32, Ordering};
struct CountingHook {
pre_tool: AtomicU32,
post_tool: AtomicU32,
pre_llm: AtomicU32,
post_llm: AtomicU32,
session_start: AtomicU32,
session_end: AtomicU32,
}
impl CountingHook {
fn new() -> Self {
Self {
pre_tool: AtomicU32::new(0),
post_tool: AtomicU32::new(0),
pre_llm: AtomicU32::new(0),
post_llm: AtomicU32::new(0),
session_start: AtomicU32::new(0),
session_end: AtomicU32::new(0),
}
}
}
impl LifecycleHook for CountingHook {
fn pre_tool_call(&self, _tool_name: &str, _input: &serde_json::Value) {
self.pre_tool.fetch_add(1, Ordering::Relaxed);
}
fn post_tool_call(&self, _tool_name: &str, _result: &str) {
self.post_tool.fetch_add(1, Ordering::Relaxed);
}
fn pre_llm_call(&self, _request: &ProviderRequest) {
self.pre_llm.fetch_add(1, Ordering::Relaxed);
}
fn post_llm_call(&self, _request: &ProviderRequest) {
self.post_llm.fetch_add(1, Ordering::Relaxed);
}
fn on_session_start(&self, _session_id: &str) {
self.session_start.fetch_add(1, Ordering::Relaxed);
}
fn on_session_end(&self, _session_id: &str, _output: &RunOutput) {
self.session_end.fetch_add(1, Ordering::Relaxed);
}
}
let provider = ScriptedProvider {
turns: vec![
ModelTurn {
directives: vec![ModelDirective::ToolCall {
call: ToolCall {
call_id: "call-1".to_string(),
tool_name: "echo".to_string(),
input: json!({ "value": "hello" }),
},
}],
stop_reason: ModelStopReason::ToolUse,
usage: None,
},
ModelTurn {
directives: vec![ModelDirective::FinalAnswer {
text: "done".to_string(),
}],
stop_reason: ModelStopReason::EndTurn,
usage: None,
},
],
cursor: Mutex::new(0),
};
let mut tools = ToolRegistry::default();
tools.register(EchoTool);
let hook = Arc::new(CountingHook::new());
let orchestrator = Orchestrator::new(
Arc::new(provider),
tools,
Vec::new(),
OrchestratorConfig {
max_iterations: 4,
context: None,
context_compiler: None,
},
)
.with_lifecycle_hooks(vec![hook.clone()]);
let output = orchestrator.run(
RunInput {
run_id: "run-1".to_string(),
session_id: "session-1".to_string(),
branch_id: "main".to_string(),
messages: vec![ChatMessage::user("test")],
state: AppState::default(),
},
|_| {},
);
assert_eq!(output.reason, RunStopReason::Completed);
assert_eq!(hook.session_start.load(Ordering::Relaxed), 1);
assert_eq!(hook.session_end.load(Ordering::Relaxed), 1);
assert_eq!(hook.pre_llm.load(Ordering::Relaxed), 2);
assert_eq!(hook.post_llm.load(Ordering::Relaxed), 2);
assert_eq!(hook.pre_tool.load(Ordering::Relaxed), 1);
assert_eq!(hook.post_tool.load(Ordering::Relaxed), 1);
}
}