use crate::callable::Callable;
use crate::graph::{Checkpoint, CheckpointStore, CompiledGraph, NodeState};
use crate::kernel::{ExecutionError, ExecutionId, StepId, StepType};
use crate::streaming::{EventEmitter, StreamEvent};
use std::sync::Arc;
use std::time::Instant;
use tokio_util::sync::CancellationToken;
pub struct Runner<S: CheckpointStore> {
execution_id: ExecutionId,
cancellation_token: CancellationToken,
checkpoint_store: Arc<S>,
emitter: EventEmitter,
paused: std::sync::atomic::AtomicBool,
start_time: Option<Instant>,
}
impl<S: CheckpointStore> Runner<S> {
pub fn new(checkpoint_store: Arc<S>) -> Self {
Self {
execution_id: ExecutionId::new(),
cancellation_token: CancellationToken::new(),
checkpoint_store,
emitter: EventEmitter::new(),
paused: std::sync::atomic::AtomicBool::new(false),
start_time: None,
}
}
pub fn execution_id(&self) -> &ExecutionId {
&self.execution_id
}
pub fn emitter(&self) -> &EventEmitter {
&self.emitter
}
pub fn cancel(&self) {
self.cancellation_token.cancel();
self.emitter.emit(StreamEvent::execution_cancelled(
&self.execution_id,
"Run cancelled by user",
));
}
pub fn is_cancelled(&self) -> bool {
self.cancellation_token.is_cancelled()
}
pub async fn pause(&self) -> anyhow::Result<()> {
self.paused.store(true, std::sync::atomic::Ordering::SeqCst);
self.emitter.emit(StreamEvent::execution_paused(
&self.execution_id,
"Paused by user",
));
Ok(())
}
pub fn resume(&self) {
self.paused
.store(false, std::sync::atomic::Ordering::SeqCst);
self.emitter
.emit(StreamEvent::execution_resumed(&self.execution_id));
}
pub fn is_paused(&self) -> bool {
self.paused.load(std::sync::atomic::Ordering::SeqCst)
}
pub async fn save_checkpoint(
&self,
state: NodeState,
node: Option<&str>,
agent_name: Option<&str>,
) -> anyhow::Result<Checkpoint> {
let mut checkpoint = Checkpoint::new(self.execution_id.clone()).with_state(state.data);
if let Some(n) = node {
checkpoint = checkpoint.with_node(n);
}
if let Some(name) = agent_name {
checkpoint = checkpoint.with_agent_name(name);
}
self.checkpoint_store.save(checkpoint.clone()).await?;
Ok(checkpoint)
}
pub async fn load_checkpoint(&self) -> anyhow::Result<Option<Checkpoint>> {
self.checkpoint_store
.load_latest(self.execution_id.as_str())
.await
}
pub async fn run_callable<A: Callable + ?Sized>(
&mut self,
callable: &A,
input: &str,
) -> anyhow::Result<String> {
self.start_time = Some(Instant::now());
self.emitter
.emit(StreamEvent::execution_start(&self.execution_id));
if self.is_cancelled() {
anyhow::bail!("Run cancelled");
}
let result = callable.run(input).await;
let duration_ms = self
.start_time
.map(|t| t.elapsed().as_millis() as u64)
.unwrap_or(0);
match &result {
Ok(output) => {
self.emitter.emit(StreamEvent::execution_end(
&self.execution_id,
Some(output.clone()),
duration_ms,
));
}
Err(e) => {
let error = ExecutionError::kernel_internal(e.to_string());
self.emitter
.emit(StreamEvent::execution_failed(&self.execution_id, error));
}
}
result
}
pub async fn run_graph(
&mut self,
graph: &CompiledGraph,
input: &str,
) -> anyhow::Result<NodeState> {
self.start_time = Some(Instant::now());
self.emitter
.emit(StreamEvent::execution_start(&self.execution_id));
let mut state = NodeState::from_string(input);
let mut current_node = graph.entry_point().to_string();
loop {
if self.is_cancelled() {
anyhow::bail!("Run cancelled");
}
while self.is_paused() {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
if self.is_cancelled() {
anyhow::bail!("Run cancelled while paused");
}
}
let node = graph
.get_node(¤t_node)
.ok_or_else(|| anyhow::anyhow!("Node '{}' not found", current_node))?;
let step_id = StepId::new();
let step_start = Instant::now();
self.emitter.emit(StreamEvent::step_start(
&self.execution_id,
&step_id,
StepType::FunctionNode, current_node.clone(),
));
state = node.execute(state).await?;
let step_duration = step_start.elapsed().as_millis() as u64;
self.emitter.emit(StreamEvent::step_end(
&self.execution_id,
&step_id,
Some(state.as_str().unwrap_or_default().to_string()),
step_duration,
));
let output = state.as_str().unwrap_or_default();
let next = graph.get_next(¤t_node, output);
if next.is_empty() {
break;
}
match &next[0] {
crate::graph::EdgeTarget::End => break,
crate::graph::EdgeTarget::Node(n) => {
current_node = n.clone();
}
}
}
let duration_ms = self
.start_time
.map(|t| t.elapsed().as_millis() as u64)
.unwrap_or(0);
self.emitter.emit(StreamEvent::execution_end(
&self.execution_id,
Some(state.as_str().unwrap_or_default().to_string()),
duration_ms,
));
Ok(state)
}
}
pub type DefaultRunner = Runner<crate::graph::InMemoryCheckpointStore>;
impl DefaultRunner {
pub fn default_new() -> Self {
Self::new(Arc::new(crate::graph::InMemoryCheckpointStore::new()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::InMemoryCheckpointStore;
use async_trait::async_trait;
struct MockCallable {
name: String,
response: Result<String, String>,
delay_ms: Option<u64>,
}
impl MockCallable {
fn success(name: &str, response: &str) -> Self {
Self {
name: name.to_string(),
response: Ok(response.to_string()),
delay_ms: None,
}
}
fn failing(name: &str, error: &str) -> Self {
Self {
name: name.to_string(),
response: Err(error.to_string()),
delay_ms: None,
}
}
}
#[async_trait]
impl Callable for MockCallable {
fn name(&self) -> &str {
&self.name
}
async fn run(&self, input: &str) -> anyhow::Result<String> {
if let Some(delay) = self.delay_ms {
tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
}
match &self.response {
Ok(r) => Ok(format!("{}:{}", r, input)),
Err(e) => anyhow::bail!("{}", e),
}
}
}
#[test]
fn test_runner_new() {
let store = Arc::new(InMemoryCheckpointStore::new());
let runner = Runner::new(store);
assert!(!runner.execution_id().as_str().is_empty());
assert!(!runner.is_cancelled());
assert!(!runner.is_paused());
}
#[test]
fn test_default_runner_new() {
let runner = DefaultRunner::default_new();
assert!(!runner.execution_id().as_str().is_empty());
}
#[test]
fn test_runner_execution_id_unique() {
let store = Arc::new(InMemoryCheckpointStore::new());
let runner1 = Runner::new(store.clone());
let runner2 = Runner::new(store);
assert_ne!(
runner1.execution_id().as_str(),
runner2.execution_id().as_str()
);
}
#[test]
fn test_runner_cancel() {
let runner = DefaultRunner::default_new();
assert!(!runner.is_cancelled());
runner.cancel();
assert!(runner.is_cancelled());
}
#[tokio::test]
async fn test_runner_callable_checks_cancellation_before_run() {
let mut runner = DefaultRunner::default_new();
let callable = MockCallable::success("test", "response");
runner.cancel();
let result = runner.run_callable(&callable, "input").await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("cancelled"));
}
#[tokio::test]
async fn test_runner_pause_resume() {
let runner = DefaultRunner::default_new();
assert!(!runner.is_paused());
runner.pause().await.unwrap();
assert!(runner.is_paused());
runner.resume();
assert!(!runner.is_paused());
}
#[tokio::test]
async fn test_run_callable_success() {
let mut runner = DefaultRunner::default_new();
let callable = MockCallable::success("test", "hello");
let result = runner.run_callable(&callable, "world").await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "hello:world");
}
#[tokio::test]
async fn test_run_callable_failure() {
let mut runner = DefaultRunner::default_new();
let callable = MockCallable::failing("test", "Something went wrong");
let result = runner.run_callable(&callable, "input").await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Something went wrong"));
}
#[tokio::test]
async fn test_run_callable_emits_events() {
let mut runner = DefaultRunner::default_new();
let callable = MockCallable::success("test", "response");
runner.run_callable(&callable, "input").await.unwrap();
let events = runner.emitter().drain();
assert!(events.len() >= 2);
let first = &events[0];
assert!(matches!(first, StreamEvent::ExecutionStart { .. }));
let last = &events[events.len() - 1];
assert!(matches!(last, StreamEvent::ExecutionEnd { .. }));
}
#[tokio::test]
async fn test_run_callable_failure_emits_failed_event() {
let mut runner = DefaultRunner::default_new();
let callable = MockCallable::failing("test", "error message");
let _ = runner.run_callable(&callable, "input").await;
let events = runner.emitter().drain();
assert!(events.len() >= 2);
let last = &events[events.len() - 1];
assert!(matches!(last, StreamEvent::ExecutionFailed { .. }));
}
#[tokio::test]
async fn test_runner_save_and_load_checkpoint() {
let runner = DefaultRunner::default_new();
let state = NodeState::from_string("test state data");
let checkpoint = runner
.save_checkpoint(state, Some("node1"), Some("test_agent"))
.await
.unwrap();
assert_eq!(checkpoint.current_node.as_ref().unwrap(), "node1");
let loaded = runner.load_checkpoint().await.unwrap();
assert!(loaded.is_some());
let loaded = loaded.unwrap();
assert_eq!(
loaded.state,
serde_json::Value::String("test state data".to_string())
);
}
#[tokio::test]
async fn test_runner_checkpoint_without_node() {
let runner = DefaultRunner::default_new();
let state = NodeState::from_string("some data");
let checkpoint = runner.save_checkpoint(state, None, None).await.unwrap();
assert!(checkpoint.current_node.is_none());
assert!(checkpoint.agent_name().is_none());
}
#[tokio::test]
async fn test_runner_checkpoint_with_agent_name() {
let runner = DefaultRunner::default_new();
let state = NodeState::from_string("agent state");
let checkpoint = runner
.save_checkpoint(state, Some("planning_node"), Some("planner"))
.await
.unwrap();
assert_eq!(checkpoint.current_node.as_ref().unwrap(), "planning_node");
assert_eq!(checkpoint.agent_name(), Some("planner"));
let loaded = runner.load_checkpoint().await.unwrap().unwrap();
assert_eq!(loaded.agent_name(), Some("planner"));
}
#[tokio::test]
async fn test_runner_load_checkpoint_no_data() {
let runner = DefaultRunner::default_new();
let loaded = runner.load_checkpoint().await.unwrap();
assert!(loaded.is_none());
}
#[test]
fn test_runner_emitter_access() {
let runner = DefaultRunner::default_new();
let emitter = runner.emitter();
emitter.emit(StreamEvent::execution_start(runner.execution_id()));
let events = emitter.drain();
assert_eq!(events.len(), 1);
}
#[test]
fn test_emitter_mode() {
use crate::streaming::StreamMode;
let runner = DefaultRunner::default_new();
let emitter = runner.emitter();
assert_eq!(emitter.mode(), StreamMode::Full);
}
}