llm_orchestrator_state/
postgres.rs

1// Copyright (c) 2025 LLM DevOps
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! PostgreSQL implementation of the StateStore trait.
5
6use crate::models::{Checkpoint, StepState, WorkflowState, WorkflowStatus};
7use crate::traits::{StateStore, StateStoreError, StateStoreResult};
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use sqlx::postgres::{PgConnectOptions, PgPoolOptions};
11use sqlx::{ConnectOptions, PgPool, Row};
12use std::str::FromStr;
13use std::time::Duration;
14use tracing::{debug, info, warn};
15use uuid::Uuid;
16
17/// PostgreSQL state store implementation.
18pub struct PostgresStateStore {
19    pool: PgPool,
20}
21
22impl PostgresStateStore {
23    /// Create a new PostgreSQL state store with connection pooling.
24    ///
25    /// # Arguments
26    /// * `database_url` - PostgreSQL connection string
27    /// * `min_connections` - Minimum number of connections in pool (default: 5)
28    /// * `max_connections` - Maximum number of connections in pool (default: 20)
29    ///
30    /// # Example
31    /// ```no_run
32    /// # use llm_orchestrator_state::postgres::PostgresStateStore;
33    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
34    /// let store = PostgresStateStore::new(
35    ///     "postgresql://user:pass@localhost/workflows",
36    ///     Some(5),
37    ///     Some(20),
38    /// ).await?;
39    /// # Ok(())
40    /// # }
41    /// ```
42    pub async fn new(
43        database_url: impl AsRef<str>,
44        min_connections: Option<u32>,
45        max_connections: Option<u32>,
46    ) -> StateStoreResult<Self> {
47        let min_conn = min_connections.unwrap_or(5);
48        let max_conn = max_connections.unwrap_or(20);
49
50        info!(
51            "Initializing PostgreSQL state store (min_connections={}, max_connections={})",
52            min_conn, max_conn
53        );
54
55        // Parse connection options
56        let mut connect_opts = PgConnectOptions::from_str(database_url.as_ref())
57            .map_err(|e| StateStoreError::Configuration(format!("Invalid database URL: {}", e)))?;
58
59        // Configure logging
60        connect_opts = connect_opts.log_statements(tracing::log::LevelFilter::Debug);
61
62        // Build connection pool
63        let pool = PgPoolOptions::new()
64            .min_connections(min_conn)
65            .max_connections(max_conn)
66            .acquire_timeout(Duration::from_secs(5))
67            .idle_timeout(Some(Duration::from_secs(300)))
68            .max_lifetime(Some(Duration::from_secs(1800)))
69            .connect_with(connect_opts)
70            .await
71            .map_err(|e| StateStoreError::Connection(format!("Failed to create connection pool: {}", e)))?;
72
73        info!("PostgreSQL connection pool established");
74
75        let store = Self { pool };
76
77        // Run migrations
78        store.run_migrations().await?;
79
80        Ok(store)
81    }
82
83    /// Run database migrations.
84    async fn run_migrations(&self) -> StateStoreResult<()> {
85        info!("Running database migrations");
86
87        // Read migration files
88        let migration_001 = include_str!("../migrations/001_initial_schema.sql");
89        let migration_002 = include_str!("../migrations/002_checkpoints.sql");
90
91        // Execute migrations
92        sqlx::query(migration_001)
93            .execute(&self.pool)
94            .await
95            .map_err(|e| StateStoreError::Database(format!("Migration 001 failed: {}", e)))?;
96
97        sqlx::query(migration_002)
98            .execute(&self.pool)
99            .await
100            .map_err(|e| StateStoreError::Database(format!("Migration 002 failed: {}", e)))?;
101
102        info!("Database migrations completed successfully");
103        Ok(())
104    }
105
106    /// Get the connection pool (for advanced use cases).
107    pub fn pool(&self) -> &PgPool {
108        &self.pool
109    }
110}
111
112#[async_trait]
113impl StateStore for PostgresStateStore {
114    async fn save_workflow_state(&self, state: &WorkflowState) -> StateStoreResult<()> {
115        debug!("Saving workflow state: id={}, workflow_id={}", state.id, state.workflow_id);
116
117        let mut tx = self.pool.begin().await?;
118
119        // Serialize context to JSON string
120        let context_json = serde_json::to_string(&state.context)?;
121
122        // Upsert workflow state
123        sqlx::query(
124            r#"
125            INSERT INTO workflow_states (
126                id, workflow_id, workflow_name, status, user_id,
127                started_at, updated_at, completed_at, context, error
128            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
129            ON CONFLICT (id) DO UPDATE SET
130                status = EXCLUDED.status,
131                updated_at = EXCLUDED.updated_at,
132                completed_at = EXCLUDED.completed_at,
133                context = EXCLUDED.context,
134                error = EXCLUDED.error
135            "#
136        )
137        .bind(state.id)
138        .bind(&state.workflow_id)
139        .bind(&state.workflow_name)
140        .bind(state.status.to_string())
141        .bind(&state.user_id)
142        .bind(state.started_at)
143        .bind(state.updated_at)
144        .bind(state.completed_at)
145        .bind(context_json)
146        .bind(&state.error)
147        .execute(&mut *tx)
148        .await?;
149
150        // Save step states
151        for (step_id, step_state) in &state.steps {
152            let outputs_json = serde_json::to_string(&step_state.outputs)?;
153
154            sqlx::query(
155                r#"
156                INSERT INTO step_states (
157                    workflow_state_id, step_id, status, started_at, completed_at,
158                    outputs, error, retry_count
159                ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
160                ON CONFLICT (workflow_state_id, step_id) DO UPDATE SET
161                    status = EXCLUDED.status,
162                    started_at = EXCLUDED.started_at,
163                    completed_at = EXCLUDED.completed_at,
164                    outputs = EXCLUDED.outputs,
165                    error = EXCLUDED.error,
166                    retry_count = EXCLUDED.retry_count
167                "#
168            )
169            .bind(state.id)
170            .bind(step_id)
171            .bind(step_state.status.to_string())
172            .bind(step_state.started_at)
173            .bind(step_state.completed_at)
174            .bind(outputs_json)
175            .bind(&step_state.error)
176            .bind(step_state.retry_count)
177            .execute(&mut *tx)
178            .await?;
179        }
180
181        tx.commit().await?;
182
183        debug!("Workflow state saved successfully: id={}", state.id);
184        Ok(())
185    }
186
187    async fn load_workflow_state(&self, id: &Uuid) -> StateStoreResult<WorkflowState> {
188        debug!("Loading workflow state: id={}", id);
189
190        // Load workflow state
191        let row = sqlx::query(
192            r#"
193            SELECT id, workflow_id, workflow_name, status, user_id,
194                   started_at, updated_at, completed_at, context, error
195            FROM workflow_states
196            WHERE id = $1
197            "#
198        )
199        .bind(id)
200        .fetch_one(&self.pool)
201        .await?;
202
203        let workflow_id: Uuid = row.get("id");
204        let status_str: String = row.get("status");
205        let status = WorkflowStatus::from_str(&status_str)
206            .map_err(StateStoreError::InvalidState)?;
207
208        let context_str: String = row.get("context");
209        let context = serde_json::from_str(&context_str)?;
210
211        let mut state = WorkflowState {
212            id: workflow_id,
213            workflow_id: row.get("workflow_id"),
214            workflow_name: row.get("workflow_name"),
215            status,
216            user_id: row.get("user_id"),
217            started_at: row.get("started_at"),
218            updated_at: row.get("updated_at"),
219            completed_at: row.get("completed_at"),
220            context,
221            error: row.get("error"),
222            steps: Default::default(),
223        };
224
225        // Load step states
226        let step_rows = sqlx::query(
227            r#"
228            SELECT step_id, status, started_at, completed_at,
229                   outputs, error, retry_count
230            FROM step_states
231            WHERE workflow_state_id = $1
232            "#
233        )
234        .bind(workflow_id)
235        .fetch_all(&self.pool)
236        .await?;
237
238        for step_row in step_rows {
239            let step_id: String = step_row.get("step_id");
240            let status_str: String = step_row.get("status");
241            let status = crate::models::StepStatus::from_str(&status_str)
242                .map_err(StateStoreError::InvalidState)?;
243
244            let outputs_str: Option<String> = step_row.get("outputs");
245            let outputs = if let Some(json_str) = outputs_str {
246                serde_json::from_str(&json_str)?
247            } else {
248                serde_json::Value::Null
249            };
250
251            let step_state = StepState {
252                step_id: step_id.clone(),
253                status,
254                started_at: step_row.get("started_at"),
255                completed_at: step_row.get("completed_at"),
256                outputs,
257                error: step_row.get("error"),
258                retry_count: step_row.get("retry_count"),
259            };
260
261            state.steps.insert(step_id, step_state);
262        }
263
264        debug!("Workflow state loaded successfully: id={}", id);
265        Ok(state)
266    }
267
268    async fn load_workflow_state_by_workflow_id(&self, workflow_id: &str) -> StateStoreResult<WorkflowState> {
269        debug!("Loading workflow state by workflow_id: {}", workflow_id);
270
271        // Get the most recent state for this workflow_id
272        let row = sqlx::query(
273            r#"
274            SELECT id
275            FROM workflow_states
276            WHERE workflow_id = $1
277            ORDER BY updated_at DESC
278            LIMIT 1
279            "#
280        )
281        .bind(workflow_id)
282        .fetch_one(&self.pool)
283        .await?;
284
285        let id: Uuid = row.get("id");
286        self.load_workflow_state(&id).await
287    }
288
289    async fn list_active_workflows(&self) -> StateStoreResult<Vec<WorkflowState>> {
290        debug!("Listing active workflows");
291
292        let rows = sqlx::query(
293            r#"
294            SELECT id
295            FROM workflow_states
296            WHERE status IN ('running', 'pending', 'paused')
297            ORDER BY updated_at DESC
298            "#
299        )
300        .fetch_all(&self.pool)
301        .await?;
302
303        let mut workflows = Vec::new();
304        for row in rows {
305            let id: Uuid = row.get("id");
306            match self.load_workflow_state(&id).await {
307                Ok(state) => workflows.push(state),
308                Err(e) => {
309                    warn!("Failed to load workflow state {}: {}", id, e);
310                }
311            }
312        }
313
314        debug!("Found {} active workflows", workflows.len());
315        Ok(workflows)
316    }
317
318    async fn create_checkpoint(&self, checkpoint: &Checkpoint) -> StateStoreResult<()> {
319        debug!("Creating checkpoint: id={}, workflow_state_id={}", checkpoint.id, checkpoint.workflow_state_id);
320
321        let snapshot_json = serde_json::to_string(&checkpoint.snapshot)?;
322
323        sqlx::query(
324            r#"
325            INSERT INTO checkpoints (id, workflow_state_id, step_id, timestamp, snapshot)
326            VALUES ($1, $2, $3, $4, $5)
327            "#
328        )
329        .bind(checkpoint.id)
330        .bind(checkpoint.workflow_state_id)
331        .bind(&checkpoint.step_id)
332        .bind(checkpoint.timestamp)
333        .bind(snapshot_json)
334        .execute(&self.pool)
335        .await?;
336
337        // Cleanup old checkpoints (keep last 10)
338        self.cleanup_old_checkpoints(&checkpoint.workflow_state_id, 10).await?;
339
340        debug!("Checkpoint created successfully: id={}", checkpoint.id);
341        Ok(())
342    }
343
344    async fn get_latest_checkpoint(&self, workflow_state_id: &Uuid) -> StateStoreResult<Option<Checkpoint>> {
345        debug!("Getting latest checkpoint for workflow_state_id={}", workflow_state_id);
346
347        let row_opt = sqlx::query(
348            r#"
349            SELECT id, workflow_state_id, step_id, timestamp, snapshot
350            FROM checkpoints
351            WHERE workflow_state_id = $1
352            ORDER BY timestamp DESC
353            LIMIT 1
354            "#
355        )
356        .bind(workflow_state_id)
357        .fetch_optional(&self.pool)
358        .await?;
359
360        if let Some(row) = row_opt {
361            let snapshot_str: String = row.get("snapshot");
362            let snapshot = serde_json::from_str(&snapshot_str)?;
363
364            let checkpoint = Checkpoint {
365                id: row.get("id"),
366                workflow_state_id: row.get("workflow_state_id"),
367                step_id: row.get("step_id"),
368                timestamp: row.get("timestamp"),
369                snapshot,
370            };
371
372            debug!("Found latest checkpoint: id={}", checkpoint.id);
373            Ok(Some(checkpoint))
374        } else {
375            debug!("No checkpoints found for workflow_state_id={}", workflow_state_id);
376            Ok(None)
377        }
378    }
379
380    async fn restore_from_checkpoint(&self, checkpoint_id: &Uuid) -> StateStoreResult<WorkflowState> {
381        debug!("Restoring from checkpoint: id={}", checkpoint_id);
382
383        let row = sqlx::query(
384            r#"
385            SELECT snapshot
386            FROM checkpoints
387            WHERE id = $1
388            "#
389        )
390        .bind(checkpoint_id)
391        .fetch_one(&self.pool)
392        .await?;
393
394        let snapshot_str: String = row.get("snapshot");
395        let state: WorkflowState = serde_json::from_str(&snapshot_str)?;
396
397        debug!("Successfully restored state from checkpoint: id={}", checkpoint_id);
398        Ok(state)
399    }
400
401    async fn delete_old_states(&self, older_than: DateTime<Utc>) -> StateStoreResult<u64> {
402        debug!("Deleting states older than: {}", older_than);
403
404        let result = sqlx::query(
405            r#"
406            DELETE FROM workflow_states
407            WHERE updated_at < $1
408              AND status IN ('completed', 'failed')
409            "#
410        )
411        .bind(older_than)
412        .execute(&self.pool)
413        .await?;
414
415        let deleted = result.rows_affected();
416        debug!("Deleted {} old workflow states", deleted);
417        Ok(deleted)
418    }
419
420    async fn cleanup_old_checkpoints(&self, workflow_state_id: &Uuid, keep_count: usize) -> StateStoreResult<u64> {
421        debug!("Cleaning up old checkpoints for workflow_state_id={}, keeping last {}", workflow_state_id, keep_count);
422
423        // PostgreSQL approach: delete checkpoints not in the top N
424        let result = sqlx::query(
425            r#"
426            DELETE FROM checkpoints
427            WHERE workflow_state_id = $1
428              AND id NOT IN (
429                SELECT id FROM checkpoints
430                WHERE workflow_state_id = $1
431                ORDER BY timestamp DESC
432                LIMIT $2
433              )
434            "#
435        )
436        .bind(workflow_state_id)
437        .bind(keep_count as i64)
438        .execute(&self.pool)
439        .await?;
440
441        let deleted = result.rows_affected();
442        if deleted > 0 {
443            debug!("Cleaned up {} old checkpoints", deleted);
444        }
445        Ok(deleted)
446    }
447
448    async fn health_check(&self) -> StateStoreResult<()> {
449        debug!("Performing health check");
450
451        // Simple query to verify database connectivity
452        sqlx::query("SELECT 1")
453            .fetch_one(&self.pool)
454            .await
455            .map_err(|e| StateStoreError::Connection(format!("Health check failed: {}", e)))?;
456
457        debug!("Health check passed");
458        Ok(())
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465    use crate::models::WorkflowState;
466    use serde_json::json;
467
468    // Integration tests require a running PostgreSQL instance
469    // These are disabled by default - run with:
470    // TEST_DATABASE_URL=postgresql://... cargo test -- --ignored
471
472    #[tokio::test]
473    #[ignore]
474    async fn test_postgres_state_store_integration() {
475        let database_url = std::env::var("TEST_DATABASE_URL")
476            .unwrap_or_else(|_| "postgresql://postgres:postgres@localhost/test_workflows".to_string());
477
478        let store = PostgresStateStore::new(&database_url, Some(2), Some(5))
479            .await
480            .expect("Failed to create state store");
481
482        // Test health check
483        store.health_check().await.expect("Health check failed");
484
485        // Create test workflow state
486        let mut state = WorkflowState::new(
487            "test-workflow-1",
488            "Test Workflow",
489            Some("user-123".to_string()),
490            json!({"inputs": {"test": "value"}}),
491        );
492        state.mark_running();
493
494        // Save state
495        store.save_workflow_state(&state).await.expect("Failed to save state");
496
497        // Load state
498        let loaded = store.load_workflow_state(&state.id).await.expect("Failed to load state");
499        assert_eq!(loaded.workflow_id, state.workflow_id);
500        assert_eq!(loaded.status, WorkflowStatus::Running);
501
502        // List active workflows
503        let active = store.list_active_workflows().await.expect("Failed to list active workflows");
504        assert!(!active.is_empty());
505
506        println!("✅ PostgreSQL integration test passed");
507    }
508}