ai_agents_runtime/spawner/
registry.rs1use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use chrono::{DateTime, Utc};
8use parking_lot::RwLock;
9use serde::{Deserialize, Serialize};
10use tracing::{debug, info, warn};
11
12use crate::spec::AgentSpec;
13use crate::{Agent, RuntimeAgent};
14use ai_agents_core::{AgentError, AgentResponse, Result};
15
16use super::spawner::SpawnedAgent;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct SpawnedAgentInfo {
21 pub id: String,
22 pub name: String,
23 pub spawned_at: DateTime<Utc>,
24}
25
26pub struct AgentRegistry {
28 agents: RwLock<HashMap<String, Arc<SpawnedAgent>>>,
29 hooks: Option<Arc<dyn RegistryHooks>>,
30 send_with_context: bool,
32}
33
34impl AgentRegistry {
35 pub fn new() -> Self {
36 Self {
37 agents: RwLock::new(HashMap::new()),
38 hooks: None,
39 send_with_context: true,
40 }
41 }
42
43 pub fn with_hooks(mut self, hooks: Arc<dyn RegistryHooks>) -> Self {
45 self.hooks = Some(hooks);
46 self
47 }
48
49 pub fn with_send_context(mut self, enabled: bool) -> Self {
51 self.send_with_context = enabled;
52 self
53 }
54
55 pub async fn register(&self, agent: SpawnedAgent) -> Result<()> {
57 let id = agent.id.clone();
58 let spec_clone = agent.spec.clone();
59 {
60 let mut agents = self.agents.write();
61 if agents.contains_key(&id) {
62 return Err(AgentError::Config(format!(
63 "Agent already registered: {}",
64 id
65 )));
66 }
67 agents.insert(id.clone(), Arc::new(agent));
68 }
69 info!(agent_id = %id, "Agent registered in registry");
70 if let Some(ref hooks) = self.hooks {
71 hooks.on_agent_spawned(&id, &spec_clone).await;
72 }
73 Ok(())
74 }
75
76 pub fn get(&self, id: &str) -> Option<Arc<RuntimeAgent>> {
78 let agents = self.agents.read();
79 agents.get(id).map(|sa| Arc::clone(&sa.agent))
80 }
81
82 pub fn get_spawned(&self, id: &str) -> Option<Arc<SpawnedAgent>> {
84 let agents = self.agents.read();
85 agents.get(id).cloned()
86 }
87
88 pub fn list(&self) -> Vec<SpawnedAgentInfo> {
90 let agents = self.agents.read();
91 agents
92 .values()
93 .map(|sa| SpawnedAgentInfo {
94 id: sa.id.clone(),
95 name: sa.spec.name.clone(),
96 spawned_at: sa.spawned_at,
97 })
98 .collect()
99 }
100
101 pub fn list_with_specs(&self) -> Vec<ai_agents_core::SpawnedAgentEntry> {
103 let agents = self.agents.read();
104 agents
105 .values()
106 .filter_map(|sa| {
107 let spec_yaml = match serde_yaml::to_string(&sa.spec) {
108 Ok(y) => y,
109 Err(e) => {
110 warn!(agent_id = %sa.id, error = %e, "Failed to serialize agent spec");
111 return None;
112 }
113 };
114 Some(ai_agents_core::SpawnedAgentEntry {
115 id: sa.id.clone(),
116 name: sa.spec.name.clone(),
117 spec_yaml,
118 })
119 })
120 .collect()
121 }
122
123 pub async fn remove(&self, id: &str) -> Option<Arc<SpawnedAgent>> {
125 let removed = {
126 let mut agents = self.agents.write();
127 agents.remove(id)
128 };
129 if removed.is_some() {
130 info!(agent_id = %id, "Agent removed from registry");
131 if let Some(ref hooks) = self.hooks {
132 hooks.on_agent_removed(id).await;
133 }
134 } else {
135 debug!(agent_id = %id, "Attempted to remove non-existent agent");
136 }
137 removed
138 }
139
140 pub async fn send(&self, from: &str, to: &str, message: &str) -> Result<AgentResponse> {
142 let target = {
143 let agents = self.agents.read();
145 agents.get(to).map(|sa| Arc::clone(&sa.agent))
146 };
147 let target =
148 target.ok_or_else(|| AgentError::Other(format!("Target agent not found: {}", to)))?;
149
150 if let Some(ref hooks) = self.hooks {
151 hooks.on_message_sent(from, to, message).await;
152 }
153
154 let formatted = if self.send_with_context {
155 format!("[From {}]: {}", from, message)
156 } else {
157 message.to_string()
158 };
159
160 debug!(from = %from, to = %to, "Sending inter-agent message");
161 target.chat(&formatted).await
162 }
163
164 pub async fn broadcast(
168 &self,
169 from: &str,
170 message: &str,
171 ) -> Vec<(String, Result<AgentResponse>)> {
172 let targets: Vec<(String, Arc<RuntimeAgent>)> = {
173 let agents = self.agents.read();
174 agents
175 .iter()
176 .filter(|(id, _)| id.as_str() != from)
177 .map(|(id, sa)| (id.clone(), Arc::clone(&sa.agent)))
178 .collect()
179 };
180
181 if targets.is_empty() {
182 return Vec::new();
183 }
184
185 let formatted = if self.send_with_context {
186 format!("[From {}]: {}", from, message)
187 } else {
188 message.to_string()
189 };
190
191 debug!(
192 from = %from,
193 target_count = targets.len(),
194 "Broadcasting message"
195 );
196
197 let mut handles = Vec::with_capacity(targets.len());
198 for (id, agent) in targets {
199 let msg = formatted.clone();
200 handles.push(tokio::spawn(async move {
201 let result = agent.chat(&msg).await;
202 (id, result)
203 }));
204 }
205
206 let mut results = Vec::new();
207 for handle in handles {
208 match handle.await {
209 Ok((id, res)) => results.push((id, res)),
210 Err(e) => {
211 warn!(error = %e, "Broadcast task panicked");
212 }
213 }
214 }
215 results
216 }
217
218 pub fn count(&self) -> usize {
220 self.agents.read().len()
221 }
222
223 pub fn contains(&self, id: &str) -> bool {
225 self.agents.read().contains_key(id)
226 }
227}
228
229impl Default for AgentRegistry {
230 fn default() -> Self {
231 Self::new()
232 }
233}
234
235impl std::fmt::Debug for AgentRegistry {
237 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238 let count = self.agents.read().len();
239 f.debug_struct("AgentRegistry")
240 .field("agent_count", &count)
241 .field("send_with_context", &self.send_with_context)
242 .field("has_hooks", &self.hooks.is_some())
243 .finish()
244 }
245}
246
247#[async_trait]
249pub trait RegistryHooks: Send + Sync {
250 async fn on_agent_spawned(&self, _id: &str, _spec: &AgentSpec) {}
252
253 async fn on_agent_removed(&self, _id: &str) {}
255
256 async fn on_message_sent(&self, _from: &str, _to: &str, _message: &str) {}
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use crate::AgentBuilder;
264 use ai_agents_core::{
265 ChatMessage, FinishReason, LLMChunk, LLMConfig, LLMError, LLMFeature, LLMProvider,
266 LLMResponse,
267 };
268 use ai_agents_llm::LLMRegistry;
269 use std::sync::atomic::{AtomicU32, Ordering};
270
271 struct EchoProvider;
272
273 #[async_trait]
274 impl LLMProvider for EchoProvider {
275 async fn complete(
276 &self,
277 messages: &[ChatMessage],
278 _config: Option<&LLMConfig>,
279 ) -> std::result::Result<LLMResponse, LLMError> {
280 let last = messages
281 .last()
282 .map(|m| m.content.clone())
283 .unwrap_or_default();
284 Ok(LLMResponse::new(
285 format!("Echo: {}", last),
286 FinishReason::Stop,
287 ))
288 }
289
290 async fn complete_stream(
291 &self,
292 _messages: &[ChatMessage],
293 _config: Option<&LLMConfig>,
294 ) -> std::result::Result<
295 Box<dyn futures::Stream<Item = std::result::Result<LLMChunk, LLMError>> + Unpin + Send>,
296 LLMError,
297 > {
298 Err(LLMError::Other("not implemented".into()))
299 }
300
301 fn provider_name(&self) -> &str {
302 "echo"
303 }
304
305 fn supports(&self, _feature: LLMFeature) -> bool {
306 false
307 }
308 }
309
310 fn make_test_agent(name: &str) -> RuntimeAgent {
311 let mut registry = LLMRegistry::new();
312 registry.register("default", Arc::new(EchoProvider));
313
314 AgentBuilder::new()
315 .system_prompt(format!("You are {}.", name))
316 .llm_registry(registry)
317 .build()
318 .unwrap()
319 }
320
321 fn make_spawned(id: &str) -> SpawnedAgent {
322 let agent = make_test_agent(id);
323 SpawnedAgent {
324 id: id.to_string(),
325 agent: Arc::new(agent),
326 spec: AgentSpec {
327 name: id.to_string(),
328 ..AgentSpec::default()
329 },
330 spawned_at: Utc::now(),
331 }
332 }
333
334 #[tokio::test]
335 async fn test_register_and_get() {
336 let registry = AgentRegistry::new();
337 registry.register(make_spawned("agent_a")).await.unwrap();
338
339 assert!(registry.get("agent_a").is_some());
340 assert!(registry.get("agent_b").is_none());
341 assert_eq!(registry.count(), 1);
342 }
343
344 #[tokio::test]
345 async fn test_duplicate_register() {
346 let registry = AgentRegistry::new();
347 registry.register(make_spawned("dup")).await.unwrap();
348 let result = registry.register(make_spawned("dup")).await;
349 assert!(result.is_err());
350 }
351
352 #[tokio::test]
353 async fn test_list_and_remove() {
354 let registry = AgentRegistry::new();
355 registry.register(make_spawned("a")).await.unwrap();
356 registry.register(make_spawned("b")).await.unwrap();
357
358 assert_eq!(registry.list().len(), 2);
359
360 let removed = registry.remove("a").await;
361 assert!(removed.is_some());
362 assert_eq!(registry.count(), 1);
363 assert!(registry.get("a").is_none());
364 }
365
366 #[tokio::test]
367 async fn test_send_message() {
368 let registry = AgentRegistry::new();
369 registry.register(make_spawned("sender")).await.unwrap();
370 registry.register(make_spawned("receiver")).await.unwrap();
371
372 let response = registry.send("sender", "receiver", "hello").await.unwrap();
373 assert!(response.content.contains("hello"));
374 }
375
376 #[tokio::test]
377 async fn test_send_to_missing() {
378 let registry = AgentRegistry::new();
379 registry.register(make_spawned("sender")).await.unwrap();
380
381 let result = registry.send("sender", "nobody", "hello").await;
382 assert!(result.is_err());
383 }
384
385 #[tokio::test]
386 async fn test_broadcast() {
387 let registry = AgentRegistry::new();
388 registry
389 .register(make_spawned("broadcaster"))
390 .await
391 .unwrap();
392 registry.register(make_spawned("listener_1")).await.unwrap();
393 registry.register(make_spawned("listener_2")).await.unwrap();
394
395 let results = registry.broadcast("broadcaster", "hey everyone").await;
396 assert_eq!(results.len(), 2);
398 for (_, res) in &results {
399 assert!(res.is_ok());
400 }
401 }
402
403 #[tokio::test]
404 async fn test_hooks() {
405 struct CountingHooks {
406 spawned: AtomicU32,
407 removed: AtomicU32,
408 sent: AtomicU32,
409 }
410
411 #[async_trait]
412 impl RegistryHooks for CountingHooks {
413 async fn on_agent_spawned(&self, _id: &str, _spec: &AgentSpec) {
414 self.spawned.fetch_add(1, Ordering::Relaxed);
415 }
416 async fn on_agent_removed(&self, _id: &str) {
417 self.removed.fetch_add(1, Ordering::Relaxed);
418 }
419 async fn on_message_sent(&self, _from: &str, _to: &str, _msg: &str) {
420 self.sent.fetch_add(1, Ordering::Relaxed);
421 }
422 }
423
424 let hooks = Arc::new(CountingHooks {
425 spawned: AtomicU32::new(0),
426 removed: AtomicU32::new(0),
427 sent: AtomicU32::new(0),
428 });
429
430 let registry = AgentRegistry::new().with_hooks(hooks.clone());
431 registry.register(make_spawned("h1")).await.unwrap();
432 registry.register(make_spawned("h2")).await.unwrap();
433 assert_eq!(hooks.spawned.load(Ordering::Relaxed), 2);
434
435 registry.send("h1", "h2", "ping").await.unwrap();
436 assert_eq!(hooks.sent.load(Ordering::Relaxed), 1);
437
438 registry.remove("h1").await;
439 assert_eq!(hooks.removed.load(Ordering::Relaxed), 1);
440 }
441
442 #[tokio::test]
443 async fn test_contains() {
444 let registry = AgentRegistry::new();
445 assert!(!registry.contains("x"));
446 registry.register(make_spawned("x")).await.unwrap();
447 assert!(registry.contains("x"));
448 }
449
450 #[tokio::test]
451 async fn test_send_without_context() {
452 let registry = AgentRegistry::new().with_send_context(false);
453 registry.register(make_spawned("a")).await.unwrap();
454 registry.register(make_spawned("b")).await.unwrap();
455
456 let response = registry.send("a", "b", "raw msg").await.unwrap();
457 assert!(response.content.contains("raw msg"));
459 assert!(!response.content.contains("[From"));
460 }
461}