1use crate::types::{Artifact, Result, Span, Session, Message, MessageRole, SessionWorkingSet, SessionWithMessages, WorkingSet, CompilerConfig};
7use crate::index::VectorIndex;
8use rusqlite::{params, Connection, OptionalExtension};
9use std::path::{Path, PathBuf};
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::{Arc, Mutex, RwLock};
12use sha2::{Digest, Sha256};
13use serde::{Serialize, Deserialize};
14
15#[derive(Clone)]
17pub struct Database {
18 conn: Arc<Mutex<Connection>>,
19 vector_index: Arc<RwLock<Option<Arc<VectorIndex>>>>,
21 index_dirty: Arc<AtomicBool>,
23 db_path: PathBuf,
25 build_lock: Arc<Mutex<()>>,
27}
28
29#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
31pub enum IndexLoadKind {
32 LoadedFromCache,
34 BuiltFromSpans,
36 CachedInMemory,
38}
39impl Database {
40 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
50 let db_path = path.as_ref().to_path_buf();
51 let conn = Connection::open(&db_path)?;
52
53 let schema_001 = r#"-- AvocadoDB Initial Schema
55-- Phase 1: Simple SQLite-compatible schema for deterministic context compilation
56
57-- Artifacts table: stores ingested documents
58CREATE TABLE IF NOT EXISTS artifacts (
59 id TEXT PRIMARY KEY, -- UUID v4
60 path TEXT NOT NULL UNIQUE, -- File path or identifier
61 content TEXT NOT NULL, -- Full document text
62 content_hash TEXT NOT NULL, -- SHA256 of content
63 metadata TEXT, -- JSON string with arbitrary metadata
64 created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
65 updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
66);
67
68-- Spans table: stores document fragments with embeddings
69CREATE TABLE IF NOT EXISTS spans (
70 id TEXT PRIMARY KEY, -- UUID v4
71 artifact_id TEXT NOT NULL, -- Foreign key to artifacts
72 start_line INTEGER NOT NULL, -- Starting line number (1-indexed)
73 end_line INTEGER NOT NULL, -- Ending line number (inclusive)
74 text TEXT NOT NULL, -- Actual span text
75 embedding BLOB, -- Serialized f32 vector (1536 dims for ada-002)
76 embedding_model TEXT, -- e.g., "text-embedding-ada-002"
77 token_count INTEGER, -- Estimated token count
78 metadata TEXT, -- JSON string
79 created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
80 FOREIGN KEY (artifact_id) REFERENCES artifacts(id) ON DELETE CASCADE
81);
82
83-- Indexes for performance
84CREATE INDEX IF NOT EXISTS idx_spans_artifact ON spans(artifact_id);
85CREATE INDEX IF NOT EXISTS idx_spans_lines ON spans(artifact_id, start_line, end_line);
86CREATE INDEX IF NOT EXISTS idx_artifacts_path ON artifacts(path);
87CREATE INDEX IF NOT EXISTS idx_artifacts_hash ON artifacts(content_hash);
88
89-- Enable WAL mode for better concurrency
90PRAGMA journal_mode = WAL;
91PRAGMA foreign_keys = ON;
92"#;
93
94 let schema_without_pragma = schema_001
96 .lines()
97 .filter(|line| {
98 let trimmed = line.trim();
99 !trimmed.starts_with("PRAGMA") && !trimmed.starts_with("-- Enable WAL")
100 })
101 .collect::<Vec<_>>()
102 .join("\n");
103
104 conn.execute_batch(&schema_without_pragma)?;
105
106 let schema_002 = r#"-- AvocadoDB Session Management Schema
108-- Phase 2, Priority 1: Session tracking for conversation history and agent memory
109
110-- Sessions table: tracks conversation sessions
111CREATE TABLE IF NOT EXISTS sessions (
112 id TEXT PRIMARY KEY, -- UUID v4
113 user_id TEXT, -- Optional user identifier
114 title TEXT, -- Optional session title (auto-generated or user-provided)
115 metadata TEXT, -- JSON string with arbitrary metadata
116 created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
117 updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
118 last_message_at TIMESTAMP -- For sorting/filtering
119);
120
121-- Messages table: stores individual conversation turns
122CREATE TABLE IF NOT EXISTS messages (
123 id TEXT PRIMARY KEY, -- UUID v4
124 session_id TEXT NOT NULL, -- Foreign key to sessions
125 role TEXT NOT NULL, -- 'user', 'assistant', 'system', 'tool'
126 content TEXT NOT NULL, -- Message content
127 metadata TEXT, -- JSON string (tool calls, citations, etc.)
128 sequence_number INTEGER NOT NULL, -- Order within session (0-indexed)
129 created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
130 FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE
131);
132
133-- Working set associations: links compiled contexts to sessions
134CREATE TABLE IF NOT EXISTS session_working_sets (
135 id TEXT PRIMARY KEY, -- UUID v4
136 session_id TEXT NOT NULL, -- Foreign key to sessions
137 message_id TEXT, -- Optional: which message triggered this compilation
138 working_set_id TEXT NOT NULL, -- Reference to working set (stored as JSON for now)
139 query TEXT NOT NULL, -- Query that generated this working set
140 config TEXT, -- JSON string of CompilerConfig used
141 created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
142 FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE,
143 FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE SET NULL
144);
145
146-- Indexes for performance
147CREATE INDEX IF NOT EXISTS idx_sessions_user ON sessions(user_id);
148CREATE INDEX IF NOT EXISTS idx_sessions_updated ON sessions(updated_at DESC);
149CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id, sequence_number);
150CREATE INDEX IF NOT EXISTS idx_working_sets_session ON session_working_sets(session_id);
151"#;
152 conn.execute_batch(schema_002)?;
153
154 conn.pragma_update(None, "journal_mode", "WAL")?;
156 conn.pragma_update(None, "foreign_keys", true)?;
157
158 Ok(Self {
159 conn: Arc::new(Mutex::new(conn)),
160 vector_index: Arc::new(RwLock::new(None)),
161 index_dirty: Arc::new(AtomicBool::new(true)),
162 db_path,
163 build_lock: Arc::new(Mutex::new(())),
164 })
165 }
166
167 pub fn insert_artifact(&self, artifact: &Artifact) -> Result<()> {
177 let conn = self.conn.lock()
178 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
179 conn.execute(
180 "INSERT INTO artifacts (id, path, content, content_hash, metadata, created_at)
181 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
182 params![
183 artifact.id,
184 artifact.path,
185 artifact.content,
186 artifact.content_hash,
187 artifact.metadata.as_ref().map(|m| m.to_string()),
188 artifact.created_at.to_rfc3339(),
189 ],
190 )?;
191 self.index_dirty.store(true, Ordering::Release);
193 let _ = std::fs::remove_dir_all(self.get_index_cache_dir());
195 Ok(())
196 }
197
198 pub fn insert_spans(&self, spans: &[Span]) -> Result<()> {
208 let mut conn = self.conn.lock()
209 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
210 let tx = conn.transaction()?;
211
212 for span in spans {
213 tx.execute(
214 "INSERT INTO spans (
215 id, artifact_id, start_line, end_line, text,
216 embedding, embedding_model, token_count, metadata
217 ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
218 params![
219 span.id,
220 span.artifact_id,
221 span.start_line as i64,
222 span.end_line as i64,
223 span.text,
224 span.embedding.as_ref().map(|e| serialize_embedding(e)),
225 span.embedding_model,
226 span.token_count as i64,
227 span.metadata.as_ref().map(|m| m.to_string()),
228 ],
229 )?;
230 }
231
232 tx.commit()?;
233 self.index_dirty.store(true, Ordering::Release);
235 let _ = std::fs::remove_dir_all(self.get_index_cache_dir());
237 Ok(())
238 }
239
240 pub fn get_vector_index(&self) -> Result<Arc<VectorIndex>> {
249 Ok(self.get_vector_index_with_kind()?.0)
250 }
251
252 pub fn get_vector_index_with_kind(&self) -> Result<(Arc<VectorIndex>, IndexLoadKind)> {
257 if self.index_dirty.load(Ordering::Acquire) {
259 let _guard = self.build_lock.lock()
261 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Build lock poisoned: {}", e)))?;
262 if !self.index_dirty.load(Ordering::Acquire) {
264 let cached = self.vector_index.read()
265 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Index lock poisoned: {}", e)))?;
266 let idx = cached.as_ref()
267 .cloned()
268 .ok_or_else(|| crate::types::Error::Other(anyhow::anyhow!("Index cache empty after build")))?
269 ;
270 return Ok((idx, IndexLoadKind::CachedInMemory));
271 }
272 let cache_dir = self.get_index_cache_dir();
274 if let Ok(index) = self.load_index_from_disk(&cache_dir) {
275 let mut cached = self.vector_index.write()
279 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Index lock poisoned: {}", e)))?;
280 *cached = Some(index.clone());
281 self.index_dirty.store(false, Ordering::Release);
282 return Ok((index, IndexLoadKind::LoadedFromCache));
283 }
284
285 let spans = self.get_all_spans()?;
288 let index = Arc::new(VectorIndex::build(spans));
289
290 let _ = self.save_index_to_disk(&cache_dir, &index);
295
296 let mut cached = self.vector_index.write()
298 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Index lock poisoned: {}", e)))?;
299 *cached = Some(index.clone());
300
301 self.index_dirty.store(false, Ordering::Release);
303
304 Ok((index, IndexLoadKind::BuiltFromSpans))
305 } else {
306 let cached = self.vector_index.read()
308 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Index lock poisoned: {}", e)))?;
309 let idx = cached.as_ref()
310 .cloned()
311 .ok_or_else(|| crate::types::Error::Other(anyhow::anyhow!("Index cache is None but not dirty - this should not happen")))?;
312 Ok((idx, IndexLoadKind::CachedInMemory))
313 }
314 }
315
316 fn get_index_cache_dir(&self) -> PathBuf {
318 let mut cache_dir = self.db_path.clone();
320 cache_dir.set_extension("sqlite.idx");
321 cache_dir
322 }
323
324 fn calculate_spans_hash(&self) -> Result<String> {
326 let spans = self.get_all_spans()?;
327 let mut hasher = Sha256::new();
328 for span in &spans {
329 hasher.update(span.id.as_bytes());
330 if let Some(emb) = &span.embedding {
331 hasher.update(&emb.len().to_le_bytes());
332 }
333 }
334 Ok(format!("{:x}", hasher.finalize()))
335 }
336
337 fn load_index_from_disk(&self, cache_dir: &Path) -> Result<Arc<VectorIndex>> {
339 match VectorIndex::load_from_disk(cache_dir) {
341 Ok(Some(index)) => {
342 let current_hash = self.calculate_spans_hash()?;
344 let cached_spans = index.spans();
345 let mut hasher = Sha256::new();
346 for span in cached_spans {
347 hasher.update(span.id.as_bytes());
348 if let Some(emb) = &span.embedding {
349 hasher.update(&emb.len().to_le_bytes());
350 }
351 }
352 let cached_hash = format!("{:x}", hasher.finalize());
353
354 if cached_hash == current_hash {
355 Ok(Arc::new(index))
356 } else {
357 Err(crate::types::Error::NotFound("Index cache is stale".to_string()))
358 }
359 }
360 Ok(None) => Err(crate::types::Error::NotFound("Index cache not found".to_string())),
361 Err(e) => Err(e),
362 }
363 }
364
365 fn save_index_to_disk(&self, cache_dir: &Path, index: &VectorIndex) -> Result<()> {
367 index.save_to_disk(cache_dir)
369 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Failed to save index to disk: {}", e)))?;
370 Ok(())
371 }
372
373 pub fn get_all_spans(&self) -> Result<Vec<Span>> {
379 let conn = self.conn.lock()
380 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
381 let mut stmt = conn.prepare(
382 "SELECT id, artifact_id, start_line, end_line, text,
383 embedding, embedding_model, token_count, metadata
384 FROM spans",
385 )?;
386
387 let spans = stmt
388 .query_map([], |row| {
389 Ok(Span {
390 id: row.get(0)?,
391 artifact_id: row.get(1)?,
392 start_line: row.get::<_, i64>(2)? as usize,
393 end_line: row.get::<_, i64>(3)? as usize,
394 text: row.get(4)?,
395 embedding: row
396 .get::<_, Option<Vec<u8>>>(5)?
397 .map(|bytes| deserialize_embedding(&bytes)),
398 embedding_model: row.get(6)?,
399 token_count: row.get::<_, i64>(7)? as usize,
400 metadata: row
401 .get::<_, Option<String>>(8)?
402 .and_then(|s| serde_json::from_str(&s).ok()),
403 })
404 })?
405 .collect::<std::result::Result<Vec<_>, _>>()?;
406
407 Ok(spans)
408 }
409
410 pub fn get_artifact(&self, artifact_id: &str) -> Result<Option<Artifact>> {
420 let conn = self.conn.lock()
421 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
422 let mut stmt = conn.prepare(
423 "SELECT id, path, content, content_hash, metadata, created_at
424 FROM artifacts WHERE id = ?1",
425 )?;
426
427 let artifact = stmt
428 .query_row(params![artifact_id], |row| {
429 Ok(Artifact {
430 id: row.get(0)?,
431 path: row.get(1)?,
432 content: row.get(2)?,
433 content_hash: row.get(3)?,
434 metadata: row
435 .get::<_, Option<String>>(4)?
436 .and_then(|s| serde_json::from_str(&s).ok()),
437 created_at: row
438 .get::<_, String>(5)?
439 .parse()
440 .unwrap_or_else(|_| chrono::Utc::now()),
441 })
442 })
443 .optional()?;
444
445 Ok(artifact)
446 }
447
448 pub fn get_artifact_by_path(&self, path: &str) -> Result<Option<Artifact>> {
452 let conn = self.conn.lock()
453 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
454 let mut stmt = conn.prepare(
455 "SELECT id, path, content, content_hash, metadata, created_at
456 FROM artifacts WHERE path = ?1",
457 )?;
458
459 let artifact = stmt
460 .query_row(params![path], |row| {
461 Ok(Artifact {
462 id: row.get(0)?,
463 path: row.get(1)?,
464 content: row.get(2)?,
465 content_hash: row.get(3)?,
466 metadata: row
467 .get::<_, Option<String>>(4)?
468 .and_then(|s| serde_json::from_str(&s).ok()),
469 created_at: row
470 .get::<_, String>(5)?
471 .parse()
472 .unwrap_or_else(|_| chrono::Utc::now()),
473 })
474 })
475 .optional()?;
476
477 Ok(artifact)
478 }
479
480 pub fn determine_ingest_action(&self, path: &str, content_hash: &str) -> Result<crate::types::IngestAction> {
495 match self.get_artifact_by_path(path)? {
496 Some(existing) => {
497 if existing.content_hash == content_hash {
498 Ok(crate::types::IngestAction::Skip {
499 artifact_id: existing.id,
500 reason: "Content unchanged (same hash)".to_string(),
501 })
502 } else {
503 Ok(crate::types::IngestAction::Update {
504 artifact_id: existing.id,
505 })
506 }
507 }
508 None => Ok(crate::types::IngestAction::Create),
509 }
510 }
511
512 pub fn delete_artifact(&self, artifact_id: &str) -> Result<usize> {
522 let conn = self.conn.lock()
523 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
524
525 let spans_deleted = conn.execute(
527 "DELETE FROM spans WHERE artifact_id = ?1",
528 params![artifact_id],
529 )?;
530
531 conn.execute(
533 "DELETE FROM artifacts WHERE id = ?1",
534 params![artifact_id],
535 )?;
536
537 self.index_dirty.store(true, std::sync::atomic::Ordering::Release);
539
540 let cache_dir = self.db_path.with_extension("sqlite.idx");
542 if cache_dir.exists() {
543 let _ = std::fs::remove_dir_all(&cache_dir);
544 }
545
546 Ok(spans_deleted)
547 }
548
549 pub fn search_spans(&self, query: &str, limit: usize) -> Result<Vec<Span>> {
560 let conn = self.conn.lock()
561 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
562 let mut stmt = conn.prepare(
563 "SELECT id, artifact_id, start_line, end_line, text,
564 embedding, embedding_model, token_count, metadata
565 FROM spans
566 WHERE text LIKE ?1
567 LIMIT ?2",
568 )?;
569
570 let pattern = format!("%{}%", query);
571 let spans = stmt
572 .query_map(params![pattern, limit as i64], |row| {
573 Ok(Span {
574 id: row.get(0)?,
575 artifact_id: row.get(1)?,
576 start_line: row.get::<_, i64>(2)? as usize,
577 end_line: row.get::<_, i64>(3)? as usize,
578 text: row.get(4)?,
579 embedding: row
580 .get::<_, Option<Vec<u8>>>(5)?
581 .map(|bytes| deserialize_embedding(&bytes)),
582 embedding_model: row.get(6)?,
583 token_count: row.get::<_, i64>(7)? as usize,
584 metadata: row
585 .get::<_, Option<String>>(8)?
586 .and_then(|s| serde_json::from_str(&s).ok()),
587 })
588 })?
589 .collect::<std::result::Result<Vec<_>, _>>()?;
590
591 Ok(spans)
592 }
593
594 pub fn get_stats(&self) -> Result<(usize, usize, usize)> {
600 let conn = self.conn.lock()
601 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
602
603 let artifacts_count: i64 = conn.query_row("SELECT COUNT(*) FROM artifacts", [], |row| {
604 row.get(0)
605 })?;
606
607 let spans_count: i64 = conn.query_row("SELECT COUNT(*) FROM spans", [], |row| row.get(0))?;
608
609 let total_tokens: i64 = conn
610 .query_row("SELECT COALESCE(SUM(token_count), 0) FROM spans", [], |row| {
611 row.get(0)
612 })?;
613
614 Ok((
615 artifacts_count as usize,
616 spans_count as usize,
617 total_tokens as usize,
618 ))
619 }
620
621 pub fn clear(&self) -> Result<()> {
623 let conn = self.conn.lock()
624 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
625 conn.execute("DELETE FROM spans", [])?;
626 conn.execute("DELETE FROM artifacts", [])?;
627 let mut cached = self.vector_index.write()
629 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Index lock poisoned: {}", e)))?;
630 *cached = None;
631 self.index_dirty.store(true, Ordering::Release);
632 let _ = std::fs::remove_dir_all(self.get_index_cache_dir());
634 Ok(())
635 }
636
637 pub fn create_session(&self, user_id: Option<&str>, title: Option<&str>) -> Result<Session> {
650 let conn = self.conn.lock()
651 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
652
653 let id = uuid::Uuid::new_v4().to_string();
654 let now = chrono::Utc::now();
655
656 conn.execute(
657 "INSERT INTO sessions (id, user_id, title, metadata, created_at, updated_at, last_message_at)
658 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
659 params![
660 id,
661 user_id,
662 title,
663 None::<String>, now.to_rfc3339(),
665 now.to_rfc3339(),
666 None::<String>, ],
668 )?;
669
670 Ok(Session {
671 id,
672 user_id: user_id.map(|s| s.to_string()),
673 title: title.map(|s| s.to_string()),
674 metadata: None,
675 created_at: now,
676 updated_at: now,
677 last_message_at: None,
678 })
679 }
680
681 pub fn get_session(&self, session_id: &str) -> Result<Option<Session>> {
691 let conn = self.conn.lock()
692 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
693
694 let mut stmt = conn.prepare(
695 "SELECT id, user_id, title, metadata, created_at, updated_at, last_message_at
696 FROM sessions WHERE id = ?1",
697 )?;
698
699 let session = stmt.query_row(params![session_id], |row| {
700 Ok(Session {
701 id: row.get(0)?,
702 user_id: row.get(1)?,
703 title: row.get(2)?,
704 metadata: row.get::<_, Option<String>>(3)?
705 .and_then(|s| serde_json::from_str(&s).ok()),
706 created_at: row.get::<_, String>(4)?
707 .parse()
708 .unwrap_or_else(|_| chrono::Utc::now()),
709 updated_at: row.get::<_, String>(5)?
710 .parse()
711 .unwrap_or_else(|_| chrono::Utc::now()),
712 last_message_at: row.get::<_, Option<String>>(6)?
713 .and_then(|s| s.parse().ok()),
714 })
715 }).optional()?;
716
717 Ok(session)
718 }
719
720 pub fn list_sessions(&self, user_id: Option<&str>, limit: Option<usize>) -> Result<Vec<Session>> {
731 let conn = self.conn.lock()
732 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
733
734 let limit_val = limit.unwrap_or(100) as i64;
735
736 let mut sessions = Vec::new();
737
738 if let Some(uid) = user_id {
739 let mut stmt = conn.prepare(
740 "SELECT id, user_id, title, metadata, created_at, updated_at, last_message_at
741 FROM sessions WHERE user_id = ?1
742 ORDER BY updated_at DESC
743 LIMIT ?2"
744 )?;
745
746 let rows = stmt.query_map(params![uid, limit_val], |row| {
747 Ok(Session {
748 id: row.get(0)?,
749 user_id: row.get(1)?,
750 title: row.get(2)?,
751 metadata: row.get::<_, Option<String>>(3)?
752 .and_then(|s| serde_json::from_str(&s).ok()),
753 created_at: row.get::<_, String>(4)?
754 .parse()
755 .unwrap_or_else(|_| chrono::Utc::now()),
756 updated_at: row.get::<_, String>(5)?
757 .parse()
758 .unwrap_or_else(|_| chrono::Utc::now()),
759 last_message_at: row.get::<_, Option<String>>(6)?
760 .and_then(|s| s.parse().ok()),
761 })
762 })?;
763
764 for row in rows {
765 sessions.push(row?);
766 }
767 } else {
768 let mut stmt = conn.prepare(
769 "SELECT id, user_id, title, metadata, created_at, updated_at, last_message_at
770 FROM sessions
771 ORDER BY updated_at DESC
772 LIMIT ?1"
773 )?;
774
775 let rows = stmt.query_map(params![limit_val], |row| {
776 Ok(Session {
777 id: row.get(0)?,
778 user_id: row.get(1)?,
779 title: row.get(2)?,
780 metadata: row.get::<_, Option<String>>(3)?
781 .and_then(|s| serde_json::from_str(&s).ok()),
782 created_at: row.get::<_, String>(4)?
783 .parse()
784 .unwrap_or_else(|_| chrono::Utc::now()),
785 updated_at: row.get::<_, String>(5)?
786 .parse()
787 .unwrap_or_else(|_| chrono::Utc::now()),
788 last_message_at: row.get::<_, Option<String>>(6)?
789 .and_then(|s| s.parse().ok()),
790 })
791 })?;
792
793 for row in rows {
794 sessions.push(row?);
795 }
796 }
797
798 Ok(sessions)
799 }
800
801 pub fn update_session(
813 &self,
814 session_id: &str,
815 title: Option<&str>,
816 metadata: Option<&serde_json::Value>,
817 ) -> Result<()> {
818 let conn = self.conn.lock()
819 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
820
821 let now = chrono::Utc::now();
822
823 conn.execute(
824 "UPDATE sessions
825 SET title = COALESCE(?1, title),
826 metadata = COALESCE(?2, metadata),
827 updated_at = ?3
828 WHERE id = ?4",
829 params![
830 title,
831 metadata.map(|m| m.to_string()),
832 now.to_rfc3339(),
833 session_id,
834 ],
835 )?;
836
837 Ok(())
838 }
839
840 pub fn delete_session(&self, session_id: &str) -> Result<()> {
850 let conn = self.conn.lock()
851 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
852
853 conn.execute("DELETE FROM sessions WHERE id = ?1", params![session_id])?;
854
855 Ok(())
856 }
857
858 pub fn add_message(
871 &self,
872 session_id: &str,
873 role: MessageRole,
874 content: &str,
875 metadata: Option<&serde_json::Value>,
876 ) -> Result<Message> {
877 let mut conn = self.conn.lock()
878 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
879
880 let tx = conn.transaction()?;
881
882 let sequence_number: i64 = tx.query_row(
884 "SELECT COALESCE(MAX(sequence_number), -1) + 1 FROM messages WHERE session_id = ?1",
885 params![session_id],
886 |row| row.get(0),
887 )?;
888
889 let id = uuid::Uuid::new_v4().to_string();
890 let now = chrono::Utc::now();
891
892 tx.execute(
893 "INSERT INTO messages (id, session_id, role, content, metadata, sequence_number, created_at)
894 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
895 params![
896 id,
897 session_id,
898 role.as_str(),
899 content,
900 metadata.map(|m| m.to_string()),
901 sequence_number,
902 now.to_rfc3339(),
903 ],
904 )?;
905
906 tx.execute(
908 "UPDATE sessions
909 SET last_message_at = ?1, updated_at = ?1
910 WHERE id = ?2",
911 params![now.to_rfc3339(), session_id],
912 )?;
913
914 tx.commit()?;
915
916 Ok(Message {
917 id,
918 session_id: session_id.to_string(),
919 role,
920 content: content.to_string(),
921 metadata: metadata.cloned(),
922 sequence_number: sequence_number as usize,
923 created_at: now,
924 })
925 }
926
927 pub fn get_messages(&self, session_id: &str, limit: Option<usize>) -> Result<Vec<Message>> {
938 let conn = self.conn.lock()
939 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
940
941 let mut messages = Vec::new();
942
943 if let Some(lim) = limit {
944 let mut stmt = conn.prepare(
945 "SELECT id, session_id, role, content, metadata, sequence_number, created_at
946 FROM messages
947 WHERE session_id = ?1
948 ORDER BY sequence_number ASC
949 LIMIT ?2"
950 )?;
951
952 let rows = stmt.query_map(params![session_id, lim as i64], |row| {
953 let role_str: String = row.get(2)?;
954 let role = MessageRole::from_str(&role_str)
955 .unwrap_or(MessageRole::User);
956
957 Ok(Message {
958 id: row.get(0)?,
959 session_id: row.get(1)?,
960 role,
961 content: row.get(3)?,
962 metadata: row.get::<_, Option<String>>(4)?
963 .and_then(|s| serde_json::from_str(&s).ok()),
964 sequence_number: row.get::<_, i64>(5)? as usize,
965 created_at: row.get::<_, String>(6)?
966 .parse()
967 .unwrap_or_else(|_| chrono::Utc::now()),
968 })
969 })?;
970
971 for row in rows {
972 messages.push(row?);
973 }
974 } else {
975 let mut stmt = conn.prepare(
976 "SELECT id, session_id, role, content, metadata, sequence_number, created_at
977 FROM messages
978 WHERE session_id = ?1
979 ORDER BY sequence_number ASC"
980 )?;
981
982 let rows = stmt.query_map(params![session_id], |row| {
983 let role_str: String = row.get(2)?;
984 let role = MessageRole::from_str(&role_str)
985 .unwrap_or(MessageRole::User);
986
987 Ok(Message {
988 id: row.get(0)?,
989 session_id: row.get(1)?,
990 role,
991 content: row.get(3)?,
992 metadata: row.get::<_, Option<String>>(4)?
993 .and_then(|s| serde_json::from_str(&s).ok()),
994 sequence_number: row.get::<_, i64>(5)? as usize,
995 created_at: row.get::<_, String>(6)?
996 .parse()
997 .unwrap_or_else(|_| chrono::Utc::now()),
998 })
999 })?;
1000
1001 for row in rows {
1002 messages.push(row?);
1003 }
1004 }
1005
1006 Ok(messages)
1007 }
1008
1009 pub fn associate_working_set(
1023 &self,
1024 session_id: &str,
1025 message_id: Option<&str>,
1026 working_set: &WorkingSet,
1027 query: &str,
1028 config: &CompilerConfig,
1029 ) -> Result<SessionWorkingSet> {
1030 let conn = self.conn.lock()
1031 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
1032
1033 let id = uuid::Uuid::new_v4().to_string();
1034 let working_set_id = uuid::Uuid::new_v4().to_string(); let now = chrono::Utc::now();
1036
1037 conn.execute(
1038 "INSERT INTO session_working_sets (id, session_id, message_id, working_set_id, query, config, created_at)
1039 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
1040 params![
1041 id,
1042 session_id,
1043 message_id,
1044 working_set_id,
1045 query,
1046 serde_json::to_string(config)?,
1047 now.to_rfc3339(),
1048 ],
1049 )?;
1050
1051 Ok(SessionWorkingSet {
1052 id,
1053 session_id: session_id.to_string(),
1054 message_id: message_id.map(|s| s.to_string()),
1055 working_set: working_set.clone(),
1056 query: query.to_string(),
1057 config: config.clone(),
1058 created_at: now,
1059 })
1060 }
1061
1062 pub fn get_session_full(&self, session_id: &str) -> Result<Option<SessionWithMessages>> {
1072 let session = self.get_session(session_id)?;
1073
1074 if session.is_none() {
1075 return Ok(None);
1076 }
1077
1078 let session = session.unwrap();
1079 let messages = self.get_messages(session_id, None)?;
1080
1081 let conn = self.conn.lock()
1083 .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
1084
1085 let mut stmt = conn.prepare(
1086 "SELECT id, session_id, message_id, working_set_id, query, config, created_at
1087 FROM session_working_sets
1088 WHERE session_id = ?1
1089 ORDER BY created_at ASC",
1090 )?;
1091
1092 let working_sets = stmt.query_map(params![session_id], |row| {
1093 let config_str: String = row.get(5)?;
1094 let config: CompilerConfig = serde_json::from_str(&config_str)
1095 .unwrap_or_default();
1096
1097 let working_set = WorkingSet {
1101 text: String::new(),
1102 spans: Vec::new(),
1103 citations: Vec::new(),
1104 tokens_used: 0,
1105 query: row.get::<_, String>(4)?,
1106 compilation_time_ms: 0,
1107 manifest: None,
1108 explain: None,
1109 };
1110
1111 Ok(SessionWorkingSet {
1112 id: row.get(0)?,
1113 session_id: row.get(1)?,
1114 message_id: row.get(2)?,
1115 working_set,
1116 query: row.get(4)?,
1117 config,
1118 created_at: row.get::<_, String>(6)?
1119 .parse()
1120 .unwrap_or_else(|_| chrono::Utc::now()),
1121 })
1122 })?
1123 .collect::<std::result::Result<Vec<_>, _>>()?;
1124
1125 Ok(Some(SessionWithMessages {
1126 session,
1127 messages,
1128 working_sets,
1129 }))
1130 }
1131}
1132
1133fn serialize_embedding(embedding: &[f32]) -> Vec<u8> {
1135 embedding.iter().flat_map(|f| f.to_le_bytes()).collect()
1136}
1137
1138fn deserialize_embedding(bytes: &[u8]) -> Vec<f32> {
1140 bytes
1141 .chunks_exact(4)
1142 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
1143 .collect()
1144}
1145
1146#[cfg(test)]
1147mod tests {
1148 use super::*;
1149 use uuid::Uuid;
1150
1151 #[test]
1152 fn test_database_creation() {
1153 let db = Database::new(":memory:").unwrap();
1154 let (artifacts, spans, tokens) = db.get_stats().unwrap();
1155 assert_eq!(artifacts, 0);
1156 assert_eq!(spans, 0);
1157 assert_eq!(tokens, 0);
1158 }
1159
1160 #[test]
1161 fn test_insert_artifact() {
1162 let db = Database::new(":memory:").unwrap();
1163
1164 let artifact = Artifact {
1165 id: Uuid::new_v4().to_string(),
1166 path: "test.txt".to_string(),
1167 content: "Test content".to_string(),
1168 content_hash: "hash123".to_string(),
1169 metadata: None,
1170 created_at: chrono::Utc::now(),
1171 };
1172
1173 db.insert_artifact(&artifact).unwrap();
1174
1175 let (count, _, _) = db.get_stats().unwrap();
1176 assert_eq!(count, 1);
1177 }
1178
1179 #[test]
1180 fn test_embedding_serialization() {
1181 let original = vec![1.0, 2.5, -3.14, 0.0];
1182 let bytes = serialize_embedding(&original);
1183 let restored = deserialize_embedding(&bytes);
1184
1185 assert_eq!(original.len(), restored.len());
1186 for (a, b) in original.iter().zip(restored.iter()) {
1187 assert!((a - b).abs() < 0.0001);
1188 }
1189 }
1190
1191 #[test]
1194 fn test_create_session() {
1195 let db = Database::new(":memory:").unwrap();
1196
1197 let session = db.create_session(Some("user123"), Some("Test Session")).unwrap();
1198
1199 assert!(!session.id.is_empty());
1200 assert_eq!(session.user_id, Some("user123".to_string()));
1201 assert_eq!(session.title, Some("Test Session".to_string()));
1202 assert!(session.metadata.is_none());
1203 assert!(session.last_message_at.is_none());
1204 }
1205
1206 #[test]
1207 fn test_get_session() {
1208 let db = Database::new(":memory:").unwrap();
1209
1210 let created = db.create_session(Some("user456"), Some("Another Session")).unwrap();
1211 let retrieved = db.get_session(&created.id).unwrap();
1212
1213 assert!(retrieved.is_some());
1214 let session = retrieved.unwrap();
1215 assert_eq!(session.id, created.id);
1216 assert_eq!(session.user_id, created.user_id);
1217 assert_eq!(session.title, created.title);
1218 }
1219
1220 #[test]
1221 fn test_get_nonexistent_session() {
1222 let db = Database::new(":memory:").unwrap();
1223
1224 let result = db.get_session("nonexistent-id").unwrap();
1225 assert!(result.is_none());
1226 }
1227
1228 #[test]
1229 fn test_list_sessions() {
1230 let db = Database::new(":memory:").unwrap();
1231
1232 db.create_session(Some("user1"), Some("Session 1")).unwrap();
1234 db.create_session(Some("user1"), Some("Session 2")).unwrap();
1235 db.create_session(Some("user2"), Some("Session 3")).unwrap();
1236
1237 let all_sessions = db.list_sessions(None, None).unwrap();
1239 assert_eq!(all_sessions.len(), 3);
1240
1241 let user1_sessions = db.list_sessions(Some("user1"), None).unwrap();
1243 assert_eq!(user1_sessions.len(), 2);
1244
1245 let user2_sessions = db.list_sessions(Some("user2"), None).unwrap();
1247 assert_eq!(user2_sessions.len(), 1);
1248
1249 let limited = db.list_sessions(None, Some(2)).unwrap();
1251 assert_eq!(limited.len(), 2);
1252 }
1253
1254 #[test]
1255 fn test_update_session() {
1256 let db = Database::new(":memory:").unwrap();
1257
1258 let session = db.create_session(Some("user1"), Some("Original Title")).unwrap();
1259
1260 db.update_session(&session.id, Some("Updated Title"), None).unwrap();
1262
1263 let updated = db.get_session(&session.id).unwrap().unwrap();
1264 assert_eq!(updated.title, Some("Updated Title".to_string()));
1265
1266 let metadata = serde_json::json!({"key": "value"});
1268 db.update_session(&session.id, None, Some(&metadata)).unwrap();
1269
1270 let updated2 = db.get_session(&session.id).unwrap().unwrap();
1271 assert!(updated2.metadata.is_some());
1272 assert_eq!(updated2.metadata.unwrap()["key"], "value");
1273 }
1274
1275 #[test]
1276 fn test_delete_session() {
1277 let db = Database::new(":memory:").unwrap();
1278
1279 let session = db.create_session(Some("user1"), Some("To Delete")).unwrap();
1280
1281 assert!(db.get_session(&session.id).unwrap().is_some());
1283
1284 db.delete_session(&session.id).unwrap();
1286
1287 assert!(db.get_session(&session.id).unwrap().is_none());
1289 }
1290
1291 #[test]
1292 fn test_add_message() {
1293 let db = Database::new(":memory:").unwrap();
1294
1295 let session = db.create_session(Some("user1"), Some("Chat Session")).unwrap();
1296
1297 let msg1 = db.add_message(&session.id, MessageRole::User, "Hello", None).unwrap();
1299 assert_eq!(msg1.sequence_number, 0);
1300 assert_eq!(msg1.content, "Hello");
1301 assert_eq!(msg1.role.as_str(), "user");
1302
1303 let msg2 = db.add_message(&session.id, MessageRole::Assistant, "Hi there!", None).unwrap();
1305 assert_eq!(msg2.sequence_number, 1);
1306 assert_eq!(msg2.content, "Hi there!");
1307 assert_eq!(msg2.role.as_str(), "assistant");
1308
1309 let updated_session = db.get_session(&session.id).unwrap().unwrap();
1311 assert!(updated_session.last_message_at.is_some());
1312 }
1313
1314 #[test]
1315 fn test_add_message_with_metadata() {
1316 let db = Database::new(":memory:").unwrap();
1317
1318 let session = db.create_session(Some("user1"), Some("Chat Session")).unwrap();
1319
1320 let metadata = serde_json::json!({"tool": "search", "query": "test"});
1321 let msg = db.add_message(&session.id, MessageRole::Tool, "Result", Some(&metadata)).unwrap();
1322
1323 assert!(msg.metadata.is_some());
1324 assert_eq!(msg.metadata.unwrap()["tool"], "search");
1325 }
1326
1327 #[test]
1328 fn test_get_messages() {
1329 let db = Database::new(":memory:").unwrap();
1330
1331 let session = db.create_session(Some("user1"), Some("Chat Session")).unwrap();
1332
1333 db.add_message(&session.id, MessageRole::User, "Message 1", None).unwrap();
1335 db.add_message(&session.id, MessageRole::Assistant, "Message 2", None).unwrap();
1336 db.add_message(&session.id, MessageRole::User, "Message 3", None).unwrap();
1337
1338 let messages = db.get_messages(&session.id, None).unwrap();
1340 assert_eq!(messages.len(), 3);
1341 assert_eq!(messages[0].sequence_number, 0);
1342 assert_eq!(messages[1].sequence_number, 1);
1343 assert_eq!(messages[2].sequence_number, 2);
1344
1345 let limited = db.get_messages(&session.id, Some(2)).unwrap();
1347 assert_eq!(limited.len(), 2);
1348 }
1349
1350 #[test]
1351 fn test_message_ordering() {
1352 let db = Database::new(":memory:").unwrap();
1353
1354 let session = db.create_session(Some("user1"), Some("Chat Session")).unwrap();
1355
1356 db.add_message(&session.id, MessageRole::User, "First", None).unwrap();
1358 db.add_message(&session.id, MessageRole::Assistant, "Second", None).unwrap();
1359 db.add_message(&session.id, MessageRole::User, "Third", None).unwrap();
1360
1361 let messages = db.get_messages(&session.id, None).unwrap();
1362
1363 assert_eq!(messages[0].content, "First");
1365 assert_eq!(messages[1].content, "Second");
1366 assert_eq!(messages[2].content, "Third");
1367
1368 for (i, msg) in messages.iter().enumerate() {
1370 assert_eq!(msg.sequence_number, i);
1371 }
1372 }
1373
1374 #[test]
1375 fn test_associate_working_set() {
1376 let db = Database::new(":memory:").unwrap();
1377
1378 let session = db.create_session(Some("user1"), Some("Chat Session")).unwrap();
1379 let message = db.add_message(&session.id, MessageRole::User, "Query", None).unwrap();
1380
1381 let working_set = WorkingSet {
1383 text: "Test context".to_string(),
1384 spans: Vec::new(),
1385 citations: Vec::new(),
1386 tokens_used: 100,
1387 query: "test query".to_string(),
1388 compilation_time_ms: 50,
1389 manifest: None,
1390 explain: None,
1391 };
1392
1393 let config = CompilerConfig::default();
1394
1395 let sws = db.associate_working_set(
1396 &session.id,
1397 Some(&message.id),
1398 &working_set,
1399 "test query",
1400 &config,
1401 ).unwrap();
1402
1403 assert_eq!(sws.session_id, session.id);
1404 assert_eq!(sws.message_id, Some(message.id));
1405 assert_eq!(sws.query, "test query");
1406 assert_eq!(sws.working_set.text, "Test context");
1407 }
1408
1409 #[test]
1410 fn test_get_session_full() {
1411 let db = Database::new(":memory:").unwrap();
1412
1413 let session = db.create_session(Some("user1"), Some("Full Session")).unwrap();
1414
1415 let msg1 = db.add_message(&session.id, MessageRole::User, "Hello", None).unwrap();
1417 db.add_message(&session.id, MessageRole::Assistant, "Hi!", None).unwrap();
1418
1419 let working_set = WorkingSet {
1421 text: "Context".to_string(),
1422 spans: Vec::new(),
1423 citations: Vec::new(),
1424 tokens_used: 50,
1425 query: "test".to_string(),
1426 compilation_time_ms: 25,
1427 manifest: None,
1428 explain: None,
1429 };
1430
1431 db.associate_working_set(
1432 &session.id,
1433 Some(&msg1.id),
1434 &working_set,
1435 "test",
1436 &CompilerConfig::default(),
1437 ).unwrap();
1438
1439 let full = db.get_session_full(&session.id).unwrap();
1441 assert!(full.is_some());
1442
1443 let swm = full.unwrap();
1444 assert_eq!(swm.session.id, session.id);
1445 assert_eq!(swm.messages.len(), 2);
1446 assert_eq!(swm.working_sets.len(), 1);
1447 }
1448
1449 #[test]
1450 fn test_delete_session_cascade() {
1451 let db = Database::new(":memory:").unwrap();
1452
1453 let session = db.create_session(Some("user1"), Some("To Delete")).unwrap();
1454
1455 db.add_message(&session.id, MessageRole::User, "Message 1", None).unwrap();
1457 db.add_message(&session.id, MessageRole::Assistant, "Message 2", None).unwrap();
1458
1459 let messages_before = db.get_messages(&session.id, None).unwrap();
1461 assert_eq!(messages_before.len(), 2);
1462
1463 db.delete_session(&session.id).unwrap();
1465
1466 let messages_after = db.get_messages(&session.id, None).unwrap();
1468 assert_eq!(messages_after.len(), 0);
1469 }
1470
1471 #[test]
1472 fn test_message_role_conversion() {
1473 assert_eq!(MessageRole::User.as_str(), "user");
1474 assert_eq!(MessageRole::Assistant.as_str(), "assistant");
1475 assert_eq!(MessageRole::System.as_str(), "system");
1476 assert_eq!(MessageRole::Tool.as_str(), "tool");
1477
1478 assert!(matches!(MessageRole::from_str("user").unwrap(), MessageRole::User));
1479 assert!(matches!(MessageRole::from_str("assistant").unwrap(), MessageRole::Assistant));
1480 assert!(matches!(MessageRole::from_str("system").unwrap(), MessageRole::System));
1481 assert!(matches!(MessageRole::from_str("tool").unwrap(), MessageRole::Tool));
1482
1483 assert!(MessageRole::from_str("invalid").is_err());
1484 }
1485}