1use std::path::PathBuf;
7use std::sync::Arc;
8
9fn sanitize_sql_string(s: &str) -> String {
19 s.replace('\0', "")
20 .replace('\\', "\\\\")
21 .replace('\'', "''")
22 .replace(';', "")
23 .replace("--", "")
24 .replace("/*", "")
25 .replace("*/", "")
26}
27
28fn is_valid_id(s: &str) -> bool {
32 !s.is_empty() && s.len() <= 64 && s.chars().all(|c| c.is_ascii_hexdigit() || c == '-')
33}
34
35use arrow_array::{
36 Array, BooleanArray, FixedSizeListArray, Float32Array, Int32Array, Int64Array, RecordBatch,
37 RecordBatchIterator, StringArray,
38};
39use arrow_schema::{DataType, Field, Schema};
40use chrono::{TimeZone, Utc};
41use futures::TryStreamExt;
42use lancedb::Table;
43use lancedb::connect;
44use lancedb::index::scalar::FullTextSearchQuery;
45use lancedb::query::{ExecutableQuery, QueryBase};
46use tracing::{debug, info};
47
48use crate::boost_similarity;
49use crate::chunker::{ChunkingConfig, chunk_content};
50use crate::document::ContentType;
51use crate::embedder::Embedder;
52use crate::error::{Result, SedimentError};
53use crate::item::{Chunk, ConflictInfo, Item, ItemFilters, SearchResult, StoreResult};
54
55const CHUNK_THRESHOLD: usize = 1000;
57
58const CONFLICT_SIMILARITY_THRESHOLD: f32 = 0.85;
60
61const CONFLICT_SEARCH_LIMIT: usize = 5;
63
64const MAX_CHUNKS_PER_ITEM: usize = 200;
68
69const EMBEDDING_BATCH_SIZE: usize = 32;
72
73const FTS_BOOST_MAX: f32 = 0.12;
78
79const FTS_GAMMA: f32 = 2.0;
84
85const VECTOR_INDEX_THRESHOLD: usize = 5000;
88
89pub struct Database {
91 db: lancedb::Connection,
92 embedder: Arc<Embedder>,
93 project_id: Option<String>,
94 items_table: Option<Table>,
95 chunks_table: Option<Table>,
96 fts_boost_max: f32,
98 fts_gamma: f32,
99}
100
101#[derive(Debug, Default, Clone)]
103pub struct DatabaseStats {
104 pub item_count: usize,
105 pub chunk_count: usize,
106}
107
108const SCHEMA_VERSION: i32 = 2;
110
111fn item_schema(dim: usize) -> Schema {
113 Schema::new(vec![
114 Field::new("id", DataType::Utf8, false),
115 Field::new("content", DataType::Utf8, false),
116 Field::new("project_id", DataType::Utf8, true),
117 Field::new("is_chunked", DataType::Boolean, false),
118 Field::new("created_at", DataType::Int64, false), Field::new(
120 "vector",
121 DataType::FixedSizeList(
122 Arc::new(Field::new("item", DataType::Float32, true)),
123 dim as i32,
124 ),
125 false,
126 ),
127 ])
128}
129
130fn chunk_schema(dim: usize) -> Schema {
131 Schema::new(vec![
132 Field::new("id", DataType::Utf8, false),
133 Field::new("item_id", DataType::Utf8, false),
134 Field::new("chunk_index", DataType::Int32, false),
135 Field::new("content", DataType::Utf8, false),
136 Field::new("context", DataType::Utf8, true),
137 Field::new(
138 "vector",
139 DataType::FixedSizeList(
140 Arc::new(Field::new("item", DataType::Float32, true)),
141 dim as i32,
142 ),
143 false,
144 ),
145 ])
146}
147
148impl Database {
149 pub async fn open(path: impl Into<PathBuf>) -> Result<Self> {
151 Self::open_with_project(path, None).await
152 }
153
154 pub async fn open_with_project(
156 path: impl Into<PathBuf>,
157 project_id: Option<String>,
158 ) -> Result<Self> {
159 let embedder = Arc::new(Embedder::new()?);
160 Self::open_with_embedder(path, project_id, embedder).await
161 }
162
163 pub async fn open_with_embedder(
175 path: impl Into<PathBuf>,
176 project_id: Option<String>,
177 embedder: Arc<Embedder>,
178 ) -> Result<Self> {
179 let path = path.into();
180 info!("Opening database at {:?}", path);
181
182 if let Some(parent) = path.parent() {
184 std::fs::create_dir_all(parent).map_err(|e| {
185 SedimentError::Database(format!("Failed to create database directory: {}", e))
186 })?;
187 }
188
189 let db = connect(path.to_str().ok_or_else(|| {
190 SedimentError::Database("Database path contains invalid UTF-8".to_string())
191 })?)
192 .execute()
193 .await
194 .map_err(|e| SedimentError::Database(format!("Failed to connect to database: {}", e)))?;
195
196 #[cfg(feature = "bench")]
199 let (fts_boost_max, fts_gamma) = {
200 let boost = std::env::var("SEDIMENT_FTS_BOOST_MAX")
201 .ok()
202 .and_then(|v| v.parse::<f32>().ok())
203 .unwrap_or(FTS_BOOST_MAX);
204 let gamma = std::env::var("SEDIMENT_FTS_GAMMA")
205 .ok()
206 .and_then(|v| v.parse::<f32>().ok())
207 .unwrap_or(FTS_GAMMA);
208 (boost, gamma)
209 };
210 #[cfg(not(feature = "bench"))]
211 let (fts_boost_max, fts_gamma) = (FTS_BOOST_MAX, FTS_GAMMA);
212
213 let mut database = Self {
214 db,
215 embedder,
216 project_id,
217 items_table: None,
218 chunks_table: None,
219 fts_boost_max,
220 fts_gamma,
221 };
222
223 database.ensure_tables().await?;
224 database.ensure_vector_index().await?;
225
226 Ok(database)
227 }
228
229 pub fn set_project_id(&mut self, project_id: Option<String>) {
231 self.project_id = project_id;
232 }
233
234 pub fn project_id(&self) -> Option<&str> {
236 self.project_id.as_deref()
237 }
238
239 async fn ensure_tables(&mut self) -> Result<()> {
241 let mut table_names = self
243 .db
244 .table_names()
245 .execute()
246 .await
247 .map_err(|e| SedimentError::Database(format!("Failed to list tables: {}", e)))?;
248
249 if table_names.contains(&"items_migrated".to_string()) {
251 info!("Detected interrupted migration, recovering...");
252 self.recover_interrupted_migration(&table_names).await?;
253 table_names =
255 self.db.table_names().execute().await.map_err(|e| {
256 SedimentError::Database(format!("Failed to list tables: {}", e))
257 })?;
258 }
259
260 if table_names.contains(&"items".to_string()) {
262 let needs_migration = self.check_needs_migration().await?;
263 if needs_migration {
264 info!("Migrating database schema to version {}", SCHEMA_VERSION);
265 self.migrate_schema().await?;
266 }
267 }
268
269 if table_names.contains(&"items".to_string()) {
271 self.items_table =
272 Some(self.db.open_table("items").execute().await.map_err(|e| {
273 SedimentError::Database(format!("Failed to open items: {}", e))
274 })?);
275 }
276
277 if table_names.contains(&"chunks".to_string()) {
279 self.chunks_table =
280 Some(self.db.open_table("chunks").execute().await.map_err(|e| {
281 SedimentError::Database(format!("Failed to open chunks: {}", e))
282 })?);
283 }
284
285 Ok(())
286 }
287
288 async fn check_needs_migration(&self) -> Result<bool> {
290 let table = self.db.open_table("items").execute().await.map_err(|e| {
291 SedimentError::Database(format!("Failed to open items for check: {}", e))
292 })?;
293
294 let schema = table
295 .schema()
296 .await
297 .map_err(|e| SedimentError::Database(format!("Failed to get schema: {}", e)))?;
298
299 let has_tags = schema.fields().iter().any(|f| f.name() == "tags");
301 Ok(has_tags)
302 }
303
304 async fn recover_interrupted_migration(&mut self, table_names: &[String]) -> Result<()> {
315 let has_items = table_names.contains(&"items".to_string());
316
317 if !has_items {
318 info!("Recovery case A: restoring items from items_migrated");
320 let staging = self
321 .db
322 .open_table("items_migrated")
323 .execute()
324 .await
325 .map_err(|e| {
326 SedimentError::Database(format!("Failed to open staging table: {}", e))
327 })?;
328
329 let results = staging
330 .query()
331 .execute()
332 .await
333 .map_err(|e| SedimentError::Database(format!("Recovery query failed: {}", e)))?
334 .try_collect::<Vec<_>>()
335 .await
336 .map_err(|e| SedimentError::Database(format!("Recovery collect failed: {}", e)))?;
337
338 let dim = self.embedder.dimension();
339 let schema = Arc::new(item_schema(dim));
340 let new_table = self
341 .db
342 .create_empty_table("items", schema.clone())
343 .execute()
344 .await
345 .map_err(|e| {
346 SedimentError::Database(format!("Failed to create items table: {}", e))
347 })?;
348
349 if !results.is_empty() {
350 let batches = RecordBatchIterator::new(results.into_iter().map(Ok), schema);
351 new_table
352 .add(Box::new(batches))
353 .execute()
354 .await
355 .map_err(|e| {
356 SedimentError::Database(format!("Failed to restore items: {}", e))
357 })?;
358 }
359
360 self.db.drop_table("items_migrated").await.map_err(|e| {
361 SedimentError::Database(format!("Failed to drop staging table: {}", e))
362 })?;
363 info!("Recovery case A completed");
364 } else {
365 let has_old_schema = self.check_needs_migration().await?;
367
368 if has_old_schema {
369 info!("Recovery case B: dropping incomplete staging table");
371 self.db.drop_table("items_migrated").await.map_err(|e| {
372 SedimentError::Database(format!("Failed to drop staging table: {}", e))
373 })?;
374 } else {
376 info!("Recovery case C: dropping leftover staging table");
378 self.db.drop_table("items_migrated").await.map_err(|e| {
379 SedimentError::Database(format!("Failed to drop staging table: {}", e))
380 })?;
381 }
382 }
383
384 Ok(())
385 }
386
387 async fn migrate_schema(&mut self) -> Result<()> {
401 info!("Starting schema migration...");
402
403 let old_table = self
405 .db
406 .open_table("items")
407 .execute()
408 .await
409 .map_err(|e| SedimentError::Database(format!("Failed to open old items: {}", e)))?;
410
411 let results = old_table
412 .query()
413 .execute()
414 .await
415 .map_err(|e| SedimentError::Database(format!("Migration query failed: {}", e)))?
416 .try_collect::<Vec<_>>()
417 .await
418 .map_err(|e| SedimentError::Database(format!("Migration collect failed: {}", e)))?;
419
420 let mut new_batches = Vec::new();
422 for batch in &results {
423 let converted = self.convert_batch_to_new_schema(batch)?;
424 new_batches.push(converted);
425 }
426
427 let old_count: usize = results.iter().map(|b| b.num_rows()).sum();
429 let new_count: usize = new_batches.iter().map(|b| b.num_rows()).sum();
430 if old_count != new_count {
431 return Err(SedimentError::Database(format!(
432 "Migration row count mismatch: old={}, new={}",
433 old_count, new_count
434 )));
435 }
436 info!("Migrating {} items to new schema", old_count);
437
438 let table_names = self
440 .db
441 .table_names()
442 .execute()
443 .await
444 .map_err(|e| SedimentError::Database(format!("Failed to list tables: {}", e)))?;
445 if table_names.contains(&"items_migrated".to_string()) {
446 self.db.drop_table("items_migrated").await.map_err(|e| {
447 SedimentError::Database(format!("Failed to drop stale staging: {}", e))
448 })?;
449 }
450
451 let dim = self.embedder.dimension();
453 let schema = Arc::new(item_schema(dim));
454 let staging_table = self
455 .db
456 .create_empty_table("items_migrated", schema.clone())
457 .execute()
458 .await
459 .map_err(|e| {
460 SedimentError::Database(format!("Failed to create staging table: {}", e))
461 })?;
462
463 if !new_batches.is_empty() {
464 let batches = RecordBatchIterator::new(new_batches.into_iter().map(Ok), schema.clone());
465 staging_table
466 .add(Box::new(batches))
467 .execute()
468 .await
469 .map_err(|e| {
470 SedimentError::Database(format!("Failed to insert into staging: {}", e))
471 })?;
472 }
473
474 let staging_count = staging_table
476 .count_rows(None)
477 .await
478 .map_err(|e| SedimentError::Database(format!("Failed to count staging rows: {}", e)))?;
479 if staging_count != old_count {
480 let _ = self.db.drop_table("items_migrated").await;
482 return Err(SedimentError::Database(format!(
483 "Staging row count mismatch: expected {}, got {}",
484 old_count, staging_count
485 )));
486 }
487
488 self.db.drop_table("items").await.map_err(|e| {
490 SedimentError::Database(format!("Failed to drop old items table: {}", e))
491 })?;
492
493 let staging_data = staging_table
495 .query()
496 .execute()
497 .await
498 .map_err(|e| SedimentError::Database(format!("Failed to read staging: {}", e)))?
499 .try_collect::<Vec<_>>()
500 .await
501 .map_err(|e| SedimentError::Database(format!("Failed to collect staging: {}", e)))?;
502
503 let new_table = self
504 .db
505 .create_empty_table("items", schema.clone())
506 .execute()
507 .await
508 .map_err(|e| {
509 SedimentError::Database(format!("Failed to create new items table: {}", e))
510 })?;
511
512 if !staging_data.is_empty() {
513 let batches = RecordBatchIterator::new(staging_data.into_iter().map(Ok), schema);
514 new_table
515 .add(Box::new(batches))
516 .execute()
517 .await
518 .map_err(|e| {
519 SedimentError::Database(format!("Failed to insert migrated items: {}", e))
520 })?;
521 }
522
523 self.db
525 .drop_table("items_migrated")
526 .await
527 .map_err(|e| SedimentError::Database(format!("Failed to drop staging table: {}", e)))?;
528
529 info!("Schema migration completed successfully");
530 Ok(())
531 }
532
533 fn convert_batch_to_new_schema(&self, batch: &RecordBatch) -> Result<RecordBatch> {
535 let schema = Arc::new(item_schema(self.embedder.dimension()));
536
537 let id_col = batch
539 .column_by_name("id")
540 .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?
541 .clone();
542
543 let content_col = batch
544 .column_by_name("content")
545 .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?
546 .clone();
547
548 let project_id_col = batch
549 .column_by_name("project_id")
550 .ok_or_else(|| SedimentError::Database("Missing project_id column".to_string()))?
551 .clone();
552
553 let is_chunked_col = batch
554 .column_by_name("is_chunked")
555 .ok_or_else(|| SedimentError::Database("Missing is_chunked column".to_string()))?
556 .clone();
557
558 let created_at_col = batch
559 .column_by_name("created_at")
560 .ok_or_else(|| SedimentError::Database("Missing created_at column".to_string()))?
561 .clone();
562
563 let vector_col = batch
564 .column_by_name("vector")
565 .ok_or_else(|| SedimentError::Database("Missing vector column".to_string()))?
566 .clone();
567
568 RecordBatch::try_new(
569 schema,
570 vec![
571 id_col,
572 content_col,
573 project_id_col,
574 is_chunked_col,
575 created_at_col,
576 vector_col,
577 ],
578 )
579 .map_err(|e| SedimentError::Database(format!("Failed to create migrated batch: {}", e)))
580 }
581
582 async fn ensure_vector_index(&self) -> Result<()> {
588 const MIN_ROWS_FOR_INDEX: usize = 256;
589
590 for (name, table_opt) in [("items", &self.items_table), ("chunks", &self.chunks_table)] {
591 if let Some(table) = table_opt {
592 let row_count = table.count_rows(None).await.unwrap_or(0);
593
594 let indices = table.list_indices().await.unwrap_or_default();
596
597 if row_count >= MIN_ROWS_FOR_INDEX {
599 let has_vector_index = indices
600 .iter()
601 .any(|idx| idx.columns.contains(&"vector".to_string()));
602
603 if !has_vector_index {
604 info!(
605 "Creating vector index on {} table ({} rows)",
606 name, row_count
607 );
608 match table
609 .create_index(&["vector"], lancedb::index::Index::Auto)
610 .execute()
611 .await
612 {
613 Ok(_) => info!("Vector index created on {} table", name),
614 Err(e) => {
615 tracing::warn!("Failed to create vector index on {}: {}", name, e);
617 }
618 }
619 }
620 }
621
622 if row_count > 0 {
626 match table
627 .create_index(&["content"], lancedb::index::Index::FTS(Default::default()))
628 .replace(true)
629 .execute()
630 .await
631 {
632 Ok(_) => {
633 debug!("FTS index refreshed on {} table ({} rows)", name, row_count)
634 }
635 Err(e) => {
636 tracing::warn!("Failed to create FTS index on {}: {}", name, e);
638 }
639 }
640 }
641 }
642 }
643
644 Ok(())
645 }
646
647 async fn get_items_table(&mut self) -> Result<&Table> {
649 if self.items_table.is_none() {
650 let schema = Arc::new(item_schema(self.embedder.dimension()));
651 let table = self
652 .db
653 .create_empty_table("items", schema)
654 .execute()
655 .await
656 .map_err(|e| {
657 SedimentError::Database(format!("Failed to create items table: {}", e))
658 })?;
659 self.items_table = Some(table);
660 }
661 Ok(self.items_table.as_ref().unwrap())
662 }
663
664 async fn get_chunks_table(&mut self) -> Result<&Table> {
666 if self.chunks_table.is_none() {
667 let schema = Arc::new(chunk_schema(self.embedder.dimension()));
668 let table = self
669 .db
670 .create_empty_table("chunks", schema)
671 .execute()
672 .await
673 .map_err(|e| {
674 SedimentError::Database(format!("Failed to create chunks table: {}", e))
675 })?;
676 self.chunks_table = Some(table);
677 }
678 Ok(self.chunks_table.as_ref().unwrap())
679 }
680
681 pub async fn store_item(&mut self, mut item: Item) -> Result<StoreResult> {
688 if item.project_id.is_none() {
690 item.project_id = self.project_id.clone();
691 }
692
693 let should_chunk = item.content.chars().count() > CHUNK_THRESHOLD;
696 item.is_chunked = should_chunk;
697
698 let embedding_text = item.embedding_text();
700 let embedding = self.embedder.embed_document(&embedding_text)?;
701 item.embedding = embedding;
702
703 let table = self.get_items_table().await?;
705 let batch = item_to_batch(&item)?;
706 let batches =
707 RecordBatchIterator::new(vec![Ok(batch)], Arc::new(item_schema(item.embedding.len())));
708
709 table
710 .add(Box::new(batches))
711 .execute()
712 .await
713 .map_err(|e| SedimentError::Database(format!("Failed to store item: {}", e)))?;
714
715 if should_chunk {
717 let content_type = detect_content_type(&item.content);
718 let config = ChunkingConfig::default();
719 let mut chunk_results = chunk_content(&item.content, content_type, &config);
720
721 if chunk_results.len() > MAX_CHUNKS_PER_ITEM {
723 tracing::warn!(
724 "Chunk count {} exceeds limit {}, truncating",
725 chunk_results.len(),
726 MAX_CHUNKS_PER_ITEM
727 );
728 chunk_results.truncate(MAX_CHUNKS_PER_ITEM);
729 }
730
731 if let Err(e) = self.store_chunks(&item.id, &chunk_results).await {
732 let _ = self.delete_item(&item.id).await;
734 return Err(e);
735 }
736
737 debug!(
738 "Stored item: {} with {} chunks",
739 item.id,
740 chunk_results.len()
741 );
742 } else {
743 debug!("Stored item: {} (no chunking)", item.id);
744 }
745
746 let potential_conflicts = self
749 .find_similar_items_by_vector(
750 &item.embedding,
751 Some(&item.id),
752 CONFLICT_SIMILARITY_THRESHOLD,
753 CONFLICT_SEARCH_LIMIT,
754 )
755 .await
756 .unwrap_or_default();
757
758 Ok(StoreResult {
759 id: item.id,
760 potential_conflicts,
761 })
762 }
763
764 async fn store_chunks(
766 &mut self,
767 item_id: &str,
768 chunk_results: &[crate::chunker::ChunkResult],
769 ) -> Result<()> {
770 let embedder = self.embedder.clone();
771 let chunks_table = self.get_chunks_table().await?;
772
773 let chunk_texts: Vec<&str> = chunk_results.iter().map(|cr| cr.content.as_str()).collect();
775 let mut all_embeddings = Vec::with_capacity(chunk_texts.len());
776 for batch_start in (0..chunk_texts.len()).step_by(EMBEDDING_BATCH_SIZE) {
777 let batch_end = (batch_start + EMBEDDING_BATCH_SIZE).min(chunk_texts.len());
778 let batch_embeddings =
779 embedder.embed_document_batch(&chunk_texts[batch_start..batch_end])?;
780 all_embeddings.extend(batch_embeddings);
781 }
782
783 let mut all_chunk_batches = Vec::with_capacity(chunk_results.len());
785 for (i, (chunk_result, embedding)) in chunk_results.iter().zip(all_embeddings).enumerate() {
786 let mut chunk = Chunk::new(item_id, i, &chunk_result.content);
787 if let Some(ctx) = &chunk_result.context {
788 chunk = chunk.with_context(ctx);
789 }
790 chunk.embedding = embedding;
791 all_chunk_batches.push(chunk_to_batch(&chunk)?);
792 }
793
794 if !all_chunk_batches.is_empty() {
796 let schema = Arc::new(chunk_schema(embedder.dimension()));
797 let batches = RecordBatchIterator::new(all_chunk_batches.into_iter().map(Ok), schema);
798 chunks_table
799 .add(Box::new(batches))
800 .execute()
801 .await
802 .map_err(|e| SedimentError::Database(format!("Failed to store chunks: {}", e)))?;
803 }
804
805 Ok(())
806 }
807
808 async fn fts_rank_items(
811 &self,
812 table: &Table,
813 query: &str,
814 limit: usize,
815 ) -> Option<std::collections::HashMap<String, f32>> {
816 let fts_query =
817 FullTextSearchQuery::new(query.to_string()).columns(Some(vec!["content".to_string()]));
818
819 let fts_results = table
820 .query()
821 .full_text_search(fts_query)
822 .limit(limit)
823 .execute()
824 .await
825 .ok()?
826 .try_collect::<Vec<_>>()
827 .await
828 .ok()?;
829
830 let mut scores = std::collections::HashMap::new();
831 for batch in fts_results {
832 let ids = batch
833 .column_by_name("id")
834 .and_then(|c| c.as_any().downcast_ref::<StringArray>())?;
835 let bm25_scores = batch
836 .column_by_name("_score")
837 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
838 for i in 0..ids.len() {
839 if !ids.is_null(i) {
840 let score = bm25_scores.map(|s| s.value(i)).unwrap_or(0.0);
841 scores.insert(ids.value(i).to_string(), score);
842 }
843 }
844 }
845 Some(scores)
846 }
847
848 pub async fn search_items(
850 &mut self,
851 query: &str,
852 limit: usize,
853 filters: ItemFilters,
854 ) -> Result<Vec<SearchResult>> {
855 let limit = limit.min(1000);
857 self.ensure_vector_index().await?;
859
860 let query_embedding = self.embedder.embed_query(query)?;
862 let min_similarity = filters.min_similarity.unwrap_or(0.3);
863
864 let mut results_map: std::collections::HashMap<String, (SearchResult, f32)> =
866 std::collections::HashMap::new();
867
868 if let Some(table) = &self.items_table {
870 let row_count = table.count_rows(None).await.unwrap_or(0);
871 let base_query = table
872 .vector_search(query_embedding.clone())
873 .map_err(|e| SedimentError::Database(format!("Failed to build search: {}", e)))?;
874 let query_builder = if row_count < VECTOR_INDEX_THRESHOLD {
875 base_query.bypass_vector_index().limit(limit * 2)
876 } else {
877 base_query.refine_factor(10).limit(limit * 2)
878 };
879
880 let results = query_builder
881 .execute()
882 .await
883 .map_err(|e| SedimentError::Database(format!("Search failed: {}", e)))?
884 .try_collect::<Vec<_>>()
885 .await
886 .map_err(|e| {
887 SedimentError::Database(format!("Failed to collect results: {}", e))
888 })?;
889
890 let mut vector_items: Vec<(Item, f32)> = Vec::new();
892 for batch in results {
893 let items = batch_to_items(&batch)?;
894 let distances = batch
895 .column_by_name("_distance")
896 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
897
898 for (i, item) in items.into_iter().enumerate() {
899 let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
900 let similarity = 1.0 / (1.0 + distance);
901 if similarity >= min_similarity {
902 vector_items.push((item, similarity));
903 }
904 }
905 }
906
907 let fts_ranking = self.fts_rank_items(table, query, limit * 2).await;
909
910 let max_bm25 = fts_ranking
912 .as_ref()
913 .and_then(|scores| scores.values().cloned().reduce(f32::max))
914 .unwrap_or(1.0)
915 .max(f32::EPSILON);
916
917 for (item, similarity) in vector_items {
923 let fts_boost = fts_ranking.as_ref().map_or(0.0, |scores| {
924 scores.get(&item.id).map_or(0.0, |&bm25_score| {
925 self.fts_boost_max * (bm25_score / max_bm25).powf(self.fts_gamma)
926 })
927 });
928 let boosted_similarity = boost_similarity(
929 similarity + fts_boost,
930 item.project_id.as_deref(),
931 self.project_id.as_deref(),
932 );
933
934 let result = SearchResult::from_item(&item, boosted_similarity);
935 results_map
936 .entry(item.id.clone())
937 .or_insert((result, boosted_similarity));
938 }
939 }
940
941 if let Some(chunks_table) = &self.chunks_table {
943 let chunk_row_count = chunks_table.count_rows(None).await.unwrap_or(0);
944 let chunk_base_query = chunks_table.vector_search(query_embedding).map_err(|e| {
945 SedimentError::Database(format!("Failed to build chunk search: {}", e))
946 })?;
947 let chunk_results = if chunk_row_count < VECTOR_INDEX_THRESHOLD {
948 chunk_base_query.bypass_vector_index().limit(limit * 3)
949 } else {
950 chunk_base_query.refine_factor(10).limit(limit * 3)
951 }
952 .execute()
953 .await
954 .map_err(|e| SedimentError::Database(format!("Chunk search failed: {}", e)))?
955 .try_collect::<Vec<_>>()
956 .await
957 .map_err(|e| {
958 SedimentError::Database(format!("Failed to collect chunk results: {}", e))
959 })?;
960
961 let mut chunk_matches: std::collections::HashMap<String, (String, f32)> =
963 std::collections::HashMap::new();
964
965 for batch in chunk_results {
966 let chunks = batch_to_chunks(&batch)?;
967 let distances = batch
968 .column_by_name("_distance")
969 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
970
971 for (i, chunk) in chunks.into_iter().enumerate() {
972 let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
973 let similarity = 1.0 / (1.0 + distance);
974
975 if similarity < min_similarity {
976 continue;
977 }
978
979 chunk_matches
981 .entry(chunk.item_id.clone())
982 .and_modify(|(content, best_sim)| {
983 if similarity > *best_sim {
984 *content = chunk.content.clone();
985 *best_sim = similarity;
986 }
987 })
988 .or_insert((chunk.content.clone(), similarity));
989 }
990 }
991
992 let chunk_item_ids: Vec<&str> = chunk_matches.keys().map(|id| id.as_str()).collect();
994 let parent_items = self.get_items_batch(&chunk_item_ids).await?;
995 let parent_map: std::collections::HashMap<&str, &Item> = parent_items
996 .iter()
997 .map(|item| (item.id.as_str(), item))
998 .collect();
999
1000 for (item_id, (excerpt, chunk_similarity)) in chunk_matches {
1001 if let Some(item) = parent_map.get(item_id.as_str()) {
1002 let boosted_similarity = boost_similarity(
1004 chunk_similarity,
1005 item.project_id.as_deref(),
1006 self.project_id.as_deref(),
1007 );
1008
1009 let result =
1010 SearchResult::from_item_with_excerpt(item, boosted_similarity, excerpt);
1011
1012 results_map
1014 .entry(item_id)
1015 .and_modify(|(existing, existing_sim)| {
1016 if boosted_similarity > *existing_sim {
1017 *existing = result.clone();
1018 *existing_sim = boosted_similarity;
1019 }
1020 })
1021 .or_insert((result, boosted_similarity));
1022 }
1023 }
1024 }
1025
1026 let mut search_results: Vec<SearchResult> =
1029 results_map.into_values().map(|(sr, _)| sr).collect();
1030 search_results.sort_by(|a, b| {
1031 b.similarity
1032 .partial_cmp(&a.similarity)
1033 .unwrap_or(std::cmp::Ordering::Equal)
1034 });
1035 search_results.truncate(limit);
1036
1037 Ok(search_results)
1038 }
1039
1040 pub async fn find_similar_items(
1045 &mut self,
1046 content: &str,
1047 min_similarity: f32,
1048 limit: usize,
1049 ) -> Result<Vec<ConflictInfo>> {
1050 let embedding = self.embedder.embed_document(content)?;
1051 self.find_similar_items_by_vector(&embedding, None, min_similarity, limit)
1052 .await
1053 }
1054
1055 pub async fn find_similar_items_by_vector(
1059 &self,
1060 embedding: &[f32],
1061 exclude_id: Option<&str>,
1062 min_similarity: f32,
1063 limit: usize,
1064 ) -> Result<Vec<ConflictInfo>> {
1065 let table = match &self.items_table {
1066 Some(t) => t,
1067 None => return Ok(Vec::new()),
1068 };
1069
1070 let row_count = table.count_rows(None).await.unwrap_or(0);
1071 let base_query = table
1072 .vector_search(embedding.to_vec())
1073 .map_err(|e| SedimentError::Database(format!("Failed to build search: {}", e)))?;
1074 let results = if row_count < VECTOR_INDEX_THRESHOLD {
1075 base_query.bypass_vector_index().limit(limit)
1076 } else {
1077 base_query.refine_factor(10).limit(limit)
1078 }
1079 .execute()
1080 .await
1081 .map_err(|e| SedimentError::Database(format!("Search failed: {}", e)))?
1082 .try_collect::<Vec<_>>()
1083 .await
1084 .map_err(|e| SedimentError::Database(format!("Failed to collect results: {}", e)))?;
1085
1086 let mut conflicts = Vec::new();
1087
1088 for batch in results {
1089 let items = batch_to_items(&batch)?;
1090 let distances = batch
1091 .column_by_name("_distance")
1092 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
1093
1094 for (i, item) in items.into_iter().enumerate() {
1095 if exclude_id.is_some_and(|eid| eid == item.id) {
1096 continue;
1097 }
1098
1099 let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
1100 let similarity = 1.0 / (1.0 + distance);
1101
1102 if similarity >= min_similarity {
1103 conflicts.push(ConflictInfo {
1104 id: item.id,
1105 content: item.content,
1106 similarity,
1107 });
1108 }
1109 }
1110 }
1111
1112 conflicts.sort_by(|a, b| {
1114 b.similarity
1115 .partial_cmp(&a.similarity)
1116 .unwrap_or(std::cmp::Ordering::Equal)
1117 });
1118
1119 Ok(conflicts)
1120 }
1121
1122 pub async fn list_items(
1124 &mut self,
1125 limit: Option<usize>,
1126 scope: crate::ListScope,
1127 ) -> Result<Vec<Item>> {
1128 let table = match &self.items_table {
1129 Some(t) => t,
1130 None => return Ok(Vec::new()),
1131 };
1132
1133 let mut filter_parts = Vec::new();
1134
1135 match scope {
1137 crate::ListScope::Project => {
1138 if let Some(ref pid) = self.project_id {
1139 if !is_valid_id(pid) {
1140 return Err(SedimentError::Database(
1141 "Invalid project_id for list filter".to_string(),
1142 ));
1143 }
1144 filter_parts.push(format!("project_id = '{}'", sanitize_sql_string(pid)));
1145 } else {
1146 return Ok(Vec::new());
1148 }
1149 }
1150 crate::ListScope::Global => {
1151 filter_parts.push("project_id IS NULL".to_string());
1152 }
1153 crate::ListScope::All => {
1154 }
1156 }
1157
1158 let mut query = table.query();
1159
1160 if !filter_parts.is_empty() {
1161 let filter_str = filter_parts.join(" AND ");
1162 query = query.only_if(filter_str);
1163 }
1164
1165 if let Some(l) = limit {
1166 query = query.limit(l);
1167 }
1168
1169 let results = query
1170 .execute()
1171 .await
1172 .map_err(|e| SedimentError::Database(format!("Query failed: {}", e)))?
1173 .try_collect::<Vec<_>>()
1174 .await
1175 .map_err(|e| SedimentError::Database(format!("Failed to collect: {}", e)))?;
1176
1177 let mut items = Vec::new();
1178 for batch in results {
1179 items.extend(batch_to_items(&batch)?);
1180 }
1181
1182 Ok(items)
1183 }
1184
1185 pub async fn get_item(&self, id: &str) -> Result<Option<Item>> {
1187 if !is_valid_id(id) {
1188 return Ok(None);
1189 }
1190 let table = match &self.items_table {
1191 Some(t) => t,
1192 None => return Ok(None),
1193 };
1194
1195 let results = table
1196 .query()
1197 .only_if(format!("id = '{}'", sanitize_sql_string(id)))
1198 .limit(1)
1199 .execute()
1200 .await
1201 .map_err(|e| SedimentError::Database(format!("Query failed: {}", e)))?
1202 .try_collect::<Vec<_>>()
1203 .await
1204 .map_err(|e| SedimentError::Database(format!("Failed to collect: {}", e)))?;
1205
1206 for batch in results {
1207 let items = batch_to_items(&batch)?;
1208 if let Some(item) = items.into_iter().next() {
1209 return Ok(Some(item));
1210 }
1211 }
1212
1213 Ok(None)
1214 }
1215
1216 pub async fn get_items_batch(&self, ids: &[&str]) -> Result<Vec<Item>> {
1218 let table = match &self.items_table {
1219 Some(t) => t,
1220 None => return Ok(Vec::new()),
1221 };
1222
1223 if ids.is_empty() {
1224 return Ok(Vec::new());
1225 }
1226
1227 let quoted: Vec<String> = ids
1228 .iter()
1229 .filter(|id| is_valid_id(id))
1230 .map(|id| format!("'{}'", sanitize_sql_string(id)))
1231 .collect();
1232 if quoted.is_empty() {
1233 return Ok(Vec::new());
1234 }
1235 let filter = format!("id IN ({})", quoted.join(", "));
1236
1237 let results = table
1238 .query()
1239 .only_if(filter)
1240 .execute()
1241 .await
1242 .map_err(|e| SedimentError::Database(format!("Batch query failed: {}", e)))?
1243 .try_collect::<Vec<_>>()
1244 .await
1245 .map_err(|e| SedimentError::Database(format!("Failed to collect batch: {}", e)))?;
1246
1247 let mut items = Vec::new();
1248 for batch in results {
1249 items.extend(batch_to_items(&batch)?);
1250 }
1251
1252 Ok(items)
1253 }
1254
1255 pub async fn delete_item(&self, id: &str) -> Result<bool> {
1258 if !is_valid_id(id) {
1259 return Ok(false);
1260 }
1261 let table = match &self.items_table {
1263 Some(t) => t,
1264 None => return Ok(false),
1265 };
1266
1267 let exists = self.get_item(id).await?.is_some();
1268 if !exists {
1269 return Ok(false);
1270 }
1271
1272 table
1276 .delete(&format!("id = '{}'", sanitize_sql_string(id)))
1277 .await
1278 .map_err(|e| SedimentError::Database(format!("Delete failed: {}", e)))?;
1279
1280 if let Some(chunks_table) = &self.chunks_table
1283 && let Err(e) = chunks_table
1284 .delete(&format!("item_id = '{}'", sanitize_sql_string(id)))
1285 .await
1286 {
1287 tracing::warn!("Failed to delete chunks for item {}: {}", id, e);
1288 }
1289
1290 Ok(true)
1291 }
1292
1293 pub async fn stats(&self) -> Result<DatabaseStats> {
1295 let mut stats = DatabaseStats::default();
1296
1297 if let Some(table) = &self.items_table {
1298 stats.item_count = table
1299 .count_rows(None)
1300 .await
1301 .map_err(|e| SedimentError::Database(format!("Count failed: {}", e)))?;
1302 }
1303
1304 if let Some(table) = &self.chunks_table {
1305 stats.chunk_count = table
1306 .count_rows(None)
1307 .await
1308 .map_err(|e| SedimentError::Database(format!("Count failed: {}", e)))?;
1309 }
1310
1311 Ok(stats)
1312 }
1313}
1314
1315pub async fn migrate_project_id(
1322 db_path: &std::path::Path,
1323 old_id: &str,
1324 new_id: &str,
1325) -> Result<u64> {
1326 if !is_valid_id(old_id) || !is_valid_id(new_id) {
1327 return Err(SedimentError::Database(
1328 "Invalid project ID for migration".to_string(),
1329 ));
1330 }
1331
1332 let db = connect(db_path.to_str().ok_or_else(|| {
1333 SedimentError::Database("Database path contains invalid UTF-8".to_string())
1334 })?)
1335 .execute()
1336 .await
1337 .map_err(|e| SedimentError::Database(format!("Failed to connect for migration: {}", e)))?;
1338
1339 let table_names = db
1340 .table_names()
1341 .execute()
1342 .await
1343 .map_err(|e| SedimentError::Database(format!("Failed to list tables: {}", e)))?;
1344
1345 let mut total_updated = 0u64;
1346
1347 if table_names.contains(&"items".to_string()) {
1348 let table =
1349 db.open_table("items").execute().await.map_err(|e| {
1350 SedimentError::Database(format!("Failed to open items table: {}", e))
1351 })?;
1352
1353 let updated = table
1354 .update()
1355 .only_if(format!("project_id = '{}'", sanitize_sql_string(old_id)))
1356 .column("project_id", format!("'{}'", sanitize_sql_string(new_id)))
1357 .execute()
1358 .await
1359 .map_err(|e| SedimentError::Database(format!("Failed to migrate items: {}", e)))?;
1360
1361 total_updated += updated;
1362 info!(
1363 "Migrated {} items from project {} to {}",
1364 updated, old_id, new_id
1365 );
1366 }
1367
1368 Ok(total_updated)
1369}
1370
1371pub fn score_with_decay(
1382 similarity: f32,
1383 now: i64,
1384 created_at: i64,
1385 access_count: u32,
1386 last_accessed_at: Option<i64>,
1387) -> f32 {
1388 if !similarity.is_finite() {
1390 return 0.0;
1391 }
1392
1393 let reference_time = last_accessed_at.unwrap_or(created_at);
1394 let age_secs = (now - reference_time).max(0) as f64;
1395 let age_days = age_secs / 86400.0;
1396
1397 let freshness = 1.0 / (1.0 + age_days / 30.0);
1398 let frequency = 1.0 + 0.1 * (1.0 + access_count as f64).ln();
1399
1400 let result = similarity * (freshness * frequency) as f32;
1401 if result.is_finite() { result } else { 0.0 }
1402}
1403
1404fn detect_content_type(content: &str) -> ContentType {
1408 let trimmed = content.trim();
1409
1410 if ((trimmed.starts_with('{') && trimmed.ends_with('}'))
1412 || (trimmed.starts_with('[') && trimmed.ends_with(']')))
1413 && serde_json::from_str::<serde_json::Value>(trimmed).is_ok()
1414 {
1415 return ContentType::Json;
1416 }
1417
1418 if trimmed.contains(":\n") || trimmed.contains(": ") || trimmed.starts_with("---") {
1422 let lines: Vec<&str> = trimmed.lines().take(10).collect();
1423 let yaml_key_count = lines
1424 .iter()
1425 .filter(|line| {
1426 let l = line.trim();
1427 !l.is_empty()
1430 && !l.starts_with('#')
1431 && !l.contains("://")
1432 && l.contains(": ")
1433 && l.split(": ").next().is_some_and(|key| {
1434 let k = key.trim_start_matches("- ");
1435 !k.is_empty()
1436 && k.chars()
1437 .all(|c| c.is_alphanumeric() || c == '_' || c == '-')
1438 })
1439 })
1440 .count();
1441 if yaml_key_count >= 2 || (trimmed.starts_with("---") && yaml_key_count >= 1) {
1443 return ContentType::Yaml;
1444 }
1445 }
1446
1447 if trimmed.contains("\n# ") || trimmed.starts_with("# ") || trimmed.contains("\n## ") {
1449 return ContentType::Markdown;
1450 }
1451
1452 let code_patterns = [
1455 "fn ",
1456 "pub fn ",
1457 "def ",
1458 "class ",
1459 "function ",
1460 "const ",
1461 "let ",
1462 "var ",
1463 "import ",
1464 "export ",
1465 "struct ",
1466 "impl ",
1467 "trait ",
1468 ];
1469 let has_code_pattern = trimmed.lines().any(|line| {
1470 let l = line.trim();
1471 code_patterns.iter().any(|p| l.starts_with(p))
1472 });
1473 if has_code_pattern {
1474 return ContentType::Code;
1475 }
1476
1477 ContentType::Text
1478}
1479
1480fn item_to_batch(item: &Item) -> Result<RecordBatch> {
1483 let schema = Arc::new(item_schema(item.embedding.len()));
1484
1485 let id = StringArray::from(vec![item.id.as_str()]);
1486 let content = StringArray::from(vec![item.content.as_str()]);
1487 let project_id = StringArray::from(vec![item.project_id.as_deref()]);
1488 let is_chunked = BooleanArray::from(vec![item.is_chunked]);
1489 let created_at = Int64Array::from(vec![item.created_at.timestamp()]);
1490
1491 let vector = create_embedding_array(&item.embedding)?;
1492
1493 RecordBatch::try_new(
1494 schema,
1495 vec![
1496 Arc::new(id),
1497 Arc::new(content),
1498 Arc::new(project_id),
1499 Arc::new(is_chunked),
1500 Arc::new(created_at),
1501 Arc::new(vector),
1502 ],
1503 )
1504 .map_err(|e| SedimentError::Database(format!("Failed to create batch: {}", e)))
1505}
1506
1507fn batch_to_items(batch: &RecordBatch) -> Result<Vec<Item>> {
1508 let mut items = Vec::new();
1509
1510 let id_col = batch
1511 .column_by_name("id")
1512 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1513 .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?;
1514
1515 let content_col = batch
1516 .column_by_name("content")
1517 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1518 .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?;
1519
1520 let project_id_col = batch
1521 .column_by_name("project_id")
1522 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
1523
1524 let is_chunked_col = batch
1525 .column_by_name("is_chunked")
1526 .and_then(|c| c.as_any().downcast_ref::<BooleanArray>());
1527
1528 let created_at_col = batch
1529 .column_by_name("created_at")
1530 .and_then(|c| c.as_any().downcast_ref::<Int64Array>());
1531
1532 let vector_col = batch
1533 .column_by_name("vector")
1534 .and_then(|c| c.as_any().downcast_ref::<FixedSizeListArray>());
1535
1536 for i in 0..batch.num_rows() {
1537 let id = id_col.value(i).to_string();
1538 let content = content_col.value(i).to_string();
1539
1540 let project_id = project_id_col.and_then(|c| {
1541 if c.is_null(i) {
1542 None
1543 } else {
1544 Some(c.value(i).to_string())
1545 }
1546 });
1547
1548 let is_chunked = is_chunked_col.map(|c| c.value(i)).unwrap_or(false);
1549
1550 let created_at = created_at_col
1551 .map(|c| {
1552 Utc.timestamp_opt(c.value(i), 0)
1553 .single()
1554 .unwrap_or_else(Utc::now)
1555 })
1556 .unwrap_or_else(Utc::now);
1557
1558 let embedding = vector_col
1559 .and_then(|col| {
1560 let value = col.value(i);
1561 value
1562 .as_any()
1563 .downcast_ref::<Float32Array>()
1564 .map(|arr| arr.values().to_vec())
1565 })
1566 .unwrap_or_default();
1567
1568 let item = Item {
1569 id,
1570 content,
1571 embedding,
1572 project_id,
1573 is_chunked,
1574 created_at,
1575 };
1576
1577 items.push(item);
1578 }
1579
1580 Ok(items)
1581}
1582
1583fn chunk_to_batch(chunk: &Chunk) -> Result<RecordBatch> {
1584 let schema = Arc::new(chunk_schema(chunk.embedding.len()));
1585
1586 let id = StringArray::from(vec![chunk.id.as_str()]);
1587 let item_id = StringArray::from(vec![chunk.item_id.as_str()]);
1588 let chunk_index = Int32Array::from(vec![i32::try_from(chunk.chunk_index).unwrap_or(i32::MAX)]);
1589 let content = StringArray::from(vec![chunk.content.as_str()]);
1590 let context = StringArray::from(vec![chunk.context.as_deref()]);
1591
1592 let vector = create_embedding_array(&chunk.embedding)?;
1593
1594 RecordBatch::try_new(
1595 schema,
1596 vec![
1597 Arc::new(id),
1598 Arc::new(item_id),
1599 Arc::new(chunk_index),
1600 Arc::new(content),
1601 Arc::new(context),
1602 Arc::new(vector),
1603 ],
1604 )
1605 .map_err(|e| SedimentError::Database(format!("Failed to create batch: {}", e)))
1606}
1607
1608fn batch_to_chunks(batch: &RecordBatch) -> Result<Vec<Chunk>> {
1609 let mut chunks = Vec::new();
1610
1611 let id_col = batch
1612 .column_by_name("id")
1613 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1614 .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?;
1615
1616 let item_id_col = batch
1617 .column_by_name("item_id")
1618 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1619 .ok_or_else(|| SedimentError::Database("Missing item_id column".to_string()))?;
1620
1621 let chunk_index_col = batch
1622 .column_by_name("chunk_index")
1623 .and_then(|c| c.as_any().downcast_ref::<Int32Array>())
1624 .ok_or_else(|| SedimentError::Database("Missing chunk_index column".to_string()))?;
1625
1626 let content_col = batch
1627 .column_by_name("content")
1628 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1629 .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?;
1630
1631 let context_col = batch
1632 .column_by_name("context")
1633 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
1634
1635 for i in 0..batch.num_rows() {
1636 let id = id_col.value(i).to_string();
1637 let item_id = item_id_col.value(i).to_string();
1638 let chunk_index = chunk_index_col.value(i) as usize;
1639 let content = content_col.value(i).to_string();
1640 let context = context_col.and_then(|c| {
1641 if c.is_null(i) {
1642 None
1643 } else {
1644 Some(c.value(i).to_string())
1645 }
1646 });
1647
1648 let chunk = Chunk {
1649 id,
1650 item_id,
1651 chunk_index,
1652 content,
1653 embedding: Vec::new(),
1654 context,
1655 };
1656
1657 chunks.push(chunk);
1658 }
1659
1660 Ok(chunks)
1661}
1662
1663fn create_embedding_array(embedding: &[f32]) -> Result<FixedSizeListArray> {
1664 let dim = embedding.len();
1665 let values = Float32Array::from(embedding.to_vec());
1666 let field = Arc::new(Field::new("item", DataType::Float32, true));
1667
1668 FixedSizeListArray::try_new(field, dim as i32, Arc::new(values), None)
1669 .map_err(|e| SedimentError::Database(format!("Failed to create vector: {}", e)))
1670}
1671
1672#[cfg(test)]
1673mod tests {
1674 use super::*;
1675
1676 #[test]
1677 fn test_score_with_decay_fresh_item() {
1678 let now = 1700000000i64;
1679 let created = now; let score = score_with_decay(0.8, now, created, 0, None);
1681 let expected = 0.8 * 1.0 * 1.0;
1683 assert!((score - expected).abs() < 0.001, "got {}", score);
1684 }
1685
1686 #[test]
1687 fn test_score_with_decay_30_day_old() {
1688 let now = 1700000000i64;
1689 let created = now - 30 * 86400; let score = score_with_decay(0.8, now, created, 0, None);
1691 let expected = 0.8 * 0.5;
1693 assert!((score - expected).abs() < 0.001, "got {}", score);
1694 }
1695
1696 #[test]
1697 fn test_score_with_decay_frequent_access() {
1698 let now = 1700000000i64;
1699 let created = now - 30 * 86400;
1700 let last_accessed = now; let score = score_with_decay(0.8, now, created, 10, Some(last_accessed));
1702 let freq = 1.0 + 0.1 * (11.0_f64).ln();
1704 let expected = 0.8 * 1.0 * freq as f32;
1705 assert!((score - expected).abs() < 0.01, "got {}", score);
1706 }
1707
1708 #[test]
1709 fn test_score_with_decay_old_and_unused() {
1710 let now = 1700000000i64;
1711 let created = now - 90 * 86400; let score = score_with_decay(0.8, now, created, 0, None);
1713 let expected = 0.8 * 0.25;
1715 assert!((score - expected).abs() < 0.001, "got {}", score);
1716 }
1717
1718 #[test]
1719 fn test_sanitize_sql_string_escapes_quotes_and_backslashes() {
1720 assert_eq!(sanitize_sql_string("hello"), "hello");
1721 assert_eq!(sanitize_sql_string("it's"), "it''s");
1722 assert_eq!(sanitize_sql_string(r"a\'b"), r"a\\''b");
1723 assert_eq!(sanitize_sql_string(r"path\to\file"), r"path\\to\\file");
1724 }
1725
1726 #[test]
1727 fn test_sanitize_sql_string_strips_null_bytes() {
1728 assert_eq!(sanitize_sql_string("abc\0def"), "abcdef");
1729 assert_eq!(sanitize_sql_string("\0' OR 1=1 --"), "'' OR 1=1 ");
1730 assert_eq!(sanitize_sql_string("*/ OR 1=1"), " OR 1=1");
1732 assert_eq!(sanitize_sql_string("clean"), "clean");
1733 }
1734
1735 #[test]
1736 fn test_sanitize_sql_string_strips_semicolons() {
1737 assert_eq!(
1738 sanitize_sql_string("a; DROP TABLE items"),
1739 "a DROP TABLE items"
1740 );
1741 assert_eq!(sanitize_sql_string("normal;"), "normal");
1742 }
1743
1744 #[test]
1745 fn test_sanitize_sql_string_strips_comments() {
1746 assert_eq!(sanitize_sql_string("val' -- comment"), "val'' comment");
1748 assert_eq!(sanitize_sql_string("val' /* block */"), "val'' block ");
1750 assert_eq!(sanitize_sql_string("a--b--c"), "abc");
1752 assert_eq!(sanitize_sql_string("injected */ rest"), "injected rest");
1754 assert_eq!(sanitize_sql_string("*/"), "");
1756 }
1757
1758 #[test]
1759 fn test_sanitize_sql_string_adversarial_inputs() {
1760 assert_eq!(
1762 sanitize_sql_string("'; DROP TABLE items;--"),
1763 "'' DROP TABLE items"
1764 );
1765 assert_eq!(
1767 sanitize_sql_string("hello\u{200B}world"),
1768 "hello\u{200B}world"
1769 );
1770 assert_eq!(sanitize_sql_string(""), "");
1772 assert_eq!(sanitize_sql_string("\0;\0"), "");
1774 }
1775
1776 #[test]
1777 fn test_is_valid_id() {
1778 assert!(is_valid_id("550e8400-e29b-41d4-a716-446655440000"));
1780 assert!(is_valid_id("abcdef0123456789"));
1781 assert!(!is_valid_id(""));
1783 assert!(!is_valid_id("'; DROP TABLE items;--"));
1784 assert!(!is_valid_id("hello world"));
1785 assert!(!is_valid_id("abc\0def"));
1786 assert!(!is_valid_id(&"a".repeat(65)));
1788 }
1789
1790 #[test]
1791 fn test_detect_content_type_yaml_not_prose() {
1792 let prose = "Dear John:\nI wanted to write you about something.\nSubject: important matter";
1794 let detected = detect_content_type(prose);
1795 assert_ne!(
1796 detected,
1797 ContentType::Yaml,
1798 "Prose with colons should not be detected as YAML"
1799 );
1800
1801 let yaml = "server: localhost\nport: 8080\ndatabase: mydb";
1803 let detected = detect_content_type(yaml);
1804 assert_eq!(detected, ContentType::Yaml);
1805 }
1806
1807 #[test]
1808 fn test_detect_content_type_yaml_with_separator() {
1809 let yaml = "---\nname: test\nversion: 1.0";
1810 let detected = detect_content_type(yaml);
1811 assert_eq!(detected, ContentType::Yaml);
1812 }
1813
1814 #[test]
1815 fn test_chunk_threshold_uses_chars_not_bytes() {
1816 let emoji_content = "😀".repeat(500);
1819 assert_eq!(emoji_content.chars().count(), 500);
1820 assert_eq!(emoji_content.len(), 2000); let should_chunk = emoji_content.chars().count() > CHUNK_THRESHOLD;
1823 assert!(
1824 !should_chunk,
1825 "500 chars should not exceed 1000-char threshold"
1826 );
1827
1828 let long_content = "a".repeat(1001);
1830 let should_chunk = long_content.chars().count() > CHUNK_THRESHOLD;
1831 assert!(should_chunk, "1001 chars should exceed 1000-char threshold");
1832 }
1833
1834 #[test]
1835 fn test_schema_version() {
1836 let version = SCHEMA_VERSION;
1838 assert!(version >= 2, "Schema version should be at least 2");
1839 }
1840
1841 use crate::embedder::EMBEDDING_DIM;
1842
1843 fn old_item_schema() -> Schema {
1845 Schema::new(vec![
1846 Field::new("id", DataType::Utf8, false),
1847 Field::new("content", DataType::Utf8, false),
1848 Field::new("project_id", DataType::Utf8, true),
1849 Field::new("tags", DataType::Utf8, true), Field::new("is_chunked", DataType::Boolean, false),
1851 Field::new("created_at", DataType::Int64, false),
1852 Field::new(
1853 "vector",
1854 DataType::FixedSizeList(
1855 Arc::new(Field::new("item", DataType::Float32, true)),
1856 EMBEDDING_DIM as i32,
1857 ),
1858 false,
1859 ),
1860 ])
1861 }
1862
1863 fn old_item_batch(id: &str, content: &str) -> RecordBatch {
1865 let schema = Arc::new(old_item_schema());
1866 let vector_values = Float32Array::from(vec![0.0f32; EMBEDDING_DIM]);
1867 let vector_field = Arc::new(Field::new("item", DataType::Float32, true));
1868 let vector = FixedSizeListArray::try_new(
1869 vector_field,
1870 EMBEDDING_DIM as i32,
1871 Arc::new(vector_values),
1872 None,
1873 )
1874 .unwrap();
1875
1876 RecordBatch::try_new(
1877 schema,
1878 vec![
1879 Arc::new(StringArray::from(vec![id])),
1880 Arc::new(StringArray::from(vec![content])),
1881 Arc::new(StringArray::from(vec![None::<&str>])), Arc::new(StringArray::from(vec![None::<&str>])), Arc::new(BooleanArray::from(vec![false])),
1884 Arc::new(Int64Array::from(vec![1700000000i64])),
1885 Arc::new(vector),
1886 ],
1887 )
1888 .unwrap()
1889 }
1890
1891 #[tokio::test]
1892 #[ignore] async fn test_check_needs_migration_detects_old_schema() {
1894 let tmp = tempfile::TempDir::new().unwrap();
1895 let db_path = tmp.path().join("data");
1896
1897 let db_conn = lancedb::connect(db_path.to_str().unwrap())
1899 .execute()
1900 .await
1901 .unwrap();
1902
1903 let schema = Arc::new(old_item_schema());
1904 let batch = old_item_batch("test-id-1", "old content");
1905 let batches = RecordBatchIterator::new(vec![Ok(batch)], schema);
1906 db_conn
1907 .create_table("items", Box::new(batches))
1908 .execute()
1909 .await
1910 .unwrap();
1911
1912 let db = Database {
1914 db: db_conn,
1915 embedder: Arc::new(Embedder::new().unwrap()),
1916 project_id: None,
1917 items_table: None,
1918 chunks_table: None,
1919 fts_boost_max: FTS_BOOST_MAX,
1920 fts_gamma: FTS_GAMMA,
1921 };
1922
1923 let needs_migration = db.check_needs_migration().await.unwrap();
1924 assert!(
1925 needs_migration,
1926 "Old schema with tags column should need migration"
1927 );
1928 }
1929
1930 #[tokio::test]
1931 #[ignore] async fn test_check_needs_migration_false_for_new_schema() {
1933 let tmp = tempfile::TempDir::new().unwrap();
1934 let db_path = tmp.path().join("data");
1935
1936 let db_conn = lancedb::connect(db_path.to_str().unwrap())
1938 .execute()
1939 .await
1940 .unwrap();
1941
1942 let schema = Arc::new(item_schema(EMBEDDING_DIM));
1943 db_conn
1944 .create_empty_table("items", schema)
1945 .execute()
1946 .await
1947 .unwrap();
1948
1949 let db = Database {
1950 db: db_conn,
1951 embedder: Arc::new(Embedder::new().unwrap()),
1952 project_id: None,
1953 items_table: None,
1954 chunks_table: None,
1955 fts_boost_max: FTS_BOOST_MAX,
1956 fts_gamma: FTS_GAMMA,
1957 };
1958
1959 let needs_migration = db.check_needs_migration().await.unwrap();
1960 assert!(!needs_migration, "New schema should not need migration");
1961 }
1962
1963 #[tokio::test]
1964 #[ignore] async fn test_migrate_schema_preserves_data() {
1966 let tmp = tempfile::TempDir::new().unwrap();
1967 let db_path = tmp.path().join("data");
1968
1969 let db_conn = lancedb::connect(db_path.to_str().unwrap())
1971 .execute()
1972 .await
1973 .unwrap();
1974
1975 let schema = Arc::new(old_item_schema());
1976 let batch1 = old_item_batch("id-aaa", "first item content");
1977 let batch2 = old_item_batch("id-bbb", "second item content");
1978 let batches = RecordBatchIterator::new(vec![Ok(batch1), Ok(batch2)], schema);
1979 db_conn
1980 .create_table("items", Box::new(batches))
1981 .execute()
1982 .await
1983 .unwrap();
1984 drop(db_conn);
1985
1986 let embedder = Arc::new(Embedder::new().unwrap());
1988 let db = Database::open_with_embedder(&db_path, None, embedder)
1989 .await
1990 .unwrap();
1991
1992 let needs_migration = db.check_needs_migration().await.unwrap();
1994 assert!(
1995 !needs_migration,
1996 "Schema should be migrated (no tags column)"
1997 );
1998
1999 let item_a = db.get_item("id-aaa").await.unwrap();
2001 assert!(item_a.is_some(), "Item id-aaa should be preserved");
2002 assert_eq!(item_a.unwrap().content, "first item content");
2003
2004 let item_b = db.get_item("id-bbb").await.unwrap();
2005 assert!(item_b.is_some(), "Item id-bbb should be preserved");
2006 assert_eq!(item_b.unwrap().content, "second item content");
2007
2008 let stats = db.stats().await.unwrap();
2010 assert_eq!(stats.item_count, 2, "Should have 2 items after migration");
2011 }
2012
2013 #[tokio::test]
2014 #[ignore] async fn test_recover_case_a_only_staging() {
2016 let tmp = tempfile::TempDir::new().unwrap();
2017 let db_path = tmp.path().join("data");
2018
2019 let db_conn = lancedb::connect(db_path.to_str().unwrap())
2021 .execute()
2022 .await
2023 .unwrap();
2024
2025 let schema = Arc::new(item_schema(EMBEDDING_DIM));
2026 let vector_values = Float32Array::from(vec![0.0f32; EMBEDDING_DIM]);
2027 let vector_field = Arc::new(Field::new("item", DataType::Float32, true));
2028 let vector = FixedSizeListArray::try_new(
2029 vector_field,
2030 EMBEDDING_DIM as i32,
2031 Arc::new(vector_values),
2032 None,
2033 )
2034 .unwrap();
2035
2036 let batch = RecordBatch::try_new(
2037 schema.clone(),
2038 vec![
2039 Arc::new(StringArray::from(vec!["staging-id"])),
2040 Arc::new(StringArray::from(vec!["staging content"])),
2041 Arc::new(StringArray::from(vec![None::<&str>])),
2042 Arc::new(BooleanArray::from(vec![false])),
2043 Arc::new(Int64Array::from(vec![1700000000i64])),
2044 Arc::new(vector),
2045 ],
2046 )
2047 .unwrap();
2048
2049 let batches = RecordBatchIterator::new(vec![Ok(batch)], schema);
2050 db_conn
2051 .create_table("items_migrated", Box::new(batches))
2052 .execute()
2053 .await
2054 .unwrap();
2055 drop(db_conn);
2056
2057 let embedder = Arc::new(Embedder::new().unwrap());
2059 let db = Database::open_with_embedder(&db_path, None, embedder)
2060 .await
2061 .unwrap();
2062
2063 let item = db.get_item("staging-id").await.unwrap();
2065 assert!(item.is_some(), "Item should be recovered from staging");
2066 assert_eq!(item.unwrap().content, "staging content");
2067
2068 let table_names = db.db.table_names().execute().await.unwrap();
2070 assert!(
2071 !table_names.contains(&"items_migrated".to_string()),
2072 "Staging table should be dropped"
2073 );
2074 }
2075
2076 #[tokio::test]
2077 #[ignore] async fn test_recover_case_b_both_old_schema() {
2079 let tmp = tempfile::TempDir::new().unwrap();
2080 let db_path = tmp.path().join("data");
2081
2082 let db_conn = lancedb::connect(db_path.to_str().unwrap())
2084 .execute()
2085 .await
2086 .unwrap();
2087
2088 let old_schema = Arc::new(old_item_schema());
2090 let batch = old_item_batch("old-id", "old content");
2091 let batches = RecordBatchIterator::new(vec![Ok(batch)], old_schema);
2092 db_conn
2093 .create_table("items", Box::new(batches))
2094 .execute()
2095 .await
2096 .unwrap();
2097
2098 let new_schema = Arc::new(item_schema(EMBEDDING_DIM));
2100 db_conn
2101 .create_empty_table("items_migrated", new_schema)
2102 .execute()
2103 .await
2104 .unwrap();
2105 drop(db_conn);
2106
2107 let embedder = Arc::new(Embedder::new().unwrap());
2109 let db = Database::open_with_embedder(&db_path, None, embedder)
2110 .await
2111 .unwrap();
2112
2113 let needs_migration = db.check_needs_migration().await.unwrap();
2115 assert!(!needs_migration, "Should have migrated after recovery");
2116
2117 let item = db.get_item("old-id").await.unwrap();
2119 assert!(
2120 item.is_some(),
2121 "Item should be preserved through recovery + migration"
2122 );
2123
2124 let table_names = db.db.table_names().execute().await.unwrap();
2126 assert!(
2127 !table_names.contains(&"items_migrated".to_string()),
2128 "Staging table should be dropped"
2129 );
2130 }
2131
2132 #[tokio::test]
2133 #[ignore] async fn test_recover_case_c_both_new_schema() {
2135 let tmp = tempfile::TempDir::new().unwrap();
2136 let db_path = tmp.path().join("data");
2137
2138 let db_conn = lancedb::connect(db_path.to_str().unwrap())
2140 .execute()
2141 .await
2142 .unwrap();
2143
2144 let new_schema = Arc::new(item_schema(EMBEDDING_DIM));
2145
2146 let vector_values = Float32Array::from(vec![0.0f32; EMBEDDING_DIM]);
2148 let vector_field = Arc::new(Field::new("item", DataType::Float32, true));
2149 let vector = FixedSizeListArray::try_new(
2150 vector_field,
2151 EMBEDDING_DIM as i32,
2152 Arc::new(vector_values),
2153 None,
2154 )
2155 .unwrap();
2156
2157 let batch = RecordBatch::try_new(
2158 new_schema.clone(),
2159 vec![
2160 Arc::new(StringArray::from(vec!["new-id"])),
2161 Arc::new(StringArray::from(vec!["new content"])),
2162 Arc::new(StringArray::from(vec![None::<&str>])),
2163 Arc::new(BooleanArray::from(vec![false])),
2164 Arc::new(Int64Array::from(vec![1700000000i64])),
2165 Arc::new(vector),
2166 ],
2167 )
2168 .unwrap();
2169
2170 let batches = RecordBatchIterator::new(vec![Ok(batch)], new_schema.clone());
2171 db_conn
2172 .create_table("items", Box::new(batches))
2173 .execute()
2174 .await
2175 .unwrap();
2176
2177 db_conn
2179 .create_empty_table("items_migrated", new_schema)
2180 .execute()
2181 .await
2182 .unwrap();
2183 drop(db_conn);
2184
2185 let embedder = Arc::new(Embedder::new().unwrap());
2187 let db = Database::open_with_embedder(&db_path, None, embedder)
2188 .await
2189 .unwrap();
2190
2191 let item = db.get_item("new-id").await.unwrap();
2193 assert!(item.is_some(), "Item should be untouched");
2194 assert_eq!(item.unwrap().content, "new content");
2195
2196 let table_names = db.db.table_names().execute().await.unwrap();
2198 assert!(
2199 !table_names.contains(&"items_migrated".to_string()),
2200 "Staging table should be dropped"
2201 );
2202 }
2203
2204 #[tokio::test]
2205 #[ignore] async fn test_list_items_rejects_invalid_project_id() {
2207 let tmp = tempfile::TempDir::new().unwrap();
2208 let db_path = tmp.path().join("data");
2209 let malicious_pid = "'; DROP TABLE items;--".to_string();
2210
2211 let mut db = Database::open_with_project(&db_path, Some(malicious_pid))
2212 .await
2213 .unwrap();
2214
2215 let result = db.list_items(Some(10), crate::ListScope::Project).await;
2216
2217 assert!(result.is_err(), "Should reject invalid project_id");
2218 let err_msg = result.unwrap_err().to_string();
2219 assert!(
2220 err_msg.contains("Invalid project_id"),
2221 "Error should mention invalid project_id, got: {}",
2222 err_msg
2223 );
2224 }
2225}