1use std::path::PathBuf;
7use std::sync::Arc;
8
9fn sanitize_sql_string(s: &str) -> String {
19 s.replace('\0', "")
20 .replace('\\', "\\\\")
21 .replace('\'', "''")
22 .replace(';', "")
23 .replace("--", "")
24 .replace("/*", "")
25 .replace("*/", "")
26}
27
28fn is_valid_id(s: &str) -> bool {
32 !s.is_empty() && s.len() <= 64 && s.chars().all(|c| c.is_ascii_hexdigit() || c == '-')
33}
34
35use arrow_array::{
36 Array, BooleanArray, FixedSizeListArray, Float32Array, Int32Array, Int64Array, RecordBatch,
37 RecordBatchIterator, StringArray,
38};
39use arrow_schema::{DataType, Field, Schema};
40use chrono::{TimeZone, Utc};
41use futures::TryStreamExt;
42use lancedb::Table;
43use lancedb::connect;
44use lancedb::index::scalar::FullTextSearchQuery;
45use lancedb::query::{ExecutableQuery, QueryBase};
46use tracing::{debug, info};
47
48use crate::boost_similarity;
49use crate::chunker::{ChunkingConfig, chunk_content};
50use crate::document::ContentType;
51use crate::embedder::{EMBEDDING_DIM, Embedder};
52use crate::error::{Result, SedimentError};
53use crate::item::{Chunk, ConflictInfo, Item, ItemFilters, SearchResult, StoreResult};
54
55const CHUNK_THRESHOLD: usize = 1000;
57
58const CONFLICT_SIMILARITY_THRESHOLD: f32 = 0.85;
60
61const CONFLICT_SEARCH_LIMIT: usize = 5;
63
64const MAX_CHUNKS_PER_ITEM: usize = 200;
68
69const EMBEDDING_BATCH_SIZE: usize = 32;
72
73pub struct Database {
75 db: lancedb::Connection,
76 embedder: Arc<Embedder>,
77 project_id: Option<String>,
78 items_table: Option<Table>,
79 chunks_table: Option<Table>,
80}
81
82#[derive(Debug, Default, Clone)]
84pub struct DatabaseStats {
85 pub item_count: usize,
86 pub chunk_count: usize,
87}
88
89const SCHEMA_VERSION: i32 = 2;
91
92fn item_schema() -> Schema {
94 Schema::new(vec![
95 Field::new("id", DataType::Utf8, false),
96 Field::new("content", DataType::Utf8, false),
97 Field::new("project_id", DataType::Utf8, true),
98 Field::new("is_chunked", DataType::Boolean, false),
99 Field::new("created_at", DataType::Int64, false), Field::new(
101 "vector",
102 DataType::FixedSizeList(
103 Arc::new(Field::new("item", DataType::Float32, true)),
104 EMBEDDING_DIM as i32,
105 ),
106 false,
107 ),
108 ])
109}
110
111fn chunk_schema() -> Schema {
112 Schema::new(vec![
113 Field::new("id", DataType::Utf8, false),
114 Field::new("item_id", DataType::Utf8, false),
115 Field::new("chunk_index", DataType::Int32, false),
116 Field::new("content", DataType::Utf8, false),
117 Field::new("context", DataType::Utf8, true),
118 Field::new(
119 "vector",
120 DataType::FixedSizeList(
121 Arc::new(Field::new("item", DataType::Float32, true)),
122 EMBEDDING_DIM as i32,
123 ),
124 false,
125 ),
126 ])
127}
128
129impl Database {
130 pub async fn open(path: impl Into<PathBuf>) -> Result<Self> {
132 Self::open_with_project(path, None).await
133 }
134
135 pub async fn open_with_project(
137 path: impl Into<PathBuf>,
138 project_id: Option<String>,
139 ) -> Result<Self> {
140 let embedder = Arc::new(Embedder::new()?);
141 Self::open_with_embedder(path, project_id, embedder).await
142 }
143
144 pub async fn open_with_embedder(
156 path: impl Into<PathBuf>,
157 project_id: Option<String>,
158 embedder: Arc<Embedder>,
159 ) -> Result<Self> {
160 let path = path.into();
161 info!("Opening database at {:?}", path);
162
163 if let Some(parent) = path.parent() {
165 std::fs::create_dir_all(parent).map_err(|e| {
166 SedimentError::Database(format!("Failed to create database directory: {}", e))
167 })?;
168 }
169
170 let db = connect(path.to_str().ok_or_else(|| {
171 SedimentError::Database("Database path contains invalid UTF-8".to_string())
172 })?)
173 .execute()
174 .await
175 .map_err(|e| SedimentError::Database(format!("Failed to connect to database: {}", e)))?;
176
177 let mut database = Self {
178 db,
179 embedder,
180 project_id,
181 items_table: None,
182 chunks_table: None,
183 };
184
185 database.ensure_tables().await?;
186 database.ensure_vector_index().await?;
187
188 Ok(database)
189 }
190
191 pub fn set_project_id(&mut self, project_id: Option<String>) {
193 self.project_id = project_id;
194 }
195
196 pub fn project_id(&self) -> Option<&str> {
198 self.project_id.as_deref()
199 }
200
201 async fn ensure_tables(&mut self) -> Result<()> {
203 let mut table_names = self
205 .db
206 .table_names()
207 .execute()
208 .await
209 .map_err(|e| SedimentError::Database(format!("Failed to list tables: {}", e)))?;
210
211 if table_names.contains(&"items_migrated".to_string()) {
213 info!("Detected interrupted migration, recovering...");
214 self.recover_interrupted_migration(&table_names).await?;
215 table_names =
217 self.db.table_names().execute().await.map_err(|e| {
218 SedimentError::Database(format!("Failed to list tables: {}", e))
219 })?;
220 }
221
222 if table_names.contains(&"items".to_string()) {
224 let needs_migration = self.check_needs_migration().await?;
225 if needs_migration {
226 info!("Migrating database schema to version {}", SCHEMA_VERSION);
227 self.migrate_schema().await?;
228 }
229 }
230
231 if table_names.contains(&"items".to_string()) {
233 self.items_table =
234 Some(self.db.open_table("items").execute().await.map_err(|e| {
235 SedimentError::Database(format!("Failed to open items: {}", e))
236 })?);
237 }
238
239 if table_names.contains(&"chunks".to_string()) {
241 self.chunks_table =
242 Some(self.db.open_table("chunks").execute().await.map_err(|e| {
243 SedimentError::Database(format!("Failed to open chunks: {}", e))
244 })?);
245 }
246
247 Ok(())
248 }
249
250 async fn check_needs_migration(&self) -> Result<bool> {
252 let table = self.db.open_table("items").execute().await.map_err(|e| {
253 SedimentError::Database(format!("Failed to open items for check: {}", e))
254 })?;
255
256 let schema = table
257 .schema()
258 .await
259 .map_err(|e| SedimentError::Database(format!("Failed to get schema: {}", e)))?;
260
261 let has_tags = schema.fields().iter().any(|f| f.name() == "tags");
263 Ok(has_tags)
264 }
265
266 async fn recover_interrupted_migration(&mut self, table_names: &[String]) -> Result<()> {
277 let has_items = table_names.contains(&"items".to_string());
278
279 if !has_items {
280 info!("Recovery case A: restoring items from items_migrated");
282 let staging = self
283 .db
284 .open_table("items_migrated")
285 .execute()
286 .await
287 .map_err(|e| {
288 SedimentError::Database(format!("Failed to open staging table: {}", e))
289 })?;
290
291 let results = staging
292 .query()
293 .execute()
294 .await
295 .map_err(|e| SedimentError::Database(format!("Recovery query failed: {}", e)))?
296 .try_collect::<Vec<_>>()
297 .await
298 .map_err(|e| SedimentError::Database(format!("Recovery collect failed: {}", e)))?;
299
300 let schema = Arc::new(item_schema());
301 let new_table = self
302 .db
303 .create_empty_table("items", schema.clone())
304 .execute()
305 .await
306 .map_err(|e| {
307 SedimentError::Database(format!("Failed to create items table: {}", e))
308 })?;
309
310 if !results.is_empty() {
311 let batches = RecordBatchIterator::new(results.into_iter().map(Ok), schema);
312 new_table
313 .add(Box::new(batches))
314 .execute()
315 .await
316 .map_err(|e| {
317 SedimentError::Database(format!("Failed to restore items: {}", e))
318 })?;
319 }
320
321 self.db.drop_table("items_migrated").await.map_err(|e| {
322 SedimentError::Database(format!("Failed to drop staging table: {}", e))
323 })?;
324 info!("Recovery case A completed");
325 } else {
326 let has_old_schema = self.check_needs_migration().await?;
328
329 if has_old_schema {
330 info!("Recovery case B: dropping incomplete staging table");
332 self.db.drop_table("items_migrated").await.map_err(|e| {
333 SedimentError::Database(format!("Failed to drop staging table: {}", e))
334 })?;
335 } else {
337 info!("Recovery case C: dropping leftover staging table");
339 self.db.drop_table("items_migrated").await.map_err(|e| {
340 SedimentError::Database(format!("Failed to drop staging table: {}", e))
341 })?;
342 }
343 }
344
345 Ok(())
346 }
347
348 async fn migrate_schema(&mut self) -> Result<()> {
362 info!("Starting schema migration...");
363
364 let old_table = self
366 .db
367 .open_table("items")
368 .execute()
369 .await
370 .map_err(|e| SedimentError::Database(format!("Failed to open old items: {}", e)))?;
371
372 let results = old_table
373 .query()
374 .execute()
375 .await
376 .map_err(|e| SedimentError::Database(format!("Migration query failed: {}", e)))?
377 .try_collect::<Vec<_>>()
378 .await
379 .map_err(|e| SedimentError::Database(format!("Migration collect failed: {}", e)))?;
380
381 let mut new_batches = Vec::new();
383 for batch in &results {
384 let converted = self.convert_batch_to_new_schema(batch)?;
385 new_batches.push(converted);
386 }
387
388 let old_count: usize = results.iter().map(|b| b.num_rows()).sum();
390 let new_count: usize = new_batches.iter().map(|b| b.num_rows()).sum();
391 if old_count != new_count {
392 return Err(SedimentError::Database(format!(
393 "Migration row count mismatch: old={}, new={}",
394 old_count, new_count
395 )));
396 }
397 info!("Migrating {} items to new schema", old_count);
398
399 let table_names = self
401 .db
402 .table_names()
403 .execute()
404 .await
405 .map_err(|e| SedimentError::Database(format!("Failed to list tables: {}", e)))?;
406 if table_names.contains(&"items_migrated".to_string()) {
407 self.db.drop_table("items_migrated").await.map_err(|e| {
408 SedimentError::Database(format!("Failed to drop stale staging: {}", e))
409 })?;
410 }
411
412 let schema = Arc::new(item_schema());
414 let staging_table = self
415 .db
416 .create_empty_table("items_migrated", schema.clone())
417 .execute()
418 .await
419 .map_err(|e| {
420 SedimentError::Database(format!("Failed to create staging table: {}", e))
421 })?;
422
423 if !new_batches.is_empty() {
424 let batches = RecordBatchIterator::new(new_batches.into_iter().map(Ok), schema.clone());
425 staging_table
426 .add(Box::new(batches))
427 .execute()
428 .await
429 .map_err(|e| {
430 SedimentError::Database(format!("Failed to insert into staging: {}", e))
431 })?;
432 }
433
434 let staging_count = staging_table
436 .count_rows(None)
437 .await
438 .map_err(|e| SedimentError::Database(format!("Failed to count staging rows: {}", e)))?;
439 if staging_count != old_count {
440 let _ = self.db.drop_table("items_migrated").await;
442 return Err(SedimentError::Database(format!(
443 "Staging row count mismatch: expected {}, got {}",
444 old_count, staging_count
445 )));
446 }
447
448 self.db.drop_table("items").await.map_err(|e| {
450 SedimentError::Database(format!("Failed to drop old items table: {}", e))
451 })?;
452
453 let staging_data = staging_table
455 .query()
456 .execute()
457 .await
458 .map_err(|e| SedimentError::Database(format!("Failed to read staging: {}", e)))?
459 .try_collect::<Vec<_>>()
460 .await
461 .map_err(|e| SedimentError::Database(format!("Failed to collect staging: {}", e)))?;
462
463 let new_table = self
464 .db
465 .create_empty_table("items", schema.clone())
466 .execute()
467 .await
468 .map_err(|e| {
469 SedimentError::Database(format!("Failed to create new items table: {}", e))
470 })?;
471
472 if !staging_data.is_empty() {
473 let batches = RecordBatchIterator::new(staging_data.into_iter().map(Ok), schema);
474 new_table
475 .add(Box::new(batches))
476 .execute()
477 .await
478 .map_err(|e| {
479 SedimentError::Database(format!("Failed to insert migrated items: {}", e))
480 })?;
481 }
482
483 self.db
485 .drop_table("items_migrated")
486 .await
487 .map_err(|e| SedimentError::Database(format!("Failed to drop staging table: {}", e)))?;
488
489 info!("Schema migration completed successfully");
490 Ok(())
491 }
492
493 fn convert_batch_to_new_schema(&self, batch: &RecordBatch) -> Result<RecordBatch> {
495 let schema = Arc::new(item_schema());
496
497 let id_col = batch
499 .column_by_name("id")
500 .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?
501 .clone();
502
503 let content_col = batch
504 .column_by_name("content")
505 .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?
506 .clone();
507
508 let project_id_col = batch
509 .column_by_name("project_id")
510 .ok_or_else(|| SedimentError::Database("Missing project_id column".to_string()))?
511 .clone();
512
513 let is_chunked_col = batch
514 .column_by_name("is_chunked")
515 .ok_or_else(|| SedimentError::Database("Missing is_chunked column".to_string()))?
516 .clone();
517
518 let created_at_col = batch
519 .column_by_name("created_at")
520 .ok_or_else(|| SedimentError::Database("Missing created_at column".to_string()))?
521 .clone();
522
523 let vector_col = batch
524 .column_by_name("vector")
525 .ok_or_else(|| SedimentError::Database("Missing vector column".to_string()))?
526 .clone();
527
528 RecordBatch::try_new(
529 schema,
530 vec![
531 id_col,
532 content_col,
533 project_id_col,
534 is_chunked_col,
535 created_at_col,
536 vector_col,
537 ],
538 )
539 .map_err(|e| SedimentError::Database(format!("Failed to create migrated batch: {}", e)))
540 }
541
542 async fn ensure_vector_index(&self) -> Result<()> {
548 const MIN_ROWS_FOR_INDEX: usize = 256;
549
550 for (name, table_opt) in [("items", &self.items_table), ("chunks", &self.chunks_table)] {
551 if let Some(table) = table_opt {
552 let row_count = table.count_rows(None).await.unwrap_or(0);
553
554 let indices = table.list_indices().await.unwrap_or_default();
556
557 if row_count >= MIN_ROWS_FOR_INDEX {
559 let has_vector_index = indices
560 .iter()
561 .any(|idx| idx.columns.contains(&"vector".to_string()));
562
563 if !has_vector_index {
564 info!(
565 "Creating vector index on {} table ({} rows)",
566 name, row_count
567 );
568 match table
569 .create_index(&["vector"], lancedb::index::Index::Auto)
570 .execute()
571 .await
572 {
573 Ok(_) => info!("Vector index created on {} table", name),
574 Err(e) => {
575 tracing::warn!("Failed to create vector index on {}: {}", name, e);
577 }
578 }
579 }
580 }
581
582 if row_count > 0 {
584 let has_fts_index = indices
585 .iter()
586 .any(|idx| idx.columns.contains(&"content".to_string()));
587
588 if !has_fts_index {
589 info!("Creating FTS index on {} table ({} rows)", name, row_count);
590 match table
591 .create_index(
592 &["content"],
593 lancedb::index::Index::FTS(Default::default()),
594 )
595 .execute()
596 .await
597 {
598 Ok(_) => info!("FTS index created on {} table", name),
599 Err(e) => {
600 tracing::warn!("Failed to create FTS index on {}: {}", name, e);
602 }
603 }
604 }
605 }
606 }
607 }
608
609 Ok(())
610 }
611
612 async fn get_items_table(&mut self) -> Result<&Table> {
614 if self.items_table.is_none() {
615 let schema = Arc::new(item_schema());
616 let table = self
617 .db
618 .create_empty_table("items", schema)
619 .execute()
620 .await
621 .map_err(|e| {
622 SedimentError::Database(format!("Failed to create items table: {}", e))
623 })?;
624 self.items_table = Some(table);
625 }
626 Ok(self.items_table.as_ref().unwrap())
627 }
628
629 async fn get_chunks_table(&mut self) -> Result<&Table> {
631 if self.chunks_table.is_none() {
632 let schema = Arc::new(chunk_schema());
633 let table = self
634 .db
635 .create_empty_table("chunks", schema)
636 .execute()
637 .await
638 .map_err(|e| {
639 SedimentError::Database(format!("Failed to create chunks table: {}", e))
640 })?;
641 self.chunks_table = Some(table);
642 }
643 Ok(self.chunks_table.as_ref().unwrap())
644 }
645
646 pub async fn store_item(&mut self, mut item: Item) -> Result<StoreResult> {
653 if item.project_id.is_none() {
655 item.project_id = self.project_id.clone();
656 }
657
658 let should_chunk = item.content.chars().count() > CHUNK_THRESHOLD;
661 item.is_chunked = should_chunk;
662
663 let embedding_text = item.embedding_text();
665 let embedding = self.embedder.embed(&embedding_text)?;
666 item.embedding = embedding;
667
668 let table = self.get_items_table().await?;
670 let batch = item_to_batch(&item)?;
671 let batches = RecordBatchIterator::new(vec![Ok(batch)], Arc::new(item_schema()));
672
673 table
674 .add(Box::new(batches))
675 .execute()
676 .await
677 .map_err(|e| SedimentError::Database(format!("Failed to store item: {}", e)))?;
678
679 if should_chunk {
681 let content_type = detect_content_type(&item.content);
682 let config = ChunkingConfig::default();
683 let mut chunk_results = chunk_content(&item.content, content_type, &config);
684
685 if chunk_results.len() > MAX_CHUNKS_PER_ITEM {
687 tracing::warn!(
688 "Chunk count {} exceeds limit {}, truncating",
689 chunk_results.len(),
690 MAX_CHUNKS_PER_ITEM
691 );
692 chunk_results.truncate(MAX_CHUNKS_PER_ITEM);
693 }
694
695 if let Err(e) = self.store_chunks(&item.id, &chunk_results).await {
696 let _ = self.delete_item(&item.id).await;
698 return Err(e);
699 }
700
701 debug!(
702 "Stored item: {} with {} chunks",
703 item.id,
704 chunk_results.len()
705 );
706 } else {
707 debug!("Stored item: {} (no chunking)", item.id);
708 }
709
710 let potential_conflicts = self
713 .find_similar_items_by_vector(
714 &item.embedding,
715 Some(&item.id),
716 CONFLICT_SIMILARITY_THRESHOLD,
717 CONFLICT_SEARCH_LIMIT,
718 )
719 .await
720 .unwrap_or_default();
721
722 Ok(StoreResult {
723 id: item.id,
724 potential_conflicts,
725 })
726 }
727
728 async fn store_chunks(
730 &mut self,
731 item_id: &str,
732 chunk_results: &[crate::chunker::ChunkResult],
733 ) -> Result<()> {
734 let embedder = self.embedder.clone();
735 let chunks_table = self.get_chunks_table().await?;
736
737 let chunk_texts: Vec<&str> = chunk_results.iter().map(|cr| cr.content.as_str()).collect();
739 let mut all_embeddings = Vec::with_capacity(chunk_texts.len());
740 for batch_start in (0..chunk_texts.len()).step_by(EMBEDDING_BATCH_SIZE) {
741 let batch_end = (batch_start + EMBEDDING_BATCH_SIZE).min(chunk_texts.len());
742 let batch_embeddings = embedder.embed_batch(&chunk_texts[batch_start..batch_end])?;
743 all_embeddings.extend(batch_embeddings);
744 }
745
746 let mut all_chunk_batches = Vec::with_capacity(chunk_results.len());
748 for (i, (chunk_result, embedding)) in chunk_results.iter().zip(all_embeddings).enumerate() {
749 let mut chunk = Chunk::new(item_id, i, &chunk_result.content);
750 if let Some(ctx) = &chunk_result.context {
751 chunk = chunk.with_context(ctx);
752 }
753 chunk.embedding = embedding;
754 all_chunk_batches.push(chunk_to_batch(&chunk)?);
755 }
756
757 if !all_chunk_batches.is_empty() {
759 let schema = Arc::new(chunk_schema());
760 let batches = RecordBatchIterator::new(all_chunk_batches.into_iter().map(Ok), schema);
761 chunks_table
762 .add(Box::new(batches))
763 .execute()
764 .await
765 .map_err(|e| SedimentError::Database(format!("Failed to store chunks: {}", e)))?;
766 }
767
768 Ok(())
769 }
770
771 async fn fts_rank_items(
774 &self,
775 table: &Table,
776 query: &str,
777 limit: usize,
778 ) -> Option<std::collections::HashMap<String, usize>> {
779 let fts_query =
780 FullTextSearchQuery::new(query.to_string()).columns(Some(vec!["content".to_string()]));
781
782 let fts_results = table
783 .query()
784 .full_text_search(fts_query)
785 .limit(limit)
786 .execute()
787 .await
788 .ok()?
789 .try_collect::<Vec<_>>()
790 .await
791 .ok()?;
792
793 let mut ranks = std::collections::HashMap::new();
794 let mut rank = 0usize;
795 for batch in fts_results {
796 let ids = batch
797 .column_by_name("id")
798 .and_then(|c| c.as_any().downcast_ref::<StringArray>())?;
799 for i in 0..ids.len() {
800 if !ids.is_null(i) {
801 ranks.insert(ids.value(i).to_string(), rank);
802 rank += 1;
803 }
804 }
805 }
806 Some(ranks)
807 }
808
809 pub async fn search_items(
811 &mut self,
812 query: &str,
813 limit: usize,
814 filters: ItemFilters,
815 ) -> Result<Vec<SearchResult>> {
816 let limit = limit.min(1000);
818 self.ensure_vector_index().await?;
820
821 let query_embedding = self.embedder.embed(query)?;
823 let min_similarity = filters.min_similarity.unwrap_or(0.3);
824
825 let mut results_map: std::collections::HashMap<String, (SearchResult, f32)> =
827 std::collections::HashMap::new();
828
829 if let Some(table) = &self.items_table {
831 let row_count = table.count_rows(None).await.unwrap_or(0);
832 let base_query = table
833 .vector_search(query_embedding.clone())
834 .map_err(|e| SedimentError::Database(format!("Failed to build search: {}", e)))?;
835 let query_builder = if row_count < 5000 {
836 base_query.bypass_vector_index().limit(limit * 2)
837 } else {
838 base_query.refine_factor(10).limit(limit * 2)
839 };
840
841 let results = query_builder
842 .execute()
843 .await
844 .map_err(|e| SedimentError::Database(format!("Search failed: {}", e)))?
845 .try_collect::<Vec<_>>()
846 .await
847 .map_err(|e| {
848 SedimentError::Database(format!("Failed to collect results: {}", e))
849 })?;
850
851 let mut vector_items: Vec<(Item, f32)> = Vec::new();
853 for batch in results {
854 let items = batch_to_items(&batch)?;
855 let distances = batch
856 .column_by_name("_distance")
857 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
858
859 for (i, item) in items.into_iter().enumerate() {
860 let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
861 let similarity = 1.0 / (1.0 + distance);
862 if similarity >= min_similarity {
863 vector_items.push((item, similarity));
864 }
865 }
866 }
867
868 let fts_ranking = self.fts_rank_items(table, query, limit * 2).await;
870
871 for (item, similarity) in vector_items {
875 let fts_boost = fts_ranking.as_ref().map_or(0.0, |ranks| {
876 ranks
877 .get(&item.id)
878 .map_or(0.0, |&fts_rank| 0.08 / (1.0 + fts_rank as f32))
879 });
880 let boosted_similarity = boost_similarity(
881 similarity + fts_boost,
882 item.project_id.as_deref(),
883 self.project_id.as_deref(),
884 );
885
886 let result = SearchResult::from_item(&item, boosted_similarity);
887 results_map
888 .entry(item.id.clone())
889 .or_insert((result, boosted_similarity));
890 }
891 }
892
893 if let Some(chunks_table) = &self.chunks_table {
895 let chunk_row_count = chunks_table.count_rows(None).await.unwrap_or(0);
896 let chunk_base_query = chunks_table.vector_search(query_embedding).map_err(|e| {
897 SedimentError::Database(format!("Failed to build chunk search: {}", e))
898 })?;
899 let chunk_results = if chunk_row_count < 5000 {
900 chunk_base_query.bypass_vector_index().limit(limit * 3)
901 } else {
902 chunk_base_query.refine_factor(10).limit(limit * 3)
903 }
904 .execute()
905 .await
906 .map_err(|e| SedimentError::Database(format!("Chunk search failed: {}", e)))?
907 .try_collect::<Vec<_>>()
908 .await
909 .map_err(|e| {
910 SedimentError::Database(format!("Failed to collect chunk results: {}", e))
911 })?;
912
913 let mut chunk_matches: std::collections::HashMap<String, (String, f32)> =
915 std::collections::HashMap::new();
916
917 for batch in chunk_results {
918 let chunks = batch_to_chunks(&batch)?;
919 let distances = batch
920 .column_by_name("_distance")
921 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
922
923 for (i, chunk) in chunks.into_iter().enumerate() {
924 let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
925 let similarity = 1.0 / (1.0 + distance);
926
927 if similarity < min_similarity {
928 continue;
929 }
930
931 chunk_matches
933 .entry(chunk.item_id.clone())
934 .and_modify(|(content, best_sim)| {
935 if similarity > *best_sim {
936 *content = chunk.content.clone();
937 *best_sim = similarity;
938 }
939 })
940 .or_insert((chunk.content.clone(), similarity));
941 }
942 }
943
944 for (item_id, (excerpt, chunk_similarity)) in chunk_matches {
946 if let Some(item) = self.get_item(&item_id).await? {
947 let boosted_similarity = boost_similarity(
949 chunk_similarity,
950 item.project_id.as_deref(),
951 self.project_id.as_deref(),
952 );
953
954 let result =
955 SearchResult::from_item_with_excerpt(&item, boosted_similarity, excerpt);
956
957 results_map
959 .entry(item_id)
960 .and_modify(|(existing, existing_sim)| {
961 if boosted_similarity > *existing_sim {
962 *existing = result.clone();
963 *existing_sim = boosted_similarity;
964 }
965 })
966 .or_insert((result, boosted_similarity));
967 }
968 }
969 }
970
971 let mut search_results: Vec<SearchResult> =
974 results_map.into_values().map(|(sr, _)| sr).collect();
975 search_results.sort_by(|a, b| {
976 b.similarity
977 .partial_cmp(&a.similarity)
978 .unwrap_or(std::cmp::Ordering::Equal)
979 });
980 search_results.truncate(limit);
981
982 Ok(search_results)
983 }
984
985 pub async fn find_similar_items(
990 &mut self,
991 content: &str,
992 min_similarity: f32,
993 limit: usize,
994 ) -> Result<Vec<ConflictInfo>> {
995 let embedding = self.embedder.embed(content)?;
996 self.find_similar_items_by_vector(&embedding, None, min_similarity, limit)
997 .await
998 }
999
1000 pub async fn find_similar_items_by_vector(
1004 &self,
1005 embedding: &[f32],
1006 exclude_id: Option<&str>,
1007 min_similarity: f32,
1008 limit: usize,
1009 ) -> Result<Vec<ConflictInfo>> {
1010 let table = match &self.items_table {
1011 Some(t) => t,
1012 None => return Ok(Vec::new()),
1013 };
1014
1015 let row_count = table.count_rows(None).await.unwrap_or(0);
1016 let base_query = table
1017 .vector_search(embedding.to_vec())
1018 .map_err(|e| SedimentError::Database(format!("Failed to build search: {}", e)))?;
1019 let results = if row_count < 5000 {
1020 base_query.bypass_vector_index().limit(limit)
1021 } else {
1022 base_query.refine_factor(10).limit(limit)
1023 }
1024 .execute()
1025 .await
1026 .map_err(|e| SedimentError::Database(format!("Search failed: {}", e)))?
1027 .try_collect::<Vec<_>>()
1028 .await
1029 .map_err(|e| SedimentError::Database(format!("Failed to collect results: {}", e)))?;
1030
1031 let mut conflicts = Vec::new();
1032
1033 for batch in results {
1034 let items = batch_to_items(&batch)?;
1035 let distances = batch
1036 .column_by_name("_distance")
1037 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
1038
1039 for (i, item) in items.into_iter().enumerate() {
1040 if exclude_id.is_some_and(|eid| eid == item.id) {
1041 continue;
1042 }
1043
1044 let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
1045 let similarity = 1.0 / (1.0 + distance);
1046
1047 if similarity >= min_similarity {
1048 conflicts.push(ConflictInfo {
1049 id: item.id,
1050 content: item.content,
1051 similarity,
1052 });
1053 }
1054 }
1055 }
1056
1057 conflicts.sort_by(|a, b| {
1059 b.similarity
1060 .partial_cmp(&a.similarity)
1061 .unwrap_or(std::cmp::Ordering::Equal)
1062 });
1063
1064 Ok(conflicts)
1065 }
1066
1067 pub async fn list_items(
1069 &mut self,
1070 _filters: ItemFilters,
1071 limit: Option<usize>,
1072 scope: crate::ListScope,
1073 ) -> Result<Vec<Item>> {
1074 let table = match &self.items_table {
1075 Some(t) => t,
1076 None => return Ok(Vec::new()),
1077 };
1078
1079 let mut filter_parts = Vec::new();
1080
1081 match scope {
1083 crate::ListScope::Project => {
1084 if let Some(ref pid) = self.project_id {
1085 if !is_valid_id(pid) {
1086 return Err(SedimentError::Database(
1087 "Invalid project_id for list filter".to_string(),
1088 ));
1089 }
1090 filter_parts.push(format!("project_id = '{}'", sanitize_sql_string(pid)));
1091 } else {
1092 return Ok(Vec::new());
1094 }
1095 }
1096 crate::ListScope::Global => {
1097 filter_parts.push("project_id IS NULL".to_string());
1098 }
1099 crate::ListScope::All => {
1100 }
1102 }
1103
1104 let mut query = table.query();
1105
1106 if !filter_parts.is_empty() {
1107 let filter_str = filter_parts.join(" AND ");
1108 query = query.only_if(filter_str);
1109 }
1110
1111 if let Some(l) = limit {
1112 query = query.limit(l);
1113 }
1114
1115 let results = query
1116 .execute()
1117 .await
1118 .map_err(|e| SedimentError::Database(format!("Query failed: {}", e)))?
1119 .try_collect::<Vec<_>>()
1120 .await
1121 .map_err(|e| SedimentError::Database(format!("Failed to collect: {}", e)))?;
1122
1123 let mut items = Vec::new();
1124 for batch in results {
1125 items.extend(batch_to_items(&batch)?);
1126 }
1127
1128 Ok(items)
1129 }
1130
1131 pub async fn get_item(&self, id: &str) -> Result<Option<Item>> {
1133 if !is_valid_id(id) {
1134 return Ok(None);
1135 }
1136 let table = match &self.items_table {
1137 Some(t) => t,
1138 None => return Ok(None),
1139 };
1140
1141 let results = table
1142 .query()
1143 .only_if(format!("id = '{}'", sanitize_sql_string(id)))
1144 .limit(1)
1145 .execute()
1146 .await
1147 .map_err(|e| SedimentError::Database(format!("Query failed: {}", e)))?
1148 .try_collect::<Vec<_>>()
1149 .await
1150 .map_err(|e| SedimentError::Database(format!("Failed to collect: {}", e)))?;
1151
1152 for batch in results {
1153 let items = batch_to_items(&batch)?;
1154 if let Some(item) = items.into_iter().next() {
1155 return Ok(Some(item));
1156 }
1157 }
1158
1159 Ok(None)
1160 }
1161
1162 pub async fn get_items_batch(&self, ids: &[&str]) -> Result<Vec<Item>> {
1164 let table = match &self.items_table {
1165 Some(t) => t,
1166 None => return Ok(Vec::new()),
1167 };
1168
1169 if ids.is_empty() {
1170 return Ok(Vec::new());
1171 }
1172
1173 let quoted: Vec<String> = ids
1174 .iter()
1175 .filter(|id| is_valid_id(id))
1176 .map(|id| format!("'{}'", sanitize_sql_string(id)))
1177 .collect();
1178 if quoted.is_empty() {
1179 return Ok(Vec::new());
1180 }
1181 let filter = format!("id IN ({})", quoted.join(", "));
1182
1183 let results = table
1184 .query()
1185 .only_if(filter)
1186 .execute()
1187 .await
1188 .map_err(|e| SedimentError::Database(format!("Batch query failed: {}", e)))?
1189 .try_collect::<Vec<_>>()
1190 .await
1191 .map_err(|e| SedimentError::Database(format!("Failed to collect batch: {}", e)))?;
1192
1193 let mut items = Vec::new();
1194 for batch in results {
1195 items.extend(batch_to_items(&batch)?);
1196 }
1197
1198 Ok(items)
1199 }
1200
1201 pub async fn delete_item(&self, id: &str) -> Result<bool> {
1204 if !is_valid_id(id) {
1205 return Ok(false);
1206 }
1207 let table = match &self.items_table {
1209 Some(t) => t,
1210 None => return Ok(false),
1211 };
1212
1213 let exists = self.get_item(id).await?.is_some();
1214 if !exists {
1215 return Ok(false);
1216 }
1217
1218 if let Some(chunks_table) = &self.chunks_table {
1220 chunks_table
1221 .delete(&format!("item_id = '{}'", sanitize_sql_string(id)))
1222 .await
1223 .map_err(|e| SedimentError::Database(format!("Delete chunks failed: {}", e)))?;
1224 }
1225
1226 table
1228 .delete(&format!("id = '{}'", sanitize_sql_string(id)))
1229 .await
1230 .map_err(|e| SedimentError::Database(format!("Delete failed: {}", e)))?;
1231
1232 Ok(true)
1233 }
1234
1235 pub async fn stats(&self) -> Result<DatabaseStats> {
1237 let mut stats = DatabaseStats::default();
1238
1239 if let Some(table) = &self.items_table {
1240 stats.item_count = table
1241 .count_rows(None)
1242 .await
1243 .map_err(|e| SedimentError::Database(format!("Count failed: {}", e)))?;
1244 }
1245
1246 if let Some(table) = &self.chunks_table {
1247 stats.chunk_count = table
1248 .count_rows(None)
1249 .await
1250 .map_err(|e| SedimentError::Database(format!("Count failed: {}", e)))?;
1251 }
1252
1253 Ok(stats)
1254 }
1255}
1256
1257pub async fn migrate_project_id(
1264 db_path: &std::path::Path,
1265 old_id: &str,
1266 new_id: &str,
1267) -> Result<u64> {
1268 if !is_valid_id(old_id) || !is_valid_id(new_id) {
1269 return Err(SedimentError::Database(
1270 "Invalid project ID for migration".to_string(),
1271 ));
1272 }
1273
1274 let db = connect(db_path.to_str().ok_or_else(|| {
1275 SedimentError::Database("Database path contains invalid UTF-8".to_string())
1276 })?)
1277 .execute()
1278 .await
1279 .map_err(|e| SedimentError::Database(format!("Failed to connect for migration: {}", e)))?;
1280
1281 let table_names = db
1282 .table_names()
1283 .execute()
1284 .await
1285 .map_err(|e| SedimentError::Database(format!("Failed to list tables: {}", e)))?;
1286
1287 let mut total_updated = 0u64;
1288
1289 if table_names.contains(&"items".to_string()) {
1290 let table =
1291 db.open_table("items").execute().await.map_err(|e| {
1292 SedimentError::Database(format!("Failed to open items table: {}", e))
1293 })?;
1294
1295 let updated = table
1296 .update()
1297 .only_if(format!("project_id = '{}'", sanitize_sql_string(old_id)))
1298 .column("project_id", format!("'{}'", sanitize_sql_string(new_id)))
1299 .execute()
1300 .await
1301 .map_err(|e| SedimentError::Database(format!("Failed to migrate items: {}", e)))?;
1302
1303 total_updated += updated;
1304 info!(
1305 "Migrated {} items from project {} to {}",
1306 updated, old_id, new_id
1307 );
1308 }
1309
1310 Ok(total_updated)
1311}
1312
1313pub fn score_with_decay(
1324 similarity: f32,
1325 now: i64,
1326 created_at: i64,
1327 access_count: u32,
1328 last_accessed_at: Option<i64>,
1329) -> f32 {
1330 if !similarity.is_finite() {
1332 return 0.0;
1333 }
1334
1335 let reference_time = last_accessed_at.unwrap_or(created_at);
1336 let age_secs = (now - reference_time).max(0) as f64;
1337 let age_days = age_secs / 86400.0;
1338
1339 let freshness = 1.0 / (1.0 + age_days / 30.0);
1340 let frequency = 1.0 + 0.1 * (1.0 + access_count as f64).ln();
1341
1342 let result = similarity * (freshness * frequency) as f32;
1343 if result.is_finite() { result } else { 0.0 }
1344}
1345
1346fn detect_content_type(content: &str) -> ContentType {
1350 let trimmed = content.trim();
1351
1352 if ((trimmed.starts_with('{') && trimmed.ends_with('}'))
1354 || (trimmed.starts_with('[') && trimmed.ends_with(']')))
1355 && serde_json::from_str::<serde_json::Value>(trimmed).is_ok()
1356 {
1357 return ContentType::Json;
1358 }
1359
1360 if trimmed.contains(":\n") || trimmed.contains(": ") || trimmed.starts_with("---") {
1364 let lines: Vec<&str> = trimmed.lines().take(10).collect();
1365 let yaml_key_count = lines
1366 .iter()
1367 .filter(|line| {
1368 let l = line.trim();
1369 !l.is_empty()
1372 && !l.starts_with('#')
1373 && !l.contains("://")
1374 && l.contains(": ")
1375 && l.split(": ").next().is_some_and(|key| {
1376 let k = key.trim_start_matches("- ");
1377 !k.is_empty()
1378 && k.chars()
1379 .all(|c| c.is_alphanumeric() || c == '_' || c == '-')
1380 })
1381 })
1382 .count();
1383 if yaml_key_count >= 2 || (trimmed.starts_with("---") && yaml_key_count >= 1) {
1385 return ContentType::Yaml;
1386 }
1387 }
1388
1389 if trimmed.contains("\n# ") || trimmed.starts_with("# ") || trimmed.contains("\n## ") {
1391 return ContentType::Markdown;
1392 }
1393
1394 let code_patterns = [
1397 "fn ",
1398 "pub fn ",
1399 "def ",
1400 "class ",
1401 "function ",
1402 "const ",
1403 "let ",
1404 "var ",
1405 "import ",
1406 "export ",
1407 "struct ",
1408 "impl ",
1409 "trait ",
1410 ];
1411 let has_code_pattern = trimmed.lines().any(|line| {
1412 let l = line.trim();
1413 code_patterns.iter().any(|p| l.starts_with(p))
1414 });
1415 if has_code_pattern {
1416 return ContentType::Code;
1417 }
1418
1419 ContentType::Text
1420}
1421
1422fn item_to_batch(item: &Item) -> Result<RecordBatch> {
1425 let schema = Arc::new(item_schema());
1426
1427 let id = StringArray::from(vec![item.id.as_str()]);
1428 let content = StringArray::from(vec![item.content.as_str()]);
1429 let project_id = StringArray::from(vec![item.project_id.as_deref()]);
1430 let is_chunked = BooleanArray::from(vec![item.is_chunked]);
1431 let created_at = Int64Array::from(vec![item.created_at.timestamp()]);
1432
1433 let vector = create_embedding_array(&item.embedding)?;
1434
1435 RecordBatch::try_new(
1436 schema,
1437 vec![
1438 Arc::new(id),
1439 Arc::new(content),
1440 Arc::new(project_id),
1441 Arc::new(is_chunked),
1442 Arc::new(created_at),
1443 Arc::new(vector),
1444 ],
1445 )
1446 .map_err(|e| SedimentError::Database(format!("Failed to create batch: {}", e)))
1447}
1448
1449fn batch_to_items(batch: &RecordBatch) -> Result<Vec<Item>> {
1450 let mut items = Vec::new();
1451
1452 let id_col = batch
1453 .column_by_name("id")
1454 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1455 .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?;
1456
1457 let content_col = batch
1458 .column_by_name("content")
1459 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1460 .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?;
1461
1462 let project_id_col = batch
1463 .column_by_name("project_id")
1464 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
1465
1466 let is_chunked_col = batch
1467 .column_by_name("is_chunked")
1468 .and_then(|c| c.as_any().downcast_ref::<BooleanArray>());
1469
1470 let created_at_col = batch
1471 .column_by_name("created_at")
1472 .and_then(|c| c.as_any().downcast_ref::<Int64Array>());
1473
1474 let vector_col = batch
1475 .column_by_name("vector")
1476 .and_then(|c| c.as_any().downcast_ref::<FixedSizeListArray>());
1477
1478 for i in 0..batch.num_rows() {
1479 let id = id_col.value(i).to_string();
1480 let content = content_col.value(i).to_string();
1481
1482 let project_id = project_id_col.and_then(|c| {
1483 if c.is_null(i) {
1484 None
1485 } else {
1486 Some(c.value(i).to_string())
1487 }
1488 });
1489
1490 let is_chunked = is_chunked_col.map(|c| c.value(i)).unwrap_or(false);
1491
1492 let created_at = created_at_col
1493 .map(|c| {
1494 Utc.timestamp_opt(c.value(i), 0)
1495 .single()
1496 .unwrap_or_else(Utc::now)
1497 })
1498 .unwrap_or_else(Utc::now);
1499
1500 let embedding = vector_col
1501 .and_then(|col| {
1502 let value = col.value(i);
1503 value
1504 .as_any()
1505 .downcast_ref::<Float32Array>()
1506 .map(|arr| arr.values().to_vec())
1507 })
1508 .unwrap_or_default();
1509
1510 let item = Item {
1511 id,
1512 content,
1513 embedding,
1514 project_id,
1515 is_chunked,
1516 created_at,
1517 };
1518
1519 items.push(item);
1520 }
1521
1522 Ok(items)
1523}
1524
1525fn chunk_to_batch(chunk: &Chunk) -> Result<RecordBatch> {
1526 let schema = Arc::new(chunk_schema());
1527
1528 let id = StringArray::from(vec![chunk.id.as_str()]);
1529 let item_id = StringArray::from(vec![chunk.item_id.as_str()]);
1530 let chunk_index = Int32Array::from(vec![i32::try_from(chunk.chunk_index).unwrap_or(i32::MAX)]);
1531 let content = StringArray::from(vec![chunk.content.as_str()]);
1532 let context = StringArray::from(vec![chunk.context.as_deref()]);
1533
1534 let vector = create_embedding_array(&chunk.embedding)?;
1535
1536 RecordBatch::try_new(
1537 schema,
1538 vec![
1539 Arc::new(id),
1540 Arc::new(item_id),
1541 Arc::new(chunk_index),
1542 Arc::new(content),
1543 Arc::new(context),
1544 Arc::new(vector),
1545 ],
1546 )
1547 .map_err(|e| SedimentError::Database(format!("Failed to create batch: {}", e)))
1548}
1549
1550fn batch_to_chunks(batch: &RecordBatch) -> Result<Vec<Chunk>> {
1551 let mut chunks = Vec::new();
1552
1553 let id_col = batch
1554 .column_by_name("id")
1555 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1556 .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?;
1557
1558 let item_id_col = batch
1559 .column_by_name("item_id")
1560 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1561 .ok_or_else(|| SedimentError::Database("Missing item_id column".to_string()))?;
1562
1563 let chunk_index_col = batch
1564 .column_by_name("chunk_index")
1565 .and_then(|c| c.as_any().downcast_ref::<Int32Array>())
1566 .ok_or_else(|| SedimentError::Database("Missing chunk_index column".to_string()))?;
1567
1568 let content_col = batch
1569 .column_by_name("content")
1570 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1571 .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?;
1572
1573 let context_col = batch
1574 .column_by_name("context")
1575 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
1576
1577 for i in 0..batch.num_rows() {
1578 let id = id_col.value(i).to_string();
1579 let item_id = item_id_col.value(i).to_string();
1580 let chunk_index = chunk_index_col.value(i) as usize;
1581 let content = content_col.value(i).to_string();
1582 let context = context_col.and_then(|c| {
1583 if c.is_null(i) {
1584 None
1585 } else {
1586 Some(c.value(i).to_string())
1587 }
1588 });
1589
1590 let chunk = Chunk {
1591 id,
1592 item_id,
1593 chunk_index,
1594 content,
1595 embedding: Vec::new(),
1596 context,
1597 };
1598
1599 chunks.push(chunk);
1600 }
1601
1602 Ok(chunks)
1603}
1604
1605fn create_embedding_array(embedding: &[f32]) -> Result<FixedSizeListArray> {
1606 let values = Float32Array::from(embedding.to_vec());
1607 let field = Arc::new(Field::new("item", DataType::Float32, true));
1608
1609 FixedSizeListArray::try_new(field, EMBEDDING_DIM as i32, Arc::new(values), None)
1610 .map_err(|e| SedimentError::Database(format!("Failed to create vector: {}", e)))
1611}
1612
1613#[cfg(test)]
1614mod tests {
1615 use super::*;
1616
1617 #[test]
1618 fn test_score_with_decay_fresh_item() {
1619 let now = 1700000000i64;
1620 let created = now; let score = score_with_decay(0.8, now, created, 0, None);
1622 let expected = 0.8 * 1.0 * 1.0;
1624 assert!((score - expected).abs() < 0.001, "got {}", score);
1625 }
1626
1627 #[test]
1628 fn test_score_with_decay_30_day_old() {
1629 let now = 1700000000i64;
1630 let created = now - 30 * 86400; let score = score_with_decay(0.8, now, created, 0, None);
1632 let expected = 0.8 * 0.5;
1634 assert!((score - expected).abs() < 0.001, "got {}", score);
1635 }
1636
1637 #[test]
1638 fn test_score_with_decay_frequent_access() {
1639 let now = 1700000000i64;
1640 let created = now - 30 * 86400;
1641 let last_accessed = now; let score = score_with_decay(0.8, now, created, 10, Some(last_accessed));
1643 let freq = 1.0 + 0.1 * (11.0_f64).ln();
1645 let expected = 0.8 * 1.0 * freq as f32;
1646 assert!((score - expected).abs() < 0.01, "got {}", score);
1647 }
1648
1649 #[test]
1650 fn test_score_with_decay_old_and_unused() {
1651 let now = 1700000000i64;
1652 let created = now - 90 * 86400; let score = score_with_decay(0.8, now, created, 0, None);
1654 let expected = 0.8 * 0.25;
1656 assert!((score - expected).abs() < 0.001, "got {}", score);
1657 }
1658
1659 #[test]
1660 fn test_sanitize_sql_string_escapes_quotes_and_backslashes() {
1661 assert_eq!(sanitize_sql_string("hello"), "hello");
1662 assert_eq!(sanitize_sql_string("it's"), "it''s");
1663 assert_eq!(sanitize_sql_string(r"a\'b"), r"a\\''b");
1664 assert_eq!(sanitize_sql_string(r"path\to\file"), r"path\\to\\file");
1665 }
1666
1667 #[test]
1668 fn test_sanitize_sql_string_strips_null_bytes() {
1669 assert_eq!(sanitize_sql_string("abc\0def"), "abcdef");
1670 assert_eq!(sanitize_sql_string("\0' OR 1=1 --"), "'' OR 1=1 ");
1671 assert_eq!(sanitize_sql_string("*/ OR 1=1"), " OR 1=1");
1673 assert_eq!(sanitize_sql_string("clean"), "clean");
1674 }
1675
1676 #[test]
1677 fn test_sanitize_sql_string_strips_semicolons() {
1678 assert_eq!(
1679 sanitize_sql_string("a; DROP TABLE items"),
1680 "a DROP TABLE items"
1681 );
1682 assert_eq!(sanitize_sql_string("normal;"), "normal");
1683 }
1684
1685 #[test]
1686 fn test_sanitize_sql_string_strips_comments() {
1687 assert_eq!(sanitize_sql_string("val' -- comment"), "val'' comment");
1689 assert_eq!(sanitize_sql_string("val' /* block */"), "val'' block ");
1691 assert_eq!(sanitize_sql_string("a--b--c"), "abc");
1693 assert_eq!(sanitize_sql_string("injected */ rest"), "injected rest");
1695 assert_eq!(sanitize_sql_string("*/"), "");
1697 }
1698
1699 #[test]
1700 fn test_sanitize_sql_string_adversarial_inputs() {
1701 assert_eq!(
1703 sanitize_sql_string("'; DROP TABLE items;--"),
1704 "'' DROP TABLE items"
1705 );
1706 assert_eq!(
1708 sanitize_sql_string("hello\u{200B}world"),
1709 "hello\u{200B}world"
1710 );
1711 assert_eq!(sanitize_sql_string(""), "");
1713 assert_eq!(sanitize_sql_string("\0;\0"), "");
1715 }
1716
1717 #[test]
1718 fn test_is_valid_id() {
1719 assert!(is_valid_id("550e8400-e29b-41d4-a716-446655440000"));
1721 assert!(is_valid_id("abcdef0123456789"));
1722 assert!(!is_valid_id(""));
1724 assert!(!is_valid_id("'; DROP TABLE items;--"));
1725 assert!(!is_valid_id("hello world"));
1726 assert!(!is_valid_id("abc\0def"));
1727 assert!(!is_valid_id(&"a".repeat(65)));
1729 }
1730
1731 #[test]
1732 fn test_detect_content_type_yaml_not_prose() {
1733 let prose = "Dear John:\nI wanted to write you about something.\nSubject: important matter";
1735 let detected = detect_content_type(prose);
1736 assert_ne!(
1737 detected,
1738 ContentType::Yaml,
1739 "Prose with colons should not be detected as YAML"
1740 );
1741
1742 let yaml = "server: localhost\nport: 8080\ndatabase: mydb";
1744 let detected = detect_content_type(yaml);
1745 assert_eq!(detected, ContentType::Yaml);
1746 }
1747
1748 #[test]
1749 fn test_detect_content_type_yaml_with_separator() {
1750 let yaml = "---\nname: test\nversion: 1.0";
1751 let detected = detect_content_type(yaml);
1752 assert_eq!(detected, ContentType::Yaml);
1753 }
1754
1755 #[test]
1756 fn test_chunk_threshold_uses_chars_not_bytes() {
1757 let emoji_content = "😀".repeat(500);
1760 assert_eq!(emoji_content.chars().count(), 500);
1761 assert_eq!(emoji_content.len(), 2000); let should_chunk = emoji_content.chars().count() > CHUNK_THRESHOLD;
1764 assert!(
1765 !should_chunk,
1766 "500 chars should not exceed 1000-char threshold"
1767 );
1768
1769 let long_content = "a".repeat(1001);
1771 let should_chunk = long_content.chars().count() > CHUNK_THRESHOLD;
1772 assert!(should_chunk, "1001 chars should exceed 1000-char threshold");
1773 }
1774
1775 #[test]
1776 fn test_schema_version() {
1777 let version = SCHEMA_VERSION;
1779 assert!(version >= 2, "Schema version should be at least 2");
1780 }
1781
1782 fn old_item_schema() -> Schema {
1784 Schema::new(vec![
1785 Field::new("id", DataType::Utf8, false),
1786 Field::new("content", DataType::Utf8, false),
1787 Field::new("project_id", DataType::Utf8, true),
1788 Field::new("tags", DataType::Utf8, true), Field::new("is_chunked", DataType::Boolean, false),
1790 Field::new("created_at", DataType::Int64, false),
1791 Field::new(
1792 "vector",
1793 DataType::FixedSizeList(
1794 Arc::new(Field::new("item", DataType::Float32, true)),
1795 EMBEDDING_DIM as i32,
1796 ),
1797 false,
1798 ),
1799 ])
1800 }
1801
1802 fn old_item_batch(id: &str, content: &str) -> RecordBatch {
1804 let schema = Arc::new(old_item_schema());
1805 let vector_values = Float32Array::from(vec![0.0f32; EMBEDDING_DIM]);
1806 let vector_field = Arc::new(Field::new("item", DataType::Float32, true));
1807 let vector = FixedSizeListArray::try_new(
1808 vector_field,
1809 EMBEDDING_DIM as i32,
1810 Arc::new(vector_values),
1811 None,
1812 )
1813 .unwrap();
1814
1815 RecordBatch::try_new(
1816 schema,
1817 vec![
1818 Arc::new(StringArray::from(vec![id])),
1819 Arc::new(StringArray::from(vec![content])),
1820 Arc::new(StringArray::from(vec![None::<&str>])), Arc::new(StringArray::from(vec![None::<&str>])), Arc::new(BooleanArray::from(vec![false])),
1823 Arc::new(Int64Array::from(vec![1700000000i64])),
1824 Arc::new(vector),
1825 ],
1826 )
1827 .unwrap()
1828 }
1829
1830 #[tokio::test]
1831 #[ignore] async fn test_check_needs_migration_detects_old_schema() {
1833 let tmp = tempfile::TempDir::new().unwrap();
1834 let db_path = tmp.path().join("data");
1835
1836 let db_conn = lancedb::connect(db_path.to_str().unwrap())
1838 .execute()
1839 .await
1840 .unwrap();
1841
1842 let schema = Arc::new(old_item_schema());
1843 let batch = old_item_batch("test-id-1", "old content");
1844 let batches = RecordBatchIterator::new(vec![Ok(batch)], schema);
1845 db_conn
1846 .create_table("items", Box::new(batches))
1847 .execute()
1848 .await
1849 .unwrap();
1850
1851 let db = Database {
1853 db: db_conn,
1854 embedder: Arc::new(Embedder::new().unwrap()),
1855 project_id: None,
1856 items_table: None,
1857 chunks_table: None,
1858 };
1859
1860 let needs_migration = db.check_needs_migration().await.unwrap();
1861 assert!(
1862 needs_migration,
1863 "Old schema with tags column should need migration"
1864 );
1865 }
1866
1867 #[tokio::test]
1868 #[ignore] async fn test_check_needs_migration_false_for_new_schema() {
1870 let tmp = tempfile::TempDir::new().unwrap();
1871 let db_path = tmp.path().join("data");
1872
1873 let db_conn = lancedb::connect(db_path.to_str().unwrap())
1875 .execute()
1876 .await
1877 .unwrap();
1878
1879 let schema = Arc::new(item_schema());
1880 db_conn
1881 .create_empty_table("items", schema)
1882 .execute()
1883 .await
1884 .unwrap();
1885
1886 let db = Database {
1887 db: db_conn,
1888 embedder: Arc::new(Embedder::new().unwrap()),
1889 project_id: None,
1890 items_table: None,
1891 chunks_table: None,
1892 };
1893
1894 let needs_migration = db.check_needs_migration().await.unwrap();
1895 assert!(!needs_migration, "New schema should not need migration");
1896 }
1897
1898 #[tokio::test]
1899 #[ignore] async fn test_migrate_schema_preserves_data() {
1901 let tmp = tempfile::TempDir::new().unwrap();
1902 let db_path = tmp.path().join("data");
1903
1904 let db_conn = lancedb::connect(db_path.to_str().unwrap())
1906 .execute()
1907 .await
1908 .unwrap();
1909
1910 let schema = Arc::new(old_item_schema());
1911 let batch1 = old_item_batch("id-aaa", "first item content");
1912 let batch2 = old_item_batch("id-bbb", "second item content");
1913 let batches = RecordBatchIterator::new(vec![Ok(batch1), Ok(batch2)], schema);
1914 db_conn
1915 .create_table("items", Box::new(batches))
1916 .execute()
1917 .await
1918 .unwrap();
1919 drop(db_conn);
1920
1921 let embedder = Arc::new(Embedder::new().unwrap());
1923 let db = Database::open_with_embedder(&db_path, None, embedder)
1924 .await
1925 .unwrap();
1926
1927 let needs_migration = db.check_needs_migration().await.unwrap();
1929 assert!(
1930 !needs_migration,
1931 "Schema should be migrated (no tags column)"
1932 );
1933
1934 let item_a = db.get_item("id-aaa").await.unwrap();
1936 assert!(item_a.is_some(), "Item id-aaa should be preserved");
1937 assert_eq!(item_a.unwrap().content, "first item content");
1938
1939 let item_b = db.get_item("id-bbb").await.unwrap();
1940 assert!(item_b.is_some(), "Item id-bbb should be preserved");
1941 assert_eq!(item_b.unwrap().content, "second item content");
1942
1943 let stats = db.stats().await.unwrap();
1945 assert_eq!(stats.item_count, 2, "Should have 2 items after migration");
1946 }
1947
1948 #[tokio::test]
1949 #[ignore] async fn test_recover_case_a_only_staging() {
1951 let tmp = tempfile::TempDir::new().unwrap();
1952 let db_path = tmp.path().join("data");
1953
1954 let db_conn = lancedb::connect(db_path.to_str().unwrap())
1956 .execute()
1957 .await
1958 .unwrap();
1959
1960 let schema = Arc::new(item_schema());
1961 let vector_values = Float32Array::from(vec![0.0f32; EMBEDDING_DIM]);
1962 let vector_field = Arc::new(Field::new("item", DataType::Float32, true));
1963 let vector = FixedSizeListArray::try_new(
1964 vector_field,
1965 EMBEDDING_DIM as i32,
1966 Arc::new(vector_values),
1967 None,
1968 )
1969 .unwrap();
1970
1971 let batch = RecordBatch::try_new(
1972 schema.clone(),
1973 vec![
1974 Arc::new(StringArray::from(vec!["staging-id"])),
1975 Arc::new(StringArray::from(vec!["staging content"])),
1976 Arc::new(StringArray::from(vec![None::<&str>])),
1977 Arc::new(BooleanArray::from(vec![false])),
1978 Arc::new(Int64Array::from(vec![1700000000i64])),
1979 Arc::new(vector),
1980 ],
1981 )
1982 .unwrap();
1983
1984 let batches = RecordBatchIterator::new(vec![Ok(batch)], schema);
1985 db_conn
1986 .create_table("items_migrated", Box::new(batches))
1987 .execute()
1988 .await
1989 .unwrap();
1990 drop(db_conn);
1991
1992 let embedder = Arc::new(Embedder::new().unwrap());
1994 let db = Database::open_with_embedder(&db_path, None, embedder)
1995 .await
1996 .unwrap();
1997
1998 let item = db.get_item("staging-id").await.unwrap();
2000 assert!(item.is_some(), "Item should be recovered from staging");
2001 assert_eq!(item.unwrap().content, "staging content");
2002
2003 let table_names = db.db.table_names().execute().await.unwrap();
2005 assert!(
2006 !table_names.contains(&"items_migrated".to_string()),
2007 "Staging table should be dropped"
2008 );
2009 }
2010
2011 #[tokio::test]
2012 #[ignore] async fn test_recover_case_b_both_old_schema() {
2014 let tmp = tempfile::TempDir::new().unwrap();
2015 let db_path = tmp.path().join("data");
2016
2017 let db_conn = lancedb::connect(db_path.to_str().unwrap())
2019 .execute()
2020 .await
2021 .unwrap();
2022
2023 let old_schema = Arc::new(old_item_schema());
2025 let batch = old_item_batch("old-id", "old content");
2026 let batches = RecordBatchIterator::new(vec![Ok(batch)], old_schema);
2027 db_conn
2028 .create_table("items", Box::new(batches))
2029 .execute()
2030 .await
2031 .unwrap();
2032
2033 let new_schema = Arc::new(item_schema());
2035 db_conn
2036 .create_empty_table("items_migrated", new_schema)
2037 .execute()
2038 .await
2039 .unwrap();
2040 drop(db_conn);
2041
2042 let embedder = Arc::new(Embedder::new().unwrap());
2044 let db = Database::open_with_embedder(&db_path, None, embedder)
2045 .await
2046 .unwrap();
2047
2048 let needs_migration = db.check_needs_migration().await.unwrap();
2050 assert!(!needs_migration, "Should have migrated after recovery");
2051
2052 let item = db.get_item("old-id").await.unwrap();
2054 assert!(
2055 item.is_some(),
2056 "Item should be preserved through recovery + migration"
2057 );
2058
2059 let table_names = db.db.table_names().execute().await.unwrap();
2061 assert!(
2062 !table_names.contains(&"items_migrated".to_string()),
2063 "Staging table should be dropped"
2064 );
2065 }
2066
2067 #[tokio::test]
2068 #[ignore] async fn test_recover_case_c_both_new_schema() {
2070 let tmp = tempfile::TempDir::new().unwrap();
2071 let db_path = tmp.path().join("data");
2072
2073 let db_conn = lancedb::connect(db_path.to_str().unwrap())
2075 .execute()
2076 .await
2077 .unwrap();
2078
2079 let new_schema = Arc::new(item_schema());
2080
2081 let vector_values = Float32Array::from(vec![0.0f32; EMBEDDING_DIM]);
2083 let vector_field = Arc::new(Field::new("item", DataType::Float32, true));
2084 let vector = FixedSizeListArray::try_new(
2085 vector_field,
2086 EMBEDDING_DIM as i32,
2087 Arc::new(vector_values),
2088 None,
2089 )
2090 .unwrap();
2091
2092 let batch = RecordBatch::try_new(
2093 new_schema.clone(),
2094 vec![
2095 Arc::new(StringArray::from(vec!["new-id"])),
2096 Arc::new(StringArray::from(vec!["new content"])),
2097 Arc::new(StringArray::from(vec![None::<&str>])),
2098 Arc::new(BooleanArray::from(vec![false])),
2099 Arc::new(Int64Array::from(vec![1700000000i64])),
2100 Arc::new(vector),
2101 ],
2102 )
2103 .unwrap();
2104
2105 let batches = RecordBatchIterator::new(vec![Ok(batch)], new_schema.clone());
2106 db_conn
2107 .create_table("items", Box::new(batches))
2108 .execute()
2109 .await
2110 .unwrap();
2111
2112 db_conn
2114 .create_empty_table("items_migrated", new_schema)
2115 .execute()
2116 .await
2117 .unwrap();
2118 drop(db_conn);
2119
2120 let embedder = Arc::new(Embedder::new().unwrap());
2122 let db = Database::open_with_embedder(&db_path, None, embedder)
2123 .await
2124 .unwrap();
2125
2126 let item = db.get_item("new-id").await.unwrap();
2128 assert!(item.is_some(), "Item should be untouched");
2129 assert_eq!(item.unwrap().content, "new content");
2130
2131 let table_names = db.db.table_names().execute().await.unwrap();
2133 assert!(
2134 !table_names.contains(&"items_migrated".to_string()),
2135 "Staging table should be dropped"
2136 );
2137 }
2138
2139 #[tokio::test]
2140 #[ignore] async fn test_list_items_rejects_invalid_project_id() {
2142 let tmp = tempfile::TempDir::new().unwrap();
2143 let db_path = tmp.path().join("data");
2144 let malicious_pid = "'; DROP TABLE items;--".to_string();
2145
2146 let mut db = Database::open_with_project(&db_path, Some(malicious_pid))
2147 .await
2148 .unwrap();
2149
2150 let result = db
2151 .list_items(ItemFilters::new(), Some(10), crate::ListScope::Project)
2152 .await;
2153
2154 assert!(result.is_err(), "Should reject invalid project_id");
2155 let err_msg = result.unwrap_err().to_string();
2156 assert!(
2157 err_msg.contains("Invalid project_id"),
2158 "Error should mention invalid project_id, got: {}",
2159 err_msg
2160 );
2161 }
2162}