engram/
message_postgres.rs1use crate::message::{ChatMessage, MessageId, MessageStore};
7use crate::scope::Scope;
8use crate::store::MemoryError;
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use sqlx::PgPool;
12use uuid::Uuid;
13
14const PG_MESSAGE_STORE_DDL: &[&str] = &[
21 r#"
22 CREATE TABLE IF NOT EXISTS messages (
23 id UUID PRIMARY KEY,
24 conversation_id TEXT NOT NULL,
25 role TEXT NOT NULL,
26 content TEXT NOT NULL,
27 org_id TEXT NOT NULL DEFAULT 'default',
28 user_id TEXT,
29 seq INTEGER NOT NULL,
30 created_at TIMESTAMPTZ NOT NULL,
31 metadata JSONB NOT NULL DEFAULT 'null'
32 )
33 "#,
34 "CREATE INDEX IF NOT EXISTS idx_pg_messages_conversation ON messages (conversation_id, seq)",
35 "CREATE INDEX IF NOT EXISTS idx_pg_messages_org_user ON messages (org_id, user_id)",
36];
37
38pub struct PostgresMessageStore {
43 pool: PgPool,
44}
45
46impl PostgresMessageStore {
47 pub fn new(pool: PgPool) -> Self {
48 Self { pool }
49 }
50
51 pub async fn open(database_url: &str) -> Result<Self, sqlx::Error> {
53 let pool = PgPool::connect(database_url).await?;
54 Ok(Self { pool })
55 }
56
57 pub async fn migrate(&self) -> Result<(), sqlx::Error> {
59 for stmt in PG_MESSAGE_STORE_DDL {
60 sqlx::query(stmt).execute(&self.pool).await?;
61 }
62 Ok(())
63 }
64}
65
66#[derive(sqlx::FromRow)]
71struct MessageRow {
72 id: Uuid,
73 conversation_id: String,
74 role: String,
75 content: String,
76 org_id: String,
77 user_id: Option<String>,
78 seq: i32,
79 created_at: DateTime<Utc>,
80 metadata: serde_json::Value,
81}
82
83fn row_to_message(row: MessageRow) -> Result<ChatMessage, MemoryError> {
88 let metadata: serde_json::Map<String, serde_json::Value> = match &row.metadata {
89 serde_json::Value::Null => serde_json::Map::new(),
90 serde_json::Value::Object(map) => map.clone(),
91 other => serde_json::from_value(other.clone())
92 .map_err(|e| MemoryError::Serialization(e.to_string()))?,
93 };
94
95 Ok(ChatMessage {
96 id: row.id,
97 conversation_id: row.conversation_id,
98 role: row.role,
99 content: row.content,
100 scope: Scope {
101 org_id: row.org_id,
102 agent_id: None,
103 user_id: row.user_id,
104 session_id: None,
105 },
106 seq: row.seq,
107 created_at: row.created_at,
108 metadata,
109 })
110}
111
112#[async_trait]
117impl MessageStore for PostgresMessageStore {
118 async fn save_messages(
119 &self,
120 conversation_id: &str,
121 messages: &[ChatMessage],
122 scope: &Scope,
123 ) -> Result<Vec<MessageId>, MemoryError> {
124 let mut ids = Vec::with_capacity(messages.len());
125
126 for msg in messages {
127 let metadata_json = if msg.metadata.is_empty() {
128 serde_json::Value::Null
129 } else {
130 serde_json::to_value(&msg.metadata)
131 .map_err(|e| MemoryError::Serialization(e.to_string()))?
132 };
133
134 sqlx::query(
135 r#"
136 INSERT INTO messages
137 (id, conversation_id, role, content, org_id, user_id, seq, created_at, metadata)
138 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
139 "#,
140 )
141 .bind(msg.id)
142 .bind(conversation_id)
143 .bind(&msg.role)
144 .bind(&msg.content)
145 .bind(&scope.org_id)
146 .bind(scope.user_id.as_deref())
147 .bind(msg.seq)
148 .bind(msg.created_at)
149 .bind(&metadata_json)
150 .execute(&self.pool)
151 .await
152 .map_err(|e| MemoryError::Database(e.to_string()))?;
153
154 ids.push(msg.id);
155 }
156
157 Ok(ids)
158 }
159
160 async fn get_messages(
161 &self,
162 conversation_id: &str,
163 last_n: Option<usize>,
164 scope: &Scope,
165 ) -> Result<Vec<ChatMessage>, MemoryError> {
166 let rows = match last_n {
167 Some(n) => {
168 let sql = r#"
170 SELECT * FROM (
171 SELECT * FROM messages
172 WHERE conversation_id = $1 AND org_id = $2
173 ORDER BY seq DESC
174 LIMIT $3
175 ) sub ORDER BY seq ASC
176 "#;
177
178 sqlx::query_as::<_, MessageRow>(sql)
179 .bind(conversation_id)
180 .bind(&scope.org_id)
181 .bind(n as i64)
182 .fetch_all(&self.pool)
183 .await
184 .map_err(|e| MemoryError::Database(e.to_string()))?
185 }
186 None => {
187 let sql = r#"
188 SELECT * FROM messages
189 WHERE conversation_id = $1 AND org_id = $2
190 ORDER BY seq ASC
191 "#;
192
193 sqlx::query_as::<_, MessageRow>(sql)
194 .bind(conversation_id)
195 .bind(&scope.org_id)
196 .fetch_all(&self.pool)
197 .await
198 .map_err(|e| MemoryError::Database(e.to_string()))?
199 }
200 };
201
202 rows.into_iter().map(row_to_message).collect()
203 }
204
205 async fn list_conversations(&self, scope: &Scope) -> Result<Vec<String>, MemoryError> {
206 let sql = r#"
207 SELECT DISTINCT conversation_id
208 FROM messages
209 WHERE org_id = $1
210 ORDER BY conversation_id
211 "#;
212
213 let rows: Vec<(String,)> = sqlx::query_as(sql)
214 .bind(&scope.org_id)
215 .fetch_all(&self.pool)
216 .await
217 .map_err(|e| MemoryError::Database(e.to_string()))?;
218
219 Ok(rows.into_iter().map(|(c,)| c).collect())
220 }
221
222 async fn delete_messages(
223 &self,
224 conversation_id: &str,
225 scope: &Scope,
226 ) -> Result<u64, MemoryError> {
227 let result = sqlx::query("DELETE FROM messages WHERE conversation_id = $1 AND org_id = $2")
228 .bind(conversation_id)
229 .bind(&scope.org_id)
230 .execute(&self.pool)
231 .await
232 .map_err(|e| MemoryError::Database(e.to_string()))?;
233
234 Ok(result.rows_affected())
235 }
236}