1use crate::memory_db::schema::*;
4use rusqlite::{params, Result, Row, Connection};
5use chrono::{DateTime, Utc, NaiveDateTime};
6use uuid::Uuid;
7use tracing::{info, debug, warn};
8use std::sync::Arc;
9use r2d2::Pool;
10use r2d2_sqlite::SqliteConnectionManager;
11
12pub struct ConversationStore {
14 pool: Arc<Pool<SqliteConnectionManager>>,
15}
16
17impl ConversationStore {
18 pub fn new(pool: Arc<Pool<SqliteConnectionManager>>) -> Self {
20 Self { pool }
21 }
22
23 fn get_conn(&self) -> anyhow::Result<r2d2::PooledConnection<SqliteConnectionManager>> {
25 self.pool.get().map_err(|e| anyhow::anyhow!("Failed to get connection from pool: {}", e))
26 }
27
28 fn update_session_access_with_conn(&self, conn: &Connection, session_id: &str) -> Result<()> {
32 let now = Utc::now().to_rfc3339();
33 conn.execute(
34 "UPDATE sessions SET last_accessed = ?1 WHERE id = ?2",
35 params![now, session_id],
36 )?;
37 Ok(())
38 }
39
40 pub fn store_message_with_tx(
42 &self,
43 tx: &mut Connection, session_id: &str,
45 role: &str,
46 content: &str,
47 message_index: i32,
48 tokens: i32,
49 importance_score: f32,
50 ) -> anyhow::Result<StoredMessage> {
51 self.update_session_access_with_conn(tx, session_id)?;
53
54 let now = Utc::now();
55
56 tx.execute(
57 "INSERT INTO messages
58 (session_id, message_index, role, content, tokens, timestamp, importance_score, embedding_generated)
59 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
60 params![
61 session_id,
62 message_index,
63 role,
64 content,
65 tokens,
66 now.to_rfc3339(),
67 importance_score,
68 false,
69 ],
70 )?;
71
72 let id = tx.last_insert_rowid();
73
74 Ok(StoredMessage {
75 id,
76 session_id: session_id.to_string(),
77 message_index,
78 role: role.to_string(),
79 content: content.to_string(),
80 tokens,
81 timestamp: now,
82 importance_score,
83 embedding_generated: false,
84 })
85 }
86
87 pub fn store_messages_batch(
91 &self,
92 session_id: &str,
93 messages: &[(String, String, i32, i32, f32)], ) -> anyhow::Result<Vec<StoredMessage>> {
95 let mut conn = self.get_conn()?;
96
97 self.update_session_access_with_conn(&conn, session_id)?;
98
99 let now = Utc::now();
100 let now_str = now.to_rfc3339();
101 let mut stored_messages = Vec::new();
102
103 let tx = conn.transaction()?;
104 {
105 for (idx, (role, content, message_index, tokens, importance_score)) in messages.iter().enumerate() {
106 tx.execute(
107 "INSERT INTO messages
108 (session_id, message_index, role, content, tokens, timestamp, importance_score, embedding_generated)
109 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
110 params![session_id, message_index, role, content, tokens, &now_str, importance_score, false],
111 )?;
112
113 let id = tx.last_insert_rowid();
114
115 stored_messages.push(StoredMessage {
116 id,
117 session_id: session_id.to_string(),
118 message_index: *message_index,
119 role: role.clone(),
120 content: content.clone(),
121 tokens: *tokens,
122 timestamp: now,
123 importance_score: *importance_score,
124 embedding_generated: false,
125 });
126
127 }
130 }
131 tx.commit()?;
132
133 debug!("Stored {} messages in batch for session {}", messages.len(), session_id);
134 Ok(stored_messages)
135 }
136
137 pub fn store_details_batch(
139 &self,
140 details: &[(&str, i64, &str, &str, &str, f32)], ) -> anyhow::Result<()> {
142 if details.is_empty() { return Ok(()); }
143
144 let mut conn = self.get_conn()?;
145 let now = Utc::now().to_rfc3339();
146 let tx = conn.transaction()?;
147
148 for (session_id, message_id, detail_type, content, context, importance_score) in details {
149 tx.execute(
150 "INSERT INTO details
151 (session_id, message_id, detail_type, content, context, importance_score, accessed_count, last_accessed)
152 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
153 params![session_id, message_id, detail_type, content, context, importance_score, 0, &now],
154 )?;
155 }
156
157 tx.commit()?;
158 debug!("Stored {} details in batch", details.len());
159 Ok(())
160 }
161
162 pub fn create_session(&self, metadata: Option<SessionMetadata>) -> anyhow::Result<Session> {
165 self.create_session_with_id(None, metadata)
166 }
167
168 pub fn create_session_with_id(&self, session_id: Option<String>, metadata: Option<SessionMetadata>) -> anyhow::Result<Session> {
169 let session_id = session_id.unwrap_or_else(|| Uuid::new_v4().to_string());
170 let now = Utc::now();
171 let metadata = metadata.unwrap_or_default();
172 let metadata_json = serde_json::to_string(&metadata)?;
173
174 let conn = self.get_conn()?;
175 conn.execute(
176 "INSERT INTO sessions (id, created_at, last_accessed, metadata) VALUES (?1, ?2, ?3, ?4)",
177 params![&session_id, now.to_rfc3339(), now.to_rfc3339(), metadata_json],
178 )?;
179
180 Ok(Session { id: session_id, created_at: now, last_accessed: now, metadata })
181 }
182
183 pub fn get_session(&self, session_id: &str) -> anyhow::Result<Option<Session>> {
184 let conn = self.get_conn()?;
185 let mut stmt = conn.prepare("SELECT id, created_at, last_accessed, metadata FROM sessions WHERE id = ?1")?;
186 let mut rows = stmt.query([session_id])?;
187
188 if let Some(row) = rows.next()? {
189 Ok(Some(self.row_to_session(row)?))
190 } else {
191 Ok(None)
192 }
193 }
194
195 fn parse_datetime_safe(datetime_str: &str) -> Option<DateTime<Utc>> {
198 if let Ok(dt) = DateTime::parse_from_rfc3339(datetime_str) {
199 return Some(dt.with_timezone(&Utc));
200 }
201 if let Ok(dt) = DateTime::parse_from_str(datetime_str, "%+") {
202 return Some(dt.with_timezone(&Utc));
203 }
204 if let Ok(dt) = NaiveDateTime::parse_from_str(datetime_str, "%Y-%m-%d %H:%M:%S") {
205 return Some(DateTime::from_naive_utc_and_offset(dt, Utc));
206 }
207 if let Ok(dt) = NaiveDateTime::parse_from_str(datetime_str, "%Y-%m-%d %H:%M:%S%.f") {
208 return Some(DateTime::from_naive_utc_and_offset(dt, Utc));
209 }
210 None
211 }
212
213 fn row_to_session(&self, row: &Row) -> anyhow::Result<Session> {
214 let metadata_json: String = row.get(3)?;
215 let metadata: SessionMetadata = serde_json::from_str(&metadata_json)
216 .map_err(|e| anyhow::anyhow!("Metadata JSON error: {}", e))?;
217
218 let created_at = Self::parse_datetime_safe(&row.get::<_, String>(1)?)
219 .unwrap_or_else(|| { warn!("Failed parse created_at"); Utc::now() });
220
221 let last_accessed = Self::parse_datetime_safe(&row.get::<_, String>(2)?)
222 .unwrap_or_else(|| { warn!("Failed parse last_accessed"); Utc::now() });
223
224 Ok(Session { id: row.get(0)?, created_at, last_accessed, metadata })
225 }
226
227 fn row_to_stored_message(&self, row: &Row) -> anyhow::Result<StoredMessage> {
228 let timestamp = Self::parse_datetime_safe(&row.get::<_, String>(6)?)
229 .unwrap_or_else(|| { warn!("Failed parse message timestamp"); Utc::now() });
230
231 Ok(StoredMessage {
232 id: row.get(0)?,
233 session_id: row.get(1)?,
234 message_index: row.get(2)?,
235 role: row.get(3)?,
236 content: row.get(4)?,
237 tokens: row.get(5)?,
238 timestamp,
239 importance_score: row.get(7)?,
240 embedding_generated: row.get(8)?,
241 })
242 }
243
244 pub fn get_session_messages(&self, session_id: &str, limit: Option<i32>, offset: Option<i32>) -> anyhow::Result<Vec<StoredMessage>> {
247 let conn = self.get_conn()?;
248 let mut stmt = conn.prepare(
249 "SELECT id, session_id, message_index, role, content, tokens, timestamp, importance_score, embedding_generated
250 FROM messages WHERE session_id = ?1 ORDER BY message_index LIMIT ?2 OFFSET ?3"
251 )?;
252 let mut rows = stmt.query(params![session_id, limit.unwrap_or(1000), offset.unwrap_or(0)])?;
253 let mut messages = Vec::new();
254 while let Some(row) = rows.next()? { messages.push(self.row_to_stored_message(row)?); }
255 Ok(messages)
256 }
257
258 pub fn mark_embedding_generated(&self, message_id: i64) -> anyhow::Result<()> {
259 let conn = self.get_conn()?;
260 conn.execute("UPDATE messages SET embedding_generated = TRUE WHERE id = ?1", [message_id])?;
261 Ok(())
262 }
263
264 pub fn delete_session(&self, session_id: &str) -> anyhow::Result<usize> {
265 let conn = self.get_conn()?;
266 let deleted = conn.execute("DELETE FROM sessions WHERE id = ?1", [session_id])?;
267 info!("Deleted session {}", session_id);
268 Ok(deleted)
269 }
270
271 pub async fn search_messages_by_keywords(
275 &self,
276 session_id: &str,
277 keywords: &[String],
278 limit: usize,
279 ) -> anyhow::Result<Vec<StoredMessage>> {
280 let conn = self.get_conn()?;
281
282 let patterns: Vec<String> = keywords.iter()
284 .map(|k| format!("%{}%", k.to_lowercase()))
285 .collect();
286
287 let mut query = String::from(
289 "SELECT id, session_id, message_index, role, content, tokens,
290 timestamp, importance_score, embedding_generated
291 FROM messages
292 WHERE session_id = ?1"
293 );
294
295 for i in 0..patterns.len() {
296 query.push_str(&format!(" AND LOWER(content) LIKE ?{}", i + 2));
297 }
298
299 query.push_str(" ORDER BY timestamp DESC LIMIT ?");
300
301 let mut stmt = conn.prepare(&query)?;
302
303 let mut params: Vec<&dyn rusqlite::ToSql> = Vec::new();
305 params.push(&session_id);
306 for pattern in &patterns {
307 params.push(pattern);
308 }
309 let limit_i64 = limit as i64;
311 params.push(&limit_i64);
312
313 let mut rows = stmt.query(rusqlite::params_from_iter(params))?;
314 let mut messages = Vec::new();
315
316 while let Some(row) = rows.next()? {
317 messages.push(self.row_to_stored_message(row)?);
318 }
319
320 Ok(messages)
321 }
322
323 pub async fn search_messages_by_topic_across_sessions(
325 &self,
326 topic_keywords: &[String],
327 limit: usize,
328 session_id_filter: Option<&str>, ) -> anyhow::Result<Vec<StoredMessage>> {
330 let conn = self.get_conn()?;
331
332 let patterns: Vec<String> = topic_keywords.iter()
334 .map(|k| format!("%{}%", k.to_lowercase()))
335 .collect();
336
337 let mut query = String::from(
339 "SELECT m.id, m.session_id, m.message_index, m.role, m.content,
340 m.tokens, m.timestamp, m.importance_score, m.embedding_generated
341 FROM messages m
342 JOIN sessions s ON m.session_id = s.id
343 WHERE 1=1"
344 );
345
346 let mut params: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
348 if let Some(session_id) = session_id_filter {
349 query.push_str(" AND m.session_id != ?");
350 params.push(Box::new(session_id.to_string())); }
352
353 for i in 0..patterns.len() {
355 query.push_str(" AND LOWER(m.content) LIKE ?");
356 params.push(Box::new(patterns[i].clone())); }
358
359 query.push_str(" ORDER BY
361 m.importance_score DESC,
362 CASE WHEN m.role = 'assistant' THEN 1 ELSE 0 END, -- Prioritize assistant responses
363 s.last_accessed DESC,
364 m.timestamp DESC
365 LIMIT ?");
366
367 let limit_i64 = limit as i64;
369 params.push(Box::new(limit_i64));
370
371 let mut stmt = conn.prepare(&query)?;
372
373 let param_refs: Vec<&dyn rusqlite::ToSql> = params.iter()
375 .map(|p| p.as_ref())
376 .collect();
377
378 let mut rows = stmt.query(rusqlite::params_from_iter(param_refs))?;
379 let mut messages = Vec::new();
380
381 while let Some(row) = rows.next()? {
382 let timestamp_str: String = row.get(6)?;
383 let timestamp = chrono::DateTime::parse_from_rfc3339(×tamp_str)
384 .map_err(|e| anyhow::anyhow!("Failed to parse timestamp: {}", e))?
385 .with_timezone(&chrono::Utc);
386
387 messages.push(StoredMessage {
388 id: row.get(0)?,
389 session_id: row.get(1)?,
390 message_index: row.get(2)?,
391 role: row.get(3)?,
392 content: row.get(4)?,
393 tokens: row.get(5)?,
394 timestamp,
395 importance_score: row.get(7)?,
396 embedding_generated: row.get(8)?,
397 });
398 }
399
400 Ok(messages)
401 }
402}