use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use opi_agent::event::AgentEvent;
use opi_agent::extension::{Extension, ExtensionCommand, ExtensionError, ExtensionRegistry};
use opi_agent::hooks::AgentHooks;
use opi_agent::loop_types::{AgentError, AgentLoopConfig};
use opi_agent::message::AgentMessage;
use opi_agent::sdk::SDK_SCHEMA_VERSION;
use opi_agent::tool::{ExecutionMode, Tool, ToolError, ToolResult};
use opi_ai::message::{AssistantContent, OutputContent, ToolDef};
use opi_ai::provider::Provider;
use opi_ai::test_support::{MockProvider, text_response, tool_call_response};
use serde_json::Value;
use tokio_util::sync::CancellationToken;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
struct SubAgentRunRecord {
id: String,
prompt: String,
status: String,
result: Option<String>,
error: Option<String>,
}
struct SubAgentExtension {
provider_factory: Arc<dyn Fn() -> Box<dyn Provider> + Send + Sync>,
tools_factory: Arc<dyn Fn() -> Vec<Box<dyn Tool>> + Send + Sync>,
model: String,
runs: Arc<Mutex<Vec<SubAgentRunRecord>>>,
parent_events: Arc<Mutex<Vec<String>>>,
child_events: Arc<Mutex<Vec<String>>>,
active_child_cancel: Arc<Mutex<Option<CancellationToken>>>,
next_run_id: AtomicU64,
}
impl SubAgentExtension {
fn new(
provider_factory: Arc<dyn Fn() -> Box<dyn Provider> + Send + Sync>,
tools_factory: Arc<dyn Fn() -> Vec<Box<dyn Tool>> + Send + Sync>,
model: String,
) -> Self {
Self {
provider_factory,
tools_factory,
model,
runs: Arc::new(Mutex::new(Vec::new())),
parent_events: Arc::new(Mutex::new(Vec::new())),
child_events: Arc::new(Mutex::new(Vec::new())),
active_child_cancel: Arc::new(Mutex::new(None)),
next_run_id: AtomicU64::new(1),
}
}
fn alloc_run_id(&self) -> String {
let id = self.next_run_id.fetch_add(1, Ordering::SeqCst);
format!("run-{id}")
}
fn extract_final_text(messages: &[AgentMessage]) -> String {
messages
.iter()
.rev()
.find_map(|m| {
if let AgentMessage::Llm(opi_ai::message::Message::Assistant(a)) = m {
let text: String = a
.content
.iter()
.filter_map(|c| match c {
AssistantContent::Text { text } => Some(text.clone()),
_ => None,
})
.collect();
if !text.is_empty() {
return Some(text);
}
}
None
})
.unwrap_or_default()
}
}
impl Extension for SubAgentExtension {
fn name(&self) -> &str {
"sub-agent"
}
fn on_event(&self, event: &AgentEvent) {
let label = match event {
AgentEvent::AgentStart => "AgentStart".to_string(),
AgentEvent::AgentEnd { .. } => "AgentEnd".to_string(),
AgentEvent::TurnStart => "TurnStart".to_string(),
AgentEvent::TurnEnd { .. } => "TurnEnd".to_string(),
AgentEvent::ToolExecutionStart { tool_name, .. } => {
format!("ToolExecutionStart({tool_name})")
}
AgentEvent::ToolExecutionEnd { tool_name, .. } => {
format!("ToolExecutionEnd({tool_name})")
}
_ => "Other".to_string(),
};
self.parent_events.lock().unwrap().push(label);
}
fn on_command(
&self,
command: &ExtensionCommand,
) -> Pin<Box<dyn Future<Output = Result<Option<Value>, ExtensionError>> + Send>> {
let name = command.name.clone();
let args = command.args.clone();
let provider_factory = self.provider_factory.clone();
let tools_factory = self.tools_factory.clone();
let model = self.model.clone();
let runs = self.runs.clone();
let child_events = self.child_events.clone();
let active_cancel = self.active_child_cancel.clone();
let run_id = self.alloc_run_id();
Box::pin(async move {
match name.as_str() {
"sub-agent/run" => {
let prompt = args["prompt"].as_str().unwrap_or("").to_string();
let child_provider = provider_factory();
let child_tools = tools_factory();
let child_hooks = Box::new(ChildHooks) as Box<dyn AgentHooks>;
let mut child_agent = opi_agent::Agent::new(
child_provider,
child_tools,
model,
None,
AgentLoopConfig {
max_turns: 10,
..Default::default()
},
child_hooks,
);
let ce = child_events.clone();
child_agent.subscribe(Box::new(move |event: &AgentEvent| {
let label = match event {
AgentEvent::AgentStart => "ChildAgentStart".to_string(),
AgentEvent::AgentEnd { .. } => "ChildAgentEnd".to_string(),
AgentEvent::TurnStart => "ChildTurnStart".to_string(),
AgentEvent::TurnEnd { .. } => "ChildTurnEnd".to_string(),
AgentEvent::ToolExecutionStart { tool_name, .. } => {
format!("ChildToolExecutionStart({tool_name})")
}
AgentEvent::ToolExecutionEnd { tool_name, .. } => {
format!("ChildToolExecutionEnd({tool_name})")
}
_ => "ChildOther".to_string(),
};
ce.lock().unwrap().push(label);
}));
let child_token = child_agent.cancel_token();
*active_cancel.lock().unwrap() = Some(child_token);
let result = child_agent.prompt(&prompt).await;
*active_cancel.lock().unwrap() = None;
let record = match result {
Ok(messages) => {
let text = Self::extract_final_text(&messages);
SubAgentRunRecord {
id: run_id,
prompt,
status: "completed".to_string(),
result: Some(text),
error: None,
}
}
Err(AgentError::Cancelled) => SubAgentRunRecord {
id: run_id,
prompt,
status: "cancelled".to_string(),
result: None,
error: Some("cancelled".to_string()),
},
Err(e) => SubAgentRunRecord {
id: run_id,
prompt,
status: "error".to_string(),
result: None,
error: Some(e.to_string()),
},
};
let response_text = record
.result
.clone()
.unwrap_or_else(|| record.error.clone().unwrap_or_default());
let status = record.status.clone();
let rid = record.id.clone();
runs.lock().unwrap().push(record);
Ok(Some(serde_json::json!({
"run_id": rid,
"status": status,
"result": response_text,
"sdk_schema_version": SDK_SCHEMA_VERSION,
})))
}
"sub-agent/list" => {
let runs_guard = runs.lock().unwrap();
let run_list: Vec<Value> = runs_guard
.iter()
.map(|r| {
serde_json::json!({
"id": r.id,
"prompt": r.prompt,
"status": r.status,
"result": r.result,
"error": r.error,
})
})
.collect();
Ok(Some(serde_json::json!({ "runs": run_list })))
}
_ => Ok(None),
}
})
}
fn serialize_state(&self) -> Result<Option<Value>, ExtensionError> {
let runs = self.runs.lock().unwrap();
Ok(Some(serde_json::json!({
"runs": serde_json::to_value(&*runs).unwrap_or(serde_json::json!([])),
"sdk_schema_version": SDK_SCHEMA_VERSION,
})))
}
fn restore_state(&self, state: Value) -> Result<(), ExtensionError> {
if let Some(runs_val) = state["runs"].as_array() {
let mut runs = self.runs.lock().unwrap();
runs.clear();
for r in runs_val {
if let Ok(record) = serde_json::from_value::<SubAgentRunRecord>(r.clone()) {
runs.push(record);
}
}
}
Ok(())
}
}
struct ChildHooks;
impl AgentHooks for ChildHooks {
fn convert_to_llm(
&self,
messages: &[AgentMessage],
) -> Result<Vec<opi_ai::message::Message>, AgentError> {
Ok(messages
.iter()
.filter_map(|m| match m {
AgentMessage::Llm(msg) => Some(msg.clone()),
_ => None,
})
.collect())
}
}
struct DummyTool {
name: String,
}
impl DummyTool {
fn new(name: &str) -> Self {
Self {
name: name.to_string(),
}
}
}
impl Tool for DummyTool {
fn definition(&self) -> ToolDef {
serde_json::from_value(serde_json::json!({
"name": self.name,
"description": format!("{} tool", self.name),
"input_schema": { "type": "object", "properties": {} }
}))
.unwrap()
}
fn execute(
&self,
_call_id: &str,
_arguments: Value,
_signal: CancellationToken,
_on_update: Option<opi_agent::tool::UpdateCallback>,
) -> Pin<Box<dyn Future<Output = Result<ToolResult, ToolError>> + Send>> {
Box::pin(async {
Ok(ToolResult {
content: vec![OutputContent::Text {
text: "child-ok".into(),
}],
details: None,
is_error: false,
terminate: false,
})
})
}
fn execution_mode(&self) -> ExecutionMode {
ExecutionMode::Parallel
}
}
struct BlockingTool;
impl Tool for BlockingTool {
fn definition(&self) -> ToolDef {
serde_json::from_value(serde_json::json!({
"name": "blocking",
"description": "A tool that blocks until cancelled",
"input_schema": { "type": "object", "properties": {} }
}))
.unwrap()
}
fn execute(
&self,
_call_id: &str,
_arguments: Value,
signal: CancellationToken,
_on_update: Option<opi_agent::tool::UpdateCallback>,
) -> Pin<Box<dyn Future<Output = Result<ToolResult, ToolError>> + Send>> {
Box::pin(async move {
signal.cancelled().await;
Ok(ToolResult {
content: vec![OutputContent::Text {
text: "cancelled".into(),
}],
details: None,
is_error: false,
terminate: false,
})
})
}
fn execution_mode(&self) -> ExecutionMode {
ExecutionMode::Sequential
}
}
#[tokio::test]
async fn child_run_completes_and_result_routed_to_parent() {
let ext = SubAgentExtension::new(
Arc::new(|| {
Box::new(MockProvider::new(
"child",
vec![text_response("Child says hello")],
)) as Box<dyn Provider>
}),
Arc::new(Vec::new as fn() -> Vec<Box<dyn Tool>>),
"mock:child-model".into(),
);
let runs = ext.runs.clone();
let child_events = ext.child_events.clone();
let mut registry = ExtensionRegistry::new();
registry.register(Box::new(ext)).unwrap();
let cmd = ExtensionCommand::new("sub-agent/run", serde_json::json!({ "prompt": "hello" }));
let result = registry.dispatch_command(&cmd).await.unwrap().unwrap();
assert_eq!(result["status"], "completed");
assert_eq!(result["result"], "Child says hello");
assert_eq!(result["sdk_schema_version"], SDK_SCHEMA_VERSION);
let runs_guard = runs.lock().unwrap();
assert_eq!(runs_guard.len(), 1);
assert_eq!(runs_guard[0].status, "completed");
assert_eq!(runs_guard[0].result.as_deref(), Some("Child says hello"));
let ce = child_events.lock().unwrap();
assert!(
ce.iter().any(|e| e == "ChildAgentStart"),
"should have ChildAgentStart, got: {ce:?}"
);
assert!(
ce.iter().any(|e| e == "ChildAgentEnd"),
"should have ChildAgentEnd, got: {ce:?}"
);
}
#[tokio::test]
async fn child_run_with_tool_call_completes() {
let ext = SubAgentExtension::new(
Arc::new(|| {
Box::new(MockProvider::new(
"child",
vec![
tool_call_response("tc_1", "read", r#"{"path":"/tmp/f"}"#),
text_response("Read result: contents"),
],
)) as Box<dyn Provider>
}),
Arc::new(|| vec![Box::new(DummyTool::new("read"))]),
"mock:child-model".into(),
);
let child_events = ext.child_events.clone();
let mut registry = ExtensionRegistry::new();
registry.register(Box::new(ext)).unwrap();
let cmd = ExtensionCommand::new(
"sub-agent/run",
serde_json::json!({ "prompt": "read the file" }),
);
let result = registry.dispatch_command(&cmd).await.unwrap().unwrap();
assert_eq!(result["status"], "completed");
assert_eq!(result["result"], "Read result: contents");
let ce = child_events.lock().unwrap();
assert!(
ce.iter().any(|e| e == "ChildToolExecutionStart(read)"),
"should have ChildToolExecutionStart(read), got: {ce:?}"
);
assert!(
ce.iter().any(|e| e == "ChildToolExecutionEnd(read)"),
"should have ChildToolExecutionEnd(read), got: {ce:?}"
);
}
#[tokio::test]
async fn child_provider_error_propagates_to_parent() {
let ext = SubAgentExtension::new(
Arc::new(|| {
Box::new(MockProvider::new_with_errors(
"child",
vec![opi_ai::test_support::MockResponse::Error(
opi_ai::provider::ProviderError::AuthFailed("bad key".into()),
)],
)) as Box<dyn Provider>
}),
Arc::new(Vec::new as fn() -> Vec<Box<dyn Tool>>),
"mock:child-model".into(),
);
let runs = ext.runs.clone();
let mut registry = ExtensionRegistry::new();
registry.register(Box::new(ext)).unwrap();
let cmd = ExtensionCommand::new("sub-agent/run", serde_json::json!({ "prompt": "test" }));
let result = registry.dispatch_command(&cmd).await.unwrap().unwrap();
assert_eq!(result["status"], "error");
assert!(
result["result"]
.as_str()
.unwrap()
.contains("authentication failed"),
"error should mention auth failure, got: {}",
result["result"]
);
let runs_guard = runs.lock().unwrap();
assert_eq!(runs_guard.len(), 1);
assert_eq!(runs_guard[0].status, "error");
assert!(runs_guard[0].error.is_some());
}
#[tokio::test]
async fn child_run_cancelled_mid_execution() {
let ext = SubAgentExtension::new(
Arc::new(|| {
Box::new(MockProvider::new(
"child",
vec![
tool_call_response("tc_1", "blocking", "{}"),
text_response("done"),
],
)) as Box<dyn Provider>
}),
Arc::new(|| vec![Box::new(BlockingTool)]),
"mock:child-model".into(),
);
let child_events = ext.child_events.clone();
let runs = ext.runs.clone();
let ext_active_cancel = ext.active_child_cancel.clone();
let mut registry = ExtensionRegistry::new();
registry.register(Box::new(ext)).unwrap();
let cmd = ExtensionCommand::new("sub-agent/run", serde_json::json!({ "prompt": "run" }));
let registry_handle = Arc::new(registry);
let registry_for_task = registry_handle.clone();
let cmd_for_task = cmd.clone();
let task = tokio::spawn(async move {
registry_for_task
.dispatch_command(&cmd_for_task)
.await
.unwrap()
.unwrap()
});
let cancel_token = {
loop {
tokio::task::yield_now().await;
let guard = ext_active_cancel.lock().unwrap();
if guard.is_some() {
break guard.clone().unwrap();
}
}
};
cancel_token.cancel();
let result = task.await.unwrap();
assert_eq!(result["status"], "cancelled");
let runs_guard = runs.lock().unwrap();
assert_eq!(runs_guard.len(), 1);
assert_eq!(runs_guard[0].status, "cancelled");
let ce = child_events.lock().unwrap();
assert!(
ce.iter().any(|e| e == "ChildAgentStart"),
"should have ChildAgentStart, got: {ce:?}"
);
}
#[tokio::test]
async fn child_events_observable_by_parent() {
let ext = SubAgentExtension::new(
Arc::new(|| {
Box::new(MockProvider::new(
"child",
vec![
tool_call_response("tc_1", "search", r#"{"query":"test"}"#),
text_response("Found results"),
],
)) as Box<dyn Provider>
}),
Arc::new(|| vec![Box::new(DummyTool::new("search"))]),
"mock:child-model".into(),
);
let child_events = ext.child_events.clone();
let mut registry = ExtensionRegistry::new();
registry.register(Box::new(ext)).unwrap();
let cmd = ExtensionCommand::new(
"sub-agent/run",
serde_json::json!({ "prompt": "search for X" }),
);
let _ = registry.dispatch_command(&cmd).await.unwrap().unwrap();
let ce = child_events.lock().unwrap();
assert!(ce.iter().any(|e| e == "ChildAgentStart"));
assert!(ce.iter().any(|e| e == "ChildTurnStart"));
assert!(ce.iter().any(|e| e == "ChildToolExecutionStart(search)"));
assert!(ce.iter().any(|e| e == "ChildToolExecutionEnd(search)"));
assert!(ce.iter().any(|e| e == "ChildTurnEnd"));
assert!(ce.iter().any(|e| e == "ChildAgentEnd"));
}
#[tokio::test]
async fn extension_receives_parent_agent_events() {
let ext = SubAgentExtension::new(
Arc::new(|| {
Box::new(MockProvider::new("child", vec![text_response("ok")])) as Box<dyn Provider>
}),
Arc::new(Vec::new as fn() -> Vec<Box<dyn Tool>>),
"mock:child-model".into(),
);
let parent_events = ext.parent_events.clone();
let mut registry = ExtensionRegistry::new();
registry.register(Box::new(ext)).unwrap();
let base_sink = Box::new(|_: AgentEvent| {}) as Box<dyn Fn(AgentEvent) + Send + Sync>;
let wrapped_sink = registry.wrap_event_sink(base_sink);
wrapped_sink(AgentEvent::AgentStart);
wrapped_sink(AgentEvent::TurnStart);
wrapped_sink(AgentEvent::ToolExecutionStart {
tool_call_id: "tc_1".into(),
tool_name: "read".into(),
args: serde_json::json!({}),
});
let received = parent_events.lock().unwrap();
assert!(
received.contains(&"AgentStart".to_string()),
"should have AgentStart"
);
assert!(
received.contains(&"TurnStart".to_string()),
"should have TurnStart"
);
assert!(
received.contains(&"ToolExecutionStart(read)".to_string()),
"should have ToolExecutionStart(read)"
);
}
#[tokio::test]
async fn multiple_child_runs_have_isolated_state() {
let call_count = Arc::new(AtomicU64::new(0));
let call_count_clone = call_count.clone();
let ext = SubAgentExtension::new(
Arc::new(move || {
let count = call_count_clone.fetch_add(1, Ordering::SeqCst);
if count == 0 {
Box::new(MockProvider::new("child", vec![text_response("alpha")]))
as Box<dyn Provider>
} else {
Box::new(MockProvider::new("child", vec![text_response("beta")]))
as Box<dyn Provider>
}
}),
Arc::new(Vec::new as fn() -> Vec<Box<dyn Tool>>),
"mock:child-model".into(),
);
let runs = ext.runs.clone();
let mut registry = ExtensionRegistry::new();
registry.register(Box::new(ext)).unwrap();
let cmd1 = ExtensionCommand::new("sub-agent/run", serde_json::json!({ "prompt": "first" }));
let result1 = registry.dispatch_command(&cmd1).await.unwrap().unwrap();
assert_eq!(result1["status"], "completed");
assert_eq!(result1["result"], "alpha");
let cmd2 = ExtensionCommand::new("sub-agent/run", serde_json::json!({ "prompt": "second" }));
let result2 = registry.dispatch_command(&cmd2).await.unwrap().unwrap();
assert_eq!(result2["status"], "completed");
assert_eq!(result2["result"], "beta");
let runs_guard = runs.lock().unwrap();
assert_eq!(runs_guard.len(), 2);
assert_eq!(runs_guard[0].prompt, "first");
assert_eq!(runs_guard[0].result.as_deref(), Some("alpha"));
assert_eq!(runs_guard[1].prompt, "second");
assert_eq!(runs_guard[1].result.as_deref(), Some("beta"));
}
#[tokio::test]
async fn run_history_visible_via_list_command() {
let ext = SubAgentExtension::new(
Arc::new(|| {
Box::new(MockProvider::new("child", vec![text_response("ok")])) as Box<dyn Provider>
}),
Arc::new(Vec::new as fn() -> Vec<Box<dyn Tool>>),
"mock:child-model".into(),
);
let mut registry = ExtensionRegistry::new();
registry.register(Box::new(ext)).unwrap();
let cmd1 = ExtensionCommand::new("sub-agent/run", serde_json::json!({ "prompt": "task A" }));
let _ = registry.dispatch_command(&cmd1).await.unwrap().unwrap();
let cmd2 = ExtensionCommand::new("sub-agent/run", serde_json::json!({ "prompt": "task B" }));
let _ = registry.dispatch_command(&cmd2).await.unwrap().unwrap();
let list_cmd = ExtensionCommand::new("sub-agent/list", serde_json::json!({}));
let list_result = registry.dispatch_command(&list_cmd).await.unwrap().unwrap();
let runs = list_result["runs"].as_array().unwrap();
assert_eq!(runs.len(), 2);
assert_eq!(runs[0]["prompt"], "task A");
assert_eq!(runs[0]["status"], "completed");
assert_eq!(runs[1]["prompt"], "task B");
assert_eq!(runs[1]["status"], "completed");
}
#[tokio::test]
async fn run_history_round_trips_through_serialization() {
let ext = SubAgentExtension::new(
Arc::new(|| {
Box::new(MockProvider::new("child", vec![text_response("ok")])) as Box<dyn Provider>
}),
Arc::new(Vec::new as fn() -> Vec<Box<dyn Tool>>),
"mock:child-model".into(),
);
let mut registry = ExtensionRegistry::new();
registry.register(Box::new(ext)).unwrap();
let cmd = ExtensionCommand::new("sub-agent/run", serde_json::json!({ "prompt": "test" }));
let _ = registry.dispatch_command(&cmd).await.unwrap().unwrap();
let states = registry.serialize_states().unwrap();
assert!(states["sub-agent"]["runs"].is_array());
let ext2 = SubAgentExtension::new(
Arc::new(|| {
Box::new(MockProvider::new("child", vec![text_response("ok")])) as Box<dyn Provider>
}),
Arc::new(Vec::new as fn() -> Vec<Box<dyn Tool>>),
"mock:child-model".into(),
);
let mut registry2 = ExtensionRegistry::new();
registry2.register(Box::new(ext2)).unwrap();
registry2.restore_states(states).unwrap();
let list_cmd = ExtensionCommand::new("sub-agent/list", serde_json::json!({}));
let list_result = registry2
.dispatch_command(&list_cmd)
.await
.unwrap()
.unwrap();
let runs = list_result["runs"].as_array().unwrap();
assert_eq!(runs.len(), 1);
assert_eq!(runs[0]["prompt"], "test");
assert_eq!(runs[0]["status"], "completed");
}
#[tokio::test]
async fn unknown_command_returns_none() {
let ext = SubAgentExtension::new(
Arc::new(|| {
Box::new(MockProvider::new("child", vec![text_response("ok")])) as Box<dyn Provider>
}),
Arc::new(Vec::new as fn() -> Vec<Box<dyn Tool>>),
"mock:child-model".into(),
);
let mut registry = ExtensionRegistry::new();
registry.register(Box::new(ext)).unwrap();
let cmd = ExtensionCommand::new("unknown/command", serde_json::json!({}));
let result = registry.dispatch_command(&cmd).await.unwrap();
assert!(
result.is_none(),
"unknown command should return None, got: {result:?}"
);
}