Skip to main content

lance_context_core/
store.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3use std::time::Duration;
4
5use arrow_array::builder::{
6    FixedSizeListBuilder, Float32Builder, Int32Builder, LargeBinaryBuilder, LargeStringBuilder,
7    StringBuilder, StringDictionaryBuilder, StructBuilder, TimestampMicrosecondBuilder,
8};
9use arrow_array::types::Int8Type;
10use arrow_array::{
11    Array, ArrayRef, DictionaryArray, FixedSizeListArray, Float32Array, Int32Array,
12    LargeBinaryArray, LargeStringArray, RecordBatch, RecordBatchIterator, StringArray, StructArray,
13    TimestampMicrosecondArray,
14};
15use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, TimeUnit};
16use chrono::{DateTime, Timelike, Utc};
17use futures::TryStreamExt;
18use lance::dataset::mem_wal::{
19    DatasetMemWalExt, LsmScanner, ShardManifestStore, ShardSnapshot, ShardWriterConfig,
20};
21use lance::dataset::optimize::{compact_files, CompactionMetrics, CompactionOptions};
22use lance::dataset::{builder::DatasetBuilder, Dataset, WriteMode, WriteParams};
23use lance::index::DatasetIndexExt;
24use lance::io::{ObjectStoreParams, StorageOptionsAccessor};
25use lance::{Error as LanceError, Result as LanceResult};
26use lance_index::mem_wal::MEM_WAL_INDEX_NAME;
27use lance_index::scalar::ScalarIndexParams;
28use lance_index::IndexType;
29use tokio::sync::Mutex;
30use tokio::task::JoinHandle;
31use tracing::{error, info, warn};
32use uuid::Uuid;
33
34use crate::record::{ContextRecord, RecordFilters, SearchResult, StateMetadata};
35use crate::serde::CONTENT_TYPE_TOMBSTONE;
36
37/// Embedding length used for the semantic index column.
38const DEFAULT_EMBEDDING_DIM: i32 = 1536;
39const DEFAULT_SEARCH_LIMIT: usize = 10;
40const DEFAULT_MANIFEST_SCAN_BATCH_SIZE: usize = 16;
41const ID_INDEX_NAME: &str = "id_idx";
42
43/// Configuration for background compaction.
44#[derive(Debug, Clone)]
45pub struct CompactionConfig {
46    /// Whether background compaction is enabled.
47    pub enabled: bool,
48    /// Minimum number of fragments to trigger compaction.
49    pub min_fragments: usize,
50    /// Target rows per fragment after compaction.
51    pub target_rows_per_fragment: usize,
52    /// Maximum rows per row group.
53    pub max_rows_per_group: usize,
54    /// Whether to materialize (remove) deleted rows during compaction.
55    pub materialize_deletions: bool,
56    /// Deletion threshold (0.0-1.0) to trigger materialization.
57    pub materialize_deletions_threshold: f32,
58    /// Number of threads for compaction (None = auto).
59    pub num_threads: Option<usize>,
60    /// Interval in seconds between compaction checks.
61    pub check_interval_secs: u64,
62    /// Quiet hours during which compaction is skipped [(start_hour, end_hour)].
63    pub quiet_hours: Vec<(u8, u8)>,
64}
65
66impl Default for CompactionConfig {
67    fn default() -> Self {
68        Self {
69            enabled: false,
70            min_fragments: 5,
71            target_rows_per_fragment: 1_000_000,
72            max_rows_per_group: 1024,
73            materialize_deletions: true,
74            materialize_deletions_threshold: 0.1,
75            num_threads: None,
76            check_interval_secs: 300,
77            quiet_hours: vec![],
78        }
79    }
80}
81
82/// Type of scalar index on the `id` column.
83#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
84pub enum IdIndexType {
85    /// No index on the id column.
86    #[default]
87    None,
88    /// Zone-map index (min/max per fragment, lightweight).
89    ZoneMap,
90    /// B-tree index (point lookups, heavier).
91    BTree,
92}
93
94/// Statistics about compaction status and history.
95#[derive(Debug, Clone)]
96pub struct CompactionStats {
97    /// Current number of fragments in the dataset.
98    pub total_fragments: usize,
99    /// Whether a compaction is currently in progress.
100    pub is_compacting: bool,
101    /// Timestamp of the last successful compaction.
102    pub last_compaction: Option<DateTime<Utc>>,
103    /// Error message from the last failed compaction.
104    pub last_error: Option<String>,
105    /// Total number of successful compactions performed.
106    pub total_compactions: u64,
107}
108
109/// Internal state for tracking background compaction.
110struct CompactionState {
111    background_task: Option<JoinHandle<()>>,
112    is_compacting: bool,
113    last_compaction: Option<DateTime<Utc>>,
114    last_error: Option<String>,
115    total_compactions: u64,
116}
117
118/// Valid column names that may use blob encoding.
119const VALID_BLOB_COLUMNS: &[&str] = &["text_payload", "binary_payload"];
120
121/// Persistent Lance-backed context store.
122#[derive(Clone)]
123pub struct ContextStore {
124    dataset: Dataset,
125    compaction_state: Arc<Mutex<CompactionState>>,
126    pub compaction_config: CompactionConfig,
127    blob_columns: HashSet<String>,
128    id_index_type: IdIndexType,
129}
130
131/// Additional configuration when opening a [`ContextStore`].
132#[derive(Debug, Clone, Default)]
133pub struct ContextStoreOptions {
134    pub storage_options: Option<HashMap<String, String>>,
135    pub compaction: CompactionConfig,
136    /// Column names that should use Lance V1 blob encoding.
137    /// Valid values: `"text_payload"`, `"binary_payload"`.
138    pub blob_columns: HashSet<String>,
139    /// Type of scalar index to create on the `id` column.
140    pub id_index_type: IdIndexType,
141}
142
143impl ContextStoreOptions {
144    #[must_use]
145    pub fn storage_options(&self) -> Option<HashMap<String, String>> {
146        self.storage_options.clone()
147    }
148}
149
150impl ContextStore {
151    /// Open an existing context dataset or create a new one with the project schema.
152    pub async fn open(uri: &str) -> LanceResult<Self> {
153        Self::open_with_options(uri, ContextStoreOptions::default()).await
154    }
155
156    /// Open a dataset with explicit object store configuration (e.g. S3 credentials).
157    pub async fn open_with_options(uri: &str, options: ContextStoreOptions) -> LanceResult<Self> {
158        // Validate blob_columns
159        for col in &options.blob_columns {
160            if !VALID_BLOB_COLUMNS.contains(&col.as_str()) {
161                return Err(LanceError::from(ArrowError::InvalidArgumentError(format!(
162                    "invalid blob column '{}': valid columns are {:?}",
163                    col, VALID_BLOB_COLUMNS
164                ))));
165            }
166        }
167
168        let storage_options = options.storage_options();
169        let blob_columns = options.blob_columns.clone();
170        let dataset = match Self::load_with_options(uri, storage_options.clone()).await {
171            Ok(dataset) => dataset,
172            Err(LanceError::DatasetNotFound { .. }) => {
173                Self::create_with_options(uri, storage_options, &blob_columns).await?
174            }
175            Err(err) => return Err(err),
176        };
177
178        let mut store = Self {
179            dataset,
180            compaction_state: Arc::new(Mutex::new(CompactionState {
181                background_task: None,
182                is_compacting: false,
183                last_compaction: None,
184                last_error: None,
185                total_compactions: 0,
186            })),
187            compaction_config: options.compaction,
188            blob_columns,
189            id_index_type: options.id_index_type,
190        };
191
192        // Ensure id index if configured
193        store.ensure_id_index().await?;
194
195        // Start background compaction if enabled
196        store.start_background_compaction().await?;
197
198        Ok(store)
199    }
200
201    /// Append context records to the store and return the new dataset version.
202    pub async fn add(&mut self, entries: &[ContextRecord]) -> LanceResult<u64> {
203        if entries.is_empty() {
204            return Ok(self.dataset.manifest.version);
205        }
206
207        self.validate_unique_ids(entries).await?;
208        self.write_entries(entries).await
209    }
210
211    async fn write_entries(&mut self, entries: &[ContextRecord]) -> LanceResult<u64> {
212        if entries.is_empty() {
213            return Ok(self.dataset.manifest.version);
214        }
215
216        // Group entries by (bot_id, session_id)
217        let mut groups: HashMap<(Option<String>, Option<String>), Vec<ContextRecord>> =
218            HashMap::new();
219        for entry in entries {
220            let key = (entry.bot_id.clone(), entry.session_id.clone());
221            groups.entry(key).or_default().push(entry.clone());
222        }
223
224        // Ensure MemWAL is initialized (once for the dataset)
225        {
226            let indices = self.dataset.load_indices().await?;
227            let has_mem_wal = indices.iter().any(|i| i.name == MEM_WAL_INDEX_NAME);
228
229            if !has_mem_wal {
230                // ZoneMap indices are not supported by MemWAL; exclude them
231                let maintained_indexes: Vec<String> = indices
232                    .iter()
233                    .filter(|i| {
234                        !(self.id_index_type == IdIndexType::ZoneMap && i.name == ID_INDEX_NAME)
235                    })
236                    .map(|i| i.name.clone())
237                    .collect();
238                self.dataset
239                    .initialize_mem_wal()
240                    .unsharded()
241                    .maintained_indexes(maintained_indexes)
242                    .execute()
243                    .await?;
244            }
245        }
246
247        for ((bot_id, session_id), group_entries) in groups {
248            let region_id = Self::derive_region_id(&bot_id, &session_id);
249            let batch = self.records_to_batch(&group_entries)?;
250            let config = ShardWriterConfig {
251                shard_id: region_id,
252                ..Default::default()
253            };
254
255            let writer = self.dataset.mem_wal_writer(region_id, config).await?;
256            writer.put(vec![batch]).await?;
257            writer.close().await?;
258        }
259
260        Ok(self.dataset.manifest.version)
261    }
262
263    /// Logically forget a record by internal storage id.
264    ///
265    /// This writes a tombstone with the same primary key, preserving prior
266    /// dataset versions while hiding the record from default reads.
267    pub async fn delete_by_id(&mut self, id: &str) -> LanceResult<bool> {
268        let Some(record) = self.get_by_id(id).await? else {
269            return Ok(false);
270        };
271        self.write_tombstone_for(record).await?;
272        Ok(true)
273    }
274
275    /// Logically forget a record by caller-supplied external id.
276    pub async fn delete_by_external_id(&mut self, external_id: &str) -> LanceResult<bool> {
277        let Some(record) = self.get_by_external_id(external_id).await? else {
278            return Ok(false);
279        };
280        self.write_tombstone_for(record).await?;
281        Ok(true)
282    }
283
284    async fn write_tombstone_for(&mut self, record: ContextRecord) -> LanceResult<u64> {
285        let tombstone = ContextRecord {
286            id: record.id,
287            external_id: record.external_id,
288            run_id: record.run_id,
289            bot_id: record.bot_id,
290            session_id: record.session_id,
291            created_at: Utc::now(),
292            role: record.role,
293            state_metadata: None,
294            metadata: None,
295            content_type: CONTENT_TYPE_TOMBSTONE.to_string(),
296            text_payload: None,
297            binary_payload: None,
298            embedding: None,
299        };
300        self.write_entries(std::slice::from_ref(&tombstone)).await
301    }
302
303    async fn validate_unique_ids(&self, entries: &[ContextRecord]) -> LanceResult<()> {
304        let mut ids = HashSet::new();
305        let mut external_ids = HashSet::new();
306        for entry in entries {
307            if entry.is_tombstone() {
308                return Err(ArrowError::InvalidArgumentError(format!(
309                    "content_type '{}' is reserved for internal tombstones",
310                    CONTENT_TYPE_TOMBSTONE
311                ))
312                .into());
313            }
314            if !ids.insert(entry.id.as_str()) {
315                return Err(ArrowError::InvalidArgumentError(format!(
316                    "duplicate id '{}' in batch",
317                    entry.id
318                ))
319                .into());
320            }
321            if let Some(external_id) = &entry.external_id {
322                if !external_ids.insert(external_id.as_str()) {
323                    return Err(ArrowError::InvalidArgumentError(format!(
324                        "duplicate external_id '{}' in batch",
325                        external_id
326                    ))
327                    .into());
328                }
329            }
330        }
331
332        for record in self.list(None, None).await? {
333            if ids.contains(record.id.as_str()) {
334                return Err(ArrowError::InvalidArgumentError(format!(
335                    "id '{}' already exists",
336                    record.id
337                ))
338                .into());
339            }
340            if let Some(external_id) = record.external_id {
341                if external_ids.contains(external_id.as_str()) {
342                    return Err(ArrowError::InvalidArgumentError(format!(
343                        "external_id '{}' already exists",
344                        external_id
345                    ))
346                    .into());
347                }
348            }
349        }
350
351        Ok(())
352    }
353
354    fn derive_region_id(bot_id: &Option<String>, session_id: &Option<String>) -> Uuid {
355        let mut input = String::new();
356
357        if let Some(bid) = bot_id {
358            input.push_str(bid);
359        }
360        input.push('#');
361        if let Some(sid) = session_id {
362            input.push_str(sid);
363        }
364
365        // Use OID namespace as a base for our deterministic UUIDs
366        Uuid::new_v5(&Uuid::NAMESPACE_OID, input.as_bytes())
367    }
368
369    /// Current dataset version.
370    pub fn version(&self) -> u64 {
371        self.dataset.manifest.version
372    }
373
374    /// Checkout a specific dataset version.
375    pub async fn checkout(&mut self, version_id: u64) -> LanceResult<()> {
376        let dataset = self.dataset.checkout_version(version_id).await?;
377        self.dataset = dataset;
378        Ok(())
379    }
380
381    /// List all records in the dataset.
382    pub async fn list(
383        &self,
384        limit: Option<usize>,
385        offset: Option<usize>,
386    ) -> LanceResult<Vec<ContextRecord>> {
387        self.list_filtered(limit, offset, None).await
388    }
389
390    /// List records matching filters.
391    pub async fn list_filtered(
392        &self,
393        limit: Option<usize>,
394        offset: Option<usize>,
395        filters: Option<&RecordFilters>,
396    ) -> LanceResult<Vec<ContextRecord>> {
397        let scanner = self.lsm_scanner().await?;
398        let mut stream = scanner.try_into_stream().await?;
399        let mut results = Vec::new();
400        while let Some(batch) = stream.try_next().await? {
401            results.extend(
402                batch_to_records(&batch)?
403                    .into_iter()
404                    .filter(|record| !record.is_tombstone()),
405            );
406        }
407
408        if let Some(filters) = filters.filter(|filters| !filters.is_empty()) {
409            results.retain(|record| filters.matches(record));
410        }
411
412        if let Some(offset) = offset {
413            results = results.into_iter().skip(offset).collect();
414        }
415        if let Some(limit) = limit {
416            results.truncate(limit);
417        }
418        Ok(results)
419    }
420
421    /// Find a record by its internal storage id.
422    pub async fn get_by_id(&self, id: &str) -> LanceResult<Option<ContextRecord>> {
423        Ok(self
424            .list(None, None)
425            .await?
426            .into_iter()
427            .find(|record| record.id == id))
428    }
429
430    /// Find a record by its caller-supplied external id.
431    pub async fn get_by_external_id(
432        &self,
433        external_id: &str,
434    ) -> LanceResult<Option<ContextRecord>> {
435        Ok(self
436            .list(None, None)
437            .await?
438            .into_iter()
439            .find(|record| record.external_id.as_deref() == Some(external_id)))
440    }
441
442    /// Perform a nearest-neighbor search over stored embeddings.
443    pub async fn search(
444        &self,
445        query: &[f32],
446        limit: Option<usize>,
447    ) -> LanceResult<Vec<SearchResult>> {
448        self.search_filtered(query, limit, None).await
449    }
450
451    /// Perform a nearest-neighbor search over stored embeddings matching filters.
452    pub async fn search_filtered(
453        &self,
454        query: &[f32],
455        limit: Option<usize>,
456        filters: Option<&RecordFilters>,
457    ) -> LanceResult<Vec<SearchResult>> {
458        if query.len() != DEFAULT_EMBEDDING_DIM as usize {
459            return Err(ArrowError::InvalidArgumentError(format!(
460                "query length {} does not match embedding dimension {}",
461                query.len(),
462                DEFAULT_EMBEDDING_DIM
463            ))
464            .into());
465        }
466
467        let top_k = limit.unwrap_or(DEFAULT_SEARCH_LIMIT);
468        if top_k == 0 {
469            return Ok(Vec::new());
470        }
471
472        let mut results: Vec<SearchResult> = self
473            .list_filtered(None, None, filters)
474            .await?
475            .into_iter()
476            .filter_map(|record| {
477                let distance = l2_distance(query, record.embedding.as_ref()?);
478                Some(SearchResult { record, distance })
479            })
480            .collect();
481        results.sort_by(|left, right| left.distance.total_cmp(&right.distance));
482        results.truncate(top_k);
483        Ok(results)
484    }
485
486    async fn lsm_scanner(&self) -> LanceResult<LsmScanner> {
487        let object_store = self.dataset.object_store(None).await?;
488        let branch_location = self.dataset.branch_location();
489        let shard_ids = self.dataset.list_mem_wal_latest_shard_ids().await?;
490
491        let mut shard_snapshots = Vec::with_capacity(shard_ids.len());
492        for shard_id in shard_ids {
493            let manifest_store = ShardManifestStore::new(
494                object_store.clone(),
495                &branch_location.path,
496                shard_id,
497                DEFAULT_MANIFEST_SCAN_BATCH_SIZE,
498            );
499            let Some(manifest) = manifest_store.read_latest().await? else {
500                continue;
501            };
502
503            let mut snapshot = ShardSnapshot::new(shard_id)
504                .with_spec_id(manifest.shard_spec_id)
505                .with_current_generation(manifest.current_generation);
506            for flushed in manifest.flushed_generations {
507                snapshot = snapshot.with_flushed_generation(flushed.generation, flushed.path);
508            }
509            shard_snapshots.push(snapshot);
510        }
511
512        Ok(LsmScanner::new(
513            Arc::new(self.dataset.clone()),
514            shard_snapshots,
515            vec!["id".to_string()],
516        ))
517    }
518
519    /// Manually trigger compaction to merge small fragments.
520    pub async fn compact(
521        &mut self,
522        options: Option<CompactionConfig>,
523    ) -> LanceResult<CompactionMetrics> {
524        let config = options.unwrap_or_else(|| self.compaction_config.clone());
525
526        info!(
527            "Starting compaction: {} fragments",
528            self.dataset.count_fragments()
529        );
530        let start = std::time::Instant::now();
531
532        // Mark as compacting
533        {
534            let mut state = self.compaction_state.lock().await;
535            if state.is_compacting {
536                warn!("Compaction already in progress, skipping");
537                return Err(LanceError::from(ArrowError::InvalidArgumentError(
538                    "Compaction already in progress".to_string(),
539                )));
540            }
541            state.is_compacting = true;
542        }
543
544        // Build Lance CompactionOptions
545        let lance_options = CompactionOptions {
546            target_rows_per_fragment: config.target_rows_per_fragment,
547            max_rows_per_group: config.max_rows_per_group,
548            materialize_deletions: config.materialize_deletions,
549            materialize_deletions_threshold: config.materialize_deletions_threshold,
550            num_threads: config.num_threads,
551            ..Default::default()
552        };
553
554        // Run compaction
555        let result = compact_files(&mut self.dataset, lance_options, None).await;
556
557        // Update state
558        let mut state = self.compaction_state.lock().await;
559        state.is_compacting = false;
560
561        match result {
562            Ok(metrics) => {
563                state.last_compaction = Some(Utc::now());
564                state.total_compactions += 1;
565                state.last_error = None;
566                drop(state); // Release lock before ensure_id_index
567
568                info!(
569                    "Compaction completed in {:?}: removed {} fragments ({}files), added {} fragments ({} files)",
570                    start.elapsed(),
571                    metrics.fragments_removed,
572                    metrics.files_removed,
573                    metrics.fragments_added,
574                    metrics.files_added
575                );
576
577                // Reload dataset to see new version
578                self.dataset = Dataset::open(self.dataset.uri()).await?;
579
580                // Ensure id index exists after compaction
581                // (handles first-time creation on previously empty dataset)
582                if let Err(e) = self.ensure_id_index().await {
583                    warn!("Failed to ensure id index after compaction: {}", e);
584                }
585
586                Ok(metrics)
587            }
588            Err(e) => {
589                error!("Compaction failed: {}", e);
590                state.last_error = Some(e.to_string());
591                Err(e)
592            }
593        }
594    }
595
596    /// Check if compaction should run based on configuration thresholds.
597    pub async fn should_compact(&self) -> LanceResult<bool> {
598        let fragment_count = self.dataset.count_fragments();
599
600        if fragment_count < self.compaction_config.min_fragments {
601            return Ok(false);
602        }
603
604        // Check quiet hours
605        if !self.compaction_config.quiet_hours.is_empty() {
606            let now = Utc::now();
607            let current_hour = now.hour() as u8;
608
609            for (start, end) in &self.compaction_config.quiet_hours {
610                if current_hour >= *start && current_hour < *end {
611                    info!("Skipping compaction during quiet hours ({}-{})", start, end);
612                    return Ok(false);
613                }
614            }
615        }
616
617        Ok(true)
618    }
619
620    /// Get current compaction statistics.
621    pub async fn compaction_stats(&self) -> LanceResult<CompactionStats> {
622        let state = self.compaction_state.lock().await;
623
624        Ok(CompactionStats {
625            total_fragments: self.dataset.count_fragments(),
626            is_compacting: state.is_compacting,
627            last_compaction: state.last_compaction,
628            last_error: state.last_error.clone(),
629            total_compactions: state.total_compactions,
630        })
631    }
632
633    /// Ensure the configured id index exists on the dataset.
634    async fn ensure_id_index(&mut self) -> LanceResult<()> {
635        if self.id_index_type == IdIndexType::None {
636            return Ok(());
637        }
638
639        let indices = self.dataset.load_indices().await?;
640        if indices.iter().any(|i| i.name == ID_INDEX_NAME) {
641            return Ok(());
642        }
643
644        self.create_id_index().await
645    }
646
647    /// Create (or replace) the scalar index on the `id` column.
648    pub async fn create_id_index(&mut self) -> LanceResult<()> {
649        let index_type = match self.id_index_type {
650            IdIndexType::ZoneMap => IndexType::ZoneMap,
651            IdIndexType::BTree => IndexType::BTree,
652            IdIndexType::None => return Ok(()),
653        };
654
655        info!("Creating {:?} index on id column", index_type);
656
657        let params = ScalarIndexParams::default();
658
659        self.dataset
660            .create_index_builder(&["id"], index_type, &params)
661            .name(ID_INDEX_NAME.to_string())
662            .replace(true)
663            .await?;
664
665        // Reload dataset to pick up new index
666        self.dataset = Dataset::open(self.dataset.uri()).await?;
667
668        Ok(())
669    }
670
671    /// Start background compaction task if enabled.
672    async fn start_background_compaction(&mut self) -> LanceResult<()> {
673        if !self.compaction_config.enabled {
674            return Ok(());
675        }
676
677        let mut state = self.compaction_state.lock().await;
678        if state.background_task.is_some() {
679            warn!("Background compaction already running");
680            return Ok(());
681        }
682
683        info!(
684            "Starting background compaction (interval: {}s, min fragments: {})",
685            self.compaction_config.check_interval_secs, self.compaction_config.min_fragments
686        );
687
688        let mut store_clone = self.clone();
689        let interval_secs = self.compaction_config.check_interval_secs;
690
691        let task = tokio::spawn(async move {
692            let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
693
694            loop {
695                interval.tick().await;
696
697                match store_clone.should_compact().await {
698                    Ok(true) => {
699                        info!("Background compaction triggered");
700                        if let Err(e) = store_clone.compact(None).await {
701                            error!("Background compaction failed: {}", e);
702                        }
703                    }
704                    Ok(false) => {
705                        // Not needed or in quiet hours
706                    }
707                    Err(e) => {
708                        error!("Error checking compaction need: {}", e);
709                    }
710                }
711            }
712        });
713
714        state.background_task = Some(task);
715        Ok(())
716    }
717
718    /// Stop background compaction task.
719    pub async fn stop_background_compaction(&mut self) -> LanceResult<()> {
720        let mut state = self.compaction_state.lock().await;
721
722        if let Some(task) = state.background_task.take() {
723            info!("Stopping background compaction");
724            task.abort();
725        }
726
727        Ok(())
728    }
729
730    /// Lance schema for the context store.
731    ///
732    /// When `blob_columns` contains a column name, that column is stored using
733    /// Lance V1 blob encoding (out-of-line binary buffers). For `text_payload`,
734    /// this also changes the Arrow type from `LargeUtf8` to `LargeBinary`.
735    pub fn schema(blob_columns: &HashSet<String>) -> Schema {
736        Self::schema_with_options(blob_columns, true, true)
737    }
738
739    fn schema_with_options(
740        blob_columns: &HashSet<String>,
741        include_external_id: bool,
742        include_metadata: bool,
743    ) -> Schema {
744        let mut id_metadata = HashMap::new();
745        id_metadata.insert(
746            "lance-schema:unenforced-primary-key".to_string(),
747            "true".to_string(),
748        );
749
750        let text_field = if blob_columns.contains("text_payload") {
751            let mut metadata = HashMap::new();
752            metadata.insert("lance-encoding:blob".to_string(), "true".to_string());
753            Field::new("text_payload", DataType::LargeBinary, true).with_metadata(metadata)
754        } else {
755            Field::new("text_payload", DataType::LargeUtf8, true)
756        };
757
758        let binary_field = if blob_columns.contains("binary_payload") {
759            let mut metadata = HashMap::new();
760            metadata.insert("lance-encoding:blob".to_string(), "true".to_string());
761            Field::new("binary_payload", DataType::LargeBinary, true).with_metadata(metadata)
762        } else {
763            Field::new("binary_payload", DataType::LargeBinary, true)
764        };
765
766        let mut fields = vec![Field::new("id", DataType::Utf8, false).with_metadata(id_metadata)];
767        if include_external_id {
768            fields.push(Field::new("external_id", DataType::Utf8, true));
769        }
770        fields.extend([
771            Field::new("run_id", DataType::Utf8, false),
772            Field::new("bot_id", DataType::Utf8, true),
773            Field::new("session_id", DataType::Utf8, true),
774            Field::new(
775                "created_at",
776                DataType::Timestamp(TimeUnit::Microsecond, None),
777                false,
778            ),
779            Field::new(
780                "role",
781                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
782                false,
783            ),
784            Field::new(
785                "state_metadata",
786                DataType::Struct(
787                    vec![
788                        Field::new("step", DataType::Int32, true),
789                        Field::new("active_plan_id", DataType::Utf8, true),
790                        Field::new("tokens_used", DataType::Int32, true),
791                        Field::new("custom", DataType::Utf8, true),
792                    ]
793                    .into(),
794                ),
795                true,
796            ),
797        ]);
798        if include_metadata {
799            fields.push(Field::new("metadata", DataType::LargeUtf8, true));
800        }
801        fields.extend([
802            Field::new("content_type", DataType::Utf8, false),
803            text_field,
804            binary_field,
805            Field::new(
806                "embedding",
807                DataType::FixedSizeList(
808                    Arc::new(Field::new("item", DataType::Float32, true)),
809                    DEFAULT_EMBEDDING_DIM,
810                ),
811                true,
812            ),
813        ]);
814
815        Schema::new(fields)
816    }
817
818    async fn load_with_options(
819        uri: &str,
820        storage_options: Option<HashMap<String, String>>,
821    ) -> LanceResult<Dataset> {
822        if let Some(options) = storage_options {
823            DatasetBuilder::from_uri(uri)
824                .with_storage_options(options)
825                .load()
826                .await
827        } else {
828            Dataset::open(uri).await
829        }
830    }
831
832    async fn create_with_options(
833        uri: &str,
834        storage_options: Option<HashMap<String, String>>,
835        blob_columns: &HashSet<String>,
836    ) -> LanceResult<Dataset> {
837        let schema = Arc::new(Self::schema(blob_columns));
838        let empty_batch = RecordBatch::new_empty(schema.clone());
839        let batches = RecordBatchIterator::new(
840            vec![Ok::<RecordBatch, ArrowError>(empty_batch)].into_iter(),
841            schema.clone(),
842        );
843
844        let mut params = WriteParams {
845            mode: WriteMode::Create,
846            ..Default::default()
847        };
848
849        if let Some(options) = storage_options {
850            let store_params = ObjectStoreParams {
851                storage_options_accessor: Some(Arc::new(
852                    StorageOptionsAccessor::with_static_options(options),
853                )),
854                ..Default::default()
855            };
856            params.store_params = Some(store_params);
857        }
858
859        Dataset::write(batches, uri, Some(params)).await
860    }
861
862    fn records_to_batch(&self, entries: &[ContextRecord]) -> LanceResult<RecordBatch> {
863        let include_external_id = self
864            .dataset
865            .schema()
866            .field_paths()
867            .iter()
868            .any(|path| path == "external_id");
869        if !include_external_id && entries.iter().any(|entry| entry.external_id.is_some()) {
870            return Err(ArrowError::InvalidArgumentError(
871                "external_id requires a context dataset created with external_id support"
872                    .to_string(),
873            )
874            .into());
875        }
876        let include_metadata = self
877            .dataset
878            .schema()
879            .field_paths()
880            .iter()
881            .any(|path| path == "metadata");
882        if !include_metadata && entries.iter().any(|entry| entry.metadata.is_some()) {
883            return Err(ArrowError::InvalidArgumentError(
884                "metadata requires a context dataset created with metadata support".to_string(),
885            )
886            .into());
887        }
888
889        let mut id_builder = StringBuilder::new();
890        let mut external_id_builder = StringBuilder::new();
891        let mut run_id_builder = StringBuilder::new();
892        let mut bot_id_builder = StringBuilder::new();
893        let mut session_id_builder = StringBuilder::new();
894        let mut created_at_builder = TimestampMicrosecondBuilder::with_capacity(entries.len());
895        let mut role_builder = StringDictionaryBuilder::<Int8Type>::new();
896        let mut metadata_builder = LargeStringBuilder::new();
897        let mut content_type_builder = StringBuilder::new();
898        let mut binary_builder = LargeBinaryBuilder::new();
899
900        let text_is_blob = self.blob_columns.contains("text_payload");
901        let mut text_string_builder = if !text_is_blob {
902            Some(LargeStringBuilder::new())
903        } else {
904            None
905        };
906        let mut text_binary_builder = if text_is_blob {
907            Some(LargeBinaryBuilder::new())
908        } else {
909            None
910        };
911
912        let state_fields: Vec<FieldRef> = vec![
913            Arc::new(Field::new("step", DataType::Int32, true)),
914            Arc::new(Field::new("active_plan_id", DataType::Utf8, true)),
915            Arc::new(Field::new("tokens_used", DataType::Int32, true)),
916            Arc::new(Field::new("custom", DataType::Utf8, true)),
917        ];
918        let mut state_builder = StructBuilder::new(
919            state_fields,
920            vec![
921                Box::new(Int32Builder::new()),
922                Box::new(StringBuilder::new()),
923                Box::new(Int32Builder::new()),
924                Box::new(StringBuilder::new()),
925            ],
926        );
927
928        let mut embedding_builder =
929            FixedSizeListBuilder::new(Float32Builder::new(), DEFAULT_EMBEDDING_DIM);
930
931        for entry in entries {
932            id_builder.append_value(&entry.id);
933            external_id_builder.append_option(entry.external_id.as_deref());
934            run_id_builder.append_value(&entry.run_id);
935            bot_id_builder.append_option(entry.bot_id.as_deref());
936            session_id_builder.append_option(entry.session_id.as_deref());
937            created_at_builder.append_value(entry.created_at.timestamp_micros());
938            role_builder.append(&entry.role)?;
939            match &entry.metadata {
940                Some(metadata) => metadata_builder.append_value(metadata.to_string()),
941                None => metadata_builder.append_null(),
942            }
943            content_type_builder.append_value(&entry.content_type);
944
945            if text_is_blob {
946                match &entry.text_payload {
947                    Some(value) => text_binary_builder
948                        .as_mut()
949                        .unwrap()
950                        .append_value(value.as_bytes()),
951                    None => text_binary_builder.as_mut().unwrap().append_null(),
952                }
953            } else {
954                match &entry.text_payload {
955                    Some(value) => text_string_builder.as_mut().unwrap().append_value(value),
956                    None => text_string_builder.as_mut().unwrap().append_null(),
957                }
958            }
959
960            match &entry.binary_payload {
961                Some(value) => binary_builder.append_value(value),
962                None => binary_builder.append_null(),
963            }
964
965            if let Some(metadata) = &entry.state_metadata {
966                state_builder
967                    .field_builder::<Int32Builder>(0)
968                    .unwrap()
969                    .append_option(metadata.step);
970                state_builder
971                    .field_builder::<StringBuilder>(1)
972                    .unwrap()
973                    .append_option(metadata.active_plan_id.as_deref());
974                state_builder
975                    .field_builder::<Int32Builder>(2)
976                    .unwrap()
977                    .append_option(metadata.tokens_used);
978                state_builder
979                    .field_builder::<StringBuilder>(3)
980                    .unwrap()
981                    .append_option(metadata.custom.as_deref());
982                state_builder.append(true);
983            } else {
984                state_builder
985                    .field_builder::<Int32Builder>(0)
986                    .unwrap()
987                    .append_null();
988                state_builder
989                    .field_builder::<StringBuilder>(1)
990                    .unwrap()
991                    .append_null();
992                state_builder
993                    .field_builder::<Int32Builder>(2)
994                    .unwrap()
995                    .append_null();
996                state_builder
997                    .field_builder::<StringBuilder>(3)
998                    .unwrap()
999                    .append_null();
1000                state_builder.append(false);
1001            }
1002
1003            if let Some(embedding) = &entry.embedding {
1004                if embedding.len() != DEFAULT_EMBEDDING_DIM as usize {
1005                    return Err(ArrowError::InvalidArgumentError(format!(
1006                        "embedding length {} does not match expected dimension {}",
1007                        embedding.len(),
1008                        DEFAULT_EMBEDDING_DIM
1009                    ))
1010                    .into());
1011                }
1012                {
1013                    let values_builder = embedding_builder.values();
1014                    for value in embedding {
1015                        values_builder.append_value(*value);
1016                    }
1017                }
1018                embedding_builder.append(true);
1019            } else {
1020                // FixedSizeListBuilder requires padding values for null slots.
1021                let values_builder = embedding_builder.values();
1022                for _ in 0..DEFAULT_EMBEDDING_DIM {
1023                    values_builder.append_null();
1024                }
1025                embedding_builder.append(false);
1026            }
1027        }
1028
1029        let id_array: ArrayRef = Arc::new(id_builder.finish());
1030        let external_id_array: ArrayRef = Arc::new(external_id_builder.finish());
1031        let run_id_array: ArrayRef = Arc::new(run_id_builder.finish());
1032        let bot_id_array: ArrayRef = Arc::new(bot_id_builder.finish());
1033        let session_id_array: ArrayRef = Arc::new(session_id_builder.finish());
1034        let created_at_array: ArrayRef = Arc::new(created_at_builder.finish());
1035        let role_array: ArrayRef = Arc::new(role_builder.finish());
1036        let metadata_array: ArrayRef = Arc::new(metadata_builder.finish());
1037        let content_type_array: ArrayRef = Arc::new(content_type_builder.finish());
1038        let text_array: ArrayRef = if text_is_blob {
1039            Arc::new(text_binary_builder.unwrap().finish())
1040        } else {
1041            Arc::new(text_string_builder.unwrap().finish())
1042        };
1043        let binary_array: ArrayRef = Arc::new(binary_builder.finish());
1044        let state_array: ArrayRef = Arc::new(state_builder.finish());
1045        let embedding_array: ArrayRef = Arc::new(embedding_builder.finish());
1046
1047        let schema = Arc::new(Self::schema_with_options(
1048            &self.blob_columns,
1049            include_external_id,
1050            include_metadata,
1051        ));
1052        let mut arrays = vec![id_array];
1053        if include_external_id {
1054            arrays.push(external_id_array);
1055        }
1056        arrays.extend([
1057            run_id_array,
1058            bot_id_array,
1059            session_id_array,
1060            created_at_array,
1061            role_array,
1062            state_array,
1063        ]);
1064        if include_metadata {
1065            arrays.push(metadata_array);
1066        }
1067        arrays.extend([
1068            content_type_array,
1069            text_array,
1070            binary_array,
1071            embedding_array,
1072        ]);
1073        let batch = RecordBatch::try_new(schema, arrays)?;
1074
1075        Ok(batch)
1076    }
1077}
1078
1079impl Drop for ContextStore {
1080    fn drop(&mut self) {
1081        // Best-effort cleanup of background task
1082        if let Ok(mut state) = self.compaction_state.try_lock() {
1083            if let Some(task) = state.background_task.take() {
1084                task.abort();
1085            }
1086        }
1087    }
1088}
1089
1090/// Convert a record batch to context records.
1091fn batch_to_records(batch: &RecordBatch) -> LanceResult<Vec<ContextRecord>> {
1092    let id_array = column_as::<StringArray>(batch, "id")?;
1093    let external_id_array = column_as_optional::<StringArray>(batch, "external_id");
1094    let run_id_array = column_as::<StringArray>(batch, "run_id")?;
1095    let bot_id_array = column_as_optional::<StringArray>(batch, "bot_id");
1096    let session_id_array = column_as_optional::<StringArray>(batch, "session_id");
1097    let created_at_array = column_as::<TimestampMicrosecondArray>(batch, "created_at")?;
1098    let role_array = column_as::<DictionaryArray<Int8Type>>(batch, "role")?;
1099    let state_array = column_as::<StructArray>(batch, "state_metadata")?;
1100    let metadata_array = column_as_optional::<LargeStringArray>(batch, "metadata");
1101    let content_type_array = column_as::<StringArray>(batch, "content_type")?;
1102    let binary_array = column_as::<LargeBinaryArray>(batch, "binary_payload")?;
1103    let embedding_array = column_as::<FixedSizeListArray>(batch, "embedding")?;
1104
1105    // Auto-detect whether text_payload is LargeBinary (blob) or LargeUtf8 (default)
1106    let text_is_binary = batch
1107        .schema()
1108        .field_with_name("text_payload")
1109        .is_ok_and(|f| f.data_type() == &DataType::LargeBinary);
1110
1111    let text_string_array = if !text_is_binary {
1112        Some(column_as::<LargeStringArray>(batch, "text_payload")?)
1113    } else {
1114        None
1115    };
1116    let text_binary_array = if text_is_binary {
1117        Some(column_as::<LargeBinaryArray>(batch, "text_payload")?)
1118    } else {
1119        None
1120    };
1121
1122    let step_array = state_array
1123        .column(0)
1124        .as_ref()
1125        .as_any()
1126        .downcast_ref::<Int32Array>()
1127        .ok_or_else(|| {
1128            LanceError::from(ArrowError::InvalidArgumentError(
1129                "step column has unexpected data type".to_string(),
1130            ))
1131        })?;
1132    let active_plan_array = state_array
1133        .column(1)
1134        .as_ref()
1135        .as_any()
1136        .downcast_ref::<StringArray>()
1137        .ok_or_else(|| {
1138            LanceError::from(ArrowError::InvalidArgumentError(
1139                "active_plan_id column has unexpected data type".to_string(),
1140            ))
1141        })?;
1142    let tokens_used_array = state_array
1143        .column(2)
1144        .as_ref()
1145        .as_any()
1146        .downcast_ref::<Int32Array>()
1147        .ok_or_else(|| {
1148            LanceError::from(ArrowError::InvalidArgumentError(
1149                "tokens_used column has unexpected data type".to_string(),
1150            ))
1151        })?;
1152    let custom_array = state_array
1153        .column(3)
1154        .as_ref()
1155        .as_any()
1156        .downcast_ref::<StringArray>()
1157        .ok_or_else(|| {
1158            LanceError::from(ArrowError::InvalidArgumentError(
1159                "custom column has unexpected data type".to_string(),
1160            ))
1161        })?;
1162
1163    let mut results = Vec::with_capacity(batch.num_rows());
1164    for row in 0..batch.num_rows() {
1165        let created_at =
1166            DateTime::from_timestamp_micros(created_at_array.value(row)).ok_or_else(|| {
1167                LanceError::from(ArrowError::InvalidArgumentError(format!(
1168                    "invalid timestamp value {}",
1169                    created_at_array.value(row)
1170                )))
1171            })?;
1172
1173        let state_metadata = if state_array.is_null(row) {
1174            None
1175        } else {
1176            Some(StateMetadata {
1177                step: if step_array.is_null(row) {
1178                    None
1179                } else {
1180                    Some(step_array.value(row))
1181                },
1182                active_plan_id: if active_plan_array.is_null(row) {
1183                    None
1184                } else {
1185                    Some(active_plan_array.value(row).to_string())
1186                },
1187                tokens_used: if tokens_used_array.is_null(row) {
1188                    None
1189                } else {
1190                    Some(tokens_used_array.value(row))
1191                },
1192                custom: if custom_array.is_null(row) {
1193                    None
1194                } else {
1195                    Some(custom_array.value(row).to_string())
1196                },
1197            })
1198        };
1199
1200        let text_payload = if text_is_binary {
1201            let arr = text_binary_array.unwrap();
1202            if arr.is_null(row) {
1203                None
1204            } else {
1205                Some(String::from_utf8_lossy(arr.value(row)).to_string())
1206            }
1207        } else {
1208            let arr = text_string_array.unwrap();
1209            if arr.is_null(row) {
1210                None
1211            } else {
1212                Some(arr.value(row).to_string())
1213            }
1214        };
1215
1216        let binary_payload = if binary_array.is_null(row) {
1217            None
1218        } else {
1219            Some(binary_array.value(row).to_vec())
1220        };
1221
1222        let embedding = if embedding_array.is_null(row) {
1223            None
1224        } else {
1225            Some(embedding_from_list(embedding_array, row)?)
1226        };
1227
1228        let role = if role_array.is_null(row) {
1229            return Err(LanceError::from(ArrowError::InvalidArgumentError(
1230                "role column contains null values".to_string(),
1231            )));
1232        } else {
1233            let role_values = role_array
1234                .values()
1235                .as_any()
1236                .downcast_ref::<StringArray>()
1237                .ok_or_else(|| {
1238                    LanceError::from(ArrowError::InvalidArgumentError(
1239                        "role dictionary values are not strings".to_string(),
1240                    ))
1241                })?;
1242            let key = role_array.keys().value(row) as usize;
1243            role_values.value(key).to_string()
1244        };
1245
1246        let bot_id = bot_id_array.and_then(|arr| {
1247            if arr.is_null(row) {
1248                None
1249            } else {
1250                Some(arr.value(row).to_string())
1251            }
1252        });
1253
1254        let session_id = session_id_array.and_then(|arr| {
1255            if arr.is_null(row) {
1256                None
1257            } else {
1258                Some(arr.value(row).to_string())
1259            }
1260        });
1261
1262        let metadata = match metadata_array {
1263            Some(arr) if !arr.is_null(row) => {
1264                Some(serde_json::from_str(arr.value(row)).map_err(|err| {
1265                    LanceError::from(ArrowError::InvalidArgumentError(format!(
1266                        "invalid metadata JSON for record {}: {}",
1267                        id_array.value(row),
1268                        err
1269                    )))
1270                })?)
1271            }
1272            _ => None,
1273        };
1274
1275        results.push(ContextRecord {
1276            id: id_array.value(row).to_string(),
1277            external_id: external_id_array.and_then(|arr| {
1278                if arr.is_null(row) {
1279                    None
1280                } else {
1281                    Some(arr.value(row).to_string())
1282                }
1283            }),
1284            run_id: run_id_array.value(row).to_string(),
1285            bot_id,
1286            session_id,
1287            created_at,
1288            role,
1289            state_metadata,
1290            metadata,
1291            content_type: content_type_array.value(row).to_string(),
1292            text_payload,
1293            binary_payload,
1294            embedding,
1295        });
1296    }
1297
1298    Ok(results)
1299}
1300
1301fn embedding_from_list(list: &FixedSizeListArray, row: usize) -> LanceResult<Vec<f32>> {
1302    let values = list.value(row);
1303    let float_array = values
1304        .as_ref()
1305        .as_any()
1306        .downcast_ref::<Float32Array>()
1307        .ok_or_else(|| {
1308            LanceError::from(ArrowError::InvalidArgumentError(
1309                "embedding column does not contain float32 values".to_string(),
1310            ))
1311        })?;
1312    let mut embedding = Vec::with_capacity(float_array.len());
1313    for idx in 0..float_array.len() {
1314        embedding.push(float_array.value(idx));
1315    }
1316    Ok(embedding)
1317}
1318
1319fn l2_distance(left: &[f32], right: &[f32]) -> f32 {
1320    left.iter()
1321        .zip(right)
1322        .map(|(left, right)| {
1323            let delta = left - right;
1324            delta * delta
1325        })
1326        .sum::<f32>()
1327        .sqrt()
1328}
1329
1330fn column_as<'a, A>(batch: &'a RecordBatch, name: &str) -> LanceResult<&'a A>
1331where
1332    A: Array + 'static,
1333{
1334    let column = batch.column_by_name(name).ok_or_else(|| {
1335        LanceError::from(ArrowError::InvalidArgumentError(format!(
1336            "column '{name}' not found"
1337        )))
1338    })?;
1339    column.as_ref().as_any().downcast_ref::<A>().ok_or_else(|| {
1340        LanceError::from(ArrowError::InvalidArgumentError(format!(
1341            "column '{name}' has unexpected data type"
1342        )))
1343    })
1344}
1345
1346fn column_as_optional<'a, A>(batch: &'a RecordBatch, name: &str) -> Option<&'a A>
1347where
1348    A: Array + 'static,
1349{
1350    batch
1351        .column_by_name(name)
1352        .and_then(|col| col.as_ref().as_any().downcast_ref::<A>())
1353}
1354
1355#[cfg(test)]
1356mod tests {
1357    use super::*;
1358    use crate::serde::CONTENT_TYPE_TEXT;
1359    use chrono::Utc;
1360    use tempfile::TempDir;
1361
1362    fn make_embedding(pivot: f32) -> Vec<f32> {
1363        let mut values = vec![0.0; DEFAULT_EMBEDDING_DIM as usize];
1364        if !values.is_empty() {
1365            values[0] = pivot;
1366        }
1367        values
1368    }
1369
1370    fn text_record(id: &str, embedding_pivot: f32) -> ContextRecord {
1371        ContextRecord {
1372            id: id.to_string(),
1373            external_id: None,
1374            run_id: format!("run-{id}"),
1375            bot_id: None,
1376            session_id: None,
1377            created_at: Utc::now(),
1378            role: "user".to_string(),
1379            state_metadata: Some(StateMetadata {
1380                step: Some(1),
1381                active_plan_id: Some("plan".to_string()),
1382                tokens_used: Some(10),
1383                custom: None,
1384            }),
1385            metadata: None,
1386            content_type: CONTENT_TYPE_TEXT.to_string(),
1387            text_payload: Some(format!("payload-{id}")),
1388            binary_payload: None,
1389            embedding: Some(make_embedding(embedding_pivot)),
1390        }
1391    }
1392
1393    #[test]
1394    fn search_orders_by_distance() {
1395        let dir = TempDir::new().unwrap();
1396        let uri = dir.path().to_string_lossy().to_string();
1397        let runtime = tokio::runtime::Runtime::new().unwrap();
1398        runtime.block_on(async {
1399            let mut store = ContextStore::open(&uri).await.unwrap();
1400            let first = text_record("a", 0.0);
1401            let second = text_record("b", 1.0);
1402            store.add(&[first.clone(), second.clone()]).await.unwrap();
1403
1404            let query = make_embedding(1.0);
1405            let results = store.search(&query, Some(2)).await.unwrap();
1406
1407            assert_eq!(results.len(), 2);
1408            assert_eq!(results[0].record.id, second.id);
1409            assert!(
1410                results[0].distance <= results[1].distance,
1411                "results not ordered by distance: {:?}",
1412                results
1413            );
1414        });
1415    }
1416
1417    #[test]
1418    fn search_validates_query_length() {
1419        let dir = TempDir::new().unwrap();
1420        let uri = dir.path().to_string_lossy().to_string();
1421        let runtime = tokio::runtime::Runtime::new().unwrap();
1422        runtime.block_on(async {
1423            let store = ContextStore::open(&uri).await.unwrap();
1424            let err = store.search(&[0.0_f32], None).await.unwrap_err();
1425            let message = err.to_string();
1426            assert!(
1427                message.contains("embedding dimension"),
1428                "unexpected error message: {message}"
1429            );
1430        });
1431    }
1432
1433    #[test]
1434    fn external_id_roundtrips_and_supports_lookup() {
1435        let dir = TempDir::new().unwrap();
1436        let uri = dir.path().to_string_lossy().to_string();
1437        let runtime = tokio::runtime::Runtime::new().unwrap();
1438        runtime.block_on(async {
1439            let mut store = ContextStore::open(&uri).await.unwrap();
1440            let mut record = text_record("a", 0.0);
1441            record.external_id = Some("doc-123#chunk-1".to_string());
1442            store.add(std::slice::from_ref(&record)).await.unwrap();
1443
1444            let by_external_id = store
1445                .get_by_external_id("doc-123#chunk-1")
1446                .await
1447                .unwrap()
1448                .unwrap();
1449            assert_eq!(by_external_id.id, record.id);
1450            assert_eq!(by_external_id.external_id, record.external_id);
1451
1452            let by_id = store.get_by_id(&record.id).await.unwrap().unwrap();
1453            assert_eq!(by_id.external_id.as_deref(), Some("doc-123#chunk-1"));
1454
1455            let missing = store.get_by_external_id("missing").await.unwrap();
1456            assert!(missing.is_none());
1457        });
1458    }
1459
1460    #[test]
1461    fn add_rejects_duplicate_external_id() {
1462        let dir = TempDir::new().unwrap();
1463        let uri = dir.path().to_string_lossy().to_string();
1464        let runtime = tokio::runtime::Runtime::new().unwrap();
1465        runtime.block_on(async {
1466            let mut store = ContextStore::open(&uri).await.unwrap();
1467            let mut first = text_record("a", 0.0);
1468            first.external_id = Some("doc-123#chunk-1".to_string());
1469            store.add(std::slice::from_ref(&first)).await.unwrap();
1470
1471            let mut duplicate = text_record("b", 0.0);
1472            duplicate.external_id = first.external_id.clone();
1473            let err = store.add(&[duplicate]).await.unwrap_err();
1474            let message = err.to_string();
1475            assert!(
1476                message.contains("external_id") && message.contains("already exists"),
1477                "unexpected error message: {message}"
1478            );
1479        });
1480    }
1481
1482    #[test]
1483    fn add_rejects_reserved_tombstone_content_type() {
1484        let dir = TempDir::new().unwrap();
1485        let uri = dir.path().to_string_lossy().to_string();
1486        let runtime = tokio::runtime::Runtime::new().unwrap();
1487        runtime.block_on(async {
1488            let mut store = ContextStore::open(&uri).await.unwrap();
1489            let mut record = text_record("a", 0.0);
1490            record.content_type = CONTENT_TYPE_TOMBSTONE.to_string();
1491
1492            let err = store.add(&[record]).await.unwrap_err();
1493            let message = err.to_string();
1494            assert!(
1495                message.contains("reserved") && message.contains("tombstone"),
1496                "unexpected error message: {message}"
1497            );
1498        });
1499    }
1500
1501    #[test]
1502    fn delete_by_external_id_hides_record_from_default_reads() {
1503        let dir = TempDir::new().unwrap();
1504        let uri = dir.path().to_string_lossy().to_string();
1505        let runtime = tokio::runtime::Runtime::new().unwrap();
1506        runtime.block_on(async {
1507            let mut store = ContextStore::open(&uri).await.unwrap();
1508            let mut first = text_record("a", 0.0);
1509            first.external_id = Some("doc-123#chunk-1".to_string());
1510            let second = text_record("b", 2.0);
1511            store.add(&[first.clone(), second.clone()]).await.unwrap();
1512
1513            assert!(store
1514                .delete_by_external_id("doc-123#chunk-1")
1515                .await
1516                .unwrap());
1517
1518            assert!(store
1519                .get_by_external_id("doc-123#chunk-1")
1520                .await
1521                .unwrap()
1522                .is_none());
1523            assert!(store.get_by_id(&first.id).await.unwrap().is_none());
1524
1525            let records = store.list(None, None).await.unwrap();
1526            assert_eq!(records.len(), 1);
1527            assert_eq!(records[0].id, second.id);
1528
1529            let query = make_embedding(0.0);
1530            let hits = store.search(&query, Some(10)).await.unwrap();
1531            assert_eq!(hits.len(), 1);
1532            assert_eq!(hits[0].record.id, second.id);
1533        });
1534    }
1535
1536    #[test]
1537    fn delete_by_id_hides_record_from_default_reads() {
1538        let dir = TempDir::new().unwrap();
1539        let uri = dir.path().to_string_lossy().to_string();
1540        let runtime = tokio::runtime::Runtime::new().unwrap();
1541        runtime.block_on(async {
1542            let mut store = ContextStore::open(&uri).await.unwrap();
1543            let mut first = text_record("a", 0.0);
1544            first.external_id = Some("doc-123#chunk-1".to_string());
1545            let second = text_record("b", 2.0);
1546            store.add(&[first.clone(), second.clone()]).await.unwrap();
1547
1548            assert!(store.delete_by_id(&first.id).await.unwrap());
1549
1550            assert!(store.get_by_id(&first.id).await.unwrap().is_none());
1551            assert!(store
1552                .get_by_external_id("doc-123#chunk-1")
1553                .await
1554                .unwrap()
1555                .is_none());
1556
1557            let records = store.list(None, None).await.unwrap();
1558            assert_eq!(records.len(), 1);
1559            assert_eq!(records[0].id, second.id);
1560
1561            let query = make_embedding(0.0);
1562            let hits = store.search(&query, Some(10)).await.unwrap();
1563            assert_eq!(hits.len(), 1);
1564            assert_eq!(hits[0].record.id, second.id);
1565        });
1566    }
1567
1568    #[test]
1569    fn delete_missing_id_is_noop() {
1570        let dir = TempDir::new().unwrap();
1571        let uri = dir.path().to_string_lossy().to_string();
1572        let runtime = tokio::runtime::Runtime::new().unwrap();
1573        runtime.block_on(async {
1574            let mut store = ContextStore::open(&uri).await.unwrap();
1575            assert!(!store.delete_by_id("missing").await.unwrap());
1576            assert!(!store.delete_by_external_id("missing").await.unwrap());
1577        });
1578    }
1579
1580    #[test]
1581    fn external_id_can_be_reused_after_delete() {
1582        let dir = TempDir::new().unwrap();
1583        let uri = dir.path().to_string_lossy().to_string();
1584        let runtime = tokio::runtime::Runtime::new().unwrap();
1585        runtime.block_on(async {
1586            let mut store = ContextStore::open(&uri).await.unwrap();
1587            let mut first = text_record("a", 0.0);
1588            first.external_id = Some("doc-123#chunk-1".to_string());
1589            store.add(std::slice::from_ref(&first)).await.unwrap();
1590            assert!(store
1591                .delete_by_external_id("doc-123#chunk-1")
1592                .await
1593                .unwrap());
1594
1595            let mut replacement = text_record("b", 1.0);
1596            replacement.external_id = first.external_id.clone();
1597            store.add(std::slice::from_ref(&replacement)).await.unwrap();
1598
1599            let by_external_id = store
1600                .get_by_external_id("doc-123#chunk-1")
1601                .await
1602                .unwrap()
1603                .unwrap();
1604            assert_eq!(by_external_id.id, replacement.id);
1605            assert_eq!(store.list(None, None).await.unwrap().len(), 1);
1606        });
1607    }
1608
1609    #[test]
1610    fn test_region_id_derivation_explicit() {
1611        let bot_id = Some("bot-123".to_string());
1612        let session_id = Some("session-456".to_string());
1613
1614        let region_id_1 = ContextStore::derive_region_id(&bot_id, &session_id);
1615        let region_id_2 = ContextStore::derive_region_id(&bot_id, &session_id);
1616
1617        assert_eq!(
1618            region_id_1, region_id_2,
1619            "Region ID should be deterministic for same inputs"
1620        );
1621
1622        let other_session = Some("session-789".to_string());
1623        let region_id_3 = ContextStore::derive_region_id(&bot_id, &other_session);
1624
1625        assert_ne!(
1626            region_id_1, region_id_3,
1627            "Region ID should differ for different inputs"
1628        );
1629
1630        // Test None/None case (now deterministic based on empty strings)
1631        let region_id_none = ContextStore::derive_region_id(&None, &None);
1632        let region_id_none_2 = ContextStore::derive_region_id(&None, &None);
1633        assert_eq!(
1634            region_id_none, region_id_none_2,
1635            "Region ID for None/None should be deterministic"
1636        );
1637    }
1638
1639    #[test]
1640    fn test_add_multiple_regions() {
1641        let dir = TempDir::new().unwrap();
1642        let uri = dir.path().to_string_lossy().to_string();
1643        let runtime = tokio::runtime::Runtime::new().unwrap();
1644
1645        runtime.block_on(async {
1646            let mut store = ContextStore::open(&uri).await.unwrap();
1647
1648            // Create records for different regions
1649            let mut record1 = text_record("r1", 0.0);
1650            record1.bot_id = Some("bot-A".to_string());
1651            record1.session_id = Some("session-1".to_string());
1652
1653            let mut record2 = text_record("r2", 0.0);
1654            record2.bot_id = Some("bot-B".to_string());
1655            record2.session_id = Some("session-2".to_string());
1656
1657            // Add them in a single batch
1658            store
1659                .add(&[record1.clone(), record2.clone()])
1660                .await
1661                .unwrap();
1662
1663            // Reload store to verify persistence
1664            let store = ContextStore::open(&uri).await.unwrap();
1665
1666            // Verify we can list them back
1667            let results = store.list(None, None).await.unwrap();
1668            assert_eq!(results.len(), 2);
1669
1670            let ids: Vec<String> = results.iter().map(|r| r.id.clone()).collect();
1671            assert!(ids.contains(&"r1".to_string()));
1672            assert!(ids.contains(&"r2".to_string()));
1673        });
1674    }
1675
1676    #[test]
1677    fn test_blob_binary_payload() {
1678        let dir = TempDir::new().unwrap();
1679        let uri = dir.path().to_string_lossy().to_string();
1680        let runtime = tokio::runtime::Runtime::new().unwrap();
1681
1682        runtime.block_on(async {
1683            let options = ContextStoreOptions {
1684                blob_columns: HashSet::from(["binary_payload".to_string()]),
1685                ..Default::default()
1686            };
1687            let mut store = ContextStore::open_with_options(&uri, options)
1688                .await
1689                .unwrap();
1690
1691            let mut record = text_record("blob-bin-1", 0.0);
1692            record.binary_payload = Some(vec![0xDE, 0xAD, 0xBE, 0xEF]);
1693            store.add(std::slice::from_ref(&record)).await.unwrap();
1694
1695            // Verify schema has blob metadata on binary_payload
1696            let schema = ContextStore::schema(&store.blob_columns);
1697            let field = schema.field_with_name("binary_payload").unwrap();
1698            assert_eq!(
1699                field.metadata().get("lance-encoding:blob"),
1700                Some(&"true".to_string()),
1701            );
1702            // text_payload should remain LargeUtf8 without blob metadata
1703            let text_field = schema.field_with_name("text_payload").unwrap();
1704            assert_eq!(text_field.data_type(), &DataType::LargeUtf8);
1705            assert!(text_field.metadata().get("lance-encoding:blob").is_none());
1706        });
1707    }
1708
1709    #[test]
1710    fn test_blob_text_payload() {
1711        let dir = TempDir::new().unwrap();
1712        let uri = dir.path().to_string_lossy().to_string();
1713        let runtime = tokio::runtime::Runtime::new().unwrap();
1714
1715        runtime.block_on(async {
1716            let options = ContextStoreOptions {
1717                blob_columns: HashSet::from(["text_payload".to_string()]),
1718                ..Default::default()
1719            };
1720            let mut store = ContextStore::open_with_options(&uri, options)
1721                .await
1722                .unwrap();
1723
1724            let record = text_record("blob-txt-1", 0.0);
1725            store.add(std::slice::from_ref(&record)).await.unwrap();
1726
1727            // Roundtrip: records_to_batch -> batch_to_records
1728            let batch = store
1729                .records_to_batch(std::slice::from_ref(&record))
1730                .unwrap();
1731            let batch_schema = batch.schema();
1732            let text_field = batch_schema.field_with_name("text_payload").unwrap();
1733            assert_eq!(
1734                text_field.data_type(),
1735                &DataType::LargeBinary,
1736                "text_payload should be LargeBinary when blob-encoded"
1737            );
1738
1739            let roundtripped = batch_to_records(&batch).unwrap();
1740            assert_eq!(roundtripped.len(), 1);
1741            assert_eq!(
1742                roundtripped[0].text_payload, record.text_payload,
1743                "text payload should survive blob roundtrip"
1744            );
1745        });
1746    }
1747
1748    #[test]
1749    fn test_blob_both_columns() {
1750        let dir = TempDir::new().unwrap();
1751        let uri = dir.path().to_string_lossy().to_string();
1752        let runtime = tokio::runtime::Runtime::new().unwrap();
1753
1754        runtime.block_on(async {
1755            let options = ContextStoreOptions {
1756                blob_columns: HashSet::from([
1757                    "text_payload".to_string(),
1758                    "binary_payload".to_string(),
1759                ]),
1760                ..Default::default()
1761            };
1762            let mut store = ContextStore::open_with_options(&uri, options)
1763                .await
1764                .unwrap();
1765
1766            let mut record = text_record("blob-both-1", 0.0);
1767            record.binary_payload = Some(b"hello binary".to_vec());
1768            store.add(std::slice::from_ref(&record)).await.unwrap();
1769
1770            // Both columns should have blob metadata
1771            let schema = ContextStore::schema(&store.blob_columns);
1772            let text_field = schema.field_with_name("text_payload").unwrap();
1773            let bin_field = schema.field_with_name("binary_payload").unwrap();
1774            assert_eq!(
1775                text_field.metadata().get("lance-encoding:blob"),
1776                Some(&"true".to_string()),
1777            );
1778            assert_eq!(
1779                bin_field.metadata().get("lance-encoding:blob"),
1780                Some(&"true".to_string()),
1781            );
1782
1783            // Roundtrip via batch
1784            let batch = store
1785                .records_to_batch(std::slice::from_ref(&record))
1786                .unwrap();
1787            let roundtripped = batch_to_records(&batch).unwrap();
1788            assert_eq!(roundtripped.len(), 1);
1789            assert_eq!(roundtripped[0].text_payload, record.text_payload);
1790            assert_eq!(roundtripped[0].binary_payload, record.binary_payload);
1791        });
1792    }
1793
1794    #[test]
1795    fn test_no_blob_default() {
1796        // Default options should produce no blob metadata
1797        let schema = ContextStore::schema(&HashSet::new());
1798        let text_field = schema.field_with_name("text_payload").unwrap();
1799        let bin_field = schema.field_with_name("binary_payload").unwrap();
1800
1801        assert_eq!(text_field.data_type(), &DataType::LargeUtf8);
1802        assert!(text_field.metadata().get("lance-encoding:blob").is_none());
1803        assert_eq!(bin_field.data_type(), &DataType::LargeBinary);
1804        assert!(bin_field.metadata().get("lance-encoding:blob").is_none());
1805    }
1806
1807    #[test]
1808    fn test_blob_schema_metadata() {
1809        let blob_columns =
1810            HashSet::from(["text_payload".to_string(), "binary_payload".to_string()]);
1811        let schema = ContextStore::schema(&blob_columns);
1812
1813        let text_field = schema.field_with_name("text_payload").unwrap();
1814        assert_eq!(text_field.data_type(), &DataType::LargeBinary);
1815        assert_eq!(
1816            text_field.metadata().get("lance-encoding:blob"),
1817            Some(&"true".to_string()),
1818        );
1819
1820        let bin_field = schema.field_with_name("binary_payload").unwrap();
1821        assert_eq!(bin_field.data_type(), &DataType::LargeBinary);
1822        assert_eq!(
1823            bin_field.metadata().get("lance-encoding:blob"),
1824            Some(&"true".to_string()),
1825        );
1826
1827        // Non-blob fields should have no blob metadata
1828        let id_field = schema.field_with_name("id").unwrap();
1829        assert!(id_field.metadata().get("lance-encoding:blob").is_none());
1830    }
1831
1832    #[test]
1833    fn test_blob_invalid_column_name() {
1834        let dir = TempDir::new().unwrap();
1835        let uri = dir.path().to_string_lossy().to_string();
1836        let runtime = tokio::runtime::Runtime::new().unwrap();
1837
1838        runtime.block_on(async {
1839            let options = ContextStoreOptions {
1840                blob_columns: HashSet::from(["nonexistent_column".to_string()]),
1841                ..Default::default()
1842            };
1843            let result = ContextStore::open_with_options(&uri, options).await;
1844            assert!(result.is_err(), "should reject invalid blob column names");
1845            let err_msg = result.err().unwrap().to_string();
1846            assert!(
1847                err_msg.contains("invalid blob column"),
1848                "error should mention invalid blob column: {err_msg}"
1849            );
1850        });
1851    }
1852
1853    #[test]
1854    fn test_batch_to_records_autodetects_text_type() {
1855        // Verify that batch_to_records works on both LargeUtf8 and LargeBinary
1856        // text_payload without needing configuration.
1857        let runtime = tokio::runtime::Runtime::new().unwrap();
1858        runtime.block_on(async {
1859            // Build a batch with text_payload as LargeUtf8 (default)
1860            let dir1 = TempDir::new().unwrap();
1861            let uri1 = dir1.path().to_string_lossy().to_string();
1862            let store_default = ContextStore::open(&uri1).await.unwrap();
1863            let record = text_record("auto-1", 0.0);
1864            let batch_utf8 = store_default
1865                .records_to_batch(std::slice::from_ref(&record))
1866                .unwrap();
1867            let results_utf8 = batch_to_records(&batch_utf8).unwrap();
1868            assert_eq!(results_utf8[0].text_payload, record.text_payload);
1869
1870            // Build a batch with text_payload as LargeBinary (blob)
1871            let dir2 = TempDir::new().unwrap();
1872            let uri2 = dir2.path().to_string_lossy().to_string();
1873            let options = ContextStoreOptions {
1874                blob_columns: HashSet::from(["text_payload".to_string()]),
1875                ..Default::default()
1876            };
1877            let store_blob = ContextStore::open_with_options(&uri2, options)
1878                .await
1879                .unwrap();
1880            let batch_binary = store_blob
1881                .records_to_batch(std::slice::from_ref(&record))
1882                .unwrap();
1883            let results_binary = batch_to_records(&batch_binary).unwrap();
1884            assert_eq!(results_binary[0].text_payload, record.text_payload);
1885        });
1886    }
1887
1888    #[test]
1889    fn test_id_index_btree() {
1890        let dir = TempDir::new().unwrap();
1891        let uri = dir.path().to_string_lossy().to_string();
1892        let runtime = tokio::runtime::Runtime::new().unwrap();
1893
1894        runtime.block_on(async {
1895            let options = ContextStoreOptions {
1896                id_index_type: IdIndexType::BTree,
1897                ..Default::default()
1898            };
1899            let mut store = ContextStore::open_with_options(&uri, options)
1900                .await
1901                .unwrap();
1902
1903            // Index should be created eagerly on open
1904            let indices = store.dataset.load_indices().await.unwrap();
1905            assert!(
1906                indices.iter().any(|i| i.name == ID_INDEX_NAME),
1907                "btree index should be created on open"
1908            );
1909
1910            // Add data and verify it still works with the index
1911            for i in 0..5 {
1912                store
1913                    .add(&[text_record(&format!("btree-{i}"), i as f32)])
1914                    .await
1915                    .unwrap();
1916            }
1917            store.compact(None).await.unwrap();
1918
1919            // Index should still exist after compaction
1920            let indices = store.dataset.load_indices().await.unwrap();
1921            assert!(
1922                indices.iter().any(|i| i.name == ID_INDEX_NAME),
1923                "btree index should persist after compaction"
1924            );
1925        });
1926    }
1927
1928    #[test]
1929    fn test_id_index_zonemap() {
1930        let dir = TempDir::new().unwrap();
1931        let uri = dir.path().to_string_lossy().to_string();
1932        let runtime = tokio::runtime::Runtime::new().unwrap();
1933
1934        runtime.block_on(async {
1935            let options = ContextStoreOptions {
1936                id_index_type: IdIndexType::ZoneMap,
1937                ..Default::default()
1938            };
1939            let mut store = ContextStore::open_with_options(&uri, options)
1940                .await
1941                .unwrap();
1942
1943            // Index should be created eagerly on open
1944            let indices = store.dataset.load_indices().await.unwrap();
1945            assert!(
1946                indices.iter().any(|i| i.name == ID_INDEX_NAME),
1947                "zonemap index should be created on open"
1948            );
1949
1950            for i in 0..5 {
1951                store
1952                    .add(&[text_record(&format!("zm-{i}"), i as f32)])
1953                    .await
1954                    .unwrap();
1955            }
1956            store.compact(None).await.unwrap();
1957
1958            let indices = store.dataset.load_indices().await.unwrap();
1959            assert!(
1960                indices.iter().any(|i| i.name == ID_INDEX_NAME),
1961                "zonemap index should persist after compaction"
1962            );
1963        });
1964    }
1965
1966    #[test]
1967    fn test_id_index_none_by_default() {
1968        let dir = TempDir::new().unwrap();
1969        let uri = dir.path().to_string_lossy().to_string();
1970        let runtime = tokio::runtime::Runtime::new().unwrap();
1971
1972        runtime.block_on(async {
1973            let mut store = ContextStore::open(&uri).await.unwrap();
1974
1975            store.add(&[text_record("no-idx-1", 0.0)]).await.unwrap();
1976            store.compact(None).await.unwrap();
1977
1978            let indices = store.dataset.load_indices().await.unwrap();
1979            assert!(
1980                !indices.iter().any(|i| i.name == ID_INDEX_NAME),
1981                "no id index should be created when IdIndexType::None"
1982            );
1983        });
1984    }
1985
1986    #[test]
1987    fn test_id_index_idempotent() {
1988        let dir = TempDir::new().unwrap();
1989        let uri = dir.path().to_string_lossy().to_string();
1990        let runtime = tokio::runtime::Runtime::new().unwrap();
1991
1992        runtime.block_on(async {
1993            let options = ContextStoreOptions {
1994                id_index_type: IdIndexType::BTree,
1995                ..Default::default()
1996            };
1997            let mut store = ContextStore::open_with_options(&uri, options)
1998                .await
1999                .unwrap();
2000
2001            for i in 0..5 {
2002                store
2003                    .add(&[text_record(&format!("idem-{i}"), i as f32)])
2004                    .await
2005                    .unwrap();
2006            }
2007
2008            // Create index twice -- second call should be a no-op
2009            store.create_id_index().await.unwrap();
2010            let v1 = store.version();
2011            store.ensure_id_index().await.unwrap();
2012            let v2 = store.version();
2013            assert_eq!(v1, v2, "ensure_id_index should not recreate existing index");
2014        });
2015    }
2016}