1pub mod builtin;
6
7use crate::config::PermissionAction;
8use crate::provider::{CompletionRequest, ContentPart, Message, Provider, Role};
9use crate::session::Session;
10use crate::swarm::{Actor, ActorStatus, Handler, SwarmMessage};
11use crate::tool::{Tool, ToolRegistry, ToolResult};
12use anyhow::Result;
13use async_trait::async_trait;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::sync::Arc;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct AgentInfo {
21 pub name: String,
22 pub description: Option<String>,
23 pub mode: AgentMode,
24 pub native: bool,
25 pub hidden: bool,
26 pub model: Option<String>,
27 pub temperature: Option<f32>,
28 pub top_p: Option<f32>,
29 pub max_steps: Option<usize>,
30}
31
32#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
33#[serde(rename_all = "lowercase")]
34pub enum AgentMode {
35 Primary,
36 Subagent,
37 All,
38}
39
40pub struct Agent {
42 pub info: AgentInfo,
43 pub provider: Arc<dyn Provider>,
44 pub tools: ToolRegistry,
45 pub permissions: HashMap<String, PermissionAction>,
46 system_prompt: String,
47}
48
49impl Agent {
50 pub fn new(
52 info: AgentInfo,
53 provider: Arc<dyn Provider>,
54 tools: ToolRegistry,
55 system_prompt: String,
56 ) -> Self {
57 Self {
58 info,
59 provider,
60 tools,
61 permissions: HashMap::new(),
62 system_prompt,
63 }
64 }
65
66 pub async fn execute(&self, session: &mut Session, prompt: &str) -> Result<AgentResponse> {
68 session.add_message(Message {
70 role: Role::User,
71 content: vec![ContentPart::Text {
72 text: prompt.to_string(),
73 }],
74 });
75
76 let mut steps = 0;
77 let max_steps = self.info.max_steps.unwrap_or(100);
78
79 loop {
80 steps += 1;
81 if steps > max_steps {
82 anyhow::bail!("Exceeded maximum steps ({})", max_steps);
83 }
84
85 let request = CompletionRequest {
87 messages: self.build_messages(session),
88 tools: self.tools.definitions(),
89 model: self
90 .info
91 .model
92 .clone()
93 .unwrap_or_else(|| "gpt-4o".to_string()),
94 temperature: self.info.temperature,
95 top_p: self.info.top_p,
96 max_tokens: None,
97 stop: vec![],
98 };
99
100 let response = self.provider.complete(request).await?;
102 session.add_message(response.message.clone());
103
104 let tool_calls: Vec<_> = response
106 .message
107 .content
108 .iter()
109 .filter_map(|p| match p {
110 ContentPart::ToolCall {
111 id,
112 name,
113 arguments,
114 } => Some((id.clone(), name.clone(), arguments.clone())),
115 _ => None,
116 })
117 .collect();
118
119 if tool_calls.is_empty() {
120 let text = response
122 .message
123 .content
124 .iter()
125 .filter_map(|p| match p {
126 ContentPart::Text { text } => Some(text.clone()),
127 _ => None,
128 })
129 .collect::<Vec<_>>()
130 .join("\n");
131
132 return Ok(AgentResponse {
133 text,
134 tool_uses: session.tool_uses.clone(),
135 usage: session.usage.clone(),
136 });
137 }
138
139 for (id, name, arguments) in tool_calls {
141 let result = self.execute_tool(&name, &arguments).await;
142
143 session.tool_uses.push(ToolUse {
144 id: id.clone(),
145 name: name.clone(),
146 input: arguments.clone(),
147 output: result.output.clone(),
148 success: result.success,
149 });
150
151 session.add_message(Message {
152 role: Role::Tool,
153 content: vec![ContentPart::ToolResult {
154 tool_call_id: id,
155 content: result.output,
156 }],
157 });
158 }
159 }
160 }
161
162 fn build_messages(&self, session: &Session) -> Vec<Message> {
164 let mut messages = vec![Message {
165 role: Role::System,
166 content: vec![ContentPart::Text {
167 text: self.system_prompt.clone(),
168 }],
169 }];
170 messages.extend(session.messages.clone());
171 messages
172 }
173
174 async fn execute_tool(&self, name: &str, arguments: &str) -> ToolResult {
176 if let Some(permission) = self.permissions.get(name) {
178 tracing::debug!(tool = name, permission = ?permission, "Checking tool permission");
179 }
182
183 match self.tools.get(name) {
184 Some(tool) => {
185 let args: serde_json::Value = match serde_json::from_str(arguments) {
186 Ok(v) => v,
187 Err(e) => {
188 return ToolResult {
189 output: format!("Failed to parse arguments: {}", e),
190 success: false,
191 metadata: HashMap::new(),
192 };
193 }
194 };
195
196 match tool.execute(args).await {
197 Ok(result) => result,
198 Err(e) => ToolResult {
199 output: format!("Tool execution failed: {}", e),
200 success: false,
201 metadata: HashMap::new(),
202 },
203 }
204 }
205 None => {
206 let available_tools = self.tools.list().iter().map(|s| s.to_string()).collect();
208 let invalid_tool = crate::tool::invalid::InvalidTool::with_context(
209 name.to_string(),
210 available_tools,
211 );
212 let args = serde_json::json!({
213 "requested_tool": name,
214 "args": serde_json::from_str::<serde_json::Value>(arguments).unwrap_or(serde_json::json!({}))
215 });
216 match invalid_tool.execute(args).await {
217 Ok(result) => result,
218 Err(e) => ToolResult {
219 output: format!("Unknown tool: {}. Error: {}", name, e),
220 success: false,
221 metadata: HashMap::new(),
222 },
223 }
224 }
225 }
226 }
227
228 pub fn get_tool(&self, name: &str) -> Option<Arc<dyn Tool>> {
230 self.tools.get(name)
231 }
232
233 pub fn register_tool(&mut self, tool: Arc<dyn Tool>) {
235 self.tools.register(tool);
236 }
237
238 pub fn list_tools(&self) -> Vec<&str> {
240 self.tools.list()
241 }
242
243 pub fn has_tool(&self, name: &str) -> bool {
245 self.tools.get(name).is_some()
246 }
247}
248
249#[async_trait]
251impl Actor for Agent {
252 fn actor_id(&self) -> &str {
253 &self.info.name
254 }
255
256 fn actor_status(&self) -> ActorStatus {
257 ActorStatus::Ready
259 }
260
261 async fn initialize(&mut self) -> Result<()> {
262 tracing::info!(
265 "Agent '{}' initialized for swarm participation",
266 self.info.name
267 );
268 Ok(())
269 }
270
271 async fn shutdown(&mut self) -> Result<()> {
272 tracing::info!("Agent '{}' shutting down", self.info.name);
273 Ok(())
274 }
275}
276
277#[async_trait]
279impl Handler<SwarmMessage> for Agent {
280 type Response = SwarmMessage;
281
282 async fn handle(&mut self, message: SwarmMessage) -> Result<Self::Response> {
283 match message {
284 SwarmMessage::ExecuteTask {
285 task_id,
286 instruction,
287 } => {
288 let mut session = Session::new().await?;
290
291 match self.execute(&mut session, &instruction).await {
293 Ok(response) => Ok(SwarmMessage::TaskCompleted {
294 task_id,
295 result: response.text,
296 }),
297 Err(e) => Ok(SwarmMessage::TaskFailed {
298 task_id,
299 error: e.to_string(),
300 }),
301 }
302 }
303 SwarmMessage::ToolRequest { tool_id, arguments } => {
304 let result = if let Some(tool) = self.get_tool(&tool_id) {
306 match tool.execute(arguments).await {
307 Ok(r) => r,
308 Err(e) => ToolResult::error(format!("Tool execution failed: {}", e)),
309 }
310 } else {
311 let available_tools = self.tools.list().iter().map(|s| s.to_string()).collect();
313 let invalid_tool = crate::tool::invalid::InvalidTool::with_context(
314 tool_id.clone(),
315 available_tools,
316 );
317 let args = serde_json::json!({
318 "requested_tool": tool_id,
319 "args": arguments
320 });
321 match invalid_tool.execute(args).await {
322 Ok(r) => r,
323 Err(e) => ToolResult::error(format!("Tool '{}' not found: {}", tool_id, e)),
324 }
325 };
326
327 Ok(SwarmMessage::ToolResponse { tool_id, result })
328 }
329 _ => {
330 Ok(SwarmMessage::TaskFailed {
332 task_id: "unknown".to_string(),
333 error: "Unsupported message type".to_string(),
334 })
335 }
336 }
337 }
338}
339
340#[derive(Debug, Clone, Serialize, Deserialize)]
342pub struct AgentResponse {
343 pub text: String,
344 pub tool_uses: Vec<ToolUse>,
345 pub usage: crate::provider::Usage,
346}
347
348#[derive(Debug, Clone, Serialize, Deserialize)]
350pub struct ToolUse {
351 pub id: String,
352 pub name: String,
353 pub input: String,
354 pub output: String,
355 pub success: bool,
356}
357
358pub struct AgentRegistry {
360 agents: HashMap<String, AgentInfo>,
361}
362
363impl AgentRegistry {
364 #[allow(dead_code)]
365 pub fn new() -> Self {
366 Self {
367 agents: HashMap::new(),
368 }
369 }
370
371 pub fn register(&mut self, info: AgentInfo) {
373 self.agents.insert(info.name.clone(), info);
374 }
375
376 #[allow(dead_code)]
378 pub fn get(&self, name: &str) -> Option<&AgentInfo> {
379 self.agents.get(name)
380 }
381
382 pub fn list(&self) -> Vec<&AgentInfo> {
384 self.agents.values().collect()
385 }
386
387 #[allow(dead_code)]
389 pub fn list_primary(&self) -> Vec<&AgentInfo> {
390 self.agents
391 .values()
392 .filter(|a| a.mode == AgentMode::Primary && !a.hidden)
393 .collect()
394 }
395
396 pub fn with_builtins() -> Self {
398 let mut registry = Self::new();
399
400 registry.register(builtin::build_agent());
401 registry.register(builtin::plan_agent());
402 registry.register(builtin::explore_agent());
403
404 registry
405 }
406}
407
408impl Default for AgentRegistry {
409 fn default() -> Self {
410 Self::with_builtins()
411 }
412}