Skip to main content

engram/
message_sqlite.rs

1//! SQLite-backed `MessageStore` implementation.
2//!
3//! All `DateTime<Utc>` values are stored as RFC 3339 strings.
4//! UUIDs are stored as TEXT. `metadata` is a JSON object string (or `"null"`
5//! when empty).
6
7use crate::message::{ChatMessage, MessageId, MessageStore};
8use crate::scope::Scope;
9use crate::store::MemoryError;
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use sqlx::SqlitePool;
13use uuid::Uuid;
14
15// ---------------------------------------------------------------------------
16// DDL
17// ---------------------------------------------------------------------------
18
19pub const MESSAGE_STORE_DDL: &str = r#"
20CREATE TABLE IF NOT EXISTS messages (
21    id              TEXT PRIMARY KEY,
22    conversation_id TEXT NOT NULL,
23    role            TEXT NOT NULL,
24    content         TEXT NOT NULL,
25    org_id          TEXT NOT NULL DEFAULT 'default',
26    user_id         TEXT,
27    seq             INTEGER NOT NULL,
28    created_at      TEXT NOT NULL,
29    metadata        TEXT NOT NULL DEFAULT 'null'
30);
31CREATE INDEX IF NOT EXISTS idx_messages_conversation ON messages (conversation_id, seq);
32CREATE INDEX IF NOT EXISTS idx_messages_org_user ON messages (org_id, user_id);
33"#;
34
35// ---------------------------------------------------------------------------
36// SqliteMessageStore
37// ---------------------------------------------------------------------------
38
39pub struct SqliteMessageStore {
40    pool: SqlitePool,
41}
42
43impl SqliteMessageStore {
44    pub fn new(pool: SqlitePool) -> Self {
45        Self { pool }
46    }
47
48    /// Open a connection pool from a database URL and return a store.
49    pub async fn open(database_url: &str) -> Result<Self, sqlx::Error> {
50        let pool = SqlitePool::connect(database_url).await?;
51        Ok(Self { pool })
52    }
53
54    /// Apply the DDL. Safe to call multiple times (uses `IF NOT EXISTS`).
55    pub async fn migrate(&self) -> Result<(), sqlx::Error> {
56        for stmt in MESSAGE_STORE_DDL.split(';') {
57            let stmt = stmt.trim();
58            if stmt.is_empty() {
59                continue;
60            }
61            sqlx::query(stmt).execute(&self.pool).await?;
62        }
63        Ok(())
64    }
65}
66
67// ---------------------------------------------------------------------------
68// Internal row type
69// ---------------------------------------------------------------------------
70
71#[derive(sqlx::FromRow)]
72struct MessageRow {
73    id: String,
74    conversation_id: String,
75    role: String,
76    content: String,
77    org_id: String,
78    user_id: Option<String>,
79    seq: i32,
80    created_at: String,
81    metadata: String,
82}
83
84// ---------------------------------------------------------------------------
85// Conversion helpers
86// ---------------------------------------------------------------------------
87
88fn parse_dt(s: &str) -> Result<DateTime<Utc>, MemoryError> {
89    DateTime::parse_from_rfc3339(s)
90        .map(|dt| dt.with_timezone(&Utc))
91        .map_err(|e| MemoryError::Serialization(e.to_string()))
92}
93
94fn row_to_message(row: MessageRow) -> Result<ChatMessage, MemoryError> {
95    let id = Uuid::parse_str(&row.id).map_err(|e| MemoryError::Serialization(e.to_string()))?;
96
97    let metadata: serde_json::Map<String, serde_json::Value> =
98        if row.metadata == "null" || row.metadata.is_empty() {
99            serde_json::Map::new()
100        } else {
101            serde_json::from_str(&row.metadata)
102                .map_err(|e| MemoryError::Serialization(e.to_string()))?
103        };
104
105    Ok(ChatMessage {
106        id,
107        conversation_id: row.conversation_id,
108        role: row.role,
109        content: row.content,
110        scope: Scope {
111            org_id: row.org_id,
112            agent_id: None,
113            user_id: row.user_id,
114            session_id: None,
115        },
116        seq: row.seq,
117        created_at: parse_dt(&row.created_at)?,
118        metadata,
119    })
120}
121
122// ---------------------------------------------------------------------------
123// MessageStore implementation
124// ---------------------------------------------------------------------------
125
126#[async_trait]
127impl MessageStore for SqliteMessageStore {
128    async fn save_messages(
129        &self,
130        conversation_id: &str,
131        messages: &[ChatMessage],
132        scope: &Scope,
133    ) -> Result<Vec<MessageId>, MemoryError> {
134        let mut ids = Vec::with_capacity(messages.len());
135
136        for msg in messages {
137            let metadata_json = if msg.metadata.is_empty() {
138                "null".to_string()
139            } else {
140                serde_json::to_string(&msg.metadata)
141                    .map_err(|e| MemoryError::Serialization(e.to_string()))?
142            };
143
144            sqlx::query(
145                r#"
146                INSERT INTO messages
147                    (id, conversation_id, role, content, org_id, user_id, seq, created_at, metadata)
148                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
149                "#,
150            )
151            .bind(msg.id.to_string())
152            .bind(conversation_id)
153            .bind(&msg.role)
154            .bind(&msg.content)
155            .bind(&scope.org_id)
156            .bind(scope.user_id.as_deref())
157            .bind(msg.seq)
158            .bind(msg.created_at.to_rfc3339())
159            .bind(metadata_json)
160            .execute(&self.pool)
161            .await
162            .map_err(|e| MemoryError::Database(e.to_string()))?;
163
164            ids.push(msg.id);
165        }
166
167        Ok(ids)
168    }
169
170    async fn get_messages(
171        &self,
172        conversation_id: &str,
173        last_n: Option<usize>,
174        scope: &Scope,
175    ) -> Result<Vec<ChatMessage>, MemoryError> {
176        let rows = match last_n {
177            Some(n) => {
178                // Subquery: pick the last N rows by seq DESC, then re-order ASC.
179                let sql = r#"
180                    SELECT * FROM (
181                        SELECT * FROM messages
182                        WHERE conversation_id = ? AND org_id = ?
183                        ORDER BY seq DESC
184                        LIMIT ?
185                    ) sub ORDER BY seq ASC
186                "#;
187
188                sqlx::query_as::<_, MessageRow>(sql)
189                    .bind(conversation_id)
190                    .bind(&scope.org_id)
191                    .bind(n as i64)
192                    .fetch_all(&self.pool)
193                    .await
194                    .map_err(|e| MemoryError::Database(e.to_string()))?
195            }
196            None => {
197                let sql = r#"
198                    SELECT * FROM messages
199                    WHERE conversation_id = ? AND org_id = ?
200                    ORDER BY seq ASC
201                "#;
202
203                sqlx::query_as::<_, MessageRow>(sql)
204                    .bind(conversation_id)
205                    .bind(&scope.org_id)
206                    .fetch_all(&self.pool)
207                    .await
208                    .map_err(|e| MemoryError::Database(e.to_string()))?
209            }
210        };
211
212        rows.into_iter().map(row_to_message).collect()
213    }
214
215    async fn list_conversations(&self, scope: &Scope) -> Result<Vec<String>, MemoryError> {
216        let sql = r#"
217            SELECT DISTINCT conversation_id
218            FROM messages
219            WHERE org_id = ?
220            ORDER BY conversation_id
221        "#;
222
223        let rows: Vec<(String,)> = sqlx::query_as(sql)
224            .bind(&scope.org_id)
225            .fetch_all(&self.pool)
226            .await
227            .map_err(|e| MemoryError::Database(e.to_string()))?;
228
229        Ok(rows.into_iter().map(|(c,)| c).collect())
230    }
231
232    async fn delete_messages(
233        &self,
234        conversation_id: &str,
235        scope: &Scope,
236    ) -> Result<u64, MemoryError> {
237        let result = sqlx::query("DELETE FROM messages WHERE conversation_id = ? AND org_id = ?")
238            .bind(conversation_id)
239            .bind(&scope.org_id)
240            .execute(&self.pool)
241            .await
242            .map_err(|e| MemoryError::Database(e.to_string()))?;
243
244        Ok(result.rows_affected())
245    }
246}
247
248// ---------------------------------------------------------------------------
249// Tests
250// ---------------------------------------------------------------------------
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    async fn test_store() -> SqliteMessageStore {
257        let store = SqliteMessageStore::open("sqlite::memory:").await.unwrap();
258        store.migrate().await.unwrap();
259        store
260    }
261
262    #[tokio::test]
263    async fn message_round_trip() {
264        let store = test_store().await;
265        let scope = Scope::user("acme", "alice");
266
267        let msgs = vec![
268            ChatMessage::new("conv-1", "user", "Hello", scope.clone(), 0),
269            ChatMessage::new("conv-1", "assistant", "Hi there!", scope.clone(), 1),
270            ChatMessage::new("conv-1", "user", "How are you?", scope.clone(), 2),
271        ];
272
273        let ids = store.save_messages("conv-1", &msgs, &scope).await.unwrap();
274        assert_eq!(ids.len(), 3);
275
276        let all = store.get_messages("conv-1", None, &scope).await.unwrap();
277        assert_eq!(all.len(), 3);
278        assert_eq!(all[0].role, "user");
279        assert_eq!(all[1].role, "assistant");
280        assert_eq!(all[2].content, "How are you?");
281
282        let last2 = store.get_messages("conv-1", Some(2), &scope).await.unwrap();
283        assert_eq!(last2.len(), 2);
284        assert_eq!(last2[0].role, "assistant");
285        assert_eq!(last2[1].role, "user");
286
287        let convs = store.list_conversations(&scope).await.unwrap();
288        assert_eq!(convs, vec!["conv-1"]);
289
290        let deleted = store.delete_messages("conv-1", &scope).await.unwrap();
291        assert_eq!(deleted, 3);
292
293        let empty = store.get_messages("conv-1", None, &scope).await.unwrap();
294        assert!(empty.is_empty());
295    }
296}