graph_flow/
storage_postgres.rs1use async_trait::async_trait;
2use serde_json;
3use sqlx::{postgres::PgPoolOptions, Pool, Postgres};
4use std::sync::Arc;
5
6use crate::{Session, error::{Result, GraphError}, storage::SessionStorage};
7
8pub struct PostgresSessionStorage {
9 pool: Arc<Pool<Postgres>>,
10}
11
12impl PostgresSessionStorage {
13 pub async fn connect(database_url: &str) -> Result<Self> {
14 let pool = PgPoolOptions::new()
15 .max_connections(5)
16 .connect(database_url)
17 .await
18 .map_err(|e| GraphError::StorageError(format!("Failed to connect to Postgres: {e}")))?;
19
20 Self::migrate(&pool).await?;
21 Ok(Self { pool: Arc::new(pool) })
22 }
23
24 async fn migrate(pool: &Pool<Postgres>) -> Result<()> {
25 sqlx::query(
26 r#"
27 CREATE TABLE IF NOT EXISTS sessions (
28 id UUID PRIMARY KEY,
29 graph_id TEXT NOT NULL,
30 current_task_id TEXT NOT NULL,
31 status_message TEXT,
32 context JSONB NOT NULL,
33 created_at TIMESTAMPTZ DEFAULT NOW(),
34 updated_at TIMESTAMPTZ DEFAULT NOW()
35 );
36 "#,
37 )
38 .execute(pool)
39 .await
40 .map_err(|e| GraphError::StorageError(format!("Migration failed: {e}")))?;
41 Ok(())
42 }
43}
44
45#[async_trait]
46impl SessionStorage for PostgresSessionStorage {
47 async fn save(&self, session: Session) -> Result<()> {
48 let context_json = serde_json::to_value(&session.context)
49 .map_err(|e| GraphError::StorageError(format!("Context serialization failed: {e}")))?;
50
51 let mut tx = self.pool.begin().await
53 .map_err(|e| GraphError::StorageError(format!("Failed to start transaction: {e}")))?;
54
55 sqlx::query(
56 r#"
57 INSERT INTO sessions (id, graph_id, current_task_id, status_message, context, updated_at)
58 VALUES ($1::uuid, $2, $3, $4, $5, NOW())
59 ON CONFLICT (id) DO UPDATE
60 SET graph_id = EXCLUDED.graph_id,
61 current_task_id = EXCLUDED.current_task_id,
62 status_message = EXCLUDED.status_message,
63 context = EXCLUDED.context,
64 updated_at = NOW()
65 WHERE sessions.updated_at <= EXCLUDED.updated_at -- Prevent overwriting newer data
66 "#,
67 )
68 .bind(&session.id)
69 .bind(&session.graph_id)
70 .bind(&session.current_task_id)
71 .bind(&session.status_message)
72 .bind(&context_json)
73 .execute(&mut *tx)
74 .await
75 .map_err(|e| GraphError::StorageError(format!("Failed to save session: {e}")))?;
76
77 tx.commit().await
78 .map_err(|e| GraphError::StorageError(format!("Failed to commit transaction: {e}")))?;
79
80 Ok(())
81 }
82
83 async fn get(&self, id: &str) -> Result<Option<Session>> {
84 let row = sqlx::query_as::<_, (String, String, String, Option<String>, serde_json::Value)>(
85 r#"
86 SELECT id::text, graph_id, current_task_id, status_message, context
87 FROM sessions
88 WHERE id = $1::uuid
89 "#,
90 )
91 .bind(id)
92 .fetch_optional(&*self.pool)
93 .await
94 .map_err(|e| GraphError::StorageError(format!("Failed to fetch session: {e}")))?;
95
96 if let Some((session_id, graph_id, current_task_id, status_message, context_json)) = row {
97 let context: crate::Context = serde_json::from_value(context_json)
98 .map_err(|e| GraphError::StorageError(format!("Context deserialization failed: {e}")))?;
99 Ok(Some(Session {
100 id: session_id,
101 graph_id,
102 current_task_id,
103 status_message,
104 context,
105 }))
106 } else {
107 Ok(None)
108 }
109 }
110
111 async fn delete(&self, id: &str) -> Result<()> {
112 sqlx::query(
113 r#"
114 DELETE FROM sessions WHERE id = $1::uuid
115 "#,
116 )
117 .bind(id)
118 .execute(&*self.pool)
119 .await
120 .map_err(|e| GraphError::StorageError(format!("Failed to delete session: {e}")))?;
121 Ok(())
122 }
123}