use crate::edge::EdgeTarget;
use crate::error::{NodeError, RuntimeError};
use crate::graph::CompiledGraph;
use crate::node::NodeResult;
use crate::state::State;
use std::collections::HashSet;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::Stream;
#[derive(Clone, Debug)]
pub enum StreamEvent<S: State> {
NodeStart {
node_id: String,
state: S,
},
NodeEnd {
node_id: String,
state: S,
},
StateUpdate {
state: S,
},
Token {
content: String,
},
Complete {
final_state: S,
},
Interrupt {
state: S,
reason: String,
},
Error {
error: String,
},
}
#[derive(Clone, Debug)]
pub struct RuntimeConfig {
pub stream: bool,
pub recursion_limit: u32,
pub thread_id: Option<String>,
pub tags: HashSet<String>,
}
impl Default for RuntimeConfig {
fn default() -> Self {
Self {
stream: true,
recursion_limit: 25,
thread_id: None,
tags: HashSet::new(),
}
}
}
impl RuntimeConfig {
pub fn new() -> Self {
Self::default()
}
pub fn recursion_limit(mut self, limit: u32) -> Self {
self.recursion_limit = limit;
self
}
pub fn thread_id(mut self, id: impl Into<String>) -> Self {
self.thread_id = Some(id.into());
self
}
pub fn tag(mut self, tag: impl Into<String>) -> Self {
self.tags.insert(tag.into());
self
}
pub fn streaming(mut self, enabled: bool) -> Self {
self.stream = enabled;
self
}
}
pub struct Runtime<S: State> {
graph: CompiledGraph<S>,
config: RuntimeConfig,
}
impl<S: State> Runtime<S> {
pub fn new(graph: CompiledGraph<S>, config: RuntimeConfig) -> Self {
Self { graph, config }
}
pub fn with_defaults(graph: CompiledGraph<S>) -> Self {
Self::new(graph, RuntimeConfig::default())
}
pub async fn invoke(&self, initial_state: S) -> Result<S, RuntimeError> {
let mut state = initial_state;
let mut current_node = self.graph.entry_point.clone();
let mut step_count = 0u32;
loop {
if step_count >= self.config.recursion_limit {
return Err(RuntimeError::RecursionLimit(self.config.recursion_limit));
}
let node = self
.graph
.get_node(¤t_node)
.ok_or_else(|| RuntimeError::NodeNotFound(current_node.clone()))?;
let result = node
.run(state)
.await
.map_err(|e| RuntimeError::node_failed(¤t_node, e))?;
match result {
NodeResult::Continue(new_state) => {
state = new_state;
match self.graph.get_edge(¤t_node) {
Some(edge) => {
let target = edge.resolve(&state);
match target {
EdgeTarget::Node(next) => {
current_node = next;
}
EdgeTarget::End => {
return Ok(state);
}
}
}
None => {
return Ok(state);
}
}
}
NodeResult::Interrupt { state, reason } => {
return Err(RuntimeError::interrupted(reason));
}
NodeResult::End(final_state) => {
return Ok(final_state);
}
}
step_count += 1;
}
}
pub fn stream(&self, initial_state: S) -> impl Stream<Item = StreamEvent<S>>
where
S: Clone + 'static,
{
let (tx, rx) = mpsc::channel(100);
let graph_entry = self.graph.entry_point.clone();
let recursion_limit = self.config.recursion_limit;
let nodes: Vec<_> = self.graph.node_ids().map(|s| s.to_string()).collect();
let graph = CompiledGraphSnapshot {
entry_point: graph_entry,
node_ids: nodes,
};
let state = initial_state.clone();
tokio::spawn(async move {
let _ = tx
.send(StreamEvent::NodeStart {
node_id: graph.entry_point.clone(),
state: state.clone(),
})
.await;
let _ = tx
.send(StreamEvent::Complete {
final_state: state,
})
.await;
});
ReceiverStream::new(rx)
}
pub fn graph(&self) -> &CompiledGraph<S> {
&self.graph
}
pub fn config(&self) -> &RuntimeConfig {
&self.config
}
}
struct CompiledGraphSnapshot {
entry_point: String,
node_ids: Vec<String>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::NodeError;
use crate::graph::GraphBuilder;
use crate::node::Node;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
struct TestState {
counter: u32,
}
impl State for TestState {
fn schema() -> serde_json::Value {
serde_json::json!({"type": "object"})
}
}
struct IncrementNode;
#[async_trait]
impl Node<TestState> for IncrementNode {
fn id(&self) -> &str {
"increment"
}
async fn run(&self, mut state: TestState) -> Result<NodeResult<TestState>, NodeError> {
state.counter += 1;
Ok(NodeResult::Continue(state))
}
}
struct EndNode;
#[async_trait]
impl Node<TestState> for EndNode {
fn id(&self) -> &str {
"end"
}
async fn run(&self, state: TestState) -> Result<NodeResult<TestState>, NodeError> {
Ok(NodeResult::End(state))
}
}
#[tokio::test]
async fn test_simple_execution() {
let graph = GraphBuilder::<TestState>::new()
.add_node(IncrementNode)
.add_node(EndNode)
.set_entry_point("increment")
.add_edge("increment", "end")
.compile()
.unwrap();
let runtime = Runtime::with_defaults(graph);
let result = runtime.invoke(TestState::default()).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().counter, 1);
}
#[tokio::test]
async fn test_recursion_limit() {
struct LoopNode;
#[async_trait]
impl Node<TestState> for LoopNode {
fn id(&self) -> &str {
"loop"
}
async fn run(&self, mut state: TestState) -> Result<NodeResult<TestState>, NodeError> {
state.counter += 1;
Ok(NodeResult::Continue(state))
}
}
let graph = GraphBuilder::<TestState>::new()
.add_node(LoopNode)
.set_entry_point("loop")
.add_edge("loop", "loop") .compile()
.unwrap();
let runtime = Runtime::new(graph, RuntimeConfig::default().recursion_limit(5));
let result = runtime.invoke(TestState::default()).await;
assert!(matches!(result, Err(RuntimeError::RecursionLimit(5))));
}
#[test]
fn test_runtime_config() {
let config = RuntimeConfig::new()
.recursion_limit(50)
.thread_id("thread-123")
.tag("test")
.tag("example")
.streaming(false);
assert_eq!(config.recursion_limit, 50);
assert_eq!(config.thread_id, Some("thread-123".to_string()));
assert!(config.tags.contains("test"));
assert!(!config.stream);
}
}