1use chrono::{DateTime, Utc};
6use rusqlite::{params, Connection, OptionalExtension};
7use std::path::Path;
8use std::sync::{Arc, Mutex};
9use uuid::Uuid;
10
11#[derive(Debug, thiserror::Error)]
13pub enum DatabaseError {
14 #[error("SQLite error: {0}")]
15 Sqlite(#[from] rusqlite::Error),
16
17 #[error("Not found")]
18 NotFound,
19
20 #[error("Serialization error: {0}")]
21 Serialization(String),
22}
23
24pub type Result<T> = std::result::Result<T, DatabaseError>;
25
26#[derive(Debug, Clone)]
28pub struct StoredEvent {
29 pub id: String,
30 pub conversation_id: String,
31 pub timestamp: DateTime<Utc>,
32 pub event_type: String,
33 pub content: String,
34 pub metadata: Option<String>,
35}
36
37#[derive(Debug, Clone)]
39pub struct Conversation {
40 pub id: String,
41 pub start_time: DateTime<Utc>,
42 pub end_time: Option<DateTime<Utc>>,
43 pub turn_count: i32,
44 pub metadata: Option<String>,
45}
46
47#[derive(Clone)]
49pub struct Database {
50 conn: Arc<Mutex<Connection>>,
51}
52
53impl Database {
54 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
59 let conn = Connection::open(path)?;
60
61 conn.execute("PRAGMA foreign_keys = ON", [])?;
63
64 let db = Database {
65 conn: Arc::new(Mutex::new(conn)),
66 };
67
68 db.initialize_schema()?;
70
71 Ok(db)
72 }
73
74 pub fn initialize_schema(&self) -> Result<()> {
76 let conn = self.conn.lock().unwrap();
77
78 conn.execute(
80 "CREATE TABLE IF NOT EXISTS conversations (
81 id TEXT PRIMARY KEY,
82 start_time TEXT NOT NULL,
83 end_time TEXT,
84 turn_count INTEGER NOT NULL DEFAULT 0,
85 metadata TEXT
86 )",
87 [],
88 )?;
89
90 conn.execute(
92 "CREATE TABLE IF NOT EXISTS events (
93 id TEXT PRIMARY KEY,
94 conversation_id TEXT NOT NULL,
95 timestamp TEXT NOT NULL,
96 event_type TEXT NOT NULL,
97 content TEXT NOT NULL,
98 metadata TEXT,
99 FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
100 )",
101 [],
102 )?;
103
104 conn.execute(
106 "CREATE INDEX IF NOT EXISTS idx_events_conversation
107 ON events(conversation_id, timestamp)",
108 [],
109 )?;
110
111 conn.execute(
112 "CREATE INDEX IF NOT EXISTS idx_events_type
113 ON events(event_type, timestamp)",
114 [],
115 )?;
116
117 conn.execute(
119 "CREATE TABLE IF NOT EXISTS conversation_embeddings (
120 turn_id TEXT PRIMARY KEY,
121 conversation_id TEXT NOT NULL,
122 embedding BLOB NOT NULL,
123 FOREIGN KEY (turn_id) REFERENCES events(id) ON DELETE CASCADE,
124 FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
125 )",
126 [],
127 )?;
128
129 conn.execute(
130 "CREATE INDEX IF NOT EXISTS idx_conv_embeddings
131 ON conversation_embeddings(conversation_id)",
132 [],
133 )?;
134
135 conn.execute(
137 "CREATE TABLE IF NOT EXISTS users (
138 id TEXT PRIMARY KEY,
139 canonical_name TEXT NOT NULL,
140 aliases TEXT,
141 hidden_aliases TEXT,
142 verification_status TEXT NOT NULL DEFAULT 'unverified',
143 pattern_confidence REAL NOT NULL DEFAULT 0.0,
144 metadata TEXT,
145 first_seen TEXT NOT NULL,
146 last_seen TEXT NOT NULL
147 )",
148 [],
149 )?;
150
151 conn.execute(
153 "CREATE TABLE IF NOT EXISTS user_relationships (
154 user_id TEXT NOT NULL,
155 conversation_id TEXT NOT NULL,
156 role TEXT,
157 PRIMARY KEY (user_id, conversation_id),
158 FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE,
159 FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
160 )",
161 [],
162 )?;
163
164 conn.execute(
165 "CREATE INDEX IF NOT EXISTS idx_user_relationships_user
166 ON user_relationships(user_id)",
167 [],
168 )?;
169
170 conn.execute(
171 "CREATE INDEX IF NOT EXISTS idx_user_relationships_conv
172 ON user_relationships(conversation_id)",
173 [],
174 )?;
175
176 conn.execute(
178 "CREATE TABLE IF NOT EXISTS typing_patterns (
179 user_id TEXT PRIMARY KEY,
180 common_words TEXT NOT NULL,
181 common_bigrams TEXT NOT NULL,
182 common_trigrams TEXT NOT NULL,
183 comma_frequency REAL NOT NULL,
184 period_frequency REAL NOT NULL,
185 exclamation_frequency REAL NOT NULL,
186 question_frequency REAL NOT NULL,
187 emoji_frequency REAL NOT NULL,
188 ellipsis_frequency REAL NOT NULL,
189 avg_sentence_length REAL NOT NULL,
190 avg_word_length REAL NOT NULL,
191 avg_message_length REAL NOT NULL,
192 capitalization_frequency REAL NOT NULL,
193 all_caps_frequency REAL NOT NULL,
194 code_block_frequency REAL NOT NULL,
195 technical_terms_frequency REAL NOT NULL,
196 formality_score REAL NOT NULL,
197 sample_count INTEGER NOT NULL,
198 total_characters INTEGER NOT NULL,
199 last_updated TEXT NOT NULL
200 )",
201 [],
202 )?;
203
204 conn.execute(
206 "CREATE TABLE IF NOT EXISTS relationships (
207 id INTEGER PRIMARY KEY AUTOINCREMENT,
208 from_user_id TEXT NOT NULL,
209 to_user_id TEXT NOT NULL,
210 relationship_type TEXT NOT NULL,
211 confidence REAL NOT NULL,
212 source TEXT NOT NULL,
213 created_at TEXT NOT NULL,
214 metadata TEXT,
215 UNIQUE(from_user_id, to_user_id, relationship_type)
216 )",
217 [],
218 )?;
219
220 conn.execute(
221 "CREATE INDEX IF NOT EXISTS idx_relationships_from
222 ON relationships(from_user_id)",
223 [],
224 )?;
225
226 conn.execute(
227 "CREATE INDEX IF NOT EXISTS idx_relationships_to
228 ON relationships(to_user_id)",
229 [],
230 )?;
231
232 conn.execute(
234 "CREATE TABLE IF NOT EXISTS chat_messages (
235 id TEXT PRIMARY KEY,
236 instance_id TEXT NOT NULL,
237 role TEXT NOT NULL,
238 content TEXT NOT NULL,
239 timestamp TEXT NOT NULL
240 )",
241 [],
242 )?;
243
244 conn.execute(
245 "CREATE INDEX IF NOT EXISTS idx_chat_messages_instance
246 ON chat_messages(instance_id, timestamp)",
247 [],
248 )?;
249
250 Ok(())
251 }
252
253 pub fn create_conversation(&self, id: Uuid, metadata: Option<String>) -> Result<()> {
255 let conn = self.conn.lock().unwrap();
256 let now = Utc::now().to_rfc3339();
257
258 conn.execute(
259 "INSERT INTO conversations (id, start_time, turn_count, metadata)
260 VALUES (?1, ?2, 0, ?3)",
261 params![id.to_string(), now, metadata],
262 )?;
263
264 Ok(())
265 }
266
267 pub fn get_conversation(&self, id: Uuid) -> Result<Conversation> {
269 let conn = self.conn.lock().unwrap();
270
271 let result = conn
272 .query_row(
273 "SELECT id, start_time, end_time, turn_count, metadata
274 FROM conversations WHERE id = ?1",
275 params![id.to_string()],
276 |row| {
277 Ok(Conversation {
278 id: row.get(0)?,
279 start_time: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?)
280 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?
281 .with_timezone(&Utc),
282 end_time: row
283 .get::<_, Option<String>>(2)?
284 .and_then(|s| DateTime::parse_from_rfc3339(&s).ok())
285 .map(|dt| dt.with_timezone(&Utc)),
286 turn_count: row.get(3)?,
287 metadata: row.get(4)?,
288 })
289 },
290 )
291 .optional()?;
292
293 result.ok_or(DatabaseError::NotFound)
294 }
295
296 pub fn end_conversation(&self, id: Uuid) -> Result<()> {
298 let conn = self.conn.lock().unwrap();
299 let now = Utc::now().to_rfc3339();
300
301 conn.execute(
302 "UPDATE conversations SET end_time = ?1 WHERE id = ?2",
303 params![now, id.to_string()],
304 )?;
305
306 Ok(())
307 }
308
309 pub fn increment_turn_count(&self, conversation_id: Uuid) -> Result<()> {
311 let conn = self.conn.lock().unwrap();
312
313 conn.execute(
314 "UPDATE conversations SET turn_count = turn_count + 1 WHERE id = ?1",
315 params![conversation_id.to_string()],
316 )?;
317
318 Ok(())
319 }
320
321 pub fn store_event(
323 &self,
324 id: Uuid,
325 conversation_id: Uuid,
326 event_type: &str,
327 content: &str,
328 metadata: Option<String>,
329 ) -> Result<()> {
330 let conn = self.conn.lock().unwrap();
331 let now = Utc::now().to_rfc3339();
332
333 conn.execute(
334 "INSERT INTO events (id, conversation_id, timestamp, event_type, content, metadata)
335 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
336 params![
337 id.to_string(),
338 conversation_id.to_string(),
339 now,
340 event_type,
341 content,
342 metadata
343 ],
344 )?;
345
346 Ok(())
347 }
348
349 pub fn get_conversation_events(&self, conversation_id: Uuid) -> Result<Vec<StoredEvent>> {
351 let conn = self.conn.lock().unwrap();
352
353 let mut stmt = conn.prepare(
354 "SELECT id, conversation_id, timestamp, event_type, content, metadata
355 FROM events
356 WHERE conversation_id = ?1
357 ORDER BY timestamp ASC",
358 )?;
359
360 let events = stmt
361 .query_map(params![conversation_id.to_string()], |row| {
362 Ok(StoredEvent {
363 id: row.get(0)?,
364 conversation_id: row.get(1)?,
365 timestamp: DateTime::parse_from_rfc3339(&row.get::<_, String>(2)?)
366 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?
367 .with_timezone(&Utc),
368 event_type: row.get(3)?,
369 content: row.get(4)?,
370 metadata: row.get(5)?,
371 })
372 })?
373 .collect::<std::result::Result<Vec<_>, _>>()?;
374
375 Ok(events)
376 }
377
378 pub fn get_recent_events_by_type(
380 &self,
381 event_type: &str,
382 limit: usize,
383 ) -> Result<Vec<StoredEvent>> {
384 let conn = self.conn.lock().unwrap();
385
386 let mut stmt = conn.prepare(
387 "SELECT id, conversation_id, timestamp, event_type, content, metadata
388 FROM events
389 WHERE event_type = ?1
390 ORDER BY timestamp DESC
391 LIMIT ?2",
392 )?;
393
394 let events = stmt
395 .query_map(params![event_type, limit as i64], |row| {
396 Ok(StoredEvent {
397 id: row.get(0)?,
398 conversation_id: row.get(1)?,
399 timestamp: DateTime::parse_from_rfc3339(&row.get::<_, String>(2)?)
400 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?
401 .with_timezone(&Utc),
402 event_type: row.get(3)?,
403 content: row.get(4)?,
404 metadata: row.get(5)?,
405 })
406 })?
407 .collect::<std::result::Result<Vec<_>, _>>()?;
408
409 Ok(events)
410 }
411
412 pub fn delete_old_events(&self, before: DateTime<Utc>) -> Result<usize> {
414 let conn = self.conn.lock().unwrap();
415 let timestamp = before.to_rfc3339();
416
417 let count = conn.execute(
418 "DELETE FROM events WHERE timestamp < ?1",
419 params![timestamp],
420 )?;
421
422 Ok(count)
423 }
424
425 pub fn store_event_with_embedding(
427 &self,
428 id: Uuid,
429 conversation_id: Uuid,
430 event_type: &str,
431 content: &str,
432 metadata: Option<String>,
433 embedding: &[f32],
434 ) -> Result<()> {
435 self.store_event(id, conversation_id, event_type, content, metadata)?;
437
438 let mut embedding_bytes = Vec::with_capacity(embedding.len() * 4);
440 for &val in embedding {
441 embedding_bytes.extend_from_slice(&val.to_le_bytes());
442 }
443
444 let conn = self.conn.lock().unwrap();
446 conn.execute(
447 "INSERT INTO conversation_embeddings (turn_id, conversation_id, embedding)
448 VALUES (?1, ?2, ?3)",
449 params![id.to_string(), conversation_id.to_string(), embedding_bytes],
450 )?;
451
452 Ok(())
453 }
454
455 pub fn search_conversation_history(
465 &self,
466 conversation_id: Uuid,
467 query_embedding: &[f32],
468 top_k: usize,
469 ) -> Result<Vec<(StoredEvent, f32)>> {
470 let conn = self.conn.lock().unwrap();
471
472 let mut stmt = conn.prepare(
474 "SELECT ce.turn_id, ce.embedding, e.id, e.conversation_id, e.timestamp,
475 e.event_type, e.content, e.metadata
476 FROM conversation_embeddings ce
477 JOIN events e ON ce.turn_id = e.id
478 WHERE ce.conversation_id = ?1",
479 )?;
480
481 let mut results: Vec<(StoredEvent, f32)> = Vec::new();
482
483 let rows = stmt.query_map(params![conversation_id.to_string()], |row| {
484 let embedding_bytes: Vec<u8> = row.get(1)?;
485 let stored_embedding: Vec<f32> = embedding_bytes
486 .chunks_exact(4)
487 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
488 .collect();
489
490 let event = StoredEvent {
491 id: row.get(2)?,
492 conversation_id: row.get(3)?,
493 timestamp: DateTime::parse_from_rfc3339(&row.get::<_, String>(4)?)
494 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?
495 .with_timezone(&Utc),
496 event_type: row.get(5)?,
497 content: row.get(6)?,
498 metadata: row.get(7)?,
499 };
500
501 Ok((event, stored_embedding))
502 })?;
503
504 for row in rows {
506 let (event, stored_embedding) = row?;
507 let similarity = cosine_similarity(query_embedding, &stored_embedding);
508 results.push((event, similarity));
509 }
510
511 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
513
514 results.truncate(top_k);
516
517 Ok(results)
518 }
519
520 pub fn get_stats(&self) -> Result<DatabaseStats> {
522 let conn = self.conn.lock().unwrap();
523
524 let conversation_count: i64 =
525 conn.query_row("SELECT COUNT(*) FROM conversations", [], |row| row.get(0))?;
526
527 let event_count: i64 =
528 conn.query_row("SELECT COUNT(*) FROM events", [], |row| row.get(0))?;
529
530 let active_conversations: i64 = conn.query_row(
531 "SELECT COUNT(*) FROM conversations WHERE end_time IS NULL",
532 [],
533 |row| row.get(0),
534 )?;
535
536 Ok(DatabaseStats {
537 conversation_count: conversation_count as usize,
538 event_count: event_count as usize,
539 active_conversations: active_conversations as usize,
540 })
541 }
542}
543
544fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
546 if a.len() != b.len() {
547 return 0.0;
548 }
549
550 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
551 let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
552 let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
553
554 if mag_a == 0.0 || mag_b == 0.0 {
555 return 0.0;
556 }
557
558 dot / (mag_a * mag_b)
559}
560
561#[derive(Debug, Clone)]
563pub struct DatabaseStats {
564 pub conversation_count: usize,
565 pub event_count: usize,
566 pub active_conversations: usize,
567}
568
569impl Database {
571 pub fn conn(&self) -> Arc<Mutex<Connection>> {
573 self.conn.clone()
574 }
575
576 pub fn get_thought_count(&self) -> Result<usize> {
578 let conn = self.conn.lock().unwrap();
579 let count: i64 =
580 conn.query_row("SELECT COUNT(*) FROM thoughtchain", [], |row| row.get(0))?;
581 Ok(count as usize)
582 }
583
584 pub fn get_identity_count(&self) -> Result<usize> {
586 let conn = self.conn.lock().unwrap();
587 let count: i64 = conn.query_row("SELECT COUNT(*) FROM identities", [], |row| row.get(0))?;
588 Ok(count as usize)
589 }
590
591 pub fn get_relationship_count(&self) -> Result<usize> {
593 let conn = self.conn.lock().unwrap();
594 let count: i64 =
595 conn.query_row("SELECT COUNT(*) FROM relationships", [], |row| row.get(0))?;
596 Ok(count as usize)
597 }
598
599 pub fn get_session_count(&self) -> Result<usize> {
601 let conn = self.conn.lock().unwrap();
602 let count: i64 = conn.query_row("SELECT COUNT(*) FROM sessions", [], |row| row.get(0))?;
603 Ok(count as usize)
604 }
605
606 pub fn store_chat_message(
608 &self,
609 id: String,
610 instance_id: String,
611 role: String,
612 content: String,
613 timestamp: String,
614 ) -> Result<()> {
615 let conn = self.conn.lock().unwrap();
616 conn.execute(
617 "INSERT INTO chat_messages (id, instance_id, role, content, timestamp)
618 VALUES (?1, ?2, ?3, ?4, ?5)",
619 params![id, instance_id, role, content, timestamp],
620 )?;
621 Ok(())
622 }
623
624 pub fn get_chat_messages(&self, instance_id: &str, limit: usize) -> Result<Vec<ChatMessage>> {
626 let conn = self.conn.lock().unwrap();
627 let mut stmt = conn.prepare(
628 "SELECT id, instance_id, role, content, timestamp
629 FROM chat_messages
630 WHERE instance_id = ?1
631 ORDER BY timestamp ASC
632 LIMIT ?2",
633 )?;
634
635 let messages = stmt
636 .query_map(params![instance_id, limit as i64], |row| {
637 Ok(ChatMessage {
638 id: row.get(0)?,
639 instance_id: row.get(1)?,
640 role: row.get(2)?,
641 content: row.get(3)?,
642 timestamp: row.get(4)?,
643 })
644 })?
645 .collect::<std::result::Result<Vec<_>, _>>()?;
646
647 Ok(messages)
648 }
649}
650
651#[derive(Debug, Clone)]
653pub struct ChatMessage {
654 pub id: String,
655 pub instance_id: String,
656 pub role: String,
657 pub content: String,
658 pub timestamp: String,
659}
660
661#[cfg(test)]
662mod tests {
663 use super::*;
664
665 #[test]
666 fn test_database_creation() {
667 let db = Database::new(":memory:").unwrap();
668 let stats = db.get_stats().unwrap();
669 assert_eq!(stats.conversation_count, 0);
670 assert_eq!(stats.event_count, 0);
671 }
672
673 #[test]
674 fn test_create_conversation() {
675 let db = Database::new(":memory:").unwrap();
676 let id = Uuid::new_v4();
677
678 db.create_conversation(id, None).unwrap();
679
680 let conv = db.get_conversation(id).unwrap();
681 assert_eq!(conv.id, id.to_string());
682 assert_eq!(conv.turn_count, 0);
683 assert!(conv.end_time.is_none());
684 }
685
686 #[test]
687 fn test_store_and_retrieve_events() {
688 let db = Database::new(":memory:").unwrap();
689 let conv_id = Uuid::new_v4();
690
691 db.create_conversation(conv_id, None).unwrap();
692
693 let event1_id = Uuid::new_v4();
695 db.store_event(event1_id, conv_id, "UserMessage", "Hello", None)
696 .unwrap();
697
698 let event2_id = Uuid::new_v4();
699 db.store_event(event2_id, conv_id, "AssistantMessage", "Hi there", None)
700 .unwrap();
701
702 let events = db.get_conversation_events(conv_id).unwrap();
704 assert_eq!(events.len(), 2);
705 assert_eq!(events[0].event_type, "UserMessage");
706 assert_eq!(events[0].content, "Hello");
707 assert_eq!(events[1].event_type, "AssistantMessage");
708 assert_eq!(events[1].content, "Hi there");
709 }
710
711 #[test]
712 fn test_end_conversation() {
713 let db = Database::new(":memory:").unwrap();
714 let id = Uuid::new_v4();
715
716 db.create_conversation(id, None).unwrap();
717 db.end_conversation(id).unwrap();
718
719 let conv = db.get_conversation(id).unwrap();
720 assert!(conv.end_time.is_some());
721 }
722
723 #[test]
724 fn test_increment_turn_count() {
725 let db = Database::new(":memory:").unwrap();
726 let id = Uuid::new_v4();
727
728 db.create_conversation(id, None).unwrap();
729 db.increment_turn_count(id).unwrap();
730 db.increment_turn_count(id).unwrap();
731
732 let conv = db.get_conversation(id).unwrap();
733 assert_eq!(conv.turn_count, 2);
734 }
735
736 #[test]
737 fn test_get_recent_events_by_type() {
738 let db = Database::new(":memory:").unwrap();
739 let conv_id = Uuid::new_v4();
740
741 db.create_conversation(conv_id, None).unwrap();
742
743 for i in 0..5 {
745 db.store_event(
746 Uuid::new_v4(),
747 conv_id,
748 "UserMessage",
749 &format!("Message {}", i),
750 None,
751 )
752 .unwrap();
753 }
754
755 for i in 0..3 {
756 db.store_event(
757 Uuid::new_v4(),
758 conv_id,
759 "SystemEvent",
760 &format!("Event {}", i),
761 None,
762 )
763 .unwrap();
764 }
765
766 let user_msgs = db.get_recent_events_by_type("UserMessage", 10).unwrap();
767 assert_eq!(user_msgs.len(), 5);
768
769 let sys_events = db.get_recent_events_by_type("SystemEvent", 2).unwrap();
770 assert_eq!(sys_events.len(), 2);
771 }
772
773 #[test]
774 fn test_database_stats() {
775 let db = Database::new(":memory:").unwrap();
776
777 let id1 = Uuid::new_v4();
779 let id2 = Uuid::new_v4();
780 db.create_conversation(id1, None).unwrap();
781 db.create_conversation(id2, None).unwrap();
782
783 db.store_event(Uuid::new_v4(), id1, "test", "content", None)
785 .unwrap();
786 db.store_event(Uuid::new_v4(), id1, "test", "content", None)
787 .unwrap();
788 db.store_event(Uuid::new_v4(), id2, "test", "content", None)
789 .unwrap();
790
791 db.end_conversation(id1).unwrap();
793
794 let stats = db.get_stats().unwrap();
795 assert_eq!(stats.conversation_count, 2);
796 assert_eq!(stats.event_count, 3);
797 assert_eq!(stats.active_conversations, 1);
798 }
799
800 #[test]
801 fn test_store_event_with_embedding() {
802 let db = Database::new(":memory:").unwrap();
803 let conv_id = Uuid::new_v4();
804 let event_id = Uuid::new_v4();
805
806 db.create_conversation(conv_id, None).unwrap();
807
808 let embedding = vec![0.5f32; 384];
810
811 db.store_event_with_embedding(
812 event_id,
813 conv_id,
814 "UserMessage",
815 "Test message",
816 None,
817 &embedding,
818 )
819 .unwrap();
820
821 let events = db.get_conversation_events(conv_id).unwrap();
823 assert_eq!(events.len(), 1);
824 assert_eq!(events[0].content, "Test message");
825
826 let query_embedding = vec![0.5f32; 384];
828 let results = db
829 .search_conversation_history(conv_id, &query_embedding, 10)
830 .unwrap();
831 assert_eq!(results.len(), 1);
832 assert!(results[0].1 > 0.99); }
834
835 #[test]
836 fn test_search_conversation_history() {
837 let db = Database::new(":memory:").unwrap();
838 let conv_id = Uuid::new_v4();
839
840 db.create_conversation(conv_id, None).unwrap();
841
842 let embedding1 = vec![1.0f32; 384]; let embedding2 = vec![0.0f32; 384]; let mut embedding3 = vec![0.0f32; 384];
846 embedding3[0] = 1.0; db.store_event_with_embedding(
849 Uuid::new_v4(),
850 conv_id,
851 "UserMessage",
852 "First message",
853 None,
854 &embedding1,
855 )
856 .unwrap();
857
858 db.store_event_with_embedding(
859 Uuid::new_v4(),
860 conv_id,
861 "UserMessage",
862 "Second message",
863 None,
864 &embedding2,
865 )
866 .unwrap();
867
868 db.store_event_with_embedding(
869 Uuid::new_v4(),
870 conv_id,
871 "UserMessage",
872 "Third message",
873 None,
874 &embedding3,
875 )
876 .unwrap();
877
878 let query = vec![1.0f32; 384];
880 let results = db.search_conversation_history(conv_id, &query, 2).unwrap();
881
882 assert_eq!(results.len(), 2);
883 assert_eq!(results[0].0.content, "First message");
884 assert!(results[0].1 > 0.99); }
886
887 #[test]
888 fn test_search_conversation_history_empty() {
889 let db = Database::new(":memory:").unwrap();
890 let conv_id = Uuid::new_v4();
891
892 db.create_conversation(conv_id, None).unwrap();
893
894 let query = vec![1.0f32; 384];
895 let results = db.search_conversation_history(conv_id, &query, 10).unwrap();
896
897 assert_eq!(results.len(), 0);
898 }
899
900 #[test]
901 fn test_search_conversation_history_top_k() {
902 let db = Database::new(":memory:").unwrap();
903 let conv_id = Uuid::new_v4();
904
905 db.create_conversation(conv_id, None).unwrap();
906
907 for i in 0..5 {
909 let mut embedding = vec![0.0f32; 384];
910 embedding[0] = i as f32;
911 db.store_event_with_embedding(
912 Uuid::new_v4(),
913 conv_id,
914 "UserMessage",
915 &format!("Message {}", i),
916 None,
917 &embedding,
918 )
919 .unwrap();
920 }
921
922 let query = vec![2.0f32; 384]; let results = db.search_conversation_history(conv_id, &query, 3).unwrap();
925
926 assert_eq!(results.len(), 3);
927 assert!(results[0].1 >= results[1].1);
929 assert!(results[1].1 >= results[2].1);
930 }
931
932 #[test]
933 fn test_cosine_similarity() {
934 let a = vec![1.0, 2.0, 3.0];
936 let b = vec![1.0, 2.0, 3.0];
937 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
938
939 let a = vec![1.0, 0.0];
941 let b = vec![0.0, 1.0];
942 assert!(cosine_similarity(&a, &b).abs() < 0.001);
943
944 let a = vec![1.0, 2.0, 3.0];
946 let b = vec![-1.0, -2.0, -3.0];
947 assert!((cosine_similarity(&a, &b) + 1.0).abs() < 0.001);
948
949 let a = vec![1.0, 2.0];
951 let b = vec![1.0, 2.0, 3.0];
952 assert_eq!(cosine_similarity(&a, &b), 0.0);
953
954 let a = vec![0.0, 0.0, 0.0];
956 let b = vec![1.0, 2.0, 3.0];
957 assert_eq!(cosine_similarity(&a, &b), 0.0);
958 }
959}