avocado_core/
db.rs

1//! Database operations using SQLite
2//!
3//! This module handles all database interactions using rusqlite.
4//! SQLite is sufficient for Phase 1 (can handle 10K+ documents easily).
5
6use crate::types::{Artifact, Result, Span, Session, Message, MessageRole, SessionWorkingSet, SessionWithMessages, WorkingSet, CompilerConfig};
7use crate::index::VectorIndex;
8use rusqlite::{params, Connection, OptionalExtension};
9use std::path::{Path, PathBuf};
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::{Arc, Mutex, RwLock};
12use sha2::{Digest, Sha256};
13use serde::{Serialize, Deserialize};
14
15/// Database connection wrapper with thread-safe access
16#[derive(Clone)]
17pub struct Database {
18    conn: Arc<Mutex<Connection>>,
19    // Cached vector index to avoid rebuilding on every compile request
20    vector_index: Arc<RwLock<Option<Arc<VectorIndex>>>>,
21    // Flag to track if index needs rebuilding (invalidated on ingest)
22    index_dirty: Arc<AtomicBool>,
23    // Path to database file (for index cache location)
24    db_path: PathBuf,
25    // Serialize builds for this Database to avoid concurrent heavy index builds
26    build_lock: Arc<Mutex<()>>,
27}
28
29/// How the ANN index was obtained
30#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
31pub enum IndexLoadKind {
32    /// Loaded from on-disk cache (ANN persistence)
33    LoadedFromCache,
34    /// Built from spans in the database
35    BuiltFromSpans,
36    /// Returned from in-memory cache (no disk or build)
37    CachedInMemory,
38}
39impl Database {
40    /// Create a new database connection and run migrations
41    ///
42    /// # Arguments
43    ///
44    /// * `path` - Path to the SQLite database file
45    ///
46    /// # Returns
47    ///
48    /// A new Database instance
49    pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
50        let db_path = path.as_ref().to_path_buf();
51        let conn = Connection::open(&db_path)?;
52
53        // Run initial migration (without PRAGMA statements)
54        let schema_001 = r#"-- AvocadoDB Initial Schema
55-- Phase 1: Simple SQLite-compatible schema for deterministic context compilation
56
57-- Artifacts table: stores ingested documents
58CREATE TABLE IF NOT EXISTS artifacts (
59    id TEXT PRIMARY KEY,                      -- UUID v4
60    path TEXT NOT NULL UNIQUE,                -- File path or identifier
61    content TEXT NOT NULL,                    -- Full document text
62    content_hash TEXT NOT NULL,               -- SHA256 of content
63    metadata TEXT,                            -- JSON string with arbitrary metadata
64    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
65    updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
66);
67
68-- Spans table: stores document fragments with embeddings
69CREATE TABLE IF NOT EXISTS spans (
70    id TEXT PRIMARY KEY,                      -- UUID v4
71    artifact_id TEXT NOT NULL,                -- Foreign key to artifacts
72    start_line INTEGER NOT NULL,              -- Starting line number (1-indexed)
73    end_line INTEGER NOT NULL,                -- Ending line number (inclusive)
74    text TEXT NOT NULL,                       -- Actual span text
75    embedding BLOB,                           -- Serialized f32 vector (1536 dims for ada-002)
76    embedding_model TEXT,                     -- e.g., "text-embedding-ada-002"
77    token_count INTEGER,                      -- Estimated token count
78    metadata TEXT,                            -- JSON string
79    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
80    FOREIGN KEY (artifact_id) REFERENCES artifacts(id) ON DELETE CASCADE
81);
82
83-- Indexes for performance
84CREATE INDEX IF NOT EXISTS idx_spans_artifact ON spans(artifact_id);
85CREATE INDEX IF NOT EXISTS idx_spans_lines ON spans(artifact_id, start_line, end_line);
86CREATE INDEX IF NOT EXISTS idx_artifacts_path ON artifacts(path);
87CREATE INDEX IF NOT EXISTS idx_artifacts_hash ON artifacts(content_hash);
88
89-- Enable WAL mode for better concurrency
90PRAGMA journal_mode = WAL;
91PRAGMA foreign_keys = ON;
92"#;
93
94        // Execute the schema without PRAGMAs
95        let schema_without_pragma = schema_001
96            .lines()
97            .filter(|line| {
98                let trimmed = line.trim();
99                !trimmed.starts_with("PRAGMA") && !trimmed.starts_with("-- Enable WAL")
100            })
101            .collect::<Vec<_>>()
102            .join("\n");
103
104        conn.execute_batch(&schema_without_pragma)?;
105
106        // Run session management migration
107        let schema_002 = r#"-- AvocadoDB Session Management Schema
108-- Phase 2, Priority 1: Session tracking for conversation history and agent memory
109
110-- Sessions table: tracks conversation sessions
111CREATE TABLE IF NOT EXISTS sessions (
112    id TEXT PRIMARY KEY,                      -- UUID v4
113    user_id TEXT,                             -- Optional user identifier
114    title TEXT,                               -- Optional session title (auto-generated or user-provided)
115    metadata TEXT,                            -- JSON string with arbitrary metadata
116    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
117    updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
118    last_message_at TIMESTAMP                 -- For sorting/filtering
119);
120
121-- Messages table: stores individual conversation turns
122CREATE TABLE IF NOT EXISTS messages (
123    id TEXT PRIMARY KEY,                      -- UUID v4
124    session_id TEXT NOT NULL,                 -- Foreign key to sessions
125    role TEXT NOT NULL,                       -- 'user', 'assistant', 'system', 'tool'
126    content TEXT NOT NULL,                    -- Message content
127    metadata TEXT,                            -- JSON string (tool calls, citations, etc.)
128    sequence_number INTEGER NOT NULL,         -- Order within session (0-indexed)
129    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
130    FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE
131);
132
133-- Working set associations: links compiled contexts to sessions
134CREATE TABLE IF NOT EXISTS session_working_sets (
135    id TEXT PRIMARY KEY,                      -- UUID v4
136    session_id TEXT NOT NULL,                 -- Foreign key to sessions
137    message_id TEXT,                          -- Optional: which message triggered this compilation
138    working_set_id TEXT NOT NULL,             -- Reference to working set (stored as JSON for now)
139    query TEXT NOT NULL,                      -- Query that generated this working set
140    config TEXT,                              -- JSON string of CompilerConfig used
141    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
142    FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE,
143    FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE SET NULL
144);
145
146-- Indexes for performance
147CREATE INDEX IF NOT EXISTS idx_sessions_user ON sessions(user_id);
148CREATE INDEX IF NOT EXISTS idx_sessions_updated ON sessions(updated_at DESC);
149CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id, sequence_number);
150CREATE INDEX IF NOT EXISTS idx_working_sets_session ON session_working_sets(session_id);
151"#;
152        conn.execute_batch(schema_002)?;
153
154        // Execute PRAGMAs separately (they return results)
155        conn.pragma_update(None, "journal_mode", "WAL")?;
156        conn.pragma_update(None, "foreign_keys", true)?;
157
158        Ok(Self {
159            conn: Arc::new(Mutex::new(conn)),
160            vector_index: Arc::new(RwLock::new(None)),
161            index_dirty: Arc::new(AtomicBool::new(true)),
162            db_path,
163            build_lock: Arc::new(Mutex::new(())),
164        })
165    }
166
167    /// Insert an artifact into the database
168    ///
169    /// # Arguments
170    ///
171    /// * `artifact` - The artifact to insert
172    ///
173    /// # Returns
174    ///
175    /// Ok(()) if successful
176    pub fn insert_artifact(&self, artifact: &Artifact) -> Result<()> {
177        let conn = self.conn.lock()
178            .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
179        conn.execute(
180            "INSERT INTO artifacts (id, path, content, content_hash, metadata, created_at)
181             VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
182            params![
183                artifact.id,
184                artifact.path,
185                artifact.content,
186                artifact.content_hash,
187                artifact.metadata.as_ref().map(|m| m.to_string()),
188                artifact.created_at.to_rfc3339(),
189            ],
190        )?;
191        // Invalidate cached index since we added a new artifact
192        self.index_dirty.store(true, Ordering::Release);
193        // Delete index cache directory since it's now stale
194        let _ = std::fs::remove_dir_all(self.get_index_cache_dir());
195        Ok(())
196    }
197
198    /// Insert multiple spans in a transaction
199    ///
200    /// # Arguments
201    ///
202    /// * `spans` - Vector of spans to insert
203    ///
204    /// # Returns
205    ///
206    /// Ok(()) if successful
207    pub fn insert_spans(&self, spans: &[Span]) -> Result<()> {
208        let mut conn = self.conn.lock()
209            .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
210        let tx = conn.transaction()?;
211
212        for span in spans {
213            tx.execute(
214                "INSERT INTO spans (
215                    id, artifact_id, start_line, end_line, text,
216                    embedding, embedding_model, token_count, metadata
217                ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
218                params![
219                    span.id,
220                    span.artifact_id,
221                    span.start_line as i64,
222                    span.end_line as i64,
223                    span.text,
224                    span.embedding.as_ref().map(|e| serialize_embedding(e)),
225                    span.embedding_model,
226                    span.token_count as i64,
227                    span.metadata.as_ref().map(|m| m.to_string()),
228                ],
229            )?;
230        }
231
232        tx.commit()?;
233        // Invalidate cached index since we added new spans
234        self.index_dirty.store(true, Ordering::Release);
235        // Delete index cache directory since it's now stale
236        let _ = std::fs::remove_dir_all(self.get_index_cache_dir());
237        Ok(())
238    }
239
240    /// Get or build the cached vector index
241    ///
242    /// The index is cached and only rebuilt when data changes (on ingest).
243    /// Phase 2.1: Tries to load from disk first, then builds if needed.
244    ///
245    /// # Returns
246    ///
247    /// A reference-counted vector index
248    pub fn get_vector_index(&self) -> Result<Arc<VectorIndex>> {
249        Ok(self.get_vector_index_with_kind()?.0)
250    }
251
252    /// Get or build the cached vector index and return how it was obtained
253    ///
254    /// Returns the index and an indicator of whether it was loaded from cache
255    /// or freshly built from spans.
256    pub fn get_vector_index_with_kind(&self) -> Result<(Arc<VectorIndex>, IndexLoadKind)> {
257        // Check if index needs rebuilding
258        if self.index_dirty.load(Ordering::Acquire) {
259            // Ensure only one thread builds/loads at a time for this database
260            let _guard = self.build_lock.lock()
261                .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Build lock poisoned: {}", e)))?;
262            // Re-check after acquiring lock in case another thread already built it
263            if !self.index_dirty.load(Ordering::Acquire) {
264                let cached = self.vector_index.read()
265                    .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Index lock poisoned: {}", e)))?;
266                let idx = cached.as_ref()
267                    .cloned()
268                    .ok_or_else(|| crate::types::Error::Other(anyhow::anyhow!("Index cache empty after build")))?
269                    ;
270                return Ok((idx, IndexLoadKind::CachedInMemory));
271            }
272            // Try to load from disk first (Phase 2.1 persistent index)
273            let cache_dir = self.get_index_cache_dir();
274            if let Ok(index) = self.load_index_from_disk(&cache_dir) {
275                // Index loaded successfully from cache
276                // Note: We still rebuild HNSW from cached spans due to lifetime constraints in hnsw_rs
277                // This is faster than loading from SQLite, but not as fast as loading HNSW structure directly
278                let mut cached = self.vector_index.write()
279                    .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Index lock poisoned: {}", e)))?;
280                *cached = Some(index.clone());
281                self.index_dirty.store(false, Ordering::Release);
282                return Ok((index, IndexLoadKind::LoadedFromCache));
283            }
284            
285            // Build index from spans (load from SQLite)
286            // For large repos, this can take 1-2 minutes
287            let spans = self.get_all_spans()?;
288            let index = Arc::new(VectorIndex::build(spans));
289            
290            // Save to disk for next time (Phase 2.1)
291            // This saves both HNSW dump files and spans cache
292            // Note: HNSW structure can't be directly loaded due to lifetime constraints,
293            // but caching spans still provides significant speedup (avoids SQLite queries)
294            let _ = self.save_index_to_disk(&cache_dir, &index);
295            
296            // Update cache
297            let mut cached = self.vector_index.write()
298                .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Index lock poisoned: {}", e)))?;
299            *cached = Some(index.clone());
300            
301            // Mark as clean
302            self.index_dirty.store(false, Ordering::Release);
303            
304            Ok((index, IndexLoadKind::BuiltFromSpans))
305        } else {
306            // Return cached index
307            let cached = self.vector_index.read()
308                .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Index lock poisoned: {}", e)))?;
309            let idx = cached.as_ref()
310                .cloned()
311                .ok_or_else(|| crate::types::Error::Other(anyhow::anyhow!("Index cache is None but not dirty - this should not happen")))?;
312            Ok((idx, IndexLoadKind::CachedInMemory))
313        }
314    }
315    
316    /// Get the path to the index cache directory
317    fn get_index_cache_dir(&self) -> PathBuf {
318        // Store index cache in a directory next to database: db.sqlite -> db.sqlite.idx/
319        let mut cache_dir = self.db_path.clone();
320        cache_dir.set_extension("sqlite.idx");
321        cache_dir
322    }
323    
324    /// Calculate a hash of all spans to detect changes
325    fn calculate_spans_hash(&self) -> Result<String> {
326        let spans = self.get_all_spans()?;
327        let mut hasher = Sha256::new();
328        for span in &spans {
329            hasher.update(span.id.as_bytes());
330            if let Some(emb) = &span.embedding {
331                hasher.update(&emb.len().to_le_bytes());
332            }
333        }
334        Ok(format!("{:x}", hasher.finalize()))
335    }
336    
337    /// Load index from disk if valid
338    fn load_index_from_disk(&self, cache_dir: &Path) -> Result<Arc<VectorIndex>> {
339        // Try to load using VectorIndex::load_from_disk
340        match VectorIndex::load_from_disk(cache_dir) {
341            Ok(Some(index)) => {
342                // Verify hash matches current spans (double-check)
343                let current_hash = self.calculate_spans_hash()?;
344                let cached_spans = index.spans();
345                let mut hasher = Sha256::new();
346                for span in cached_spans {
347                    hasher.update(span.id.as_bytes());
348                    if let Some(emb) = &span.embedding {
349                        hasher.update(&emb.len().to_le_bytes());
350                    }
351                }
352                let cached_hash = format!("{:x}", hasher.finalize());
353                
354                if cached_hash == current_hash {
355                    Ok(Arc::new(index))
356                } else {
357                    Err(crate::types::Error::NotFound("Index cache is stale".to_string()))
358                }
359            }
360            Ok(None) => Err(crate::types::Error::NotFound("Index cache not found".to_string())),
361            Err(e) => Err(e),
362        }
363    }
364    
365    /// Save index to disk for persistence
366    fn save_index_to_disk(&self, cache_dir: &Path, index: &VectorIndex) -> Result<()> {
367        // Use VectorIndex::save_to_disk which saves both HNSW dump and spans
368        index.save_to_disk(cache_dir)
369            .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Failed to save index to disk: {}", e)))?;
370        Ok(())
371    }
372
373    /// Get all spans from the database
374    ///
375    /// # Returns
376    ///
377    /// Vector of all spans
378    pub fn get_all_spans(&self) -> Result<Vec<Span>> {
379        let conn = self.conn.lock()
380            .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
381        let mut stmt = conn.prepare(
382            "SELECT id, artifact_id, start_line, end_line, text,
383                    embedding, embedding_model, token_count, metadata
384             FROM spans",
385        )?;
386
387        let spans = stmt
388            .query_map([], |row| {
389                Ok(Span {
390                    id: row.get(0)?,
391                    artifact_id: row.get(1)?,
392                    start_line: row.get::<_, i64>(2)? as usize,
393                    end_line: row.get::<_, i64>(3)? as usize,
394                    text: row.get(4)?,
395                    embedding: row
396                        .get::<_, Option<Vec<u8>>>(5)?
397                        .map(|bytes| deserialize_embedding(&bytes)),
398                    embedding_model: row.get(6)?,
399                    token_count: row.get::<_, i64>(7)? as usize,
400                    metadata: row
401                        .get::<_, Option<String>>(8)?
402                        .and_then(|s| serde_json::from_str(&s).ok()),
403                })
404            })?
405            .collect::<std::result::Result<Vec<_>, _>>()?;
406
407        Ok(spans)
408    }
409
410    /// Get artifact by ID
411    ///
412    /// # Arguments
413    ///
414    /// * `artifact_id` - The artifact ID to look up
415    ///
416    /// # Returns
417    ///
418    /// The artifact if found
419    pub fn get_artifact(&self, artifact_id: &str) -> Result<Option<Artifact>> {
420        let conn = self.conn.lock()
421            .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
422        let mut stmt = conn.prepare(
423            "SELECT id, path, content, content_hash, metadata, created_at
424             FROM artifacts WHERE id = ?1",
425        )?;
426
427        let artifact = stmt
428            .query_row(params![artifact_id], |row| {
429                Ok(Artifact {
430                    id: row.get(0)?,
431                    path: row.get(1)?,
432                    content: row.get(2)?,
433                    content_hash: row.get(3)?,
434                    metadata: row
435                        .get::<_, Option<String>>(4)?
436                        .and_then(|s| serde_json::from_str(&s).ok()),
437                    created_at: row
438                        .get::<_, String>(5)?
439                        .parse()
440                        .unwrap_or_else(|_| chrono::Utc::now()),
441                })
442            })
443            .optional()?;
444
445        Ok(artifact)
446    }
447
448    /// Get artifact by path
449    ///
450    /// Returns the artifact row matching the unique path, if present.
451    pub fn get_artifact_by_path(&self, path: &str) -> Result<Option<Artifact>> {
452        let conn = self.conn.lock()
453            .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
454        let mut stmt = conn.prepare(
455            "SELECT id, path, content, content_hash, metadata, created_at
456             FROM artifacts WHERE path = ?1",
457        )?;
458
459        let artifact = stmt
460            .query_row(params![path], |row| {
461                Ok(Artifact {
462                    id: row.get(0)?,
463                    path: row.get(1)?,
464                    content: row.get(2)?,
465                    content_hash: row.get(3)?,
466                    metadata: row
467                        .get::<_, Option<String>>(4)?
468                        .and_then(|s| serde_json::from_str(&s).ok()),
469                    created_at: row
470                        .get::<_, String>(5)?
471                        .parse()
472                        .unwrap_or_else(|_| chrono::Utc::now()),
473                })
474            })
475            .optional()?;
476
477        Ok(artifact)
478    }
479
480    /// Search spans by text content (simple keyword matching)
481    ///
482    /// # Arguments
483    ///
484    /// * `query` - The search query
485    /// * `limit` - Maximum number of results
486    ///
487    /// # Returns
488    ///
489    /// Vector of matching spans
490    pub fn search_spans(&self, query: &str, limit: usize) -> Result<Vec<Span>> {
491        let conn = self.conn.lock()
492            .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
493        let mut stmt = conn.prepare(
494            "SELECT id, artifact_id, start_line, end_line, text,
495                    embedding, embedding_model, token_count, metadata
496             FROM spans
497             WHERE text LIKE ?1
498             LIMIT ?2",
499        )?;
500
501        let pattern = format!("%{}%", query);
502        let spans = stmt
503            .query_map(params![pattern, limit as i64], |row| {
504                Ok(Span {
505                    id: row.get(0)?,
506                    artifact_id: row.get(1)?,
507                    start_line: row.get::<_, i64>(2)? as usize,
508                    end_line: row.get::<_, i64>(3)? as usize,
509                    text: row.get(4)?,
510                    embedding: row
511                        .get::<_, Option<Vec<u8>>>(5)?
512                        .map(|bytes| deserialize_embedding(&bytes)),
513                    embedding_model: row.get(6)?,
514                    token_count: row.get::<_, i64>(7)? as usize,
515                    metadata: row
516                        .get::<_, Option<String>>(8)?
517                        .and_then(|s| serde_json::from_str(&s).ok()),
518                })
519            })?
520            .collect::<std::result::Result<Vec<_>, _>>()?;
521
522        Ok(spans)
523    }
524
525    /// Get database statistics
526    ///
527    /// # Returns
528    ///
529    /// (artifacts_count, spans_count, total_tokens)
530    pub fn get_stats(&self) -> Result<(usize, usize, usize)> {
531        let conn = self.conn.lock()
532            .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
533
534        let artifacts_count: i64 = conn.query_row("SELECT COUNT(*) FROM artifacts", [], |row| {
535            row.get(0)
536        })?;
537
538        let spans_count: i64 = conn.query_row("SELECT COUNT(*) FROM spans", [], |row| row.get(0))?;
539
540        let total_tokens: i64 = conn
541            .query_row("SELECT COALESCE(SUM(token_count), 0) FROM spans", [], |row| {
542                row.get(0)
543            })?;
544
545        Ok((
546            artifacts_count as usize,
547            spans_count as usize,
548            total_tokens as usize,
549        ))
550    }
551
552    /// Clear all data from the database
553    pub fn clear(&self) -> Result<()> {
554        let conn = self.conn.lock()
555            .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
556        conn.execute("DELETE FROM spans", [])?;
557        conn.execute("DELETE FROM artifacts", [])?;
558        // Clear cached index
559        let mut cached = self.vector_index.write()
560            .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Index lock poisoned: {}", e)))?;
561        *cached = None;
562        self.index_dirty.store(true, Ordering::Release);
563        // Delete index cache directory
564        let _ = std::fs::remove_dir_all(self.get_index_cache_dir());
565        Ok(())
566    }
567
568    // ========== Session Management Operations ==========
569
570    /// Create a new session
571    ///
572    /// # Arguments
573    ///
574    /// * `user_id` - Optional user identifier
575    /// * `title` - Optional session title
576    ///
577    /// # Returns
578    ///
579    /// The newly created session
580    pub fn create_session(&self, user_id: Option<&str>, title: Option<&str>) -> Result<Session> {
581        let conn = self.conn.lock()
582            .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
583
584        let id = uuid::Uuid::new_v4().to_string();
585        let now = chrono::Utc::now();
586
587        conn.execute(
588            "INSERT INTO sessions (id, user_id, title, metadata, created_at, updated_at, last_message_at)
589             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
590            params![
591                id,
592                user_id,
593                title,
594                None::<String>, // metadata
595                now.to_rfc3339(),
596                now.to_rfc3339(),
597                None::<String>, // last_message_at
598            ],
599        )?;
600
601        Ok(Session {
602            id,
603            user_id: user_id.map(|s| s.to_string()),
604            title: title.map(|s| s.to_string()),
605            metadata: None,
606            created_at: now,
607            updated_at: now,
608            last_message_at: None,
609        })
610    }
611
612    /// Get a session by ID
613    ///
614    /// # Arguments
615    ///
616    /// * `session_id` - The session ID to look up
617    ///
618    /// # Returns
619    ///
620    /// The session if found
621    pub fn get_session(&self, session_id: &str) -> Result<Option<Session>> {
622        let conn = self.conn.lock()
623            .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
624
625        let mut stmt = conn.prepare(
626            "SELECT id, user_id, title, metadata, created_at, updated_at, last_message_at
627             FROM sessions WHERE id = ?1",
628        )?;
629
630        let session = stmt.query_row(params![session_id], |row| {
631            Ok(Session {
632                id: row.get(0)?,
633                user_id: row.get(1)?,
634                title: row.get(2)?,
635                metadata: row.get::<_, Option<String>>(3)?
636                    .and_then(|s| serde_json::from_str(&s).ok()),
637                created_at: row.get::<_, String>(4)?
638                    .parse()
639                    .unwrap_or_else(|_| chrono::Utc::now()),
640                updated_at: row.get::<_, String>(5)?
641                    .parse()
642                    .unwrap_or_else(|_| chrono::Utc::now()),
643                last_message_at: row.get::<_, Option<String>>(6)?
644                    .and_then(|s| s.parse().ok()),
645            })
646        }).optional()?;
647
648        Ok(session)
649    }
650
651    /// List sessions for a user (or all sessions if user_id is None)
652    ///
653    /// # Arguments
654    ///
655    /// * `user_id` - Optional user ID to filter by
656    /// * `limit` - Maximum number of sessions to return
657    ///
658    /// # Returns
659    ///
660    /// Vector of sessions, sorted by updated_at descending
661    pub fn list_sessions(&self, user_id: Option<&str>, limit: Option<usize>) -> Result<Vec<Session>> {
662        let conn = self.conn.lock()
663            .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
664
665        let limit_val = limit.unwrap_or(100) as i64;
666
667        let mut sessions = Vec::new();
668
669        if let Some(uid) = user_id {
670            let mut stmt = conn.prepare(
671                "SELECT id, user_id, title, metadata, created_at, updated_at, last_message_at
672                 FROM sessions WHERE user_id = ?1
673                 ORDER BY updated_at DESC
674                 LIMIT ?2"
675            )?;
676
677            let rows = stmt.query_map(params![uid, limit_val], |row| {
678                Ok(Session {
679                    id: row.get(0)?,
680                    user_id: row.get(1)?,
681                    title: row.get(2)?,
682                    metadata: row.get::<_, Option<String>>(3)?
683                        .and_then(|s| serde_json::from_str(&s).ok()),
684                    created_at: row.get::<_, String>(4)?
685                        .parse()
686                        .unwrap_or_else(|_| chrono::Utc::now()),
687                    updated_at: row.get::<_, String>(5)?
688                        .parse()
689                        .unwrap_or_else(|_| chrono::Utc::now()),
690                    last_message_at: row.get::<_, Option<String>>(6)?
691                        .and_then(|s| s.parse().ok()),
692                })
693            })?;
694
695            for row in rows {
696                sessions.push(row?);
697            }
698        } else {
699            let mut stmt = conn.prepare(
700                "SELECT id, user_id, title, metadata, created_at, updated_at, last_message_at
701                 FROM sessions
702                 ORDER BY updated_at DESC
703                 LIMIT ?1"
704            )?;
705
706            let rows = stmt.query_map(params![limit_val], |row| {
707                Ok(Session {
708                    id: row.get(0)?,
709                    user_id: row.get(1)?,
710                    title: row.get(2)?,
711                    metadata: row.get::<_, Option<String>>(3)?
712                        .and_then(|s| serde_json::from_str(&s).ok()),
713                    created_at: row.get::<_, String>(4)?
714                        .parse()
715                        .unwrap_or_else(|_| chrono::Utc::now()),
716                    updated_at: row.get::<_, String>(5)?
717                        .parse()
718                        .unwrap_or_else(|_| chrono::Utc::now()),
719                    last_message_at: row.get::<_, Option<String>>(6)?
720                        .and_then(|s| s.parse().ok()),
721                })
722            })?;
723
724            for row in rows {
725                sessions.push(row?);
726            }
727        }
728
729        Ok(sessions)
730    }
731
732    /// Update session metadata
733    ///
734    /// # Arguments
735    ///
736    /// * `session_id` - The session ID to update
737    /// * `title` - Optional new title
738    /// * `metadata` - Optional new metadata
739    ///
740    /// # Returns
741    ///
742    /// Ok(()) if successful
743    pub fn update_session(
744        &self,
745        session_id: &str,
746        title: Option<&str>,
747        metadata: Option<&serde_json::Value>,
748    ) -> Result<()> {
749        let conn = self.conn.lock()
750            .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
751
752        let now = chrono::Utc::now();
753
754        conn.execute(
755            "UPDATE sessions
756             SET title = COALESCE(?1, title),
757                 metadata = COALESCE(?2, metadata),
758                 updated_at = ?3
759             WHERE id = ?4",
760            params![
761                title,
762                metadata.map(|m| m.to_string()),
763                now.to_rfc3339(),
764                session_id,
765            ],
766        )?;
767
768        Ok(())
769    }
770
771    /// Delete a session (cascades to messages and working sets)
772    ///
773    /// # Arguments
774    ///
775    /// * `session_id` - The session ID to delete
776    ///
777    /// # Returns
778    ///
779    /// Ok(()) if successful
780    pub fn delete_session(&self, session_id: &str) -> Result<()> {
781        let conn = self.conn.lock()
782            .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
783
784        conn.execute("DELETE FROM sessions WHERE id = ?1", params![session_id])?;
785
786        Ok(())
787    }
788
789    /// Add a message to a session
790    ///
791    /// # Arguments
792    ///
793    /// * `session_id` - The session to add the message to
794    /// * `role` - Message role (user, assistant, system, tool)
795    /// * `content` - Message content
796    /// * `metadata` - Optional metadata
797    ///
798    /// # Returns
799    ///
800    /// The newly created message
801    pub fn add_message(
802        &self,
803        session_id: &str,
804        role: MessageRole,
805        content: &str,
806        metadata: Option<&serde_json::Value>,
807    ) -> Result<Message> {
808        let mut conn = self.conn.lock()
809            .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
810
811        let tx = conn.transaction()?;
812
813        // Get the next sequence number for this session
814        let sequence_number: i64 = tx.query_row(
815            "SELECT COALESCE(MAX(sequence_number), -1) + 1 FROM messages WHERE session_id = ?1",
816            params![session_id],
817            |row| row.get(0),
818        )?;
819
820        let id = uuid::Uuid::new_v4().to_string();
821        let now = chrono::Utc::now();
822
823        tx.execute(
824            "INSERT INTO messages (id, session_id, role, content, metadata, sequence_number, created_at)
825             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
826            params![
827                id,
828                session_id,
829                role.as_str(),
830                content,
831                metadata.map(|m| m.to_string()),
832                sequence_number,
833                now.to_rfc3339(),
834            ],
835        )?;
836
837        // Update session's last_message_at and updated_at
838        tx.execute(
839            "UPDATE sessions
840             SET last_message_at = ?1, updated_at = ?1
841             WHERE id = ?2",
842            params![now.to_rfc3339(), session_id],
843        )?;
844
845        tx.commit()?;
846
847        Ok(Message {
848            id,
849            session_id: session_id.to_string(),
850            role,
851            content: content.to_string(),
852            metadata: metadata.cloned(),
853            sequence_number: sequence_number as usize,
854            created_at: now,
855        })
856    }
857
858    /// Get messages for a session
859    ///
860    /// # Arguments
861    ///
862    /// * `session_id` - The session ID
863    /// * `limit` - Optional limit on number of messages (most recent first)
864    ///
865    /// # Returns
866    ///
867    /// Vector of messages in chronological order
868    pub fn get_messages(&self, session_id: &str, limit: Option<usize>) -> Result<Vec<Message>> {
869        let conn = self.conn.lock()
870            .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
871
872        let mut messages = Vec::new();
873
874        if let Some(lim) = limit {
875            let mut stmt = conn.prepare(
876                "SELECT id, session_id, role, content, metadata, sequence_number, created_at
877                 FROM messages
878                 WHERE session_id = ?1
879                 ORDER BY sequence_number ASC
880                 LIMIT ?2"
881            )?;
882
883            let rows = stmt.query_map(params![session_id, lim as i64], |row| {
884                let role_str: String = row.get(2)?;
885                let role = MessageRole::from_str(&role_str)
886                    .unwrap_or(MessageRole::User);
887
888                Ok(Message {
889                    id: row.get(0)?,
890                    session_id: row.get(1)?,
891                    role,
892                    content: row.get(3)?,
893                    metadata: row.get::<_, Option<String>>(4)?
894                        .and_then(|s| serde_json::from_str(&s).ok()),
895                    sequence_number: row.get::<_, i64>(5)? as usize,
896                    created_at: row.get::<_, String>(6)?
897                        .parse()
898                        .unwrap_or_else(|_| chrono::Utc::now()),
899                })
900            })?;
901
902            for row in rows {
903                messages.push(row?);
904            }
905        } else {
906            let mut stmt = conn.prepare(
907                "SELECT id, session_id, role, content, metadata, sequence_number, created_at
908                 FROM messages
909                 WHERE session_id = ?1
910                 ORDER BY sequence_number ASC"
911            )?;
912
913            let rows = stmt.query_map(params![session_id], |row| {
914                let role_str: String = row.get(2)?;
915                let role = MessageRole::from_str(&role_str)
916                    .unwrap_or(MessageRole::User);
917
918                Ok(Message {
919                    id: row.get(0)?,
920                    session_id: row.get(1)?,
921                    role,
922                    content: row.get(3)?,
923                    metadata: row.get::<_, Option<String>>(4)?
924                        .and_then(|s| serde_json::from_str(&s).ok()),
925                    sequence_number: row.get::<_, i64>(5)? as usize,
926                    created_at: row.get::<_, String>(6)?
927                        .parse()
928                        .unwrap_or_else(|_| chrono::Utc::now()),
929                })
930            })?;
931
932            for row in rows {
933                messages.push(row?);
934            }
935        }
936
937        Ok(messages)
938    }
939
940    /// Associate a working set with a session
941    ///
942    /// # Arguments
943    ///
944    /// * `session_id` - The session ID
945    /// * `message_id` - Optional message ID that triggered this compilation
946    /// * `working_set` - The working set to associate
947    /// * `query` - Query that generated this working set
948    /// * `config` - Configuration used for compilation
949    ///
950    /// # Returns
951    ///
952    /// The newly created SessionWorkingSet
953    pub fn associate_working_set(
954        &self,
955        session_id: &str,
956        message_id: Option<&str>,
957        working_set: &WorkingSet,
958        query: &str,
959        config: &CompilerConfig,
960    ) -> Result<SessionWorkingSet> {
961        let conn = self.conn.lock()
962            .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
963
964        let id = uuid::Uuid::new_v4().to_string();
965        let working_set_id = uuid::Uuid::new_v4().to_string(); // Generate a unique ID for this working set
966        let now = chrono::Utc::now();
967
968        conn.execute(
969            "INSERT INTO session_working_sets (id, session_id, message_id, working_set_id, query, config, created_at)
970             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
971            params![
972                id,
973                session_id,
974                message_id,
975                working_set_id,
976                query,
977                serde_json::to_string(config)?,
978                now.to_rfc3339(),
979            ],
980        )?;
981
982        Ok(SessionWorkingSet {
983            id,
984            session_id: session_id.to_string(),
985            message_id: message_id.map(|s| s.to_string()),
986            working_set: working_set.clone(),
987            query: query.to_string(),
988            config: config.clone(),
989            created_at: now,
990        })
991    }
992
993    /// Get session with all messages and working sets
994    ///
995    /// # Arguments
996    ///
997    /// * `session_id` - The session ID
998    ///
999    /// # Returns
1000    ///
1001    /// SessionWithMessages if found
1002    pub fn get_session_full(&self, session_id: &str) -> Result<Option<SessionWithMessages>> {
1003        let session = self.get_session(session_id)?;
1004
1005        if session.is_none() {
1006            return Ok(None);
1007        }
1008
1009        let session = session.unwrap();
1010        let messages = self.get_messages(session_id, None)?;
1011
1012        // Get working sets for this session
1013        let conn = self.conn.lock()
1014            .map_err(|e| crate::types::Error::Other(anyhow::anyhow!("Database lock poisoned: {}", e)))?;
1015
1016        let mut stmt = conn.prepare(
1017            "SELECT id, session_id, message_id, working_set_id, query, config, created_at
1018             FROM session_working_sets
1019             WHERE session_id = ?1
1020             ORDER BY created_at ASC",
1021        )?;
1022
1023        let working_sets = stmt.query_map(params![session_id], |row| {
1024            let config_str: String = row.get(5)?;
1025            let config: CompilerConfig = serde_json::from_str(&config_str)
1026                .unwrap_or_default();
1027
1028            // Note: We can't reconstruct the full WorkingSet from storage without additional data
1029            // For now, we'll create a placeholder. In a real implementation, you'd store the
1030            // working set data as JSON and deserialize it here.
1031            let working_set = WorkingSet {
1032                text: String::new(),
1033                spans: Vec::new(),
1034                citations: Vec::new(),
1035                tokens_used: 0,
1036                query: row.get::<_, String>(4)?,
1037                compilation_time_ms: 0,
1038            };
1039
1040            Ok(SessionWorkingSet {
1041                id: row.get(0)?,
1042                session_id: row.get(1)?,
1043                message_id: row.get(2)?,
1044                working_set,
1045                query: row.get(4)?,
1046                config,
1047                created_at: row.get::<_, String>(6)?
1048                    .parse()
1049                    .unwrap_or_else(|_| chrono::Utc::now()),
1050            })
1051        })?
1052        .collect::<std::result::Result<Vec<_>, _>>()?;
1053
1054        Ok(Some(SessionWithMessages {
1055            session,
1056            messages,
1057            working_sets,
1058        }))
1059    }
1060}
1061
1062/// Serialize embedding vector to bytes for storage
1063fn serialize_embedding(embedding: &[f32]) -> Vec<u8> {
1064    embedding.iter().flat_map(|f| f.to_le_bytes()).collect()
1065}
1066
1067/// Deserialize embedding vector from bytes
1068fn deserialize_embedding(bytes: &[u8]) -> Vec<f32> {
1069    bytes
1070        .chunks_exact(4)
1071        .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
1072        .collect()
1073}
1074
1075#[cfg(test)]
1076mod tests {
1077    use super::*;
1078    use uuid::Uuid;
1079
1080    #[test]
1081    fn test_database_creation() {
1082        let db = Database::new(":memory:").unwrap();
1083        let (artifacts, spans, tokens) = db.get_stats().unwrap();
1084        assert_eq!(artifacts, 0);
1085        assert_eq!(spans, 0);
1086        assert_eq!(tokens, 0);
1087    }
1088
1089    #[test]
1090    fn test_insert_artifact() {
1091        let db = Database::new(":memory:").unwrap();
1092
1093        let artifact = Artifact {
1094            id: Uuid::new_v4().to_string(),
1095            path: "test.txt".to_string(),
1096            content: "Test content".to_string(),
1097            content_hash: "hash123".to_string(),
1098            metadata: None,
1099            created_at: chrono::Utc::now(),
1100        };
1101
1102        db.insert_artifact(&artifact).unwrap();
1103
1104        let (count, _, _) = db.get_stats().unwrap();
1105        assert_eq!(count, 1);
1106    }
1107
1108    #[test]
1109    fn test_embedding_serialization() {
1110        let original = vec![1.0, 2.5, -3.14, 0.0];
1111        let bytes = serialize_embedding(&original);
1112        let restored = deserialize_embedding(&bytes);
1113
1114        assert_eq!(original.len(), restored.len());
1115        for (a, b) in original.iter().zip(restored.iter()) {
1116            assert!((a - b).abs() < 0.0001);
1117        }
1118    }
1119
1120    // ========== Session Management Tests ==========
1121
1122    #[test]
1123    fn test_create_session() {
1124        let db = Database::new(":memory:").unwrap();
1125
1126        let session = db.create_session(Some("user123"), Some("Test Session")).unwrap();
1127
1128        assert!(!session.id.is_empty());
1129        assert_eq!(session.user_id, Some("user123".to_string()));
1130        assert_eq!(session.title, Some("Test Session".to_string()));
1131        assert!(session.metadata.is_none());
1132        assert!(session.last_message_at.is_none());
1133    }
1134
1135    #[test]
1136    fn test_get_session() {
1137        let db = Database::new(":memory:").unwrap();
1138
1139        let created = db.create_session(Some("user456"), Some("Another Session")).unwrap();
1140        let retrieved = db.get_session(&created.id).unwrap();
1141
1142        assert!(retrieved.is_some());
1143        let session = retrieved.unwrap();
1144        assert_eq!(session.id, created.id);
1145        assert_eq!(session.user_id, created.user_id);
1146        assert_eq!(session.title, created.title);
1147    }
1148
1149    #[test]
1150    fn test_get_nonexistent_session() {
1151        let db = Database::new(":memory:").unwrap();
1152
1153        let result = db.get_session("nonexistent-id").unwrap();
1154        assert!(result.is_none());
1155    }
1156
1157    #[test]
1158    fn test_list_sessions() {
1159        let db = Database::new(":memory:").unwrap();
1160
1161        // Create multiple sessions
1162        db.create_session(Some("user1"), Some("Session 1")).unwrap();
1163        db.create_session(Some("user1"), Some("Session 2")).unwrap();
1164        db.create_session(Some("user2"), Some("Session 3")).unwrap();
1165
1166        // List all sessions
1167        let all_sessions = db.list_sessions(None, None).unwrap();
1168        assert_eq!(all_sessions.len(), 3);
1169
1170        // List sessions for user1
1171        let user1_sessions = db.list_sessions(Some("user1"), None).unwrap();
1172        assert_eq!(user1_sessions.len(), 2);
1173
1174        // List sessions for user2
1175        let user2_sessions = db.list_sessions(Some("user2"), None).unwrap();
1176        assert_eq!(user2_sessions.len(), 1);
1177
1178        // Test limit
1179        let limited = db.list_sessions(None, Some(2)).unwrap();
1180        assert_eq!(limited.len(), 2);
1181    }
1182
1183    #[test]
1184    fn test_update_session() {
1185        let db = Database::new(":memory:").unwrap();
1186
1187        let session = db.create_session(Some("user1"), Some("Original Title")).unwrap();
1188
1189        // Update title
1190        db.update_session(&session.id, Some("Updated Title"), None).unwrap();
1191
1192        let updated = db.get_session(&session.id).unwrap().unwrap();
1193        assert_eq!(updated.title, Some("Updated Title".to_string()));
1194
1195        // Update metadata
1196        let metadata = serde_json::json!({"key": "value"});
1197        db.update_session(&session.id, None, Some(&metadata)).unwrap();
1198
1199        let updated2 = db.get_session(&session.id).unwrap().unwrap();
1200        assert!(updated2.metadata.is_some());
1201        assert_eq!(updated2.metadata.unwrap()["key"], "value");
1202    }
1203
1204    #[test]
1205    fn test_delete_session() {
1206        let db = Database::new(":memory:").unwrap();
1207
1208        let session = db.create_session(Some("user1"), Some("To Delete")).unwrap();
1209
1210        // Verify session exists
1211        assert!(db.get_session(&session.id).unwrap().is_some());
1212
1213        // Delete session
1214        db.delete_session(&session.id).unwrap();
1215
1216        // Verify session is gone
1217        assert!(db.get_session(&session.id).unwrap().is_none());
1218    }
1219
1220    #[test]
1221    fn test_add_message() {
1222        let db = Database::new(":memory:").unwrap();
1223
1224        let session = db.create_session(Some("user1"), Some("Chat Session")).unwrap();
1225
1226        // Add first message
1227        let msg1 = db.add_message(&session.id, MessageRole::User, "Hello", None).unwrap();
1228        assert_eq!(msg1.sequence_number, 0);
1229        assert_eq!(msg1.content, "Hello");
1230        assert_eq!(msg1.role.as_str(), "user");
1231
1232        // Add second message
1233        let msg2 = db.add_message(&session.id, MessageRole::Assistant, "Hi there!", None).unwrap();
1234        assert_eq!(msg2.sequence_number, 1);
1235        assert_eq!(msg2.content, "Hi there!");
1236        assert_eq!(msg2.role.as_str(), "assistant");
1237
1238        // Verify session was updated
1239        let updated_session = db.get_session(&session.id).unwrap().unwrap();
1240        assert!(updated_session.last_message_at.is_some());
1241    }
1242
1243    #[test]
1244    fn test_add_message_with_metadata() {
1245        let db = Database::new(":memory:").unwrap();
1246
1247        let session = db.create_session(Some("user1"), Some("Chat Session")).unwrap();
1248
1249        let metadata = serde_json::json!({"tool": "search", "query": "test"});
1250        let msg = db.add_message(&session.id, MessageRole::Tool, "Result", Some(&metadata)).unwrap();
1251
1252        assert!(msg.metadata.is_some());
1253        assert_eq!(msg.metadata.unwrap()["tool"], "search");
1254    }
1255
1256    #[test]
1257    fn test_get_messages() {
1258        let db = Database::new(":memory:").unwrap();
1259
1260        let session = db.create_session(Some("user1"), Some("Chat Session")).unwrap();
1261
1262        // Add multiple messages
1263        db.add_message(&session.id, MessageRole::User, "Message 1", None).unwrap();
1264        db.add_message(&session.id, MessageRole::Assistant, "Message 2", None).unwrap();
1265        db.add_message(&session.id, MessageRole::User, "Message 3", None).unwrap();
1266
1267        // Get all messages
1268        let messages = db.get_messages(&session.id, None).unwrap();
1269        assert_eq!(messages.len(), 3);
1270        assert_eq!(messages[0].sequence_number, 0);
1271        assert_eq!(messages[1].sequence_number, 1);
1272        assert_eq!(messages[2].sequence_number, 2);
1273
1274        // Test limit
1275        let limited = db.get_messages(&session.id, Some(2)).unwrap();
1276        assert_eq!(limited.len(), 2);
1277    }
1278
1279    #[test]
1280    fn test_message_ordering() {
1281        let db = Database::new(":memory:").unwrap();
1282
1283        let session = db.create_session(Some("user1"), Some("Chat Session")).unwrap();
1284
1285        // Add messages
1286        db.add_message(&session.id, MessageRole::User, "First", None).unwrap();
1287        db.add_message(&session.id, MessageRole::Assistant, "Second", None).unwrap();
1288        db.add_message(&session.id, MessageRole::User, "Third", None).unwrap();
1289
1290        let messages = db.get_messages(&session.id, None).unwrap();
1291
1292        // Verify chronological order
1293        assert_eq!(messages[0].content, "First");
1294        assert_eq!(messages[1].content, "Second");
1295        assert_eq!(messages[2].content, "Third");
1296
1297        // Verify sequence numbers are consecutive
1298        for (i, msg) in messages.iter().enumerate() {
1299            assert_eq!(msg.sequence_number, i);
1300        }
1301    }
1302
1303    #[test]
1304    fn test_associate_working_set() {
1305        let db = Database::new(":memory:").unwrap();
1306
1307        let session = db.create_session(Some("user1"), Some("Chat Session")).unwrap();
1308        let message = db.add_message(&session.id, MessageRole::User, "Query", None).unwrap();
1309
1310        // Create a working set
1311        let working_set = WorkingSet {
1312            text: "Test context".to_string(),
1313            spans: Vec::new(),
1314            citations: Vec::new(),
1315            tokens_used: 100,
1316            query: "test query".to_string(),
1317            compilation_time_ms: 50,
1318        };
1319
1320        let config = CompilerConfig::default();
1321
1322        let sws = db.associate_working_set(
1323            &session.id,
1324            Some(&message.id),
1325            &working_set,
1326            "test query",
1327            &config,
1328        ).unwrap();
1329
1330        assert_eq!(sws.session_id, session.id);
1331        assert_eq!(sws.message_id, Some(message.id));
1332        assert_eq!(sws.query, "test query");
1333        assert_eq!(sws.working_set.text, "Test context");
1334    }
1335
1336    #[test]
1337    fn test_get_session_full() {
1338        let db = Database::new(":memory:").unwrap();
1339
1340        let session = db.create_session(Some("user1"), Some("Full Session")).unwrap();
1341
1342        // Add messages
1343        let msg1 = db.add_message(&session.id, MessageRole::User, "Hello", None).unwrap();
1344        db.add_message(&session.id, MessageRole::Assistant, "Hi!", None).unwrap();
1345
1346        // Add working set
1347        let working_set = WorkingSet {
1348            text: "Context".to_string(),
1349            spans: Vec::new(),
1350            citations: Vec::new(),
1351            tokens_used: 50,
1352            query: "test".to_string(),
1353            compilation_time_ms: 25,
1354        };
1355
1356        db.associate_working_set(
1357            &session.id,
1358            Some(&msg1.id),
1359            &working_set,
1360            "test",
1361            &CompilerConfig::default(),
1362        ).unwrap();
1363
1364        // Get full session
1365        let full = db.get_session_full(&session.id).unwrap();
1366        assert!(full.is_some());
1367
1368        let swm = full.unwrap();
1369        assert_eq!(swm.session.id, session.id);
1370        assert_eq!(swm.messages.len(), 2);
1371        assert_eq!(swm.working_sets.len(), 1);
1372    }
1373
1374    #[test]
1375    fn test_delete_session_cascade() {
1376        let db = Database::new(":memory:").unwrap();
1377
1378        let session = db.create_session(Some("user1"), Some("To Delete")).unwrap();
1379
1380        // Add messages
1381        db.add_message(&session.id, MessageRole::User, "Message 1", None).unwrap();
1382        db.add_message(&session.id, MessageRole::Assistant, "Message 2", None).unwrap();
1383
1384        // Verify messages exist
1385        let messages_before = db.get_messages(&session.id, None).unwrap();
1386        assert_eq!(messages_before.len(), 2);
1387
1388        // Delete session
1389        db.delete_session(&session.id).unwrap();
1390
1391        // Verify messages are gone (cascade delete)
1392        let messages_after = db.get_messages(&session.id, None).unwrap();
1393        assert_eq!(messages_after.len(), 0);
1394    }
1395
1396    #[test]
1397    fn test_message_role_conversion() {
1398        assert_eq!(MessageRole::User.as_str(), "user");
1399        assert_eq!(MessageRole::Assistant.as_str(), "assistant");
1400        assert_eq!(MessageRole::System.as_str(), "system");
1401        assert_eq!(MessageRole::Tool.as_str(), "tool");
1402
1403        assert!(matches!(MessageRole::from_str("user").unwrap(), MessageRole::User));
1404        assert!(matches!(MessageRole::from_str("assistant").unwrap(), MessageRole::Assistant));
1405        assert!(matches!(MessageRole::from_str("system").unwrap(), MessageRole::System));
1406        assert!(matches!(MessageRole::from_str("tool").unwrap(), MessageRole::Tool));
1407
1408        assert!(MessageRole::from_str("invalid").is_err());
1409    }
1410}