use crate::agent::{Agent, AgentBuilder};
use crate::config::Config;
use crate::error::{HeliosError, Result};
use crate::tools::{Tool, ToolParameter, ToolResult};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
pub type AgentId = String;
#[derive(Debug, Clone)]
pub struct ForestMessage {
pub from: AgentId,
pub to: Option<AgentId>,
pub content: String,
pub metadata: HashMap<String, String>,
pub timestamp: chrono::DateTime<chrono::Utc>,
}
impl ForestMessage {
pub fn new(from: AgentId, to: Option<AgentId>, content: String) -> Self {
Self {
from,
to,
content,
metadata: HashMap::new(),
timestamp: chrono::Utc::now(),
}
}
pub fn broadcast(from: AgentId, content: String) -> Self {
Self::new(from, None, content)
}
pub fn with_metadata(mut self, key: String, value: String) -> Self {
self.metadata.insert(key, value);
self
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum TaskStatus {
Pending,
InProgress,
Completed,
Failed,
}
impl TaskStatus {
pub fn as_str(&self) -> &str {
match self {
TaskStatus::Pending => "pending",
TaskStatus::InProgress => "in_progress",
TaskStatus::Completed => "completed",
TaskStatus::Failed => "failed",
}
}
}
#[derive(Debug, Clone)]
pub struct TaskItem {
pub id: String,
pub description: String,
pub assigned_to: AgentId,
pub status: TaskStatus,
pub result: Option<String>,
pub dependencies: Vec<String>,
pub metadata: HashMap<String, String>,
}
impl TaskItem {
pub fn new(id: String, description: String, assigned_to: AgentId) -> Self {
Self {
id,
description,
assigned_to,
status: TaskStatus::Pending,
result: None,
dependencies: Vec::new(),
metadata: HashMap::new(),
}
}
pub fn with_dependencies(mut self, deps: Vec<String>) -> Self {
self.dependencies = deps;
self
}
}
#[derive(Debug, Clone)]
pub struct TaskPlan {
pub plan_id: String,
pub objective: String,
pub tasks: HashMap<String, TaskItem>,
pub task_order: Vec<String>,
pub created_at: chrono::DateTime<chrono::Utc>,
}
impl TaskPlan {
pub fn new(plan_id: String, objective: String) -> Self {
Self {
plan_id,
objective,
tasks: HashMap::new(),
task_order: Vec::new(),
created_at: chrono::Utc::now(),
}
}
pub fn add_task(&mut self, task: TaskItem) {
self.task_order.push(task.id.clone());
self.tasks.insert(task.id.clone(), task);
}
pub fn get_task_mut(&mut self, task_id: &str) -> Option<&mut TaskItem> {
self.tasks.get_mut(task_id)
}
pub fn get_task(&self, task_id: &str) -> Option<&TaskItem> {
self.tasks.get(task_id)
}
pub fn get_next_ready_tasks(&self) -> Vec<&TaskItem> {
self.task_order
.iter()
.filter_map(|task_id| self.tasks.get(task_id))
.filter(|t| {
t.status == TaskStatus::Pending
&& t.dependencies.iter().all(|dep_id| {
self.tasks
.get(dep_id)
.map(|dt| dt.status == TaskStatus::Completed)
.unwrap_or(false)
})
})
.collect()
}
pub fn is_complete(&self) -> bool {
self.tasks
.values()
.all(|t| t.status == TaskStatus::Completed || t.status == TaskStatus::Failed)
}
pub fn get_progress(&self) -> (usize, usize) {
let completed = self
.tasks
.values()
.filter(|t| t.status == TaskStatus::Completed)
.count();
(completed, self.tasks.len())
}
pub fn tasks_in_order(&self) -> Vec<&TaskItem> {
self.task_order
.iter()
.filter_map(|id| self.tasks.get(id))
.collect()
}
}
#[derive(Debug, Clone)]
pub struct SharedContext {
pub data: HashMap<String, Value>,
pub message_history: Vec<ForestMessage>,
pub metadata: HashMap<String, String>,
pub current_plan: Option<TaskPlan>,
}
impl SharedContext {
pub fn new() -> Self {
Self {
data: HashMap::new(),
message_history: Vec::new(),
metadata: HashMap::new(),
current_plan: None,
}
}
pub fn set(&mut self, key: String, value: Value) {
self.data.insert(key, value);
}
pub fn get(&self, key: &str) -> Option<&Value> {
self.data.get(key)
}
pub fn remove(&mut self, key: &str) -> Option<Value> {
self.data.remove(key)
}
pub fn add_message(&mut self, message: ForestMessage) {
self.message_history.push(message);
}
pub fn get_recent_messages(&self, limit: usize) -> &[ForestMessage] {
let len = self.message_history.len();
let start = len.saturating_sub(limit);
&self.message_history[start..]
}
pub fn set_plan(&mut self, plan: TaskPlan) {
self.current_plan = Some(plan);
}
pub fn get_plan(&self) -> Option<&TaskPlan> {
self.current_plan.as_ref()
}
pub fn get_plan_mut(&mut self) -> Option<&mut TaskPlan> {
self.current_plan.as_mut()
}
pub fn clear_plan(&mut self) {
self.current_plan = None;
}
}
impl Default for SharedContext {
fn default() -> Self {
Self::new()
}
}
pub struct ForestOfAgents {
agents: HashMap<AgentId, Agent>,
shared_context: Arc<RwLock<SharedContext>>,
message_queue: Arc<RwLock<Vec<ForestMessage>>>,
max_iterations: usize,
}
impl ForestOfAgents {
pub fn new() -> Self {
Self {
agents: HashMap::new(),
shared_context: Arc::new(RwLock::new(SharedContext::new())),
message_queue: Arc::new(RwLock::new(Vec::new())),
max_iterations: 10,
}
}
pub fn with_max_iterations(max_iterations: usize) -> Self {
Self {
max_iterations,
..Self::new()
}
}
pub fn add_agent(&mut self, id: AgentId, mut agent: Agent) -> Result<()> {
if self.agents.contains_key(&id) {
return Err(HeliosError::AgentError(format!(
"Agent with ID '{}' already exists",
id
)));
}
let send_message_tool = Box::new(SendMessageTool::new(
id.clone(),
Arc::clone(&self.message_queue),
Arc::clone(&self.shared_context),
));
agent.register_tool(send_message_tool);
let delegate_task_tool = Box::new(DelegateTaskTool::new(
id.clone(),
Arc::clone(&self.message_queue),
Arc::clone(&self.shared_context),
));
agent.register_tool(delegate_task_tool);
let share_context_tool = Box::new(ShareContextTool::new(
id.clone(),
Arc::clone(&self.shared_context),
));
agent.register_tool(share_context_tool);
let update_task_memory_tool = Box::new(UpdateTaskMemoryTool::new(
id.clone(),
Arc::clone(&self.shared_context),
));
agent.register_tool(update_task_memory_tool);
let create_plan_tool = Box::new(CreatePlanTool::new(
id.clone(),
Arc::clone(&self.shared_context),
));
agent.register_tool(create_plan_tool);
self.agents.insert(id, agent);
Ok(())
}
pub fn remove_agent(&mut self, id: &AgentId) -> Option<Agent> {
self.agents.remove(id)
}
pub fn get_agent(&self, id: &AgentId) -> Option<&Agent> {
self.agents.get(id)
}
pub fn get_agent_mut(&mut self, id: &AgentId) -> Option<&mut Agent> {
self.agents.get_mut(id)
}
pub fn list_agents(&self) -> Vec<AgentId> {
self.agents.keys().cloned().collect()
}
pub async fn send_message(
&self,
from: &AgentId,
to: Option<&AgentId>,
content: String,
) -> Result<()> {
if !self.agents.contains_key(from) {
return Err(HeliosError::AgentError(format!(
"Agent '{}' not found",
from
)));
}
let message = if let Some(to_id) = to {
ForestMessage::new(from.clone(), Some(to_id.clone()), content)
} else {
ForestMessage::broadcast(from.clone(), content)
};
let mut queue = self.message_queue.write().await;
queue.push(message.clone());
let mut context = self.shared_context.write().await;
context.add_message(message);
Ok(())
}
pub async fn process_messages(&mut self) -> Result<()> {
let messages: Vec<ForestMessage> = {
let mut queue = self.message_queue.write().await;
queue.drain(..).collect()
};
for message in messages {
if let Some(to_id) = &message.to {
if let Some(agent) = self.agents.get_mut(to_id) {
let formatted_message =
format!("Message from {}: {}", message.from, message.content);
agent.chat_session_mut().add_user_message(formatted_message);
}
} else {
for (agent_id, agent) in &mut self.agents {
if agent_id != &message.from {
let formatted_message =
format!("Broadcast from {}: {}", message.from, message.content);
agent.chat_session_mut().add_user_message(formatted_message);
}
}
}
}
Ok(())
}
pub async fn execute_collaborative_task(
&mut self,
initiator: &AgentId,
task_description: String,
involved_agents: Vec<AgentId>,
) -> Result<String> {
for agent_id in &involved_agents {
if !self.agents.contains_key(agent_id) {
return Err(HeliosError::AgentError(format!(
"Agent '{}' not found",
agent_id
)));
}
}
if !self.agents.contains_key(initiator) {
return Err(HeliosError::AgentError(format!(
"Initiator agent '{}' not found",
initiator
)));
}
{
let mut context = self.shared_context.write().await;
context.set(
"current_task".to_string(),
Value::String(task_description.clone()),
);
context.set(
"involved_agents".to_string(),
Value::Array(
involved_agents
.iter()
.map(|id| Value::String(id.clone()))
.collect(),
),
);
context.set(
"task_status".to_string(),
Value::String("planning".to_string()),
);
}
let coordinator = self.agents.get_mut(initiator).unwrap();
let planning_prompt = format!(
"You are coordinating a collaborative task. Create a detailed plan using the 'create_plan' tool.\n\n\
Task: {}\n\n\
Available team members and their expertise:\n{}\n\n\
Break this task into subtasks and assign each to the most appropriate agent. \
Use the create_plan tool with a JSON array of tasks. Each task should have:\n\
- id: unique identifier (e.g., 'task_1')\n\
- description: what needs to be done\n\
- assigned_to: agent name\n\
- dependencies: array of task IDs that must complete first (use [] if none)\n\n\
IMPORTANT: You MUST use the create_plan tool to create a plan before doing anything else. \
Do not try to complete the task yourself - just create the plan using the tool.",
task_description,
involved_agents.join(", ")
);
let _planning_result = coordinator.chat(planning_prompt).await?;
let plan_exists = {
let context = self.shared_context.read().await;
context.get_plan().is_some()
};
if !plan_exists {
return Ok(_planning_result);
}
let mut iteration = 0;
let max_task_iterations = self.max_iterations * 3;
while iteration < max_task_iterations {
let ready_tasks: Vec<(String, String, AgentId)> = {
let context = self.shared_context.read().await;
if let Some(plan) = context.get_plan() {
if plan.is_complete() {
break;
}
plan.get_next_ready_tasks()
.iter()
.map(|t| (t.id.clone(), t.description.clone(), t.assigned_to.clone()))
.collect()
} else {
let initiator_agent = self.agents.get_mut(initiator).unwrap();
let result = initiator_agent
.chat(format!(
"Complete this task: {}\nYou can delegate to: {}",
task_description,
involved_agents.join(", ")
))
.await?;
return Ok(result);
}
};
if ready_tasks.is_empty() {
let has_in_progress = {
let context = self.shared_context.read().await;
context
.get_plan()
.map(|p| p.tasks.values().any(|t| t.status == TaskStatus::InProgress))
.unwrap_or(false)
};
if !has_in_progress {
break; }
if iteration > 5 {
break;
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
iteration += 1;
continue;
}
for (task_id, task_desc, agent_id) in ready_tasks {
{
let mut context = self.shared_context.write().await;
if let Some(plan) = context.get_plan_mut() {
if let Some(task) = plan.get_task_mut(&task_id) {
task.status = TaskStatus::InProgress;
}
}
}
let shared_memory_info = {
let context = self.shared_context.read().await;
let mut info = String::from("\n=== SHARED TASK MEMORY ===\n");
if let Some(plan) = context.get_plan() {
info.push_str(&format!("Overall Objective: {}\n", plan.objective));
info.push_str(&format!(
"Progress: {}/{} tasks completed\n\n",
plan.get_progress().0,
plan.get_progress().1
));
info.push_str("Completed Tasks:\n");
for task in plan.tasks_in_order() {
if task.status == TaskStatus::Completed {
info.push_str(&format!(
" ✓ [{}] {}: {}\n",
task.assigned_to,
task.description,
task.result.as_ref().unwrap_or(&"No result".to_string())
));
}
}
}
info.push_str("\nShared Data:\n");
for (key, value) in &context.data {
if !key.starts_with("current_task")
&& !key.starts_with("involved_agents")
&& !key.starts_with("task_status")
{
info.push_str(&format!(" • {}: {}\n", key, value));
}
}
info.push_str("=========================\n\n");
info
};
if let Some(agent) = self.agents.get_mut(&agent_id) {
let task_prompt = format!(
"{}Your assigned task: {}\n\n\
Complete this task and use the 'update_task_memory' tool to save your results to the shared memory. \
The task_id is '{}'. Include key findings and data that other agents might need.\n\n\
Provide a complete response with your results.",
shared_memory_info, task_desc, task_id
);
let result = agent.chat(task_prompt).await?;
{
let mut context = self.shared_context.write().await;
if let Some(plan) = context.get_plan_mut() {
if let Some(task) = plan.get_task_mut(&task_id) {
if task.status == TaskStatus::InProgress {
task.status = TaskStatus::Completed;
task.result = Some(result.clone());
}
}
}
}
}
}
iteration += 1;
}
let final_summary = {
let context = self.shared_context.read().await;
let mut summary = String::from("=== TASK COMPLETION SUMMARY ===\n\n");
if let Some(plan) = context.get_plan() {
summary.push_str(&format!("Objective: {}\n", plan.objective));
summary.push_str(&format!(
"Status: All tasks completed ({}/{} tasks)\n\n",
plan.get_progress().0,
plan.get_progress().1
));
summary.push_str("Task Results:\n");
for task in plan.tasks_in_order() {
summary.push_str(&format!("\n[{}] {}\n", task.assigned_to, task.description));
if let Some(result) = &task.result {
summary.push_str(&format!("Result: {}\n", result));
}
}
}
summary
};
let coordinator = self.agents.get_mut(initiator).unwrap();
let synthesis_prompt = format!(
"Based on the completed tasks, provide a comprehensive final answer to the original request.\n\n\
Original Task: {}\n\n\
{}\n\n\
Synthesize all the information into a cohesive, complete response.",
task_description, final_summary
);
let final_result = coordinator.chat(synthesis_prompt).await?;
{
let mut context = self.shared_context.write().await;
context.set(
"task_status".to_string(),
Value::String("completed".to_string()),
);
}
Ok(final_result)
}
#[allow(dead_code)]
async fn process_messages_and_trigger_responses(
&mut self,
involved_agents: &[AgentId],
) -> Result<()> {
let mut iteration = 0;
while iteration < self.max_iterations {
self.process_messages().await?;
let mut agents_to_respond = Vec::new();
for agent_id in involved_agents {
if let Some(agent) = self.agents.get(agent_id) {
let messages = &agent.chat_session().messages;
if !messages.is_empty() {
let last_message = messages.last().unwrap();
if last_message.role == crate::chat::Role::User {
agents_to_respond.push(agent_id.clone());
}
}
}
}
if agents_to_respond.is_empty() {
break;
}
for agent_id in agents_to_respond {
if let Some(agent) = self.agents.get_mut(&agent_id) {
let _response = agent.chat("").await?;
}
}
iteration += 1;
}
Ok(())
}
pub async fn get_shared_context(&self) -> SharedContext {
self.shared_context.read().await.clone()
}
pub async fn set_shared_context(&self, key: String, value: Value) {
let mut context = self.shared_context.write().await;
context.set(key, value);
}
}
impl Default for ForestOfAgents {
fn default() -> Self {
Self::new()
}
}
pub struct SendMessageTool {
agent_id: AgentId,
message_queue: Arc<RwLock<Vec<ForestMessage>>>,
shared_context: Arc<RwLock<SharedContext>>,
}
impl SendMessageTool {
pub fn new(
agent_id: AgentId,
message_queue: Arc<RwLock<Vec<ForestMessage>>>,
shared_context: Arc<RwLock<SharedContext>>,
) -> Self {
Self {
agent_id,
message_queue,
shared_context,
}
}
}
#[async_trait::async_trait]
impl Tool for SendMessageTool {
fn name(&self) -> &str {
"send_message"
}
fn description(&self) -> &str {
"Send a message to another agent or broadcast to all agents in the forest."
}
fn parameters(&self) -> HashMap<String, ToolParameter> {
let mut params = HashMap::new();
params.insert(
"to".to_string(),
ToolParameter {
param_type: "string".to_string(),
description: "ID of the recipient agent (leave empty for broadcast)".to_string(),
required: Some(false),
},
);
params.insert(
"message".to_string(),
ToolParameter {
param_type: "string".to_string(),
description: "The message content to send".to_string(),
required: Some(true),
},
);
params
}
async fn execute(&self, args: Value) -> Result<ToolResult> {
let message = args
.get("message")
.and_then(|v| v.as_str())
.ok_or_else(|| HeliosError::ToolError("Missing 'message' parameter".to_string()))?
.to_string();
let to = args
.get("to")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let forest_message = if let Some(to_id) = &to {
ForestMessage::new(self.agent_id.clone(), Some(to_id.clone()), message)
} else {
ForestMessage::broadcast(self.agent_id.clone(), message)
};
{
let mut queue = self.message_queue.write().await;
queue.push(forest_message.clone());
}
{
let mut context = self.shared_context.write().await;
context.add_message(forest_message);
}
Ok(ToolResult::success("Message sent successfully"))
}
}
pub struct DelegateTaskTool {
agent_id: AgentId,
message_queue: Arc<RwLock<Vec<ForestMessage>>>,
shared_context: Arc<RwLock<SharedContext>>,
}
impl DelegateTaskTool {
pub fn new(
agent_id: AgentId,
message_queue: Arc<RwLock<Vec<ForestMessage>>>,
shared_context: Arc<RwLock<SharedContext>>,
) -> Self {
Self {
agent_id,
message_queue,
shared_context,
}
}
}
#[async_trait::async_trait]
impl Tool for DelegateTaskTool {
fn name(&self) -> &str {
"delegate_task"
}
fn description(&self) -> &str {
"Delegate a specific task to another agent for execution."
}
fn parameters(&self) -> HashMap<String, ToolParameter> {
let mut params = HashMap::new();
params.insert(
"to".to_string(),
ToolParameter {
param_type: "string".to_string(),
description: "ID of the agent to delegate the task to".to_string(),
required: Some(true),
},
);
params.insert(
"task".to_string(),
ToolParameter {
param_type: "string".to_string(),
description: "Description of the task to delegate".to_string(),
required: Some(true),
},
);
params.insert(
"context".to_string(),
ToolParameter {
param_type: "string".to_string(),
description: "Additional context or requirements for the task".to_string(),
required: Some(false),
},
);
params
}
async fn execute(&self, args: Value) -> Result<ToolResult> {
let to = args
.get("to")
.and_then(|v| v.as_str())
.ok_or_else(|| HeliosError::ToolError("Missing 'to' parameter".to_string()))?;
let task = args
.get("task")
.and_then(|v| v.as_str())
.ok_or_else(|| HeliosError::ToolError("Missing 'task' parameter".to_string()))?;
let context = args.get("context").and_then(|v| v.as_str()).unwrap_or("");
let message = if context.is_empty() {
format!("Task delegated: {}", task)
} else {
format!("Task delegated: {}\nContext: {}", task, context)
};
let forest_message =
ForestMessage::new(self.agent_id.clone(), Some(to.to_string()), message)
.with_metadata("type".to_string(), "task_delegation".to_string())
.with_metadata("task".to_string(), task.to_string());
{
let mut queue = self.message_queue.write().await;
queue.push(forest_message.clone());
}
{
let mut context_lock = self.shared_context.write().await;
context_lock.add_message(forest_message);
}
Ok(ToolResult::success(format!(
"Task delegated to agent '{}'",
to
)))
}
}
pub struct ShareContextTool {
agent_id: AgentId,
shared_context: Arc<RwLock<SharedContext>>,
}
impl ShareContextTool {
pub fn new(agent_id: AgentId, shared_context: Arc<RwLock<SharedContext>>) -> Self {
Self {
agent_id,
shared_context,
}
}
}
#[async_trait::async_trait]
impl Tool for ShareContextTool {
fn name(&self) -> &str {
"share_context"
}
fn description(&self) -> &str {
"Share information in the shared context that all agents can access."
}
fn parameters(&self) -> HashMap<String, ToolParameter> {
let mut params = HashMap::new();
params.insert(
"key".to_string(),
ToolParameter {
param_type: "string".to_string(),
description: "Key for the shared information".to_string(),
required: Some(true),
},
);
params.insert(
"value".to_string(),
ToolParameter {
param_type: "string".to_string(),
description: "Value to share".to_string(),
required: Some(true),
},
);
params.insert(
"description".to_string(),
ToolParameter {
param_type: "string".to_string(),
description: "Description of what this information represents".to_string(),
required: Some(false),
},
);
params
}
async fn execute(&self, args: Value) -> Result<ToolResult> {
let key = args
.get("key")
.and_then(|v| v.as_str())
.ok_or_else(|| HeliosError::ToolError("Missing 'key' parameter".to_string()))?;
let value = args
.get("value")
.and_then(|v| v.as_str())
.ok_or_else(|| HeliosError::ToolError("Missing 'value' parameter".to_string()))?;
let description = args
.get("description")
.and_then(|v| v.as_str())
.unwrap_or("");
let mut context = self.shared_context.write().await;
let metadata = serde_json::json!({
"shared_by": self.agent_id,
"timestamp": chrono::Utc::now().to_rfc3339(),
"description": description
});
let value_with_meta = serde_json::json!({
"value": value,
"metadata": metadata
});
context.set(key.to_string(), value_with_meta);
Ok(ToolResult::success(format!(
"Information shared with key '{}'",
key
)))
}
}
pub struct UpdateTaskMemoryTool {
agent_id: AgentId,
shared_context: Arc<RwLock<SharedContext>>,
}
impl UpdateTaskMemoryTool {
pub fn new(agent_id: AgentId, shared_context: Arc<RwLock<SharedContext>>) -> Self {
Self {
agent_id,
shared_context,
}
}
}
#[async_trait::async_trait]
impl Tool for UpdateTaskMemoryTool {
fn name(&self) -> &str {
"update_task_memory"
}
fn description(&self) -> &str {
"Update the shared task memory with your results, findings, and data. This allows other agents to see your progress and use your outputs."
}
fn parameters(&self) -> HashMap<String, ToolParameter> {
let mut params = HashMap::new();
params.insert(
"task_id".to_string(),
ToolParameter {
param_type: "string".to_string(),
description: "The ID of the task you're updating (e.g., 'task_1')".to_string(),
required: Some(true),
},
);
params.insert(
"result".to_string(),
ToolParameter {
param_type: "string".to_string(),
description: "Your results, findings, or output from completing the task"
.to_string(),
required: Some(true),
},
);
params.insert(
"data".to_string(),
ToolParameter {
param_type: "string".to_string(),
description: "Additional data or information to share (e.g., key findings, metrics, recommendations)".to_string(),
required: Some(false),
},
);
params
}
async fn execute(&self, args: Value) -> Result<ToolResult> {
let task_id = args
.get("task_id")
.and_then(|v| v.as_str())
.ok_or_else(|| HeliosError::ToolError("Missing 'task_id' parameter".to_string()))?;
let result = args
.get("result")
.and_then(|v| v.as_str())
.ok_or_else(|| HeliosError::ToolError("Missing 'result' parameter".to_string()))?;
let additional_data = args.get("data").and_then(|v| v.as_str()).unwrap_or("");
let mut context = self.shared_context.write().await;
if let Some(plan) = context.get_plan_mut() {
if let Some(task) = plan.get_task_mut(task_id) {
task.status = TaskStatus::Completed;
task.result = Some(result.to_string());
let task_description = task.description.clone();
if !additional_data.is_empty() {
let data_key = format!("task_data_{}", task_id);
context.set(
data_key,
serde_json::json!({
"agent": self.agent_id,
"task": task_description,
"data": additional_data,
"timestamp": chrono::Utc::now().to_rfc3339()
}),
);
}
return Ok(ToolResult::success(format!(
"Task '{}' marked as completed. Results saved to shared memory.",
task_id
)));
} else {
return Err(HeliosError::ToolError(format!(
"Task '{}' not found in current plan",
task_id
)));
}
}
Err(HeliosError::ToolError(
"No active task plan found".to_string(),
))
}
}
pub struct CreatePlanTool {
#[allow(dead_code)]
agent_id: AgentId,
shared_context: Arc<RwLock<SharedContext>>,
}
impl CreatePlanTool {
pub fn new(agent_id: AgentId, shared_context: Arc<RwLock<SharedContext>>) -> Self {
Self {
agent_id,
shared_context,
}
}
}
#[async_trait::async_trait]
impl Tool for CreatePlanTool {
fn name(&self) -> &str {
"create_plan"
}
fn description(&self) -> &str {
"Create a detailed task plan for collaborative work. Break down the overall objective into specific tasks and assign them to team members."
}
fn parameters(&self) -> HashMap<String, ToolParameter> {
let mut params = HashMap::new();
params.insert(
"objective".to_string(),
ToolParameter {
param_type: "string".to_string(),
description: "The overall objective or goal of the plan".to_string(),
required: Some(true),
},
);
params.insert(
"tasks".to_string(),
ToolParameter {
param_type: "string".to_string(),
description: "JSON array of tasks. Each task must have: id (string), description (string), assigned_to (string), dependencies (array of task IDs)".to_string(),
required: Some(true),
},
);
params
}
async fn execute(&self, args: Value) -> Result<ToolResult> {
let objective = args
.get("objective")
.and_then(|v| v.as_str())
.ok_or_else(|| HeliosError::ToolError("Missing 'objective' parameter".to_string()))?;
let tasks_json = args
.get("tasks")
.and_then(|v| v.as_str())
.ok_or_else(|| HeliosError::ToolError("Missing 'tasks' parameter".to_string()))?;
let tasks_array: Vec<Value> = serde_json::from_str(tasks_json)
.map_err(|e| HeliosError::ToolError(format!("Invalid JSON for tasks: {}", e)))?;
let plan_id = format!("plan_{}", chrono::Utc::now().timestamp());
let mut plan = TaskPlan::new(plan_id.clone(), objective.to_string());
for task_value in tasks_array {
let task_obj = task_value.as_object().ok_or_else(|| {
HeliosError::ToolError("Each task must be a JSON object".to_string())
})?;
let id = task_obj
.get("id")
.and_then(|v| v.as_str())
.ok_or_else(|| HeliosError::ToolError("Task missing 'id' field".to_string()))?
.to_string();
let description = task_obj
.get("description")
.and_then(|v| v.as_str())
.ok_or_else(|| {
HeliosError::ToolError("Task missing 'description' field".to_string())
})?
.to_string();
let assigned_to = task_obj
.get("assigned_to")
.and_then(|v| v.as_str())
.ok_or_else(|| {
HeliosError::ToolError("Task missing 'assigned_to' field".to_string())
})?
.to_string();
let dependencies = task_obj
.get("dependencies")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect::<Vec<String>>()
})
.unwrap_or_else(Vec::new);
let task = TaskItem::new(id, description, assigned_to).with_dependencies(dependencies);
plan.add_task(task);
}
let mut context = self.shared_context.write().await;
context.set_plan(plan.clone());
let task_summary = plan
.tasks_in_order()
.iter()
.map(|t| {
format!(
" • [{}] {} (assigned to: {})",
t.id, t.description, t.assigned_to
)
})
.collect::<Vec<_>>()
.join("\n");
Ok(ToolResult::success(format!(
"Plan created with {} tasks:\n{}",
plan.tasks.len(),
task_summary
)))
}
}
pub struct ForestBuilder {
config: Option<Config>,
agents: Vec<(AgentId, AgentBuilder)>,
max_iterations: usize,
}
impl ForestBuilder {
pub fn new() -> Self {
Self {
config: None,
agents: Vec::new(),
max_iterations: 10,
}
}
pub fn config(mut self, config: Config) -> Self {
self.config = Some(config);
self
}
pub fn agent(mut self, id: AgentId, builder: AgentBuilder) -> Self {
self.agents.push((id, builder));
self
}
pub fn agents(mut self, agents: Vec<(AgentId, AgentBuilder)>) -> Self {
self.agents.extend(agents);
self
}
pub fn max_iterations(mut self, max: usize) -> Self {
self.max_iterations = max;
self
}
pub async fn build(self) -> Result<ForestOfAgents> {
let config = self
.config
.ok_or_else(|| HeliosError::AgentError("Config is required".to_string()))?;
let mut forest = ForestOfAgents::with_max_iterations(self.max_iterations);
for (id, builder) in self.agents {
let agent = builder.config(config.clone()).build().await?;
forest.add_agent(id, agent)?;
}
Ok(forest)
}
}
impl Default for ForestBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::Config;
use crate::tools::Tool;
use serde_json::Value;
#[tokio::test]
async fn test_forest_creation_and_agent_management() {
let mut forest = ForestOfAgents::new();
let config = Config::new_default();
let agent1 = Agent::builder("agent1")
.config(config.clone())
.system_prompt("You are agent 1")
.build()
.await
.unwrap();
let agent2 = Agent::builder("agent2")
.config(config)
.system_prompt("You are agent 2")
.build()
.await
.unwrap();
forest.add_agent("agent1".to_string(), agent1).unwrap();
forest.add_agent("agent2".to_string(), agent2).unwrap();
let agents = forest.list_agents();
assert_eq!(agents.len(), 2);
assert!(agents.contains(&"agent1".to_string()));
assert!(agents.contains(&"agent2".to_string()));
assert!(forest.get_agent(&"agent1".to_string()).is_some());
assert!(forest.get_agent(&"agent3".to_string()).is_none());
let agent3 = Agent::builder("agent3")
.config(Config::new_default())
.build()
.await
.unwrap();
let result = forest.add_agent("agent1".to_string(), agent3);
assert!(result.is_err());
let removed = forest.remove_agent(&"agent1".to_string());
assert!(removed.is_some());
assert_eq!(forest.list_agents().len(), 1);
assert!(forest.get_agent(&"agent1".to_string()).is_none());
}
#[tokio::test]
async fn test_message_sending() {
let mut forest = ForestOfAgents::new();
let config = Config::new_default();
let agent1 = Agent::builder("alice")
.config(config.clone())
.build()
.await
.unwrap();
let agent2 = Agent::builder("bob").config(config).build().await.unwrap();
forest.add_agent("alice".to_string(), agent1).unwrap();
forest.add_agent("bob".to_string(), agent2).unwrap();
forest
.send_message(
&"alice".to_string(),
Some(&"bob".to_string()),
"Hello Bob!".to_string(),
)
.await
.unwrap();
forest.process_messages().await.unwrap();
let bob = forest.get_agent(&"bob".to_string()).unwrap();
let messages = bob.chat_session().messages.clone();
assert!(!messages.is_empty());
let last_message = messages.last().unwrap();
assert_eq!(last_message.role, crate::chat::Role::User);
assert!(last_message
.content
.contains("Message from alice: Hello Bob!"));
let alice_message_count_before = forest
.get_agent(&"alice".to_string())
.unwrap()
.chat_session()
.messages
.len();
forest
.send_message(&"alice".to_string(), None, "Hello everyone!".to_string())
.await
.unwrap();
forest.process_messages().await.unwrap();
let alice = forest.get_agent(&"alice".to_string()).unwrap();
assert_eq!(
alice.chat_session().messages.len(),
alice_message_count_before
);
let bob = forest.get_agent(&"bob".to_string()).unwrap();
let bob_messages = bob.chat_session().messages.clone();
let bob_last = bob_messages.last().unwrap();
assert!(bob_last
.content
.contains("Broadcast from alice: Hello everyone!"));
}
#[tokio::test]
async fn test_send_message_tool() {
let message_queue = Arc::new(RwLock::new(Vec::<ForestMessage>::new()));
let shared_context = Arc::new(RwLock::new(SharedContext::new()));
let tool = SendMessageTool::new(
"alice".to_string(),
message_queue.clone(),
shared_context.clone(),
);
let args = serde_json::json!({
"to": "bob",
"message": "Test message"
});
let result = tool.execute(args).await.unwrap();
assert!(result.success);
assert_eq!(result.output, "Message sent successfully");
let queue = message_queue.read().await;
assert_eq!(queue.len(), 1);
let message = &queue[0];
assert_eq!(message.from, "alice");
assert_eq!(message.to, Some("bob".to_string()));
assert_eq!(message.content, "Test message");
let context = shared_context.read().await;
let messages = context.get_recent_messages(10);
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].from, "alice");
}
#[tokio::test]
async fn test_delegate_task_tool() {
let message_queue = Arc::new(RwLock::new(Vec::new()));
let shared_context = Arc::new(RwLock::new(SharedContext::new()));
let tool = DelegateTaskTool::new(
"manager".to_string(),
Arc::clone(&message_queue),
Arc::clone(&shared_context),
);
let args = serde_json::json!({
"to": "worker",
"task": "Analyze the data",
"context": "Use statistical methods"
});
let result = tool.execute(args).await.unwrap();
assert!(result.success);
assert_eq!(result.output, "Task delegated to agent 'worker'");
let queue = message_queue.read().await;
assert_eq!(queue.len(), 1);
let message = &queue[0];
assert_eq!(message.from, "manager");
assert_eq!(message.to, Some("worker".to_string()));
assert!(message.content.contains("Task delegated: Analyze the data"));
assert!(message.content.contains("Context: Use statistical methods"));
assert_eq!(
message.metadata.get("type"),
Some(&"task_delegation".to_string())
);
assert_eq!(
message.metadata.get("task"),
Some(&"Analyze the data".to_string())
);
}
#[tokio::test]
async fn test_share_context_tool() {
let shared_context = Arc::new(RwLock::new(SharedContext::new()));
let tool = ShareContextTool::new("researcher".to_string(), Arc::clone(&shared_context));
let args = serde_json::json!({
"key": "findings",
"value": "Temperature affects reaction rate",
"description": "Key experimental finding"
});
let result = tool.execute(args).await.unwrap();
assert!(result.success);
assert_eq!(result.output, "Information shared with key 'findings'");
let context = shared_context.read().await;
let findings_data = context.get("findings").unwrap();
let findings_obj = findings_data.as_object().unwrap();
assert_eq!(
findings_obj.get("value").unwrap(),
&Value::String("Temperature affects reaction rate".to_string())
);
let metadata = findings_obj.get("metadata").unwrap();
let metadata_obj = metadata.as_object().unwrap();
assert_eq!(
metadata_obj.get("shared_by").unwrap(),
&Value::String("researcher".to_string())
);
assert_eq!(
metadata_obj.get("description").unwrap(),
&Value::String("Key experimental finding".to_string())
);
assert!(metadata_obj.contains_key("timestamp"));
}
#[tokio::test]
async fn test_shared_context() {
let mut context = SharedContext::new();
context.set("key1".to_string(), Value::String("value1".to_string()));
context.set("key2".to_string(), Value::Number(42.into()));
assert_eq!(
context.get("key1"),
Some(&Value::String("value1".to_string()))
);
assert_eq!(context.get("key2"), Some(&Value::Number(42.into())));
assert_eq!(context.get("key3"), None);
let msg1 = ForestMessage::new(
"alice".to_string(),
Some("bob".to_string()),
"Hello".to_string(),
);
let msg2 = ForestMessage::broadcast("bob".to_string(), "Hi everyone".to_string());
context.add_message(msg1);
context.add_message(msg2);
let messages = context.get_recent_messages(10);
assert_eq!(messages.len(), 2);
assert_eq!(messages[0].from, "alice");
assert_eq!(messages[1].from, "bob");
let removed = context.remove("key1");
assert_eq!(removed, Some(Value::String("value1".to_string())));
assert_eq!(context.get("key1"), None);
}
#[tokio::test]
async fn test_collaborative_task() {
let mut forest = ForestOfAgents::new();
let config = Config::new_default();
let coordinator = Agent::builder("coordinator")
.config(config.clone())
.system_prompt(
"You are a task coordinator. Break down tasks and delegate to specialists.",
)
.build()
.await
.unwrap();
let researcher = Agent::builder("researcher")
.config(config.clone())
.system_prompt("You are a researcher. Gather and analyze information.")
.build()
.await
.unwrap();
let writer = Agent::builder("writer")
.config(config)
.system_prompt("You are a writer. Create clear, well-structured content.")
.build()
.await
.unwrap();
forest
.add_agent("coordinator".to_string(), coordinator)
.unwrap();
forest
.add_agent("researcher".to_string(), researcher)
.unwrap();
forest.add_agent("writer".to_string(), writer).unwrap();
assert_eq!(forest.list_agents().len(), 3);
assert!(forest.get_agent(&"coordinator".to_string()).is_some());
assert!(forest.get_agent(&"researcher".to_string()).is_some());
assert!(forest.get_agent(&"writer".to_string()).is_some());
forest
.set_shared_context(
"current_task".to_string(),
Value::String("Create a report on climate change impacts".to_string()),
)
.await;
forest
.set_shared_context(
"involved_agents".to_string(),
Value::Array(vec![
Value::String("researcher".to_string()),
Value::String("writer".to_string()),
]),
)
.await;
forest
.set_shared_context(
"task_status".to_string(),
Value::String("in_progress".to_string()),
)
.await;
let context = forest.get_shared_context().await;
assert_eq!(
context.get("task_status"),
Some(&Value::String("in_progress".to_string()))
);
assert!(context.get("current_task").is_some());
assert!(context.get("involved_agents").is_some());
}
#[tokio::test]
async fn test_forest_builder() {
let config = Config::new_default();
let forest = ForestBuilder::new()
.config(config)
.agent(
"agent1".to_string(),
Agent::builder("agent1").system_prompt("Agent 1 prompt"),
)
.agent(
"agent2".to_string(),
Agent::builder("agent2").system_prompt("Agent 2 prompt"),
)
.max_iterations(5)
.build()
.await
.unwrap();
assert_eq!(forest.list_agents().len(), 2);
assert!(forest.get_agent(&"agent1".to_string()).is_some());
assert!(forest.get_agent(&"agent2".to_string()).is_some());
assert_eq!(forest.max_iterations, 5);
}
#[tokio::test]
async fn test_forest_error_handling() {
let mut forest = ForestOfAgents::new();
let result = forest
.send_message(
&"nonexistent".to_string(),
Some(&"target".to_string()),
"test".to_string(),
)
.await;
assert!(result.is_err());
let result = forest
.execute_collaborative_task(&"nonexistent".to_string(), "test task".to_string(), vec![])
.await;
assert!(result.is_err());
let config = Config::new_default();
let agent = Agent::builder("real_agent")
.config(config)
.build()
.await
.unwrap();
forest.add_agent("real_agent".to_string(), agent).unwrap();
let result = forest
.execute_collaborative_task(
&"real_agent".to_string(),
"test task".to_string(),
vec!["nonexistent".to_string()],
)
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_forest_message() {
let msg = ForestMessage::new(
"alice".to_string(),
Some("bob".to_string()),
"Hello".to_string(),
);
assert_eq!(msg.from, "alice");
assert_eq!(msg.to, Some("bob".to_string()));
assert_eq!(msg.content, "Hello");
let broadcast = ForestMessage::broadcast("alice".to_string(), "Announcement".to_string());
assert_eq!(broadcast.from, "alice");
assert!(broadcast.to.is_none());
assert_eq!(broadcast.content, "Announcement");
let msg_with_meta = msg.with_metadata("priority".to_string(), "high".to_string());
assert_eq!(
msg_with_meta.metadata.get("priority"),
Some(&"high".to_string())
);
}
}