Skip to main content

offline_intelligence/memory_db/
conversation_store.rs

1//! Conversation storage and retrieval operations with batch support and safe parsing
2
3use crate::memory_db::schema::*;
4use rusqlite::{params, Result, Row, Connection};
5use chrono::{DateTime, Utc, NaiveDateTime};
6use uuid::Uuid;
7use tracing::{info, debug, warn};
8use std::sync::Arc;
9use r2d2::Pool;
10use r2d2_sqlite::SqliteConnectionManager;
11
12/// Parameters for storing a message
13pub struct MessageParams<'a> {
14    pub session_id: &'a str,
15    pub role: &'a str,
16    pub content: &'a str,
17    pub message_index: i32,
18    pub tokens: i32,
19    pub importance_score: f32,
20}
21
22/// Manages conversation storage and retrieval using a connection pool
23pub struct ConversationStore {
24    pool: Arc<Pool<SqliteConnectionManager>>,
25}
26
27impl ConversationStore {
28    /// Create a new conversation store with a shared connection pool
29    pub fn new(pool: Arc<Pool<SqliteConnectionManager>>) -> Self {
30        Self { pool }
31    }
32
33    /// Internal helper to get a connection from the pool
34    fn get_conn(&self) -> anyhow::Result<r2d2::PooledConnection<SqliteConnectionManager>> {
35        self.pool.get().map_err(|e| anyhow::anyhow!("Failed to get connection from pool: {}", e))
36    }
37
38    /// Public connection accessor for cross-module queries (e.g., search_api)
39    pub fn get_conn_public(&self) -> anyhow::Result<r2d2::PooledConnection<SqliteConnectionManager>> {
40        self.get_conn()
41    }
42
43    // --- Transactional & Internal Helpers ---
44
45    /// Internal helper to update session access using an existing connection or transaction
46    fn update_session_access_with_conn(&self, conn: &Connection, session_id: &str) -> Result<()> {
47        let now = Utc::now().to_rfc3339();
48        conn.execute(
49            "UPDATE sessions SET last_accessed = ?1 WHERE id = ?2",
50            params![now, session_id],
51        )?;
52        Ok(())
53    }
54
55    /// Store a message using an external transaction
56    pub fn store_message_with_tx(
57        &self,
58        tx: &mut Connection,
59        params: MessageParams,
60    ) -> anyhow::Result<StoredMessage> {
61        // Update session access time
62        self.update_session_access_with_conn(tx, params.session_id)?;
63        
64        let now = Utc::now();
65        
66        tx.execute(
67            "INSERT INTO messages 
68             (session_id, message_index, role, content, tokens, timestamp, importance_score, embedding_generated)
69             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
70            params![
71                params.session_id,
72                params.message_index,
73                params.role,
74                params.content,
75                params.tokens,
76                now.to_rfc3339(),
77                params.importance_score,
78                false,
79            ],
80        )?;
81        
82        let id = tx.last_insert_rowid();
83        
84        Ok(StoredMessage {
85            id,
86            session_id: params.session_id.to_string(),
87            message_index: params.message_index,
88            role: params.role.to_string(),
89            content: params.content.to_string(),
90            tokens: params.tokens,
91            timestamp: now,
92            importance_score: params.importance_score,
93            embedding_generated: false,
94        })
95    }
96
97    // --- Batch Operations ---
98
99    /// Store multiple messages in batch
100    pub fn store_messages_batch(
101        &self,
102        session_id: &str,
103        messages: &[(String, String, i32, i32, f32)], // (role, content, index, tokens, importance)
104    ) -> anyhow::Result<Vec<StoredMessage>> {
105        let mut conn = self.get_conn()?;
106        
107        self.update_session_access_with_conn(&conn, session_id)?;
108        
109        let now = Utc::now();
110        let now_str = now.to_rfc3339();
111        let mut stored_messages = Vec::new();
112        
113        let tx = conn.transaction()?;
114        {
115            for (role, content, message_index, tokens, importance_score) in messages.iter() {
116                tx.execute(
117                    "INSERT INTO messages 
118                     (session_id, message_index, role, content, tokens, timestamp, importance_score, embedding_generated)
119                     VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
120                    params![session_id, message_index, role, content, tokens, &now_str, importance_score, false],
121                )?;
122                
123                let id = tx.last_insert_rowid();
124                
125                stored_messages.push(StoredMessage {
126                    id,
127                    session_id: session_id.to_string(),
128                    message_index: *message_index,
129                    role: role.clone(),
130                    content: content.clone(),
131                    tokens: *tokens,
132                    timestamp: now,
133                    importance_score: *importance_score,
134                    embedding_generated: false,
135                });
136
137                // Periodic commit check can be handled by outer logic or left to the full transaction
138                // Note: manual "COMMIT; BEGIN;" inside a rusqlite Transaction is generally discouraged.
139            }
140        }
141        tx.commit()?;
142        
143        debug!("Stored {} messages in batch for session {}", messages.len(), session_id);
144        Ok(stored_messages)
145    }
146
147    /// Store details in batch
148    pub fn store_details_batch(
149        &self,
150        details: &[(&str, i64, &str, &str, &str, f32)], // (session_id, message_id, type, content, context, importance)
151    ) -> anyhow::Result<()> {
152        if details.is_empty() { return Ok(()); }
153        
154        let mut conn = self.get_conn()?;
155        let now = Utc::now().to_rfc3339();
156        let tx = conn.transaction()?;
157        
158        for (session_id, message_id, detail_type, content, context, importance_score) in details {
159            tx.execute(
160                "INSERT INTO details 
161                 (session_id, message_id, detail_type, content, context, importance_score, accessed_count, last_accessed)
162                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
163                params![session_id, message_id, detail_type, content, context, importance_score, 0, &now],
164            )?;
165        }
166        
167        tx.commit()?;
168        debug!("Stored {} details in batch", details.len());
169        Ok(())
170    }
171
172    // --- Session & Message Management ---
173
174    pub fn create_session(&self, metadata: Option<SessionMetadata>) -> anyhow::Result<Session> {
175        let session_id = Uuid::new_v4().to_string();
176        let now = Utc::now();
177        let metadata = metadata.unwrap_or_default();
178        let metadata_json = serde_json::to_string(&metadata)?;
179        
180        let conn = self.get_conn()?;
181        conn.execute(
182            "INSERT INTO sessions (id, created_at, last_accessed, metadata) VALUES (?1, ?2, ?3, ?4)",
183            params![&session_id, now.to_rfc3339(), now.to_rfc3339(), metadata_json],
184        )?;
185        
186        Ok(Session { id: session_id, created_at: now, last_accessed: now, metadata })
187    }
188
189    /// Chat persistence: Create session with frontend-provided ID to maintain ID consistency across frontend and backend
190    pub fn create_session_with_id(&self, session_id: &str, metadata: Option<SessionMetadata>) -> anyhow::Result<Session> {
191        let now = Utc::now();
192        let metadata = metadata.unwrap_or_default();
193        let metadata_json = serde_json::to_string(&metadata)?;
194        
195        let conn = self.get_conn()?;
196        conn.execute(
197            "INSERT INTO sessions (id, created_at, last_accessed, metadata) VALUES (?1, ?2, ?3, ?4)",
198            params![session_id, now.to_rfc3339(), now.to_rfc3339(), metadata_json],
199        )?;
200        
201        info!("Created session with ID: {}", session_id);
202        Ok(Session { id: session_id.to_string(), created_at: now, last_accessed: now, metadata })
203    }
204
205    /// Chat persistence: Update session title after auto-generation, also refresh last_accessed
206    pub fn update_session_title(&self, session_id: &str, title: &str) -> anyhow::Result<()> {
207        let conn = self.get_conn()?;
208        
209        // Fetch current metadata
210        let mut stmt = conn.prepare("SELECT metadata FROM sessions WHERE id = ?1")?;
211        let mut rows = stmt.query([session_id])?;
212        
213        if let Some(row) = rows.next()? {
214            let metadata_json: String = row.get(0)?;
215            let mut metadata: SessionMetadata = serde_json::from_str(&metadata_json)
216                .unwrap_or_default();
217            
218            // Update title
219            metadata.title = Some(title.to_string());
220            let updated_metadata_json = serde_json::to_string(&metadata)?;
221            
222            // Update session with new metadata and timestamp
223            let now = Utc::now();
224            conn.execute(
225                "UPDATE sessions SET metadata = ?1, last_accessed = ?2 WHERE id = ?3",
226                params![updated_metadata_json, now.to_rfc3339(), session_id],
227            )?;
228            
229            info!("Updated session {} title to: {}", session_id, title);
230            Ok(())
231        } else {
232            Err(anyhow::anyhow!("Session {} not found", session_id))
233        }
234    }
235
236    pub fn update_session_pinned(&self, session_id: &str, pinned: bool) -> anyhow::Result<()> {
237        let conn = self.get_conn()?;
238        
239        // Fetch current metadata
240        let mut stmt = conn.prepare("SELECT metadata FROM sessions WHERE id = ?1")?;
241        let mut rows = stmt.query([session_id])?;
242        
243        if let Some(row) = rows.next()? {
244            let metadata_json: String = row.get(0)?;
245            let mut metadata: SessionMetadata = serde_json::from_str(&metadata_json)
246                .unwrap_or_default();
247            
248            // Update pinned status
249            metadata.pinned = pinned;
250            let updated_metadata_json = serde_json::to_string(&metadata)?;
251            
252            // Update session with new metadata only; pinning is a UI organization action
253            // and does not constitute accessing the conversation content, so don't update last_accessed
254            conn.execute(
255                "UPDATE sessions SET metadata = ?1 WHERE id = ?2",
256                params![updated_metadata_json, session_id],
257            )?;
258            
259            info!("Updated session {} pinned status to: {}", session_id, pinned);
260            Ok(())
261        } else {
262            Err(anyhow::anyhow!("Session {} not found", session_id))
263        }
264    }
265
266    pub fn get_session(&self, session_id: &str) -> anyhow::Result<Option<Session>> {
267        let conn = self.get_conn()?;
268        let mut stmt = conn.prepare("SELECT id, created_at, last_accessed, metadata FROM sessions WHERE id = ?1")?;
269        let mut rows = stmt.query([session_id])?;
270        
271        if let Some(row) = rows.next()? {
272            Ok(Some(self.row_to_session(row)?))
273        } else {
274            Ok(None)
275        }
276    }
277
278    /// Chat persistence: Retrieve all sessions for sidebar display, ordered by recency
279    pub fn get_all_sessions(&self) -> anyhow::Result<Vec<Session>> {
280        let conn = self.get_conn()?;
281        let mut stmt = conn.prepare(
282            "SELECT id, created_at, last_accessed, metadata FROM sessions ORDER BY last_accessed DESC"
283        )?;
284        let mut rows = stmt.query([])?;
285        let mut sessions = Vec::new();
286        
287        while let Some(row) = rows.next()? {
288            sessions.push(self.row_to_session(row)?);
289        }
290        
291        Ok(sessions)
292    }
293
294    // --- Parsing Logic ---
295
296    fn parse_datetime_safe(datetime_str: &str) -> Option<DateTime<Utc>> {
297        if let Ok(dt) = DateTime::parse_from_rfc3339(datetime_str) {
298            return Some(dt.with_timezone(&Utc));
299        }
300        if let Ok(dt) = DateTime::parse_from_str(datetime_str, "%+") {
301            return Some(dt.with_timezone(&Utc));
302        }
303        if let Ok(dt) = NaiveDateTime::parse_from_str(datetime_str, "%Y-%m-%d %H:%M:%S") {
304            return Some(DateTime::from_naive_utc_and_offset(dt, Utc));
305        }
306        if let Ok(dt) = NaiveDateTime::parse_from_str(datetime_str, "%Y-%m-%d %H:%M:%S%.f") {
307            return Some(DateTime::from_naive_utc_and_offset(dt, Utc));
308        }
309        None
310    }
311
312    fn row_to_session(&self, row: &Row) -> anyhow::Result<Session> {
313        let metadata_json: String = row.get(3)?;
314        let metadata: SessionMetadata = serde_json::from_str(&metadata_json)
315            .map_err(|e| anyhow::anyhow!("Metadata JSON error: {}", e))?;
316        
317        let created_at = Self::parse_datetime_safe(&row.get::<_, String>(1)?)
318            .unwrap_or_else(|| { warn!("Failed parse created_at"); Utc::now() });
319            
320        let last_accessed = Self::parse_datetime_safe(&row.get::<_, String>(2)?)
321            .unwrap_or_else(|| { warn!("Failed parse last_accessed"); Utc::now() });
322        
323        Ok(Session { id: row.get(0)?, created_at, last_accessed, metadata })
324    }
325
326    fn row_to_stored_message(&self, row: &Row) -> anyhow::Result<StoredMessage> {
327        let timestamp = Self::parse_datetime_safe(&row.get::<_, String>(6)?)
328            .unwrap_or_else(|| { warn!("Failed parse message timestamp"); Utc::now() });
329        
330        Ok(StoredMessage {
331            id: row.get(0)?,
332            session_id: row.get(1)?,
333            message_index: row.get(2)?,
334            role: row.get(3)?,
335            content: row.get(4)?,
336            tokens: row.get(5)?,
337            timestamp,
338            importance_score: row.get(7)?,
339            embedding_generated: row.get(8)?,
340        })
341    }
342
343    // --- Standard Operations (Existing) ---
344
345    pub fn get_session_messages(&self, session_id: &str, limit: Option<i32>, offset: Option<i32>) -> anyhow::Result<Vec<StoredMessage>> {
346        let conn = self.get_conn()?;
347        let mut stmt = conn.prepare(
348            "SELECT id, session_id, message_index, role, content, tokens, timestamp, importance_score, embedding_generated
349             FROM messages WHERE session_id = ?1 ORDER BY message_index LIMIT ?2 OFFSET ?3"
350        )?;
351        let mut rows = stmt.query(params![session_id, limit.unwrap_or(1000), offset.unwrap_or(0)])?;
352        let mut messages = Vec::new();
353        while let Some(row) = rows.next()? { messages.push(self.row_to_stored_message(row)?); }
354        Ok(messages)
355    }
356
357    pub fn get_session_message_count(&self, session_id: &str) -> anyhow::Result<usize> {
358        let conn = self.get_conn()?;
359        let count: i64 = conn.query_row(
360            "SELECT COUNT(*) FROM messages WHERE session_id = ?1",
361            [session_id],
362            |row| row.get(0)
363        )?;
364        Ok(count as usize)
365    }
366
367    pub fn mark_embedding_generated(&self, message_id: i64) -> anyhow::Result<()> {
368        let conn = self.get_conn()?;
369        conn.execute("UPDATE messages SET embedding_generated = TRUE WHERE id = ?1", [message_id])?;
370        Ok(())
371    }
372
373    pub fn delete_session(&self, session_id: &str) -> anyhow::Result<usize> {
374        let conn = self.get_conn()?;
375        let deleted = conn.execute("DELETE FROM sessions WHERE id = ?1", [session_id])?;
376        info!("Deleted session {}", session_id);
377        Ok(deleted)
378    }
379
380    // --- New Method for Cache Management ---
381
382    /// Search messages by keywords
383    pub async fn search_messages_by_keywords(
384        &self,
385        session_id: &str,
386        keywords: &[String],
387        limit: usize,
388    ) -> anyhow::Result<Vec<StoredMessage>> {
389        let conn = self.get_conn()?;
390        
391        // Build search patterns
392        let patterns: Vec<String> = keywords.iter()
393            .map(|k| format!("%{}%", k.to_lowercase()))
394            .collect();
395        
396        // Build query
397        let mut query = String::from(
398            "SELECT id, session_id, message_index, role, content, tokens, 
399                    timestamp, importance_score, embedding_generated
400             FROM messages 
401             WHERE session_id = ?1"
402        );
403        
404        for i in 0..patterns.len() {
405            query.push_str(&format!(" AND LOWER(content) LIKE ?{}", i + 2));
406        }
407        
408        query.push_str(" ORDER BY timestamp DESC LIMIT ?");
409        
410        let mut stmt = conn.prepare(&query)?;
411        
412        // Build parameters
413        let mut params: Vec<&dyn rusqlite::ToSql> = Vec::new();
414        params.push(&session_id);
415        for pattern in &patterns {
416            params.push(pattern);
417        }
418        // FIX: Store in variable to avoid temporary reference
419        let limit_i64 = limit as i64;
420        params.push(&limit_i64);
421        
422        let mut rows = stmt.query(rusqlite::params_from_iter(params))?;
423        let mut messages = Vec::new();
424        
425        while let Some(row) = rows.next()? {
426            messages.push(self.row_to_stored_message(row)?);
427        }
428        
429        Ok(messages)
430    }
431
432    /// Search messages by topic keywords across sessions
433    pub async fn search_messages_by_topic_across_sessions(
434        &self,
435        topic_keywords: &[String],
436        limit: usize,
437        session_id_filter: Option<&str>, // Optional: exclude or include specific sessions
438    ) -> anyhow::Result<Vec<StoredMessage>> {
439        let conn = self.get_conn()?;
440        
441        // Build search patterns
442        let patterns: Vec<String> = topic_keywords.iter()
443            .map(|k| format!("%{}%", k.to_lowercase()))
444            .collect();
445        
446        // Build query with session filtering
447        let mut query = String::from(
448            "SELECT m.id, m.session_id, m.message_index, m.role, m.content, 
449                    m.tokens, m.timestamp, m.importance_score, m.embedding_generated
450             FROM messages m
451             JOIN sessions s ON m.session_id = s.id
452             WHERE 1=1"
453        );
454        
455        // Add session filter if provided - use Box<dyn ToSql> to store owned values
456        let mut params: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
457        if let Some(session_id) = session_id_filter {
458            query.push_str(" AND m.session_id != ?");
459            params.push(Box::new(session_id.to_string())); // Store owned string
460        }
461        
462        // Add keyword search
463        for pattern in &patterns {
464            query.push_str(" AND LOWER(m.content) LIKE ?");
465            params.push(Box::new(pattern.clone())); // Clone the pattern
466        }
467        
468        // Order by relevance (keyword matches + recency + importance)
469        query.push_str(" ORDER BY 
470            m.importance_score DESC,
471            CASE WHEN m.role = 'assistant' THEN 1 ELSE 0 END, -- Prioritize assistant responses
472            s.last_accessed DESC,
473            m.timestamp DESC
474            LIMIT ?");
475        
476        // Store limit in variable
477        let limit_i64 = limit as i64;
478        params.push(Box::new(limit_i64));
479        
480        let mut stmt = conn.prepare(&query)?;
481        
482        // Convert params to references for the query
483        let param_refs: Vec<&dyn rusqlite::ToSql> = params.iter()
484            .map(|p| p.as_ref())
485            .collect();
486        
487        let mut rows = stmt.query(rusqlite::params_from_iter(param_refs))?;
488        let mut messages = Vec::new();
489        
490        while let Some(row) = rows.next()? {
491            let timestamp_str: String = row.get(6)?;
492            let timestamp = chrono::DateTime::parse_from_rfc3339(&timestamp_str)
493                .map_err(|e| anyhow::anyhow!("Failed to parse timestamp: {}", e))?
494                .with_timezone(&chrono::Utc);
495            
496            messages.push(StoredMessage {
497                id: row.get(0)?,
498                session_id: row.get(1)?,
499                message_index: row.get(2)?,
500                role: row.get(3)?,
501                content: row.get(4)?,
502                tokens: row.get(5)?,
503                timestamp,
504                importance_score: row.get(7)?,
505                embedding_generated: row.get(8)?,
506            });
507        }
508        
509        Ok(messages)
510    }
511}