1use crate::memory_db::schema::*;
4use rusqlite::{params, Result, Row, Connection};
5use chrono::{DateTime, Utc, NaiveDateTime};
6use uuid::Uuid;
7use tracing::{info, debug, warn};
8use std::sync::Arc;
9use r2d2::Pool;
10use r2d2_sqlite::SqliteConnectionManager;
11
12pub struct MessageParams<'a> {
14 pub session_id: &'a str,
15 pub role: &'a str,
16 pub content: &'a str,
17 pub message_index: i32,
18 pub tokens: i32,
19 pub importance_score: f32,
20}
21
22pub struct ConversationStore {
24 pool: Arc<Pool<SqliteConnectionManager>>,
25}
26
27impl ConversationStore {
28 pub fn new(pool: Arc<Pool<SqliteConnectionManager>>) -> Self {
30 Self { pool }
31 }
32
33 fn get_conn(&self) -> anyhow::Result<r2d2::PooledConnection<SqliteConnectionManager>> {
35 self.pool.get().map_err(|e| anyhow::anyhow!("Failed to get connection from pool: {}", e))
36 }
37
38 pub fn get_conn_public(&self) -> anyhow::Result<r2d2::PooledConnection<SqliteConnectionManager>> {
40 self.get_conn()
41 }
42
43 fn update_session_access_with_conn(&self, conn: &Connection, session_id: &str) -> Result<()> {
47 let now = Utc::now().to_rfc3339();
48 conn.execute(
49 "UPDATE sessions SET last_accessed = ?1 WHERE id = ?2",
50 params![now, session_id],
51 )?;
52 Ok(())
53 }
54
55 pub fn store_message_with_tx(
57 &self,
58 tx: &mut Connection,
59 params: MessageParams,
60 ) -> anyhow::Result<StoredMessage> {
61 self.update_session_access_with_conn(tx, params.session_id)?;
63
64 let now = Utc::now();
65
66 tx.execute(
67 "INSERT INTO messages
68 (session_id, message_index, role, content, tokens, timestamp, importance_score, embedding_generated)
69 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
70 params![
71 params.session_id,
72 params.message_index,
73 params.role,
74 params.content,
75 params.tokens,
76 now.to_rfc3339(),
77 params.importance_score,
78 false,
79 ],
80 )?;
81
82 let id = tx.last_insert_rowid();
83
84 Ok(StoredMessage {
85 id,
86 session_id: params.session_id.to_string(),
87 message_index: params.message_index,
88 role: params.role.to_string(),
89 content: params.content.to_string(),
90 tokens: params.tokens,
91 timestamp: now,
92 importance_score: params.importance_score,
93 embedding_generated: false,
94 })
95 }
96
97 pub fn store_messages_batch(
101 &self,
102 session_id: &str,
103 messages: &[(String, String, i32, i32, f32)], ) -> anyhow::Result<Vec<StoredMessage>> {
105 let mut conn = self.get_conn()?;
106
107 self.update_session_access_with_conn(&conn, session_id)?;
108
109 let now = Utc::now();
110 let now_str = now.to_rfc3339();
111 let mut stored_messages = Vec::new();
112
113 let tx = conn.transaction()?;
114 {
115 for (role, content, message_index, tokens, importance_score) in messages.iter() {
116 tx.execute(
117 "INSERT INTO messages
118 (session_id, message_index, role, content, tokens, timestamp, importance_score, embedding_generated)
119 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
120 params![session_id, message_index, role, content, tokens, &now_str, importance_score, false],
121 )?;
122
123 let id = tx.last_insert_rowid();
124
125 stored_messages.push(StoredMessage {
126 id,
127 session_id: session_id.to_string(),
128 message_index: *message_index,
129 role: role.clone(),
130 content: content.clone(),
131 tokens: *tokens,
132 timestamp: now,
133 importance_score: *importance_score,
134 embedding_generated: false,
135 });
136
137 }
140 }
141 tx.commit()?;
142
143 debug!("Stored {} messages in batch for session {}", messages.len(), session_id);
144 Ok(stored_messages)
145 }
146
147 pub fn store_details_batch(
149 &self,
150 details: &[(&str, i64, &str, &str, &str, f32)], ) -> anyhow::Result<()> {
152 if details.is_empty() { return Ok(()); }
153
154 let mut conn = self.get_conn()?;
155 let now = Utc::now().to_rfc3339();
156 let tx = conn.transaction()?;
157
158 for (session_id, message_id, detail_type, content, context, importance_score) in details {
159 tx.execute(
160 "INSERT INTO details
161 (session_id, message_id, detail_type, content, context, importance_score, accessed_count, last_accessed)
162 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
163 params![session_id, message_id, detail_type, content, context, importance_score, 0, &now],
164 )?;
165 }
166
167 tx.commit()?;
168 debug!("Stored {} details in batch", details.len());
169 Ok(())
170 }
171
172 pub fn create_session(&self, metadata: Option<SessionMetadata>) -> anyhow::Result<Session> {
175 let session_id = Uuid::new_v4().to_string();
176 let now = Utc::now();
177 let metadata = metadata.unwrap_or_default();
178 let metadata_json = serde_json::to_string(&metadata)?;
179
180 let conn = self.get_conn()?;
181 conn.execute(
182 "INSERT INTO sessions (id, created_at, last_accessed, metadata) VALUES (?1, ?2, ?3, ?4)",
183 params![&session_id, now.to_rfc3339(), now.to_rfc3339(), metadata_json],
184 )?;
185
186 Ok(Session { id: session_id, created_at: now, last_accessed: now, metadata })
187 }
188
189 pub fn create_session_with_id(&self, session_id: &str, metadata: Option<SessionMetadata>) -> anyhow::Result<Session> {
191 let now = Utc::now();
192 let metadata = metadata.unwrap_or_default();
193 let metadata_json = serde_json::to_string(&metadata)?;
194
195 let conn = self.get_conn()?;
196 conn.execute(
197 "INSERT INTO sessions (id, created_at, last_accessed, metadata) VALUES (?1, ?2, ?3, ?4)",
198 params![session_id, now.to_rfc3339(), now.to_rfc3339(), metadata_json],
199 )?;
200
201 info!("Created session with ID: {}", session_id);
202 Ok(Session { id: session_id.to_string(), created_at: now, last_accessed: now, metadata })
203 }
204
205 pub fn update_session_title(&self, session_id: &str, title: &str) -> anyhow::Result<()> {
209 let conn = self.get_conn()?;
210 let now = Utc::now();
211
212 let default_metadata = SessionMetadata {
214 title: Some(title.to_string()),
215 ..Default::default()
216 };
217 let default_metadata_json = serde_json::to_string(&default_metadata)?;
218
219 conn.execute(
220 "INSERT OR IGNORE INTO sessions (id, created_at, last_accessed, metadata) VALUES (?1, ?2, ?3, ?4)",
221 params![session_id, now.to_rfc3339(), now.to_rfc3339(), default_metadata_json],
222 )?;
223
224 let mut stmt = conn.prepare("SELECT metadata FROM sessions WHERE id = ?1")?;
226 let mut rows = stmt.query([session_id])?;
227
228 if let Some(row) = rows.next()? {
229 let metadata_json: String = row.get(0)?;
230 let mut metadata: SessionMetadata = serde_json::from_str(&metadata_json)
231 .unwrap_or_default();
232
233 metadata.title = Some(title.to_string());
234 let updated_metadata_json = serde_json::to_string(&metadata)?;
235
236 conn.execute(
237 "UPDATE sessions SET metadata = ?1, last_accessed = ?2 WHERE id = ?3",
238 params![updated_metadata_json, now.to_rfc3339(), session_id],
239 )?;
240
241 info!("Updated session {} title to: {}", session_id, title);
242 }
243
244 Ok(())
245 }
246
247 pub fn update_session_pinned(&self, session_id: &str, pinned: bool) -> anyhow::Result<()> {
248 let conn = self.get_conn()?;
249
250 let mut stmt = conn.prepare("SELECT metadata FROM sessions WHERE id = ?1")?;
252 let mut rows = stmt.query([session_id])?;
253
254 if let Some(row) = rows.next()? {
255 let metadata_json: String = row.get(0)?;
256 let mut metadata: SessionMetadata = serde_json::from_str(&metadata_json)
257 .unwrap_or_default();
258
259 metadata.pinned = pinned;
261 let updated_metadata_json = serde_json::to_string(&metadata)?;
262
263 conn.execute(
266 "UPDATE sessions SET metadata = ?1 WHERE id = ?2",
267 params![updated_metadata_json, session_id],
268 )?;
269
270 info!("Updated session {} pinned status to: {}", session_id, pinned);
271 Ok(())
272 } else {
273 Err(anyhow::anyhow!("Session {} not found", session_id))
274 }
275 }
276
277 pub fn get_session(&self, session_id: &str) -> anyhow::Result<Option<Session>> {
278 let conn = self.get_conn()?;
279 let mut stmt = conn.prepare("SELECT id, created_at, last_accessed, metadata FROM sessions WHERE id = ?1")?;
280 let mut rows = stmt.query([session_id])?;
281
282 if let Some(row) = rows.next()? {
283 Ok(Some(self.row_to_session(row)?))
284 } else {
285 Ok(None)
286 }
287 }
288
289 pub fn get_all_sessions(&self) -> anyhow::Result<Vec<Session>> {
291 let conn = self.get_conn()?;
292 let mut stmt = conn.prepare(
293 "SELECT id, created_at, last_accessed, metadata FROM sessions ORDER BY last_accessed DESC"
294 )?;
295 let mut rows = stmt.query([])?;
296 let mut sessions = Vec::new();
297
298 while let Some(row) = rows.next()? {
299 sessions.push(self.row_to_session(row)?);
300 }
301
302 Ok(sessions)
303 }
304
305 fn parse_datetime_safe(datetime_str: &str) -> Option<DateTime<Utc>> {
308 if let Ok(dt) = DateTime::parse_from_rfc3339(datetime_str) {
309 return Some(dt.with_timezone(&Utc));
310 }
311 if let Ok(dt) = DateTime::parse_from_str(datetime_str, "%+") {
312 return Some(dt.with_timezone(&Utc));
313 }
314 if let Ok(dt) = NaiveDateTime::parse_from_str(datetime_str, "%Y-%m-%d %H:%M:%S") {
315 return Some(DateTime::from_naive_utc_and_offset(dt, Utc));
316 }
317 if let Ok(dt) = NaiveDateTime::parse_from_str(datetime_str, "%Y-%m-%d %H:%M:%S%.f") {
318 return Some(DateTime::from_naive_utc_and_offset(dt, Utc));
319 }
320 None
321 }
322
323 fn row_to_session(&self, row: &Row) -> anyhow::Result<Session> {
324 let metadata_json: String = row.get(3)?;
325 let metadata: SessionMetadata = serde_json::from_str(&metadata_json)
326 .map_err(|e| anyhow::anyhow!("Metadata JSON error: {}", e))?;
327
328 let created_at = Self::parse_datetime_safe(&row.get::<_, String>(1)?)
329 .unwrap_or_else(|| { warn!("Failed parse created_at"); Utc::now() });
330
331 let last_accessed = Self::parse_datetime_safe(&row.get::<_, String>(2)?)
332 .unwrap_or_else(|| { warn!("Failed parse last_accessed"); Utc::now() });
333
334 Ok(Session { id: row.get(0)?, created_at, last_accessed, metadata })
335 }
336
337 fn row_to_stored_message(&self, row: &Row) -> anyhow::Result<StoredMessage> {
338 let timestamp = Self::parse_datetime_safe(&row.get::<_, String>(6)?)
339 .unwrap_or_else(|| { warn!("Failed parse message timestamp"); Utc::now() });
340
341 Ok(StoredMessage {
342 id: row.get(0)?,
343 session_id: row.get(1)?,
344 message_index: row.get(2)?,
345 role: row.get(3)?,
346 content: row.get(4)?,
347 tokens: row.get(5)?,
348 timestamp,
349 importance_score: row.get(7)?,
350 embedding_generated: row.get(8)?,
351 })
352 }
353
354 pub fn get_session_messages(&self, session_id: &str, limit: Option<i32>, offset: Option<i32>) -> anyhow::Result<Vec<StoredMessage>> {
357 let conn = self.get_conn()?;
358 let mut stmt = conn.prepare(
359 "SELECT id, session_id, message_index, role, content, tokens, timestamp, importance_score, embedding_generated
360 FROM messages WHERE session_id = ?1 ORDER BY message_index LIMIT ?2 OFFSET ?3"
361 )?;
362 let mut rows = stmt.query(params![session_id, limit.unwrap_or(1000), offset.unwrap_or(0)])?;
363 let mut messages = Vec::new();
364 while let Some(row) = rows.next()? { messages.push(self.row_to_stored_message(row)?); }
365 Ok(messages)
366 }
367
368 pub fn get_session_message_count(&self, session_id: &str) -> anyhow::Result<usize> {
369 let conn = self.get_conn()?;
370 let count: i64 = conn.query_row(
371 "SELECT COUNT(*) FROM messages WHERE session_id = ?1",
372 [session_id],
373 |row| row.get(0)
374 )?;
375 Ok(count as usize)
376 }
377
378 pub fn mark_embedding_generated(&self, message_id: i64) -> anyhow::Result<()> {
379 let conn = self.get_conn()?;
380 conn.execute("UPDATE messages SET embedding_generated = TRUE WHERE id = ?1", [message_id])?;
381 Ok(())
382 }
383
384 pub fn delete_session(&self, session_id: &str) -> anyhow::Result<usize> {
385 let conn = self.get_conn()?;
386 let deleted = conn.execute("DELETE FROM sessions WHERE id = ?1", [session_id])?;
387 info!("Deleted session {}", session_id);
388 Ok(deleted)
389 }
390
391 pub async fn search_messages_by_keywords(
395 &self,
396 session_id: &str,
397 keywords: &[String],
398 limit: usize,
399 ) -> anyhow::Result<Vec<StoredMessage>> {
400 let conn = self.get_conn()?;
401
402 let patterns: Vec<String> = keywords.iter()
404 .map(|k| format!("%{}%", k.to_lowercase()))
405 .collect();
406
407 let mut query = String::from(
409 "SELECT id, session_id, message_index, role, content, tokens,
410 timestamp, importance_score, embedding_generated
411 FROM messages
412 WHERE session_id = ?1"
413 );
414
415 for i in 0..patterns.len() {
416 query.push_str(&format!(" AND LOWER(content) LIKE ?{}", i + 2));
417 }
418
419 query.push_str(" ORDER BY timestamp DESC LIMIT ?");
420
421 let mut stmt = conn.prepare(&query)?;
422
423 let mut params: Vec<&dyn rusqlite::ToSql> = Vec::new();
425 params.push(&session_id);
426 for pattern in &patterns {
427 params.push(pattern);
428 }
429 let limit_i64 = limit as i64;
431 params.push(&limit_i64);
432
433 let mut rows = stmt.query(rusqlite::params_from_iter(params))?;
434 let mut messages = Vec::new();
435
436 while let Some(row) = rows.next()? {
437 messages.push(self.row_to_stored_message(row)?);
438 }
439
440 Ok(messages)
441 }
442
443 pub async fn search_messages_by_topic_across_sessions(
445 &self,
446 topic_keywords: &[String],
447 limit: usize,
448 session_id_filter: Option<&str>, ) -> anyhow::Result<Vec<StoredMessage>> {
450 let conn = self.get_conn()?;
451
452 let patterns: Vec<String> = topic_keywords.iter()
454 .map(|k| format!("%{}%", k.to_lowercase()))
455 .collect();
456
457 let mut query = String::from(
459 "SELECT m.id, m.session_id, m.message_index, m.role, m.content,
460 m.tokens, m.timestamp, m.importance_score, m.embedding_generated
461 FROM messages m
462 JOIN sessions s ON m.session_id = s.id
463 WHERE 1=1"
464 );
465
466 let mut params: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
468 if let Some(session_id) = session_id_filter {
469 query.push_str(" AND m.session_id != ?");
470 params.push(Box::new(session_id.to_string())); }
472
473 for pattern in &patterns {
475 query.push_str(" AND LOWER(m.content) LIKE ?");
476 params.push(Box::new(pattern.clone())); }
478
479 query.push_str(" ORDER BY
481 m.importance_score DESC,
482 CASE WHEN m.role = 'assistant' THEN 1 ELSE 0 END, -- Prioritize assistant responses
483 s.last_accessed DESC,
484 m.timestamp DESC
485 LIMIT ?");
486
487 let limit_i64 = limit as i64;
489 params.push(Box::new(limit_i64));
490
491 let mut stmt = conn.prepare(&query)?;
492
493 let param_refs: Vec<&dyn rusqlite::ToSql> = params.iter()
495 .map(|p| p.as_ref())
496 .collect();
497
498 let mut rows = stmt.query(rusqlite::params_from_iter(param_refs))?;
499 let mut messages = Vec::new();
500
501 while let Some(row) = rows.next()? {
502 let timestamp_str: String = row.get(6)?;
503 let timestamp = chrono::DateTime::parse_from_rfc3339(×tamp_str)
504 .map_err(|e| anyhow::anyhow!("Failed to parse timestamp: {}", e))?
505 .with_timezone(&chrono::Utc);
506
507 messages.push(StoredMessage {
508 id: row.get(0)?,
509 session_id: row.get(1)?,
510 message_index: row.get(2)?,
511 role: row.get(3)?,
512 content: row.get(4)?,
513 tokens: row.get(5)?,
514 timestamp,
515 importance_score: row.get(7)?,
516 embedding_generated: row.get(8)?,
517 });
518 }
519
520 Ok(messages)
521 }
522}