agentic_payments/system/
pool.rs

1//! Agent pool management
2
3use crate::agents::{BasicVerificationAgent, VerificationAgent};
4use crate::error::{Error, Result};
5use dashmap::DashMap;
6use std::sync::Arc;
7use uuid::Uuid;
8
9/// Agent pool for managing verification agents
10pub struct AgentPool {
11    agents: DashMap<Uuid, Arc<dyn VerificationAgent>>,
12    max_size: usize,
13}
14
15impl AgentPool {
16    /// Create a new agent pool
17    pub fn new(max_size: usize) -> Self {
18        Self {
19            agents: DashMap::new(),
20            max_size,
21        }
22    }
23
24    /// Add an agent to the pool
25    pub async fn add_agent(&mut self, agent: Arc<dyn VerificationAgent>) -> Result<()> {
26        if self.agents.len() >= self.max_size {
27            return Err(Error::agent_pool(format!(
28                "Pool is full (max: {})",
29                self.max_size
30            )));
31        }
32
33        let agent_id = agent.id();
34
35        // Perform health check before adding
36        agent.health_check().await?;
37
38        self.agents.insert(agent_id, agent);
39        tracing::info!("Added agent {} to pool", agent_id);
40
41        Ok(())
42    }
43
44    /// Remove an agent from the pool
45    pub async fn remove_agent(&mut self, agent_id: Uuid) -> Result<()> {
46        self.agents
47            .remove(&agent_id)
48            .ok_or_else(|| Error::agent_pool(format!("Agent {} not found in pool", agent_id)))?;
49
50        tracing::info!("Removed agent {} from pool", agent_id);
51        Ok(())
52    }
53
54    /// Get an agent by ID
55    pub fn get_agent(&self, agent_id: Uuid) -> Option<Arc<dyn VerificationAgent>> {
56        self.agents.get(&agent_id).map(|r| Arc::clone(&r))
57    }
58
59    /// Get all agents
60    pub fn get_all_agents(&self) -> Vec<Arc<dyn VerificationAgent>> {
61        self.agents.iter().map(|r| Arc::clone(&r)).collect()
62    }
63
64    /// Get pool size
65    pub fn size(&self) -> usize {
66        self.agents.len()
67    }
68
69    /// Check if pool is empty
70    pub fn is_empty(&self) -> bool {
71        self.agents.is_empty()
72    }
73
74    /// Scale the pool to target size
75    pub async fn scale(&mut self, target_size: usize) -> Result<()> {
76        if target_size > self.max_size {
77            return Err(Error::agent_pool(format!(
78                "Target size {} exceeds maximum {}",
79                target_size, self.max_size
80            )));
81        }
82
83        let current_size = self.size();
84
85        if target_size > current_size {
86            // Scale up - add agents
87            let agents_to_add = target_size - current_size;
88            for _ in 0..agents_to_add {
89                let agent = BasicVerificationAgent::new()?;
90                self.add_agent(Arc::new(agent) as Arc<dyn VerificationAgent>)
91                    .await?;
92            }
93            tracing::info!("Scaled pool up from {} to {} agents", current_size, target_size);
94        } else if target_size < current_size {
95            // Scale down - remove agents
96            let agents_to_remove = current_size - target_size;
97            let agent_ids: Vec<Uuid> = self.agents.iter().take(agents_to_remove).map(|r| *r.key()).collect();
98
99            for agent_id in agent_ids {
100                self.remove_agent(agent_id).await?;
101            }
102            tracing::info!("Scaled pool down from {} to {} agents", current_size, target_size);
103        }
104
105        Ok(())
106    }
107
108    /// Perform health check on all agents
109    pub async fn health_check_all(&self) -> Result<()> {
110        let mut unhealthy_agents = Vec::new();
111
112        for entry in self.agents.iter() {
113            let agent_id = *entry.key();
114            let agent = entry.value();
115
116            if let Err(e) = agent.health_check().await {
117                tracing::warn!("Agent {} failed health check: {}", agent_id, e);
118                unhealthy_agents.push(agent_id);
119            }
120        }
121
122        if !unhealthy_agents.is_empty() {
123            return Err(Error::health_check(format!(
124                "{} agents failed health check: {:?}",
125                unhealthy_agents.len(),
126                unhealthy_agents
127            )));
128        }
129
130        Ok(())
131    }
132
133    /// Shutdown all agents and clear the pool
134    pub async fn shutdown(&mut self) -> Result<()> {
135        tracing::info!("Shutting down agent pool with {} agents", self.size());
136        self.agents.clear();
137        tracing::info!("Agent pool shutdown complete");
138        Ok(())
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    #[tokio::test]
147    async fn test_pool_creation() {
148        let pool = AgentPool::new(10);
149        assert_eq!(pool.size(), 0);
150        assert!(pool.is_empty());
151    }
152
153    #[tokio::test]
154    async fn test_add_agent() {
155        let mut pool = AgentPool::new(10);
156        let agent = BasicVerificationAgent::new().unwrap();
157        let agent_id = agent.id();
158
159        pool.add_agent(Arc::new(agent) as Arc<dyn VerificationAgent>)
160            .await
161            .unwrap();
162
163        assert_eq!(pool.size(), 1);
164        assert!(pool.get_agent(agent_id).is_some());
165    }
166
167    #[tokio::test]
168    async fn test_remove_agent() {
169        let mut pool = AgentPool::new(10);
170        let agent = BasicVerificationAgent::new().unwrap();
171        let agent_id = agent.id();
172
173        pool.add_agent(Arc::new(agent) as Arc<dyn VerificationAgent>)
174            .await
175            .unwrap();
176        pool.remove_agent(agent_id).await.unwrap();
177
178        assert_eq!(pool.size(), 0);
179    }
180
181    #[tokio::test]
182    async fn test_scale_up() {
183        let mut pool = AgentPool::new(10);
184        pool.scale(5).await.unwrap();
185        assert_eq!(pool.size(), 5);
186    }
187
188    #[tokio::test]
189    async fn test_scale_down() {
190        let mut pool = AgentPool::new(10);
191        pool.scale(5).await.unwrap();
192        pool.scale(3).await.unwrap();
193        assert_eq!(pool.size(), 3);
194    }
195
196    #[tokio::test]
197    async fn test_pool_max_size() {
198        let mut pool = AgentPool::new(3);
199        pool.scale(3).await.unwrap();
200
201        let agent = BasicVerificationAgent::new().unwrap();
202        let result = pool.add_agent(Arc::new(agent) as Arc<dyn VerificationAgent>).await;
203        assert!(result.is_err());
204    }
205}