1#![allow(dead_code)]
8
9use crate::agency::models::{ModelConfig, ModelProvider};
10use crate::agency::tools::Tool;
11use chrono::{DateTime, Utc};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::{Arc, RwLock};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
18#[serde(rename_all = "lowercase")]
19pub enum AgentStatus {
20 #[default]
21 Idle,
22 Thinking,
23 Executing,
24 WaitingForTool,
25 WaitingForInput,
26 Completed,
27 Failed,
28 Cancelled,
29}
30
31impl std::fmt::Display for AgentStatus {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 match self {
34 AgentStatus::Idle => write!(f, "idle"),
35 AgentStatus::Thinking => write!(f, "thinking"),
36 AgentStatus::Executing => write!(f, "executing"),
37 AgentStatus::WaitingForTool => write!(f, "waiting_for_tool"),
38 AgentStatus::WaitingForInput => write!(f, "waiting_for_input"),
39 AgentStatus::Completed => write!(f, "completed"),
40 AgentStatus::Failed => write!(f, "failed"),
41 AgentStatus::Cancelled => write!(f, "cancelled"),
42 }
43 }
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
48#[serde(rename_all = "lowercase")]
49pub enum AgentRole {
50 #[default]
51 Assistant,
52 Coordinator,
53 Researcher,
54 Coder,
55 Reviewer,
56 Analyst,
57 Writer,
58 Executor,
59 Household,
61 Business,
63 Tester,
65 Custom,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct AgentConfig {
71 pub name: String,
73 pub description: String,
75 pub instruction: String,
77 #[serde(default)]
79 pub role: AgentRole,
80 #[serde(default)]
82 pub model: ModelConfig,
83 #[serde(default)]
85 pub tools: Vec<String>,
86 #[serde(default)]
88 pub sub_agents: Vec<String>,
89 #[serde(default, skip_serializing_if = "Option::is_none")]
91 pub output_key: Option<String>,
92 #[serde(default, skip_serializing_if = "Option::is_none")]
94 pub max_iterations: Option<u32>,
95 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
97 pub metadata: HashMap<String, serde_json::Value>,
98 pub created_at: DateTime<Utc>,
100 pub updated_at: DateTime<Utc>,
102}
103
104impl Default for AgentConfig {
105 fn default() -> Self {
106 let now = Utc::now();
107 Self {
108 name: String::new(),
109 description: String::new(),
110 instruction: String::new(),
111 role: AgentRole::default(),
112 model: ModelConfig::default(),
113 tools: Vec::new(),
114 sub_agents: Vec::new(),
115 output_key: None,
116 max_iterations: None,
117 metadata: HashMap::new(),
118 created_at: now,
119 updated_at: now,
120 }
121 }
122}
123
124#[derive(Debug)]
126pub struct Agent {
127 pub config: AgentConfig,
129 pub registered_tools: Vec<Arc<Tool>>,
131 pub sub_agents: Vec<Arc<Agent>>,
133 status: RwLock<AgentStatus>,
135}
136
137impl Clone for Agent {
138 fn clone(&self) -> Self {
139 Self {
140 config: self.config.clone(),
141 registered_tools: self.registered_tools.clone(),
142 sub_agents: self.sub_agents.clone(),
143 status: RwLock::new(*self.status.read().unwrap()),
144 }
145 }
146}
147
148impl Agent {
149 pub fn new(config: AgentConfig) -> Self {
151 Self {
152 config,
153 registered_tools: Vec::new(),
154 sub_agents: Vec::new(),
155 status: RwLock::new(AgentStatus::Idle),
156 }
157 }
158
159 pub fn name(&self) -> &str {
161 &self.config.name
162 }
163
164 pub fn description(&self) -> &str {
166 &self.config.description
167 }
168
169 pub fn instruction(&self) -> &str {
171 &self.config.instruction
172 }
173
174 pub fn model(&self) -> &ModelConfig {
176 &self.config.model
177 }
178
179 pub fn has_tools(&self) -> bool {
181 !self.registered_tools.is_empty()
182 }
183
184 pub fn has_sub_agents(&self) -> bool {
186 !self.sub_agents.is_empty()
187 }
188
189 pub fn get_tool(&self, name: &str) -> Option<&Arc<Tool>> {
191 self.registered_tools.iter().find(|t| t.name == name)
192 }
193
194 pub fn status(&self) -> AgentStatus {
196 *self.status.read().unwrap()
197 }
198
199 pub fn set_status(&self, status: AgentStatus) {
201 *self.status.write().unwrap() = status;
202 }
203
204 pub fn tool_definitions(&self) -> Vec<serde_json::Value> {
206 self.registered_tools
207 .iter()
208 .map(|tool| tool.to_function_definition())
209 .collect()
210 }
211}
212
213#[derive(Default)]
215pub struct AgentBuilder {
216 config: AgentConfig,
217 tools: Vec<Arc<Tool>>,
218 sub_agents: Vec<Arc<Agent>>,
219}
220
221impl AgentBuilder {
222 pub fn new(name: impl Into<String>) -> Self {
224 let mut builder = Self::default();
225 builder.config.name = name.into();
226 builder.config.created_at = Utc::now();
227 builder.config.updated_at = Utc::now();
228 builder
229 }
230
231 pub fn description(mut self, description: impl Into<String>) -> Self {
233 self.config.description = description.into();
234 self
235 }
236
237 pub fn instruction(mut self, instruction: impl Into<String>) -> Self {
239 self.config.instruction = instruction.into();
240 self
241 }
242
243 pub fn role(mut self, role: AgentRole) -> Self {
245 self.config.role = role;
246 self
247 }
248
249 pub fn model(mut self, model: impl Into<String>) -> Self {
251 let model_name = model.into();
252 let provider = infer_provider(&model_name);
253 self.config.model = ModelConfig {
254 model: model_name,
255 provider,
256 ..Default::default()
257 };
258 self
259 }
260
261 pub fn model_config(mut self, config: ModelConfig) -> Self {
263 self.config.model = config;
264 self
265 }
266
267 pub fn temperature(mut self, temperature: f32) -> Self {
269 self.config.model.temperature = temperature.clamp(0.0, 2.0);
270 self
271 }
272
273 pub fn max_tokens(mut self, max_tokens: u32) -> Self {
275 self.config.model.max_tokens = Some(max_tokens);
276 self
277 }
278
279 pub fn tool(mut self, tool: Tool) -> Self {
281 let name = tool.name.clone();
282 self.tools.push(Arc::new(tool));
283 self.config.tools.push(name);
284 self
285 }
286
287 pub fn tools(mut self, tools: impl IntoIterator<Item = Tool>) -> Self {
289 for tool in tools {
290 self = self.tool(tool);
291 }
292 self
293 }
294
295 pub fn sub_agent(mut self, agent: Agent) -> Self {
297 let name = agent.config.name.clone();
298 self.sub_agents.push(Arc::new(agent));
299 self.config.sub_agents.push(name);
300 self
301 }
302
303 pub fn sub_agents(mut self, agents: impl IntoIterator<Item = Agent>) -> Self {
305 for agent in agents {
306 self = self.sub_agent(agent);
307 }
308 self
309 }
310
311 pub fn output_key(mut self, key: impl Into<String>) -> Self {
313 self.config.output_key = Some(key.into());
314 self
315 }
316
317 pub fn max_iterations(mut self, max: u32) -> Self {
319 self.config.max_iterations = Some(max);
320 self
321 }
322
323 pub fn metadata(mut self, key: impl Into<String>, value: impl Serialize) -> Self {
325 if let Ok(v) = serde_json::to_value(value) {
326 self.config.metadata.insert(key.into(), v);
327 }
328 self
329 }
330
331 pub fn build(self) -> Agent {
333 Agent {
334 config: self.config,
335 registered_tools: self.tools,
336 sub_agents: self.sub_agents,
337 status: RwLock::new(AgentStatus::Idle),
338 }
339 }
340}
341
342fn infer_provider(model: &str) -> ModelProvider {
344 let model_lower = model.to_lowercase();
345 if model_lower.contains("gemini") || model_lower.contains("palm") {
346 ModelProvider::Google
347 } else if model_lower.contains("gpt")
348 || model_lower.contains("o1")
349 || model_lower.contains("davinci")
350 {
351 ModelProvider::OpenAI
352 } else if model_lower.contains("claude") {
353 ModelProvider::Anthropic
354 } else if model_lower.contains("llama")
355 || model_lower.contains("mistral")
356 || model_lower.contains("codellama")
357 {
358 ModelProvider::Ollama
359 } else {
360 ModelProvider::OpenAICompatible
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367
368 #[test]
369 fn test_agent_builder() {
370 let agent = AgentBuilder::new("test_agent")
371 .description("A test agent")
372 .instruction("You are a helpful assistant.")
373 .model("gemini-2.5-flash")
374 .temperature(0.5)
375 .build();
376
377 assert_eq!(agent.name(), "test_agent");
378 assert_eq!(agent.description(), "A test agent");
379 assert_eq!(agent.config.model.model, "gemini-2.5-flash");
380 assert_eq!(agent.config.model.temperature, 0.5);
381 assert_eq!(agent.config.model.provider, ModelProvider::Google);
382 }
383
384 #[test]
385 fn test_infer_provider() {
386 assert_eq!(infer_provider("gemini-2.5-flash"), ModelProvider::Google);
387 assert_eq!(infer_provider("gpt-4o"), ModelProvider::OpenAI);
388 assert_eq!(infer_provider("claude-3-opus"), ModelProvider::Anthropic);
389 assert_eq!(infer_provider("llama-3.2-8b"), ModelProvider::Ollama);
390 }
391}