1use std::{
7 collections::HashMap,
8 sync::Arc,
9 time::{Duration, SystemTime},
10};
11
12use async_trait::async_trait;
13use serde::{Deserialize, Serialize};
14use tokio::sync::RwLock;
15use tracing::{debug, info};
16
17use crate::workflow::{
18 StageId, Version, WorkflowError, WorkflowId, WorkflowMetrics, WorkflowStatus,
19};
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct WorkflowState {
24 pub instance_id: WorkflowId,
26 pub workflow_id: String,
28 pub version: Version,
30 pub status: WorkflowStatus,
32 pub current_stage: StageId,
34 pub input: HashMap<String, Vec<u8>>,
36 pub state: HashMap<String, Vec<u8>>,
38 pub metrics: WorkflowMetrics,
40 pub retry_attempts: HashMap<StageId, u32>,
42 pub created_at: SystemTime,
44 pub updated_at: SystemTime,
46 pub checkpoint_version: u64,
48}
49
50#[async_trait]
52pub trait StateStore: Send + Sync {
53 async fn save(&self, state: &WorkflowState) -> Result<(), WorkflowError>;
55
56 async fn load(&self, instance_id: &WorkflowId) -> Result<WorkflowState, WorkflowError>;
58
59 async fn delete(&self, instance_id: &WorkflowId) -> Result<(), WorkflowError>;
61
62 async fn list(&self) -> Result<Vec<WorkflowId>, WorkflowError>;
64
65 async fn list_by_status(&self, status: WorkflowStatus) -> Result<Vec<WorkflowId>, WorkflowError>;
67
68 async fn cleanup(&self, retention: Duration) -> Result<u64, WorkflowError>;
70}
71
72pub struct InMemoryStateStore {
74 states: Arc<RwLock<HashMap<WorkflowId, WorkflowState>>>,
75}
76
77impl InMemoryStateStore {
78 pub fn new() -> Self {
80 Self {
81 states: Arc::new(RwLock::new(HashMap::new())),
82 }
83 }
84}
85
86#[async_trait]
87impl StateStore for InMemoryStateStore {
88 async fn save(&self, state: &WorkflowState) -> Result<(), WorkflowError> {
89 let mut states = self.states.write().await;
90 states.insert(state.instance_id, state.clone());
91 debug!("Saved state for workflow {}", state.instance_id);
92 Ok(())
93 }
94
95 async fn load(&self, instance_id: &WorkflowId) -> Result<WorkflowState, WorkflowError> {
96 let states = self.states.read().await;
97 states.get(instance_id).cloned().ok_or_else(|| WorkflowError {
98 code: "STATE_NOT_FOUND".to_string(),
99 message: format!("State not found for workflow {}", instance_id),
100 stage: None,
101 trace: None,
102 recovery_hints: vec![],
103 })
104 }
105
106 async fn delete(&self, instance_id: &WorkflowId) -> Result<(), WorkflowError> {
107 let mut states = self.states.write().await;
108 states.remove(instance_id);
109 debug!("Deleted state for workflow {}", instance_id);
110 Ok(())
111 }
112
113 async fn list(&self) -> Result<Vec<WorkflowId>, WorkflowError> {
114 let states = self.states.read().await;
115 Ok(states.keys().cloned().collect())
116 }
117
118 async fn list_by_status(&self, target_status: WorkflowStatus) -> Result<Vec<WorkflowId>, WorkflowError> {
119 let states = self.states.read().await;
120 Ok(states.iter()
121 .filter(|(_, state)| state.status == target_status)
122 .map(|(id, _)| *id)
123 .collect())
124 }
125
126 async fn cleanup(&self, retention: Duration) -> Result<u64, WorkflowError> {
127 let mut states = self.states.write().await;
128 let now = SystemTime::now();
129 let mut removed = 0;
130
131 states.retain(|_, state| {
132 match &state.status {
133 WorkflowStatus::Completed { .. } | WorkflowStatus::Failed { .. } | WorkflowStatus::Cancelled => {
134 if let Ok(age) = now.duration_since(state.updated_at) {
135 if age > retention {
136 removed += 1;
137 return false;
138 }
139 }
140 }
141 _ => {}
142 }
143 true
144 });
145
146 debug!("Cleaned up {} old workflow states", removed);
147 Ok(removed)
148 }
149}
150
151pub struct FileStateStore {
153 base_dir: std::path::PathBuf,
155 locks: Arc<RwLock<HashMap<WorkflowId, Arc<tokio::sync::Mutex<()>>>>>,
157}
158
159impl FileStateStore {
160 pub fn new(base_dir: std::path::PathBuf) -> Result<Self, WorkflowError> {
162 std::fs::create_dir_all(&base_dir).map_err(|e| WorkflowError {
164 code: "STORAGE_ERROR".to_string(),
165 message: format!("Failed to create state directory: {}", e),
166 stage: None,
167 trace: None,
168 recovery_hints: vec!["Check directory permissions".to_string()],
169 })?;
170
171 Ok(Self {
172 base_dir,
173 locks: Arc::new(RwLock::new(HashMap::new())),
174 })
175 }
176
177 fn get_file_path(&self, instance_id: &WorkflowId) -> std::path::PathBuf {
179 self.base_dir.join(format!("{}.json", hex::encode(&instance_id.0)))
181 }
182
183 async fn get_lock(&self, instance_id: &WorkflowId) -> Arc<tokio::sync::Mutex<()>> {
185 let mut locks = self.locks.write().await;
186 locks.entry(*instance_id)
187 .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
188 .clone()
189 }
190}
191
192#[async_trait]
193impl StateStore for FileStateStore {
194 async fn save(&self, state: &WorkflowState) -> Result<(), WorkflowError> {
195 let lock = self.get_lock(&state.instance_id).await;
196 let _guard = lock.lock().await;
197
198 let path = self.get_file_path(&state.instance_id);
199 let mut updated_state = state.clone();
200 updated_state.updated_at = SystemTime::now();
201 updated_state.checkpoint_version += 1;
202
203 let json = serde_json::to_string_pretty(&updated_state).map_err(|e| WorkflowError {
204 code: "SERIALIZATION_ERROR".to_string(),
205 message: format!("Failed to serialize state: {}", e),
206 stage: None,
207 trace: None,
208 recovery_hints: vec![],
209 })?;
210
211 let temp_path = path.with_extension("tmp");
213 tokio::fs::write(&temp_path, json).await.map_err(|e| WorkflowError {
214 code: "STORAGE_ERROR".to_string(),
215 message: format!("Failed to write state file: {}", e),
216 stage: None,
217 trace: None,
218 recovery_hints: vec!["Check disk space and permissions".to_string()],
219 })?;
220
221 tokio::fs::rename(&temp_path, &path).await.map_err(|e| WorkflowError {
223 code: "STORAGE_ERROR".to_string(),
224 message: format!("Failed to rename state file: {}", e),
225 stage: None,
226 trace: None,
227 recovery_hints: vec!["Check disk permissions".to_string()],
228 })?;
229
230 debug!("Saved state for workflow {} to {:?}", state.instance_id, path);
231 Ok(())
232 }
233
234 async fn load(&self, instance_id: &WorkflowId) -> Result<WorkflowState, WorkflowError> {
235 let lock = self.get_lock(instance_id).await;
236 let _guard = lock.lock().await;
237
238 let path = self.get_file_path(instance_id);
239
240 let json = tokio::fs::read_to_string(&path).await.map_err(|e| WorkflowError {
241 code: "STORAGE_ERROR".to_string(),
242 message: format!("Failed to read state file: {}", e),
243 stage: None,
244 trace: None,
245 recovery_hints: vec!["Check if workflow exists".to_string()],
246 })?;
247
248 let state = serde_json::from_str(&json).map_err(|e| WorkflowError {
249 code: "DESERIALIZATION_ERROR".to_string(),
250 message: format!("Failed to deserialize state: {}", e),
251 stage: None,
252 trace: None,
253 recovery_hints: vec!["State file may be corrupted".to_string()],
254 })?;
255
256 debug!("Loaded state for workflow {} from {:?}", instance_id, path);
257 Ok(state)
258 }
259
260 async fn delete(&self, instance_id: &WorkflowId) -> Result<(), WorkflowError> {
261 let lock = self.get_lock(instance_id).await;
262 let _guard = lock.lock().await;
263
264 let path = self.get_file_path(instance_id);
265
266 tokio::fs::remove_file(&path).await.map_err(|e| WorkflowError {
267 code: "STORAGE_ERROR".to_string(),
268 message: format!("Failed to delete state file: {}", e),
269 stage: None,
270 trace: None,
271 recovery_hints: vec!["Check file permissions".to_string()],
272 })?;
273
274 let mut locks = self.locks.write().await;
276 locks.remove(instance_id);
277
278 debug!("Deleted state for workflow {} at {:?}", instance_id, path);
279 Ok(())
280 }
281
282 async fn list(&self) -> Result<Vec<WorkflowId>, WorkflowError> {
283 let mut entries = tokio::fs::read_dir(&self.base_dir).await.map_err(|e| WorkflowError {
284 code: "STORAGE_ERROR".to_string(),
285 message: format!("Failed to read state directory: {}", e),
286 stage: None,
287 trace: None,
288 recovery_hints: vec!["Check directory permissions".to_string()],
289 })?;
290
291 let mut workflow_ids = Vec::new();
292
293 while let Some(entry) = entries.next_entry().await.map_err(|e| WorkflowError {
294 code: "STORAGE_ERROR".to_string(),
295 message: format!("Failed to read directory entry: {}", e),
296 stage: None,
297 trace: None,
298 recovery_hints: vec![],
299 })? {
300 if let Some(name) = entry.file_name().to_str() {
301 if name.ends_with(".json") {
302 let id_str = &name[..name.len() - 5];
304 if let Ok(id_bytes) = hex::decode(id_str) {
305 if id_bytes.len() == 16 {
306 let mut id_array = [0u8; 16];
307 id_array.copy_from_slice(&id_bytes);
308 workflow_ids.push(WorkflowId(id_array));
309 }
310 }
311 }
312 }
313 }
314
315 Ok(workflow_ids)
316 }
317
318 async fn list_by_status(&self, target_status: WorkflowStatus) -> Result<Vec<WorkflowId>, WorkflowError> {
319 let all_ids = self.list().await?;
320 let mut matching_ids = Vec::new();
321
322 for id in all_ids {
323 if let Ok(state) = self.load(&id).await {
324 if state.status == target_status {
325 matching_ids.push(id);
326 }
327 }
328 }
329
330 Ok(matching_ids)
331 }
332
333 async fn cleanup(&self, retention: Duration) -> Result<u64, WorkflowError> {
334 let all_ids = self.list().await?;
335 let now = SystemTime::now();
336 let mut removed = 0;
337
338 for id in all_ids {
339 if let Ok(state) = self.load(&id).await {
340 match &state.status {
341 WorkflowStatus::Completed { .. } | WorkflowStatus::Failed { .. } | WorkflowStatus::Cancelled => {
342 if let Ok(age) = now.duration_since(state.updated_at) {
343 if age > retention {
344 if self.delete(&id).await.is_ok() {
345 removed += 1;
346 }
347 }
348 }
349 }
350 _ => {}
351 }
352 }
353 }
354
355 info!("Cleaned up {} old workflow states", removed);
356 Ok(removed)
357 }
358}
359
360pub struct CachedStateStore<S: StateStore> {
362 inner: S,
364 cache: Arc<RwLock<HashMap<WorkflowId, (WorkflowState, SystemTime)>>>,
366 ttl: Duration,
368}
369
370impl<S: StateStore> CachedStateStore<S> {
371 pub fn new(inner: S, ttl: Duration) -> Self {
373 Self {
374 inner,
375 cache: Arc::new(RwLock::new(HashMap::new())),
376 ttl,
377 }
378 }
379
380 pub async fn cleanup_cache(&self) {
382 let mut cache = self.cache.write().await;
383 let now = SystemTime::now();
384
385 cache.retain(|_, (_, timestamp)| {
386 if let Ok(age) = now.duration_since(*timestamp) {
387 age < self.ttl
388 } else {
389 true
390 }
391 });
392 }
393}
394
395#[async_trait]
396impl<S: StateStore> StateStore for CachedStateStore<S> {
397 async fn save(&self, state: &WorkflowState) -> Result<(), WorkflowError> {
398 self.inner.save(state).await?;
400
401 let mut cache = self.cache.write().await;
403 cache.insert(state.instance_id, (state.clone(), SystemTime::now()));
404
405 Ok(())
406 }
407
408 async fn load(&self, instance_id: &WorkflowId) -> Result<WorkflowState, WorkflowError> {
409 {
411 let cache = self.cache.read().await;
412 if let Some((state, timestamp)) = cache.get(instance_id) {
413 if let Ok(age) = SystemTime::now().duration_since(*timestamp) {
414 if age < self.ttl {
415 return Ok(state.clone());
416 }
417 }
418 }
419 }
420
421 let state = self.inner.load(instance_id).await?;
423
424 let mut cache = self.cache.write().await;
426 cache.insert(*instance_id, (state.clone(), SystemTime::now()));
427
428 Ok(state)
429 }
430
431 async fn delete(&self, instance_id: &WorkflowId) -> Result<(), WorkflowError> {
432 self.inner.delete(instance_id).await?;
434
435 let mut cache = self.cache.write().await;
437 cache.remove(instance_id);
438
439 Ok(())
440 }
441
442 async fn list(&self) -> Result<Vec<WorkflowId>, WorkflowError> {
443 self.inner.list().await
444 }
445
446 async fn list_by_status(&self, status: WorkflowStatus) -> Result<Vec<WorkflowId>, WorkflowError> {
447 self.inner.list_by_status(status).await
448 }
449
450 async fn cleanup(&self, retention: Duration) -> Result<u64, WorkflowError> {
451 let result = self.inner.cleanup(retention).await?;
452
453 self.cleanup_cache().await;
455
456 Ok(result)
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463
464 #[tokio::test]
465 async fn test_in_memory_store() {
466 let store = InMemoryStateStore::new();
467
468 let state = WorkflowState {
469 instance_id: WorkflowId::generate(),
470 workflow_id: "test_workflow".to_string(),
471 version: Version { major: 1, minor: 0, patch: 0 },
472 status: WorkflowStatus::Running { current_stage: StageId("stage1".to_string()) },
473 current_stage: StageId("stage1".to_string()),
474 input: HashMap::new(),
475 state: HashMap::new(),
476 metrics: WorkflowMetrics::default(),
477 retry_attempts: HashMap::new(),
478 created_at: SystemTime::now(),
479 updated_at: SystemTime::now(),
480 checkpoint_version: 1,
481 };
482
483 store.save(&state).await.unwrap();
485
486 let loaded = store.load(&state.instance_id).await.unwrap();
488 assert_eq!(loaded.instance_id, state.instance_id);
489 assert_eq!(loaded.workflow_id, state.workflow_id);
490
491 let list = store.list().await.unwrap();
493 assert_eq!(list.len(), 1);
494 assert_eq!(list[0], state.instance_id);
495
496 store.delete(&state.instance_id).await.unwrap();
498
499 assert!(store.load(&state.instance_id).await.is_err());
501 }
502
503 #[tokio::test]
504 async fn test_file_store() {
505 let temp_dir = tempfile::tempdir().unwrap();
506 let store = FileStateStore::new(temp_dir.path().to_path_buf()).unwrap();
507
508 let old_state = WorkflowState {
510 instance_id: WorkflowId::generate(),
511 workflow_id: "old_workflow".to_string(),
512 version: Version { major: 1, minor: 0, patch: 0 },
513 status: WorkflowStatus::Completed {
514 result: crate::workflow::WorkflowResult {
515 output: HashMap::new(),
516 duration: Duration::from_secs(5),
517 metrics: WorkflowMetrics::default(),
518 }
519 },
520 current_stage: StageId("final".to_string()),
521 input: HashMap::new(),
522 state: HashMap::new(),
523 metrics: WorkflowMetrics::default(),
524 retry_attempts: HashMap::new(),
525 created_at: SystemTime::now() - Duration::from_secs(200),
526 updated_at: SystemTime::now() - Duration::from_secs(200),
527 checkpoint_version: 1,
528 };
529
530 store.save(&old_state).await.unwrap();
533
534 let path = store.get_file_path(&old_state.instance_id);
536 let mut old_state_with_old_times = old_state.clone();
537 old_state_with_old_times.updated_at = SystemTime::now() - Duration::from_secs(200);
538 let json = serde_json::to_string_pretty(&old_state_with_old_times).unwrap();
539 tokio::fs::write(&path, json).await.unwrap();
540
541 let new_state = WorkflowState {
543 instance_id: WorkflowId::generate(),
544 workflow_id: "new_workflow".to_string(),
545 version: Version { major: 1, minor: 0, patch: 0 },
546 status: WorkflowStatus::Completed {
547 result: crate::workflow::WorkflowResult {
548 output: HashMap::new(),
549 duration: Duration::from_secs(5),
550 metrics: WorkflowMetrics::default(),
551 }
552 },
553 current_stage: StageId("final".to_string()),
554 input: HashMap::new(),
555 state: HashMap::new(),
556 metrics: WorkflowMetrics::default(),
557 retry_attempts: HashMap::new(),
558 created_at: SystemTime::now(),
559 updated_at: SystemTime::now(),
560 checkpoint_version: 1,
561 };
562
563 store.save(&new_state).await.unwrap();
565
566 assert_eq!(store.list().await.unwrap().len(), 2);
568
569 let removed = store.cleanup(Duration::from_secs(100)).await.unwrap();
571 assert_eq!(removed, 1);
572
573 let remaining = store.list().await.unwrap();
575 assert_eq!(remaining.len(), 1);
576 assert_eq!(remaining[0], new_state.instance_id);
577 }
578}