Skip to main content

oxigdal_workflow/engine/
runtime.rs

1//! Workflow runtime for managing multiple workflow executions.
2
3use crate::dag::WorkflowDag;
4use crate::engine::executor::{ExecutorConfig, TaskExecutor, WorkflowExecutor};
5use crate::engine::state::{WorkflowState, WorkflowStatus};
6use crate::error::{Result, WorkflowError};
7use dashmap::DashMap;
8use serde::{Deserialize, Serialize};
9use std::sync::Arc;
10use tokio::sync::RwLock;
11use tokio::task::JoinHandle;
12use tracing::{debug, info};
13
14use uuid::Uuid;
15
16/// Workflow definition.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct WorkflowDefinition {
19    /// Workflow ID.
20    pub id: String,
21    /// Workflow name.
22    pub name: String,
23    /// Workflow version.
24    pub version: String,
25    /// Workflow DAG.
26    pub dag: WorkflowDag,
27    /// Workflow description.
28    pub description: Option<String>,
29}
30
31/// Active workflow execution.
32struct ActiveExecution {
33    /// Execution handle.
34    handle: JoinHandle<Result<WorkflowState>>,
35    /// Workflow state.
36    state: Arc<RwLock<WorkflowState>>,
37}
38
39/// Workflow runtime for managing workflow executions.
40pub struct WorkflowRuntime<E: TaskExecutor> {
41    /// Executor configuration.
42    config: ExecutorConfig,
43    /// Task executor.
44    task_executor: Arc<E>,
45    /// Registered workflow definitions.
46    workflows: Arc<DashMap<String, WorkflowDefinition>>,
47    /// Active executions.
48    executions: Arc<DashMap<String, ActiveExecution>>,
49}
50
51impl<E: TaskExecutor + Clone + 'static> WorkflowRuntime<E> {
52    /// Create a new workflow runtime.
53    pub fn new(config: ExecutorConfig, task_executor: E) -> Self {
54        Self {
55            config,
56            task_executor: Arc::new(task_executor),
57            workflows: Arc::new(DashMap::new()),
58            executions: Arc::new(DashMap::new()),
59        }
60    }
61
62    /// Register a workflow definition.
63    pub fn register_workflow(&self, definition: WorkflowDefinition) -> Result<()> {
64        if self.workflows.contains_key(&definition.id) {
65            return Err(WorkflowError::already_exists(format!(
66                "Workflow '{}'",
67                definition.id
68            )));
69        }
70
71        // Validate the DAG
72        definition.dag.validate()?;
73
74        info!("Registering workflow: {}", definition.id);
75        self.workflows.insert(definition.id.clone(), definition);
76
77        Ok(())
78    }
79
80    /// Unregister a workflow definition.
81    pub fn unregister_workflow(&self, workflow_id: &str) -> Result<()> {
82        self.workflows
83            .remove(workflow_id)
84            .ok_or_else(|| WorkflowError::not_found(format!("Workflow '{}'", workflow_id)))?;
85
86        info!("Unregistered workflow: {}", workflow_id);
87        Ok(())
88    }
89
90    /// Get a workflow definition.
91    pub fn get_workflow(&self, workflow_id: &str) -> Option<WorkflowDefinition> {
92        self.workflows
93            .get(workflow_id)
94            .map(|entry| entry.value().clone())
95    }
96
97    /// List all registered workflows.
98    pub fn list_workflows(&self) -> Vec<WorkflowDefinition> {
99        self.workflows
100            .iter()
101            .map(|entry| entry.value().clone())
102            .collect()
103    }
104
105    /// Start a workflow execution.
106    pub fn start_workflow(&self, workflow_id: &str) -> Result<String> {
107        let definition = self
108            .get_workflow(workflow_id)
109            .ok_or_else(|| WorkflowError::not_found(format!("Workflow '{}'", workflow_id)))?;
110
111        let execution_id = Uuid::new_v4().to_string();
112
113        info!(
114            "Starting workflow execution: workflow_id={}, execution_id={}",
115            workflow_id, execution_id
116        );
117
118        let executor = WorkflowExecutor::new(self.config.clone(), (*self.task_executor).clone());
119
120        let wf_id = workflow_id.to_string();
121        let exec_id = execution_id.clone();
122        let dag = definition.dag.clone();
123
124        // Create initial state
125        let state = WorkflowState::new(wf_id.clone(), exec_id.clone(), definition.name.clone());
126        let state_arc = Arc::new(RwLock::new(state));
127        let state_arc_clone = Arc::clone(&state_arc);
128
129        // Spawn execution task
130        let handle = tokio::spawn(async move { executor.execute(wf_id, exec_id, dag).await });
131
132        self.executions.insert(
133            execution_id.clone(),
134            ActiveExecution {
135                handle,
136                state: state_arc_clone,
137            },
138        );
139
140        Ok(execution_id)
141    }
142
143    /// Get the status of a workflow execution.
144    pub async fn get_execution_status(&self, execution_id: &str) -> Result<WorkflowStatus> {
145        let execution = self
146            .executions
147            .get(execution_id)
148            .ok_or_else(|| WorkflowError::not_found(format!("Execution '{}'", execution_id)))?;
149
150        let state = execution.state.read().await;
151        Ok(state.status)
152    }
153
154    /// Get the full state of a workflow execution.
155    pub async fn get_execution_state(&self, execution_id: &str) -> Result<WorkflowState> {
156        let execution = self
157            .executions
158            .get(execution_id)
159            .ok_or_else(|| WorkflowError::not_found(format!("Execution '{}'", execution_id)))?;
160
161        let state = execution.state.read().await;
162        Ok(state.clone())
163    }
164
165    /// Cancel a workflow execution.
166    pub async fn cancel_execution(&self, execution_id: &str) -> Result<()> {
167        let (_, execution) = self
168            .executions
169            .remove(execution_id)
170            .ok_or_else(|| WorkflowError::not_found(format!("Execution '{}'", execution_id)))?;
171
172        debug!("Cancelling workflow execution: {}", execution_id);
173
174        // Cancel the execution
175        execution.handle.abort();
176
177        // Update state
178        let mut state = execution.state.write().await;
179        state.cancel();
180
181        info!("Cancelled workflow execution: {}", execution_id);
182
183        Ok(())
184    }
185
186    /// Wait for a workflow execution to complete.
187    pub async fn wait_for_completion(&self, execution_id: &str) -> Result<WorkflowState> {
188        let (_, execution) = self
189            .executions
190            .remove(execution_id)
191            .ok_or_else(|| WorkflowError::not_found(format!("Execution '{}'", execution_id)))?;
192
193        debug!("Waiting for workflow execution: {}", execution_id);
194
195        match execution.handle.await {
196            Ok(result) => result,
197            Err(e) => {
198                if e.is_cancelled() {
199                    let state = execution.state.read().await;
200                    Ok(state.clone())
201                } else {
202                    Err(WorkflowError::execution(format!(
203                        "Execution task panicked: {}",
204                        e
205                    )))
206                }
207            }
208        }
209    }
210
211    /// List all active executions.
212    pub fn list_active_executions(&self) -> Vec<String> {
213        self.executions
214            .iter()
215            .map(|entry| entry.key().clone())
216            .collect()
217    }
218
219    /// Get the number of active executions.
220    pub fn active_execution_count(&self) -> usize {
221        self.executions.len()
222    }
223
224    /// Clean up completed executions.
225    pub async fn cleanup_completed(&self) -> usize {
226        let mut completed = Vec::new();
227
228        for entry in self.executions.iter() {
229            let execution_id = entry.key().clone();
230            let state = entry.value().state.read().await;
231
232            if state.is_terminal() {
233                completed.push(execution_id);
234            }
235        }
236
237        let count = completed.len();
238
239        for execution_id in completed {
240            self.executions.remove(&execution_id);
241        }
242
243        if count > 0 {
244            info!("Cleaned up {} completed executions", count);
245        }
246
247        count
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    use crate::dag::graph::{ResourceRequirements, RetryPolicy, TaskNode};
255    use crate::engine::executor::{ExecutionContext, TaskExecutor, TaskOutput};
256    use async_trait::async_trait;
257    use std::collections::HashMap;
258
259    struct DummyExecutor;
260
261    #[async_trait]
262    impl TaskExecutor for DummyExecutor {
263        async fn execute(
264            &self,
265            _task: &TaskNode,
266            _context: &ExecutionContext,
267        ) -> Result<TaskOutput> {
268            Ok(TaskOutput {
269                data: Some(serde_json::json!({"result": "success"})),
270                logs: vec!["Task executed".to_string()],
271            })
272        }
273    }
274
275    impl Clone for DummyExecutor {
276        fn clone(&self) -> Self {
277            Self
278        }
279    }
280
281    fn create_test_workflow() -> WorkflowDefinition {
282        let mut dag = WorkflowDag::new();
283        dag.add_task(TaskNode {
284            id: "task1".to_string(),
285            name: "Task 1".to_string(),
286            description: None,
287            config: serde_json::json!({}),
288            retry: RetryPolicy::default(),
289            timeout_secs: Some(60),
290            resources: ResourceRequirements::default(),
291            metadata: HashMap::new(),
292        })
293        .ok();
294
295        WorkflowDefinition {
296            id: "wf1".to_string(),
297            name: "Test Workflow".to_string(),
298            version: "1.0.0".to_string(),
299            dag,
300            description: Some("Test workflow".to_string()),
301        }
302    }
303
304    #[tokio::test]
305    async fn test_register_workflow() {
306        let runtime = WorkflowRuntime::new(ExecutorConfig::default(), DummyExecutor);
307        let workflow = create_test_workflow();
308
309        let result = runtime.register_workflow(workflow);
310        assert!(result.is_ok());
311
312        assert!(runtime.get_workflow("wf1").is_some());
313    }
314
315    #[tokio::test]
316    async fn test_start_workflow() {
317        let runtime = WorkflowRuntime::new(ExecutorConfig::default(), DummyExecutor);
318        let workflow = create_test_workflow();
319
320        runtime.register_workflow(workflow).ok();
321
322        let execution_id = runtime.start_workflow("wf1");
323        assert!(execution_id.is_ok());
324    }
325}