1#![allow(clippy::cast_possible_truncation)]
9#![allow(clippy::cast_sign_loss)]
10
11use crate::core::{Buffer, BufferMetadata, Chunk, ChunkMetadata, Context};
12use crate::error::{Result, StorageError};
13use crate::storage::schema::{
14 CHECK_SCHEMA_SQL, CURRENT_SCHEMA_VERSION, GET_VERSION_SQL, SCHEMA_SQL, SET_VERSION_SQL,
15};
16use crate::storage::traits::{Storage, StorageStats};
17use rusqlite::{Connection, OptionalExtension, params};
18use std::path::{Path, PathBuf};
19
20pub struct SqliteStorage {
33 conn: Connection,
35 path: Option<PathBuf>,
37}
38
39impl SqliteStorage {
40 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
50 let path = path.as_ref().to_path_buf();
51
52 if let Some(parent) = path.parent()
54 && !parent.exists()
55 {
56 std::fs::create_dir_all(parent).map_err(|e| StorageError::Database(e.to_string()))?;
57 }
58
59 let conn = Connection::open(&path).map_err(StorageError::from)?;
60
61 conn.execute("PRAGMA foreign_keys = ON;", [])
63 .map_err(StorageError::from)?;
64
65 let _: String = conn
67 .query_row("PRAGMA journal_mode = WAL;", [], |row| row.get(0))
68 .map_err(StorageError::from)?;
69
70 Ok(Self {
71 conn,
72 path: Some(path),
73 })
74 }
75
76 pub fn in_memory() -> Result<Self> {
84 let conn = Connection::open_in_memory().map_err(StorageError::from)?;
85 conn.execute("PRAGMA foreign_keys = ON;", [])
86 .map_err(StorageError::from)?;
87
88 Ok(Self { conn, path: None })
89 }
90
91 #[must_use]
93 pub fn path(&self) -> Option<&Path> {
94 self.path.as_deref()
95 }
96
97 fn get_schema_version(&self) -> Result<Option<u32>> {
99 let version: Option<String> = self
100 .conn
101 .query_row(GET_VERSION_SQL, [], |row| row.get(0))
102 .optional()
103 .map_err(StorageError::from)?;
104
105 Ok(version.and_then(|v| v.parse().ok()))
106 }
107
108 fn set_schema_version(&self, version: u32) -> Result<()> {
110 self.conn
111 .execute(SET_VERSION_SQL, params![version.to_string()])
112 .map_err(StorageError::from)?;
113 Ok(())
114 }
115
116 #[allow(clippy::cast_possible_wrap)]
118 fn now() -> i64 {
119 std::time::SystemTime::now()
120 .duration_since(std::time::UNIX_EPOCH)
121 .map_or(0, |d| d.as_secs() as i64)
122 }
123}
124
125impl Storage for SqliteStorage {
126 fn init(&mut self) -> Result<()> {
127 let is_init: i64 = self
129 .conn
130 .query_row(CHECK_SCHEMA_SQL, [], |row| row.get(0))
131 .map_err(StorageError::from)?;
132
133 if is_init == 0 {
134 self.conn
136 .execute_batch(SCHEMA_SQL)
137 .map_err(StorageError::from)?;
138 self.set_schema_version(CURRENT_SCHEMA_VERSION)?;
139 } else if let Some(current) = self.get_schema_version()?
140 && current < CURRENT_SCHEMA_VERSION
141 {
142 let migrations = crate::storage::schema::get_migrations_from(current);
144 for migration in migrations {
145 self.conn
146 .execute_batch(migration.sql)
147 .map_err(|e| StorageError::Migration(e.to_string()))?;
148 }
149 self.set_schema_version(CURRENT_SCHEMA_VERSION)?;
150 }
151
152 Ok(())
153 }
154
155 fn is_initialized(&self) -> Result<bool> {
156 let count: i64 = self
157 .conn
158 .query_row(CHECK_SCHEMA_SQL, [], |row| row.get(0))
159 .map_err(StorageError::from)?;
160 Ok(count > 0)
161 }
162
163 fn reset(&mut self) -> Result<()> {
164 self.conn
165 .execute_batch(
166 r"
167 DELETE FROM chunk_embeddings;
168 DELETE FROM chunks;
169 DELETE FROM buffers;
170 DELETE FROM context;
171 DELETE FROM metadata;
172 ",
173 )
174 .map_err(StorageError::from)?;
175 Ok(())
176 }
177
178 fn save_context(&mut self, context: &Context) -> Result<()> {
181 let data = serde_json::to_string(context).map_err(StorageError::from)?;
182 let now = Self::now();
183
184 self.conn
185 .execute(
186 r"
187 INSERT OR REPLACE INTO context (id, data, created_at, updated_at)
188 VALUES (1, ?, COALESCE((SELECT created_at FROM context WHERE id = 1), ?), ?)
189 ",
190 params![data, now, now],
191 )
192 .map_err(StorageError::from)?;
193
194 Ok(())
195 }
196
197 fn load_context(&self) -> Result<Option<Context>> {
198 let data: Option<String> = self
199 .conn
200 .query_row("SELECT data FROM context WHERE id = 1", [], |row| {
201 row.get(0)
202 })
203 .optional()
204 .map_err(StorageError::from)?;
205
206 match data {
207 Some(json) => {
208 let context = serde_json::from_str(&json).map_err(StorageError::from)?;
209 Ok(Some(context))
210 }
211 None => Ok(None),
212 }
213 }
214
215 fn delete_context(&mut self) -> Result<()> {
216 self.conn
217 .execute("DELETE FROM context WHERE id = 1", [])
218 .map_err(StorageError::from)?;
219 Ok(())
220 }
221
222 #[allow(clippy::cast_possible_wrap)]
225 fn add_buffer(&mut self, buffer: &Buffer) -> Result<i64> {
226 let now = Self::now();
227
228 self.conn
229 .execute(
230 r"
231 INSERT INTO buffers (
232 name, source_path, content, content_type, content_hash,
233 size, line_count, chunk_count, created_at, updated_at
234 ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
235 ",
236 params![
237 buffer.name,
238 buffer
239 .source
240 .as_ref()
241 .map(|p| p.to_string_lossy().to_string()),
242 buffer.content,
243 buffer.metadata.content_type,
244 buffer.metadata.content_hash,
245 buffer.metadata.size as i64,
246 buffer.metadata.line_count.map(|c| c as i64),
247 buffer.metadata.chunk_count.map(|c| c as i64),
248 now,
249 now,
250 ],
251 )
252 .map_err(StorageError::from)?;
253
254 Ok(self.conn.last_insert_rowid())
255 }
256
257 fn get_buffer(&self, id: i64) -> Result<Option<Buffer>> {
258 let result = self
259 .conn
260 .query_row(
261 r"
262 SELECT id, name, source_path, content, content_type, content_hash,
263 size, line_count, chunk_count, created_at, updated_at
264 FROM buffers WHERE id = ?
265 ",
266 params![id],
267 |row| {
268 Ok(Buffer {
269 id: Some(row.get::<_, i64>(0)?),
270 name: row.get(1)?,
271 source: row.get::<_, Option<String>>(2)?.map(PathBuf::from),
272 content: row.get(3)?,
273 metadata: BufferMetadata {
274 content_type: row.get(4)?,
275 content_hash: row.get(5)?,
276 size: row.get::<_, i64>(6)? as usize,
277 line_count: row.get::<_, Option<i64>>(7)?.map(|c| c as usize),
278 chunk_count: row.get::<_, Option<i64>>(8)?.map(|c| c as usize),
279 created_at: row.get(9)?,
280 updated_at: row.get(10)?,
281 },
282 })
283 },
284 )
285 .optional()
286 .map_err(StorageError::from)?;
287
288 Ok(result)
289 }
290
291 fn get_buffer_by_name(&self, name: &str) -> Result<Option<Buffer>> {
292 let id: Option<i64> = self
293 .conn
294 .query_row(
295 "SELECT id FROM buffers WHERE name = ?",
296 params![name],
297 |row| row.get(0),
298 )
299 .optional()
300 .map_err(StorageError::from)?;
301
302 id.map_or(Ok(None), |id| self.get_buffer(id))
303 }
304
305 fn list_buffers(&self) -> Result<Vec<Buffer>> {
306 let mut stmt = self
307 .conn
308 .prepare(
309 r"
310 SELECT id, name, source_path, content, content_type, content_hash,
311 size, line_count, chunk_count, created_at, updated_at
312 FROM buffers ORDER BY id
313 ",
314 )
315 .map_err(StorageError::from)?;
316
317 let buffers = stmt
318 .query_map([], |row| {
319 Ok(Buffer {
320 id: Some(row.get::<_, i64>(0)?),
321 name: row.get(1)?,
322 source: row.get::<_, Option<String>>(2)?.map(PathBuf::from),
323 content: row.get(3)?,
324 metadata: BufferMetadata {
325 content_type: row.get(4)?,
326 content_hash: row.get(5)?,
327 size: row.get::<_, i64>(6)? as usize,
328 line_count: row.get::<_, Option<i64>>(7)?.map(|c| c as usize),
329 chunk_count: row.get::<_, Option<i64>>(8)?.map(|c| c as usize),
330 created_at: row.get(9)?,
331 updated_at: row.get(10)?,
332 },
333 })
334 })
335 .map_err(StorageError::from)?
336 .collect::<std::result::Result<Vec<_>, _>>()
337 .map_err(StorageError::from)?;
338
339 Ok(buffers)
340 }
341
342 #[allow(clippy::cast_possible_wrap)]
343 fn update_buffer(&mut self, buffer: &Buffer) -> Result<()> {
344 let id = buffer.id.ok_or_else(|| StorageError::BufferNotFound {
345 identifier: "no ID".to_string(),
346 })?;
347
348 let now = Self::now();
349
350 self.conn
351 .execute(
352 r"
353 UPDATE buffers SET
354 name = ?, source_path = ?, content = ?, content_type = ?,
355 content_hash = ?, size = ?, line_count = ?, chunk_count = ?,
356 updated_at = ?
357 WHERE id = ?
358 ",
359 params![
360 buffer.name,
361 buffer
362 .source
363 .as_ref()
364 .map(|p| p.to_string_lossy().to_string()),
365 buffer.content,
366 buffer.metadata.content_type,
367 buffer.metadata.content_hash,
368 buffer.metadata.size as i64,
369 buffer.metadata.line_count.map(|c| c as i64),
370 buffer.metadata.chunk_count.map(|c| c as i64),
371 now,
372 id,
373 ],
374 )
375 .map_err(StorageError::from)?;
376
377 Ok(())
378 }
379
380 fn delete_buffer(&mut self, id: i64) -> Result<()> {
381 self.conn
383 .execute("DELETE FROM buffers WHERE id = ?", params![id])
384 .map_err(StorageError::from)?;
385 Ok(())
386 }
387
388 fn buffer_count(&self) -> Result<usize> {
389 let count: i64 = self
390 .conn
391 .query_row("SELECT COUNT(*) FROM buffers", [], |row| row.get(0))
392 .map_err(StorageError::from)?;
393 Ok(count as usize)
394 }
395
396 #[allow(clippy::cast_possible_wrap)]
399 fn add_chunks(&mut self, buffer_id: i64, chunks: &[Chunk]) -> Result<()> {
400 let tx = self.conn.transaction().map_err(StorageError::from)?;
401 let now = Self::now();
402
403 {
404 let mut stmt = tx
405 .prepare(
406 r"
407 INSERT INTO chunks (
408 buffer_id, content, byte_start, byte_end, chunk_index,
409 strategy, token_count, line_start, line_end, has_overlap,
410 content_hash, custom_metadata, created_at
411 ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
412 ",
413 )
414 .map_err(StorageError::from)?;
415
416 for chunk in chunks {
417 let custom_meta = chunk.metadata.custom.clone();
418
419 let (line_start, line_end) = chunk
420 .metadata
421 .line_range
422 .as_ref()
423 .map_or((None, None), |r| (Some(r.start as i64), Some(r.end as i64)));
424
425 stmt.execute(params![
426 buffer_id,
427 chunk.content,
428 chunk.byte_range.start as i64,
429 chunk.byte_range.end as i64,
430 chunk.index as i64,
431 chunk.metadata.strategy,
432 chunk.metadata.token_count.map(|c| c as i64),
433 line_start,
434 line_end,
435 i64::from(chunk.metadata.has_overlap),
436 chunk.metadata.content_hash,
437 custom_meta,
438 now,
439 ])
440 .map_err(StorageError::from)?;
441 }
442 }
443
444 tx.commit().map_err(StorageError::from)?;
445
446 self.conn
448 .execute(
449 "UPDATE buffers SET chunk_count = ? WHERE id = ?",
450 params![chunks.len() as i64, buffer_id],
451 )
452 .map_err(StorageError::from)?;
453
454 Ok(())
455 }
456
457 fn get_chunks(&self, buffer_id: i64) -> Result<Vec<Chunk>> {
458 let mut stmt = self
459 .conn
460 .prepare(
461 r"
462 SELECT id, buffer_id, content, byte_start, byte_end, chunk_index,
463 strategy, token_count, line_start, line_end, has_overlap,
464 content_hash, custom_metadata, created_at
465 FROM chunks WHERE buffer_id = ? ORDER BY chunk_index
466 ",
467 )
468 .map_err(StorageError::from)?;
469
470 let chunks = stmt
471 .query_map(params![buffer_id], |row| {
472 let line_start: Option<i64> = row.get(8)?;
473 let line_end: Option<i64> = row.get(9)?;
474 let line_range = match (line_start, line_end) {
475 (Some(s), Some(e)) => Some((s as usize)..(e as usize)),
476 _ => None,
477 };
478
479 Ok(Chunk {
480 id: Some(row.get::<_, i64>(0)?),
481 buffer_id: row.get(1)?,
482 content: row.get(2)?,
483 byte_range: (row.get::<_, i64>(3)? as usize)..(row.get::<_, i64>(4)? as usize),
484 index: row.get::<_, i64>(5)? as usize,
485 metadata: ChunkMetadata {
486 strategy: row.get(6)?,
487 token_count: row.get::<_, Option<i64>>(7)?.map(|c| c as usize),
488 line_range,
489 has_overlap: row.get::<_, i64>(10)? != 0,
490 content_hash: row.get(11)?,
491 custom: row.get(12)?,
492 created_at: row.get(13)?,
493 },
494 })
495 })
496 .map_err(StorageError::from)?
497 .collect::<std::result::Result<Vec<_>, _>>()
498 .map_err(StorageError::from)?;
499
500 Ok(chunks)
501 }
502
503 fn get_chunk(&self, id: i64) -> Result<Option<Chunk>> {
504 let result = self
505 .conn
506 .query_row(
507 r"
508 SELECT id, buffer_id, content, byte_start, byte_end, chunk_index,
509 strategy, token_count, line_start, line_end, has_overlap,
510 content_hash, custom_metadata, created_at
511 FROM chunks WHERE id = ?
512 ",
513 params![id],
514 |row| {
515 let line_start: Option<i64> = row.get(8)?;
516 let line_end: Option<i64> = row.get(9)?;
517 let line_range = match (line_start, line_end) {
518 (Some(s), Some(e)) => Some((s as usize)..(e as usize)),
519 _ => None,
520 };
521
522 Ok(Chunk {
523 id: Some(row.get::<_, i64>(0)?),
524 buffer_id: row.get(1)?,
525 content: row.get(2)?,
526 byte_range: (row.get::<_, i64>(3)? as usize)
527 ..(row.get::<_, i64>(4)? as usize),
528 index: row.get::<_, i64>(5)? as usize,
529 metadata: ChunkMetadata {
530 strategy: row.get(6)?,
531 token_count: row.get::<_, Option<i64>>(7)?.map(|c| c as usize),
532 line_range,
533 has_overlap: row.get::<_, i64>(10)? != 0,
534 content_hash: row.get(11)?,
535 custom: row.get(12)?,
536 created_at: row.get(13)?,
537 },
538 })
539 },
540 )
541 .optional()
542 .map_err(StorageError::from)?;
543
544 Ok(result)
545 }
546
547 fn delete_chunks(&mut self, buffer_id: i64) -> Result<()> {
548 self.conn
549 .execute("DELETE FROM chunks WHERE buffer_id = ?", params![buffer_id])
550 .map_err(StorageError::from)?;
551
552 self.conn
554 .execute(
555 "UPDATE buffers SET chunk_count = 0 WHERE id = ?",
556 params![buffer_id],
557 )
558 .map_err(StorageError::from)?;
559
560 Ok(())
561 }
562
563 fn chunk_count(&self, buffer_id: i64) -> Result<usize> {
564 let count: i64 = self
565 .conn
566 .query_row(
567 "SELECT COUNT(*) FROM chunks WHERE buffer_id = ?",
568 params![buffer_id],
569 |row| row.get(0),
570 )
571 .map_err(StorageError::from)?;
572 Ok(count as usize)
573 }
574
575 fn export_buffers(&self) -> Result<String> {
578 let buffers = self.list_buffers()?;
579 let mut output = String::new();
580
581 for (i, buffer) in buffers.iter().enumerate() {
582 if i > 0 {
583 output.push_str("\n\n");
584 }
585 output.push_str(&buffer.content);
586 }
587
588 Ok(output)
589 }
590
591 fn stats(&self) -> Result<StorageStats> {
592 let buffer_count = self.buffer_count()?;
593
594 let chunk_count: i64 = self
595 .conn
596 .query_row("SELECT COUNT(*) FROM chunks", [], |row| row.get(0))
597 .map_err(StorageError::from)?;
598
599 let total_size: i64 = self
600 .conn
601 .query_row("SELECT COALESCE(SUM(size), 0) FROM buffers", [], |row| {
602 row.get(0)
603 })
604 .map_err(StorageError::from)?;
605
606 let has_context = self.load_context()?.is_some();
607
608 let schema_version = self.get_schema_version()?.unwrap_or(0);
609
610 let db_size = self
611 .path
612 .as_ref()
613 .and_then(|p| std::fs::metadata(p).ok().map(|m| m.len()));
614
615 Ok(StorageStats {
616 buffer_count,
617 chunk_count: chunk_count as usize,
618 total_content_size: total_size as usize,
619 has_context,
620 schema_version,
621 db_size,
622 })
623 }
624}
625
626impl SqliteStorage {
629 pub fn get_chunks_by_ids(&self, ids: &[i64]) -> Result<std::collections::HashMap<i64, Chunk>> {
640 if ids.is_empty() {
641 return Ok(std::collections::HashMap::new());
642 }
643
644 let placeholders = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
645 let sql = format!(
646 "SELECT id, buffer_id, content, byte_start, byte_end, chunk_index, \
647 strategy, token_count, line_start, line_end, has_overlap, \
648 content_hash, custom_metadata, created_at \
649 FROM chunks WHERE id IN ({placeholders})"
650 );
651
652 let mut stmt = self.conn.prepare(&sql).map_err(StorageError::from)?;
653
654 let chunks = stmt
655 .query_map(rusqlite::params_from_iter(ids.iter().copied()), |row| {
656 let line_start: Option<i64> = row.get(8)?;
657 let line_end: Option<i64> = row.get(9)?;
658 let line_range = match (line_start, line_end) {
659 (Some(s), Some(e)) => Some((s as usize)..(e as usize)),
660 _ => None,
661 };
662
663 Ok(Chunk {
664 id: Some(row.get::<_, i64>(0)?),
665 buffer_id: row.get(1)?,
666 content: row.get(2)?,
667 byte_range: (row.get::<_, i64>(3)? as usize)..(row.get::<_, i64>(4)? as usize),
668 index: row.get::<_, i64>(5)? as usize,
669 metadata: ChunkMetadata {
670 strategy: row.get(6)?,
671 token_count: row.get::<_, Option<i64>>(7)?.map(|c| c as usize),
672 line_range,
673 has_overlap: row.get::<_, i64>(10)? != 0,
674 content_hash: row.get(11)?,
675 custom: row.get(12)?,
676 created_at: row.get(13)?,
677 },
678 })
679 })
680 .map_err(StorageError::from)?
681 .collect::<std::result::Result<Vec<_>, _>>()
682 .map_err(StorageError::from)?;
683
684 Ok(chunks
685 .into_iter()
686 .filter_map(|c| c.id.map(|id| (id, c)))
687 .collect())
688 }
689
690 #[allow(clippy::cast_possible_wrap)]
702 pub fn store_embedding(
703 &mut self,
704 chunk_id: i64,
705 embedding: &[f32],
706 model_name: Option<&str>,
707 ) -> Result<()> {
708 let now = Self::now();
709
710 let mut bytes = Vec::with_capacity(embedding.len() * 4);
712 for f in embedding {
713 bytes.extend_from_slice(&f.to_le_bytes());
714 }
715
716 self.conn
717 .execute(
718 r"
719 INSERT OR REPLACE INTO chunk_embeddings (chunk_id, embedding, dimensions, model_name, created_at)
720 VALUES (?, ?, ?, ?, ?)
721 ",
722 params![chunk_id, bytes, embedding.len() as i64, model_name, now],
723 )
724 .map_err(StorageError::from)?;
725
726 Ok(())
727 }
728
729 pub fn get_embedding(&self, chunk_id: i64) -> Result<Option<Vec<f32>>> {
735 let result: Option<Vec<u8>> = self
736 .conn
737 .query_row(
738 "SELECT embedding FROM chunk_embeddings WHERE chunk_id = ?",
739 params![chunk_id],
740 |row| row.get(0),
741 )
742 .optional()
743 .map_err(StorageError::from)?;
744
745 Ok(result.map(|bytes| {
746 bytes
747 .chunks_exact(4)
748 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
749 .collect()
750 }))
751 }
752
753 pub fn get_embedding_models(&self, buffer_id: i64) -> Result<Vec<String>> {
762 let mut stmt = self
763 .conn
764 .prepare(
765 r"
766 SELECT DISTINCT ce.model_name
767 FROM chunk_embeddings ce
768 JOIN chunks c ON ce.chunk_id = c.id
769 WHERE c.buffer_id = ? AND ce.model_name IS NOT NULL
770 ",
771 )
772 .map_err(StorageError::from)?;
773
774 let models = stmt
775 .query_map(params![buffer_id], |row| row.get::<_, String>(0))
776 .map_err(StorageError::from)?
777 .filter_map(std::result::Result::ok)
778 .collect();
779
780 Ok(models)
781 }
782
783 pub fn get_embedding_model_counts(&self, buffer_id: i64) -> Result<Vec<(Option<String>, i64)>> {
791 let mut stmt = self
792 .conn
793 .prepare(
794 r"
795 SELECT ce.model_name, COUNT(*) as count
796 FROM chunk_embeddings ce
797 JOIN chunks c ON ce.chunk_id = c.id
798 WHERE c.buffer_id = ?
799 GROUP BY ce.model_name
800 ",
801 )
802 .map_err(StorageError::from)?;
803
804 let counts = stmt
805 .query_map(params![buffer_id], |row| {
806 Ok((row.get::<_, Option<String>>(0)?, row.get::<_, i64>(1)?))
807 })
808 .map_err(StorageError::from)?
809 .filter_map(std::result::Result::ok)
810 .collect();
811
812 Ok(counts)
813 }
814
815 #[allow(clippy::cast_possible_wrap)]
821 pub fn store_embeddings_batch(
822 &mut self,
823 embeddings: &[(i64, Vec<f32>)],
824 model_name: Option<&str>,
825 ) -> Result<()> {
826 let tx = self.conn.transaction().map_err(StorageError::from)?;
827 let now = Self::now();
828
829 {
830 let mut stmt = tx
831 .prepare(
832 r"
833 INSERT OR REPLACE INTO chunk_embeddings (chunk_id, embedding, dimensions, model_name, created_at)
834 VALUES (?, ?, ?, ?, ?)
835 ",
836 )
837 .map_err(StorageError::from)?;
838
839 for (chunk_id, embedding) in embeddings {
840 let mut bytes = Vec::with_capacity(embedding.len() * 4);
841 bytes.extend(embedding.iter().flat_map(|f| f.to_le_bytes()));
842
843 stmt.execute(params![
844 chunk_id,
845 bytes,
846 embedding.len() as i64,
847 model_name,
848 now
849 ])
850 .map_err(StorageError::from)?;
851 }
852 }
853
854 tx.commit().map_err(StorageError::from)?;
855 Ok(())
856 }
857
858 pub fn delete_embedding(&mut self, chunk_id: i64) -> Result<()> {
864 self.conn
865 .execute(
866 "DELETE FROM chunk_embeddings WHERE chunk_id = ?",
867 params![chunk_id],
868 )
869 .map_err(StorageError::from)?;
870 Ok(())
871 }
872
873 #[allow(clippy::cast_possible_wrap)]
886 pub fn search_fts(&self, query: &str, limit: usize) -> Result<Vec<(i64, f64)>> {
887 let fts_query = query
894 .split_whitespace()
895 .map(|term| format!("\"{}\"", term.replace('"', "\"\"")))
896 .collect::<Vec<_>>()
897 .join(" OR ");
898
899 let mut stmt = self
900 .conn
901 .prepare(
902 r"
903 SELECT rowid, -bm25(chunks_fts) as score
904 FROM chunks_fts
905 WHERE chunks_fts MATCH ?
906 ORDER BY score DESC
907 LIMIT ?
908 ",
909 )
910 .map_err(StorageError::from)?;
911
912 let results = stmt
913 .query_map(params![fts_query, limit as i64], |row| {
914 Ok((row.get::<_, i64>(0)?, row.get::<_, f64>(1)?))
915 })
916 .map_err(StorageError::from)?
917 .collect::<std::result::Result<Vec<_>, _>>()
918 .map_err(StorageError::from)?;
919
920 Ok(results)
921 }
922
923 pub fn get_all_embeddings(&self) -> Result<Vec<(i64, Vec<f32>)>> {
929 let mut stmt = self
930 .conn
931 .prepare("SELECT chunk_id, embedding FROM chunk_embeddings")
932 .map_err(StorageError::from)?;
933
934 let results = stmt
935 .query_map([], |row| {
936 let chunk_id: i64 = row.get(0)?;
937 let bytes: Vec<u8> = row.get(1)?;
938 let embedding: Vec<f32> = bytes
939 .chunks_exact(4)
940 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
941 .collect();
942 Ok((chunk_id, embedding))
943 })
944 .map_err(StorageError::from)?
945 .collect::<std::result::Result<Vec<_>, _>>()
946 .map_err(StorageError::from)?;
947
948 Ok(results)
949 }
950
951 pub fn embedding_count(&self) -> Result<usize> {
957 let count: i64 = self
958 .conn
959 .query_row("SELECT COUNT(*) FROM chunk_embeddings", [], |row| {
960 row.get(0)
961 })
962 .map_err(StorageError::from)?;
963 Ok(count as usize)
964 }
965
966 pub fn all_chunks_have_embeddings(&self, buffer_id: i64) -> Result<bool> {
975 let result: i64 = self
976 .conn
977 .query_row(
978 r"
979 SELECT NOT EXISTS (
980 SELECT 1 FROM chunks c
981 LEFT JOIN chunk_embeddings e ON e.chunk_id = c.id
982 WHERE c.buffer_id = ? AND e.chunk_id IS NULL
983 )
984 ",
985 params![buffer_id],
986 |row| row.get(0),
987 )
988 .map_err(StorageError::from)?;
989 Ok(result != 0)
990 }
991
992 pub fn has_embedding(&self, chunk_id: i64) -> Result<bool> {
998 let count: i64 = self
999 .conn
1000 .query_row(
1001 "SELECT COUNT(*) FROM chunk_embeddings WHERE chunk_id = ?",
1002 params![chunk_id],
1003 |row| row.get(0),
1004 )
1005 .map_err(StorageError::from)?;
1006 Ok(count > 0)
1007 }
1008
1009 pub fn get_chunks_needing_embedding(
1025 &self,
1026 buffer_id: i64,
1027 current_model: Option<&str>,
1028 ) -> Result<Vec<i64>> {
1029 let mut results = Vec::new();
1030
1031 let mut stmt = self
1033 .conn
1034 .prepare(
1035 r"
1036 SELECT c.id FROM chunks c
1037 LEFT JOIN chunk_embeddings e ON c.id = e.chunk_id
1038 WHERE c.buffer_id = ? AND e.chunk_id IS NULL
1039 ",
1040 )
1041 .map_err(StorageError::from)?;
1042
1043 let rows = stmt
1044 .query_map(params![buffer_id], |row| row.get(0))
1045 .map_err(StorageError::from)?;
1046
1047 for row in rows {
1048 results.push(row.map_err(StorageError::from)?);
1049 }
1050
1051 if let Some(model) = current_model {
1053 let mut stmt = self
1054 .conn
1055 .prepare(
1056 r"
1057 SELECT c.id FROM chunks c
1058 INNER JOIN chunk_embeddings e ON c.id = e.chunk_id
1059 WHERE c.buffer_id = ? AND (e.model_name IS NULL OR e.model_name != ?)
1060 ",
1061 )
1062 .map_err(StorageError::from)?;
1063
1064 let rows = stmt
1065 .query_map(params![buffer_id, model], |row| row.get(0))
1066 .map_err(StorageError::from)?;
1067
1068 for row in rows {
1069 results.push(row.map_err(StorageError::from)?);
1070 }
1071 }
1072
1073 results.sort_unstable();
1075 results.dedup();
1076 Ok(results)
1077 }
1078
1079 pub fn get_chunks_without_embedding(&self, buffer_id: i64) -> Result<Vec<i64>> {
1087 self.get_chunks_needing_embedding(buffer_id, None)
1088 }
1089
1090 pub fn delete_embeddings_by_model(
1107 &mut self,
1108 buffer_id: i64,
1109 model_name: Option<&str>,
1110 ) -> Result<usize> {
1111 let deleted = match model_name {
1112 Some(name) => self
1113 .conn
1114 .execute(
1115 r"
1116 DELETE FROM chunk_embeddings
1117 WHERE chunk_id IN (
1118 SELECT id FROM chunks WHERE buffer_id = ?
1119 ) AND model_name = ?
1120 ",
1121 params![buffer_id, name],
1122 )
1123 .map_err(StorageError::from)?,
1124 None => self
1125 .conn
1126 .execute(
1127 r"
1128 DELETE FROM chunk_embeddings
1129 WHERE chunk_id IN (
1130 SELECT id FROM chunks WHERE buffer_id = ?
1131 ) AND model_name IS NULL
1132 ",
1133 params![buffer_id],
1134 )
1135 .map_err(StorageError::from)?,
1136 };
1137 Ok(deleted)
1138 }
1139
1140 pub fn get_embedding_stats(&self, buffer_id: i64) -> Result<EmbeddingStats> {
1148 let total_chunks: i64 = self
1150 .conn
1151 .query_row(
1152 "SELECT COUNT(*) FROM chunks WHERE buffer_id = ?",
1153 params![buffer_id],
1154 |row| row.get(0),
1155 )
1156 .map_err(StorageError::from)?;
1157
1158 let embedded_chunks: i64 = self
1160 .conn
1161 .query_row(
1162 r"
1163 SELECT COUNT(*) FROM chunk_embeddings e
1164 INNER JOIN chunks c ON e.chunk_id = c.id
1165 WHERE c.buffer_id = ?
1166 ",
1167 params![buffer_id],
1168 |row| row.get(0),
1169 )
1170 .map_err(StorageError::from)?;
1171
1172 let model_counts = self.get_embedding_model_counts(buffer_id)?;
1174
1175 Ok(EmbeddingStats {
1176 total_chunks: total_chunks as usize,
1177 embedded_chunks: embedded_chunks as usize,
1178 model_counts,
1179 })
1180 }
1181}
1182
1183#[derive(Debug, Clone)]
1185pub struct EmbeddingStats {
1186 pub total_chunks: usize,
1188 pub embedded_chunks: usize,
1190 pub model_counts: Vec<(Option<String>, i64)>,
1192}
1193
1194#[cfg(test)]
1195mod tests {
1196 use super::*;
1197 use crate::core::ContextValue;
1198
1199 fn setup() -> SqliteStorage {
1200 let mut storage = SqliteStorage::in_memory().unwrap();
1201 storage.init().unwrap();
1202 storage
1203 }
1204
1205 #[test]
1206 fn test_init() {
1207 let mut storage = SqliteStorage::in_memory().unwrap();
1208 assert!(storage.init().is_ok());
1209 assert!(storage.is_initialized().unwrap());
1210 }
1211
1212 #[test]
1213 fn test_init_idempotent() {
1214 let mut storage = SqliteStorage::in_memory().unwrap();
1215 assert!(storage.init().is_ok());
1216 assert!(storage.init().is_ok()); }
1218
1219 #[test]
1220 fn test_context_crud() {
1221 let mut storage = setup();
1222
1223 assert!(storage.load_context().unwrap().is_none());
1225
1226 let mut ctx = Context::new();
1228 ctx.set_variable("key".to_string(), ContextValue::String("value".to_string()));
1229 storage.save_context(&ctx).unwrap();
1230
1231 let loaded = storage.load_context().unwrap().unwrap();
1233 assert_eq!(
1234 loaded.get_variable("key"),
1235 Some(&ContextValue::String("value".to_string()))
1236 );
1237
1238 storage.delete_context().unwrap();
1240 assert!(storage.load_context().unwrap().is_none());
1241 }
1242
1243 #[test]
1244 fn test_buffer_crud() {
1245 let mut storage = setup();
1246
1247 let buffer = Buffer::from_named("test".to_string(), "Hello, world!".to_string());
1249 let id = storage.add_buffer(&buffer).unwrap();
1250 assert!(id > 0);
1251
1252 let loaded = storage.get_buffer(id).unwrap().unwrap();
1254 assert_eq!(loaded.name, Some("test".to_string()));
1255 assert_eq!(loaded.content, "Hello, world!");
1256
1257 let by_name = storage.get_buffer_by_name("test").unwrap().unwrap();
1259 assert_eq!(by_name.id, Some(id));
1260
1261 let buffers = storage.list_buffers().unwrap();
1263 assert_eq!(buffers.len(), 1);
1264
1265 let mut updated = loaded;
1267 updated.content = "Updated content".to_string();
1268 storage.update_buffer(&updated).unwrap();
1269
1270 let reloaded = storage.get_buffer(id).unwrap().unwrap();
1271 assert_eq!(reloaded.content, "Updated content");
1272
1273 storage.delete_buffer(id).unwrap();
1275 assert!(storage.get_buffer(id).unwrap().is_none());
1276 }
1277
1278 #[test]
1279 fn test_chunk_crud() {
1280 let mut storage = setup();
1281
1282 let buffer = Buffer::from_content("Hello, world!".to_string());
1284 let buffer_id = storage.add_buffer(&buffer).unwrap();
1285
1286 let chunks = vec![
1288 Chunk::new(buffer_id, "Hello, ".to_string(), 0..7, 0),
1289 Chunk::new(buffer_id, "world!".to_string(), 7..13, 1),
1290 ];
1291 storage.add_chunks(buffer_id, &chunks).unwrap();
1292
1293 let loaded = storage.get_chunks(buffer_id).unwrap();
1295 assert_eq!(loaded.len(), 2);
1296 assert_eq!(loaded[0].content, "Hello, ");
1297 assert_eq!(loaded[1].content, "world!");
1298
1299 assert_eq!(storage.chunk_count(buffer_id).unwrap(), 2);
1301
1302 let chunk_id = loaded[0].id.unwrap();
1304 let single = storage.get_chunk(chunk_id).unwrap().unwrap();
1305 assert_eq!(single.content, "Hello, ");
1306
1307 storage.delete_chunks(buffer_id).unwrap();
1309 assert_eq!(storage.chunk_count(buffer_id).unwrap(), 0);
1310 }
1311
1312 #[test]
1313 fn test_cascade_delete() {
1314 let mut storage = setup();
1315
1316 let buffer = Buffer::from_content("Hello, world!".to_string());
1318 let buffer_id = storage.add_buffer(&buffer).unwrap();
1319
1320 let chunks = vec![Chunk::new(buffer_id, "Hello".to_string(), 0..5, 0)];
1321 storage.add_chunks(buffer_id, &chunks).unwrap();
1322
1323 assert_eq!(storage.chunk_count(buffer_id).unwrap(), 1);
1325
1326 storage.delete_buffer(buffer_id).unwrap();
1328
1329 let count: i64 = storage
1331 .conn
1332 .query_row("SELECT COUNT(*) FROM chunks", [], |row| row.get(0))
1333 .unwrap();
1334 assert_eq!(count, 0);
1335 }
1336
1337 #[test]
1338 fn test_reset() {
1339 let mut storage = setup();
1340
1341 let ctx = Context::new();
1343 storage.save_context(&ctx).unwrap();
1344
1345 let buffer = Buffer::from_content("test".to_string());
1346 storage.add_buffer(&buffer).unwrap();
1347
1348 storage.reset().unwrap();
1350
1351 assert!(storage.load_context().unwrap().is_none());
1353 assert_eq!(storage.buffer_count().unwrap(), 0);
1354 }
1355
1356 #[test]
1357 fn test_stats() {
1358 let mut storage = setup();
1359
1360 let stats = storage.stats().unwrap();
1362 assert_eq!(stats.buffer_count, 0);
1363 assert_eq!(stats.chunk_count, 0);
1364 assert!(!stats.has_context);
1365
1366 let ctx = Context::new();
1368 storage.save_context(&ctx).unwrap();
1369
1370 let buffer = Buffer::from_content("Hello, world!".to_string());
1371 let buffer_id = storage.add_buffer(&buffer).unwrap();
1372
1373 let chunks = vec![Chunk::new(buffer_id, "Hello".to_string(), 0..5, 0)];
1374 storage.add_chunks(buffer_id, &chunks).unwrap();
1375
1376 let stats = storage.stats().unwrap();
1378 assert_eq!(stats.buffer_count, 1);
1379 assert_eq!(stats.chunk_count, 1);
1380 assert!(stats.has_context);
1381 assert_eq!(stats.total_content_size, 13);
1382 }
1383
1384 #[test]
1385 fn test_export_buffers() {
1386 let mut storage = setup();
1387
1388 storage
1389 .add_buffer(&Buffer::from_content("First".to_string()))
1390 .unwrap();
1391 storage
1392 .add_buffer(&Buffer::from_content("Second".to_string()))
1393 .unwrap();
1394
1395 let exported = storage.export_buffers().unwrap();
1396 assert_eq!(exported, "First\n\nSecond");
1397 }
1398
1399 fn setup_buffer_with_chunk(storage: &mut SqliteStorage) -> (i64, i64) {
1401 let buffer = Buffer::from_content("test content".to_string());
1402 let buffer_id = storage.add_buffer(&buffer).unwrap();
1403 let chunks = vec![Chunk::new(buffer_id, "test content".to_string(), 0..12, 0)];
1404 storage.add_chunks(buffer_id, &chunks).unwrap();
1405 let chunk_id = storage.get_chunks(buffer_id).unwrap()[0].id.unwrap();
1406 (buffer_id, chunk_id)
1407 }
1408
1409 #[test]
1410 fn test_store_and_get_embedding() {
1411 let mut storage = setup();
1412 let (_buffer_id, chunk_id) = setup_buffer_with_chunk(&mut storage);
1413
1414 let embedding = vec![0.1_f32, 0.2, 0.3, 0.4];
1415 storage
1416 .store_embedding(chunk_id, &embedding, Some("test-model"))
1417 .unwrap();
1418
1419 let loaded = storage.get_embedding(chunk_id).unwrap().unwrap();
1420 assert_eq!(loaded.len(), 4);
1421 for (a, b) in loaded.iter().zip(embedding.iter()) {
1422 assert!((a - b).abs() < 1e-6, "expected {b}, got {a}");
1423 }
1424 }
1425
1426 #[test]
1427 fn test_get_embedding_nonexistent() {
1428 let storage = setup();
1429 let result = storage.get_embedding(9999).unwrap();
1430 assert!(result.is_none());
1431 }
1432
1433 #[test]
1434 fn test_store_embedding_upsert() {
1435 let mut storage = setup();
1436 let (_buffer_id, chunk_id) = setup_buffer_with_chunk(&mut storage);
1437
1438 let embedding1 = vec![0.1_f32, 0.2];
1439 storage
1440 .store_embedding(chunk_id, &embedding1, Some("model-a"))
1441 .unwrap();
1442
1443 let embedding2 = vec![0.9_f32, 0.8];
1445 storage
1446 .store_embedding(chunk_id, &embedding2, Some("model-a"))
1447 .unwrap();
1448
1449 let loaded = storage.get_embedding(chunk_id).unwrap().unwrap();
1450 for (a, b) in loaded.iter().zip(embedding2.iter()) {
1451 assert!((a - b).abs() < 1e-6);
1452 }
1453 }
1454
1455 #[test]
1456 fn test_has_embedding() {
1457 let mut storage = setup();
1458 let (_buffer_id, chunk_id) = setup_buffer_with_chunk(&mut storage);
1459
1460 assert!(!storage.has_embedding(chunk_id).unwrap());
1461
1462 storage.store_embedding(chunk_id, &[0.1_f32], None).unwrap();
1463
1464 assert!(storage.has_embedding(chunk_id).unwrap());
1465 }
1466
1467 #[test]
1468 fn test_embedding_count() {
1469 let mut storage = setup();
1470 assert_eq!(storage.embedding_count().unwrap(), 0);
1471
1472 let (buffer_id, chunk_id) = setup_buffer_with_chunk(&mut storage);
1473
1474 let chunks2 = vec![Chunk::new(buffer_id, "more".to_string(), 0..4, 1)];
1476 storage.add_chunks(buffer_id, &chunks2).unwrap();
1477 let chunk_id2 = storage
1478 .get_chunks(buffer_id)
1479 .unwrap()
1480 .into_iter()
1481 .find(|c| c.id != Some(chunk_id))
1482 .unwrap()
1483 .id
1484 .unwrap();
1485
1486 storage.store_embedding(chunk_id, &[0.1_f32], None).unwrap();
1487 assert_eq!(storage.embedding_count().unwrap(), 1);
1488
1489 storage
1490 .store_embedding(chunk_id2, &[0.2_f32], None)
1491 .unwrap();
1492 assert_eq!(storage.embedding_count().unwrap(), 2);
1493 }
1494
1495 #[test]
1496 fn test_delete_embedding() {
1497 let mut storage = setup();
1498 let (_buffer_id, chunk_id) = setup_buffer_with_chunk(&mut storage);
1499
1500 storage.store_embedding(chunk_id, &[0.1_f32], None).unwrap();
1501 assert!(storage.has_embedding(chunk_id).unwrap());
1502
1503 storage.delete_embedding(chunk_id).unwrap();
1504 assert!(!storage.has_embedding(chunk_id).unwrap());
1505 }
1506
1507 #[test]
1508 fn test_store_embeddings_batch() {
1509 let mut storage = setup();
1510 let (buffer_id, chunk_id1) = setup_buffer_with_chunk(&mut storage);
1511
1512 let chunks2 = vec![Chunk::new(buffer_id, "second".to_string(), 0..6, 1)];
1514 storage.add_chunks(buffer_id, &chunks2).unwrap();
1515 let chunk_id2 = storage
1516 .get_chunks(buffer_id)
1517 .unwrap()
1518 .into_iter()
1519 .find(|c| c.id != Some(chunk_id1))
1520 .unwrap()
1521 .id
1522 .unwrap();
1523
1524 let batch = vec![
1525 (chunk_id1, vec![0.1_f32, 0.2]),
1526 (chunk_id2, vec![0.3_f32, 0.4]),
1527 ];
1528 storage
1529 .store_embeddings_batch(&batch, Some("batch-model"))
1530 .unwrap();
1531
1532 assert!(storage.has_embedding(chunk_id1).unwrap());
1533 assert!(storage.has_embedding(chunk_id2).unwrap());
1534 assert_eq!(storage.embedding_count().unwrap(), 2);
1535 }
1536
1537 #[test]
1538 fn test_get_all_embeddings() {
1539 let mut storage = setup();
1540 let (buffer_id, chunk_id1) = setup_buffer_with_chunk(&mut storage);
1541
1542 let chunks2 = vec![Chunk::new(buffer_id, "second".to_string(), 0..6, 1)];
1543 storage.add_chunks(buffer_id, &chunks2).unwrap();
1544 let chunk_id2 = storage
1545 .get_chunks(buffer_id)
1546 .unwrap()
1547 .into_iter()
1548 .find(|c| c.id != Some(chunk_id1))
1549 .unwrap()
1550 .id
1551 .unwrap();
1552
1553 storage
1554 .store_embedding(chunk_id1, &[0.1_f32], Some("m"))
1555 .unwrap();
1556 storage
1557 .store_embedding(chunk_id2, &[0.2_f32], Some("m"))
1558 .unwrap();
1559
1560 let all = storage.get_all_embeddings().unwrap();
1561 assert_eq!(all.len(), 2);
1562 }
1563
1564 #[test]
1565 fn test_get_embedding_models() {
1566 let mut storage = setup();
1567 let (buffer_id, chunk_id) = setup_buffer_with_chunk(&mut storage);
1568
1569 assert!(storage.get_embedding_models(buffer_id).unwrap().is_empty());
1571
1572 storage
1573 .store_embedding(chunk_id, &[0.1_f32], Some("model-x"))
1574 .unwrap();
1575
1576 let models = storage.get_embedding_models(buffer_id).unwrap();
1577 assert_eq!(models.len(), 1);
1578 assert_eq!(models[0], "model-x");
1579 }
1580
1581 #[test]
1582 fn test_get_embedding_model_counts() {
1583 let mut storage = setup();
1584 let (buffer_id, chunk_id1) = setup_buffer_with_chunk(&mut storage);
1585
1586 let chunks2 = vec![Chunk::new(buffer_id, "extra".to_string(), 0..5, 1)];
1587 storage.add_chunks(buffer_id, &chunks2).unwrap();
1588 let chunk_id2 = storage
1589 .get_chunks(buffer_id)
1590 .unwrap()
1591 .into_iter()
1592 .find(|c| c.id != Some(chunk_id1))
1593 .unwrap()
1594 .id
1595 .unwrap();
1596
1597 storage
1598 .store_embedding(chunk_id1, &[0.1_f32], Some("model-a"))
1599 .unwrap();
1600 storage
1601 .store_embedding(chunk_id2, &[0.2_f32], Some("model-a"))
1602 .unwrap();
1603
1604 let counts = storage.get_embedding_model_counts(buffer_id).unwrap();
1605 assert_eq!(counts.len(), 1);
1606 assert_eq!(counts[0], (Some("model-a".to_string()), 2));
1607 }
1608
1609 #[test]
1610 fn test_get_chunks_needing_embedding_no_embeddings() {
1611 let mut storage = setup();
1612 let (buffer_id, _chunk_id) = setup_buffer_with_chunk(&mut storage);
1613
1614 let needing = storage
1615 .get_chunks_needing_embedding(buffer_id, None)
1616 .unwrap();
1617 assert_eq!(needing.len(), 1);
1618 }
1619
1620 #[test]
1621 fn test_get_chunks_needing_embedding_with_model() {
1622 let mut storage = setup();
1623 let (buffer_id, chunk_id) = setup_buffer_with_chunk(&mut storage);
1624
1625 let chunks2 = vec![Chunk::new(buffer_id, "b".to_string(), 0..1, 1)];
1627 storage.add_chunks(buffer_id, &chunks2).unwrap();
1628 let chunk_id2 = storage
1629 .get_chunks(buffer_id)
1630 .unwrap()
1631 .into_iter()
1632 .find(|c| c.id != Some(chunk_id))
1633 .unwrap()
1634 .id
1635 .unwrap();
1636
1637 storage
1639 .store_embedding(chunk_id, &[0.1_f32], Some("model-a"))
1640 .unwrap();
1641
1642 let needing = storage
1644 .get_chunks_needing_embedding(buffer_id, Some("model-a"))
1645 .unwrap();
1646 assert!(needing.contains(&chunk_id2));
1647 assert!(!needing.contains(&chunk_id));
1648
1649 let needing_b = storage
1651 .get_chunks_needing_embedding(buffer_id, Some("model-b"))
1652 .unwrap();
1653 assert!(needing_b.contains(&chunk_id));
1654 assert!(needing_b.contains(&chunk_id2));
1655 }
1656
1657 #[test]
1658 fn test_get_chunks_without_embedding() {
1659 let mut storage = setup();
1660 let (buffer_id, chunk_id) = setup_buffer_with_chunk(&mut storage);
1661
1662 let without = storage.get_chunks_without_embedding(buffer_id).unwrap();
1663 assert_eq!(without.len(), 1);
1664 assert!(without.contains(&chunk_id));
1665
1666 storage.store_embedding(chunk_id, &[0.1_f32], None).unwrap();
1667
1668 let without_after = storage.get_chunks_without_embedding(buffer_id).unwrap();
1669 assert!(without_after.is_empty());
1670 }
1671
1672 #[test]
1673 fn test_delete_embeddings_by_model_named() {
1674 let mut storage = setup();
1675 let (buffer_id, chunk_id1) = setup_buffer_with_chunk(&mut storage);
1676
1677 let chunks2 = vec![Chunk::new(buffer_id, "b".to_string(), 0..1, 1)];
1678 storage.add_chunks(buffer_id, &chunks2).unwrap();
1679 let chunk_id2 = storage
1680 .get_chunks(buffer_id)
1681 .unwrap()
1682 .into_iter()
1683 .find(|c| c.id != Some(chunk_id1))
1684 .unwrap()
1685 .id
1686 .unwrap();
1687
1688 storage
1689 .store_embedding(chunk_id1, &[0.1_f32], Some("model-a"))
1690 .unwrap();
1691 storage
1692 .store_embedding(chunk_id2, &[0.2_f32], Some("model-b"))
1693 .unwrap();
1694
1695 let deleted = storage
1696 .delete_embeddings_by_model(buffer_id, Some("model-a"))
1697 .unwrap();
1698 assert_eq!(deleted, 1);
1699 assert!(!storage.has_embedding(chunk_id1).unwrap());
1700 assert!(storage.has_embedding(chunk_id2).unwrap());
1701 }
1702
1703 #[test]
1704 fn test_delete_embeddings_by_model_null() {
1705 let mut storage = setup();
1706 let (buffer_id, chunk_id1) = setup_buffer_with_chunk(&mut storage);
1707
1708 let chunks2 = vec![Chunk::new(buffer_id, "b".to_string(), 0..1, 1)];
1709 storage.add_chunks(buffer_id, &chunks2).unwrap();
1710 let chunk_id2 = storage
1711 .get_chunks(buffer_id)
1712 .unwrap()
1713 .into_iter()
1714 .find(|c| c.id != Some(chunk_id1))
1715 .unwrap()
1716 .id
1717 .unwrap();
1718
1719 storage
1720 .store_embedding(chunk_id1, &[0.1_f32], None)
1721 .unwrap();
1722 storage
1723 .store_embedding(chunk_id2, &[0.2_f32], Some("model-b"))
1724 .unwrap();
1725
1726 let deleted = storage.delete_embeddings_by_model(buffer_id, None).unwrap();
1728 assert_eq!(deleted, 1);
1729 assert!(!storage.has_embedding(chunk_id1).unwrap());
1730 assert!(storage.has_embedding(chunk_id2).unwrap());
1731 }
1732
1733 #[test]
1734 fn test_get_embedding_stats() {
1735 let mut storage = setup();
1736 let (buffer_id, chunk_id) = setup_buffer_with_chunk(&mut storage);
1737
1738 let chunks2 = vec![Chunk::new(buffer_id, "extra".to_string(), 0..5, 1)];
1740 storage.add_chunks(buffer_id, &chunks2).unwrap();
1741
1742 let stats = storage.get_embedding_stats(buffer_id).unwrap();
1744 assert_eq!(stats.total_chunks, 2);
1745 assert_eq!(stats.embedded_chunks, 0);
1746
1747 storage
1749 .store_embedding(chunk_id, &[0.1_f32], Some("m1"))
1750 .unwrap();
1751
1752 let stats = storage.get_embedding_stats(buffer_id).unwrap();
1753 assert_eq!(stats.total_chunks, 2);
1754 assert_eq!(stats.embedded_chunks, 1);
1755 assert_eq!(stats.model_counts.len(), 1);
1756 }
1757
1758 #[test]
1759 fn test_search_fts_finds_match() {
1760 let mut storage = setup();
1761 let buffer = Buffer::from_content("The quick brown fox jumps".to_string());
1762 let buffer_id = storage.add_buffer(&buffer).unwrap();
1763
1764 let chunks = vec![
1765 Chunk::new(buffer_id, "The quick brown fox".to_string(), 0..19, 0),
1766 Chunk::new(buffer_id, "fox jumps".to_string(), 20..29, 1),
1767 ];
1768 storage.add_chunks(buffer_id, &chunks).unwrap();
1769
1770 let results = storage.search_fts("quick", 10).unwrap();
1771 assert!(!results.is_empty());
1773 for (_id, score) in &results {
1775 assert!(
1776 *score >= 0.0,
1777 "BM25 score should be non-negative, got {score}"
1778 );
1779 }
1780 }
1781
1782 #[test]
1783 fn test_search_fts_no_match() {
1784 let mut storage = setup();
1785 let buffer = Buffer::from_content("The quick brown fox".to_string());
1786 let buffer_id = storage.add_buffer(&buffer).unwrap();
1787
1788 let chunks = vec![Chunk::new(
1789 buffer_id,
1790 "The quick brown fox".to_string(),
1791 0..19,
1792 0,
1793 )];
1794 storage.add_chunks(buffer_id, &chunks).unwrap();
1795
1796 let results = storage.search_fts("zzzyyyxxx", 10).unwrap();
1797 assert!(results.is_empty());
1798 }
1799
1800 #[test]
1801 fn test_search_fts_respects_limit() {
1802 let mut storage = setup();
1803 let buffer = Buffer::from_content("hello world hello world hello".to_string());
1804 let buffer_id = storage.add_buffer(&buffer).unwrap();
1805
1806 let chunks = vec![
1808 Chunk::new(buffer_id, "hello world".to_string(), 0..11, 0),
1809 Chunk::new(buffer_id, "hello world".to_string(), 12..23, 1),
1810 Chunk::new(buffer_id, "hello".to_string(), 24..29, 2),
1811 ];
1812 storage.add_chunks(buffer_id, &chunks).unwrap();
1813
1814 let results = storage.search_fts("hello", 2).unwrap();
1815 assert!(results.len() <= 2);
1816 }
1817
1818 #[test]
1819 fn test_get_chunks_by_ids_empty() {
1820 let mut storage = SqliteStorage::in_memory().unwrap();
1821 storage.init().unwrap();
1822
1823 let result = storage.get_chunks_by_ids(&[]).unwrap();
1824 assert!(result.is_empty());
1825 }
1826
1827 #[test]
1828 fn test_get_chunks_by_ids_batch() {
1829 let mut storage = SqliteStorage::in_memory().unwrap();
1830 storage.init().unwrap();
1831
1832 let buffer_id = storage
1833 .add_buffer(&Buffer::from_content("abc def ghi".to_string()))
1834 .unwrap();
1835 let chunks = vec![
1836 Chunk::new(buffer_id, "abc".to_string(), 0..3, 0),
1837 Chunk::new(buffer_id, "def".to_string(), 4..7, 1),
1838 Chunk::new(buffer_id, "ghi".to_string(), 8..11, 2),
1839 ];
1840 storage.add_chunks(buffer_id, &chunks).unwrap();
1841
1842 let all = storage.get_chunks(buffer_id).unwrap();
1844 let ids: Vec<i64> = all.iter().filter_map(|c| c.id).take(2).collect();
1845 assert_eq!(ids.len(), 2);
1846
1847 let map = storage.get_chunks_by_ids(&ids).unwrap();
1848 assert_eq!(map.len(), 2);
1849 for id in &ids {
1850 assert!(map.contains_key(id));
1851 }
1852 }
1853
1854 #[test]
1855 fn test_get_chunks_by_ids_missing_id() {
1856 let mut storage = SqliteStorage::in_memory().unwrap();
1857 storage.init().unwrap();
1858
1859 let map = storage.get_chunks_by_ids(&[99999]).unwrap();
1861 assert!(map.is_empty());
1862 }
1863
1864 #[test]
1865 fn test_fts_and_embedding_pipeline() {
1866 let mut storage = setup();
1869 let buffer = Buffer::from_content("machine learning and neural networks".to_string());
1870 let buffer_id = storage.add_buffer(&buffer).unwrap();
1871
1872 let chunks = vec![
1873 Chunk::new(
1874 buffer_id,
1875 "machine learning algorithms".to_string(),
1876 0..27,
1877 0,
1878 ),
1879 Chunk::new(
1880 buffer_id,
1881 "neural networks architecture".to_string(),
1882 28..56,
1883 1,
1884 ),
1885 ];
1886 storage.add_chunks(buffer_id, &chunks).unwrap();
1887
1888 let chunk_ids: Vec<i64> = storage
1890 .get_chunks(buffer_id)
1891 .unwrap()
1892 .into_iter()
1893 .map(|c| c.id.unwrap())
1894 .collect();
1895 assert_eq!(chunk_ids.len(), 2);
1896
1897 let embeddings: &[&[f32]] = &[&[0.1, 0.2], &[0.2, 0.4]];
1898 for (&chunk_id, embedding) in chunk_ids.iter().zip(embeddings) {
1899 storage
1900 .store_embedding(chunk_id, embedding, Some("test-model"))
1901 .unwrap();
1902 }
1903
1904 let fts_results = storage.search_fts("machine", 10).unwrap();
1906 assert!(!fts_results.is_empty());
1907
1908 for (chunk_id, _score) in &fts_results {
1910 assert!(
1911 storage.has_embedding(*chunk_id).unwrap(),
1912 "FTS result chunk {chunk_id} should have an embedding"
1913 );
1914 }
1915
1916 assert_eq!(storage.embedding_count().unwrap(), 2);
1918 }
1919}