graph_flow/
storage_postgres.rs

1use async_trait::async_trait;
2use serde_json;
3use sqlx::{postgres::PgPoolOptions, Pool, Postgres};
4use std::sync::Arc;
5
6use crate::{Session, error::{Result, GraphError}, storage::SessionStorage};
7
8pub struct PostgresSessionStorage {
9    pool: Arc<Pool<Postgres>>,
10}
11
12impl PostgresSessionStorage {
13    pub async fn connect(database_url: &str) -> Result<Self> {
14        let pool = PgPoolOptions::new()
15            .max_connections(5)
16            .connect(database_url)
17            .await
18            .map_err(|e| GraphError::StorageError(format!("Failed to connect to Postgres: {e}")))?;
19
20        Self::migrate(&pool).await?;
21        Ok(Self { pool: Arc::new(pool) })
22    }
23
24    async fn migrate(pool: &Pool<Postgres>) -> Result<()> {
25        sqlx::query(
26            r#"
27            CREATE TABLE IF NOT EXISTS sessions (
28                id UUID PRIMARY KEY,
29                graph_id TEXT NOT NULL,
30                current_task_id TEXT NOT NULL,
31                status_message TEXT,
32                context JSONB NOT NULL,
33                created_at TIMESTAMPTZ DEFAULT NOW(),
34                updated_at TIMESTAMPTZ DEFAULT NOW()
35            );
36            "#,
37        )
38        .execute(pool)
39        .await
40        .map_err(|e| GraphError::StorageError(format!("Migration failed: {e}")))?;
41        Ok(())
42    }
43}
44
45#[async_trait]
46impl SessionStorage for PostgresSessionStorage {
47    async fn save(&self, session: Session) -> Result<()> {
48        let context_json = serde_json::to_value(&session.context)
49            .map_err(|e| GraphError::StorageError(format!("Context serialization failed: {e}")))?;
50
51        // Use a transaction to ensure atomicity
52        let mut tx = self.pool.begin().await
53            .map_err(|e| GraphError::StorageError(format!("Failed to start transaction: {e}")))?;
54
55        sqlx::query(
56            r#"
57            INSERT INTO sessions (id, graph_id, current_task_id, status_message, context, updated_at)
58            VALUES ($1::uuid, $2, $3, $4, $5, NOW())
59            ON CONFLICT (id) DO UPDATE
60            SET graph_id = EXCLUDED.graph_id,
61                current_task_id = EXCLUDED.current_task_id,
62                status_message = EXCLUDED.status_message,
63                context = EXCLUDED.context,
64                updated_at = NOW()
65            WHERE sessions.updated_at <= EXCLUDED.updated_at  -- Prevent overwriting newer data
66            "#,
67        )
68        .bind(&session.id)
69        .bind(&session.graph_id)
70        .bind(&session.current_task_id)
71        .bind(&session.status_message)
72        .bind(&context_json)
73        .execute(&mut *tx)
74        .await
75        .map_err(|e| GraphError::StorageError(format!("Failed to save session: {e}")))?;
76        
77        tx.commit().await
78            .map_err(|e| GraphError::StorageError(format!("Failed to commit transaction: {e}")))?;
79        
80        Ok(())
81    }
82
83    async fn get(&self, id: &str) -> Result<Option<Session>> {
84        let row = sqlx::query_as::<_, (String, String, String, Option<String>, serde_json::Value)>(
85            r#"
86            SELECT id::text, graph_id, current_task_id, status_message, context
87            FROM sessions
88            WHERE id = $1::uuid
89            "#,
90        )
91        .bind(id)
92        .fetch_optional(&*self.pool)
93        .await
94        .map_err(|e| GraphError::StorageError(format!("Failed to fetch session: {e}")))?;
95
96        if let Some((session_id, graph_id, current_task_id, status_message, context_json)) = row {
97            let context: crate::Context = serde_json::from_value(context_json)
98                .map_err(|e| GraphError::StorageError(format!("Context deserialization failed: {e}")))?;
99            Ok(Some(Session {
100                id: session_id,
101                graph_id,
102                current_task_id,
103                status_message,
104                context,
105            }))
106        } else {
107            Ok(None)
108        }
109    }
110
111    async fn delete(&self, id: &str) -> Result<()> {
112        sqlx::query(
113            r#"
114            DELETE FROM sessions WHERE id = $1::uuid
115            "#,
116        )
117        .bind(id)
118        .execute(&*self.pool)
119        .await
120        .map_err(|e| GraphError::StorageError(format!("Failed to delete session: {e}")))?;
121        Ok(())
122    }
123}