Skip to main content

ai_agents_runtime/spawner/
registry.rs

1//! Agent registry for tracking and messaging spawned agents.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use chrono::{DateTime, Utc};
8use parking_lot::RwLock;
9use serde::{Deserialize, Serialize};
10use tracing::{debug, info, warn};
11
12use crate::spec::AgentSpec;
13use crate::{Agent, RuntimeAgent};
14use ai_agents_core::{AgentError, AgentResponse, Result};
15
16use super::spawner::SpawnedAgent;
17
18/// Summary information for a registered agent, returned by `list()`.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct SpawnedAgentInfo {
21    pub id: String,
22    pub name: String,
23    pub spawned_at: DateTime<Utc>,
24}
25
26/// Tracks spawned agents and provides inter-agent messaging.
27pub struct AgentRegistry {
28    agents: RwLock<HashMap<String, Arc<SpawnedAgent>>>,
29    hooks: Option<Arc<dyn RegistryHooks>>,
30    /// When true, `send()` prefixes messages with `[From {sender}]: `.
31    send_with_context: bool,
32}
33
34impl AgentRegistry {
35    pub fn new() -> Self {
36        Self {
37            agents: RwLock::new(HashMap::new()),
38            hooks: None,
39            send_with_context: true,
40        }
41    }
42
43    /// Attach lifecycle hooks to the registry.
44    pub fn with_hooks(mut self, hooks: Arc<dyn RegistryHooks>) -> Self {
45        self.hooks = Some(hooks);
46        self
47    }
48
49    /// Configure whether `send()` injects sender identity into messages.
50    pub fn with_send_context(mut self, enabled: bool) -> Self {
51        self.send_with_context = enabled;
52        self
53    }
54
55    /// Register a spawned agent. Returns error if the ID already exists.
56    pub async fn register(&self, agent: SpawnedAgent) -> Result<()> {
57        let id = agent.id.clone();
58        let spec_clone = agent.spec.clone();
59        {
60            let mut agents = self.agents.write();
61            if agents.contains_key(&id) {
62                return Err(AgentError::Config(format!(
63                    "Agent already registered: {}",
64                    id
65                )));
66            }
67            agents.insert(id.clone(), Arc::new(agent));
68        }
69        info!(agent_id = %id, "Agent registered in registry");
70        if let Some(ref hooks) = self.hooks {
71            hooks.on_agent_spawned(&id, &spec_clone).await;
72        }
73        Ok(())
74    }
75
76    /// Clone an Arc handle to a registered agent's RuntimeAgent.
77    pub fn get(&self, id: &str) -> Option<Arc<RuntimeAgent>> {
78        let agents = self.agents.read();
79        agents.get(id).map(|sa| Arc::clone(&sa.agent))
80    }
81
82    /// Get the full SpawnedAgent metadata (agent + spec + timestamp).
83    pub fn get_spawned(&self, id: &str) -> Option<Arc<SpawnedAgent>> {
84        let agents = self.agents.read();
85        agents.get(id).cloned()
86    }
87
88    /// List metadata for all registered agents.
89    pub fn list(&self) -> Vec<SpawnedAgentInfo> {
90        let agents = self.agents.read();
91        agents
92            .values()
93            .map(|sa| SpawnedAgentInfo {
94                id: sa.id.clone(),
95                name: sa.spec.name.clone(),
96                spawned_at: sa.spawned_at,
97            })
98            .collect()
99    }
100
101    /// List all registered agents with their specs serialized as YAML for session persistence.
102    pub fn list_with_specs(&self) -> Vec<ai_agents_core::SpawnedAgentEntry> {
103        let agents = self.agents.read();
104        agents
105            .values()
106            .filter_map(|sa| {
107                let spec_yaml = match serde_yaml::to_string(&sa.spec) {
108                    Ok(y) => y,
109                    Err(e) => {
110                        warn!(agent_id = %sa.id, error = %e, "Failed to serialize agent spec");
111                        return None;
112                    }
113                };
114                Some(ai_agents_core::SpawnedAgentEntry {
115                    id: sa.id.clone(),
116                    name: sa.spec.name.clone(),
117                    spec_yaml,
118                })
119            })
120            .collect()
121    }
122
123    /// Remove an agent from the registry and return it.
124    pub async fn remove(&self, id: &str) -> Option<Arc<SpawnedAgent>> {
125        let removed = {
126            let mut agents = self.agents.write();
127            agents.remove(id)
128        };
129        if removed.is_some() {
130            info!(agent_id = %id, "Agent removed from registry");
131            if let Some(ref hooks) = self.hooks {
132                hooks.on_agent_removed(id).await;
133            }
134        } else {
135            debug!(agent_id = %id, "Attempted to remove non-existent agent");
136        }
137        removed
138    }
139
140    /// Send a message from one agent to another and return the response.
141    pub async fn send(&self, from: &str, to: &str, message: &str) -> Result<AgentResponse> {
142        let target = {
143            // The read lock is held only long enough to clone the target Arc, then released before the async `chat()` call.
144            let agents = self.agents.read();
145            agents.get(to).map(|sa| Arc::clone(&sa.agent))
146        };
147        let target =
148            target.ok_or_else(|| AgentError::Other(format!("Target agent not found: {}", to)))?;
149
150        if let Some(ref hooks) = self.hooks {
151            hooks.on_message_sent(from, to, message).await;
152        }
153
154        let formatted = if self.send_with_context {
155            format!("[From {}]: {}", from, message)
156        } else {
157            message.to_string()
158        };
159
160        debug!(from = %from, to = %to, "Sending inter-agent message");
161        target.chat(&formatted).await
162    }
163
164    /// Broadcast a message to all agents except the sender.
165    ///
166    /// Clones all target Arcs under a single brief read lock, then drives all `chat()` calls concurrently after releasing the lock.
167    pub async fn broadcast(
168        &self,
169        from: &str,
170        message: &str,
171    ) -> Vec<(String, Result<AgentResponse>)> {
172        let targets: Vec<(String, Arc<RuntimeAgent>)> = {
173            let agents = self.agents.read();
174            agents
175                .iter()
176                .filter(|(id, _)| id.as_str() != from)
177                .map(|(id, sa)| (id.clone(), Arc::clone(&sa.agent)))
178                .collect()
179        };
180
181        if targets.is_empty() {
182            return Vec::new();
183        }
184
185        let formatted = if self.send_with_context {
186            format!("[From {}]: {}", from, message)
187        } else {
188            message.to_string()
189        };
190
191        debug!(
192            from = %from,
193            target_count = targets.len(),
194            "Broadcasting message"
195        );
196
197        let mut handles = Vec::with_capacity(targets.len());
198        for (id, agent) in targets {
199            let msg = formatted.clone();
200            handles.push(tokio::spawn(async move {
201                let result = agent.chat(&msg).await;
202                (id, result)
203            }));
204        }
205
206        let mut results = Vec::new();
207        for handle in handles {
208            match handle.await {
209                Ok((id, res)) => results.push((id, res)),
210                Err(e) => {
211                    warn!(error = %e, "Broadcast task panicked");
212                }
213            }
214        }
215        results
216    }
217
218    /// Number of currently registered agents.
219    pub fn count(&self) -> usize {
220        self.agents.read().len()
221    }
222
223    /// Returns true if the registry contains an agent with this ID.
224    pub fn contains(&self, id: &str) -> bool {
225        self.agents.read().contains_key(id)
226    }
227}
228
229impl Default for AgentRegistry {
230    fn default() -> Self {
231        Self::new()
232    }
233}
234
235// Debug impl avoids printing agent internals.
236impl std::fmt::Debug for AgentRegistry {
237    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238        let count = self.agents.read().len();
239        f.debug_struct("AgentRegistry")
240            .field("agent_count", &count)
241            .field("send_with_context", &self.send_with_context)
242            .field("has_hooks", &self.hooks.is_some())
243            .finish()
244    }
245}
246
247/// Optional lifecycle hooks for registry events.
248#[async_trait]
249pub trait RegistryHooks: Send + Sync {
250    /// Called after an agent is successfully registered.
251    async fn on_agent_spawned(&self, _id: &str, _spec: &AgentSpec) {}
252
253    /// Called after an agent is removed from the registry.
254    async fn on_agent_removed(&self, _id: &str) {}
255
256    /// Called before a message is delivered via `send()`.
257    async fn on_message_sent(&self, _from: &str, _to: &str, _message: &str) {}
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use crate::AgentBuilder;
264    use ai_agents_core::{
265        ChatMessage, FinishReason, LLMChunk, LLMConfig, LLMError, LLMFeature, LLMProvider,
266        LLMResponse,
267    };
268    use ai_agents_llm::LLMRegistry;
269    use std::sync::atomic::{AtomicU32, Ordering};
270
271    struct EchoProvider;
272
273    #[async_trait]
274    impl LLMProvider for EchoProvider {
275        async fn complete(
276            &self,
277            messages: &[ChatMessage],
278            _config: Option<&LLMConfig>,
279        ) -> std::result::Result<LLMResponse, LLMError> {
280            let last = messages
281                .last()
282                .map(|m| m.content.clone())
283                .unwrap_or_default();
284            Ok(LLMResponse::new(
285                format!("Echo: {}", last),
286                FinishReason::Stop,
287            ))
288        }
289
290        async fn complete_stream(
291            &self,
292            _messages: &[ChatMessage],
293            _config: Option<&LLMConfig>,
294        ) -> std::result::Result<
295            Box<dyn futures::Stream<Item = std::result::Result<LLMChunk, LLMError>> + Unpin + Send>,
296            LLMError,
297        > {
298            Err(LLMError::Other("not implemented".into()))
299        }
300
301        fn provider_name(&self) -> &str {
302            "echo"
303        }
304
305        fn supports(&self, _feature: LLMFeature) -> bool {
306            false
307        }
308    }
309
310    fn make_test_agent(name: &str) -> RuntimeAgent {
311        let mut registry = LLMRegistry::new();
312        registry.register("default", Arc::new(EchoProvider));
313
314        AgentBuilder::new()
315            .system_prompt(format!("You are {}.", name))
316            .llm_registry(registry)
317            .build()
318            .unwrap()
319    }
320
321    fn make_spawned(id: &str) -> SpawnedAgent {
322        let agent = make_test_agent(id);
323        SpawnedAgent {
324            id: id.to_string(),
325            agent: Arc::new(agent),
326            spec: AgentSpec {
327                name: id.to_string(),
328                ..AgentSpec::default()
329            },
330            spawned_at: Utc::now(),
331        }
332    }
333
334    #[tokio::test]
335    async fn test_register_and_get() {
336        let registry = AgentRegistry::new();
337        registry.register(make_spawned("agent_a")).await.unwrap();
338
339        assert!(registry.get("agent_a").is_some());
340        assert!(registry.get("agent_b").is_none());
341        assert_eq!(registry.count(), 1);
342    }
343
344    #[tokio::test]
345    async fn test_duplicate_register() {
346        let registry = AgentRegistry::new();
347        registry.register(make_spawned("dup")).await.unwrap();
348        let result = registry.register(make_spawned("dup")).await;
349        assert!(result.is_err());
350    }
351
352    #[tokio::test]
353    async fn test_list_and_remove() {
354        let registry = AgentRegistry::new();
355        registry.register(make_spawned("a")).await.unwrap();
356        registry.register(make_spawned("b")).await.unwrap();
357
358        assert_eq!(registry.list().len(), 2);
359
360        let removed = registry.remove("a").await;
361        assert!(removed.is_some());
362        assert_eq!(registry.count(), 1);
363        assert!(registry.get("a").is_none());
364    }
365
366    #[tokio::test]
367    async fn test_send_message() {
368        let registry = AgentRegistry::new();
369        registry.register(make_spawned("sender")).await.unwrap();
370        registry.register(make_spawned("receiver")).await.unwrap();
371
372        let response = registry.send("sender", "receiver", "hello").await.unwrap();
373        assert!(response.content.contains("hello"));
374    }
375
376    #[tokio::test]
377    async fn test_send_to_missing() {
378        let registry = AgentRegistry::new();
379        registry.register(make_spawned("sender")).await.unwrap();
380
381        let result = registry.send("sender", "nobody", "hello").await;
382        assert!(result.is_err());
383    }
384
385    #[tokio::test]
386    async fn test_broadcast() {
387        let registry = AgentRegistry::new();
388        registry
389            .register(make_spawned("broadcaster"))
390            .await
391            .unwrap();
392        registry.register(make_spawned("listener_1")).await.unwrap();
393        registry.register(make_spawned("listener_2")).await.unwrap();
394
395        let results = registry.broadcast("broadcaster", "hey everyone").await;
396        // Should have 2 results (excluding broadcaster)
397        assert_eq!(results.len(), 2);
398        for (_, res) in &results {
399            assert!(res.is_ok());
400        }
401    }
402
403    #[tokio::test]
404    async fn test_hooks() {
405        struct CountingHooks {
406            spawned: AtomicU32,
407            removed: AtomicU32,
408            sent: AtomicU32,
409        }
410
411        #[async_trait]
412        impl RegistryHooks for CountingHooks {
413            async fn on_agent_spawned(&self, _id: &str, _spec: &AgentSpec) {
414                self.spawned.fetch_add(1, Ordering::Relaxed);
415            }
416            async fn on_agent_removed(&self, _id: &str) {
417                self.removed.fetch_add(1, Ordering::Relaxed);
418            }
419            async fn on_message_sent(&self, _from: &str, _to: &str, _msg: &str) {
420                self.sent.fetch_add(1, Ordering::Relaxed);
421            }
422        }
423
424        let hooks = Arc::new(CountingHooks {
425            spawned: AtomicU32::new(0),
426            removed: AtomicU32::new(0),
427            sent: AtomicU32::new(0),
428        });
429
430        let registry = AgentRegistry::new().with_hooks(hooks.clone());
431        registry.register(make_spawned("h1")).await.unwrap();
432        registry.register(make_spawned("h2")).await.unwrap();
433        assert_eq!(hooks.spawned.load(Ordering::Relaxed), 2);
434
435        registry.send("h1", "h2", "ping").await.unwrap();
436        assert_eq!(hooks.sent.load(Ordering::Relaxed), 1);
437
438        registry.remove("h1").await;
439        assert_eq!(hooks.removed.load(Ordering::Relaxed), 1);
440    }
441
442    #[tokio::test]
443    async fn test_contains() {
444        let registry = AgentRegistry::new();
445        assert!(!registry.contains("x"));
446        registry.register(make_spawned("x")).await.unwrap();
447        assert!(registry.contains("x"));
448    }
449
450    #[tokio::test]
451    async fn test_send_without_context() {
452        let registry = AgentRegistry::new().with_send_context(false);
453        registry.register(make_spawned("a")).await.unwrap();
454        registry.register(make_spawned("b")).await.unwrap();
455
456        let response = registry.send("a", "b", "raw msg").await.unwrap();
457        // Without context prefix, the message should be passed as-is
458        assert!(response.content.contains("raw msg"));
459        assert!(!response.content.contains("[From"));
460    }
461}