oxify_model/
execution.rs

1use crate::{NodeId, WorkflowId};
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use uuid::Uuid;
6
7#[cfg(feature = "openapi")]
8use utoipa::ToSchema;
9
10/// Unique identifier for a workflow execution
11pub type ExecutionId = Uuid;
12
13/// Context for workflow execution
14#[derive(Debug, Clone, Serialize, Deserialize)]
15#[cfg_attr(feature = "openapi", derive(ToSchema))]
16pub struct ExecutionContext {
17    /// Unique execution identifier
18    #[cfg_attr(feature = "openapi", schema(value_type = String))]
19    pub execution_id: ExecutionId,
20
21    /// The workflow being executed
22    #[cfg_attr(feature = "openapi", schema(value_type = String))]
23    pub workflow_id: WorkflowId,
24
25    /// When the execution started
26    pub started_at: DateTime<Utc>,
27
28    /// When the execution completed (if finished)
29    pub completed_at: Option<DateTime<Utc>>,
30
31    /// Current execution state
32    pub state: ExecutionState,
33
34    /// Node execution results
35    #[cfg_attr(feature = "openapi", schema(value_type = HashMap<String, NodeExecutionResult>))]
36    pub node_results: HashMap<NodeId, NodeExecutionResult>,
37
38    /// Global variables/context available to all nodes
39    #[serde(default)]
40    pub variables: HashMap<String, serde_json::Value>,
41
42    /// Checkpoint data for resume capability
43    #[serde(default)]
44    pub checkpoint: Option<ExecutionCheckpoint>,
45}
46
47impl ExecutionContext {
48    pub fn new(workflow_id: WorkflowId) -> Self {
49        Self {
50            execution_id: Uuid::new_v4(),
51            workflow_id,
52            started_at: Utc::now(),
53            completed_at: None,
54            state: ExecutionState::Running,
55            node_results: HashMap::new(),
56            variables: HashMap::new(),
57            checkpoint: None,
58        }
59    }
60
61    /// Create a checkpoint of the current execution state
62    pub fn create_checkpoint(&mut self) -> ExecutionCheckpoint {
63        let checkpoint = ExecutionCheckpoint {
64            timestamp: Utc::now(),
65            completed_nodes: self.node_results.keys().copied().collect(),
66            variables: self.variables.clone(),
67            state: self.state.clone(),
68        };
69        self.checkpoint = Some(checkpoint.clone());
70        checkpoint
71    }
72
73    /// Resume from a checkpoint
74    pub fn resume_from_checkpoint(
75        checkpoint: ExecutionCheckpoint,
76        workflow_id: WorkflowId,
77    ) -> Self {
78        let variables = checkpoint.variables.clone();
79        let state = checkpoint.state.clone();
80        Self {
81            execution_id: Uuid::new_v4(),
82            workflow_id,
83            started_at: checkpoint.timestamp,
84            completed_at: None,
85            state,
86            node_results: HashMap::new(), // Will be restored by engine
87            variables,
88            checkpoint: Some(checkpoint),
89        }
90    }
91
92    /// Check if execution can be resumed
93    pub fn can_resume(&self) -> bool {
94        self.checkpoint.is_some() && matches!(self.state, ExecutionState::Paused)
95    }
96
97    /// Pause execution
98    pub fn pause(&mut self) {
99        self.state = ExecutionState::Paused;
100        self.create_checkpoint();
101    }
102
103    /// Resume paused execution
104    pub fn resume(&mut self) {
105        if self.state == ExecutionState::Paused {
106            self.state = ExecutionState::Running;
107        }
108    }
109
110    /// Cancel execution
111    pub fn cancel(&mut self) {
112        self.state = ExecutionState::Cancelled;
113        self.mark_completed();
114    }
115
116    /// Mark execution as completed
117    pub fn mark_completed(&mut self) {
118        if self.completed_at.is_none() {
119            self.completed_at = Some(Utc::now());
120        }
121    }
122
123    /// Record the result of a node execution
124    pub fn record_node_result(&mut self, node_id: NodeId, result: NodeExecutionResult) {
125        self.node_results.insert(node_id, result);
126    }
127
128    /// Get the result of a previous node execution
129    pub fn get_node_result(&self, node_id: &NodeId) -> Option<&NodeExecutionResult> {
130        self.node_results.get(node_id)
131    }
132
133    /// Set a variable in the execution context
134    pub fn set_variable(&mut self, key: String, value: serde_json::Value) {
135        self.variables.insert(key, value);
136    }
137
138    /// Get a variable from the execution context
139    pub fn get_variable(&self, key: &str) -> Option<&serde_json::Value> {
140        self.variables.get(key)
141    }
142}
143
144/// State of workflow execution
145#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
146#[cfg_attr(feature = "openapi", derive(ToSchema))]
147pub enum ExecutionState {
148    /// Execution is currently running
149    Running,
150
151    /// Execution completed successfully
152    Completed,
153
154    /// Execution failed
155    Failed(String),
156
157    /// Execution was cancelled
158    Cancelled,
159
160    /// Execution is paused
161    Paused,
162}
163
164/// Result of executing a single node
165#[derive(Debug, Clone, Serialize, Deserialize)]
166#[cfg_attr(feature = "openapi", derive(ToSchema))]
167pub struct NodeExecutionResult {
168    /// When this node started executing
169    pub started_at: DateTime<Utc>,
170
171    /// When this node finished executing
172    pub completed_at: Option<DateTime<Utc>>,
173
174    /// The result of the execution
175    pub result: ExecutionResult,
176
177    /// Number of retry attempts made (0 means no retries)
178    #[serde(default)]
179    pub retry_count: u32,
180
181    /// Execution metrics (token usage, costs, etc.)
182    #[serde(default)]
183    pub metrics: Option<NodeMetrics>,
184}
185
186/// Execution metrics for a node
187#[derive(Debug, Clone, Serialize, Deserialize, Default)]
188#[cfg_attr(feature = "openapi", derive(ToSchema))]
189pub struct NodeMetrics {
190    /// Execution duration in milliseconds
191    pub duration_ms: Option<u64>,
192
193    /// Token usage for LLM nodes
194    #[serde(default)]
195    pub token_usage: Option<TokenUsage>,
196
197    /// Estimated cost in USD (for LLM API calls)
198    #[serde(default)]
199    pub cost_usd: Option<f64>,
200
201    /// API calls made by this node
202    #[serde(default)]
203    pub api_calls: u32,
204
205    /// Bytes transferred (input + output)
206    #[serde(default)]
207    pub bytes_transferred: u64,
208
209    /// Memory usage in bytes (if tracked)
210    #[serde(default)]
211    pub memory_bytes: Option<u64>,
212
213    /// Custom metrics (provider-specific)
214    #[serde(default)]
215    pub custom: HashMap<String, serde_json::Value>,
216}
217
218/// Token usage for LLM nodes
219#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
220#[cfg_attr(feature = "openapi", derive(ToSchema))]
221pub struct TokenUsage {
222    /// Input/prompt tokens
223    pub input_tokens: u32,
224
225    /// Output/completion tokens
226    pub output_tokens: u32,
227
228    /// Total tokens (input + output)
229    pub total_tokens: u32,
230
231    /// Cached tokens (if applicable)
232    #[serde(default)]
233    pub cached_tokens: Option<u32>,
234}
235
236impl TokenUsage {
237    /// Create new token usage record
238    pub fn new(input_tokens: u32, output_tokens: u32) -> Self {
239        Self {
240            input_tokens,
241            output_tokens,
242            total_tokens: input_tokens + output_tokens,
243            cached_tokens: None,
244        }
245    }
246
247    /// Estimate cost based on provider pricing
248    pub fn estimate_cost(&self, input_price_per_1k: f64, output_price_per_1k: f64) -> f64 {
249        let input_cost = (self.input_tokens as f64 / 1000.0) * input_price_per_1k;
250        let output_cost = (self.output_tokens as f64 / 1000.0) * output_price_per_1k;
251        input_cost + output_cost
252    }
253}
254
255impl NodeExecutionResult {
256    pub fn new() -> Self {
257        Self {
258            started_at: Utc::now(),
259            completed_at: None,
260            result: ExecutionResult::Pending,
261            retry_count: 0,
262            metrics: None,
263        }
264    }
265
266    pub fn complete(mut self, result: ExecutionResult) -> Self {
267        let completed = Utc::now();
268        let duration_ms = (completed - self.started_at).num_milliseconds() as u64;
269        self.completed_at = Some(completed);
270        self.result = result;
271
272        // Auto-populate duration if metrics exist
273        if let Some(ref mut metrics) = self.metrics {
274            metrics.duration_ms = Some(duration_ms);
275        } else {
276            self.metrics = Some(NodeMetrics {
277                duration_ms: Some(duration_ms),
278                ..Default::default()
279            });
280        }
281
282        self
283    }
284
285    /// Add metrics to the execution result
286    pub fn with_metrics(mut self, metrics: NodeMetrics) -> Self {
287        self.metrics = Some(metrics);
288        self
289    }
290
291    /// Add token usage to the execution result
292    pub fn with_token_usage(mut self, usage: TokenUsage) -> Self {
293        if let Some(ref mut metrics) = self.metrics {
294            metrics.token_usage = Some(usage);
295        } else {
296            self.metrics = Some(NodeMetrics {
297                token_usage: Some(usage),
298                ..Default::default()
299            });
300        }
301        self
302    }
303
304    /// Get execution duration in milliseconds
305    pub fn duration_ms(&self) -> Option<u64> {
306        self.metrics.as_ref().and_then(|m| m.duration_ms)
307    }
308
309    /// Get total token count
310    pub fn total_tokens(&self) -> Option<u32> {
311        self.metrics
312            .as_ref()
313            .and_then(|m| m.token_usage.as_ref())
314            .map(|t| t.total_tokens)
315    }
316
317    /// Get estimated cost
318    pub fn cost_usd(&self) -> Option<f64> {
319        self.metrics.as_ref().and_then(|m| m.cost_usd)
320    }
321}
322
323impl Default for NodeExecutionResult {
324    fn default() -> Self {
325        Self::new()
326    }
327}
328
329/// Result of a node execution
330#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
331#[cfg_attr(feature = "openapi", derive(ToSchema))]
332pub enum ExecutionResult {
333    /// Node hasn't executed yet
334    Pending,
335
336    /// Node executed successfully with output
337    Success(serde_json::Value),
338
339    /// Node execution failed
340    Failure(String),
341
342    /// Node execution was skipped (e.g., conditional branch not taken)
343    Skipped,
344}
345
346/// Checkpoint for resumable execution
347#[derive(Debug, Clone, Serialize, Deserialize)]
348#[cfg_attr(feature = "openapi", derive(ToSchema))]
349pub struct ExecutionCheckpoint {
350    /// When this checkpoint was created
351    pub timestamp: DateTime<Utc>,
352
353    /// Nodes that have been completed
354    #[cfg_attr(feature = "openapi", schema(value_type = Vec<String>))]
355    pub completed_nodes: Vec<NodeId>,
356
357    /// Variables at checkpoint time
358    pub variables: HashMap<String, serde_json::Value>,
359
360    /// Execution state at checkpoint
361    pub state: ExecutionState,
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367
368    #[test]
369    fn test_execution_context() {
370        let workflow_id = Uuid::new_v4();
371        let mut ctx = ExecutionContext::new(workflow_id);
372
373        let node_id = Uuid::new_v4();
374        let result = NodeExecutionResult::new().complete(ExecutionResult::Success(
375            serde_json::json!({"output": "test"}),
376        ));
377
378        ctx.record_node_result(node_id, result);
379
380        assert!(ctx.get_node_result(&node_id).is_some());
381        assert_eq!(ctx.state, ExecutionState::Running);
382    }
383
384    #[test]
385    fn test_execution_context_new() {
386        let workflow_id = Uuid::new_v4();
387        let ctx = ExecutionContext::new(workflow_id);
388
389        assert_eq!(ctx.workflow_id, workflow_id);
390        assert_eq!(ctx.state, ExecutionState::Running);
391        assert_eq!(ctx.node_results.len(), 0);
392        assert_eq!(ctx.variables.len(), 0);
393        assert!(ctx.completed_at.is_none());
394        assert!(ctx.checkpoint.is_none());
395    }
396
397    #[test]
398    fn test_execution_context_pause_resume() {
399        let workflow_id = Uuid::new_v4();
400        let mut ctx = ExecutionContext::new(workflow_id);
401
402        assert_eq!(ctx.state, ExecutionState::Running);
403        assert!(!ctx.can_resume());
404
405        ctx.pause();
406        assert_eq!(ctx.state, ExecutionState::Paused);
407        assert!(ctx.can_resume());
408        assert!(ctx.checkpoint.is_some());
409
410        ctx.resume();
411        assert_eq!(ctx.state, ExecutionState::Running);
412    }
413
414    #[test]
415    fn test_execution_context_cancel() {
416        let workflow_id = Uuid::new_v4();
417        let mut ctx = ExecutionContext::new(workflow_id);
418
419        ctx.cancel();
420        assert_eq!(ctx.state, ExecutionState::Cancelled);
421        assert!(ctx.completed_at.is_some());
422    }
423
424    #[test]
425    fn test_execution_context_mark_completed() {
426        let workflow_id = Uuid::new_v4();
427        let mut ctx = ExecutionContext::new(workflow_id);
428
429        assert!(ctx.completed_at.is_none());
430
431        ctx.mark_completed();
432        assert!(ctx.completed_at.is_some());
433
434        let first_completion = ctx.completed_at.unwrap();
435        ctx.mark_completed(); // Should not update
436        assert_eq!(ctx.completed_at.unwrap(), first_completion);
437    }
438
439    #[test]
440    fn test_execution_context_variables() {
441        let workflow_id = Uuid::new_v4();
442        let mut ctx = ExecutionContext::new(workflow_id);
443
444        ctx.set_variable("key1".to_string(), serde_json::json!("value1"));
445        ctx.set_variable("key2".to_string(), serde_json::json!(42));
446
447        assert_eq!(ctx.get_variable("key1"), Some(&serde_json::json!("value1")));
448        assert_eq!(ctx.get_variable("key2"), Some(&serde_json::json!(42)));
449        assert_eq!(ctx.get_variable("key3"), None);
450    }
451
452    #[test]
453    fn test_execution_context_checkpoint() {
454        let workflow_id = Uuid::new_v4();
455        let mut ctx = ExecutionContext::new(workflow_id);
456
457        ctx.set_variable("var1".to_string(), serde_json::json!("test"));
458
459        let checkpoint = ctx.create_checkpoint();
460
461        assert_eq!(checkpoint.variables.len(), 1);
462        assert_eq!(checkpoint.state, ExecutionState::Running);
463        assert!(ctx.checkpoint.is_some());
464    }
465
466    #[test]
467    fn test_execution_context_resume_from_checkpoint() {
468        let workflow_id = Uuid::new_v4();
469        let mut original_ctx = ExecutionContext::new(workflow_id);
470
471        original_ctx.set_variable("var1".to_string(), serde_json::json!("test"));
472        let checkpoint = original_ctx.create_checkpoint();
473
474        let resumed_ctx = ExecutionContext::resume_from_checkpoint(checkpoint, workflow_id);
475
476        assert_eq!(resumed_ctx.workflow_id, workflow_id);
477        assert_eq!(resumed_ctx.variables.len(), 1);
478        assert_eq!(
479            resumed_ctx.get_variable("var1"),
480            Some(&serde_json::json!("test"))
481        );
482    }
483
484    #[test]
485    fn test_node_execution_result_new() {
486        let result = NodeExecutionResult::new();
487
488        assert_eq!(result.retry_count, 0);
489        assert!(result.completed_at.is_none());
490        assert!(result.metrics.is_none());
491        assert_eq!(result.result, ExecutionResult::Pending);
492    }
493
494    #[test]
495    fn test_node_execution_result_complete() {
496        let result = NodeExecutionResult::new().complete(ExecutionResult::Success(
497            serde_json::json!({"data": "test"}),
498        ));
499
500        assert!(result.completed_at.is_some());
501        assert!(matches!(result.result, ExecutionResult::Success(_)));
502    }
503
504    #[test]
505    fn test_node_execution_result_with_metrics() {
506        let metrics = NodeMetrics {
507            duration_ms: Some(100),
508            token_usage: Some(TokenUsage {
509                input_tokens: 50,
510                output_tokens: 30,
511                total_tokens: 80,
512                cached_tokens: None,
513            }),
514            cost_usd: Some(0.001),
515            api_calls: 1,
516            bytes_transferred: 1024,
517            memory_bytes: Some(128),
518            custom: Default::default(),
519        };
520
521        let result = NodeExecutionResult::new().with_metrics(metrics.clone());
522
523        assert!(result.metrics.is_some());
524        let result_metrics = result.metrics.unwrap();
525        assert_eq!(result_metrics.duration_ms, Some(100));
526        assert_eq!(result_metrics.cost_usd, Some(0.001));
527        assert_eq!(result_metrics.api_calls, 1);
528        assert_eq!(result_metrics.bytes_transferred, 1024);
529    }
530
531    #[test]
532    fn test_execution_result_variants() {
533        assert!(matches!(ExecutionResult::Pending, ExecutionResult::Pending));
534        assert!(matches!(
535            ExecutionResult::Success(serde_json::json!(null)),
536            ExecutionResult::Success(_)
537        ));
538        assert!(matches!(
539            ExecutionResult::Failure("test".to_string()),
540            ExecutionResult::Failure(_)
541        ));
542        assert!(matches!(ExecutionResult::Skipped, ExecutionResult::Skipped));
543    }
544
545    #[test]
546    fn test_execution_state_variants() {
547        assert_eq!(ExecutionState::Running, ExecutionState::Running);
548        assert_eq!(ExecutionState::Completed, ExecutionState::Completed);
549        assert_eq!(ExecutionState::Cancelled, ExecutionState::Cancelled);
550        assert_eq!(ExecutionState::Paused, ExecutionState::Paused);
551        assert_eq!(
552            ExecutionState::Failed("error".to_string()),
553            ExecutionState::Failed("error".to_string())
554        );
555    }
556
557    #[test]
558    fn test_token_usage() {
559        let token_usage = TokenUsage {
560            input_tokens: 100,
561            output_tokens: 50,
562            total_tokens: 150,
563            cached_tokens: None,
564        };
565
566        assert_eq!(token_usage.input_tokens, 100);
567        assert_eq!(token_usage.output_tokens, 50);
568        assert_eq!(token_usage.total_tokens, 150);
569        assert_eq!(token_usage.cached_tokens, None);
570    }
571
572    #[test]
573    fn test_node_metrics_default() {
574        let metrics = NodeMetrics::default();
575
576        assert_eq!(metrics.duration_ms, None);
577        assert_eq!(metrics.token_usage, None);
578        assert_eq!(metrics.cost_usd, None);
579        assert_eq!(metrics.api_calls, 0);
580        assert_eq!(metrics.bytes_transferred, 0);
581        assert_eq!(metrics.memory_bytes, None);
582    }
583}