use anyhow::Result;
use async_trait::async_trait;
use petgraph::{graph::NodeIndex, visit::EdgeRef, Graph};
use std::{collections::HashMap, sync::Arc};
pub type NodeId = usize;
pub type Condition<Ctx> = Box<dyn Fn(&Ctx) -> bool + Send + Sync>;
pub type NodeRegistry<Ctx> = HashMap<NodeId, Arc<dyn Node<Ctx>>>;
pub type ToolRegistry<Ctx> = HashMap<String, Arc<dyn Tool<Ctx>>>;
#[async_trait]
pub trait Tool<Ctx>: Send + Sync {
async fn execute(&self, ctx: &mut Ctx) -> Result<()>;
fn name(&self) -> &str;
}
#[async_trait]
pub trait Node<Ctx>: Send + Sync {
async fn enter(&self, ctx: &mut Ctx, tools: &ToolRegistry<Ctx>) -> Result<()>;
}
pub struct AgentEngine<Ctx> {
graph: Graph<NodeId, Condition<Ctx>>,
nodes: NodeRegistry<Ctx>,
tools: ToolRegistry<Ctx>,
current: Option<NodeIndex>,
context: Ctx,
next_node_id: NodeId,
}
impl<Ctx> AgentEngine<Ctx> {
pub fn new(context: Ctx) -> Self {
Self {
graph: Graph::new(),
nodes: HashMap::new(),
tools: HashMap::new(),
current: None,
context,
next_node_id: 0,
}
}
pub fn add_tools<T>(&mut self, tools: impl IntoIterator<Item = T>)
where
T: Tool<Ctx> + 'static,
{
for tool in tools {
let name = tool.name().to_owned();
self.tools.insert(name, Arc::new(tool));
}
}
pub fn add_node<N: Node<Ctx> + 'static>(&mut self, node: N) -> NodeId {
let node_id = self.next_node_id;
self.next_node_id += 1;
self.graph.add_node(node_id);
self.nodes.insert(node_id, Arc::new(node));
node_id
}
pub fn add_transition<F>(&mut self, from: NodeId, to: NodeId, condition: F) -> Result<()>
where
F: Fn(&Ctx) -> bool + Send + Sync + 'static,
{
use anyhow::anyhow;
let from_idx = self
.graph
.node_indices()
.find(|idx| self.graph[*idx] == from)
.ok_or_else(|| anyhow!("Source node not found: {}", from))?;
let to_idx = self
.graph
.node_indices()
.find(|idx| self.graph[*idx] == to)
.ok_or_else(|| anyhow!("Target node not found: {}", to))?;
self.graph.add_edge(from_idx, to_idx, Box::new(condition));
Ok(())
}
pub fn set_start_node(&mut self, id: NodeId) -> Result<()> {
use anyhow::anyhow;
let start_idx = self
.graph
.node_indices()
.find(|idx| self.graph[*idx] == id)
.ok_or_else(|| anyhow!("Start node not found: {}", id))?;
self.current = Some(start_idx);
Ok(())
}
pub async fn run(&mut self) -> Result<()> {
use anyhow::anyhow;
let mut current_idx = self
.current
.ok_or_else(|| anyhow!("No start node set. Call set_start_node() first"))?;
loop {
let node_id = match self.graph.node_weight(current_idx) {
Some(id) => *id,
None => break,
};
let node = self
.nodes
.get(&node_id)
.ok_or_else(|| anyhow!("Node not found: {}", node_id))?;
node.enter(&mut self.context, &self.tools).await?;
let next_edge = self
.graph
.edges(current_idx)
.find(|edge| (edge.weight())(&self.context));
match next_edge {
Some(edge) => {
current_idx = edge.target();
}
None => {
break;
}
}
}
self.current = Some(current_idx);
Ok(())
}
}