Skip to main content

engram/
message_postgres.rs

1//! PostgreSQL-backed `MessageStore` implementation.
2//!
3//! Uses native Postgres types: `UUID` for ids, `TIMESTAMPTZ` for timestamps,
4//! `JSONB` for metadata. Placeholder syntax uses `$1, $2, …`.
5
6use crate::message::{ChatMessage, MessageId, MessageStore};
7use crate::scope::Scope;
8use crate::store::MemoryError;
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use sqlx::PgPool;
12use uuid::Uuid;
13
14// ---------------------------------------------------------------------------
15// DDL
16// ---------------------------------------------------------------------------
17
18/// DDL statements for the Postgres messages table. Each element is a single
19/// statement to be executed independently.
20const PG_MESSAGE_STORE_DDL: &[&str] = &[
21    r#"
22    CREATE TABLE IF NOT EXISTS messages (
23        id              UUID PRIMARY KEY,
24        conversation_id TEXT NOT NULL,
25        role            TEXT NOT NULL,
26        content         TEXT NOT NULL,
27        org_id          TEXT NOT NULL DEFAULT 'default',
28        user_id         TEXT,
29        seq             INTEGER NOT NULL,
30        created_at      TIMESTAMPTZ NOT NULL,
31        metadata        JSONB NOT NULL DEFAULT 'null'
32    )
33    "#,
34    "CREATE INDEX IF NOT EXISTS idx_pg_messages_conversation ON messages (conversation_id, seq)",
35    "CREATE INDEX IF NOT EXISTS idx_pg_messages_org_user ON messages (org_id, user_id)",
36];
37
38// ---------------------------------------------------------------------------
39// PostgresMessageStore
40// ---------------------------------------------------------------------------
41
42pub struct PostgresMessageStore {
43    pool: PgPool,
44}
45
46impl PostgresMessageStore {
47    pub fn new(pool: PgPool) -> Self {
48        Self { pool }
49    }
50
51    /// Open a connection pool from a database URL and return a store.
52    pub async fn open(database_url: &str) -> Result<Self, sqlx::Error> {
53        let pool = PgPool::connect(database_url).await?;
54        Ok(Self { pool })
55    }
56
57    /// Apply the DDL. Safe to call multiple times (uses `IF NOT EXISTS`).
58    pub async fn migrate(&self) -> Result<(), sqlx::Error> {
59        for stmt in PG_MESSAGE_STORE_DDL {
60            sqlx::query(stmt).execute(&self.pool).await?;
61        }
62        Ok(())
63    }
64}
65
66// ---------------------------------------------------------------------------
67// Internal row type
68// ---------------------------------------------------------------------------
69
70#[derive(sqlx::FromRow)]
71struct MessageRow {
72    id: Uuid,
73    conversation_id: String,
74    role: String,
75    content: String,
76    org_id: String,
77    user_id: Option<String>,
78    seq: i32,
79    created_at: DateTime<Utc>,
80    metadata: serde_json::Value,
81}
82
83// ---------------------------------------------------------------------------
84// Conversion helpers
85// ---------------------------------------------------------------------------
86
87fn row_to_message(row: MessageRow) -> Result<ChatMessage, MemoryError> {
88    let metadata: serde_json::Map<String, serde_json::Value> = match &row.metadata {
89        serde_json::Value::Null => serde_json::Map::new(),
90        serde_json::Value::Object(map) => map.clone(),
91        other => serde_json::from_value(other.clone())
92            .map_err(|e| MemoryError::Serialization(e.to_string()))?,
93    };
94
95    Ok(ChatMessage {
96        id: row.id,
97        conversation_id: row.conversation_id,
98        role: row.role,
99        content: row.content,
100        scope: Scope {
101            org_id: row.org_id,
102            agent_id: None,
103            user_id: row.user_id,
104            session_id: None,
105        },
106        seq: row.seq,
107        created_at: row.created_at,
108        metadata,
109    })
110}
111
112// ---------------------------------------------------------------------------
113// MessageStore implementation
114// ---------------------------------------------------------------------------
115
116#[async_trait]
117impl MessageStore for PostgresMessageStore {
118    async fn save_messages(
119        &self,
120        conversation_id: &str,
121        messages: &[ChatMessage],
122        scope: &Scope,
123    ) -> Result<Vec<MessageId>, MemoryError> {
124        let mut ids = Vec::with_capacity(messages.len());
125
126        for msg in messages {
127            let metadata_json = if msg.metadata.is_empty() {
128                serde_json::Value::Null
129            } else {
130                serde_json::to_value(&msg.metadata)
131                    .map_err(|e| MemoryError::Serialization(e.to_string()))?
132            };
133
134            sqlx::query(
135                r#"
136                INSERT INTO messages
137                    (id, conversation_id, role, content, org_id, user_id, seq, created_at, metadata)
138                VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
139                "#,
140            )
141            .bind(msg.id)
142            .bind(conversation_id)
143            .bind(&msg.role)
144            .bind(&msg.content)
145            .bind(&scope.org_id)
146            .bind(scope.user_id.as_deref())
147            .bind(msg.seq)
148            .bind(msg.created_at)
149            .bind(&metadata_json)
150            .execute(&self.pool)
151            .await
152            .map_err(|e| MemoryError::Database(e.to_string()))?;
153
154            ids.push(msg.id);
155        }
156
157        Ok(ids)
158    }
159
160    async fn get_messages(
161        &self,
162        conversation_id: &str,
163        last_n: Option<usize>,
164        scope: &Scope,
165    ) -> Result<Vec<ChatMessage>, MemoryError> {
166        let rows = match last_n {
167            Some(n) => {
168                // Subquery: pick the last N rows by seq DESC, then re-order ASC.
169                let sql = r#"
170                    SELECT * FROM (
171                        SELECT * FROM messages
172                        WHERE conversation_id = $1 AND org_id = $2
173                        ORDER BY seq DESC
174                        LIMIT $3
175                    ) sub ORDER BY seq ASC
176                "#;
177
178                sqlx::query_as::<_, MessageRow>(sql)
179                    .bind(conversation_id)
180                    .bind(&scope.org_id)
181                    .bind(n as i64)
182                    .fetch_all(&self.pool)
183                    .await
184                    .map_err(|e| MemoryError::Database(e.to_string()))?
185            }
186            None => {
187                let sql = r#"
188                    SELECT * FROM messages
189                    WHERE conversation_id = $1 AND org_id = $2
190                    ORDER BY seq ASC
191                "#;
192
193                sqlx::query_as::<_, MessageRow>(sql)
194                    .bind(conversation_id)
195                    .bind(&scope.org_id)
196                    .fetch_all(&self.pool)
197                    .await
198                    .map_err(|e| MemoryError::Database(e.to_string()))?
199            }
200        };
201
202        rows.into_iter().map(row_to_message).collect()
203    }
204
205    async fn list_conversations(&self, scope: &Scope) -> Result<Vec<String>, MemoryError> {
206        let sql = r#"
207            SELECT DISTINCT conversation_id
208            FROM messages
209            WHERE org_id = $1
210            ORDER BY conversation_id
211        "#;
212
213        let rows: Vec<(String,)> = sqlx::query_as(sql)
214            .bind(&scope.org_id)
215            .fetch_all(&self.pool)
216            .await
217            .map_err(|e| MemoryError::Database(e.to_string()))?;
218
219        Ok(rows.into_iter().map(|(c,)| c).collect())
220    }
221
222    async fn delete_messages(
223        &self,
224        conversation_id: &str,
225        scope: &Scope,
226    ) -> Result<u64, MemoryError> {
227        let result = sqlx::query("DELETE FROM messages WHERE conversation_id = $1 AND org_id = $2")
228            .bind(conversation_id)
229            .bind(&scope.org_id)
230            .execute(&self.pool)
231            .await
232            .map_err(|e| MemoryError::Database(e.to_string()))?;
233
234        Ok(result.rows_affected())
235    }
236}