aagt_core/agent/
multi_agent.rs1use std::sync::Arc;
6
7use async_trait::async_trait;
8use dashmap::DashMap;
9use tracing::info;
10
11use crate::error::{Error, Result};
12use crate::agent::scheduler::Scheduler;
13use crate::agent::memory::Memory;
14
15#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
17pub enum AgentRole {
18 Researcher,
20 Trader,
22 RiskAnalyst,
24 Strategist,
26 Assistant,
28 Custom(String),
30}
31
32impl AgentRole {
33 pub fn name(&self) -> &str {
35 match self {
36 Self::Researcher => "researcher",
37 Self::Trader => "trader",
38 Self::RiskAnalyst => "risk_analyst",
39 Self::Strategist => "strategist",
40 Self::Assistant => "assistant",
41 Self::Custom(name) => name,
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
48pub struct AgentMessage {
49 pub from: AgentRole,
51 pub to: Option<AgentRole>,
53 pub content: String,
55 pub msg_type: MessageType,
57}
58
59#[derive(Debug, Clone)]
61pub enum MessageType {
62 Request,
64 Response,
66 Info,
68 Approval,
70 Denial,
72 Handover,
74}
75
76#[async_trait]
78pub trait MultiAgent: Send + Sync {
79 fn role(&self) -> AgentRole;
81
82 async fn handle_message(&self, message: AgentMessage) -> Result<Option<AgentMessage>>;
84
85 async fn process(&self, input: &str) -> Result<String>;
87}
88
89pub struct Coordinator {
91 agents: DashMap<AgentRole, Arc<dyn MultiAgent>>,
93 max_rounds: usize,
95 pub scheduler: tokio::sync::OnceCell<Arc<Scheduler>>,
97 pub memory: tokio::sync::OnceCell<Arc<dyn Memory>>,
99}
100
101impl Coordinator {
102 pub fn new() -> Self {
104 Self {
105 agents: DashMap::new(),
106 max_rounds: 10,
107 scheduler: tokio::sync::OnceCell::new(),
108 memory: tokio::sync::OnceCell::new(),
109 }
110 }
111
112 pub fn with_max_rounds(mut self, rounds: usize) -> Self {
114 self.max_rounds = rounds;
115 self
116 }
117
118 pub fn register(&self, agent: Arc<dyn MultiAgent>) {
120 self.agents.insert(agent.role(), agent);
121 }
122
123 pub fn get(&self, role: &AgentRole) -> Option<Arc<dyn MultiAgent>> {
125 self.agents.get(role).map(|r| Arc::clone(&r))
126 }
127
128 pub async fn start_scheduler(self: &Arc<Self>) -> Arc<Scheduler> {
130 let scheduler = self.scheduler.get_or_init(|| async {
131 let scheduler = Arc::new(Scheduler::new(Arc::downgrade(self)).await);
132
133 if let Some(memory) = self.memory.get() {
135 memory.link_scheduler(Arc::downgrade(&scheduler));
136 }
137
138 let s_clone = Arc::clone(&scheduler);
139 tokio::spawn(async move {
140 s_clone.run().await;
141 });
142 scheduler
143 }).await.clone();
144
145 scheduler
146 }
147
148 pub async fn route(&self, message: AgentMessage) -> Result<Option<AgentMessage>> {
150 if let Some(target_role) = &message.to {
151 if let Some(agent) = self.get(target_role) {
153 return agent.handle_message(message).await;
154 } else {
155 return Err(Error::AgentCommunication(format!(
156 "No agent with role: {:?}",
157 target_role
158 )));
159 }
160 }
161
162 let from_role = message.from.clone();
164 let mut responses = Vec::new();
165
166 for entry in self.agents.iter() {
167 if entry.key() != &from_role {
168 if let Some(response) = entry.value().handle_message(message.clone()).await? {
169 responses.push(response);
170 }
171 }
172 }
173
174 Ok(responses.into_iter().next())
176 }
177
178 pub async fn orchestrate(&self, task: &str, workflow: Vec<AgentRole>) -> Result<String> {
180 if workflow.is_empty() {
181 return Err(Error::AgentCoordination("Workflow cannot be empty".to_string()));
182 }
183
184 let lead_role = &workflow[0];
185 let lead = self
186 .get(lead_role)
187 .ok_or_else(|| Error::AgentCoordination(format!("No lead agent found for role: {:?}", lead_role)))?;
188
189 let mut current_result = lead.process(task).await?;
191 let mut current_role = lead_role.clone();
192
193 let mut i = 1;
195 while i < workflow.len() {
196 let next_role = &workflow[i];
197 if let Some(agent) = self.get(next_role) {
198 let msg_type = if i == workflow.len() - 1 {
199 MessageType::Approval
200 } else {
201 MessageType::Request
202 };
203
204 let message = AgentMessage {
205 from: current_role.clone(),
206 to: Some(next_role.clone()),
207 content: current_result.clone(),
208 msg_type,
209 };
210
211 if let Some(response) = agent.handle_message(message).await? {
212 if matches!(response.msg_type, MessageType::Handover) {
214 if let Some(handover_to) = response.to {
216 if let Some(_handover_agent) = self.get(&handover_to) {
218 info!("Dynamic Handover from {:?} to {:?}", next_role, handover_to);
219 current_result = response.content;
220 current_role = handover_to;
221 continue;
224 }
225 }
226 }
227
228 if matches!(response.msg_type, MessageType::Denial) {
230 return Err(Error::AgentCoordination(format!(
231 "Agent {:?} denied processing: {}",
232 next_role, response.content
233 )));
234 }
235 current_result = response.content;
236 }
237 current_role = next_role.clone();
238 } else {
239 return Err(Error::AgentCoordination(format!(
240 "Workflow failed: Agent {:?} not found",
241 next_role
242 )));
243 }
244 i += 1;
245 }
246
247 Ok(current_result)
248 }
249
250 pub fn roles(&self) -> Vec<AgentRole> {
252 self.agents.iter().map(|r| r.key().clone()).collect()
253 }
254
255 pub fn set_memory(&self, memory: Arc<dyn Memory>) {
257 if let Some(scheduler) = self.scheduler.get() {
258 memory.link_scheduler(Arc::downgrade(scheduler));
259 }
260 let _ = self.memory.set(memory);
261 }
262}
263
264impl Default for Coordinator {
265 fn default() -> Self {
266 Self::new()
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 struct MockAgent {
275 role: AgentRole,
276 response: String,
277 }
278
279 #[async_trait]
280 impl MultiAgent for MockAgent {
281 fn role(&self) -> AgentRole {
282 self.role.clone()
283 }
284
285 async fn handle_message(&self, _message: AgentMessage) -> Result<Option<AgentMessage>> {
286 Ok(Some(AgentMessage {
287 from: self.role.clone(),
288 to: None,
289 content: self.response.clone(),
290 msg_type: MessageType::Response,
291 }))
292 }
293
294 async fn process(&self, _input: &str) -> Result<String> {
295 Ok(self.response.clone())
296 }
297 }
298
299 #[tokio::test]
300 async fn test_coordinator() {
301 let coordinator = Coordinator::new();
302
303 coordinator.register(Arc::new(MockAgent {
304 role: AgentRole::Researcher,
305 response: "Research complete".to_string(),
306 }));
307
308 coordinator.register(Arc::new(MockAgent {
309 role: AgentRole::Trader,
310 response: "Trade executed".to_string(),
311 }));
312
313 assert_eq!(coordinator.roles().len(), 2);
314 }
315}