agentic_payments/system/
pool.rs1use crate::agents::{BasicVerificationAgent, VerificationAgent};
4use crate::error::{Error, Result};
5use dashmap::DashMap;
6use std::sync::Arc;
7use uuid::Uuid;
8
9pub struct AgentPool {
11 agents: DashMap<Uuid, Arc<dyn VerificationAgent>>,
12 max_size: usize,
13}
14
15impl AgentPool {
16 pub fn new(max_size: usize) -> Self {
18 Self {
19 agents: DashMap::new(),
20 max_size,
21 }
22 }
23
24 pub async fn add_agent(&mut self, agent: Arc<dyn VerificationAgent>) -> Result<()> {
26 if self.agents.len() >= self.max_size {
27 return Err(Error::agent_pool(format!(
28 "Pool is full (max: {})",
29 self.max_size
30 )));
31 }
32
33 let agent_id = agent.id();
34
35 agent.health_check().await?;
37
38 self.agents.insert(agent_id, agent);
39 tracing::info!("Added agent {} to pool", agent_id);
40
41 Ok(())
42 }
43
44 pub async fn remove_agent(&mut self, agent_id: Uuid) -> Result<()> {
46 self.agents
47 .remove(&agent_id)
48 .ok_or_else(|| Error::agent_pool(format!("Agent {} not found in pool", agent_id)))?;
49
50 tracing::info!("Removed agent {} from pool", agent_id);
51 Ok(())
52 }
53
54 pub fn get_agent(&self, agent_id: Uuid) -> Option<Arc<dyn VerificationAgent>> {
56 self.agents.get(&agent_id).map(|r| Arc::clone(&r))
57 }
58
59 pub fn get_all_agents(&self) -> Vec<Arc<dyn VerificationAgent>> {
61 self.agents.iter().map(|r| Arc::clone(&r)).collect()
62 }
63
64 pub fn size(&self) -> usize {
66 self.agents.len()
67 }
68
69 pub fn is_empty(&self) -> bool {
71 self.agents.is_empty()
72 }
73
74 pub async fn scale(&mut self, target_size: usize) -> Result<()> {
76 if target_size > self.max_size {
77 return Err(Error::agent_pool(format!(
78 "Target size {} exceeds maximum {}",
79 target_size, self.max_size
80 )));
81 }
82
83 let current_size = self.size();
84
85 if target_size > current_size {
86 let agents_to_add = target_size - current_size;
88 for _ in 0..agents_to_add {
89 let agent = BasicVerificationAgent::new()?;
90 self.add_agent(Arc::new(agent) as Arc<dyn VerificationAgent>)
91 .await?;
92 }
93 tracing::info!("Scaled pool up from {} to {} agents", current_size, target_size);
94 } else if target_size < current_size {
95 let agents_to_remove = current_size - target_size;
97 let agent_ids: Vec<Uuid> = self.agents.iter().take(agents_to_remove).map(|r| *r.key()).collect();
98
99 for agent_id in agent_ids {
100 self.remove_agent(agent_id).await?;
101 }
102 tracing::info!("Scaled pool down from {} to {} agents", current_size, target_size);
103 }
104
105 Ok(())
106 }
107
108 pub async fn health_check_all(&self) -> Result<()> {
110 let mut unhealthy_agents = Vec::new();
111
112 for entry in self.agents.iter() {
113 let agent_id = *entry.key();
114 let agent = entry.value();
115
116 if let Err(e) = agent.health_check().await {
117 tracing::warn!("Agent {} failed health check: {}", agent_id, e);
118 unhealthy_agents.push(agent_id);
119 }
120 }
121
122 if !unhealthy_agents.is_empty() {
123 return Err(Error::health_check(format!(
124 "{} agents failed health check: {:?}",
125 unhealthy_agents.len(),
126 unhealthy_agents
127 )));
128 }
129
130 Ok(())
131 }
132
133 pub async fn shutdown(&mut self) -> Result<()> {
135 tracing::info!("Shutting down agent pool with {} agents", self.size());
136 self.agents.clear();
137 tracing::info!("Agent pool shutdown complete");
138 Ok(())
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145
146 #[tokio::test]
147 async fn test_pool_creation() {
148 let pool = AgentPool::new(10);
149 assert_eq!(pool.size(), 0);
150 assert!(pool.is_empty());
151 }
152
153 #[tokio::test]
154 async fn test_add_agent() {
155 let mut pool = AgentPool::new(10);
156 let agent = BasicVerificationAgent::new().unwrap();
157 let agent_id = agent.id();
158
159 pool.add_agent(Arc::new(agent) as Arc<dyn VerificationAgent>)
160 .await
161 .unwrap();
162
163 assert_eq!(pool.size(), 1);
164 assert!(pool.get_agent(agent_id).is_some());
165 }
166
167 #[tokio::test]
168 async fn test_remove_agent() {
169 let mut pool = AgentPool::new(10);
170 let agent = BasicVerificationAgent::new().unwrap();
171 let agent_id = agent.id();
172
173 pool.add_agent(Arc::new(agent) as Arc<dyn VerificationAgent>)
174 .await
175 .unwrap();
176 pool.remove_agent(agent_id).await.unwrap();
177
178 assert_eq!(pool.size(), 0);
179 }
180
181 #[tokio::test]
182 async fn test_scale_up() {
183 let mut pool = AgentPool::new(10);
184 pool.scale(5).await.unwrap();
185 assert_eq!(pool.size(), 5);
186 }
187
188 #[tokio::test]
189 async fn test_scale_down() {
190 let mut pool = AgentPool::new(10);
191 pool.scale(5).await.unwrap();
192 pool.scale(3).await.unwrap();
193 assert_eq!(pool.size(), 3);
194 }
195
196 #[tokio::test]
197 async fn test_pool_max_size() {
198 let mut pool = AgentPool::new(3);
199 pool.scale(3).await.unwrap();
200
201 let agent = BasicVerificationAgent::new().unwrap();
202 let result = pool.add_agent(Arc::new(agent) as Arc<dyn VerificationAgent>).await;
203 assert!(result.is_err());
204 }
205}