1use crate::callable::Callable;
4use crate::graph::{Checkpoint, CheckpointStore, CompiledGraph, NodeState};
5use crate::kernel::{ExecutionError, ExecutionId, StepId, StepType};
6use crate::streaming::{EventEmitter, StreamEvent};
7use std::sync::Arc;
8use std::time::Instant;
9use tokio_util::sync::CancellationToken;
10
11pub struct Runner<S: CheckpointStore> {
13 execution_id: ExecutionId,
14 cancellation_token: CancellationToken,
15 checkpoint_store: Arc<S>,
16 emitter: EventEmitter,
17 paused: std::sync::atomic::AtomicBool,
18 start_time: Option<Instant>,
19}
20
21impl<S: CheckpointStore> Runner<S> {
22 pub fn new(checkpoint_store: Arc<S>) -> Self {
24 Self {
25 execution_id: ExecutionId::new(),
26 cancellation_token: CancellationToken::new(),
27 checkpoint_store,
28 emitter: EventEmitter::new(),
29 paused: std::sync::atomic::AtomicBool::new(false),
30 start_time: None,
31 }
32 }
33
34 pub fn execution_id(&self) -> &ExecutionId {
36 &self.execution_id
37 }
38
39 pub fn emitter(&self) -> &EventEmitter {
41 &self.emitter
42 }
43
44 pub fn cancel(&self) {
46 self.cancellation_token.cancel();
47 self.emitter.emit(StreamEvent::execution_cancelled(
48 &self.execution_id,
49 "Run cancelled by user",
50 ));
51 }
52
53 pub fn is_cancelled(&self) -> bool {
55 self.cancellation_token.is_cancelled()
56 }
57
58 pub async fn pause(&self) -> anyhow::Result<()> {
60 self.paused.store(true, std::sync::atomic::Ordering::SeqCst);
61 self.emitter.emit(StreamEvent::execution_paused(
62 &self.execution_id,
63 "Paused by user",
64 ));
65 Ok(())
66 }
67
68 pub fn resume(&self) {
70 self.paused
71 .store(false, std::sync::atomic::Ordering::SeqCst);
72 self.emitter
73 .emit(StreamEvent::execution_resumed(&self.execution_id));
74 }
75
76 pub fn is_paused(&self) -> bool {
78 self.paused.load(std::sync::atomic::Ordering::SeqCst)
79 }
80
81 pub async fn save_checkpoint(
88 &self,
89 state: NodeState,
90 node: Option<&str>,
91 agent_name: Option<&str>,
92 ) -> anyhow::Result<Checkpoint> {
93 let mut checkpoint = Checkpoint::new(self.execution_id.clone()).with_state(state.data);
94
95 if let Some(n) = node {
96 checkpoint = checkpoint.with_node(n);
97 }
98
99 if let Some(name) = agent_name {
100 checkpoint = checkpoint.with_agent_name(name);
101 }
102
103 self.checkpoint_store.save(checkpoint.clone()).await?;
104 Ok(checkpoint)
105 }
106
107 pub async fn load_checkpoint(&self) -> anyhow::Result<Option<Checkpoint>> {
109 self.checkpoint_store
110 .load_latest(self.execution_id.as_str())
111 .await
112 }
113
114 pub async fn run_callable<A: Callable + ?Sized>(
116 &mut self,
117 callable: &A,
118 input: &str,
119 ) -> anyhow::Result<String> {
120 self.start_time = Some(Instant::now());
121 self.emitter
122 .emit(StreamEvent::execution_start(&self.execution_id));
123
124 if self.is_cancelled() {
126 anyhow::bail!("Run cancelled");
127 }
128
129 let result = callable.run(input).await;
130 let duration_ms = self
131 .start_time
132 .map(|t| t.elapsed().as_millis() as u64)
133 .unwrap_or(0);
134
135 match &result {
136 Ok(output) => {
137 self.emitter.emit(StreamEvent::execution_end(
138 &self.execution_id,
139 Some(output.clone()),
140 duration_ms,
141 ));
142 }
143 Err(e) => {
144 let error = ExecutionError::kernel_internal(e.to_string());
145 self.emitter
146 .emit(StreamEvent::execution_failed(&self.execution_id, error));
147 }
148 }
149
150 result
151 }
152
153 pub async fn run_graph(
155 &mut self,
156 graph: &CompiledGraph,
157 input: &str,
158 ) -> anyhow::Result<NodeState> {
159 self.start_time = Some(Instant::now());
160 self.emitter
161 .emit(StreamEvent::execution_start(&self.execution_id));
162
163 let mut state = NodeState::from_string(input);
164 let mut current_node = graph.entry_point().to_string();
165
166 loop {
167 if self.is_cancelled() {
169 anyhow::bail!("Run cancelled");
170 }
171
172 while self.is_paused() {
174 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
175 if self.is_cancelled() {
176 anyhow::bail!("Run cancelled while paused");
177 }
178 }
179
180 let node = graph
182 .get_node(¤t_node)
183 .ok_or_else(|| anyhow::anyhow!("Node '{}' not found", current_node))?;
184
185 let step_id = StepId::new();
187 let step_start = Instant::now();
188
189 self.emitter.emit(StreamEvent::step_start(
191 &self.execution_id,
192 &step_id,
193 StepType::FunctionNode, current_node.clone(),
195 ));
196
197 state = node.execute(state).await?;
199
200 let step_duration = step_start.elapsed().as_millis() as u64;
202 self.emitter.emit(StreamEvent::step_end(
203 &self.execution_id,
204 &step_id,
205 Some(state.as_str().unwrap_or_default().to_string()),
206 step_duration,
207 ));
208
209 let output = state.as_str().unwrap_or_default();
211 let next = graph.get_next(¤t_node, output);
212
213 if next.is_empty() {
214 break;
215 }
216
217 match &next[0] {
218 crate::graph::EdgeTarget::End => break,
219 crate::graph::EdgeTarget::Node(n) => {
220 current_node = n.clone();
221 }
222 }
223 }
224
225 let duration_ms = self
226 .start_time
227 .map(|t| t.elapsed().as_millis() as u64)
228 .unwrap_or(0);
229 self.emitter.emit(StreamEvent::execution_end(
230 &self.execution_id,
231 Some(state.as_str().unwrap_or_default().to_string()),
232 duration_ms,
233 ));
234
235 Ok(state)
236 }
237}
238
239pub type DefaultRunner = Runner<crate::graph::InMemoryCheckpointStore>;
241
242impl DefaultRunner {
243 pub fn default_new() -> Self {
245 Self::new(Arc::new(crate::graph::InMemoryCheckpointStore::new()))
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252 use crate::graph::InMemoryCheckpointStore;
253 use async_trait::async_trait;
254
255 struct MockCallable {
257 name: String,
258 response: Result<String, String>,
259 delay_ms: Option<u64>,
260 }
261
262 impl MockCallable {
263 fn success(name: &str, response: &str) -> Self {
264 Self {
265 name: name.to_string(),
266 response: Ok(response.to_string()),
267 delay_ms: None,
268 }
269 }
270
271 fn failing(name: &str, error: &str) -> Self {
272 Self {
273 name: name.to_string(),
274 response: Err(error.to_string()),
275 delay_ms: None,
276 }
277 }
278 }
279
280 #[async_trait]
281 impl Callable for MockCallable {
282 fn name(&self) -> &str {
283 &self.name
284 }
285
286 async fn run(&self, input: &str) -> anyhow::Result<String> {
287 if let Some(delay) = self.delay_ms {
288 tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
289 }
290 match &self.response {
291 Ok(r) => Ok(format!("{}:{}", r, input)),
292 Err(e) => anyhow::bail!("{}", e),
293 }
294 }
295 }
296
297 #[test]
300 fn test_runner_new() {
301 let store = Arc::new(InMemoryCheckpointStore::new());
302 let runner = Runner::new(store);
303
304 assert!(!runner.execution_id().as_str().is_empty());
306 assert!(!runner.is_cancelled());
308 assert!(!runner.is_paused());
310 }
311
312 #[test]
313 fn test_default_runner_new() {
314 let runner = DefaultRunner::default_new();
315 assert!(!runner.execution_id().as_str().is_empty());
316 }
317
318 #[test]
319 fn test_runner_execution_id_unique() {
320 let store = Arc::new(InMemoryCheckpointStore::new());
321 let runner1 = Runner::new(store.clone());
322 let runner2 = Runner::new(store);
323
324 assert_ne!(
326 runner1.execution_id().as_str(),
327 runner2.execution_id().as_str()
328 );
329 }
330
331 #[test]
334 fn test_runner_cancel() {
335 let runner = DefaultRunner::default_new();
336
337 assert!(!runner.is_cancelled());
338 runner.cancel();
339 assert!(runner.is_cancelled());
340 }
341
342 #[tokio::test]
343 async fn test_runner_callable_checks_cancellation_before_run() {
344 let mut runner = DefaultRunner::default_new();
345 let callable = MockCallable::success("test", "response");
346
347 runner.cancel();
349
350 let result = runner.run_callable(&callable, "input").await;
351 assert!(result.is_err());
352 assert!(result.unwrap_err().to_string().contains("cancelled"));
353 }
354
355 #[tokio::test]
358 async fn test_runner_pause_resume() {
359 let runner = DefaultRunner::default_new();
360
361 assert!(!runner.is_paused());
362
363 runner.pause().await.unwrap();
364 assert!(runner.is_paused());
365
366 runner.resume();
367 assert!(!runner.is_paused());
368 }
369
370 #[tokio::test]
373 async fn test_run_callable_success() {
374 let mut runner = DefaultRunner::default_new();
375 let callable = MockCallable::success("test", "hello");
376
377 let result = runner.run_callable(&callable, "world").await;
378 assert!(result.is_ok());
379 assert_eq!(result.unwrap(), "hello:world");
380 }
381
382 #[tokio::test]
383 async fn test_run_callable_failure() {
384 let mut runner = DefaultRunner::default_new();
385 let callable = MockCallable::failing("test", "Something went wrong");
386
387 let result = runner.run_callable(&callable, "input").await;
388 assert!(result.is_err());
389 assert!(result
390 .unwrap_err()
391 .to_string()
392 .contains("Something went wrong"));
393 }
394
395 #[tokio::test]
396 async fn test_run_callable_emits_events() {
397 let mut runner = DefaultRunner::default_new();
398 let callable = MockCallable::success("test", "response");
399
400 runner.run_callable(&callable, "input").await.unwrap();
401
402 let events = runner.emitter().drain();
404
405 assert!(events.len() >= 2);
407
408 let first = &events[0];
410 assert!(matches!(first, StreamEvent::ExecutionStart { .. }));
411
412 let last = &events[events.len() - 1];
414 assert!(matches!(last, StreamEvent::ExecutionEnd { .. }));
415 }
416
417 #[tokio::test]
418 async fn test_run_callable_failure_emits_failed_event() {
419 let mut runner = DefaultRunner::default_new();
420 let callable = MockCallable::failing("test", "error message");
421
422 let _ = runner.run_callable(&callable, "input").await;
423
424 let events = runner.emitter().drain();
426
427 assert!(events.len() >= 2);
429
430 let last = &events[events.len() - 1];
432 assert!(matches!(last, StreamEvent::ExecutionFailed { .. }));
433 }
434
435 #[tokio::test]
438 async fn test_runner_save_and_load_checkpoint() {
439 let runner = DefaultRunner::default_new();
440
441 let state = NodeState::from_string("test state data");
443 let checkpoint = runner
444 .save_checkpoint(state, Some("node1"), Some("test_agent"))
445 .await
446 .unwrap();
447
448 assert_eq!(checkpoint.current_node.as_ref().unwrap(), "node1");
449
450 let loaded = runner.load_checkpoint().await.unwrap();
452 assert!(loaded.is_some());
453
454 let loaded = loaded.unwrap();
455 assert_eq!(
456 loaded.state,
457 serde_json::Value::String("test state data".to_string())
458 );
459 }
460
461 #[tokio::test]
462 async fn test_runner_checkpoint_without_node() {
463 let runner = DefaultRunner::default_new();
464
465 let state = NodeState::from_string("some data");
466 let checkpoint = runner.save_checkpoint(state, None, None).await.unwrap();
467
468 assert!(checkpoint.current_node.is_none());
469 assert!(checkpoint.agent_name().is_none());
470 }
471
472 #[tokio::test]
473 async fn test_runner_checkpoint_with_agent_name() {
474 let runner = DefaultRunner::default_new();
475
476 let state = NodeState::from_string("agent state");
477 let checkpoint = runner
478 .save_checkpoint(state, Some("planning_node"), Some("planner"))
479 .await
480 .unwrap();
481
482 assert_eq!(checkpoint.current_node.as_ref().unwrap(), "planning_node");
483 assert_eq!(checkpoint.agent_name(), Some("planner"));
484
485 let loaded = runner.load_checkpoint().await.unwrap().unwrap();
487 assert_eq!(loaded.agent_name(), Some("planner"));
488 }
489
490 #[tokio::test]
491 async fn test_runner_load_checkpoint_no_data() {
492 let runner = DefaultRunner::default_new();
493
494 let loaded = runner.load_checkpoint().await.unwrap();
496 assert!(loaded.is_none());
497 }
498
499 #[test]
502 fn test_runner_emitter_access() {
503 let runner = DefaultRunner::default_new();
504 let emitter = runner.emitter();
505
506 emitter.emit(StreamEvent::execution_start(runner.execution_id()));
508 let events = emitter.drain();
509 assert_eq!(events.len(), 1);
510 }
511
512 #[test]
513 fn test_emitter_mode() {
514 use crate::streaming::StreamMode;
515
516 let runner = DefaultRunner::default_new();
517 let emitter = runner.emitter();
518
519 assert_eq!(emitter.mode(), StreamMode::Full);
521 }
522}