Skip to main content

aagt_core/agent/
multi_agent.rs

1//! Multi-agent coordination system
2//!
3//! Enables multiple specialized agents to work together.
4
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use dashmap::DashMap;
9
10use crate::error::{Error, Result};
11use crate::agent::scheduler::Scheduler;
12use crate::agent::memory::Memory;
13
14/// Role of an agent in a multi-agent system
15#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
16pub enum AgentRole {
17    /// Research and analysis
18    Researcher,
19    /// Trade execution
20    Trader,
21    /// Risk assessment
22    RiskAnalyst,
23    /// Strategy planning
24    Strategist,
25    /// User interaction
26    Assistant,
27    /// Custom role
28    Custom(String),
29}
30
31impl AgentRole {
32    /// Get the role name
33    pub fn name(&self) -> &str {
34        match self {
35            Self::Researcher => "researcher",
36            Self::Trader => "trader",
37            Self::RiskAnalyst => "risk_analyst",
38            Self::Strategist => "strategist",
39            Self::Assistant => "assistant",
40            Self::Custom(name) => name,
41        }
42    }
43}
44
45/// Message between agents
46#[derive(Debug, Clone)]
47pub struct AgentMessage {
48    /// Sender role
49    pub from: AgentRole,
50    /// Target role (None = broadcast)
51    pub to: Option<AgentRole>,
52    /// Message content
53    pub content: String,
54    /// Message type
55    pub msg_type: MessageType,
56}
57
58/// Type of inter-agent message
59#[derive(Debug, Clone)]
60pub enum MessageType {
61    /// Request for action
62    Request,
63    /// Response to request
64    Response,
65    /// Information share
66    Info,
67    /// Approval request
68    Approval,
69    /// Denial response
70    Denial,
71}
72
73/// Trait for agents that can participate in multi-agent systems
74#[async_trait]
75pub trait MultiAgent: Send + Sync {
76    /// Get this agent's role
77    fn role(&self) -> AgentRole;
78
79    /// Handle an incoming message from another agent
80    async fn handle_message(&self, message: AgentMessage) -> Result<Option<AgentMessage>>;
81
82    /// Process a user request
83    async fn process(&self, input: &str) -> Result<String>;
84}
85
86/// Coordinator for multi-agent systems
87pub struct Coordinator {
88    /// Registered agents
89    agents: DashMap<AgentRole, Arc<dyn MultiAgent>>,
90    /// Max rounds of coordination
91    max_rounds: usize,
92    /// Scheduler for proactive tasks
93    pub scheduler: tokio::sync::OnceCell<Arc<Scheduler>>,
94    /// Shared memory for the system
95    pub memory: tokio::sync::OnceCell<Arc<dyn Memory>>,
96}
97
98impl Coordinator {
99    /// Create a new coordinator
100    pub fn new() -> Self {
101        Self {
102            agents: DashMap::new(),
103            max_rounds: 10,
104            scheduler: tokio::sync::OnceCell::new(),
105            memory: tokio::sync::OnceCell::new(),
106        }
107    }
108
109    /// Set max coordination rounds
110    pub fn with_max_rounds(mut self, rounds: usize) -> Self {
111        self.max_rounds = rounds;
112        self
113    }
114
115    /// Register an agent
116    pub fn register(&self, agent: Arc<dyn MultiAgent>) {
117        self.agents.insert(agent.role(), agent);
118    }
119
120    /// Get an agent by role
121    pub fn get(&self, role: &AgentRole) -> Option<Arc<dyn MultiAgent>> {
122        self.agents.get(role).map(|r| Arc::clone(&r))
123    }
124
125    /// Start the background scheduler
126    pub async fn start_scheduler(self: &Arc<Self>) -> Arc<Scheduler> {
127        let scheduler = self.scheduler.get_or_init(|| async {
128            let scheduler = Arc::new(Scheduler::new(Arc::downgrade(self)).await);
129            
130            // Link scheduler to memory if available
131            if let Some(memory) = self.memory.get() {
132                memory.link_scheduler(Arc::downgrade(&scheduler));
133            }
134            
135            let s_clone = Arc::clone(&scheduler);
136            tokio::spawn(async move {
137                s_clone.run().await;
138            });
139            scheduler
140        }).await.clone();
141        
142        scheduler
143    }
144
145    /// Route a message to the appropriate agent
146    pub async fn route(&self, message: AgentMessage) -> Result<Option<AgentMessage>> {
147        if let Some(target_role) = &message.to {
148            // Directed message
149            if let Some(agent) = self.get(target_role) {
150                return agent.handle_message(message).await;
151            } else {
152                return Err(Error::AgentCommunication(format!(
153                    "No agent with role: {:?}",
154                    target_role
155                )));
156            }
157        }
158
159        // Broadcast message - send to all agents except sender
160        let from_role = message.from.clone();
161        let mut responses = Vec::new();
162
163        for entry in self.agents.iter() {
164            if entry.key() != &from_role {
165                if let Some(response) = entry.value().handle_message(message.clone()).await? {
166                    responses.push(response);
167                }
168            }
169        }
170
171        // Return first response for now (could aggregate in future)
172        Ok(responses.into_iter().next())
173    }
174
175    /// Orchestrate a task through a dynamic workflow of agents
176    pub async fn orchestrate(&self, task: &str, workflow: Vec<AgentRole>) -> Result<String> {
177        if workflow.is_empty() {
178            return Err(Error::AgentCoordination("Workflow cannot be empty".to_string()));
179        }
180
181        let lead_role = &workflow[0];
182        let lead = self
183            .get(lead_role)
184            .ok_or_else(|| Error::AgentCoordination(format!("No lead agent found for role: {:?}", lead_role)))?;
185
186        // 1. Initial processing by lead agent
187        let mut current_result = lead.process(task).await?;
188
189        // 2. Pass result through the rest of the workflow chain
190        for (i, role) in workflow.iter().enumerate().skip(1) {
191            if let Some(agent) = self.get(role) {
192                // Determine message type based on position in chain
193                // Last agent usually gives final approval/response
194                let msg_type = if i == workflow.len() - 1 {
195                    MessageType::Approval
196                } else {
197                    MessageType::Request
198                };
199
200                let message = AgentMessage {
201                    from: workflow[i-1].clone(),
202                    to: Some(role.clone()),
203                    content: current_result.clone(),
204                    msg_type,
205                };
206
207                if let Some(response) = agent.handle_message(message).await? {
208                    // Check for strict denial/stop signal
209                    if matches!(response.msg_type, MessageType::Denial) {
210                        return Err(Error::AgentCoordination(format!(
211                            "Agent {:?} denied processing: {}",
212                            role, response.content
213                        )));
214                    }
215                    current_result = response.content;
216                }
217            } else {
218                return Err(Error::AgentCoordination(format!(
219                    "Workflow failed: Agent {:?} not found",
220                    role
221                )));
222            }
223        }
224
225        Ok(current_result)
226    }
227
228    /// Get list of registered agent roles
229    pub fn roles(&self) -> Vec<AgentRole> {
230        self.agents.iter().map(|r| r.key().clone()).collect()
231    }
232
233    /// Set the shared memory for the coordinator
234    pub fn set_memory(&self, memory: Arc<dyn Memory>) {
235        if let Some(scheduler) = self.scheduler.get() {
236            memory.link_scheduler(Arc::downgrade(scheduler));
237        }
238        let _ = self.memory.set(memory);
239    }
240}
241
242impl Default for Coordinator {
243    fn default() -> Self {
244        Self::new()
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    struct MockAgent {
253        role: AgentRole,
254        response: String,
255    }
256
257    #[async_trait]
258    impl MultiAgent for MockAgent {
259        fn role(&self) -> AgentRole {
260            self.role.clone()
261        }
262
263        async fn handle_message(&self, _message: AgentMessage) -> Result<Option<AgentMessage>> {
264            Ok(Some(AgentMessage {
265                from: self.role.clone(),
266                to: None,
267                content: self.response.clone(),
268                msg_type: MessageType::Response,
269            }))
270        }
271
272        async fn process(&self, _input: &str) -> Result<String> {
273            Ok(self.response.clone())
274        }
275    }
276
277    #[tokio::test]
278    async fn test_coordinator() {
279        let coordinator = Coordinator::new();
280
281        coordinator.register(Arc::new(MockAgent {
282            role: AgentRole::Researcher,
283            response: "Research complete".to_string(),
284        }));
285
286        coordinator.register(Arc::new(MockAgent {
287            role: AgentRole::Trader,
288            response: "Trade executed".to_string(),
289        }));
290
291        assert_eq!(coordinator.roles().len(), 2);
292    }
293}