1use crate::types::{Artifact, Result, Span, Session, Message, MessageRole, SessionWorkingSet, SessionWithMessages, WorkingSet, CompilerConfig};
7use crate::index::VectorIndex;
8use rusqlite::{params, Connection, OptionalExtension};
9use std::path::{Path, PathBuf};
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::{Arc, Mutex, RwLock};
12use sha2::{Digest, Sha256};
13use serde::{Serialize, Deserialize};
14
15#[derive(Clone)]
17pub struct Database {
18 conn: Arc<Mutex<Connection>>,
19 vector_index: Arc<RwLock<Option<Arc<VectorIndex>>>>,
21 index_dirty: Arc<AtomicBool>,
23 db_path: PathBuf,
25 build_lock: Arc<Mutex<()>>,
27}
28
29#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
31pub enum IndexLoadKind {
32 LoadedFromCache,
34 BuiltFromSpans,
36 CachedInMemory,
38}
39impl Database {
40 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
50 let db_path = path.as_ref().to_path_buf();
51 let conn = Connection::open(&db_path)?;
52
53 let schema_001 = r#"-- AvocadoDB Initial Schema
55-- Phase 1: Simple SQLite-compatible schema for deterministic context compilation
56
57-- Artifacts table: stores ingested documents
58CREATE TABLE IF NOT EXISTS artifacts (
59 id TEXT PRIMARY KEY, -- UUID v4
60 path TEXT NOT NULL UNIQUE, -- File path or identifier
61 content TEXT NOT NULL, -- Full document text
62 content_hash TEXT NOT NULL, -- SHA256 of content
63 metadata TEXT, -- JSON string with arbitrary metadata
64 created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
65 updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
66);
67
68-- Spans table: stores document fragments with embeddings
69CREATE TABLE IF NOT EXISTS spans (
70 id TEXT PRIMARY KEY, -- UUID v4
71 artifact_id TEXT NOT NULL, -- Foreign key to artifacts
72 start_line INTEGER NOT NULL, -- Starting line number (1-indexed)
73 end_line INTEGER NOT NULL, -- Ending line number (inclusive)
74 text TEXT NOT NULL, -- Actual span text
75 embedding BLOB, -- Serialized f32 vector (1536 dims for ada-002)
76 embedding_model TEXT, -- e.g., "text-embedding-ada-002"
77 token_count INTEGER, -- Estimated token count
78 metadata TEXT, -- JSON string
79 created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
80 FOREIGN KEY (artifact_id) REFERENCES artifacts(id) ON DELETE CASCADE
81);
82
83-- Indexes for performance
84CREATE INDEX IF NOT EXISTS idx_spans_artifact ON spans(artifact_id);
85CREATE INDEX IF NOT EXISTS idx_spans_lines ON spans(artifact_id, start_line, end_line);
86CREATE INDEX IF NOT EXISTS idx_artifacts_path ON artifacts(path);
87CREATE INDEX IF NOT EXISTS idx_artifacts_hash ON artifacts(content_hash);
88
89-- Enable WAL mode for better concurrency
90PRAGMA journal_mode = WAL;
91PRAGMA foreign_keys = ON;
92"#;
93
94 let schema_without_pragma = schema_001
96 .lines()
97 .filter(|line| {
98 let trimmed = line.trim();
99 !trimmed.starts_with("PRAGMA") && !trimmed.starts_with("-- Enable WAL")
100 })
101 .collect::<Vec<_>>()
102 .join("\n");
103
104 conn.execute_batch(&schema_without_pragma)?;
105
106 let schema_002 = r#"-- AvocadoDB Session Management Schema
108-- Phase 2, Priority 1: Session tracking for conversation history and agent memory
109
110-- Sessions table: tracks conversation sessions
111CREATE TABLE IF NOT EXISTS sessions (
112 id TEXT PRIMARY KEY, -- UUID v4
113 user_id TEXT, -- Optional user identifier
114 title TEXT, -- Optional session title (auto-generated or user-provided)
115 metadata TEXT, -- JSON string with arbitrary metadata
116 created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
117 updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
118 last_message_at TIMESTAMP -- For sorting/filtering
119);
120
121-- Messages table: stores individual conversation turns
122CREATE TABLE IF NOT EXISTS messages (
123 id TEXT PRIMARY KEY, -- UUID v4
124 session_id TEXT NOT NULL, -- Foreign key to sessions
125 role TEXT NOT NULL, -- 'user', 'assistant', 'system', 'tool'
126 content TEXT NOT NULL, -- Message content
127 metadata TEXT, -- JSON string (tool calls, citations, etc.)
128 sequence_number INTEGER NOT NULL, -- Order within session (0-indexed)
129 created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
130 FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE
131);
132
133-- Working set associations: links compiled contexts to sessions
134CREATE TABLE IF NOT EXISTS session_working_sets (
135 id TEXT PRIMARY KEY, -- UUID v4
136 session_id TEXT NOT NULL, -- Foreign key to sessions
137 message_id TEXT, -- Optional: which message triggered this compilation
138 working_set_id TEXT NOT NULL, -- Reference to working set (stored as JSON for now)
139 query TEXT NOT NULL, -- Query that generated this working set
140 config TEXT, -- JSON string of CompilerConfig used
141 created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
142 FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE,
143 FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE SET NULL
144);
145
146-- Indexes for performance
147CREATE INDEX IF NOT EXISTS idx_sessions_user ON sessions(user_id);
148CREATE INDEX IF NOT EXISTS idx_sessions_updated ON sessions(updated_at DESC);
149CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id, sequence_number);
150CREATE INDEX IF NOT EXISTS idx_working_sets_session ON session_working_sets(session_id);
151"#;
152 conn.execute_batch(schema_002)?;
153
154 conn.pragma_update(None, "journal_mode", "WAL")?;
156 conn.pragma_update(None, "foreign_keys", true)?;
157
158 Ok(Self {
159 conn: Arc::new(Mutex::new(conn)),
160 vector_index: Arc::new(RwLock::new(None)),
161 index_dirty: Arc::new(AtomicBool::new(true)),
162 db_path,
163 build_lock: Arc::new(Mutex::new(())),
164 })
165 }
166
167 pub fn insert_artifact(&self, artifact: &Artifact) -> Result<()> {
177 let conn = self.conn.lock()
178 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
179 conn.execute(
180 "INSERT INTO artifacts (id, path, content, content_hash, metadata, created_at)
181 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
182 params![
183 artifact.id,
184 artifact.path,
185 artifact.content,
186 artifact.content_hash,
187 artifact.metadata.as_ref().map(|m| m.to_string()),
188 artifact.created_at.to_rfc3339(),
189 ],
190 )?;
191 self.index_dirty.store(true, Ordering::Release);
193 let _ = std::fs::remove_dir_all(self.get_index_cache_dir());
195 Ok(())
196 }
197
198 pub fn insert_spans(&self, spans: &[Span]) -> Result<()> {
208 let mut conn = self.conn.lock()
209 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
210 let tx = conn.transaction()?;
211
212 for span in spans {
213 tx.execute(
214 "INSERT INTO spans (
215 id, artifact_id, start_line, end_line, text,
216 embedding, embedding_model, token_count, metadata
217 ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
218 params![
219 span.id,
220 span.artifact_id,
221 span.start_line as i64,
222 span.end_line as i64,
223 span.text,
224 span.embedding.as_ref().map(|e| serialize_embedding(e)),
225 span.embedding_model,
226 span.token_count as i64,
227 span.metadata.as_ref().map(|m| m.to_string()),
228 ],
229 )?;
230 }
231
232 tx.commit()?;
233 self.index_dirty.store(true, Ordering::Release);
235 let _ = std::fs::remove_dir_all(self.get_index_cache_dir());
237 Ok(())
238 }
239
240 pub fn get_vector_index(&self) -> Result<Arc<VectorIndex>> {
249 Ok(self.get_vector_index_with_kind()?.0)
250 }
251
252 pub fn get_vector_index_with_kind(&self) -> Result<(Arc<VectorIndex>, IndexLoadKind)> {
257 if self.index_dirty.load(Ordering::Acquire) {
259 let _guard = self.build_lock.lock()
261 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Build lock poisoned: {}", e)))?;
262 if !self.index_dirty.load(Ordering::Acquire) {
264 let cached = self.vector_index.read()
265 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Index lock poisoned: {}", e)))?;
266 let idx = cached.as_ref()
267 .cloned()
268 .ok_or_else(|| crate::types::Error::Other(anyhow::anyhow!("Index cache empty after build")))?
269 ;
270 return Ok((idx, IndexLoadKind::CachedInMemory));
271 }
272 let cache_dir = self.get_index_cache_dir();
274 if let Ok(index) = self.load_index_from_disk(&cache_dir) {
275 let mut cached = self.vector_index.write()
279 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Index lock poisoned: {}", e)))?;
280 *cached = Some(index.clone());
281 self.index_dirty.store(false, Ordering::Release);
282 return Ok((index, IndexLoadKind::LoadedFromCache));
283 }
284
285 let spans = self.get_all_spans()?;
288 let index = Arc::new(VectorIndex::build(spans));
289
290 let _ = self.save_index_to_disk(&cache_dir, &index);
295
296 let mut cached = self.vector_index.write()
298 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Index lock poisoned: {}", e)))?;
299 *cached = Some(index.clone());
300
301 self.index_dirty.store(false, Ordering::Release);
303
304 Ok((index, IndexLoadKind::BuiltFromSpans))
305 } else {
306 let cached = self.vector_index.read()
308 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Index lock poisoned: {}", e)))?;
309 let idx = cached.as_ref()
310 .cloned()
311 .ok_or_else(|| crate::types::Error::Other(anyhow::anyhow!("Index cache is None but not dirty - this should not happen")))?;
312 Ok((idx, IndexLoadKind::CachedInMemory))
313 }
314 }
315
316 fn get_index_cache_dir(&self) -> PathBuf {
318 let mut cache_dir = self.db_path.clone();
320 cache_dir.set_extension("sqlite.idx");
321 cache_dir
322 }
323
324 fn calculate_spans_hash(&self) -> Result<String> {
326 let spans = self.get_all_spans()?;
327 let mut hasher = Sha256::new();
328 for span in &spans {
329 hasher.update(span.id.as_bytes());
330 if let Some(emb) = &span.embedding {
331 hasher.update(&emb.len().to_le_bytes());
332 }
333 }
334 Ok(format!("{:x}", hasher.finalize()))
335 }
336
337 fn load_index_from_disk(&self, cache_dir: &Path) -> Result<Arc<VectorIndex>> {
339 match VectorIndex::load_from_disk(cache_dir) {
341 Ok(Some(index)) => {
342 let current_hash = self.calculate_spans_hash()?;
344 let cached_spans = index.spans();
345 let mut hasher = Sha256::new();
346 for span in cached_spans {
347 hasher.update(span.id.as_bytes());
348 if let Some(emb) = &span.embedding {
349 hasher.update(&emb.len().to_le_bytes());
350 }
351 }
352 let cached_hash = format!("{:x}", hasher.finalize());
353
354 if cached_hash == current_hash {
355 Ok(Arc::new(index))
356 } else {
357 Err(crate::types::Error::NotFound("Index cache is stale".to_string()))
358 }
359 }
360 Ok(None) => Err(crate::types::Error::NotFound("Index cache not found".to_string())),
361 Err(e) => Err(e),
362 }
363 }
364
365 fn save_index_to_disk(&self, cache_dir: &Path, index: &VectorIndex) -> Result<()> {
367 index.save_to_disk(cache_dir)
369 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Failed to save index to disk: {}", e)))?;
370 Ok(())
371 }
372
373 pub fn get_all_spans(&self) -> Result<Vec<Span>> {
379 let conn = self.conn.lock()
380 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
381 let mut stmt = conn.prepare(
382 "SELECT id, artifact_id, start_line, end_line, text,
383 embedding, embedding_model, token_count, metadata
384 FROM spans",
385 )?;
386
387 let spans = stmt
388 .query_map([], |row| {
389 Ok(Span {
390 id: row.get(0)?,
391 artifact_id: row.get(1)?,
392 start_line: row.get::<_, i64>(2)? as usize,
393 end_line: row.get::<_, i64>(3)? as usize,
394 text: row.get(4)?,
395 embedding: row
396 .get::<_, Option<Vec<u8>>>(5)?
397 .map(|bytes| deserialize_embedding(&bytes)),
398 embedding_model: row.get(6)?,
399 token_count: row.get::<_, i64>(7)? as usize,
400 metadata: row
401 .get::<_, Option<String>>(8)?
402 .and_then(|s| serde_json::from_str(&s).ok()),
403 })
404 })?
405 .collect::<std::result::Result<Vec<_>, _>>()?;
406
407 Ok(spans)
408 }
409
410 pub fn get_artifact(&self, artifact_id: &str) -> Result<Option<Artifact>> {
420 let conn = self.conn.lock()
421 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
422 let mut stmt = conn.prepare(
423 "SELECT id, path, content, content_hash, metadata, created_at
424 FROM artifacts WHERE id = ?1",
425 )?;
426
427 let artifact = stmt
428 .query_row(params![artifact_id], |row| {
429 Ok(Artifact {
430 id: row.get(0)?,
431 path: row.get(1)?,
432 content: row.get(2)?,
433 content_hash: row.get(3)?,
434 metadata: row
435 .get::<_, Option<String>>(4)?
436 .and_then(|s| serde_json::from_str(&s).ok()),
437 created_at: row
438 .get::<_, String>(5)?
439 .parse()
440 .unwrap_or_else(|_| chrono::Utc::now()),
441 })
442 })
443 .optional()?;
444
445 Ok(artifact)
446 }
447
448 pub fn get_artifact_by_path(&self, path: &str) -> Result<Option<Artifact>> {
452 let conn = self.conn.lock()
453 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
454 let mut stmt = conn.prepare(
455 "SELECT id, path, content, content_hash, metadata, created_at
456 FROM artifacts WHERE path = ?1",
457 )?;
458
459 let artifact = stmt
460 .query_row(params![path], |row| {
461 Ok(Artifact {
462 id: row.get(0)?,
463 path: row.get(1)?,
464 content: row.get(2)?,
465 content_hash: row.get(3)?,
466 metadata: row
467 .get::<_, Option<String>>(4)?
468 .and_then(|s| serde_json::from_str(&s).ok()),
469 created_at: row
470 .get::<_, String>(5)?
471 .parse()
472 .unwrap_or_else(|_| chrono::Utc::now()),
473 })
474 })
475 .optional()?;
476
477 Ok(artifact)
478 }
479
480 pub fn search_spans(&self, query: &str, limit: usize) -> Result<Vec<Span>> {
491 let conn = self.conn.lock()
492 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
493 let mut stmt = conn.prepare(
494 "SELECT id, artifact_id, start_line, end_line, text,
495 embedding, embedding_model, token_count, metadata
496 FROM spans
497 WHERE text LIKE ?1
498 LIMIT ?2",
499 )?;
500
501 let pattern = format!("%{}%", query);
502 let spans = stmt
503 .query_map(params![pattern, limit as i64], |row| {
504 Ok(Span {
505 id: row.get(0)?,
506 artifact_id: row.get(1)?,
507 start_line: row.get::<_, i64>(2)? as usize,
508 end_line: row.get::<_, i64>(3)? as usize,
509 text: row.get(4)?,
510 embedding: row
511 .get::<_, Option<Vec<u8>>>(5)?
512 .map(|bytes| deserialize_embedding(&bytes)),
513 embedding_model: row.get(6)?,
514 token_count: row.get::<_, i64>(7)? as usize,
515 metadata: row
516 .get::<_, Option<String>>(8)?
517 .and_then(|s| serde_json::from_str(&s).ok()),
518 })
519 })?
520 .collect::<std::result::Result<Vec<_>, _>>()?;
521
522 Ok(spans)
523 }
524
525 pub fn get_stats(&self) -> Result<(usize, usize, usize)> {
531 let conn = self.conn.lock()
532 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
533
534 let artifacts_count: i64 = conn.query_row("SELECT COUNT(*) FROM artifacts", [], |row| {
535 row.get(0)
536 })?;
537
538 let spans_count: i64 = conn.query_row("SELECT COUNT(*) FROM spans", [], |row| row.get(0))?;
539
540 let total_tokens: i64 = conn
541 .query_row("SELECT COALESCE(SUM(token_count), 0) FROM spans", [], |row| {
542 row.get(0)
543 })?;
544
545 Ok((
546 artifacts_count as usize,
547 spans_count as usize,
548 total_tokens as usize,
549 ))
550 }
551
552 pub fn clear(&self) -> Result<()> {
554 let conn = self.conn.lock()
555 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
556 conn.execute("DELETE FROM spans", [])?;
557 conn.execute("DELETE FROM artifacts", [])?;
558 let mut cached = self.vector_index.write()
560 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Index lock poisoned: {}", e)))?;
561 *cached = None;
562 self.index_dirty.store(true, Ordering::Release);
563 let _ = std::fs::remove_dir_all(self.get_index_cache_dir());
565 Ok(())
566 }
567
568 pub fn create_session(&self, user_id: Option<&str>, title: Option<&str>) -> Result<Session> {
581 let conn = self.conn.lock()
582 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
583
584 let id = uuid::Uuid::new_v4().to_string();
585 let now = chrono::Utc::now();
586
587 conn.execute(
588 "INSERT INTO sessions (id, user_id, title, metadata, created_at, updated_at, last_message_at)
589 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
590 params![
591 id,
592 user_id,
593 title,
594 None::<String>, now.to_rfc3339(),
596 now.to_rfc3339(),
597 None::<String>, ],
599 )?;
600
601 Ok(Session {
602 id,
603 user_id: user_id.map(|s| s.to_string()),
604 title: title.map(|s| s.to_string()),
605 metadata: None,
606 created_at: now,
607 updated_at: now,
608 last_message_at: None,
609 })
610 }
611
612 pub fn get_session(&self, session_id: &str) -> Result<Option<Session>> {
622 let conn = self.conn.lock()
623 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
624
625 let mut stmt = conn.prepare(
626 "SELECT id, user_id, title, metadata, created_at, updated_at, last_message_at
627 FROM sessions WHERE id = ?1",
628 )?;
629
630 let session = stmt.query_row(params![session_id], |row| {
631 Ok(Session {
632 id: row.get(0)?,
633 user_id: row.get(1)?,
634 title: row.get(2)?,
635 metadata: row.get::<_, Option<String>>(3)?
636 .and_then(|s| serde_json::from_str(&s).ok()),
637 created_at: row.get::<_, String>(4)?
638 .parse()
639 .unwrap_or_else(|_| chrono::Utc::now()),
640 updated_at: row.get::<_, String>(5)?
641 .parse()
642 .unwrap_or_else(|_| chrono::Utc::now()),
643 last_message_at: row.get::<_, Option<String>>(6)?
644 .and_then(|s| s.parse().ok()),
645 })
646 }).optional()?;
647
648 Ok(session)
649 }
650
651 pub fn list_sessions(&self, user_id: Option<&str>, limit: Option<usize>) -> Result<Vec<Session>> {
662 let conn = self.conn.lock()
663 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
664
665 let limit_val = limit.unwrap_or(100) as i64;
666
667 let mut sessions = Vec::new();
668
669 if let Some(uid) = user_id {
670 let mut stmt = conn.prepare(
671 "SELECT id, user_id, title, metadata, created_at, updated_at, last_message_at
672 FROM sessions WHERE user_id = ?1
673 ORDER BY updated_at DESC
674 LIMIT ?2"
675 )?;
676
677 let rows = stmt.query_map(params![uid, limit_val], |row| {
678 Ok(Session {
679 id: row.get(0)?,
680 user_id: row.get(1)?,
681 title: row.get(2)?,
682 metadata: row.get::<_, Option<String>>(3)?
683 .and_then(|s| serde_json::from_str(&s).ok()),
684 created_at: row.get::<_, String>(4)?
685 .parse()
686 .unwrap_or_else(|_| chrono::Utc::now()),
687 updated_at: row.get::<_, String>(5)?
688 .parse()
689 .unwrap_or_else(|_| chrono::Utc::now()),
690 last_message_at: row.get::<_, Option<String>>(6)?
691 .and_then(|s| s.parse().ok()),
692 })
693 })?;
694
695 for row in rows {
696 sessions.push(row?);
697 }
698 } else {
699 let mut stmt = conn.prepare(
700 "SELECT id, user_id, title, metadata, created_at, updated_at, last_message_at
701 FROM sessions
702 ORDER BY updated_at DESC
703 LIMIT ?1"
704 )?;
705
706 let rows = stmt.query_map(params![limit_val], |row| {
707 Ok(Session {
708 id: row.get(0)?,
709 user_id: row.get(1)?,
710 title: row.get(2)?,
711 metadata: row.get::<_, Option<String>>(3)?
712 .and_then(|s| serde_json::from_str(&s).ok()),
713 created_at: row.get::<_, String>(4)?
714 .parse()
715 .unwrap_or_else(|_| chrono::Utc::now()),
716 updated_at: row.get::<_, String>(5)?
717 .parse()
718 .unwrap_or_else(|_| chrono::Utc::now()),
719 last_message_at: row.get::<_, Option<String>>(6)?
720 .and_then(|s| s.parse().ok()),
721 })
722 })?;
723
724 for row in rows {
725 sessions.push(row?);
726 }
727 }
728
729 Ok(sessions)
730 }
731
732 pub fn update_session(
744 &self,
745 session_id: &str,
746 title: Option<&str>,
747 metadata: Option<&serde_json::Value>,
748 ) -> Result<()> {
749 let conn = self.conn.lock()
750 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
751
752 let now = chrono::Utc::now();
753
754 conn.execute(
755 "UPDATE sessions
756 SET title = COALESCE(?1, title),
757 metadata = COALESCE(?2, metadata),
758 updated_at = ?3
759 WHERE id = ?4",
760 params![
761 title,
762 metadata.map(|m| m.to_string()),
763 now.to_rfc3339(),
764 session_id,
765 ],
766 )?;
767
768 Ok(())
769 }
770
771 pub fn delete_session(&self, session_id: &str) -> Result<()> {
781 let conn = self.conn.lock()
782 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
783
784 conn.execute("DELETE FROM sessions WHERE id = ?1", params![session_id])?;
785
786 Ok(())
787 }
788
789 pub fn add_message(
802 &self,
803 session_id: &str,
804 role: MessageRole,
805 content: &str,
806 metadata: Option<&serde_json::Value>,
807 ) -> Result<Message> {
808 let mut conn = self.conn.lock()
809 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
810
811 let tx = conn.transaction()?;
812
813 let sequence_number: i64 = tx.query_row(
815 "SELECT COALESCE(MAX(sequence_number), -1) + 1 FROM messages WHERE session_id = ?1",
816 params![session_id],
817 |row| row.get(0),
818 )?;
819
820 let id = uuid::Uuid::new_v4().to_string();
821 let now = chrono::Utc::now();
822
823 tx.execute(
824 "INSERT INTO messages (id, session_id, role, content, metadata, sequence_number, created_at)
825 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
826 params![
827 id,
828 session_id,
829 role.as_str(),
830 content,
831 metadata.map(|m| m.to_string()),
832 sequence_number,
833 now.to_rfc3339(),
834 ],
835 )?;
836
837 tx.execute(
839 "UPDATE sessions
840 SET last_message_at = ?1, updated_at = ?1
841 WHERE id = ?2",
842 params![now.to_rfc3339(), session_id],
843 )?;
844
845 tx.commit()?;
846
847 Ok(Message {
848 id,
849 session_id: session_id.to_string(),
850 role,
851 content: content.to_string(),
852 metadata: metadata.cloned(),
853 sequence_number: sequence_number as usize,
854 created_at: now,
855 })
856 }
857
858 pub fn get_messages(&self, session_id: &str, limit: Option<usize>) -> Result<Vec<Message>> {
869 let conn = self.conn.lock()
870 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
871
872 let mut messages = Vec::new();
873
874 if let Some(lim) = limit {
875 let mut stmt = conn.prepare(
876 "SELECT id, session_id, role, content, metadata, sequence_number, created_at
877 FROM messages
878 WHERE session_id = ?1
879 ORDER BY sequence_number ASC
880 LIMIT ?2"
881 )?;
882
883 let rows = stmt.query_map(params![session_id, lim as i64], |row| {
884 let role_str: String = row.get(2)?;
885 let role = MessageRole::from_str(&role_str)
886 .unwrap_or(MessageRole::User);
887
888 Ok(Message {
889 id: row.get(0)?,
890 session_id: row.get(1)?,
891 role,
892 content: row.get(3)?,
893 metadata: row.get::<_, Option<String>>(4)?
894 .and_then(|s| serde_json::from_str(&s).ok()),
895 sequence_number: row.get::<_, i64>(5)? as usize,
896 created_at: row.get::<_, String>(6)?
897 .parse()
898 .unwrap_or_else(|_| chrono::Utc::now()),
899 })
900 })?;
901
902 for row in rows {
903 messages.push(row?);
904 }
905 } else {
906 let mut stmt = conn.prepare(
907 "SELECT id, session_id, role, content, metadata, sequence_number, created_at
908 FROM messages
909 WHERE session_id = ?1
910 ORDER BY sequence_number ASC"
911 )?;
912
913 let rows = stmt.query_map(params![session_id], |row| {
914 let role_str: String = row.get(2)?;
915 let role = MessageRole::from_str(&role_str)
916 .unwrap_or(MessageRole::User);
917
918 Ok(Message {
919 id: row.get(0)?,
920 session_id: row.get(1)?,
921 role,
922 content: row.get(3)?,
923 metadata: row.get::<_, Option<String>>(4)?
924 .and_then(|s| serde_json::from_str(&s).ok()),
925 sequence_number: row.get::<_, i64>(5)? as usize,
926 created_at: row.get::<_, String>(6)?
927 .parse()
928 .unwrap_or_else(|_| chrono::Utc::now()),
929 })
930 })?;
931
932 for row in rows {
933 messages.push(row?);
934 }
935 }
936
937 Ok(messages)
938 }
939
940 pub fn associate_working_set(
954 &self,
955 session_id: &str,
956 message_id: Option<&str>,
957 working_set: &WorkingSet,
958 query: &str,
959 config: &CompilerConfig,
960 ) -> Result<SessionWorkingSet> {
961 let conn = self.conn.lock()
962 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
963
964 let id = uuid::Uuid::new_v4().to_string();
965 let working_set_id = uuid::Uuid::new_v4().to_string(); let now = chrono::Utc::now();
967
968 conn.execute(
969 "INSERT INTO session_working_sets (id, session_id, message_id, working_set_id, query, config, created_at)
970 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
971 params![
972 id,
973 session_id,
974 message_id,
975 working_set_id,
976 query,
977 serde_json::to_string(config)?,
978 now.to_rfc3339(),
979 ],
980 )?;
981
982 Ok(SessionWorkingSet {
983 id,
984 session_id: session_id.to_string(),
985 message_id: message_id.map(|s| s.to_string()),
986 working_set: working_set.clone(),
987 query: query.to_string(),
988 config: config.clone(),
989 created_at: now,
990 })
991 }
992
993 pub fn get_session_full(&self, session_id: &str) -> Result<Option<SessionWithMessages>> {
1003 let session = self.get_session(session_id)?;
1004
1005 if session.is_none() {
1006 return Ok(None);
1007 }
1008
1009 let session = session.unwrap();
1010 let messages = self.get_messages(session_id, None)?;
1011
1012 let conn = self.conn.lock()
1014 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
1015
1016 let mut stmt = conn.prepare(
1017 "SELECT id, session_id, message_id, working_set_id, query, config, created_at
1018 FROM session_working_sets
1019 WHERE session_id = ?1
1020 ORDER BY created_at ASC",
1021 )?;
1022
1023 let working_sets = stmt.query_map(params![session_id], |row| {
1024 let config_str: String = row.get(5)?;
1025 let config: CompilerConfig = serde_json::from_str(&config_str)
1026 .unwrap_or_default();
1027
1028 let working_set = WorkingSet {
1032 text: String::new(),
1033 spans: Vec::new(),
1034 citations: Vec::new(),
1035 tokens_used: 0,
1036 query: row.get::<_, String>(4)?,
1037 compilation_time_ms: 0,
1038 };
1039
1040 Ok(SessionWorkingSet {
1041 id: row.get(0)?,
1042 session_id: row.get(1)?,
1043 message_id: row.get(2)?,
1044 working_set,
1045 query: row.get(4)?,
1046 config,
1047 created_at: row.get::<_, String>(6)?
1048 .parse()
1049 .unwrap_or_else(|_| chrono::Utc::now()),
1050 })
1051 })?
1052 .collect::<std::result::Result<Vec<_>, _>>()?;
1053
1054 Ok(Some(SessionWithMessages {
1055 session,
1056 messages,
1057 working_sets,
1058 }))
1059 }
1060}
1061
1062fn serialize_embedding(embedding: &[f32]) -> Vec<u8> {
1064 embedding.iter().flat_map(|f| f.to_le_bytes()).collect()
1065}
1066
1067fn deserialize_embedding(bytes: &[u8]) -> Vec<f32> {
1069 bytes
1070 .chunks_exact(4)
1071 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
1072 .collect()
1073}
1074
1075#[cfg(test)]
1076mod tests {
1077 use super::*;
1078 use uuid::Uuid;
1079
1080 #[test]
1081 fn test_database_creation() {
1082 let db = Database::new(":memory:").unwrap();
1083 let (artifacts, spans, tokens) = db.get_stats().unwrap();
1084 assert_eq!(artifacts, 0);
1085 assert_eq!(spans, 0);
1086 assert_eq!(tokens, 0);
1087 }
1088
1089 #[test]
1090 fn test_insert_artifact() {
1091 let db = Database::new(":memory:").unwrap();
1092
1093 let artifact = Artifact {
1094 id: Uuid::new_v4().to_string(),
1095 path: "test.txt".to_string(),
1096 content: "Test content".to_string(),
1097 content_hash: "hash123".to_string(),
1098 metadata: None,
1099 created_at: chrono::Utc::now(),
1100 };
1101
1102 db.insert_artifact(&artifact).unwrap();
1103
1104 let (count, _, _) = db.get_stats().unwrap();
1105 assert_eq!(count, 1);
1106 }
1107
1108 #[test]
1109 fn test_embedding_serialization() {
1110 let original = vec![1.0, 2.5, -3.14, 0.0];
1111 let bytes = serialize_embedding(&original);
1112 let restored = deserialize_embedding(&bytes);
1113
1114 assert_eq!(original.len(), restored.len());
1115 for (a, b) in original.iter().zip(restored.iter()) {
1116 assert!((a - b).abs() < 0.0001);
1117 }
1118 }
1119
1120 #[test]
1123 fn test_create_session() {
1124 let db = Database::new(":memory:").unwrap();
1125
1126 let session = db.create_session(Some("user123"), Some("Test Session")).unwrap();
1127
1128 assert!(!session.id.is_empty());
1129 assert_eq!(session.user_id, Some("user123".to_string()));
1130 assert_eq!(session.title, Some("Test Session".to_string()));
1131 assert!(session.metadata.is_none());
1132 assert!(session.last_message_at.is_none());
1133 }
1134
1135 #[test]
1136 fn test_get_session() {
1137 let db = Database::new(":memory:").unwrap();
1138
1139 let created = db.create_session(Some("user456"), Some("Another Session")).unwrap();
1140 let retrieved = db.get_session(&created.id).unwrap();
1141
1142 assert!(retrieved.is_some());
1143 let session = retrieved.unwrap();
1144 assert_eq!(session.id, created.id);
1145 assert_eq!(session.user_id, created.user_id);
1146 assert_eq!(session.title, created.title);
1147 }
1148
1149 #[test]
1150 fn test_get_nonexistent_session() {
1151 let db = Database::new(":memory:").unwrap();
1152
1153 let result = db.get_session("nonexistent-id").unwrap();
1154 assert!(result.is_none());
1155 }
1156
1157 #[test]
1158 fn test_list_sessions() {
1159 let db = Database::new(":memory:").unwrap();
1160
1161 db.create_session(Some("user1"), Some("Session 1")).unwrap();
1163 db.create_session(Some("user1"), Some("Session 2")).unwrap();
1164 db.create_session(Some("user2"), Some("Session 3")).unwrap();
1165
1166 let all_sessions = db.list_sessions(None, None).unwrap();
1168 assert_eq!(all_sessions.len(), 3);
1169
1170 let user1_sessions = db.list_sessions(Some("user1"), None).unwrap();
1172 assert_eq!(user1_sessions.len(), 2);
1173
1174 let user2_sessions = db.list_sessions(Some("user2"), None).unwrap();
1176 assert_eq!(user2_sessions.len(), 1);
1177
1178 let limited = db.list_sessions(None, Some(2)).unwrap();
1180 assert_eq!(limited.len(), 2);
1181 }
1182
1183 #[test]
1184 fn test_update_session() {
1185 let db = Database::new(":memory:").unwrap();
1186
1187 let session = db.create_session(Some("user1"), Some("Original Title")).unwrap();
1188
1189 db.update_session(&session.id, Some("Updated Title"), None).unwrap();
1191
1192 let updated = db.get_session(&session.id).unwrap().unwrap();
1193 assert_eq!(updated.title, Some("Updated Title".to_string()));
1194
1195 let metadata = serde_json::json!({"key": "value"});
1197 db.update_session(&session.id, None, Some(&metadata)).unwrap();
1198
1199 let updated2 = db.get_session(&session.id).unwrap().unwrap();
1200 assert!(updated2.metadata.is_some());
1201 assert_eq!(updated2.metadata.unwrap()["key"], "value");
1202 }
1203
1204 #[test]
1205 fn test_delete_session() {
1206 let db = Database::new(":memory:").unwrap();
1207
1208 let session = db.create_session(Some("user1"), Some("To Delete")).unwrap();
1209
1210 assert!(db.get_session(&session.id).unwrap().is_some());
1212
1213 db.delete_session(&session.id).unwrap();
1215
1216 assert!(db.get_session(&session.id).unwrap().is_none());
1218 }
1219
1220 #[test]
1221 fn test_add_message() {
1222 let db = Database::new(":memory:").unwrap();
1223
1224 let session = db.create_session(Some("user1"), Some("Chat Session")).unwrap();
1225
1226 let msg1 = db.add_message(&session.id, MessageRole::User, "Hello", None).unwrap();
1228 assert_eq!(msg1.sequence_number, 0);
1229 assert_eq!(msg1.content, "Hello");
1230 assert_eq!(msg1.role.as_str(), "user");
1231
1232 let msg2 = db.add_message(&session.id, MessageRole::Assistant, "Hi there!", None).unwrap();
1234 assert_eq!(msg2.sequence_number, 1);
1235 assert_eq!(msg2.content, "Hi there!");
1236 assert_eq!(msg2.role.as_str(), "assistant");
1237
1238 let updated_session = db.get_session(&session.id).unwrap().unwrap();
1240 assert!(updated_session.last_message_at.is_some());
1241 }
1242
1243 #[test]
1244 fn test_add_message_with_metadata() {
1245 let db = Database::new(":memory:").unwrap();
1246
1247 let session = db.create_session(Some("user1"), Some("Chat Session")).unwrap();
1248
1249 let metadata = serde_json::json!({"tool": "search", "query": "test"});
1250 let msg = db.add_message(&session.id, MessageRole::Tool, "Result", Some(&metadata)).unwrap();
1251
1252 assert!(msg.metadata.is_some());
1253 assert_eq!(msg.metadata.unwrap()["tool"], "search");
1254 }
1255
1256 #[test]
1257 fn test_get_messages() {
1258 let db = Database::new(":memory:").unwrap();
1259
1260 let session = db.create_session(Some("user1"), Some("Chat Session")).unwrap();
1261
1262 db.add_message(&session.id, MessageRole::User, "Message 1", None).unwrap();
1264 db.add_message(&session.id, MessageRole::Assistant, "Message 2", None).unwrap();
1265 db.add_message(&session.id, MessageRole::User, "Message 3", None).unwrap();
1266
1267 let messages = db.get_messages(&session.id, None).unwrap();
1269 assert_eq!(messages.len(), 3);
1270 assert_eq!(messages[0].sequence_number, 0);
1271 assert_eq!(messages[1].sequence_number, 1);
1272 assert_eq!(messages[2].sequence_number, 2);
1273
1274 let limited = db.get_messages(&session.id, Some(2)).unwrap();
1276 assert_eq!(limited.len(), 2);
1277 }
1278
1279 #[test]
1280 fn test_message_ordering() {
1281 let db = Database::new(":memory:").unwrap();
1282
1283 let session = db.create_session(Some("user1"), Some("Chat Session")).unwrap();
1284
1285 db.add_message(&session.id, MessageRole::User, "First", None).unwrap();
1287 db.add_message(&session.id, MessageRole::Assistant, "Second", None).unwrap();
1288 db.add_message(&session.id, MessageRole::User, "Third", None).unwrap();
1289
1290 let messages = db.get_messages(&session.id, None).unwrap();
1291
1292 assert_eq!(messages[0].content, "First");
1294 assert_eq!(messages[1].content, "Second");
1295 assert_eq!(messages[2].content, "Third");
1296
1297 for (i, msg) in messages.iter().enumerate() {
1299 assert_eq!(msg.sequence_number, i);
1300 }
1301 }
1302
1303 #[test]
1304 fn test_associate_working_set() {
1305 let db = Database::new(":memory:").unwrap();
1306
1307 let session = db.create_session(Some("user1"), Some("Chat Session")).unwrap();
1308 let message = db.add_message(&session.id, MessageRole::User, "Query", None).unwrap();
1309
1310 let working_set = WorkingSet {
1312 text: "Test context".to_string(),
1313 spans: Vec::new(),
1314 citations: Vec::new(),
1315 tokens_used: 100,
1316 query: "test query".to_string(),
1317 compilation_time_ms: 50,
1318 };
1319
1320 let config = CompilerConfig::default();
1321
1322 let sws = db.associate_working_set(
1323 &session.id,
1324 Some(&message.id),
1325 &working_set,
1326 "test query",
1327 &config,
1328 ).unwrap();
1329
1330 assert_eq!(sws.session_id, session.id);
1331 assert_eq!(sws.message_id, Some(message.id));
1332 assert_eq!(sws.query, "test query");
1333 assert_eq!(sws.working_set.text, "Test context");
1334 }
1335
1336 #[test]
1337 fn test_get_session_full() {
1338 let db = Database::new(":memory:").unwrap();
1339
1340 let session = db.create_session(Some("user1"), Some("Full Session")).unwrap();
1341
1342 let msg1 = db.add_message(&session.id, MessageRole::User, "Hello", None).unwrap();
1344 db.add_message(&session.id, MessageRole::Assistant, "Hi!", None).unwrap();
1345
1346 let working_set = WorkingSet {
1348 text: "Context".to_string(),
1349 spans: Vec::new(),
1350 citations: Vec::new(),
1351 tokens_used: 50,
1352 query: "test".to_string(),
1353 compilation_time_ms: 25,
1354 };
1355
1356 db.associate_working_set(
1357 &session.id,
1358 Some(&msg1.id),
1359 &working_set,
1360 "test",
1361 &CompilerConfig::default(),
1362 ).unwrap();
1363
1364 let full = db.get_session_full(&session.id).unwrap();
1366 assert!(full.is_some());
1367
1368 let swm = full.unwrap();
1369 assert_eq!(swm.session.id, session.id);
1370 assert_eq!(swm.messages.len(), 2);
1371 assert_eq!(swm.working_sets.len(), 1);
1372 }
1373
1374 #[test]
1375 fn test_delete_session_cascade() {
1376 let db = Database::new(":memory:").unwrap();
1377
1378 let session = db.create_session(Some("user1"), Some("To Delete")).unwrap();
1379
1380 db.add_message(&session.id, MessageRole::User, "Message 1", None).unwrap();
1382 db.add_message(&session.id, MessageRole::Assistant, "Message 2", None).unwrap();
1383
1384 let messages_before = db.get_messages(&session.id, None).unwrap();
1386 assert_eq!(messages_before.len(), 2);
1387
1388 db.delete_session(&session.id).unwrap();
1390
1391 let messages_after = db.get_messages(&session.id, None).unwrap();
1393 assert_eq!(messages_after.len(), 0);
1394 }
1395
1396 #[test]
1397 fn test_message_role_conversion() {
1398 assert_eq!(MessageRole::User.as_str(), "user");
1399 assert_eq!(MessageRole::Assistant.as_str(), "assistant");
1400 assert_eq!(MessageRole::System.as_str(), "system");
1401 assert_eq!(MessageRole::Tool.as_str(), "tool");
1402
1403 assert!(matches!(MessageRole::from_str("user").unwrap(), MessageRole::User));
1404 assert!(matches!(MessageRole::from_str("assistant").unwrap(), MessageRole::Assistant));
1405 assert!(matches!(MessageRole::from_str("system").unwrap(), MessageRole::System));
1406 assert!(matches!(MessageRole::from_str("tool").unwrap(), MessageRole::Tool));
1407
1408 assert!(MessageRole::from_str("invalid").is_err());
1409 }
1410}