ai_agents_runtime/spawner/
spawner.rs1use std::collections::HashMap;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU32, Ordering};
6
7use chrono::Utc;
8use minijinja::Environment;
9use tracing::info;
10
11use crate::AgentBuilder;
12use crate::RuntimeAgent;
13use crate::spec::AgentSpec;
14use ai_agents_core::{AgentError, AgentStorage, Result};
15use ai_agents_llm::LLMRegistry;
16
17use super::storage::NamespacedStorage;
18
19#[derive(Debug, Clone)]
21pub struct ResolvedTemplate {
22 pub content: String,
24 pub description: Option<String>,
26 pub variables: Option<HashMap<String, String>>,
28}
29
30impl ResolvedTemplate {
31 pub fn from_content(content: impl Into<String>) -> Self {
33 Self {
34 content: content.into(),
35 description: None,
36 variables: None,
37 }
38 }
39}
40
41pub struct SpawnedAgent {
43 pub id: String,
45 pub agent: Arc<RuntimeAgent>,
47 pub spec: AgentSpec,
49 pub spawned_at: chrono::DateTime<Utc>,
51}
52
53impl std::fmt::Debug for SpawnedAgent {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 f.debug_struct("SpawnedAgent")
56 .field("id", &self.id)
57 .field("spawned_at", &self.spawned_at)
58 .finish_non_exhaustive()
59 }
60}
61
62pub struct AgentSpawner {
64 llm_registry: Option<LLMRegistry>,
66
67 storage: Option<Arc<dyn AgentStorage>>,
69
70 shared_context: HashMap<String, serde_json::Value>,
72
73 max_agents: Option<usize>,
75
76 name_prefix: Option<String>,
78
79 templates: HashMap<String, ResolvedTemplate>,
81
82 allowed_tools: Option<Vec<String>>,
84
85 counter: AtomicU32,
87
88 agent_count: AtomicU32,
90}
91
92impl AgentSpawner {
93 pub fn new() -> Self {
94 Self {
95 llm_registry: None,
96 storage: None,
97 shared_context: HashMap::new(),
98 max_agents: None,
99 name_prefix: None,
100 templates: HashMap::new(),
101 allowed_tools: None,
102 counter: AtomicU32::new(1),
103 agent_count: AtomicU32::new(0),
104 }
105 }
106
107 pub fn with_shared_llms(mut self, registry: LLMRegistry) -> Self {
109 self.llm_registry = Some(registry);
110 self
111 }
112
113 pub fn with_shared_storage(mut self, storage: Arc<dyn AgentStorage>) -> Self {
115 self.storage = Some(storage);
116 self
117 }
118
119 pub fn with_shared_context(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
121 self.shared_context.insert(key.into(), value);
122 self
123 }
124
125 pub fn with_shared_context_map(mut self, ctx: HashMap<String, serde_json::Value>) -> Self {
127 self.shared_context.extend(ctx);
128 self
129 }
130
131 pub fn with_max_agents(mut self, max: usize) -> Self {
133 self.max_agents = Some(max);
134 self
135 }
136
137 pub fn with_name_prefix(mut self, prefix: impl Into<String>) -> Self {
139 self.name_prefix = Some(prefix.into());
140 self
141 }
142
143 pub fn with_template(
145 mut self,
146 name: impl Into<String>,
147 yaml_template: impl Into<String>,
148 ) -> Self {
149 self.templates
150 .insert(name.into(), ResolvedTemplate::from_content(yaml_template));
151 self
152 }
153
154 pub fn with_templates(mut self, templates: HashMap<String, ResolvedTemplate>) -> Self {
156 self.templates.extend(templates);
157 self
158 }
159
160 pub fn with_allowed_tools(mut self, tools: Vec<String>) -> Self {
162 self.allowed_tools = Some(tools);
163 self
164 }
165
166 pub async fn spawn_from_yaml(&self, yaml: &str) -> Result<SpawnedAgent> {
168 let mut spec: AgentSpec = serde_yaml::from_str(yaml)?;
169 spec.validate()?;
170 self.enforce_tool_allowlist(&mut spec);
171 self.spawn_from_spec(spec).await
172 }
173
174 pub async fn spawn_from_spec(&self, spec: AgentSpec) -> Result<SpawnedAgent> {
176 let agent_id = self.generate_id(&spec.name);
177 self.spawn_inner(agent_id, spec).await
178 }
179
180 pub async fn spawn_with_id(&self, id: String, spec: AgentSpec) -> Result<SpawnedAgent> {
182 self.spawn_inner(id, spec).await
183 }
184
185 async fn spawn_inner(&self, agent_id: String, spec: AgentSpec) -> Result<SpawnedAgent> {
187 self.check_spawn_limit()?;
188
189 let mut builder = AgentBuilder::from_spec(spec.clone());
191
192 if let Some(ref shared_reg) = self.llm_registry {
194 builder = builder.llm_registry(shared_reg.clone());
195 }
196
197 if !spec.llms.is_empty() {
199 builder = builder.auto_configure_llms()?;
200 }
201
202 builder = builder.auto_configure_features()?;
204
205 if let Some(ref shared_storage) = self.storage {
207 let namespaced = Arc::new(NamespacedStorage::new(
208 Arc::clone(shared_storage),
209 &agent_id,
210 ));
211 builder = builder.storage(namespaced);
212 }
213
214 let agent = builder.build()?;
215
216 for (key, value) in &self.shared_context {
218 let _ = agent.set_context(key, value.clone());
219 }
220
221 self.agent_count.fetch_add(1, Ordering::Relaxed);
222
223 info!(agent_id = %agent_id, name = %spec.name, "Agent spawned");
224
225 Ok(SpawnedAgent {
226 id: agent_id,
227 agent: Arc::new(agent),
228 spec,
229 spawned_at: Utc::now(),
230 })
231 }
232
233 pub async fn spawn_from_template(
239 &self,
240 template_name: &str,
241 variables: HashMap<String, String>,
242 ) -> Result<SpawnedAgent> {
243 let template = self.templates.get(template_name).ok_or_else(|| {
244 AgentError::Config(format!("Spawner template not found: {}", template_name))
245 })?;
246
247 let rendered = self.render_template(&template.content, &variables)?;
248 self.spawn_from_yaml(&rendered).await
249 }
250
251 pub fn spawned_count(&self) -> u32 {
253 self.agent_count.load(Ordering::Relaxed)
254 }
255
256 pub fn notify_agent_removed(&self) {
258 let prev = self.agent_count.load(Ordering::Relaxed);
259 if prev > 0 {
260 self.agent_count.fetch_sub(1, Ordering::Relaxed);
261 }
262 }
263
264 pub fn llm_registry(&self) -> Option<&LLMRegistry> {
266 self.llm_registry.as_ref()
267 }
268
269 pub fn shared_storage(&self) -> Option<&Arc<dyn AgentStorage>> {
271 self.storage.as_ref()
272 }
273
274 pub fn templates(&self) -> &HashMap<String, ResolvedTemplate> {
276 &self.templates
277 }
278
279 fn check_spawn_limit(&self) -> Result<()> {
280 if let Some(max) = self.max_agents {
281 let current = self.agent_count.load(Ordering::Relaxed) as usize;
282 if current >= max {
283 return Err(AgentError::Config(format!(
284 "Spawn limit exceeded: {}/{}",
285 current, max
286 )));
287 }
288 }
289 Ok(())
290 }
291
292 fn generate_id(&self, spec_name: &str) -> String {
293 if let Some(ref prefix) = self.name_prefix {
294 let n = self.counter.fetch_add(1, Ordering::Relaxed);
295 format!("{}{:03}", prefix, n)
296 } else {
297 spec_name.to_lowercase().replace(' ', "_")
298 }
299 }
300
301 fn enforce_tool_allowlist(&self, spec: &mut AgentSpec) {
303 if let Some(ref allowed) = self.allowed_tools {
304 if let Some(ref mut tools) = spec.tools {
305 let before = tools.len();
306 tools.retain(|t| allowed.contains(&t.name().to_string()));
307 let removed = before - tools.len();
308 if removed > 0 {
309 tracing::warn!(
310 removed_count = removed,
311 "Stripped disallowed tools from spawned agent spec"
312 );
313 }
314 }
315 }
316 }
317
318 fn render_template(
320 &self,
321 template_str: &str,
322 variables: &HashMap<String, String>,
323 ) -> Result<String> {
324 let mut env = Environment::new();
325 env.add_template("_spawn", template_str)
326 .map_err(|e| AgentError::TemplateError(format!("template parse error: {}", e)))?;
327
328 let tmpl = env
329 .get_template("_spawn")
330 .map_err(|e| AgentError::TemplateError(format!("template load error: {}", e)))?;
331
332 let mut ctx = serde_json::Map::new();
334
335 for (k, v) in variables {
336 ctx.insert(k.clone(), serde_json::Value::String(v.clone()));
337 }
338
339 let context_obj = serde_json::Value::Object(
341 self.shared_context
342 .iter()
343 .map(|(k, v)| (k.clone(), v.clone()))
344 .collect(),
345 );
346 ctx.insert("context".to_string(), context_obj);
347
348 let ctx_value = serde_json::Value::Object(ctx);
349 let mj_value = minijinja::Value::from_serialize(&ctx_value);
350
351 tmpl.render(mj_value)
352 .map_err(|e| AgentError::TemplateError(format!("template render error: {}", e)))
353 }
354}
355
356impl Default for AgentSpawner {
357 fn default() -> Self {
358 Self::new()
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365
366 #[test]
367 fn test_generate_id_with_prefix() {
368 let spawner = AgentSpawner::new().with_name_prefix("npc_");
369 assert_eq!(spawner.generate_id("Gormund"), "npc_001");
370 assert_eq!(spawner.generate_id("Elena"), "npc_002");
371 }
372
373 #[test]
374 fn test_generate_id_without_prefix() {
375 let spawner = AgentSpawner::new();
376 assert_eq!(spawner.generate_id("My Agent"), "my_agent");
377 assert_eq!(spawner.generate_id("TestBot"), "testbot");
378 }
379
380 #[test]
381 fn test_check_spawn_limit() {
382 let spawner = AgentSpawner::new().with_max_agents(2);
383 assert!(spawner.check_spawn_limit().is_ok());
384 spawner.agent_count.store(2, Ordering::Relaxed);
385 assert!(spawner.check_spawn_limit().is_err());
386 }
387
388 #[test]
389 fn test_render_template_basic() {
390 let spawner = AgentSpawner::new()
391 .with_shared_context("world_name", serde_json::json!("Fantasy Land"));
392
393 let template =
394 "name: {{ name }}\nsystem_prompt: You are {{ name }} in {{ context.world_name }}.";
395 let mut vars = HashMap::new();
396 vars.insert("name".to_string(), "Gormund".to_string());
397
398 let rendered = spawner.render_template(template, &vars).unwrap();
399 assert!(rendered.contains("name: Gormund"));
400 assert!(rendered.contains("Fantasy Land"));
401 }
402
403 #[test]
404 fn test_enforce_tool_allowlist() {
405 let spawner = AgentSpawner::new()
406 .with_allowed_tools(vec!["echo".to_string(), "calculator".to_string()]);
407
408 let yaml = r#"
409name: Test
410system_prompt: test
411tools:
412 - echo
413 - calculator
414 - file
415 - http
416"#;
417 let mut spec: AgentSpec = serde_yaml::from_str(yaml).unwrap();
418 spawner.enforce_tool_allowlist(&mut spec);
419
420 let tool_names: Vec<_> = spec
421 .tools
422 .as_ref()
423 .unwrap()
424 .iter()
425 .map(|t| t.name().to_string())
426 .collect();
427 assert_eq!(tool_names, vec!["echo", "calculator"]);
428 }
429
430 #[test]
431 fn test_with_template_plain_string() {
432 let spawner =
433 AgentSpawner::new().with_template("basic", "name: {{ name }}\nsystem_prompt: hi");
434 let tpl = spawner.templates().get("basic").unwrap();
435 assert_eq!(tpl.content, "name: {{ name }}\nsystem_prompt: hi");
436 assert!(tpl.description.is_none());
437 assert!(tpl.variables.is_none());
438 }
439
440 #[test]
441 fn test_with_templates_resolved() {
442 let mut templates = HashMap::new();
443 templates.insert(
444 "base".to_string(),
445 ResolvedTemplate {
446 content: "name: {{ name }}".to_string(),
447 description: Some("Test template".to_string()),
448 variables: Some({
449 let mut v = HashMap::new();
450 v.insert("role".to_string(), "occupation".to_string());
451 v
452 }),
453 },
454 );
455 let spawner = AgentSpawner::new().with_templates(templates);
456 let tpl = spawner.templates().get("base").unwrap();
457 assert_eq!(tpl.description.as_deref(), Some("Test template"));
458 assert_eq!(
459 tpl.variables.as_ref().unwrap().get("role").unwrap(),
460 "occupation"
461 );
462 }
463}