Skip to main content

rlm_rs/storage/
sqlite.rs

1//! `SQLite` storage implementation.
2//!
3//! Provides persistent storage using `SQLite` with proper transaction
4//! management and migration support.
5
6// SQLite stores all integers as i64. These casts are intentional and safe
7// because we only store non-negative values that fit in usize.
8#![allow(clippy::cast_possible_truncation)]
9#![allow(clippy::cast_sign_loss)]
10
11use crate::core::{Buffer, BufferMetadata, Chunk, ChunkMetadata, Context};
12use crate::error::{Result, StorageError};
13use crate::storage::schema::{
14    CHECK_SCHEMA_SQL, CURRENT_SCHEMA_VERSION, GET_VERSION_SQL, SCHEMA_SQL, SET_VERSION_SQL,
15};
16use crate::storage::traits::{Storage, StorageStats};
17use rusqlite::{Connection, OptionalExtension, params};
18use std::path::{Path, PathBuf};
19
20/// SQLite-based storage implementation.
21///
22/// Provides persistent storage for RLM state with full ACID guarantees.
23///
24/// # Examples
25///
26/// ```no_run
27/// use rlm_rs::storage::{SqliteStorage, Storage};
28///
29/// let mut storage = SqliteStorage::open("rlm-state.db").unwrap();
30/// storage.init().unwrap();
31/// ```
32pub struct SqliteStorage {
33    /// `SQLite` connection.
34    conn: Connection,
35    /// Path to the database file (None for in-memory).
36    path: Option<PathBuf>,
37}
38
39impl SqliteStorage {
40    /// Opens or creates a `SQLite` database at the given path.
41    ///
42    /// # Arguments
43    ///
44    /// * `path` - Path to the database file. Parent directory must exist.
45    ///
46    /// # Errors
47    ///
48    /// Returns an error if the database cannot be opened or initialized.
49    pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
50        let path = path.as_ref().to_path_buf();
51
52        // Ensure parent directory exists
53        if let Some(parent) = path.parent()
54            && !parent.exists()
55        {
56            std::fs::create_dir_all(parent).map_err(|e| StorageError::Database(e.to_string()))?;
57        }
58
59        let conn = Connection::open(&path).map_err(StorageError::from)?;
60
61        // Enable foreign keys
62        conn.execute("PRAGMA foreign_keys = ON;", [])
63            .map_err(StorageError::from)?;
64
65        // Use WAL mode for better concurrent access (returns result, use query_row)
66        let _: String = conn
67            .query_row("PRAGMA journal_mode = WAL;", [], |row| row.get(0))
68            .map_err(StorageError::from)?;
69
70        Ok(Self {
71            conn,
72            path: Some(path),
73        })
74    }
75
76    /// Creates an in-memory `SQLite` database.
77    ///
78    /// Useful for testing.
79    ///
80    /// # Errors
81    ///
82    /// Returns an error if the database cannot be created.
83    pub fn in_memory() -> Result<Self> {
84        let conn = Connection::open_in_memory().map_err(StorageError::from)?;
85        conn.execute("PRAGMA foreign_keys = ON;", [])
86            .map_err(StorageError::from)?;
87
88        Ok(Self { conn, path: None })
89    }
90
91    /// Returns the database path (None for in-memory).
92    #[must_use]
93    pub fn path(&self) -> Option<&Path> {
94        self.path.as_deref()
95    }
96
97    /// Gets the current schema version.
98    fn get_schema_version(&self) -> Result<Option<u32>> {
99        let version: Option<String> = self
100            .conn
101            .query_row(GET_VERSION_SQL, [], |row| row.get(0))
102            .optional()
103            .map_err(StorageError::from)?;
104
105        Ok(version.and_then(|v| v.parse().ok()))
106    }
107
108    /// Sets the schema version.
109    fn set_schema_version(&self, version: u32) -> Result<()> {
110        self.conn
111            .execute(SET_VERSION_SQL, params![version.to_string()])
112            .map_err(StorageError::from)?;
113        Ok(())
114    }
115
116    /// Returns current Unix timestamp.
117    #[allow(clippy::cast_possible_wrap)]
118    fn now() -> i64 {
119        std::time::SystemTime::now()
120            .duration_since(std::time::UNIX_EPOCH)
121            .map_or(0, |d| d.as_secs() as i64)
122    }
123}
124
125impl Storage for SqliteStorage {
126    fn init(&mut self) -> Result<()> {
127        // Check if already initialized
128        let is_init: i64 = self
129            .conn
130            .query_row(CHECK_SCHEMA_SQL, [], |row| row.get(0))
131            .map_err(StorageError::from)?;
132
133        if is_init == 0 {
134            // Fresh install - create schema
135            self.conn
136                .execute_batch(SCHEMA_SQL)
137                .map_err(StorageError::from)?;
138            self.set_schema_version(CURRENT_SCHEMA_VERSION)?;
139        } else if let Some(current) = self.get_schema_version()?
140            && current < CURRENT_SCHEMA_VERSION
141        {
142            // Run migrations
143            let migrations = crate::storage::schema::get_migrations_from(current);
144            for migration in migrations {
145                self.conn
146                    .execute_batch(migration.sql)
147                    .map_err(|e| StorageError::Migration(e.to_string()))?;
148            }
149            self.set_schema_version(CURRENT_SCHEMA_VERSION)?;
150        }
151
152        Ok(())
153    }
154
155    fn is_initialized(&self) -> Result<bool> {
156        let count: i64 = self
157            .conn
158            .query_row(CHECK_SCHEMA_SQL, [], |row| row.get(0))
159            .map_err(StorageError::from)?;
160        Ok(count > 0)
161    }
162
163    fn reset(&mut self) -> Result<()> {
164        self.conn
165            .execute_batch(
166                r"
167            DELETE FROM chunk_embeddings;
168            DELETE FROM chunks;
169            DELETE FROM buffers;
170            DELETE FROM context;
171            DELETE FROM metadata;
172        ",
173            )
174            .map_err(StorageError::from)?;
175        Ok(())
176    }
177
178    // ==================== Context Operations ====================
179
180    fn save_context(&mut self, context: &Context) -> Result<()> {
181        let data = serde_json::to_string(context).map_err(StorageError::from)?;
182        let now = Self::now();
183
184        self.conn
185            .execute(
186                r"
187            INSERT OR REPLACE INTO context (id, data, created_at, updated_at)
188            VALUES (1, ?, COALESCE((SELECT created_at FROM context WHERE id = 1), ?), ?)
189        ",
190                params![data, now, now],
191            )
192            .map_err(StorageError::from)?;
193
194        Ok(())
195    }
196
197    fn load_context(&self) -> Result<Option<Context>> {
198        let data: Option<String> = self
199            .conn
200            .query_row("SELECT data FROM context WHERE id = 1", [], |row| {
201                row.get(0)
202            })
203            .optional()
204            .map_err(StorageError::from)?;
205
206        match data {
207            Some(json) => {
208                let context = serde_json::from_str(&json).map_err(StorageError::from)?;
209                Ok(Some(context))
210            }
211            None => Ok(None),
212        }
213    }
214
215    fn delete_context(&mut self) -> Result<()> {
216        self.conn
217            .execute("DELETE FROM context WHERE id = 1", [])
218            .map_err(StorageError::from)?;
219        Ok(())
220    }
221
222    // ==================== Buffer Operations ====================
223
224    #[allow(clippy::cast_possible_wrap)]
225    fn add_buffer(&mut self, buffer: &Buffer) -> Result<i64> {
226        let now = Self::now();
227
228        self.conn
229            .execute(
230                r"
231            INSERT INTO buffers (
232                name, source_path, content, content_type, content_hash,
233                size, line_count, chunk_count, created_at, updated_at
234            ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
235        ",
236                params![
237                    buffer.name,
238                    buffer
239                        .source
240                        .as_ref()
241                        .map(|p| p.to_string_lossy().to_string()),
242                    buffer.content,
243                    buffer.metadata.content_type,
244                    buffer.metadata.content_hash,
245                    buffer.metadata.size as i64,
246                    buffer.metadata.line_count.map(|c| c as i64),
247                    buffer.metadata.chunk_count.map(|c| c as i64),
248                    now,
249                    now,
250                ],
251            )
252            .map_err(StorageError::from)?;
253
254        Ok(self.conn.last_insert_rowid())
255    }
256
257    fn get_buffer(&self, id: i64) -> Result<Option<Buffer>> {
258        let result = self
259            .conn
260            .query_row(
261                r"
262            SELECT id, name, source_path, content, content_type, content_hash,
263                   size, line_count, chunk_count, created_at, updated_at
264            FROM buffers WHERE id = ?
265        ",
266                params![id],
267                |row| {
268                    Ok(Buffer {
269                        id: Some(row.get::<_, i64>(0)?),
270                        name: row.get(1)?,
271                        source: row.get::<_, Option<String>>(2)?.map(PathBuf::from),
272                        content: row.get(3)?,
273                        metadata: BufferMetadata {
274                            content_type: row.get(4)?,
275                            content_hash: row.get(5)?,
276                            size: row.get::<_, i64>(6)? as usize,
277                            line_count: row.get::<_, Option<i64>>(7)?.map(|c| c as usize),
278                            chunk_count: row.get::<_, Option<i64>>(8)?.map(|c| c as usize),
279                            created_at: row.get(9)?,
280                            updated_at: row.get(10)?,
281                        },
282                    })
283                },
284            )
285            .optional()
286            .map_err(StorageError::from)?;
287
288        Ok(result)
289    }
290
291    fn get_buffer_by_name(&self, name: &str) -> Result<Option<Buffer>> {
292        let id: Option<i64> = self
293            .conn
294            .query_row(
295                "SELECT id FROM buffers WHERE name = ?",
296                params![name],
297                |row| row.get(0),
298            )
299            .optional()
300            .map_err(StorageError::from)?;
301
302        id.map_or(Ok(None), |id| self.get_buffer(id))
303    }
304
305    fn list_buffers(&self) -> Result<Vec<Buffer>> {
306        let mut stmt = self
307            .conn
308            .prepare(
309                r"
310            SELECT id, name, source_path, content, content_type, content_hash,
311                   size, line_count, chunk_count, created_at, updated_at
312            FROM buffers ORDER BY id
313        ",
314            )
315            .map_err(StorageError::from)?;
316
317        let buffers = stmt
318            .query_map([], |row| {
319                Ok(Buffer {
320                    id: Some(row.get::<_, i64>(0)?),
321                    name: row.get(1)?,
322                    source: row.get::<_, Option<String>>(2)?.map(PathBuf::from),
323                    content: row.get(3)?,
324                    metadata: BufferMetadata {
325                        content_type: row.get(4)?,
326                        content_hash: row.get(5)?,
327                        size: row.get::<_, i64>(6)? as usize,
328                        line_count: row.get::<_, Option<i64>>(7)?.map(|c| c as usize),
329                        chunk_count: row.get::<_, Option<i64>>(8)?.map(|c| c as usize),
330                        created_at: row.get(9)?,
331                        updated_at: row.get(10)?,
332                    },
333                })
334            })
335            .map_err(StorageError::from)?
336            .collect::<std::result::Result<Vec<_>, _>>()
337            .map_err(StorageError::from)?;
338
339        Ok(buffers)
340    }
341
342    #[allow(clippy::cast_possible_wrap)]
343    fn update_buffer(&mut self, buffer: &Buffer) -> Result<()> {
344        let id = buffer.id.ok_or_else(|| StorageError::BufferNotFound {
345            identifier: "no ID".to_string(),
346        })?;
347
348        let now = Self::now();
349
350        self.conn
351            .execute(
352                r"
353            UPDATE buffers SET
354                name = ?, source_path = ?, content = ?, content_type = ?,
355                content_hash = ?, size = ?, line_count = ?, chunk_count = ?,
356                updated_at = ?
357            WHERE id = ?
358        ",
359                params![
360                    buffer.name,
361                    buffer
362                        .source
363                        .as_ref()
364                        .map(|p| p.to_string_lossy().to_string()),
365                    buffer.content,
366                    buffer.metadata.content_type,
367                    buffer.metadata.content_hash,
368                    buffer.metadata.size as i64,
369                    buffer.metadata.line_count.map(|c| c as i64),
370                    buffer.metadata.chunk_count.map(|c| c as i64),
371                    now,
372                    id,
373                ],
374            )
375            .map_err(StorageError::from)?;
376
377        Ok(())
378    }
379
380    fn delete_buffer(&mut self, id: i64) -> Result<()> {
381        // Chunks are deleted automatically via CASCADE
382        self.conn
383            .execute("DELETE FROM buffers WHERE id = ?", params![id])
384            .map_err(StorageError::from)?;
385        Ok(())
386    }
387
388    fn buffer_count(&self) -> Result<usize> {
389        let count: i64 = self
390            .conn
391            .query_row("SELECT COUNT(*) FROM buffers", [], |row| row.get(0))
392            .map_err(StorageError::from)?;
393        Ok(count as usize)
394    }
395
396    // ==================== Chunk Operations ====================
397
398    #[allow(clippy::cast_possible_wrap)]
399    fn add_chunks(&mut self, buffer_id: i64, chunks: &[Chunk]) -> Result<()> {
400        let tx = self.conn.transaction().map_err(StorageError::from)?;
401        let now = Self::now();
402
403        {
404            let mut stmt = tx
405                .prepare(
406                    r"
407                INSERT INTO chunks (
408                    buffer_id, content, byte_start, byte_end, chunk_index,
409                    strategy, token_count, line_start, line_end, has_overlap,
410                    content_hash, custom_metadata, created_at
411                ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
412            ",
413                )
414                .map_err(StorageError::from)?;
415
416            for chunk in chunks {
417                let custom_meta = chunk.metadata.custom.clone();
418
419                let (line_start, line_end) = chunk
420                    .metadata
421                    .line_range
422                    .as_ref()
423                    .map_or((None, None), |r| (Some(r.start as i64), Some(r.end as i64)));
424
425                stmt.execute(params![
426                    buffer_id,
427                    chunk.content,
428                    chunk.byte_range.start as i64,
429                    chunk.byte_range.end as i64,
430                    chunk.index as i64,
431                    chunk.metadata.strategy,
432                    chunk.metadata.token_count.map(|c| c as i64),
433                    line_start,
434                    line_end,
435                    i64::from(chunk.metadata.has_overlap),
436                    chunk.metadata.content_hash,
437                    custom_meta,
438                    now,
439                ])
440                .map_err(StorageError::from)?;
441            }
442        }
443
444        tx.commit().map_err(StorageError::from)?;
445
446        // Update chunk count on buffer
447        self.conn
448            .execute(
449                "UPDATE buffers SET chunk_count = ? WHERE id = ?",
450                params![chunks.len() as i64, buffer_id],
451            )
452            .map_err(StorageError::from)?;
453
454        Ok(())
455    }
456
457    fn get_chunks(&self, buffer_id: i64) -> Result<Vec<Chunk>> {
458        let mut stmt = self
459            .conn
460            .prepare(
461                r"
462            SELECT id, buffer_id, content, byte_start, byte_end, chunk_index,
463                   strategy, token_count, line_start, line_end, has_overlap,
464                   content_hash, custom_metadata, created_at
465            FROM chunks WHERE buffer_id = ? ORDER BY chunk_index
466        ",
467            )
468            .map_err(StorageError::from)?;
469
470        let chunks = stmt
471            .query_map(params![buffer_id], |row| {
472                let line_start: Option<i64> = row.get(8)?;
473                let line_end: Option<i64> = row.get(9)?;
474                let line_range = match (line_start, line_end) {
475                    (Some(s), Some(e)) => Some((s as usize)..(e as usize)),
476                    _ => None,
477                };
478
479                Ok(Chunk {
480                    id: Some(row.get::<_, i64>(0)?),
481                    buffer_id: row.get(1)?,
482                    content: row.get(2)?,
483                    byte_range: (row.get::<_, i64>(3)? as usize)..(row.get::<_, i64>(4)? as usize),
484                    index: row.get::<_, i64>(5)? as usize,
485                    metadata: ChunkMetadata {
486                        strategy: row.get(6)?,
487                        token_count: row.get::<_, Option<i64>>(7)?.map(|c| c as usize),
488                        line_range,
489                        has_overlap: row.get::<_, i64>(10)? != 0,
490                        content_hash: row.get(11)?,
491                        custom: row.get(12)?,
492                        created_at: row.get(13)?,
493                    },
494                })
495            })
496            .map_err(StorageError::from)?
497            .collect::<std::result::Result<Vec<_>, _>>()
498            .map_err(StorageError::from)?;
499
500        Ok(chunks)
501    }
502
503    fn get_chunk(&self, id: i64) -> Result<Option<Chunk>> {
504        let result = self
505            .conn
506            .query_row(
507                r"
508            SELECT id, buffer_id, content, byte_start, byte_end, chunk_index,
509                   strategy, token_count, line_start, line_end, has_overlap,
510                   content_hash, custom_metadata, created_at
511            FROM chunks WHERE id = ?
512        ",
513                params![id],
514                |row| {
515                    let line_start: Option<i64> = row.get(8)?;
516                    let line_end: Option<i64> = row.get(9)?;
517                    let line_range = match (line_start, line_end) {
518                        (Some(s), Some(e)) => Some((s as usize)..(e as usize)),
519                        _ => None,
520                    };
521
522                    Ok(Chunk {
523                        id: Some(row.get::<_, i64>(0)?),
524                        buffer_id: row.get(1)?,
525                        content: row.get(2)?,
526                        byte_range: (row.get::<_, i64>(3)? as usize)
527                            ..(row.get::<_, i64>(4)? as usize),
528                        index: row.get::<_, i64>(5)? as usize,
529                        metadata: ChunkMetadata {
530                            strategy: row.get(6)?,
531                            token_count: row.get::<_, Option<i64>>(7)?.map(|c| c as usize),
532                            line_range,
533                            has_overlap: row.get::<_, i64>(10)? != 0,
534                            content_hash: row.get(11)?,
535                            custom: row.get(12)?,
536                            created_at: row.get(13)?,
537                        },
538                    })
539                },
540            )
541            .optional()
542            .map_err(StorageError::from)?;
543
544        Ok(result)
545    }
546
547    fn delete_chunks(&mut self, buffer_id: i64) -> Result<()> {
548        self.conn
549            .execute("DELETE FROM chunks WHERE buffer_id = ?", params![buffer_id])
550            .map_err(StorageError::from)?;
551
552        // Update chunk count on buffer
553        self.conn
554            .execute(
555                "UPDATE buffers SET chunk_count = 0 WHERE id = ?",
556                params![buffer_id],
557            )
558            .map_err(StorageError::from)?;
559
560        Ok(())
561    }
562
563    fn chunk_count(&self, buffer_id: i64) -> Result<usize> {
564        let count: i64 = self
565            .conn
566            .query_row(
567                "SELECT COUNT(*) FROM chunks WHERE buffer_id = ?",
568                params![buffer_id],
569                |row| row.get(0),
570            )
571            .map_err(StorageError::from)?;
572        Ok(count as usize)
573    }
574
575    // ==================== Utility Operations ====================
576
577    fn export_buffers(&self) -> Result<String> {
578        let buffers = self.list_buffers()?;
579        let mut output = String::new();
580
581        for (i, buffer) in buffers.iter().enumerate() {
582            if i > 0 {
583                output.push_str("\n\n");
584            }
585            output.push_str(&buffer.content);
586        }
587
588        Ok(output)
589    }
590
591    fn stats(&self) -> Result<StorageStats> {
592        let buffer_count = self.buffer_count()?;
593
594        let chunk_count: i64 = self
595            .conn
596            .query_row("SELECT COUNT(*) FROM chunks", [], |row| row.get(0))
597            .map_err(StorageError::from)?;
598
599        let total_size: i64 = self
600            .conn
601            .query_row("SELECT COALESCE(SUM(size), 0) FROM buffers", [], |row| {
602                row.get(0)
603            })
604            .map_err(StorageError::from)?;
605
606        let has_context = self.load_context()?.is_some();
607
608        let schema_version = self.get_schema_version()?.unwrap_or(0);
609
610        let db_size = self
611            .path
612            .as_ref()
613            .and_then(|p| std::fs::metadata(p).ok().map(|m| m.len()));
614
615        Ok(StorageStats {
616            buffer_count,
617            chunk_count: chunk_count as usize,
618            total_content_size: total_size as usize,
619            has_context,
620            schema_version,
621            db_size,
622        })
623    }
624}
625
626// ==================== Embedding & Search Operations ====================
627
628impl SqliteStorage {
629    /// Retrieves multiple chunks by their IDs in a single query.
630    ///
631    /// More efficient than calling [`Storage::get_chunk`] repeatedly when fetching
632    /// several chunks at once (e.g., populating search result previews).
633    ///
634    /// Returns a map from chunk ID to [`Chunk`].
635    ///
636    /// # Errors
637    ///
638    /// Returns an error if the query fails.
639    pub fn get_chunks_by_ids(&self, ids: &[i64]) -> Result<std::collections::HashMap<i64, Chunk>> {
640        if ids.is_empty() {
641            return Ok(std::collections::HashMap::new());
642        }
643
644        let placeholders = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
645        let sql = format!(
646            "SELECT id, buffer_id, content, byte_start, byte_end, chunk_index, \
647             strategy, token_count, line_start, line_end, has_overlap, \
648             content_hash, custom_metadata, created_at \
649             FROM chunks WHERE id IN ({placeholders})"
650        );
651
652        let mut stmt = self.conn.prepare(&sql).map_err(StorageError::from)?;
653
654        let chunks = stmt
655            .query_map(rusqlite::params_from_iter(ids.iter().copied()), |row| {
656                let line_start: Option<i64> = row.get(8)?;
657                let line_end: Option<i64> = row.get(9)?;
658                let line_range = match (line_start, line_end) {
659                    (Some(s), Some(e)) => Some((s as usize)..(e as usize)),
660                    _ => None,
661                };
662
663                Ok(Chunk {
664                    id: Some(row.get::<_, i64>(0)?),
665                    buffer_id: row.get(1)?,
666                    content: row.get(2)?,
667                    byte_range: (row.get::<_, i64>(3)? as usize)..(row.get::<_, i64>(4)? as usize),
668                    index: row.get::<_, i64>(5)? as usize,
669                    metadata: ChunkMetadata {
670                        strategy: row.get(6)?,
671                        token_count: row.get::<_, Option<i64>>(7)?.map(|c| c as usize),
672                        line_range,
673                        has_overlap: row.get::<_, i64>(10)? != 0,
674                        content_hash: row.get(11)?,
675                        custom: row.get(12)?,
676                        created_at: row.get(13)?,
677                    },
678                })
679            })
680            .map_err(StorageError::from)?
681            .collect::<std::result::Result<Vec<_>, _>>()
682            .map_err(StorageError::from)?;
683
684        Ok(chunks
685            .into_iter()
686            .filter_map(|c| c.id.map(|id| (id, c)))
687            .collect())
688    }
689
690    /// Stores an embedding for a chunk.
691    ///
692    /// # Arguments
693    ///
694    /// * `chunk_id` - The chunk ID to associate the embedding with.
695    /// * `embedding` - The embedding vector (f32 array).
696    /// * `model_name` - Optional name of the model that generated the embedding.
697    ///
698    /// # Errors
699    ///
700    /// Returns an error if the embedding cannot be stored.
701    #[allow(clippy::cast_possible_wrap)]
702    pub fn store_embedding(
703        &mut self,
704        chunk_id: i64,
705        embedding: &[f32],
706        model_name: Option<&str>,
707    ) -> Result<()> {
708        let now = Self::now();
709
710        // Serialize f32 array to bytes (little-endian)
711        let mut bytes = Vec::with_capacity(embedding.len() * 4);
712        for f in embedding {
713            bytes.extend_from_slice(&f.to_le_bytes());
714        }
715
716        self.conn
717            .execute(
718                r"
719                INSERT OR REPLACE INTO chunk_embeddings (chunk_id, embedding, dimensions, model_name, created_at)
720                VALUES (?, ?, ?, ?, ?)
721            ",
722                params![chunk_id, bytes, embedding.len() as i64, model_name, now],
723            )
724            .map_err(StorageError::from)?;
725
726        Ok(())
727    }
728
729    /// Retrieves the embedding for a chunk.
730    ///
731    /// # Errors
732    ///
733    /// Returns an error if the query fails.
734    pub fn get_embedding(&self, chunk_id: i64) -> Result<Option<Vec<f32>>> {
735        let result: Option<Vec<u8>> = self
736            .conn
737            .query_row(
738                "SELECT embedding FROM chunk_embeddings WHERE chunk_id = ?",
739                params![chunk_id],
740                |row| row.get(0),
741            )
742            .optional()
743            .map_err(StorageError::from)?;
744
745        Ok(result.map(|bytes| {
746            bytes
747                .chunks_exact(4)
748                .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
749                .collect()
750        }))
751    }
752
753    /// Gets the distinct model names used for embeddings in a buffer.
754    ///
755    /// Returns the set of model names used to generate embeddings for
756    /// chunks belonging to the specified buffer.
757    ///
758    /// # Errors
759    ///
760    /// Returns an error if the query fails.
761    pub fn get_embedding_models(&self, buffer_id: i64) -> Result<Vec<String>> {
762        let mut stmt = self
763            .conn
764            .prepare(
765                r"
766                SELECT DISTINCT ce.model_name
767                FROM chunk_embeddings ce
768                JOIN chunks c ON ce.chunk_id = c.id
769                WHERE c.buffer_id = ? AND ce.model_name IS NOT NULL
770                ",
771            )
772            .map_err(StorageError::from)?;
773
774        let models = stmt
775            .query_map(params![buffer_id], |row| row.get::<_, String>(0))
776            .map_err(StorageError::from)?
777            .filter_map(std::result::Result::ok)
778            .collect();
779
780        Ok(models)
781    }
782
783    /// Gets the count of embeddings by model name for a buffer.
784    ///
785    /// Returns a list of (`model_name`, count) pairs.
786    ///
787    /// # Errors
788    ///
789    /// Returns an error if the query fails.
790    pub fn get_embedding_model_counts(&self, buffer_id: i64) -> Result<Vec<(Option<String>, i64)>> {
791        let mut stmt = self
792            .conn
793            .prepare(
794                r"
795                SELECT ce.model_name, COUNT(*) as count
796                FROM chunk_embeddings ce
797                JOIN chunks c ON ce.chunk_id = c.id
798                WHERE c.buffer_id = ?
799                GROUP BY ce.model_name
800                ",
801            )
802            .map_err(StorageError::from)?;
803
804        let counts = stmt
805            .query_map(params![buffer_id], |row| {
806                Ok((row.get::<_, Option<String>>(0)?, row.get::<_, i64>(1)?))
807            })
808            .map_err(StorageError::from)?
809            .filter_map(std::result::Result::ok)
810            .collect();
811
812        Ok(counts)
813    }
814
815    /// Stores embeddings for multiple chunks in a batch.
816    ///
817    /// # Errors
818    ///
819    /// Returns an error if any embedding cannot be stored.
820    #[allow(clippy::cast_possible_wrap)]
821    pub fn store_embeddings_batch(
822        &mut self,
823        embeddings: &[(i64, Vec<f32>)],
824        model_name: Option<&str>,
825    ) -> Result<()> {
826        let tx = self.conn.transaction().map_err(StorageError::from)?;
827        let now = Self::now();
828
829        {
830            let mut stmt = tx
831                .prepare(
832                    r"
833                    INSERT OR REPLACE INTO chunk_embeddings (chunk_id, embedding, dimensions, model_name, created_at)
834                    VALUES (?, ?, ?, ?, ?)
835                ",
836                )
837                .map_err(StorageError::from)?;
838
839            for (chunk_id, embedding) in embeddings {
840                let mut bytes = Vec::with_capacity(embedding.len() * 4);
841                bytes.extend(embedding.iter().flat_map(|f| f.to_le_bytes()));
842
843                stmt.execute(params![
844                    chunk_id,
845                    bytes,
846                    embedding.len() as i64,
847                    model_name,
848                    now
849                ])
850                .map_err(StorageError::from)?;
851            }
852        }
853
854        tx.commit().map_err(StorageError::from)?;
855        Ok(())
856    }
857
858    /// Deletes the embedding for a chunk.
859    ///
860    /// # Errors
861    ///
862    /// Returns an error if deletion fails.
863    pub fn delete_embedding(&mut self, chunk_id: i64) -> Result<()> {
864        self.conn
865            .execute(
866                "DELETE FROM chunk_embeddings WHERE chunk_id = ?",
867                params![chunk_id],
868            )
869            .map_err(StorageError::from)?;
870        Ok(())
871    }
872
873    /// Performs FTS5 BM25 full-text search.
874    ///
875    /// Returns chunk IDs and their BM25 scores (higher is better match).
876    ///
877    /// # Arguments
878    ///
879    /// * `query` - The search query (supports FTS5 query syntax).
880    /// * `limit` - Maximum number of results to return.
881    ///
882    /// # Errors
883    ///
884    /// Returns an error if the search fails.
885    #[allow(clippy::cast_possible_wrap)]
886    pub fn search_fts(&self, query: &str, limit: usize) -> Result<Vec<(i64, f64)>> {
887        // FTS5 bm25() returns negative scores, more negative = better match
888        // We negate it so higher scores = better match
889
890        // Convert space-separated terms to OR query for more forgiving search
891        // Each term is quoted to escape FTS5 special characters (?, *, ^, etc.)
892        // "CLI tool?" becomes '"CLI" OR "tool?"' so special chars are treated as literals
893        let fts_query = query
894            .split_whitespace()
895            .map(|term| format!("\"{}\"", term.replace('"', "\"\"")))
896            .collect::<Vec<_>>()
897            .join(" OR ");
898
899        let mut stmt = self
900            .conn
901            .prepare(
902                r"
903                SELECT rowid, -bm25(chunks_fts) as score
904                FROM chunks_fts
905                WHERE chunks_fts MATCH ?
906                ORDER BY score DESC
907                LIMIT ?
908            ",
909            )
910            .map_err(StorageError::from)?;
911
912        let results = stmt
913            .query_map(params![fts_query, limit as i64], |row| {
914                Ok((row.get::<_, i64>(0)?, row.get::<_, f64>(1)?))
915            })
916            .map_err(StorageError::from)?
917            .collect::<std::result::Result<Vec<_>, _>>()
918            .map_err(StorageError::from)?;
919
920        Ok(results)
921    }
922
923    /// Returns all chunk embeddings for vector similarity search.
924    ///
925    /// # Errors
926    ///
927    /// Returns an error if the query fails.
928    pub fn get_all_embeddings(&self) -> Result<Vec<(i64, Vec<f32>)>> {
929        let mut stmt = self
930            .conn
931            .prepare("SELECT chunk_id, embedding FROM chunk_embeddings")
932            .map_err(StorageError::from)?;
933
934        let results = stmt
935            .query_map([], |row| {
936                let chunk_id: i64 = row.get(0)?;
937                let bytes: Vec<u8> = row.get(1)?;
938                let embedding: Vec<f32> = bytes
939                    .chunks_exact(4)
940                    .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
941                    .collect();
942                Ok((chunk_id, embedding))
943            })
944            .map_err(StorageError::from)?
945            .collect::<std::result::Result<Vec<_>, _>>()
946            .map_err(StorageError::from)?;
947
948        Ok(results)
949    }
950
951    /// Counts chunks with embeddings.
952    ///
953    /// # Errors
954    ///
955    /// Returns an error if the count fails.
956    pub fn embedding_count(&self) -> Result<usize> {
957        let count: i64 = self
958            .conn
959            .query_row("SELECT COUNT(*) FROM chunk_embeddings", [], |row| {
960                row.get(0)
961            })
962            .map_err(StorageError::from)?;
963        Ok(count as usize)
964    }
965
966    /// Returns `true` if every chunk in the buffer has an embedding (or the buffer has no chunks).
967    ///
968    /// Uses a single `NOT EXISTS` query instead of per-chunk lookups, making it O(1) in terms
969    /// of round-trips regardless of how many chunks the buffer contains.
970    ///
971    /// # Errors
972    ///
973    /// Returns an error if the query fails.
974    pub fn all_chunks_have_embeddings(&self, buffer_id: i64) -> Result<bool> {
975        let result: i64 = self
976            .conn
977            .query_row(
978                r"
979                SELECT NOT EXISTS (
980                    SELECT 1 FROM chunks c
981                    LEFT JOIN chunk_embeddings e ON e.chunk_id = c.id
982                    WHERE c.buffer_id = ? AND e.chunk_id IS NULL
983                )
984                ",
985                params![buffer_id],
986                |row| row.get(0),
987            )
988            .map_err(StorageError::from)?;
989        Ok(result != 0)
990    }
991
992    /// Checks if a chunk has an embedding.
993    ///
994    /// # Errors
995    ///
996    /// Returns an error if the query fails.
997    pub fn has_embedding(&self, chunk_id: i64) -> Result<bool> {
998        let count: i64 = self
999            .conn
1000            .query_row(
1001                "SELECT COUNT(*) FROM chunk_embeddings WHERE chunk_id = ?",
1002                params![chunk_id],
1003                |row| row.get(0),
1004            )
1005            .map_err(StorageError::from)?;
1006        Ok(count > 0)
1007    }
1008
1009    /// Gets chunk IDs that need embedding (either no embedding or wrong model).
1010    ///
1011    /// This is used for incremental embedding updates. Returns chunks that:
1012    /// - Have no embedding at all, OR
1013    /// - Have an embedding with a different model name (if `current_model` is provided)
1014    ///
1015    /// # Arguments
1016    ///
1017    /// * `buffer_id` - The buffer to check.
1018    /// * `current_model` - Optional model name to check against. If provided,
1019    ///   chunks with different models are included.
1020    ///
1021    /// # Errors
1022    ///
1023    /// Returns an error if the query fails.
1024    pub fn get_chunks_needing_embedding(
1025        &self,
1026        buffer_id: i64,
1027        current_model: Option<&str>,
1028    ) -> Result<Vec<i64>> {
1029        let mut results = Vec::new();
1030
1031        // Get chunks without any embedding
1032        let mut stmt = self
1033            .conn
1034            .prepare(
1035                r"
1036                SELECT c.id FROM chunks c
1037                LEFT JOIN chunk_embeddings e ON c.id = e.chunk_id
1038                WHERE c.buffer_id = ? AND e.chunk_id IS NULL
1039                ",
1040            )
1041            .map_err(StorageError::from)?;
1042
1043        let rows = stmt
1044            .query_map(params![buffer_id], |row| row.get(0))
1045            .map_err(StorageError::from)?;
1046
1047        for row in rows {
1048            results.push(row.map_err(StorageError::from)?);
1049        }
1050
1051        // If model specified, also get chunks with different model
1052        if let Some(model) = current_model {
1053            let mut stmt = self
1054                .conn
1055                .prepare(
1056                    r"
1057                    SELECT c.id FROM chunks c
1058                    INNER JOIN chunk_embeddings e ON c.id = e.chunk_id
1059                    WHERE c.buffer_id = ? AND (e.model_name IS NULL OR e.model_name != ?)
1060                    ",
1061                )
1062                .map_err(StorageError::from)?;
1063
1064            let rows = stmt
1065                .query_map(params![buffer_id, model], |row| row.get(0))
1066                .map_err(StorageError::from)?;
1067
1068            for row in rows {
1069                results.push(row.map_err(StorageError::from)?);
1070            }
1071        }
1072
1073        // Deduplicate (in case of overlap, though shouldn't happen)
1074        results.sort_unstable();
1075        results.dedup();
1076        Ok(results)
1077    }
1078
1079    /// Gets chunks without any embedding for a buffer.
1080    ///
1081    /// Simpler version of `get_chunks_needing_embedding` when model doesn't matter.
1082    ///
1083    /// # Errors
1084    ///
1085    /// Returns an error if the query fails.
1086    pub fn get_chunks_without_embedding(&self, buffer_id: i64) -> Result<Vec<i64>> {
1087        self.get_chunks_needing_embedding(buffer_id, None)
1088    }
1089
1090    /// Deletes embeddings with a specific model name.
1091    ///
1092    /// Useful for cleaning up embeddings from old models before re-embedding.
1093    ///
1094    /// # Arguments
1095    ///
1096    /// * `buffer_id` - The buffer to clean.
1097    /// * `model_name` - The model name to match (or None to match NULL).
1098    ///
1099    /// # Returns
1100    ///
1101    /// The number of embeddings deleted.
1102    ///
1103    /// # Errors
1104    ///
1105    /// Returns an error if deletion fails.
1106    pub fn delete_embeddings_by_model(
1107        &mut self,
1108        buffer_id: i64,
1109        model_name: Option<&str>,
1110    ) -> Result<usize> {
1111        let deleted = match model_name {
1112            Some(name) => self
1113                .conn
1114                .execute(
1115                    r"
1116                    DELETE FROM chunk_embeddings
1117                    WHERE chunk_id IN (
1118                        SELECT id FROM chunks WHERE buffer_id = ?
1119                    ) AND model_name = ?
1120                    ",
1121                    params![buffer_id, name],
1122                )
1123                .map_err(StorageError::from)?,
1124            None => self
1125                .conn
1126                .execute(
1127                    r"
1128                    DELETE FROM chunk_embeddings
1129                    WHERE chunk_id IN (
1130                        SELECT id FROM chunks WHERE buffer_id = ?
1131                    ) AND model_name IS NULL
1132                    ",
1133                    params![buffer_id],
1134                )
1135                .map_err(StorageError::from)?,
1136        };
1137        Ok(deleted)
1138    }
1139
1140    /// Gets embedding statistics for a buffer.
1141    ///
1142    /// Returns counts of embedded vs total chunks, and model breakdown.
1143    ///
1144    /// # Errors
1145    ///
1146    /// Returns an error if the query fails.
1147    pub fn get_embedding_stats(&self, buffer_id: i64) -> Result<EmbeddingStats> {
1148        // Total chunks
1149        let total_chunks: i64 = self
1150            .conn
1151            .query_row(
1152                "SELECT COUNT(*) FROM chunks WHERE buffer_id = ?",
1153                params![buffer_id],
1154                |row| row.get(0),
1155            )
1156            .map_err(StorageError::from)?;
1157
1158        // Embedded chunks
1159        let embedded_chunks: i64 = self
1160            .conn
1161            .query_row(
1162                r"
1163                SELECT COUNT(*) FROM chunk_embeddings e
1164                INNER JOIN chunks c ON e.chunk_id = c.id
1165                WHERE c.buffer_id = ?
1166                ",
1167                params![buffer_id],
1168                |row| row.get(0),
1169            )
1170            .map_err(StorageError::from)?;
1171
1172        // Model counts
1173        let model_counts = self.get_embedding_model_counts(buffer_id)?;
1174
1175        Ok(EmbeddingStats {
1176            total_chunks: total_chunks as usize,
1177            embedded_chunks: embedded_chunks as usize,
1178            model_counts,
1179        })
1180    }
1181}
1182
1183/// Statistics about embeddings for a buffer.
1184#[derive(Debug, Clone)]
1185pub struct EmbeddingStats {
1186    /// Total number of chunks in the buffer.
1187    pub total_chunks: usize,
1188    /// Number of chunks with embeddings.
1189    pub embedded_chunks: usize,
1190    /// Count of embeddings by model (`model_name`, count).
1191    pub model_counts: Vec<(Option<String>, i64)>,
1192}
1193
1194#[cfg(test)]
1195mod tests {
1196    use super::*;
1197    use crate::core::ContextValue;
1198
1199    fn setup() -> SqliteStorage {
1200        let mut storage = SqliteStorage::in_memory().unwrap();
1201        storage.init().unwrap();
1202        storage
1203    }
1204
1205    #[test]
1206    fn test_init() {
1207        let mut storage = SqliteStorage::in_memory().unwrap();
1208        assert!(storage.init().is_ok());
1209        assert!(storage.is_initialized().unwrap());
1210    }
1211
1212    #[test]
1213    fn test_init_idempotent() {
1214        let mut storage = SqliteStorage::in_memory().unwrap();
1215        assert!(storage.init().is_ok());
1216        assert!(storage.init().is_ok()); // Second init should be fine
1217    }
1218
1219    #[test]
1220    fn test_context_crud() {
1221        let mut storage = setup();
1222
1223        // No context initially
1224        assert!(storage.load_context().unwrap().is_none());
1225
1226        // Save context
1227        let mut ctx = Context::new();
1228        ctx.set_variable("key".to_string(), ContextValue::String("value".to_string()));
1229        storage.save_context(&ctx).unwrap();
1230
1231        // Load context
1232        let loaded = storage.load_context().unwrap().unwrap();
1233        assert_eq!(
1234            loaded.get_variable("key"),
1235            Some(&ContextValue::String("value".to_string()))
1236        );
1237
1238        // Delete context
1239        storage.delete_context().unwrap();
1240        assert!(storage.load_context().unwrap().is_none());
1241    }
1242
1243    #[test]
1244    fn test_buffer_crud() {
1245        let mut storage = setup();
1246
1247        // Add buffer
1248        let buffer = Buffer::from_named("test".to_string(), "Hello, world!".to_string());
1249        let id = storage.add_buffer(&buffer).unwrap();
1250        assert!(id > 0);
1251
1252        // Get buffer
1253        let loaded = storage.get_buffer(id).unwrap().unwrap();
1254        assert_eq!(loaded.name, Some("test".to_string()));
1255        assert_eq!(loaded.content, "Hello, world!");
1256
1257        // Get by name
1258        let by_name = storage.get_buffer_by_name("test").unwrap().unwrap();
1259        assert_eq!(by_name.id, Some(id));
1260
1261        // List buffers
1262        let buffers = storage.list_buffers().unwrap();
1263        assert_eq!(buffers.len(), 1);
1264
1265        // Update buffer
1266        let mut updated = loaded;
1267        updated.content = "Updated content".to_string();
1268        storage.update_buffer(&updated).unwrap();
1269
1270        let reloaded = storage.get_buffer(id).unwrap().unwrap();
1271        assert_eq!(reloaded.content, "Updated content");
1272
1273        // Delete buffer
1274        storage.delete_buffer(id).unwrap();
1275        assert!(storage.get_buffer(id).unwrap().is_none());
1276    }
1277
1278    #[test]
1279    fn test_chunk_crud() {
1280        let mut storage = setup();
1281
1282        // Create buffer first
1283        let buffer = Buffer::from_content("Hello, world!".to_string());
1284        let buffer_id = storage.add_buffer(&buffer).unwrap();
1285
1286        // Add chunks
1287        let chunks = vec![
1288            Chunk::new(buffer_id, "Hello, ".to_string(), 0..7, 0),
1289            Chunk::new(buffer_id, "world!".to_string(), 7..13, 1),
1290        ];
1291        storage.add_chunks(buffer_id, &chunks).unwrap();
1292
1293        // Get chunks
1294        let loaded = storage.get_chunks(buffer_id).unwrap();
1295        assert_eq!(loaded.len(), 2);
1296        assert_eq!(loaded[0].content, "Hello, ");
1297        assert_eq!(loaded[1].content, "world!");
1298
1299        // Chunk count
1300        assert_eq!(storage.chunk_count(buffer_id).unwrap(), 2);
1301
1302        // Get single chunk
1303        let chunk_id = loaded[0].id.unwrap();
1304        let single = storage.get_chunk(chunk_id).unwrap().unwrap();
1305        assert_eq!(single.content, "Hello, ");
1306
1307        // Delete chunks
1308        storage.delete_chunks(buffer_id).unwrap();
1309        assert_eq!(storage.chunk_count(buffer_id).unwrap(), 0);
1310    }
1311
1312    #[test]
1313    fn test_cascade_delete() {
1314        let mut storage = setup();
1315
1316        // Create buffer with chunks
1317        let buffer = Buffer::from_content("Hello, world!".to_string());
1318        let buffer_id = storage.add_buffer(&buffer).unwrap();
1319
1320        let chunks = vec![Chunk::new(buffer_id, "Hello".to_string(), 0..5, 0)];
1321        storage.add_chunks(buffer_id, &chunks).unwrap();
1322
1323        // Verify chunk exists
1324        assert_eq!(storage.chunk_count(buffer_id).unwrap(), 1);
1325
1326        // Delete buffer - chunks should be deleted too
1327        storage.delete_buffer(buffer_id).unwrap();
1328
1329        // Verify no orphan chunks (query all chunks)
1330        let count: i64 = storage
1331            .conn
1332            .query_row("SELECT COUNT(*) FROM chunks", [], |row| row.get(0))
1333            .unwrap();
1334        assert_eq!(count, 0);
1335    }
1336
1337    #[test]
1338    fn test_reset() {
1339        let mut storage = setup();
1340
1341        // Add some data
1342        let ctx = Context::new();
1343        storage.save_context(&ctx).unwrap();
1344
1345        let buffer = Buffer::from_content("test".to_string());
1346        storage.add_buffer(&buffer).unwrap();
1347
1348        // Reset
1349        storage.reset().unwrap();
1350
1351        // Verify empty
1352        assert!(storage.load_context().unwrap().is_none());
1353        assert_eq!(storage.buffer_count().unwrap(), 0);
1354    }
1355
1356    #[test]
1357    fn test_stats() {
1358        let mut storage = setup();
1359
1360        // Empty stats
1361        let stats = storage.stats().unwrap();
1362        assert_eq!(stats.buffer_count, 0);
1363        assert_eq!(stats.chunk_count, 0);
1364        assert!(!stats.has_context);
1365
1366        // Add data
1367        let ctx = Context::new();
1368        storage.save_context(&ctx).unwrap();
1369
1370        let buffer = Buffer::from_content("Hello, world!".to_string());
1371        let buffer_id = storage.add_buffer(&buffer).unwrap();
1372
1373        let chunks = vec![Chunk::new(buffer_id, "Hello".to_string(), 0..5, 0)];
1374        storage.add_chunks(buffer_id, &chunks).unwrap();
1375
1376        // Stats with data
1377        let stats = storage.stats().unwrap();
1378        assert_eq!(stats.buffer_count, 1);
1379        assert_eq!(stats.chunk_count, 1);
1380        assert!(stats.has_context);
1381        assert_eq!(stats.total_content_size, 13);
1382    }
1383
1384    #[test]
1385    fn test_export_buffers() {
1386        let mut storage = setup();
1387
1388        storage
1389            .add_buffer(&Buffer::from_content("First".to_string()))
1390            .unwrap();
1391        storage
1392            .add_buffer(&Buffer::from_content("Second".to_string()))
1393            .unwrap();
1394
1395        let exported = storage.export_buffers().unwrap();
1396        assert_eq!(exported, "First\n\nSecond");
1397    }
1398
1399    // Helper: create a buffer with one chunk and return (buffer_id, chunk_id).
1400    fn setup_buffer_with_chunk(storage: &mut SqliteStorage) -> (i64, i64) {
1401        let buffer = Buffer::from_content("test content".to_string());
1402        let buffer_id = storage.add_buffer(&buffer).unwrap();
1403        let chunks = vec![Chunk::new(buffer_id, "test content".to_string(), 0..12, 0)];
1404        storage.add_chunks(buffer_id, &chunks).unwrap();
1405        let chunk_id = storage.get_chunks(buffer_id).unwrap()[0].id.unwrap();
1406        (buffer_id, chunk_id)
1407    }
1408
1409    #[test]
1410    fn test_store_and_get_embedding() {
1411        let mut storage = setup();
1412        let (_buffer_id, chunk_id) = setup_buffer_with_chunk(&mut storage);
1413
1414        let embedding = vec![0.1_f32, 0.2, 0.3, 0.4];
1415        storage
1416            .store_embedding(chunk_id, &embedding, Some("test-model"))
1417            .unwrap();
1418
1419        let loaded = storage.get_embedding(chunk_id).unwrap().unwrap();
1420        assert_eq!(loaded.len(), 4);
1421        for (a, b) in loaded.iter().zip(embedding.iter()) {
1422            assert!((a - b).abs() < 1e-6, "expected {b}, got {a}");
1423        }
1424    }
1425
1426    #[test]
1427    fn test_get_embedding_nonexistent() {
1428        let storage = setup();
1429        let result = storage.get_embedding(9999).unwrap();
1430        assert!(result.is_none());
1431    }
1432
1433    #[test]
1434    fn test_store_embedding_upsert() {
1435        let mut storage = setup();
1436        let (_buffer_id, chunk_id) = setup_buffer_with_chunk(&mut storage);
1437
1438        let embedding1 = vec![0.1_f32, 0.2];
1439        storage
1440            .store_embedding(chunk_id, &embedding1, Some("model-a"))
1441            .unwrap();
1442
1443        // Upsert with new values
1444        let embedding2 = vec![0.9_f32, 0.8];
1445        storage
1446            .store_embedding(chunk_id, &embedding2, Some("model-a"))
1447            .unwrap();
1448
1449        let loaded = storage.get_embedding(chunk_id).unwrap().unwrap();
1450        for (a, b) in loaded.iter().zip(embedding2.iter()) {
1451            assert!((a - b).abs() < 1e-6);
1452        }
1453    }
1454
1455    #[test]
1456    fn test_has_embedding() {
1457        let mut storage = setup();
1458        let (_buffer_id, chunk_id) = setup_buffer_with_chunk(&mut storage);
1459
1460        assert!(!storage.has_embedding(chunk_id).unwrap());
1461
1462        storage.store_embedding(chunk_id, &[0.1_f32], None).unwrap();
1463
1464        assert!(storage.has_embedding(chunk_id).unwrap());
1465    }
1466
1467    #[test]
1468    fn test_embedding_count() {
1469        let mut storage = setup();
1470        assert_eq!(storage.embedding_count().unwrap(), 0);
1471
1472        let (buffer_id, chunk_id) = setup_buffer_with_chunk(&mut storage);
1473
1474        // Add a second chunk
1475        let chunks2 = vec![Chunk::new(buffer_id, "more".to_string(), 0..4, 1)];
1476        storage.add_chunks(buffer_id, &chunks2).unwrap();
1477        let chunk_id2 = storage
1478            .get_chunks(buffer_id)
1479            .unwrap()
1480            .into_iter()
1481            .find(|c| c.id != Some(chunk_id))
1482            .unwrap()
1483            .id
1484            .unwrap();
1485
1486        storage.store_embedding(chunk_id, &[0.1_f32], None).unwrap();
1487        assert_eq!(storage.embedding_count().unwrap(), 1);
1488
1489        storage
1490            .store_embedding(chunk_id2, &[0.2_f32], None)
1491            .unwrap();
1492        assert_eq!(storage.embedding_count().unwrap(), 2);
1493    }
1494
1495    #[test]
1496    fn test_delete_embedding() {
1497        let mut storage = setup();
1498        let (_buffer_id, chunk_id) = setup_buffer_with_chunk(&mut storage);
1499
1500        storage.store_embedding(chunk_id, &[0.1_f32], None).unwrap();
1501        assert!(storage.has_embedding(chunk_id).unwrap());
1502
1503        storage.delete_embedding(chunk_id).unwrap();
1504        assert!(!storage.has_embedding(chunk_id).unwrap());
1505    }
1506
1507    #[test]
1508    fn test_store_embeddings_batch() {
1509        let mut storage = setup();
1510        let (buffer_id, chunk_id1) = setup_buffer_with_chunk(&mut storage);
1511
1512        // Add second chunk
1513        let chunks2 = vec![Chunk::new(buffer_id, "second".to_string(), 0..6, 1)];
1514        storage.add_chunks(buffer_id, &chunks2).unwrap();
1515        let chunk_id2 = storage
1516            .get_chunks(buffer_id)
1517            .unwrap()
1518            .into_iter()
1519            .find(|c| c.id != Some(chunk_id1))
1520            .unwrap()
1521            .id
1522            .unwrap();
1523
1524        let batch = vec![
1525            (chunk_id1, vec![0.1_f32, 0.2]),
1526            (chunk_id2, vec![0.3_f32, 0.4]),
1527        ];
1528        storage
1529            .store_embeddings_batch(&batch, Some("batch-model"))
1530            .unwrap();
1531
1532        assert!(storage.has_embedding(chunk_id1).unwrap());
1533        assert!(storage.has_embedding(chunk_id2).unwrap());
1534        assert_eq!(storage.embedding_count().unwrap(), 2);
1535    }
1536
1537    #[test]
1538    fn test_get_all_embeddings() {
1539        let mut storage = setup();
1540        let (buffer_id, chunk_id1) = setup_buffer_with_chunk(&mut storage);
1541
1542        let chunks2 = vec![Chunk::new(buffer_id, "second".to_string(), 0..6, 1)];
1543        storage.add_chunks(buffer_id, &chunks2).unwrap();
1544        let chunk_id2 = storage
1545            .get_chunks(buffer_id)
1546            .unwrap()
1547            .into_iter()
1548            .find(|c| c.id != Some(chunk_id1))
1549            .unwrap()
1550            .id
1551            .unwrap();
1552
1553        storage
1554            .store_embedding(chunk_id1, &[0.1_f32], Some("m"))
1555            .unwrap();
1556        storage
1557            .store_embedding(chunk_id2, &[0.2_f32], Some("m"))
1558            .unwrap();
1559
1560        let all = storage.get_all_embeddings().unwrap();
1561        assert_eq!(all.len(), 2);
1562    }
1563
1564    #[test]
1565    fn test_get_embedding_models() {
1566        let mut storage = setup();
1567        let (buffer_id, chunk_id) = setup_buffer_with_chunk(&mut storage);
1568
1569        // No models initially
1570        assert!(storage.get_embedding_models(buffer_id).unwrap().is_empty());
1571
1572        storage
1573            .store_embedding(chunk_id, &[0.1_f32], Some("model-x"))
1574            .unwrap();
1575
1576        let models = storage.get_embedding_models(buffer_id).unwrap();
1577        assert_eq!(models.len(), 1);
1578        assert_eq!(models[0], "model-x");
1579    }
1580
1581    #[test]
1582    fn test_get_embedding_model_counts() {
1583        let mut storage = setup();
1584        let (buffer_id, chunk_id1) = setup_buffer_with_chunk(&mut storage);
1585
1586        let chunks2 = vec![Chunk::new(buffer_id, "extra".to_string(), 0..5, 1)];
1587        storage.add_chunks(buffer_id, &chunks2).unwrap();
1588        let chunk_id2 = storage
1589            .get_chunks(buffer_id)
1590            .unwrap()
1591            .into_iter()
1592            .find(|c| c.id != Some(chunk_id1))
1593            .unwrap()
1594            .id
1595            .unwrap();
1596
1597        storage
1598            .store_embedding(chunk_id1, &[0.1_f32], Some("model-a"))
1599            .unwrap();
1600        storage
1601            .store_embedding(chunk_id2, &[0.2_f32], Some("model-a"))
1602            .unwrap();
1603
1604        let counts = storage.get_embedding_model_counts(buffer_id).unwrap();
1605        assert_eq!(counts.len(), 1);
1606        assert_eq!(counts[0], (Some("model-a".to_string()), 2));
1607    }
1608
1609    #[test]
1610    fn test_get_chunks_needing_embedding_no_embeddings() {
1611        let mut storage = setup();
1612        let (buffer_id, _chunk_id) = setup_buffer_with_chunk(&mut storage);
1613
1614        let needing = storage
1615            .get_chunks_needing_embedding(buffer_id, None)
1616            .unwrap();
1617        assert_eq!(needing.len(), 1);
1618    }
1619
1620    #[test]
1621    fn test_get_chunks_needing_embedding_with_model() {
1622        let mut storage = setup();
1623        let (buffer_id, chunk_id) = setup_buffer_with_chunk(&mut storage);
1624
1625        // Add second chunk
1626        let chunks2 = vec![Chunk::new(buffer_id, "b".to_string(), 0..1, 1)];
1627        storage.add_chunks(buffer_id, &chunks2).unwrap();
1628        let chunk_id2 = storage
1629            .get_chunks(buffer_id)
1630            .unwrap()
1631            .into_iter()
1632            .find(|c| c.id != Some(chunk_id))
1633            .unwrap()
1634            .id
1635            .unwrap();
1636
1637        // chunk_id has model-a, chunk_id2 has no embedding
1638        storage
1639            .store_embedding(chunk_id, &[0.1_f32], Some("model-a"))
1640            .unwrap();
1641
1642        // When checking for model-a: chunk_id2 needs one (no embedding)
1643        let needing = storage
1644            .get_chunks_needing_embedding(buffer_id, Some("model-a"))
1645            .unwrap();
1646        assert!(needing.contains(&chunk_id2));
1647        assert!(!needing.contains(&chunk_id));
1648
1649        // When checking for model-b: chunk_id has wrong model, chunk_id2 has none
1650        let needing_b = storage
1651            .get_chunks_needing_embedding(buffer_id, Some("model-b"))
1652            .unwrap();
1653        assert!(needing_b.contains(&chunk_id));
1654        assert!(needing_b.contains(&chunk_id2));
1655    }
1656
1657    #[test]
1658    fn test_get_chunks_without_embedding() {
1659        let mut storage = setup();
1660        let (buffer_id, chunk_id) = setup_buffer_with_chunk(&mut storage);
1661
1662        let without = storage.get_chunks_without_embedding(buffer_id).unwrap();
1663        assert_eq!(without.len(), 1);
1664        assert!(without.contains(&chunk_id));
1665
1666        storage.store_embedding(chunk_id, &[0.1_f32], None).unwrap();
1667
1668        let without_after = storage.get_chunks_without_embedding(buffer_id).unwrap();
1669        assert!(without_after.is_empty());
1670    }
1671
1672    #[test]
1673    fn test_delete_embeddings_by_model_named() {
1674        let mut storage = setup();
1675        let (buffer_id, chunk_id1) = setup_buffer_with_chunk(&mut storage);
1676
1677        let chunks2 = vec![Chunk::new(buffer_id, "b".to_string(), 0..1, 1)];
1678        storage.add_chunks(buffer_id, &chunks2).unwrap();
1679        let chunk_id2 = storage
1680            .get_chunks(buffer_id)
1681            .unwrap()
1682            .into_iter()
1683            .find(|c| c.id != Some(chunk_id1))
1684            .unwrap()
1685            .id
1686            .unwrap();
1687
1688        storage
1689            .store_embedding(chunk_id1, &[0.1_f32], Some("model-a"))
1690            .unwrap();
1691        storage
1692            .store_embedding(chunk_id2, &[0.2_f32], Some("model-b"))
1693            .unwrap();
1694
1695        let deleted = storage
1696            .delete_embeddings_by_model(buffer_id, Some("model-a"))
1697            .unwrap();
1698        assert_eq!(deleted, 1);
1699        assert!(!storage.has_embedding(chunk_id1).unwrap());
1700        assert!(storage.has_embedding(chunk_id2).unwrap());
1701    }
1702
1703    #[test]
1704    fn test_delete_embeddings_by_model_null() {
1705        let mut storage = setup();
1706        let (buffer_id, chunk_id1) = setup_buffer_with_chunk(&mut storage);
1707
1708        let chunks2 = vec![Chunk::new(buffer_id, "b".to_string(), 0..1, 1)];
1709        storage.add_chunks(buffer_id, &chunks2).unwrap();
1710        let chunk_id2 = storage
1711            .get_chunks(buffer_id)
1712            .unwrap()
1713            .into_iter()
1714            .find(|c| c.id != Some(chunk_id1))
1715            .unwrap()
1716            .id
1717            .unwrap();
1718
1719        storage
1720            .store_embedding(chunk_id1, &[0.1_f32], None)
1721            .unwrap();
1722        storage
1723            .store_embedding(chunk_id2, &[0.2_f32], Some("model-b"))
1724            .unwrap();
1725
1726        // Delete only the NULL-model embedding
1727        let deleted = storage.delete_embeddings_by_model(buffer_id, None).unwrap();
1728        assert_eq!(deleted, 1);
1729        assert!(!storage.has_embedding(chunk_id1).unwrap());
1730        assert!(storage.has_embedding(chunk_id2).unwrap());
1731    }
1732
1733    #[test]
1734    fn test_get_embedding_stats() {
1735        let mut storage = setup();
1736        let (buffer_id, chunk_id) = setup_buffer_with_chunk(&mut storage);
1737
1738        // Add second chunk
1739        let chunks2 = vec![Chunk::new(buffer_id, "extra".to_string(), 0..5, 1)];
1740        storage.add_chunks(buffer_id, &chunks2).unwrap();
1741
1742        // Stats before embedding
1743        let stats = storage.get_embedding_stats(buffer_id).unwrap();
1744        assert_eq!(stats.total_chunks, 2);
1745        assert_eq!(stats.embedded_chunks, 0);
1746
1747        // Embed one chunk
1748        storage
1749            .store_embedding(chunk_id, &[0.1_f32], Some("m1"))
1750            .unwrap();
1751
1752        let stats = storage.get_embedding_stats(buffer_id).unwrap();
1753        assert_eq!(stats.total_chunks, 2);
1754        assert_eq!(stats.embedded_chunks, 1);
1755        assert_eq!(stats.model_counts.len(), 1);
1756    }
1757
1758    #[test]
1759    fn test_search_fts_finds_match() {
1760        let mut storage = setup();
1761        let buffer = Buffer::from_content("The quick brown fox jumps".to_string());
1762        let buffer_id = storage.add_buffer(&buffer).unwrap();
1763
1764        let chunks = vec![
1765            Chunk::new(buffer_id, "The quick brown fox".to_string(), 0..19, 0),
1766            Chunk::new(buffer_id, "fox jumps".to_string(), 20..29, 1),
1767        ];
1768        storage.add_chunks(buffer_id, &chunks).unwrap();
1769
1770        let results = storage.search_fts("quick", 10).unwrap();
1771        // At least the first chunk should match
1772        assert!(!results.is_empty());
1773        // Scores should be positive
1774        for (_id, score) in &results {
1775            assert!(
1776                *score >= 0.0,
1777                "BM25 score should be non-negative, got {score}"
1778            );
1779        }
1780    }
1781
1782    #[test]
1783    fn test_search_fts_no_match() {
1784        let mut storage = setup();
1785        let buffer = Buffer::from_content("The quick brown fox".to_string());
1786        let buffer_id = storage.add_buffer(&buffer).unwrap();
1787
1788        let chunks = vec![Chunk::new(
1789            buffer_id,
1790            "The quick brown fox".to_string(),
1791            0..19,
1792            0,
1793        )];
1794        storage.add_chunks(buffer_id, &chunks).unwrap();
1795
1796        let results = storage.search_fts("zzzyyyxxx", 10).unwrap();
1797        assert!(results.is_empty());
1798    }
1799
1800    #[test]
1801    fn test_search_fts_respects_limit() {
1802        let mut storage = setup();
1803        let buffer = Buffer::from_content("hello world hello world hello".to_string());
1804        let buffer_id = storage.add_buffer(&buffer).unwrap();
1805
1806        // Add 3 chunks all containing "hello"
1807        let chunks = vec![
1808            Chunk::new(buffer_id, "hello world".to_string(), 0..11, 0),
1809            Chunk::new(buffer_id, "hello world".to_string(), 12..23, 1),
1810            Chunk::new(buffer_id, "hello".to_string(), 24..29, 2),
1811        ];
1812        storage.add_chunks(buffer_id, &chunks).unwrap();
1813
1814        let results = storage.search_fts("hello", 2).unwrap();
1815        assert!(results.len() <= 2);
1816    }
1817
1818    #[test]
1819    fn test_get_chunks_by_ids_empty() {
1820        let mut storage = SqliteStorage::in_memory().unwrap();
1821        storage.init().unwrap();
1822
1823        let result = storage.get_chunks_by_ids(&[]).unwrap();
1824        assert!(result.is_empty());
1825    }
1826
1827    #[test]
1828    fn test_get_chunks_by_ids_batch() {
1829        let mut storage = SqliteStorage::in_memory().unwrap();
1830        storage.init().unwrap();
1831
1832        let buffer_id = storage
1833            .add_buffer(&Buffer::from_content("abc def ghi".to_string()))
1834            .unwrap();
1835        let chunks = vec![
1836            Chunk::new(buffer_id, "abc".to_string(), 0..3, 0),
1837            Chunk::new(buffer_id, "def".to_string(), 4..7, 1),
1838            Chunk::new(buffer_id, "ghi".to_string(), 8..11, 2),
1839        ];
1840        storage.add_chunks(buffer_id, &chunks).unwrap();
1841
1842        // Fetch by known IDs
1843        let all = storage.get_chunks(buffer_id).unwrap();
1844        let ids: Vec<i64> = all.iter().filter_map(|c| c.id).take(2).collect();
1845        assert_eq!(ids.len(), 2);
1846
1847        let map = storage.get_chunks_by_ids(&ids).unwrap();
1848        assert_eq!(map.len(), 2);
1849        for id in &ids {
1850            assert!(map.contains_key(id));
1851        }
1852    }
1853
1854    #[test]
1855    fn test_get_chunks_by_ids_missing_id() {
1856        let mut storage = SqliteStorage::in_memory().unwrap();
1857        storage.init().unwrap();
1858
1859        // Query for an ID that doesn't exist
1860        let map = storage.get_chunks_by_ids(&[99999]).unwrap();
1861        assert!(map.is_empty());
1862    }
1863
1864    #[test]
1865    fn test_fts_and_embedding_pipeline() {
1866        // Integration: index chunks, store embeddings, then verify that FTS
1867        // hits match chunks that also have embeddings.
1868        let mut storage = setup();
1869        let buffer = Buffer::from_content("machine learning and neural networks".to_string());
1870        let buffer_id = storage.add_buffer(&buffer).unwrap();
1871
1872        let chunks = vec![
1873            Chunk::new(
1874                buffer_id,
1875                "machine learning algorithms".to_string(),
1876                0..27,
1877                0,
1878            ),
1879            Chunk::new(
1880                buffer_id,
1881                "neural networks architecture".to_string(),
1882                28..56,
1883                1,
1884            ),
1885        ];
1886        storage.add_chunks(buffer_id, &chunks).unwrap();
1887
1888        // Store embeddings for both chunks
1889        let chunk_ids: Vec<i64> = storage
1890            .get_chunks(buffer_id)
1891            .unwrap()
1892            .into_iter()
1893            .map(|c| c.id.unwrap())
1894            .collect();
1895        assert_eq!(chunk_ids.len(), 2);
1896
1897        let embeddings: &[&[f32]] = &[&[0.1, 0.2], &[0.2, 0.4]];
1898        for (&chunk_id, embedding) in chunk_ids.iter().zip(embeddings) {
1899            storage
1900                .store_embedding(chunk_id, embedding, Some("test-model"))
1901                .unwrap();
1902        }
1903
1904        // FTS search should find the relevant chunk
1905        let fts_results = storage.search_fts("machine", 10).unwrap();
1906        assert!(!fts_results.is_empty());
1907
1908        // Every FTS result should also have an embedding stored
1909        for (chunk_id, _score) in &fts_results {
1910            assert!(
1911                storage.has_embedding(*chunk_id).unwrap(),
1912                "FTS result chunk {chunk_id} should have an embedding"
1913            );
1914        }
1915
1916        // Verify embedding count matches what was stored
1917        assert_eq!(storage.embedding_count().unwrap(), 2);
1918    }
1919}