1use crate::memory_db::schema::*;
4use rusqlite::{params, Result, Row, Connection};
5use chrono::{DateTime, Utc, NaiveDateTime};
6use uuid::Uuid;
7use tracing::{info, debug, warn};
8use std::sync::Arc;
9use r2d2::Pool;
10use r2d2_sqlite::SqliteConnectionManager;
11
12pub struct MessageParams<'a> {
14 pub session_id: &'a str,
15 pub role: &'a str,
16 pub content: &'a str,
17 pub message_index: i32,
18 pub tokens: i32,
19 pub importance_score: f32,
20}
21
22pub struct ConversationStore {
24 pool: Arc<Pool<SqliteConnectionManager>>,
25}
26
27impl ConversationStore {
28 pub fn new(pool: Arc<Pool<SqliteConnectionManager>>) -> Self {
30 Self { pool }
31 }
32
33 fn get_conn(&self) -> anyhow::Result<r2d2::PooledConnection<SqliteConnectionManager>> {
35 self.pool.get().map_err(|e| anyhow::anyhow!("Failed to get connection from pool: {}", e))
36 }
37
38 pub fn get_conn_public(&self) -> anyhow::Result<r2d2::PooledConnection<SqliteConnectionManager>> {
40 self.get_conn()
41 }
42
43 fn update_session_access_with_conn(&self, conn: &Connection, session_id: &str) -> Result<()> {
47 let now = Utc::now().to_rfc3339();
48 conn.execute(
49 "UPDATE sessions SET last_accessed = ?1 WHERE id = ?2",
50 params![now, session_id],
51 )?;
52 Ok(())
53 }
54
55 pub fn store_message_with_tx(
57 &self,
58 tx: &mut Connection,
59 params: MessageParams,
60 ) -> anyhow::Result<StoredMessage> {
61 self.update_session_access_with_conn(tx, params.session_id)?;
63
64 let now = Utc::now();
65
66 tx.execute(
67 "INSERT INTO messages
68 (session_id, message_index, role, content, tokens, timestamp, importance_score, embedding_generated)
69 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
70 params![
71 params.session_id,
72 params.message_index,
73 params.role,
74 params.content,
75 params.tokens,
76 now.to_rfc3339(),
77 params.importance_score,
78 false,
79 ],
80 )?;
81
82 let id = tx.last_insert_rowid();
83
84 Ok(StoredMessage {
85 id,
86 session_id: params.session_id.to_string(),
87 message_index: params.message_index,
88 role: params.role.to_string(),
89 content: params.content.to_string(),
90 tokens: params.tokens,
91 timestamp: now,
92 importance_score: params.importance_score,
93 embedding_generated: false,
94 })
95 }
96
97 pub fn store_messages_batch(
101 &self,
102 session_id: &str,
103 messages: &[(String, String, i32, i32, f32)], ) -> anyhow::Result<Vec<StoredMessage>> {
105 let mut conn = self.get_conn()?;
106
107 self.update_session_access_with_conn(&conn, session_id)?;
108
109 let now = Utc::now();
110 let now_str = now.to_rfc3339();
111 let mut stored_messages = Vec::new();
112
113 let tx = conn.transaction()?;
114 {
115 for (role, content, message_index, tokens, importance_score) in messages.iter() {
116 tx.execute(
117 "INSERT INTO messages
118 (session_id, message_index, role, content, tokens, timestamp, importance_score, embedding_generated)
119 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
120 params![session_id, message_index, role, content, tokens, &now_str, importance_score, false],
121 )?;
122
123 let id = tx.last_insert_rowid();
124
125 stored_messages.push(StoredMessage {
126 id,
127 session_id: session_id.to_string(),
128 message_index: *message_index,
129 role: role.clone(),
130 content: content.clone(),
131 tokens: *tokens,
132 timestamp: now,
133 importance_score: *importance_score,
134 embedding_generated: false,
135 });
136
137 }
140 }
141 tx.commit()?;
142
143 debug!("Stored {} messages in batch for session {}", messages.len(), session_id);
144 Ok(stored_messages)
145 }
146
147 pub fn store_details_batch(
149 &self,
150 details: &[(&str, i64, &str, &str, &str, f32)], ) -> anyhow::Result<()> {
152 if details.is_empty() { return Ok(()); }
153
154 let mut conn = self.get_conn()?;
155 let now = Utc::now().to_rfc3339();
156 let tx = conn.transaction()?;
157
158 for (session_id, message_id, detail_type, content, context, importance_score) in details {
159 tx.execute(
160 "INSERT INTO details
161 (session_id, message_id, detail_type, content, context, importance_score, accessed_count, last_accessed)
162 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
163 params![session_id, message_id, detail_type, content, context, importance_score, 0, &now],
164 )?;
165 }
166
167 tx.commit()?;
168 debug!("Stored {} details in batch", details.len());
169 Ok(())
170 }
171
172 pub fn create_session(&self, metadata: Option<SessionMetadata>) -> anyhow::Result<Session> {
175 let session_id = Uuid::new_v4().to_string();
176 let now = Utc::now();
177 let metadata = metadata.unwrap_or_default();
178 let metadata_json = serde_json::to_string(&metadata)?;
179
180 let conn = self.get_conn()?;
181 conn.execute(
182 "INSERT INTO sessions (id, created_at, last_accessed, metadata) VALUES (?1, ?2, ?3, ?4)",
183 params![&session_id, now.to_rfc3339(), now.to_rfc3339(), metadata_json],
184 )?;
185
186 Ok(Session { id: session_id, created_at: now, last_accessed: now, metadata })
187 }
188
189 pub fn create_session_with_id(&self, session_id: &str, metadata: Option<SessionMetadata>) -> anyhow::Result<Session> {
191 let now = Utc::now();
192 let metadata = metadata.unwrap_or_default();
193 let metadata_json = serde_json::to_string(&metadata)?;
194
195 let conn = self.get_conn()?;
196 conn.execute(
197 "INSERT INTO sessions (id, created_at, last_accessed, metadata) VALUES (?1, ?2, ?3, ?4)",
198 params![session_id, now.to_rfc3339(), now.to_rfc3339(), metadata_json],
199 )?;
200
201 info!("Created session with ID: {}", session_id);
202 Ok(Session { id: session_id.to_string(), created_at: now, last_accessed: now, metadata })
203 }
204
205 pub fn update_session_title(&self, session_id: &str, title: &str) -> anyhow::Result<()> {
207 let conn = self.get_conn()?;
208
209 let mut stmt = conn.prepare("SELECT metadata FROM sessions WHERE id = ?1")?;
211 let mut rows = stmt.query([session_id])?;
212
213 if let Some(row) = rows.next()? {
214 let metadata_json: String = row.get(0)?;
215 let mut metadata: SessionMetadata = serde_json::from_str(&metadata_json)
216 .unwrap_or_default();
217
218 metadata.title = Some(title.to_string());
220 let updated_metadata_json = serde_json::to_string(&metadata)?;
221
222 let now = Utc::now();
224 conn.execute(
225 "UPDATE sessions SET metadata = ?1, last_accessed = ?2 WHERE id = ?3",
226 params![updated_metadata_json, now.to_rfc3339(), session_id],
227 )?;
228
229 info!("Updated session {} title to: {}", session_id, title);
230 Ok(())
231 } else {
232 Err(anyhow::anyhow!("Session {} not found", session_id))
233 }
234 }
235
236 pub fn update_session_pinned(&self, session_id: &str, pinned: bool) -> anyhow::Result<()> {
237 let conn = self.get_conn()?;
238
239 let mut stmt = conn.prepare("SELECT metadata FROM sessions WHERE id = ?1")?;
241 let mut rows = stmt.query([session_id])?;
242
243 if let Some(row) = rows.next()? {
244 let metadata_json: String = row.get(0)?;
245 let mut metadata: SessionMetadata = serde_json::from_str(&metadata_json)
246 .unwrap_or_default();
247
248 metadata.pinned = pinned;
250 let updated_metadata_json = serde_json::to_string(&metadata)?;
251
252 conn.execute(
255 "UPDATE sessions SET metadata = ?1 WHERE id = ?2",
256 params![updated_metadata_json, session_id],
257 )?;
258
259 info!("Updated session {} pinned status to: {}", session_id, pinned);
260 Ok(())
261 } else {
262 Err(anyhow::anyhow!("Session {} not found", session_id))
263 }
264 }
265
266 pub fn get_session(&self, session_id: &str) -> anyhow::Result<Option<Session>> {
267 let conn = self.get_conn()?;
268 let mut stmt = conn.prepare("SELECT id, created_at, last_accessed, metadata FROM sessions WHERE id = ?1")?;
269 let mut rows = stmt.query([session_id])?;
270
271 if let Some(row) = rows.next()? {
272 Ok(Some(self.row_to_session(row)?))
273 } else {
274 Ok(None)
275 }
276 }
277
278 pub fn get_all_sessions(&self) -> anyhow::Result<Vec<Session>> {
280 let conn = self.get_conn()?;
281 let mut stmt = conn.prepare(
282 "SELECT id, created_at, last_accessed, metadata FROM sessions ORDER BY last_accessed DESC"
283 )?;
284 let mut rows = stmt.query([])?;
285 let mut sessions = Vec::new();
286
287 while let Some(row) = rows.next()? {
288 sessions.push(self.row_to_session(row)?);
289 }
290
291 Ok(sessions)
292 }
293
294 fn parse_datetime_safe(datetime_str: &str) -> Option<DateTime<Utc>> {
297 if let Ok(dt) = DateTime::parse_from_rfc3339(datetime_str) {
298 return Some(dt.with_timezone(&Utc));
299 }
300 if let Ok(dt) = DateTime::parse_from_str(datetime_str, "%+") {
301 return Some(dt.with_timezone(&Utc));
302 }
303 if let Ok(dt) = NaiveDateTime::parse_from_str(datetime_str, "%Y-%m-%d %H:%M:%S") {
304 return Some(DateTime::from_naive_utc_and_offset(dt, Utc));
305 }
306 if let Ok(dt) = NaiveDateTime::parse_from_str(datetime_str, "%Y-%m-%d %H:%M:%S%.f") {
307 return Some(DateTime::from_naive_utc_and_offset(dt, Utc));
308 }
309 None
310 }
311
312 fn row_to_session(&self, row: &Row) -> anyhow::Result<Session> {
313 let metadata_json: String = row.get(3)?;
314 let metadata: SessionMetadata = serde_json::from_str(&metadata_json)
315 .map_err(|e| anyhow::anyhow!("Metadata JSON error: {}", e))?;
316
317 let created_at = Self::parse_datetime_safe(&row.get::<_, String>(1)?)
318 .unwrap_or_else(|| { warn!("Failed parse created_at"); Utc::now() });
319
320 let last_accessed = Self::parse_datetime_safe(&row.get::<_, String>(2)?)
321 .unwrap_or_else(|| { warn!("Failed parse last_accessed"); Utc::now() });
322
323 Ok(Session { id: row.get(0)?, created_at, last_accessed, metadata })
324 }
325
326 fn row_to_stored_message(&self, row: &Row) -> anyhow::Result<StoredMessage> {
327 let timestamp = Self::parse_datetime_safe(&row.get::<_, String>(6)?)
328 .unwrap_or_else(|| { warn!("Failed parse message timestamp"); Utc::now() });
329
330 Ok(StoredMessage {
331 id: row.get(0)?,
332 session_id: row.get(1)?,
333 message_index: row.get(2)?,
334 role: row.get(3)?,
335 content: row.get(4)?,
336 tokens: row.get(5)?,
337 timestamp,
338 importance_score: row.get(7)?,
339 embedding_generated: row.get(8)?,
340 })
341 }
342
343 pub fn get_session_messages(&self, session_id: &str, limit: Option<i32>, offset: Option<i32>) -> anyhow::Result<Vec<StoredMessage>> {
346 let conn = self.get_conn()?;
347 let mut stmt = conn.prepare(
348 "SELECT id, session_id, message_index, role, content, tokens, timestamp, importance_score, embedding_generated
349 FROM messages WHERE session_id = ?1 ORDER BY message_index LIMIT ?2 OFFSET ?3"
350 )?;
351 let mut rows = stmt.query(params![session_id, limit.unwrap_or(1000), offset.unwrap_or(0)])?;
352 let mut messages = Vec::new();
353 while let Some(row) = rows.next()? { messages.push(self.row_to_stored_message(row)?); }
354 Ok(messages)
355 }
356
357 pub fn get_session_message_count(&self, session_id: &str) -> anyhow::Result<usize> {
358 let conn = self.get_conn()?;
359 let count: i64 = conn.query_row(
360 "SELECT COUNT(*) FROM messages WHERE session_id = ?1",
361 [session_id],
362 |row| row.get(0)
363 )?;
364 Ok(count as usize)
365 }
366
367 pub fn mark_embedding_generated(&self, message_id: i64) -> anyhow::Result<()> {
368 let conn = self.get_conn()?;
369 conn.execute("UPDATE messages SET embedding_generated = TRUE WHERE id = ?1", [message_id])?;
370 Ok(())
371 }
372
373 pub fn delete_session(&self, session_id: &str) -> anyhow::Result<usize> {
374 let conn = self.get_conn()?;
375 let deleted = conn.execute("DELETE FROM sessions WHERE id = ?1", [session_id])?;
376 info!("Deleted session {}", session_id);
377 Ok(deleted)
378 }
379
380 pub async fn search_messages_by_keywords(
384 &self,
385 session_id: &str,
386 keywords: &[String],
387 limit: usize,
388 ) -> anyhow::Result<Vec<StoredMessage>> {
389 let conn = self.get_conn()?;
390
391 let patterns: Vec<String> = keywords.iter()
393 .map(|k| format!("%{}%", k.to_lowercase()))
394 .collect();
395
396 let mut query = String::from(
398 "SELECT id, session_id, message_index, role, content, tokens,
399 timestamp, importance_score, embedding_generated
400 FROM messages
401 WHERE session_id = ?1"
402 );
403
404 for i in 0..patterns.len() {
405 query.push_str(&format!(" AND LOWER(content) LIKE ?{}", i + 2));
406 }
407
408 query.push_str(" ORDER BY timestamp DESC LIMIT ?");
409
410 let mut stmt = conn.prepare(&query)?;
411
412 let mut params: Vec<&dyn rusqlite::ToSql> = Vec::new();
414 params.push(&session_id);
415 for pattern in &patterns {
416 params.push(pattern);
417 }
418 let limit_i64 = limit as i64;
420 params.push(&limit_i64);
421
422 let mut rows = stmt.query(rusqlite::params_from_iter(params))?;
423 let mut messages = Vec::new();
424
425 while let Some(row) = rows.next()? {
426 messages.push(self.row_to_stored_message(row)?);
427 }
428
429 Ok(messages)
430 }
431
432 pub async fn search_messages_by_topic_across_sessions(
434 &self,
435 topic_keywords: &[String],
436 limit: usize,
437 session_id_filter: Option<&str>, ) -> anyhow::Result<Vec<StoredMessage>> {
439 let conn = self.get_conn()?;
440
441 let patterns: Vec<String> = topic_keywords.iter()
443 .map(|k| format!("%{}%", k.to_lowercase()))
444 .collect();
445
446 let mut query = String::from(
448 "SELECT m.id, m.session_id, m.message_index, m.role, m.content,
449 m.tokens, m.timestamp, m.importance_score, m.embedding_generated
450 FROM messages m
451 JOIN sessions s ON m.session_id = s.id
452 WHERE 1=1"
453 );
454
455 let mut params: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
457 if let Some(session_id) = session_id_filter {
458 query.push_str(" AND m.session_id != ?");
459 params.push(Box::new(session_id.to_string())); }
461
462 for pattern in &patterns {
464 query.push_str(" AND LOWER(m.content) LIKE ?");
465 params.push(Box::new(pattern.clone())); }
467
468 query.push_str(" ORDER BY
470 m.importance_score DESC,
471 CASE WHEN m.role = 'assistant' THEN 1 ELSE 0 END, -- Prioritize assistant responses
472 s.last_accessed DESC,
473 m.timestamp DESC
474 LIMIT ?");
475
476 let limit_i64 = limit as i64;
478 params.push(Box::new(limit_i64));
479
480 let mut stmt = conn.prepare(&query)?;
481
482 let param_refs: Vec<&dyn rusqlite::ToSql> = params.iter()
484 .map(|p| p.as_ref())
485 .collect();
486
487 let mut rows = stmt.query(rusqlite::params_from_iter(param_refs))?;
488 let mut messages = Vec::new();
489
490 while let Some(row) = rows.next()? {
491 let timestamp_str: String = row.get(6)?;
492 let timestamp = chrono::DateTime::parse_from_rfc3339(×tamp_str)
493 .map_err(|e| anyhow::anyhow!("Failed to parse timestamp: {}", e))?
494 .with_timezone(&chrono::Utc);
495
496 messages.push(StoredMessage {
497 id: row.get(0)?,
498 session_id: row.get(1)?,
499 message_index: row.get(2)?,
500 role: row.get(3)?,
501 content: row.get(4)?,
502 tokens: row.get(5)?,
503 timestamp,
504 importance_score: row.get(7)?,
505 embedding_generated: row.get(8)?,
506 });
507 }
508
509 Ok(messages)
510 }
511}