1use crate::models::{Checkpoint, StepState, WorkflowState, WorkflowStatus};
7use crate::traits::{StateStore, StateStoreError, StateStoreResult};
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use sqlx::postgres::{PgConnectOptions, PgPoolOptions};
11use sqlx::{ConnectOptions, PgPool, Row};
12use std::str::FromStr;
13use std::time::Duration;
14use tracing::{debug, info, warn};
15use uuid::Uuid;
16
17pub struct PostgresStateStore {
19 pool: PgPool,
20}
21
22impl PostgresStateStore {
23 pub async fn new(
43 database_url: impl AsRef<str>,
44 min_connections: Option<u32>,
45 max_connections: Option<u32>,
46 ) -> StateStoreResult<Self> {
47 let min_conn = min_connections.unwrap_or(5);
48 let max_conn = max_connections.unwrap_or(20);
49
50 info!(
51 "Initializing PostgreSQL state store (min_connections={}, max_connections={})",
52 min_conn, max_conn
53 );
54
55 let mut connect_opts = PgConnectOptions::from_str(database_url.as_ref())
57 .map_err(|e| StateStoreError::Configuration(format!("Invalid database URL: {}", e)))?;
58
59 connect_opts = connect_opts.log_statements(tracing::log::LevelFilter::Debug);
61
62 let pool = PgPoolOptions::new()
64 .min_connections(min_conn)
65 .max_connections(max_conn)
66 .acquire_timeout(Duration::from_secs(5))
67 .idle_timeout(Some(Duration::from_secs(300)))
68 .max_lifetime(Some(Duration::from_secs(1800)))
69 .connect_with(connect_opts)
70 .await
71 .map_err(|e| StateStoreError::Connection(format!("Failed to create connection pool: {}", e)))?;
72
73 info!("PostgreSQL connection pool established");
74
75 let store = Self { pool };
76
77 store.run_migrations().await?;
79
80 Ok(store)
81 }
82
83 async fn run_migrations(&self) -> StateStoreResult<()> {
85 info!("Running database migrations");
86
87 let migration_001 = include_str!("../migrations/001_initial_schema.sql");
89 let migration_002 = include_str!("../migrations/002_checkpoints.sql");
90
91 sqlx::query(migration_001)
93 .execute(&self.pool)
94 .await
95 .map_err(|e| StateStoreError::Database(format!("Migration 001 failed: {}", e)))?;
96
97 sqlx::query(migration_002)
98 .execute(&self.pool)
99 .await
100 .map_err(|e| StateStoreError::Database(format!("Migration 002 failed: {}", e)))?;
101
102 info!("Database migrations completed successfully");
103 Ok(())
104 }
105
106 pub fn pool(&self) -> &PgPool {
108 &self.pool
109 }
110}
111
112#[async_trait]
113impl StateStore for PostgresStateStore {
114 async fn save_workflow_state(&self, state: &WorkflowState) -> StateStoreResult<()> {
115 debug!("Saving workflow state: id={}, workflow_id={}", state.id, state.workflow_id);
116
117 let mut tx = self.pool.begin().await?;
118
119 let context_json = serde_json::to_string(&state.context)?;
121
122 sqlx::query(
124 r#"
125 INSERT INTO workflow_states (
126 id, workflow_id, workflow_name, status, user_id,
127 started_at, updated_at, completed_at, context, error
128 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
129 ON CONFLICT (id) DO UPDATE SET
130 status = EXCLUDED.status,
131 updated_at = EXCLUDED.updated_at,
132 completed_at = EXCLUDED.completed_at,
133 context = EXCLUDED.context,
134 error = EXCLUDED.error
135 "#
136 )
137 .bind(state.id)
138 .bind(&state.workflow_id)
139 .bind(&state.workflow_name)
140 .bind(state.status.to_string())
141 .bind(&state.user_id)
142 .bind(state.started_at)
143 .bind(state.updated_at)
144 .bind(state.completed_at)
145 .bind(context_json)
146 .bind(&state.error)
147 .execute(&mut *tx)
148 .await?;
149
150 for (step_id, step_state) in &state.steps {
152 let outputs_json = serde_json::to_string(&step_state.outputs)?;
153
154 sqlx::query(
155 r#"
156 INSERT INTO step_states (
157 workflow_state_id, step_id, status, started_at, completed_at,
158 outputs, error, retry_count
159 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
160 ON CONFLICT (workflow_state_id, step_id) DO UPDATE SET
161 status = EXCLUDED.status,
162 started_at = EXCLUDED.started_at,
163 completed_at = EXCLUDED.completed_at,
164 outputs = EXCLUDED.outputs,
165 error = EXCLUDED.error,
166 retry_count = EXCLUDED.retry_count
167 "#
168 )
169 .bind(state.id)
170 .bind(step_id)
171 .bind(step_state.status.to_string())
172 .bind(step_state.started_at)
173 .bind(step_state.completed_at)
174 .bind(outputs_json)
175 .bind(&step_state.error)
176 .bind(step_state.retry_count)
177 .execute(&mut *tx)
178 .await?;
179 }
180
181 tx.commit().await?;
182
183 debug!("Workflow state saved successfully: id={}", state.id);
184 Ok(())
185 }
186
187 async fn load_workflow_state(&self, id: &Uuid) -> StateStoreResult<WorkflowState> {
188 debug!("Loading workflow state: id={}", id);
189
190 let row = sqlx::query(
192 r#"
193 SELECT id, workflow_id, workflow_name, status, user_id,
194 started_at, updated_at, completed_at, context, error
195 FROM workflow_states
196 WHERE id = $1
197 "#
198 )
199 .bind(id)
200 .fetch_one(&self.pool)
201 .await?;
202
203 let workflow_id: Uuid = row.get("id");
204 let status_str: String = row.get("status");
205 let status = WorkflowStatus::from_str(&status_str)
206 .map_err(StateStoreError::InvalidState)?;
207
208 let context_str: String = row.get("context");
209 let context = serde_json::from_str(&context_str)?;
210
211 let mut state = WorkflowState {
212 id: workflow_id,
213 workflow_id: row.get("workflow_id"),
214 workflow_name: row.get("workflow_name"),
215 status,
216 user_id: row.get("user_id"),
217 started_at: row.get("started_at"),
218 updated_at: row.get("updated_at"),
219 completed_at: row.get("completed_at"),
220 context,
221 error: row.get("error"),
222 steps: Default::default(),
223 };
224
225 let step_rows = sqlx::query(
227 r#"
228 SELECT step_id, status, started_at, completed_at,
229 outputs, error, retry_count
230 FROM step_states
231 WHERE workflow_state_id = $1
232 "#
233 )
234 .bind(workflow_id)
235 .fetch_all(&self.pool)
236 .await?;
237
238 for step_row in step_rows {
239 let step_id: String = step_row.get("step_id");
240 let status_str: String = step_row.get("status");
241 let status = crate::models::StepStatus::from_str(&status_str)
242 .map_err(StateStoreError::InvalidState)?;
243
244 let outputs_str: Option<String> = step_row.get("outputs");
245 let outputs = if let Some(json_str) = outputs_str {
246 serde_json::from_str(&json_str)?
247 } else {
248 serde_json::Value::Null
249 };
250
251 let step_state = StepState {
252 step_id: step_id.clone(),
253 status,
254 started_at: step_row.get("started_at"),
255 completed_at: step_row.get("completed_at"),
256 outputs,
257 error: step_row.get("error"),
258 retry_count: step_row.get("retry_count"),
259 };
260
261 state.steps.insert(step_id, step_state);
262 }
263
264 debug!("Workflow state loaded successfully: id={}", id);
265 Ok(state)
266 }
267
268 async fn load_workflow_state_by_workflow_id(&self, workflow_id: &str) -> StateStoreResult<WorkflowState> {
269 debug!("Loading workflow state by workflow_id: {}", workflow_id);
270
271 let row = sqlx::query(
273 r#"
274 SELECT id
275 FROM workflow_states
276 WHERE workflow_id = $1
277 ORDER BY updated_at DESC
278 LIMIT 1
279 "#
280 )
281 .bind(workflow_id)
282 .fetch_one(&self.pool)
283 .await?;
284
285 let id: Uuid = row.get("id");
286 self.load_workflow_state(&id).await
287 }
288
289 async fn list_active_workflows(&self) -> StateStoreResult<Vec<WorkflowState>> {
290 debug!("Listing active workflows");
291
292 let rows = sqlx::query(
293 r#"
294 SELECT id
295 FROM workflow_states
296 WHERE status IN ('running', 'pending', 'paused')
297 ORDER BY updated_at DESC
298 "#
299 )
300 .fetch_all(&self.pool)
301 .await?;
302
303 let mut workflows = Vec::new();
304 for row in rows {
305 let id: Uuid = row.get("id");
306 match self.load_workflow_state(&id).await {
307 Ok(state) => workflows.push(state),
308 Err(e) => {
309 warn!("Failed to load workflow state {}: {}", id, e);
310 }
311 }
312 }
313
314 debug!("Found {} active workflows", workflows.len());
315 Ok(workflows)
316 }
317
318 async fn create_checkpoint(&self, checkpoint: &Checkpoint) -> StateStoreResult<()> {
319 debug!("Creating checkpoint: id={}, workflow_state_id={}", checkpoint.id, checkpoint.workflow_state_id);
320
321 let snapshot_json = serde_json::to_string(&checkpoint.snapshot)?;
322
323 sqlx::query(
324 r#"
325 INSERT INTO checkpoints (id, workflow_state_id, step_id, timestamp, snapshot)
326 VALUES ($1, $2, $3, $4, $5)
327 "#
328 )
329 .bind(checkpoint.id)
330 .bind(checkpoint.workflow_state_id)
331 .bind(&checkpoint.step_id)
332 .bind(checkpoint.timestamp)
333 .bind(snapshot_json)
334 .execute(&self.pool)
335 .await?;
336
337 self.cleanup_old_checkpoints(&checkpoint.workflow_state_id, 10).await?;
339
340 debug!("Checkpoint created successfully: id={}", checkpoint.id);
341 Ok(())
342 }
343
344 async fn get_latest_checkpoint(&self, workflow_state_id: &Uuid) -> StateStoreResult<Option<Checkpoint>> {
345 debug!("Getting latest checkpoint for workflow_state_id={}", workflow_state_id);
346
347 let row_opt = sqlx::query(
348 r#"
349 SELECT id, workflow_state_id, step_id, timestamp, snapshot
350 FROM checkpoints
351 WHERE workflow_state_id = $1
352 ORDER BY timestamp DESC
353 LIMIT 1
354 "#
355 )
356 .bind(workflow_state_id)
357 .fetch_optional(&self.pool)
358 .await?;
359
360 if let Some(row) = row_opt {
361 let snapshot_str: String = row.get("snapshot");
362 let snapshot = serde_json::from_str(&snapshot_str)?;
363
364 let checkpoint = Checkpoint {
365 id: row.get("id"),
366 workflow_state_id: row.get("workflow_state_id"),
367 step_id: row.get("step_id"),
368 timestamp: row.get("timestamp"),
369 snapshot,
370 };
371
372 debug!("Found latest checkpoint: id={}", checkpoint.id);
373 Ok(Some(checkpoint))
374 } else {
375 debug!("No checkpoints found for workflow_state_id={}", workflow_state_id);
376 Ok(None)
377 }
378 }
379
380 async fn restore_from_checkpoint(&self, checkpoint_id: &Uuid) -> StateStoreResult<WorkflowState> {
381 debug!("Restoring from checkpoint: id={}", checkpoint_id);
382
383 let row = sqlx::query(
384 r#"
385 SELECT snapshot
386 FROM checkpoints
387 WHERE id = $1
388 "#
389 )
390 .bind(checkpoint_id)
391 .fetch_one(&self.pool)
392 .await?;
393
394 let snapshot_str: String = row.get("snapshot");
395 let state: WorkflowState = serde_json::from_str(&snapshot_str)?;
396
397 debug!("Successfully restored state from checkpoint: id={}", checkpoint_id);
398 Ok(state)
399 }
400
401 async fn delete_old_states(&self, older_than: DateTime<Utc>) -> StateStoreResult<u64> {
402 debug!("Deleting states older than: {}", older_than);
403
404 let result = sqlx::query(
405 r#"
406 DELETE FROM workflow_states
407 WHERE updated_at < $1
408 AND status IN ('completed', 'failed')
409 "#
410 )
411 .bind(older_than)
412 .execute(&self.pool)
413 .await?;
414
415 let deleted = result.rows_affected();
416 debug!("Deleted {} old workflow states", deleted);
417 Ok(deleted)
418 }
419
420 async fn cleanup_old_checkpoints(&self, workflow_state_id: &Uuid, keep_count: usize) -> StateStoreResult<u64> {
421 debug!("Cleaning up old checkpoints for workflow_state_id={}, keeping last {}", workflow_state_id, keep_count);
422
423 let result = sqlx::query(
425 r#"
426 DELETE FROM checkpoints
427 WHERE workflow_state_id = $1
428 AND id NOT IN (
429 SELECT id FROM checkpoints
430 WHERE workflow_state_id = $1
431 ORDER BY timestamp DESC
432 LIMIT $2
433 )
434 "#
435 )
436 .bind(workflow_state_id)
437 .bind(keep_count as i64)
438 .execute(&self.pool)
439 .await?;
440
441 let deleted = result.rows_affected();
442 if deleted > 0 {
443 debug!("Cleaned up {} old checkpoints", deleted);
444 }
445 Ok(deleted)
446 }
447
448 async fn health_check(&self) -> StateStoreResult<()> {
449 debug!("Performing health check");
450
451 sqlx::query("SELECT 1")
453 .fetch_one(&self.pool)
454 .await
455 .map_err(|e| StateStoreError::Connection(format!("Health check failed: {}", e)))?;
456
457 debug!("Health check passed");
458 Ok(())
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465 use crate::models::WorkflowState;
466 use serde_json::json;
467
468 #[tokio::test]
473 #[ignore]
474 async fn test_postgres_state_store_integration() {
475 let database_url = std::env::var("TEST_DATABASE_URL")
476 .unwrap_or_else(|_| "postgresql://postgres:postgres@localhost/test_workflows".to_string());
477
478 let store = PostgresStateStore::new(&database_url, Some(2), Some(5))
479 .await
480 .expect("Failed to create state store");
481
482 store.health_check().await.expect("Health check failed");
484
485 let mut state = WorkflowState::new(
487 "test-workflow-1",
488 "Test Workflow",
489 Some("user-123".to_string()),
490 json!({"inputs": {"test": "value"}}),
491 );
492 state.mark_running();
493
494 store.save_workflow_state(&state).await.expect("Failed to save state");
496
497 let loaded = store.load_workflow_state(&state.id).await.expect("Failed to load state");
499 assert_eq!(loaded.workflow_id, state.workflow_id);
500 assert_eq!(loaded.status, WorkflowStatus::Running);
501
502 let active = store.list_active_workflows().await.expect("Failed to list active workflows");
504 assert!(!active.is_empty());
505
506 println!("✅ PostgreSQL integration test passed");
507 }
508}