ant_quic/workflow/
state_store.rs

1//! Workflow State Store
2//!
3//! This module provides persistence for workflow state, enabling workflow
4//! recovery and fault tolerance across system restarts.
5
6use std::{
7    collections::HashMap,
8    sync::Arc,
9    time::{Duration, SystemTime},
10};
11
12use async_trait::async_trait;
13use serde::{Deserialize, Serialize};
14use tokio::sync::RwLock;
15use tracing::{debug, info};
16
17use crate::workflow::{
18    StageId, Version, WorkflowError, WorkflowId, WorkflowMetrics, WorkflowStatus,
19};
20
21/// Workflow state that can be persisted
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct WorkflowState {
24    /// Workflow instance ID
25    pub instance_id: WorkflowId,
26    /// Workflow definition ID
27    pub workflow_id: String,
28    /// Workflow version
29    pub version: Version,
30    /// Current status
31    pub status: WorkflowStatus,
32    /// Current stage
33    pub current_stage: StageId,
34    /// Input data
35    pub input: HashMap<String, Vec<u8>>,
36    /// Workflow state data
37    pub state: HashMap<String, Vec<u8>>,
38    /// Metrics
39    pub metrics: WorkflowMetrics,
40    /// Retry attempts per stage
41    pub retry_attempts: HashMap<StageId, u32>,
42    /// Creation timestamp
43    pub created_at: SystemTime,
44    /// Last updated timestamp
45    pub updated_at: SystemTime,
46    /// Checkpoint version
47    pub checkpoint_version: u64,
48}
49
50/// Trait for workflow state persistence
51#[async_trait]
52pub trait StateStore: Send + Sync {
53    /// Save workflow state
54    async fn save(&self, state: &WorkflowState) -> Result<(), WorkflowError>;
55    
56    /// Load workflow state
57    async fn load(&self, instance_id: &WorkflowId) -> Result<WorkflowState, WorkflowError>;
58    
59    /// Delete workflow state
60    async fn delete(&self, instance_id: &WorkflowId) -> Result<(), WorkflowError>;
61    
62    /// List all workflow instances
63    async fn list(&self) -> Result<Vec<WorkflowId>, WorkflowError>;
64    
65    /// List workflow instances by status
66    async fn list_by_status(&self, status: WorkflowStatus) -> Result<Vec<WorkflowId>, WorkflowError>;
67    
68    /// Clean up old completed workflows
69    async fn cleanup(&self, retention: Duration) -> Result<u64, WorkflowError>;
70}
71
72/// In-memory state store for testing
73pub struct InMemoryStateStore {
74    states: Arc<RwLock<HashMap<WorkflowId, WorkflowState>>>,
75}
76
77impl InMemoryStateStore {
78    /// Create a new in-memory state store
79    pub fn new() -> Self {
80        Self {
81            states: Arc::new(RwLock::new(HashMap::new())),
82        }
83    }
84}
85
86#[async_trait]
87impl StateStore for InMemoryStateStore {
88    async fn save(&self, state: &WorkflowState) -> Result<(), WorkflowError> {
89        let mut states = self.states.write().await;
90        states.insert(state.instance_id, state.clone());
91        debug!("Saved state for workflow {}", state.instance_id);
92        Ok(())
93    }
94    
95    async fn load(&self, instance_id: &WorkflowId) -> Result<WorkflowState, WorkflowError> {
96        let states = self.states.read().await;
97        states.get(instance_id).cloned().ok_or_else(|| WorkflowError {
98            code: "STATE_NOT_FOUND".to_string(),
99            message: format!("State not found for workflow {}", instance_id),
100            stage: None,
101            trace: None,
102            recovery_hints: vec![],
103        })
104    }
105    
106    async fn delete(&self, instance_id: &WorkflowId) -> Result<(), WorkflowError> {
107        let mut states = self.states.write().await;
108        states.remove(instance_id);
109        debug!("Deleted state for workflow {}", instance_id);
110        Ok(())
111    }
112    
113    async fn list(&self) -> Result<Vec<WorkflowId>, WorkflowError> {
114        let states = self.states.read().await;
115        Ok(states.keys().cloned().collect())
116    }
117    
118    async fn list_by_status(&self, target_status: WorkflowStatus) -> Result<Vec<WorkflowId>, WorkflowError> {
119        let states = self.states.read().await;
120        Ok(states.iter()
121            .filter(|(_, state)| state.status == target_status)
122            .map(|(id, _)| *id)
123            .collect())
124    }
125    
126    async fn cleanup(&self, retention: Duration) -> Result<u64, WorkflowError> {
127        let mut states = self.states.write().await;
128        let now = SystemTime::now();
129        let mut removed = 0;
130        
131        states.retain(|_, state| {
132            match &state.status {
133                WorkflowStatus::Completed { .. } | WorkflowStatus::Failed { .. } | WorkflowStatus::Cancelled => {
134                    if let Ok(age) = now.duration_since(state.updated_at) {
135                        if age > retention {
136                            removed += 1;
137                            return false;
138                        }
139                    }
140                }
141                _ => {}
142            }
143            true
144        });
145        
146        debug!("Cleaned up {} old workflow states", removed);
147        Ok(removed)
148    }
149}
150
151/// File-based state store for production use
152pub struct FileStateStore {
153    /// Base directory for state files
154    base_dir: std::path::PathBuf,
155    /// File lock for concurrent access
156    locks: Arc<RwLock<HashMap<WorkflowId, Arc<tokio::sync::Mutex<()>>>>>,
157}
158
159impl FileStateStore {
160    /// Create a new file-based state store
161    pub fn new(base_dir: std::path::PathBuf) -> Result<Self, WorkflowError> {
162        // Create base directory if it doesn't exist
163        std::fs::create_dir_all(&base_dir).map_err(|e| WorkflowError {
164            code: "STORAGE_ERROR".to_string(),
165            message: format!("Failed to create state directory: {}", e),
166            stage: None,
167            trace: None,
168            recovery_hints: vec!["Check directory permissions".to_string()],
169        })?;
170        
171        Ok(Self {
172            base_dir,
173            locks: Arc::new(RwLock::new(HashMap::new())),
174        })
175    }
176    
177    /// Get the file path for a workflow instance
178    fn get_file_path(&self, instance_id: &WorkflowId) -> std::path::PathBuf {
179        // Use full hex encoding of the WorkflowId bytes
180        self.base_dir.join(format!("{}.json", hex::encode(&instance_id.0)))
181    }
182    
183    /// Get or create a lock for a workflow instance
184    async fn get_lock(&self, instance_id: &WorkflowId) -> Arc<tokio::sync::Mutex<()>> {
185        let mut locks = self.locks.write().await;
186        locks.entry(*instance_id)
187            .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
188            .clone()
189    }
190}
191
192#[async_trait]
193impl StateStore for FileStateStore {
194    async fn save(&self, state: &WorkflowState) -> Result<(), WorkflowError> {
195        let lock = self.get_lock(&state.instance_id).await;
196        let _guard = lock.lock().await;
197        
198        let path = self.get_file_path(&state.instance_id);
199        let mut updated_state = state.clone();
200        updated_state.updated_at = SystemTime::now();
201        updated_state.checkpoint_version += 1;
202        
203        let json = serde_json::to_string_pretty(&updated_state).map_err(|e| WorkflowError {
204            code: "SERIALIZATION_ERROR".to_string(),
205            message: format!("Failed to serialize state: {}", e),
206            stage: None,
207            trace: None,
208            recovery_hints: vec![],
209        })?;
210        
211        // Write to temporary file first
212        let temp_path = path.with_extension("tmp");
213        tokio::fs::write(&temp_path, json).await.map_err(|e| WorkflowError {
214            code: "STORAGE_ERROR".to_string(),
215            message: format!("Failed to write state file: {}", e),
216            stage: None,
217            trace: None,
218            recovery_hints: vec!["Check disk space and permissions".to_string()],
219        })?;
220        
221        // Atomically rename to final path
222        tokio::fs::rename(&temp_path, &path).await.map_err(|e| WorkflowError {
223            code: "STORAGE_ERROR".to_string(),
224            message: format!("Failed to rename state file: {}", e),
225            stage: None,
226            trace: None,
227            recovery_hints: vec!["Check disk permissions".to_string()],
228        })?;
229        
230        debug!("Saved state for workflow {} to {:?}", state.instance_id, path);
231        Ok(())
232    }
233    
234    async fn load(&self, instance_id: &WorkflowId) -> Result<WorkflowState, WorkflowError> {
235        let lock = self.get_lock(instance_id).await;
236        let _guard = lock.lock().await;
237        
238        let path = self.get_file_path(instance_id);
239        
240        let json = tokio::fs::read_to_string(&path).await.map_err(|e| WorkflowError {
241            code: "STORAGE_ERROR".to_string(),
242            message: format!("Failed to read state file: {}", e),
243            stage: None,
244            trace: None,
245            recovery_hints: vec!["Check if workflow exists".to_string()],
246        })?;
247        
248        let state = serde_json::from_str(&json).map_err(|e| WorkflowError {
249            code: "DESERIALIZATION_ERROR".to_string(),
250            message: format!("Failed to deserialize state: {}", e),
251            stage: None,
252            trace: None,
253            recovery_hints: vec!["State file may be corrupted".to_string()],
254        })?;
255        
256        debug!("Loaded state for workflow {} from {:?}", instance_id, path);
257        Ok(state)
258    }
259    
260    async fn delete(&self, instance_id: &WorkflowId) -> Result<(), WorkflowError> {
261        let lock = self.get_lock(instance_id).await;
262        let _guard = lock.lock().await;
263        
264        let path = self.get_file_path(instance_id);
265        
266        tokio::fs::remove_file(&path).await.map_err(|e| WorkflowError {
267            code: "STORAGE_ERROR".to_string(),
268            message: format!("Failed to delete state file: {}", e),
269            stage: None,
270            trace: None,
271            recovery_hints: vec!["Check file permissions".to_string()],
272        })?;
273        
274        // Remove lock
275        let mut locks = self.locks.write().await;
276        locks.remove(instance_id);
277        
278        debug!("Deleted state for workflow {} at {:?}", instance_id, path);
279        Ok(())
280    }
281    
282    async fn list(&self) -> Result<Vec<WorkflowId>, WorkflowError> {
283        let mut entries = tokio::fs::read_dir(&self.base_dir).await.map_err(|e| WorkflowError {
284            code: "STORAGE_ERROR".to_string(),
285            message: format!("Failed to read state directory: {}", e),
286            stage: None,
287            trace: None,
288            recovery_hints: vec!["Check directory permissions".to_string()],
289        })?;
290        
291        let mut workflow_ids = Vec::new();
292        
293        while let Some(entry) = entries.next_entry().await.map_err(|e| WorkflowError {
294            code: "STORAGE_ERROR".to_string(),
295            message: format!("Failed to read directory entry: {}", e),
296            stage: None,
297            trace: None,
298            recovery_hints: vec![],
299        })? {
300            if let Some(name) = entry.file_name().to_str() {
301                if name.ends_with(".json") {
302                    // Parse workflow ID from filename
303                    let id_str = &name[..name.len() - 5];
304                    if let Ok(id_bytes) = hex::decode(id_str) {
305                        if id_bytes.len() == 16 {
306                            let mut id_array = [0u8; 16];
307                            id_array.copy_from_slice(&id_bytes);
308                            workflow_ids.push(WorkflowId(id_array));
309                        }
310                    }
311                }
312            }
313        }
314        
315        Ok(workflow_ids)
316    }
317    
318    async fn list_by_status(&self, target_status: WorkflowStatus) -> Result<Vec<WorkflowId>, WorkflowError> {
319        let all_ids = self.list().await?;
320        let mut matching_ids = Vec::new();
321        
322        for id in all_ids {
323            if let Ok(state) = self.load(&id).await {
324                if state.status == target_status {
325                    matching_ids.push(id);
326                }
327            }
328        }
329        
330        Ok(matching_ids)
331    }
332    
333    async fn cleanup(&self, retention: Duration) -> Result<u64, WorkflowError> {
334        let all_ids = self.list().await?;
335        let now = SystemTime::now();
336        let mut removed = 0;
337        
338        for id in all_ids {
339            if let Ok(state) = self.load(&id).await {
340                match &state.status {
341                    WorkflowStatus::Completed { .. } | WorkflowStatus::Failed { .. } | WorkflowStatus::Cancelled => {
342                        if let Ok(age) = now.duration_since(state.updated_at) {
343                            if age > retention {
344                                if self.delete(&id).await.is_ok() {
345                                    removed += 1;
346                                }
347                            }
348                        }
349                    }
350                    _ => {}
351                }
352            }
353        }
354        
355        info!("Cleaned up {} old workflow states", removed);
356        Ok(removed)
357    }
358}
359
360/// State store with caching for improved performance
361pub struct CachedStateStore<S: StateStore> {
362    /// Underlying state store
363    inner: S,
364    /// In-memory cache
365    cache: Arc<RwLock<HashMap<WorkflowId, (WorkflowState, SystemTime)>>>,
366    /// Cache TTL
367    ttl: Duration,
368}
369
370impl<S: StateStore> CachedStateStore<S> {
371    /// Create a new cached state store
372    pub fn new(inner: S, ttl: Duration) -> Self {
373        Self {
374            inner,
375            cache: Arc::new(RwLock::new(HashMap::new())),
376            ttl,
377        }
378    }
379    
380    /// Clean up expired cache entries
381    pub async fn cleanup_cache(&self) {
382        let mut cache = self.cache.write().await;
383        let now = SystemTime::now();
384        
385        cache.retain(|_, (_, timestamp)| {
386            if let Ok(age) = now.duration_since(*timestamp) {
387                age < self.ttl
388            } else {
389                true
390            }
391        });
392    }
393}
394
395#[async_trait]
396impl<S: StateStore> StateStore for CachedStateStore<S> {
397    async fn save(&self, state: &WorkflowState) -> Result<(), WorkflowError> {
398        // Save to underlying store
399        self.inner.save(state).await?;
400        
401        // Update cache
402        let mut cache = self.cache.write().await;
403        cache.insert(state.instance_id, (state.clone(), SystemTime::now()));
404        
405        Ok(())
406    }
407    
408    async fn load(&self, instance_id: &WorkflowId) -> Result<WorkflowState, WorkflowError> {
409        // Check cache first
410        {
411            let cache = self.cache.read().await;
412            if let Some((state, timestamp)) = cache.get(instance_id) {
413                if let Ok(age) = SystemTime::now().duration_since(*timestamp) {
414                    if age < self.ttl {
415                        return Ok(state.clone());
416                    }
417                }
418            }
419        }
420        
421        // Load from underlying store
422        let state = self.inner.load(instance_id).await?;
423        
424        // Update cache
425        let mut cache = self.cache.write().await;
426        cache.insert(*instance_id, (state.clone(), SystemTime::now()));
427        
428        Ok(state)
429    }
430    
431    async fn delete(&self, instance_id: &WorkflowId) -> Result<(), WorkflowError> {
432        // Delete from underlying store
433        self.inner.delete(instance_id).await?;
434        
435        // Remove from cache
436        let mut cache = self.cache.write().await;
437        cache.remove(instance_id);
438        
439        Ok(())
440    }
441    
442    async fn list(&self) -> Result<Vec<WorkflowId>, WorkflowError> {
443        self.inner.list().await
444    }
445    
446    async fn list_by_status(&self, status: WorkflowStatus) -> Result<Vec<WorkflowId>, WorkflowError> {
447        self.inner.list_by_status(status).await
448    }
449    
450    async fn cleanup(&self, retention: Duration) -> Result<u64, WorkflowError> {
451        let result = self.inner.cleanup(retention).await?;
452        
453        // Clean cache as well
454        self.cleanup_cache().await;
455        
456        Ok(result)
457    }
458}
459
460#[cfg(test)]
461mod tests {
462    use super::*;
463    
464    #[tokio::test]
465    async fn test_in_memory_store() {
466        let store = InMemoryStateStore::new();
467        
468        let state = WorkflowState {
469            instance_id: WorkflowId::generate(),
470            workflow_id: "test_workflow".to_string(),
471            version: Version { major: 1, minor: 0, patch: 0 },
472            status: WorkflowStatus::Running { current_stage: StageId("stage1".to_string()) },
473            current_stage: StageId("stage1".to_string()),
474            input: HashMap::new(),
475            state: HashMap::new(),
476            metrics: WorkflowMetrics::default(),
477            retry_attempts: HashMap::new(),
478            created_at: SystemTime::now(),
479            updated_at: SystemTime::now(),
480            checkpoint_version: 1,
481        };
482        
483        // Save state
484        store.save(&state).await.unwrap();
485        
486        // Load state
487        let loaded = store.load(&state.instance_id).await.unwrap();
488        assert_eq!(loaded.instance_id, state.instance_id);
489        assert_eq!(loaded.workflow_id, state.workflow_id);
490        
491        // List workflows
492        let list = store.list().await.unwrap();
493        assert_eq!(list.len(), 1);
494        assert_eq!(list[0], state.instance_id);
495        
496        // Delete state
497        store.delete(&state.instance_id).await.unwrap();
498        
499        // Verify deleted
500        assert!(store.load(&state.instance_id).await.is_err());
501    }
502    
503    #[tokio::test]
504    async fn test_file_store() {
505        let temp_dir = tempfile::tempdir().unwrap();
506        let store = FileStateStore::new(temp_dir.path().to_path_buf()).unwrap();
507        
508        // Create an old completed workflow
509        let old_state = WorkflowState {
510            instance_id: WorkflowId::generate(),
511            workflow_id: "old_workflow".to_string(),
512            version: Version { major: 1, minor: 0, patch: 0 },
513            status: WorkflowStatus::Completed { 
514                result: crate::workflow::WorkflowResult {
515                    output: HashMap::new(),
516                    duration: Duration::from_secs(5),
517                    metrics: WorkflowMetrics::default(),
518                }
519            },
520            current_stage: StageId("final".to_string()),
521            input: HashMap::new(),
522            state: HashMap::new(),
523            metrics: WorkflowMetrics::default(),
524            retry_attempts: HashMap::new(),
525            created_at: SystemTime::now() - Duration::from_secs(200),
526            updated_at: SystemTime::now() - Duration::from_secs(200),
527            checkpoint_version: 1,
528        };
529        
530        // Save the old workflow using the store's save method first,
531        // then manually update the file to have old timestamps
532        store.save(&old_state).await.unwrap();
533        
534        // Now manually update the file's content to have old timestamps
535        let path = store.get_file_path(&old_state.instance_id);
536        let mut old_state_with_old_times = old_state.clone();
537        old_state_with_old_times.updated_at = SystemTime::now() - Duration::from_secs(200);
538        let json = serde_json::to_string_pretty(&old_state_with_old_times).unwrap();
539        tokio::fs::write(&path, json).await.unwrap();
540        
541        // Create a new workflow to verify it's not deleted
542        let new_state = WorkflowState {
543            instance_id: WorkflowId::generate(),
544            workflow_id: "new_workflow".to_string(),
545            version: Version { major: 1, minor: 0, patch: 0 },
546            status: WorkflowStatus::Completed { 
547                result: crate::workflow::WorkflowResult {
548                    output: HashMap::new(),
549                    duration: Duration::from_secs(5),
550                    metrics: WorkflowMetrics::default(),
551                }
552            },
553            current_stage: StageId("final".to_string()),
554            input: HashMap::new(),
555            state: HashMap::new(),
556            metrics: WorkflowMetrics::default(),
557            retry_attempts: HashMap::new(),
558            created_at: SystemTime::now(),
559            updated_at: SystemTime::now(),
560            checkpoint_version: 1,
561        };
562        
563        // Save the new workflow normally
564        store.save(&new_state).await.unwrap();
565        
566        // Verify both exist
567        assert_eq!(store.list().await.unwrap().len(), 2);
568        
569        // Cleanup old workflows (older than 100 seconds)
570        let removed = store.cleanup(Duration::from_secs(100)).await.unwrap();
571        assert_eq!(removed, 1);
572        
573        // Verify only the new workflow remains
574        let remaining = store.list().await.unwrap();
575        assert_eq!(remaining.len(), 1);
576        assert_eq!(remaining[0], new_state.instance_id);
577    }
578}