neuron_runtime/
sub_agent.rs1use std::collections::HashMap;
7
8use neuron_loop::{AgentLoop, AgentResult, LoopConfig};
9use neuron_tool::ToolRegistry;
10use neuron_types::{
11 ContextStrategy, Message, Provider, SubAgentError, SystemPrompt, ToolContext,
12};
13
14#[derive(Debug, Clone)]
16pub struct SubAgentConfig {
17 pub system_prompt: SystemPrompt,
19 pub tools: Vec<String>,
21 pub model: Option<String>,
23 pub max_depth: usize,
25 pub max_turns: Option<usize>,
27}
28
29impl SubAgentConfig {
30 #[must_use]
32 pub fn new(system_prompt: SystemPrompt) -> Self {
33 Self {
34 system_prompt,
35 tools: Vec::new(),
36 model: None,
37 max_depth: 1,
38 max_turns: None,
39 }
40 }
41
42 #[must_use]
44 pub fn with_tools(mut self, tools: Vec<String>) -> Self {
45 self.tools = tools;
46 self
47 }
48
49 #[must_use]
51 pub fn with_max_depth(mut self, max_depth: usize) -> Self {
52 self.max_depth = max_depth;
53 self
54 }
55
56 #[must_use]
58 pub fn with_max_turns(mut self, max_turns: usize) -> Self {
59 self.max_turns = Some(max_turns);
60 self
61 }
62}
63
64pub struct SubAgentManager {
69 configs: HashMap<String, SubAgentConfig>,
70}
71
72impl SubAgentManager {
73 #[must_use]
75 pub fn new() -> Self {
76 Self {
77 configs: HashMap::new(),
78 }
79 }
80
81 pub fn register(&mut self, name: impl Into<String>, config: SubAgentConfig) {
83 self.configs.insert(name.into(), config);
84 }
85
86 #[must_use]
88 pub fn get(&self, name: &str) -> Option<&SubAgentConfig> {
89 self.configs.get(name)
90 }
91
92 #[must_use = "this returns a Result that should be handled"]
104 #[allow(clippy::too_many_arguments)]
105 pub async fn spawn<P: Provider, C: ContextStrategy>(
106 &self,
107 name: &str,
108 provider: P,
109 context: C,
110 parent_tools: &ToolRegistry,
111 user_message: Message,
112 tool_ctx: &ToolContext,
113 current_depth: usize,
114 ) -> Result<AgentResult, SubAgentError> {
115 let config = self
116 .configs
117 .get(name)
118 .ok_or_else(|| SubAgentError::NotFound(name.to_string()))?;
119
120 if current_depth >= config.max_depth {
121 return Err(SubAgentError::MaxDepthExceeded(config.max_depth));
122 }
123
124 let mut filtered_tools = ToolRegistry::new();
126 for tool_name in &config.tools {
127 if let Some(tool) = parent_tools.get(tool_name) {
128 filtered_tools.register_dyn(tool);
129 }
130 }
131
132 let loop_config = LoopConfig {
133 system_prompt: config.system_prompt.clone(),
134 max_turns: config.max_turns,
135 parallel_tool_execution: false,
136 };
137
138 let mut neuron_loop = AgentLoop::new(provider, filtered_tools, context, loop_config);
139 let result = neuron_loop.run(user_message, tool_ctx).await?;
140 Ok(result)
141 }
142
143 #[must_use = "this returns a Result that should be handled"]
159 pub async fn spawn_parallel<P, C>(
160 &self,
161 tasks: Vec<(String, P, C, Message)>,
162 parent_tools: &ToolRegistry,
163 tool_ctx: &ToolContext,
164 current_depth: usize,
165 ) -> Vec<Result<AgentResult, SubAgentError>>
166 where
167 P: Provider,
168 C: ContextStrategy,
169 {
170 let mut names = Vec::with_capacity(tasks.len());
172 let mut rest = Vec::with_capacity(tasks.len());
173 for (name, provider, context, message) in tasks {
174 names.push(name);
175 rest.push((provider, context, message));
176 }
177
178 let futs: Vec<_> = names
182 .iter()
183 .zip(rest)
184 .map(|(name, (provider, context, message))| {
185 self.spawn(name, provider, context, parent_tools, message, tool_ctx, current_depth)
186 })
187 .collect();
188
189 futures::future::join_all(futs).await
190 }
191}
192
193impl Default for SubAgentManager {
194 fn default() -> Self {
195 Self::new()
196 }
197}