1use std::path::PathBuf;
7use std::sync::Arc;
8
9fn sanitize_sql_string(s: &str) -> String {
19 s.replace('\0', "")
20 .replace('\\', "\\\\")
21 .replace('\'', "''")
22 .replace(';', "")
23 .replace("--", "")
24 .replace("/*", "")
25 .replace("*/", "")
26}
27
28fn is_valid_id(s: &str) -> bool {
32 !s.is_empty() && s.len() <= 64 && s.chars().all(|c| c.is_ascii_hexdigit() || c == '-')
33}
34
35use arrow_array::{
36 Array, BooleanArray, FixedSizeListArray, Float32Array, Int32Array, Int64Array, RecordBatch,
37 RecordBatchIterator, StringArray,
38};
39use arrow_schema::{DataType, Field, Schema};
40use chrono::{TimeZone, Utc};
41use futures::TryStreamExt;
42use lancedb::Table;
43use lancedb::connect;
44use lancedb::index::scalar::FullTextSearchQuery;
45use lancedb::query::{ExecutableQuery, QueryBase};
46use tracing::{debug, info};
47
48use crate::boost_similarity;
49use crate::chunker::{ChunkingConfig, chunk_content};
50use crate::document::ContentType;
51use crate::embedder::{EMBEDDING_DIM, Embedder};
52use crate::error::{Result, SedimentError};
53use crate::item::{Chunk, ConflictInfo, Item, ItemFilters, SearchResult, StoreResult};
54
55const CHUNK_THRESHOLD: usize = 1000;
57
58const CONFLICT_SIMILARITY_THRESHOLD: f32 = 0.85;
60
61const CONFLICT_SEARCH_LIMIT: usize = 5;
63
64const MAX_CHUNKS_PER_ITEM: usize = 200;
68
69const EMBEDDING_BATCH_SIZE: usize = 32;
72
73const FTS_BOOST_MAX: f32 = 0.12;
78
79const FTS_GAMMA: f32 = 2.0;
84
85const VECTOR_INDEX_THRESHOLD: usize = 5000;
88
89pub struct Database {
91 db: lancedb::Connection,
92 embedder: Arc<Embedder>,
93 project_id: Option<String>,
94 items_table: Option<Table>,
95 chunks_table: Option<Table>,
96 fts_boost_max: f32,
98 fts_gamma: f32,
99}
100
101#[derive(Debug, Default, Clone)]
103pub struct DatabaseStats {
104 pub item_count: usize,
105 pub chunk_count: usize,
106}
107
108const SCHEMA_VERSION: i32 = 2;
110
111fn item_schema() -> Schema {
113 Schema::new(vec![
114 Field::new("id", DataType::Utf8, false),
115 Field::new("content", DataType::Utf8, false),
116 Field::new("project_id", DataType::Utf8, true),
117 Field::new("is_chunked", DataType::Boolean, false),
118 Field::new("created_at", DataType::Int64, false), Field::new(
120 "vector",
121 DataType::FixedSizeList(
122 Arc::new(Field::new("item", DataType::Float32, true)),
123 EMBEDDING_DIM as i32,
124 ),
125 false,
126 ),
127 ])
128}
129
130fn chunk_schema() -> Schema {
131 Schema::new(vec![
132 Field::new("id", DataType::Utf8, false),
133 Field::new("item_id", DataType::Utf8, false),
134 Field::new("chunk_index", DataType::Int32, false),
135 Field::new("content", DataType::Utf8, false),
136 Field::new("context", DataType::Utf8, true),
137 Field::new(
138 "vector",
139 DataType::FixedSizeList(
140 Arc::new(Field::new("item", DataType::Float32, true)),
141 EMBEDDING_DIM as i32,
142 ),
143 false,
144 ),
145 ])
146}
147
148impl Database {
149 pub async fn open(path: impl Into<PathBuf>) -> Result<Self> {
151 Self::open_with_project(path, None).await
152 }
153
154 pub async fn open_with_project(
156 path: impl Into<PathBuf>,
157 project_id: Option<String>,
158 ) -> Result<Self> {
159 let embedder = Arc::new(Embedder::new()?);
160 Self::open_with_embedder(path, project_id, embedder).await
161 }
162
163 pub async fn open_with_embedder(
175 path: impl Into<PathBuf>,
176 project_id: Option<String>,
177 embedder: Arc<Embedder>,
178 ) -> Result<Self> {
179 let path = path.into();
180 info!("Opening database at {:?}", path);
181
182 if let Some(parent) = path.parent() {
184 std::fs::create_dir_all(parent).map_err(|e| {
185 SedimentError::Database(format!("Failed to create database directory: {}", e))
186 })?;
187 }
188
189 let db = connect(path.to_str().ok_or_else(|| {
190 SedimentError::Database("Database path contains invalid UTF-8".to_string())
191 })?)
192 .execute()
193 .await
194 .map_err(|e| SedimentError::Database(format!("Failed to connect to database: {}", e)))?;
195
196 #[cfg(feature = "bench")]
199 let (fts_boost_max, fts_gamma) = {
200 let boost = std::env::var("SEDIMENT_FTS_BOOST_MAX")
201 .ok()
202 .and_then(|v| v.parse::<f32>().ok())
203 .unwrap_or(FTS_BOOST_MAX);
204 let gamma = std::env::var("SEDIMENT_FTS_GAMMA")
205 .ok()
206 .and_then(|v| v.parse::<f32>().ok())
207 .unwrap_or(FTS_GAMMA);
208 (boost, gamma)
209 };
210 #[cfg(not(feature = "bench"))]
211 let (fts_boost_max, fts_gamma) = (FTS_BOOST_MAX, FTS_GAMMA);
212
213 let mut database = Self {
214 db,
215 embedder,
216 project_id,
217 items_table: None,
218 chunks_table: None,
219 fts_boost_max,
220 fts_gamma,
221 };
222
223 database.ensure_tables().await?;
224 database.ensure_vector_index().await?;
225
226 Ok(database)
227 }
228
229 pub fn set_project_id(&mut self, project_id: Option<String>) {
231 self.project_id = project_id;
232 }
233
234 pub fn project_id(&self) -> Option<&str> {
236 self.project_id.as_deref()
237 }
238
239 async fn ensure_tables(&mut self) -> Result<()> {
241 let mut table_names = self
243 .db
244 .table_names()
245 .execute()
246 .await
247 .map_err(|e| SedimentError::Database(format!("Failed to list tables: {}", e)))?;
248
249 if table_names.contains(&"items_migrated".to_string()) {
251 info!("Detected interrupted migration, recovering...");
252 self.recover_interrupted_migration(&table_names).await?;
253 table_names =
255 self.db.table_names().execute().await.map_err(|e| {
256 SedimentError::Database(format!("Failed to list tables: {}", e))
257 })?;
258 }
259
260 if table_names.contains(&"items".to_string()) {
262 let needs_migration = self.check_needs_migration().await?;
263 if needs_migration {
264 info!("Migrating database schema to version {}", SCHEMA_VERSION);
265 self.migrate_schema().await?;
266 }
267 }
268
269 if table_names.contains(&"items".to_string()) {
271 self.items_table =
272 Some(self.db.open_table("items").execute().await.map_err(|e| {
273 SedimentError::Database(format!("Failed to open items: {}", e))
274 })?);
275 }
276
277 if table_names.contains(&"chunks".to_string()) {
279 self.chunks_table =
280 Some(self.db.open_table("chunks").execute().await.map_err(|e| {
281 SedimentError::Database(format!("Failed to open chunks: {}", e))
282 })?);
283 }
284
285 Ok(())
286 }
287
288 async fn check_needs_migration(&self) -> Result<bool> {
290 let table = self.db.open_table("items").execute().await.map_err(|e| {
291 SedimentError::Database(format!("Failed to open items for check: {}", e))
292 })?;
293
294 let schema = table
295 .schema()
296 .await
297 .map_err(|e| SedimentError::Database(format!("Failed to get schema: {}", e)))?;
298
299 let has_tags = schema.fields().iter().any(|f| f.name() == "tags");
301 Ok(has_tags)
302 }
303
304 async fn recover_interrupted_migration(&mut self, table_names: &[String]) -> Result<()> {
315 let has_items = table_names.contains(&"items".to_string());
316
317 if !has_items {
318 info!("Recovery case A: restoring items from items_migrated");
320 let staging = self
321 .db
322 .open_table("items_migrated")
323 .execute()
324 .await
325 .map_err(|e| {
326 SedimentError::Database(format!("Failed to open staging table: {}", e))
327 })?;
328
329 let results = staging
330 .query()
331 .execute()
332 .await
333 .map_err(|e| SedimentError::Database(format!("Recovery query failed: {}", e)))?
334 .try_collect::<Vec<_>>()
335 .await
336 .map_err(|e| SedimentError::Database(format!("Recovery collect failed: {}", e)))?;
337
338 let schema = Arc::new(item_schema());
339 let new_table = self
340 .db
341 .create_empty_table("items", schema.clone())
342 .execute()
343 .await
344 .map_err(|e| {
345 SedimentError::Database(format!("Failed to create items table: {}", e))
346 })?;
347
348 if !results.is_empty() {
349 let batches = RecordBatchIterator::new(results.into_iter().map(Ok), schema);
350 new_table
351 .add(Box::new(batches))
352 .execute()
353 .await
354 .map_err(|e| {
355 SedimentError::Database(format!("Failed to restore items: {}", e))
356 })?;
357 }
358
359 self.db.drop_table("items_migrated").await.map_err(|e| {
360 SedimentError::Database(format!("Failed to drop staging table: {}", e))
361 })?;
362 info!("Recovery case A completed");
363 } else {
364 let has_old_schema = self.check_needs_migration().await?;
366
367 if has_old_schema {
368 info!("Recovery case B: dropping incomplete staging table");
370 self.db.drop_table("items_migrated").await.map_err(|e| {
371 SedimentError::Database(format!("Failed to drop staging table: {}", e))
372 })?;
373 } else {
375 info!("Recovery case C: dropping leftover staging table");
377 self.db.drop_table("items_migrated").await.map_err(|e| {
378 SedimentError::Database(format!("Failed to drop staging table: {}", e))
379 })?;
380 }
381 }
382
383 Ok(())
384 }
385
386 async fn migrate_schema(&mut self) -> Result<()> {
400 info!("Starting schema migration...");
401
402 let old_table = self
404 .db
405 .open_table("items")
406 .execute()
407 .await
408 .map_err(|e| SedimentError::Database(format!("Failed to open old items: {}", e)))?;
409
410 let results = old_table
411 .query()
412 .execute()
413 .await
414 .map_err(|e| SedimentError::Database(format!("Migration query failed: {}", e)))?
415 .try_collect::<Vec<_>>()
416 .await
417 .map_err(|e| SedimentError::Database(format!("Migration collect failed: {}", e)))?;
418
419 let mut new_batches = Vec::new();
421 for batch in &results {
422 let converted = self.convert_batch_to_new_schema(batch)?;
423 new_batches.push(converted);
424 }
425
426 let old_count: usize = results.iter().map(|b| b.num_rows()).sum();
428 let new_count: usize = new_batches.iter().map(|b| b.num_rows()).sum();
429 if old_count != new_count {
430 return Err(SedimentError::Database(format!(
431 "Migration row count mismatch: old={}, new={}",
432 old_count, new_count
433 )));
434 }
435 info!("Migrating {} items to new schema", old_count);
436
437 let table_names = self
439 .db
440 .table_names()
441 .execute()
442 .await
443 .map_err(|e| SedimentError::Database(format!("Failed to list tables: {}", e)))?;
444 if table_names.contains(&"items_migrated".to_string()) {
445 self.db.drop_table("items_migrated").await.map_err(|e| {
446 SedimentError::Database(format!("Failed to drop stale staging: {}", e))
447 })?;
448 }
449
450 let schema = Arc::new(item_schema());
452 let staging_table = self
453 .db
454 .create_empty_table("items_migrated", schema.clone())
455 .execute()
456 .await
457 .map_err(|e| {
458 SedimentError::Database(format!("Failed to create staging table: {}", e))
459 })?;
460
461 if !new_batches.is_empty() {
462 let batches = RecordBatchIterator::new(new_batches.into_iter().map(Ok), schema.clone());
463 staging_table
464 .add(Box::new(batches))
465 .execute()
466 .await
467 .map_err(|e| {
468 SedimentError::Database(format!("Failed to insert into staging: {}", e))
469 })?;
470 }
471
472 let staging_count = staging_table
474 .count_rows(None)
475 .await
476 .map_err(|e| SedimentError::Database(format!("Failed to count staging rows: {}", e)))?;
477 if staging_count != old_count {
478 let _ = self.db.drop_table("items_migrated").await;
480 return Err(SedimentError::Database(format!(
481 "Staging row count mismatch: expected {}, got {}",
482 old_count, staging_count
483 )));
484 }
485
486 self.db.drop_table("items").await.map_err(|e| {
488 SedimentError::Database(format!("Failed to drop old items table: {}", e))
489 })?;
490
491 let staging_data = staging_table
493 .query()
494 .execute()
495 .await
496 .map_err(|e| SedimentError::Database(format!("Failed to read staging: {}", e)))?
497 .try_collect::<Vec<_>>()
498 .await
499 .map_err(|e| SedimentError::Database(format!("Failed to collect staging: {}", e)))?;
500
501 let new_table = self
502 .db
503 .create_empty_table("items", schema.clone())
504 .execute()
505 .await
506 .map_err(|e| {
507 SedimentError::Database(format!("Failed to create new items table: {}", e))
508 })?;
509
510 if !staging_data.is_empty() {
511 let batches = RecordBatchIterator::new(staging_data.into_iter().map(Ok), schema);
512 new_table
513 .add(Box::new(batches))
514 .execute()
515 .await
516 .map_err(|e| {
517 SedimentError::Database(format!("Failed to insert migrated items: {}", e))
518 })?;
519 }
520
521 self.db
523 .drop_table("items_migrated")
524 .await
525 .map_err(|e| SedimentError::Database(format!("Failed to drop staging table: {}", e)))?;
526
527 info!("Schema migration completed successfully");
528 Ok(())
529 }
530
531 fn convert_batch_to_new_schema(&self, batch: &RecordBatch) -> Result<RecordBatch> {
533 let schema = Arc::new(item_schema());
534
535 let id_col = batch
537 .column_by_name("id")
538 .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?
539 .clone();
540
541 let content_col = batch
542 .column_by_name("content")
543 .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?
544 .clone();
545
546 let project_id_col = batch
547 .column_by_name("project_id")
548 .ok_or_else(|| SedimentError::Database("Missing project_id column".to_string()))?
549 .clone();
550
551 let is_chunked_col = batch
552 .column_by_name("is_chunked")
553 .ok_or_else(|| SedimentError::Database("Missing is_chunked column".to_string()))?
554 .clone();
555
556 let created_at_col = batch
557 .column_by_name("created_at")
558 .ok_or_else(|| SedimentError::Database("Missing created_at column".to_string()))?
559 .clone();
560
561 let vector_col = batch
562 .column_by_name("vector")
563 .ok_or_else(|| SedimentError::Database("Missing vector column".to_string()))?
564 .clone();
565
566 RecordBatch::try_new(
567 schema,
568 vec![
569 id_col,
570 content_col,
571 project_id_col,
572 is_chunked_col,
573 created_at_col,
574 vector_col,
575 ],
576 )
577 .map_err(|e| SedimentError::Database(format!("Failed to create migrated batch: {}", e)))
578 }
579
580 async fn ensure_vector_index(&self) -> Result<()> {
586 const MIN_ROWS_FOR_INDEX: usize = 256;
587
588 for (name, table_opt) in [("items", &self.items_table), ("chunks", &self.chunks_table)] {
589 if let Some(table) = table_opt {
590 let row_count = table.count_rows(None).await.unwrap_or(0);
591
592 let indices = table.list_indices().await.unwrap_or_default();
594
595 if row_count >= MIN_ROWS_FOR_INDEX {
597 let has_vector_index = indices
598 .iter()
599 .any(|idx| idx.columns.contains(&"vector".to_string()));
600
601 if !has_vector_index {
602 info!(
603 "Creating vector index on {} table ({} rows)",
604 name, row_count
605 );
606 match table
607 .create_index(&["vector"], lancedb::index::Index::Auto)
608 .execute()
609 .await
610 {
611 Ok(_) => info!("Vector index created on {} table", name),
612 Err(e) => {
613 tracing::warn!("Failed to create vector index on {}: {}", name, e);
615 }
616 }
617 }
618 }
619
620 if row_count > 0 {
624 match table
625 .create_index(&["content"], lancedb::index::Index::FTS(Default::default()))
626 .replace(true)
627 .execute()
628 .await
629 {
630 Ok(_) => {
631 debug!("FTS index refreshed on {} table ({} rows)", name, row_count)
632 }
633 Err(e) => {
634 tracing::warn!("Failed to create FTS index on {}: {}", name, e);
636 }
637 }
638 }
639 }
640 }
641
642 Ok(())
643 }
644
645 async fn get_items_table(&mut self) -> Result<&Table> {
647 if self.items_table.is_none() {
648 let schema = Arc::new(item_schema());
649 let table = self
650 .db
651 .create_empty_table("items", schema)
652 .execute()
653 .await
654 .map_err(|e| {
655 SedimentError::Database(format!("Failed to create items table: {}", e))
656 })?;
657 self.items_table = Some(table);
658 }
659 Ok(self.items_table.as_ref().unwrap())
660 }
661
662 async fn get_chunks_table(&mut self) -> Result<&Table> {
664 if self.chunks_table.is_none() {
665 let schema = Arc::new(chunk_schema());
666 let table = self
667 .db
668 .create_empty_table("chunks", schema)
669 .execute()
670 .await
671 .map_err(|e| {
672 SedimentError::Database(format!("Failed to create chunks table: {}", e))
673 })?;
674 self.chunks_table = Some(table);
675 }
676 Ok(self.chunks_table.as_ref().unwrap())
677 }
678
679 pub async fn store_item(&mut self, mut item: Item) -> Result<StoreResult> {
686 if item.project_id.is_none() {
688 item.project_id = self.project_id.clone();
689 }
690
691 let should_chunk = item.content.chars().count() > CHUNK_THRESHOLD;
694 item.is_chunked = should_chunk;
695
696 let embedding_text = item.embedding_text();
698 let embedding = self.embedder.embed(&embedding_text)?;
699 item.embedding = embedding;
700
701 let table = self.get_items_table().await?;
703 let batch = item_to_batch(&item)?;
704 let batches = RecordBatchIterator::new(vec![Ok(batch)], Arc::new(item_schema()));
705
706 table
707 .add(Box::new(batches))
708 .execute()
709 .await
710 .map_err(|e| SedimentError::Database(format!("Failed to store item: {}", e)))?;
711
712 if should_chunk {
714 let content_type = detect_content_type(&item.content);
715 let config = ChunkingConfig::default();
716 let mut chunk_results = chunk_content(&item.content, content_type, &config);
717
718 if chunk_results.len() > MAX_CHUNKS_PER_ITEM {
720 tracing::warn!(
721 "Chunk count {} exceeds limit {}, truncating",
722 chunk_results.len(),
723 MAX_CHUNKS_PER_ITEM
724 );
725 chunk_results.truncate(MAX_CHUNKS_PER_ITEM);
726 }
727
728 if let Err(e) = self.store_chunks(&item.id, &chunk_results).await {
729 let _ = self.delete_item(&item.id).await;
731 return Err(e);
732 }
733
734 debug!(
735 "Stored item: {} with {} chunks",
736 item.id,
737 chunk_results.len()
738 );
739 } else {
740 debug!("Stored item: {} (no chunking)", item.id);
741 }
742
743 let potential_conflicts = self
746 .find_similar_items_by_vector(
747 &item.embedding,
748 Some(&item.id),
749 CONFLICT_SIMILARITY_THRESHOLD,
750 CONFLICT_SEARCH_LIMIT,
751 )
752 .await
753 .unwrap_or_default();
754
755 Ok(StoreResult {
756 id: item.id,
757 potential_conflicts,
758 })
759 }
760
761 async fn store_chunks(
763 &mut self,
764 item_id: &str,
765 chunk_results: &[crate::chunker::ChunkResult],
766 ) -> Result<()> {
767 let embedder = self.embedder.clone();
768 let chunks_table = self.get_chunks_table().await?;
769
770 let chunk_texts: Vec<&str> = chunk_results.iter().map(|cr| cr.content.as_str()).collect();
772 let mut all_embeddings = Vec::with_capacity(chunk_texts.len());
773 for batch_start in (0..chunk_texts.len()).step_by(EMBEDDING_BATCH_SIZE) {
774 let batch_end = (batch_start + EMBEDDING_BATCH_SIZE).min(chunk_texts.len());
775 let batch_embeddings = embedder.embed_batch(&chunk_texts[batch_start..batch_end])?;
776 all_embeddings.extend(batch_embeddings);
777 }
778
779 let mut all_chunk_batches = Vec::with_capacity(chunk_results.len());
781 for (i, (chunk_result, embedding)) in chunk_results.iter().zip(all_embeddings).enumerate() {
782 let mut chunk = Chunk::new(item_id, i, &chunk_result.content);
783 if let Some(ctx) = &chunk_result.context {
784 chunk = chunk.with_context(ctx);
785 }
786 chunk.embedding = embedding;
787 all_chunk_batches.push(chunk_to_batch(&chunk)?);
788 }
789
790 if !all_chunk_batches.is_empty() {
792 let schema = Arc::new(chunk_schema());
793 let batches = RecordBatchIterator::new(all_chunk_batches.into_iter().map(Ok), schema);
794 chunks_table
795 .add(Box::new(batches))
796 .execute()
797 .await
798 .map_err(|e| SedimentError::Database(format!("Failed to store chunks: {}", e)))?;
799 }
800
801 Ok(())
802 }
803
804 async fn fts_rank_items(
807 &self,
808 table: &Table,
809 query: &str,
810 limit: usize,
811 ) -> Option<std::collections::HashMap<String, f32>> {
812 let fts_query =
813 FullTextSearchQuery::new(query.to_string()).columns(Some(vec!["content".to_string()]));
814
815 let fts_results = table
816 .query()
817 .full_text_search(fts_query)
818 .limit(limit)
819 .execute()
820 .await
821 .ok()?
822 .try_collect::<Vec<_>>()
823 .await
824 .ok()?;
825
826 let mut scores = std::collections::HashMap::new();
827 for batch in fts_results {
828 let ids = batch
829 .column_by_name("id")
830 .and_then(|c| c.as_any().downcast_ref::<StringArray>())?;
831 let bm25_scores = batch
832 .column_by_name("_score")
833 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
834 for i in 0..ids.len() {
835 if !ids.is_null(i) {
836 let score = bm25_scores.map(|s| s.value(i)).unwrap_or(0.0);
837 scores.insert(ids.value(i).to_string(), score);
838 }
839 }
840 }
841 Some(scores)
842 }
843
844 pub async fn search_items(
846 &mut self,
847 query: &str,
848 limit: usize,
849 filters: ItemFilters,
850 ) -> Result<Vec<SearchResult>> {
851 let limit = limit.min(1000);
853 self.ensure_vector_index().await?;
855
856 let query_embedding = self.embedder.embed(query)?;
858 let min_similarity = filters.min_similarity.unwrap_or(0.3);
859
860 let mut results_map: std::collections::HashMap<String, (SearchResult, f32)> =
862 std::collections::HashMap::new();
863
864 if let Some(table) = &self.items_table {
866 let row_count = table.count_rows(None).await.unwrap_or(0);
867 let base_query = table
868 .vector_search(query_embedding.clone())
869 .map_err(|e| SedimentError::Database(format!("Failed to build search: {}", e)))?;
870 let query_builder = if row_count < VECTOR_INDEX_THRESHOLD {
871 base_query.bypass_vector_index().limit(limit * 2)
872 } else {
873 base_query.refine_factor(10).limit(limit * 2)
874 };
875
876 let results = query_builder
877 .execute()
878 .await
879 .map_err(|e| SedimentError::Database(format!("Search failed: {}", e)))?
880 .try_collect::<Vec<_>>()
881 .await
882 .map_err(|e| {
883 SedimentError::Database(format!("Failed to collect results: {}", e))
884 })?;
885
886 let mut vector_items: Vec<(Item, f32)> = Vec::new();
888 for batch in results {
889 let items = batch_to_items(&batch)?;
890 let distances = batch
891 .column_by_name("_distance")
892 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
893
894 for (i, item) in items.into_iter().enumerate() {
895 let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
896 let similarity = 1.0 / (1.0 + distance);
897 if similarity >= min_similarity {
898 vector_items.push((item, similarity));
899 }
900 }
901 }
902
903 let fts_ranking = self.fts_rank_items(table, query, limit * 2).await;
905
906 let max_bm25 = fts_ranking
908 .as_ref()
909 .and_then(|scores| scores.values().cloned().reduce(f32::max))
910 .unwrap_or(1.0)
911 .max(f32::EPSILON);
912
913 for (item, similarity) in vector_items {
919 let fts_boost = fts_ranking.as_ref().map_or(0.0, |scores| {
920 scores.get(&item.id).map_or(0.0, |&bm25_score| {
921 self.fts_boost_max * (bm25_score / max_bm25).powf(self.fts_gamma)
922 })
923 });
924 let boosted_similarity = boost_similarity(
925 similarity + fts_boost,
926 item.project_id.as_deref(),
927 self.project_id.as_deref(),
928 );
929
930 let result = SearchResult::from_item(&item, boosted_similarity);
931 results_map
932 .entry(item.id.clone())
933 .or_insert((result, boosted_similarity));
934 }
935 }
936
937 if let Some(chunks_table) = &self.chunks_table {
939 let chunk_row_count = chunks_table.count_rows(None).await.unwrap_or(0);
940 let chunk_base_query = chunks_table.vector_search(query_embedding).map_err(|e| {
941 SedimentError::Database(format!("Failed to build chunk search: {}", e))
942 })?;
943 let chunk_results = if chunk_row_count < VECTOR_INDEX_THRESHOLD {
944 chunk_base_query.bypass_vector_index().limit(limit * 3)
945 } else {
946 chunk_base_query.refine_factor(10).limit(limit * 3)
947 }
948 .execute()
949 .await
950 .map_err(|e| SedimentError::Database(format!("Chunk search failed: {}", e)))?
951 .try_collect::<Vec<_>>()
952 .await
953 .map_err(|e| {
954 SedimentError::Database(format!("Failed to collect chunk results: {}", e))
955 })?;
956
957 let mut chunk_matches: std::collections::HashMap<String, (String, f32)> =
959 std::collections::HashMap::new();
960
961 for batch in chunk_results {
962 let chunks = batch_to_chunks(&batch)?;
963 let distances = batch
964 .column_by_name("_distance")
965 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
966
967 for (i, chunk) in chunks.into_iter().enumerate() {
968 let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
969 let similarity = 1.0 / (1.0 + distance);
970
971 if similarity < min_similarity {
972 continue;
973 }
974
975 chunk_matches
977 .entry(chunk.item_id.clone())
978 .and_modify(|(content, best_sim)| {
979 if similarity > *best_sim {
980 *content = chunk.content.clone();
981 *best_sim = similarity;
982 }
983 })
984 .or_insert((chunk.content.clone(), similarity));
985 }
986 }
987
988 let chunk_item_ids: Vec<&str> = chunk_matches.keys().map(|id| id.as_str()).collect();
990 let parent_items = self.get_items_batch(&chunk_item_ids).await?;
991 let parent_map: std::collections::HashMap<&str, &Item> = parent_items
992 .iter()
993 .map(|item| (item.id.as_str(), item))
994 .collect();
995
996 for (item_id, (excerpt, chunk_similarity)) in chunk_matches {
997 if let Some(item) = parent_map.get(item_id.as_str()) {
998 let boosted_similarity = boost_similarity(
1000 chunk_similarity,
1001 item.project_id.as_deref(),
1002 self.project_id.as_deref(),
1003 );
1004
1005 let result =
1006 SearchResult::from_item_with_excerpt(item, boosted_similarity, excerpt);
1007
1008 results_map
1010 .entry(item_id)
1011 .and_modify(|(existing, existing_sim)| {
1012 if boosted_similarity > *existing_sim {
1013 *existing = result.clone();
1014 *existing_sim = boosted_similarity;
1015 }
1016 })
1017 .or_insert((result, boosted_similarity));
1018 }
1019 }
1020 }
1021
1022 let mut search_results: Vec<SearchResult> =
1025 results_map.into_values().map(|(sr, _)| sr).collect();
1026 search_results.sort_by(|a, b| {
1027 b.similarity
1028 .partial_cmp(&a.similarity)
1029 .unwrap_or(std::cmp::Ordering::Equal)
1030 });
1031 search_results.truncate(limit);
1032
1033 Ok(search_results)
1034 }
1035
1036 pub async fn find_similar_items(
1041 &mut self,
1042 content: &str,
1043 min_similarity: f32,
1044 limit: usize,
1045 ) -> Result<Vec<ConflictInfo>> {
1046 let embedding = self.embedder.embed(content)?;
1047 self.find_similar_items_by_vector(&embedding, None, min_similarity, limit)
1048 .await
1049 }
1050
1051 pub async fn find_similar_items_by_vector(
1055 &self,
1056 embedding: &[f32],
1057 exclude_id: Option<&str>,
1058 min_similarity: f32,
1059 limit: usize,
1060 ) -> Result<Vec<ConflictInfo>> {
1061 let table = match &self.items_table {
1062 Some(t) => t,
1063 None => return Ok(Vec::new()),
1064 };
1065
1066 let row_count = table.count_rows(None).await.unwrap_or(0);
1067 let base_query = table
1068 .vector_search(embedding.to_vec())
1069 .map_err(|e| SedimentError::Database(format!("Failed to build search: {}", e)))?;
1070 let results = if row_count < VECTOR_INDEX_THRESHOLD {
1071 base_query.bypass_vector_index().limit(limit)
1072 } else {
1073 base_query.refine_factor(10).limit(limit)
1074 }
1075 .execute()
1076 .await
1077 .map_err(|e| SedimentError::Database(format!("Search failed: {}", e)))?
1078 .try_collect::<Vec<_>>()
1079 .await
1080 .map_err(|e| SedimentError::Database(format!("Failed to collect results: {}", e)))?;
1081
1082 let mut conflicts = Vec::new();
1083
1084 for batch in results {
1085 let items = batch_to_items(&batch)?;
1086 let distances = batch
1087 .column_by_name("_distance")
1088 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
1089
1090 for (i, item) in items.into_iter().enumerate() {
1091 if exclude_id.is_some_and(|eid| eid == item.id) {
1092 continue;
1093 }
1094
1095 let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
1096 let similarity = 1.0 / (1.0 + distance);
1097
1098 if similarity >= min_similarity {
1099 conflicts.push(ConflictInfo {
1100 id: item.id,
1101 content: item.content,
1102 similarity,
1103 });
1104 }
1105 }
1106 }
1107
1108 conflicts.sort_by(|a, b| {
1110 b.similarity
1111 .partial_cmp(&a.similarity)
1112 .unwrap_or(std::cmp::Ordering::Equal)
1113 });
1114
1115 Ok(conflicts)
1116 }
1117
1118 pub async fn list_items(
1120 &mut self,
1121 limit: Option<usize>,
1122 scope: crate::ListScope,
1123 ) -> Result<Vec<Item>> {
1124 let table = match &self.items_table {
1125 Some(t) => t,
1126 None => return Ok(Vec::new()),
1127 };
1128
1129 let mut filter_parts = Vec::new();
1130
1131 match scope {
1133 crate::ListScope::Project => {
1134 if let Some(ref pid) = self.project_id {
1135 if !is_valid_id(pid) {
1136 return Err(SedimentError::Database(
1137 "Invalid project_id for list filter".to_string(),
1138 ));
1139 }
1140 filter_parts.push(format!("project_id = '{}'", sanitize_sql_string(pid)));
1141 } else {
1142 return Ok(Vec::new());
1144 }
1145 }
1146 crate::ListScope::Global => {
1147 filter_parts.push("project_id IS NULL".to_string());
1148 }
1149 crate::ListScope::All => {
1150 }
1152 }
1153
1154 let mut query = table.query();
1155
1156 if !filter_parts.is_empty() {
1157 let filter_str = filter_parts.join(" AND ");
1158 query = query.only_if(filter_str);
1159 }
1160
1161 if let Some(l) = limit {
1162 query = query.limit(l);
1163 }
1164
1165 let results = query
1166 .execute()
1167 .await
1168 .map_err(|e| SedimentError::Database(format!("Query failed: {}", e)))?
1169 .try_collect::<Vec<_>>()
1170 .await
1171 .map_err(|e| SedimentError::Database(format!("Failed to collect: {}", e)))?;
1172
1173 let mut items = Vec::new();
1174 for batch in results {
1175 items.extend(batch_to_items(&batch)?);
1176 }
1177
1178 Ok(items)
1179 }
1180
1181 pub async fn get_item(&self, id: &str) -> Result<Option<Item>> {
1183 if !is_valid_id(id) {
1184 return Ok(None);
1185 }
1186 let table = match &self.items_table {
1187 Some(t) => t,
1188 None => return Ok(None),
1189 };
1190
1191 let results = table
1192 .query()
1193 .only_if(format!("id = '{}'", sanitize_sql_string(id)))
1194 .limit(1)
1195 .execute()
1196 .await
1197 .map_err(|e| SedimentError::Database(format!("Query failed: {}", e)))?
1198 .try_collect::<Vec<_>>()
1199 .await
1200 .map_err(|e| SedimentError::Database(format!("Failed to collect: {}", e)))?;
1201
1202 for batch in results {
1203 let items = batch_to_items(&batch)?;
1204 if let Some(item) = items.into_iter().next() {
1205 return Ok(Some(item));
1206 }
1207 }
1208
1209 Ok(None)
1210 }
1211
1212 pub async fn get_items_batch(&self, ids: &[&str]) -> Result<Vec<Item>> {
1214 let table = match &self.items_table {
1215 Some(t) => t,
1216 None => return Ok(Vec::new()),
1217 };
1218
1219 if ids.is_empty() {
1220 return Ok(Vec::new());
1221 }
1222
1223 let quoted: Vec<String> = ids
1224 .iter()
1225 .filter(|id| is_valid_id(id))
1226 .map(|id| format!("'{}'", sanitize_sql_string(id)))
1227 .collect();
1228 if quoted.is_empty() {
1229 return Ok(Vec::new());
1230 }
1231 let filter = format!("id IN ({})", quoted.join(", "));
1232
1233 let results = table
1234 .query()
1235 .only_if(filter)
1236 .execute()
1237 .await
1238 .map_err(|e| SedimentError::Database(format!("Batch query failed: {}", e)))?
1239 .try_collect::<Vec<_>>()
1240 .await
1241 .map_err(|e| SedimentError::Database(format!("Failed to collect batch: {}", e)))?;
1242
1243 let mut items = Vec::new();
1244 for batch in results {
1245 items.extend(batch_to_items(&batch)?);
1246 }
1247
1248 Ok(items)
1249 }
1250
1251 pub async fn delete_item(&self, id: &str) -> Result<bool> {
1254 if !is_valid_id(id) {
1255 return Ok(false);
1256 }
1257 let table = match &self.items_table {
1259 Some(t) => t,
1260 None => return Ok(false),
1261 };
1262
1263 let exists = self.get_item(id).await?.is_some();
1264 if !exists {
1265 return Ok(false);
1266 }
1267
1268 table
1272 .delete(&format!("id = '{}'", sanitize_sql_string(id)))
1273 .await
1274 .map_err(|e| SedimentError::Database(format!("Delete failed: {}", e)))?;
1275
1276 if let Some(chunks_table) = &self.chunks_table
1279 && let Err(e) = chunks_table
1280 .delete(&format!("item_id = '{}'", sanitize_sql_string(id)))
1281 .await
1282 {
1283 tracing::warn!("Failed to delete chunks for item {}: {}", id, e);
1284 }
1285
1286 Ok(true)
1287 }
1288
1289 pub async fn stats(&self) -> Result<DatabaseStats> {
1291 let mut stats = DatabaseStats::default();
1292
1293 if let Some(table) = &self.items_table {
1294 stats.item_count = table
1295 .count_rows(None)
1296 .await
1297 .map_err(|e| SedimentError::Database(format!("Count failed: {}", e)))?;
1298 }
1299
1300 if let Some(table) = &self.chunks_table {
1301 stats.chunk_count = table
1302 .count_rows(None)
1303 .await
1304 .map_err(|e| SedimentError::Database(format!("Count failed: {}", e)))?;
1305 }
1306
1307 Ok(stats)
1308 }
1309}
1310
1311pub async fn migrate_project_id(
1318 db_path: &std::path::Path,
1319 old_id: &str,
1320 new_id: &str,
1321) -> Result<u64> {
1322 if !is_valid_id(old_id) || !is_valid_id(new_id) {
1323 return Err(SedimentError::Database(
1324 "Invalid project ID for migration".to_string(),
1325 ));
1326 }
1327
1328 let db = connect(db_path.to_str().ok_or_else(|| {
1329 SedimentError::Database("Database path contains invalid UTF-8".to_string())
1330 })?)
1331 .execute()
1332 .await
1333 .map_err(|e| SedimentError::Database(format!("Failed to connect for migration: {}", e)))?;
1334
1335 let table_names = db
1336 .table_names()
1337 .execute()
1338 .await
1339 .map_err(|e| SedimentError::Database(format!("Failed to list tables: {}", e)))?;
1340
1341 let mut total_updated = 0u64;
1342
1343 if table_names.contains(&"items".to_string()) {
1344 let table =
1345 db.open_table("items").execute().await.map_err(|e| {
1346 SedimentError::Database(format!("Failed to open items table: {}", e))
1347 })?;
1348
1349 let updated = table
1350 .update()
1351 .only_if(format!("project_id = '{}'", sanitize_sql_string(old_id)))
1352 .column("project_id", format!("'{}'", sanitize_sql_string(new_id)))
1353 .execute()
1354 .await
1355 .map_err(|e| SedimentError::Database(format!("Failed to migrate items: {}", e)))?;
1356
1357 total_updated += updated;
1358 info!(
1359 "Migrated {} items from project {} to {}",
1360 updated, old_id, new_id
1361 );
1362 }
1363
1364 Ok(total_updated)
1365}
1366
1367pub fn score_with_decay(
1378 similarity: f32,
1379 now: i64,
1380 created_at: i64,
1381 access_count: u32,
1382 last_accessed_at: Option<i64>,
1383) -> f32 {
1384 if !similarity.is_finite() {
1386 return 0.0;
1387 }
1388
1389 let reference_time = last_accessed_at.unwrap_or(created_at);
1390 let age_secs = (now - reference_time).max(0) as f64;
1391 let age_days = age_secs / 86400.0;
1392
1393 let freshness = 1.0 / (1.0 + age_days / 30.0);
1394 let frequency = 1.0 + 0.1 * (1.0 + access_count as f64).ln();
1395
1396 let result = similarity * (freshness * frequency) as f32;
1397 if result.is_finite() { result } else { 0.0 }
1398}
1399
1400fn detect_content_type(content: &str) -> ContentType {
1404 let trimmed = content.trim();
1405
1406 if ((trimmed.starts_with('{') && trimmed.ends_with('}'))
1408 || (trimmed.starts_with('[') && trimmed.ends_with(']')))
1409 && serde_json::from_str::<serde_json::Value>(trimmed).is_ok()
1410 {
1411 return ContentType::Json;
1412 }
1413
1414 if trimmed.contains(":\n") || trimmed.contains(": ") || trimmed.starts_with("---") {
1418 let lines: Vec<&str> = trimmed.lines().take(10).collect();
1419 let yaml_key_count = lines
1420 .iter()
1421 .filter(|line| {
1422 let l = line.trim();
1423 !l.is_empty()
1426 && !l.starts_with('#')
1427 && !l.contains("://")
1428 && l.contains(": ")
1429 && l.split(": ").next().is_some_and(|key| {
1430 let k = key.trim_start_matches("- ");
1431 !k.is_empty()
1432 && k.chars()
1433 .all(|c| c.is_alphanumeric() || c == '_' || c == '-')
1434 })
1435 })
1436 .count();
1437 if yaml_key_count >= 2 || (trimmed.starts_with("---") && yaml_key_count >= 1) {
1439 return ContentType::Yaml;
1440 }
1441 }
1442
1443 if trimmed.contains("\n# ") || trimmed.starts_with("# ") || trimmed.contains("\n## ") {
1445 return ContentType::Markdown;
1446 }
1447
1448 let code_patterns = [
1451 "fn ",
1452 "pub fn ",
1453 "def ",
1454 "class ",
1455 "function ",
1456 "const ",
1457 "let ",
1458 "var ",
1459 "import ",
1460 "export ",
1461 "struct ",
1462 "impl ",
1463 "trait ",
1464 ];
1465 let has_code_pattern = trimmed.lines().any(|line| {
1466 let l = line.trim();
1467 code_patterns.iter().any(|p| l.starts_with(p))
1468 });
1469 if has_code_pattern {
1470 return ContentType::Code;
1471 }
1472
1473 ContentType::Text
1474}
1475
1476fn item_to_batch(item: &Item) -> Result<RecordBatch> {
1479 let schema = Arc::new(item_schema());
1480
1481 let id = StringArray::from(vec![item.id.as_str()]);
1482 let content = StringArray::from(vec![item.content.as_str()]);
1483 let project_id = StringArray::from(vec![item.project_id.as_deref()]);
1484 let is_chunked = BooleanArray::from(vec![item.is_chunked]);
1485 let created_at = Int64Array::from(vec![item.created_at.timestamp()]);
1486
1487 let vector = create_embedding_array(&item.embedding)?;
1488
1489 RecordBatch::try_new(
1490 schema,
1491 vec![
1492 Arc::new(id),
1493 Arc::new(content),
1494 Arc::new(project_id),
1495 Arc::new(is_chunked),
1496 Arc::new(created_at),
1497 Arc::new(vector),
1498 ],
1499 )
1500 .map_err(|e| SedimentError::Database(format!("Failed to create batch: {}", e)))
1501}
1502
1503fn batch_to_items(batch: &RecordBatch) -> Result<Vec<Item>> {
1504 let mut items = Vec::new();
1505
1506 let id_col = batch
1507 .column_by_name("id")
1508 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1509 .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?;
1510
1511 let content_col = batch
1512 .column_by_name("content")
1513 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1514 .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?;
1515
1516 let project_id_col = batch
1517 .column_by_name("project_id")
1518 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
1519
1520 let is_chunked_col = batch
1521 .column_by_name("is_chunked")
1522 .and_then(|c| c.as_any().downcast_ref::<BooleanArray>());
1523
1524 let created_at_col = batch
1525 .column_by_name("created_at")
1526 .and_then(|c| c.as_any().downcast_ref::<Int64Array>());
1527
1528 let vector_col = batch
1529 .column_by_name("vector")
1530 .and_then(|c| c.as_any().downcast_ref::<FixedSizeListArray>());
1531
1532 for i in 0..batch.num_rows() {
1533 let id = id_col.value(i).to_string();
1534 let content = content_col.value(i).to_string();
1535
1536 let project_id = project_id_col.and_then(|c| {
1537 if c.is_null(i) {
1538 None
1539 } else {
1540 Some(c.value(i).to_string())
1541 }
1542 });
1543
1544 let is_chunked = is_chunked_col.map(|c| c.value(i)).unwrap_or(false);
1545
1546 let created_at = created_at_col
1547 .map(|c| {
1548 Utc.timestamp_opt(c.value(i), 0)
1549 .single()
1550 .unwrap_or_else(Utc::now)
1551 })
1552 .unwrap_or_else(Utc::now);
1553
1554 let embedding = vector_col
1555 .and_then(|col| {
1556 let value = col.value(i);
1557 value
1558 .as_any()
1559 .downcast_ref::<Float32Array>()
1560 .map(|arr| arr.values().to_vec())
1561 })
1562 .unwrap_or_default();
1563
1564 let item = Item {
1565 id,
1566 content,
1567 embedding,
1568 project_id,
1569 is_chunked,
1570 created_at,
1571 };
1572
1573 items.push(item);
1574 }
1575
1576 Ok(items)
1577}
1578
1579fn chunk_to_batch(chunk: &Chunk) -> Result<RecordBatch> {
1580 let schema = Arc::new(chunk_schema());
1581
1582 let id = StringArray::from(vec![chunk.id.as_str()]);
1583 let item_id = StringArray::from(vec![chunk.item_id.as_str()]);
1584 let chunk_index = Int32Array::from(vec![i32::try_from(chunk.chunk_index).unwrap_or(i32::MAX)]);
1585 let content = StringArray::from(vec![chunk.content.as_str()]);
1586 let context = StringArray::from(vec![chunk.context.as_deref()]);
1587
1588 let vector = create_embedding_array(&chunk.embedding)?;
1589
1590 RecordBatch::try_new(
1591 schema,
1592 vec![
1593 Arc::new(id),
1594 Arc::new(item_id),
1595 Arc::new(chunk_index),
1596 Arc::new(content),
1597 Arc::new(context),
1598 Arc::new(vector),
1599 ],
1600 )
1601 .map_err(|e| SedimentError::Database(format!("Failed to create batch: {}", e)))
1602}
1603
1604fn batch_to_chunks(batch: &RecordBatch) -> Result<Vec<Chunk>> {
1605 let mut chunks = Vec::new();
1606
1607 let id_col = batch
1608 .column_by_name("id")
1609 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1610 .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?;
1611
1612 let item_id_col = batch
1613 .column_by_name("item_id")
1614 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1615 .ok_or_else(|| SedimentError::Database("Missing item_id column".to_string()))?;
1616
1617 let chunk_index_col = batch
1618 .column_by_name("chunk_index")
1619 .and_then(|c| c.as_any().downcast_ref::<Int32Array>())
1620 .ok_or_else(|| SedimentError::Database("Missing chunk_index column".to_string()))?;
1621
1622 let content_col = batch
1623 .column_by_name("content")
1624 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1625 .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?;
1626
1627 let context_col = batch
1628 .column_by_name("context")
1629 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
1630
1631 for i in 0..batch.num_rows() {
1632 let id = id_col.value(i).to_string();
1633 let item_id = item_id_col.value(i).to_string();
1634 let chunk_index = chunk_index_col.value(i) as usize;
1635 let content = content_col.value(i).to_string();
1636 let context = context_col.and_then(|c| {
1637 if c.is_null(i) {
1638 None
1639 } else {
1640 Some(c.value(i).to_string())
1641 }
1642 });
1643
1644 let chunk = Chunk {
1645 id,
1646 item_id,
1647 chunk_index,
1648 content,
1649 embedding: Vec::new(),
1650 context,
1651 };
1652
1653 chunks.push(chunk);
1654 }
1655
1656 Ok(chunks)
1657}
1658
1659fn create_embedding_array(embedding: &[f32]) -> Result<FixedSizeListArray> {
1660 let values = Float32Array::from(embedding.to_vec());
1661 let field = Arc::new(Field::new("item", DataType::Float32, true));
1662
1663 FixedSizeListArray::try_new(field, EMBEDDING_DIM as i32, Arc::new(values), None)
1664 .map_err(|e| SedimentError::Database(format!("Failed to create vector: {}", e)))
1665}
1666
1667#[cfg(test)]
1668mod tests {
1669 use super::*;
1670
1671 #[test]
1672 fn test_score_with_decay_fresh_item() {
1673 let now = 1700000000i64;
1674 let created = now; let score = score_with_decay(0.8, now, created, 0, None);
1676 let expected = 0.8 * 1.0 * 1.0;
1678 assert!((score - expected).abs() < 0.001, "got {}", score);
1679 }
1680
1681 #[test]
1682 fn test_score_with_decay_30_day_old() {
1683 let now = 1700000000i64;
1684 let created = now - 30 * 86400; let score = score_with_decay(0.8, now, created, 0, None);
1686 let expected = 0.8 * 0.5;
1688 assert!((score - expected).abs() < 0.001, "got {}", score);
1689 }
1690
1691 #[test]
1692 fn test_score_with_decay_frequent_access() {
1693 let now = 1700000000i64;
1694 let created = now - 30 * 86400;
1695 let last_accessed = now; let score = score_with_decay(0.8, now, created, 10, Some(last_accessed));
1697 let freq = 1.0 + 0.1 * (11.0_f64).ln();
1699 let expected = 0.8 * 1.0 * freq as f32;
1700 assert!((score - expected).abs() < 0.01, "got {}", score);
1701 }
1702
1703 #[test]
1704 fn test_score_with_decay_old_and_unused() {
1705 let now = 1700000000i64;
1706 let created = now - 90 * 86400; let score = score_with_decay(0.8, now, created, 0, None);
1708 let expected = 0.8 * 0.25;
1710 assert!((score - expected).abs() < 0.001, "got {}", score);
1711 }
1712
1713 #[test]
1714 fn test_sanitize_sql_string_escapes_quotes_and_backslashes() {
1715 assert_eq!(sanitize_sql_string("hello"), "hello");
1716 assert_eq!(sanitize_sql_string("it's"), "it''s");
1717 assert_eq!(sanitize_sql_string(r"a\'b"), r"a\\''b");
1718 assert_eq!(sanitize_sql_string(r"path\to\file"), r"path\\to\\file");
1719 }
1720
1721 #[test]
1722 fn test_sanitize_sql_string_strips_null_bytes() {
1723 assert_eq!(sanitize_sql_string("abc\0def"), "abcdef");
1724 assert_eq!(sanitize_sql_string("\0' OR 1=1 --"), "'' OR 1=1 ");
1725 assert_eq!(sanitize_sql_string("*/ OR 1=1"), " OR 1=1");
1727 assert_eq!(sanitize_sql_string("clean"), "clean");
1728 }
1729
1730 #[test]
1731 fn test_sanitize_sql_string_strips_semicolons() {
1732 assert_eq!(
1733 sanitize_sql_string("a; DROP TABLE items"),
1734 "a DROP TABLE items"
1735 );
1736 assert_eq!(sanitize_sql_string("normal;"), "normal");
1737 }
1738
1739 #[test]
1740 fn test_sanitize_sql_string_strips_comments() {
1741 assert_eq!(sanitize_sql_string("val' -- comment"), "val'' comment");
1743 assert_eq!(sanitize_sql_string("val' /* block */"), "val'' block ");
1745 assert_eq!(sanitize_sql_string("a--b--c"), "abc");
1747 assert_eq!(sanitize_sql_string("injected */ rest"), "injected rest");
1749 assert_eq!(sanitize_sql_string("*/"), "");
1751 }
1752
1753 #[test]
1754 fn test_sanitize_sql_string_adversarial_inputs() {
1755 assert_eq!(
1757 sanitize_sql_string("'; DROP TABLE items;--"),
1758 "'' DROP TABLE items"
1759 );
1760 assert_eq!(
1762 sanitize_sql_string("hello\u{200B}world"),
1763 "hello\u{200B}world"
1764 );
1765 assert_eq!(sanitize_sql_string(""), "");
1767 assert_eq!(sanitize_sql_string("\0;\0"), "");
1769 }
1770
1771 #[test]
1772 fn test_is_valid_id() {
1773 assert!(is_valid_id("550e8400-e29b-41d4-a716-446655440000"));
1775 assert!(is_valid_id("abcdef0123456789"));
1776 assert!(!is_valid_id(""));
1778 assert!(!is_valid_id("'; DROP TABLE items;--"));
1779 assert!(!is_valid_id("hello world"));
1780 assert!(!is_valid_id("abc\0def"));
1781 assert!(!is_valid_id(&"a".repeat(65)));
1783 }
1784
1785 #[test]
1786 fn test_detect_content_type_yaml_not_prose() {
1787 let prose = "Dear John:\nI wanted to write you about something.\nSubject: important matter";
1789 let detected = detect_content_type(prose);
1790 assert_ne!(
1791 detected,
1792 ContentType::Yaml,
1793 "Prose with colons should not be detected as YAML"
1794 );
1795
1796 let yaml = "server: localhost\nport: 8080\ndatabase: mydb";
1798 let detected = detect_content_type(yaml);
1799 assert_eq!(detected, ContentType::Yaml);
1800 }
1801
1802 #[test]
1803 fn test_detect_content_type_yaml_with_separator() {
1804 let yaml = "---\nname: test\nversion: 1.0";
1805 let detected = detect_content_type(yaml);
1806 assert_eq!(detected, ContentType::Yaml);
1807 }
1808
1809 #[test]
1810 fn test_chunk_threshold_uses_chars_not_bytes() {
1811 let emoji_content = "😀".repeat(500);
1814 assert_eq!(emoji_content.chars().count(), 500);
1815 assert_eq!(emoji_content.len(), 2000); let should_chunk = emoji_content.chars().count() > CHUNK_THRESHOLD;
1818 assert!(
1819 !should_chunk,
1820 "500 chars should not exceed 1000-char threshold"
1821 );
1822
1823 let long_content = "a".repeat(1001);
1825 let should_chunk = long_content.chars().count() > CHUNK_THRESHOLD;
1826 assert!(should_chunk, "1001 chars should exceed 1000-char threshold");
1827 }
1828
1829 #[test]
1830 fn test_schema_version() {
1831 let version = SCHEMA_VERSION;
1833 assert!(version >= 2, "Schema version should be at least 2");
1834 }
1835
1836 fn old_item_schema() -> Schema {
1838 Schema::new(vec![
1839 Field::new("id", DataType::Utf8, false),
1840 Field::new("content", DataType::Utf8, false),
1841 Field::new("project_id", DataType::Utf8, true),
1842 Field::new("tags", DataType::Utf8, true), Field::new("is_chunked", DataType::Boolean, false),
1844 Field::new("created_at", DataType::Int64, false),
1845 Field::new(
1846 "vector",
1847 DataType::FixedSizeList(
1848 Arc::new(Field::new("item", DataType::Float32, true)),
1849 EMBEDDING_DIM as i32,
1850 ),
1851 false,
1852 ),
1853 ])
1854 }
1855
1856 fn old_item_batch(id: &str, content: &str) -> RecordBatch {
1858 let schema = Arc::new(old_item_schema());
1859 let vector_values = Float32Array::from(vec![0.0f32; EMBEDDING_DIM]);
1860 let vector_field = Arc::new(Field::new("item", DataType::Float32, true));
1861 let vector = FixedSizeListArray::try_new(
1862 vector_field,
1863 EMBEDDING_DIM as i32,
1864 Arc::new(vector_values),
1865 None,
1866 )
1867 .unwrap();
1868
1869 RecordBatch::try_new(
1870 schema,
1871 vec![
1872 Arc::new(StringArray::from(vec![id])),
1873 Arc::new(StringArray::from(vec![content])),
1874 Arc::new(StringArray::from(vec![None::<&str>])), Arc::new(StringArray::from(vec![None::<&str>])), Arc::new(BooleanArray::from(vec![false])),
1877 Arc::new(Int64Array::from(vec![1700000000i64])),
1878 Arc::new(vector),
1879 ],
1880 )
1881 .unwrap()
1882 }
1883
1884 #[tokio::test]
1885 #[ignore] async fn test_check_needs_migration_detects_old_schema() {
1887 let tmp = tempfile::TempDir::new().unwrap();
1888 let db_path = tmp.path().join("data");
1889
1890 let db_conn = lancedb::connect(db_path.to_str().unwrap())
1892 .execute()
1893 .await
1894 .unwrap();
1895
1896 let schema = Arc::new(old_item_schema());
1897 let batch = old_item_batch("test-id-1", "old content");
1898 let batches = RecordBatchIterator::new(vec![Ok(batch)], schema);
1899 db_conn
1900 .create_table("items", Box::new(batches))
1901 .execute()
1902 .await
1903 .unwrap();
1904
1905 let db = Database {
1907 db: db_conn,
1908 embedder: Arc::new(Embedder::new().unwrap()),
1909 project_id: None,
1910 items_table: None,
1911 chunks_table: None,
1912 fts_boost_max: FTS_BOOST_MAX,
1913 fts_gamma: FTS_GAMMA,
1914 };
1915
1916 let needs_migration = db.check_needs_migration().await.unwrap();
1917 assert!(
1918 needs_migration,
1919 "Old schema with tags column should need migration"
1920 );
1921 }
1922
1923 #[tokio::test]
1924 #[ignore] async fn test_check_needs_migration_false_for_new_schema() {
1926 let tmp = tempfile::TempDir::new().unwrap();
1927 let db_path = tmp.path().join("data");
1928
1929 let db_conn = lancedb::connect(db_path.to_str().unwrap())
1931 .execute()
1932 .await
1933 .unwrap();
1934
1935 let schema = Arc::new(item_schema());
1936 db_conn
1937 .create_empty_table("items", schema)
1938 .execute()
1939 .await
1940 .unwrap();
1941
1942 let db = Database {
1943 db: db_conn,
1944 embedder: Arc::new(Embedder::new().unwrap()),
1945 project_id: None,
1946 items_table: None,
1947 chunks_table: None,
1948 fts_boost_max: FTS_BOOST_MAX,
1949 fts_gamma: FTS_GAMMA,
1950 };
1951
1952 let needs_migration = db.check_needs_migration().await.unwrap();
1953 assert!(!needs_migration, "New schema should not need migration");
1954 }
1955
1956 #[tokio::test]
1957 #[ignore] async fn test_migrate_schema_preserves_data() {
1959 let tmp = tempfile::TempDir::new().unwrap();
1960 let db_path = tmp.path().join("data");
1961
1962 let db_conn = lancedb::connect(db_path.to_str().unwrap())
1964 .execute()
1965 .await
1966 .unwrap();
1967
1968 let schema = Arc::new(old_item_schema());
1969 let batch1 = old_item_batch("id-aaa", "first item content");
1970 let batch2 = old_item_batch("id-bbb", "second item content");
1971 let batches = RecordBatchIterator::new(vec![Ok(batch1), Ok(batch2)], schema);
1972 db_conn
1973 .create_table("items", Box::new(batches))
1974 .execute()
1975 .await
1976 .unwrap();
1977 drop(db_conn);
1978
1979 let embedder = Arc::new(Embedder::new().unwrap());
1981 let db = Database::open_with_embedder(&db_path, None, embedder)
1982 .await
1983 .unwrap();
1984
1985 let needs_migration = db.check_needs_migration().await.unwrap();
1987 assert!(
1988 !needs_migration,
1989 "Schema should be migrated (no tags column)"
1990 );
1991
1992 let item_a = db.get_item("id-aaa").await.unwrap();
1994 assert!(item_a.is_some(), "Item id-aaa should be preserved");
1995 assert_eq!(item_a.unwrap().content, "first item content");
1996
1997 let item_b = db.get_item("id-bbb").await.unwrap();
1998 assert!(item_b.is_some(), "Item id-bbb should be preserved");
1999 assert_eq!(item_b.unwrap().content, "second item content");
2000
2001 let stats = db.stats().await.unwrap();
2003 assert_eq!(stats.item_count, 2, "Should have 2 items after migration");
2004 }
2005
2006 #[tokio::test]
2007 #[ignore] async fn test_recover_case_a_only_staging() {
2009 let tmp = tempfile::TempDir::new().unwrap();
2010 let db_path = tmp.path().join("data");
2011
2012 let db_conn = lancedb::connect(db_path.to_str().unwrap())
2014 .execute()
2015 .await
2016 .unwrap();
2017
2018 let schema = Arc::new(item_schema());
2019 let vector_values = Float32Array::from(vec![0.0f32; EMBEDDING_DIM]);
2020 let vector_field = Arc::new(Field::new("item", DataType::Float32, true));
2021 let vector = FixedSizeListArray::try_new(
2022 vector_field,
2023 EMBEDDING_DIM as i32,
2024 Arc::new(vector_values),
2025 None,
2026 )
2027 .unwrap();
2028
2029 let batch = RecordBatch::try_new(
2030 schema.clone(),
2031 vec![
2032 Arc::new(StringArray::from(vec!["staging-id"])),
2033 Arc::new(StringArray::from(vec!["staging content"])),
2034 Arc::new(StringArray::from(vec![None::<&str>])),
2035 Arc::new(BooleanArray::from(vec![false])),
2036 Arc::new(Int64Array::from(vec![1700000000i64])),
2037 Arc::new(vector),
2038 ],
2039 )
2040 .unwrap();
2041
2042 let batches = RecordBatchIterator::new(vec![Ok(batch)], schema);
2043 db_conn
2044 .create_table("items_migrated", Box::new(batches))
2045 .execute()
2046 .await
2047 .unwrap();
2048 drop(db_conn);
2049
2050 let embedder = Arc::new(Embedder::new().unwrap());
2052 let db = Database::open_with_embedder(&db_path, None, embedder)
2053 .await
2054 .unwrap();
2055
2056 let item = db.get_item("staging-id").await.unwrap();
2058 assert!(item.is_some(), "Item should be recovered from staging");
2059 assert_eq!(item.unwrap().content, "staging content");
2060
2061 let table_names = db.db.table_names().execute().await.unwrap();
2063 assert!(
2064 !table_names.contains(&"items_migrated".to_string()),
2065 "Staging table should be dropped"
2066 );
2067 }
2068
2069 #[tokio::test]
2070 #[ignore] async fn test_recover_case_b_both_old_schema() {
2072 let tmp = tempfile::TempDir::new().unwrap();
2073 let db_path = tmp.path().join("data");
2074
2075 let db_conn = lancedb::connect(db_path.to_str().unwrap())
2077 .execute()
2078 .await
2079 .unwrap();
2080
2081 let old_schema = Arc::new(old_item_schema());
2083 let batch = old_item_batch("old-id", "old content");
2084 let batches = RecordBatchIterator::new(vec![Ok(batch)], old_schema);
2085 db_conn
2086 .create_table("items", Box::new(batches))
2087 .execute()
2088 .await
2089 .unwrap();
2090
2091 let new_schema = Arc::new(item_schema());
2093 db_conn
2094 .create_empty_table("items_migrated", new_schema)
2095 .execute()
2096 .await
2097 .unwrap();
2098 drop(db_conn);
2099
2100 let embedder = Arc::new(Embedder::new().unwrap());
2102 let db = Database::open_with_embedder(&db_path, None, embedder)
2103 .await
2104 .unwrap();
2105
2106 let needs_migration = db.check_needs_migration().await.unwrap();
2108 assert!(!needs_migration, "Should have migrated after recovery");
2109
2110 let item = db.get_item("old-id").await.unwrap();
2112 assert!(
2113 item.is_some(),
2114 "Item should be preserved through recovery + migration"
2115 );
2116
2117 let table_names = db.db.table_names().execute().await.unwrap();
2119 assert!(
2120 !table_names.contains(&"items_migrated".to_string()),
2121 "Staging table should be dropped"
2122 );
2123 }
2124
2125 #[tokio::test]
2126 #[ignore] async fn test_recover_case_c_both_new_schema() {
2128 let tmp = tempfile::TempDir::new().unwrap();
2129 let db_path = tmp.path().join("data");
2130
2131 let db_conn = lancedb::connect(db_path.to_str().unwrap())
2133 .execute()
2134 .await
2135 .unwrap();
2136
2137 let new_schema = Arc::new(item_schema());
2138
2139 let vector_values = Float32Array::from(vec![0.0f32; EMBEDDING_DIM]);
2141 let vector_field = Arc::new(Field::new("item", DataType::Float32, true));
2142 let vector = FixedSizeListArray::try_new(
2143 vector_field,
2144 EMBEDDING_DIM as i32,
2145 Arc::new(vector_values),
2146 None,
2147 )
2148 .unwrap();
2149
2150 let batch = RecordBatch::try_new(
2151 new_schema.clone(),
2152 vec![
2153 Arc::new(StringArray::from(vec!["new-id"])),
2154 Arc::new(StringArray::from(vec!["new content"])),
2155 Arc::new(StringArray::from(vec![None::<&str>])),
2156 Arc::new(BooleanArray::from(vec![false])),
2157 Arc::new(Int64Array::from(vec![1700000000i64])),
2158 Arc::new(vector),
2159 ],
2160 )
2161 .unwrap();
2162
2163 let batches = RecordBatchIterator::new(vec![Ok(batch)], new_schema.clone());
2164 db_conn
2165 .create_table("items", Box::new(batches))
2166 .execute()
2167 .await
2168 .unwrap();
2169
2170 db_conn
2172 .create_empty_table("items_migrated", new_schema)
2173 .execute()
2174 .await
2175 .unwrap();
2176 drop(db_conn);
2177
2178 let embedder = Arc::new(Embedder::new().unwrap());
2180 let db = Database::open_with_embedder(&db_path, None, embedder)
2181 .await
2182 .unwrap();
2183
2184 let item = db.get_item("new-id").await.unwrap();
2186 assert!(item.is_some(), "Item should be untouched");
2187 assert_eq!(item.unwrap().content, "new content");
2188
2189 let table_names = db.db.table_names().execute().await.unwrap();
2191 assert!(
2192 !table_names.contains(&"items_migrated".to_string()),
2193 "Staging table should be dropped"
2194 );
2195 }
2196
2197 #[tokio::test]
2198 #[ignore] async fn test_list_items_rejects_invalid_project_id() {
2200 let tmp = tempfile::TempDir::new().unwrap();
2201 let db_path = tmp.path().join("data");
2202 let malicious_pid = "'; DROP TABLE items;--".to_string();
2203
2204 let mut db = Database::open_with_project(&db_path, Some(malicious_pid))
2205 .await
2206 .unwrap();
2207
2208 let result = db.list_items(Some(10), crate::ListScope::Project).await;
2209
2210 assert!(result.is_err(), "Should reject invalid project_id");
2211 let err_msg = result.unwrap_err().to_string();
2212 assert!(
2213 err_msg.contains("Invalid project_id"),
2214 "Error should mention invalid project_id, got: {}",
2215 err_msg
2216 );
2217 }
2218}