1use crate::dag::WorkflowDag;
4use crate::engine::executor::{ExecutorConfig, TaskExecutor, WorkflowExecutor};
5use crate::engine::state::{WorkflowState, WorkflowStatus};
6use crate::error::{Result, WorkflowError};
7use dashmap::DashMap;
8use serde::{Deserialize, Serialize};
9use std::sync::Arc;
10use tokio::sync::RwLock;
11use tokio::task::JoinHandle;
12use tracing::{debug, info};
13
14use uuid::Uuid;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct WorkflowDefinition {
19 pub id: String,
21 pub name: String,
23 pub version: String,
25 pub dag: WorkflowDag,
27 pub description: Option<String>,
29}
30
31struct ActiveExecution {
33 handle: JoinHandle<Result<WorkflowState>>,
35 state: Arc<RwLock<WorkflowState>>,
37}
38
39pub struct WorkflowRuntime<E: TaskExecutor> {
41 config: ExecutorConfig,
43 task_executor: Arc<E>,
45 workflows: Arc<DashMap<String, WorkflowDefinition>>,
47 executions: Arc<DashMap<String, ActiveExecution>>,
49}
50
51impl<E: TaskExecutor + Clone + 'static> WorkflowRuntime<E> {
52 pub fn new(config: ExecutorConfig, task_executor: E) -> Self {
54 Self {
55 config,
56 task_executor: Arc::new(task_executor),
57 workflows: Arc::new(DashMap::new()),
58 executions: Arc::new(DashMap::new()),
59 }
60 }
61
62 pub fn register_workflow(&self, definition: WorkflowDefinition) -> Result<()> {
64 if self.workflows.contains_key(&definition.id) {
65 return Err(WorkflowError::already_exists(format!(
66 "Workflow '{}'",
67 definition.id
68 )));
69 }
70
71 definition.dag.validate()?;
73
74 info!("Registering workflow: {}", definition.id);
75 self.workflows.insert(definition.id.clone(), definition);
76
77 Ok(())
78 }
79
80 pub fn unregister_workflow(&self, workflow_id: &str) -> Result<()> {
82 self.workflows
83 .remove(workflow_id)
84 .ok_or_else(|| WorkflowError::not_found(format!("Workflow '{}'", workflow_id)))?;
85
86 info!("Unregistered workflow: {}", workflow_id);
87 Ok(())
88 }
89
90 pub fn get_workflow(&self, workflow_id: &str) -> Option<WorkflowDefinition> {
92 self.workflows
93 .get(workflow_id)
94 .map(|entry| entry.value().clone())
95 }
96
97 pub fn list_workflows(&self) -> Vec<WorkflowDefinition> {
99 self.workflows
100 .iter()
101 .map(|entry| entry.value().clone())
102 .collect()
103 }
104
105 pub fn start_workflow(&self, workflow_id: &str) -> Result<String> {
107 let definition = self
108 .get_workflow(workflow_id)
109 .ok_or_else(|| WorkflowError::not_found(format!("Workflow '{}'", workflow_id)))?;
110
111 let execution_id = Uuid::new_v4().to_string();
112
113 info!(
114 "Starting workflow execution: workflow_id={}, execution_id={}",
115 workflow_id, execution_id
116 );
117
118 let executor = WorkflowExecutor::new(self.config.clone(), (*self.task_executor).clone());
119
120 let wf_id = workflow_id.to_string();
121 let exec_id = execution_id.clone();
122 let dag = definition.dag.clone();
123
124 let state = WorkflowState::new(wf_id.clone(), exec_id.clone(), definition.name.clone());
126 let state_arc = Arc::new(RwLock::new(state));
127 let state_arc_clone = Arc::clone(&state_arc);
128
129 let handle = tokio::spawn(async move { executor.execute(wf_id, exec_id, dag).await });
131
132 self.executions.insert(
133 execution_id.clone(),
134 ActiveExecution {
135 handle,
136 state: state_arc_clone,
137 },
138 );
139
140 Ok(execution_id)
141 }
142
143 pub async fn get_execution_status(&self, execution_id: &str) -> Result<WorkflowStatus> {
145 let execution = self
146 .executions
147 .get(execution_id)
148 .ok_or_else(|| WorkflowError::not_found(format!("Execution '{}'", execution_id)))?;
149
150 let state = execution.state.read().await;
151 Ok(state.status)
152 }
153
154 pub async fn get_execution_state(&self, execution_id: &str) -> Result<WorkflowState> {
156 let execution = self
157 .executions
158 .get(execution_id)
159 .ok_or_else(|| WorkflowError::not_found(format!("Execution '{}'", execution_id)))?;
160
161 let state = execution.state.read().await;
162 Ok(state.clone())
163 }
164
165 pub async fn cancel_execution(&self, execution_id: &str) -> Result<()> {
167 let (_, execution) = self
168 .executions
169 .remove(execution_id)
170 .ok_or_else(|| WorkflowError::not_found(format!("Execution '{}'", execution_id)))?;
171
172 debug!("Cancelling workflow execution: {}", execution_id);
173
174 execution.handle.abort();
176
177 let mut state = execution.state.write().await;
179 state.cancel();
180
181 info!("Cancelled workflow execution: {}", execution_id);
182
183 Ok(())
184 }
185
186 pub async fn wait_for_completion(&self, execution_id: &str) -> Result<WorkflowState> {
188 let (_, execution) = self
189 .executions
190 .remove(execution_id)
191 .ok_or_else(|| WorkflowError::not_found(format!("Execution '{}'", execution_id)))?;
192
193 debug!("Waiting for workflow execution: {}", execution_id);
194
195 match execution.handle.await {
196 Ok(result) => result,
197 Err(e) => {
198 if e.is_cancelled() {
199 let state = execution.state.read().await;
200 Ok(state.clone())
201 } else {
202 Err(WorkflowError::execution(format!(
203 "Execution task panicked: {}",
204 e
205 )))
206 }
207 }
208 }
209 }
210
211 pub fn list_active_executions(&self) -> Vec<String> {
213 self.executions
214 .iter()
215 .map(|entry| entry.key().clone())
216 .collect()
217 }
218
219 pub fn active_execution_count(&self) -> usize {
221 self.executions.len()
222 }
223
224 pub async fn cleanup_completed(&self) -> usize {
226 let mut completed = Vec::new();
227
228 for entry in self.executions.iter() {
229 let execution_id = entry.key().clone();
230 let state = entry.value().state.read().await;
231
232 if state.is_terminal() {
233 completed.push(execution_id);
234 }
235 }
236
237 let count = completed.len();
238
239 for execution_id in completed {
240 self.executions.remove(&execution_id);
241 }
242
243 if count > 0 {
244 info!("Cleaned up {} completed executions", count);
245 }
246
247 count
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254 use crate::dag::graph::{ResourceRequirements, RetryPolicy, TaskNode};
255 use crate::engine::executor::{ExecutionContext, TaskExecutor, TaskOutput};
256 use async_trait::async_trait;
257 use std::collections::HashMap;
258
259 struct DummyExecutor;
260
261 #[async_trait]
262 impl TaskExecutor for DummyExecutor {
263 async fn execute(
264 &self,
265 _task: &TaskNode,
266 _context: &ExecutionContext,
267 ) -> Result<TaskOutput> {
268 Ok(TaskOutput {
269 data: Some(serde_json::json!({"result": "success"})),
270 logs: vec!["Task executed".to_string()],
271 })
272 }
273 }
274
275 impl Clone for DummyExecutor {
276 fn clone(&self) -> Self {
277 Self
278 }
279 }
280
281 fn create_test_workflow() -> WorkflowDefinition {
282 let mut dag = WorkflowDag::new();
283 dag.add_task(TaskNode {
284 id: "task1".to_string(),
285 name: "Task 1".to_string(),
286 description: None,
287 config: serde_json::json!({}),
288 retry: RetryPolicy::default(),
289 timeout_secs: Some(60),
290 resources: ResourceRequirements::default(),
291 metadata: HashMap::new(),
292 })
293 .ok();
294
295 WorkflowDefinition {
296 id: "wf1".to_string(),
297 name: "Test Workflow".to_string(),
298 version: "1.0.0".to_string(),
299 dag,
300 description: Some("Test workflow".to_string()),
301 }
302 }
303
304 #[tokio::test]
305 async fn test_register_workflow() {
306 let runtime = WorkflowRuntime::new(ExecutorConfig::default(), DummyExecutor);
307 let workflow = create_test_workflow();
308
309 let result = runtime.register_workflow(workflow);
310 assert!(result.is_ok());
311
312 assert!(runtime.get_workflow("wf1").is_some());
313 }
314
315 #[tokio::test]
316 async fn test_start_workflow() {
317 let runtime = WorkflowRuntime::new(ExecutorConfig::default(), DummyExecutor);
318 let workflow = create_test_workflow();
319
320 runtime.register_workflow(workflow).ok();
321
322 let execution_id = runtime.start_workflow("wf1");
323 assert!(execution_id.is_ok());
324 }
325}