1#![allow(clippy::cast_possible_truncation)]
9#![allow(clippy::cast_sign_loss)]
10
11use crate::core::{Buffer, BufferMetadata, Chunk, ChunkMetadata, Context};
12use crate::error::{Result, StorageError};
13use crate::storage::schema::{
14 CHECK_SCHEMA_SQL, CURRENT_SCHEMA_VERSION, GET_VERSION_SQL, SCHEMA_SQL, SET_VERSION_SQL,
15};
16use crate::storage::traits::{Storage, StorageStats};
17use rusqlite::{Connection, OptionalExtension, params};
18use std::path::{Path, PathBuf};
19
20pub struct SqliteStorage {
33 conn: Connection,
35 path: Option<PathBuf>,
37}
38
39impl SqliteStorage {
40 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
50 let path = path.as_ref().to_path_buf();
51
52 if let Some(parent) = path.parent()
54 && !parent.exists()
55 {
56 std::fs::create_dir_all(parent).map_err(|e| StorageError::Database(e.to_string()))?;
57 }
58
59 let conn = Connection::open(&path).map_err(StorageError::from)?;
60
61 conn.execute("PRAGMA foreign_keys = ON;", [])
63 .map_err(StorageError::from)?;
64
65 let _: String = conn
67 .query_row("PRAGMA journal_mode = WAL;", [], |row| row.get(0))
68 .map_err(StorageError::from)?;
69
70 Ok(Self {
71 conn,
72 path: Some(path),
73 })
74 }
75
76 pub fn in_memory() -> Result<Self> {
84 let conn = Connection::open_in_memory().map_err(StorageError::from)?;
85 conn.execute("PRAGMA foreign_keys = ON;", [])
86 .map_err(StorageError::from)?;
87
88 Ok(Self { conn, path: None })
89 }
90
91 #[must_use]
93 pub fn path(&self) -> Option<&Path> {
94 self.path.as_deref()
95 }
96
97 fn get_schema_version(&self) -> Result<Option<u32>> {
99 let version: Option<String> = self
100 .conn
101 .query_row(GET_VERSION_SQL, [], |row| row.get(0))
102 .optional()
103 .map_err(StorageError::from)?;
104
105 Ok(version.and_then(|v| v.parse().ok()))
106 }
107
108 fn set_schema_version(&self, version: u32) -> Result<()> {
110 self.conn
111 .execute(SET_VERSION_SQL, params![version.to_string()])
112 .map_err(StorageError::from)?;
113 Ok(())
114 }
115
116 #[allow(clippy::cast_possible_wrap)]
118 fn now() -> i64 {
119 std::time::SystemTime::now()
120 .duration_since(std::time::UNIX_EPOCH)
121 .map(|d| d.as_secs() as i64)
122 .unwrap_or(0)
123 }
124}
125
126impl Storage for SqliteStorage {
127 fn init(&mut self) -> Result<()> {
128 let is_init: i64 = self
130 .conn
131 .query_row(CHECK_SCHEMA_SQL, [], |row| row.get(0))
132 .map_err(StorageError::from)?;
133
134 if is_init == 0 {
135 self.conn
137 .execute_batch(SCHEMA_SQL)
138 .map_err(StorageError::from)?;
139 self.set_schema_version(CURRENT_SCHEMA_VERSION)?;
140 } else if let Some(current) = self.get_schema_version()?
141 && current < CURRENT_SCHEMA_VERSION
142 {
143 let migrations = crate::storage::schema::get_migrations_from(current);
145 for migration in migrations {
146 self.conn
147 .execute_batch(migration.sql)
148 .map_err(|e| StorageError::Migration(e.to_string()))?;
149 }
150 self.set_schema_version(CURRENT_SCHEMA_VERSION)?;
151 }
152
153 Ok(())
154 }
155
156 fn is_initialized(&self) -> Result<bool> {
157 let count: i64 = self
158 .conn
159 .query_row(CHECK_SCHEMA_SQL, [], |row| row.get(0))
160 .map_err(StorageError::from)?;
161 Ok(count > 0)
162 }
163
164 fn reset(&mut self) -> Result<()> {
165 self.conn
166 .execute_batch(
167 r"
168 DELETE FROM chunk_embeddings;
169 DELETE FROM chunks;
170 DELETE FROM buffers;
171 DELETE FROM context;
172 DELETE FROM metadata;
173 ",
174 )
175 .map_err(StorageError::from)?;
176 Ok(())
177 }
178
179 fn save_context(&mut self, context: &Context) -> Result<()> {
182 let data = serde_json::to_string(context).map_err(StorageError::from)?;
183 let now = Self::now();
184
185 self.conn
186 .execute(
187 r"
188 INSERT OR REPLACE INTO context (id, data, created_at, updated_at)
189 VALUES (1, ?, COALESCE((SELECT created_at FROM context WHERE id = 1), ?), ?)
190 ",
191 params![data, now, now],
192 )
193 .map_err(StorageError::from)?;
194
195 Ok(())
196 }
197
198 fn load_context(&self) -> Result<Option<Context>> {
199 let data: Option<String> = self
200 .conn
201 .query_row("SELECT data FROM context WHERE id = 1", [], |row| {
202 row.get(0)
203 })
204 .optional()
205 .map_err(StorageError::from)?;
206
207 match data {
208 Some(json) => {
209 let context = serde_json::from_str(&json).map_err(StorageError::from)?;
210 Ok(Some(context))
211 }
212 None => Ok(None),
213 }
214 }
215
216 fn delete_context(&mut self) -> Result<()> {
217 self.conn
218 .execute("DELETE FROM context WHERE id = 1", [])
219 .map_err(StorageError::from)?;
220 Ok(())
221 }
222
223 #[allow(clippy::cast_possible_wrap)]
226 fn add_buffer(&mut self, buffer: &Buffer) -> Result<i64> {
227 let now = Self::now();
228
229 self.conn
230 .execute(
231 r"
232 INSERT INTO buffers (
233 name, source_path, content, content_type, content_hash,
234 size, line_count, chunk_count, created_at, updated_at
235 ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
236 ",
237 params![
238 buffer.name,
239 buffer
240 .source
241 .as_ref()
242 .map(|p| p.to_string_lossy().to_string()),
243 buffer.content,
244 buffer.metadata.content_type,
245 buffer.metadata.content_hash,
246 buffer.metadata.size as i64,
247 buffer.metadata.line_count.map(|c| c as i64),
248 buffer.metadata.chunk_count.map(|c| c as i64),
249 now,
250 now,
251 ],
252 )
253 .map_err(StorageError::from)?;
254
255 Ok(self.conn.last_insert_rowid())
256 }
257
258 fn get_buffer(&self, id: i64) -> Result<Option<Buffer>> {
259 let result = self
260 .conn
261 .query_row(
262 r"
263 SELECT id, name, source_path, content, content_type, content_hash,
264 size, line_count, chunk_count, created_at, updated_at
265 FROM buffers WHERE id = ?
266 ",
267 params![id],
268 |row| {
269 Ok(Buffer {
270 id: Some(row.get::<_, i64>(0)?),
271 name: row.get(1)?,
272 source: row.get::<_, Option<String>>(2)?.map(PathBuf::from),
273 content: row.get(3)?,
274 metadata: BufferMetadata {
275 content_type: row.get(4)?,
276 content_hash: row.get(5)?,
277 size: row.get::<_, i64>(6)? as usize,
278 line_count: row.get::<_, Option<i64>>(7)?.map(|c| c as usize),
279 chunk_count: row.get::<_, Option<i64>>(8)?.map(|c| c as usize),
280 created_at: row.get(9)?,
281 updated_at: row.get(10)?,
282 },
283 })
284 },
285 )
286 .optional()
287 .map_err(StorageError::from)?;
288
289 Ok(result)
290 }
291
292 fn get_buffer_by_name(&self, name: &str) -> Result<Option<Buffer>> {
293 let id: Option<i64> = self
294 .conn
295 .query_row(
296 "SELECT id FROM buffers WHERE name = ?",
297 params![name],
298 |row| row.get(0),
299 )
300 .optional()
301 .map_err(StorageError::from)?;
302
303 id.map_or(Ok(None), |id| self.get_buffer(id))
304 }
305
306 fn list_buffers(&self) -> Result<Vec<Buffer>> {
307 let mut stmt = self
308 .conn
309 .prepare(
310 r"
311 SELECT id, name, source_path, content, content_type, content_hash,
312 size, line_count, chunk_count, created_at, updated_at
313 FROM buffers ORDER BY id
314 ",
315 )
316 .map_err(StorageError::from)?;
317
318 let buffers = stmt
319 .query_map([], |row| {
320 Ok(Buffer {
321 id: Some(row.get::<_, i64>(0)?),
322 name: row.get(1)?,
323 source: row.get::<_, Option<String>>(2)?.map(PathBuf::from),
324 content: row.get(3)?,
325 metadata: BufferMetadata {
326 content_type: row.get(4)?,
327 content_hash: row.get(5)?,
328 size: row.get::<_, i64>(6)? as usize,
329 line_count: row.get::<_, Option<i64>>(7)?.map(|c| c as usize),
330 chunk_count: row.get::<_, Option<i64>>(8)?.map(|c| c as usize),
331 created_at: row.get(9)?,
332 updated_at: row.get(10)?,
333 },
334 })
335 })
336 .map_err(StorageError::from)?
337 .collect::<std::result::Result<Vec<_>, _>>()
338 .map_err(StorageError::from)?;
339
340 Ok(buffers)
341 }
342
343 #[allow(clippy::cast_possible_wrap)]
344 fn update_buffer(&mut self, buffer: &Buffer) -> Result<()> {
345 let id = buffer.id.ok_or_else(|| StorageError::BufferNotFound {
346 identifier: "no ID".to_string(),
347 })?;
348
349 let now = Self::now();
350
351 self.conn
352 .execute(
353 r"
354 UPDATE buffers SET
355 name = ?, source_path = ?, content = ?, content_type = ?,
356 content_hash = ?, size = ?, line_count = ?, chunk_count = ?,
357 updated_at = ?
358 WHERE id = ?
359 ",
360 params![
361 buffer.name,
362 buffer
363 .source
364 .as_ref()
365 .map(|p| p.to_string_lossy().to_string()),
366 buffer.content,
367 buffer.metadata.content_type,
368 buffer.metadata.content_hash,
369 buffer.metadata.size as i64,
370 buffer.metadata.line_count.map(|c| c as i64),
371 buffer.metadata.chunk_count.map(|c| c as i64),
372 now,
373 id,
374 ],
375 )
376 .map_err(StorageError::from)?;
377
378 Ok(())
379 }
380
381 fn delete_buffer(&mut self, id: i64) -> Result<()> {
382 self.conn
384 .execute("DELETE FROM buffers WHERE id = ?", params![id])
385 .map_err(StorageError::from)?;
386 Ok(())
387 }
388
389 fn buffer_count(&self) -> Result<usize> {
390 let count: i64 = self
391 .conn
392 .query_row("SELECT COUNT(*) FROM buffers", [], |row| row.get(0))
393 .map_err(StorageError::from)?;
394 Ok(count as usize)
395 }
396
397 #[allow(clippy::cast_possible_wrap)]
400 fn add_chunks(&mut self, buffer_id: i64, chunks: &[Chunk]) -> Result<()> {
401 let tx = self.conn.transaction().map_err(StorageError::from)?;
402 let now = Self::now();
403
404 {
405 let mut stmt = tx
406 .prepare(
407 r"
408 INSERT INTO chunks (
409 buffer_id, content, byte_start, byte_end, chunk_index,
410 strategy, token_count, line_start, line_end, has_overlap,
411 content_hash, custom_metadata, created_at
412 ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
413 ",
414 )
415 .map_err(StorageError::from)?;
416
417 for chunk in chunks {
418 let custom_meta = chunk.metadata.custom.clone();
419
420 let (line_start, line_end) = chunk
421 .metadata
422 .line_range
423 .as_ref()
424 .map_or((None, None), |r| (Some(r.start as i64), Some(r.end as i64)));
425
426 stmt.execute(params![
427 buffer_id,
428 chunk.content,
429 chunk.byte_range.start as i64,
430 chunk.byte_range.end as i64,
431 chunk.index as i64,
432 chunk.metadata.strategy,
433 chunk.metadata.token_count.map(|c| c as i64),
434 line_start,
435 line_end,
436 i64::from(chunk.metadata.has_overlap),
437 chunk.metadata.content_hash,
438 custom_meta,
439 now,
440 ])
441 .map_err(StorageError::from)?;
442 }
443 }
444
445 tx.commit().map_err(StorageError::from)?;
446
447 self.conn
449 .execute(
450 "UPDATE buffers SET chunk_count = ? WHERE id = ?",
451 params![chunks.len() as i64, buffer_id],
452 )
453 .map_err(StorageError::from)?;
454
455 Ok(())
456 }
457
458 fn get_chunks(&self, buffer_id: i64) -> Result<Vec<Chunk>> {
459 let mut stmt = self
460 .conn
461 .prepare(
462 r"
463 SELECT id, buffer_id, content, byte_start, byte_end, chunk_index,
464 strategy, token_count, line_start, line_end, has_overlap,
465 content_hash, custom_metadata, created_at
466 FROM chunks WHERE buffer_id = ? ORDER BY chunk_index
467 ",
468 )
469 .map_err(StorageError::from)?;
470
471 let chunks = stmt
472 .query_map(params![buffer_id], |row| {
473 let line_start: Option<i64> = row.get(8)?;
474 let line_end: Option<i64> = row.get(9)?;
475 let line_range = match (line_start, line_end) {
476 (Some(s), Some(e)) => Some((s as usize)..(e as usize)),
477 _ => None,
478 };
479
480 Ok(Chunk {
481 id: Some(row.get::<_, i64>(0)?),
482 buffer_id: row.get(1)?,
483 content: row.get(2)?,
484 byte_range: (row.get::<_, i64>(3)? as usize)..(row.get::<_, i64>(4)? as usize),
485 index: row.get::<_, i64>(5)? as usize,
486 metadata: ChunkMetadata {
487 strategy: row.get(6)?,
488 token_count: row.get::<_, Option<i64>>(7)?.map(|c| c as usize),
489 line_range,
490 has_overlap: row.get::<_, i64>(10)? != 0,
491 content_hash: row.get(11)?,
492 custom: row.get(12)?,
493 created_at: row.get(13)?,
494 },
495 })
496 })
497 .map_err(StorageError::from)?
498 .collect::<std::result::Result<Vec<_>, _>>()
499 .map_err(StorageError::from)?;
500
501 Ok(chunks)
502 }
503
504 fn get_chunk(&self, id: i64) -> Result<Option<Chunk>> {
505 let result = self
506 .conn
507 .query_row(
508 r"
509 SELECT id, buffer_id, content, byte_start, byte_end, chunk_index,
510 strategy, token_count, line_start, line_end, has_overlap,
511 content_hash, custom_metadata, created_at
512 FROM chunks WHERE id = ?
513 ",
514 params![id],
515 |row| {
516 let line_start: Option<i64> = row.get(8)?;
517 let line_end: Option<i64> = row.get(9)?;
518 let line_range = match (line_start, line_end) {
519 (Some(s), Some(e)) => Some((s as usize)..(e as usize)),
520 _ => None,
521 };
522
523 Ok(Chunk {
524 id: Some(row.get::<_, i64>(0)?),
525 buffer_id: row.get(1)?,
526 content: row.get(2)?,
527 byte_range: (row.get::<_, i64>(3)? as usize)
528 ..(row.get::<_, i64>(4)? as usize),
529 index: row.get::<_, i64>(5)? as usize,
530 metadata: ChunkMetadata {
531 strategy: row.get(6)?,
532 token_count: row.get::<_, Option<i64>>(7)?.map(|c| c as usize),
533 line_range,
534 has_overlap: row.get::<_, i64>(10)? != 0,
535 content_hash: row.get(11)?,
536 custom: row.get(12)?,
537 created_at: row.get(13)?,
538 },
539 })
540 },
541 )
542 .optional()
543 .map_err(StorageError::from)?;
544
545 Ok(result)
546 }
547
548 fn delete_chunks(&mut self, buffer_id: i64) -> Result<()> {
549 self.conn
550 .execute("DELETE FROM chunks WHERE buffer_id = ?", params![buffer_id])
551 .map_err(StorageError::from)?;
552
553 self.conn
555 .execute(
556 "UPDATE buffers SET chunk_count = 0 WHERE id = ?",
557 params![buffer_id],
558 )
559 .map_err(StorageError::from)?;
560
561 Ok(())
562 }
563
564 fn chunk_count(&self, buffer_id: i64) -> Result<usize> {
565 let count: i64 = self
566 .conn
567 .query_row(
568 "SELECT COUNT(*) FROM chunks WHERE buffer_id = ?",
569 params![buffer_id],
570 |row| row.get(0),
571 )
572 .map_err(StorageError::from)?;
573 Ok(count as usize)
574 }
575
576 fn export_buffers(&self) -> Result<String> {
579 let buffers = self.list_buffers()?;
580 let mut output = String::new();
581
582 for (i, buffer) in buffers.iter().enumerate() {
583 if i > 0 {
584 output.push_str("\n\n");
585 }
586 output.push_str(&buffer.content);
587 }
588
589 Ok(output)
590 }
591
592 fn stats(&self) -> Result<StorageStats> {
593 let buffer_count = self.buffer_count()?;
594
595 let chunk_count: i64 = self
596 .conn
597 .query_row("SELECT COUNT(*) FROM chunks", [], |row| row.get(0))
598 .map_err(StorageError::from)?;
599
600 let total_size: i64 = self
601 .conn
602 .query_row("SELECT COALESCE(SUM(size), 0) FROM buffers", [], |row| {
603 row.get(0)
604 })
605 .map_err(StorageError::from)?;
606
607 let has_context = self.load_context()?.is_some();
608
609 let schema_version = self.get_schema_version()?.unwrap_or(0);
610
611 let db_size = self
612 .path
613 .as_ref()
614 .and_then(|p| std::fs::metadata(p).ok().map(|m| m.len()));
615
616 Ok(StorageStats {
617 buffer_count,
618 chunk_count: chunk_count as usize,
619 total_content_size: total_size as usize,
620 has_context,
621 schema_version,
622 db_size,
623 })
624 }
625}
626
627impl SqliteStorage {
630 #[allow(clippy::cast_possible_wrap)]
642 pub fn store_embedding(
643 &mut self,
644 chunk_id: i64,
645 embedding: &[f32],
646 model_name: Option<&str>,
647 ) -> Result<()> {
648 let now = Self::now();
649
650 let bytes: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
652
653 self.conn
654 .execute(
655 r"
656 INSERT OR REPLACE INTO chunk_embeddings (chunk_id, embedding, dimensions, model_name, created_at)
657 VALUES (?, ?, ?, ?, ?)
658 ",
659 params![chunk_id, bytes, embedding.len() as i64, model_name, now],
660 )
661 .map_err(StorageError::from)?;
662
663 Ok(())
664 }
665
666 pub fn get_embedding(&self, chunk_id: i64) -> Result<Option<Vec<f32>>> {
672 let result: Option<Vec<u8>> = self
673 .conn
674 .query_row(
675 "SELECT embedding FROM chunk_embeddings WHERE chunk_id = ?",
676 params![chunk_id],
677 |row| row.get(0),
678 )
679 .optional()
680 .map_err(StorageError::from)?;
681
682 Ok(result.map(|bytes| {
683 bytes
684 .chunks_exact(4)
685 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
686 .collect()
687 }))
688 }
689
690 pub fn get_embedding_models(&self, buffer_id: i64) -> Result<Vec<String>> {
699 let mut stmt = self
700 .conn
701 .prepare(
702 r"
703 SELECT DISTINCT ce.model_name
704 FROM chunk_embeddings ce
705 JOIN chunks c ON ce.chunk_id = c.id
706 WHERE c.buffer_id = ? AND ce.model_name IS NOT NULL
707 ",
708 )
709 .map_err(StorageError::from)?;
710
711 let models = stmt
712 .query_map(params![buffer_id], |row| row.get::<_, String>(0))
713 .map_err(StorageError::from)?
714 .filter_map(std::result::Result::ok)
715 .collect();
716
717 Ok(models)
718 }
719
720 pub fn get_embedding_model_counts(&self, buffer_id: i64) -> Result<Vec<(Option<String>, i64)>> {
728 let mut stmt = self
729 .conn
730 .prepare(
731 r"
732 SELECT ce.model_name, COUNT(*) as count
733 FROM chunk_embeddings ce
734 JOIN chunks c ON ce.chunk_id = c.id
735 WHERE c.buffer_id = ?
736 GROUP BY ce.model_name
737 ",
738 )
739 .map_err(StorageError::from)?;
740
741 let counts = stmt
742 .query_map(params![buffer_id], |row| {
743 Ok((row.get::<_, Option<String>>(0)?, row.get::<_, i64>(1)?))
744 })
745 .map_err(StorageError::from)?
746 .filter_map(std::result::Result::ok)
747 .collect();
748
749 Ok(counts)
750 }
751
752 #[allow(clippy::cast_possible_wrap)]
758 pub fn store_embeddings_batch(
759 &mut self,
760 embeddings: &[(i64, Vec<f32>)],
761 model_name: Option<&str>,
762 ) -> Result<()> {
763 let tx = self.conn.transaction().map_err(StorageError::from)?;
764 let now = Self::now();
765
766 {
767 let mut stmt = tx
768 .prepare(
769 r"
770 INSERT OR REPLACE INTO chunk_embeddings (chunk_id, embedding, dimensions, model_name, created_at)
771 VALUES (?, ?, ?, ?, ?)
772 ",
773 )
774 .map_err(StorageError::from)?;
775
776 for (chunk_id, embedding) in embeddings {
777 let bytes: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
778
779 stmt.execute(params![
780 chunk_id,
781 bytes,
782 embedding.len() as i64,
783 model_name,
784 now
785 ])
786 .map_err(StorageError::from)?;
787 }
788 }
789
790 tx.commit().map_err(StorageError::from)?;
791 Ok(())
792 }
793
794 pub fn delete_embedding(&mut self, chunk_id: i64) -> Result<()> {
800 self.conn
801 .execute(
802 "DELETE FROM chunk_embeddings WHERE chunk_id = ?",
803 params![chunk_id],
804 )
805 .map_err(StorageError::from)?;
806 Ok(())
807 }
808
809 #[allow(clippy::cast_possible_wrap)]
822 pub fn search_fts(&self, query: &str, limit: usize) -> Result<Vec<(i64, f64)>> {
823 let fts_query = query
830 .split_whitespace()
831 .map(|term| format!("\"{}\"", term.replace('"', "\"\"")))
832 .collect::<Vec<_>>()
833 .join(" OR ");
834
835 let mut stmt = self
836 .conn
837 .prepare(
838 r"
839 SELECT rowid, -bm25(chunks_fts) as score
840 FROM chunks_fts
841 WHERE chunks_fts MATCH ?
842 ORDER BY score DESC
843 LIMIT ?
844 ",
845 )
846 .map_err(StorageError::from)?;
847
848 let results = stmt
849 .query_map(params![fts_query, limit as i64], |row| {
850 Ok((row.get::<_, i64>(0)?, row.get::<_, f64>(1)?))
851 })
852 .map_err(StorageError::from)?
853 .collect::<std::result::Result<Vec<_>, _>>()
854 .map_err(StorageError::from)?;
855
856 Ok(results)
857 }
858
859 pub fn get_all_embeddings(&self) -> Result<Vec<(i64, Vec<f32>)>> {
865 let mut stmt = self
866 .conn
867 .prepare("SELECT chunk_id, embedding FROM chunk_embeddings")
868 .map_err(StorageError::from)?;
869
870 let results = stmt
871 .query_map([], |row| {
872 let chunk_id: i64 = row.get(0)?;
873 let bytes: Vec<u8> = row.get(1)?;
874 let embedding: Vec<f32> = bytes
875 .chunks_exact(4)
876 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
877 .collect();
878 Ok((chunk_id, embedding))
879 })
880 .map_err(StorageError::from)?
881 .collect::<std::result::Result<Vec<_>, _>>()
882 .map_err(StorageError::from)?;
883
884 Ok(results)
885 }
886
887 pub fn embedding_count(&self) -> Result<usize> {
893 let count: i64 = self
894 .conn
895 .query_row("SELECT COUNT(*) FROM chunk_embeddings", [], |row| {
896 row.get(0)
897 })
898 .map_err(StorageError::from)?;
899 Ok(count as usize)
900 }
901
902 pub fn has_embedding(&self, chunk_id: i64) -> Result<bool> {
908 let count: i64 = self
909 .conn
910 .query_row(
911 "SELECT COUNT(*) FROM chunk_embeddings WHERE chunk_id = ?",
912 params![chunk_id],
913 |row| row.get(0),
914 )
915 .map_err(StorageError::from)?;
916 Ok(count > 0)
917 }
918
919 pub fn get_chunks_needing_embedding(
935 &self,
936 buffer_id: i64,
937 current_model: Option<&str>,
938 ) -> Result<Vec<i64>> {
939 let mut results = Vec::new();
940
941 let mut stmt = self
943 .conn
944 .prepare(
945 r"
946 SELECT c.id FROM chunks c
947 LEFT JOIN chunk_embeddings e ON c.id = e.chunk_id
948 WHERE c.buffer_id = ? AND e.chunk_id IS NULL
949 ",
950 )
951 .map_err(StorageError::from)?;
952
953 let rows = stmt
954 .query_map(params![buffer_id], |row| row.get(0))
955 .map_err(StorageError::from)?;
956
957 for row in rows {
958 results.push(row.map_err(StorageError::from)?);
959 }
960
961 if let Some(model) = current_model {
963 let mut stmt = self
964 .conn
965 .prepare(
966 r"
967 SELECT c.id FROM chunks c
968 INNER JOIN chunk_embeddings e ON c.id = e.chunk_id
969 WHERE c.buffer_id = ? AND (e.model_name IS NULL OR e.model_name != ?)
970 ",
971 )
972 .map_err(StorageError::from)?;
973
974 let rows = stmt
975 .query_map(params![buffer_id, model], |row| row.get(0))
976 .map_err(StorageError::from)?;
977
978 for row in rows {
979 results.push(row.map_err(StorageError::from)?);
980 }
981 }
982
983 results.sort_unstable();
985 results.dedup();
986 Ok(results)
987 }
988
989 pub fn get_chunks_without_embedding(&self, buffer_id: i64) -> Result<Vec<i64>> {
997 self.get_chunks_needing_embedding(buffer_id, None)
998 }
999
1000 pub fn delete_embeddings_by_model(
1017 &mut self,
1018 buffer_id: i64,
1019 model_name: Option<&str>,
1020 ) -> Result<usize> {
1021 let deleted = match model_name {
1022 Some(name) => self
1023 .conn
1024 .execute(
1025 r"
1026 DELETE FROM chunk_embeddings
1027 WHERE chunk_id IN (
1028 SELECT id FROM chunks WHERE buffer_id = ?
1029 ) AND model_name = ?
1030 ",
1031 params![buffer_id, name],
1032 )
1033 .map_err(StorageError::from)?,
1034 None => self
1035 .conn
1036 .execute(
1037 r"
1038 DELETE FROM chunk_embeddings
1039 WHERE chunk_id IN (
1040 SELECT id FROM chunks WHERE buffer_id = ?
1041 ) AND model_name IS NULL
1042 ",
1043 params![buffer_id],
1044 )
1045 .map_err(StorageError::from)?,
1046 };
1047 Ok(deleted)
1048 }
1049
1050 pub fn get_embedding_stats(&self, buffer_id: i64) -> Result<EmbeddingStats> {
1058 let total_chunks: i64 = self
1060 .conn
1061 .query_row(
1062 "SELECT COUNT(*) FROM chunks WHERE buffer_id = ?",
1063 params![buffer_id],
1064 |row| row.get(0),
1065 )
1066 .map_err(StorageError::from)?;
1067
1068 let embedded_chunks: i64 = self
1070 .conn
1071 .query_row(
1072 r"
1073 SELECT COUNT(*) FROM chunk_embeddings e
1074 INNER JOIN chunks c ON e.chunk_id = c.id
1075 WHERE c.buffer_id = ?
1076 ",
1077 params![buffer_id],
1078 |row| row.get(0),
1079 )
1080 .map_err(StorageError::from)?;
1081
1082 let model_counts = self.get_embedding_model_counts(buffer_id)?;
1084
1085 Ok(EmbeddingStats {
1086 total_chunks: total_chunks as usize,
1087 embedded_chunks: embedded_chunks as usize,
1088 model_counts,
1089 })
1090 }
1091}
1092
1093#[derive(Debug, Clone)]
1095pub struct EmbeddingStats {
1096 pub total_chunks: usize,
1098 pub embedded_chunks: usize,
1100 pub model_counts: Vec<(Option<String>, i64)>,
1102}
1103
1104#[cfg(test)]
1105mod tests {
1106 use super::*;
1107 use crate::core::ContextValue;
1108
1109 fn setup() -> SqliteStorage {
1110 let mut storage = SqliteStorage::in_memory().unwrap();
1111 storage.init().unwrap();
1112 storage
1113 }
1114
1115 #[test]
1116 fn test_init() {
1117 let mut storage = SqliteStorage::in_memory().unwrap();
1118 assert!(storage.init().is_ok());
1119 assert!(storage.is_initialized().unwrap());
1120 }
1121
1122 #[test]
1123 fn test_init_idempotent() {
1124 let mut storage = SqliteStorage::in_memory().unwrap();
1125 assert!(storage.init().is_ok());
1126 assert!(storage.init().is_ok()); }
1128
1129 #[test]
1130 fn test_context_crud() {
1131 let mut storage = setup();
1132
1133 assert!(storage.load_context().unwrap().is_none());
1135
1136 let mut ctx = Context::new();
1138 ctx.set_variable("key".to_string(), ContextValue::String("value".to_string()));
1139 storage.save_context(&ctx).unwrap();
1140
1141 let loaded = storage.load_context().unwrap().unwrap();
1143 assert_eq!(
1144 loaded.get_variable("key"),
1145 Some(&ContextValue::String("value".to_string()))
1146 );
1147
1148 storage.delete_context().unwrap();
1150 assert!(storage.load_context().unwrap().is_none());
1151 }
1152
1153 #[test]
1154 fn test_buffer_crud() {
1155 let mut storage = setup();
1156
1157 let buffer = Buffer::from_named("test".to_string(), "Hello, world!".to_string());
1159 let id = storage.add_buffer(&buffer).unwrap();
1160 assert!(id > 0);
1161
1162 let loaded = storage.get_buffer(id).unwrap().unwrap();
1164 assert_eq!(loaded.name, Some("test".to_string()));
1165 assert_eq!(loaded.content, "Hello, world!");
1166
1167 let by_name = storage.get_buffer_by_name("test").unwrap().unwrap();
1169 assert_eq!(by_name.id, Some(id));
1170
1171 let buffers = storage.list_buffers().unwrap();
1173 assert_eq!(buffers.len(), 1);
1174
1175 let mut updated = loaded;
1177 updated.content = "Updated content".to_string();
1178 storage.update_buffer(&updated).unwrap();
1179
1180 let reloaded = storage.get_buffer(id).unwrap().unwrap();
1181 assert_eq!(reloaded.content, "Updated content");
1182
1183 storage.delete_buffer(id).unwrap();
1185 assert!(storage.get_buffer(id).unwrap().is_none());
1186 }
1187
1188 #[test]
1189 fn test_chunk_crud() {
1190 let mut storage = setup();
1191
1192 let buffer = Buffer::from_content("Hello, world!".to_string());
1194 let buffer_id = storage.add_buffer(&buffer).unwrap();
1195
1196 let chunks = vec![
1198 Chunk::new(buffer_id, "Hello, ".to_string(), 0..7, 0),
1199 Chunk::new(buffer_id, "world!".to_string(), 7..13, 1),
1200 ];
1201 storage.add_chunks(buffer_id, &chunks).unwrap();
1202
1203 let loaded = storage.get_chunks(buffer_id).unwrap();
1205 assert_eq!(loaded.len(), 2);
1206 assert_eq!(loaded[0].content, "Hello, ");
1207 assert_eq!(loaded[1].content, "world!");
1208
1209 assert_eq!(storage.chunk_count(buffer_id).unwrap(), 2);
1211
1212 let chunk_id = loaded[0].id.unwrap();
1214 let single = storage.get_chunk(chunk_id).unwrap().unwrap();
1215 assert_eq!(single.content, "Hello, ");
1216
1217 storage.delete_chunks(buffer_id).unwrap();
1219 assert_eq!(storage.chunk_count(buffer_id).unwrap(), 0);
1220 }
1221
1222 #[test]
1223 fn test_cascade_delete() {
1224 let mut storage = setup();
1225
1226 let buffer = Buffer::from_content("Hello, world!".to_string());
1228 let buffer_id = storage.add_buffer(&buffer).unwrap();
1229
1230 let chunks = vec![Chunk::new(buffer_id, "Hello".to_string(), 0..5, 0)];
1231 storage.add_chunks(buffer_id, &chunks).unwrap();
1232
1233 assert_eq!(storage.chunk_count(buffer_id).unwrap(), 1);
1235
1236 storage.delete_buffer(buffer_id).unwrap();
1238
1239 let count: i64 = storage
1241 .conn
1242 .query_row("SELECT COUNT(*) FROM chunks", [], |row| row.get(0))
1243 .unwrap();
1244 assert_eq!(count, 0);
1245 }
1246
1247 #[test]
1248 fn test_reset() {
1249 let mut storage = setup();
1250
1251 let ctx = Context::new();
1253 storage.save_context(&ctx).unwrap();
1254
1255 let buffer = Buffer::from_content("test".to_string());
1256 storage.add_buffer(&buffer).unwrap();
1257
1258 storage.reset().unwrap();
1260
1261 assert!(storage.load_context().unwrap().is_none());
1263 assert_eq!(storage.buffer_count().unwrap(), 0);
1264 }
1265
1266 #[test]
1267 fn test_stats() {
1268 let mut storage = setup();
1269
1270 let stats = storage.stats().unwrap();
1272 assert_eq!(stats.buffer_count, 0);
1273 assert_eq!(stats.chunk_count, 0);
1274 assert!(!stats.has_context);
1275
1276 let ctx = Context::new();
1278 storage.save_context(&ctx).unwrap();
1279
1280 let buffer = Buffer::from_content("Hello, world!".to_string());
1281 let buffer_id = storage.add_buffer(&buffer).unwrap();
1282
1283 let chunks = vec![Chunk::new(buffer_id, "Hello".to_string(), 0..5, 0)];
1284 storage.add_chunks(buffer_id, &chunks).unwrap();
1285
1286 let stats = storage.stats().unwrap();
1288 assert_eq!(stats.buffer_count, 1);
1289 assert_eq!(stats.chunk_count, 1);
1290 assert!(stats.has_context);
1291 assert_eq!(stats.total_content_size, 13);
1292 }
1293
1294 #[test]
1295 fn test_export_buffers() {
1296 let mut storage = setup();
1297
1298 storage
1299 .add_buffer(&Buffer::from_content("First".to_string()))
1300 .unwrap();
1301 storage
1302 .add_buffer(&Buffer::from_content("Second".to_string()))
1303 .unwrap();
1304
1305 let exported = storage.export_buffers().unwrap();
1306 assert_eq!(exported, "First\n\nSecond");
1307 }
1308}