Skip to main content

sediment/
db.rs

1//! Database module using LanceDB for vector storage
2//!
3//! Provides a simple interface for storing and searching items
4//! using LanceDB's native vector search capabilities.
5
6use std::path::PathBuf;
7use std::sync::Arc;
8
9use arrow_array::{
10    Array, BooleanArray, FixedSizeListArray, Float32Array, Int32Array, Int64Array, RecordBatch,
11    RecordBatchIterator, StringArray,
12};
13use arrow_schema::{DataType, Field, Schema};
14use chrono::{TimeZone, Utc};
15use futures::TryStreamExt;
16use lancedb::Table;
17use lancedb::connect;
18use lancedb::query::{ExecutableQuery, QueryBase};
19use tracing::{debug, info};
20
21use crate::chunker::{ChunkingConfig, chunk_content};
22use crate::document::ContentType;
23use crate::embedder::{EMBEDDING_DIM, Embedder};
24use crate::error::{Result, SedimentError};
25use crate::item::{Chunk, ConflictInfo, Item, ItemFilters, SearchResult, StoreResult};
26
27/// Threshold for auto-chunking (in characters)
28const CHUNK_THRESHOLD: usize = 1000;
29
30/// Similarity threshold for conflict detection
31const CONFLICT_SIMILARITY_THRESHOLD: f32 = 0.85;
32
33/// Maximum number of conflicts to return
34const CONFLICT_SEARCH_LIMIT: usize = 5;
35
36/// Database wrapper for LanceDB
37pub struct Database {
38    db: lancedb::Connection,
39    embedder: Arc<Embedder>,
40    project_id: Option<String>,
41    items_table: Option<Table>,
42    chunks_table: Option<Table>,
43}
44
45/// Database statistics
46#[derive(Debug, Default, Clone)]
47pub struct DatabaseStats {
48    pub item_count: usize,
49    pub chunk_count: usize,
50}
51
52// Arrow schema builders
53fn item_schema() -> Schema {
54    Schema::new(vec![
55        Field::new("id", DataType::Utf8, false),
56        Field::new("content", DataType::Utf8, false),
57        Field::new("title", DataType::Utf8, true),
58        Field::new("tags", DataType::Utf8, true), // JSON array as string
59        Field::new("source", DataType::Utf8, true),
60        Field::new("metadata", DataType::Utf8, true), // JSON as string
61        Field::new("project_id", DataType::Utf8, true),
62        Field::new("is_chunked", DataType::Boolean, false),
63        Field::new("expires_at", DataType::Int64, true), // Unix timestamp
64        Field::new("created_at", DataType::Int64, false), // Unix timestamp
65        Field::new(
66            "vector",
67            DataType::FixedSizeList(
68                Arc::new(Field::new("item", DataType::Float32, true)),
69                EMBEDDING_DIM as i32,
70            ),
71            false,
72        ),
73    ])
74}
75
76fn chunk_schema() -> Schema {
77    Schema::new(vec![
78        Field::new("id", DataType::Utf8, false),
79        Field::new("item_id", DataType::Utf8, false),
80        Field::new("chunk_index", DataType::Int32, false),
81        Field::new("content", DataType::Utf8, false),
82        Field::new("context", DataType::Utf8, true),
83        Field::new(
84            "vector",
85            DataType::FixedSizeList(
86                Arc::new(Field::new("item", DataType::Float32, true)),
87                EMBEDDING_DIM as i32,
88            ),
89            false,
90        ),
91    ])
92}
93
94impl Database {
95    /// Open or create a database at the given path
96    pub async fn open(path: impl Into<PathBuf>) -> Result<Self> {
97        Self::open_with_project(path, None).await
98    }
99
100    /// Open or create a database at the given path with a project ID
101    pub async fn open_with_project(
102        path: impl Into<PathBuf>,
103        project_id: Option<String>,
104    ) -> Result<Self> {
105        let embedder = Arc::new(Embedder::new()?);
106        Self::open_with_embedder(path, project_id, embedder).await
107    }
108
109    /// Open or create a database with a pre-existing embedder.
110    ///
111    /// This constructor is useful for connection pooling scenarios where
112    /// the expensive embedder should be loaded once and shared across
113    /// multiple database connections.
114    ///
115    /// # Arguments
116    ///
117    /// * `path` - Path to the database directory
118    /// * `project_id` - Optional project ID for scoped operations
119    /// * `embedder` - Shared embedder instance
120    pub async fn open_with_embedder(
121        path: impl Into<PathBuf>,
122        project_id: Option<String>,
123        embedder: Arc<Embedder>,
124    ) -> Result<Self> {
125        let path = path.into();
126        info!("Opening database at {:?}", path);
127
128        // Ensure parent directory exists
129        if let Some(parent) = path.parent() {
130            std::fs::create_dir_all(parent).map_err(|e| {
131                SedimentError::Database(format!("Failed to create database directory: {}", e))
132            })?;
133        }
134
135        let db = connect(path.to_str().unwrap())
136            .execute()
137            .await
138            .map_err(|e| {
139                SedimentError::Database(format!("Failed to connect to database: {}", e))
140            })?;
141
142        let mut database = Self {
143            db,
144            embedder,
145            project_id,
146            items_table: None,
147            chunks_table: None,
148        };
149
150        database.ensure_tables().await?;
151        database.ensure_vector_index().await?;
152
153        Ok(database)
154    }
155
156    /// Set the current project ID for scoped operations
157    pub fn set_project_id(&mut self, project_id: Option<String>) {
158        self.project_id = project_id;
159    }
160
161    /// Get the current project ID
162    pub fn project_id(&self) -> Option<&str> {
163        self.project_id.as_deref()
164    }
165
166    /// Ensure all required tables exist
167    async fn ensure_tables(&mut self) -> Result<()> {
168        // Check for existing tables
169        let table_names = self
170            .db
171            .table_names()
172            .execute()
173            .await
174            .map_err(|e| SedimentError::Database(format!("Failed to list tables: {}", e)))?;
175
176        // Items table
177        if table_names.contains(&"items".to_string()) {
178            self.items_table =
179                Some(self.db.open_table("items").execute().await.map_err(|e| {
180                    SedimentError::Database(format!("Failed to open items: {}", e))
181                })?);
182        }
183
184        // Chunks table
185        if table_names.contains(&"chunks".to_string()) {
186            self.chunks_table =
187                Some(self.db.open_table("chunks").execute().await.map_err(|e| {
188                    SedimentError::Database(format!("Failed to open chunks: {}", e))
189                })?);
190        }
191
192        Ok(())
193    }
194
195    /// Ensure vector indexes exist on tables with enough rows.
196    ///
197    /// LanceDB requires at least 256 rows before creating an index.
198    /// Once created, the index converts brute-force scans to HNSW/IVF-PQ.
199    async fn ensure_vector_index(&self) -> Result<()> {
200        const MIN_ROWS_FOR_INDEX: usize = 256;
201
202        for (name, table_opt) in [("items", &self.items_table), ("chunks", &self.chunks_table)] {
203            if let Some(table) = table_opt {
204                let row_count = table.count_rows(None).await.unwrap_or(0);
205                if row_count < MIN_ROWS_FOR_INDEX {
206                    continue;
207                }
208
209                // Check if index already exists by listing indices
210                let indices = table.list_indices().await.unwrap_or_default();
211
212                let has_vector_index = indices
213                    .iter()
214                    .any(|idx| idx.columns.contains(&"vector".to_string()));
215
216                if !has_vector_index {
217                    info!(
218                        "Creating vector index on {} table ({} rows)",
219                        name, row_count
220                    );
221                    match table
222                        .create_index(&["vector"], lancedb::index::Index::Auto)
223                        .execute()
224                        .await
225                    {
226                        Ok(_) => info!("Vector index created on {} table", name),
227                        Err(e) => {
228                            // Non-fatal: brute-force search still works
229                            tracing::warn!("Failed to create vector index on {}: {}", name, e);
230                        }
231                    }
232                }
233            }
234        }
235
236        Ok(())
237    }
238
239    /// Get or create the items table
240    async fn get_items_table(&mut self) -> Result<&Table> {
241        if self.items_table.is_none() {
242            let schema = Arc::new(item_schema());
243            let table = self
244                .db
245                .create_empty_table("items", schema)
246                .execute()
247                .await
248                .map_err(|e| {
249                    SedimentError::Database(format!("Failed to create items table: {}", e))
250                })?;
251            self.items_table = Some(table);
252        }
253        Ok(self.items_table.as_ref().unwrap())
254    }
255
256    /// Get or create the chunks table
257    async fn get_chunks_table(&mut self) -> Result<&Table> {
258        if self.chunks_table.is_none() {
259            let schema = Arc::new(chunk_schema());
260            let table = self
261                .db
262                .create_empty_table("chunks", schema)
263                .execute()
264                .await
265                .map_err(|e| {
266                    SedimentError::Database(format!("Failed to create chunks table: {}", e))
267                })?;
268            self.chunks_table = Some(table);
269        }
270        Ok(self.chunks_table.as_ref().unwrap())
271    }
272
273    // ==================== Item Operations ====================
274
275    /// Store an item with automatic chunking for long content
276    ///
277    /// Returns a `StoreResult` containing the new item ID and any potential conflicts
278    /// (items with similarity >= 0.85 to the new content).
279    pub async fn store_item(&mut self, mut item: Item) -> Result<StoreResult> {
280        // Set project_id if not already set and we have a current project
281        if item.project_id.is_none() {
282            item.project_id = self.project_id.clone();
283        }
284
285        // Check for potential conflicts before storing
286        let potential_conflicts = self
287            .find_similar_items(
288                &item.content,
289                CONFLICT_SIMILARITY_THRESHOLD,
290                CONFLICT_SEARCH_LIMIT,
291            )
292            .await
293            .unwrap_or_default();
294
295        // Determine if we need to chunk
296        let should_chunk = item.content.len() > CHUNK_THRESHOLD;
297        item.is_chunked = should_chunk;
298
299        // Generate item embedding
300        let embedding_text = item.embedding_text();
301        let embedding = self.embedder.embed(&embedding_text)?;
302        item.embedding = embedding;
303
304        // Store the item
305        let table = self.get_items_table().await?;
306        let batch = item_to_batch(&item)?;
307        let batches = RecordBatchIterator::new(vec![Ok(batch)], Arc::new(item_schema()));
308
309        table
310            .add(Box::new(batches))
311            .execute()
312            .await
313            .map_err(|e| SedimentError::Database(format!("Failed to store item: {}", e)))?;
314
315        // If chunking is needed, create and store chunks
316        if should_chunk {
317            let embedder = self.embedder.clone();
318            let chunks_table = self.get_chunks_table().await?;
319
320            // Detect content type for smart chunking
321            let content_type = detect_content_type(&item.content);
322            let config = ChunkingConfig::default();
323            let chunk_results = chunk_content(&item.content, content_type, &config);
324
325            for (i, chunk_result) in chunk_results.iter().enumerate() {
326                let mut chunk = Chunk::new(&item.id, i, &chunk_result.content);
327
328                if let Some(ctx) = &chunk_result.context {
329                    chunk = chunk.with_context(ctx);
330                }
331
332                let chunk_embedding = embedder.embed(&chunk.content)?;
333                chunk.embedding = chunk_embedding;
334
335                let chunk_batch = chunk_to_batch(&chunk)?;
336                let batches =
337                    RecordBatchIterator::new(vec![Ok(chunk_batch)], Arc::new(chunk_schema()));
338
339                chunks_table
340                    .add(Box::new(batches))
341                    .execute()
342                    .await
343                    .map_err(|e| {
344                        SedimentError::Database(format!("Failed to store chunk: {}", e))
345                    })?;
346            }
347
348            debug!(
349                "Stored item: {} with {} chunks",
350                item.id,
351                chunk_results.len()
352            );
353        } else {
354            debug!("Stored item: {} (no chunking)", item.id);
355        }
356
357        Ok(StoreResult {
358            id: item.id,
359            potential_conflicts,
360        })
361    }
362
363    /// Search items by semantic similarity
364    pub async fn search_items(
365        &mut self,
366        query: &str,
367        limit: usize,
368        filters: ItemFilters,
369    ) -> Result<Vec<SearchResult>> {
370        // Generate query embedding
371        let query_embedding = self.embedder.embed(query)?;
372        let min_similarity = filters.min_similarity.unwrap_or(0.3);
373
374        // We need to search both items and chunks, then merge results
375        let mut results_map: std::collections::HashMap<String, (SearchResult, f32)> =
376            std::collections::HashMap::new();
377
378        // Search items table directly (for non-chunked items and chunked items by title)
379        if let Some(table) = &self.items_table {
380            let mut filter_parts = Vec::new();
381
382            if !filters.include_expired {
383                let now = Utc::now().timestamp();
384                filter_parts.push(format!("(expires_at IS NULL OR expires_at > {})", now));
385            }
386
387            let mut query_builder = table
388                .vector_search(query_embedding.clone())
389                .map_err(|e| SedimentError::Database(format!("Failed to build search: {}", e)))?
390                .limit(limit * 2);
391
392            if !filter_parts.is_empty() {
393                let filter_str = filter_parts.join(" AND ");
394                query_builder = query_builder.only_if(filter_str);
395            }
396
397            let results = query_builder
398                .execute()
399                .await
400                .map_err(|e| SedimentError::Database(format!("Search failed: {}", e)))?
401                .try_collect::<Vec<_>>()
402                .await
403                .map_err(|e| {
404                    SedimentError::Database(format!("Failed to collect results: {}", e))
405                })?;
406
407            for batch in results {
408                let items = batch_to_items(&batch)?;
409                let distances = batch
410                    .column_by_name("_distance")
411                    .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
412
413                for (i, item) in items.into_iter().enumerate() {
414                    let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
415                    let similarity = 1.0 / (1.0 + distance);
416
417                    if similarity < min_similarity {
418                        continue;
419                    }
420
421                    // Apply tag filter
422                    if let Some(ref filter_tags) = filters.tags
423                        && !filter_tags.iter().any(|t| item.tags.contains(t))
424                    {
425                        continue;
426                    }
427
428                    // Apply project boosting
429                    let boosted_similarity = boost_similarity(
430                        similarity,
431                        item.project_id.as_deref(),
432                        self.project_id.as_deref(),
433                    );
434
435                    let result = SearchResult::from_item(&item, boosted_similarity);
436                    results_map
437                        .entry(item.id.clone())
438                        .or_insert((result, boosted_similarity));
439                }
440            }
441        }
442
443        // Search chunks table (for chunked items)
444        if let Some(chunks_table) = &self.chunks_table {
445            let chunk_results = chunks_table
446                .vector_search(query_embedding)
447                .map_err(|e| {
448                    SedimentError::Database(format!("Failed to build chunk search: {}", e))
449                })?
450                .limit(limit * 3)
451                .execute()
452                .await
453                .map_err(|e| SedimentError::Database(format!("Chunk search failed: {}", e)))?
454                .try_collect::<Vec<_>>()
455                .await
456                .map_err(|e| {
457                    SedimentError::Database(format!("Failed to collect chunk results: {}", e))
458                })?;
459
460            // Group chunks by item and find best chunk for each item
461            let mut chunk_matches: std::collections::HashMap<String, (String, f32)> =
462                std::collections::HashMap::new();
463
464            for batch in chunk_results {
465                let chunks = batch_to_chunks(&batch)?;
466                let distances = batch
467                    .column_by_name("_distance")
468                    .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
469
470                for (i, chunk) in chunks.into_iter().enumerate() {
471                    let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
472                    let similarity = 1.0 / (1.0 + distance);
473
474                    if similarity < min_similarity {
475                        continue;
476                    }
477
478                    // Keep track of best matching chunk per item
479                    chunk_matches
480                        .entry(chunk.item_id.clone())
481                        .and_modify(|(content, best_sim)| {
482                            if similarity > *best_sim {
483                                *content = chunk.content.clone();
484                                *best_sim = similarity;
485                            }
486                        })
487                        .or_insert((chunk.content.clone(), similarity));
488                }
489            }
490
491            // Fetch parent items for chunk matches
492            for (item_id, (excerpt, chunk_similarity)) in chunk_matches {
493                if let Some(item) = self.get_item(&item_id).await? {
494                    // Apply tag filter
495                    if let Some(ref filter_tags) = filters.tags
496                        && !filter_tags.iter().any(|t| item.tags.contains(t))
497                    {
498                        continue;
499                    }
500
501                    // Apply project boosting
502                    let boosted_similarity = boost_similarity(
503                        chunk_similarity,
504                        item.project_id.as_deref(),
505                        self.project_id.as_deref(),
506                    );
507
508                    let result =
509                        SearchResult::from_item_with_excerpt(&item, boosted_similarity, excerpt);
510
511                    // Update if this chunk-based result is better
512                    results_map
513                        .entry(item_id)
514                        .and_modify(|(existing, existing_sim)| {
515                            if boosted_similarity > *existing_sim {
516                                *existing = result.clone();
517                                *existing_sim = boosted_similarity;
518                            }
519                        })
520                        .or_insert((result, boosted_similarity));
521                }
522            }
523        }
524
525        // Convert map to sorted vec
526        let mut search_results: Vec<SearchResult> =
527            results_map.into_values().map(|(r, _)| r).collect();
528        search_results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
529        search_results.truncate(limit);
530
531        Ok(search_results)
532    }
533
534    /// Find items similar to the given content (for conflict detection)
535    ///
536    /// This searches the items table directly by content embedding to find
537    /// potentially conflicting items before storing new content.
538    pub async fn find_similar_items(
539        &mut self,
540        content: &str,
541        min_similarity: f32,
542        limit: usize,
543    ) -> Result<Vec<ConflictInfo>> {
544        // Generate embedding for the content
545        let embedding = self.embedder.embed(content)?;
546
547        let table = match &self.items_table {
548            Some(t) => t,
549            None => return Ok(Vec::new()),
550        };
551
552        // Build filter for non-expired items
553        let now = Utc::now().timestamp();
554        let filter = format!("(expires_at IS NULL OR expires_at > {})", now);
555
556        let results = table
557            .vector_search(embedding)
558            .map_err(|e| SedimentError::Database(format!("Failed to build search: {}", e)))?
559            .limit(limit)
560            .only_if(filter)
561            .execute()
562            .await
563            .map_err(|e| SedimentError::Database(format!("Search failed: {}", e)))?
564            .try_collect::<Vec<_>>()
565            .await
566            .map_err(|e| SedimentError::Database(format!("Failed to collect results: {}", e)))?;
567
568        let mut conflicts = Vec::new();
569
570        for batch in results {
571            let items = batch_to_items(&batch)?;
572            let distances = batch
573                .column_by_name("_distance")
574                .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
575
576            for (i, item) in items.into_iter().enumerate() {
577                let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
578                let similarity = 1.0 / (1.0 + distance);
579
580                if similarity >= min_similarity {
581                    conflicts.push(ConflictInfo {
582                        id: item.id,
583                        content: item.content,
584                        similarity,
585                    });
586                }
587            }
588        }
589
590        // Sort by similarity descending
591        conflicts.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
592
593        Ok(conflicts)
594    }
595
596    /// List items with optional filters
597    pub async fn list_items(
598        &mut self,
599        filters: ItemFilters,
600        limit: Option<usize>,
601        scope: crate::ListScope,
602    ) -> Result<Vec<Item>> {
603        let table = match &self.items_table {
604            Some(t) => t,
605            None => return Ok(Vec::new()),
606        };
607
608        let mut filter_parts = Vec::new();
609
610        if !filters.include_expired {
611            let now = Utc::now().timestamp();
612            filter_parts.push(format!("(expires_at IS NULL OR expires_at > {})", now));
613        }
614
615        // Apply scope filter
616        match scope {
617            crate::ListScope::Project => {
618                if let Some(ref pid) = self.project_id {
619                    filter_parts.push(format!("project_id = '{}'", pid));
620                }
621            }
622            crate::ListScope::Global => {
623                filter_parts.push("project_id IS NULL".to_string());
624            }
625            crate::ListScope::All => {
626                // No additional filter
627            }
628        }
629
630        let mut query = table.query();
631
632        if !filter_parts.is_empty() {
633            let filter_str = filter_parts.join(" AND ");
634            query = query.only_if(filter_str);
635        }
636
637        if let Some(l) = limit {
638            query = query.limit(l);
639        }
640
641        let results = query
642            .execute()
643            .await
644            .map_err(|e| SedimentError::Database(format!("Query failed: {}", e)))?
645            .try_collect::<Vec<_>>()
646            .await
647            .map_err(|e| SedimentError::Database(format!("Failed to collect: {}", e)))?;
648
649        let mut items = Vec::new();
650        for batch in results {
651            items.extend(batch_to_items(&batch)?);
652        }
653
654        // Apply tag filter
655        if let Some(ref filter_tags) = filters.tags {
656            items.retain(|item| filter_tags.iter().any(|t| item.tags.contains(t)));
657        }
658
659        Ok(items)
660    }
661
662    /// Get an item by ID
663    pub async fn get_item(&self, id: &str) -> Result<Option<Item>> {
664        let table = match &self.items_table {
665            Some(t) => t,
666            None => return Ok(None),
667        };
668
669        let results = table
670            .query()
671            .only_if(format!("id = '{}'", id))
672            .limit(1)
673            .execute()
674            .await
675            .map_err(|e| SedimentError::Database(format!("Query failed: {}", e)))?
676            .try_collect::<Vec<_>>()
677            .await
678            .map_err(|e| SedimentError::Database(format!("Failed to collect: {}", e)))?;
679
680        for batch in results {
681            let items = batch_to_items(&batch)?;
682            if let Some(item) = items.into_iter().next() {
683                return Ok(Some(item));
684            }
685        }
686
687        Ok(None)
688    }
689
690    /// Get multiple items by ID in a single query
691    pub async fn get_items_batch(&self, ids: &[&str]) -> Result<Vec<Item>> {
692        let table = match &self.items_table {
693            Some(t) => t,
694            None => return Ok(Vec::new()),
695        };
696
697        if ids.is_empty() {
698            return Ok(Vec::new());
699        }
700
701        let quoted: Vec<String> = ids.iter().map(|id| format!("'{}'", id)).collect();
702        let filter = format!("id IN ({})", quoted.join(", "));
703
704        let results = table
705            .query()
706            .only_if(filter)
707            .execute()
708            .await
709            .map_err(|e| SedimentError::Database(format!("Batch query failed: {}", e)))?
710            .try_collect::<Vec<_>>()
711            .await
712            .map_err(|e| SedimentError::Database(format!("Failed to collect batch: {}", e)))?;
713
714        let mut items = Vec::new();
715        for batch in results {
716            items.extend(batch_to_items(&batch)?);
717        }
718
719        Ok(items)
720    }
721
722    /// Delete an item and its chunks
723    pub async fn delete_item(&self, id: &str) -> Result<bool> {
724        // Delete chunks first
725        if let Some(chunks_table) = &self.chunks_table {
726            chunks_table
727                .delete(&format!("item_id = '{}'", id))
728                .await
729                .map_err(|e| SedimentError::Database(format!("Delete chunks failed: {}", e)))?;
730        }
731
732        // Delete item
733        let table = match &self.items_table {
734            Some(t) => t,
735            None => return Ok(false),
736        };
737
738        table
739            .delete(&format!("id = '{}'", id))
740            .await
741            .map_err(|e| SedimentError::Database(format!("Delete failed: {}", e)))?;
742
743        Ok(true)
744    }
745
746    /// Get database statistics
747    pub async fn stats(&self) -> Result<DatabaseStats> {
748        let mut stats = DatabaseStats::default();
749
750        if let Some(table) = &self.items_table {
751            stats.item_count = table
752                .count_rows(None)
753                .await
754                .map_err(|e| SedimentError::Database(format!("Count failed: {}", e)))?;
755        }
756
757        if let Some(table) = &self.chunks_table {
758            stats.chunk_count = table
759                .count_rows(None)
760                .await
761                .map_err(|e| SedimentError::Database(format!("Count failed: {}", e)))?;
762        }
763
764        Ok(stats)
765    }
766}
767
768// ==================== Decay Scoring ====================
769
770/// Compute a decay-adjusted score for a search result.
771///
772/// Formula: `similarity * freshness * frequency`
773/// - freshness = 1.0 / (1.0 + age_days / 30.0)  (half-life ~30 days)
774/// - frequency = 1.0 + 0.1 * ln(1 + access_count)
775///
776/// `last_accessed_at` and `created_at` are unix timestamps.
777/// If no access record exists, pass `access_count=0` and use `created_at` for age.
778pub fn score_with_decay(
779    similarity: f32,
780    now: i64,
781    created_at: i64,
782    access_count: u32,
783    last_accessed_at: Option<i64>,
784) -> f32 {
785    let reference_time = last_accessed_at.unwrap_or(created_at);
786    let age_secs = (now - reference_time).max(0) as f64;
787    let age_days = age_secs / 86400.0;
788
789    let freshness = 1.0 / (1.0 + age_days / 30.0);
790    let frequency = 1.0 + 0.1 * (1.0 + access_count as f64).ln();
791
792    similarity * (freshness * frequency) as f32
793}
794
795// ==================== Helper Functions ====================
796
797/// Apply similarity boosting based on project context.
798fn boost_similarity(base: f32, item_project: Option<&str>, current_project: Option<&str>) -> f32 {
799    match (item_project, current_project) {
800        (Some(m), Some(c)) if m == c => (base * 1.15).min(1.0), // Same project: boost
801        (Some(_), Some(_)) => base * 0.95,                      // Different project: slight penalty
802        _ => base,                                              // Global or no context
803    }
804}
805
806/// Detect content type for smart chunking
807fn detect_content_type(content: &str) -> ContentType {
808    let trimmed = content.trim();
809
810    // Check for JSON
811    if ((trimmed.starts_with('{') && trimmed.ends_with('}'))
812        || (trimmed.starts_with('[') && trimmed.ends_with(']')))
813        && serde_json::from_str::<serde_json::Value>(trimmed).is_ok()
814    {
815        return ContentType::Json;
816    }
817
818    // Check for YAML (common patterns)
819    if trimmed.contains(":\n") || trimmed.starts_with("---") {
820        // Simple heuristic: looks like YAML if it has key: value patterns
821        let lines: Vec<&str> = trimmed.lines().take(5).collect();
822        let yaml_like = lines.iter().any(|line| {
823            let l = line.trim();
824            !l.is_empty() && !l.starts_with('#') && l.contains(':') && !l.starts_with("http")
825        });
826        if yaml_like {
827            return ContentType::Yaml;
828        }
829    }
830
831    // Check for Markdown (has headers)
832    if trimmed.contains("\n# ") || trimmed.starts_with("# ") || trimmed.contains("\n## ") {
833        return ContentType::Markdown;
834    }
835
836    // Check for code (common patterns)
837    let code_patterns = [
838        "fn ",
839        "pub fn ",
840        "def ",
841        "class ",
842        "function ",
843        "const ",
844        "let ",
845        "var ",
846        "import ",
847        "export ",
848        "struct ",
849        "impl ",
850        "trait ",
851    ];
852    if code_patterns.iter().any(|p| trimmed.contains(p)) {
853        return ContentType::Code;
854    }
855
856    ContentType::Text
857}
858
859// ==================== Arrow Conversion Helpers ====================
860
861fn item_to_batch(item: &Item) -> Result<RecordBatch> {
862    let schema = Arc::new(item_schema());
863
864    let id = StringArray::from(vec![item.id.as_str()]);
865    let content = StringArray::from(vec![item.content.as_str()]);
866    let title = StringArray::from(vec![item.title.as_deref()]);
867    let tags = StringArray::from(vec![serde_json::to_string(&item.tags).ok()]);
868    let source = StringArray::from(vec![item.source.as_deref()]);
869    let metadata = StringArray::from(vec![item.metadata.as_ref().map(|m| m.to_string())]);
870    let project_id = StringArray::from(vec![item.project_id.as_deref()]);
871    let is_chunked = BooleanArray::from(vec![item.is_chunked]);
872    let expires_at = Int64Array::from(vec![item.expires_at.map(|t| t.timestamp())]);
873    let created_at = Int64Array::from(vec![item.created_at.timestamp()]);
874
875    let vector = create_embedding_array(&item.embedding)?;
876
877    RecordBatch::try_new(
878        schema,
879        vec![
880            Arc::new(id),
881            Arc::new(content),
882            Arc::new(title),
883            Arc::new(tags),
884            Arc::new(source),
885            Arc::new(metadata),
886            Arc::new(project_id),
887            Arc::new(is_chunked),
888            Arc::new(expires_at),
889            Arc::new(created_at),
890            Arc::new(vector),
891        ],
892    )
893    .map_err(|e| SedimentError::Database(format!("Failed to create batch: {}", e)))
894}
895
896fn batch_to_items(batch: &RecordBatch) -> Result<Vec<Item>> {
897    let mut items = Vec::new();
898
899    let id_col = batch
900        .column_by_name("id")
901        .and_then(|c| c.as_any().downcast_ref::<StringArray>())
902        .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?;
903
904    let content_col = batch
905        .column_by_name("content")
906        .and_then(|c| c.as_any().downcast_ref::<StringArray>())
907        .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?;
908
909    let title_col = batch
910        .column_by_name("title")
911        .and_then(|c| c.as_any().downcast_ref::<StringArray>());
912
913    let tags_col = batch
914        .column_by_name("tags")
915        .and_then(|c| c.as_any().downcast_ref::<StringArray>());
916
917    let source_col = batch
918        .column_by_name("source")
919        .and_then(|c| c.as_any().downcast_ref::<StringArray>());
920
921    let metadata_col = batch
922        .column_by_name("metadata")
923        .and_then(|c| c.as_any().downcast_ref::<StringArray>());
924
925    let project_id_col = batch
926        .column_by_name("project_id")
927        .and_then(|c| c.as_any().downcast_ref::<StringArray>());
928
929    let is_chunked_col = batch
930        .column_by_name("is_chunked")
931        .and_then(|c| c.as_any().downcast_ref::<BooleanArray>());
932
933    let expires_at_col = batch
934        .column_by_name("expires_at")
935        .and_then(|c| c.as_any().downcast_ref::<Int64Array>());
936
937    let created_at_col = batch
938        .column_by_name("created_at")
939        .and_then(|c| c.as_any().downcast_ref::<Int64Array>());
940
941    for i in 0..batch.num_rows() {
942        let id = id_col.value(i).to_string();
943        let content = content_col.value(i).to_string();
944
945        let title = title_col.and_then(|c| {
946            if c.is_null(i) {
947                None
948            } else {
949                Some(c.value(i).to_string())
950            }
951        });
952
953        let tags: Vec<String> = tags_col
954            .and_then(|c| {
955                if c.is_null(i) {
956                    None
957                } else {
958                    serde_json::from_str(c.value(i)).ok()
959                }
960            })
961            .unwrap_or_default();
962
963        let source = source_col.and_then(|c| {
964            if c.is_null(i) {
965                None
966            } else {
967                Some(c.value(i).to_string())
968            }
969        });
970
971        let metadata = metadata_col.and_then(|c| {
972            if c.is_null(i) {
973                None
974            } else {
975                serde_json::from_str(c.value(i)).ok()
976            }
977        });
978
979        let project_id = project_id_col.and_then(|c| {
980            if c.is_null(i) {
981                None
982            } else {
983                Some(c.value(i).to_string())
984            }
985        });
986
987        let is_chunked = is_chunked_col.map(|c| c.value(i)).unwrap_or(false);
988
989        let expires_at = expires_at_col.and_then(|c| {
990            if c.is_null(i) {
991                None
992            } else {
993                Some(Utc.timestamp_opt(c.value(i), 0).unwrap())
994            }
995        });
996
997        let created_at = created_at_col
998            .map(|c| Utc.timestamp_opt(c.value(i), 0).unwrap())
999            .unwrap_or_else(Utc::now);
1000
1001        let item = Item {
1002            id,
1003            content,
1004            embedding: Vec::new(),
1005            title,
1006            tags,
1007            source,
1008            metadata,
1009            project_id,
1010            is_chunked,
1011            expires_at,
1012            created_at,
1013        };
1014
1015        items.push(item);
1016    }
1017
1018    Ok(items)
1019}
1020
1021fn chunk_to_batch(chunk: &Chunk) -> Result<RecordBatch> {
1022    let schema = Arc::new(chunk_schema());
1023
1024    let id = StringArray::from(vec![chunk.id.as_str()]);
1025    let item_id = StringArray::from(vec![chunk.item_id.as_str()]);
1026    let chunk_index = Int32Array::from(vec![chunk.chunk_index as i32]);
1027    let content = StringArray::from(vec![chunk.content.as_str()]);
1028    let context = StringArray::from(vec![chunk.context.as_deref()]);
1029
1030    let vector = create_embedding_array(&chunk.embedding)?;
1031
1032    RecordBatch::try_new(
1033        schema,
1034        vec![
1035            Arc::new(id),
1036            Arc::new(item_id),
1037            Arc::new(chunk_index),
1038            Arc::new(content),
1039            Arc::new(context),
1040            Arc::new(vector),
1041        ],
1042    )
1043    .map_err(|e| SedimentError::Database(format!("Failed to create batch: {}", e)))
1044}
1045
1046fn batch_to_chunks(batch: &RecordBatch) -> Result<Vec<Chunk>> {
1047    let mut chunks = Vec::new();
1048
1049    let id_col = batch
1050        .column_by_name("id")
1051        .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1052        .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?;
1053
1054    let item_id_col = batch
1055        .column_by_name("item_id")
1056        .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1057        .ok_or_else(|| SedimentError::Database("Missing item_id column".to_string()))?;
1058
1059    let chunk_index_col = batch
1060        .column_by_name("chunk_index")
1061        .and_then(|c| c.as_any().downcast_ref::<Int32Array>())
1062        .ok_or_else(|| SedimentError::Database("Missing chunk_index column".to_string()))?;
1063
1064    let content_col = batch
1065        .column_by_name("content")
1066        .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1067        .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?;
1068
1069    let context_col = batch
1070        .column_by_name("context")
1071        .and_then(|c| c.as_any().downcast_ref::<StringArray>());
1072
1073    for i in 0..batch.num_rows() {
1074        let id = id_col.value(i).to_string();
1075        let item_id = item_id_col.value(i).to_string();
1076        let chunk_index = chunk_index_col.value(i) as usize;
1077        let content = content_col.value(i).to_string();
1078        let context = context_col.and_then(|c| {
1079            if c.is_null(i) {
1080                None
1081            } else {
1082                Some(c.value(i).to_string())
1083            }
1084        });
1085
1086        let chunk = Chunk {
1087            id,
1088            item_id,
1089            chunk_index,
1090            content,
1091            embedding: Vec::new(),
1092            context,
1093        };
1094
1095        chunks.push(chunk);
1096    }
1097
1098    Ok(chunks)
1099}
1100
1101fn create_embedding_array(embedding: &[f32]) -> Result<FixedSizeListArray> {
1102    let values = Float32Array::from(embedding.to_vec());
1103    let field = Arc::new(Field::new("item", DataType::Float32, true));
1104
1105    FixedSizeListArray::try_new(field, EMBEDDING_DIM as i32, Arc::new(values), None)
1106        .map_err(|e| SedimentError::Database(format!("Failed to create vector: {}", e)))
1107}
1108
1109#[cfg(test)]
1110mod tests {
1111    use super::*;
1112
1113    #[test]
1114    fn test_score_with_decay_fresh_item() {
1115        let now = 1700000000i64;
1116        let created = now; // just created
1117        let score = score_with_decay(0.8, now, created, 0, None);
1118        // freshness = 1.0, frequency = 1.0 + 0.1 * ln(1) = 1.0
1119        let expected = 0.8 * 1.0 * 1.0;
1120        assert!((score - expected).abs() < 0.001, "got {}", score);
1121    }
1122
1123    #[test]
1124    fn test_score_with_decay_30_day_old() {
1125        let now = 1700000000i64;
1126        let created = now - 30 * 86400; // 30 days old
1127        let score = score_with_decay(0.8, now, created, 0, None);
1128        // freshness = 1/(1+1) = 0.5, frequency = 1.0
1129        let expected = 0.8 * 0.5;
1130        assert!((score - expected).abs() < 0.001, "got {}", score);
1131    }
1132
1133    #[test]
1134    fn test_score_with_decay_frequent_access() {
1135        let now = 1700000000i64;
1136        let created = now - 30 * 86400;
1137        let last_accessed = now; // just accessed
1138        let score = score_with_decay(0.8, now, created, 10, Some(last_accessed));
1139        // freshness = 1.0 (just accessed), frequency = 1.0 + 0.1 * ln(11) ≈ 1.2397
1140        let freq = 1.0 + 0.1 * (11.0_f64).ln();
1141        let expected = 0.8 * 1.0 * freq as f32;
1142        assert!((score - expected).abs() < 0.01, "got {}", score);
1143    }
1144
1145    #[test]
1146    fn test_score_with_decay_old_and_unused() {
1147        let now = 1700000000i64;
1148        let created = now - 90 * 86400; // 90 days old
1149        let score = score_with_decay(0.8, now, created, 0, None);
1150        // freshness = 1/(1+3) = 0.25
1151        let expected = 0.8 * 0.25;
1152        assert!((score - expected).abs() < 0.001, "got {}", score);
1153    }
1154}