helios_engine/
forest.rs

1//! # Forest of Agents Module
2//!
3//! This module implements the "Forest of Agents" feature, which allows multiple agents
4//! to interact with each other, share context, and collaborate on tasks.
5//!
6//! The ForestOfAgents struct manages a collection of agents and provides mechanisms
7//! for inter-agent communication and coordination.
8
9use crate::agent::{Agent, AgentBuilder};
10use crate::config::Config;
11use crate::error::{HeliosError, Result};
12use crate::tools::{Tool, ToolParameter, ToolResult};
13use serde_json::Value;
14use std::collections::HashMap;
15use std::sync::Arc;
16use tokio::sync::RwLock;
17
18/// A unique identifier for an agent in the forest.
19pub type AgentId = String;
20
21/// A message sent between agents in the forest.
22#[derive(Debug, Clone)]
23pub struct ForestMessage {
24    /// The ID of the sender agent.
25    pub from: AgentId,
26    /// The ID of the recipient agent (None for broadcast).
27    pub to: Option<AgentId>,
28    /// The message content.
29    pub content: String,
30    /// Optional metadata associated with the message.
31    pub metadata: HashMap<String, String>,
32    /// Timestamp of the message.
33    pub timestamp: chrono::DateTime<chrono::Utc>,
34}
35
36impl ForestMessage {
37    /// Creates a new forest message.
38    pub fn new(from: AgentId, to: Option<AgentId>, content: String) -> Self {
39        Self {
40            from,
41            to,
42            content,
43            metadata: HashMap::new(),
44            timestamp: chrono::Utc::now(),
45        }
46    }
47
48    /// Creates a broadcast message to all agents.
49    pub fn broadcast(from: AgentId, content: String) -> Self {
50        Self::new(from, None, content)
51    }
52
53    /// Adds metadata to the message.
54    pub fn with_metadata(mut self, key: String, value: String) -> Self {
55        self.metadata.insert(key, value);
56        self
57    }
58}
59
60/// Status of a task in the collaborative workflow.
61#[derive(Debug, Clone, PartialEq)]
62pub enum TaskStatus {
63    Pending,
64    InProgress,
65    Completed,
66    Failed,
67}
68
69impl TaskStatus {
70    pub fn as_str(&self) -> &str {
71        match self {
72            TaskStatus::Pending => "pending",
73            TaskStatus::InProgress => "in_progress",
74            TaskStatus::Completed => "completed",
75            TaskStatus::Failed => "failed",
76        }
77    }
78}
79
80/// A task in the collaborative plan.
81#[derive(Debug, Clone)]
82pub struct TaskItem {
83    /// Unique identifier for the task.
84    pub id: String,
85    /// Description of the task.
86    pub description: String,
87    /// Agent assigned to this task.
88    pub assigned_to: AgentId,
89    /// Current status of the task.
90    pub status: TaskStatus,
91    /// Result/output from the task execution.
92    pub result: Option<String>,
93    /// Dependencies (task IDs that must complete before this one).
94    pub dependencies: Vec<String>,
95    /// Metadata about the task.
96    pub metadata: HashMap<String, String>,
97}
98
99impl TaskItem {
100    pub fn new(id: String, description: String, assigned_to: AgentId) -> Self {
101        Self {
102            id,
103            description,
104            assigned_to,
105            status: TaskStatus::Pending,
106            result: None,
107            dependencies: Vec::new(),
108            metadata: HashMap::new(),
109        }
110    }
111
112    pub fn with_dependencies(mut self, deps: Vec<String>) -> Self {
113        self.dependencies = deps;
114        self
115    }
116}
117
118/// A collaborative task plan created by the coordinator.
119#[derive(Debug, Clone)]
120pub struct TaskPlan {
121    /// Unique identifier for the plan.
122    pub plan_id: String,
123    /// Overall goal/objective.
124    pub objective: String,
125    /// Individual tasks in the plan (HashMap for O(1) lookup).
126    pub tasks: HashMap<String, TaskItem>,
127    /// Task order (maintains insertion order for iteration).
128    pub task_order: Vec<String>,
129    /// Timestamp when plan was created.
130    pub created_at: chrono::DateTime<chrono::Utc>,
131}
132
133impl TaskPlan {
134    pub fn new(plan_id: String, objective: String) -> Self {
135        Self {
136            plan_id,
137            objective,
138            tasks: HashMap::new(),
139            task_order: Vec::new(),
140            created_at: chrono::Utc::now(),
141        }
142    }
143
144    pub fn add_task(&mut self, task: TaskItem) {
145        self.task_order.push(task.id.clone());
146        self.tasks.insert(task.id.clone(), task);
147    }
148
149    pub fn get_task_mut(&mut self, task_id: &str) -> Option<&mut TaskItem> {
150        self.tasks.get_mut(task_id)
151    }
152
153    pub fn get_task(&self, task_id: &str) -> Option<&TaskItem> {
154        self.tasks.get(task_id)
155    }
156
157    /// Get next ready tasks with O(T * D) complexity instead of O(T²)
158    /// where T = number of tasks, D = average dependencies per task
159    pub fn get_next_ready_tasks(&self) -> Vec<&TaskItem> {
160        self.task_order
161            .iter()
162            .filter_map(|task_id| self.tasks.get(task_id))
163            .filter(|t| {
164                t.status == TaskStatus::Pending
165                    && t.dependencies.iter().all(|dep_id| {
166                        // O(1) HashMap lookup instead of O(T) Vec iteration
167                        self.tasks
168                            .get(dep_id)
169                            .map(|dt| dt.status == TaskStatus::Completed)
170                            .unwrap_or(false)
171                    })
172            })
173            .collect()
174    }
175
176    pub fn is_complete(&self) -> bool {
177        self.tasks
178            .values()
179            .all(|t| t.status == TaskStatus::Completed || t.status == TaskStatus::Failed)
180    }
181
182    pub fn get_progress(&self) -> (usize, usize) {
183        let completed = self
184            .tasks
185            .values()
186            .filter(|t| t.status == TaskStatus::Completed)
187            .count();
188        (completed, self.tasks.len())
189    }
190
191    /// Get all tasks in order
192    pub fn tasks_in_order(&self) -> Vec<&TaskItem> {
193        self.task_order
194            .iter()
195            .filter_map(|id| self.tasks.get(id))
196            .collect()
197    }
198}
199
200/// Shared context that can be accessed by all agents in the forest.
201#[derive(Debug, Clone)]
202pub struct SharedContext {
203    /// Key-value store for shared data.
204    pub data: HashMap<String, Value>,
205    /// Message history between agents.
206    pub message_history: Vec<ForestMessage>,
207    /// Global metadata.
208    pub metadata: HashMap<String, String>,
209    /// Current task plan being executed.
210    pub current_plan: Option<TaskPlan>,
211}
212
213impl SharedContext {
214    /// Creates a new empty shared context.
215    pub fn new() -> Self {
216        Self {
217            data: HashMap::new(),
218            message_history: Vec::new(),
219            metadata: HashMap::new(),
220            current_plan: None,
221        }
222    }
223
224    /// Sets a value in the shared context.
225    pub fn set(&mut self, key: String, value: Value) {
226        self.data.insert(key, value);
227    }
228
229    /// Gets a value from the shared context.
230    pub fn get(&self, key: &str) -> Option<&Value> {
231        self.data.get(key)
232    }
233
234    /// Removes a value from the shared context.
235    pub fn remove(&mut self, key: &str) -> Option<Value> {
236        self.data.remove(key)
237    }
238
239    /// Adds a message to the history.
240    pub fn add_message(&mut self, message: ForestMessage) {
241        self.message_history.push(message);
242    }
243
244    /// Gets recent messages (last N messages).
245    pub fn get_recent_messages(&self, limit: usize) -> &[ForestMessage] {
246        let len = self.message_history.len();
247        let start = len.saturating_sub(limit);
248        &self.message_history[start..]
249    }
250
251    /// Sets the current task plan.
252    pub fn set_plan(&mut self, plan: TaskPlan) {
253        self.current_plan = Some(plan);
254    }
255
256    /// Gets the current task plan.
257    pub fn get_plan(&self) -> Option<&TaskPlan> {
258        self.current_plan.as_ref()
259    }
260
261    /// Gets a mutable reference to the current task plan.
262    pub fn get_plan_mut(&mut self) -> Option<&mut TaskPlan> {
263        self.current_plan.as_mut()
264    }
265
266    /// Clears the current task plan.
267    pub fn clear_plan(&mut self) {
268        self.current_plan = None;
269    }
270}
271
272impl Default for SharedContext {
273    fn default() -> Self {
274        Self::new()
275    }
276}
277
278/// The main Forest of Agents structure that manages multiple agents.
279pub struct ForestOfAgents {
280    /// The agents in the forest, keyed by their IDs.
281    agents: HashMap<AgentId, Agent>,
282    /// Shared context accessible to all agents.
283    shared_context: Arc<RwLock<SharedContext>>,
284    /// Message queue for inter-agent communication.
285    message_queue: Arc<RwLock<Vec<ForestMessage>>>,
286    /// Maximum number of iterations for agent interactions.
287    max_iterations: usize,
288}
289
290impl ForestOfAgents {
291    /// Creates a new empty Forest of Agents.
292    pub fn new() -> Self {
293        Self {
294            agents: HashMap::new(),
295            shared_context: Arc::new(RwLock::new(SharedContext::new())),
296            message_queue: Arc::new(RwLock::new(Vec::new())),
297            max_iterations: 10,
298        }
299    }
300
301    /// Creates a new Forest of Agents with the specified max iterations.
302    pub fn with_max_iterations(max_iterations: usize) -> Self {
303        Self {
304            max_iterations,
305            ..Self::new()
306        }
307    }
308
309    /// Adds an agent to the forest.
310    ///
311    /// # Arguments
312    ///
313    /// * `id` - Unique identifier for the agent
314    /// * `agent` - The agent to add
315    ///
316    /// # Returns
317    ///
318    /// Returns an error if an agent with the same ID already exists.
319    pub fn add_agent(&mut self, id: AgentId, mut agent: Agent) -> Result<()> {
320        if self.agents.contains_key(&id) {
321            return Err(HeliosError::AgentError(format!(
322                "Agent with ID '{}' already exists",
323                id
324            )));
325        }
326
327        // Register communication tools for this agent
328        let send_message_tool = Box::new(SendMessageTool::new(
329            id.clone(),
330            Arc::clone(&self.message_queue),
331            Arc::clone(&self.shared_context),
332        ));
333        agent.register_tool(send_message_tool);
334
335        let delegate_task_tool = Box::new(DelegateTaskTool::new(
336            id.clone(),
337            Arc::clone(&self.message_queue),
338            Arc::clone(&self.shared_context),
339        ));
340        agent.register_tool(delegate_task_tool);
341
342        let share_context_tool = Box::new(ShareContextTool::new(
343            id.clone(),
344            Arc::clone(&self.shared_context),
345        ));
346        agent.register_tool(share_context_tool);
347
348        let update_task_memory_tool = Box::new(UpdateTaskMemoryTool::new(
349            id.clone(),
350            Arc::clone(&self.shared_context),
351        ));
352        agent.register_tool(update_task_memory_tool);
353
354        let create_plan_tool = Box::new(CreatePlanTool::new(
355            id.clone(),
356            Arc::clone(&self.shared_context),
357        ));
358        agent.register_tool(create_plan_tool);
359
360        self.agents.insert(id, agent);
361        Ok(())
362    }
363
364    /// Removes an agent from the forest.
365    ///
366    /// # Arguments
367    ///
368    /// * `id` - The ID of the agent to remove
369    ///
370    /// # Returns
371    ///
372    /// Returns the removed agent if it existed.
373    pub fn remove_agent(&mut self, id: &AgentId) -> Option<Agent> {
374        self.agents.remove(id)
375    }
376
377    /// Gets a reference to an agent by ID.
378    pub fn get_agent(&self, id: &AgentId) -> Option<&Agent> {
379        self.agents.get(id)
380    }
381
382    /// Gets a mutable reference to an agent by ID.
383    pub fn get_agent_mut(&mut self, id: &AgentId) -> Option<&mut Agent> {
384        self.agents.get_mut(id)
385    }
386
387    /// Lists all agent IDs in the forest.
388    pub fn list_agents(&self) -> Vec<AgentId> {
389        self.agents.keys().cloned().collect()
390    }
391
392    /// Sends a message from one agent to another.
393    ///
394    /// # Arguments
395    ///
396    /// * `from` - ID of the sending agent
397    /// * `to` - ID of the receiving agent (None for broadcast)
398    /// * `content` - Message content
399    ///
400    /// # Returns
401    ///
402    /// Returns an error if the sender doesn't exist.
403    pub async fn send_message(
404        &self,
405        from: &AgentId,
406        to: Option<&AgentId>,
407        content: String,
408    ) -> Result<()> {
409        if !self.agents.contains_key(from) {
410            return Err(HeliosError::AgentError(format!(
411                "Agent '{}' not found",
412                from
413            )));
414        }
415
416        let message = if let Some(to_id) = to {
417            ForestMessage::new(from.clone(), Some(to_id.clone()), content)
418        } else {
419            ForestMessage::broadcast(from.clone(), content)
420        };
421
422        let mut queue = self.message_queue.write().await;
423        queue.push(message.clone());
424
425        // Also add to shared context history
426        let mut context = self.shared_context.write().await;
427        context.add_message(message);
428
429        Ok(())
430    }
431
432    /// Processes pending messages in the queue.
433    pub async fn process_messages(&mut self) -> Result<()> {
434        let messages: Vec<ForestMessage> = {
435            let mut queue = self.message_queue.write().await;
436            queue.drain(..).collect()
437        };
438
439        for message in messages {
440            if let Some(to_id) = &message.to {
441                // Direct message
442                if let Some(agent) = self.agents.get_mut(to_id) {
443                    // Add the message as a user message to the agent's chat session
444                    let formatted_message =
445                        format!("Message from {}: {}", message.from, message.content);
446                    agent.chat_session_mut().add_user_message(formatted_message);
447                }
448            } else {
449                // Broadcast message - send to all agents except sender
450                for (agent_id, agent) in &mut self.agents {
451                    if agent_id != &message.from {
452                        let formatted_message =
453                            format!("Broadcast from {}: {}", message.from, message.content);
454                        agent.chat_session_mut().add_user_message(formatted_message);
455                    }
456                }
457            }
458        }
459
460        Ok(())
461    }
462
463    /// Executes a collaborative task across multiple agents with planning.
464    ///
465    /// # Arguments
466    ///
467    /// * `initiator` - ID of the coordinator agent (must create the plan)
468    /// * `task_description` - Description of the overall task
469    /// * `involved_agents` - IDs of agents available for task execution
470    ///
471    /// # Returns
472    ///
473    /// Returns the final result from the collaborative process.
474    pub async fn execute_collaborative_task(
475        &mut self,
476        initiator: &AgentId,
477        task_description: String,
478        involved_agents: Vec<AgentId>,
479    ) -> Result<String> {
480        // Verify all involved agents exist
481        for agent_id in &involved_agents {
482            if !self.agents.contains_key(agent_id) {
483                return Err(HeliosError::AgentError(format!(
484                    "Agent '{}' not found",
485                    agent_id
486                )));
487            }
488        }
489
490        if !self.agents.contains_key(initiator) {
491            return Err(HeliosError::AgentError(format!(
492                "Initiator agent '{}' not found",
493                initiator
494            )));
495        }
496        // Phase 1: Coordinator creates a plan
497        {
498            let mut context = self.shared_context.write().await;
499            context.set(
500                "current_task".to_string(),
501                Value::String(task_description.clone()),
502            );
503            context.set(
504                "involved_agents".to_string(),
505                Value::Array(
506                    involved_agents
507                        .iter()
508                        .map(|id| Value::String(id.clone()))
509                        .collect(),
510                ),
511            );
512            context.set(
513                "task_status".to_string(),
514                Value::String("planning".to_string()),
515            );
516        }
517
518        let coordinator = self.agents.get_mut(initiator).unwrap();
519        let planning_prompt = format!(
520            "You are coordinating a collaborative task. Create a detailed plan using the 'create_plan' tool.\n\n\
521            Task: {}\n\n\
522            Available team members and their expertise:\n{}\n\n\
523            Break this task into subtasks and assign each to the most appropriate agent. \
524            Use the create_plan tool with a JSON array of tasks. Each task should have:\n\
525            - id: unique identifier (e.g., 'task_1')\n\
526            - description: what needs to be done\n\
527            - assigned_to: agent name\n\
528            - dependencies: array of task IDs that must complete first (use [] if none)\n\n\
529            IMPORTANT: You MUST use the create_plan tool to create a plan before doing anything else. \
530            Do not try to complete the task yourself - just create the plan using the tool.",
531            task_description,
532            involved_agents.join(", ")
533        );
534
535        let _planning_result = coordinator.chat(planning_prompt).await?;
536
537        // Check if plan was actually created
538        let plan_exists = {
539            let context = self.shared_context.read().await;
540            context.get_plan().is_some()
541        };
542
543        if !plan_exists {
544            // Fallback: coordinator handles it directly
545            return Ok(_planning_result);
546        }
547
548        // Phase 2: Execute tasks according to the plan
549        let mut iteration = 0;
550        let max_task_iterations = self.max_iterations * 3; // Allow more iterations for complex plans
551
552        while iteration < max_task_iterations {
553            // Get next ready tasks
554            let ready_tasks: Vec<(String, String, AgentId)> = {
555                let context = self.shared_context.read().await;
556                if let Some(plan) = context.get_plan() {
557                    if plan.is_complete() {
558                        break;
559                    }
560                    plan.get_next_ready_tasks()
561                        .iter()
562                        .map(|t| (t.id.clone(), t.description.clone(), t.assigned_to.clone()))
563                        .collect()
564                } else {
565                    // No plan created, fall back to simple delegation
566                    let initiator_agent = self.agents.get_mut(initiator).unwrap();
567                    let result = initiator_agent
568                        .chat(format!(
569                            "Complete this task: {}\nYou can delegate to: {}",
570                            task_description,
571                            involved_agents.join(", ")
572                        ))
573                        .await?;
574                    return Ok(result);
575                }
576            };
577
578            if ready_tasks.is_empty() {
579                // Check if we're waiting for in-progress tasks
580                let has_in_progress = {
581                    let context = self.shared_context.read().await;
582                    context
583                        .get_plan()
584                        .map(|p| p.tasks.values().any(|t| t.status == TaskStatus::InProgress))
585                        .unwrap_or(false)
586                };
587
588                if !has_in_progress {
589                    break; // No tasks ready and none in progress
590                }
591
592                // Avoid infinite loop - if we've waited too long, break
593                if iteration > 5 {
594                    break;
595                }
596
597                tokio::time::sleep(std::time::Duration::from_millis(100)).await;
598                iteration += 1;
599                continue;
600            }
601
602            // Execute ready tasks
603            for (task_id, task_desc, agent_id) in ready_tasks {
604                // Mark task as in progress
605                {
606                    let mut context = self.shared_context.write().await;
607                    if let Some(plan) = context.get_plan_mut() {
608                        if let Some(task) = plan.get_task_mut(&task_id) {
609                            task.status = TaskStatus::InProgress;
610                        }
611                    }
612                }
613
614                // Get shared memory context for the agent
615                let shared_memory_info = {
616                    let context = self.shared_context.read().await;
617                    let mut info = String::from("\n=== SHARED TASK MEMORY ===\n");
618
619                    if let Some(plan) = context.get_plan() {
620                        info.push_str(&format!("Overall Objective: {}\n", plan.objective));
621                        info.push_str(&format!(
622                            "Progress: {}/{} tasks completed\n\n",
623                            plan.get_progress().0,
624                            plan.get_progress().1
625                        ));
626
627                        info.push_str("Completed Tasks:\n");
628                        for task in plan.tasks_in_order() {
629                            if task.status == TaskStatus::Completed {
630                                info.push_str(&format!(
631                                    "  ✓ [{}] {}: {}\n",
632                                    task.assigned_to,
633                                    task.description,
634                                    task.result.as_ref().unwrap_or(&"No result".to_string())
635                                ));
636                            }
637                        }
638                    }
639
640                    info.push_str("\nShared Data:\n");
641                    for (key, value) in &context.data {
642                        if !key.starts_with("current_task")
643                            && !key.starts_with("involved_agents")
644                            && !key.starts_with("task_status")
645                        {
646                            info.push_str(&format!("  • {}: {}\n", key, value));
647                        }
648                    }
649                    info.push_str("=========================\n\n");
650                    info
651                };
652
653                // Execute the task
654                if let Some(agent) = self.agents.get_mut(&agent_id) {
655                    let task_prompt = format!(
656                        "{}Your assigned task: {}\n\n\
657                        Complete this task and use the 'update_task_memory' tool to save your results to the shared memory. \
658                        The task_id is '{}'. Include key findings and data that other agents might need.\n\n\
659                        Provide a complete response with your results.",
660                        shared_memory_info, task_desc, task_id
661                    );
662
663                    let result = agent.chat(task_prompt).await?;
664
665                    // If agent didn't update memory, do it automatically
666                    {
667                        let mut context = self.shared_context.write().await;
668                        if let Some(plan) = context.get_plan_mut() {
669                            if let Some(task) = plan.get_task_mut(&task_id) {
670                                if task.status == TaskStatus::InProgress {
671                                    task.status = TaskStatus::Completed;
672                                    task.result = Some(result.clone());
673                                }
674                            }
675                        }
676                    }
677                }
678            }
679
680            iteration += 1;
681        }
682
683        // Phase 3: Coordinator synthesizes final result
684
685        let final_summary = {
686            let context = self.shared_context.read().await;
687            let mut summary = String::from("=== TASK COMPLETION SUMMARY ===\n\n");
688
689            if let Some(plan) = context.get_plan() {
690                summary.push_str(&format!("Objective: {}\n", plan.objective));
691                summary.push_str(&format!(
692                    "Status: All tasks completed ({}/{} tasks)\n\n",
693                    plan.get_progress().0,
694                    plan.get_progress().1
695                ));
696
697                summary.push_str("Task Results:\n");
698                for task in plan.tasks_in_order() {
699                    summary.push_str(&format!("\n[{}] {}\n", task.assigned_to, task.description));
700                    if let Some(result) = &task.result {
701                        summary.push_str(&format!("Result: {}\n", result));
702                    }
703                }
704            }
705            summary
706        };
707
708        let coordinator = self.agents.get_mut(initiator).unwrap();
709        let synthesis_prompt = format!(
710            "Based on the completed tasks, provide a comprehensive final answer to the original request.\n\n\
711            Original Task: {}\n\n\
712            {}\n\n\
713            Synthesize all the information into a cohesive, complete response.",
714            task_description, final_summary
715        );
716
717        let final_result = coordinator.chat(synthesis_prompt).await?;
718
719        // Mark overall task as completed
720        {
721            let mut context = self.shared_context.write().await;
722            context.set(
723                "task_status".to_string(),
724                Value::String("completed".to_string()),
725            );
726        }
727
728        Ok(final_result)
729    }
730
731    /// Processes pending messages and triggers responses from agents.
732    ///
733    /// This method iterates through pending messages, delivers them to recipient agents,
734    /// and triggers their responses. It continues until no more messages are generated
735    /// or max_iterations is reached.
736    #[allow(dead_code)]
737    async fn process_messages_and_trigger_responses(
738        &mut self,
739        involved_agents: &[AgentId],
740    ) -> Result<()> {
741        let mut iteration = 0;
742
743        while iteration < self.max_iterations {
744            // First, deliver all pending messages
745            self.process_messages().await?;
746
747            // Track agents that received new messages and need to respond
748            let mut agents_to_respond = Vec::new();
749
750            for agent_id in involved_agents {
751                if let Some(agent) = self.agents.get(agent_id) {
752                    let messages = &agent.chat_session().messages;
753                    if !messages.is_empty() {
754                        let last_message = messages.last().unwrap();
755                        // If the last message is from a user (another agent), this agent should respond
756                        if last_message.role == crate::chat::Role::User {
757                            agents_to_respond.push(agent_id.clone());
758                        }
759                    }
760                }
761            }
762
763            // If no agents need to respond, we're done
764            if agents_to_respond.is_empty() {
765                break;
766            }
767
768            // Have each agent respond to their messages
769            for agent_id in agents_to_respond {
770                if let Some(agent) = self.agents.get_mut(&agent_id) {
771                    // Agent processes the message and may use tools to delegate or send messages
772                    let _response = agent.chat("").await?;
773                }
774            }
775
776            iteration += 1;
777        }
778
779        Ok(())
780    }
781
782    /// Gets the shared context.
783    pub async fn get_shared_context(&self) -> SharedContext {
784        self.shared_context.read().await.clone()
785    }
786
787    /// Sets a value in the shared context.
788    pub async fn set_shared_context(&self, key: String, value: Value) {
789        let mut context = self.shared_context.write().await;
790        context.set(key, value);
791    }
792}
793
794impl Default for ForestOfAgents {
795    fn default() -> Self {
796        Self::new()
797    }
798}
799
800/// A tool that allows agents to send messages to other agents.
801pub struct SendMessageTool {
802    agent_id: AgentId,
803    message_queue: Arc<RwLock<Vec<ForestMessage>>>,
804    shared_context: Arc<RwLock<SharedContext>>,
805}
806
807impl SendMessageTool {
808    /// Creates a new SendMessageTool.
809    pub fn new(
810        agent_id: AgentId,
811        message_queue: Arc<RwLock<Vec<ForestMessage>>>,
812        shared_context: Arc<RwLock<SharedContext>>,
813    ) -> Self {
814        Self {
815            agent_id,
816            message_queue,
817            shared_context,
818        }
819    }
820}
821
822#[async_trait::async_trait]
823impl Tool for SendMessageTool {
824    fn name(&self) -> &str {
825        "send_message"
826    }
827
828    fn description(&self) -> &str {
829        "Send a message to another agent or broadcast to all agents in the forest."
830    }
831
832    fn parameters(&self) -> HashMap<String, ToolParameter> {
833        let mut params = HashMap::new();
834        params.insert(
835            "to".to_string(),
836            ToolParameter {
837                param_type: "string".to_string(),
838                description: "ID of the recipient agent (leave empty for broadcast)".to_string(),
839                required: Some(false),
840            },
841        );
842        params.insert(
843            "message".to_string(),
844            ToolParameter {
845                param_type: "string".to_string(),
846                description: "The message content to send".to_string(),
847                required: Some(true),
848            },
849        );
850        params
851    }
852
853    async fn execute(&self, args: Value) -> Result<ToolResult> {
854        let message = args
855            .get("message")
856            .and_then(|v| v.as_str())
857            .ok_or_else(|| HeliosError::ToolError("Missing 'message' parameter".to_string()))?
858            .to_string();
859
860        let to = args
861            .get("to")
862            .and_then(|v| v.as_str())
863            .map(|s| s.to_string());
864
865        let forest_message = if let Some(to_id) = &to {
866            ForestMessage::new(self.agent_id.clone(), Some(to_id.clone()), message)
867        } else {
868            ForestMessage::broadcast(self.agent_id.clone(), message)
869        };
870
871        {
872            let mut queue = self.message_queue.write().await;
873            queue.push(forest_message.clone());
874        }
875
876        {
877            let mut context = self.shared_context.write().await;
878            context.add_message(forest_message);
879        }
880
881        Ok(ToolResult::success("Message sent successfully"))
882    }
883}
884
885/// A tool that allows agents to delegate tasks to other agents.
886pub struct DelegateTaskTool {
887    agent_id: AgentId,
888    message_queue: Arc<RwLock<Vec<ForestMessage>>>,
889    shared_context: Arc<RwLock<SharedContext>>,
890}
891
892impl DelegateTaskTool {
893    /// Creates a new DelegateTaskTool.
894    pub fn new(
895        agent_id: AgentId,
896        message_queue: Arc<RwLock<Vec<ForestMessage>>>,
897        shared_context: Arc<RwLock<SharedContext>>,
898    ) -> Self {
899        Self {
900            agent_id,
901            message_queue,
902            shared_context,
903        }
904    }
905}
906
907#[async_trait::async_trait]
908impl Tool for DelegateTaskTool {
909    fn name(&self) -> &str {
910        "delegate_task"
911    }
912
913    fn description(&self) -> &str {
914        "Delegate a specific task to another agent for execution."
915    }
916
917    fn parameters(&self) -> HashMap<String, ToolParameter> {
918        let mut params = HashMap::new();
919        params.insert(
920            "to".to_string(),
921            ToolParameter {
922                param_type: "string".to_string(),
923                description: "ID of the agent to delegate the task to".to_string(),
924                required: Some(true),
925            },
926        );
927        params.insert(
928            "task".to_string(),
929            ToolParameter {
930                param_type: "string".to_string(),
931                description: "Description of the task to delegate".to_string(),
932                required: Some(true),
933            },
934        );
935        params.insert(
936            "context".to_string(),
937            ToolParameter {
938                param_type: "string".to_string(),
939                description: "Additional context or requirements for the task".to_string(),
940                required: Some(false),
941            },
942        );
943        params
944    }
945
946    async fn execute(&self, args: Value) -> Result<ToolResult> {
947        let to = args
948            .get("to")
949            .and_then(|v| v.as_str())
950            .ok_or_else(|| HeliosError::ToolError("Missing 'to' parameter".to_string()))?;
951
952        let task = args
953            .get("task")
954            .and_then(|v| v.as_str())
955            .ok_or_else(|| HeliosError::ToolError("Missing 'task' parameter".to_string()))?;
956
957        let context = args.get("context").and_then(|v| v.as_str()).unwrap_or("");
958
959        let message = if context.is_empty() {
960            format!("Task delegated: {}", task)
961        } else {
962            format!("Task delegated: {}\nContext: {}", task, context)
963        };
964
965        let forest_message =
966            ForestMessage::new(self.agent_id.clone(), Some(to.to_string()), message)
967                .with_metadata("type".to_string(), "task_delegation".to_string())
968                .with_metadata("task".to_string(), task.to_string());
969
970        {
971            let mut queue = self.message_queue.write().await;
972            queue.push(forest_message.clone());
973        }
974
975        {
976            let mut context_lock = self.shared_context.write().await;
977            context_lock.add_message(forest_message);
978        }
979
980        Ok(ToolResult::success(format!(
981            "Task delegated to agent '{}'",
982            to
983        )))
984    }
985}
986
987/// A tool that allows agents to share information in the shared context.
988pub struct ShareContextTool {
989    agent_id: AgentId,
990    shared_context: Arc<RwLock<SharedContext>>,
991}
992
993impl ShareContextTool {
994    /// Creates a new ShareContextTool.
995    pub fn new(agent_id: AgentId, shared_context: Arc<RwLock<SharedContext>>) -> Self {
996        Self {
997            agent_id,
998            shared_context,
999        }
1000    }
1001}
1002
1003#[async_trait::async_trait]
1004impl Tool for ShareContextTool {
1005    fn name(&self) -> &str {
1006        "share_context"
1007    }
1008
1009    fn description(&self) -> &str {
1010        "Share information in the shared context that all agents can access."
1011    }
1012
1013    fn parameters(&self) -> HashMap<String, ToolParameter> {
1014        let mut params = HashMap::new();
1015        params.insert(
1016            "key".to_string(),
1017            ToolParameter {
1018                param_type: "string".to_string(),
1019                description: "Key for the shared information".to_string(),
1020                required: Some(true),
1021            },
1022        );
1023        params.insert(
1024            "value".to_string(),
1025            ToolParameter {
1026                param_type: "string".to_string(),
1027                description: "Value to share".to_string(),
1028                required: Some(true),
1029            },
1030        );
1031        params.insert(
1032            "description".to_string(),
1033            ToolParameter {
1034                param_type: "string".to_string(),
1035                description: "Description of what this information represents".to_string(),
1036                required: Some(false),
1037            },
1038        );
1039        params
1040    }
1041
1042    async fn execute(&self, args: Value) -> Result<ToolResult> {
1043        let key = args
1044            .get("key")
1045            .and_then(|v| v.as_str())
1046            .ok_or_else(|| HeliosError::ToolError("Missing 'key' parameter".to_string()))?;
1047
1048        let value = args
1049            .get("value")
1050            .and_then(|v| v.as_str())
1051            .ok_or_else(|| HeliosError::ToolError("Missing 'value' parameter".to_string()))?;
1052
1053        let description = args
1054            .get("description")
1055            .and_then(|v| v.as_str())
1056            .unwrap_or("");
1057
1058        let mut context = self.shared_context.write().await;
1059
1060        // Store the value with its metadata in a nested object
1061        let metadata = serde_json::json!({
1062            "shared_by": self.agent_id,
1063            "timestamp": chrono::Utc::now().to_rfc3339(),
1064            "description": description
1065        });
1066
1067        let value_with_meta = serde_json::json!({
1068            "value": value,
1069            "metadata": metadata
1070        });
1071
1072        context.set(key.to_string(), value_with_meta);
1073
1074        Ok(ToolResult::success(format!(
1075            "Information shared with key '{}'",
1076            key
1077        )))
1078    }
1079}
1080
1081/// A tool for updating task memory with results and findings.
1082pub struct UpdateTaskMemoryTool {
1083    agent_id: AgentId,
1084    shared_context: Arc<RwLock<SharedContext>>,
1085}
1086
1087impl UpdateTaskMemoryTool {
1088    pub fn new(agent_id: AgentId, shared_context: Arc<RwLock<SharedContext>>) -> Self {
1089        Self {
1090            agent_id,
1091            shared_context,
1092        }
1093    }
1094}
1095
1096#[async_trait::async_trait]
1097impl Tool for UpdateTaskMemoryTool {
1098    fn name(&self) -> &str {
1099        "update_task_memory"
1100    }
1101
1102    fn description(&self) -> &str {
1103        "Update the shared task memory with your results, findings, and data. This allows other agents to see your progress and use your outputs."
1104    }
1105
1106    fn parameters(&self) -> HashMap<String, ToolParameter> {
1107        let mut params = HashMap::new();
1108        params.insert(
1109            "task_id".to_string(),
1110            ToolParameter {
1111                param_type: "string".to_string(),
1112                description: "The ID of the task you're updating (e.g., 'task_1')".to_string(),
1113                required: Some(true),
1114            },
1115        );
1116        params.insert(
1117            "result".to_string(),
1118            ToolParameter {
1119                param_type: "string".to_string(),
1120                description: "Your results, findings, or output from completing the task"
1121                    .to_string(),
1122                required: Some(true),
1123            },
1124        );
1125        params.insert(
1126            "data".to_string(),
1127            ToolParameter {
1128                param_type: "string".to_string(),
1129                description: "Additional data or information to share (e.g., key findings, metrics, recommendations)".to_string(),
1130                required: Some(false),
1131            },
1132        );
1133        params
1134    }
1135
1136    async fn execute(&self, args: Value) -> Result<ToolResult> {
1137        let task_id = args
1138            .get("task_id")
1139            .and_then(|v| v.as_str())
1140            .ok_or_else(|| HeliosError::ToolError("Missing 'task_id' parameter".to_string()))?;
1141
1142        let result = args
1143            .get("result")
1144            .and_then(|v| v.as_str())
1145            .ok_or_else(|| HeliosError::ToolError("Missing 'result' parameter".to_string()))?;
1146
1147        let additional_data = args.get("data").and_then(|v| v.as_str()).unwrap_or("");
1148
1149        let mut context = self.shared_context.write().await;
1150
1151        // Update the task in the plan
1152        if let Some(plan) = context.get_plan_mut() {
1153            if let Some(task) = plan.get_task_mut(task_id) {
1154                task.status = TaskStatus::Completed;
1155                task.result = Some(result.to_string());
1156                let task_description = task.description.clone();
1157
1158                // Also store in shared data for easy access
1159                if !additional_data.is_empty() {
1160                    let data_key = format!("task_data_{}", task_id);
1161                    context.set(
1162                        data_key,
1163                        serde_json::json!({
1164                            "agent": self.agent_id,
1165                            "task": task_description,
1166                            "data": additional_data,
1167                            "timestamp": chrono::Utc::now().to_rfc3339()
1168                        }),
1169                    );
1170                }
1171
1172                return Ok(ToolResult::success(format!(
1173                    "Task '{}' marked as completed. Results saved to shared memory.",
1174                    task_id
1175                )));
1176            } else {
1177                return Err(HeliosError::ToolError(format!(
1178                    "Task '{}' not found in current plan",
1179                    task_id
1180                )));
1181            }
1182        }
1183
1184        Err(HeliosError::ToolError(
1185            "No active task plan found".to_string(),
1186        ))
1187    }
1188}
1189
1190/// A tool for the coordinator to create a task plan.
1191pub struct CreatePlanTool {
1192    #[allow(dead_code)]
1193    agent_id: AgentId,
1194    shared_context: Arc<RwLock<SharedContext>>,
1195}
1196
1197impl CreatePlanTool {
1198    pub fn new(agent_id: AgentId, shared_context: Arc<RwLock<SharedContext>>) -> Self {
1199        Self {
1200            agent_id,
1201            shared_context,
1202        }
1203    }
1204}
1205
1206#[async_trait::async_trait]
1207impl Tool for CreatePlanTool {
1208    fn name(&self) -> &str {
1209        "create_plan"
1210    }
1211
1212    fn description(&self) -> &str {
1213        "Create a detailed task plan for collaborative work. Break down the overall objective into specific tasks and assign them to team members."
1214    }
1215
1216    fn parameters(&self) -> HashMap<String, ToolParameter> {
1217        let mut params = HashMap::new();
1218        params.insert(
1219            "objective".to_string(),
1220            ToolParameter {
1221                param_type: "string".to_string(),
1222                description: "The overall objective or goal of the plan".to_string(),
1223                required: Some(true),
1224            },
1225        );
1226        params.insert(
1227            "tasks".to_string(),
1228            ToolParameter {
1229                param_type: "string".to_string(),
1230                description: "JSON array of tasks. Each task must have: id (string), description (string), assigned_to (string), dependencies (array of task IDs)".to_string(),
1231                required: Some(true),
1232            },
1233        );
1234        params
1235    }
1236
1237    async fn execute(&self, args: Value) -> Result<ToolResult> {
1238        let objective = args
1239            .get("objective")
1240            .and_then(|v| v.as_str())
1241            .ok_or_else(|| HeliosError::ToolError("Missing 'objective' parameter".to_string()))?;
1242
1243        let tasks_json = args
1244            .get("tasks")
1245            .and_then(|v| v.as_str())
1246            .ok_or_else(|| HeliosError::ToolError("Missing 'tasks' parameter".to_string()))?;
1247
1248        // Parse the tasks JSON
1249        let tasks_array: Vec<Value> = serde_json::from_str(tasks_json)
1250            .map_err(|e| HeliosError::ToolError(format!("Invalid JSON for tasks: {}", e)))?;
1251
1252        let plan_id = format!("plan_{}", chrono::Utc::now().timestamp());
1253        let mut plan = TaskPlan::new(plan_id.clone(), objective.to_string());
1254
1255        for task_value in tasks_array {
1256            let task_obj = task_value.as_object().ok_or_else(|| {
1257                HeliosError::ToolError("Each task must be a JSON object".to_string())
1258            })?;
1259
1260            let id = task_obj
1261                .get("id")
1262                .and_then(|v| v.as_str())
1263                .ok_or_else(|| HeliosError::ToolError("Task missing 'id' field".to_string()))?
1264                .to_string();
1265
1266            let description = task_obj
1267                .get("description")
1268                .and_then(|v| v.as_str())
1269                .ok_or_else(|| {
1270                    HeliosError::ToolError("Task missing 'description' field".to_string())
1271                })?
1272                .to_string();
1273
1274            let assigned_to = task_obj
1275                .get("assigned_to")
1276                .and_then(|v| v.as_str())
1277                .ok_or_else(|| {
1278                    HeliosError::ToolError("Task missing 'assigned_to' field".to_string())
1279                })?
1280                .to_string();
1281
1282            let dependencies = task_obj
1283                .get("dependencies")
1284                .and_then(|v| v.as_array())
1285                .map(|arr| {
1286                    arr.iter()
1287                        .filter_map(|v| v.as_str().map(|s| s.to_string()))
1288                        .collect::<Vec<String>>()
1289                })
1290                .unwrap_or_else(Vec::new);
1291
1292            let task = TaskItem::new(id, description, assigned_to).with_dependencies(dependencies);
1293            plan.add_task(task);
1294        }
1295
1296        let mut context = self.shared_context.write().await;
1297        context.set_plan(plan.clone());
1298
1299        let task_summary = plan
1300            .tasks_in_order()
1301            .iter()
1302            .map(|t| {
1303                format!(
1304                    "  • [{}] {} (assigned to: {})",
1305                    t.id, t.description, t.assigned_to
1306                )
1307            })
1308            .collect::<Vec<_>>()
1309            .join("\n");
1310
1311        Ok(ToolResult::success(format!(
1312            "Plan created with {} tasks:\n{}",
1313            plan.tasks.len(),
1314            task_summary
1315        )))
1316    }
1317}
1318
1319/// Builder for creating a Forest of Agents with multiple agents.
1320pub struct ForestBuilder {
1321    config: Option<Config>,
1322    agents: Vec<(AgentId, AgentBuilder)>,
1323    max_iterations: usize,
1324}
1325
1326impl ForestBuilder {
1327    /// Creates a new ForestBuilder.
1328    pub fn new() -> Self {
1329        Self {
1330            config: None,
1331            agents: Vec::new(),
1332            max_iterations: 10,
1333        }
1334    }
1335
1336    /// Sets the configuration for all agents in the forest.
1337    pub fn config(mut self, config: Config) -> Self {
1338        self.config = Some(config);
1339        self
1340    }
1341
1342    /// Adds an agent to the forest with a builder.
1343    pub fn agent(mut self, id: AgentId, builder: AgentBuilder) -> Self {
1344        self.agents.push((id, builder));
1345        self
1346    }
1347
1348    /// Sets the maximum iterations for agent interactions.
1349    pub fn max_iterations(mut self, max: usize) -> Self {
1350        self.max_iterations = max;
1351        self
1352    }
1353
1354    /// Builds the Forest of Agents.
1355    pub async fn build(self) -> Result<ForestOfAgents> {
1356        let config = self
1357            .config
1358            .ok_or_else(|| HeliosError::AgentError("Config is required".to_string()))?;
1359
1360        let mut forest = ForestOfAgents::with_max_iterations(self.max_iterations);
1361
1362        for (id, builder) in self.agents {
1363            let agent = builder.config(config.clone()).build().await?;
1364            forest.add_agent(id, agent)?;
1365        }
1366
1367        Ok(forest)
1368    }
1369}
1370
1371impl Default for ForestBuilder {
1372    fn default() -> Self {
1373        Self::new()
1374    }
1375}
1376
1377#[cfg(test)]
1378mod tests {
1379    use super::*;
1380    use crate::config::Config;
1381    use crate::tools::Tool;
1382    use serde_json::Value;
1383
1384    /// Tests basic ForestOfAgents creation and agent management.
1385    #[tokio::test]
1386    async fn test_forest_creation_and_agent_management() {
1387        let mut forest = ForestOfAgents::new();
1388        let config = Config::new_default();
1389
1390        // Create and add agents
1391        let agent1 = Agent::builder("agent1")
1392            .config(config.clone())
1393            .system_prompt("You are agent 1")
1394            .build()
1395            .await
1396            .unwrap();
1397
1398        let agent2 = Agent::builder("agent2")
1399            .config(config)
1400            .system_prompt("You are agent 2")
1401            .build()
1402            .await
1403            .unwrap();
1404
1405        // Add agents to forest
1406        forest.add_agent("agent1".to_string(), agent1).unwrap();
1407        forest.add_agent("agent2".to_string(), agent2).unwrap();
1408
1409        // Test agent listing
1410        let agents = forest.list_agents();
1411        assert_eq!(agents.len(), 2);
1412        assert!(agents.contains(&"agent1".to_string()));
1413        assert!(agents.contains(&"agent2".to_string()));
1414
1415        // Test agent retrieval
1416        assert!(forest.get_agent(&"agent1".to_string()).is_some());
1417        assert!(forest.get_agent(&"agent3".to_string()).is_none());
1418
1419        // Test duplicate agent addition
1420        let agent3 = Agent::builder("agent3")
1421            .config(Config::new_default())
1422            .build()
1423            .await
1424            .unwrap();
1425        let result = forest.add_agent("agent1".to_string(), agent3);
1426        assert!(result.is_err());
1427
1428        // Test agent removal
1429        let removed = forest.remove_agent(&"agent1".to_string());
1430        assert!(removed.is_some());
1431        assert_eq!(forest.list_agents().len(), 1);
1432        assert!(forest.get_agent(&"agent1".to_string()).is_none());
1433    }
1434
1435    /// Tests message sending between agents.
1436    #[tokio::test]
1437    async fn test_message_sending() {
1438        let mut forest = ForestOfAgents::new();
1439        let config = Config::new_default();
1440
1441        // Create and add agents
1442        let agent1 = Agent::builder("alice")
1443            .config(config.clone())
1444            .build()
1445            .await
1446            .unwrap();
1447
1448        let agent2 = Agent::builder("bob").config(config).build().await.unwrap();
1449
1450        forest.add_agent("alice".to_string(), agent1).unwrap();
1451        forest.add_agent("bob".to_string(), agent2).unwrap();
1452
1453        // Test direct message
1454        forest
1455            .send_message(
1456                &"alice".to_string(),
1457                Some(&"bob".to_string()),
1458                "Hello Bob!".to_string(),
1459            )
1460            .await
1461            .unwrap();
1462
1463        // Process messages
1464        forest.process_messages().await.unwrap();
1465
1466        // Check that Bob received the message
1467        let bob = forest.get_agent(&"bob".to_string()).unwrap();
1468        let messages = bob.chat_session().messages.clone();
1469        assert!(!messages.is_empty());
1470        let last_message = messages.last().unwrap();
1471        assert_eq!(last_message.role, crate::chat::Role::User);
1472        assert!(last_message
1473            .content
1474            .contains("Message from alice: Hello Bob!"));
1475
1476        // Test broadcast message
1477        let alice_message_count_before = forest
1478            .get_agent(&"alice".to_string())
1479            .unwrap()
1480            .chat_session()
1481            .messages
1482            .len();
1483        forest
1484            .send_message(&"alice".to_string(), None, "Hello everyone!".to_string())
1485            .await
1486            .unwrap();
1487        forest.process_messages().await.unwrap();
1488
1489        // Check that Bob received the broadcast, but Alice did not
1490        let alice = forest.get_agent(&"alice".to_string()).unwrap();
1491        assert_eq!(
1492            alice.chat_session().messages.len(),
1493            alice_message_count_before
1494        );
1495
1496        let bob = forest.get_agent(&"bob".to_string()).unwrap();
1497        let bob_messages = bob.chat_session().messages.clone();
1498        let bob_last = bob_messages.last().unwrap();
1499        assert!(bob_last
1500            .content
1501            .contains("Broadcast from alice: Hello everyone!"));
1502    }
1503
1504    /// Tests the SendMessageTool functionality.
1505    #[tokio::test]
1506    async fn test_send_message_tool() {
1507        let message_queue = Arc::new(RwLock::new(Vec::<ForestMessage>::new()));
1508        let shared_context = Arc::new(RwLock::new(SharedContext::new()));
1509
1510        let tool = SendMessageTool::new(
1511            "alice".to_string(),
1512            message_queue.clone(),
1513            shared_context.clone(),
1514        );
1515
1516        // Test sending a direct message
1517        let args = serde_json::json!({
1518            "to": "bob",
1519            "message": "Test message"
1520        });
1521
1522        let result = tool.execute(args).await.unwrap();
1523        assert!(result.success);
1524        assert_eq!(result.output, "Message sent successfully");
1525
1526        // Check message queue
1527        let queue = message_queue.read().await;
1528        assert_eq!(queue.len(), 1);
1529        let message = &queue[0];
1530        assert_eq!(message.from, "alice");
1531        assert_eq!(message.to, Some("bob".to_string()));
1532        assert_eq!(message.content, "Test message");
1533
1534        // Check shared context
1535        let context = shared_context.read().await;
1536        let messages = context.get_recent_messages(10);
1537        assert_eq!(messages.len(), 1);
1538        assert_eq!(messages[0].from, "alice");
1539
1540        // TODO: Test broadcast message - currently causes hang
1541        // The direct message functionality works correctly
1542    }
1543
1544    /// Tests the DelegateTaskTool functionality.
1545    #[tokio::test]
1546    async fn test_delegate_task_tool() {
1547        let message_queue = Arc::new(RwLock::new(Vec::new()));
1548        let shared_context = Arc::new(RwLock::new(SharedContext::new()));
1549
1550        let tool = DelegateTaskTool::new(
1551            "manager".to_string(),
1552            Arc::clone(&message_queue),
1553            Arc::clone(&shared_context),
1554        );
1555
1556        // Test task delegation
1557        let args = serde_json::json!({
1558            "to": "worker",
1559            "task": "Analyze the data",
1560            "context": "Use statistical methods"
1561        });
1562
1563        let result = tool.execute(args).await.unwrap();
1564        assert!(result.success);
1565        assert_eq!(result.output, "Task delegated to agent 'worker'");
1566
1567        // Check message queue
1568        let queue = message_queue.read().await;
1569        assert_eq!(queue.len(), 1);
1570        let message = &queue[0];
1571        assert_eq!(message.from, "manager");
1572        assert_eq!(message.to, Some("worker".to_string()));
1573        assert!(message.content.contains("Task delegated: Analyze the data"));
1574        assert!(message.content.contains("Context: Use statistical methods"));
1575
1576        // Check metadata
1577        assert_eq!(
1578            message.metadata.get("type"),
1579            Some(&"task_delegation".to_string())
1580        );
1581        assert_eq!(
1582            message.metadata.get("task"),
1583            Some(&"Analyze the data".to_string())
1584        );
1585    }
1586
1587    /// Tests the ShareContextTool functionality.
1588    #[tokio::test]
1589    async fn test_share_context_tool() {
1590        let shared_context = Arc::new(RwLock::new(SharedContext::new()));
1591
1592        let tool = ShareContextTool::new("researcher".to_string(), Arc::clone(&shared_context));
1593
1594        // Test sharing context
1595        let args = serde_json::json!({
1596            "key": "findings",
1597            "value": "Temperature affects reaction rate",
1598            "description": "Key experimental finding"
1599        });
1600
1601        let result = tool.execute(args).await.unwrap();
1602        assert!(result.success);
1603        assert_eq!(result.output, "Information shared with key 'findings'");
1604
1605        // Check shared context
1606        let context = shared_context.read().await;
1607        let findings_data = context.get("findings").unwrap();
1608        let findings_obj = findings_data.as_object().unwrap();
1609
1610        // Check the value
1611        assert_eq!(
1612            findings_obj.get("value").unwrap(),
1613            &Value::String("Temperature affects reaction rate".to_string())
1614        );
1615
1616        // Check metadata
1617        let metadata = findings_obj.get("metadata").unwrap();
1618        let metadata_obj = metadata.as_object().unwrap();
1619        assert_eq!(
1620            metadata_obj.get("shared_by").unwrap(),
1621            &Value::String("researcher".to_string())
1622        );
1623        assert_eq!(
1624            metadata_obj.get("description").unwrap(),
1625            &Value::String("Key experimental finding".to_string())
1626        );
1627        assert!(metadata_obj.contains_key("timestamp"));
1628    }
1629
1630    /// Tests the SharedContext functionality.
1631    #[tokio::test]
1632    async fn test_shared_context() {
1633        let mut context = SharedContext::new();
1634
1635        // Test setting and getting values
1636        context.set("key1".to_string(), Value::String("value1".to_string()));
1637        context.set("key2".to_string(), Value::Number(42.into()));
1638
1639        assert_eq!(
1640            context.get("key1"),
1641            Some(&Value::String("value1".to_string()))
1642        );
1643        assert_eq!(context.get("key2"), Some(&Value::Number(42.into())));
1644        assert_eq!(context.get("key3"), None);
1645
1646        // Test message history
1647        let msg1 = ForestMessage::new(
1648            "alice".to_string(),
1649            Some("bob".to_string()),
1650            "Hello".to_string(),
1651        );
1652        let msg2 = ForestMessage::broadcast("bob".to_string(), "Hi everyone".to_string());
1653
1654        context.add_message(msg1);
1655        context.add_message(msg2);
1656
1657        let messages = context.get_recent_messages(10);
1658        assert_eq!(messages.len(), 2);
1659        assert_eq!(messages[0].from, "alice");
1660        assert_eq!(messages[1].from, "bob");
1661
1662        // Test removing values
1663        let removed = context.remove("key1");
1664        assert_eq!(removed, Some(Value::String("value1".to_string())));
1665        assert_eq!(context.get("key1"), None);
1666    }
1667
1668    /// Tests collaborative task execution.
1669    #[tokio::test]
1670    async fn test_collaborative_task() {
1671        let mut forest = ForestOfAgents::new();
1672        let config = Config::new_default();
1673
1674        // Create agents with different roles
1675        let coordinator = Agent::builder("coordinator")
1676            .config(config.clone())
1677            .system_prompt(
1678                "You are a task coordinator. Break down tasks and delegate to specialists.",
1679            )
1680            .build()
1681            .await
1682            .unwrap();
1683
1684        let researcher = Agent::builder("researcher")
1685            .config(config.clone())
1686            .system_prompt("You are a researcher. Gather and analyze information.")
1687            .build()
1688            .await
1689            .unwrap();
1690
1691        let writer = Agent::builder("writer")
1692            .config(config)
1693            .system_prompt("You are a writer. Create clear, well-structured content.")
1694            .build()
1695            .await
1696            .unwrap();
1697
1698        forest
1699            .add_agent("coordinator".to_string(), coordinator)
1700            .unwrap();
1701        forest
1702            .add_agent("researcher".to_string(), researcher)
1703            .unwrap();
1704        forest.add_agent("writer".to_string(), writer).unwrap();
1705
1706        // Test that collaborative task setup works (without actually executing LLM calls)
1707        // We can't run the full collaborative task in unit tests due to LLM dependencies,
1708        // but we can test the setup and basic validation
1709
1710        // Test that agents exist validation works
1711        // (The actual task execution would require valid LLM API keys)
1712
1713        // Check that the forest has the expected agents
1714        assert_eq!(forest.list_agents().len(), 3);
1715        assert!(forest.get_agent(&"coordinator".to_string()).is_some());
1716        assert!(forest.get_agent(&"researcher".to_string()).is_some());
1717        assert!(forest.get_agent(&"writer".to_string()).is_some());
1718
1719        // Test that the method would set up shared context correctly by calling a minimal version
1720        // We'll test the context setup by manually calling the initial setup part
1721
1722        // Simulate the initial context setup that happens in execute_collaborative_task
1723        forest
1724            .set_shared_context(
1725                "current_task".to_string(),
1726                Value::String("Create a report on climate change impacts".to_string()),
1727            )
1728            .await;
1729        forest
1730            .set_shared_context(
1731                "involved_agents".to_string(),
1732                Value::Array(vec![
1733                    Value::String("researcher".to_string()),
1734                    Value::String("writer".to_string()),
1735                ]),
1736            )
1737            .await;
1738        forest
1739            .set_shared_context(
1740                "task_status".to_string(),
1741                Value::String("in_progress".to_string()),
1742            )
1743            .await;
1744
1745        // Check shared context was updated
1746        let context = forest.get_shared_context().await;
1747        assert_eq!(
1748            context.get("task_status"),
1749            Some(&Value::String("in_progress".to_string()))
1750        );
1751        assert!(context.get("current_task").is_some());
1752        assert!(context.get("involved_agents").is_some());
1753    }
1754
1755    /// Tests the ForestBuilder functionality.
1756    #[tokio::test]
1757    async fn test_forest_builder() {
1758        let config = Config::new_default();
1759
1760        let forest = ForestBuilder::new()
1761            .config(config)
1762            .agent(
1763                "agent1".to_string(),
1764                Agent::builder("agent1").system_prompt("Agent 1 prompt"),
1765            )
1766            .agent(
1767                "agent2".to_string(),
1768                Agent::builder("agent2").system_prompt("Agent 2 prompt"),
1769            )
1770            .max_iterations(5)
1771            .build()
1772            .await
1773            .unwrap();
1774
1775        assert_eq!(forest.list_agents().len(), 2);
1776        assert!(forest.get_agent(&"agent1".to_string()).is_some());
1777        assert!(forest.get_agent(&"agent2".to_string()).is_some());
1778        assert_eq!(forest.max_iterations, 5);
1779    }
1780
1781    /// Tests error handling in ForestOfAgents.
1782    #[tokio::test]
1783    async fn test_forest_error_handling() {
1784        let mut forest = ForestOfAgents::new();
1785
1786        // Test sending message from non-existent agent
1787        let result = forest
1788            .send_message(
1789                &"nonexistent".to_string(),
1790                Some(&"target".to_string()),
1791                "test".to_string(),
1792            )
1793            .await;
1794        assert!(result.is_err());
1795
1796        // Test collaborative task with non-existent initiator
1797        let result = forest
1798            .execute_collaborative_task(&"nonexistent".to_string(), "test task".to_string(), vec![])
1799            .await;
1800        assert!(result.is_err());
1801
1802        // Test collaborative task with non-existent involved agent
1803        let config = Config::new_default();
1804        let agent = Agent::builder("real_agent")
1805            .config(config)
1806            .build()
1807            .await
1808            .unwrap();
1809        forest.add_agent("real_agent".to_string(), agent).unwrap();
1810
1811        let result = forest
1812            .execute_collaborative_task(
1813                &"real_agent".to_string(),
1814                "test task".to_string(),
1815                vec!["nonexistent".to_string()],
1816            )
1817            .await;
1818        assert!(result.is_err());
1819    }
1820
1821    /// Tests ForestMessage creation and properties.
1822    #[tokio::test]
1823    async fn test_forest_message() {
1824        // Test direct message
1825        let msg = ForestMessage::new(
1826            "alice".to_string(),
1827            Some("bob".to_string()),
1828            "Hello".to_string(),
1829        );
1830        assert_eq!(msg.from, "alice");
1831        assert_eq!(msg.to, Some("bob".to_string()));
1832        assert_eq!(msg.content, "Hello");
1833
1834        // Test broadcast message
1835        let broadcast = ForestMessage::broadcast("alice".to_string(), "Announcement".to_string());
1836        assert_eq!(broadcast.from, "alice");
1837        assert!(broadcast.to.is_none());
1838        assert_eq!(broadcast.content, "Announcement");
1839
1840        // Test metadata
1841        let msg_with_meta = msg.with_metadata("priority".to_string(), "high".to_string());
1842        assert_eq!(
1843            msg_with_meta.metadata.get("priority"),
1844            Some(&"high".to_string())
1845        );
1846    }
1847}