use std::sync::{Arc, Mutex};
use crate::core::Swarm;
use crate::error::SwarmError;
use crate::event::{AgentEvent, EventSubscriber};
use crate::tool::{ClosureTool, InvocationArgs, Tool};
use crate::types::{
AgentFunction, AgentFunctionHandler, ContextVariables, Instructions, Message, ResultType,
};
use async_trait::async_trait;
use serde_json::json;
struct CollectingSubscriber {
events: Mutex<Vec<AgentEvent>>,
}
impl CollectingSubscriber {
fn new() -> Arc<Self> {
Arc::new(Self {
events: Mutex::new(Vec::new()),
})
}
fn collected(&self) -> Vec<AgentEvent> {
self.events.lock().unwrap().clone()
}
}
#[async_trait]
impl EventSubscriber for CollectingSubscriber {
async fn on_event(&self, event: &AgentEvent) {
self.events.lock().unwrap().push(event.clone());
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::Agent;
use wiremock::matchers::method;
use wiremock::{Mock, MockServer, ResponseTemplate};
fn text_agent(name: &str, instructions: &str) -> Agent {
Agent::new(name, "gpt-4", Instructions::Text(instructions.to_string()))
.expect("agent creation failed")
}
#[tokio::test]
async fn test_closure_tool_executes() {
let fn_arc: Arc<AgentFunctionHandler> = Arc::new(|_ctx: ContextVariables| {
Box::pin(async move { Ok(ResultType::Value("hello from closure".to_string())) })
});
let agent_fn = AgentFunction::new("greet", fn_arc, false).expect("valid agent function");
let tool = ClosureTool::from_agent_function(agent_fn).with_description("greet the user");
assert_eq!(tool.name(), "greet");
assert_eq!(tool.description(), "greet the user");
let args = InvocationArgs::from_value(json!({})).unwrap();
let result = tool.execute(args).await.expect("execute failed");
assert_eq!(result, json!("hello from closure"));
}
#[test]
fn test_max_iterations_error_fields() {
let err = SwarmError::MaxIterationsError { max: 5, actual: 7 };
let msg = err.to_string();
assert!(msg.contains('5'), "expected max in message: {}", msg);
assert!(msg.contains('7'), "expected actual in message: {}", msg);
match err {
SwarmError::MaxIterationsError { max, actual } => {
assert_eq!(max, 5);
assert_eq!(actual, 7);
}
_ => panic!("wrong variant"),
}
}
#[tokio::test]
async fn test_loop_events_emitted() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"id": "chatcmpl-test",
"object": "chat.completion",
"created": 0,
"model": "gpt-4",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "done"
},
"finish_reason": "stop"
}],
"usage": null
})))
.mount(&mock_server)
.await;
let collector = CollectingSubscriber::new();
let agent = text_agent("helper", "You are a helpful assistant.");
let swarm = Swarm::builder()
.with_api_key("sk-test".to_string())
.with_api_url(mock_server.uri())
.with_agent(agent.clone())
.with_subscriber(collector.clone())
.build()
.expect("build failed");
let messages = vec![Message::user("hi").expect("message failed")];
swarm
.run(
agent,
messages,
ContextVariables::new(),
None,
false,
false,
5,
)
.await
.expect("run failed");
let events = collector.collected();
assert!(
events
.iter()
.any(|e| matches!(e, AgentEvent::LoopStart { .. })),
"LoopStart not emitted; got: {:?}",
events.iter().map(|e| e.to_string()).collect::<Vec<_>>()
);
assert!(
events
.iter()
.any(|e| matches!(e, AgentEvent::LlmRequest { .. })),
"LlmRequest not emitted"
);
assert!(
events
.iter()
.any(|e| matches!(e, AgentEvent::LlmResponse { .. })),
"LlmResponse not emitted"
);
assert!(
events
.iter()
.any(|e| matches!(e, AgentEvent::LoopEnd { .. })),
"LoopEnd not emitted"
);
}
#[tokio::test]
async fn test_open_ai_provider_complete() {
use crate::provider::{CompletionRequest, LlmProvider, OpenAiProvider};
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(
ResponseTemplate::new(200).set_body_raw(
json!({
"id": "cmp-1",
"object": "chat.completion",
"created": 0,
"model": "gpt-4",
"choices": [{
"index": 0,
"message": { "role": "assistant", "content": "pong" },
"finish_reason": "stop"
}],
"usage": null
})
.to_string(),
"application/json",
),
)
.mount(&mock_server)
.await;
let provider = OpenAiProvider::new(reqwest::Client::new(), "sk-test", mock_server.uri());
let req = CompletionRequest::new("gpt-4", vec![Message::user("ping").expect("msg")]);
let resp = provider.complete(req).await.expect("complete failed");
assert_eq!(resp.text().as_deref(), Some("pong"));
assert_eq!(resp.model, "gpt-4");
}
}