use std::pin::Pin;
use std::sync::{Arc, Mutex};
use opi_agent::event::{AgentEvent, AgentEventSink};
use opi_agent::hooks::{
AgentHooks, BeforeToolCallContext, BeforeToolCallResult, ShouldStopAfterTurnContext,
};
use opi_agent::loop_types::{AgentError, AgentLoopConfig, AgentLoopContext};
use opi_agent::message::AgentMessage;
use opi_agent::tool::{Tool, ToolError, ToolResult};
use opi_ai::message::{InputContent, Message, UserMessage};
use opi_ai::test_support::{self, MockProvider};
use serde_json::json;
use tokio_util::sync::CancellationToken;
struct MockTool {
name: String,
call_log: Arc<Mutex<Vec<serde_json::Value>>>,
}
impl MockTool {
fn new(name: &str, call_log: Arc<Mutex<Vec<serde_json::Value>>>) -> Self {
Self {
name: name.to_owned(),
call_log,
}
}
}
impl Tool for MockTool {
fn definition(&self) -> opi_ai::message::ToolDef {
opi_ai::message::ToolDef {
name: self.name.clone(),
description: format!("Mock tool: {}", self.name),
input_schema: json!({
"type": "object",
"properties": {
"arg": { "type": "string" }
},
"required": ["arg"]
}),
}
}
fn execute(
&self,
_call_id: &str,
arguments: serde_json::Value,
_signal: CancellationToken,
_on_update: Option<opi_agent::tool::UpdateCallback>,
) -> Pin<Box<dyn std::future::Future<Output = Result<ToolResult, ToolError>> + Send>> {
let log = self.call_log.clone();
log.lock().unwrap().push(arguments.clone());
let text = arguments
.get("arg")
.and_then(|v| v.as_str())
.unwrap_or("mock-result")
.to_owned();
Box::pin(async move {
Ok(ToolResult {
content: vec![opi_ai::message::OutputContent::Text {
text: format!("tool-result: {text}"),
}],
details: None,
is_error: false,
terminate: false,
})
})
}
}
struct TestHooks;
impl AgentHooks for TestHooks {
fn convert_to_llm(&self, messages: &[AgentMessage]) -> Result<Vec<Message>, AgentError> {
let mut result = Vec::new();
for msg in messages {
if let AgentMessage::Llm(m) = msg {
result.push(m.clone());
}
}
Ok(result)
}
fn should_stop_after_turn(
&self,
_ctx: ShouldStopAfterTurnContext,
) -> Pin<Box<dyn std::future::Future<Output = bool> + Send>> {
Box::pin(async { false })
}
fn before_tool_call(
&self,
_ctx: BeforeToolCallContext,
) -> Pin<Box<dyn std::future::Future<Output = BeforeToolCallResult> + Send>> {
Box::pin(async { BeforeToolCallResult::Allow })
}
}
#[tokio::test]
async fn no_tool_turn_emits_lifecycle_events() {
let response_events = test_support::text_response("Hello!");
let provider = MockProvider::new("mock", vec![response_events]);
let collected_events: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let event_collector = collected_events.clone();
let sink: AgentEventSink = Box::new(move |event| {
let name = match &event {
AgentEvent::AgentStart => "AgentStart",
AgentEvent::AgentEnd { .. } => "AgentEnd",
AgentEvent::TurnStart => "TurnStart",
AgentEvent::TurnEnd { .. } => "TurnEnd",
AgentEvent::MessageStart { .. } => "MessageStart",
AgentEvent::MessageUpdate { .. } => "MessageUpdate",
AgentEvent::MessageEnd { .. } => "MessageEnd",
AgentEvent::ToolExecutionStart { .. } => "ToolExecutionStart",
AgentEvent::ToolExecutionUpdate { .. } => "ToolExecutionUpdate",
AgentEvent::ToolExecutionEnd { .. } => "ToolExecutionEnd",
AgentEvent::QueueUpdate { .. } => "QueueUpdate",
_ => "Unknown",
};
event_collector.lock().unwrap().push(name.to_owned());
});
let context = AgentLoopContext {
provider: Box::new(provider),
tools: vec![],
messages: vec![AgentMessage::Llm(Message::User(UserMessage {
content: vec![InputContent::Text { text: "Hi".into() }],
timestamp_ms: 0,
}))],
model: "mock-model".into(),
system: None,
steering_queue: None,
follow_up_queue: None,
};
let config = AgentLoopConfig {
max_turns: 10,
..Default::default()
};
let hooks = TestHooks;
let result = opi_agent::agent_loop(context, config, &hooks, sink, CancellationToken::new())
.await
.unwrap();
assert!(result.len() >= 2);
let events = collected_events.lock().unwrap();
assert!(
events.contains(&"AgentStart".to_owned()),
"missing AgentStart"
);
assert!(events.contains(&"AgentEnd".to_owned()), "missing AgentEnd");
assert!(
events.contains(&"TurnStart".to_owned()),
"missing TurnStart"
);
assert!(events.contains(&"TurnEnd".to_owned()), "missing TurnEnd");
}
#[tokio::test]
async fn tool_use_turn_executes_tool_and_loops() {
let tool_call_log: Arc<Mutex<Vec<serde_json::Value>>> = Arc::new(Mutex::new(Vec::new()));
let first_response = test_support::tool_call_response("tc-1", "greet", r#"{"arg":"world"}"#);
let second_response = test_support::text_response("Done!");
let provider = MockProvider::new("mock", vec![first_response, second_response]);
let tool = MockTool::new("greet", tool_call_log.clone());
let collected_events: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let event_collector = collected_events.clone();
let sink: AgentEventSink = Box::new(move |event| {
let name = match &event {
AgentEvent::AgentStart => "AgentStart",
AgentEvent::AgentEnd { .. } => "AgentEnd",
AgentEvent::TurnStart => "TurnStart",
AgentEvent::TurnEnd { .. } => "TurnEnd",
AgentEvent::MessageStart { .. } => "MessageStart",
AgentEvent::MessageUpdate { .. } => "MessageUpdate",
AgentEvent::MessageEnd { .. } => "MessageEnd",
AgentEvent::ToolExecutionStart { tool_name, .. } => {
event_collector
.lock()
.unwrap()
.push(format!("ToolExecutionStart:{tool_name}"));
return;
}
AgentEvent::ToolExecutionUpdate { .. } => "ToolExecutionUpdate",
AgentEvent::ToolExecutionEnd {
tool_name,
is_error,
..
} => {
let status = if *is_error { "err" } else { "ok" };
event_collector
.lock()
.unwrap()
.push(format!("ToolExecutionEnd:{tool_name}:{status}"));
return;
}
AgentEvent::QueueUpdate { .. } => "QueueUpdate",
_ => "Unknown",
};
event_collector.lock().unwrap().push(name.to_owned());
});
let context = AgentLoopContext {
provider: Box::new(provider),
tools: vec![Box::new(tool)],
messages: vec![AgentMessage::Llm(Message::User(UserMessage {
content: vec![InputContent::Text {
text: "Use the tool".into(),
}],
timestamp_ms: 0,
}))],
model: "mock-model".into(),
system: None,
steering_queue: None,
follow_up_queue: None,
};
let config = AgentLoopConfig {
max_turns: 10,
..Default::default()
};
let hooks = TestHooks;
let result = opi_agent::agent_loop(context, config, &hooks, sink, CancellationToken::new())
.await
.unwrap();
let log = tool_call_log.lock().unwrap();
assert_eq!(log.len(), 1, "tool should have been called exactly once");
assert_eq!(log[0]["arg"], "world");
assert!(
result.len() >= 3,
"expected at least 3 messages, got {}",
result.len()
);
let events = collected_events.lock().unwrap();
assert!(
events.iter().any(|e| e == "ToolExecutionStart:greet"),
"missing ToolExecutionStart for greet"
);
assert!(
events.iter().any(|e| e == "ToolExecutionEnd:greet:ok"),
"missing ToolExecutionEnd(ok) for greet"
);
let turn_starts = events.iter().filter(|e| *e == "TurnStart").count();
assert!(
turn_starts >= 2,
"expected at least 2 TurnStart events, got {}",
turn_starts
);
}
#[tokio::test]
async fn text_content_preserved_in_assistant_message() {
let response_events = test_support::text_response("Hello, world!");
let provider = MockProvider::new("mock", vec![response_events]);
let sink: AgentEventSink = Box::new(|_| {});
let context = AgentLoopContext {
provider: Box::new(provider),
tools: vec![],
messages: vec![AgentMessage::Llm(Message::User(UserMessage {
content: vec![InputContent::Text { text: "Hi".into() }],
timestamp_ms: 0,
}))],
model: "mock-model".into(),
system: None,
steering_queue: None,
follow_up_queue: None,
};
let config = AgentLoopConfig {
max_turns: 10,
..Default::default()
};
let hooks = TestHooks;
let result = opi_agent::agent_loop(context, config, &hooks, sink, CancellationToken::new())
.await
.unwrap();
let assistant = result
.iter()
.find_map(|m| match m {
AgentMessage::Llm(Message::Assistant(a)) => Some(a),
_ => None,
})
.expect("should have assistant message");
let has_text = assistant.content.iter().any(
|c| matches!(c, opi_ai::message::AssistantContent::Text { text } if text.contains("Hello")),
);
assert!(
has_text,
"assistant message must contain text, got: {:?}",
assistant.content
);
}