1use crate::{NodeId, WorkflowId};
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use uuid::Uuid;
6
7#[cfg(feature = "openapi")]
8use utoipa::ToSchema;
9
10pub type ExecutionId = Uuid;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15#[cfg_attr(feature = "openapi", derive(ToSchema))]
16pub struct ExecutionContext {
17 #[cfg_attr(feature = "openapi", schema(value_type = String))]
19 pub execution_id: ExecutionId,
20
21 #[cfg_attr(feature = "openapi", schema(value_type = String))]
23 pub workflow_id: WorkflowId,
24
25 pub started_at: DateTime<Utc>,
27
28 pub completed_at: Option<DateTime<Utc>>,
30
31 pub state: ExecutionState,
33
34 #[cfg_attr(feature = "openapi", schema(value_type = HashMap<String, NodeExecutionResult>))]
36 pub node_results: HashMap<NodeId, NodeExecutionResult>,
37
38 #[serde(default)]
40 pub variables: HashMap<String, serde_json::Value>,
41
42 #[serde(default)]
44 pub checkpoint: Option<ExecutionCheckpoint>,
45}
46
47impl ExecutionContext {
48 pub fn new(workflow_id: WorkflowId) -> Self {
49 Self {
50 execution_id: Uuid::new_v4(),
51 workflow_id,
52 started_at: Utc::now(),
53 completed_at: None,
54 state: ExecutionState::Running,
55 node_results: HashMap::new(),
56 variables: HashMap::new(),
57 checkpoint: None,
58 }
59 }
60
61 pub fn create_checkpoint(&mut self) -> ExecutionCheckpoint {
63 let checkpoint = ExecutionCheckpoint {
64 timestamp: Utc::now(),
65 completed_nodes: self.node_results.keys().copied().collect(),
66 variables: self.variables.clone(),
67 state: self.state.clone(),
68 };
69 self.checkpoint = Some(checkpoint.clone());
70 checkpoint
71 }
72
73 pub fn resume_from_checkpoint(
75 checkpoint: ExecutionCheckpoint,
76 workflow_id: WorkflowId,
77 ) -> Self {
78 let variables = checkpoint.variables.clone();
79 let state = checkpoint.state.clone();
80 Self {
81 execution_id: Uuid::new_v4(),
82 workflow_id,
83 started_at: checkpoint.timestamp,
84 completed_at: None,
85 state,
86 node_results: HashMap::new(), variables,
88 checkpoint: Some(checkpoint),
89 }
90 }
91
92 pub fn can_resume(&self) -> bool {
94 self.checkpoint.is_some() && matches!(self.state, ExecutionState::Paused)
95 }
96
97 pub fn pause(&mut self) {
99 self.state = ExecutionState::Paused;
100 self.create_checkpoint();
101 }
102
103 pub fn resume(&mut self) {
105 if self.state == ExecutionState::Paused {
106 self.state = ExecutionState::Running;
107 }
108 }
109
110 pub fn cancel(&mut self) {
112 self.state = ExecutionState::Cancelled;
113 self.mark_completed();
114 }
115
116 pub fn mark_completed(&mut self) {
118 if self.completed_at.is_none() {
119 self.completed_at = Some(Utc::now());
120 }
121 }
122
123 pub fn record_node_result(&mut self, node_id: NodeId, result: NodeExecutionResult) {
125 self.node_results.insert(node_id, result);
126 }
127
128 pub fn get_node_result(&self, node_id: &NodeId) -> Option<&NodeExecutionResult> {
130 self.node_results.get(node_id)
131 }
132
133 pub fn set_variable(&mut self, key: String, value: serde_json::Value) {
135 self.variables.insert(key, value);
136 }
137
138 pub fn get_variable(&self, key: &str) -> Option<&serde_json::Value> {
140 self.variables.get(key)
141 }
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
146#[cfg_attr(feature = "openapi", derive(ToSchema))]
147pub enum ExecutionState {
148 Running,
150
151 Completed,
153
154 Failed(String),
156
157 Cancelled,
159
160 Paused,
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
166#[cfg_attr(feature = "openapi", derive(ToSchema))]
167pub struct NodeExecutionResult {
168 pub started_at: DateTime<Utc>,
170
171 pub completed_at: Option<DateTime<Utc>>,
173
174 pub result: ExecutionResult,
176
177 #[serde(default)]
179 pub retry_count: u32,
180
181 #[serde(default)]
183 pub metrics: Option<NodeMetrics>,
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize, Default)]
188#[cfg_attr(feature = "openapi", derive(ToSchema))]
189pub struct NodeMetrics {
190 pub duration_ms: Option<u64>,
192
193 #[serde(default)]
195 pub token_usage: Option<TokenUsage>,
196
197 #[serde(default)]
199 pub cost_usd: Option<f64>,
200
201 #[serde(default)]
203 pub api_calls: u32,
204
205 #[serde(default)]
207 pub bytes_transferred: u64,
208
209 #[serde(default)]
211 pub memory_bytes: Option<u64>,
212
213 #[serde(default)]
215 pub custom: HashMap<String, serde_json::Value>,
216}
217
218#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
220#[cfg_attr(feature = "openapi", derive(ToSchema))]
221pub struct TokenUsage {
222 pub input_tokens: u32,
224
225 pub output_tokens: u32,
227
228 pub total_tokens: u32,
230
231 #[serde(default)]
233 pub cached_tokens: Option<u32>,
234}
235
236impl TokenUsage {
237 pub fn new(input_tokens: u32, output_tokens: u32) -> Self {
239 Self {
240 input_tokens,
241 output_tokens,
242 total_tokens: input_tokens + output_tokens,
243 cached_tokens: None,
244 }
245 }
246
247 pub fn estimate_cost(&self, input_price_per_1k: f64, output_price_per_1k: f64) -> f64 {
249 let input_cost = (self.input_tokens as f64 / 1000.0) * input_price_per_1k;
250 let output_cost = (self.output_tokens as f64 / 1000.0) * output_price_per_1k;
251 input_cost + output_cost
252 }
253}
254
255impl NodeExecutionResult {
256 pub fn new() -> Self {
257 Self {
258 started_at: Utc::now(),
259 completed_at: None,
260 result: ExecutionResult::Pending,
261 retry_count: 0,
262 metrics: None,
263 }
264 }
265
266 pub fn complete(mut self, result: ExecutionResult) -> Self {
267 let completed = Utc::now();
268 let duration_ms = (completed - self.started_at).num_milliseconds() as u64;
269 self.completed_at = Some(completed);
270 self.result = result;
271
272 if let Some(ref mut metrics) = self.metrics {
274 metrics.duration_ms = Some(duration_ms);
275 } else {
276 self.metrics = Some(NodeMetrics {
277 duration_ms: Some(duration_ms),
278 ..Default::default()
279 });
280 }
281
282 self
283 }
284
285 pub fn with_metrics(mut self, metrics: NodeMetrics) -> Self {
287 self.metrics = Some(metrics);
288 self
289 }
290
291 pub fn with_token_usage(mut self, usage: TokenUsage) -> Self {
293 if let Some(ref mut metrics) = self.metrics {
294 metrics.token_usage = Some(usage);
295 } else {
296 self.metrics = Some(NodeMetrics {
297 token_usage: Some(usage),
298 ..Default::default()
299 });
300 }
301 self
302 }
303
304 pub fn duration_ms(&self) -> Option<u64> {
306 self.metrics.as_ref().and_then(|m| m.duration_ms)
307 }
308
309 pub fn total_tokens(&self) -> Option<u32> {
311 self.metrics
312 .as_ref()
313 .and_then(|m| m.token_usage.as_ref())
314 .map(|t| t.total_tokens)
315 }
316
317 pub fn cost_usd(&self) -> Option<f64> {
319 self.metrics.as_ref().and_then(|m| m.cost_usd)
320 }
321}
322
323impl Default for NodeExecutionResult {
324 fn default() -> Self {
325 Self::new()
326 }
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
331#[cfg_attr(feature = "openapi", derive(ToSchema))]
332pub enum ExecutionResult {
333 Pending,
335
336 Success(serde_json::Value),
338
339 Failure(String),
341
342 Skipped,
344}
345
346#[derive(Debug, Clone, Serialize, Deserialize)]
348#[cfg_attr(feature = "openapi", derive(ToSchema))]
349pub struct ExecutionCheckpoint {
350 pub timestamp: DateTime<Utc>,
352
353 #[cfg_attr(feature = "openapi", schema(value_type = Vec<String>))]
355 pub completed_nodes: Vec<NodeId>,
356
357 pub variables: HashMap<String, serde_json::Value>,
359
360 pub state: ExecutionState,
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367
368 #[test]
369 fn test_execution_context() {
370 let workflow_id = Uuid::new_v4();
371 let mut ctx = ExecutionContext::new(workflow_id);
372
373 let node_id = Uuid::new_v4();
374 let result = NodeExecutionResult::new().complete(ExecutionResult::Success(
375 serde_json::json!({"output": "test"}),
376 ));
377
378 ctx.record_node_result(node_id, result);
379
380 assert!(ctx.get_node_result(&node_id).is_some());
381 assert_eq!(ctx.state, ExecutionState::Running);
382 }
383
384 #[test]
385 fn test_execution_context_new() {
386 let workflow_id = Uuid::new_v4();
387 let ctx = ExecutionContext::new(workflow_id);
388
389 assert_eq!(ctx.workflow_id, workflow_id);
390 assert_eq!(ctx.state, ExecutionState::Running);
391 assert_eq!(ctx.node_results.len(), 0);
392 assert_eq!(ctx.variables.len(), 0);
393 assert!(ctx.completed_at.is_none());
394 assert!(ctx.checkpoint.is_none());
395 }
396
397 #[test]
398 fn test_execution_context_pause_resume() {
399 let workflow_id = Uuid::new_v4();
400 let mut ctx = ExecutionContext::new(workflow_id);
401
402 assert_eq!(ctx.state, ExecutionState::Running);
403 assert!(!ctx.can_resume());
404
405 ctx.pause();
406 assert_eq!(ctx.state, ExecutionState::Paused);
407 assert!(ctx.can_resume());
408 assert!(ctx.checkpoint.is_some());
409
410 ctx.resume();
411 assert_eq!(ctx.state, ExecutionState::Running);
412 }
413
414 #[test]
415 fn test_execution_context_cancel() {
416 let workflow_id = Uuid::new_v4();
417 let mut ctx = ExecutionContext::new(workflow_id);
418
419 ctx.cancel();
420 assert_eq!(ctx.state, ExecutionState::Cancelled);
421 assert!(ctx.completed_at.is_some());
422 }
423
424 #[test]
425 fn test_execution_context_mark_completed() {
426 let workflow_id = Uuid::new_v4();
427 let mut ctx = ExecutionContext::new(workflow_id);
428
429 assert!(ctx.completed_at.is_none());
430
431 ctx.mark_completed();
432 assert!(ctx.completed_at.is_some());
433
434 let first_completion = ctx.completed_at.unwrap();
435 ctx.mark_completed(); assert_eq!(ctx.completed_at.unwrap(), first_completion);
437 }
438
439 #[test]
440 fn test_execution_context_variables() {
441 let workflow_id = Uuid::new_v4();
442 let mut ctx = ExecutionContext::new(workflow_id);
443
444 ctx.set_variable("key1".to_string(), serde_json::json!("value1"));
445 ctx.set_variable("key2".to_string(), serde_json::json!(42));
446
447 assert_eq!(ctx.get_variable("key1"), Some(&serde_json::json!("value1")));
448 assert_eq!(ctx.get_variable("key2"), Some(&serde_json::json!(42)));
449 assert_eq!(ctx.get_variable("key3"), None);
450 }
451
452 #[test]
453 fn test_execution_context_checkpoint() {
454 let workflow_id = Uuid::new_v4();
455 let mut ctx = ExecutionContext::new(workflow_id);
456
457 ctx.set_variable("var1".to_string(), serde_json::json!("test"));
458
459 let checkpoint = ctx.create_checkpoint();
460
461 assert_eq!(checkpoint.variables.len(), 1);
462 assert_eq!(checkpoint.state, ExecutionState::Running);
463 assert!(ctx.checkpoint.is_some());
464 }
465
466 #[test]
467 fn test_execution_context_resume_from_checkpoint() {
468 let workflow_id = Uuid::new_v4();
469 let mut original_ctx = ExecutionContext::new(workflow_id);
470
471 original_ctx.set_variable("var1".to_string(), serde_json::json!("test"));
472 let checkpoint = original_ctx.create_checkpoint();
473
474 let resumed_ctx = ExecutionContext::resume_from_checkpoint(checkpoint, workflow_id);
475
476 assert_eq!(resumed_ctx.workflow_id, workflow_id);
477 assert_eq!(resumed_ctx.variables.len(), 1);
478 assert_eq!(
479 resumed_ctx.get_variable("var1"),
480 Some(&serde_json::json!("test"))
481 );
482 }
483
484 #[test]
485 fn test_node_execution_result_new() {
486 let result = NodeExecutionResult::new();
487
488 assert_eq!(result.retry_count, 0);
489 assert!(result.completed_at.is_none());
490 assert!(result.metrics.is_none());
491 assert_eq!(result.result, ExecutionResult::Pending);
492 }
493
494 #[test]
495 fn test_node_execution_result_complete() {
496 let result = NodeExecutionResult::new().complete(ExecutionResult::Success(
497 serde_json::json!({"data": "test"}),
498 ));
499
500 assert!(result.completed_at.is_some());
501 assert!(matches!(result.result, ExecutionResult::Success(_)));
502 }
503
504 #[test]
505 fn test_node_execution_result_with_metrics() {
506 let metrics = NodeMetrics {
507 duration_ms: Some(100),
508 token_usage: Some(TokenUsage {
509 input_tokens: 50,
510 output_tokens: 30,
511 total_tokens: 80,
512 cached_tokens: None,
513 }),
514 cost_usd: Some(0.001),
515 api_calls: 1,
516 bytes_transferred: 1024,
517 memory_bytes: Some(128),
518 custom: Default::default(),
519 };
520
521 let result = NodeExecutionResult::new().with_metrics(metrics.clone());
522
523 assert!(result.metrics.is_some());
524 let result_metrics = result.metrics.unwrap();
525 assert_eq!(result_metrics.duration_ms, Some(100));
526 assert_eq!(result_metrics.cost_usd, Some(0.001));
527 assert_eq!(result_metrics.api_calls, 1);
528 assert_eq!(result_metrics.bytes_transferred, 1024);
529 }
530
531 #[test]
532 fn test_execution_result_variants() {
533 assert!(matches!(ExecutionResult::Pending, ExecutionResult::Pending));
534 assert!(matches!(
535 ExecutionResult::Success(serde_json::json!(null)),
536 ExecutionResult::Success(_)
537 ));
538 assert!(matches!(
539 ExecutionResult::Failure("test".to_string()),
540 ExecutionResult::Failure(_)
541 ));
542 assert!(matches!(ExecutionResult::Skipped, ExecutionResult::Skipped));
543 }
544
545 #[test]
546 fn test_execution_state_variants() {
547 assert_eq!(ExecutionState::Running, ExecutionState::Running);
548 assert_eq!(ExecutionState::Completed, ExecutionState::Completed);
549 assert_eq!(ExecutionState::Cancelled, ExecutionState::Cancelled);
550 assert_eq!(ExecutionState::Paused, ExecutionState::Paused);
551 assert_eq!(
552 ExecutionState::Failed("error".to_string()),
553 ExecutionState::Failed("error".to_string())
554 );
555 }
556
557 #[test]
558 fn test_token_usage() {
559 let token_usage = TokenUsage {
560 input_tokens: 100,
561 output_tokens: 50,
562 total_tokens: 150,
563 cached_tokens: None,
564 };
565
566 assert_eq!(token_usage.input_tokens, 100);
567 assert_eq!(token_usage.output_tokens, 50);
568 assert_eq!(token_usage.total_tokens, 150);
569 assert_eq!(token_usage.cached_tokens, None);
570 }
571
572 #[test]
573 fn test_node_metrics_default() {
574 let metrics = NodeMetrics::default();
575
576 assert_eq!(metrics.duration_ms, None);
577 assert_eq!(metrics.token_usage, None);
578 assert_eq!(metrics.cost_usd, None);
579 assert_eq!(metrics.api_calls, 0);
580 assert_eq!(metrics.bytes_transferred, 0);
581 assert_eq!(metrics.memory_bytes, None);
582 }
583}