Skip to main content

aagt_core/agent/
multi_agent.rs

1//! Multi-agent coordination system
2//!
3//! Enables multiple specialized agents to work together.
4
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use dashmap::DashMap;
9use tracing::info;
10
11use crate::error::{Error, Result};
12use crate::agent::scheduler::Scheduler;
13use crate::agent::memory::Memory;
14
15/// Role of an agent in a multi-agent system
16#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
17pub enum AgentRole {
18    /// Research and analysis
19    Researcher,
20    /// Trade execution
21    Trader,
22    /// Risk assessment
23    RiskAnalyst,
24    /// Strategy planning
25    Strategist,
26    /// User interaction
27    Assistant,
28    /// Custom role
29    Custom(String),
30}
31
32impl AgentRole {
33    /// Get the role name
34    pub fn name(&self) -> &str {
35        match self {
36            Self::Researcher => "researcher",
37            Self::Trader => "trader",
38            Self::RiskAnalyst => "risk_analyst",
39            Self::Strategist => "strategist",
40            Self::Assistant => "assistant",
41            Self::Custom(name) => name,
42        }
43    }
44}
45
46/// Message between agents
47#[derive(Debug, Clone)]
48pub struct AgentMessage {
49    /// Sender role
50    pub from: AgentRole,
51    /// Target role (None = broadcast)
52    pub to: Option<AgentRole>,
53    /// Message content
54    pub content: String,
55    /// Message type
56    pub msg_type: MessageType,
57}
58
59/// Type of inter-agent message
60#[derive(Debug, Clone)]
61pub enum MessageType {
62    /// Request for action
63    Request,
64    /// Response to request
65    Response,
66    /// Information share
67    Info,
68    /// Approval request
69    Approval,
70    /// Denial response
71    Denial,
72    /// Handover to another agent
73    Handover,
74}
75
76/// Trait for agents that can participate in multi-agent systems
77#[async_trait]
78pub trait MultiAgent: Send + Sync {
79    /// Get this agent's role
80    fn role(&self) -> AgentRole;
81
82    /// Handle an incoming message from another agent
83    async fn handle_message(&self, message: AgentMessage) -> Result<Option<AgentMessage>>;
84
85    /// Process a user request
86    async fn process(&self, input: &str) -> Result<String>;
87}
88
89/// Coordinator for multi-agent systems
90pub struct Coordinator {
91    /// Registered agents
92    agents: DashMap<AgentRole, Arc<dyn MultiAgent>>,
93    /// Max rounds of coordination
94    max_rounds: usize,
95    /// Scheduler for proactive tasks
96    pub scheduler: tokio::sync::OnceCell<Arc<Scheduler>>,
97    /// Shared memory for the system
98    pub memory: tokio::sync::OnceCell<Arc<dyn Memory>>,
99}
100
101impl Coordinator {
102    /// Create a new coordinator
103    pub fn new() -> Self {
104        Self {
105            agents: DashMap::new(),
106            max_rounds: 10,
107            scheduler: tokio::sync::OnceCell::new(),
108            memory: tokio::sync::OnceCell::new(),
109        }
110    }
111
112    /// Set max coordination rounds
113    pub fn with_max_rounds(mut self, rounds: usize) -> Self {
114        self.max_rounds = rounds;
115        self
116    }
117
118    /// Register an agent
119    pub fn register(&self, agent: Arc<dyn MultiAgent>) {
120        self.agents.insert(agent.role(), agent);
121    }
122
123    /// Get an agent by role
124    pub fn get(&self, role: &AgentRole) -> Option<Arc<dyn MultiAgent>> {
125        self.agents.get(role).map(|r| Arc::clone(&r))
126    }
127
128    /// Start the background scheduler
129    pub async fn start_scheduler(self: &Arc<Self>) -> Arc<Scheduler> {
130        let scheduler = self.scheduler.get_or_init(|| async {
131            let scheduler = Arc::new(Scheduler::new(Arc::downgrade(self)).await);
132            
133            // Link scheduler to memory if available
134            if let Some(memory) = self.memory.get() {
135                memory.link_scheduler(Arc::downgrade(&scheduler));
136            }
137            
138            let s_clone = Arc::clone(&scheduler);
139            tokio::spawn(async move {
140                s_clone.run().await;
141            });
142            scheduler
143        }).await.clone();
144        
145        scheduler
146    }
147
148    /// Route a message to the appropriate agent
149    pub async fn route(&self, message: AgentMessage) -> Result<Option<AgentMessage>> {
150        if let Some(target_role) = &message.to {
151            // Directed message
152            if let Some(agent) = self.get(target_role) {
153                return agent.handle_message(message).await;
154            } else {
155                return Err(Error::AgentCommunication(format!(
156                    "No agent with role: {:?}",
157                    target_role
158                )));
159            }
160        }
161
162        // Broadcast message - send to all agents except sender
163        let from_role = message.from.clone();
164        let mut responses = Vec::new();
165
166        for entry in self.agents.iter() {
167            if entry.key() != &from_role {
168                if let Some(response) = entry.value().handle_message(message.clone()).await? {
169                    responses.push(response);
170                }
171            }
172        }
173
174        // Return first response for now (could aggregate in future)
175        Ok(responses.into_iter().next())
176    }
177
178    /// Orchestrate a task through a dynamic workflow of agents
179    pub async fn orchestrate(&self, task: &str, workflow: Vec<AgentRole>) -> Result<String> {
180        if workflow.is_empty() {
181            return Err(Error::AgentCoordination("Workflow cannot be empty".to_string()));
182        }
183
184        let lead_role = &workflow[0];
185        let lead = self
186            .get(lead_role)
187            .ok_or_else(|| Error::AgentCoordination(format!("No lead agent found for role: {:?}", lead_role)))?;
188
189        // 1. Initial processing by lead agent
190        let mut current_result = lead.process(task).await?;
191        let mut current_role = lead_role.clone();
192
193        // 2. Pass result through the rest of the workflow chain OR follow handovers
194        let mut i = 1;
195        while i < workflow.len() {
196            let next_role = &workflow[i];
197            if let Some(agent) = self.get(next_role) {
198                let msg_type = if i == workflow.len() - 1 {
199                    MessageType::Approval
200                } else {
201                    MessageType::Request
202                };
203
204                let message = AgentMessage {
205                    from: current_role.clone(),
206                    to: Some(next_role.clone()),
207                    content: current_result.clone(),
208                    msg_type,
209                };
210
211                if let Some(response) = agent.handle_message(message).await? {
212                    // Check for Handover
213                    if matches!(response.msg_type, MessageType::Handover) {
214                        // Dynamic handover: the agent specifies the next role in the content or target
215                        if let Some(handover_to) = response.to {
216                             // If target is specified, we diverted from static workflow
217                             if let Some(_handover_agent) = self.get(&handover_to) {
218                                 info!("Dynamic Handover from {:?} to {:?}", next_role, handover_to);
219                                 current_result = response.content;
220                                 current_role = handover_to;
221                                 // We don't increment i here, we stay in the loop to process the handover
222                                 // To prevent infinite loops, we should have a max_rounds check
223                                 continue; 
224                             }
225                        }
226                    }
227
228                    // Check for strict denial/stop signal
229                    if matches!(response.msg_type, MessageType::Denial) {
230                        return Err(Error::AgentCoordination(format!(
231                            "Agent {:?} denied processing: {}",
232                            next_role, response.content
233                        )));
234                    }
235                    current_result = response.content;
236                }
237                current_role = next_role.clone();
238            } else {
239                return Err(Error::AgentCoordination(format!(
240                    "Workflow failed: Agent {:?} not found",
241                    next_role
242                )));
243            }
244            i += 1;
245        }
246
247        Ok(current_result)
248    }
249
250    /// Get list of registered agent roles
251    pub fn roles(&self) -> Vec<AgentRole> {
252        self.agents.iter().map(|r| r.key().clone()).collect()
253    }
254
255    /// Set the shared memory for the coordinator
256    pub fn set_memory(&self, memory: Arc<dyn Memory>) {
257        if let Some(scheduler) = self.scheduler.get() {
258            memory.link_scheduler(Arc::downgrade(scheduler));
259        }
260        let _ = self.memory.set(memory);
261    }
262}
263
264impl Default for Coordinator {
265    fn default() -> Self {
266        Self::new()
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    struct MockAgent {
275        role: AgentRole,
276        response: String,
277    }
278
279    #[async_trait]
280    impl MultiAgent for MockAgent {
281        fn role(&self) -> AgentRole {
282            self.role.clone()
283        }
284
285        async fn handle_message(&self, _message: AgentMessage) -> Result<Option<AgentMessage>> {
286            Ok(Some(AgentMessage {
287                from: self.role.clone(),
288                to: None,
289                content: self.response.clone(),
290                msg_type: MessageType::Response,
291            }))
292        }
293
294        async fn process(&self, _input: &str) -> Result<String> {
295            Ok(self.response.clone())
296        }
297    }
298
299    #[tokio::test]
300    async fn test_coordinator() {
301        let coordinator = Coordinator::new();
302
303        coordinator.register(Arc::new(MockAgent {
304            role: AgentRole::Researcher,
305            response: "Research complete".to_string(),
306        }));
307
308        coordinator.register(Arc::new(MockAgent {
309            role: AgentRole::Trader,
310            response: "Trade executed".to_string(),
311        }));
312
313        assert_eq!(coordinator.roles().len(), 2);
314    }
315}