Skip to main content

offline_intelligence/memory_db/
conversation_store.rs

1//! Conversation storage and retrieval operations with batch support and safe parsing
2
3use crate::memory_db::schema::*;
4use rusqlite::{params, Result, Row, Connection};
5use chrono::{DateTime, Utc, NaiveDateTime};
6use uuid::Uuid;
7use tracing::{info, debug, warn};
8use std::sync::Arc;
9use r2d2::Pool;
10use r2d2_sqlite::SqliteConnectionManager;
11
12/// Parameters for storing a message
13pub struct MessageParams<'a> {
14    pub session_id: &'a str,
15    pub role: &'a str,
16    pub content: &'a str,
17    pub message_index: i32,
18    pub tokens: i32,
19    pub importance_score: f32,
20}
21
22/// Manages conversation storage and retrieval using a connection pool
23pub struct ConversationStore {
24    pool: Arc<Pool<SqliteConnectionManager>>,
25}
26
27impl ConversationStore {
28    /// Create a new conversation store with a shared connection pool
29    pub fn new(pool: Arc<Pool<SqliteConnectionManager>>) -> Self {
30        Self { pool }
31    }
32
33    /// Internal helper to get a connection from the pool
34    fn get_conn(&self) -> anyhow::Result<r2d2::PooledConnection<SqliteConnectionManager>> {
35        self.pool.get().map_err(|e| anyhow::anyhow!("Failed to get connection from pool: {}", e))
36    }
37
38    /// Public connection accessor for cross-module queries (e.g., search_api)
39    pub fn get_conn_public(&self) -> anyhow::Result<r2d2::PooledConnection<SqliteConnectionManager>> {
40        self.get_conn()
41    }
42
43    // --- Transactional & Internal Helpers ---
44
45    /// Internal helper to update session access using an existing connection or transaction
46    fn update_session_access_with_conn(&self, conn: &Connection, session_id: &str) -> Result<()> {
47        let now = Utc::now().to_rfc3339();
48        conn.execute(
49            "UPDATE sessions SET last_accessed = ?1 WHERE id = ?2",
50            params![now, session_id],
51        )?;
52        Ok(())
53    }
54
55    /// Store a message using an external transaction
56    pub fn store_message_with_tx(
57        &self,
58        tx: &mut Connection,
59        params: MessageParams,
60    ) -> anyhow::Result<StoredMessage> {
61        // Update session access time
62        self.update_session_access_with_conn(tx, params.session_id)?;
63        
64        let now = Utc::now();
65        
66        tx.execute(
67            "INSERT INTO messages 
68             (session_id, message_index, role, content, tokens, timestamp, importance_score, embedding_generated)
69             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
70            params![
71                params.session_id,
72                params.message_index,
73                params.role,
74                params.content,
75                params.tokens,
76                now.to_rfc3339(),
77                params.importance_score,
78                false,
79            ],
80        )?;
81        
82        let id = tx.last_insert_rowid();
83        
84        Ok(StoredMessage {
85            id,
86            session_id: params.session_id.to_string(),
87            message_index: params.message_index,
88            role: params.role.to_string(),
89            content: params.content.to_string(),
90            tokens: params.tokens,
91            timestamp: now,
92            importance_score: params.importance_score,
93            embedding_generated: false,
94        })
95    }
96
97    // --- Batch Operations ---
98
99    /// Store multiple messages in batch
100    pub fn store_messages_batch(
101        &self,
102        session_id: &str,
103        messages: &[(String, String, i32, i32, f32)], // (role, content, index, tokens, importance)
104    ) -> anyhow::Result<Vec<StoredMessage>> {
105        let mut conn = self.get_conn()?;
106        
107        self.update_session_access_with_conn(&conn, session_id)?;
108        
109        let now = Utc::now();
110        let now_str = now.to_rfc3339();
111        let mut stored_messages = Vec::new();
112        
113        let tx = conn.transaction()?;
114        {
115            for (role, content, message_index, tokens, importance_score) in messages.iter() {
116                tx.execute(
117                    "INSERT INTO messages 
118                     (session_id, message_index, role, content, tokens, timestamp, importance_score, embedding_generated)
119                     VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
120                    params![session_id, message_index, role, content, tokens, &now_str, importance_score, false],
121                )?;
122                
123                let id = tx.last_insert_rowid();
124                
125                stored_messages.push(StoredMessage {
126                    id,
127                    session_id: session_id.to_string(),
128                    message_index: *message_index,
129                    role: role.clone(),
130                    content: content.clone(),
131                    tokens: *tokens,
132                    timestamp: now,
133                    importance_score: *importance_score,
134                    embedding_generated: false,
135                });
136
137                // Periodic commit check can be handled by outer logic or left to the full transaction
138                // Note: manual "COMMIT; BEGIN;" inside a rusqlite Transaction is generally discouraged.
139            }
140        }
141        tx.commit()?;
142        
143        debug!("Stored {} messages in batch for session {}", messages.len(), session_id);
144        Ok(stored_messages)
145    }
146
147    /// Store details in batch
148    pub fn store_details_batch(
149        &self,
150        details: &[(&str, i64, &str, &str, &str, f32)], // (session_id, message_id, type, content, context, importance)
151    ) -> anyhow::Result<()> {
152        if details.is_empty() { return Ok(()); }
153        
154        let mut conn = self.get_conn()?;
155        let now = Utc::now().to_rfc3339();
156        let tx = conn.transaction()?;
157        
158        for (session_id, message_id, detail_type, content, context, importance_score) in details {
159            tx.execute(
160                "INSERT INTO details 
161                 (session_id, message_id, detail_type, content, context, importance_score, accessed_count, last_accessed)
162                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
163                params![session_id, message_id, detail_type, content, context, importance_score, 0, &now],
164            )?;
165        }
166        
167        tx.commit()?;
168        debug!("Stored {} details in batch", details.len());
169        Ok(())
170    }
171
172    // --- Session & Message Management ---
173
174    pub fn create_session(&self, metadata: Option<SessionMetadata>) -> anyhow::Result<Session> {
175        let session_id = Uuid::new_v4().to_string();
176        let now = Utc::now();
177        let metadata = metadata.unwrap_or_default();
178        let metadata_json = serde_json::to_string(&metadata)?;
179        
180        let conn = self.get_conn()?;
181        conn.execute(
182            "INSERT INTO sessions (id, created_at, last_accessed, metadata) VALUES (?1, ?2, ?3, ?4)",
183            params![&session_id, now.to_rfc3339(), now.to_rfc3339(), metadata_json],
184        )?;
185        
186        Ok(Session { id: session_id, created_at: now, last_accessed: now, metadata })
187    }
188
189    /// Chat persistence: Create session with frontend-provided ID to maintain ID consistency across frontend and backend
190    pub fn create_session_with_id(&self, session_id: &str, metadata: Option<SessionMetadata>) -> anyhow::Result<Session> {
191        let now = Utc::now();
192        let metadata = metadata.unwrap_or_default();
193        let metadata_json = serde_json::to_string(&metadata)?;
194        
195        let conn = self.get_conn()?;
196        conn.execute(
197            "INSERT INTO sessions (id, created_at, last_accessed, metadata) VALUES (?1, ?2, ?3, ?4)",
198            params![session_id, now.to_rfc3339(), now.to_rfc3339(), metadata_json],
199        )?;
200        
201        info!("Created session with ID: {}", session_id);
202        Ok(Session { id: session_id.to_string(), created_at: now, last_accessed: now, metadata })
203    }
204
205    /// Chat persistence: Update session title after auto-generation, also refresh last_accessed.
206    /// Handles the race condition where title update arrives before the session is created
207    /// by the streaming endpoint, using INSERT OR IGNORE to avoid UNIQUE constraint failures.
208    pub fn update_session_title(&self, session_id: &str, title: &str) -> anyhow::Result<()> {
209        let conn = self.get_conn()?;
210        let now = Utc::now();
211
212        // Ensure the session exists first (INSERT OR IGNORE handles concurrent creation)
213        let default_metadata = SessionMetadata {
214            title: Some(title.to_string()),
215            ..Default::default()
216        };
217        let default_metadata_json = serde_json::to_string(&default_metadata)?;
218
219        conn.execute(
220            "INSERT OR IGNORE INTO sessions (id, created_at, last_accessed, metadata) VALUES (?1, ?2, ?3, ?4)",
221            params![session_id, now.to_rfc3339(), now.to_rfc3339(), default_metadata_json],
222        )?;
223
224        // Now fetch and update metadata (session is guaranteed to exist)
225        let mut stmt = conn.prepare("SELECT metadata FROM sessions WHERE id = ?1")?;
226        let mut rows = stmt.query([session_id])?;
227
228        if let Some(row) = rows.next()? {
229            let metadata_json: String = row.get(0)?;
230            let mut metadata: SessionMetadata = serde_json::from_str(&metadata_json)
231                .unwrap_or_default();
232
233            metadata.title = Some(title.to_string());
234            let updated_metadata_json = serde_json::to_string(&metadata)?;
235
236            conn.execute(
237                "UPDATE sessions SET metadata = ?1, last_accessed = ?2 WHERE id = ?3",
238                params![updated_metadata_json, now.to_rfc3339(), session_id],
239            )?;
240
241            info!("Updated session {} title to: {}", session_id, title);
242        }
243
244        Ok(())
245    }
246
247    pub fn update_session_pinned(&self, session_id: &str, pinned: bool) -> anyhow::Result<()> {
248        let conn = self.get_conn()?;
249        
250        // Fetch current metadata
251        let mut stmt = conn.prepare("SELECT metadata FROM sessions WHERE id = ?1")?;
252        let mut rows = stmt.query([session_id])?;
253        
254        if let Some(row) = rows.next()? {
255            let metadata_json: String = row.get(0)?;
256            let mut metadata: SessionMetadata = serde_json::from_str(&metadata_json)
257                .unwrap_or_default();
258            
259            // Update pinned status
260            metadata.pinned = pinned;
261            let updated_metadata_json = serde_json::to_string(&metadata)?;
262            
263            // Update session with new metadata only; pinning is a UI organization action
264            // and does not constitute accessing the conversation content, so don't update last_accessed
265            conn.execute(
266                "UPDATE sessions SET metadata = ?1 WHERE id = ?2",
267                params![updated_metadata_json, session_id],
268            )?;
269            
270            info!("Updated session {} pinned status to: {}", session_id, pinned);
271            Ok(())
272        } else {
273            Err(anyhow::anyhow!("Session {} not found", session_id))
274        }
275    }
276
277    pub fn get_session(&self, session_id: &str) -> anyhow::Result<Option<Session>> {
278        let conn = self.get_conn()?;
279        let mut stmt = conn.prepare("SELECT id, created_at, last_accessed, metadata FROM sessions WHERE id = ?1")?;
280        let mut rows = stmt.query([session_id])?;
281        
282        if let Some(row) = rows.next()? {
283            Ok(Some(self.row_to_session(row)?))
284        } else {
285            Ok(None)
286        }
287    }
288
289    /// Chat persistence: Retrieve all sessions for sidebar display, ordered by recency
290    pub fn get_all_sessions(&self) -> anyhow::Result<Vec<Session>> {
291        let conn = self.get_conn()?;
292        let mut stmt = conn.prepare(
293            "SELECT id, created_at, last_accessed, metadata FROM sessions ORDER BY last_accessed DESC"
294        )?;
295        let mut rows = stmt.query([])?;
296        let mut sessions = Vec::new();
297        
298        while let Some(row) = rows.next()? {
299            sessions.push(self.row_to_session(row)?);
300        }
301        
302        Ok(sessions)
303    }
304
305    // --- Parsing Logic ---
306
307    fn parse_datetime_safe(datetime_str: &str) -> Option<DateTime<Utc>> {
308        if let Ok(dt) = DateTime::parse_from_rfc3339(datetime_str) {
309            return Some(dt.with_timezone(&Utc));
310        }
311        if let Ok(dt) = DateTime::parse_from_str(datetime_str, "%+") {
312            return Some(dt.with_timezone(&Utc));
313        }
314        if let Ok(dt) = NaiveDateTime::parse_from_str(datetime_str, "%Y-%m-%d %H:%M:%S") {
315            return Some(DateTime::from_naive_utc_and_offset(dt, Utc));
316        }
317        if let Ok(dt) = NaiveDateTime::parse_from_str(datetime_str, "%Y-%m-%d %H:%M:%S%.f") {
318            return Some(DateTime::from_naive_utc_and_offset(dt, Utc));
319        }
320        None
321    }
322
323    fn row_to_session(&self, row: &Row) -> anyhow::Result<Session> {
324        let metadata_json: String = row.get(3)?;
325        let metadata: SessionMetadata = serde_json::from_str(&metadata_json)
326            .map_err(|e| anyhow::anyhow!("Metadata JSON error: {}", e))?;
327        
328        let created_at = Self::parse_datetime_safe(&row.get::<_, String>(1)?)
329            .unwrap_or_else(|| { warn!("Failed parse created_at"); Utc::now() });
330            
331        let last_accessed = Self::parse_datetime_safe(&row.get::<_, String>(2)?)
332            .unwrap_or_else(|| { warn!("Failed parse last_accessed"); Utc::now() });
333        
334        Ok(Session { id: row.get(0)?, created_at, last_accessed, metadata })
335    }
336
337    fn row_to_stored_message(&self, row: &Row) -> anyhow::Result<StoredMessage> {
338        let timestamp = Self::parse_datetime_safe(&row.get::<_, String>(6)?)
339            .unwrap_or_else(|| { warn!("Failed parse message timestamp"); Utc::now() });
340        
341        Ok(StoredMessage {
342            id: row.get(0)?,
343            session_id: row.get(1)?,
344            message_index: row.get(2)?,
345            role: row.get(3)?,
346            content: row.get(4)?,
347            tokens: row.get(5)?,
348            timestamp,
349            importance_score: row.get(7)?,
350            embedding_generated: row.get(8)?,
351        })
352    }
353
354    // --- Standard Operations (Existing) ---
355
356    pub fn get_session_messages(&self, session_id: &str, limit: Option<i32>, offset: Option<i32>) -> anyhow::Result<Vec<StoredMessage>> {
357        let conn = self.get_conn()?;
358        let mut stmt = conn.prepare(
359            "SELECT id, session_id, message_index, role, content, tokens, timestamp, importance_score, embedding_generated
360             FROM messages WHERE session_id = ?1 ORDER BY message_index LIMIT ?2 OFFSET ?3"
361        )?;
362        let mut rows = stmt.query(params![session_id, limit.unwrap_or(1000), offset.unwrap_or(0)])?;
363        let mut messages = Vec::new();
364        while let Some(row) = rows.next()? { messages.push(self.row_to_stored_message(row)?); }
365        Ok(messages)
366    }
367
368    pub fn get_session_message_count(&self, session_id: &str) -> anyhow::Result<usize> {
369        let conn = self.get_conn()?;
370        let count: i64 = conn.query_row(
371            "SELECT COUNT(*) FROM messages WHERE session_id = ?1",
372            [session_id],
373            |row| row.get(0)
374        )?;
375        Ok(count as usize)
376    }
377
378    pub fn mark_embedding_generated(&self, message_id: i64) -> anyhow::Result<()> {
379        let conn = self.get_conn()?;
380        conn.execute("UPDATE messages SET embedding_generated = TRUE WHERE id = ?1", [message_id])?;
381        Ok(())
382    }
383
384    pub fn delete_session(&self, session_id: &str) -> anyhow::Result<usize> {
385        let conn = self.get_conn()?;
386        let deleted = conn.execute("DELETE FROM sessions WHERE id = ?1", [session_id])?;
387        info!("Deleted session {}", session_id);
388        Ok(deleted)
389    }
390
391    // --- New Method for Cache Management ---
392
393    /// Search messages by keywords
394    pub async fn search_messages_by_keywords(
395        &self,
396        session_id: &str,
397        keywords: &[String],
398        limit: usize,
399    ) -> anyhow::Result<Vec<StoredMessage>> {
400        let conn = self.get_conn()?;
401        
402        // Build search patterns
403        let patterns: Vec<String> = keywords.iter()
404            .map(|k| format!("%{}%", k.to_lowercase()))
405            .collect();
406        
407        // Build query
408        let mut query = String::from(
409            "SELECT id, session_id, message_index, role, content, tokens, 
410                    timestamp, importance_score, embedding_generated
411             FROM messages 
412             WHERE session_id = ?1"
413        );
414        
415        for i in 0..patterns.len() {
416            query.push_str(&format!(" AND LOWER(content) LIKE ?{}", i + 2));
417        }
418        
419        query.push_str(" ORDER BY timestamp DESC LIMIT ?");
420        
421        let mut stmt = conn.prepare(&query)?;
422        
423        // Build parameters
424        let mut params: Vec<&dyn rusqlite::ToSql> = Vec::new();
425        params.push(&session_id);
426        for pattern in &patterns {
427            params.push(pattern);
428        }
429        // FIX: Store in variable to avoid temporary reference
430        let limit_i64 = limit as i64;
431        params.push(&limit_i64);
432        
433        let mut rows = stmt.query(rusqlite::params_from_iter(params))?;
434        let mut messages = Vec::new();
435        
436        while let Some(row) = rows.next()? {
437            messages.push(self.row_to_stored_message(row)?);
438        }
439        
440        Ok(messages)
441    }
442
443    /// Search messages by topic keywords across sessions
444    pub async fn search_messages_by_topic_across_sessions(
445        &self,
446        topic_keywords: &[String],
447        limit: usize,
448        session_id_filter: Option<&str>, // Optional: exclude or include specific sessions
449    ) -> anyhow::Result<Vec<StoredMessage>> {
450        let conn = self.get_conn()?;
451        
452        // Build search patterns
453        let patterns: Vec<String> = topic_keywords.iter()
454            .map(|k| format!("%{}%", k.to_lowercase()))
455            .collect();
456        
457        // Build query with session filtering
458        let mut query = String::from(
459            "SELECT m.id, m.session_id, m.message_index, m.role, m.content, 
460                    m.tokens, m.timestamp, m.importance_score, m.embedding_generated
461             FROM messages m
462             JOIN sessions s ON m.session_id = s.id
463             WHERE 1=1"
464        );
465        
466        // Add session filter if provided - use Box<dyn ToSql> to store owned values
467        let mut params: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
468        if let Some(session_id) = session_id_filter {
469            query.push_str(" AND m.session_id != ?");
470            params.push(Box::new(session_id.to_string())); // Store owned string
471        }
472        
473        // Add keyword search
474        for pattern in &patterns {
475            query.push_str(" AND LOWER(m.content) LIKE ?");
476            params.push(Box::new(pattern.clone())); // Clone the pattern
477        }
478        
479        // Order by relevance (keyword matches + recency + importance)
480        query.push_str(" ORDER BY 
481            m.importance_score DESC,
482            CASE WHEN m.role = 'assistant' THEN 1 ELSE 0 END, -- Prioritize assistant responses
483            s.last_accessed DESC,
484            m.timestamp DESC
485            LIMIT ?");
486        
487        // Store limit in variable
488        let limit_i64 = limit as i64;
489        params.push(Box::new(limit_i64));
490        
491        let mut stmt = conn.prepare(&query)?;
492        
493        // Convert params to references for the query
494        let param_refs: Vec<&dyn rusqlite::ToSql> = params.iter()
495            .map(|p| p.as_ref())
496            .collect();
497        
498        let mut rows = stmt.query(rusqlite::params_from_iter(param_refs))?;
499        let mut messages = Vec::new();
500        
501        while let Some(row) = rows.next()? {
502            let timestamp_str: String = row.get(6)?;
503            let timestamp = chrono::DateTime::parse_from_rfc3339(&timestamp_str)
504                .map_err(|e| anyhow::anyhow!("Failed to parse timestamp: {}", e))?
505                .with_timezone(&chrono::Utc);
506            
507            messages.push(StoredMessage {
508                id: row.get(0)?,
509                session_id: row.get(1)?,
510                message_index: row.get(2)?,
511                role: row.get(3)?,
512                content: row.get(4)?,
513                tokens: row.get(5)?,
514                timestamp,
515                importance_score: row.get(7)?,
516                embedding_generated: row.get(8)?,
517            });
518        }
519        
520        Ok(messages)
521    }
522}