aagt_core/agent/
multi_agent.rs1use std::sync::Arc;
6
7use async_trait::async_trait;
8use dashmap::DashMap;
9
10use crate::error::{Error, Result};
11use crate::agent::scheduler::Scheduler;
12use crate::agent::memory::Memory;
13
14#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
16pub enum AgentRole {
17 Researcher,
19 Trader,
21 RiskAnalyst,
23 Strategist,
25 Assistant,
27 Custom(String),
29}
30
31impl AgentRole {
32 pub fn name(&self) -> &str {
34 match self {
35 Self::Researcher => "researcher",
36 Self::Trader => "trader",
37 Self::RiskAnalyst => "risk_analyst",
38 Self::Strategist => "strategist",
39 Self::Assistant => "assistant",
40 Self::Custom(name) => name,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct AgentMessage {
48 pub from: AgentRole,
50 pub to: Option<AgentRole>,
52 pub content: String,
54 pub msg_type: MessageType,
56}
57
58#[derive(Debug, Clone)]
60pub enum MessageType {
61 Request,
63 Response,
65 Info,
67 Approval,
69 Denial,
71}
72
73#[async_trait]
75pub trait MultiAgent: Send + Sync {
76 fn role(&self) -> AgentRole;
78
79 async fn handle_message(&self, message: AgentMessage) -> Result<Option<AgentMessage>>;
81
82 async fn process(&self, input: &str) -> Result<String>;
84}
85
86pub struct Coordinator {
88 agents: DashMap<AgentRole, Arc<dyn MultiAgent>>,
90 max_rounds: usize,
92 pub scheduler: tokio::sync::OnceCell<Arc<Scheduler>>,
94 pub memory: tokio::sync::OnceCell<Arc<dyn Memory>>,
96}
97
98impl Coordinator {
99 pub fn new() -> Self {
101 Self {
102 agents: DashMap::new(),
103 max_rounds: 10,
104 scheduler: tokio::sync::OnceCell::new(),
105 memory: tokio::sync::OnceCell::new(),
106 }
107 }
108
109 pub fn with_max_rounds(mut self, rounds: usize) -> Self {
111 self.max_rounds = rounds;
112 self
113 }
114
115 pub fn register(&self, agent: Arc<dyn MultiAgent>) {
117 self.agents.insert(agent.role(), agent);
118 }
119
120 pub fn get(&self, role: &AgentRole) -> Option<Arc<dyn MultiAgent>> {
122 self.agents.get(role).map(|r| Arc::clone(&r))
123 }
124
125 pub async fn start_scheduler(self: &Arc<Self>) -> Arc<Scheduler> {
127 let scheduler = self.scheduler.get_or_init(|| async {
128 let scheduler = Arc::new(Scheduler::new(Arc::downgrade(self)).await);
129
130 if let Some(memory) = self.memory.get() {
132 memory.link_scheduler(Arc::downgrade(&scheduler));
133 }
134
135 let s_clone = Arc::clone(&scheduler);
136 tokio::spawn(async move {
137 s_clone.run().await;
138 });
139 scheduler
140 }).await.clone();
141
142 scheduler
143 }
144
145 pub async fn route(&self, message: AgentMessage) -> Result<Option<AgentMessage>> {
147 if let Some(target_role) = &message.to {
148 if let Some(agent) = self.get(target_role) {
150 return agent.handle_message(message).await;
151 } else {
152 return Err(Error::AgentCommunication(format!(
153 "No agent with role: {:?}",
154 target_role
155 )));
156 }
157 }
158
159 let from_role = message.from.clone();
161 let mut responses = Vec::new();
162
163 for entry in self.agents.iter() {
164 if entry.key() != &from_role {
165 if let Some(response) = entry.value().handle_message(message.clone()).await? {
166 responses.push(response);
167 }
168 }
169 }
170
171 Ok(responses.into_iter().next())
173 }
174
175 pub async fn orchestrate(&self, task: &str, workflow: Vec<AgentRole>) -> Result<String> {
177 if workflow.is_empty() {
178 return Err(Error::AgentCoordination("Workflow cannot be empty".to_string()));
179 }
180
181 let lead_role = &workflow[0];
182 let lead = self
183 .get(lead_role)
184 .ok_or_else(|| Error::AgentCoordination(format!("No lead agent found for role: {:?}", lead_role)))?;
185
186 let mut current_result = lead.process(task).await?;
188
189 for (i, role) in workflow.iter().enumerate().skip(1) {
191 if let Some(agent) = self.get(role) {
192 let msg_type = if i == workflow.len() - 1 {
195 MessageType::Approval
196 } else {
197 MessageType::Request
198 };
199
200 let message = AgentMessage {
201 from: workflow[i-1].clone(),
202 to: Some(role.clone()),
203 content: current_result.clone(),
204 msg_type,
205 };
206
207 if let Some(response) = agent.handle_message(message).await? {
208 if matches!(response.msg_type, MessageType::Denial) {
210 return Err(Error::AgentCoordination(format!(
211 "Agent {:?} denied processing: {}",
212 role, response.content
213 )));
214 }
215 current_result = response.content;
216 }
217 } else {
218 return Err(Error::AgentCoordination(format!(
219 "Workflow failed: Agent {:?} not found",
220 role
221 )));
222 }
223 }
224
225 Ok(current_result)
226 }
227
228 pub fn roles(&self) -> Vec<AgentRole> {
230 self.agents.iter().map(|r| r.key().clone()).collect()
231 }
232
233 pub fn set_memory(&self, memory: Arc<dyn Memory>) {
235 if let Some(scheduler) = self.scheduler.get() {
236 memory.link_scheduler(Arc::downgrade(scheduler));
237 }
238 let _ = self.memory.set(memory);
239 }
240}
241
242impl Default for Coordinator {
243 fn default() -> Self {
244 Self::new()
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251
252 struct MockAgent {
253 role: AgentRole,
254 response: String,
255 }
256
257 #[async_trait]
258 impl MultiAgent for MockAgent {
259 fn role(&self) -> AgentRole {
260 self.role.clone()
261 }
262
263 async fn handle_message(&self, _message: AgentMessage) -> Result<Option<AgentMessage>> {
264 Ok(Some(AgentMessage {
265 from: self.role.clone(),
266 to: None,
267 content: self.response.clone(),
268 msg_type: MessageType::Response,
269 }))
270 }
271
272 async fn process(&self, _input: &str) -> Result<String> {
273 Ok(self.response.clone())
274 }
275 }
276
277 #[tokio::test]
278 async fn test_coordinator() {
279 let coordinator = Coordinator::new();
280
281 coordinator.register(Arc::new(MockAgent {
282 role: AgentRole::Researcher,
283 response: "Research complete".to_string(),
284 }));
285
286 coordinator.register(Arc::new(MockAgent {
287 role: AgentRole::Trader,
288 response: "Trade executed".to_string(),
289 }));
290
291 assert_eq!(coordinator.roles().len(), 2);
292 }
293}