frame_catalog/
database.rs

1//! SQLite database for persistent storage
2//!
3//! Stores conversation history, events, and metadata
4
5use chrono::{DateTime, Utc};
6use rusqlite::{params, Connection, OptionalExtension};
7use std::path::Path;
8use std::sync::{Arc, Mutex};
9use uuid::Uuid;
10
11/// Database error
12#[derive(Debug, thiserror::Error)]
13pub enum DatabaseError {
14    #[error("SQLite error: {0}")]
15    Sqlite(#[from] rusqlite::Error),
16
17    #[error("Not found")]
18    NotFound,
19
20    #[error("Serialization error: {0}")]
21    Serialization(String),
22}
23
24pub type Result<T> = std::result::Result<T, DatabaseError>;
25
26/// Event stored in the database
27#[derive(Debug, Clone)]
28pub struct StoredEvent {
29    pub id: String,
30    pub conversation_id: String,
31    pub timestamp: DateTime<Utc>,
32    pub event_type: String,
33    pub content: String,
34    pub metadata: Option<String>,
35}
36
37/// Conversation metadata
38#[derive(Debug, Clone)]
39pub struct Conversation {
40    pub id: String,
41    pub start_time: DateTime<Utc>,
42    pub end_time: Option<DateTime<Utc>>,
43    pub turn_count: i32,
44    pub metadata: Option<String>,
45}
46
47/// Database connection pool (simple Arc<Mutex> wrapper for SQLite)
48#[derive(Clone)]
49pub struct Database {
50    conn: Arc<Mutex<Connection>>,
51}
52
53impl Database {
54    /// Create a new database connection
55    ///
56    /// # Arguments
57    /// * `path` - Path to SQLite database file (or ":memory:" for in-memory)
58    pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
59        let conn = Connection::open(path)?;
60
61        // Enable foreign keys
62        conn.execute("PRAGMA foreign_keys = ON", [])?;
63
64        let db = Database {
65            conn: Arc::new(Mutex::new(conn)),
66        };
67
68        // Initialize schema
69        db.initialize_schema()?;
70
71        Ok(db)
72    }
73
74    /// Initialize database schema
75    pub fn initialize_schema(&self) -> Result<()> {
76        let conn = self.conn.lock().unwrap();
77
78        // Conversations table
79        conn.execute(
80            "CREATE TABLE IF NOT EXISTS conversations (
81                id TEXT PRIMARY KEY,
82                start_time TEXT NOT NULL,
83                end_time TEXT,
84                turn_count INTEGER NOT NULL DEFAULT 0,
85                metadata TEXT
86            )",
87            [],
88        )?;
89
90        // Events table
91        conn.execute(
92            "CREATE TABLE IF NOT EXISTS events (
93                id TEXT PRIMARY KEY,
94                conversation_id TEXT NOT NULL,
95                timestamp TEXT NOT NULL,
96                event_type TEXT NOT NULL,
97                content TEXT NOT NULL,
98                metadata TEXT,
99                FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
100            )",
101            [],
102        )?;
103
104        // Indices for performance
105        conn.execute(
106            "CREATE INDEX IF NOT EXISTS idx_events_conversation
107             ON events(conversation_id, timestamp)",
108            [],
109        )?;
110
111        conn.execute(
112            "CREATE INDEX IF NOT EXISTS idx_events_type
113             ON events(event_type, timestamp)",
114            [],
115        )?;
116
117        // Conversation embeddings table for semantic search over conversation history
118        conn.execute(
119            "CREATE TABLE IF NOT EXISTS conversation_embeddings (
120                turn_id TEXT PRIMARY KEY,
121                conversation_id TEXT NOT NULL,
122                embedding BLOB NOT NULL,
123                FOREIGN KEY (turn_id) REFERENCES events(id) ON DELETE CASCADE,
124                FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
125            )",
126            [],
127        )?;
128
129        conn.execute(
130            "CREATE INDEX IF NOT EXISTS idx_conv_embeddings
131             ON conversation_embeddings(conversation_id)",
132            [],
133        )?;
134
135        // Users table for identity tracking
136        conn.execute(
137            "CREATE TABLE IF NOT EXISTS users (
138                id TEXT PRIMARY KEY,
139                canonical_name TEXT NOT NULL,
140                aliases TEXT,
141                hidden_aliases TEXT,
142                verification_status TEXT NOT NULL DEFAULT 'unverified',
143                pattern_confidence REAL NOT NULL DEFAULT 0.0,
144                metadata TEXT,
145                first_seen TEXT NOT NULL,
146                last_seen TEXT NOT NULL
147            )",
148            [],
149        )?;
150
151        // User relationships to conversations
152        conn.execute(
153            "CREATE TABLE IF NOT EXISTS user_relationships (
154                user_id TEXT NOT NULL,
155                conversation_id TEXT NOT NULL,
156                role TEXT,
157                PRIMARY KEY (user_id, conversation_id),
158                FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE,
159                FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
160            )",
161            [],
162        )?;
163
164        conn.execute(
165            "CREATE INDEX IF NOT EXISTS idx_user_relationships_user
166             ON user_relationships(user_id)",
167            [],
168        )?;
169
170        conn.execute(
171            "CREATE INDEX IF NOT EXISTS idx_user_relationships_conv
172             ON user_relationships(conversation_id)",
173            [],
174        )?;
175
176        // Typing patterns table for behavioral fingerprinting
177        conn.execute(
178            "CREATE TABLE IF NOT EXISTS typing_patterns (
179                user_id TEXT PRIMARY KEY,
180                common_words TEXT NOT NULL,
181                common_bigrams TEXT NOT NULL,
182                common_trigrams TEXT NOT NULL,
183                comma_frequency REAL NOT NULL,
184                period_frequency REAL NOT NULL,
185                exclamation_frequency REAL NOT NULL,
186                question_frequency REAL NOT NULL,
187                emoji_frequency REAL NOT NULL,
188                ellipsis_frequency REAL NOT NULL,
189                avg_sentence_length REAL NOT NULL,
190                avg_word_length REAL NOT NULL,
191                avg_message_length REAL NOT NULL,
192                capitalization_frequency REAL NOT NULL,
193                all_caps_frequency REAL NOT NULL,
194                code_block_frequency REAL NOT NULL,
195                technical_terms_frequency REAL NOT NULL,
196                formality_score REAL NOT NULL,
197                sample_count INTEGER NOT NULL,
198                total_characters INTEGER NOT NULL,
199                last_updated TEXT NOT NULL
200            )",
201            [],
202        )?;
203
204        // Relationships table for social graph
205        conn.execute(
206            "CREATE TABLE IF NOT EXISTS relationships (
207                id INTEGER PRIMARY KEY AUTOINCREMENT,
208                from_user_id TEXT NOT NULL,
209                to_user_id TEXT NOT NULL,
210                relationship_type TEXT NOT NULL,
211                confidence REAL NOT NULL,
212                source TEXT NOT NULL,
213                created_at TEXT NOT NULL,
214                metadata TEXT,
215                UNIQUE(from_user_id, to_user_id, relationship_type)
216            )",
217            [],
218        )?;
219
220        conn.execute(
221            "CREATE INDEX IF NOT EXISTS idx_relationships_from
222             ON relationships(from_user_id)",
223            [],
224        )?;
225
226        conn.execute(
227            "CREATE INDEX IF NOT EXISTS idx_relationships_to
228             ON relationships(to_user_id)",
229            [],
230        )?;
231
232        // Chat messages table for admin UI chat persistence
233        conn.execute(
234            "CREATE TABLE IF NOT EXISTS chat_messages (
235                id TEXT PRIMARY KEY,
236                instance_id TEXT NOT NULL,
237                role TEXT NOT NULL,
238                content TEXT NOT NULL,
239                timestamp TEXT NOT NULL
240            )",
241            [],
242        )?;
243
244        conn.execute(
245            "CREATE INDEX IF NOT EXISTS idx_chat_messages_instance
246             ON chat_messages(instance_id, timestamp)",
247            [],
248        )?;
249
250        Ok(())
251    }
252
253    /// Create a new conversation
254    pub fn create_conversation(&self, id: Uuid, metadata: Option<String>) -> Result<()> {
255        let conn = self.conn.lock().unwrap();
256        let now = Utc::now().to_rfc3339();
257
258        conn.execute(
259            "INSERT INTO conversations (id, start_time, turn_count, metadata)
260             VALUES (?1, ?2, 0, ?3)",
261            params![id.to_string(), now, metadata],
262        )?;
263
264        Ok(())
265    }
266
267    /// Get conversation by ID
268    pub fn get_conversation(&self, id: Uuid) -> Result<Conversation> {
269        let conn = self.conn.lock().unwrap();
270
271        let result = conn
272            .query_row(
273                "SELECT id, start_time, end_time, turn_count, metadata
274             FROM conversations WHERE id = ?1",
275                params![id.to_string()],
276                |row| {
277                    Ok(Conversation {
278                        id: row.get(0)?,
279                        start_time: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?)
280                            .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?
281                            .with_timezone(&Utc),
282                        end_time: row
283                            .get::<_, Option<String>>(2)?
284                            .and_then(|s| DateTime::parse_from_rfc3339(&s).ok())
285                            .map(|dt| dt.with_timezone(&Utc)),
286                        turn_count: row.get(3)?,
287                        metadata: row.get(4)?,
288                    })
289                },
290            )
291            .optional()?;
292
293        result.ok_or(DatabaseError::NotFound)
294    }
295
296    /// End a conversation
297    pub fn end_conversation(&self, id: Uuid) -> Result<()> {
298        let conn = self.conn.lock().unwrap();
299        let now = Utc::now().to_rfc3339();
300
301        conn.execute(
302            "UPDATE conversations SET end_time = ?1 WHERE id = ?2",
303            params![now, id.to_string()],
304        )?;
305
306        Ok(())
307    }
308
309    /// Increment conversation turn count
310    pub fn increment_turn_count(&self, conversation_id: Uuid) -> Result<()> {
311        let conn = self.conn.lock().unwrap();
312
313        conn.execute(
314            "UPDATE conversations SET turn_count = turn_count + 1 WHERE id = ?1",
315            params![conversation_id.to_string()],
316        )?;
317
318        Ok(())
319    }
320
321    /// Store an event
322    pub fn store_event(
323        &self,
324        id: Uuid,
325        conversation_id: Uuid,
326        event_type: &str,
327        content: &str,
328        metadata: Option<String>,
329    ) -> Result<()> {
330        let conn = self.conn.lock().unwrap();
331        let now = Utc::now().to_rfc3339();
332
333        conn.execute(
334            "INSERT INTO events (id, conversation_id, timestamp, event_type, content, metadata)
335             VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
336            params![
337                id.to_string(),
338                conversation_id.to_string(),
339                now,
340                event_type,
341                content,
342                metadata
343            ],
344        )?;
345
346        Ok(())
347    }
348
349    /// Get all events for a conversation
350    pub fn get_conversation_events(&self, conversation_id: Uuid) -> Result<Vec<StoredEvent>> {
351        let conn = self.conn.lock().unwrap();
352
353        let mut stmt = conn.prepare(
354            "SELECT id, conversation_id, timestamp, event_type, content, metadata
355             FROM events
356             WHERE conversation_id = ?1
357             ORDER BY timestamp ASC",
358        )?;
359
360        let events = stmt
361            .query_map(params![conversation_id.to_string()], |row| {
362                Ok(StoredEvent {
363                    id: row.get(0)?,
364                    conversation_id: row.get(1)?,
365                    timestamp: DateTime::parse_from_rfc3339(&row.get::<_, String>(2)?)
366                        .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?
367                        .with_timezone(&Utc),
368                    event_type: row.get(3)?,
369                    content: row.get(4)?,
370                    metadata: row.get(5)?,
371                })
372            })?
373            .collect::<std::result::Result<Vec<_>, _>>()?;
374
375        Ok(events)
376    }
377
378    /// Get recent events by type
379    pub fn get_recent_events_by_type(
380        &self,
381        event_type: &str,
382        limit: usize,
383    ) -> Result<Vec<StoredEvent>> {
384        let conn = self.conn.lock().unwrap();
385
386        let mut stmt = conn.prepare(
387            "SELECT id, conversation_id, timestamp, event_type, content, metadata
388             FROM events
389             WHERE event_type = ?1
390             ORDER BY timestamp DESC
391             LIMIT ?2",
392        )?;
393
394        let events = stmt
395            .query_map(params![event_type, limit as i64], |row| {
396                Ok(StoredEvent {
397                    id: row.get(0)?,
398                    conversation_id: row.get(1)?,
399                    timestamp: DateTime::parse_from_rfc3339(&row.get::<_, String>(2)?)
400                        .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?
401                        .with_timezone(&Utc),
402                    event_type: row.get(3)?,
403                    content: row.get(4)?,
404                    metadata: row.get(5)?,
405                })
406            })?
407            .collect::<std::result::Result<Vec<_>, _>>()?;
408
409        Ok(events)
410    }
411
412    /// Delete old events (for cleanup)
413    pub fn delete_old_events(&self, before: DateTime<Utc>) -> Result<usize> {
414        let conn = self.conn.lock().unwrap();
415        let timestamp = before.to_rfc3339();
416
417        let count = conn.execute(
418            "DELETE FROM events WHERE timestamp < ?1",
419            params![timestamp],
420        )?;
421
422        Ok(count)
423    }
424
425    /// Store an event with its embedding for conversation memory
426    pub fn store_event_with_embedding(
427        &self,
428        id: Uuid,
429        conversation_id: Uuid,
430        event_type: &str,
431        content: &str,
432        metadata: Option<String>,
433        embedding: &[f32],
434    ) -> Result<()> {
435        // Store the event first
436        self.store_event(id, conversation_id, event_type, content, metadata)?;
437
438        // Convert embedding to bytes
439        let mut embedding_bytes = Vec::with_capacity(embedding.len() * 4);
440        for &val in embedding {
441            embedding_bytes.extend_from_slice(&val.to_le_bytes());
442        }
443
444        // Store the embedding
445        let conn = self.conn.lock().unwrap();
446        conn.execute(
447            "INSERT INTO conversation_embeddings (turn_id, conversation_id, embedding)
448             VALUES (?1, ?2, ?3)",
449            params![id.to_string(), conversation_id.to_string(), embedding_bytes],
450        )?;
451
452        Ok(())
453    }
454
455    /// Search conversation history by semantic similarity
456    ///
457    /// # Arguments
458    /// * `conversation_id` - Conversation to search within
459    /// * `query_embedding` - Embedding vector of the query
460    /// * `top_k` - Number of results to return
461    ///
462    /// # Returns
463    /// Vector of (event, similarity_score) tuples, sorted by similarity (highest first)
464    pub fn search_conversation_history(
465        &self,
466        conversation_id: Uuid,
467        query_embedding: &[f32],
468        top_k: usize,
469    ) -> Result<Vec<(StoredEvent, f32)>> {
470        let conn = self.conn.lock().unwrap();
471
472        // Get all embeddings for this conversation
473        let mut stmt = conn.prepare(
474            "SELECT ce.turn_id, ce.embedding, e.id, e.conversation_id, e.timestamp,
475                    e.event_type, e.content, e.metadata
476             FROM conversation_embeddings ce
477             JOIN events e ON ce.turn_id = e.id
478             WHERE ce.conversation_id = ?1",
479        )?;
480
481        let mut results: Vec<(StoredEvent, f32)> = Vec::new();
482
483        let rows = stmt.query_map(params![conversation_id.to_string()], |row| {
484            let embedding_bytes: Vec<u8> = row.get(1)?;
485            let stored_embedding: Vec<f32> = embedding_bytes
486                .chunks_exact(4)
487                .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
488                .collect();
489
490            let event = StoredEvent {
491                id: row.get(2)?,
492                conversation_id: row.get(3)?,
493                timestamp: DateTime::parse_from_rfc3339(&row.get::<_, String>(4)?)
494                    .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?
495                    .with_timezone(&Utc),
496                event_type: row.get(5)?,
497                content: row.get(6)?,
498                metadata: row.get(7)?,
499            };
500
501            Ok((event, stored_embedding))
502        })?;
503
504        // Calculate cosine similarity for each result
505        for row in rows {
506            let (event, stored_embedding) = row?;
507            let similarity = cosine_similarity(query_embedding, &stored_embedding);
508            results.push((event, similarity));
509        }
510
511        // Sort by similarity (highest first)
512        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
513
514        // Return top K results
515        results.truncate(top_k);
516
517        Ok(results)
518    }
519
520    /// Get database statistics
521    pub fn get_stats(&self) -> Result<DatabaseStats> {
522        let conn = self.conn.lock().unwrap();
523
524        let conversation_count: i64 =
525            conn.query_row("SELECT COUNT(*) FROM conversations", [], |row| row.get(0))?;
526
527        let event_count: i64 =
528            conn.query_row("SELECT COUNT(*) FROM events", [], |row| row.get(0))?;
529
530        let active_conversations: i64 = conn.query_row(
531            "SELECT COUNT(*) FROM conversations WHERE end_time IS NULL",
532            [],
533            |row| row.get(0),
534        )?;
535
536        Ok(DatabaseStats {
537            conversation_count: conversation_count as usize,
538            event_count: event_count as usize,
539            active_conversations: active_conversations as usize,
540        })
541    }
542}
543
544/// Calculate cosine similarity between two vectors
545fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
546    if a.len() != b.len() {
547        return 0.0;
548    }
549
550    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
551    let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
552    let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
553
554    if mag_a == 0.0 || mag_b == 0.0 {
555        return 0.0;
556    }
557
558    dot / (mag_a * mag_b)
559}
560
561/// Database statistics
562#[derive(Debug, Clone)]
563pub struct DatabaseStats {
564    pub conversation_count: usize,
565    pub event_count: usize,
566    pub active_conversations: usize,
567}
568
569// Additional Database methods for identity module
570impl Database {
571    /// Get a clone of the internal connection (for IdentityStore)
572    pub fn conn(&self) -> Arc<Mutex<Connection>> {
573        self.conn.clone()
574    }
575
576    /// Get thought count from thoughtchain table
577    pub fn get_thought_count(&self) -> Result<usize> {
578        let conn = self.conn.lock().unwrap();
579        let count: i64 =
580            conn.query_row("SELECT COUNT(*) FROM thoughtchain", [], |row| row.get(0))?;
581        Ok(count as usize)
582    }
583
584    /// Get identity count from identities table
585    pub fn get_identity_count(&self) -> Result<usize> {
586        let conn = self.conn.lock().unwrap();
587        let count: i64 = conn.query_row("SELECT COUNT(*) FROM identities", [], |row| row.get(0))?;
588        Ok(count as usize)
589    }
590
591    /// Get relationship count from relationships table
592    pub fn get_relationship_count(&self) -> Result<usize> {
593        let conn = self.conn.lock().unwrap();
594        let count: i64 =
595            conn.query_row("SELECT COUNT(*) FROM relationships", [], |row| row.get(0))?;
596        Ok(count as usize)
597    }
598
599    /// Get session count from sessions table
600    pub fn get_session_count(&self) -> Result<usize> {
601        let conn = self.conn.lock().unwrap();
602        let count: i64 = conn.query_row("SELECT COUNT(*) FROM sessions", [], |row| row.get(0))?;
603        Ok(count as usize)
604    }
605
606    /// Store a chat message
607    pub fn store_chat_message(
608        &self,
609        id: String,
610        instance_id: String,
611        role: String,
612        content: String,
613        timestamp: String,
614    ) -> Result<()> {
615        let conn = self.conn.lock().unwrap();
616        conn.execute(
617            "INSERT INTO chat_messages (id, instance_id, role, content, timestamp)
618             VALUES (?1, ?2, ?3, ?4, ?5)",
619            params![id, instance_id, role, content, timestamp],
620        )?;
621        Ok(())
622    }
623
624    /// Get chat messages for a specific instance
625    pub fn get_chat_messages(&self, instance_id: &str, limit: usize) -> Result<Vec<ChatMessage>> {
626        let conn = self.conn.lock().unwrap();
627        let mut stmt = conn.prepare(
628            "SELECT id, instance_id, role, content, timestamp
629             FROM chat_messages
630             WHERE instance_id = ?1
631             ORDER BY timestamp ASC
632             LIMIT ?2",
633        )?;
634
635        let messages = stmt
636            .query_map(params![instance_id, limit as i64], |row| {
637                Ok(ChatMessage {
638                    id: row.get(0)?,
639                    instance_id: row.get(1)?,
640                    role: row.get(2)?,
641                    content: row.get(3)?,
642                    timestamp: row.get(4)?,
643                })
644            })?
645            .collect::<std::result::Result<Vec<_>, _>>()?;
646
647        Ok(messages)
648    }
649}
650
651/// Chat message stored in database
652#[derive(Debug, Clone)]
653pub struct ChatMessage {
654    pub id: String,
655    pub instance_id: String,
656    pub role: String,
657    pub content: String,
658    pub timestamp: String,
659}
660
661#[cfg(test)]
662mod tests {
663    use super::*;
664
665    #[test]
666    fn test_database_creation() {
667        let db = Database::new(":memory:").unwrap();
668        let stats = db.get_stats().unwrap();
669        assert_eq!(stats.conversation_count, 0);
670        assert_eq!(stats.event_count, 0);
671    }
672
673    #[test]
674    fn test_create_conversation() {
675        let db = Database::new(":memory:").unwrap();
676        let id = Uuid::new_v4();
677
678        db.create_conversation(id, None).unwrap();
679
680        let conv = db.get_conversation(id).unwrap();
681        assert_eq!(conv.id, id.to_string());
682        assert_eq!(conv.turn_count, 0);
683        assert!(conv.end_time.is_none());
684    }
685
686    #[test]
687    fn test_store_and_retrieve_events() {
688        let db = Database::new(":memory:").unwrap();
689        let conv_id = Uuid::new_v4();
690
691        db.create_conversation(conv_id, None).unwrap();
692
693        // Store some events
694        let event1_id = Uuid::new_v4();
695        db.store_event(event1_id, conv_id, "UserMessage", "Hello", None)
696            .unwrap();
697
698        let event2_id = Uuid::new_v4();
699        db.store_event(event2_id, conv_id, "AssistantMessage", "Hi there", None)
700            .unwrap();
701
702        // Retrieve events
703        let events = db.get_conversation_events(conv_id).unwrap();
704        assert_eq!(events.len(), 2);
705        assert_eq!(events[0].event_type, "UserMessage");
706        assert_eq!(events[0].content, "Hello");
707        assert_eq!(events[1].event_type, "AssistantMessage");
708        assert_eq!(events[1].content, "Hi there");
709    }
710
711    #[test]
712    fn test_end_conversation() {
713        let db = Database::new(":memory:").unwrap();
714        let id = Uuid::new_v4();
715
716        db.create_conversation(id, None).unwrap();
717        db.end_conversation(id).unwrap();
718
719        let conv = db.get_conversation(id).unwrap();
720        assert!(conv.end_time.is_some());
721    }
722
723    #[test]
724    fn test_increment_turn_count() {
725        let db = Database::new(":memory:").unwrap();
726        let id = Uuid::new_v4();
727
728        db.create_conversation(id, None).unwrap();
729        db.increment_turn_count(id).unwrap();
730        db.increment_turn_count(id).unwrap();
731
732        let conv = db.get_conversation(id).unwrap();
733        assert_eq!(conv.turn_count, 2);
734    }
735
736    #[test]
737    fn test_get_recent_events_by_type() {
738        let db = Database::new(":memory:").unwrap();
739        let conv_id = Uuid::new_v4();
740
741        db.create_conversation(conv_id, None).unwrap();
742
743        // Store events of different types
744        for i in 0..5 {
745            db.store_event(
746                Uuid::new_v4(),
747                conv_id,
748                "UserMessage",
749                &format!("Message {}", i),
750                None,
751            )
752            .unwrap();
753        }
754
755        for i in 0..3 {
756            db.store_event(
757                Uuid::new_v4(),
758                conv_id,
759                "SystemEvent",
760                &format!("Event {}", i),
761                None,
762            )
763            .unwrap();
764        }
765
766        let user_msgs = db.get_recent_events_by_type("UserMessage", 10).unwrap();
767        assert_eq!(user_msgs.len(), 5);
768
769        let sys_events = db.get_recent_events_by_type("SystemEvent", 2).unwrap();
770        assert_eq!(sys_events.len(), 2);
771    }
772
773    #[test]
774    fn test_database_stats() {
775        let db = Database::new(":memory:").unwrap();
776
777        // Create conversations
778        let id1 = Uuid::new_v4();
779        let id2 = Uuid::new_v4();
780        db.create_conversation(id1, None).unwrap();
781        db.create_conversation(id2, None).unwrap();
782
783        // Store events
784        db.store_event(Uuid::new_v4(), id1, "test", "content", None)
785            .unwrap();
786        db.store_event(Uuid::new_v4(), id1, "test", "content", None)
787            .unwrap();
788        db.store_event(Uuid::new_v4(), id2, "test", "content", None)
789            .unwrap();
790
791        // End one conversation
792        db.end_conversation(id1).unwrap();
793
794        let stats = db.get_stats().unwrap();
795        assert_eq!(stats.conversation_count, 2);
796        assert_eq!(stats.event_count, 3);
797        assert_eq!(stats.active_conversations, 1);
798    }
799
800    #[test]
801    fn test_store_event_with_embedding() {
802        let db = Database::new(":memory:").unwrap();
803        let conv_id = Uuid::new_v4();
804        let event_id = Uuid::new_v4();
805
806        db.create_conversation(conv_id, None).unwrap();
807
808        // Create a test embedding (384 dimensions, all 0.5)
809        let embedding = vec![0.5f32; 384];
810
811        db.store_event_with_embedding(
812            event_id,
813            conv_id,
814            "UserMessage",
815            "Test message",
816            None,
817            &embedding,
818        )
819        .unwrap();
820
821        // Verify event was stored
822        let events = db.get_conversation_events(conv_id).unwrap();
823        assert_eq!(events.len(), 1);
824        assert_eq!(events[0].content, "Test message");
825
826        // Verify embedding was stored (can't directly query but search will validate)
827        let query_embedding = vec![0.5f32; 384];
828        let results = db
829            .search_conversation_history(conv_id, &query_embedding, 10)
830            .unwrap();
831        assert_eq!(results.len(), 1);
832        assert!(results[0].1 > 0.99); // High similarity since embeddings match
833    }
834
835    #[test]
836    fn test_search_conversation_history() {
837        let db = Database::new(":memory:").unwrap();
838        let conv_id = Uuid::new_v4();
839
840        db.create_conversation(conv_id, None).unwrap();
841
842        // Store events with different embeddings
843        let embedding1 = vec![1.0f32; 384]; // All 1.0
844        let embedding2 = vec![0.0f32; 384]; // All 0.0
845        let mut embedding3 = vec![0.0f32; 384];
846        embedding3[0] = 1.0; // First element 1.0, rest 0.0
847
848        db.store_event_with_embedding(
849            Uuid::new_v4(),
850            conv_id,
851            "UserMessage",
852            "First message",
853            None,
854            &embedding1,
855        )
856        .unwrap();
857
858        db.store_event_with_embedding(
859            Uuid::new_v4(),
860            conv_id,
861            "UserMessage",
862            "Second message",
863            None,
864            &embedding2,
865        )
866        .unwrap();
867
868        db.store_event_with_embedding(
869            Uuid::new_v4(),
870            conv_id,
871            "UserMessage",
872            "Third message",
873            None,
874            &embedding3,
875        )
876        .unwrap();
877
878        // Search with query matching embedding1
879        let query = vec![1.0f32; 384];
880        let results = db.search_conversation_history(conv_id, &query, 2).unwrap();
881
882        assert_eq!(results.len(), 2);
883        assert_eq!(results[0].0.content, "First message");
884        assert!(results[0].1 > 0.99); // Very high similarity
885    }
886
887    #[test]
888    fn test_search_conversation_history_empty() {
889        let db = Database::new(":memory:").unwrap();
890        let conv_id = Uuid::new_v4();
891
892        db.create_conversation(conv_id, None).unwrap();
893
894        let query = vec![1.0f32; 384];
895        let results = db.search_conversation_history(conv_id, &query, 10).unwrap();
896
897        assert_eq!(results.len(), 0);
898    }
899
900    #[test]
901    fn test_search_conversation_history_top_k() {
902        let db = Database::new(":memory:").unwrap();
903        let conv_id = Uuid::new_v4();
904
905        db.create_conversation(conv_id, None).unwrap();
906
907        // Store 5 events
908        for i in 0..5 {
909            let mut embedding = vec![0.0f32; 384];
910            embedding[0] = i as f32;
911            db.store_event_with_embedding(
912                Uuid::new_v4(),
913                conv_id,
914                "UserMessage",
915                &format!("Message {}", i),
916                None,
917                &embedding,
918            )
919            .unwrap();
920        }
921
922        // Search with top_k = 3
923        let query = vec![2.0f32; 384]; // Should match message 2 best
924        let results = db.search_conversation_history(conv_id, &query, 3).unwrap();
925
926        assert_eq!(results.len(), 3);
927        // Results should be sorted by similarity
928        assert!(results[0].1 >= results[1].1);
929        assert!(results[1].1 >= results[2].1);
930    }
931
932    #[test]
933    fn test_cosine_similarity() {
934        // Identical vectors
935        let a = vec![1.0, 2.0, 3.0];
936        let b = vec![1.0, 2.0, 3.0];
937        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
938
939        // Orthogonal vectors
940        let a = vec![1.0, 0.0];
941        let b = vec![0.0, 1.0];
942        assert!(cosine_similarity(&a, &b).abs() < 0.001);
943
944        // Opposite vectors
945        let a = vec![1.0, 2.0, 3.0];
946        let b = vec![-1.0, -2.0, -3.0];
947        assert!((cosine_similarity(&a, &b) + 1.0).abs() < 0.001);
948
949        // Different lengths (should return 0.0)
950        let a = vec![1.0, 2.0];
951        let b = vec![1.0, 2.0, 3.0];
952        assert_eq!(cosine_similarity(&a, &b), 0.0);
953
954        // Zero vectors (should return 0.0)
955        let a = vec![0.0, 0.0, 0.0];
956        let b = vec![1.0, 2.0, 3.0];
957        assert_eq!(cosine_similarity(&a, &b), 0.0);
958    }
959}