use phi_core::provider::mock::*;
use phi_core::provider::{MockProvider, ModelConfig};
use phi_core::BasicAgent;
use phi_core::{LlmMessage, *};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::mpsc;
#[tokio::test]
async fn test_agent_simple_prompt() {
let provider = MockProvider::text("Hello!");
let mut agent = BasicAgent::new(ModelConfig::anthropic("mock", "mock", "test"))
.with_provider_override(Arc::new(provider))
.with_system_prompt("You are helpful.");
let rx = agent.prompt("Hi there").await;
let mut events = Vec::new();
let mut rx = rx;
while let Ok(e) = rx.try_recv() {
events.push(e);
}
assert!(!events.is_empty());
assert_eq!(agent.messages().len(), 2); }
#[tokio::test]
async fn test_agent_reset() {
let provider = MockProvider::text("Hello!");
let mut agent = BasicAgent::new(ModelConfig::anthropic("mock", "mock", "test"))
.with_provider_override(Arc::new(provider))
.with_system_prompt("test");
let _ = agent.prompt("Hi").await;
assert!(!agent.messages().is_empty());
agent.reset();
assert!(agent.messages().is_empty());
assert!(!agent.is_streaming());
}
#[tokio::test]
async fn test_agent_with_tools() {
struct EchoTool;
#[async_trait::async_trait]
impl AgentTool for EchoTool {
fn name(&self) -> &str {
"echo"
}
fn label(&self) -> &str {
"Echo"
}
fn description(&self) -> &str {
"Echoes input"
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({"type": "object", "properties": {"text": {"type": "string"}}})
}
async fn execute(
&self,
params: serde_json::Value,
_ctx: ToolContext,
) -> Result<ToolResult, ToolError> {
let text = params["text"].as_str().unwrap_or("").to_string();
Ok(ToolResult {
content: vec![Content::Text { text }],
details: serde_json::Value::Null,
child_loop_id: None,
})
}
}
let provider = MockProvider::new(vec![
MockResponse::ToolCalls(vec![MockToolCall {
name: "echo".into(),
arguments: serde_json::json!({"text": "hello"}),
}]),
MockResponse::Text("Echoed: hello".into()),
]);
let mut agent = BasicAgent::new(ModelConfig::anthropic("mock", "mock", "test"))
.with_provider_override(Arc::new(provider))
.with_system_prompt("test")
.with_tools(vec![Arc::new(EchoTool)]);
let _ = agent.prompt("Echo hello").await;
assert_eq!(agent.messages().len(), 4);
}
#[tokio::test]
async fn test_agent_builder_pattern() {
let provider = MockProvider::text("ok");
let agent = BasicAgent::new(ModelConfig::anthropic("test-model", "test-model", "key123"))
.with_provider_override(Arc::new(provider))
.with_system_prompt("sys")
.with_thinking(ThinkingLevel::Medium)
.with_max_tokens(4096);
assert_eq!(agent.system_prompt, "sys");
assert_eq!(agent.model_config.id, "test-model");
assert_eq!(agent.model_config.api_key, "key123");
assert_eq!(agent.thinking_level, ThinkingLevel::Medium);
assert_eq!(agent.max_tokens, Some(4096));
}
#[tokio::test]
async fn test_with_messages_builder() {
let saved = vec![
AgentMessage::Llm(LlmMessage::new(Message::user("Hello"))),
AgentMessage::Llm(LlmMessage::new(Message::Assistant {
content: vec![Content::Text {
text: "Hi there!".into(),
}],
stop_reason: StopReason::Stop,
model: "mock".into(),
provider: "mock".into(),
usage: Usage::default(),
timestamp: 0,
error_message: None,
})),
];
let provider = MockProvider::text("ok");
let agent = BasicAgent::new(ModelConfig::anthropic("mock", "mock", "test"))
.with_provider_override(Arc::new(provider))
.with_messages(saved.clone());
assert_eq!(agent.messages().len(), 2);
assert_eq!(*agent.messages(), saved[..]);
}
#[tokio::test]
async fn test_save_and_restore_messages() {
let provider = MockProvider::text("Hello!");
let mut agent = BasicAgent::new(ModelConfig::anthropic("mock", "mock", "test"))
.with_provider_override(Arc::new(provider))
.with_system_prompt("test");
let _ = agent.prompt("Hi").await;
let json = agent.save_messages().expect("save should succeed");
let provider2 = MockProvider::text("ok");
let mut agent2 = BasicAgent::new(ModelConfig::anthropic("mock", "mock", "test"))
.with_provider_override(Arc::new(provider2))
.with_system_prompt("test");
agent2
.restore_messages(&json)
.expect("restore should succeed");
assert_eq!(agent.messages(), agent2.messages());
}
#[tokio::test]
async fn test_agent_continues_after_restore() {
let provider1 = MockProvider::text("First response");
let mut agent1 = BasicAgent::new(ModelConfig::anthropic("mock", "mock", "test"))
.with_provider_override(Arc::new(provider1))
.with_system_prompt("test");
let _ = agent1.prompt("Hello").await;
let json = agent1.save_messages().expect("save");
let provider2 = MockProvider::text("Second response");
let mut agent2 = BasicAgent::new(ModelConfig::anthropic("mock", "mock", "test"))
.with_provider_override(Arc::new(provider2))
.with_system_prompt("test");
agent2.restore_messages(&json).expect("restore");
let _ = agent2.prompt("Follow up").await;
assert_eq!(agent2.messages().len(), 4);
assert_eq!(agent2.messages()[0].role(), "user");
assert_eq!(agent2.messages()[1].role(), "assistant");
assert_eq!(agent2.messages()[2].role(), "user");
assert_eq!(agent2.messages()[3].role(), "assistant");
}
#[tokio::test]
async fn test_prompt_with_sender_streams_events() {
let provider = MockProvider::text("Hello!");
let mut agent = BasicAgent::new(ModelConfig::anthropic("mock", "mock", "test"))
.with_provider_override(Arc::new(provider))
.with_system_prompt("test");
let (tx, mut rx) = mpsc::unbounded_channel();
let event_count = Arc::new(AtomicUsize::new(0));
let count_clone = event_count.clone();
let consumer = tokio::spawn(async move {
while let Some(_event) = rx.recv().await {
count_clone.fetch_add(1, Ordering::SeqCst);
}
});
agent.prompt_with_sender("Hi there", tx).await;
consumer.await.unwrap();
assert!(event_count.load(Ordering::SeqCst) > 0);
assert_eq!(agent.messages().len(), 2); assert!(!agent.is_streaming());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_prompt_with_sender_real_time_streaming() {
let provider = MockProvider::text("Hello!");
let mut agent = BasicAgent::new(ModelConfig::anthropic("mock", "mock", "test"))
.with_provider_override(Arc::new(provider))
.with_system_prompt("test");
let (tx, mut rx) = mpsc::unbounded_channel();
let received_during = Arc::new(AtomicUsize::new(0));
let received_clone = received_during.clone();
let consumer = tokio::spawn(async move {
while let Some(_event) = rx.recv().await {
received_clone.fetch_add(1, Ordering::SeqCst);
}
});
agent.prompt_with_sender("Hello", tx).await;
consumer.await.unwrap();
assert!(received_during.load(Ordering::SeqCst) > 0);
assert_eq!(agent.messages().len(), 2);
}
#[tokio::test]
async fn test_prompt_messages_with_sender() {
let provider = MockProvider::text("Response");
let mut agent = BasicAgent::new(ModelConfig::anthropic("mock", "mock", "test"))
.with_provider_override(Arc::new(provider))
.with_system_prompt("test");
let (tx, mut rx) = mpsc::unbounded_channel();
let consumer = tokio::spawn(async move {
let mut events = Vec::new();
while let Some(event) = rx.recv().await {
events.push(event);
}
events
});
let msgs = vec![AgentMessage::Llm(LlmMessage::new(Message::user("Hello")))];
agent.prompt_messages_with_sender(msgs, tx).await;
let events = consumer.await.unwrap();
assert!(!events.is_empty());
assert_eq!(agent.messages().len(), 2);
}
#[tokio::test]
async fn test_continue_loop_with_sender() {
let provider = MockProvider::text("Continued response");
let mut agent = BasicAgent::new(ModelConfig::anthropic("mock", "mock", "test"))
.with_provider_override(Arc::new(provider))
.with_system_prompt("test");
agent.append_message(AgentMessage::Llm(LlmMessage::new(Message::user("Hello"))));
agent.append_message(AgentMessage::Llm(LlmMessage::new(Message::Assistant {
content: vec![Content::Text { text: "Hi!".into() }],
stop_reason: StopReason::Error,
model: "mock".into(),
provider: "mock".into(),
usage: Usage::default(),
timestamp: 0,
error_message: Some("rate limited".into()),
})));
agent.append_message(AgentMessage::Llm(LlmMessage::new(Message::user(
"Please try again",
))));
let (tx, mut rx) = mpsc::unbounded_channel();
let consumer = tokio::spawn(async move {
let mut events = Vec::new();
while let Some(event) = rx.recv().await {
events.push(event);
}
events
});
agent
.continue_loop_with_sender(tx, ContinuationKind::Default)
.await;
let events = consumer.await.unwrap();
assert!(!events.is_empty());
assert!(!agent.is_streaming());
}
#[tokio::test]
async fn test_prompt_with_sender_tools_restored() {
struct DummyTool;
#[async_trait::async_trait]
impl AgentTool for DummyTool {
fn name(&self) -> &str {
"dummy"
}
fn label(&self) -> &str {
"Dummy"
}
fn description(&self) -> &str {
"A dummy tool"
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({"type": "object"})
}
async fn execute(
&self,
_params: serde_json::Value,
_ctx: ToolContext,
) -> Result<ToolResult, ToolError> {
Ok(ToolResult {
content: vec![Content::Text { text: "ok".into() }],
details: serde_json::Value::Null,
child_loop_id: None,
})
}
}
let provider = MockProvider::text("Hello!");
let mut agent = BasicAgent::new(ModelConfig::anthropic("mock", "mock", "test"))
.with_provider_override(Arc::new(provider))
.with_system_prompt("test")
.with_tools(vec![Arc::new(DummyTool)]);
let (tx, mut rx) = mpsc::unbounded_channel();
let consumer = tokio::spawn(async move { while rx.recv().await.is_some() {} });
agent.prompt_with_sender("Hi", tx).await;
consumer.await.unwrap();
assert!(!agent.is_streaming());
let rx2 = agent.prompt("Follow up").await;
let mut rx2 = rx2;
while rx2.try_recv().is_ok() {}
assert_eq!(agent.messages().len(), 4); }