1use crate::message::{ChatMessage, MessageId, MessageStore};
8use crate::scope::Scope;
9use crate::store::MemoryError;
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use sqlx::SqlitePool;
13use uuid::Uuid;
14
15pub const MESSAGE_STORE_DDL: &str = r#"
20CREATE TABLE IF NOT EXISTS messages (
21 id TEXT PRIMARY KEY,
22 conversation_id TEXT NOT NULL,
23 role TEXT NOT NULL,
24 content TEXT NOT NULL,
25 org_id TEXT NOT NULL DEFAULT 'default',
26 user_id TEXT,
27 seq INTEGER NOT NULL,
28 created_at TEXT NOT NULL,
29 metadata TEXT NOT NULL DEFAULT 'null'
30);
31CREATE INDEX IF NOT EXISTS idx_messages_conversation ON messages (conversation_id, seq);
32CREATE INDEX IF NOT EXISTS idx_messages_org_user ON messages (org_id, user_id);
33"#;
34
35pub struct SqliteMessageStore {
40 pool: SqlitePool,
41}
42
43impl SqliteMessageStore {
44 pub fn new(pool: SqlitePool) -> Self {
45 Self { pool }
46 }
47
48 pub async fn open(database_url: &str) -> Result<Self, sqlx::Error> {
50 let pool = SqlitePool::connect(database_url).await?;
51 Ok(Self { pool })
52 }
53
54 pub async fn migrate(&self) -> Result<(), sqlx::Error> {
56 for stmt in MESSAGE_STORE_DDL.split(';') {
57 let stmt = stmt.trim();
58 if stmt.is_empty() {
59 continue;
60 }
61 sqlx::query(stmt).execute(&self.pool).await?;
62 }
63 Ok(())
64 }
65}
66
67#[derive(sqlx::FromRow)]
72struct MessageRow {
73 id: String,
74 conversation_id: String,
75 role: String,
76 content: String,
77 org_id: String,
78 user_id: Option<String>,
79 seq: i32,
80 created_at: String,
81 metadata: String,
82}
83
84fn parse_dt(s: &str) -> Result<DateTime<Utc>, MemoryError> {
89 DateTime::parse_from_rfc3339(s)
90 .map(|dt| dt.with_timezone(&Utc))
91 .map_err(|e| MemoryError::Serialization(e.to_string()))
92}
93
94fn row_to_message(row: MessageRow) -> Result<ChatMessage, MemoryError> {
95 let id = Uuid::parse_str(&row.id).map_err(|e| MemoryError::Serialization(e.to_string()))?;
96
97 let metadata: serde_json::Map<String, serde_json::Value> =
98 if row.metadata == "null" || row.metadata.is_empty() {
99 serde_json::Map::new()
100 } else {
101 serde_json::from_str(&row.metadata)
102 .map_err(|e| MemoryError::Serialization(e.to_string()))?
103 };
104
105 Ok(ChatMessage {
106 id,
107 conversation_id: row.conversation_id,
108 role: row.role,
109 content: row.content,
110 scope: Scope {
111 org_id: row.org_id,
112 agent_id: None,
113 user_id: row.user_id,
114 session_id: None,
115 },
116 seq: row.seq,
117 created_at: parse_dt(&row.created_at)?,
118 metadata,
119 })
120}
121
122#[async_trait]
127impl MessageStore for SqliteMessageStore {
128 async fn save_messages(
129 &self,
130 conversation_id: &str,
131 messages: &[ChatMessage],
132 scope: &Scope,
133 ) -> Result<Vec<MessageId>, MemoryError> {
134 let mut ids = Vec::with_capacity(messages.len());
135
136 for msg in messages {
137 let metadata_json = if msg.metadata.is_empty() {
138 "null".to_string()
139 } else {
140 serde_json::to_string(&msg.metadata)
141 .map_err(|e| MemoryError::Serialization(e.to_string()))?
142 };
143
144 sqlx::query(
145 r#"
146 INSERT INTO messages
147 (id, conversation_id, role, content, org_id, user_id, seq, created_at, metadata)
148 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
149 "#,
150 )
151 .bind(msg.id.to_string())
152 .bind(conversation_id)
153 .bind(&msg.role)
154 .bind(&msg.content)
155 .bind(&scope.org_id)
156 .bind(scope.user_id.as_deref())
157 .bind(msg.seq)
158 .bind(msg.created_at.to_rfc3339())
159 .bind(metadata_json)
160 .execute(&self.pool)
161 .await
162 .map_err(|e| MemoryError::Database(e.to_string()))?;
163
164 ids.push(msg.id);
165 }
166
167 Ok(ids)
168 }
169
170 async fn get_messages(
171 &self,
172 conversation_id: &str,
173 last_n: Option<usize>,
174 scope: &Scope,
175 ) -> Result<Vec<ChatMessage>, MemoryError> {
176 let rows = match last_n {
177 Some(n) => {
178 let sql = r#"
180 SELECT * FROM (
181 SELECT * FROM messages
182 WHERE conversation_id = ? AND org_id = ?
183 ORDER BY seq DESC
184 LIMIT ?
185 ) sub ORDER BY seq ASC
186 "#;
187
188 sqlx::query_as::<_, MessageRow>(sql)
189 .bind(conversation_id)
190 .bind(&scope.org_id)
191 .bind(n as i64)
192 .fetch_all(&self.pool)
193 .await
194 .map_err(|e| MemoryError::Database(e.to_string()))?
195 }
196 None => {
197 let sql = r#"
198 SELECT * FROM messages
199 WHERE conversation_id = ? AND org_id = ?
200 ORDER BY seq ASC
201 "#;
202
203 sqlx::query_as::<_, MessageRow>(sql)
204 .bind(conversation_id)
205 .bind(&scope.org_id)
206 .fetch_all(&self.pool)
207 .await
208 .map_err(|e| MemoryError::Database(e.to_string()))?
209 }
210 };
211
212 rows.into_iter().map(row_to_message).collect()
213 }
214
215 async fn list_conversations(&self, scope: &Scope) -> Result<Vec<String>, MemoryError> {
216 let sql = r#"
217 SELECT DISTINCT conversation_id
218 FROM messages
219 WHERE org_id = ?
220 ORDER BY conversation_id
221 "#;
222
223 let rows: Vec<(String,)> = sqlx::query_as(sql)
224 .bind(&scope.org_id)
225 .fetch_all(&self.pool)
226 .await
227 .map_err(|e| MemoryError::Database(e.to_string()))?;
228
229 Ok(rows.into_iter().map(|(c,)| c).collect())
230 }
231
232 async fn delete_messages(
233 &self,
234 conversation_id: &str,
235 scope: &Scope,
236 ) -> Result<u64, MemoryError> {
237 let result = sqlx::query("DELETE FROM messages WHERE conversation_id = ? AND org_id = ?")
238 .bind(conversation_id)
239 .bind(&scope.org_id)
240 .execute(&self.pool)
241 .await
242 .map_err(|e| MemoryError::Database(e.to_string()))?;
243
244 Ok(result.rows_affected())
245 }
246}
247
248#[cfg(test)]
253mod tests {
254 use super::*;
255
256 async fn test_store() -> SqliteMessageStore {
257 let store = SqliteMessageStore::open("sqlite::memory:").await.unwrap();
258 store.migrate().await.unwrap();
259 store
260 }
261
262 #[tokio::test]
263 async fn message_round_trip() {
264 let store = test_store().await;
265 let scope = Scope::user("acme", "alice");
266
267 let msgs = vec![
268 ChatMessage::new("conv-1", "user", "Hello", scope.clone(), 0),
269 ChatMessage::new("conv-1", "assistant", "Hi there!", scope.clone(), 1),
270 ChatMessage::new("conv-1", "user", "How are you?", scope.clone(), 2),
271 ];
272
273 let ids = store.save_messages("conv-1", &msgs, &scope).await.unwrap();
274 assert_eq!(ids.len(), 3);
275
276 let all = store.get_messages("conv-1", None, &scope).await.unwrap();
277 assert_eq!(all.len(), 3);
278 assert_eq!(all[0].role, "user");
279 assert_eq!(all[1].role, "assistant");
280 assert_eq!(all[2].content, "How are you?");
281
282 let last2 = store.get_messages("conv-1", Some(2), &scope).await.unwrap();
283 assert_eq!(last2.len(), 2);
284 assert_eq!(last2[0].role, "assistant");
285 assert_eq!(last2[1].role, "user");
286
287 let convs = store.list_conversations(&scope).await.unwrap();
288 assert_eq!(convs, vec!["conv-1"]);
289
290 let deleted = store.delete_messages("conv-1", &scope).await.unwrap();
291 assert_eq!(deleted, 3);
292
293 let empty = store.get_messages("conv-1", None, &scope).await.unwrap();
294 assert!(empty.is_empty());
295 }
296}