use anyhow::Result;
use async_trait::async_trait;
use petgraph::{graph::NodeIndex, visit::EdgeRef, Graph};
use std::{collections::HashMap, fmt::Debug, hash::Hash, sync::Arc};
type NodeId = usize;
type Condition<Ctx> = Box<dyn Fn(&Ctx) -> bool + Send + Sync>;
type NodeRegistry<Ctx> = HashMap<NodeId, Arc<dyn NodeTrait<Ctx>>>;
pub trait NodeKind: Clone + Copy + PartialEq + Eq + Hash + Debug {}
impl<T: Clone + Copy + PartialEq + Eq + Hash + Debug> NodeKind for T {}
#[async_trait]
pub trait NodeTrait<Ctx>: Send + Sync {
async fn enter(&self, ctx: &mut Ctx) -> Result<()>;
}
pub struct AgentEngine<Ctx, N>
where
N: NodeKind,
{
graph: Graph<NodeId, Condition<Ctx>>,
nodes: NodeRegistry<Ctx>,
current: Option<NodeIndex>,
context: Ctx,
next_node_id: NodeId,
node_map: HashMap<N, NodeId>,
}
impl<Ctx, N> AgentEngine<Ctx, N>
where
N: NodeKind,
{
pub fn builder(context: Ctx) -> AgentEngineBuilder<Ctx, N> {
AgentEngineBuilder::new(context)
}
fn new(context: Ctx) -> Self {
Self {
graph: Graph::new(),
nodes: HashMap::new(),
current: None,
context,
next_node_id: 0,
node_map: HashMap::new(),
}
}
pub fn add_node<T: NodeTrait<Ctx> + 'static>(&mut self, kind: N, node: T) -> NodeId {
let node_id = self.next_node_id;
self.next_node_id += 1;
let _idx = self.graph.add_node(node_id);
self.nodes.insert(node_id, Arc::new(node));
self.node_map.insert(kind, node_id);
node_id
}
pub fn add_transition<F>(&mut self, from: N, to: N, condition: F) -> Result<()>
where
F: Fn(&Ctx) -> bool + Send + Sync + 'static,
{
use anyhow::anyhow;
let from_id = self.node_map.get(&from).copied().ok_or_else(|| {
anyhow!("Source node {:?} not found", from)
})?;
let to_id = self.node_map.get(&to).copied().ok_or_else(|| {
anyhow!("Target node {:?} not found", to)
})?;
let from_idx = self
.graph
.node_indices()
.find(|idx| self.graph[*idx] == from_id)
.ok_or_else(|| anyhow!("Source node not found: {:?}", from))?;
let to_idx = self
.graph
.node_indices()
.find(|idx| self.graph[*idx] == to_id)
.ok_or_else(|| anyhow!("Target node not found: {:?}", to))?;
self.graph.add_edge(from_idx, to_idx, Box::new(condition));
Ok(())
}
pub fn set_start_node(&mut self, kind: N) -> Result<()> {
use anyhow::anyhow;
let id = self.node_map.get(&kind).copied().ok_or_else(|| {
anyhow!("Start node {:?} not found", kind)
})?;
let start_idx = self
.graph
.node_indices()
.find(|idx| self.graph[*idx] == id)
.ok_or_else(|| anyhow!("Start node not found: {:?}", kind))?;
self.current = Some(start_idx);
Ok(())
}
pub async fn run(&mut self) -> Result<()> {
use anyhow::anyhow;
let mut current_idx = self
.current
.ok_or_else(|| anyhow!("Start node not set, call set_start_node() first"))?;
loop {
let node_id = match self.graph.node_weight(current_idx) {
Some(id) => *id,
None => break,
};
let node = self
.nodes
.get(&node_id)
.ok_or_else(|| anyhow!("Node does not exist: {}", node_id))?;
node.enter(&mut self.context).await?;
let next_edge = self
.graph
.edges(current_idx)
.find(|edge| (edge.weight())(&self.context));
match next_edge {
Some(edge) => {
current_idx = edge.target();
}
None => {
break;
}
}
}
self.current = Some(current_idx);
Ok(())
}
}
pub struct AgentEngineBuilder<Ctx, N>
where
N: NodeKind,
{
engine: AgentEngine<Ctx, N>,
start_node: Option<N>,
}
impl<Ctx, N> AgentEngineBuilder<Ctx, N>
where
N: NodeKind,
{
fn new(context: Ctx) -> Self {
Self {
engine: AgentEngine::new(context),
start_node: None,
}
}
pub fn with_node(mut self, kind: N, node: impl NodeTrait<Ctx> + 'static) -> Self {
self.engine.add_node(kind, node);
self
}
pub fn with_transition(
mut self,
from: N,
to: N,
condition: impl Fn(&Ctx) -> bool + Send + Sync + 'static,
) -> Result<Self> {
self.engine.add_transition(from, to, condition)?;
Ok(self)
}
pub fn with_start_node(mut self, kind: N) -> Self {
self.start_node = Some(kind);
self
}
pub fn build(mut self) -> Result<AgentEngine<Ctx, N>> {
if let Some(start_kind) = self.start_node {
self.engine.set_start_node(start_kind)?;
}
Ok(self.engine)
}
}