1use std::path::PathBuf;
7use std::sync::Arc;
8
9fn sanitize_sql_string(s: &str) -> String {
19 s.replace('\0', "")
20 .replace('\\', "\\\\")
21 .replace('\'', "''")
22 .replace(';', "")
23 .replace("--", "")
24 .replace("/*", "")
25 .replace("*/", "")
26}
27
28fn is_valid_id(s: &str) -> bool {
32 !s.is_empty() && s.len() <= 64 && s.chars().all(|c| c.is_ascii_hexdigit() || c == '-')
33}
34
35use arrow_array::{
36 Array, BooleanArray, FixedSizeListArray, Float32Array, Int32Array, Int64Array, RecordBatch,
37 RecordBatchIterator, StringArray,
38};
39use arrow_schema::{DataType, Field, Schema};
40use chrono::{TimeZone, Utc};
41use futures::TryStreamExt;
42use lancedb::Table;
43use lancedb::connect;
44use lancedb::index::scalar::FullTextSearchQuery;
45use lancedb::query::{ExecutableQuery, QueryBase};
46use tracing::{debug, info};
47
48use crate::boost_similarity;
49use crate::chunker::{ChunkingConfig, chunk_content};
50use crate::document::ContentType;
51use crate::embedder::{EMBEDDING_DIM, Embedder};
52use crate::error::{Result, SedimentError};
53use crate::item::{Chunk, ConflictInfo, Item, ItemFilters, SearchResult, StoreResult};
54
55const CHUNK_THRESHOLD: usize = 1000;
57
58const CONFLICT_SIMILARITY_THRESHOLD: f32 = 0.85;
60
61const CONFLICT_SEARCH_LIMIT: usize = 5;
63
64const MAX_CHUNKS_PER_ITEM: usize = 200;
68
69const EMBEDDING_BATCH_SIZE: usize = 32;
72
73pub struct Database {
75 db: lancedb::Connection,
76 embedder: Arc<Embedder>,
77 project_id: Option<String>,
78 items_table: Option<Table>,
79 chunks_table: Option<Table>,
80}
81
82#[derive(Debug, Default, Clone)]
84pub struct DatabaseStats {
85 pub item_count: usize,
86 pub chunk_count: usize,
87}
88
89const SCHEMA_VERSION: i32 = 2;
91
92fn item_schema() -> Schema {
94 Schema::new(vec![
95 Field::new("id", DataType::Utf8, false),
96 Field::new("content", DataType::Utf8, false),
97 Field::new("project_id", DataType::Utf8, true),
98 Field::new("is_chunked", DataType::Boolean, false),
99 Field::new("created_at", DataType::Int64, false), Field::new(
101 "vector",
102 DataType::FixedSizeList(
103 Arc::new(Field::new("item", DataType::Float32, true)),
104 EMBEDDING_DIM as i32,
105 ),
106 false,
107 ),
108 ])
109}
110
111fn chunk_schema() -> Schema {
112 Schema::new(vec![
113 Field::new("id", DataType::Utf8, false),
114 Field::new("item_id", DataType::Utf8, false),
115 Field::new("chunk_index", DataType::Int32, false),
116 Field::new("content", DataType::Utf8, false),
117 Field::new("context", DataType::Utf8, true),
118 Field::new(
119 "vector",
120 DataType::FixedSizeList(
121 Arc::new(Field::new("item", DataType::Float32, true)),
122 EMBEDDING_DIM as i32,
123 ),
124 false,
125 ),
126 ])
127}
128
129impl Database {
130 pub async fn open(path: impl Into<PathBuf>) -> Result<Self> {
132 Self::open_with_project(path, None).await
133 }
134
135 pub async fn open_with_project(
137 path: impl Into<PathBuf>,
138 project_id: Option<String>,
139 ) -> Result<Self> {
140 let embedder = Arc::new(Embedder::new()?);
141 Self::open_with_embedder(path, project_id, embedder).await
142 }
143
144 pub async fn open_with_embedder(
156 path: impl Into<PathBuf>,
157 project_id: Option<String>,
158 embedder: Arc<Embedder>,
159 ) -> Result<Self> {
160 let path = path.into();
161 info!("Opening database at {:?}", path);
162
163 if let Some(parent) = path.parent() {
165 std::fs::create_dir_all(parent).map_err(|e| {
166 SedimentError::Database(format!("Failed to create database directory: {}", e))
167 })?;
168 }
169
170 let db = connect(path.to_str().ok_or_else(|| {
171 SedimentError::Database("Database path contains invalid UTF-8".to_string())
172 })?)
173 .execute()
174 .await
175 .map_err(|e| SedimentError::Database(format!("Failed to connect to database: {}", e)))?;
176
177 let mut database = Self {
178 db,
179 embedder,
180 project_id,
181 items_table: None,
182 chunks_table: None,
183 };
184
185 database.ensure_tables().await?;
186 database.ensure_vector_index().await?;
187
188 Ok(database)
189 }
190
191 pub fn set_project_id(&mut self, project_id: Option<String>) {
193 self.project_id = project_id;
194 }
195
196 pub fn project_id(&self) -> Option<&str> {
198 self.project_id.as_deref()
199 }
200
201 async fn ensure_tables(&mut self) -> Result<()> {
203 let mut table_names = self
205 .db
206 .table_names()
207 .execute()
208 .await
209 .map_err(|e| SedimentError::Database(format!("Failed to list tables: {}", e)))?;
210
211 if table_names.contains(&"items_migrated".to_string()) {
213 info!("Detected interrupted migration, recovering...");
214 self.recover_interrupted_migration(&table_names).await?;
215 table_names =
217 self.db.table_names().execute().await.map_err(|e| {
218 SedimentError::Database(format!("Failed to list tables: {}", e))
219 })?;
220 }
221
222 if table_names.contains(&"items".to_string()) {
224 let needs_migration = self.check_needs_migration().await?;
225 if needs_migration {
226 info!("Migrating database schema to version {}", SCHEMA_VERSION);
227 self.migrate_schema().await?;
228 }
229 }
230
231 if table_names.contains(&"items".to_string()) {
233 self.items_table =
234 Some(self.db.open_table("items").execute().await.map_err(|e| {
235 SedimentError::Database(format!("Failed to open items: {}", e))
236 })?);
237 }
238
239 if table_names.contains(&"chunks".to_string()) {
241 self.chunks_table =
242 Some(self.db.open_table("chunks").execute().await.map_err(|e| {
243 SedimentError::Database(format!("Failed to open chunks: {}", e))
244 })?);
245 }
246
247 Ok(())
248 }
249
250 async fn check_needs_migration(&self) -> Result<bool> {
252 let table = self.db.open_table("items").execute().await.map_err(|e| {
253 SedimentError::Database(format!("Failed to open items for check: {}", e))
254 })?;
255
256 let schema = table
257 .schema()
258 .await
259 .map_err(|e| SedimentError::Database(format!("Failed to get schema: {}", e)))?;
260
261 let has_tags = schema.fields().iter().any(|f| f.name() == "tags");
263 Ok(has_tags)
264 }
265
266 async fn recover_interrupted_migration(&mut self, table_names: &[String]) -> Result<()> {
277 let has_items = table_names.contains(&"items".to_string());
278
279 if !has_items {
280 info!("Recovery case A: restoring items from items_migrated");
282 let staging = self
283 .db
284 .open_table("items_migrated")
285 .execute()
286 .await
287 .map_err(|e| {
288 SedimentError::Database(format!("Failed to open staging table: {}", e))
289 })?;
290
291 let results = staging
292 .query()
293 .execute()
294 .await
295 .map_err(|e| SedimentError::Database(format!("Recovery query failed: {}", e)))?
296 .try_collect::<Vec<_>>()
297 .await
298 .map_err(|e| SedimentError::Database(format!("Recovery collect failed: {}", e)))?;
299
300 let schema = Arc::new(item_schema());
301 let new_table = self
302 .db
303 .create_empty_table("items", schema.clone())
304 .execute()
305 .await
306 .map_err(|e| {
307 SedimentError::Database(format!("Failed to create items table: {}", e))
308 })?;
309
310 if !results.is_empty() {
311 let batches = RecordBatchIterator::new(results.into_iter().map(Ok), schema);
312 new_table
313 .add(Box::new(batches))
314 .execute()
315 .await
316 .map_err(|e| {
317 SedimentError::Database(format!("Failed to restore items: {}", e))
318 })?;
319 }
320
321 self.db.drop_table("items_migrated").await.map_err(|e| {
322 SedimentError::Database(format!("Failed to drop staging table: {}", e))
323 })?;
324 info!("Recovery case A completed");
325 } else {
326 let has_old_schema = self.check_needs_migration().await?;
328
329 if has_old_schema {
330 info!("Recovery case B: dropping incomplete staging table");
332 self.db.drop_table("items_migrated").await.map_err(|e| {
333 SedimentError::Database(format!("Failed to drop staging table: {}", e))
334 })?;
335 } else {
337 info!("Recovery case C: dropping leftover staging table");
339 self.db.drop_table("items_migrated").await.map_err(|e| {
340 SedimentError::Database(format!("Failed to drop staging table: {}", e))
341 })?;
342 }
343 }
344
345 Ok(())
346 }
347
348 async fn migrate_schema(&mut self) -> Result<()> {
362 info!("Starting schema migration...");
363
364 let old_table = self
366 .db
367 .open_table("items")
368 .execute()
369 .await
370 .map_err(|e| SedimentError::Database(format!("Failed to open old items: {}", e)))?;
371
372 let results = old_table
373 .query()
374 .execute()
375 .await
376 .map_err(|e| SedimentError::Database(format!("Migration query failed: {}", e)))?
377 .try_collect::<Vec<_>>()
378 .await
379 .map_err(|e| SedimentError::Database(format!("Migration collect failed: {}", e)))?;
380
381 let mut new_batches = Vec::new();
383 for batch in &results {
384 let converted = self.convert_batch_to_new_schema(batch)?;
385 new_batches.push(converted);
386 }
387
388 let old_count: usize = results.iter().map(|b| b.num_rows()).sum();
390 let new_count: usize = new_batches.iter().map(|b| b.num_rows()).sum();
391 if old_count != new_count {
392 return Err(SedimentError::Database(format!(
393 "Migration row count mismatch: old={}, new={}",
394 old_count, new_count
395 )));
396 }
397 info!("Migrating {} items to new schema", old_count);
398
399 let table_names = self
401 .db
402 .table_names()
403 .execute()
404 .await
405 .map_err(|e| SedimentError::Database(format!("Failed to list tables: {}", e)))?;
406 if table_names.contains(&"items_migrated".to_string()) {
407 self.db.drop_table("items_migrated").await.map_err(|e| {
408 SedimentError::Database(format!("Failed to drop stale staging: {}", e))
409 })?;
410 }
411
412 let schema = Arc::new(item_schema());
414 let staging_table = self
415 .db
416 .create_empty_table("items_migrated", schema.clone())
417 .execute()
418 .await
419 .map_err(|e| {
420 SedimentError::Database(format!("Failed to create staging table: {}", e))
421 })?;
422
423 if !new_batches.is_empty() {
424 let batches = RecordBatchIterator::new(new_batches.into_iter().map(Ok), schema.clone());
425 staging_table
426 .add(Box::new(batches))
427 .execute()
428 .await
429 .map_err(|e| {
430 SedimentError::Database(format!("Failed to insert into staging: {}", e))
431 })?;
432 }
433
434 let staging_count = staging_table
436 .count_rows(None)
437 .await
438 .map_err(|e| SedimentError::Database(format!("Failed to count staging rows: {}", e)))?;
439 if staging_count != old_count {
440 let _ = self.db.drop_table("items_migrated").await;
442 return Err(SedimentError::Database(format!(
443 "Staging row count mismatch: expected {}, got {}",
444 old_count, staging_count
445 )));
446 }
447
448 self.db.drop_table("items").await.map_err(|e| {
450 SedimentError::Database(format!("Failed to drop old items table: {}", e))
451 })?;
452
453 let staging_data = staging_table
455 .query()
456 .execute()
457 .await
458 .map_err(|e| SedimentError::Database(format!("Failed to read staging: {}", e)))?
459 .try_collect::<Vec<_>>()
460 .await
461 .map_err(|e| SedimentError::Database(format!("Failed to collect staging: {}", e)))?;
462
463 let new_table = self
464 .db
465 .create_empty_table("items", schema.clone())
466 .execute()
467 .await
468 .map_err(|e| {
469 SedimentError::Database(format!("Failed to create new items table: {}", e))
470 })?;
471
472 if !staging_data.is_empty() {
473 let batches = RecordBatchIterator::new(staging_data.into_iter().map(Ok), schema);
474 new_table
475 .add(Box::new(batches))
476 .execute()
477 .await
478 .map_err(|e| {
479 SedimentError::Database(format!("Failed to insert migrated items: {}", e))
480 })?;
481 }
482
483 self.db
485 .drop_table("items_migrated")
486 .await
487 .map_err(|e| SedimentError::Database(format!("Failed to drop staging table: {}", e)))?;
488
489 info!("Schema migration completed successfully");
490 Ok(())
491 }
492
493 fn convert_batch_to_new_schema(&self, batch: &RecordBatch) -> Result<RecordBatch> {
495 let schema = Arc::new(item_schema());
496
497 let id_col = batch
499 .column_by_name("id")
500 .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?
501 .clone();
502
503 let content_col = batch
504 .column_by_name("content")
505 .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?
506 .clone();
507
508 let project_id_col = batch
509 .column_by_name("project_id")
510 .ok_or_else(|| SedimentError::Database("Missing project_id column".to_string()))?
511 .clone();
512
513 let is_chunked_col = batch
514 .column_by_name("is_chunked")
515 .ok_or_else(|| SedimentError::Database("Missing is_chunked column".to_string()))?
516 .clone();
517
518 let created_at_col = batch
519 .column_by_name("created_at")
520 .ok_or_else(|| SedimentError::Database("Missing created_at column".to_string()))?
521 .clone();
522
523 let vector_col = batch
524 .column_by_name("vector")
525 .ok_or_else(|| SedimentError::Database("Missing vector column".to_string()))?
526 .clone();
527
528 RecordBatch::try_new(
529 schema,
530 vec![
531 id_col,
532 content_col,
533 project_id_col,
534 is_chunked_col,
535 created_at_col,
536 vector_col,
537 ],
538 )
539 .map_err(|e| SedimentError::Database(format!("Failed to create migrated batch: {}", e)))
540 }
541
542 async fn ensure_vector_index(&self) -> Result<()> {
548 const MIN_ROWS_FOR_INDEX: usize = 256;
549
550 for (name, table_opt) in [("items", &self.items_table), ("chunks", &self.chunks_table)] {
551 if let Some(table) = table_opt {
552 let row_count = table.count_rows(None).await.unwrap_or(0);
553
554 let indices = table.list_indices().await.unwrap_or_default();
556
557 if row_count >= MIN_ROWS_FOR_INDEX {
559 let has_vector_index = indices
560 .iter()
561 .any(|idx| idx.columns.contains(&"vector".to_string()));
562
563 if !has_vector_index {
564 info!(
565 "Creating vector index on {} table ({} rows)",
566 name, row_count
567 );
568 match table
569 .create_index(&["vector"], lancedb::index::Index::Auto)
570 .execute()
571 .await
572 {
573 Ok(_) => info!("Vector index created on {} table", name),
574 Err(e) => {
575 tracing::warn!("Failed to create vector index on {}: {}", name, e);
577 }
578 }
579 }
580 }
581
582 if row_count > 0 {
584 let has_fts_index = indices
585 .iter()
586 .any(|idx| idx.columns.contains(&"content".to_string()));
587
588 if !has_fts_index {
589 info!("Creating FTS index on {} table ({} rows)", name, row_count);
590 match table
591 .create_index(
592 &["content"],
593 lancedb::index::Index::FTS(Default::default()),
594 )
595 .execute()
596 .await
597 {
598 Ok(_) => info!("FTS index created on {} table", name),
599 Err(e) => {
600 tracing::warn!("Failed to create FTS index on {}: {}", name, e);
602 }
603 }
604 }
605 }
606 }
607 }
608
609 Ok(())
610 }
611
612 async fn get_items_table(&mut self) -> Result<&Table> {
614 if self.items_table.is_none() {
615 let schema = Arc::new(item_schema());
616 let table = self
617 .db
618 .create_empty_table("items", schema)
619 .execute()
620 .await
621 .map_err(|e| {
622 SedimentError::Database(format!("Failed to create items table: {}", e))
623 })?;
624 self.items_table = Some(table);
625 }
626 Ok(self.items_table.as_ref().unwrap())
627 }
628
629 async fn get_chunks_table(&mut self) -> Result<&Table> {
631 if self.chunks_table.is_none() {
632 let schema = Arc::new(chunk_schema());
633 let table = self
634 .db
635 .create_empty_table("chunks", schema)
636 .execute()
637 .await
638 .map_err(|e| {
639 SedimentError::Database(format!("Failed to create chunks table: {}", e))
640 })?;
641 self.chunks_table = Some(table);
642 }
643 Ok(self.chunks_table.as_ref().unwrap())
644 }
645
646 pub async fn store_item(&mut self, mut item: Item) -> Result<StoreResult> {
653 if item.project_id.is_none() {
655 item.project_id = self.project_id.clone();
656 }
657
658 let should_chunk = item.content.chars().count() > CHUNK_THRESHOLD;
661 item.is_chunked = should_chunk;
662
663 let embedding_text = item.embedding_text();
665 let embedding = self.embedder.embed(&embedding_text)?;
666 item.embedding = embedding;
667
668 let table = self.get_items_table().await?;
670 let batch = item_to_batch(&item)?;
671 let batches = RecordBatchIterator::new(vec![Ok(batch)], Arc::new(item_schema()));
672
673 table
674 .add(Box::new(batches))
675 .execute()
676 .await
677 .map_err(|e| SedimentError::Database(format!("Failed to store item: {}", e)))?;
678
679 if should_chunk {
681 let content_type = detect_content_type(&item.content);
682 let config = ChunkingConfig::default();
683 let mut chunk_results = chunk_content(&item.content, content_type, &config);
684
685 if chunk_results.len() > MAX_CHUNKS_PER_ITEM {
687 tracing::warn!(
688 "Chunk count {} exceeds limit {}, truncating",
689 chunk_results.len(),
690 MAX_CHUNKS_PER_ITEM
691 );
692 chunk_results.truncate(MAX_CHUNKS_PER_ITEM);
693 }
694
695 if let Err(e) = self.store_chunks(&item.id, &chunk_results).await {
696 let _ = self.delete_item(&item.id).await;
698 return Err(e);
699 }
700
701 debug!(
702 "Stored item: {} with {} chunks",
703 item.id,
704 chunk_results.len()
705 );
706 } else {
707 debug!("Stored item: {} (no chunking)", item.id);
708 }
709
710 let potential_conflicts = self
713 .find_similar_items_by_vector(
714 &item.embedding,
715 Some(&item.id),
716 CONFLICT_SIMILARITY_THRESHOLD,
717 CONFLICT_SEARCH_LIMIT,
718 )
719 .await
720 .unwrap_or_default();
721
722 Ok(StoreResult {
723 id: item.id,
724 potential_conflicts,
725 })
726 }
727
728 async fn store_chunks(
730 &mut self,
731 item_id: &str,
732 chunk_results: &[crate::chunker::ChunkResult],
733 ) -> Result<()> {
734 let embedder = self.embedder.clone();
735 let chunks_table = self.get_chunks_table().await?;
736
737 let chunk_texts: Vec<&str> = chunk_results.iter().map(|cr| cr.content.as_str()).collect();
739 let mut all_embeddings = Vec::with_capacity(chunk_texts.len());
740 for batch_start in (0..chunk_texts.len()).step_by(EMBEDDING_BATCH_SIZE) {
741 let batch_end = (batch_start + EMBEDDING_BATCH_SIZE).min(chunk_texts.len());
742 let batch_embeddings = embedder.embed_batch(&chunk_texts[batch_start..batch_end])?;
743 all_embeddings.extend(batch_embeddings);
744 }
745
746 let mut all_chunk_batches = Vec::with_capacity(chunk_results.len());
748 for (i, (chunk_result, embedding)) in chunk_results.iter().zip(all_embeddings).enumerate() {
749 let mut chunk = Chunk::new(item_id, i, &chunk_result.content);
750 if let Some(ctx) = &chunk_result.context {
751 chunk = chunk.with_context(ctx);
752 }
753 chunk.embedding = embedding;
754 all_chunk_batches.push(chunk_to_batch(&chunk)?);
755 }
756
757 if !all_chunk_batches.is_empty() {
759 let schema = Arc::new(chunk_schema());
760 let batches = RecordBatchIterator::new(all_chunk_batches.into_iter().map(Ok), schema);
761 chunks_table
762 .add(Box::new(batches))
763 .execute()
764 .await
765 .map_err(|e| SedimentError::Database(format!("Failed to store chunks: {}", e)))?;
766 }
767
768 Ok(())
769 }
770
771 async fn fts_rank_items(
774 &self,
775 table: &Table,
776 query: &str,
777 limit: usize,
778 ) -> Option<std::collections::HashMap<String, usize>> {
779 let fts_query =
780 FullTextSearchQuery::new(query.to_string()).columns(Some(vec!["content".to_string()]));
781
782 let fts_results = table
783 .query()
784 .full_text_search(fts_query)
785 .limit(limit)
786 .execute()
787 .await
788 .ok()?
789 .try_collect::<Vec<_>>()
790 .await
791 .ok()?;
792
793 let mut ranks = std::collections::HashMap::new();
794 let mut rank = 0usize;
795 for batch in fts_results {
796 let ids = batch
797 .column_by_name("id")
798 .and_then(|c| c.as_any().downcast_ref::<StringArray>())?;
799 for i in 0..ids.len() {
800 if !ids.is_null(i) {
801 ranks.insert(ids.value(i).to_string(), rank);
802 rank += 1;
803 }
804 }
805 }
806 Some(ranks)
807 }
808
809 pub async fn search_items(
811 &mut self,
812 query: &str,
813 limit: usize,
814 filters: ItemFilters,
815 ) -> Result<Vec<SearchResult>> {
816 let limit = limit.min(1000);
818 self.ensure_vector_index().await?;
820
821 let query_embedding = self.embedder.embed(query)?;
823 let min_similarity = filters.min_similarity.unwrap_or(0.3);
824
825 let mut results_map: std::collections::HashMap<String, (SearchResult, f32)> =
827 std::collections::HashMap::new();
828
829 let mut fts_ranking: Option<std::collections::HashMap<String, usize>> = None;
831
832 if let Some(table) = &self.items_table {
834 let row_count = table.count_rows(None).await.unwrap_or(0);
835 let base_query = table
836 .vector_search(query_embedding.clone())
837 .map_err(|e| SedimentError::Database(format!("Failed to build search: {}", e)))?;
838 let query_builder = if row_count < 5000 {
839 base_query.bypass_vector_index().limit(limit * 2)
840 } else {
841 base_query.refine_factor(10).limit(limit * 2)
842 };
843
844 let results = query_builder
845 .execute()
846 .await
847 .map_err(|e| SedimentError::Database(format!("Search failed: {}", e)))?
848 .try_collect::<Vec<_>>()
849 .await
850 .map_err(|e| {
851 SedimentError::Database(format!("Failed to collect results: {}", e))
852 })?;
853
854 let mut vector_items: Vec<(Item, f32)> = Vec::new();
856 for batch in results {
857 let items = batch_to_items(&batch)?;
858 let distances = batch
859 .column_by_name("_distance")
860 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
861
862 for (i, item) in items.into_iter().enumerate() {
863 let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
864 let similarity = 1.0 / (1.0 + distance);
865 if similarity >= min_similarity {
866 vector_items.push((item, similarity));
867 }
868 }
869 }
870
871 fts_ranking = self.fts_rank_items(table, query, limit * 2).await;
873
874 for (item, similarity) in vector_items {
877 let boosted_similarity = boost_similarity(
878 similarity,
879 item.project_id.as_deref(),
880 self.project_id.as_deref(),
881 );
882
883 let result = SearchResult::from_item(&item, boosted_similarity);
884 results_map
885 .entry(item.id.clone())
886 .or_insert((result, boosted_similarity));
887 }
888 }
889
890 if let Some(chunks_table) = &self.chunks_table {
892 let chunk_row_count = chunks_table.count_rows(None).await.unwrap_or(0);
893 let chunk_base_query = chunks_table.vector_search(query_embedding).map_err(|e| {
894 SedimentError::Database(format!("Failed to build chunk search: {}", e))
895 })?;
896 let chunk_results = if chunk_row_count < 5000 {
897 chunk_base_query.bypass_vector_index().limit(limit * 3)
898 } else {
899 chunk_base_query.refine_factor(10).limit(limit * 3)
900 }
901 .execute()
902 .await
903 .map_err(|e| SedimentError::Database(format!("Chunk search failed: {}", e)))?
904 .try_collect::<Vec<_>>()
905 .await
906 .map_err(|e| {
907 SedimentError::Database(format!("Failed to collect chunk results: {}", e))
908 })?;
909
910 let mut chunk_matches: std::collections::HashMap<String, (String, f32)> =
912 std::collections::HashMap::new();
913
914 for batch in chunk_results {
915 let chunks = batch_to_chunks(&batch)?;
916 let distances = batch
917 .column_by_name("_distance")
918 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
919
920 for (i, chunk) in chunks.into_iter().enumerate() {
921 let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
922 let similarity = 1.0 / (1.0 + distance);
923
924 if similarity < min_similarity {
925 continue;
926 }
927
928 chunk_matches
930 .entry(chunk.item_id.clone())
931 .and_modify(|(content, best_sim)| {
932 if similarity > *best_sim {
933 *content = chunk.content.clone();
934 *best_sim = similarity;
935 }
936 })
937 .or_insert((chunk.content.clone(), similarity));
938 }
939 }
940
941 for (item_id, (excerpt, chunk_similarity)) in chunk_matches {
943 if let Some(item) = self.get_item(&item_id).await? {
944 let boosted_similarity = boost_similarity(
946 chunk_similarity,
947 item.project_id.as_deref(),
948 self.project_id.as_deref(),
949 );
950
951 let result =
952 SearchResult::from_item_with_excerpt(&item, boosted_similarity, excerpt);
953
954 results_map
956 .entry(item_id)
957 .and_modify(|(existing, existing_sim)| {
958 if boosted_similarity > *existing_sim {
959 *existing = result.clone();
960 *existing_sim = boosted_similarity;
961 }
962 })
963 .or_insert((result, boosted_similarity));
964 }
965 }
966 }
967
968 const RRF_K: f32 = 60.0;
973
974 let mut ranked: Vec<(SearchResult, f32)> = results_map.into_values().collect();
976 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
977
978 let mut rrf_scored: Vec<(SearchResult, f32)> = ranked
980 .into_iter()
981 .enumerate()
982 .map(|(vector_rank, (sr, _sim))| {
983 let vector_rrf = 1.0 / (RRF_K + vector_rank as f32);
984 let fts_rrf = fts_ranking.as_ref().map_or(0.0, |ranks| {
985 ranks.get(&sr.id).map_or(0.0, |&r| 1.0 / (RRF_K + r as f32))
986 });
987 (sr, vector_rrf + fts_rrf)
988 })
989 .collect();
990
991 rrf_scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
993
994 let search_results: Vec<SearchResult> = rrf_scored
995 .into_iter()
996 .take(limit)
997 .map(|(sr, _)| sr)
998 .collect();
999
1000 Ok(search_results)
1001 }
1002
1003 pub async fn find_similar_items(
1008 &mut self,
1009 content: &str,
1010 min_similarity: f32,
1011 limit: usize,
1012 ) -> Result<Vec<ConflictInfo>> {
1013 let embedding = self.embedder.embed(content)?;
1014 self.find_similar_items_by_vector(&embedding, None, min_similarity, limit)
1015 .await
1016 }
1017
1018 pub async fn find_similar_items_by_vector(
1022 &self,
1023 embedding: &[f32],
1024 exclude_id: Option<&str>,
1025 min_similarity: f32,
1026 limit: usize,
1027 ) -> Result<Vec<ConflictInfo>> {
1028 let table = match &self.items_table {
1029 Some(t) => t,
1030 None => return Ok(Vec::new()),
1031 };
1032
1033 let row_count = table.count_rows(None).await.unwrap_or(0);
1034 let base_query = table
1035 .vector_search(embedding.to_vec())
1036 .map_err(|e| SedimentError::Database(format!("Failed to build search: {}", e)))?;
1037 let results = if row_count < 5000 {
1038 base_query.bypass_vector_index().limit(limit)
1039 } else {
1040 base_query.refine_factor(10).limit(limit)
1041 }
1042 .execute()
1043 .await
1044 .map_err(|e| SedimentError::Database(format!("Search failed: {}", e)))?
1045 .try_collect::<Vec<_>>()
1046 .await
1047 .map_err(|e| SedimentError::Database(format!("Failed to collect results: {}", e)))?;
1048
1049 let mut conflicts = Vec::new();
1050
1051 for batch in results {
1052 let items = batch_to_items(&batch)?;
1053 let distances = batch
1054 .column_by_name("_distance")
1055 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
1056
1057 for (i, item) in items.into_iter().enumerate() {
1058 if exclude_id.is_some_and(|eid| eid == item.id) {
1059 continue;
1060 }
1061
1062 let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
1063 let similarity = 1.0 / (1.0 + distance);
1064
1065 if similarity >= min_similarity {
1066 conflicts.push(ConflictInfo {
1067 id: item.id,
1068 content: item.content,
1069 similarity,
1070 });
1071 }
1072 }
1073 }
1074
1075 conflicts.sort_by(|a, b| {
1077 b.similarity
1078 .partial_cmp(&a.similarity)
1079 .unwrap_or(std::cmp::Ordering::Equal)
1080 });
1081
1082 Ok(conflicts)
1083 }
1084
1085 pub async fn list_items(
1087 &mut self,
1088 _filters: ItemFilters,
1089 limit: Option<usize>,
1090 scope: crate::ListScope,
1091 ) -> Result<Vec<Item>> {
1092 let table = match &self.items_table {
1093 Some(t) => t,
1094 None => return Ok(Vec::new()),
1095 };
1096
1097 let mut filter_parts = Vec::new();
1098
1099 match scope {
1101 crate::ListScope::Project => {
1102 if let Some(ref pid) = self.project_id {
1103 if !is_valid_id(pid) {
1104 return Err(SedimentError::Database(
1105 "Invalid project_id for list filter".to_string(),
1106 ));
1107 }
1108 filter_parts.push(format!("project_id = '{}'", sanitize_sql_string(pid)));
1109 } else {
1110 return Ok(Vec::new());
1112 }
1113 }
1114 crate::ListScope::Global => {
1115 filter_parts.push("project_id IS NULL".to_string());
1116 }
1117 crate::ListScope::All => {
1118 }
1120 }
1121
1122 let mut query = table.query();
1123
1124 if !filter_parts.is_empty() {
1125 let filter_str = filter_parts.join(" AND ");
1126 query = query.only_if(filter_str);
1127 }
1128
1129 if let Some(l) = limit {
1130 query = query.limit(l);
1131 }
1132
1133 let results = query
1134 .execute()
1135 .await
1136 .map_err(|e| SedimentError::Database(format!("Query failed: {}", e)))?
1137 .try_collect::<Vec<_>>()
1138 .await
1139 .map_err(|e| SedimentError::Database(format!("Failed to collect: {}", e)))?;
1140
1141 let mut items = Vec::new();
1142 for batch in results {
1143 items.extend(batch_to_items(&batch)?);
1144 }
1145
1146 Ok(items)
1147 }
1148
1149 pub async fn get_item(&self, id: &str) -> Result<Option<Item>> {
1151 if !is_valid_id(id) {
1152 return Ok(None);
1153 }
1154 let table = match &self.items_table {
1155 Some(t) => t,
1156 None => return Ok(None),
1157 };
1158
1159 let results = table
1160 .query()
1161 .only_if(format!("id = '{}'", sanitize_sql_string(id)))
1162 .limit(1)
1163 .execute()
1164 .await
1165 .map_err(|e| SedimentError::Database(format!("Query failed: {}", e)))?
1166 .try_collect::<Vec<_>>()
1167 .await
1168 .map_err(|e| SedimentError::Database(format!("Failed to collect: {}", e)))?;
1169
1170 for batch in results {
1171 let items = batch_to_items(&batch)?;
1172 if let Some(item) = items.into_iter().next() {
1173 return Ok(Some(item));
1174 }
1175 }
1176
1177 Ok(None)
1178 }
1179
1180 pub async fn get_items_batch(&self, ids: &[&str]) -> Result<Vec<Item>> {
1182 let table = match &self.items_table {
1183 Some(t) => t,
1184 None => return Ok(Vec::new()),
1185 };
1186
1187 if ids.is_empty() {
1188 return Ok(Vec::new());
1189 }
1190
1191 let quoted: Vec<String> = ids
1192 .iter()
1193 .filter(|id| is_valid_id(id))
1194 .map(|id| format!("'{}'", sanitize_sql_string(id)))
1195 .collect();
1196 if quoted.is_empty() {
1197 return Ok(Vec::new());
1198 }
1199 let filter = format!("id IN ({})", quoted.join(", "));
1200
1201 let results = table
1202 .query()
1203 .only_if(filter)
1204 .execute()
1205 .await
1206 .map_err(|e| SedimentError::Database(format!("Batch query failed: {}", e)))?
1207 .try_collect::<Vec<_>>()
1208 .await
1209 .map_err(|e| SedimentError::Database(format!("Failed to collect batch: {}", e)))?;
1210
1211 let mut items = Vec::new();
1212 for batch in results {
1213 items.extend(batch_to_items(&batch)?);
1214 }
1215
1216 Ok(items)
1217 }
1218
1219 pub async fn delete_item(&self, id: &str) -> Result<bool> {
1222 if !is_valid_id(id) {
1223 return Ok(false);
1224 }
1225 let table = match &self.items_table {
1227 Some(t) => t,
1228 None => return Ok(false),
1229 };
1230
1231 let exists = self.get_item(id).await?.is_some();
1232 if !exists {
1233 return Ok(false);
1234 }
1235
1236 if let Some(chunks_table) = &self.chunks_table {
1238 chunks_table
1239 .delete(&format!("item_id = '{}'", sanitize_sql_string(id)))
1240 .await
1241 .map_err(|e| SedimentError::Database(format!("Delete chunks failed: {}", e)))?;
1242 }
1243
1244 table
1246 .delete(&format!("id = '{}'", sanitize_sql_string(id)))
1247 .await
1248 .map_err(|e| SedimentError::Database(format!("Delete failed: {}", e)))?;
1249
1250 Ok(true)
1251 }
1252
1253 pub async fn stats(&self) -> Result<DatabaseStats> {
1255 let mut stats = DatabaseStats::default();
1256
1257 if let Some(table) = &self.items_table {
1258 stats.item_count = table
1259 .count_rows(None)
1260 .await
1261 .map_err(|e| SedimentError::Database(format!("Count failed: {}", e)))?;
1262 }
1263
1264 if let Some(table) = &self.chunks_table {
1265 stats.chunk_count = table
1266 .count_rows(None)
1267 .await
1268 .map_err(|e| SedimentError::Database(format!("Count failed: {}", e)))?;
1269 }
1270
1271 Ok(stats)
1272 }
1273}
1274
1275pub async fn migrate_project_id(
1282 db_path: &std::path::Path,
1283 old_id: &str,
1284 new_id: &str,
1285) -> Result<u64> {
1286 if !is_valid_id(old_id) || !is_valid_id(new_id) {
1287 return Err(SedimentError::Database(
1288 "Invalid project ID for migration".to_string(),
1289 ));
1290 }
1291
1292 let db = connect(db_path.to_str().ok_or_else(|| {
1293 SedimentError::Database("Database path contains invalid UTF-8".to_string())
1294 })?)
1295 .execute()
1296 .await
1297 .map_err(|e| SedimentError::Database(format!("Failed to connect for migration: {}", e)))?;
1298
1299 let table_names = db
1300 .table_names()
1301 .execute()
1302 .await
1303 .map_err(|e| SedimentError::Database(format!("Failed to list tables: {}", e)))?;
1304
1305 let mut total_updated = 0u64;
1306
1307 if table_names.contains(&"items".to_string()) {
1308 let table =
1309 db.open_table("items").execute().await.map_err(|e| {
1310 SedimentError::Database(format!("Failed to open items table: {}", e))
1311 })?;
1312
1313 let updated = table
1314 .update()
1315 .only_if(format!("project_id = '{}'", sanitize_sql_string(old_id)))
1316 .column("project_id", format!("'{}'", sanitize_sql_string(new_id)))
1317 .execute()
1318 .await
1319 .map_err(|e| SedimentError::Database(format!("Failed to migrate items: {}", e)))?;
1320
1321 total_updated += updated;
1322 info!(
1323 "Migrated {} items from project {} to {}",
1324 updated, old_id, new_id
1325 );
1326 }
1327
1328 Ok(total_updated)
1329}
1330
1331pub fn score_with_decay(
1342 similarity: f32,
1343 now: i64,
1344 created_at: i64,
1345 access_count: u32,
1346 last_accessed_at: Option<i64>,
1347) -> f32 {
1348 if !similarity.is_finite() {
1350 return 0.0;
1351 }
1352
1353 let reference_time = last_accessed_at.unwrap_or(created_at);
1354 let age_secs = (now - reference_time).max(0) as f64;
1355 let age_days = age_secs / 86400.0;
1356
1357 let freshness = 1.0 / (1.0 + age_days / 30.0);
1358 let frequency = 1.0 + 0.1 * (1.0 + access_count as f64).ln();
1359
1360 let result = similarity * (freshness * frequency) as f32;
1361 if result.is_finite() { result } else { 0.0 }
1362}
1363
1364fn detect_content_type(content: &str) -> ContentType {
1368 let trimmed = content.trim();
1369
1370 if ((trimmed.starts_with('{') && trimmed.ends_with('}'))
1372 || (trimmed.starts_with('[') && trimmed.ends_with(']')))
1373 && serde_json::from_str::<serde_json::Value>(trimmed).is_ok()
1374 {
1375 return ContentType::Json;
1376 }
1377
1378 if trimmed.contains(":\n") || trimmed.contains(": ") || trimmed.starts_with("---") {
1382 let lines: Vec<&str> = trimmed.lines().take(10).collect();
1383 let yaml_key_count = lines
1384 .iter()
1385 .filter(|line| {
1386 let l = line.trim();
1387 !l.is_empty()
1390 && !l.starts_with('#')
1391 && !l.contains("://")
1392 && l.contains(": ")
1393 && l.split(": ").next().is_some_and(|key| {
1394 let k = key.trim_start_matches("- ");
1395 !k.is_empty()
1396 && k.chars()
1397 .all(|c| c.is_alphanumeric() || c == '_' || c == '-')
1398 })
1399 })
1400 .count();
1401 if yaml_key_count >= 2 || (trimmed.starts_with("---") && yaml_key_count >= 1) {
1403 return ContentType::Yaml;
1404 }
1405 }
1406
1407 if trimmed.contains("\n# ") || trimmed.starts_with("# ") || trimmed.contains("\n## ") {
1409 return ContentType::Markdown;
1410 }
1411
1412 let code_patterns = [
1415 "fn ",
1416 "pub fn ",
1417 "def ",
1418 "class ",
1419 "function ",
1420 "const ",
1421 "let ",
1422 "var ",
1423 "import ",
1424 "export ",
1425 "struct ",
1426 "impl ",
1427 "trait ",
1428 ];
1429 let has_code_pattern = trimmed.lines().any(|line| {
1430 let l = line.trim();
1431 code_patterns.iter().any(|p| l.starts_with(p))
1432 });
1433 if has_code_pattern {
1434 return ContentType::Code;
1435 }
1436
1437 ContentType::Text
1438}
1439
1440fn item_to_batch(item: &Item) -> Result<RecordBatch> {
1443 let schema = Arc::new(item_schema());
1444
1445 let id = StringArray::from(vec![item.id.as_str()]);
1446 let content = StringArray::from(vec![item.content.as_str()]);
1447 let project_id = StringArray::from(vec![item.project_id.as_deref()]);
1448 let is_chunked = BooleanArray::from(vec![item.is_chunked]);
1449 let created_at = Int64Array::from(vec![item.created_at.timestamp()]);
1450
1451 let vector = create_embedding_array(&item.embedding)?;
1452
1453 RecordBatch::try_new(
1454 schema,
1455 vec![
1456 Arc::new(id),
1457 Arc::new(content),
1458 Arc::new(project_id),
1459 Arc::new(is_chunked),
1460 Arc::new(created_at),
1461 Arc::new(vector),
1462 ],
1463 )
1464 .map_err(|e| SedimentError::Database(format!("Failed to create batch: {}", e)))
1465}
1466
1467fn batch_to_items(batch: &RecordBatch) -> Result<Vec<Item>> {
1468 let mut items = Vec::new();
1469
1470 let id_col = batch
1471 .column_by_name("id")
1472 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1473 .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?;
1474
1475 let content_col = batch
1476 .column_by_name("content")
1477 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1478 .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?;
1479
1480 let project_id_col = batch
1481 .column_by_name("project_id")
1482 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
1483
1484 let is_chunked_col = batch
1485 .column_by_name("is_chunked")
1486 .and_then(|c| c.as_any().downcast_ref::<BooleanArray>());
1487
1488 let created_at_col = batch
1489 .column_by_name("created_at")
1490 .and_then(|c| c.as_any().downcast_ref::<Int64Array>());
1491
1492 let vector_col = batch
1493 .column_by_name("vector")
1494 .and_then(|c| c.as_any().downcast_ref::<FixedSizeListArray>());
1495
1496 for i in 0..batch.num_rows() {
1497 let id = id_col.value(i).to_string();
1498 let content = content_col.value(i).to_string();
1499
1500 let project_id = project_id_col.and_then(|c| {
1501 if c.is_null(i) {
1502 None
1503 } else {
1504 Some(c.value(i).to_string())
1505 }
1506 });
1507
1508 let is_chunked = is_chunked_col.map(|c| c.value(i)).unwrap_or(false);
1509
1510 let created_at = created_at_col
1511 .map(|c| {
1512 Utc.timestamp_opt(c.value(i), 0)
1513 .single()
1514 .unwrap_or_else(Utc::now)
1515 })
1516 .unwrap_or_else(Utc::now);
1517
1518 let embedding = vector_col
1519 .and_then(|col| {
1520 let value = col.value(i);
1521 value
1522 .as_any()
1523 .downcast_ref::<Float32Array>()
1524 .map(|arr| arr.values().to_vec())
1525 })
1526 .unwrap_or_default();
1527
1528 let item = Item {
1529 id,
1530 content,
1531 embedding,
1532 project_id,
1533 is_chunked,
1534 created_at,
1535 };
1536
1537 items.push(item);
1538 }
1539
1540 Ok(items)
1541}
1542
1543fn chunk_to_batch(chunk: &Chunk) -> Result<RecordBatch> {
1544 let schema = Arc::new(chunk_schema());
1545
1546 let id = StringArray::from(vec![chunk.id.as_str()]);
1547 let item_id = StringArray::from(vec![chunk.item_id.as_str()]);
1548 let chunk_index = Int32Array::from(vec![i32::try_from(chunk.chunk_index).unwrap_or(i32::MAX)]);
1549 let content = StringArray::from(vec![chunk.content.as_str()]);
1550 let context = StringArray::from(vec![chunk.context.as_deref()]);
1551
1552 let vector = create_embedding_array(&chunk.embedding)?;
1553
1554 RecordBatch::try_new(
1555 schema,
1556 vec![
1557 Arc::new(id),
1558 Arc::new(item_id),
1559 Arc::new(chunk_index),
1560 Arc::new(content),
1561 Arc::new(context),
1562 Arc::new(vector),
1563 ],
1564 )
1565 .map_err(|e| SedimentError::Database(format!("Failed to create batch: {}", e)))
1566}
1567
1568fn batch_to_chunks(batch: &RecordBatch) -> Result<Vec<Chunk>> {
1569 let mut chunks = Vec::new();
1570
1571 let id_col = batch
1572 .column_by_name("id")
1573 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1574 .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?;
1575
1576 let item_id_col = batch
1577 .column_by_name("item_id")
1578 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1579 .ok_or_else(|| SedimentError::Database("Missing item_id column".to_string()))?;
1580
1581 let chunk_index_col = batch
1582 .column_by_name("chunk_index")
1583 .and_then(|c| c.as_any().downcast_ref::<Int32Array>())
1584 .ok_or_else(|| SedimentError::Database("Missing chunk_index column".to_string()))?;
1585
1586 let content_col = batch
1587 .column_by_name("content")
1588 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1589 .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?;
1590
1591 let context_col = batch
1592 .column_by_name("context")
1593 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
1594
1595 for i in 0..batch.num_rows() {
1596 let id = id_col.value(i).to_string();
1597 let item_id = item_id_col.value(i).to_string();
1598 let chunk_index = chunk_index_col.value(i) as usize;
1599 let content = content_col.value(i).to_string();
1600 let context = context_col.and_then(|c| {
1601 if c.is_null(i) {
1602 None
1603 } else {
1604 Some(c.value(i).to_string())
1605 }
1606 });
1607
1608 let chunk = Chunk {
1609 id,
1610 item_id,
1611 chunk_index,
1612 content,
1613 embedding: Vec::new(),
1614 context,
1615 };
1616
1617 chunks.push(chunk);
1618 }
1619
1620 Ok(chunks)
1621}
1622
1623fn create_embedding_array(embedding: &[f32]) -> Result<FixedSizeListArray> {
1624 let values = Float32Array::from(embedding.to_vec());
1625 let field = Arc::new(Field::new("item", DataType::Float32, true));
1626
1627 FixedSizeListArray::try_new(field, EMBEDDING_DIM as i32, Arc::new(values), None)
1628 .map_err(|e| SedimentError::Database(format!("Failed to create vector: {}", e)))
1629}
1630
1631#[cfg(test)]
1632mod tests {
1633 use super::*;
1634
1635 #[test]
1636 fn test_score_with_decay_fresh_item() {
1637 let now = 1700000000i64;
1638 let created = now; let score = score_with_decay(0.8, now, created, 0, None);
1640 let expected = 0.8 * 1.0 * 1.0;
1642 assert!((score - expected).abs() < 0.001, "got {}", score);
1643 }
1644
1645 #[test]
1646 fn test_score_with_decay_30_day_old() {
1647 let now = 1700000000i64;
1648 let created = now - 30 * 86400; let score = score_with_decay(0.8, now, created, 0, None);
1650 let expected = 0.8 * 0.5;
1652 assert!((score - expected).abs() < 0.001, "got {}", score);
1653 }
1654
1655 #[test]
1656 fn test_score_with_decay_frequent_access() {
1657 let now = 1700000000i64;
1658 let created = now - 30 * 86400;
1659 let last_accessed = now; let score = score_with_decay(0.8, now, created, 10, Some(last_accessed));
1661 let freq = 1.0 + 0.1 * (11.0_f64).ln();
1663 let expected = 0.8 * 1.0 * freq as f32;
1664 assert!((score - expected).abs() < 0.01, "got {}", score);
1665 }
1666
1667 #[test]
1668 fn test_score_with_decay_old_and_unused() {
1669 let now = 1700000000i64;
1670 let created = now - 90 * 86400; let score = score_with_decay(0.8, now, created, 0, None);
1672 let expected = 0.8 * 0.25;
1674 assert!((score - expected).abs() < 0.001, "got {}", score);
1675 }
1676
1677 #[test]
1678 fn test_sanitize_sql_string_escapes_quotes_and_backslashes() {
1679 assert_eq!(sanitize_sql_string("hello"), "hello");
1680 assert_eq!(sanitize_sql_string("it's"), "it''s");
1681 assert_eq!(sanitize_sql_string(r"a\'b"), r"a\\''b");
1682 assert_eq!(sanitize_sql_string(r"path\to\file"), r"path\\to\\file");
1683 }
1684
1685 #[test]
1686 fn test_sanitize_sql_string_strips_null_bytes() {
1687 assert_eq!(sanitize_sql_string("abc\0def"), "abcdef");
1688 assert_eq!(sanitize_sql_string("\0' OR 1=1 --"), "'' OR 1=1 ");
1689 assert_eq!(sanitize_sql_string("*/ OR 1=1"), " OR 1=1");
1691 assert_eq!(sanitize_sql_string("clean"), "clean");
1692 }
1693
1694 #[test]
1695 fn test_sanitize_sql_string_strips_semicolons() {
1696 assert_eq!(
1697 sanitize_sql_string("a; DROP TABLE items"),
1698 "a DROP TABLE items"
1699 );
1700 assert_eq!(sanitize_sql_string("normal;"), "normal");
1701 }
1702
1703 #[test]
1704 fn test_sanitize_sql_string_strips_comments() {
1705 assert_eq!(sanitize_sql_string("val' -- comment"), "val'' comment");
1707 assert_eq!(sanitize_sql_string("val' /* block */"), "val'' block ");
1709 assert_eq!(sanitize_sql_string("a--b--c"), "abc");
1711 assert_eq!(sanitize_sql_string("injected */ rest"), "injected rest");
1713 assert_eq!(sanitize_sql_string("*/"), "");
1715 }
1716
1717 #[test]
1718 fn test_sanitize_sql_string_adversarial_inputs() {
1719 assert_eq!(
1721 sanitize_sql_string("'; DROP TABLE items;--"),
1722 "'' DROP TABLE items"
1723 );
1724 assert_eq!(
1726 sanitize_sql_string("hello\u{200B}world"),
1727 "hello\u{200B}world"
1728 );
1729 assert_eq!(sanitize_sql_string(""), "");
1731 assert_eq!(sanitize_sql_string("\0;\0"), "");
1733 }
1734
1735 #[test]
1736 fn test_is_valid_id() {
1737 assert!(is_valid_id("550e8400-e29b-41d4-a716-446655440000"));
1739 assert!(is_valid_id("abcdef0123456789"));
1740 assert!(!is_valid_id(""));
1742 assert!(!is_valid_id("'; DROP TABLE items;--"));
1743 assert!(!is_valid_id("hello world"));
1744 assert!(!is_valid_id("abc\0def"));
1745 assert!(!is_valid_id(&"a".repeat(65)));
1747 }
1748
1749 #[test]
1750 fn test_detect_content_type_yaml_not_prose() {
1751 let prose = "Dear John:\nI wanted to write you about something.\nSubject: important matter";
1753 let detected = detect_content_type(prose);
1754 assert_ne!(
1755 detected,
1756 ContentType::Yaml,
1757 "Prose with colons should not be detected as YAML"
1758 );
1759
1760 let yaml = "server: localhost\nport: 8080\ndatabase: mydb";
1762 let detected = detect_content_type(yaml);
1763 assert_eq!(detected, ContentType::Yaml);
1764 }
1765
1766 #[test]
1767 fn test_detect_content_type_yaml_with_separator() {
1768 let yaml = "---\nname: test\nversion: 1.0";
1769 let detected = detect_content_type(yaml);
1770 assert_eq!(detected, ContentType::Yaml);
1771 }
1772
1773 #[test]
1774 fn test_chunk_threshold_uses_chars_not_bytes() {
1775 let emoji_content = "😀".repeat(500);
1778 assert_eq!(emoji_content.chars().count(), 500);
1779 assert_eq!(emoji_content.len(), 2000); let should_chunk = emoji_content.chars().count() > CHUNK_THRESHOLD;
1782 assert!(
1783 !should_chunk,
1784 "500 chars should not exceed 1000-char threshold"
1785 );
1786
1787 let long_content = "a".repeat(1001);
1789 let should_chunk = long_content.chars().count() > CHUNK_THRESHOLD;
1790 assert!(should_chunk, "1001 chars should exceed 1000-char threshold");
1791 }
1792
1793 #[test]
1794 fn test_schema_version() {
1795 let version = SCHEMA_VERSION;
1797 assert!(version >= 2, "Schema version should be at least 2");
1798 }
1799
1800 fn old_item_schema() -> Schema {
1802 Schema::new(vec![
1803 Field::new("id", DataType::Utf8, false),
1804 Field::new("content", DataType::Utf8, false),
1805 Field::new("project_id", DataType::Utf8, true),
1806 Field::new("tags", DataType::Utf8, true), Field::new("is_chunked", DataType::Boolean, false),
1808 Field::new("created_at", DataType::Int64, false),
1809 Field::new(
1810 "vector",
1811 DataType::FixedSizeList(
1812 Arc::new(Field::new("item", DataType::Float32, true)),
1813 EMBEDDING_DIM as i32,
1814 ),
1815 false,
1816 ),
1817 ])
1818 }
1819
1820 fn old_item_batch(id: &str, content: &str) -> RecordBatch {
1822 let schema = Arc::new(old_item_schema());
1823 let vector_values = Float32Array::from(vec![0.0f32; EMBEDDING_DIM]);
1824 let vector_field = Arc::new(Field::new("item", DataType::Float32, true));
1825 let vector = FixedSizeListArray::try_new(
1826 vector_field,
1827 EMBEDDING_DIM as i32,
1828 Arc::new(vector_values),
1829 None,
1830 )
1831 .unwrap();
1832
1833 RecordBatch::try_new(
1834 schema,
1835 vec![
1836 Arc::new(StringArray::from(vec![id])),
1837 Arc::new(StringArray::from(vec![content])),
1838 Arc::new(StringArray::from(vec![None::<&str>])), Arc::new(StringArray::from(vec![None::<&str>])), Arc::new(BooleanArray::from(vec![false])),
1841 Arc::new(Int64Array::from(vec![1700000000i64])),
1842 Arc::new(vector),
1843 ],
1844 )
1845 .unwrap()
1846 }
1847
1848 #[tokio::test]
1849 #[ignore] async fn test_check_needs_migration_detects_old_schema() {
1851 let tmp = tempfile::TempDir::new().unwrap();
1852 let db_path = tmp.path().join("data");
1853
1854 let db_conn = lancedb::connect(db_path.to_str().unwrap())
1856 .execute()
1857 .await
1858 .unwrap();
1859
1860 let schema = Arc::new(old_item_schema());
1861 let batch = old_item_batch("test-id-1", "old content");
1862 let batches = RecordBatchIterator::new(vec![Ok(batch)], schema);
1863 db_conn
1864 .create_table("items", Box::new(batches))
1865 .execute()
1866 .await
1867 .unwrap();
1868
1869 let db = Database {
1871 db: db_conn,
1872 embedder: Arc::new(Embedder::new().unwrap()),
1873 project_id: None,
1874 items_table: None,
1875 chunks_table: None,
1876 };
1877
1878 let needs_migration = db.check_needs_migration().await.unwrap();
1879 assert!(
1880 needs_migration,
1881 "Old schema with tags column should need migration"
1882 );
1883 }
1884
1885 #[tokio::test]
1886 #[ignore] async fn test_check_needs_migration_false_for_new_schema() {
1888 let tmp = tempfile::TempDir::new().unwrap();
1889 let db_path = tmp.path().join("data");
1890
1891 let db_conn = lancedb::connect(db_path.to_str().unwrap())
1893 .execute()
1894 .await
1895 .unwrap();
1896
1897 let schema = Arc::new(item_schema());
1898 db_conn
1899 .create_empty_table("items", schema)
1900 .execute()
1901 .await
1902 .unwrap();
1903
1904 let db = Database {
1905 db: db_conn,
1906 embedder: Arc::new(Embedder::new().unwrap()),
1907 project_id: None,
1908 items_table: None,
1909 chunks_table: None,
1910 };
1911
1912 let needs_migration = db.check_needs_migration().await.unwrap();
1913 assert!(!needs_migration, "New schema should not need migration");
1914 }
1915
1916 #[tokio::test]
1917 #[ignore] async fn test_migrate_schema_preserves_data() {
1919 let tmp = tempfile::TempDir::new().unwrap();
1920 let db_path = tmp.path().join("data");
1921
1922 let db_conn = lancedb::connect(db_path.to_str().unwrap())
1924 .execute()
1925 .await
1926 .unwrap();
1927
1928 let schema = Arc::new(old_item_schema());
1929 let batch1 = old_item_batch("id-aaa", "first item content");
1930 let batch2 = old_item_batch("id-bbb", "second item content");
1931 let batches = RecordBatchIterator::new(vec![Ok(batch1), Ok(batch2)], schema);
1932 db_conn
1933 .create_table("items", Box::new(batches))
1934 .execute()
1935 .await
1936 .unwrap();
1937 drop(db_conn);
1938
1939 let embedder = Arc::new(Embedder::new().unwrap());
1941 let db = Database::open_with_embedder(&db_path, None, embedder)
1942 .await
1943 .unwrap();
1944
1945 let needs_migration = db.check_needs_migration().await.unwrap();
1947 assert!(
1948 !needs_migration,
1949 "Schema should be migrated (no tags column)"
1950 );
1951
1952 let item_a = db.get_item("id-aaa").await.unwrap();
1954 assert!(item_a.is_some(), "Item id-aaa should be preserved");
1955 assert_eq!(item_a.unwrap().content, "first item content");
1956
1957 let item_b = db.get_item("id-bbb").await.unwrap();
1958 assert!(item_b.is_some(), "Item id-bbb should be preserved");
1959 assert_eq!(item_b.unwrap().content, "second item content");
1960
1961 let stats = db.stats().await.unwrap();
1963 assert_eq!(stats.item_count, 2, "Should have 2 items after migration");
1964 }
1965
1966 #[tokio::test]
1967 #[ignore] async fn test_recover_case_a_only_staging() {
1969 let tmp = tempfile::TempDir::new().unwrap();
1970 let db_path = tmp.path().join("data");
1971
1972 let db_conn = lancedb::connect(db_path.to_str().unwrap())
1974 .execute()
1975 .await
1976 .unwrap();
1977
1978 let schema = Arc::new(item_schema());
1979 let vector_values = Float32Array::from(vec![0.0f32; EMBEDDING_DIM]);
1980 let vector_field = Arc::new(Field::new("item", DataType::Float32, true));
1981 let vector = FixedSizeListArray::try_new(
1982 vector_field,
1983 EMBEDDING_DIM as i32,
1984 Arc::new(vector_values),
1985 None,
1986 )
1987 .unwrap();
1988
1989 let batch = RecordBatch::try_new(
1990 schema.clone(),
1991 vec![
1992 Arc::new(StringArray::from(vec!["staging-id"])),
1993 Arc::new(StringArray::from(vec!["staging content"])),
1994 Arc::new(StringArray::from(vec![None::<&str>])),
1995 Arc::new(BooleanArray::from(vec![false])),
1996 Arc::new(Int64Array::from(vec![1700000000i64])),
1997 Arc::new(vector),
1998 ],
1999 )
2000 .unwrap();
2001
2002 let batches = RecordBatchIterator::new(vec![Ok(batch)], schema);
2003 db_conn
2004 .create_table("items_migrated", Box::new(batches))
2005 .execute()
2006 .await
2007 .unwrap();
2008 drop(db_conn);
2009
2010 let embedder = Arc::new(Embedder::new().unwrap());
2012 let db = Database::open_with_embedder(&db_path, None, embedder)
2013 .await
2014 .unwrap();
2015
2016 let item = db.get_item("staging-id").await.unwrap();
2018 assert!(item.is_some(), "Item should be recovered from staging");
2019 assert_eq!(item.unwrap().content, "staging content");
2020
2021 let table_names = db.db.table_names().execute().await.unwrap();
2023 assert!(
2024 !table_names.contains(&"items_migrated".to_string()),
2025 "Staging table should be dropped"
2026 );
2027 }
2028
2029 #[tokio::test]
2030 #[ignore] async fn test_recover_case_b_both_old_schema() {
2032 let tmp = tempfile::TempDir::new().unwrap();
2033 let db_path = tmp.path().join("data");
2034
2035 let db_conn = lancedb::connect(db_path.to_str().unwrap())
2037 .execute()
2038 .await
2039 .unwrap();
2040
2041 let old_schema = Arc::new(old_item_schema());
2043 let batch = old_item_batch("old-id", "old content");
2044 let batches = RecordBatchIterator::new(vec![Ok(batch)], old_schema);
2045 db_conn
2046 .create_table("items", Box::new(batches))
2047 .execute()
2048 .await
2049 .unwrap();
2050
2051 let new_schema = Arc::new(item_schema());
2053 db_conn
2054 .create_empty_table("items_migrated", new_schema)
2055 .execute()
2056 .await
2057 .unwrap();
2058 drop(db_conn);
2059
2060 let embedder = Arc::new(Embedder::new().unwrap());
2062 let db = Database::open_with_embedder(&db_path, None, embedder)
2063 .await
2064 .unwrap();
2065
2066 let needs_migration = db.check_needs_migration().await.unwrap();
2068 assert!(!needs_migration, "Should have migrated after recovery");
2069
2070 let item = db.get_item("old-id").await.unwrap();
2072 assert!(
2073 item.is_some(),
2074 "Item should be preserved through recovery + migration"
2075 );
2076
2077 let table_names = db.db.table_names().execute().await.unwrap();
2079 assert!(
2080 !table_names.contains(&"items_migrated".to_string()),
2081 "Staging table should be dropped"
2082 );
2083 }
2084
2085 #[tokio::test]
2086 #[ignore] async fn test_recover_case_c_both_new_schema() {
2088 let tmp = tempfile::TempDir::new().unwrap();
2089 let db_path = tmp.path().join("data");
2090
2091 let db_conn = lancedb::connect(db_path.to_str().unwrap())
2093 .execute()
2094 .await
2095 .unwrap();
2096
2097 let new_schema = Arc::new(item_schema());
2098
2099 let vector_values = Float32Array::from(vec![0.0f32; EMBEDDING_DIM]);
2101 let vector_field = Arc::new(Field::new("item", DataType::Float32, true));
2102 let vector = FixedSizeListArray::try_new(
2103 vector_field,
2104 EMBEDDING_DIM as i32,
2105 Arc::new(vector_values),
2106 None,
2107 )
2108 .unwrap();
2109
2110 let batch = RecordBatch::try_new(
2111 new_schema.clone(),
2112 vec![
2113 Arc::new(StringArray::from(vec!["new-id"])),
2114 Arc::new(StringArray::from(vec!["new content"])),
2115 Arc::new(StringArray::from(vec![None::<&str>])),
2116 Arc::new(BooleanArray::from(vec![false])),
2117 Arc::new(Int64Array::from(vec![1700000000i64])),
2118 Arc::new(vector),
2119 ],
2120 )
2121 .unwrap();
2122
2123 let batches = RecordBatchIterator::new(vec![Ok(batch)], new_schema.clone());
2124 db_conn
2125 .create_table("items", Box::new(batches))
2126 .execute()
2127 .await
2128 .unwrap();
2129
2130 db_conn
2132 .create_empty_table("items_migrated", new_schema)
2133 .execute()
2134 .await
2135 .unwrap();
2136 drop(db_conn);
2137
2138 let embedder = Arc::new(Embedder::new().unwrap());
2140 let db = Database::open_with_embedder(&db_path, None, embedder)
2141 .await
2142 .unwrap();
2143
2144 let item = db.get_item("new-id").await.unwrap();
2146 assert!(item.is_some(), "Item should be untouched");
2147 assert_eq!(item.unwrap().content, "new content");
2148
2149 let table_names = db.db.table_names().execute().await.unwrap();
2151 assert!(
2152 !table_names.contains(&"items_migrated".to_string()),
2153 "Staging table should be dropped"
2154 );
2155 }
2156
2157 #[tokio::test]
2158 #[ignore] async fn test_list_items_rejects_invalid_project_id() {
2160 let tmp = tempfile::TempDir::new().unwrap();
2161 let db_path = tmp.path().join("data");
2162 let malicious_pid = "'; DROP TABLE items;--".to_string();
2163
2164 let mut db = Database::open_with_project(&db_path, Some(malicious_pid))
2165 .await
2166 .unwrap();
2167
2168 let result = db
2169 .list_items(ItemFilters::new(), Some(10), crate::ListScope::Project)
2170 .await;
2171
2172 assert!(result.is_err(), "Should reject invalid project_id");
2173 let err_msg = result.unwrap_err().to_string();
2174 assert!(
2175 err_msg.contains("Invalid project_id"),
2176 "Error should mention invalid project_id, got: {}",
2177 err_msg
2178 );
2179 }
2180}