use std::pin::Pin;
use std::sync::{Arc, Mutex};
use futures_util::StreamExt;
use futures_util::stream;
use opi_agent::agent::Agent;
use opi_agent::event::AgentEvent;
use opi_agent::hooks::{
AgentHooks, BeforeToolCallContext, BeforeToolCallResult, ShouldStopAfterTurnContext,
};
use opi_agent::loop_types::{AgentError, AgentLoopConfig};
use opi_agent::message::AgentMessage;
use opi_ai::message::{AssistantContent, AssistantMessage, InputContent, Message};
use opi_ai::provider::{EventStream, Provider, ProviderError, Request};
use opi_ai::stream::{AssistantStreamEvent, StopReason, Usage};
struct MockProvider {
id: String,
responses: Arc<Mutex<Vec<Vec<AssistantStreamEvent>>>>,
}
impl MockProvider {
fn new(id: &str, responses: Vec<Vec<AssistantStreamEvent>>) -> Self {
Self {
id: id.to_owned(),
responses: Arc::new(Mutex::new(responses)),
}
}
}
impl Provider for MockProvider {
fn id(&self) -> &str {
&self.id
}
fn models(&self) -> &[opi_ai::provider::ModelInfo] {
&[]
}
fn stream(&self, _request: Request) -> EventStream {
let events = self.responses.lock().unwrap().remove(0);
Box::pin(stream::iter(events.into_iter().map(Ok::<_, ProviderError>)))
}
}
struct TestHooks;
impl AgentHooks for TestHooks {
fn convert_to_llm(&self, messages: &[AgentMessage]) -> Result<Vec<Message>, AgentError> {
let mut result = Vec::new();
for msg in messages {
if let AgentMessage::Llm(m) = msg {
result.push(m.clone());
}
}
Ok(result)
}
fn should_stop_after_turn(
&self,
_ctx: ShouldStopAfterTurnContext,
) -> Pin<Box<dyn std::future::Future<Output = bool> + Send>> {
Box::pin(async { false })
}
fn before_tool_call(
&self,
_ctx: BeforeToolCallContext,
) -> Pin<Box<dyn std::future::Future<Output = BeforeToolCallResult> + Send>> {
Box::pin(async { BeforeToolCallResult::Allow })
}
}
fn base_assistant() -> AssistantMessage {
AssistantMessage {
content: vec![],
api: opi_ai::ApiKind::Anthropic,
provider: "mock".into(),
model: "mock-model".into(),
response_model: None,
response_id: None,
usage: Usage::default(),
stop_reason: StopReason::Stop,
error_message: None,
timestamp_ms: 0,
}
}
fn text_response(text: &str) -> Vec<AssistantStreamEvent> {
let mut partial = base_assistant();
partial
.content
.push(AssistantContent::Text { text: text.into() });
vec![
AssistantStreamEvent::Start {
partial: base_assistant(),
},
AssistantStreamEvent::TextDelta {
content_index: 0,
delta: text.into(),
partial: partial.clone(),
},
AssistantStreamEvent::Done {
reason: StopReason::Stop,
message: partial,
},
]
}
#[tokio::test]
async fn prompt_sends_user_message_and_returns_result() {
let provider = MockProvider::new("mock", vec![text_response("Hello!")]);
let mut agent = Agent::new(
Box::new(provider),
vec![],
"mock-model".into(),
None,
AgentLoopConfig::default(),
Box::new(TestHooks),
);
let result = agent.prompt("Hi there").await.unwrap();
assert!(
result.len() >= 2,
"expected at least 2 messages, got {}",
result.len()
);
if let AgentMessage::Llm(Message::User(msg)) = &result[0] {
match &msg.content[0] {
InputContent::Text { text } => assert_eq!(text, "Hi there"),
_ => panic!("expected text content"),
}
} else {
panic!("first message should be user message");
}
}
#[tokio::test]
async fn prompt_accumulates_state_across_calls() {
let provider = MockProvider::new(
"mock",
vec![text_response("First"), text_response("Second")],
);
let mut agent = Agent::new(
Box::new(provider),
vec![],
"mock-model".into(),
None,
AgentLoopConfig::default(),
Box::new(TestHooks),
);
let r1 = agent.prompt("Hello").await.unwrap();
assert!(r1.len() >= 2);
let r2 = agent.prompt("World").await.unwrap();
assert!(
r2.len() >= 4,
"expected at least 4 messages after two prompts, got {}",
r2.len()
);
}
#[tokio::test]
async fn continue_appends_message_and_runs_loop() {
let provider = MockProvider::new(
"mock",
vec![text_response("First"), text_response("Continued")],
);
let mut agent = Agent::new(
Box::new(provider),
vec![],
"mock-model".into(),
None,
AgentLoopConfig::default(),
Box::new(TestHooks),
);
let r1 = agent.prompt("Hello").await.unwrap();
assert!(r1.len() >= 2);
let r2 = agent.continue_("Tell me more").await.unwrap();
assert!(
r2.len() >= 4,
"expected at least 4 messages after prompt+continue, got {}",
r2.len()
);
}
#[tokio::test]
async fn abort_cancels_running_loop() {
struct BlockingProvider;
impl Provider for BlockingProvider {
fn id(&self) -> &str {
"blocking"
}
fn models(&self) -> &[opi_ai::provider::ModelInfo] {
&[]
}
fn stream(&self, request: Request) -> EventStream {
let cancel = request.cancel;
Box::pin(
futures_util::stream::once(async move {
Ok(AssistantStreamEvent::Start {
partial: base_assistant(),
})
})
.chain(futures_util::stream::unfold((), move |()| {
let cancel = cancel.clone();
async move {
cancel.cancelled().await;
None }
})),
)
}
}
let mut agent = Agent::new(
Box::new(BlockingProvider),
vec![],
"mock-model".into(),
None,
AgentLoopConfig::default(),
Box::new(TestHooks),
);
let token = agent.cancel_token();
let handle = tokio::spawn(async move { agent.prompt("Hello").await });
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
token.cancel();
let result = handle.await.unwrap();
assert!(
matches!(result, Err(AgentError::Cancelled)),
"expected Cancelled error, got {:?}",
result
);
}
#[tokio::test]
async fn subscribe_receives_events() {
let provider = MockProvider::new("mock", vec![text_response("Response")]);
let collected: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let collected_clone = collected.clone();
let mut agent = Agent::new(
Box::new(provider),
vec![],
"mock-model".into(),
None,
AgentLoopConfig::default(),
Box::new(TestHooks),
);
agent.subscribe(Box::new(move |event| {
let name = match event {
AgentEvent::AgentStart => "AgentStart",
AgentEvent::AgentEnd { .. } => "AgentEnd",
AgentEvent::TurnStart => "TurnStart",
AgentEvent::TurnEnd { .. } => "TurnEnd",
AgentEvent::MessageStart { .. } => "MessageStart",
AgentEvent::MessageUpdate { .. } => "MessageUpdate",
AgentEvent::MessageEnd { .. } => "MessageEnd",
AgentEvent::ToolExecutionStart { .. } => "ToolExecutionStart",
AgentEvent::ToolExecutionUpdate { .. } => "ToolExecutionUpdate",
AgentEvent::ToolExecutionEnd { .. } => "ToolExecutionEnd",
_ => "Unknown",
};
collected_clone.lock().unwrap().push(name.to_owned());
}));
let result = agent.prompt("Hello").await.unwrap();
assert!(result.len() >= 2);
let events = collected.lock().unwrap();
assert!(
events.contains(&"AgentStart".to_owned()),
"subscriber should receive AgentStart, got {:?}",
*events
);
assert!(
events.contains(&"AgentEnd".to_owned()),
"subscriber should receive AgentEnd, got {:?}",
*events
);
}