use async_trait::async_trait;
use std::collections::HashMap;
use crate::models::tools::{ToolRegistry, Tool, ToolRegistryTrait, CombinedToolRegistry};
#[async_trait]
pub trait Agent: Send {
async fn run(
&mut self,
input: &str,
tool_registry: &(dyn ToolRegistryTrait + Send + Sync),
) -> (String, Option<i32>);
fn get_name(&self) -> &str;
}
pub struct Graph {
nodes: HashMap<i32, Node>,
tool_registry: ToolRegistry, }
struct Node {
agent: Box<dyn Agent>,
neighbors: Vec<i32>,
tool_registry: ToolRegistry, }
impl Graph {
pub fn new() -> Self {
Self {
nodes: HashMap::new(),
tool_registry: ToolRegistry::new(),
}
}
pub fn register_tool<F>(&mut self, tool: Tool, function: F)
where
F: Fn(serde_json::Value) -> Result<serde_json::Value, String> + Send + Sync + 'static,
{
self.tool_registry.register_tool(tool, function);
}
pub fn register_tool_for_node<F>(
&mut self,
node_id: i32,
tool: Tool,
function: F,
) -> Result<(), String>
where
F: Fn(serde_json::Value) -> Result<serde_json::Value, String> + Send + Sync + 'static,
{
if let Some(node) = self.nodes.get_mut(&node_id) {
node.tool_registry.register_tool(tool, function);
Ok(())
} else {
Err(format!("Node {} does not exist", node_id))
}
}
pub fn get_node_tool_registry(&self, node_id: i32) -> Option<&ToolRegistry> {
self.nodes.get(&node_id).map(|node| &node.tool_registry)
}
pub fn get_shared_tool_registry(&self) -> &ToolRegistry {
&self.tool_registry
}
pub fn add_node(&mut self, id: i32, agent: Box<dyn Agent>) {
self.nodes.insert(
id,
Node {
agent,
neighbors: Vec::new(),
tool_registry: ToolRegistry::new(),
},
);
}
pub fn add_edge(&mut self, u: i32, v: i32) -> Result<(), String> {
if !self.nodes.contains_key(&u) || !self.nodes.contains_key(&v) {
return Err("One or both nodes do not exist".to_string());
}
if let Some(node) = self.nodes.get_mut(&u) {
if !node.neighbors.contains(&v) {
node.neighbors.push(v);
}
}
if let Some(node) = self.nodes.get_mut(&v) {
if !node.neighbors.contains(&u) {
node.neighbors.push(u);
}
}
Ok(())
}
pub fn print(&self) {
println!("Adjacency list for the Graph:");
for (id, node) in &self.nodes {
print!("{} (Agent: {}) -> ", id, node.agent.get_name());
for neighbor in &node.neighbors {
print!("{} ", neighbor);
}
println!();
}
}
pub async fn run(&mut self, start_id: i32, input: &str) -> String {
let mut current_id = start_id;
let mut current_input = input.to_string();
let mut result = String::new();
loop {
if !self.nodes.contains_key(¤t_id) {
result.push_str(&format!("Error: Node {} does not exist\n", current_id));
break;
}
let (output, next_id) = unsafe {
let nodes_ptr = &mut self.nodes as *mut HashMap<i32, Node>;
let tool_registry_ptr = &self.tool_registry as *const ToolRegistry;
let node_tool_registry_ptr = {
let nodes_ref = &*nodes_ptr;
&nodes_ref[¤t_id].tool_registry as *const ToolRegistry
};
let combined_registry = CombinedToolRegistry::new(
&*node_tool_registry_ptr as &(dyn ToolRegistryTrait + Send + Sync),
&*tool_registry_ptr as &(dyn ToolRegistryTrait + Send + Sync),
);
let nodes_mut = &mut *nodes_ptr;
let node = nodes_mut.get_mut(¤t_id).unwrap();
node.agent.run(¤t_input, &combined_registry).await
};
result.push_str(&output);
result.push('\n');
match next_id {
Some(next) => {
if self.nodes.contains_key(&next) {
current_id = next;
current_input = output;
} else {
result.push_str(&format!("Error: Node {} does not exist\n", next));
break;
}
}
None => break,
}
}
result
}
}