use std::sync::{Arc, RwLock};
use tracing::{debug, info, instrument, warn};
use crate::error::RuntimeError;
use crate::graph::{transitions, CompiledGraph, NodeOutput};
use crate::state::AgentState;
#[derive(Clone, Debug)]
pub struct RunnerConfig {
pub max_iterations: u32,
pub verbose: bool,
pub tags: Vec<String>,
}
impl Default for RunnerConfig {
fn default() -> Self {
Self {
max_iterations: 100,
verbose: false,
tags: Vec::new(),
}
}
}
impl RunnerConfig {
pub fn new() -> Self {
Self::default()
}
pub fn max_iterations(mut self, max: u32) -> Self {
self.max_iterations = max;
self
}
pub fn recursion_limit(self, limit: u32) -> Self {
self.max_iterations(limit)
}
pub fn verbose(mut self, verbose: bool) -> Self {
self.verbose = verbose;
self
}
pub fn tag(mut self, tag: impl Into<String>) -> Self {
self.tags.push(tag.into());
self
}
pub fn tags(mut self, tags: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.tags.extend(tags.into_iter().map(|t| t.into()));
self
}
}
pub struct GraphRunner {
graph: CompiledGraph,
config: RunnerConfig,
}
impl GraphRunner {
pub fn new(graph: CompiledGraph, config: RunnerConfig) -> Self {
Self { graph, config }
}
pub fn with_defaults(graph: CompiledGraph) -> Self {
Self::new(graph, RunnerConfig::default())
}
#[instrument(skip(self, initial_state), fields(graph_name = ?self.graph.name()))]
pub async fn invoke(&self, initial_state: AgentState) -> Result<AgentState, RuntimeError> {
let state = Arc::new(RwLock::new(initial_state));
self.run_loop(state).await
}
#[instrument(skip(self, state), fields(graph_name = ?self.graph.name()))]
pub async fn invoke_shared(
&self,
state: Arc<RwLock<AgentState>>,
) -> Result<AgentState, RuntimeError> {
self.run_loop(state).await
}
async fn run_loop(
&self,
state: Arc<RwLock<AgentState>>,
) -> Result<AgentState, RuntimeError> {
let mut current_node = self.graph.entry_point().to_string();
let mut iterations: u32 = 0;
info!(
entry_point = %current_node,
max_iterations = self.config.max_iterations,
"Starting graph execution"
);
loop {
if iterations >= self.config.max_iterations {
warn!(
iterations = iterations,
max = self.config.max_iterations,
"Maximum iterations exceeded"
);
return Err(RuntimeError::RecursionLimit(self.config.max_iterations));
}
if current_node == transitions::END {
info!(iterations = iterations, "Graph execution completed");
let guard = state
.read()
.map_err(|e| RuntimeError::InvalidState(e.to_string()))?;
return Ok(guard.clone());
}
let node = self
.graph
.get_node(¤t_node)
.ok_or_else(|| RuntimeError::NodeNotFound(current_node.clone()))?;
debug!(node_id = %current_node, iteration = iterations, "Executing node");
let output = node
.executor
.execute(state.clone())
.await
.map_err(|e| RuntimeError::node_failed(¤t_node, e))?;
{
let mut guard = state
.write()
.map_err(|e| RuntimeError::InvalidState(e.to_string()))?;
guard.increment_iteration();
}
iterations += 1;
let next_node = match &output {
NodeOutput::Finish => {
info!(node_id = %current_node, "Node signaled finish");
transitions::END.to_string()
}
NodeOutput::Continue(Some(target)) => {
debug!(node_id = %current_node, target = %target, "Node specified next target");
target.clone()
}
NodeOutput::Continue(None) => {
let current_state = state
.read()
.map_err(|e| RuntimeError::InvalidState(e.to_string()))?;
match self.graph.get_next_node(¤t_node, ¤t_state) {
Some(next) => {
debug!(node_id = %current_node, next = %next, "Following graph edge");
next
}
None => {
debug!(
node_id = %current_node,
"No outgoing edge, ending execution"
);
transitions::END.to_string()
}
}
}
NodeOutput::Route(target) => {
debug!(node_id = %current_node, target = %target, "Node routing to target");
target.clone()
}
};
current_node = next_node;
}
}
pub fn graph(&self) -> &CompiledGraph {
&self.graph
}
pub fn config(&self) -> &RunnerConfig {
&self.config
}
}
pub type Runtime = GraphRunner;
#[cfg(test)]
mod tests {
use super::*;
use crate::error::NodeError;
use crate::graph::{GraphBuilder, NodeExecutor};
use crate::state::SharedState;
use async_trait::async_trait;
struct CounterNode {
id: String,
next: Option<String>,
}
#[async_trait]
impl NodeExecutor for CounterNode {
fn id(&self) -> &str {
&self.id
}
async fn execute(&self, state: SharedState) -> Result<NodeOutput, NodeError> {
{
let mut guard = state
.write()
.map_err(|e| NodeError::execution_failed(e.to_string()))?;
let count = guard.get_context::<i32>("count").unwrap_or(0);
guard.set_context("count", count + 1);
}
match &self.next {
Some(target) => Ok(NodeOutput::continue_to(target.clone())),
None => Ok(NodeOutput::finish()),
}
}
}
#[tokio::test]
async fn test_simple_execution() {
let graph = GraphBuilder::new()
.add_node(CounterNode {
id: "counter".to_string(),
next: None,
})
.set_entry_point("counter")
.compile()
.unwrap();
let runner = GraphRunner::with_defaults(graph);
let result = runner.invoke(AgentState::new()).await.unwrap();
assert_eq!(result.get_context::<i32>("count"), Some(1));
}
#[tokio::test]
async fn test_chained_execution() {
let graph = GraphBuilder::new()
.add_node(CounterNode {
id: "first".to_string(),
next: Some("second".to_string()),
})
.add_node(CounterNode {
id: "second".to_string(),
next: Some("third".to_string()),
})
.add_node(CounterNode {
id: "third".to_string(),
next: None,
})
.set_entry_point("first")
.compile()
.unwrap();
let runner = GraphRunner::with_defaults(graph);
let result = runner.invoke(AgentState::new()).await.unwrap();
assert_eq!(result.get_context::<i32>("count"), Some(3));
}
#[tokio::test]
async fn test_max_iterations() {
struct InfiniteNode;
#[async_trait]
impl NodeExecutor for InfiniteNode {
fn id(&self) -> &str {
"infinite"
}
async fn execute(&self, _state: SharedState) -> Result<NodeOutput, NodeError> {
Ok(NodeOutput::continue_to("infinite"))
}
}
let graph = GraphBuilder::new()
.add_node(InfiniteNode)
.set_entry_point("infinite")
.compile()
.unwrap();
let runner = GraphRunner::new(graph, RunnerConfig::default().max_iterations(5));
let result = runner.invoke(AgentState::new()).await;
assert!(matches!(result, Err(RuntimeError::RecursionLimit(5))));
}
#[tokio::test]
async fn test_edge_based_routing() {
struct SimpleNode {
id: String,
}
#[async_trait]
impl NodeExecutor for SimpleNode {
fn id(&self) -> &str {
&self.id
}
async fn execute(&self, state: SharedState) -> Result<NodeOutput, NodeError> {
{
let mut guard = state
.write()
.map_err(|e| NodeError::execution_failed(e.to_string()))?;
let visited = guard.get_context::<String>("visited").unwrap_or_default();
guard.set_context("visited", format!("{}{}", visited, self.id));
}
Ok(NodeOutput::cont())
}
}
let graph = GraphBuilder::new()
.add_node(SimpleNode {
id: "a".to_string(),
})
.add_node(SimpleNode {
id: "b".to_string(),
})
.add_node(SimpleNode {
id: "c".to_string(),
})
.set_entry_point("a")
.add_edge("a", "b")
.add_edge("b", "c")
.add_edge_to_end("c")
.compile()
.unwrap();
let runner = GraphRunner::with_defaults(graph);
let result = runner.invoke(AgentState::new()).await.unwrap();
assert_eq!(
result.get_context::<String>("visited"),
Some("abc".to_string())
);
}
#[test]
fn test_runner_config() {
let config = RunnerConfig::new()
.max_iterations(50)
.verbose(true)
.tag("test")
.tag("example");
assert_eq!(config.max_iterations, 50);
assert!(config.verbose);
assert_eq!(config.tags, vec!["test", "example"]);
}
}