1use crate::dag::WorkflowDag;
4use crate::error::{Result, WorkflowError};
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::Path;
9use tokio::fs;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct WorkflowState {
14 pub workflow_id: String,
16 pub execution_id: String,
18 pub status: WorkflowStatus,
20 pub task_states: HashMap<String, TaskState>,
22 pub metadata: WorkflowMetadata,
24 pub context: ExecutionContext,
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
30pub enum WorkflowStatus {
31 Pending,
33 Running,
35 Completed,
37 Failed,
39 Cancelled,
41 Paused,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct TaskState {
48 pub task_id: String,
50 pub status: TaskStatus,
52 pub attempts: u32,
54 pub started_at: Option<DateTime<Utc>>,
56 pub completed_at: Option<DateTime<Utc>>,
58 pub duration_ms: Option<u64>,
60 pub output: Option<serde_json::Value>,
62 pub error: Option<String>,
64 pub logs: Vec<String>,
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
70pub enum TaskStatus {
71 Pending,
73 Running,
75 Completed,
77 Failed,
79 Skipped,
81 Cancelled,
83 WaitingRetry,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct WorkflowMetadata {
90 pub name: String,
92 pub version: String,
94 pub created_at: DateTime<Utc>,
96 pub started_at: Option<DateTime<Utc>>,
98 pub completed_at: Option<DateTime<Utc>>,
100 pub duration_ms: Option<u64>,
102 pub owner: Option<String>,
104 pub tags: HashMap<String, String>,
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct ExecutionContext {
111 pub variables: HashMap<String, serde_json::Value>,
113 pub parameters: HashMap<String, serde_json::Value>,
115 pub env: HashMap<String, String>,
117}
118
119impl WorkflowState {
120 pub fn new(workflow_id: String, execution_id: String, name: String) -> Self {
122 Self {
123 workflow_id,
124 execution_id,
125 status: WorkflowStatus::Pending,
126 task_states: HashMap::new(),
127 metadata: WorkflowMetadata {
128 name,
129 version: "1.0.0".to_string(),
130 created_at: Utc::now(),
131 started_at: None,
132 completed_at: None,
133 duration_ms: None,
134 owner: None,
135 tags: HashMap::new(),
136 },
137 context: ExecutionContext {
138 variables: HashMap::new(),
139 parameters: HashMap::new(),
140 env: HashMap::new(),
141 },
142 }
143 }
144
145 pub fn init_task(&mut self, task_id: String) {
147 self.task_states.insert(
148 task_id.clone(),
149 TaskState {
150 task_id,
151 status: TaskStatus::Pending,
152 attempts: 0,
153 started_at: None,
154 completed_at: None,
155 duration_ms: None,
156 output: None,
157 error: None,
158 logs: Vec::new(),
159 },
160 );
161 }
162
163 pub fn start_task(&mut self, task_id: &str) -> Result<()> {
165 let task_state = self
166 .task_states
167 .get_mut(task_id)
168 .ok_or_else(|| WorkflowError::not_found(format!("Task '{}'", task_id)))?;
169
170 task_state.status = TaskStatus::Running;
171 task_state.started_at = Some(Utc::now());
172 task_state.attempts += 1;
173
174 Ok(())
175 }
176
177 pub fn complete_task(
179 &mut self,
180 task_id: &str,
181 output: Option<serde_json::Value>,
182 ) -> Result<()> {
183 let task_state = self
184 .task_states
185 .get_mut(task_id)
186 .ok_or_else(|| WorkflowError::not_found(format!("Task '{}'", task_id)))?;
187
188 task_state.status = TaskStatus::Completed;
189 task_state.completed_at = Some(Utc::now());
190 task_state.output = output;
191
192 if let Some(started) = task_state.started_at {
193 task_state.duration_ms = Some(
194 (Utc::now() - started)
195 .num_milliseconds()
196 .try_into()
197 .unwrap_or(0),
198 );
199 }
200
201 Ok(())
202 }
203
204 pub fn fail_task(&mut self, task_id: &str, error: String) -> Result<()> {
206 let task_state = self
207 .task_states
208 .get_mut(task_id)
209 .ok_or_else(|| WorkflowError::not_found(format!("Task '{}'", task_id)))?;
210
211 task_state.status = TaskStatus::Failed;
212 task_state.completed_at = Some(Utc::now());
213 task_state.error = Some(error);
214
215 if let Some(started) = task_state.started_at {
216 task_state.duration_ms = Some(
217 (Utc::now() - started)
218 .num_milliseconds()
219 .try_into()
220 .unwrap_or(0),
221 );
222 }
223
224 Ok(())
225 }
226
227 pub fn skip_task(&mut self, task_id: &str) -> Result<()> {
229 let task_state = self
230 .task_states
231 .get_mut(task_id)
232 .ok_or_else(|| WorkflowError::not_found(format!("Task '{}'", task_id)))?;
233
234 task_state.status = TaskStatus::Skipped;
235 task_state.completed_at = Some(Utc::now());
236
237 Ok(())
238 }
239
240 pub fn add_task_log(&mut self, task_id: &str, log: String) -> Result<()> {
242 let task_state = self
243 .task_states
244 .get_mut(task_id)
245 .ok_or_else(|| WorkflowError::not_found(format!("Task '{}'", task_id)))?;
246
247 task_state.logs.push(log);
248
249 Ok(())
250 }
251
252 pub fn start(&mut self) {
254 self.status = WorkflowStatus::Running;
255 self.metadata.started_at = Some(Utc::now());
256 }
257
258 pub fn complete(&mut self) {
260 self.status = WorkflowStatus::Completed;
261 self.metadata.completed_at = Some(Utc::now());
262
263 if let Some(started) = self.metadata.started_at {
264 self.metadata.duration_ms = Some(
265 (Utc::now() - started)
266 .num_milliseconds()
267 .try_into()
268 .unwrap_or(0),
269 );
270 }
271 }
272
273 pub fn fail(&mut self) {
275 self.status = WorkflowStatus::Failed;
276 self.metadata.completed_at = Some(Utc::now());
277
278 if let Some(started) = self.metadata.started_at {
279 self.metadata.duration_ms = Some(
280 (Utc::now() - started)
281 .num_milliseconds()
282 .try_into()
283 .unwrap_or(0),
284 );
285 }
286 }
287
288 pub fn cancel(&mut self) {
290 self.status = WorkflowStatus::Cancelled;
291 self.metadata.completed_at = Some(Utc::now());
292
293 if let Some(started) = self.metadata.started_at {
294 self.metadata.duration_ms = Some(
295 (Utc::now() - started)
296 .num_milliseconds()
297 .try_into()
298 .unwrap_or(0),
299 );
300 }
301 }
302
303 pub fn get_task_state(&self, task_id: &str) -> Option<&TaskState> {
305 self.task_states.get(task_id)
306 }
307
308 pub fn set_variable(&mut self, key: String, value: serde_json::Value) {
310 self.context.variables.insert(key, value);
311 }
312
313 pub fn get_variable(&self, key: &str) -> Option<&serde_json::Value> {
315 self.context.variables.get(key)
316 }
317
318 pub fn is_terminal(&self) -> bool {
320 matches!(
321 self.status,
322 WorkflowStatus::Completed | WorkflowStatus::Failed | WorkflowStatus::Cancelled
323 )
324 }
325}
326
327#[derive(Debug, Clone, Serialize, Deserialize)]
329pub struct WorkflowCheckpoint {
330 pub version: u32,
332 pub created_at: DateTime<Utc>,
334 pub sequence: u64,
336 pub state: WorkflowState,
338 pub dag: WorkflowDag,
340}
341
342impl WorkflowCheckpoint {
343 pub const CURRENT_VERSION: u32 = 1;
345
346 pub fn new(state: WorkflowState, dag: WorkflowDag, sequence: u64) -> Self {
348 Self {
349 version: Self::CURRENT_VERSION,
350 created_at: Utc::now(),
351 sequence,
352 state,
353 dag,
354 }
355 }
356
357 pub fn get_pending_tasks(&self) -> Vec<String> {
359 self.state
360 .task_states
361 .iter()
362 .filter(|(_, ts)| matches!(ts.status, TaskStatus::Pending | TaskStatus::WaitingRetry))
363 .map(|(id, _)| id.clone())
364 .collect()
365 }
366
367 pub fn get_interrupted_tasks(&self) -> Vec<String> {
369 self.state
370 .task_states
371 .iter()
372 .filter(|(_, ts)| ts.status == TaskStatus::Running)
373 .map(|(id, _)| id.clone())
374 .collect()
375 }
376
377 pub fn get_completed_tasks(&self) -> Vec<String> {
379 self.state
380 .task_states
381 .iter()
382 .filter(|(_, ts)| ts.status == TaskStatus::Completed)
383 .map(|(id, _)| id.clone())
384 .collect()
385 }
386
387 pub fn get_failed_tasks(&self) -> Vec<String> {
389 self.state
390 .task_states
391 .iter()
392 .filter(|(_, ts)| ts.status == TaskStatus::Failed)
393 .map(|(id, _)| id.clone())
394 .collect()
395 }
396
397 pub fn get_skipped_tasks(&self) -> Vec<String> {
399 self.state
400 .task_states
401 .iter()
402 .filter(|(_, ts)| ts.status == TaskStatus::Skipped)
403 .map(|(id, _)| id.clone())
404 .collect()
405 }
406
407 pub fn are_dependencies_satisfied(&self, task_id: &str) -> bool {
409 let dependencies = self.dag.get_dependencies(task_id);
410 dependencies.iter().all(|dep_id| {
411 self.state
412 .task_states
413 .get(dep_id)
414 .map(|ts| ts.status == TaskStatus::Completed)
415 .unwrap_or(false)
416 })
417 }
418
419 pub fn get_ready_tasks(&self) -> Vec<String> {
421 self.get_pending_tasks()
422 .into_iter()
423 .filter(|task_id| self.are_dependencies_satisfied(task_id))
424 .collect()
425 }
426
427 pub fn prepare_for_resume(&mut self) -> Result<()> {
429 let interrupted = self.get_interrupted_tasks();
431 for task_id in interrupted {
432 if let Some(task_state) = self.state.task_states.get_mut(&task_id) {
433 task_state.status = TaskStatus::Pending;
434 }
436 }
437
438 if self.state.status == WorkflowStatus::Paused {
440 self.state.status = WorkflowStatus::Running;
441 }
442
443 Ok(())
444 }
445}
446
447pub struct StatePersistence {
449 state_dir: String,
451}
452
453impl StatePersistence {
454 pub fn new(state_dir: String) -> Self {
456 Self { state_dir }
457 }
458
459 pub async fn save(&self, state: &WorkflowState) -> Result<()> {
461 let dir_path = Path::new(&self.state_dir);
462 fs::create_dir_all(dir_path).await.map_err(|e| {
463 WorkflowError::persistence(format!("Failed to create state dir: {}", e))
464 })?;
465
466 let file_path = dir_path.join(format!("{}.json", state.execution_id));
467 let json = serde_json::to_string_pretty(state)?;
468
469 fs::write(&file_path, json)
470 .await
471 .map_err(|e| WorkflowError::persistence(format!("Failed to write state: {}", e)))?;
472
473 Ok(())
474 }
475
476 pub async fn load(&self, execution_id: &str) -> Result<WorkflowState> {
478 let file_path = Path::new(&self.state_dir).join(format!("{}.json", execution_id));
479
480 let json = fs::read_to_string(&file_path)
481 .await
482 .map_err(|e| WorkflowError::persistence(format!("Failed to read state: {}", e)))?;
483
484 let state = serde_json::from_str(&json)?;
485 Ok(state)
486 }
487
488 pub async fn delete(&self, execution_id: &str) -> Result<()> {
490 let file_path = Path::new(&self.state_dir).join(format!("{}.json", execution_id));
491
492 fs::remove_file(&file_path)
493 .await
494 .map_err(|e| WorkflowError::persistence(format!("Failed to delete state: {}", e)))?;
495
496 Ok(())
497 }
498
499 pub async fn list(&self) -> Result<Vec<String>> {
501 let dir_path = Path::new(&self.state_dir);
502
503 if !dir_path.exists() {
504 return Ok(Vec::new());
505 }
506
507 let mut entries = fs::read_dir(dir_path)
508 .await
509 .map_err(|e| WorkflowError::persistence(format!("Failed to read state dir: {}", e)))?;
510
511 let mut execution_ids = Vec::new();
512
513 while let Some(entry) = entries
514 .next_entry()
515 .await
516 .map_err(|e| WorkflowError::persistence(format!("Failed to read entry: {}", e)))?
517 {
518 let path = entry.path();
519 if path.extension().and_then(|s| s.to_str()) == Some("json") {
520 if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
521 execution_ids.push(stem.to_string());
522 }
523 }
524 }
525
526 Ok(execution_ids)
527 }
528
529 pub async fn save_checkpoint(&self, checkpoint: &WorkflowCheckpoint) -> Result<()> {
531 let dir_path = Path::new(&self.state_dir).join("checkpoints");
532 fs::create_dir_all(&dir_path).await.map_err(|e| {
533 WorkflowError::persistence(format!("Failed to create checkpoint dir: {}", e))
534 })?;
535
536 let file_path = dir_path.join(format!(
537 "{}_checkpoint_{}.json",
538 checkpoint.state.execution_id, checkpoint.sequence
539 ));
540 let json = serde_json::to_string_pretty(checkpoint)?;
541
542 fs::write(&file_path, json).await.map_err(|e| {
543 WorkflowError::persistence(format!("Failed to write checkpoint: {}", e))
544 })?;
545
546 let latest_path = dir_path.join(format!("{}_latest.json", checkpoint.state.execution_id));
548 let json_latest = serde_json::to_string_pretty(checkpoint)?;
549 fs::write(&latest_path, json_latest).await.map_err(|e| {
550 WorkflowError::persistence(format!("Failed to write latest checkpoint: {}", e))
551 })?;
552
553 Ok(())
554 }
555
556 pub async fn load_checkpoint(&self, execution_id: &str) -> Result<WorkflowCheckpoint> {
558 let latest_path = Path::new(&self.state_dir)
559 .join("checkpoints")
560 .join(format!("{}_latest.json", execution_id));
561
562 let json = fs::read_to_string(&latest_path)
563 .await
564 .map_err(|e| WorkflowError::persistence(format!("Failed to read checkpoint: {}", e)))?;
565
566 let checkpoint: WorkflowCheckpoint = serde_json::from_str(&json)?;
567
568 if checkpoint.version > WorkflowCheckpoint::CURRENT_VERSION {
570 return Err(WorkflowError::persistence(format!(
571 "Checkpoint version {} is newer than supported version {}",
572 checkpoint.version,
573 WorkflowCheckpoint::CURRENT_VERSION
574 )));
575 }
576
577 Ok(checkpoint)
578 }
579
580 pub async fn load_checkpoint_by_sequence(
582 &self,
583 execution_id: &str,
584 sequence: u64,
585 ) -> Result<WorkflowCheckpoint> {
586 let file_path = Path::new(&self.state_dir)
587 .join("checkpoints")
588 .join(format!("{}_checkpoint_{}.json", execution_id, sequence));
589
590 let json = fs::read_to_string(&file_path)
591 .await
592 .map_err(|e| WorkflowError::persistence(format!("Failed to read checkpoint: {}", e)))?;
593
594 let checkpoint: WorkflowCheckpoint = serde_json::from_str(&json)?;
595 Ok(checkpoint)
596 }
597
598 pub async fn delete_checkpoint(&self, execution_id: &str, sequence: u64) -> Result<()> {
600 let file_path = Path::new(&self.state_dir)
601 .join("checkpoints")
602 .join(format!("{}_checkpoint_{}.json", execution_id, sequence));
603
604 fs::remove_file(&file_path).await.map_err(|e| {
605 WorkflowError::persistence(format!("Failed to delete checkpoint: {}", e))
606 })?;
607
608 Ok(())
609 }
610
611 pub async fn delete_all_checkpoints(&self, execution_id: &str) -> Result<()> {
613 let checkpoints_dir = Path::new(&self.state_dir).join("checkpoints");
614
615 if !checkpoints_dir.exists() {
616 return Ok(());
617 }
618
619 let mut entries = fs::read_dir(&checkpoints_dir).await.map_err(|e| {
620 WorkflowError::persistence(format!("Failed to read checkpoints dir: {}", e))
621 })?;
622
623 let prefix = format!("{}_", execution_id);
624
625 while let Some(entry) = entries
626 .next_entry()
627 .await
628 .map_err(|e| WorkflowError::persistence(format!("Failed to read entry: {}", e)))?
629 {
630 let path = entry.path();
631 if let Some(name) = path.file_name().and_then(|s| s.to_str()) {
632 if name.starts_with(&prefix) {
633 fs::remove_file(&path).await.map_err(|e| {
634 WorkflowError::persistence(format!("Failed to delete checkpoint: {}", e))
635 })?;
636 }
637 }
638 }
639
640 Ok(())
641 }
642
643 pub async fn list_checkpoints(&self, execution_id: &str) -> Result<Vec<u64>> {
645 let checkpoints_dir = Path::new(&self.state_dir).join("checkpoints");
646
647 if !checkpoints_dir.exists() {
648 return Ok(Vec::new());
649 }
650
651 let mut entries = fs::read_dir(&checkpoints_dir).await.map_err(|e| {
652 WorkflowError::persistence(format!("Failed to read checkpoints dir: {}", e))
653 })?;
654
655 let mut sequences = Vec::new();
656 let prefix = format!("{}_checkpoint_", execution_id);
657
658 while let Some(entry) = entries
659 .next_entry()
660 .await
661 .map_err(|e| WorkflowError::persistence(format!("Failed to read entry: {}", e)))?
662 {
663 let path = entry.path();
664 if let Some(name) = path.file_stem().and_then(|s| s.to_str()) {
665 if name.starts_with(&prefix) {
666 if let Some(seq_str) = name.strip_prefix(&prefix) {
667 if let Ok(seq) = seq_str.parse::<u64>() {
668 sequences.push(seq);
669 }
670 }
671 }
672 }
673 }
674
675 sequences.sort();
676 Ok(sequences)
677 }
678
679 pub async fn checkpoint_exists(&self, execution_id: &str) -> bool {
681 let latest_path = Path::new(&self.state_dir)
682 .join("checkpoints")
683 .join(format!("{}_latest.json", execution_id));
684 latest_path.exists()
685 }
686}
687
688#[cfg(test)]
689mod tests {
690 use super::*;
691
692 #[test]
693 fn test_workflow_state_lifecycle() {
694 let mut state = WorkflowState::new(
695 "wf1".to_string(),
696 "exec1".to_string(),
697 "Test Workflow".to_string(),
698 );
699
700 assert_eq!(state.status, WorkflowStatus::Pending);
701
702 state.start();
703 assert_eq!(state.status, WorkflowStatus::Running);
704 assert!(state.metadata.started_at.is_some());
705
706 state.complete();
707 assert_eq!(state.status, WorkflowStatus::Completed);
708 assert!(state.metadata.completed_at.is_some());
709 assert!(state.metadata.duration_ms.is_some());
710 }
711
712 #[test]
713 fn test_task_state_lifecycle() {
714 let mut state = WorkflowState::new(
715 "wf1".to_string(),
716 "exec1".to_string(),
717 "Test Workflow".to_string(),
718 );
719
720 state.init_task("task1".to_string());
721 assert_eq!(
722 state.get_task_state("task1").map(|t| t.status),
723 Some(TaskStatus::Pending)
724 );
725
726 state.start_task("task1").ok();
727 assert_eq!(
728 state.get_task_state("task1").map(|t| t.status),
729 Some(TaskStatus::Running)
730 );
731 assert_eq!(state.get_task_state("task1").map(|t| t.attempts), Some(1));
732
733 state
734 .complete_task("task1", Some(serde_json::json!({"result": "success"})))
735 .ok();
736 assert_eq!(
737 state.get_task_state("task1").map(|t| t.status),
738 Some(TaskStatus::Completed)
739 );
740 }
741
742 #[test]
743 fn test_context_variables() {
744 let mut state = WorkflowState::new(
745 "wf1".to_string(),
746 "exec1".to_string(),
747 "Test Workflow".to_string(),
748 );
749
750 state.set_variable("key1".to_string(), serde_json::json!("value1"));
751 assert_eq!(
752 state.get_variable("key1"),
753 Some(&serde_json::json!("value1"))
754 );
755 }
756}