Skip to main content

enact_core/runner/
execution_runner.rs

1//! Runner - executes callables and graphs with flow features
2
3use crate::callable::Callable;
4use crate::graph::{Checkpoint, CheckpointStore, CompiledGraph, NodeState};
5use crate::kernel::{ExecutionError, ExecutionId, StepId, StepType};
6use crate::streaming::{EventEmitter, StreamEvent};
7use std::sync::Arc;
8use std::time::Instant;
9use tokio_util::sync::CancellationToken;
10
11/// Runner - executes agents/graphs with abort, pause, and event streaming
12pub struct Runner<S: CheckpointStore> {
13    execution_id: ExecutionId,
14    cancellation_token: CancellationToken,
15    checkpoint_store: Arc<S>,
16    emitter: EventEmitter,
17    paused: std::sync::atomic::AtomicBool,
18    start_time: Option<Instant>,
19}
20
21impl<S: CheckpointStore> Runner<S> {
22    /// Create a new runner
23    pub fn new(checkpoint_store: Arc<S>) -> Self {
24        Self {
25            execution_id: ExecutionId::new(),
26            cancellation_token: CancellationToken::new(),
27            checkpoint_store,
28            emitter: EventEmitter::new(),
29            paused: std::sync::atomic::AtomicBool::new(false),
30            start_time: None,
31        }
32    }
33
34    /// Get the execution ID
35    pub fn execution_id(&self) -> &ExecutionId {
36        &self.execution_id
37    }
38
39    /// Get the event emitter
40    pub fn emitter(&self) -> &EventEmitter {
41        &self.emitter
42    }
43
44    /// Cancel the run (abort signal)
45    pub fn cancel(&self) {
46        self.cancellation_token.cancel();
47        self.emitter.emit(StreamEvent::execution_cancelled(
48            &self.execution_id,
49            "Run cancelled by user",
50        ));
51    }
52
53    /// Check if cancelled
54    pub fn is_cancelled(&self) -> bool {
55        self.cancellation_token.is_cancelled()
56    }
57
58    /// Pause the run
59    pub async fn pause(&self) -> anyhow::Result<()> {
60        self.paused.store(true, std::sync::atomic::Ordering::SeqCst);
61        self.emitter.emit(StreamEvent::execution_paused(
62            &self.execution_id,
63            "Paused by user",
64        ));
65        Ok(())
66    }
67
68    /// Resume the run
69    pub fn resume(&self) {
70        self.paused
71            .store(false, std::sync::atomic::Ordering::SeqCst);
72        self.emitter
73            .emit(StreamEvent::execution_resumed(&self.execution_id));
74    }
75
76    /// Check if paused
77    pub fn is_paused(&self) -> bool {
78        self.paused.load(std::sync::atomic::Ordering::SeqCst)
79    }
80
81    /// Save checkpoint
82    ///
83    /// # Arguments
84    /// * `state` - The current node state to save
85    /// * `node` - Optional current node name
86    /// * `agent_name` - Optional agent name for restoration on resume
87    pub async fn save_checkpoint(
88        &self,
89        state: NodeState,
90        node: Option<&str>,
91        agent_name: Option<&str>,
92    ) -> anyhow::Result<Checkpoint> {
93        let mut checkpoint = Checkpoint::new(self.execution_id.clone()).with_state(state.data);
94
95        if let Some(n) = node {
96            checkpoint = checkpoint.with_node(n);
97        }
98
99        if let Some(name) = agent_name {
100            checkpoint = checkpoint.with_agent_name(name);
101        }
102
103        self.checkpoint_store.save(checkpoint.clone()).await?;
104        Ok(checkpoint)
105    }
106
107    /// Load latest checkpoint
108    pub async fn load_checkpoint(&self) -> anyhow::Result<Option<Checkpoint>> {
109        self.checkpoint_store
110            .load_latest(self.execution_id.as_str())
111            .await
112    }
113
114    /// Run a callable with event streaming
115    pub async fn run_callable<A: Callable + ?Sized>(
116        &mut self,
117        callable: &A,
118        input: &str,
119    ) -> anyhow::Result<String> {
120        self.start_time = Some(Instant::now());
121        self.emitter
122            .emit(StreamEvent::execution_start(&self.execution_id));
123
124        // Check for cancellation before running
125        if self.is_cancelled() {
126            anyhow::bail!("Run cancelled");
127        }
128
129        let result = callable.run(input).await;
130        let duration_ms = self
131            .start_time
132            .map(|t| t.elapsed().as_millis() as u64)
133            .unwrap_or(0);
134
135        match &result {
136            Ok(output) => {
137                self.emitter.emit(StreamEvent::execution_end(
138                    &self.execution_id,
139                    Some(output.clone()),
140                    duration_ms,
141                ));
142            }
143            Err(e) => {
144                let error = ExecutionError::kernel_internal(e.to_string());
145                self.emitter
146                    .emit(StreamEvent::execution_failed(&self.execution_id, error));
147            }
148        }
149
150        result
151    }
152
153    /// Run a compiled graph with event streaming
154    pub async fn run_graph(
155        &mut self,
156        graph: &CompiledGraph,
157        input: &str,
158    ) -> anyhow::Result<NodeState> {
159        self.start_time = Some(Instant::now());
160        self.emitter
161            .emit(StreamEvent::execution_start(&self.execution_id));
162
163        let mut state = NodeState::from_string(input);
164        let mut current_node = graph.entry_point().to_string();
165
166        loop {
167            // Check for cancellation
168            if self.is_cancelled() {
169                anyhow::bail!("Run cancelled");
170            }
171
172            // Check for pause - wait until resumed
173            while self.is_paused() {
174                tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
175                if self.is_cancelled() {
176                    anyhow::bail!("Run cancelled while paused");
177                }
178            }
179
180            // Get the node
181            let node = graph
182                .get_node(&current_node)
183                .ok_or_else(|| anyhow::anyhow!("Node '{}' not found", current_node))?;
184
185            // Create step ID for this node execution
186            let step_id = StepId::new();
187            let step_start = Instant::now();
188
189            // Emit step started
190            self.emitter.emit(StreamEvent::step_start(
191                &self.execution_id,
192                &step_id,
193                StepType::FunctionNode, // Graph nodes are function nodes by default
194                current_node.clone(),
195            ));
196
197            // Execute node
198            state = node.execute(state).await?;
199
200            // Emit step completed
201            let step_duration = step_start.elapsed().as_millis() as u64;
202            self.emitter.emit(StreamEvent::step_end(
203                &self.execution_id,
204                &step_id,
205                Some(state.as_str().unwrap_or_default().to_string()),
206                step_duration,
207            ));
208
209            // Get next nodes
210            let output = state.as_str().unwrap_or_default();
211            let next = graph.get_next(&current_node, output);
212
213            if next.is_empty() {
214                break;
215            }
216
217            match &next[0] {
218                crate::graph::EdgeTarget::End => break,
219                crate::graph::EdgeTarget::Node(n) => {
220                    current_node = n.clone();
221                }
222            }
223        }
224
225        let duration_ms = self
226            .start_time
227            .map(|t| t.elapsed().as_millis() as u64)
228            .unwrap_or(0);
229        self.emitter.emit(StreamEvent::execution_end(
230            &self.execution_id,
231            Some(state.as_str().unwrap_or_default().to_string()),
232            duration_ms,
233        ));
234
235        Ok(state)
236    }
237}
238
239/// Runner with in-memory checkpoint store (default)
240pub type DefaultRunner = Runner<crate::graph::InMemoryCheckpointStore>;
241
242impl DefaultRunner {
243    /// Create a new runner with in-memory checkpoint store
244    pub fn default_new() -> Self {
245        Self::new(Arc::new(crate::graph::InMemoryCheckpointStore::new()))
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252    use crate::graph::InMemoryCheckpointStore;
253    use async_trait::async_trait;
254
255    /// Mock callable for testing
256    struct MockCallable {
257        name: String,
258        response: Result<String, String>,
259        delay_ms: Option<u64>,
260    }
261
262    impl MockCallable {
263        fn success(name: &str, response: &str) -> Self {
264            Self {
265                name: name.to_string(),
266                response: Ok(response.to_string()),
267                delay_ms: None,
268            }
269        }
270
271        fn failing(name: &str, error: &str) -> Self {
272            Self {
273                name: name.to_string(),
274                response: Err(error.to_string()),
275                delay_ms: None,
276            }
277        }
278    }
279
280    #[async_trait]
281    impl Callable for MockCallable {
282        fn name(&self) -> &str {
283            &self.name
284        }
285
286        async fn run(&self, input: &str) -> anyhow::Result<String> {
287            if let Some(delay) = self.delay_ms {
288                tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
289            }
290            match &self.response {
291                Ok(r) => Ok(format!("{}:{}", r, input)),
292                Err(e) => anyhow::bail!("{}", e),
293            }
294        }
295    }
296
297    // ============ Construction Tests ============
298
299    #[test]
300    fn test_runner_new() {
301        let store = Arc::new(InMemoryCheckpointStore::new());
302        let runner = Runner::new(store);
303
304        // Should have a valid execution ID
305        assert!(!runner.execution_id().as_str().is_empty());
306        // Should not be cancelled initially
307        assert!(!runner.is_cancelled());
308        // Should not be paused initially
309        assert!(!runner.is_paused());
310    }
311
312    #[test]
313    fn test_default_runner_new() {
314        let runner = DefaultRunner::default_new();
315        assert!(!runner.execution_id().as_str().is_empty());
316    }
317
318    #[test]
319    fn test_runner_execution_id_unique() {
320        let store = Arc::new(InMemoryCheckpointStore::new());
321        let runner1 = Runner::new(store.clone());
322        let runner2 = Runner::new(store);
323
324        // Each runner should have unique execution ID
325        assert_ne!(
326            runner1.execution_id().as_str(),
327            runner2.execution_id().as_str()
328        );
329    }
330
331    // ============ Cancellation Tests ============
332
333    #[test]
334    fn test_runner_cancel() {
335        let runner = DefaultRunner::default_new();
336
337        assert!(!runner.is_cancelled());
338        runner.cancel();
339        assert!(runner.is_cancelled());
340    }
341
342    #[tokio::test]
343    async fn test_runner_callable_checks_cancellation_before_run() {
344        let mut runner = DefaultRunner::default_new();
345        let callable = MockCallable::success("test", "response");
346
347        // Cancel before running
348        runner.cancel();
349
350        let result = runner.run_callable(&callable, "input").await;
351        assert!(result.is_err());
352        assert!(result.unwrap_err().to_string().contains("cancelled"));
353    }
354
355    // ============ Pause/Resume Tests ============
356
357    #[tokio::test]
358    async fn test_runner_pause_resume() {
359        let runner = DefaultRunner::default_new();
360
361        assert!(!runner.is_paused());
362
363        runner.pause().await.unwrap();
364        assert!(runner.is_paused());
365
366        runner.resume();
367        assert!(!runner.is_paused());
368    }
369
370    // ============ Run Callable Tests ============
371
372    #[tokio::test]
373    async fn test_run_callable_success() {
374        let mut runner = DefaultRunner::default_new();
375        let callable = MockCallable::success("test", "hello");
376
377        let result = runner.run_callable(&callable, "world").await;
378        assert!(result.is_ok());
379        assert_eq!(result.unwrap(), "hello:world");
380    }
381
382    #[tokio::test]
383    async fn test_run_callable_failure() {
384        let mut runner = DefaultRunner::default_new();
385        let callable = MockCallable::failing("test", "Something went wrong");
386
387        let result = runner.run_callable(&callable, "input").await;
388        assert!(result.is_err());
389        assert!(result
390            .unwrap_err()
391            .to_string()
392            .contains("Something went wrong"));
393    }
394
395    #[tokio::test]
396    async fn test_run_callable_emits_events() {
397        let mut runner = DefaultRunner::default_new();
398        let callable = MockCallable::success("test", "response");
399
400        runner.run_callable(&callable, "input").await.unwrap();
401
402        // Drain and check collected events
403        let events = runner.emitter().drain();
404
405        // Should have execution start and end events
406        assert!(events.len() >= 2);
407
408        // First event should be execution start
409        let first = &events[0];
410        assert!(matches!(first, StreamEvent::ExecutionStart { .. }));
411
412        // Last event should be execution end
413        let last = &events[events.len() - 1];
414        assert!(matches!(last, StreamEvent::ExecutionEnd { .. }));
415    }
416
417    #[tokio::test]
418    async fn test_run_callable_failure_emits_failed_event() {
419        let mut runner = DefaultRunner::default_new();
420        let callable = MockCallable::failing("test", "error message");
421
422        let _ = runner.run_callable(&callable, "input").await;
423
424        // Drain and check collected events
425        let events = runner.emitter().drain();
426
427        // Should have execution start and failed events
428        assert!(events.len() >= 2);
429
430        // Last event should be execution failed
431        let last = &events[events.len() - 1];
432        assert!(matches!(last, StreamEvent::ExecutionFailed { .. }));
433    }
434
435    // ============ Checkpoint Tests ============
436
437    #[tokio::test]
438    async fn test_runner_save_and_load_checkpoint() {
439        let runner = DefaultRunner::default_new();
440
441        // Save a checkpoint
442        let state = NodeState::from_string("test state data");
443        let checkpoint = runner
444            .save_checkpoint(state, Some("node1"), Some("test_agent"))
445            .await
446            .unwrap();
447
448        assert_eq!(checkpoint.current_node.as_ref().unwrap(), "node1");
449
450        // Load the checkpoint
451        let loaded = runner.load_checkpoint().await.unwrap();
452        assert!(loaded.is_some());
453
454        let loaded = loaded.unwrap();
455        assert_eq!(
456            loaded.state,
457            serde_json::Value::String("test state data".to_string())
458        );
459    }
460
461    #[tokio::test]
462    async fn test_runner_checkpoint_without_node() {
463        let runner = DefaultRunner::default_new();
464
465        let state = NodeState::from_string("some data");
466        let checkpoint = runner.save_checkpoint(state, None, None).await.unwrap();
467
468        assert!(checkpoint.current_node.is_none());
469        assert!(checkpoint.agent_name().is_none());
470    }
471
472    #[tokio::test]
473    async fn test_runner_checkpoint_with_agent_name() {
474        let runner = DefaultRunner::default_new();
475
476        let state = NodeState::from_string("agent state");
477        let checkpoint = runner
478            .save_checkpoint(state, Some("planning_node"), Some("planner"))
479            .await
480            .unwrap();
481
482        assert_eq!(checkpoint.current_node.as_ref().unwrap(), "planning_node");
483        assert_eq!(checkpoint.agent_name(), Some("planner"));
484
485        // Load and verify agent_name persists
486        let loaded = runner.load_checkpoint().await.unwrap().unwrap();
487        assert_eq!(loaded.agent_name(), Some("planner"));
488    }
489
490    #[tokio::test]
491    async fn test_runner_load_checkpoint_no_data() {
492        let runner = DefaultRunner::default_new();
493
494        // Without saving, load should return None
495        let loaded = runner.load_checkpoint().await.unwrap();
496        assert!(loaded.is_none());
497    }
498
499    // ============ Emitter Tests ============
500
501    #[test]
502    fn test_runner_emitter_access() {
503        let runner = DefaultRunner::default_new();
504        let emitter = runner.emitter();
505
506        // Should be able to emit and drain events
507        emitter.emit(StreamEvent::execution_start(runner.execution_id()));
508        let events = emitter.drain();
509        assert_eq!(events.len(), 1);
510    }
511
512    #[test]
513    fn test_emitter_mode() {
514        use crate::streaming::StreamMode;
515
516        let runner = DefaultRunner::default_new();
517        let emitter = runner.emitter();
518
519        // Default mode should be Full
520        assert_eq!(emitter.mode(), StreamMode::Full);
521    }
522}