1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3use std::time::Duration;
4
5use arrow_array::builder::{
6 FixedSizeListBuilder, Float32Builder, Int32Builder, LargeBinaryBuilder, LargeStringBuilder,
7 StringBuilder, StringDictionaryBuilder, StructBuilder, TimestampMicrosecondBuilder,
8};
9use arrow_array::types::Int8Type;
10use arrow_array::{
11 Array, ArrayRef, DictionaryArray, FixedSizeListArray, Float32Array, Int32Array,
12 LargeBinaryArray, LargeStringArray, RecordBatch, RecordBatchIterator, StringArray, StructArray,
13 TimestampMicrosecondArray,
14};
15use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, TimeUnit};
16use chrono::{DateTime, Timelike, Utc};
17use futures::TryStreamExt;
18use lance::dataset::mem_wal::{
19 DatasetMemWalExt, LsmScanner, ShardManifestStore, ShardSnapshot, ShardWriterConfig,
20};
21use lance::dataset::optimize::{compact_files, CompactionMetrics, CompactionOptions};
22use lance::dataset::{builder::DatasetBuilder, Dataset, WriteMode, WriteParams};
23use lance::index::DatasetIndexExt;
24use lance::io::{ObjectStoreParams, StorageOptionsAccessor};
25use lance::{Error as LanceError, Result as LanceResult};
26use lance_index::mem_wal::MEM_WAL_INDEX_NAME;
27use lance_index::scalar::ScalarIndexParams;
28use lance_index::IndexType;
29use tokio::sync::Mutex;
30use tokio::task::JoinHandle;
31use tracing::{error, info, warn};
32use uuid::Uuid;
33
34use crate::record::{ContextRecord, RecordFilters, SearchResult, StateMetadata};
35use crate::serde::CONTENT_TYPE_TOMBSTONE;
36
37const DEFAULT_EMBEDDING_DIM: i32 = 1536;
39const DEFAULT_SEARCH_LIMIT: usize = 10;
40const DEFAULT_MANIFEST_SCAN_BATCH_SIZE: usize = 16;
41const ID_INDEX_NAME: &str = "id_idx";
42
43#[derive(Debug, Clone)]
45pub struct CompactionConfig {
46 pub enabled: bool,
48 pub min_fragments: usize,
50 pub target_rows_per_fragment: usize,
52 pub max_rows_per_group: usize,
54 pub materialize_deletions: bool,
56 pub materialize_deletions_threshold: f32,
58 pub num_threads: Option<usize>,
60 pub check_interval_secs: u64,
62 pub quiet_hours: Vec<(u8, u8)>,
64}
65
66impl Default for CompactionConfig {
67 fn default() -> Self {
68 Self {
69 enabled: false,
70 min_fragments: 5,
71 target_rows_per_fragment: 1_000_000,
72 max_rows_per_group: 1024,
73 materialize_deletions: true,
74 materialize_deletions_threshold: 0.1,
75 num_threads: None,
76 check_interval_secs: 300,
77 quiet_hours: vec![],
78 }
79 }
80}
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
84pub enum IdIndexType {
85 #[default]
87 None,
88 ZoneMap,
90 BTree,
92}
93
94#[derive(Debug, Clone)]
96pub struct CompactionStats {
97 pub total_fragments: usize,
99 pub is_compacting: bool,
101 pub last_compaction: Option<DateTime<Utc>>,
103 pub last_error: Option<String>,
105 pub total_compactions: u64,
107}
108
109struct CompactionState {
111 background_task: Option<JoinHandle<()>>,
112 is_compacting: bool,
113 last_compaction: Option<DateTime<Utc>>,
114 last_error: Option<String>,
115 total_compactions: u64,
116}
117
118const VALID_BLOB_COLUMNS: &[&str] = &["text_payload", "binary_payload"];
120
121#[derive(Clone)]
123pub struct ContextStore {
124 dataset: Dataset,
125 compaction_state: Arc<Mutex<CompactionState>>,
126 pub compaction_config: CompactionConfig,
127 blob_columns: HashSet<String>,
128 id_index_type: IdIndexType,
129}
130
131#[derive(Debug, Clone, Default)]
133pub struct ContextStoreOptions {
134 pub storage_options: Option<HashMap<String, String>>,
135 pub compaction: CompactionConfig,
136 pub blob_columns: HashSet<String>,
139 pub id_index_type: IdIndexType,
141}
142
143impl ContextStoreOptions {
144 #[must_use]
145 pub fn storage_options(&self) -> Option<HashMap<String, String>> {
146 self.storage_options.clone()
147 }
148}
149
150impl ContextStore {
151 pub async fn open(uri: &str) -> LanceResult<Self> {
153 Self::open_with_options(uri, ContextStoreOptions::default()).await
154 }
155
156 pub async fn open_with_options(uri: &str, options: ContextStoreOptions) -> LanceResult<Self> {
158 for col in &options.blob_columns {
160 if !VALID_BLOB_COLUMNS.contains(&col.as_str()) {
161 return Err(LanceError::from(ArrowError::InvalidArgumentError(format!(
162 "invalid blob column '{}': valid columns are {:?}",
163 col, VALID_BLOB_COLUMNS
164 ))));
165 }
166 }
167
168 let storage_options = options.storage_options();
169 let blob_columns = options.blob_columns.clone();
170 let dataset = match Self::load_with_options(uri, storage_options.clone()).await {
171 Ok(dataset) => dataset,
172 Err(LanceError::DatasetNotFound { .. }) => {
173 Self::create_with_options(uri, storage_options, &blob_columns).await?
174 }
175 Err(err) => return Err(err),
176 };
177
178 let mut store = Self {
179 dataset,
180 compaction_state: Arc::new(Mutex::new(CompactionState {
181 background_task: None,
182 is_compacting: false,
183 last_compaction: None,
184 last_error: None,
185 total_compactions: 0,
186 })),
187 compaction_config: options.compaction,
188 blob_columns,
189 id_index_type: options.id_index_type,
190 };
191
192 store.ensure_id_index().await?;
194
195 store.start_background_compaction().await?;
197
198 Ok(store)
199 }
200
201 pub async fn add(&mut self, entries: &[ContextRecord]) -> LanceResult<u64> {
203 if entries.is_empty() {
204 return Ok(self.dataset.manifest.version);
205 }
206
207 self.validate_unique_ids(entries).await?;
208 self.write_entries(entries).await
209 }
210
211 async fn write_entries(&mut self, entries: &[ContextRecord]) -> LanceResult<u64> {
212 if entries.is_empty() {
213 return Ok(self.dataset.manifest.version);
214 }
215
216 let mut groups: HashMap<(Option<String>, Option<String>), Vec<ContextRecord>> =
218 HashMap::new();
219 for entry in entries {
220 let key = (entry.bot_id.clone(), entry.session_id.clone());
221 groups.entry(key).or_default().push(entry.clone());
222 }
223
224 {
226 let indices = self.dataset.load_indices().await?;
227 let has_mem_wal = indices.iter().any(|i| i.name == MEM_WAL_INDEX_NAME);
228
229 if !has_mem_wal {
230 let maintained_indexes: Vec<String> = indices
232 .iter()
233 .filter(|i| {
234 !(self.id_index_type == IdIndexType::ZoneMap && i.name == ID_INDEX_NAME)
235 })
236 .map(|i| i.name.clone())
237 .collect();
238 self.dataset
239 .initialize_mem_wal()
240 .unsharded()
241 .maintained_indexes(maintained_indexes)
242 .execute()
243 .await?;
244 }
245 }
246
247 for ((bot_id, session_id), group_entries) in groups {
248 let region_id = Self::derive_region_id(&bot_id, &session_id);
249 let batch = self.records_to_batch(&group_entries)?;
250 let config = ShardWriterConfig {
251 shard_id: region_id,
252 ..Default::default()
253 };
254
255 let writer = self.dataset.mem_wal_writer(region_id, config).await?;
256 writer.put(vec![batch]).await?;
257 writer.close().await?;
258 }
259
260 Ok(self.dataset.manifest.version)
261 }
262
263 pub async fn delete_by_id(&mut self, id: &str) -> LanceResult<bool> {
268 let Some(record) = self.get_by_id(id).await? else {
269 return Ok(false);
270 };
271 self.write_tombstone_for(record).await?;
272 Ok(true)
273 }
274
275 pub async fn delete_by_external_id(&mut self, external_id: &str) -> LanceResult<bool> {
277 let Some(record) = self.get_by_external_id(external_id).await? else {
278 return Ok(false);
279 };
280 self.write_tombstone_for(record).await?;
281 Ok(true)
282 }
283
284 async fn write_tombstone_for(&mut self, record: ContextRecord) -> LanceResult<u64> {
285 let tombstone = ContextRecord {
286 id: record.id,
287 external_id: record.external_id,
288 run_id: record.run_id,
289 bot_id: record.bot_id,
290 session_id: record.session_id,
291 created_at: Utc::now(),
292 role: record.role,
293 state_metadata: None,
294 metadata: None,
295 content_type: CONTENT_TYPE_TOMBSTONE.to_string(),
296 text_payload: None,
297 binary_payload: None,
298 embedding: None,
299 };
300 self.write_entries(std::slice::from_ref(&tombstone)).await
301 }
302
303 async fn validate_unique_ids(&self, entries: &[ContextRecord]) -> LanceResult<()> {
304 let mut ids = HashSet::new();
305 let mut external_ids = HashSet::new();
306 for entry in entries {
307 if entry.is_tombstone() {
308 return Err(ArrowError::InvalidArgumentError(format!(
309 "content_type '{}' is reserved for internal tombstones",
310 CONTENT_TYPE_TOMBSTONE
311 ))
312 .into());
313 }
314 if !ids.insert(entry.id.as_str()) {
315 return Err(ArrowError::InvalidArgumentError(format!(
316 "duplicate id '{}' in batch",
317 entry.id
318 ))
319 .into());
320 }
321 if let Some(external_id) = &entry.external_id {
322 if !external_ids.insert(external_id.as_str()) {
323 return Err(ArrowError::InvalidArgumentError(format!(
324 "duplicate external_id '{}' in batch",
325 external_id
326 ))
327 .into());
328 }
329 }
330 }
331
332 for record in self.list(None, None).await? {
333 if ids.contains(record.id.as_str()) {
334 return Err(ArrowError::InvalidArgumentError(format!(
335 "id '{}' already exists",
336 record.id
337 ))
338 .into());
339 }
340 if let Some(external_id) = record.external_id {
341 if external_ids.contains(external_id.as_str()) {
342 return Err(ArrowError::InvalidArgumentError(format!(
343 "external_id '{}' already exists",
344 external_id
345 ))
346 .into());
347 }
348 }
349 }
350
351 Ok(())
352 }
353
354 fn derive_region_id(bot_id: &Option<String>, session_id: &Option<String>) -> Uuid {
355 let mut input = String::new();
356
357 if let Some(bid) = bot_id {
358 input.push_str(bid);
359 }
360 input.push('#');
361 if let Some(sid) = session_id {
362 input.push_str(sid);
363 }
364
365 Uuid::new_v5(&Uuid::NAMESPACE_OID, input.as_bytes())
367 }
368
369 pub fn version(&self) -> u64 {
371 self.dataset.manifest.version
372 }
373
374 pub async fn checkout(&mut self, version_id: u64) -> LanceResult<()> {
376 let dataset = self.dataset.checkout_version(version_id).await?;
377 self.dataset = dataset;
378 Ok(())
379 }
380
381 pub async fn list(
383 &self,
384 limit: Option<usize>,
385 offset: Option<usize>,
386 ) -> LanceResult<Vec<ContextRecord>> {
387 self.list_filtered(limit, offset, None).await
388 }
389
390 pub async fn list_filtered(
392 &self,
393 limit: Option<usize>,
394 offset: Option<usize>,
395 filters: Option<&RecordFilters>,
396 ) -> LanceResult<Vec<ContextRecord>> {
397 let scanner = self.lsm_scanner().await?;
398 let mut stream = scanner.try_into_stream().await?;
399 let mut results = Vec::new();
400 while let Some(batch) = stream.try_next().await? {
401 results.extend(
402 batch_to_records(&batch)?
403 .into_iter()
404 .filter(|record| !record.is_tombstone()),
405 );
406 }
407
408 if let Some(filters) = filters.filter(|filters| !filters.is_empty()) {
409 results.retain(|record| filters.matches(record));
410 }
411
412 if let Some(offset) = offset {
413 results = results.into_iter().skip(offset).collect();
414 }
415 if let Some(limit) = limit {
416 results.truncate(limit);
417 }
418 Ok(results)
419 }
420
421 pub async fn get_by_id(&self, id: &str) -> LanceResult<Option<ContextRecord>> {
423 Ok(self
424 .list(None, None)
425 .await?
426 .into_iter()
427 .find(|record| record.id == id))
428 }
429
430 pub async fn get_by_external_id(
432 &self,
433 external_id: &str,
434 ) -> LanceResult<Option<ContextRecord>> {
435 Ok(self
436 .list(None, None)
437 .await?
438 .into_iter()
439 .find(|record| record.external_id.as_deref() == Some(external_id)))
440 }
441
442 pub async fn search(
444 &self,
445 query: &[f32],
446 limit: Option<usize>,
447 ) -> LanceResult<Vec<SearchResult>> {
448 self.search_filtered(query, limit, None).await
449 }
450
451 pub async fn search_filtered(
453 &self,
454 query: &[f32],
455 limit: Option<usize>,
456 filters: Option<&RecordFilters>,
457 ) -> LanceResult<Vec<SearchResult>> {
458 if query.len() != DEFAULT_EMBEDDING_DIM as usize {
459 return Err(ArrowError::InvalidArgumentError(format!(
460 "query length {} does not match embedding dimension {}",
461 query.len(),
462 DEFAULT_EMBEDDING_DIM
463 ))
464 .into());
465 }
466
467 let top_k = limit.unwrap_or(DEFAULT_SEARCH_LIMIT);
468 if top_k == 0 {
469 return Ok(Vec::new());
470 }
471
472 let mut results: Vec<SearchResult> = self
473 .list_filtered(None, None, filters)
474 .await?
475 .into_iter()
476 .filter_map(|record| {
477 let distance = l2_distance(query, record.embedding.as_ref()?);
478 Some(SearchResult { record, distance })
479 })
480 .collect();
481 results.sort_by(|left, right| left.distance.total_cmp(&right.distance));
482 results.truncate(top_k);
483 Ok(results)
484 }
485
486 async fn lsm_scanner(&self) -> LanceResult<LsmScanner> {
487 let object_store = self.dataset.object_store(None).await?;
488 let branch_location = self.dataset.branch_location();
489 let shard_ids = self.dataset.list_mem_wal_latest_shard_ids().await?;
490
491 let mut shard_snapshots = Vec::with_capacity(shard_ids.len());
492 for shard_id in shard_ids {
493 let manifest_store = ShardManifestStore::new(
494 object_store.clone(),
495 &branch_location.path,
496 shard_id,
497 DEFAULT_MANIFEST_SCAN_BATCH_SIZE,
498 );
499 let Some(manifest) = manifest_store.read_latest().await? else {
500 continue;
501 };
502
503 let mut snapshot = ShardSnapshot::new(shard_id)
504 .with_spec_id(manifest.shard_spec_id)
505 .with_current_generation(manifest.current_generation);
506 for flushed in manifest.flushed_generations {
507 snapshot = snapshot.with_flushed_generation(flushed.generation, flushed.path);
508 }
509 shard_snapshots.push(snapshot);
510 }
511
512 Ok(LsmScanner::new(
513 Arc::new(self.dataset.clone()),
514 shard_snapshots,
515 vec!["id".to_string()],
516 ))
517 }
518
519 pub async fn compact(
521 &mut self,
522 options: Option<CompactionConfig>,
523 ) -> LanceResult<CompactionMetrics> {
524 let config = options.unwrap_or_else(|| self.compaction_config.clone());
525
526 info!(
527 "Starting compaction: {} fragments",
528 self.dataset.count_fragments()
529 );
530 let start = std::time::Instant::now();
531
532 {
534 let mut state = self.compaction_state.lock().await;
535 if state.is_compacting {
536 warn!("Compaction already in progress, skipping");
537 return Err(LanceError::from(ArrowError::InvalidArgumentError(
538 "Compaction already in progress".to_string(),
539 )));
540 }
541 state.is_compacting = true;
542 }
543
544 let lance_options = CompactionOptions {
546 target_rows_per_fragment: config.target_rows_per_fragment,
547 max_rows_per_group: config.max_rows_per_group,
548 materialize_deletions: config.materialize_deletions,
549 materialize_deletions_threshold: config.materialize_deletions_threshold,
550 num_threads: config.num_threads,
551 ..Default::default()
552 };
553
554 let result = compact_files(&mut self.dataset, lance_options, None).await;
556
557 let mut state = self.compaction_state.lock().await;
559 state.is_compacting = false;
560
561 match result {
562 Ok(metrics) => {
563 state.last_compaction = Some(Utc::now());
564 state.total_compactions += 1;
565 state.last_error = None;
566 drop(state); info!(
569 "Compaction completed in {:?}: removed {} fragments ({}files), added {} fragments ({} files)",
570 start.elapsed(),
571 metrics.fragments_removed,
572 metrics.files_removed,
573 metrics.fragments_added,
574 metrics.files_added
575 );
576
577 self.dataset = Dataset::open(self.dataset.uri()).await?;
579
580 if let Err(e) = self.ensure_id_index().await {
583 warn!("Failed to ensure id index after compaction: {}", e);
584 }
585
586 Ok(metrics)
587 }
588 Err(e) => {
589 error!("Compaction failed: {}", e);
590 state.last_error = Some(e.to_string());
591 Err(e)
592 }
593 }
594 }
595
596 pub async fn should_compact(&self) -> LanceResult<bool> {
598 let fragment_count = self.dataset.count_fragments();
599
600 if fragment_count < self.compaction_config.min_fragments {
601 return Ok(false);
602 }
603
604 if !self.compaction_config.quiet_hours.is_empty() {
606 let now = Utc::now();
607 let current_hour = now.hour() as u8;
608
609 for (start, end) in &self.compaction_config.quiet_hours {
610 if current_hour >= *start && current_hour < *end {
611 info!("Skipping compaction during quiet hours ({}-{})", start, end);
612 return Ok(false);
613 }
614 }
615 }
616
617 Ok(true)
618 }
619
620 pub async fn compaction_stats(&self) -> LanceResult<CompactionStats> {
622 let state = self.compaction_state.lock().await;
623
624 Ok(CompactionStats {
625 total_fragments: self.dataset.count_fragments(),
626 is_compacting: state.is_compacting,
627 last_compaction: state.last_compaction,
628 last_error: state.last_error.clone(),
629 total_compactions: state.total_compactions,
630 })
631 }
632
633 async fn ensure_id_index(&mut self) -> LanceResult<()> {
635 if self.id_index_type == IdIndexType::None {
636 return Ok(());
637 }
638
639 let indices = self.dataset.load_indices().await?;
640 if indices.iter().any(|i| i.name == ID_INDEX_NAME) {
641 return Ok(());
642 }
643
644 self.create_id_index().await
645 }
646
647 pub async fn create_id_index(&mut self) -> LanceResult<()> {
649 let index_type = match self.id_index_type {
650 IdIndexType::ZoneMap => IndexType::ZoneMap,
651 IdIndexType::BTree => IndexType::BTree,
652 IdIndexType::None => return Ok(()),
653 };
654
655 info!("Creating {:?} index on id column", index_type);
656
657 let params = ScalarIndexParams::default();
658
659 self.dataset
660 .create_index_builder(&["id"], index_type, ¶ms)
661 .name(ID_INDEX_NAME.to_string())
662 .replace(true)
663 .await?;
664
665 self.dataset = Dataset::open(self.dataset.uri()).await?;
667
668 Ok(())
669 }
670
671 async fn start_background_compaction(&mut self) -> LanceResult<()> {
673 if !self.compaction_config.enabled {
674 return Ok(());
675 }
676
677 let mut state = self.compaction_state.lock().await;
678 if state.background_task.is_some() {
679 warn!("Background compaction already running");
680 return Ok(());
681 }
682
683 info!(
684 "Starting background compaction (interval: {}s, min fragments: {})",
685 self.compaction_config.check_interval_secs, self.compaction_config.min_fragments
686 );
687
688 let mut store_clone = self.clone();
689 let interval_secs = self.compaction_config.check_interval_secs;
690
691 let task = tokio::spawn(async move {
692 let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
693
694 loop {
695 interval.tick().await;
696
697 match store_clone.should_compact().await {
698 Ok(true) => {
699 info!("Background compaction triggered");
700 if let Err(e) = store_clone.compact(None).await {
701 error!("Background compaction failed: {}", e);
702 }
703 }
704 Ok(false) => {
705 }
707 Err(e) => {
708 error!("Error checking compaction need: {}", e);
709 }
710 }
711 }
712 });
713
714 state.background_task = Some(task);
715 Ok(())
716 }
717
718 pub async fn stop_background_compaction(&mut self) -> LanceResult<()> {
720 let mut state = self.compaction_state.lock().await;
721
722 if let Some(task) = state.background_task.take() {
723 info!("Stopping background compaction");
724 task.abort();
725 }
726
727 Ok(())
728 }
729
730 pub fn schema(blob_columns: &HashSet<String>) -> Schema {
736 Self::schema_with_options(blob_columns, true, true)
737 }
738
739 fn schema_with_options(
740 blob_columns: &HashSet<String>,
741 include_external_id: bool,
742 include_metadata: bool,
743 ) -> Schema {
744 let mut id_metadata = HashMap::new();
745 id_metadata.insert(
746 "lance-schema:unenforced-primary-key".to_string(),
747 "true".to_string(),
748 );
749
750 let text_field = if blob_columns.contains("text_payload") {
751 let mut metadata = HashMap::new();
752 metadata.insert("lance-encoding:blob".to_string(), "true".to_string());
753 Field::new("text_payload", DataType::LargeBinary, true).with_metadata(metadata)
754 } else {
755 Field::new("text_payload", DataType::LargeUtf8, true)
756 };
757
758 let binary_field = if blob_columns.contains("binary_payload") {
759 let mut metadata = HashMap::new();
760 metadata.insert("lance-encoding:blob".to_string(), "true".to_string());
761 Field::new("binary_payload", DataType::LargeBinary, true).with_metadata(metadata)
762 } else {
763 Field::new("binary_payload", DataType::LargeBinary, true)
764 };
765
766 let mut fields = vec![Field::new("id", DataType::Utf8, false).with_metadata(id_metadata)];
767 if include_external_id {
768 fields.push(Field::new("external_id", DataType::Utf8, true));
769 }
770 fields.extend([
771 Field::new("run_id", DataType::Utf8, false),
772 Field::new("bot_id", DataType::Utf8, true),
773 Field::new("session_id", DataType::Utf8, true),
774 Field::new(
775 "created_at",
776 DataType::Timestamp(TimeUnit::Microsecond, None),
777 false,
778 ),
779 Field::new(
780 "role",
781 DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
782 false,
783 ),
784 Field::new(
785 "state_metadata",
786 DataType::Struct(
787 vec![
788 Field::new("step", DataType::Int32, true),
789 Field::new("active_plan_id", DataType::Utf8, true),
790 Field::new("tokens_used", DataType::Int32, true),
791 Field::new("custom", DataType::Utf8, true),
792 ]
793 .into(),
794 ),
795 true,
796 ),
797 ]);
798 if include_metadata {
799 fields.push(Field::new("metadata", DataType::LargeUtf8, true));
800 }
801 fields.extend([
802 Field::new("content_type", DataType::Utf8, false),
803 text_field,
804 binary_field,
805 Field::new(
806 "embedding",
807 DataType::FixedSizeList(
808 Arc::new(Field::new("item", DataType::Float32, true)),
809 DEFAULT_EMBEDDING_DIM,
810 ),
811 true,
812 ),
813 ]);
814
815 Schema::new(fields)
816 }
817
818 async fn load_with_options(
819 uri: &str,
820 storage_options: Option<HashMap<String, String>>,
821 ) -> LanceResult<Dataset> {
822 if let Some(options) = storage_options {
823 DatasetBuilder::from_uri(uri)
824 .with_storage_options(options)
825 .load()
826 .await
827 } else {
828 Dataset::open(uri).await
829 }
830 }
831
832 async fn create_with_options(
833 uri: &str,
834 storage_options: Option<HashMap<String, String>>,
835 blob_columns: &HashSet<String>,
836 ) -> LanceResult<Dataset> {
837 let schema = Arc::new(Self::schema(blob_columns));
838 let empty_batch = RecordBatch::new_empty(schema.clone());
839 let batches = RecordBatchIterator::new(
840 vec![Ok::<RecordBatch, ArrowError>(empty_batch)].into_iter(),
841 schema.clone(),
842 );
843
844 let mut params = WriteParams {
845 mode: WriteMode::Create,
846 ..Default::default()
847 };
848
849 if let Some(options) = storage_options {
850 let store_params = ObjectStoreParams {
851 storage_options_accessor: Some(Arc::new(
852 StorageOptionsAccessor::with_static_options(options),
853 )),
854 ..Default::default()
855 };
856 params.store_params = Some(store_params);
857 }
858
859 Dataset::write(batches, uri, Some(params)).await
860 }
861
862 fn records_to_batch(&self, entries: &[ContextRecord]) -> LanceResult<RecordBatch> {
863 let include_external_id = self
864 .dataset
865 .schema()
866 .field_paths()
867 .iter()
868 .any(|path| path == "external_id");
869 if !include_external_id && entries.iter().any(|entry| entry.external_id.is_some()) {
870 return Err(ArrowError::InvalidArgumentError(
871 "external_id requires a context dataset created with external_id support"
872 .to_string(),
873 )
874 .into());
875 }
876 let include_metadata = self
877 .dataset
878 .schema()
879 .field_paths()
880 .iter()
881 .any(|path| path == "metadata");
882 if !include_metadata && entries.iter().any(|entry| entry.metadata.is_some()) {
883 return Err(ArrowError::InvalidArgumentError(
884 "metadata requires a context dataset created with metadata support".to_string(),
885 )
886 .into());
887 }
888
889 let mut id_builder = StringBuilder::new();
890 let mut external_id_builder = StringBuilder::new();
891 let mut run_id_builder = StringBuilder::new();
892 let mut bot_id_builder = StringBuilder::new();
893 let mut session_id_builder = StringBuilder::new();
894 let mut created_at_builder = TimestampMicrosecondBuilder::with_capacity(entries.len());
895 let mut role_builder = StringDictionaryBuilder::<Int8Type>::new();
896 let mut metadata_builder = LargeStringBuilder::new();
897 let mut content_type_builder = StringBuilder::new();
898 let mut binary_builder = LargeBinaryBuilder::new();
899
900 let text_is_blob = self.blob_columns.contains("text_payload");
901 let mut text_string_builder = if !text_is_blob {
902 Some(LargeStringBuilder::new())
903 } else {
904 None
905 };
906 let mut text_binary_builder = if text_is_blob {
907 Some(LargeBinaryBuilder::new())
908 } else {
909 None
910 };
911
912 let state_fields: Vec<FieldRef> = vec![
913 Arc::new(Field::new("step", DataType::Int32, true)),
914 Arc::new(Field::new("active_plan_id", DataType::Utf8, true)),
915 Arc::new(Field::new("tokens_used", DataType::Int32, true)),
916 Arc::new(Field::new("custom", DataType::Utf8, true)),
917 ];
918 let mut state_builder = StructBuilder::new(
919 state_fields,
920 vec![
921 Box::new(Int32Builder::new()),
922 Box::new(StringBuilder::new()),
923 Box::new(Int32Builder::new()),
924 Box::new(StringBuilder::new()),
925 ],
926 );
927
928 let mut embedding_builder =
929 FixedSizeListBuilder::new(Float32Builder::new(), DEFAULT_EMBEDDING_DIM);
930
931 for entry in entries {
932 id_builder.append_value(&entry.id);
933 external_id_builder.append_option(entry.external_id.as_deref());
934 run_id_builder.append_value(&entry.run_id);
935 bot_id_builder.append_option(entry.bot_id.as_deref());
936 session_id_builder.append_option(entry.session_id.as_deref());
937 created_at_builder.append_value(entry.created_at.timestamp_micros());
938 role_builder.append(&entry.role)?;
939 match &entry.metadata {
940 Some(metadata) => metadata_builder.append_value(metadata.to_string()),
941 None => metadata_builder.append_null(),
942 }
943 content_type_builder.append_value(&entry.content_type);
944
945 if text_is_blob {
946 match &entry.text_payload {
947 Some(value) => text_binary_builder
948 .as_mut()
949 .unwrap()
950 .append_value(value.as_bytes()),
951 None => text_binary_builder.as_mut().unwrap().append_null(),
952 }
953 } else {
954 match &entry.text_payload {
955 Some(value) => text_string_builder.as_mut().unwrap().append_value(value),
956 None => text_string_builder.as_mut().unwrap().append_null(),
957 }
958 }
959
960 match &entry.binary_payload {
961 Some(value) => binary_builder.append_value(value),
962 None => binary_builder.append_null(),
963 }
964
965 if let Some(metadata) = &entry.state_metadata {
966 state_builder
967 .field_builder::<Int32Builder>(0)
968 .unwrap()
969 .append_option(metadata.step);
970 state_builder
971 .field_builder::<StringBuilder>(1)
972 .unwrap()
973 .append_option(metadata.active_plan_id.as_deref());
974 state_builder
975 .field_builder::<Int32Builder>(2)
976 .unwrap()
977 .append_option(metadata.tokens_used);
978 state_builder
979 .field_builder::<StringBuilder>(3)
980 .unwrap()
981 .append_option(metadata.custom.as_deref());
982 state_builder.append(true);
983 } else {
984 state_builder
985 .field_builder::<Int32Builder>(0)
986 .unwrap()
987 .append_null();
988 state_builder
989 .field_builder::<StringBuilder>(1)
990 .unwrap()
991 .append_null();
992 state_builder
993 .field_builder::<Int32Builder>(2)
994 .unwrap()
995 .append_null();
996 state_builder
997 .field_builder::<StringBuilder>(3)
998 .unwrap()
999 .append_null();
1000 state_builder.append(false);
1001 }
1002
1003 if let Some(embedding) = &entry.embedding {
1004 if embedding.len() != DEFAULT_EMBEDDING_DIM as usize {
1005 return Err(ArrowError::InvalidArgumentError(format!(
1006 "embedding length {} does not match expected dimension {}",
1007 embedding.len(),
1008 DEFAULT_EMBEDDING_DIM
1009 ))
1010 .into());
1011 }
1012 {
1013 let values_builder = embedding_builder.values();
1014 for value in embedding {
1015 values_builder.append_value(*value);
1016 }
1017 }
1018 embedding_builder.append(true);
1019 } else {
1020 let values_builder = embedding_builder.values();
1022 for _ in 0..DEFAULT_EMBEDDING_DIM {
1023 values_builder.append_null();
1024 }
1025 embedding_builder.append(false);
1026 }
1027 }
1028
1029 let id_array: ArrayRef = Arc::new(id_builder.finish());
1030 let external_id_array: ArrayRef = Arc::new(external_id_builder.finish());
1031 let run_id_array: ArrayRef = Arc::new(run_id_builder.finish());
1032 let bot_id_array: ArrayRef = Arc::new(bot_id_builder.finish());
1033 let session_id_array: ArrayRef = Arc::new(session_id_builder.finish());
1034 let created_at_array: ArrayRef = Arc::new(created_at_builder.finish());
1035 let role_array: ArrayRef = Arc::new(role_builder.finish());
1036 let metadata_array: ArrayRef = Arc::new(metadata_builder.finish());
1037 let content_type_array: ArrayRef = Arc::new(content_type_builder.finish());
1038 let text_array: ArrayRef = if text_is_blob {
1039 Arc::new(text_binary_builder.unwrap().finish())
1040 } else {
1041 Arc::new(text_string_builder.unwrap().finish())
1042 };
1043 let binary_array: ArrayRef = Arc::new(binary_builder.finish());
1044 let state_array: ArrayRef = Arc::new(state_builder.finish());
1045 let embedding_array: ArrayRef = Arc::new(embedding_builder.finish());
1046
1047 let schema = Arc::new(Self::schema_with_options(
1048 &self.blob_columns,
1049 include_external_id,
1050 include_metadata,
1051 ));
1052 let mut arrays = vec![id_array];
1053 if include_external_id {
1054 arrays.push(external_id_array);
1055 }
1056 arrays.extend([
1057 run_id_array,
1058 bot_id_array,
1059 session_id_array,
1060 created_at_array,
1061 role_array,
1062 state_array,
1063 ]);
1064 if include_metadata {
1065 arrays.push(metadata_array);
1066 }
1067 arrays.extend([
1068 content_type_array,
1069 text_array,
1070 binary_array,
1071 embedding_array,
1072 ]);
1073 let batch = RecordBatch::try_new(schema, arrays)?;
1074
1075 Ok(batch)
1076 }
1077}
1078
1079impl Drop for ContextStore {
1080 fn drop(&mut self) {
1081 if let Ok(mut state) = self.compaction_state.try_lock() {
1083 if let Some(task) = state.background_task.take() {
1084 task.abort();
1085 }
1086 }
1087 }
1088}
1089
1090fn batch_to_records(batch: &RecordBatch) -> LanceResult<Vec<ContextRecord>> {
1092 let id_array = column_as::<StringArray>(batch, "id")?;
1093 let external_id_array = column_as_optional::<StringArray>(batch, "external_id");
1094 let run_id_array = column_as::<StringArray>(batch, "run_id")?;
1095 let bot_id_array = column_as_optional::<StringArray>(batch, "bot_id");
1096 let session_id_array = column_as_optional::<StringArray>(batch, "session_id");
1097 let created_at_array = column_as::<TimestampMicrosecondArray>(batch, "created_at")?;
1098 let role_array = column_as::<DictionaryArray<Int8Type>>(batch, "role")?;
1099 let state_array = column_as::<StructArray>(batch, "state_metadata")?;
1100 let metadata_array = column_as_optional::<LargeStringArray>(batch, "metadata");
1101 let content_type_array = column_as::<StringArray>(batch, "content_type")?;
1102 let binary_array = column_as::<LargeBinaryArray>(batch, "binary_payload")?;
1103 let embedding_array = column_as::<FixedSizeListArray>(batch, "embedding")?;
1104
1105 let text_is_binary = batch
1107 .schema()
1108 .field_with_name("text_payload")
1109 .is_ok_and(|f| f.data_type() == &DataType::LargeBinary);
1110
1111 let text_string_array = if !text_is_binary {
1112 Some(column_as::<LargeStringArray>(batch, "text_payload")?)
1113 } else {
1114 None
1115 };
1116 let text_binary_array = if text_is_binary {
1117 Some(column_as::<LargeBinaryArray>(batch, "text_payload")?)
1118 } else {
1119 None
1120 };
1121
1122 let step_array = state_array
1123 .column(0)
1124 .as_ref()
1125 .as_any()
1126 .downcast_ref::<Int32Array>()
1127 .ok_or_else(|| {
1128 LanceError::from(ArrowError::InvalidArgumentError(
1129 "step column has unexpected data type".to_string(),
1130 ))
1131 })?;
1132 let active_plan_array = state_array
1133 .column(1)
1134 .as_ref()
1135 .as_any()
1136 .downcast_ref::<StringArray>()
1137 .ok_or_else(|| {
1138 LanceError::from(ArrowError::InvalidArgumentError(
1139 "active_plan_id column has unexpected data type".to_string(),
1140 ))
1141 })?;
1142 let tokens_used_array = state_array
1143 .column(2)
1144 .as_ref()
1145 .as_any()
1146 .downcast_ref::<Int32Array>()
1147 .ok_or_else(|| {
1148 LanceError::from(ArrowError::InvalidArgumentError(
1149 "tokens_used column has unexpected data type".to_string(),
1150 ))
1151 })?;
1152 let custom_array = state_array
1153 .column(3)
1154 .as_ref()
1155 .as_any()
1156 .downcast_ref::<StringArray>()
1157 .ok_or_else(|| {
1158 LanceError::from(ArrowError::InvalidArgumentError(
1159 "custom column has unexpected data type".to_string(),
1160 ))
1161 })?;
1162
1163 let mut results = Vec::with_capacity(batch.num_rows());
1164 for row in 0..batch.num_rows() {
1165 let created_at =
1166 DateTime::from_timestamp_micros(created_at_array.value(row)).ok_or_else(|| {
1167 LanceError::from(ArrowError::InvalidArgumentError(format!(
1168 "invalid timestamp value {}",
1169 created_at_array.value(row)
1170 )))
1171 })?;
1172
1173 let state_metadata = if state_array.is_null(row) {
1174 None
1175 } else {
1176 Some(StateMetadata {
1177 step: if step_array.is_null(row) {
1178 None
1179 } else {
1180 Some(step_array.value(row))
1181 },
1182 active_plan_id: if active_plan_array.is_null(row) {
1183 None
1184 } else {
1185 Some(active_plan_array.value(row).to_string())
1186 },
1187 tokens_used: if tokens_used_array.is_null(row) {
1188 None
1189 } else {
1190 Some(tokens_used_array.value(row))
1191 },
1192 custom: if custom_array.is_null(row) {
1193 None
1194 } else {
1195 Some(custom_array.value(row).to_string())
1196 },
1197 })
1198 };
1199
1200 let text_payload = if text_is_binary {
1201 let arr = text_binary_array.unwrap();
1202 if arr.is_null(row) {
1203 None
1204 } else {
1205 Some(String::from_utf8_lossy(arr.value(row)).to_string())
1206 }
1207 } else {
1208 let arr = text_string_array.unwrap();
1209 if arr.is_null(row) {
1210 None
1211 } else {
1212 Some(arr.value(row).to_string())
1213 }
1214 };
1215
1216 let binary_payload = if binary_array.is_null(row) {
1217 None
1218 } else {
1219 Some(binary_array.value(row).to_vec())
1220 };
1221
1222 let embedding = if embedding_array.is_null(row) {
1223 None
1224 } else {
1225 Some(embedding_from_list(embedding_array, row)?)
1226 };
1227
1228 let role = if role_array.is_null(row) {
1229 return Err(LanceError::from(ArrowError::InvalidArgumentError(
1230 "role column contains null values".to_string(),
1231 )));
1232 } else {
1233 let role_values = role_array
1234 .values()
1235 .as_any()
1236 .downcast_ref::<StringArray>()
1237 .ok_or_else(|| {
1238 LanceError::from(ArrowError::InvalidArgumentError(
1239 "role dictionary values are not strings".to_string(),
1240 ))
1241 })?;
1242 let key = role_array.keys().value(row) as usize;
1243 role_values.value(key).to_string()
1244 };
1245
1246 let bot_id = bot_id_array.and_then(|arr| {
1247 if arr.is_null(row) {
1248 None
1249 } else {
1250 Some(arr.value(row).to_string())
1251 }
1252 });
1253
1254 let session_id = session_id_array.and_then(|arr| {
1255 if arr.is_null(row) {
1256 None
1257 } else {
1258 Some(arr.value(row).to_string())
1259 }
1260 });
1261
1262 let metadata = match metadata_array {
1263 Some(arr) if !arr.is_null(row) => {
1264 Some(serde_json::from_str(arr.value(row)).map_err(|err| {
1265 LanceError::from(ArrowError::InvalidArgumentError(format!(
1266 "invalid metadata JSON for record {}: {}",
1267 id_array.value(row),
1268 err
1269 )))
1270 })?)
1271 }
1272 _ => None,
1273 };
1274
1275 results.push(ContextRecord {
1276 id: id_array.value(row).to_string(),
1277 external_id: external_id_array.and_then(|arr| {
1278 if arr.is_null(row) {
1279 None
1280 } else {
1281 Some(arr.value(row).to_string())
1282 }
1283 }),
1284 run_id: run_id_array.value(row).to_string(),
1285 bot_id,
1286 session_id,
1287 created_at,
1288 role,
1289 state_metadata,
1290 metadata,
1291 content_type: content_type_array.value(row).to_string(),
1292 text_payload,
1293 binary_payload,
1294 embedding,
1295 });
1296 }
1297
1298 Ok(results)
1299}
1300
1301fn embedding_from_list(list: &FixedSizeListArray, row: usize) -> LanceResult<Vec<f32>> {
1302 let values = list.value(row);
1303 let float_array = values
1304 .as_ref()
1305 .as_any()
1306 .downcast_ref::<Float32Array>()
1307 .ok_or_else(|| {
1308 LanceError::from(ArrowError::InvalidArgumentError(
1309 "embedding column does not contain float32 values".to_string(),
1310 ))
1311 })?;
1312 let mut embedding = Vec::with_capacity(float_array.len());
1313 for idx in 0..float_array.len() {
1314 embedding.push(float_array.value(idx));
1315 }
1316 Ok(embedding)
1317}
1318
1319fn l2_distance(left: &[f32], right: &[f32]) -> f32 {
1320 left.iter()
1321 .zip(right)
1322 .map(|(left, right)| {
1323 let delta = left - right;
1324 delta * delta
1325 })
1326 .sum::<f32>()
1327 .sqrt()
1328}
1329
1330fn column_as<'a, A>(batch: &'a RecordBatch, name: &str) -> LanceResult<&'a A>
1331where
1332 A: Array + 'static,
1333{
1334 let column = batch.column_by_name(name).ok_or_else(|| {
1335 LanceError::from(ArrowError::InvalidArgumentError(format!(
1336 "column '{name}' not found"
1337 )))
1338 })?;
1339 column.as_ref().as_any().downcast_ref::<A>().ok_or_else(|| {
1340 LanceError::from(ArrowError::InvalidArgumentError(format!(
1341 "column '{name}' has unexpected data type"
1342 )))
1343 })
1344}
1345
1346fn column_as_optional<'a, A>(batch: &'a RecordBatch, name: &str) -> Option<&'a A>
1347where
1348 A: Array + 'static,
1349{
1350 batch
1351 .column_by_name(name)
1352 .and_then(|col| col.as_ref().as_any().downcast_ref::<A>())
1353}
1354
1355#[cfg(test)]
1356mod tests {
1357 use super::*;
1358 use crate::serde::CONTENT_TYPE_TEXT;
1359 use chrono::Utc;
1360 use tempfile::TempDir;
1361
1362 fn make_embedding(pivot: f32) -> Vec<f32> {
1363 let mut values = vec![0.0; DEFAULT_EMBEDDING_DIM as usize];
1364 if !values.is_empty() {
1365 values[0] = pivot;
1366 }
1367 values
1368 }
1369
1370 fn text_record(id: &str, embedding_pivot: f32) -> ContextRecord {
1371 ContextRecord {
1372 id: id.to_string(),
1373 external_id: None,
1374 run_id: format!("run-{id}"),
1375 bot_id: None,
1376 session_id: None,
1377 created_at: Utc::now(),
1378 role: "user".to_string(),
1379 state_metadata: Some(StateMetadata {
1380 step: Some(1),
1381 active_plan_id: Some("plan".to_string()),
1382 tokens_used: Some(10),
1383 custom: None,
1384 }),
1385 metadata: None,
1386 content_type: CONTENT_TYPE_TEXT.to_string(),
1387 text_payload: Some(format!("payload-{id}")),
1388 binary_payload: None,
1389 embedding: Some(make_embedding(embedding_pivot)),
1390 }
1391 }
1392
1393 #[test]
1394 fn search_orders_by_distance() {
1395 let dir = TempDir::new().unwrap();
1396 let uri = dir.path().to_string_lossy().to_string();
1397 let runtime = tokio::runtime::Runtime::new().unwrap();
1398 runtime.block_on(async {
1399 let mut store = ContextStore::open(&uri).await.unwrap();
1400 let first = text_record("a", 0.0);
1401 let second = text_record("b", 1.0);
1402 store.add(&[first.clone(), second.clone()]).await.unwrap();
1403
1404 let query = make_embedding(1.0);
1405 let results = store.search(&query, Some(2)).await.unwrap();
1406
1407 assert_eq!(results.len(), 2);
1408 assert_eq!(results[0].record.id, second.id);
1409 assert!(
1410 results[0].distance <= results[1].distance,
1411 "results not ordered by distance: {:?}",
1412 results
1413 );
1414 });
1415 }
1416
1417 #[test]
1418 fn search_validates_query_length() {
1419 let dir = TempDir::new().unwrap();
1420 let uri = dir.path().to_string_lossy().to_string();
1421 let runtime = tokio::runtime::Runtime::new().unwrap();
1422 runtime.block_on(async {
1423 let store = ContextStore::open(&uri).await.unwrap();
1424 let err = store.search(&[0.0_f32], None).await.unwrap_err();
1425 let message = err.to_string();
1426 assert!(
1427 message.contains("embedding dimension"),
1428 "unexpected error message: {message}"
1429 );
1430 });
1431 }
1432
1433 #[test]
1434 fn external_id_roundtrips_and_supports_lookup() {
1435 let dir = TempDir::new().unwrap();
1436 let uri = dir.path().to_string_lossy().to_string();
1437 let runtime = tokio::runtime::Runtime::new().unwrap();
1438 runtime.block_on(async {
1439 let mut store = ContextStore::open(&uri).await.unwrap();
1440 let mut record = text_record("a", 0.0);
1441 record.external_id = Some("doc-123#chunk-1".to_string());
1442 store.add(std::slice::from_ref(&record)).await.unwrap();
1443
1444 let by_external_id = store
1445 .get_by_external_id("doc-123#chunk-1")
1446 .await
1447 .unwrap()
1448 .unwrap();
1449 assert_eq!(by_external_id.id, record.id);
1450 assert_eq!(by_external_id.external_id, record.external_id);
1451
1452 let by_id = store.get_by_id(&record.id).await.unwrap().unwrap();
1453 assert_eq!(by_id.external_id.as_deref(), Some("doc-123#chunk-1"));
1454
1455 let missing = store.get_by_external_id("missing").await.unwrap();
1456 assert!(missing.is_none());
1457 });
1458 }
1459
1460 #[test]
1461 fn add_rejects_duplicate_external_id() {
1462 let dir = TempDir::new().unwrap();
1463 let uri = dir.path().to_string_lossy().to_string();
1464 let runtime = tokio::runtime::Runtime::new().unwrap();
1465 runtime.block_on(async {
1466 let mut store = ContextStore::open(&uri).await.unwrap();
1467 let mut first = text_record("a", 0.0);
1468 first.external_id = Some("doc-123#chunk-1".to_string());
1469 store.add(std::slice::from_ref(&first)).await.unwrap();
1470
1471 let mut duplicate = text_record("b", 0.0);
1472 duplicate.external_id = first.external_id.clone();
1473 let err = store.add(&[duplicate]).await.unwrap_err();
1474 let message = err.to_string();
1475 assert!(
1476 message.contains("external_id") && message.contains("already exists"),
1477 "unexpected error message: {message}"
1478 );
1479 });
1480 }
1481
1482 #[test]
1483 fn add_rejects_reserved_tombstone_content_type() {
1484 let dir = TempDir::new().unwrap();
1485 let uri = dir.path().to_string_lossy().to_string();
1486 let runtime = tokio::runtime::Runtime::new().unwrap();
1487 runtime.block_on(async {
1488 let mut store = ContextStore::open(&uri).await.unwrap();
1489 let mut record = text_record("a", 0.0);
1490 record.content_type = CONTENT_TYPE_TOMBSTONE.to_string();
1491
1492 let err = store.add(&[record]).await.unwrap_err();
1493 let message = err.to_string();
1494 assert!(
1495 message.contains("reserved") && message.contains("tombstone"),
1496 "unexpected error message: {message}"
1497 );
1498 });
1499 }
1500
1501 #[test]
1502 fn delete_by_external_id_hides_record_from_default_reads() {
1503 let dir = TempDir::new().unwrap();
1504 let uri = dir.path().to_string_lossy().to_string();
1505 let runtime = tokio::runtime::Runtime::new().unwrap();
1506 runtime.block_on(async {
1507 let mut store = ContextStore::open(&uri).await.unwrap();
1508 let mut first = text_record("a", 0.0);
1509 first.external_id = Some("doc-123#chunk-1".to_string());
1510 let second = text_record("b", 2.0);
1511 store.add(&[first.clone(), second.clone()]).await.unwrap();
1512
1513 assert!(store
1514 .delete_by_external_id("doc-123#chunk-1")
1515 .await
1516 .unwrap());
1517
1518 assert!(store
1519 .get_by_external_id("doc-123#chunk-1")
1520 .await
1521 .unwrap()
1522 .is_none());
1523 assert!(store.get_by_id(&first.id).await.unwrap().is_none());
1524
1525 let records = store.list(None, None).await.unwrap();
1526 assert_eq!(records.len(), 1);
1527 assert_eq!(records[0].id, second.id);
1528
1529 let query = make_embedding(0.0);
1530 let hits = store.search(&query, Some(10)).await.unwrap();
1531 assert_eq!(hits.len(), 1);
1532 assert_eq!(hits[0].record.id, second.id);
1533 });
1534 }
1535
1536 #[test]
1537 fn delete_by_id_hides_record_from_default_reads() {
1538 let dir = TempDir::new().unwrap();
1539 let uri = dir.path().to_string_lossy().to_string();
1540 let runtime = tokio::runtime::Runtime::new().unwrap();
1541 runtime.block_on(async {
1542 let mut store = ContextStore::open(&uri).await.unwrap();
1543 let mut first = text_record("a", 0.0);
1544 first.external_id = Some("doc-123#chunk-1".to_string());
1545 let second = text_record("b", 2.0);
1546 store.add(&[first.clone(), second.clone()]).await.unwrap();
1547
1548 assert!(store.delete_by_id(&first.id).await.unwrap());
1549
1550 assert!(store.get_by_id(&first.id).await.unwrap().is_none());
1551 assert!(store
1552 .get_by_external_id("doc-123#chunk-1")
1553 .await
1554 .unwrap()
1555 .is_none());
1556
1557 let records = store.list(None, None).await.unwrap();
1558 assert_eq!(records.len(), 1);
1559 assert_eq!(records[0].id, second.id);
1560
1561 let query = make_embedding(0.0);
1562 let hits = store.search(&query, Some(10)).await.unwrap();
1563 assert_eq!(hits.len(), 1);
1564 assert_eq!(hits[0].record.id, second.id);
1565 });
1566 }
1567
1568 #[test]
1569 fn delete_missing_id_is_noop() {
1570 let dir = TempDir::new().unwrap();
1571 let uri = dir.path().to_string_lossy().to_string();
1572 let runtime = tokio::runtime::Runtime::new().unwrap();
1573 runtime.block_on(async {
1574 let mut store = ContextStore::open(&uri).await.unwrap();
1575 assert!(!store.delete_by_id("missing").await.unwrap());
1576 assert!(!store.delete_by_external_id("missing").await.unwrap());
1577 });
1578 }
1579
1580 #[test]
1581 fn external_id_can_be_reused_after_delete() {
1582 let dir = TempDir::new().unwrap();
1583 let uri = dir.path().to_string_lossy().to_string();
1584 let runtime = tokio::runtime::Runtime::new().unwrap();
1585 runtime.block_on(async {
1586 let mut store = ContextStore::open(&uri).await.unwrap();
1587 let mut first = text_record("a", 0.0);
1588 first.external_id = Some("doc-123#chunk-1".to_string());
1589 store.add(std::slice::from_ref(&first)).await.unwrap();
1590 assert!(store
1591 .delete_by_external_id("doc-123#chunk-1")
1592 .await
1593 .unwrap());
1594
1595 let mut replacement = text_record("b", 1.0);
1596 replacement.external_id = first.external_id.clone();
1597 store.add(std::slice::from_ref(&replacement)).await.unwrap();
1598
1599 let by_external_id = store
1600 .get_by_external_id("doc-123#chunk-1")
1601 .await
1602 .unwrap()
1603 .unwrap();
1604 assert_eq!(by_external_id.id, replacement.id);
1605 assert_eq!(store.list(None, None).await.unwrap().len(), 1);
1606 });
1607 }
1608
1609 #[test]
1610 fn test_region_id_derivation_explicit() {
1611 let bot_id = Some("bot-123".to_string());
1612 let session_id = Some("session-456".to_string());
1613
1614 let region_id_1 = ContextStore::derive_region_id(&bot_id, &session_id);
1615 let region_id_2 = ContextStore::derive_region_id(&bot_id, &session_id);
1616
1617 assert_eq!(
1618 region_id_1, region_id_2,
1619 "Region ID should be deterministic for same inputs"
1620 );
1621
1622 let other_session = Some("session-789".to_string());
1623 let region_id_3 = ContextStore::derive_region_id(&bot_id, &other_session);
1624
1625 assert_ne!(
1626 region_id_1, region_id_3,
1627 "Region ID should differ for different inputs"
1628 );
1629
1630 let region_id_none = ContextStore::derive_region_id(&None, &None);
1632 let region_id_none_2 = ContextStore::derive_region_id(&None, &None);
1633 assert_eq!(
1634 region_id_none, region_id_none_2,
1635 "Region ID for None/None should be deterministic"
1636 );
1637 }
1638
1639 #[test]
1640 fn test_add_multiple_regions() {
1641 let dir = TempDir::new().unwrap();
1642 let uri = dir.path().to_string_lossy().to_string();
1643 let runtime = tokio::runtime::Runtime::new().unwrap();
1644
1645 runtime.block_on(async {
1646 let mut store = ContextStore::open(&uri).await.unwrap();
1647
1648 let mut record1 = text_record("r1", 0.0);
1650 record1.bot_id = Some("bot-A".to_string());
1651 record1.session_id = Some("session-1".to_string());
1652
1653 let mut record2 = text_record("r2", 0.0);
1654 record2.bot_id = Some("bot-B".to_string());
1655 record2.session_id = Some("session-2".to_string());
1656
1657 store
1659 .add(&[record1.clone(), record2.clone()])
1660 .await
1661 .unwrap();
1662
1663 let store = ContextStore::open(&uri).await.unwrap();
1665
1666 let results = store.list(None, None).await.unwrap();
1668 assert_eq!(results.len(), 2);
1669
1670 let ids: Vec<String> = results.iter().map(|r| r.id.clone()).collect();
1671 assert!(ids.contains(&"r1".to_string()));
1672 assert!(ids.contains(&"r2".to_string()));
1673 });
1674 }
1675
1676 #[test]
1677 fn test_blob_binary_payload() {
1678 let dir = TempDir::new().unwrap();
1679 let uri = dir.path().to_string_lossy().to_string();
1680 let runtime = tokio::runtime::Runtime::new().unwrap();
1681
1682 runtime.block_on(async {
1683 let options = ContextStoreOptions {
1684 blob_columns: HashSet::from(["binary_payload".to_string()]),
1685 ..Default::default()
1686 };
1687 let mut store = ContextStore::open_with_options(&uri, options)
1688 .await
1689 .unwrap();
1690
1691 let mut record = text_record("blob-bin-1", 0.0);
1692 record.binary_payload = Some(vec![0xDE, 0xAD, 0xBE, 0xEF]);
1693 store.add(std::slice::from_ref(&record)).await.unwrap();
1694
1695 let schema = ContextStore::schema(&store.blob_columns);
1697 let field = schema.field_with_name("binary_payload").unwrap();
1698 assert_eq!(
1699 field.metadata().get("lance-encoding:blob"),
1700 Some(&"true".to_string()),
1701 );
1702 let text_field = schema.field_with_name("text_payload").unwrap();
1704 assert_eq!(text_field.data_type(), &DataType::LargeUtf8);
1705 assert!(text_field.metadata().get("lance-encoding:blob").is_none());
1706 });
1707 }
1708
1709 #[test]
1710 fn test_blob_text_payload() {
1711 let dir = TempDir::new().unwrap();
1712 let uri = dir.path().to_string_lossy().to_string();
1713 let runtime = tokio::runtime::Runtime::new().unwrap();
1714
1715 runtime.block_on(async {
1716 let options = ContextStoreOptions {
1717 blob_columns: HashSet::from(["text_payload".to_string()]),
1718 ..Default::default()
1719 };
1720 let mut store = ContextStore::open_with_options(&uri, options)
1721 .await
1722 .unwrap();
1723
1724 let record = text_record("blob-txt-1", 0.0);
1725 store.add(std::slice::from_ref(&record)).await.unwrap();
1726
1727 let batch = store
1729 .records_to_batch(std::slice::from_ref(&record))
1730 .unwrap();
1731 let batch_schema = batch.schema();
1732 let text_field = batch_schema.field_with_name("text_payload").unwrap();
1733 assert_eq!(
1734 text_field.data_type(),
1735 &DataType::LargeBinary,
1736 "text_payload should be LargeBinary when blob-encoded"
1737 );
1738
1739 let roundtripped = batch_to_records(&batch).unwrap();
1740 assert_eq!(roundtripped.len(), 1);
1741 assert_eq!(
1742 roundtripped[0].text_payload, record.text_payload,
1743 "text payload should survive blob roundtrip"
1744 );
1745 });
1746 }
1747
1748 #[test]
1749 fn test_blob_both_columns() {
1750 let dir = TempDir::new().unwrap();
1751 let uri = dir.path().to_string_lossy().to_string();
1752 let runtime = tokio::runtime::Runtime::new().unwrap();
1753
1754 runtime.block_on(async {
1755 let options = ContextStoreOptions {
1756 blob_columns: HashSet::from([
1757 "text_payload".to_string(),
1758 "binary_payload".to_string(),
1759 ]),
1760 ..Default::default()
1761 };
1762 let mut store = ContextStore::open_with_options(&uri, options)
1763 .await
1764 .unwrap();
1765
1766 let mut record = text_record("blob-both-1", 0.0);
1767 record.binary_payload = Some(b"hello binary".to_vec());
1768 store.add(std::slice::from_ref(&record)).await.unwrap();
1769
1770 let schema = ContextStore::schema(&store.blob_columns);
1772 let text_field = schema.field_with_name("text_payload").unwrap();
1773 let bin_field = schema.field_with_name("binary_payload").unwrap();
1774 assert_eq!(
1775 text_field.metadata().get("lance-encoding:blob"),
1776 Some(&"true".to_string()),
1777 );
1778 assert_eq!(
1779 bin_field.metadata().get("lance-encoding:blob"),
1780 Some(&"true".to_string()),
1781 );
1782
1783 let batch = store
1785 .records_to_batch(std::slice::from_ref(&record))
1786 .unwrap();
1787 let roundtripped = batch_to_records(&batch).unwrap();
1788 assert_eq!(roundtripped.len(), 1);
1789 assert_eq!(roundtripped[0].text_payload, record.text_payload);
1790 assert_eq!(roundtripped[0].binary_payload, record.binary_payload);
1791 });
1792 }
1793
1794 #[test]
1795 fn test_no_blob_default() {
1796 let schema = ContextStore::schema(&HashSet::new());
1798 let text_field = schema.field_with_name("text_payload").unwrap();
1799 let bin_field = schema.field_with_name("binary_payload").unwrap();
1800
1801 assert_eq!(text_field.data_type(), &DataType::LargeUtf8);
1802 assert!(text_field.metadata().get("lance-encoding:blob").is_none());
1803 assert_eq!(bin_field.data_type(), &DataType::LargeBinary);
1804 assert!(bin_field.metadata().get("lance-encoding:blob").is_none());
1805 }
1806
1807 #[test]
1808 fn test_blob_schema_metadata() {
1809 let blob_columns =
1810 HashSet::from(["text_payload".to_string(), "binary_payload".to_string()]);
1811 let schema = ContextStore::schema(&blob_columns);
1812
1813 let text_field = schema.field_with_name("text_payload").unwrap();
1814 assert_eq!(text_field.data_type(), &DataType::LargeBinary);
1815 assert_eq!(
1816 text_field.metadata().get("lance-encoding:blob"),
1817 Some(&"true".to_string()),
1818 );
1819
1820 let bin_field = schema.field_with_name("binary_payload").unwrap();
1821 assert_eq!(bin_field.data_type(), &DataType::LargeBinary);
1822 assert_eq!(
1823 bin_field.metadata().get("lance-encoding:blob"),
1824 Some(&"true".to_string()),
1825 );
1826
1827 let id_field = schema.field_with_name("id").unwrap();
1829 assert!(id_field.metadata().get("lance-encoding:blob").is_none());
1830 }
1831
1832 #[test]
1833 fn test_blob_invalid_column_name() {
1834 let dir = TempDir::new().unwrap();
1835 let uri = dir.path().to_string_lossy().to_string();
1836 let runtime = tokio::runtime::Runtime::new().unwrap();
1837
1838 runtime.block_on(async {
1839 let options = ContextStoreOptions {
1840 blob_columns: HashSet::from(["nonexistent_column".to_string()]),
1841 ..Default::default()
1842 };
1843 let result = ContextStore::open_with_options(&uri, options).await;
1844 assert!(result.is_err(), "should reject invalid blob column names");
1845 let err_msg = result.err().unwrap().to_string();
1846 assert!(
1847 err_msg.contains("invalid blob column"),
1848 "error should mention invalid blob column: {err_msg}"
1849 );
1850 });
1851 }
1852
1853 #[test]
1854 fn test_batch_to_records_autodetects_text_type() {
1855 let runtime = tokio::runtime::Runtime::new().unwrap();
1858 runtime.block_on(async {
1859 let dir1 = TempDir::new().unwrap();
1861 let uri1 = dir1.path().to_string_lossy().to_string();
1862 let store_default = ContextStore::open(&uri1).await.unwrap();
1863 let record = text_record("auto-1", 0.0);
1864 let batch_utf8 = store_default
1865 .records_to_batch(std::slice::from_ref(&record))
1866 .unwrap();
1867 let results_utf8 = batch_to_records(&batch_utf8).unwrap();
1868 assert_eq!(results_utf8[0].text_payload, record.text_payload);
1869
1870 let dir2 = TempDir::new().unwrap();
1872 let uri2 = dir2.path().to_string_lossy().to_string();
1873 let options = ContextStoreOptions {
1874 blob_columns: HashSet::from(["text_payload".to_string()]),
1875 ..Default::default()
1876 };
1877 let store_blob = ContextStore::open_with_options(&uri2, options)
1878 .await
1879 .unwrap();
1880 let batch_binary = store_blob
1881 .records_to_batch(std::slice::from_ref(&record))
1882 .unwrap();
1883 let results_binary = batch_to_records(&batch_binary).unwrap();
1884 assert_eq!(results_binary[0].text_payload, record.text_payload);
1885 });
1886 }
1887
1888 #[test]
1889 fn test_id_index_btree() {
1890 let dir = TempDir::new().unwrap();
1891 let uri = dir.path().to_string_lossy().to_string();
1892 let runtime = tokio::runtime::Runtime::new().unwrap();
1893
1894 runtime.block_on(async {
1895 let options = ContextStoreOptions {
1896 id_index_type: IdIndexType::BTree,
1897 ..Default::default()
1898 };
1899 let mut store = ContextStore::open_with_options(&uri, options)
1900 .await
1901 .unwrap();
1902
1903 let indices = store.dataset.load_indices().await.unwrap();
1905 assert!(
1906 indices.iter().any(|i| i.name == ID_INDEX_NAME),
1907 "btree index should be created on open"
1908 );
1909
1910 for i in 0..5 {
1912 store
1913 .add(&[text_record(&format!("btree-{i}"), i as f32)])
1914 .await
1915 .unwrap();
1916 }
1917 store.compact(None).await.unwrap();
1918
1919 let indices = store.dataset.load_indices().await.unwrap();
1921 assert!(
1922 indices.iter().any(|i| i.name == ID_INDEX_NAME),
1923 "btree index should persist after compaction"
1924 );
1925 });
1926 }
1927
1928 #[test]
1929 fn test_id_index_zonemap() {
1930 let dir = TempDir::new().unwrap();
1931 let uri = dir.path().to_string_lossy().to_string();
1932 let runtime = tokio::runtime::Runtime::new().unwrap();
1933
1934 runtime.block_on(async {
1935 let options = ContextStoreOptions {
1936 id_index_type: IdIndexType::ZoneMap,
1937 ..Default::default()
1938 };
1939 let mut store = ContextStore::open_with_options(&uri, options)
1940 .await
1941 .unwrap();
1942
1943 let indices = store.dataset.load_indices().await.unwrap();
1945 assert!(
1946 indices.iter().any(|i| i.name == ID_INDEX_NAME),
1947 "zonemap index should be created on open"
1948 );
1949
1950 for i in 0..5 {
1951 store
1952 .add(&[text_record(&format!("zm-{i}"), i as f32)])
1953 .await
1954 .unwrap();
1955 }
1956 store.compact(None).await.unwrap();
1957
1958 let indices = store.dataset.load_indices().await.unwrap();
1959 assert!(
1960 indices.iter().any(|i| i.name == ID_INDEX_NAME),
1961 "zonemap index should persist after compaction"
1962 );
1963 });
1964 }
1965
1966 #[test]
1967 fn test_id_index_none_by_default() {
1968 let dir = TempDir::new().unwrap();
1969 let uri = dir.path().to_string_lossy().to_string();
1970 let runtime = tokio::runtime::Runtime::new().unwrap();
1971
1972 runtime.block_on(async {
1973 let mut store = ContextStore::open(&uri).await.unwrap();
1974
1975 store.add(&[text_record("no-idx-1", 0.0)]).await.unwrap();
1976 store.compact(None).await.unwrap();
1977
1978 let indices = store.dataset.load_indices().await.unwrap();
1979 assert!(
1980 !indices.iter().any(|i| i.name == ID_INDEX_NAME),
1981 "no id index should be created when IdIndexType::None"
1982 );
1983 });
1984 }
1985
1986 #[test]
1987 fn test_id_index_idempotent() {
1988 let dir = TempDir::new().unwrap();
1989 let uri = dir.path().to_string_lossy().to_string();
1990 let runtime = tokio::runtime::Runtime::new().unwrap();
1991
1992 runtime.block_on(async {
1993 let options = ContextStoreOptions {
1994 id_index_type: IdIndexType::BTree,
1995 ..Default::default()
1996 };
1997 let mut store = ContextStore::open_with_options(&uri, options)
1998 .await
1999 .unwrap();
2000
2001 for i in 0..5 {
2002 store
2003 .add(&[text_record(&format!("idem-{i}"), i as f32)])
2004 .await
2005 .unwrap();
2006 }
2007
2008 store.create_id_index().await.unwrap();
2010 let v1 = store.version();
2011 store.ensure_id_index().await.unwrap();
2012 let v2 = store.version();
2013 assert_eq!(v1, v2, "ensure_id_index should not recreate existing index");
2014 });
2015 }
2016}