Skip to main content

rlm_rs/storage/
sqlite.rs

1//! `SQLite` storage implementation.
2//!
3//! Provides persistent storage using `SQLite` with proper transaction
4//! management and migration support.
5
6// SQLite stores all integers as i64. These casts are intentional and safe
7// because we only store non-negative values that fit in usize.
8#![allow(clippy::cast_possible_truncation)]
9#![allow(clippy::cast_sign_loss)]
10
11use crate::core::{Buffer, BufferMetadata, Chunk, ChunkMetadata, Context};
12use crate::error::{Result, StorageError};
13use crate::storage::schema::{
14    CHECK_SCHEMA_SQL, CURRENT_SCHEMA_VERSION, GET_VERSION_SQL, SCHEMA_SQL, SET_VERSION_SQL,
15};
16use crate::storage::traits::{Storage, StorageStats};
17use rusqlite::{Connection, OptionalExtension, params};
18use std::path::{Path, PathBuf};
19
20/// SQLite-based storage implementation.
21///
22/// Provides persistent storage for RLM state with full ACID guarantees.
23///
24/// # Examples
25///
26/// ```no_run
27/// use rlm_rs::storage::{SqliteStorage, Storage};
28///
29/// let mut storage = SqliteStorage::open("rlm-state.db").unwrap();
30/// storage.init().unwrap();
31/// ```
32pub struct SqliteStorage {
33    /// `SQLite` connection.
34    conn: Connection,
35    /// Path to the database file (None for in-memory).
36    path: Option<PathBuf>,
37}
38
39impl SqliteStorage {
40    /// Opens or creates a `SQLite` database at the given path.
41    ///
42    /// # Arguments
43    ///
44    /// * `path` - Path to the database file. Parent directory must exist.
45    ///
46    /// # Errors
47    ///
48    /// Returns an error if the database cannot be opened or initialized.
49    pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
50        let path = path.as_ref().to_path_buf();
51
52        // Ensure parent directory exists
53        if let Some(parent) = path.parent()
54            && !parent.exists()
55        {
56            std::fs::create_dir_all(parent).map_err(|e| StorageError::Database(e.to_string()))?;
57        }
58
59        let conn = Connection::open(&path).map_err(StorageError::from)?;
60
61        // Enable foreign keys
62        conn.execute("PRAGMA foreign_keys = ON;", [])
63            .map_err(StorageError::from)?;
64
65        // Use WAL mode for better concurrent access (returns result, use query_row)
66        let _: String = conn
67            .query_row("PRAGMA journal_mode = WAL;", [], |row| row.get(0))
68            .map_err(StorageError::from)?;
69
70        Ok(Self {
71            conn,
72            path: Some(path),
73        })
74    }
75
76    /// Creates an in-memory `SQLite` database.
77    ///
78    /// Useful for testing.
79    ///
80    /// # Errors
81    ///
82    /// Returns an error if the database cannot be created.
83    pub fn in_memory() -> Result<Self> {
84        let conn = Connection::open_in_memory().map_err(StorageError::from)?;
85        conn.execute("PRAGMA foreign_keys = ON;", [])
86            .map_err(StorageError::from)?;
87
88        Ok(Self { conn, path: None })
89    }
90
91    /// Returns the database path (None for in-memory).
92    #[must_use]
93    pub fn path(&self) -> Option<&Path> {
94        self.path.as_deref()
95    }
96
97    /// Gets the current schema version.
98    fn get_schema_version(&self) -> Result<Option<u32>> {
99        let version: Option<String> = self
100            .conn
101            .query_row(GET_VERSION_SQL, [], |row| row.get(0))
102            .optional()
103            .map_err(StorageError::from)?;
104
105        Ok(version.and_then(|v| v.parse().ok()))
106    }
107
108    /// Sets the schema version.
109    fn set_schema_version(&self, version: u32) -> Result<()> {
110        self.conn
111            .execute(SET_VERSION_SQL, params![version.to_string()])
112            .map_err(StorageError::from)?;
113        Ok(())
114    }
115
116    /// Returns current Unix timestamp.
117    #[allow(clippy::cast_possible_wrap)]
118    fn now() -> i64 {
119        std::time::SystemTime::now()
120            .duration_since(std::time::UNIX_EPOCH)
121            .map(|d| d.as_secs() as i64)
122            .unwrap_or(0)
123    }
124}
125
126impl Storage for SqliteStorage {
127    fn init(&mut self) -> Result<()> {
128        // Check if already initialized
129        let is_init: i64 = self
130            .conn
131            .query_row(CHECK_SCHEMA_SQL, [], |row| row.get(0))
132            .map_err(StorageError::from)?;
133
134        if is_init == 0 {
135            // Fresh install - create schema
136            self.conn
137                .execute_batch(SCHEMA_SQL)
138                .map_err(StorageError::from)?;
139            self.set_schema_version(CURRENT_SCHEMA_VERSION)?;
140        } else if let Some(current) = self.get_schema_version()?
141            && current < CURRENT_SCHEMA_VERSION
142        {
143            // Run migrations
144            let migrations = crate::storage::schema::get_migrations_from(current);
145            for migration in migrations {
146                self.conn
147                    .execute_batch(migration.sql)
148                    .map_err(|e| StorageError::Migration(e.to_string()))?;
149            }
150            self.set_schema_version(CURRENT_SCHEMA_VERSION)?;
151        }
152
153        Ok(())
154    }
155
156    fn is_initialized(&self) -> Result<bool> {
157        let count: i64 = self
158            .conn
159            .query_row(CHECK_SCHEMA_SQL, [], |row| row.get(0))
160            .map_err(StorageError::from)?;
161        Ok(count > 0)
162    }
163
164    fn reset(&mut self) -> Result<()> {
165        self.conn
166            .execute_batch(
167                r"
168            DELETE FROM chunk_embeddings;
169            DELETE FROM chunks;
170            DELETE FROM buffers;
171            DELETE FROM context;
172            DELETE FROM metadata;
173        ",
174            )
175            .map_err(StorageError::from)?;
176        Ok(())
177    }
178
179    // ==================== Context Operations ====================
180
181    fn save_context(&mut self, context: &Context) -> Result<()> {
182        let data = serde_json::to_string(context).map_err(StorageError::from)?;
183        let now = Self::now();
184
185        self.conn
186            .execute(
187                r"
188            INSERT OR REPLACE INTO context (id, data, created_at, updated_at)
189            VALUES (1, ?, COALESCE((SELECT created_at FROM context WHERE id = 1), ?), ?)
190        ",
191                params![data, now, now],
192            )
193            .map_err(StorageError::from)?;
194
195        Ok(())
196    }
197
198    fn load_context(&self) -> Result<Option<Context>> {
199        let data: Option<String> = self
200            .conn
201            .query_row("SELECT data FROM context WHERE id = 1", [], |row| {
202                row.get(0)
203            })
204            .optional()
205            .map_err(StorageError::from)?;
206
207        match data {
208            Some(json) => {
209                let context = serde_json::from_str(&json).map_err(StorageError::from)?;
210                Ok(Some(context))
211            }
212            None => Ok(None),
213        }
214    }
215
216    fn delete_context(&mut self) -> Result<()> {
217        self.conn
218            .execute("DELETE FROM context WHERE id = 1", [])
219            .map_err(StorageError::from)?;
220        Ok(())
221    }
222
223    // ==================== Buffer Operations ====================
224
225    #[allow(clippy::cast_possible_wrap)]
226    fn add_buffer(&mut self, buffer: &Buffer) -> Result<i64> {
227        let now = Self::now();
228
229        self.conn
230            .execute(
231                r"
232            INSERT INTO buffers (
233                name, source_path, content, content_type, content_hash,
234                size, line_count, chunk_count, created_at, updated_at
235            ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
236        ",
237                params![
238                    buffer.name,
239                    buffer
240                        .source
241                        .as_ref()
242                        .map(|p| p.to_string_lossy().to_string()),
243                    buffer.content,
244                    buffer.metadata.content_type,
245                    buffer.metadata.content_hash,
246                    buffer.metadata.size as i64,
247                    buffer.metadata.line_count.map(|c| c as i64),
248                    buffer.metadata.chunk_count.map(|c| c as i64),
249                    now,
250                    now,
251                ],
252            )
253            .map_err(StorageError::from)?;
254
255        Ok(self.conn.last_insert_rowid())
256    }
257
258    fn get_buffer(&self, id: i64) -> Result<Option<Buffer>> {
259        let result = self
260            .conn
261            .query_row(
262                r"
263            SELECT id, name, source_path, content, content_type, content_hash,
264                   size, line_count, chunk_count, created_at, updated_at
265            FROM buffers WHERE id = ?
266        ",
267                params![id],
268                |row| {
269                    Ok(Buffer {
270                        id: Some(row.get::<_, i64>(0)?),
271                        name: row.get(1)?,
272                        source: row.get::<_, Option<String>>(2)?.map(PathBuf::from),
273                        content: row.get(3)?,
274                        metadata: BufferMetadata {
275                            content_type: row.get(4)?,
276                            content_hash: row.get(5)?,
277                            size: row.get::<_, i64>(6)? as usize,
278                            line_count: row.get::<_, Option<i64>>(7)?.map(|c| c as usize),
279                            chunk_count: row.get::<_, Option<i64>>(8)?.map(|c| c as usize),
280                            created_at: row.get(9)?,
281                            updated_at: row.get(10)?,
282                        },
283                    })
284                },
285            )
286            .optional()
287            .map_err(StorageError::from)?;
288
289        Ok(result)
290    }
291
292    fn get_buffer_by_name(&self, name: &str) -> Result<Option<Buffer>> {
293        let id: Option<i64> = self
294            .conn
295            .query_row(
296                "SELECT id FROM buffers WHERE name = ?",
297                params![name],
298                |row| row.get(0),
299            )
300            .optional()
301            .map_err(StorageError::from)?;
302
303        id.map_or(Ok(None), |id| self.get_buffer(id))
304    }
305
306    fn list_buffers(&self) -> Result<Vec<Buffer>> {
307        let mut stmt = self
308            .conn
309            .prepare(
310                r"
311            SELECT id, name, source_path, content, content_type, content_hash,
312                   size, line_count, chunk_count, created_at, updated_at
313            FROM buffers ORDER BY id
314        ",
315            )
316            .map_err(StorageError::from)?;
317
318        let buffers = stmt
319            .query_map([], |row| {
320                Ok(Buffer {
321                    id: Some(row.get::<_, i64>(0)?),
322                    name: row.get(1)?,
323                    source: row.get::<_, Option<String>>(2)?.map(PathBuf::from),
324                    content: row.get(3)?,
325                    metadata: BufferMetadata {
326                        content_type: row.get(4)?,
327                        content_hash: row.get(5)?,
328                        size: row.get::<_, i64>(6)? as usize,
329                        line_count: row.get::<_, Option<i64>>(7)?.map(|c| c as usize),
330                        chunk_count: row.get::<_, Option<i64>>(8)?.map(|c| c as usize),
331                        created_at: row.get(9)?,
332                        updated_at: row.get(10)?,
333                    },
334                })
335            })
336            .map_err(StorageError::from)?
337            .collect::<std::result::Result<Vec<_>, _>>()
338            .map_err(StorageError::from)?;
339
340        Ok(buffers)
341    }
342
343    #[allow(clippy::cast_possible_wrap)]
344    fn update_buffer(&mut self, buffer: &Buffer) -> Result<()> {
345        let id = buffer.id.ok_or_else(|| StorageError::BufferNotFound {
346            identifier: "no ID".to_string(),
347        })?;
348
349        let now = Self::now();
350
351        self.conn
352            .execute(
353                r"
354            UPDATE buffers SET
355                name = ?, source_path = ?, content = ?, content_type = ?,
356                content_hash = ?, size = ?, line_count = ?, chunk_count = ?,
357                updated_at = ?
358            WHERE id = ?
359        ",
360                params![
361                    buffer.name,
362                    buffer
363                        .source
364                        .as_ref()
365                        .map(|p| p.to_string_lossy().to_string()),
366                    buffer.content,
367                    buffer.metadata.content_type,
368                    buffer.metadata.content_hash,
369                    buffer.metadata.size as i64,
370                    buffer.metadata.line_count.map(|c| c as i64),
371                    buffer.metadata.chunk_count.map(|c| c as i64),
372                    now,
373                    id,
374                ],
375            )
376            .map_err(StorageError::from)?;
377
378        Ok(())
379    }
380
381    fn delete_buffer(&mut self, id: i64) -> Result<()> {
382        // Chunks are deleted automatically via CASCADE
383        self.conn
384            .execute("DELETE FROM buffers WHERE id = ?", params![id])
385            .map_err(StorageError::from)?;
386        Ok(())
387    }
388
389    fn buffer_count(&self) -> Result<usize> {
390        let count: i64 = self
391            .conn
392            .query_row("SELECT COUNT(*) FROM buffers", [], |row| row.get(0))
393            .map_err(StorageError::from)?;
394        Ok(count as usize)
395    }
396
397    // ==================== Chunk Operations ====================
398
399    #[allow(clippy::cast_possible_wrap)]
400    fn add_chunks(&mut self, buffer_id: i64, chunks: &[Chunk]) -> Result<()> {
401        let tx = self.conn.transaction().map_err(StorageError::from)?;
402        let now = Self::now();
403
404        {
405            let mut stmt = tx
406                .prepare(
407                    r"
408                INSERT INTO chunks (
409                    buffer_id, content, byte_start, byte_end, chunk_index,
410                    strategy, token_count, line_start, line_end, has_overlap,
411                    content_hash, custom_metadata, created_at
412                ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
413            ",
414                )
415                .map_err(StorageError::from)?;
416
417            for chunk in chunks {
418                let custom_meta = chunk.metadata.custom.clone();
419
420                let (line_start, line_end) = chunk
421                    .metadata
422                    .line_range
423                    .as_ref()
424                    .map_or((None, None), |r| (Some(r.start as i64), Some(r.end as i64)));
425
426                stmt.execute(params![
427                    buffer_id,
428                    chunk.content,
429                    chunk.byte_range.start as i64,
430                    chunk.byte_range.end as i64,
431                    chunk.index as i64,
432                    chunk.metadata.strategy,
433                    chunk.metadata.token_count.map(|c| c as i64),
434                    line_start,
435                    line_end,
436                    i64::from(chunk.metadata.has_overlap),
437                    chunk.metadata.content_hash,
438                    custom_meta,
439                    now,
440                ])
441                .map_err(StorageError::from)?;
442            }
443        }
444
445        tx.commit().map_err(StorageError::from)?;
446
447        // Update chunk count on buffer
448        self.conn
449            .execute(
450                "UPDATE buffers SET chunk_count = ? WHERE id = ?",
451                params![chunks.len() as i64, buffer_id],
452            )
453            .map_err(StorageError::from)?;
454
455        Ok(())
456    }
457
458    fn get_chunks(&self, buffer_id: i64) -> Result<Vec<Chunk>> {
459        let mut stmt = self
460            .conn
461            .prepare(
462                r"
463            SELECT id, buffer_id, content, byte_start, byte_end, chunk_index,
464                   strategy, token_count, line_start, line_end, has_overlap,
465                   content_hash, custom_metadata, created_at
466            FROM chunks WHERE buffer_id = ? ORDER BY chunk_index
467        ",
468            )
469            .map_err(StorageError::from)?;
470
471        let chunks = stmt
472            .query_map(params![buffer_id], |row| {
473                let line_start: Option<i64> = row.get(8)?;
474                let line_end: Option<i64> = row.get(9)?;
475                let line_range = match (line_start, line_end) {
476                    (Some(s), Some(e)) => Some((s as usize)..(e as usize)),
477                    _ => None,
478                };
479
480                Ok(Chunk {
481                    id: Some(row.get::<_, i64>(0)?),
482                    buffer_id: row.get(1)?,
483                    content: row.get(2)?,
484                    byte_range: (row.get::<_, i64>(3)? as usize)..(row.get::<_, i64>(4)? as usize),
485                    index: row.get::<_, i64>(5)? as usize,
486                    metadata: ChunkMetadata {
487                        strategy: row.get(6)?,
488                        token_count: row.get::<_, Option<i64>>(7)?.map(|c| c as usize),
489                        line_range,
490                        has_overlap: row.get::<_, i64>(10)? != 0,
491                        content_hash: row.get(11)?,
492                        custom: row.get(12)?,
493                        created_at: row.get(13)?,
494                    },
495                })
496            })
497            .map_err(StorageError::from)?
498            .collect::<std::result::Result<Vec<_>, _>>()
499            .map_err(StorageError::from)?;
500
501        Ok(chunks)
502    }
503
504    fn get_chunk(&self, id: i64) -> Result<Option<Chunk>> {
505        let result = self
506            .conn
507            .query_row(
508                r"
509            SELECT id, buffer_id, content, byte_start, byte_end, chunk_index,
510                   strategy, token_count, line_start, line_end, has_overlap,
511                   content_hash, custom_metadata, created_at
512            FROM chunks WHERE id = ?
513        ",
514                params![id],
515                |row| {
516                    let line_start: Option<i64> = row.get(8)?;
517                    let line_end: Option<i64> = row.get(9)?;
518                    let line_range = match (line_start, line_end) {
519                        (Some(s), Some(e)) => Some((s as usize)..(e as usize)),
520                        _ => None,
521                    };
522
523                    Ok(Chunk {
524                        id: Some(row.get::<_, i64>(0)?),
525                        buffer_id: row.get(1)?,
526                        content: row.get(2)?,
527                        byte_range: (row.get::<_, i64>(3)? as usize)
528                            ..(row.get::<_, i64>(4)? as usize),
529                        index: row.get::<_, i64>(5)? as usize,
530                        metadata: ChunkMetadata {
531                            strategy: row.get(6)?,
532                            token_count: row.get::<_, Option<i64>>(7)?.map(|c| c as usize),
533                            line_range,
534                            has_overlap: row.get::<_, i64>(10)? != 0,
535                            content_hash: row.get(11)?,
536                            custom: row.get(12)?,
537                            created_at: row.get(13)?,
538                        },
539                    })
540                },
541            )
542            .optional()
543            .map_err(StorageError::from)?;
544
545        Ok(result)
546    }
547
548    fn delete_chunks(&mut self, buffer_id: i64) -> Result<()> {
549        self.conn
550            .execute("DELETE FROM chunks WHERE buffer_id = ?", params![buffer_id])
551            .map_err(StorageError::from)?;
552
553        // Update chunk count on buffer
554        self.conn
555            .execute(
556                "UPDATE buffers SET chunk_count = 0 WHERE id = ?",
557                params![buffer_id],
558            )
559            .map_err(StorageError::from)?;
560
561        Ok(())
562    }
563
564    fn chunk_count(&self, buffer_id: i64) -> Result<usize> {
565        let count: i64 = self
566            .conn
567            .query_row(
568                "SELECT COUNT(*) FROM chunks WHERE buffer_id = ?",
569                params![buffer_id],
570                |row| row.get(0),
571            )
572            .map_err(StorageError::from)?;
573        Ok(count as usize)
574    }
575
576    // ==================== Utility Operations ====================
577
578    fn export_buffers(&self) -> Result<String> {
579        let buffers = self.list_buffers()?;
580        let mut output = String::new();
581
582        for (i, buffer) in buffers.iter().enumerate() {
583            if i > 0 {
584                output.push_str("\n\n");
585            }
586            output.push_str(&buffer.content);
587        }
588
589        Ok(output)
590    }
591
592    fn stats(&self) -> Result<StorageStats> {
593        let buffer_count = self.buffer_count()?;
594
595        let chunk_count: i64 = self
596            .conn
597            .query_row("SELECT COUNT(*) FROM chunks", [], |row| row.get(0))
598            .map_err(StorageError::from)?;
599
600        let total_size: i64 = self
601            .conn
602            .query_row("SELECT COALESCE(SUM(size), 0) FROM buffers", [], |row| {
603                row.get(0)
604            })
605            .map_err(StorageError::from)?;
606
607        let has_context = self.load_context()?.is_some();
608
609        let schema_version = self.get_schema_version()?.unwrap_or(0);
610
611        let db_size = self
612            .path
613            .as_ref()
614            .and_then(|p| std::fs::metadata(p).ok().map(|m| m.len()));
615
616        Ok(StorageStats {
617            buffer_count,
618            chunk_count: chunk_count as usize,
619            total_content_size: total_size as usize,
620            has_context,
621            schema_version,
622            db_size,
623        })
624    }
625}
626
627// ==================== Embedding & Search Operations ====================
628
629impl SqliteStorage {
630    /// Stores an embedding for a chunk.
631    ///
632    /// # Arguments
633    ///
634    /// * `chunk_id` - The chunk ID to associate the embedding with.
635    /// * `embedding` - The embedding vector (f32 array).
636    /// * `model_name` - Optional name of the model that generated the embedding.
637    ///
638    /// # Errors
639    ///
640    /// Returns an error if the embedding cannot be stored.
641    #[allow(clippy::cast_possible_wrap)]
642    pub fn store_embedding(
643        &mut self,
644        chunk_id: i64,
645        embedding: &[f32],
646        model_name: Option<&str>,
647    ) -> Result<()> {
648        let now = Self::now();
649
650        // Serialize f32 array to bytes (little-endian)
651        let bytes: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
652
653        self.conn
654            .execute(
655                r"
656                INSERT OR REPLACE INTO chunk_embeddings (chunk_id, embedding, dimensions, model_name, created_at)
657                VALUES (?, ?, ?, ?, ?)
658            ",
659                params![chunk_id, bytes, embedding.len() as i64, model_name, now],
660            )
661            .map_err(StorageError::from)?;
662
663        Ok(())
664    }
665
666    /// Retrieves the embedding for a chunk.
667    ///
668    /// # Errors
669    ///
670    /// Returns an error if the query fails.
671    pub fn get_embedding(&self, chunk_id: i64) -> Result<Option<Vec<f32>>> {
672        let result: Option<Vec<u8>> = self
673            .conn
674            .query_row(
675                "SELECT embedding FROM chunk_embeddings WHERE chunk_id = ?",
676                params![chunk_id],
677                |row| row.get(0),
678            )
679            .optional()
680            .map_err(StorageError::from)?;
681
682        Ok(result.map(|bytes| {
683            bytes
684                .chunks_exact(4)
685                .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
686                .collect()
687        }))
688    }
689
690    /// Gets the distinct model names used for embeddings in a buffer.
691    ///
692    /// Returns the set of model names used to generate embeddings for
693    /// chunks belonging to the specified buffer.
694    ///
695    /// # Errors
696    ///
697    /// Returns an error if the query fails.
698    pub fn get_embedding_models(&self, buffer_id: i64) -> Result<Vec<String>> {
699        let mut stmt = self
700            .conn
701            .prepare(
702                r"
703                SELECT DISTINCT ce.model_name
704                FROM chunk_embeddings ce
705                JOIN chunks c ON ce.chunk_id = c.id
706                WHERE c.buffer_id = ? AND ce.model_name IS NOT NULL
707                ",
708            )
709            .map_err(StorageError::from)?;
710
711        let models = stmt
712            .query_map(params![buffer_id], |row| row.get::<_, String>(0))
713            .map_err(StorageError::from)?
714            .filter_map(std::result::Result::ok)
715            .collect();
716
717        Ok(models)
718    }
719
720    /// Gets the count of embeddings by model name for a buffer.
721    ///
722    /// Returns a list of (`model_name`, count) pairs.
723    ///
724    /// # Errors
725    ///
726    /// Returns an error if the query fails.
727    pub fn get_embedding_model_counts(&self, buffer_id: i64) -> Result<Vec<(Option<String>, i64)>> {
728        let mut stmt = self
729            .conn
730            .prepare(
731                r"
732                SELECT ce.model_name, COUNT(*) as count
733                FROM chunk_embeddings ce
734                JOIN chunks c ON ce.chunk_id = c.id
735                WHERE c.buffer_id = ?
736                GROUP BY ce.model_name
737                ",
738            )
739            .map_err(StorageError::from)?;
740
741        let counts = stmt
742            .query_map(params![buffer_id], |row| {
743                Ok((row.get::<_, Option<String>>(0)?, row.get::<_, i64>(1)?))
744            })
745            .map_err(StorageError::from)?
746            .filter_map(std::result::Result::ok)
747            .collect();
748
749        Ok(counts)
750    }
751
752    /// Stores embeddings for multiple chunks in a batch.
753    ///
754    /// # Errors
755    ///
756    /// Returns an error if any embedding cannot be stored.
757    #[allow(clippy::cast_possible_wrap)]
758    pub fn store_embeddings_batch(
759        &mut self,
760        embeddings: &[(i64, Vec<f32>)],
761        model_name: Option<&str>,
762    ) -> Result<()> {
763        let tx = self.conn.transaction().map_err(StorageError::from)?;
764        let now = Self::now();
765
766        {
767            let mut stmt = tx
768                .prepare(
769                    r"
770                    INSERT OR REPLACE INTO chunk_embeddings (chunk_id, embedding, dimensions, model_name, created_at)
771                    VALUES (?, ?, ?, ?, ?)
772                ",
773                )
774                .map_err(StorageError::from)?;
775
776            for (chunk_id, embedding) in embeddings {
777                let bytes: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
778
779                stmt.execute(params![
780                    chunk_id,
781                    bytes,
782                    embedding.len() as i64,
783                    model_name,
784                    now
785                ])
786                .map_err(StorageError::from)?;
787            }
788        }
789
790        tx.commit().map_err(StorageError::from)?;
791        Ok(())
792    }
793
794    /// Deletes the embedding for a chunk.
795    ///
796    /// # Errors
797    ///
798    /// Returns an error if deletion fails.
799    pub fn delete_embedding(&mut self, chunk_id: i64) -> Result<()> {
800        self.conn
801            .execute(
802                "DELETE FROM chunk_embeddings WHERE chunk_id = ?",
803                params![chunk_id],
804            )
805            .map_err(StorageError::from)?;
806        Ok(())
807    }
808
809    /// Performs FTS5 BM25 full-text search.
810    ///
811    /// Returns chunk IDs and their BM25 scores (lower is better match).
812    ///
813    /// # Arguments
814    ///
815    /// * `query` - The search query (supports FTS5 query syntax).
816    /// * `limit` - Maximum number of results to return.
817    ///
818    /// # Errors
819    ///
820    /// Returns an error if the search fails.
821    #[allow(clippy::cast_possible_wrap)]
822    pub fn search_fts(&self, query: &str, limit: usize) -> Result<Vec<(i64, f64)>> {
823        // FTS5 bm25() returns negative scores, more negative = better match
824        // We negate it so higher scores = better match
825
826        // Convert space-separated terms to OR query for more forgiving search
827        // Each term is quoted to escape FTS5 special characters (?, *, ^, etc.)
828        // "CLI tool?" becomes '"CLI" OR "tool?"' so special chars are treated as literals
829        let fts_query = query
830            .split_whitespace()
831            .map(|term| format!("\"{}\"", term.replace('"', "\"\"")))
832            .collect::<Vec<_>>()
833            .join(" OR ");
834
835        let mut stmt = self
836            .conn
837            .prepare(
838                r"
839                SELECT rowid, -bm25(chunks_fts) as score
840                FROM chunks_fts
841                WHERE chunks_fts MATCH ?
842                ORDER BY score DESC
843                LIMIT ?
844            ",
845            )
846            .map_err(StorageError::from)?;
847
848        let results = stmt
849            .query_map(params![fts_query, limit as i64], |row| {
850                Ok((row.get::<_, i64>(0)?, row.get::<_, f64>(1)?))
851            })
852            .map_err(StorageError::from)?
853            .collect::<std::result::Result<Vec<_>, _>>()
854            .map_err(StorageError::from)?;
855
856        Ok(results)
857    }
858
859    /// Returns all chunk embeddings for vector similarity search.
860    ///
861    /// # Errors
862    ///
863    /// Returns an error if the query fails.
864    pub fn get_all_embeddings(&self) -> Result<Vec<(i64, Vec<f32>)>> {
865        let mut stmt = self
866            .conn
867            .prepare("SELECT chunk_id, embedding FROM chunk_embeddings")
868            .map_err(StorageError::from)?;
869
870        let results = stmt
871            .query_map([], |row| {
872                let chunk_id: i64 = row.get(0)?;
873                let bytes: Vec<u8> = row.get(1)?;
874                let embedding: Vec<f32> = bytes
875                    .chunks_exact(4)
876                    .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
877                    .collect();
878                Ok((chunk_id, embedding))
879            })
880            .map_err(StorageError::from)?
881            .collect::<std::result::Result<Vec<_>, _>>()
882            .map_err(StorageError::from)?;
883
884        Ok(results)
885    }
886
887    /// Counts chunks with embeddings.
888    ///
889    /// # Errors
890    ///
891    /// Returns an error if the count fails.
892    pub fn embedding_count(&self) -> Result<usize> {
893        let count: i64 = self
894            .conn
895            .query_row("SELECT COUNT(*) FROM chunk_embeddings", [], |row| {
896                row.get(0)
897            })
898            .map_err(StorageError::from)?;
899        Ok(count as usize)
900    }
901
902    /// Checks if a chunk has an embedding.
903    ///
904    /// # Errors
905    ///
906    /// Returns an error if the query fails.
907    pub fn has_embedding(&self, chunk_id: i64) -> Result<bool> {
908        let count: i64 = self
909            .conn
910            .query_row(
911                "SELECT COUNT(*) FROM chunk_embeddings WHERE chunk_id = ?",
912                params![chunk_id],
913                |row| row.get(0),
914            )
915            .map_err(StorageError::from)?;
916        Ok(count > 0)
917    }
918
919    /// Gets chunk IDs that need embedding (either no embedding or wrong model).
920    ///
921    /// This is used for incremental embedding updates. Returns chunks that:
922    /// - Have no embedding at all, OR
923    /// - Have an embedding with a different model name (if `current_model` is provided)
924    ///
925    /// # Arguments
926    ///
927    /// * `buffer_id` - The buffer to check.
928    /// * `current_model` - Optional model name to check against. If provided,
929    ///   chunks with different models are included.
930    ///
931    /// # Errors
932    ///
933    /// Returns an error if the query fails.
934    pub fn get_chunks_needing_embedding(
935        &self,
936        buffer_id: i64,
937        current_model: Option<&str>,
938    ) -> Result<Vec<i64>> {
939        let mut results = Vec::new();
940
941        // Get chunks without any embedding
942        let mut stmt = self
943            .conn
944            .prepare(
945                r"
946                SELECT c.id FROM chunks c
947                LEFT JOIN chunk_embeddings e ON c.id = e.chunk_id
948                WHERE c.buffer_id = ? AND e.chunk_id IS NULL
949                ",
950            )
951            .map_err(StorageError::from)?;
952
953        let rows = stmt
954            .query_map(params![buffer_id], |row| row.get(0))
955            .map_err(StorageError::from)?;
956
957        for row in rows {
958            results.push(row.map_err(StorageError::from)?);
959        }
960
961        // If model specified, also get chunks with different model
962        if let Some(model) = current_model {
963            let mut stmt = self
964                .conn
965                .prepare(
966                    r"
967                    SELECT c.id FROM chunks c
968                    INNER JOIN chunk_embeddings e ON c.id = e.chunk_id
969                    WHERE c.buffer_id = ? AND (e.model_name IS NULL OR e.model_name != ?)
970                    ",
971                )
972                .map_err(StorageError::from)?;
973
974            let rows = stmt
975                .query_map(params![buffer_id, model], |row| row.get(0))
976                .map_err(StorageError::from)?;
977
978            for row in rows {
979                results.push(row.map_err(StorageError::from)?);
980            }
981        }
982
983        // Deduplicate (in case of overlap, though shouldn't happen)
984        results.sort_unstable();
985        results.dedup();
986        Ok(results)
987    }
988
989    /// Gets chunks without any embedding for a buffer.
990    ///
991    /// Simpler version of `get_chunks_needing_embedding` when model doesn't matter.
992    ///
993    /// # Errors
994    ///
995    /// Returns an error if the query fails.
996    pub fn get_chunks_without_embedding(&self, buffer_id: i64) -> Result<Vec<i64>> {
997        self.get_chunks_needing_embedding(buffer_id, None)
998    }
999
1000    /// Deletes embeddings with a specific model name.
1001    ///
1002    /// Useful for cleaning up embeddings from old models before re-embedding.
1003    ///
1004    /// # Arguments
1005    ///
1006    /// * `buffer_id` - The buffer to clean.
1007    /// * `model_name` - The model name to match (or None to match NULL).
1008    ///
1009    /// # Returns
1010    ///
1011    /// The number of embeddings deleted.
1012    ///
1013    /// # Errors
1014    ///
1015    /// Returns an error if deletion fails.
1016    pub fn delete_embeddings_by_model(
1017        &mut self,
1018        buffer_id: i64,
1019        model_name: Option<&str>,
1020    ) -> Result<usize> {
1021        let deleted = match model_name {
1022            Some(name) => self
1023                .conn
1024                .execute(
1025                    r"
1026                    DELETE FROM chunk_embeddings
1027                    WHERE chunk_id IN (
1028                        SELECT id FROM chunks WHERE buffer_id = ?
1029                    ) AND model_name = ?
1030                    ",
1031                    params![buffer_id, name],
1032                )
1033                .map_err(StorageError::from)?,
1034            None => self
1035                .conn
1036                .execute(
1037                    r"
1038                    DELETE FROM chunk_embeddings
1039                    WHERE chunk_id IN (
1040                        SELECT id FROM chunks WHERE buffer_id = ?
1041                    ) AND model_name IS NULL
1042                    ",
1043                    params![buffer_id],
1044                )
1045                .map_err(StorageError::from)?,
1046        };
1047        Ok(deleted)
1048    }
1049
1050    /// Gets embedding statistics for a buffer.
1051    ///
1052    /// Returns counts of embedded vs total chunks, and model breakdown.
1053    ///
1054    /// # Errors
1055    ///
1056    /// Returns an error if the query fails.
1057    pub fn get_embedding_stats(&self, buffer_id: i64) -> Result<EmbeddingStats> {
1058        // Total chunks
1059        let total_chunks: i64 = self
1060            .conn
1061            .query_row(
1062                "SELECT COUNT(*) FROM chunks WHERE buffer_id = ?",
1063                params![buffer_id],
1064                |row| row.get(0),
1065            )
1066            .map_err(StorageError::from)?;
1067
1068        // Embedded chunks
1069        let embedded_chunks: i64 = self
1070            .conn
1071            .query_row(
1072                r"
1073                SELECT COUNT(*) FROM chunk_embeddings e
1074                INNER JOIN chunks c ON e.chunk_id = c.id
1075                WHERE c.buffer_id = ?
1076                ",
1077                params![buffer_id],
1078                |row| row.get(0),
1079            )
1080            .map_err(StorageError::from)?;
1081
1082        // Model counts
1083        let model_counts = self.get_embedding_model_counts(buffer_id)?;
1084
1085        Ok(EmbeddingStats {
1086            total_chunks: total_chunks as usize,
1087            embedded_chunks: embedded_chunks as usize,
1088            model_counts,
1089        })
1090    }
1091}
1092
1093/// Statistics about embeddings for a buffer.
1094#[derive(Debug, Clone)]
1095pub struct EmbeddingStats {
1096    /// Total number of chunks in the buffer.
1097    pub total_chunks: usize,
1098    /// Number of chunks with embeddings.
1099    pub embedded_chunks: usize,
1100    /// Count of embeddings by model (`model_name`, count).
1101    pub model_counts: Vec<(Option<String>, i64)>,
1102}
1103
1104#[cfg(test)]
1105mod tests {
1106    use super::*;
1107    use crate::core::ContextValue;
1108
1109    fn setup() -> SqliteStorage {
1110        let mut storage = SqliteStorage::in_memory().unwrap();
1111        storage.init().unwrap();
1112        storage
1113    }
1114
1115    #[test]
1116    fn test_init() {
1117        let mut storage = SqliteStorage::in_memory().unwrap();
1118        assert!(storage.init().is_ok());
1119        assert!(storage.is_initialized().unwrap());
1120    }
1121
1122    #[test]
1123    fn test_init_idempotent() {
1124        let mut storage = SqliteStorage::in_memory().unwrap();
1125        assert!(storage.init().is_ok());
1126        assert!(storage.init().is_ok()); // Second init should be fine
1127    }
1128
1129    #[test]
1130    fn test_context_crud() {
1131        let mut storage = setup();
1132
1133        // No context initially
1134        assert!(storage.load_context().unwrap().is_none());
1135
1136        // Save context
1137        let mut ctx = Context::new();
1138        ctx.set_variable("key".to_string(), ContextValue::String("value".to_string()));
1139        storage.save_context(&ctx).unwrap();
1140
1141        // Load context
1142        let loaded = storage.load_context().unwrap().unwrap();
1143        assert_eq!(
1144            loaded.get_variable("key"),
1145            Some(&ContextValue::String("value".to_string()))
1146        );
1147
1148        // Delete context
1149        storage.delete_context().unwrap();
1150        assert!(storage.load_context().unwrap().is_none());
1151    }
1152
1153    #[test]
1154    fn test_buffer_crud() {
1155        let mut storage = setup();
1156
1157        // Add buffer
1158        let buffer = Buffer::from_named("test".to_string(), "Hello, world!".to_string());
1159        let id = storage.add_buffer(&buffer).unwrap();
1160        assert!(id > 0);
1161
1162        // Get buffer
1163        let loaded = storage.get_buffer(id).unwrap().unwrap();
1164        assert_eq!(loaded.name, Some("test".to_string()));
1165        assert_eq!(loaded.content, "Hello, world!");
1166
1167        // Get by name
1168        let by_name = storage.get_buffer_by_name("test").unwrap().unwrap();
1169        assert_eq!(by_name.id, Some(id));
1170
1171        // List buffers
1172        let buffers = storage.list_buffers().unwrap();
1173        assert_eq!(buffers.len(), 1);
1174
1175        // Update buffer
1176        let mut updated = loaded;
1177        updated.content = "Updated content".to_string();
1178        storage.update_buffer(&updated).unwrap();
1179
1180        let reloaded = storage.get_buffer(id).unwrap().unwrap();
1181        assert_eq!(reloaded.content, "Updated content");
1182
1183        // Delete buffer
1184        storage.delete_buffer(id).unwrap();
1185        assert!(storage.get_buffer(id).unwrap().is_none());
1186    }
1187
1188    #[test]
1189    fn test_chunk_crud() {
1190        let mut storage = setup();
1191
1192        // Create buffer first
1193        let buffer = Buffer::from_content("Hello, world!".to_string());
1194        let buffer_id = storage.add_buffer(&buffer).unwrap();
1195
1196        // Add chunks
1197        let chunks = vec![
1198            Chunk::new(buffer_id, "Hello, ".to_string(), 0..7, 0),
1199            Chunk::new(buffer_id, "world!".to_string(), 7..13, 1),
1200        ];
1201        storage.add_chunks(buffer_id, &chunks).unwrap();
1202
1203        // Get chunks
1204        let loaded = storage.get_chunks(buffer_id).unwrap();
1205        assert_eq!(loaded.len(), 2);
1206        assert_eq!(loaded[0].content, "Hello, ");
1207        assert_eq!(loaded[1].content, "world!");
1208
1209        // Chunk count
1210        assert_eq!(storage.chunk_count(buffer_id).unwrap(), 2);
1211
1212        // Get single chunk
1213        let chunk_id = loaded[0].id.unwrap();
1214        let single = storage.get_chunk(chunk_id).unwrap().unwrap();
1215        assert_eq!(single.content, "Hello, ");
1216
1217        // Delete chunks
1218        storage.delete_chunks(buffer_id).unwrap();
1219        assert_eq!(storage.chunk_count(buffer_id).unwrap(), 0);
1220    }
1221
1222    #[test]
1223    fn test_cascade_delete() {
1224        let mut storage = setup();
1225
1226        // Create buffer with chunks
1227        let buffer = Buffer::from_content("Hello, world!".to_string());
1228        let buffer_id = storage.add_buffer(&buffer).unwrap();
1229
1230        let chunks = vec![Chunk::new(buffer_id, "Hello".to_string(), 0..5, 0)];
1231        storage.add_chunks(buffer_id, &chunks).unwrap();
1232
1233        // Verify chunk exists
1234        assert_eq!(storage.chunk_count(buffer_id).unwrap(), 1);
1235
1236        // Delete buffer - chunks should be deleted too
1237        storage.delete_buffer(buffer_id).unwrap();
1238
1239        // Verify no orphan chunks (query all chunks)
1240        let count: i64 = storage
1241            .conn
1242            .query_row("SELECT COUNT(*) FROM chunks", [], |row| row.get(0))
1243            .unwrap();
1244        assert_eq!(count, 0);
1245    }
1246
1247    #[test]
1248    fn test_reset() {
1249        let mut storage = setup();
1250
1251        // Add some data
1252        let ctx = Context::new();
1253        storage.save_context(&ctx).unwrap();
1254
1255        let buffer = Buffer::from_content("test".to_string());
1256        storage.add_buffer(&buffer).unwrap();
1257
1258        // Reset
1259        storage.reset().unwrap();
1260
1261        // Verify empty
1262        assert!(storage.load_context().unwrap().is_none());
1263        assert_eq!(storage.buffer_count().unwrap(), 0);
1264    }
1265
1266    #[test]
1267    fn test_stats() {
1268        let mut storage = setup();
1269
1270        // Empty stats
1271        let stats = storage.stats().unwrap();
1272        assert_eq!(stats.buffer_count, 0);
1273        assert_eq!(stats.chunk_count, 0);
1274        assert!(!stats.has_context);
1275
1276        // Add data
1277        let ctx = Context::new();
1278        storage.save_context(&ctx).unwrap();
1279
1280        let buffer = Buffer::from_content("Hello, world!".to_string());
1281        let buffer_id = storage.add_buffer(&buffer).unwrap();
1282
1283        let chunks = vec![Chunk::new(buffer_id, "Hello".to_string(), 0..5, 0)];
1284        storage.add_chunks(buffer_id, &chunks).unwrap();
1285
1286        // Stats with data
1287        let stats = storage.stats().unwrap();
1288        assert_eq!(stats.buffer_count, 1);
1289        assert_eq!(stats.chunk_count, 1);
1290        assert!(stats.has_context);
1291        assert_eq!(stats.total_content_size, 13);
1292    }
1293
1294    #[test]
1295    fn test_export_buffers() {
1296        let mut storage = setup();
1297
1298        storage
1299            .add_buffer(&Buffer::from_content("First".to_string()))
1300            .unwrap();
1301        storage
1302            .add_buffer(&Buffer::from_content("Second".to_string()))
1303            .unwrap();
1304
1305        let exported = storage.export_buffers().unwrap();
1306        assert_eq!(exported, "First\n\nSecond");
1307    }
1308}