1use std::path::PathBuf;
7use std::sync::Arc;
8
9use arrow_array::{
10 Array, BooleanArray, FixedSizeListArray, Float32Array, Int32Array, Int64Array, RecordBatch,
11 RecordBatchIterator, StringArray,
12};
13use arrow_schema::{DataType, Field, Schema};
14use chrono::{TimeZone, Utc};
15use futures::TryStreamExt;
16use lancedb::Table;
17use lancedb::connect;
18use lancedb::query::{ExecutableQuery, QueryBase};
19use tracing::{debug, info};
20
21use crate::chunker::{ChunkingConfig, chunk_content};
22use crate::document::ContentType;
23use crate::embedder::{EMBEDDING_DIM, Embedder};
24use crate::error::{Result, SedimentError};
25use crate::item::{Chunk, ConflictInfo, Item, ItemFilters, SearchResult, StoreResult};
26
27const CHUNK_THRESHOLD: usize = 1000;
29
30const CONFLICT_SIMILARITY_THRESHOLD: f32 = 0.85;
32
33const CONFLICT_SEARCH_LIMIT: usize = 5;
35
36pub struct Database {
38 db: lancedb::Connection,
39 embedder: Arc<Embedder>,
40 project_id: Option<String>,
41 items_table: Option<Table>,
42 chunks_table: Option<Table>,
43}
44
45#[derive(Debug, Default, Clone)]
47pub struct DatabaseStats {
48 pub item_count: usize,
49 pub chunk_count: usize,
50}
51
52fn item_schema() -> Schema {
54 Schema::new(vec![
55 Field::new("id", DataType::Utf8, false),
56 Field::new("content", DataType::Utf8, false),
57 Field::new("title", DataType::Utf8, true),
58 Field::new("tags", DataType::Utf8, true), Field::new("source", DataType::Utf8, true),
60 Field::new("metadata", DataType::Utf8, true), Field::new("project_id", DataType::Utf8, true),
62 Field::new("is_chunked", DataType::Boolean, false),
63 Field::new("expires_at", DataType::Int64, true), Field::new("created_at", DataType::Int64, false), Field::new(
66 "vector",
67 DataType::FixedSizeList(
68 Arc::new(Field::new("item", DataType::Float32, true)),
69 EMBEDDING_DIM as i32,
70 ),
71 false,
72 ),
73 ])
74}
75
76fn chunk_schema() -> Schema {
77 Schema::new(vec![
78 Field::new("id", DataType::Utf8, false),
79 Field::new("item_id", DataType::Utf8, false),
80 Field::new("chunk_index", DataType::Int32, false),
81 Field::new("content", DataType::Utf8, false),
82 Field::new("context", DataType::Utf8, true),
83 Field::new(
84 "vector",
85 DataType::FixedSizeList(
86 Arc::new(Field::new("item", DataType::Float32, true)),
87 EMBEDDING_DIM as i32,
88 ),
89 false,
90 ),
91 ])
92}
93
94impl Database {
95 pub async fn open(path: impl Into<PathBuf>) -> Result<Self> {
97 Self::open_with_project(path, None).await
98 }
99
100 pub async fn open_with_project(
102 path: impl Into<PathBuf>,
103 project_id: Option<String>,
104 ) -> Result<Self> {
105 let embedder = Arc::new(Embedder::new()?);
106 Self::open_with_embedder(path, project_id, embedder).await
107 }
108
109 pub async fn open_with_embedder(
121 path: impl Into<PathBuf>,
122 project_id: Option<String>,
123 embedder: Arc<Embedder>,
124 ) -> Result<Self> {
125 let path = path.into();
126 info!("Opening database at {:?}", path);
127
128 if let Some(parent) = path.parent() {
130 std::fs::create_dir_all(parent).map_err(|e| {
131 SedimentError::Database(format!("Failed to create database directory: {}", e))
132 })?;
133 }
134
135 let db = connect(path.to_str().unwrap())
136 .execute()
137 .await
138 .map_err(|e| {
139 SedimentError::Database(format!("Failed to connect to database: {}", e))
140 })?;
141
142 let mut database = Self {
143 db,
144 embedder,
145 project_id,
146 items_table: None,
147 chunks_table: None,
148 };
149
150 database.ensure_tables().await?;
151 database.ensure_vector_index().await?;
152
153 Ok(database)
154 }
155
156 pub fn set_project_id(&mut self, project_id: Option<String>) {
158 self.project_id = project_id;
159 }
160
161 pub fn project_id(&self) -> Option<&str> {
163 self.project_id.as_deref()
164 }
165
166 async fn ensure_tables(&mut self) -> Result<()> {
168 let table_names = self
170 .db
171 .table_names()
172 .execute()
173 .await
174 .map_err(|e| SedimentError::Database(format!("Failed to list tables: {}", e)))?;
175
176 if table_names.contains(&"items".to_string()) {
178 self.items_table =
179 Some(self.db.open_table("items").execute().await.map_err(|e| {
180 SedimentError::Database(format!("Failed to open items: {}", e))
181 })?);
182 }
183
184 if table_names.contains(&"chunks".to_string()) {
186 self.chunks_table =
187 Some(self.db.open_table("chunks").execute().await.map_err(|e| {
188 SedimentError::Database(format!("Failed to open chunks: {}", e))
189 })?);
190 }
191
192 Ok(())
193 }
194
195 async fn ensure_vector_index(&self) -> Result<()> {
200 const MIN_ROWS_FOR_INDEX: usize = 256;
201
202 for (name, table_opt) in [("items", &self.items_table), ("chunks", &self.chunks_table)] {
203 if let Some(table) = table_opt {
204 let row_count = table.count_rows(None).await.unwrap_or(0);
205 if row_count < MIN_ROWS_FOR_INDEX {
206 continue;
207 }
208
209 let indices = table.list_indices().await.unwrap_or_default();
211
212 let has_vector_index = indices
213 .iter()
214 .any(|idx| idx.columns.contains(&"vector".to_string()));
215
216 if !has_vector_index {
217 info!(
218 "Creating vector index on {} table ({} rows)",
219 name, row_count
220 );
221 match table
222 .create_index(&["vector"], lancedb::index::Index::Auto)
223 .execute()
224 .await
225 {
226 Ok(_) => info!("Vector index created on {} table", name),
227 Err(e) => {
228 tracing::warn!("Failed to create vector index on {}: {}", name, e);
230 }
231 }
232 }
233 }
234 }
235
236 Ok(())
237 }
238
239 async fn get_items_table(&mut self) -> Result<&Table> {
241 if self.items_table.is_none() {
242 let schema = Arc::new(item_schema());
243 let table = self
244 .db
245 .create_empty_table("items", schema)
246 .execute()
247 .await
248 .map_err(|e| {
249 SedimentError::Database(format!("Failed to create items table: {}", e))
250 })?;
251 self.items_table = Some(table);
252 }
253 Ok(self.items_table.as_ref().unwrap())
254 }
255
256 async fn get_chunks_table(&mut self) -> Result<&Table> {
258 if self.chunks_table.is_none() {
259 let schema = Arc::new(chunk_schema());
260 let table = self
261 .db
262 .create_empty_table("chunks", schema)
263 .execute()
264 .await
265 .map_err(|e| {
266 SedimentError::Database(format!("Failed to create chunks table: {}", e))
267 })?;
268 self.chunks_table = Some(table);
269 }
270 Ok(self.chunks_table.as_ref().unwrap())
271 }
272
273 pub async fn store_item(&mut self, mut item: Item) -> Result<StoreResult> {
280 if item.project_id.is_none() {
282 item.project_id = self.project_id.clone();
283 }
284
285 let potential_conflicts = self
287 .find_similar_items(
288 &item.content,
289 CONFLICT_SIMILARITY_THRESHOLD,
290 CONFLICT_SEARCH_LIMIT,
291 )
292 .await
293 .unwrap_or_default();
294
295 let should_chunk = item.content.len() > CHUNK_THRESHOLD;
297 item.is_chunked = should_chunk;
298
299 let embedding_text = item.embedding_text();
301 let embedding = self.embedder.embed(&embedding_text)?;
302 item.embedding = embedding;
303
304 let table = self.get_items_table().await?;
306 let batch = item_to_batch(&item)?;
307 let batches = RecordBatchIterator::new(vec![Ok(batch)], Arc::new(item_schema()));
308
309 table
310 .add(Box::new(batches))
311 .execute()
312 .await
313 .map_err(|e| SedimentError::Database(format!("Failed to store item: {}", e)))?;
314
315 if should_chunk {
317 let embedder = self.embedder.clone();
318 let chunks_table = self.get_chunks_table().await?;
319
320 let content_type = detect_content_type(&item.content);
322 let config = ChunkingConfig::default();
323 let chunk_results = chunk_content(&item.content, content_type, &config);
324
325 for (i, chunk_result) in chunk_results.iter().enumerate() {
326 let mut chunk = Chunk::new(&item.id, i, &chunk_result.content);
327
328 if let Some(ctx) = &chunk_result.context {
329 chunk = chunk.with_context(ctx);
330 }
331
332 let chunk_embedding = embedder.embed(&chunk.content)?;
333 chunk.embedding = chunk_embedding;
334
335 let chunk_batch = chunk_to_batch(&chunk)?;
336 let batches =
337 RecordBatchIterator::new(vec![Ok(chunk_batch)], Arc::new(chunk_schema()));
338
339 chunks_table
340 .add(Box::new(batches))
341 .execute()
342 .await
343 .map_err(|e| {
344 SedimentError::Database(format!("Failed to store chunk: {}", e))
345 })?;
346 }
347
348 debug!(
349 "Stored item: {} with {} chunks",
350 item.id,
351 chunk_results.len()
352 );
353 } else {
354 debug!("Stored item: {} (no chunking)", item.id);
355 }
356
357 Ok(StoreResult {
358 id: item.id,
359 potential_conflicts,
360 })
361 }
362
363 pub async fn search_items(
365 &mut self,
366 query: &str,
367 limit: usize,
368 filters: ItemFilters,
369 ) -> Result<Vec<SearchResult>> {
370 let query_embedding = self.embedder.embed(query)?;
372 let min_similarity = filters.min_similarity.unwrap_or(0.3);
373
374 let mut results_map: std::collections::HashMap<String, (SearchResult, f32)> =
376 std::collections::HashMap::new();
377
378 if let Some(table) = &self.items_table {
380 let mut filter_parts = Vec::new();
381
382 if !filters.include_expired {
383 let now = Utc::now().timestamp();
384 filter_parts.push(format!("(expires_at IS NULL OR expires_at > {})", now));
385 }
386
387 let mut query_builder = table
388 .vector_search(query_embedding.clone())
389 .map_err(|e| SedimentError::Database(format!("Failed to build search: {}", e)))?
390 .limit(limit * 2);
391
392 if !filter_parts.is_empty() {
393 let filter_str = filter_parts.join(" AND ");
394 query_builder = query_builder.only_if(filter_str);
395 }
396
397 let results = query_builder
398 .execute()
399 .await
400 .map_err(|e| SedimentError::Database(format!("Search failed: {}", e)))?
401 .try_collect::<Vec<_>>()
402 .await
403 .map_err(|e| {
404 SedimentError::Database(format!("Failed to collect results: {}", e))
405 })?;
406
407 for batch in results {
408 let items = batch_to_items(&batch)?;
409 let distances = batch
410 .column_by_name("_distance")
411 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
412
413 for (i, item) in items.into_iter().enumerate() {
414 let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
415 let similarity = 1.0 / (1.0 + distance);
416
417 if similarity < min_similarity {
418 continue;
419 }
420
421 if let Some(ref filter_tags) = filters.tags
423 && !filter_tags.iter().any(|t| item.tags.contains(t))
424 {
425 continue;
426 }
427
428 let boosted_similarity = boost_similarity(
430 similarity,
431 item.project_id.as_deref(),
432 self.project_id.as_deref(),
433 );
434
435 let result = SearchResult::from_item(&item, boosted_similarity);
436 results_map
437 .entry(item.id.clone())
438 .or_insert((result, boosted_similarity));
439 }
440 }
441 }
442
443 if let Some(chunks_table) = &self.chunks_table {
445 let chunk_results = chunks_table
446 .vector_search(query_embedding)
447 .map_err(|e| {
448 SedimentError::Database(format!("Failed to build chunk search: {}", e))
449 })?
450 .limit(limit * 3)
451 .execute()
452 .await
453 .map_err(|e| SedimentError::Database(format!("Chunk search failed: {}", e)))?
454 .try_collect::<Vec<_>>()
455 .await
456 .map_err(|e| {
457 SedimentError::Database(format!("Failed to collect chunk results: {}", e))
458 })?;
459
460 let mut chunk_matches: std::collections::HashMap<String, (String, f32)> =
462 std::collections::HashMap::new();
463
464 for batch in chunk_results {
465 let chunks = batch_to_chunks(&batch)?;
466 let distances = batch
467 .column_by_name("_distance")
468 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
469
470 for (i, chunk) in chunks.into_iter().enumerate() {
471 let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
472 let similarity = 1.0 / (1.0 + distance);
473
474 if similarity < min_similarity {
475 continue;
476 }
477
478 chunk_matches
480 .entry(chunk.item_id.clone())
481 .and_modify(|(content, best_sim)| {
482 if similarity > *best_sim {
483 *content = chunk.content.clone();
484 *best_sim = similarity;
485 }
486 })
487 .or_insert((chunk.content.clone(), similarity));
488 }
489 }
490
491 for (item_id, (excerpt, chunk_similarity)) in chunk_matches {
493 if let Some(item) = self.get_item(&item_id).await? {
494 if let Some(ref filter_tags) = filters.tags
496 && !filter_tags.iter().any(|t| item.tags.contains(t))
497 {
498 continue;
499 }
500
501 let boosted_similarity = boost_similarity(
503 chunk_similarity,
504 item.project_id.as_deref(),
505 self.project_id.as_deref(),
506 );
507
508 let result =
509 SearchResult::from_item_with_excerpt(&item, boosted_similarity, excerpt);
510
511 results_map
513 .entry(item_id)
514 .and_modify(|(existing, existing_sim)| {
515 if boosted_similarity > *existing_sim {
516 *existing = result.clone();
517 *existing_sim = boosted_similarity;
518 }
519 })
520 .or_insert((result, boosted_similarity));
521 }
522 }
523 }
524
525 let mut search_results: Vec<SearchResult> =
527 results_map.into_values().map(|(r, _)| r).collect();
528 search_results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
529 search_results.truncate(limit);
530
531 Ok(search_results)
532 }
533
534 pub async fn find_similar_items(
539 &mut self,
540 content: &str,
541 min_similarity: f32,
542 limit: usize,
543 ) -> Result<Vec<ConflictInfo>> {
544 let embedding = self.embedder.embed(content)?;
546
547 let table = match &self.items_table {
548 Some(t) => t,
549 None => return Ok(Vec::new()),
550 };
551
552 let now = Utc::now().timestamp();
554 let filter = format!("(expires_at IS NULL OR expires_at > {})", now);
555
556 let results = table
557 .vector_search(embedding)
558 .map_err(|e| SedimentError::Database(format!("Failed to build search: {}", e)))?
559 .limit(limit)
560 .only_if(filter)
561 .execute()
562 .await
563 .map_err(|e| SedimentError::Database(format!("Search failed: {}", e)))?
564 .try_collect::<Vec<_>>()
565 .await
566 .map_err(|e| SedimentError::Database(format!("Failed to collect results: {}", e)))?;
567
568 let mut conflicts = Vec::new();
569
570 for batch in results {
571 let items = batch_to_items(&batch)?;
572 let distances = batch
573 .column_by_name("_distance")
574 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
575
576 for (i, item) in items.into_iter().enumerate() {
577 let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
578 let similarity = 1.0 / (1.0 + distance);
579
580 if similarity >= min_similarity {
581 conflicts.push(ConflictInfo {
582 id: item.id,
583 content: item.content,
584 similarity,
585 });
586 }
587 }
588 }
589
590 conflicts.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
592
593 Ok(conflicts)
594 }
595
596 pub async fn list_items(
598 &mut self,
599 filters: ItemFilters,
600 limit: Option<usize>,
601 scope: crate::ListScope,
602 ) -> Result<Vec<Item>> {
603 let table = match &self.items_table {
604 Some(t) => t,
605 None => return Ok(Vec::new()),
606 };
607
608 let mut filter_parts = Vec::new();
609
610 if !filters.include_expired {
611 let now = Utc::now().timestamp();
612 filter_parts.push(format!("(expires_at IS NULL OR expires_at > {})", now));
613 }
614
615 match scope {
617 crate::ListScope::Project => {
618 if let Some(ref pid) = self.project_id {
619 filter_parts.push(format!("project_id = '{}'", pid));
620 }
621 }
622 crate::ListScope::Global => {
623 filter_parts.push("project_id IS NULL".to_string());
624 }
625 crate::ListScope::All => {
626 }
628 }
629
630 let mut query = table.query();
631
632 if !filter_parts.is_empty() {
633 let filter_str = filter_parts.join(" AND ");
634 query = query.only_if(filter_str);
635 }
636
637 if let Some(l) = limit {
638 query = query.limit(l);
639 }
640
641 let results = query
642 .execute()
643 .await
644 .map_err(|e| SedimentError::Database(format!("Query failed: {}", e)))?
645 .try_collect::<Vec<_>>()
646 .await
647 .map_err(|e| SedimentError::Database(format!("Failed to collect: {}", e)))?;
648
649 let mut items = Vec::new();
650 for batch in results {
651 items.extend(batch_to_items(&batch)?);
652 }
653
654 if let Some(ref filter_tags) = filters.tags {
656 items.retain(|item| filter_tags.iter().any(|t| item.tags.contains(t)));
657 }
658
659 Ok(items)
660 }
661
662 pub async fn get_item(&self, id: &str) -> Result<Option<Item>> {
664 let table = match &self.items_table {
665 Some(t) => t,
666 None => return Ok(None),
667 };
668
669 let results = table
670 .query()
671 .only_if(format!("id = '{}'", id))
672 .limit(1)
673 .execute()
674 .await
675 .map_err(|e| SedimentError::Database(format!("Query failed: {}", e)))?
676 .try_collect::<Vec<_>>()
677 .await
678 .map_err(|e| SedimentError::Database(format!("Failed to collect: {}", e)))?;
679
680 for batch in results {
681 let items = batch_to_items(&batch)?;
682 if let Some(item) = items.into_iter().next() {
683 return Ok(Some(item));
684 }
685 }
686
687 Ok(None)
688 }
689
690 pub async fn get_items_batch(&self, ids: &[&str]) -> Result<Vec<Item>> {
692 let table = match &self.items_table {
693 Some(t) => t,
694 None => return Ok(Vec::new()),
695 };
696
697 if ids.is_empty() {
698 return Ok(Vec::new());
699 }
700
701 let quoted: Vec<String> = ids.iter().map(|id| format!("'{}'", id)).collect();
702 let filter = format!("id IN ({})", quoted.join(", "));
703
704 let results = table
705 .query()
706 .only_if(filter)
707 .execute()
708 .await
709 .map_err(|e| SedimentError::Database(format!("Batch query failed: {}", e)))?
710 .try_collect::<Vec<_>>()
711 .await
712 .map_err(|e| SedimentError::Database(format!("Failed to collect batch: {}", e)))?;
713
714 let mut items = Vec::new();
715 for batch in results {
716 items.extend(batch_to_items(&batch)?);
717 }
718
719 Ok(items)
720 }
721
722 pub async fn delete_item(&self, id: &str) -> Result<bool> {
724 if let Some(chunks_table) = &self.chunks_table {
726 chunks_table
727 .delete(&format!("item_id = '{}'", id))
728 .await
729 .map_err(|e| SedimentError::Database(format!("Delete chunks failed: {}", e)))?;
730 }
731
732 let table = match &self.items_table {
734 Some(t) => t,
735 None => return Ok(false),
736 };
737
738 table
739 .delete(&format!("id = '{}'", id))
740 .await
741 .map_err(|e| SedimentError::Database(format!("Delete failed: {}", e)))?;
742
743 Ok(true)
744 }
745
746 pub async fn stats(&self) -> Result<DatabaseStats> {
748 let mut stats = DatabaseStats::default();
749
750 if let Some(table) = &self.items_table {
751 stats.item_count = table
752 .count_rows(None)
753 .await
754 .map_err(|e| SedimentError::Database(format!("Count failed: {}", e)))?;
755 }
756
757 if let Some(table) = &self.chunks_table {
758 stats.chunk_count = table
759 .count_rows(None)
760 .await
761 .map_err(|e| SedimentError::Database(format!("Count failed: {}", e)))?;
762 }
763
764 Ok(stats)
765 }
766}
767
768pub fn score_with_decay(
779 similarity: f32,
780 now: i64,
781 created_at: i64,
782 access_count: u32,
783 last_accessed_at: Option<i64>,
784) -> f32 {
785 let reference_time = last_accessed_at.unwrap_or(created_at);
786 let age_secs = (now - reference_time).max(0) as f64;
787 let age_days = age_secs / 86400.0;
788
789 let freshness = 1.0 / (1.0 + age_days / 30.0);
790 let frequency = 1.0 + 0.1 * (1.0 + access_count as f64).ln();
791
792 similarity * (freshness * frequency) as f32
793}
794
795fn boost_similarity(base: f32, item_project: Option<&str>, current_project: Option<&str>) -> f32 {
799 match (item_project, current_project) {
800 (Some(m), Some(c)) if m == c => (base * 1.15).min(1.0), (Some(_), Some(_)) => base * 0.95, _ => base, }
804}
805
806fn detect_content_type(content: &str) -> ContentType {
808 let trimmed = content.trim();
809
810 if ((trimmed.starts_with('{') && trimmed.ends_with('}'))
812 || (trimmed.starts_with('[') && trimmed.ends_with(']')))
813 && serde_json::from_str::<serde_json::Value>(trimmed).is_ok()
814 {
815 return ContentType::Json;
816 }
817
818 if trimmed.contains(":\n") || trimmed.starts_with("---") {
820 let lines: Vec<&str> = trimmed.lines().take(5).collect();
822 let yaml_like = lines.iter().any(|line| {
823 let l = line.trim();
824 !l.is_empty() && !l.starts_with('#') && l.contains(':') && !l.starts_with("http")
825 });
826 if yaml_like {
827 return ContentType::Yaml;
828 }
829 }
830
831 if trimmed.contains("\n# ") || trimmed.starts_with("# ") || trimmed.contains("\n## ") {
833 return ContentType::Markdown;
834 }
835
836 let code_patterns = [
838 "fn ",
839 "pub fn ",
840 "def ",
841 "class ",
842 "function ",
843 "const ",
844 "let ",
845 "var ",
846 "import ",
847 "export ",
848 "struct ",
849 "impl ",
850 "trait ",
851 ];
852 if code_patterns.iter().any(|p| trimmed.contains(p)) {
853 return ContentType::Code;
854 }
855
856 ContentType::Text
857}
858
859fn item_to_batch(item: &Item) -> Result<RecordBatch> {
862 let schema = Arc::new(item_schema());
863
864 let id = StringArray::from(vec![item.id.as_str()]);
865 let content = StringArray::from(vec![item.content.as_str()]);
866 let title = StringArray::from(vec![item.title.as_deref()]);
867 let tags = StringArray::from(vec![serde_json::to_string(&item.tags).ok()]);
868 let source = StringArray::from(vec![item.source.as_deref()]);
869 let metadata = StringArray::from(vec![item.metadata.as_ref().map(|m| m.to_string())]);
870 let project_id = StringArray::from(vec![item.project_id.as_deref()]);
871 let is_chunked = BooleanArray::from(vec![item.is_chunked]);
872 let expires_at = Int64Array::from(vec![item.expires_at.map(|t| t.timestamp())]);
873 let created_at = Int64Array::from(vec![item.created_at.timestamp()]);
874
875 let vector = create_embedding_array(&item.embedding)?;
876
877 RecordBatch::try_new(
878 schema,
879 vec![
880 Arc::new(id),
881 Arc::new(content),
882 Arc::new(title),
883 Arc::new(tags),
884 Arc::new(source),
885 Arc::new(metadata),
886 Arc::new(project_id),
887 Arc::new(is_chunked),
888 Arc::new(expires_at),
889 Arc::new(created_at),
890 Arc::new(vector),
891 ],
892 )
893 .map_err(|e| SedimentError::Database(format!("Failed to create batch: {}", e)))
894}
895
896fn batch_to_items(batch: &RecordBatch) -> Result<Vec<Item>> {
897 let mut items = Vec::new();
898
899 let id_col = batch
900 .column_by_name("id")
901 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
902 .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?;
903
904 let content_col = batch
905 .column_by_name("content")
906 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
907 .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?;
908
909 let title_col = batch
910 .column_by_name("title")
911 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
912
913 let tags_col = batch
914 .column_by_name("tags")
915 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
916
917 let source_col = batch
918 .column_by_name("source")
919 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
920
921 let metadata_col = batch
922 .column_by_name("metadata")
923 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
924
925 let project_id_col = batch
926 .column_by_name("project_id")
927 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
928
929 let is_chunked_col = batch
930 .column_by_name("is_chunked")
931 .and_then(|c| c.as_any().downcast_ref::<BooleanArray>());
932
933 let expires_at_col = batch
934 .column_by_name("expires_at")
935 .and_then(|c| c.as_any().downcast_ref::<Int64Array>());
936
937 let created_at_col = batch
938 .column_by_name("created_at")
939 .and_then(|c| c.as_any().downcast_ref::<Int64Array>());
940
941 for i in 0..batch.num_rows() {
942 let id = id_col.value(i).to_string();
943 let content = content_col.value(i).to_string();
944
945 let title = title_col.and_then(|c| {
946 if c.is_null(i) {
947 None
948 } else {
949 Some(c.value(i).to_string())
950 }
951 });
952
953 let tags: Vec<String> = tags_col
954 .and_then(|c| {
955 if c.is_null(i) {
956 None
957 } else {
958 serde_json::from_str(c.value(i)).ok()
959 }
960 })
961 .unwrap_or_default();
962
963 let source = source_col.and_then(|c| {
964 if c.is_null(i) {
965 None
966 } else {
967 Some(c.value(i).to_string())
968 }
969 });
970
971 let metadata = metadata_col.and_then(|c| {
972 if c.is_null(i) {
973 None
974 } else {
975 serde_json::from_str(c.value(i)).ok()
976 }
977 });
978
979 let project_id = project_id_col.and_then(|c| {
980 if c.is_null(i) {
981 None
982 } else {
983 Some(c.value(i).to_string())
984 }
985 });
986
987 let is_chunked = is_chunked_col.map(|c| c.value(i)).unwrap_or(false);
988
989 let expires_at = expires_at_col.and_then(|c| {
990 if c.is_null(i) {
991 None
992 } else {
993 Some(Utc.timestamp_opt(c.value(i), 0).unwrap())
994 }
995 });
996
997 let created_at = created_at_col
998 .map(|c| Utc.timestamp_opt(c.value(i), 0).unwrap())
999 .unwrap_or_else(Utc::now);
1000
1001 let item = Item {
1002 id,
1003 content,
1004 embedding: Vec::new(),
1005 title,
1006 tags,
1007 source,
1008 metadata,
1009 project_id,
1010 is_chunked,
1011 expires_at,
1012 created_at,
1013 };
1014
1015 items.push(item);
1016 }
1017
1018 Ok(items)
1019}
1020
1021fn chunk_to_batch(chunk: &Chunk) -> Result<RecordBatch> {
1022 let schema = Arc::new(chunk_schema());
1023
1024 let id = StringArray::from(vec![chunk.id.as_str()]);
1025 let item_id = StringArray::from(vec![chunk.item_id.as_str()]);
1026 let chunk_index = Int32Array::from(vec![chunk.chunk_index as i32]);
1027 let content = StringArray::from(vec![chunk.content.as_str()]);
1028 let context = StringArray::from(vec![chunk.context.as_deref()]);
1029
1030 let vector = create_embedding_array(&chunk.embedding)?;
1031
1032 RecordBatch::try_new(
1033 schema,
1034 vec![
1035 Arc::new(id),
1036 Arc::new(item_id),
1037 Arc::new(chunk_index),
1038 Arc::new(content),
1039 Arc::new(context),
1040 Arc::new(vector),
1041 ],
1042 )
1043 .map_err(|e| SedimentError::Database(format!("Failed to create batch: {}", e)))
1044}
1045
1046fn batch_to_chunks(batch: &RecordBatch) -> Result<Vec<Chunk>> {
1047 let mut chunks = Vec::new();
1048
1049 let id_col = batch
1050 .column_by_name("id")
1051 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1052 .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?;
1053
1054 let item_id_col = batch
1055 .column_by_name("item_id")
1056 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1057 .ok_or_else(|| SedimentError::Database("Missing item_id column".to_string()))?;
1058
1059 let chunk_index_col = batch
1060 .column_by_name("chunk_index")
1061 .and_then(|c| c.as_any().downcast_ref::<Int32Array>())
1062 .ok_or_else(|| SedimentError::Database("Missing chunk_index column".to_string()))?;
1063
1064 let content_col = batch
1065 .column_by_name("content")
1066 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1067 .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?;
1068
1069 let context_col = batch
1070 .column_by_name("context")
1071 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
1072
1073 for i in 0..batch.num_rows() {
1074 let id = id_col.value(i).to_string();
1075 let item_id = item_id_col.value(i).to_string();
1076 let chunk_index = chunk_index_col.value(i) as usize;
1077 let content = content_col.value(i).to_string();
1078 let context = context_col.and_then(|c| {
1079 if c.is_null(i) {
1080 None
1081 } else {
1082 Some(c.value(i).to_string())
1083 }
1084 });
1085
1086 let chunk = Chunk {
1087 id,
1088 item_id,
1089 chunk_index,
1090 content,
1091 embedding: Vec::new(),
1092 context,
1093 };
1094
1095 chunks.push(chunk);
1096 }
1097
1098 Ok(chunks)
1099}
1100
1101fn create_embedding_array(embedding: &[f32]) -> Result<FixedSizeListArray> {
1102 let values = Float32Array::from(embedding.to_vec());
1103 let field = Arc::new(Field::new("item", DataType::Float32, true));
1104
1105 FixedSizeListArray::try_new(field, EMBEDDING_DIM as i32, Arc::new(values), None)
1106 .map_err(|e| SedimentError::Database(format!("Failed to create vector: {}", e)))
1107}
1108
1109#[cfg(test)]
1110mod tests {
1111 use super::*;
1112
1113 #[test]
1114 fn test_score_with_decay_fresh_item() {
1115 let now = 1700000000i64;
1116 let created = now; let score = score_with_decay(0.8, now, created, 0, None);
1118 let expected = 0.8 * 1.0 * 1.0;
1120 assert!((score - expected).abs() < 0.001, "got {}", score);
1121 }
1122
1123 #[test]
1124 fn test_score_with_decay_30_day_old() {
1125 let now = 1700000000i64;
1126 let created = now - 30 * 86400; let score = score_with_decay(0.8, now, created, 0, None);
1128 let expected = 0.8 * 0.5;
1130 assert!((score - expected).abs() < 0.001, "got {}", score);
1131 }
1132
1133 #[test]
1134 fn test_score_with_decay_frequent_access() {
1135 let now = 1700000000i64;
1136 let created = now - 30 * 86400;
1137 let last_accessed = now; let score = score_with_decay(0.8, now, created, 10, Some(last_accessed));
1139 let freq = 1.0 + 0.1 * (11.0_f64).ln();
1141 let expected = 0.8 * 1.0 * freq as f32;
1142 assert!((score - expected).abs() < 0.01, "got {}", score);
1143 }
1144
1145 #[test]
1146 fn test_score_with_decay_old_and_unused() {
1147 let now = 1700000000i64;
1148 let created = now - 90 * 86400; let score = score_with_decay(0.8, now, created, 0, None);
1150 let expected = 0.8 * 0.25;
1152 assert!((score - expected).abs() < 0.001, "got {}", score);
1153 }
1154}