offline_intelligence/memory_db/
conversation_store.rs

1//! Conversation storage and retrieval operations with batch support and safe parsing
2
3use crate::memory_db::schema::*;
4use rusqlite::{params, Result, Row, Connection};
5use chrono::{DateTime, Utc, NaiveDateTime};
6use uuid::Uuid;
7use tracing::{info, debug, warn};
8use std::sync::Arc;
9use r2d2::Pool;
10use r2d2_sqlite::SqliteConnectionManager;
11
12/// Manages conversation storage and retrieval using a connection pool
13pub struct ConversationStore {
14    pool: Arc<Pool<SqliteConnectionManager>>,
15}
16
17impl ConversationStore {
18    /// Create a new conversation store with a shared connection pool
19    pub fn new(pool: Arc<Pool<SqliteConnectionManager>>) -> Self {
20        Self { pool }
21    }
22
23    /// Internal helper to get a connection from the pool
24    fn get_conn(&self) -> anyhow::Result<r2d2::PooledConnection<SqliteConnectionManager>> {
25        self.pool.get().map_err(|e| anyhow::anyhow!("Failed to get connection from pool: {}", e))
26    }
27
28    // --- Transactional & Internal Helpers ---
29
30    /// Internal helper to update session access using an existing connection or transaction
31    fn update_session_access_with_conn(&self, conn: &Connection, session_id: &str) -> Result<()> {
32        let now = Utc::now().to_rfc3339();
33        conn.execute(
34            "UPDATE sessions SET last_accessed = ?1 WHERE id = ?2",
35            params![now, session_id],
36        )?;
37        Ok(())
38    }
39
40    /// Store a message using an external transaction
41    pub fn store_message_with_tx(
42        &self,
43        tx: &mut Connection, // Accepts a connection or transaction
44        session_id: &str,
45        role: &str,
46        content: &str,
47        message_index: i32,
48        tokens: i32,
49        importance_score: f32,
50    ) -> anyhow::Result<StoredMessage> {
51        // Update session access time
52        self.update_session_access_with_conn(tx, session_id)?;
53        
54        let now = Utc::now();
55        
56        tx.execute(
57            "INSERT INTO messages 
58             (session_id, message_index, role, content, tokens, timestamp, importance_score, embedding_generated)
59             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
60            params![
61                session_id,
62                message_index,
63                role,
64                content,
65                tokens,
66                now.to_rfc3339(),
67                importance_score,
68                false,
69            ],
70        )?;
71        
72        let id = tx.last_insert_rowid();
73        
74        Ok(StoredMessage {
75            id,
76            session_id: session_id.to_string(),
77            message_index,
78            role: role.to_string(),
79            content: content.to_string(),
80            tokens,
81            timestamp: now,
82            importance_score,
83            embedding_generated: false,
84        })
85    }
86
87    // --- Batch Operations ---
88
89    /// Store multiple messages in batch
90    pub fn store_messages_batch(
91        &self,
92        session_id: &str,
93        messages: &[(String, String, i32, i32, f32)], // (role, content, index, tokens, importance)
94    ) -> anyhow::Result<Vec<StoredMessage>> {
95        let mut conn = self.get_conn()?;
96        
97        self.update_session_access_with_conn(&conn, session_id)?;
98        
99        let now = Utc::now();
100        let now_str = now.to_rfc3339();
101        let mut stored_messages = Vec::new();
102        
103        let tx = conn.transaction()?;
104        {
105            for (idx, (role, content, message_index, tokens, importance_score)) in messages.iter().enumerate() {
106                tx.execute(
107                    "INSERT INTO messages 
108                     (session_id, message_index, role, content, tokens, timestamp, importance_score, embedding_generated)
109                     VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
110                    params![session_id, message_index, role, content, tokens, &now_str, importance_score, false],
111                )?;
112                
113                let id = tx.last_insert_rowid();
114                
115                stored_messages.push(StoredMessage {
116                    id,
117                    session_id: session_id.to_string(),
118                    message_index: *message_index,
119                    role: role.clone(),
120                    content: content.clone(),
121                    tokens: *tokens,
122                    timestamp: now,
123                    importance_score: *importance_score,
124                    embedding_generated: false,
125                });
126
127                // Periodic commit check can be handled by outer logic or left to the full transaction
128                // Note: manual "COMMIT; BEGIN;" inside a rusqlite Transaction is generally discouraged.
129            }
130        }
131        tx.commit()?;
132        
133        debug!("Stored {} messages in batch for session {}", messages.len(), session_id);
134        Ok(stored_messages)
135    }
136
137    /// Store details in batch
138    pub fn store_details_batch(
139        &self,
140        details: &[(&str, i64, &str, &str, &str, f32)], // (session_id, message_id, type, content, context, importance)
141    ) -> anyhow::Result<()> {
142        if details.is_empty() { return Ok(()); }
143        
144        let mut conn = self.get_conn()?;
145        let now = Utc::now().to_rfc3339();
146        let tx = conn.transaction()?;
147        
148        for (session_id, message_id, detail_type, content, context, importance_score) in details {
149            tx.execute(
150                "INSERT INTO details 
151                 (session_id, message_id, detail_type, content, context, importance_score, accessed_count, last_accessed)
152                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
153                params![session_id, message_id, detail_type, content, context, importance_score, 0, &now],
154            )?;
155        }
156        
157        tx.commit()?;
158        debug!("Stored {} details in batch", details.len());
159        Ok(())
160    }
161
162    // --- Session & Message Management ---
163
164    pub fn create_session(&self, metadata: Option<SessionMetadata>) -> anyhow::Result<Session> {
165        self.create_session_with_id(None, metadata)
166    }
167
168    pub fn create_session_with_id(&self, session_id: Option<String>, metadata: Option<SessionMetadata>) -> anyhow::Result<Session> {
169        let session_id = session_id.unwrap_or_else(|| Uuid::new_v4().to_string());
170        let now = Utc::now();
171        let metadata = metadata.unwrap_or_default();
172        let metadata_json = serde_json::to_string(&metadata)?;
173        
174        let conn = self.get_conn()?;
175        conn.execute(
176            "INSERT INTO sessions (id, created_at, last_accessed, metadata) VALUES (?1, ?2, ?3, ?4)",
177            params![&session_id, now.to_rfc3339(), now.to_rfc3339(), metadata_json],
178        )?;
179        
180        Ok(Session { id: session_id, created_at: now, last_accessed: now, metadata })
181    }
182
183    pub fn get_session(&self, session_id: &str) -> anyhow::Result<Option<Session>> {
184        let conn = self.get_conn()?;
185        let mut stmt = conn.prepare("SELECT id, created_at, last_accessed, metadata FROM sessions WHERE id = ?1")?;
186        let mut rows = stmt.query([session_id])?;
187        
188        if let Some(row) = rows.next()? {
189            Ok(Some(self.row_to_session(row)?))
190        } else {
191            Ok(None)
192        }
193    }
194
195    // --- Parsing Logic ---
196
197    fn parse_datetime_safe(datetime_str: &str) -> Option<DateTime<Utc>> {
198        if let Ok(dt) = DateTime::parse_from_rfc3339(datetime_str) {
199            return Some(dt.with_timezone(&Utc));
200        }
201        if let Ok(dt) = DateTime::parse_from_str(datetime_str, "%+") {
202            return Some(dt.with_timezone(&Utc));
203        }
204        if let Ok(dt) = NaiveDateTime::parse_from_str(datetime_str, "%Y-%m-%d %H:%M:%S") {
205            return Some(DateTime::from_naive_utc_and_offset(dt, Utc));
206        }
207        if let Ok(dt) = NaiveDateTime::parse_from_str(datetime_str, "%Y-%m-%d %H:%M:%S%.f") {
208            return Some(DateTime::from_naive_utc_and_offset(dt, Utc));
209        }
210        None
211    }
212
213    fn row_to_session(&self, row: &Row) -> anyhow::Result<Session> {
214        let metadata_json: String = row.get(3)?;
215        let metadata: SessionMetadata = serde_json::from_str(&metadata_json)
216            .map_err(|e| anyhow::anyhow!("Metadata JSON error: {}", e))?;
217        
218        let created_at = Self::parse_datetime_safe(&row.get::<_, String>(1)?)
219            .unwrap_or_else(|| { warn!("Failed parse created_at"); Utc::now() });
220            
221        let last_accessed = Self::parse_datetime_safe(&row.get::<_, String>(2)?)
222            .unwrap_or_else(|| { warn!("Failed parse last_accessed"); Utc::now() });
223        
224        Ok(Session { id: row.get(0)?, created_at, last_accessed, metadata })
225    }
226
227    fn row_to_stored_message(&self, row: &Row) -> anyhow::Result<StoredMessage> {
228        let timestamp = Self::parse_datetime_safe(&row.get::<_, String>(6)?)
229            .unwrap_or_else(|| { warn!("Failed parse message timestamp"); Utc::now() });
230        
231        Ok(StoredMessage {
232            id: row.get(0)?,
233            session_id: row.get(1)?,
234            message_index: row.get(2)?,
235            role: row.get(3)?,
236            content: row.get(4)?,
237            tokens: row.get(5)?,
238            timestamp,
239            importance_score: row.get(7)?,
240            embedding_generated: row.get(8)?,
241        })
242    }
243
244    // --- Standard Operations (Existing) ---
245
246    pub fn get_session_messages(&self, session_id: &str, limit: Option<i32>, offset: Option<i32>) -> anyhow::Result<Vec<StoredMessage>> {
247        let conn = self.get_conn()?;
248        let mut stmt = conn.prepare(
249            "SELECT id, session_id, message_index, role, content, tokens, timestamp, importance_score, embedding_generated
250             FROM messages WHERE session_id = ?1 ORDER BY message_index LIMIT ?2 OFFSET ?3"
251        )?;
252        let mut rows = stmt.query(params![session_id, limit.unwrap_or(1000), offset.unwrap_or(0)])?;
253        let mut messages = Vec::new();
254        while let Some(row) = rows.next()? { messages.push(self.row_to_stored_message(row)?); }
255        Ok(messages)
256    }
257
258    pub fn mark_embedding_generated(&self, message_id: i64) -> anyhow::Result<()> {
259        let conn = self.get_conn()?;
260        conn.execute("UPDATE messages SET embedding_generated = TRUE WHERE id = ?1", [message_id])?;
261        Ok(())
262    }
263
264    pub fn delete_session(&self, session_id: &str) -> anyhow::Result<usize> {
265        let conn = self.get_conn()?;
266        let deleted = conn.execute("DELETE FROM sessions WHERE id = ?1", [session_id])?;
267        info!("Deleted session {}", session_id);
268        Ok(deleted)
269    }
270
271    // --- New Method for Cache Management ---
272
273    /// Search messages by keywords
274    pub async fn search_messages_by_keywords(
275        &self,
276        session_id: &str,
277        keywords: &[String],
278        limit: usize,
279    ) -> anyhow::Result<Vec<StoredMessage>> {
280        let conn = self.get_conn()?;
281        
282        // Build search patterns
283        let patterns: Vec<String> = keywords.iter()
284            .map(|k| format!("%{}%", k.to_lowercase()))
285            .collect();
286        
287        // Build query
288        let mut query = String::from(
289            "SELECT id, session_id, message_index, role, content, tokens, 
290                    timestamp, importance_score, embedding_generated
291             FROM messages 
292             WHERE session_id = ?1"
293        );
294        
295        for i in 0..patterns.len() {
296            query.push_str(&format!(" AND LOWER(content) LIKE ?{}", i + 2));
297        }
298        
299        query.push_str(" ORDER BY timestamp DESC LIMIT ?");
300        
301        let mut stmt = conn.prepare(&query)?;
302        
303        // Build parameters
304        let mut params: Vec<&dyn rusqlite::ToSql> = Vec::new();
305        params.push(&session_id);
306        for pattern in &patterns {
307            params.push(pattern);
308        }
309        // FIX: Store in variable to avoid temporary reference
310        let limit_i64 = limit as i64;
311        params.push(&limit_i64);
312        
313        let mut rows = stmt.query(rusqlite::params_from_iter(params))?;
314        let mut messages = Vec::new();
315        
316        while let Some(row) = rows.next()? {
317            messages.push(self.row_to_stored_message(row)?);
318        }
319        
320        Ok(messages)
321    }
322
323    /// Search messages by topic keywords across sessions
324    pub async fn search_messages_by_topic_across_sessions(
325        &self,
326        topic_keywords: &[String],
327        limit: usize,
328        session_id_filter: Option<&str>, // Optional: exclude or include specific sessions
329    ) -> anyhow::Result<Vec<StoredMessage>> {
330        let conn = self.get_conn()?;
331        
332        // Build search patterns
333        let patterns: Vec<String> = topic_keywords.iter()
334            .map(|k| format!("%{}%", k.to_lowercase()))
335            .collect();
336        
337        // Build query with session filtering
338        let mut query = String::from(
339            "SELECT m.id, m.session_id, m.message_index, m.role, m.content, 
340                    m.tokens, m.timestamp, m.importance_score, m.embedding_generated
341             FROM messages m
342             JOIN sessions s ON m.session_id = s.id
343             WHERE 1=1"
344        );
345        
346        // Add session filter if provided - use Box<dyn ToSql> to store owned values
347        let mut params: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
348        if let Some(session_id) = session_id_filter {
349            query.push_str(" AND m.session_id != ?");
350            params.push(Box::new(session_id.to_string())); // Store owned string
351        }
352        
353        // Add keyword search
354        for i in 0..patterns.len() {
355            query.push_str(" AND LOWER(m.content) LIKE ?");
356            params.push(Box::new(patterns[i].clone())); // Clone the pattern
357        }
358        
359        // Order by relevance (keyword matches + recency + importance)
360        query.push_str(" ORDER BY 
361            m.importance_score DESC,
362            CASE WHEN m.role = 'assistant' THEN 1 ELSE 0 END, -- Prioritize assistant responses
363            s.last_accessed DESC,
364            m.timestamp DESC
365            LIMIT ?");
366        
367        // Store limit in variable
368        let limit_i64 = limit as i64;
369        params.push(Box::new(limit_i64));
370        
371        let mut stmt = conn.prepare(&query)?;
372        
373        // Convert params to references for the query
374        let param_refs: Vec<&dyn rusqlite::ToSql> = params.iter()
375            .map(|p| p.as_ref())
376            .collect();
377        
378        let mut rows = stmt.query(rusqlite::params_from_iter(param_refs))?;
379        let mut messages = Vec::new();
380        
381        while let Some(row) = rows.next()? {
382            let timestamp_str: String = row.get(6)?;
383            let timestamp = chrono::DateTime::parse_from_rfc3339(&timestamp_str)
384                .map_err(|e| anyhow::anyhow!("Failed to parse timestamp: {}", e))?
385                .with_timezone(&chrono::Utc);
386            
387            messages.push(StoredMessage {
388                id: row.get(0)?,
389                session_id: row.get(1)?,
390                message_index: row.get(2)?,
391                role: row.get(3)?,
392                content: row.get(4)?,
393                tokens: row.get(5)?,
394                timestamp,
395                importance_score: row.get(7)?,
396                embedding_generated: row.get(8)?,
397            });
398        }
399        
400        Ok(messages)
401    }
402}