use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::{AgentMetadata, OrchestrationError, OrchestrationResult, SimpleAgent, Task};
pub struct AgentPool {
agents: Arc<RwLock<HashMap<String, Arc<dyn SimpleAgent>>>>,
}
impl AgentPool {
pub fn new() -> Self {
Self {
agents: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn register_agent(&self, agent: Arc<dyn SimpleAgent>) -> OrchestrationResult<()> {
let agent_id = agent.agent_id().to_string();
let mut agents = self.agents.write().await;
if agents.contains_key(&agent_id) {
return Err(OrchestrationError::AgentPoolError(format!(
"Agent {} already registered",
agent_id
)));
}
agents.insert(agent_id, agent);
Ok(())
}
pub async fn unregister_agent(&self, agent_id: &str) -> OrchestrationResult<()> {
let mut agents = self.agents.write().await;
if agents.remove(agent_id).is_none() {
return Err(OrchestrationError::AgentNotFound(agent_id.to_string()));
}
Ok(())
}
pub async fn get_agent(&self, agent_id: &str) -> OrchestrationResult<Arc<dyn SimpleAgent>> {
let agents = self.agents.read().await;
agents
.get(agent_id)
.cloned()
.ok_or_else(|| OrchestrationError::AgentNotFound(agent_id.to_string()))
}
pub async fn find_suitable_agents(
&self,
task: &Task,
) -> OrchestrationResult<Vec<Arc<dyn SimpleAgent>>> {
let agents = self.agents.read().await;
let mut suitable_agents = Vec::new();
for agent in agents.values() {
if agent.can_handle_task(task) {
suitable_agents.push(agent.clone());
}
}
Ok(suitable_agents)
}
pub async fn list_agents(&self) -> Vec<Arc<dyn SimpleAgent>> {
let agents = self.agents.read().await;
agents.values().cloned().collect()
}
pub async fn get_all_metadata(&self) -> Vec<AgentMetadata> {
let agents = self.agents.read().await;
agents.values().map(|agent| agent.metadata()).collect()
}
pub async fn agent_count(&self) -> usize {
let agents = self.agents.read().await;
agents.len()
}
pub async fn has_agent(&self, agent_id: &str) -> bool {
let agents = self.agents.read().await;
agents.contains_key(agent_id)
}
}
impl Default for AgentPool {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{ExampleAgent, TaskComplexity};
#[tokio::test]
async fn test_agent_pool_registration() {
let pool = AgentPool::new();
let agent = Arc::new(ExampleAgent::new(
"test_agent".to_string(),
vec!["capability1".to_string()],
));
assert!(pool.register_agent(agent.clone()).await.is_ok());
assert_eq!(pool.agent_count().await, 1);
assert!(pool.has_agent("test_agent").await);
assert!(pool.register_agent(agent).await.is_err());
}
#[tokio::test]
async fn test_agent_pool_unregistration() {
let pool = AgentPool::new();
let agent = Arc::new(ExampleAgent::new(
"test_agent".to_string(),
vec!["capability1".to_string()],
));
pool.register_agent(agent).await.unwrap();
assert_eq!(pool.agent_count().await, 1);
pool.unregister_agent("test_agent").await.unwrap();
assert_eq!(pool.agent_count().await, 0);
assert!(!pool.has_agent("test_agent").await);
assert!(pool.unregister_agent("non_existent").await.is_err());
}
#[tokio::test]
async fn test_find_suitable_agents() {
let pool = AgentPool::new();
let agent1 = Arc::new(ExampleAgent::new(
"agent1".to_string(),
vec!["analysis".to_string()],
));
let agent2 = Arc::new(ExampleAgent::new(
"agent2".to_string(),
vec!["processing".to_string()],
));
pool.register_agent(agent1).await.unwrap();
pool.register_agent(agent2).await.unwrap();
let mut task = Task::new(
"test_task".to_string(),
"Test task".to_string(),
TaskComplexity::Simple,
1,
);
task.required_capabilities.push("analysis".to_string());
let suitable_agents = pool.find_suitable_agents(&task).await.unwrap();
assert_eq!(suitable_agents.len(), 1);
assert_eq!(suitable_agents[0].agent_id(), "agent1");
}
#[tokio::test]
async fn test_get_agent() {
let pool = AgentPool::new();
let agent = Arc::new(ExampleAgent::new(
"test_agent".to_string(),
vec!["capability1".to_string()],
));
pool.register_agent(agent).await.unwrap();
let retrieved_agent = pool.get_agent("test_agent").await.unwrap();
assert_eq!(retrieved_agent.agent_id(), "test_agent");
assert!(pool.get_agent("non_existent").await.is_err());
}
#[tokio::test]
async fn test_list_agents() {
let pool = AgentPool::new();
let agent1 = Arc::new(ExampleAgent::new("agent1".to_string(), vec![]));
let agent2 = Arc::new(ExampleAgent::new("agent2".to_string(), vec![]));
pool.register_agent(agent1).await.unwrap();
pool.register_agent(agent2).await.unwrap();
let agents = pool.list_agents().await;
assert_eq!(agents.len(), 2);
let agent_ids: std::collections::HashSet<&str> =
agents.iter().map(|a| a.agent_id()).collect();
assert!(agent_ids.contains("agent1"));
assert!(agent_ids.contains("agent2"));
}
}