1use std::cmp::Ordering;
2use std::collections::{HashMap, HashSet};
3use std::sync::Arc;
4use std::time::Duration;
5
6use arrow_array::builder::{
7 FixedSizeListBuilder, Float32Builder, Int32Builder, LargeBinaryBuilder, LargeStringBuilder,
8 ListBuilder, StringBuilder, StringDictionaryBuilder, StructBuilder,
9 TimestampMicrosecondBuilder,
10};
11use arrow_array::types::Int8Type;
12use arrow_array::{
13 Array, ArrayRef, DictionaryArray, FixedSizeListArray, Float32Array, Int32Array,
14 LargeBinaryArray, LargeStringArray, ListArray, RecordBatch, RecordBatchIterator, StringArray,
15 StructArray, TimestampMicrosecondArray,
16};
17use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, TimeUnit};
18use chrono::{DateTime, Timelike, Utc};
19use futures::TryStreamExt;
20use lance::dataset::mem_wal::{
21 DatasetMemWalExt, LsmScanner, ShardManifestStore, ShardSnapshot, ShardWriterConfig,
22};
23use lance::dataset::optimize::{compact_files, CompactionMetrics, CompactionOptions};
24use lance::dataset::NewColumnTransform;
25use lance::dataset::{builder::DatasetBuilder, Dataset, WriteMode, WriteParams};
26use lance::index::DatasetIndexExt;
27use lance::io::{ObjectStoreParams, StorageOptionsAccessor};
28use lance::{Error as LanceError, Result as LanceResult};
29use lance_index::mem_wal::MEM_WAL_INDEX_NAME;
30use lance_index::scalar::ScalarIndexParams;
31use lance_index::IndexType;
32use tokio::sync::Mutex;
33use tokio::task::JoinHandle;
34use tracing::{error, info, warn};
35use uuid::Uuid;
36
37use crate::record::{
38 ContextRecord, LifecycleQueryOptions, RecordFilters, RecordPatch, Relationship, RetrieveResult,
39 SearchResult, StateMetadata, UpdateResult, UpsertResult, LIFECYCLE_ACTIVE,
40};
41use crate::serde::CONTENT_TYPE_TOMBSTONE;
42
43const DEFAULT_EMBEDDING_DIM: i32 = 1536;
45const DEFAULT_SEARCH_LIMIT: usize = 10;
46const DEFAULT_MANIFEST_SCAN_BATCH_SIZE: usize = 16;
47const RRF_K: f32 = 60.0;
48const ID_INDEX_NAME: &str = "id_idx";
49const RELATIONSHIPS_COLUMN: &str = "relationships";
50const DISTANCE_METRIC_METADATA_KEY: &str = "lance-context:distance_metric";
53
54#[derive(Debug, Clone)]
56pub struct CompactionConfig {
57 pub enabled: bool,
59 pub min_fragments: usize,
61 pub target_rows_per_fragment: usize,
63 pub max_rows_per_group: usize,
65 pub materialize_deletions: bool,
67 pub materialize_deletions_threshold: f32,
69 pub num_threads: Option<usize>,
71 pub check_interval_secs: u64,
73 pub quiet_hours: Vec<(u8, u8)>,
75}
76
77impl Default for CompactionConfig {
78 fn default() -> Self {
79 Self {
80 enabled: false,
81 min_fragments: 5,
82 target_rows_per_fragment: 1_000_000,
83 max_rows_per_group: 1024,
84 materialize_deletions: true,
85 materialize_deletions_threshold: 0.1,
86 num_threads: None,
87 check_interval_secs: 300,
88 quiet_hours: vec![],
89 }
90 }
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
95pub enum IdIndexType {
96 #[default]
98 None,
99 ZoneMap,
101 BTree,
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
110pub enum DistanceMetric {
111 #[default]
113 L2,
114 Cosine,
117 Dot,
119}
120
121impl DistanceMetric {
122 pub fn parse(value: &str) -> LanceResult<Self> {
128 match value.to_ascii_lowercase().as_str() {
129 "l2" | "euclidean" => Ok(Self::L2),
130 "cosine" => Ok(Self::Cosine),
131 "dot" | "dot_product" => Ok(Self::Dot),
132 other => Err(LanceError::from(ArrowError::InvalidArgumentError(format!(
133 "invalid distance metric '{other}': valid values are 'l2', 'cosine', 'dot'"
134 )))),
135 }
136 }
137
138 #[must_use]
142 pub fn distance(self, query: &[f32], candidate: &[f32]) -> f32 {
143 match self {
144 Self::L2 => l2_distance(query, candidate),
145 Self::Cosine => cosine_distance(query, candidate),
146 Self::Dot => dot_distance(query, candidate),
147 }
148 }
149
150 #[must_use]
153 pub fn as_str(self) -> &'static str {
154 match self {
155 Self::L2 => "l2",
156 Self::Cosine => "cosine",
157 Self::Dot => "dot",
158 }
159 }
160}
161
162#[derive(Debug, Clone)]
164pub struct CompactionStats {
165 pub total_fragments: usize,
167 pub is_compacting: bool,
169 pub last_compaction: Option<DateTime<Utc>>,
171 pub last_error: Option<String>,
173 pub total_compactions: u64,
175}
176
177struct CompactionState {
179 background_task: Option<JoinHandle<()>>,
180 is_compacting: bool,
181 last_compaction: Option<DateTime<Utc>>,
182 last_error: Option<String>,
183 total_compactions: u64,
184}
185
186const VALID_BLOB_COLUMNS: &[&str] = &["text_payload", "binary_payload"];
188
189#[derive(Clone)]
191pub struct ContextStore {
192 dataset: Dataset,
193 compaction_state: Arc<Mutex<CompactionState>>,
194 pub compaction_config: CompactionConfig,
195 blob_columns: HashSet<String>,
196 id_index_type: IdIndexType,
197 embedding_dim: i32,
198 distance_metric: DistanceMetric,
199}
200
201#[derive(Debug, Clone, Default)]
203pub struct ContextStoreOptions {
204 pub storage_options: Option<HashMap<String, String>>,
205 pub compaction: CompactionConfig,
206 pub embedding_dim: Option<i32>,
209 pub blob_columns: HashSet<String>,
212 pub id_index_type: IdIndexType,
214 pub distance_metric: Option<DistanceMetric>,
221}
222
223impl ContextStoreOptions {
224 #[must_use]
225 pub fn storage_options(&self) -> Option<HashMap<String, String>> {
226 self.storage_options.clone()
227 }
228}
229
230fn relationship_struct_fields() -> Vec<Field> {
231 vec![
232 Field::new("target_id", DataType::Utf8, true),
233 Field::new("relation", DataType::Utf8, true),
234 Field::new("weight", DataType::Float32, true),
235 ]
236}
237
238fn relationship_struct_data_type() -> DataType {
239 DataType::Struct(relationship_struct_fields().into())
240}
241
242fn relationship_list_item_field() -> FieldRef {
243 Arc::new(Field::new("item", relationship_struct_data_type(), true))
244}
245
246fn relationship_field() -> Field {
247 Field::new(
248 RELATIONSHIPS_COLUMN,
249 DataType::List(relationship_list_item_field()),
250 true,
251 )
252}
253
254fn relationship_struct_builder() -> StructBuilder {
255 let fields: Vec<FieldRef> = relationship_struct_fields()
256 .into_iter()
257 .map(|field| Arc::new(field) as FieldRef)
258 .collect();
259 StructBuilder::new(
260 fields,
261 vec![
262 Box::new(StringBuilder::new()),
263 Box::new(StringBuilder::new()),
264 Box::new(Float32Builder::new()),
265 ],
266 )
267}
268
269impl ContextStore {
270 pub async fn open(uri: &str) -> LanceResult<Self> {
272 Self::open_with_options(uri, ContextStoreOptions::default()).await
273 }
274
275 pub async fn open_with_options(uri: &str, options: ContextStoreOptions) -> LanceResult<Self> {
277 for col in &options.blob_columns {
279 if !VALID_BLOB_COLUMNS.contains(&col.as_str()) {
280 return Err(LanceError::from(ArrowError::InvalidArgumentError(format!(
281 "invalid blob column '{}': valid columns are {:?}",
282 col, VALID_BLOB_COLUMNS
283 ))));
284 }
285 }
286
287 let requested_embedding_dim = match options.embedding_dim {
288 Some(dim) => {
289 validate_embedding_dim(dim)?;
290 dim
291 }
292 None => DEFAULT_EMBEDDING_DIM,
293 };
294 let storage_options = options.storage_options();
295 let blob_columns = options.blob_columns.clone();
296 let (dataset, created) = match Self::load_with_options(uri, storage_options.clone()).await {
297 Ok(dataset) => (dataset, false),
298 Err(LanceError::DatasetNotFound { .. }) => {
299 let dataset = Self::create_with_options(
300 uri,
301 storage_options,
302 &blob_columns,
303 requested_embedding_dim,
304 options.distance_metric.unwrap_or_default(),
305 )
306 .await?;
307 (dataset, true)
308 }
309 Err(err) => return Err(err),
310 };
311 let arrow_schema: Schema = dataset.schema().into();
312 let embedding_dim = embedding_dim_from_schema(&arrow_schema)?;
313 if !created && options.embedding_dim.is_some() && embedding_dim != requested_embedding_dim {
314 return Err(LanceError::from(ArrowError::InvalidArgumentError(format!(
315 "existing context embedding dimension {} does not match requested dimension {}",
316 embedding_dim, requested_embedding_dim
317 ))));
318 }
319 let distance_metric = distance_metric_from_schema(&arrow_schema)?;
320 if !created {
321 if let Some(requested) = options.distance_metric {
322 if requested != distance_metric {
323 return Err(LanceError::from(ArrowError::InvalidArgumentError(format!(
324 "existing context distance metric '{}' does not match requested metric '{}'",
325 distance_metric.as_str(),
326 requested.as_str()
327 ))));
328 }
329 }
330 }
331
332 let mut store = Self {
333 dataset,
334 compaction_state: Arc::new(Mutex::new(CompactionState {
335 background_task: None,
336 is_compacting: false,
337 last_compaction: None,
338 last_error: None,
339 total_compactions: 0,
340 })),
341 compaction_config: options.compaction,
342 blob_columns,
343 id_index_type: options.id_index_type,
344 embedding_dim,
345 distance_metric,
346 };
347
348 store.ensure_id_index().await?;
350
351 store.start_background_compaction().await?;
353
354 Ok(store)
355 }
356
357 #[must_use]
359 pub fn embedding_dim(&self) -> i32 {
360 self.embedding_dim
361 }
362
363 pub async fn add(&mut self, entries: &[ContextRecord]) -> LanceResult<u64> {
365 if entries.is_empty() {
366 return Ok(self.dataset.manifest.version);
367 }
368
369 self.validate_unique_ids(entries).await?;
370 self.write_entries(entries).await
371 }
372
373 async fn write_entries(&mut self, entries: &[ContextRecord]) -> LanceResult<u64> {
374 if entries.is_empty() {
375 return Ok(self.dataset.manifest.version);
376 }
377
378 let mut groups: HashMap<(Option<String>, Option<String>), Vec<ContextRecord>> =
380 HashMap::new();
381 for entry in entries {
382 let key = (entry.bot_id.clone(), entry.session_id.clone());
383 groups.entry(key).or_default().push(entry.clone());
384 }
385
386 {
388 let indices = self.dataset.load_indices().await?;
389 let has_mem_wal = indices.iter().any(|i| i.name == MEM_WAL_INDEX_NAME);
390
391 if !has_mem_wal {
392 let maintained_indexes: Vec<String> = indices
394 .iter()
395 .filter(|i| {
396 !(self.id_index_type == IdIndexType::ZoneMap && i.name == ID_INDEX_NAME)
397 })
398 .map(|i| i.name.clone())
399 .collect();
400 self.dataset
401 .initialize_mem_wal()
402 .unsharded()
403 .maintained_indexes(maintained_indexes)
404 .execute()
405 .await?;
406 }
407 }
408
409 for ((bot_id, session_id), group_entries) in groups {
410 let region_id = Self::derive_region_id(&bot_id, &session_id);
411 let batch = self.records_to_batch(&group_entries)?;
412 let config = ShardWriterConfig {
413 shard_id: region_id,
414 ..Default::default()
415 };
416
417 let writer = self.dataset.mem_wal_writer(region_id, config).await?;
418 writer.put(vec![batch]).await?;
419 writer.close().await?;
420 }
421
422 Ok(self.dataset.manifest.version)
423 }
424
425 pub async fn delete_by_id(&mut self, id: &str) -> LanceResult<bool> {
430 let Some(record) = self.get_by_id(id).await? else {
431 return Ok(false);
432 };
433 self.write_tombstone_for(record).await?;
434 Ok(true)
435 }
436
437 pub async fn delete_by_external_id(&mut self, external_id: &str) -> LanceResult<bool> {
439 let Some(record) = self.get_by_external_id(external_id).await? else {
440 return Ok(false);
441 };
442 self.write_tombstone_for(record).await?;
443 Ok(true)
444 }
445
446 pub async fn upsert_by_external_id(
454 &mut self,
455 mut record: ContextRecord,
456 ) -> LanceResult<UpsertResult> {
457 let Some(external_id) = record.external_id.clone() else {
458 return Err(ArrowError::InvalidArgumentError(
459 "upsert_by_external_id requires external_id".to_string(),
460 )
461 .into());
462 };
463 if external_id.is_empty() {
464 return Err(ArrowError::InvalidArgumentError(
465 "upsert_by_external_id requires a non-empty external_id".to_string(),
466 )
467 .into());
468 }
469 if record.is_tombstone() {
470 return Err(ArrowError::InvalidArgumentError(format!(
471 "content_type '{}' is reserved for internal tombstones",
472 CONTENT_TYPE_TOMBSTONE
473 ))
474 .into());
475 }
476 record.supersedes_id = None;
477 record.superseded_by_id = None;
478 self.validate_new_record_id(&record).await?;
479
480 let matches: Vec<ContextRecord> = self
481 .list(None, None)
482 .await?
483 .into_iter()
484 .filter(|existing| existing.external_id.as_deref() == Some(external_id.as_str()))
485 .collect();
486
487 match matches.as_slice() {
488 [] => {
489 let version = self.add(std::slice::from_ref(&record)).await?;
490 Ok(UpsertResult {
491 record,
492 inserted: true,
493 replaced_id: None,
494 version,
495 })
496 }
497 [existing] => {
498 record.supersedes_id = Some(existing.id.clone());
499 let version = self.write_entries(std::slice::from_ref(&record)).await?;
500 Ok(UpsertResult {
501 record,
502 inserted: false,
503 replaced_id: Some(existing.id.clone()),
504 version,
505 })
506 }
507 _ => Err(ArrowError::InvalidArgumentError(format!(
508 "external_id '{}' matches multiple visible records",
509 external_id
510 ))
511 .into()),
512 }
513 }
514
515 pub async fn update_by_id(
521 &mut self,
522 id: &str,
523 patch: RecordPatch,
524 ) -> LanceResult<Option<UpdateResult>> {
525 if id.is_empty() {
526 return Err(ArrowError::InvalidArgumentError(
527 "update_by_id requires a non-empty id".to_string(),
528 )
529 .into());
530 }
531 let Some(existing) = self.get_by_id(id).await? else {
532 return Ok(None);
533 };
534 self.update_visible_record(existing, patch).await.map(Some)
535 }
536
537 pub async fn update_by_external_id(
541 &mut self,
542 external_id: &str,
543 patch: RecordPatch,
544 ) -> LanceResult<Option<UpdateResult>> {
545 if external_id.is_empty() {
546 return Err(ArrowError::InvalidArgumentError(
547 "update_by_external_id requires a non-empty external_id".to_string(),
548 )
549 .into());
550 }
551
552 let matches: Vec<ContextRecord> = self
553 .list(None, None)
554 .await?
555 .into_iter()
556 .filter(|existing| existing.external_id.as_deref() == Some(external_id))
557 .collect();
558
559 match matches.as_slice() {
560 [] => Ok(None),
561 [existing] => self
562 .update_visible_record(existing.clone(), patch)
563 .await
564 .map(Some),
565 _ => Err(ArrowError::InvalidArgumentError(format!(
566 "external_id '{}' matches multiple visible records",
567 external_id
568 ))
569 .into()),
570 }
571 }
572
573 async fn update_visible_record(
574 &mut self,
575 existing: ContextRecord,
576 patch: RecordPatch,
577 ) -> LanceResult<UpdateResult> {
578 if patch.is_empty() {
579 return Err(ArrowError::InvalidArgumentError(
580 "update requires at least one patch field".to_string(),
581 )
582 .into());
583 }
584
585 let mut record = existing.clone();
586 record.id = Uuid::new_v4().to_string();
587 record.run_id = Uuid::new_v4().to_string();
588 record.created_at = Utc::now();
589 record.supersedes_id = Some(existing.id.clone());
590 record.superseded_by_id = None;
591
592 if let Some(bot_id) = patch.bot_id {
593 record.bot_id = Some(bot_id);
594 }
595 if let Some(session_id) = patch.session_id {
596 record.session_id = Some(session_id);
597 }
598 if let Some(tenant) = patch.tenant {
599 record.tenant = Some(tenant);
600 }
601 if let Some(source) = patch.source {
602 record.source = Some(source);
603 }
604 if let Some(state_metadata) = patch.state_metadata {
605 record.state_metadata = Some(state_metadata);
606 }
607 if let Some(metadata) = patch.metadata {
608 record.metadata = Some(metadata);
609 }
610 if let Some(relationships) = patch.relationships {
611 record.relationships = relationships;
612 }
613 if let Some(expires_at) = patch.expires_at {
614 record.expires_at = Some(expires_at);
615 }
616 if let Some(retention_policy) = patch.retention_policy {
617 record.retention_policy = Some(retention_policy);
618 }
619 if let Some(lifecycle_status) = patch.lifecycle_status {
620 record.lifecycle_status = lifecycle_status;
621 }
622 if let Some(retired_at) = patch.retired_at {
623 record.retired_at = Some(retired_at);
624 }
625 if let Some(retired_reason) = patch.retired_reason {
626 record.retired_reason = Some(retired_reason);
627 }
628 if let Some(embedding) = patch.embedding {
629 record.embedding = Some(embedding);
630 }
631
632 self.validate_new_record_id(&record).await?;
633 let version = self.write_entries(std::slice::from_ref(&record)).await?;
634 Ok(UpdateResult {
635 record,
636 replaced_id: existing.id,
637 version,
638 })
639 }
640
641 async fn write_tombstone_for(&mut self, record: ContextRecord) -> LanceResult<u64> {
642 let tombstone = ContextRecord {
643 id: record.id,
644 external_id: record.external_id,
645 run_id: record.run_id,
646 bot_id: record.bot_id,
647 session_id: record.session_id,
648 tenant: record.tenant,
649 source: record.source,
650 created_at: Utc::now(),
651 role: record.role,
652 state_metadata: None,
653 metadata: None,
654 relationships: Vec::new(),
655 expires_at: None,
656 retention_policy: None,
657 lifecycle_status: LIFECYCLE_ACTIVE.to_string(),
658 retired_at: None,
659 retired_reason: None,
660 supersedes_id: None,
661 superseded_by_id: None,
662 content_type: CONTENT_TYPE_TOMBSTONE.to_string(),
663 text_payload: None,
664 binary_payload: None,
665 embedding: None,
666 };
667 self.write_entries(std::slice::from_ref(&tombstone)).await
668 }
669
670 async fn validate_unique_ids(&self, entries: &[ContextRecord]) -> LanceResult<()> {
671 let mut ids = HashSet::new();
672 let mut external_ids = HashSet::new();
673 for entry in entries {
674 if entry.is_tombstone() {
675 return Err(ArrowError::InvalidArgumentError(format!(
676 "content_type '{}' is reserved for internal tombstones",
677 CONTENT_TYPE_TOMBSTONE
678 ))
679 .into());
680 }
681 if !ids.insert(entry.id.as_str()) {
682 return Err(ArrowError::InvalidArgumentError(format!(
683 "duplicate id '{}' in batch",
684 entry.id
685 ))
686 .into());
687 }
688 if let Some(external_id) = &entry.external_id {
689 if !external_ids.insert(external_id.as_str()) {
690 return Err(ArrowError::InvalidArgumentError(format!(
691 "duplicate external_id '{}' in batch",
692 external_id
693 ))
694 .into());
695 }
696 }
697 }
698
699 for record in self
700 .list_with_options(None, None, LifecycleQueryOptions::new(true, true))
701 .await?
702 {
703 if ids.contains(record.id.as_str()) {
704 return Err(ArrowError::InvalidArgumentError(format!(
705 "id '{}' already exists",
706 record.id
707 ))
708 .into());
709 }
710 if let Some(external_id) = record.external_id {
711 if external_ids.contains(external_id.as_str()) {
712 return Err(ArrowError::InvalidArgumentError(format!(
713 "external_id '{}' already exists",
714 external_id
715 ))
716 .into());
717 }
718 }
719 }
720
721 Ok(())
722 }
723
724 async fn validate_new_record_id(&self, entry: &ContextRecord) -> LanceResult<()> {
725 for record in self
726 .list_with_options(None, None, LifecycleQueryOptions::new(true, true))
727 .await?
728 {
729 if record.id == entry.id {
730 return Err(ArrowError::InvalidArgumentError(format!(
731 "id '{}' already exists",
732 entry.id
733 ))
734 .into());
735 }
736 }
737 Ok(())
738 }
739
740 fn derive_region_id(bot_id: &Option<String>, session_id: &Option<String>) -> Uuid {
741 let mut input = String::new();
742
743 if let Some(bid) = bot_id {
744 input.push_str(bid);
745 }
746 input.push('#');
747 if let Some(sid) = session_id {
748 input.push_str(sid);
749 }
750
751 Uuid::new_v5(&Uuid::NAMESPACE_OID, input.as_bytes())
753 }
754
755 fn has_relationships_column(&self) -> bool {
756 self.dataset
757 .schema()
758 .field_paths()
759 .iter()
760 .any(|path| path == RELATIONSHIPS_COLUMN)
761 }
762
763 pub fn version(&self) -> u64 {
765 self.dataset.manifest.version
766 }
767
768 pub async fn migrate_relationships_column(&mut self) -> LanceResult<bool> {
773 if self.has_relationships_column() {
774 return Ok(false);
775 }
776
777 let schema = Arc::new(Schema::new(vec![relationship_field()]));
778 self.dataset
779 .add_columns(NewColumnTransform::AllNulls(schema), None, None)
780 .await?;
781 Ok(true)
782 }
783
784 pub async fn checkout(&mut self, version_id: u64) -> LanceResult<()> {
786 let dataset = self.dataset.checkout_version(version_id).await?;
787 self.dataset = dataset;
788 Ok(())
789 }
790
791 pub async fn get(&self, id: &str) -> LanceResult<Option<ContextRecord>> {
793 let escaped_id = id.replace('\'', "''");
794 let mut scanner = self.dataset.scan();
795 scanner.filter(&format!("id = '{}'", escaped_id))?;
796 scanner.limit(Some(1), None)?;
797
798 let mut stream = scanner.try_into_stream().await?;
799 if let Some(batch) = stream.try_next().await? {
800 let records = batch_to_records(&batch)?;
801 return Ok(records.into_iter().next());
802 }
803 Ok(None)
804 }
805
806 pub async fn list(
808 &self,
809 limit: Option<usize>,
810 offset: Option<usize>,
811 ) -> LanceResult<Vec<ContextRecord>> {
812 self.list_filtered_with_options(limit, offset, None, LifecycleQueryOptions::default())
813 .await
814 }
815
816 pub async fn list_filtered(
818 &self,
819 limit: Option<usize>,
820 offset: Option<usize>,
821 filters: Option<&RecordFilters>,
822 ) -> LanceResult<Vec<ContextRecord>> {
823 self.list_filtered_with_options(limit, offset, filters, LifecycleQueryOptions::default())
824 .await
825 }
826
827 pub async fn list_with_options(
829 &self,
830 limit: Option<usize>,
831 offset: Option<usize>,
832 options: LifecycleQueryOptions,
833 ) -> LanceResult<Vec<ContextRecord>> {
834 self.list_filtered_with_options(limit, offset, None, options)
835 .await
836 }
837
838 pub async fn list_filtered_with_options(
840 &self,
841 limit: Option<usize>,
842 offset: Option<usize>,
843 filters: Option<&RecordFilters>,
844 options: LifecycleQueryOptions,
845 ) -> LanceResult<Vec<ContextRecord>> {
846 let scanner = self.lsm_scanner().await?;
847 let mut stream = scanner.try_into_stream().await?;
848 let mut results = Vec::new();
849 while let Some(batch) = stream.try_next().await? {
850 results.extend(batch_to_records(&batch)?);
851 }
852
853 let superseded_ids: HashSet<String> = results
854 .iter()
855 .filter_map(|record| {
856 let supersedes_id = record.supersedes_id.as_ref()?;
857 if supersedes_id == &record.id {
858 None
859 } else {
860 Some(supersedes_id.clone())
861 }
862 })
863 .collect();
864 results.retain(|record| {
865 options.is_visible(record)
866 && (options.include_retired || !superseded_ids.contains(&record.id))
867 });
868 if let Some(filters) = filters.filter(|filters| !filters.is_empty()) {
869 results.retain(|record| filters.matches(record));
870 }
871
872 if let Some(offset) = offset {
873 results = results.into_iter().skip(offset).collect();
874 }
875 if let Some(limit) = limit {
876 results.truncate(limit);
877 }
878 Ok(results)
879 }
880
881 pub async fn get_by_id(&self, id: &str) -> LanceResult<Option<ContextRecord>> {
883 Ok(self
884 .list(None, None)
885 .await?
886 .into_iter()
887 .find(|record| record.id == id))
888 }
889
890 pub async fn get_by_external_id(
892 &self,
893 external_id: &str,
894 ) -> LanceResult<Option<ContextRecord>> {
895 Ok(self
896 .list(None, None)
897 .await?
898 .into_iter()
899 .find(|record| record.external_id.as_deref() == Some(external_id)))
900 }
901
902 pub async fn list_related(
904 &self,
905 target_id: &str,
906 relation: Option<&str>,
907 limit: Option<usize>,
908 ) -> LanceResult<Vec<ContextRecord>> {
909 self.list_related_with_options(target_id, relation, limit, LifecycleQueryOptions::default())
910 .await
911 }
912
913 pub async fn list_related_with_options(
915 &self,
916 target_id: &str,
917 relation: Option<&str>,
918 limit: Option<usize>,
919 options: LifecycleQueryOptions,
920 ) -> LanceResult<Vec<ContextRecord>> {
921 let mut results: Vec<ContextRecord> = self
922 .list_with_options(None, None, options)
923 .await?
924 .into_iter()
925 .filter(|record| {
926 record.relationships.iter().any(|relationship| {
927 relationship.target_id == target_id
928 && relation.is_none_or(|value| relationship.relation == value)
929 })
930 })
931 .collect();
932
933 if let Some(limit) = limit {
934 results.truncate(limit);
935 }
936 Ok(results)
937 }
938
939 pub async fn search(
941 &self,
942 query: &[f32],
943 limit: Option<usize>,
944 ) -> LanceResult<Vec<SearchResult>> {
945 self.search_filtered_with_options(query, limit, None, LifecycleQueryOptions::default())
946 .await
947 }
948
949 pub async fn search_filtered(
951 &self,
952 query: &[f32],
953 limit: Option<usize>,
954 filters: Option<&RecordFilters>,
955 ) -> LanceResult<Vec<SearchResult>> {
956 self.search_filtered_with_options(query, limit, filters, LifecycleQueryOptions::default())
957 .await
958 }
959
960 pub async fn search_with_options(
962 &self,
963 query: &[f32],
964 limit: Option<usize>,
965 options: LifecycleQueryOptions,
966 ) -> LanceResult<Vec<SearchResult>> {
967 self.search_filtered_with_options(query, limit, None, options)
968 .await
969 }
970
971 pub async fn search_filtered_with_options(
973 &self,
974 query: &[f32],
975 limit: Option<usize>,
976 filters: Option<&RecordFilters>,
977 options: LifecycleQueryOptions,
978 ) -> LanceResult<Vec<SearchResult>> {
979 validate_query_dimension(query, self.embedding_dim)?;
980
981 let top_k = limit.unwrap_or(DEFAULT_SEARCH_LIMIT);
982 if top_k == 0 {
983 return Ok(Vec::new());
984 }
985
986 let mut results: Vec<SearchResult> = self
987 .list_filtered_with_options(None, None, filters, options)
988 .await?
989 .into_iter()
990 .filter_map(|record| {
991 let distance = self
992 .distance_metric
993 .distance(query, record.embedding.as_ref()?);
994 Some(SearchResult { record, distance })
995 })
996 .collect();
997 results.sort_by(|left, right| left.distance.total_cmp(&right.distance));
998 results.truncate(top_k);
999 Ok(results)
1000 }
1001
1002 pub async fn retrieve_filtered_with_options(
1004 &self,
1005 text: Option<&str>,
1006 vector: Option<&[f32]>,
1007 limit: Option<usize>,
1008 filters: Option<&RecordFilters>,
1009 options: LifecycleQueryOptions,
1010 ) -> LanceResult<Vec<RetrieveResult>> {
1011 let text_terms = text.map(unique_query_terms).unwrap_or_default();
1012 let has_text = !text_terms.is_empty();
1013
1014 if !has_text && vector.is_none() {
1015 return Err(ArrowError::InvalidArgumentError(
1016 "retrieve requires text or vector".to_string(),
1017 )
1018 .into());
1019 }
1020
1021 if let Some(query) = vector {
1022 validate_query_dimension(query, self.embedding_dim)?;
1023 }
1024
1025 let top_k = limit.unwrap_or(DEFAULT_SEARCH_LIMIT);
1026 if top_k == 0 {
1027 return Ok(Vec::new());
1028 }
1029
1030 let records = self
1031 .list_filtered_with_options(None, None, filters, options)
1032 .await?;
1033 let mut candidates: HashMap<String, RetrieveResult> = HashMap::new();
1034
1035 if let Some(query) = vector {
1036 let mut vector_hits: Vec<(usize, f32)> = records
1037 .iter()
1038 .enumerate()
1039 .filter_map(|(index, record)| {
1040 let distance = self
1041 .distance_metric
1042 .distance(query, record.embedding.as_ref()?);
1043 Some((index, distance))
1044 })
1045 .collect();
1046 vector_hits.sort_by(|left, right| {
1047 left.1
1048 .total_cmp(&right.1)
1049 .then_with(|| records[left.0].id.cmp(&records[right.0].id))
1050 });
1051
1052 for (rank, (index, distance)) in vector_hits.into_iter().enumerate() {
1053 add_retrieve_channel(
1054 &mut candidates,
1055 &records[index],
1056 rank + 1,
1057 "vector",
1058 Some(distance),
1059 None,
1060 );
1061 }
1062 }
1063
1064 if has_text {
1065 let mut text_hits: Vec<(usize, f32)> = records
1066 .iter()
1067 .enumerate()
1068 .filter_map(|(index, record)| {
1069 lexical_score(&text_terms, record.text_payload.as_deref())
1070 .map(|score| (index, score))
1071 })
1072 .collect();
1073 text_hits.sort_by(|left, right| {
1074 right
1075 .1
1076 .total_cmp(&left.1)
1077 .then_with(|| records[left.0].id.cmp(&records[right.0].id))
1078 });
1079
1080 for (rank, (index, score)) in text_hits.into_iter().enumerate() {
1081 add_retrieve_channel(
1082 &mut candidates,
1083 &records[index],
1084 rank + 1,
1085 "text",
1086 None,
1087 Some(score),
1088 );
1089 }
1090 }
1091
1092 let mut results: Vec<RetrieveResult> = candidates.into_values().collect();
1093 results.sort_by(compare_retrieve_results);
1094 results.truncate(top_k);
1095 Ok(results)
1096 }
1097
1098 async fn lsm_scanner(&self) -> LanceResult<LsmScanner> {
1099 let object_store = self.dataset.object_store(None).await?;
1100 let branch_location = self.dataset.branch_location();
1101 let shard_ids = self.dataset.list_mem_wal_latest_shard_ids().await?;
1102
1103 let mut shard_snapshots = Vec::with_capacity(shard_ids.len());
1104 for shard_id in shard_ids {
1105 let manifest_store = ShardManifestStore::new(
1106 object_store.clone(),
1107 &branch_location.path,
1108 shard_id,
1109 DEFAULT_MANIFEST_SCAN_BATCH_SIZE,
1110 );
1111 let Some(manifest) = manifest_store.read_latest().await? else {
1112 continue;
1113 };
1114
1115 let mut snapshot = ShardSnapshot::new(shard_id)
1116 .with_spec_id(manifest.shard_spec_id)
1117 .with_current_generation(manifest.current_generation);
1118 for flushed in manifest.flushed_generations {
1119 snapshot = snapshot.with_flushed_generation(flushed.generation, flushed.path);
1120 }
1121 shard_snapshots.push(snapshot);
1122 }
1123
1124 Ok(LsmScanner::new(
1125 Arc::new(self.dataset.clone()),
1126 shard_snapshots,
1127 vec!["id".to_string()],
1128 ))
1129 }
1130
1131 pub async fn compact(
1133 &mut self,
1134 options: Option<CompactionConfig>,
1135 ) -> LanceResult<CompactionMetrics> {
1136 let config = options.unwrap_or_else(|| self.compaction_config.clone());
1137
1138 info!(
1139 "Starting compaction: {} fragments",
1140 self.dataset.count_fragments()
1141 );
1142 let start = std::time::Instant::now();
1143
1144 {
1146 let mut state = self.compaction_state.lock().await;
1147 if state.is_compacting {
1148 warn!("Compaction already in progress, skipping");
1149 return Err(LanceError::from(ArrowError::InvalidArgumentError(
1150 "Compaction already in progress".to_string(),
1151 )));
1152 }
1153 state.is_compacting = true;
1154 }
1155
1156 let lance_options = CompactionOptions {
1158 target_rows_per_fragment: config.target_rows_per_fragment,
1159 max_rows_per_group: config.max_rows_per_group,
1160 materialize_deletions: config.materialize_deletions,
1161 materialize_deletions_threshold: config.materialize_deletions_threshold,
1162 num_threads: config.num_threads,
1163 ..Default::default()
1164 };
1165
1166 let result = compact_files(&mut self.dataset, lance_options, None).await;
1168
1169 let mut state = self.compaction_state.lock().await;
1171 state.is_compacting = false;
1172
1173 match result {
1174 Ok(metrics) => {
1175 state.last_compaction = Some(Utc::now());
1176 state.total_compactions += 1;
1177 state.last_error = None;
1178 drop(state); info!(
1181 "Compaction completed in {:?}: removed {} fragments ({}files), added {} fragments ({} files)",
1182 start.elapsed(),
1183 metrics.fragments_removed,
1184 metrics.files_removed,
1185 metrics.fragments_added,
1186 metrics.files_added
1187 );
1188
1189 self.dataset = Dataset::open(self.dataset.uri()).await?;
1191
1192 if let Err(e) = self.ensure_id_index().await {
1195 warn!("Failed to ensure id index after compaction: {}", e);
1196 }
1197
1198 Ok(metrics)
1199 }
1200 Err(e) => {
1201 error!("Compaction failed: {}", e);
1202 state.last_error = Some(e.to_string());
1203 Err(e)
1204 }
1205 }
1206 }
1207
1208 pub async fn should_compact(&self) -> LanceResult<bool> {
1210 let fragment_count = self.dataset.count_fragments();
1211
1212 if fragment_count < self.compaction_config.min_fragments {
1213 return Ok(false);
1214 }
1215
1216 if !self.compaction_config.quiet_hours.is_empty() {
1218 let now = Utc::now();
1219 let current_hour = now.hour() as u8;
1220
1221 for (start, end) in &self.compaction_config.quiet_hours {
1222 if current_hour >= *start && current_hour < *end {
1223 info!("Skipping compaction during quiet hours ({}-{})", start, end);
1224 return Ok(false);
1225 }
1226 }
1227 }
1228
1229 Ok(true)
1230 }
1231
1232 pub async fn compaction_stats(&self) -> LanceResult<CompactionStats> {
1234 let state = self.compaction_state.lock().await;
1235
1236 Ok(CompactionStats {
1237 total_fragments: self.dataset.count_fragments(),
1238 is_compacting: state.is_compacting,
1239 last_compaction: state.last_compaction,
1240 last_error: state.last_error.clone(),
1241 total_compactions: state.total_compactions,
1242 })
1243 }
1244
1245 async fn ensure_id_index(&mut self) -> LanceResult<()> {
1247 if self.id_index_type == IdIndexType::None {
1248 return Ok(());
1249 }
1250
1251 let indices = self.dataset.load_indices().await?;
1252 if indices.iter().any(|i| i.name == ID_INDEX_NAME) {
1253 return Ok(());
1254 }
1255
1256 self.create_id_index().await
1257 }
1258
1259 pub async fn create_id_index(&mut self) -> LanceResult<()> {
1261 let index_type = match self.id_index_type {
1262 IdIndexType::ZoneMap => IndexType::ZoneMap,
1263 IdIndexType::BTree => IndexType::BTree,
1264 IdIndexType::None => return Ok(()),
1265 };
1266
1267 info!("Creating {:?} index on id column", index_type);
1268
1269 let params = ScalarIndexParams::default();
1270
1271 self.dataset
1272 .create_index_builder(&["id"], index_type, ¶ms)
1273 .name(ID_INDEX_NAME.to_string())
1274 .replace(true)
1275 .await?;
1276
1277 self.dataset = Dataset::open(self.dataset.uri()).await?;
1279
1280 Ok(())
1281 }
1282
1283 async fn start_background_compaction(&mut self) -> LanceResult<()> {
1285 if !self.compaction_config.enabled {
1286 return Ok(());
1287 }
1288
1289 let mut state = self.compaction_state.lock().await;
1290 if state.background_task.is_some() {
1291 warn!("Background compaction already running");
1292 return Ok(());
1293 }
1294
1295 info!(
1296 "Starting background compaction (interval: {}s, min fragments: {})",
1297 self.compaction_config.check_interval_secs, self.compaction_config.min_fragments
1298 );
1299
1300 let mut store_clone = self.clone();
1301 let interval_secs = self.compaction_config.check_interval_secs;
1302
1303 let task = tokio::spawn(async move {
1304 let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
1305
1306 loop {
1307 interval.tick().await;
1308
1309 match store_clone.should_compact().await {
1310 Ok(true) => {
1311 info!("Background compaction triggered");
1312 if let Err(e) = store_clone.compact(None).await {
1313 error!("Background compaction failed: {}", e);
1314 }
1315 }
1316 Ok(false) => {
1317 }
1319 Err(e) => {
1320 error!("Error checking compaction need: {}", e);
1321 }
1322 }
1323 }
1324 });
1325
1326 state.background_task = Some(task);
1327 Ok(())
1328 }
1329
1330 pub async fn stop_background_compaction(&mut self) -> LanceResult<()> {
1332 let mut state = self.compaction_state.lock().await;
1333
1334 if let Some(task) = state.background_task.take() {
1335 info!("Stopping background compaction");
1336 task.abort();
1337 }
1338
1339 Ok(())
1340 }
1341
1342 pub fn schema(blob_columns: &HashSet<String>) -> Schema {
1348 Self::schema_with_embedding_dim(blob_columns, DEFAULT_EMBEDDING_DIM)
1349 }
1350
1351 pub fn schema_with_embedding_dim(blob_columns: &HashSet<String>, embedding_dim: i32) -> Schema {
1353 Self::schema_with_options(
1354 blob_columns,
1355 true,
1356 true,
1357 true,
1358 true,
1359 embedding_dim,
1360 DistanceMetric::default(),
1361 )
1362 }
1363
1364 fn schema_with_options(
1365 blob_columns: &HashSet<String>,
1366 include_external_id: bool,
1367 include_metadata: bool,
1368 include_relationships: bool,
1369 include_lifecycle: bool,
1370 embedding_dim: i32,
1371 distance_metric: DistanceMetric,
1372 ) -> Schema {
1373 let mut id_metadata = HashMap::new();
1374 id_metadata.insert(
1375 "lance-schema:unenforced-primary-key".to_string(),
1376 "true".to_string(),
1377 );
1378
1379 let text_field = if blob_columns.contains("text_payload") {
1380 let mut metadata = HashMap::new();
1381 metadata.insert("lance-encoding:blob".to_string(), "true".to_string());
1382 Field::new("text_payload", DataType::LargeBinary, true).with_metadata(metadata)
1383 } else {
1384 Field::new("text_payload", DataType::LargeUtf8, true)
1385 };
1386
1387 let binary_field = if blob_columns.contains("binary_payload") {
1388 let mut metadata = HashMap::new();
1389 metadata.insert("lance-encoding:blob".to_string(), "true".to_string());
1390 Field::new("binary_payload", DataType::LargeBinary, true).with_metadata(metadata)
1391 } else {
1392 Field::new("binary_payload", DataType::LargeBinary, true)
1393 };
1394
1395 let mut fields = vec![Field::new("id", DataType::Utf8, false).with_metadata(id_metadata)];
1396 if include_external_id {
1397 fields.push(Field::new("external_id", DataType::Utf8, true));
1398 }
1399 fields.extend([
1400 Field::new("run_id", DataType::Utf8, false),
1401 Field::new("bot_id", DataType::Utf8, true),
1402 Field::new("session_id", DataType::Utf8, true),
1403 Field::new("tenant", DataType::Utf8, true),
1404 Field::new("source", DataType::Utf8, true),
1405 Field::new(
1406 "created_at",
1407 DataType::Timestamp(TimeUnit::Microsecond, None),
1408 false,
1409 ),
1410 Field::new(
1411 "role",
1412 DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
1413 false,
1414 ),
1415 Field::new(
1416 "state_metadata",
1417 DataType::Struct(
1418 vec![
1419 Field::new("step", DataType::Int32, true),
1420 Field::new("active_plan_id", DataType::Utf8, true),
1421 Field::new("tokens_used", DataType::Int32, true),
1422 Field::new("custom", DataType::Utf8, true),
1423 ]
1424 .into(),
1425 ),
1426 true,
1427 ),
1428 ]);
1429 if include_metadata {
1430 fields.push(Field::new("metadata", DataType::LargeUtf8, true));
1431 }
1432 if include_relationships {
1433 fields.push(relationship_field());
1434 }
1435 if include_lifecycle {
1436 fields.extend([
1437 Field::new(
1438 "expires_at",
1439 DataType::Timestamp(TimeUnit::Microsecond, None),
1440 true,
1441 ),
1442 Field::new("retention_policy", DataType::Utf8, true),
1443 Field::new("lifecycle_status", DataType::Utf8, false),
1444 Field::new(
1445 "retired_at",
1446 DataType::Timestamp(TimeUnit::Microsecond, None),
1447 true,
1448 ),
1449 Field::new("retired_reason", DataType::Utf8, true),
1450 Field::new("supersedes_id", DataType::Utf8, true),
1451 Field::new("superseded_by_id", DataType::Utf8, true),
1452 ]);
1453 }
1454 fields.extend([
1455 Field::new("content_type", DataType::Utf8, false),
1456 text_field,
1457 binary_field,
1458 Field::new(
1459 "embedding",
1460 DataType::FixedSizeList(
1461 Arc::new(Field::new("item", DataType::Float32, true)),
1462 embedding_dim,
1463 ),
1464 true,
1465 ),
1466 ]);
1467
1468 let schema_metadata = HashMap::from([(
1469 DISTANCE_METRIC_METADATA_KEY.to_string(),
1470 distance_metric.as_str().to_string(),
1471 )]);
1472
1473 Schema::new_with_metadata(fields, schema_metadata)
1474 }
1475
1476 async fn load_with_options(
1477 uri: &str,
1478 storage_options: Option<HashMap<String, String>>,
1479 ) -> LanceResult<Dataset> {
1480 if let Some(options) = storage_options {
1481 DatasetBuilder::from_uri(uri)
1482 .with_storage_options(options)
1483 .load()
1484 .await
1485 } else {
1486 Dataset::open(uri).await
1487 }
1488 }
1489
1490 async fn create_with_options(
1491 uri: &str,
1492 storage_options: Option<HashMap<String, String>>,
1493 blob_columns: &HashSet<String>,
1494 embedding_dim: i32,
1495 distance_metric: DistanceMetric,
1496 ) -> LanceResult<Dataset> {
1497 let schema = Arc::new(Self::schema_with_options(
1498 blob_columns,
1499 true,
1500 true,
1501 true,
1502 true,
1503 embedding_dim,
1504 distance_metric,
1505 ));
1506 let empty_batch = RecordBatch::new_empty(schema.clone());
1507 let batches = RecordBatchIterator::new(
1508 vec![Ok::<RecordBatch, ArrowError>(empty_batch)].into_iter(),
1509 schema.clone(),
1510 );
1511
1512 let mut params = WriteParams {
1513 mode: WriteMode::Create,
1514 ..Default::default()
1515 };
1516
1517 if let Some(options) = storage_options {
1518 let store_params = ObjectStoreParams {
1519 storage_options_accessor: Some(Arc::new(
1520 StorageOptionsAccessor::with_static_options(options),
1521 )),
1522 ..Default::default()
1523 };
1524 params.store_params = Some(store_params);
1525 }
1526
1527 Dataset::write(batches, uri, Some(params)).await
1528 }
1529
1530 fn records_to_batch(&self, entries: &[ContextRecord]) -> LanceResult<RecordBatch> {
1531 let include_external_id = self
1532 .dataset
1533 .schema()
1534 .field_paths()
1535 .iter()
1536 .any(|path| path == "external_id");
1537 let include_lifecycle = self
1538 .dataset
1539 .schema()
1540 .field_paths()
1541 .iter()
1542 .any(|path| path == "expires_at");
1543 let include_metadata = self
1544 .dataset
1545 .schema()
1546 .field_paths()
1547 .iter()
1548 .any(|path| path == "metadata");
1549 let include_tenant = self
1550 .dataset
1551 .schema()
1552 .field_paths()
1553 .iter()
1554 .any(|path| path == "tenant");
1555 let include_source = self
1556 .dataset
1557 .schema()
1558 .field_paths()
1559 .iter()
1560 .any(|path| path == "source");
1561 let include_relationships = self.has_relationships_column();
1562 if !include_external_id && entries.iter().any(|entry| entry.external_id.is_some()) {
1563 return Err(ArrowError::InvalidArgumentError(
1564 "external_id requires a context dataset created with external_id support"
1565 .to_string(),
1566 )
1567 .into());
1568 }
1569 if !include_metadata && entries.iter().any(|entry| entry.metadata.is_some()) {
1570 return Err(ArrowError::InvalidArgumentError(
1571 "metadata requires a context dataset created with metadata support".to_string(),
1572 )
1573 .into());
1574 }
1575 if !include_tenant && entries.iter().any(|entry| entry.tenant.is_some()) {
1576 return Err(ArrowError::InvalidArgumentError(
1577 "tenant requires a context dataset created with partition-key column support"
1578 .to_string(),
1579 )
1580 .into());
1581 }
1582 if !include_source && entries.iter().any(|entry| entry.source.is_some()) {
1583 return Err(ArrowError::InvalidArgumentError(
1584 "source requires a context dataset created with partition-key column support"
1585 .to_string(),
1586 )
1587 .into());
1588 }
1589 if !include_relationships && entries.iter().any(|entry| !entry.relationships.is_empty()) {
1590 return Err(ArrowError::InvalidArgumentError(
1591 "relationships require a context dataset with relationships support; run migrate_relationships_column() on older datasets".to_string(),
1592 )
1593 .into());
1594 }
1595 if !include_lifecycle && entries.iter().any(ContextRecord::has_non_default_lifecycle) {
1596 return Err(ArrowError::InvalidArgumentError(
1597 "lifecycle fields require a context dataset created with lifecycle support"
1598 .to_string(),
1599 )
1600 .into());
1601 }
1602
1603 let mut id_builder = StringBuilder::new();
1604 let mut external_id_builder = StringBuilder::new();
1605 let mut run_id_builder = StringBuilder::new();
1606 let mut bot_id_builder = StringBuilder::new();
1607 let mut session_id_builder = StringBuilder::new();
1608 let mut tenant_builder = StringBuilder::new();
1609 let mut source_builder = StringBuilder::new();
1610 let mut created_at_builder = TimestampMicrosecondBuilder::with_capacity(entries.len());
1611 let mut role_builder = StringDictionaryBuilder::<Int8Type>::new();
1612 let mut metadata_builder = LargeStringBuilder::new();
1613 let mut relationships_builder = ListBuilder::new(relationship_struct_builder())
1614 .with_field(relationship_list_item_field());
1615 let mut expires_at_builder = TimestampMicrosecondBuilder::with_capacity(entries.len());
1616 let mut retention_policy_builder = StringBuilder::new();
1617 let mut lifecycle_status_builder = StringBuilder::new();
1618 let mut retired_at_builder = TimestampMicrosecondBuilder::with_capacity(entries.len());
1619 let mut retired_reason_builder = StringBuilder::new();
1620 let mut supersedes_id_builder = StringBuilder::new();
1621 let mut superseded_by_id_builder = StringBuilder::new();
1622 let mut content_type_builder = StringBuilder::new();
1623 let mut binary_builder = LargeBinaryBuilder::new();
1624
1625 let text_is_blob = self.blob_columns.contains("text_payload");
1626 let mut text_string_builder = if !text_is_blob {
1627 Some(LargeStringBuilder::new())
1628 } else {
1629 None
1630 };
1631 let mut text_binary_builder = if text_is_blob {
1632 Some(LargeBinaryBuilder::new())
1633 } else {
1634 None
1635 };
1636
1637 let state_fields: Vec<FieldRef> = vec![
1638 Arc::new(Field::new("step", DataType::Int32, true)),
1639 Arc::new(Field::new("active_plan_id", DataType::Utf8, true)),
1640 Arc::new(Field::new("tokens_used", DataType::Int32, true)),
1641 Arc::new(Field::new("custom", DataType::Utf8, true)),
1642 ];
1643 let mut state_builder = StructBuilder::new(
1644 state_fields,
1645 vec![
1646 Box::new(Int32Builder::new()),
1647 Box::new(StringBuilder::new()),
1648 Box::new(Int32Builder::new()),
1649 Box::new(StringBuilder::new()),
1650 ],
1651 );
1652
1653 let mut embedding_builder =
1654 FixedSizeListBuilder::new(Float32Builder::new(), self.embedding_dim);
1655
1656 for entry in entries {
1657 id_builder.append_value(&entry.id);
1658 external_id_builder.append_option(entry.external_id.as_deref());
1659 run_id_builder.append_value(&entry.run_id);
1660 bot_id_builder.append_option(entry.bot_id.as_deref());
1661 session_id_builder.append_option(entry.session_id.as_deref());
1662 tenant_builder.append_option(entry.tenant.as_deref());
1663 source_builder.append_option(entry.source.as_deref());
1664 created_at_builder.append_value(entry.created_at.timestamp_micros());
1665 role_builder.append(&entry.role)?;
1666 match &entry.metadata {
1667 Some(metadata) => metadata_builder.append_value(metadata.to_string()),
1668 None => metadata_builder.append_null(),
1669 }
1670 for relationship in &entry.relationships {
1671 let values_builder = relationships_builder.values();
1672 values_builder
1673 .field_builder::<StringBuilder>(0)
1674 .unwrap()
1675 .append_value(&relationship.target_id);
1676 values_builder
1677 .field_builder::<StringBuilder>(1)
1678 .unwrap()
1679 .append_value(&relationship.relation);
1680 values_builder
1681 .field_builder::<Float32Builder>(2)
1682 .unwrap()
1683 .append_option(relationship.weight);
1684 values_builder.append(true);
1685 }
1686 relationships_builder.append(true);
1687 expires_at_builder
1688 .append_option(entry.expires_at.map(|value| value.timestamp_micros()));
1689 retention_policy_builder.append_option(entry.retention_policy.as_deref());
1690 lifecycle_status_builder.append_value(&entry.lifecycle_status);
1691 retired_at_builder
1692 .append_option(entry.retired_at.map(|value| value.timestamp_micros()));
1693 retired_reason_builder.append_option(entry.retired_reason.as_deref());
1694 supersedes_id_builder.append_option(entry.supersedes_id.as_deref());
1695 superseded_by_id_builder.append_option(entry.superseded_by_id.as_deref());
1696 content_type_builder.append_value(&entry.content_type);
1697
1698 if text_is_blob {
1699 match &entry.text_payload {
1700 Some(value) => text_binary_builder
1701 .as_mut()
1702 .unwrap()
1703 .append_value(value.as_bytes()),
1704 None => text_binary_builder.as_mut().unwrap().append_null(),
1705 }
1706 } else {
1707 match &entry.text_payload {
1708 Some(value) => text_string_builder.as_mut().unwrap().append_value(value),
1709 None => text_string_builder.as_mut().unwrap().append_null(),
1710 }
1711 }
1712
1713 match &entry.binary_payload {
1714 Some(value) => binary_builder.append_value(value),
1715 None => binary_builder.append_null(),
1716 }
1717
1718 if let Some(metadata) = &entry.state_metadata {
1719 state_builder
1720 .field_builder::<Int32Builder>(0)
1721 .unwrap()
1722 .append_option(metadata.step);
1723 state_builder
1724 .field_builder::<StringBuilder>(1)
1725 .unwrap()
1726 .append_option(metadata.active_plan_id.as_deref());
1727 state_builder
1728 .field_builder::<Int32Builder>(2)
1729 .unwrap()
1730 .append_option(metadata.tokens_used);
1731 state_builder
1732 .field_builder::<StringBuilder>(3)
1733 .unwrap()
1734 .append_option(metadata.custom.as_deref());
1735 state_builder.append(true);
1736 } else {
1737 state_builder
1738 .field_builder::<Int32Builder>(0)
1739 .unwrap()
1740 .append_null();
1741 state_builder
1742 .field_builder::<StringBuilder>(1)
1743 .unwrap()
1744 .append_null();
1745 state_builder
1746 .field_builder::<Int32Builder>(2)
1747 .unwrap()
1748 .append_null();
1749 state_builder
1750 .field_builder::<StringBuilder>(3)
1751 .unwrap()
1752 .append_null();
1753 state_builder.append(false);
1754 }
1755
1756 if let Some(embedding) = &entry.embedding {
1757 if embedding.len() != self.embedding_dim as usize {
1758 return Err(ArrowError::InvalidArgumentError(format!(
1759 "embedding length {} does not match expected dimension {}",
1760 embedding.len(),
1761 self.embedding_dim
1762 ))
1763 .into());
1764 }
1765 {
1766 let values_builder = embedding_builder.values();
1767 for value in embedding {
1768 values_builder.append_value(*value);
1769 }
1770 }
1771 embedding_builder.append(true);
1772 } else {
1773 let values_builder = embedding_builder.values();
1775 for _ in 0..self.embedding_dim {
1776 values_builder.append_null();
1777 }
1778 embedding_builder.append(false);
1779 }
1780 }
1781
1782 let id_array: ArrayRef = Arc::new(id_builder.finish());
1783 let external_id_array: ArrayRef = Arc::new(external_id_builder.finish());
1784 let run_id_array: ArrayRef = Arc::new(run_id_builder.finish());
1785 let bot_id_array: ArrayRef = Arc::new(bot_id_builder.finish());
1786 let session_id_array: ArrayRef = Arc::new(session_id_builder.finish());
1787 let tenant_array: ArrayRef = Arc::new(tenant_builder.finish());
1788 let source_array: ArrayRef = Arc::new(source_builder.finish());
1789 let created_at_array: ArrayRef = Arc::new(created_at_builder.finish());
1790 let role_array: ArrayRef = Arc::new(role_builder.finish());
1791 let metadata_array: ArrayRef = Arc::new(metadata_builder.finish());
1792 let relationships_array: ArrayRef = Arc::new(relationships_builder.finish());
1793 let expires_at_array: ArrayRef = Arc::new(expires_at_builder.finish());
1794 let retention_policy_array: ArrayRef = Arc::new(retention_policy_builder.finish());
1795 let lifecycle_status_array: ArrayRef = Arc::new(lifecycle_status_builder.finish());
1796 let retired_at_array: ArrayRef = Arc::new(retired_at_builder.finish());
1797 let retired_reason_array: ArrayRef = Arc::new(retired_reason_builder.finish());
1798 let supersedes_id_array: ArrayRef = Arc::new(supersedes_id_builder.finish());
1799 let superseded_by_id_array: ArrayRef = Arc::new(superseded_by_id_builder.finish());
1800 let content_type_array: ArrayRef = Arc::new(content_type_builder.finish());
1801 let text_array: ArrayRef = if text_is_blob {
1802 Arc::new(text_binary_builder.unwrap().finish())
1803 } else {
1804 Arc::new(text_string_builder.unwrap().finish())
1805 };
1806 let binary_array: ArrayRef = Arc::new(binary_builder.finish());
1807 let state_array: ArrayRef = Arc::new(state_builder.finish());
1808 let embedding_array: ArrayRef = Arc::new(embedding_builder.finish());
1809
1810 let mut arrays_by_name = HashMap::from([("id".to_string(), id_array)]);
1811 if include_external_id {
1812 arrays_by_name.insert("external_id".to_string(), external_id_array);
1813 }
1814 arrays_by_name.extend([
1815 ("run_id".to_string(), run_id_array),
1816 ("bot_id".to_string(), bot_id_array),
1817 ("session_id".to_string(), session_id_array),
1818 ("created_at".to_string(), created_at_array),
1819 ("role".to_string(), role_array),
1820 ("state_metadata".to_string(), state_array),
1821 ]);
1822 if include_tenant {
1823 arrays_by_name.insert("tenant".to_string(), tenant_array);
1824 }
1825 if include_source {
1826 arrays_by_name.insert("source".to_string(), source_array);
1827 }
1828 if include_metadata {
1829 arrays_by_name.insert("metadata".to_string(), metadata_array);
1830 }
1831 if include_relationships {
1832 arrays_by_name.insert(RELATIONSHIPS_COLUMN.to_string(), relationships_array);
1833 }
1834 if include_lifecycle {
1835 arrays_by_name.extend([
1836 ("expires_at".to_string(), expires_at_array),
1837 ("retention_policy".to_string(), retention_policy_array),
1838 ("lifecycle_status".to_string(), lifecycle_status_array),
1839 ("retired_at".to_string(), retired_at_array),
1840 ("retired_reason".to_string(), retired_reason_array),
1841 ("supersedes_id".to_string(), supersedes_id_array),
1842 ("superseded_by_id".to_string(), superseded_by_id_array),
1843 ]);
1844 }
1845 arrays_by_name.extend([
1846 ("content_type".to_string(), content_type_array),
1847 ("text_payload".to_string(), text_array),
1848 ("binary_payload".to_string(), binary_array),
1849 ("embedding".to_string(), embedding_array),
1850 ]);
1851
1852 let schema: Arc<Schema> = Arc::new(self.dataset.schema().into());
1853 let arrays = schema
1854 .fields()
1855 .iter()
1856 .map(|field| {
1857 arrays_by_name.remove(field.name().as_str()).ok_or_else(|| {
1858 LanceError::from(ArrowError::InvalidArgumentError(format!(
1859 "unsupported dataset column '{}'",
1860 field.name()
1861 )))
1862 })
1863 })
1864 .collect::<LanceResult<Vec<_>>>()?;
1865 let batch = RecordBatch::try_new(schema, arrays)?;
1866
1867 Ok(batch)
1868 }
1869}
1870
1871impl Drop for ContextStore {
1872 fn drop(&mut self) {
1873 if let Ok(mut state) = self.compaction_state.try_lock() {
1875 if let Some(task) = state.background_task.take() {
1876 task.abort();
1877 }
1878 }
1879 }
1880}
1881
1882fn batch_to_records(batch: &RecordBatch) -> LanceResult<Vec<ContextRecord>> {
1884 let id_array = column_as::<StringArray>(batch, "id")?;
1885 let external_id_array = column_as_optional::<StringArray>(batch, "external_id");
1886 let run_id_array = column_as::<StringArray>(batch, "run_id")?;
1887 let bot_id_array = column_as_optional::<StringArray>(batch, "bot_id");
1888 let session_id_array = column_as_optional::<StringArray>(batch, "session_id");
1889 let tenant_array = column_as_optional::<StringArray>(batch, "tenant");
1890 let source_array = column_as_optional::<StringArray>(batch, "source");
1891 let created_at_array = column_as::<TimestampMicrosecondArray>(batch, "created_at")?;
1892 let role_array = column_as::<DictionaryArray<Int8Type>>(batch, "role")?;
1893 let state_array = column_as::<StructArray>(batch, "state_metadata")?;
1894 let metadata_array = column_as_optional::<LargeStringArray>(batch, "metadata");
1895 let relationships_array = column_as_optional::<ListArray>(batch, RELATIONSHIPS_COLUMN);
1896 let expires_at_array = column_as_optional::<TimestampMicrosecondArray>(batch, "expires_at");
1897 let retention_policy_array = column_as_optional::<StringArray>(batch, "retention_policy");
1898 let lifecycle_status_array = column_as_optional::<StringArray>(batch, "lifecycle_status");
1899 let retired_at_array = column_as_optional::<TimestampMicrosecondArray>(batch, "retired_at");
1900 let retired_reason_array = column_as_optional::<StringArray>(batch, "retired_reason");
1901 let supersedes_id_array = column_as_optional::<StringArray>(batch, "supersedes_id");
1902 let superseded_by_id_array = column_as_optional::<StringArray>(batch, "superseded_by_id");
1903 let content_type_array = column_as::<StringArray>(batch, "content_type")?;
1904 let binary_array = column_as::<LargeBinaryArray>(batch, "binary_payload")?;
1905 let embedding_array = column_as::<FixedSizeListArray>(batch, "embedding")?;
1906
1907 let text_is_binary = batch
1909 .schema()
1910 .field_with_name("text_payload")
1911 .is_ok_and(|f| f.data_type() == &DataType::LargeBinary);
1912
1913 let text_string_array = if !text_is_binary {
1914 Some(column_as::<LargeStringArray>(batch, "text_payload")?)
1915 } else {
1916 None
1917 };
1918 let text_binary_array = if text_is_binary {
1919 Some(column_as::<LargeBinaryArray>(batch, "text_payload")?)
1920 } else {
1921 None
1922 };
1923
1924 let step_array = state_array
1925 .column(0)
1926 .as_ref()
1927 .as_any()
1928 .downcast_ref::<Int32Array>()
1929 .ok_or_else(|| {
1930 LanceError::from(ArrowError::InvalidArgumentError(
1931 "step column has unexpected data type".to_string(),
1932 ))
1933 })?;
1934 let active_plan_array = state_array
1935 .column(1)
1936 .as_ref()
1937 .as_any()
1938 .downcast_ref::<StringArray>()
1939 .ok_or_else(|| {
1940 LanceError::from(ArrowError::InvalidArgumentError(
1941 "active_plan_id column has unexpected data type".to_string(),
1942 ))
1943 })?;
1944 let tokens_used_array = state_array
1945 .column(2)
1946 .as_ref()
1947 .as_any()
1948 .downcast_ref::<Int32Array>()
1949 .ok_or_else(|| {
1950 LanceError::from(ArrowError::InvalidArgumentError(
1951 "tokens_used column has unexpected data type".to_string(),
1952 ))
1953 })?;
1954 let custom_array = state_array
1955 .column(3)
1956 .as_ref()
1957 .as_any()
1958 .downcast_ref::<StringArray>()
1959 .ok_or_else(|| {
1960 LanceError::from(ArrowError::InvalidArgumentError(
1961 "custom column has unexpected data type".to_string(),
1962 ))
1963 })?;
1964
1965 let mut results = Vec::with_capacity(batch.num_rows());
1966 for row in 0..batch.num_rows() {
1967 let created_at = timestamp_from_micros(created_at_array.value(row), "created_at")?;
1968
1969 let state_metadata = if state_array.is_null(row) {
1970 None
1971 } else {
1972 Some(StateMetadata {
1973 step: if step_array.is_null(row) {
1974 None
1975 } else {
1976 Some(step_array.value(row))
1977 },
1978 active_plan_id: if active_plan_array.is_null(row) {
1979 None
1980 } else {
1981 Some(active_plan_array.value(row).to_string())
1982 },
1983 tokens_used: if tokens_used_array.is_null(row) {
1984 None
1985 } else {
1986 Some(tokens_used_array.value(row))
1987 },
1988 custom: if custom_array.is_null(row) {
1989 None
1990 } else {
1991 Some(custom_array.value(row).to_string())
1992 },
1993 })
1994 };
1995
1996 let text_payload = if text_is_binary {
1997 let arr = text_binary_array.unwrap();
1998 if arr.is_null(row) {
1999 None
2000 } else {
2001 Some(String::from_utf8_lossy(arr.value(row)).to_string())
2002 }
2003 } else {
2004 let arr = text_string_array.unwrap();
2005 if arr.is_null(row) {
2006 None
2007 } else {
2008 Some(arr.value(row).to_string())
2009 }
2010 };
2011
2012 let binary_payload = if binary_array.is_null(row) {
2013 None
2014 } else {
2015 Some(binary_array.value(row).to_vec())
2016 };
2017
2018 let embedding = if embedding_array.is_null(row) {
2019 None
2020 } else {
2021 Some(embedding_from_list(embedding_array, row)?)
2022 };
2023
2024 let role = if role_array.is_null(row) {
2025 return Err(LanceError::from(ArrowError::InvalidArgumentError(
2026 "role column contains null values".to_string(),
2027 )));
2028 } else {
2029 let role_values = role_array
2030 .values()
2031 .as_any()
2032 .downcast_ref::<StringArray>()
2033 .ok_or_else(|| {
2034 LanceError::from(ArrowError::InvalidArgumentError(
2035 "role dictionary values are not strings".to_string(),
2036 ))
2037 })?;
2038 let key = role_array.keys().value(row) as usize;
2039 role_values.value(key).to_string()
2040 };
2041
2042 let bot_id = bot_id_array.and_then(|arr| {
2043 if arr.is_null(row) {
2044 None
2045 } else {
2046 Some(arr.value(row).to_string())
2047 }
2048 });
2049
2050 let session_id = session_id_array.and_then(|arr| {
2051 if arr.is_null(row) {
2052 None
2053 } else {
2054 Some(arr.value(row).to_string())
2055 }
2056 });
2057
2058 let tenant = tenant_array.and_then(|arr| {
2059 if arr.is_null(row) {
2060 None
2061 } else {
2062 Some(arr.value(row).to_string())
2063 }
2064 });
2065
2066 let source = source_array.and_then(|arr| {
2067 if arr.is_null(row) {
2068 None
2069 } else {
2070 Some(arr.value(row).to_string())
2071 }
2072 });
2073
2074 let metadata = match metadata_array {
2075 Some(arr) if !arr.is_null(row) => {
2076 Some(serde_json::from_str(arr.value(row)).map_err(|err| {
2077 LanceError::from(ArrowError::InvalidArgumentError(format!(
2078 "invalid metadata JSON for record {}: {}",
2079 id_array.value(row),
2080 err
2081 )))
2082 })?)
2083 }
2084 _ => None,
2085 };
2086 let relationships = match relationships_array {
2087 Some(arr) if !arr.is_null(row) => relationships_from_list(arr, row)?,
2088 _ => Vec::new(),
2089 };
2090 let expires_at = optional_timestamp_from_array(expires_at_array, row, "expires_at")?;
2091 let retention_policy = optional_string_from_array(retention_policy_array, row);
2092 let lifecycle_status = optional_string_from_array(lifecycle_status_array, row)
2093 .unwrap_or_else(|| LIFECYCLE_ACTIVE.to_string());
2094 let retired_at = optional_timestamp_from_array(retired_at_array, row, "retired_at")?;
2095 let retired_reason = optional_string_from_array(retired_reason_array, row);
2096 let supersedes_id = optional_string_from_array(supersedes_id_array, row);
2097 let superseded_by_id = optional_string_from_array(superseded_by_id_array, row);
2098
2099 results.push(ContextRecord {
2100 id: id_array.value(row).to_string(),
2101 external_id: external_id_array.and_then(|arr| {
2102 if arr.is_null(row) {
2103 None
2104 } else {
2105 Some(arr.value(row).to_string())
2106 }
2107 }),
2108 run_id: run_id_array.value(row).to_string(),
2109 bot_id,
2110 session_id,
2111 tenant,
2112 source,
2113 created_at,
2114 role,
2115 state_metadata,
2116 metadata,
2117 relationships,
2118 expires_at,
2119 retention_policy,
2120 lifecycle_status,
2121 retired_at,
2122 retired_reason,
2123 supersedes_id,
2124 superseded_by_id,
2125 content_type: content_type_array.value(row).to_string(),
2126 text_payload,
2127 binary_payload,
2128 embedding,
2129 });
2130 }
2131
2132 Ok(results)
2133}
2134
2135fn embedding_from_list(list: &FixedSizeListArray, row: usize) -> LanceResult<Vec<f32>> {
2136 let values = list.value(row);
2137 let float_array = values
2138 .as_ref()
2139 .as_any()
2140 .downcast_ref::<Float32Array>()
2141 .ok_or_else(|| {
2142 LanceError::from(ArrowError::InvalidArgumentError(
2143 "embedding column does not contain float32 values".to_string(),
2144 ))
2145 })?;
2146 let mut embedding = Vec::with_capacity(float_array.len());
2147 for idx in 0..float_array.len() {
2148 embedding.push(float_array.value(idx));
2149 }
2150 Ok(embedding)
2151}
2152
2153fn relationships_from_list(list: &ListArray, row: usize) -> LanceResult<Vec<Relationship>> {
2154 let values = list.value(row);
2155 let struct_array = values
2156 .as_ref()
2157 .as_any()
2158 .downcast_ref::<StructArray>()
2159 .ok_or_else(|| {
2160 LanceError::from(ArrowError::InvalidArgumentError(
2161 "relationships column does not contain struct values".to_string(),
2162 ))
2163 })?;
2164
2165 let target_id_array = struct_array
2166 .column(0)
2167 .as_ref()
2168 .as_any()
2169 .downcast_ref::<StringArray>()
2170 .ok_or_else(|| {
2171 LanceError::from(ArrowError::InvalidArgumentError(
2172 "relationships.target_id column has unexpected data type".to_string(),
2173 ))
2174 })?;
2175 let relation_array = struct_array
2176 .column(1)
2177 .as_ref()
2178 .as_any()
2179 .downcast_ref::<StringArray>()
2180 .ok_or_else(|| {
2181 LanceError::from(ArrowError::InvalidArgumentError(
2182 "relationships.relation column has unexpected data type".to_string(),
2183 ))
2184 })?;
2185 let weight_array = struct_array
2186 .column(2)
2187 .as_ref()
2188 .as_any()
2189 .downcast_ref::<Float32Array>()
2190 .ok_or_else(|| {
2191 LanceError::from(ArrowError::InvalidArgumentError(
2192 "relationships.weight column has unexpected data type".to_string(),
2193 ))
2194 })?;
2195
2196 let mut relationships = Vec::with_capacity(struct_array.len());
2197 for idx in 0..struct_array.len() {
2198 if struct_array.is_null(idx) {
2199 continue;
2200 }
2201 if target_id_array.is_null(idx) {
2202 return Err(LanceError::from(ArrowError::InvalidArgumentError(
2203 "relationships.target_id contains null values".to_string(),
2204 )));
2205 }
2206 if relation_array.is_null(idx) {
2207 return Err(LanceError::from(ArrowError::InvalidArgumentError(
2208 "relationships.relation contains null values".to_string(),
2209 )));
2210 }
2211
2212 relationships.push(Relationship {
2213 target_id: target_id_array.value(idx).to_string(),
2214 relation: relation_array.value(idx).to_string(),
2215 weight: if weight_array.is_null(idx) {
2216 None
2217 } else {
2218 Some(weight_array.value(idx))
2219 },
2220 });
2221 }
2222 Ok(relationships)
2223}
2224
2225fn timestamp_from_micros(value: i64, column: &str) -> LanceResult<DateTime<Utc>> {
2226 DateTime::from_timestamp_micros(value).ok_or_else(|| {
2227 LanceError::from(ArrowError::InvalidArgumentError(format!(
2228 "invalid timestamp value {value} in column '{column}'"
2229 )))
2230 })
2231}
2232
2233fn optional_timestamp_from_array(
2234 array: Option<&TimestampMicrosecondArray>,
2235 row: usize,
2236 column: &str,
2237) -> LanceResult<Option<DateTime<Utc>>> {
2238 let Some(array) = array else {
2239 return Ok(None);
2240 };
2241 if array.is_null(row) {
2242 Ok(None)
2243 } else {
2244 timestamp_from_micros(array.value(row), column).map(Some)
2245 }
2246}
2247
2248fn optional_string_from_array(array: Option<&StringArray>, row: usize) -> Option<String> {
2249 array.and_then(|arr| {
2250 if arr.is_null(row) {
2251 None
2252 } else {
2253 Some(arr.value(row).to_string())
2254 }
2255 })
2256}
2257
2258fn l2_distance(left: &[f32], right: &[f32]) -> f32 {
2259 left.iter()
2260 .zip(right)
2261 .map(|(left, right)| {
2262 let delta = left - right;
2263 delta * delta
2264 })
2265 .sum::<f32>()
2266 .sqrt()
2267}
2268
2269fn validate_embedding_dim(embedding_dim: i32) -> LanceResult<()> {
2270 if embedding_dim <= 0 {
2271 return Err(LanceError::from(ArrowError::InvalidArgumentError(format!(
2272 "embedding_dim must be positive, got {embedding_dim}"
2273 ))));
2274 }
2275 Ok(())
2276}
2277
2278fn validate_query_dimension(query: &[f32], embedding_dim: i32) -> LanceResult<()> {
2279 if query.len() != embedding_dim as usize {
2280 return Err(ArrowError::InvalidArgumentError(format!(
2281 "query length {} does not match embedding dimension {}",
2282 query.len(),
2283 embedding_dim
2284 ))
2285 .into());
2286 }
2287 Ok(())
2288}
2289
2290fn unique_query_terms(text: &str) -> Vec<String> {
2291 let mut seen = HashSet::new();
2292 tokenize_for_retrieval(text)
2293 .into_iter()
2294 .filter(|term| seen.insert(term.clone()))
2295 .collect()
2296}
2297
2298fn tokenize_for_retrieval(text: &str) -> Vec<String> {
2299 let mut terms = Vec::new();
2300 let mut current = String::new();
2301
2302 for character in text.chars() {
2303 if character.is_alphanumeric() {
2304 current.extend(character.to_lowercase());
2305 } else if !current.is_empty() {
2306 terms.push(std::mem::take(&mut current));
2307 }
2308 }
2309
2310 if !current.is_empty() {
2311 terms.push(current);
2312 }
2313
2314 terms
2315}
2316
2317fn lexical_score(query_terms: &[String], text: Option<&str>) -> Option<f32> {
2318 let text = text?;
2319 if query_terms.is_empty() {
2320 return None;
2321 }
2322
2323 let payload_terms: HashSet<String> = tokenize_for_retrieval(text).into_iter().collect();
2324 if payload_terms.is_empty() {
2325 return None;
2326 }
2327
2328 let matched_terms = query_terms
2329 .iter()
2330 .filter(|term| payload_terms.contains(*term))
2331 .count();
2332 if matched_terms == 0 {
2333 return None;
2334 }
2335
2336 Some(matched_terms as f32 / query_terms.len() as f32)
2337}
2338
2339fn add_retrieve_channel(
2340 candidates: &mut HashMap<String, RetrieveResult>,
2341 record: &ContextRecord,
2342 rank: usize,
2343 channel: &str,
2344 vector_distance: Option<f32>,
2345 text_score: Option<f32>,
2346) {
2347 let candidate = candidates
2348 .entry(record.id.clone())
2349 .or_insert_with(|| RetrieveResult {
2350 record: record.clone(),
2351 score: 0.0,
2352 vector_distance: None,
2353 text_score: None,
2354 matched_channels: Vec::new(),
2355 });
2356 candidate.score += 1.0 / (RRF_K + rank as f32);
2357 if let Some(distance) = vector_distance {
2358 candidate.vector_distance = Some(distance);
2359 }
2360 if let Some(score) = text_score {
2361 candidate.text_score = Some(score);
2362 }
2363 if !candidate
2364 .matched_channels
2365 .iter()
2366 .any(|existing| existing == channel)
2367 {
2368 candidate.matched_channels.push(channel.to_string());
2369 }
2370}
2371
2372fn compare_retrieve_results(left: &RetrieveResult, right: &RetrieveResult) -> Ordering {
2373 right
2374 .score
2375 .total_cmp(&left.score)
2376 .then_with(|| compare_optional_distance(left.vector_distance, right.vector_distance))
2377 .then_with(|| compare_optional_score(left.text_score, right.text_score))
2378 .then_with(|| left.record.id.cmp(&right.record.id))
2379}
2380
2381fn compare_optional_distance(left: Option<f32>, right: Option<f32>) -> Ordering {
2382 match (left, right) {
2383 (Some(left), Some(right)) => left.total_cmp(&right),
2384 (Some(_), None) => Ordering::Less,
2385 (None, Some(_)) => Ordering::Greater,
2386 (None, None) => Ordering::Equal,
2387 }
2388}
2389
2390fn compare_optional_score(left: Option<f32>, right: Option<f32>) -> Ordering {
2391 match (left, right) {
2392 (Some(left), Some(right)) => right.total_cmp(&left),
2393 (Some(_), None) => Ordering::Less,
2394 (None, Some(_)) => Ordering::Greater,
2395 (None, None) => Ordering::Equal,
2396 }
2397}
2398
2399fn embedding_dim_from_schema(schema: &Schema) -> LanceResult<i32> {
2400 let field = schema
2401 .field_with_name("embedding")
2402 .map_err(LanceError::from)?;
2403 let DataType::FixedSizeList(item_field, embedding_dim) = field.data_type() else {
2404 return Err(LanceError::from(ArrowError::InvalidArgumentError(
2405 "embedding column must be a FixedSizeList<Float32>".to_string(),
2406 )));
2407 };
2408 if item_field.data_type() != &DataType::Float32 {
2409 return Err(LanceError::from(ArrowError::InvalidArgumentError(
2410 "embedding column must contain Float32 values".to_string(),
2411 )));
2412 }
2413 validate_embedding_dim(*embedding_dim)?;
2414 Ok(*embedding_dim)
2415}
2416
2417fn distance_metric_from_schema(schema: &Schema) -> LanceResult<DistanceMetric> {
2422 match schema.metadata.get(DISTANCE_METRIC_METADATA_KEY) {
2423 Some(value) => DistanceMetric::parse(value),
2424 None => Ok(DistanceMetric::default()),
2425 }
2426}
2427
2428fn dot_product(left: &[f32], right: &[f32]) -> f32 {
2430 left.iter()
2431 .zip(right)
2432 .map(|(left, right)| left * right)
2433 .sum::<f32>()
2434}
2435
2436fn cosine_distance(left: &[f32], right: &[f32]) -> f32 {
2441 let dot = dot_product(left, right);
2442 let left_norm = dot_product(left, left).sqrt();
2443 let right_norm = dot_product(right, right).sqrt();
2444 if left_norm == 0.0 || right_norm == 0.0 {
2445 return 1.0;
2446 }
2447 1.0 - (dot / (left_norm * right_norm))
2448}
2449
2450fn dot_distance(left: &[f32], right: &[f32]) -> f32 {
2453 -dot_product(left, right)
2454}
2455
2456fn column_as<'a, A>(batch: &'a RecordBatch, name: &str) -> LanceResult<&'a A>
2457where
2458 A: Array + 'static,
2459{
2460 let column = batch.column_by_name(name).ok_or_else(|| {
2461 LanceError::from(ArrowError::InvalidArgumentError(format!(
2462 "column '{name}' not found"
2463 )))
2464 })?;
2465 column.as_ref().as_any().downcast_ref::<A>().ok_or_else(|| {
2466 LanceError::from(ArrowError::InvalidArgumentError(format!(
2467 "column '{name}' has unexpected data type"
2468 )))
2469 })
2470}
2471
2472fn column_as_optional<'a, A>(batch: &'a RecordBatch, name: &str) -> Option<&'a A>
2473where
2474 A: Array + 'static,
2475{
2476 batch
2477 .column_by_name(name)
2478 .and_then(|col| col.as_ref().as_any().downcast_ref::<A>())
2479}
2480
2481#[cfg(test)]
2482mod tests {
2483 use super::*;
2484 use crate::serde::CONTENT_TYPE_TEXT;
2485 use chrono::{Duration as ChronoDuration, Utc};
2486 use tempfile::TempDir;
2487
2488 fn make_embedding_with_dim(dim: usize, pivot: f32) -> Vec<f32> {
2489 let mut values = vec![0.0; dim];
2490 if !values.is_empty() {
2491 values[0] = pivot;
2492 }
2493 values
2494 }
2495
2496 fn make_embedding(pivot: f32) -> Vec<f32> {
2497 make_embedding_with_dim(DEFAULT_EMBEDDING_DIM as usize, pivot)
2498 }
2499
2500 fn text_record(id: &str, embedding_pivot: f32) -> ContextRecord {
2501 ContextRecord {
2502 id: id.to_string(),
2503 external_id: None,
2504 run_id: format!("run-{id}"),
2505 bot_id: None,
2506 session_id: None,
2507 tenant: None,
2508 source: None,
2509 created_at: Utc::now(),
2510 role: "user".to_string(),
2511 state_metadata: Some(StateMetadata {
2512 step: Some(1),
2513 active_plan_id: Some("plan".to_string()),
2514 tokens_used: Some(10),
2515 custom: None,
2516 }),
2517 metadata: None,
2518 relationships: Vec::new(),
2519 expires_at: None,
2520 retention_policy: None,
2521 lifecycle_status: LIFECYCLE_ACTIVE.to_string(),
2522 retired_at: None,
2523 retired_reason: None,
2524 supersedes_id: None,
2525 superseded_by_id: None,
2526 content_type: CONTENT_TYPE_TEXT.to_string(),
2527 text_payload: Some(format!("payload-{id}")),
2528 binary_payload: None,
2529 embedding: Some(make_embedding(embedding_pivot)),
2530 }
2531 }
2532
2533 #[test]
2534 fn search_orders_by_distance() {
2535 let dir = TempDir::new().unwrap();
2536 let uri = dir.path().to_string_lossy().to_string();
2537 let runtime = tokio::runtime::Runtime::new().unwrap();
2538 runtime.block_on(async {
2539 let mut store = ContextStore::open(&uri).await.unwrap();
2540 let first = text_record("a", 0.0);
2541 let second = text_record("b", 1.0);
2542 store.add(&[first.clone(), second.clone()]).await.unwrap();
2543
2544 let query = make_embedding(1.0);
2545 let results = store.search(&query, Some(2)).await.unwrap();
2546
2547 assert_eq!(results.len(), 2);
2548 assert_eq!(results[0].record.id, second.id);
2549 assert!(
2550 results[0].distance <= results[1].distance,
2551 "results not ordered by distance: {:?}",
2552 results
2553 );
2554 });
2555 }
2556
2557 #[test]
2558 fn search_validates_query_length() {
2559 let dir = TempDir::new().unwrap();
2560 let uri = dir.path().to_string_lossy().to_string();
2561 let runtime = tokio::runtime::Runtime::new().unwrap();
2562 runtime.block_on(async {
2563 let store = ContextStore::open(&uri).await.unwrap();
2564 let err = store.search(&[0.0_f32], None).await.unwrap_err();
2565 let message = err.to_string();
2566 assert!(
2567 message.contains("embedding dimension"),
2568 "unexpected error message: {message}"
2569 );
2570 });
2571 }
2572
2573 fn make_embedding2(x0: f32, x1: f32) -> Vec<f32> {
2574 let mut values = vec![0.0; DEFAULT_EMBEDDING_DIM as usize];
2575 values[0] = x0;
2576 values[1] = x1;
2577 values
2578 }
2579
2580 fn text_record_with(id: &str, embedding: Vec<f32>) -> ContextRecord {
2581 let mut record = text_record(id, 0.0);
2582 record.embedding = Some(embedding);
2583 record
2584 }
2585
2586 #[test]
2587 fn distance_metric_parse_and_math() {
2588 assert_eq!(DistanceMetric::parse("l2").unwrap(), DistanceMetric::L2);
2589 assert_eq!(DistanceMetric::parse("L2").unwrap(), DistanceMetric::L2);
2590 assert_eq!(
2591 DistanceMetric::parse("cosine").unwrap(),
2592 DistanceMetric::Cosine
2593 );
2594 assert_eq!(DistanceMetric::parse("DOT").unwrap(), DistanceMetric::Dot);
2595 assert!(DistanceMetric::parse("manhattan").is_err());
2596 assert_eq!(DistanceMetric::default(), DistanceMetric::L2);
2597
2598 let a = [1.0_f32, 0.0];
2599 let b = [1.0_f32, 1.0];
2600 assert!((DistanceMetric::L2.distance(&a, &b) - 1.0).abs() < 1e-6);
2602 assert!((DistanceMetric::Cosine.distance(&a, &b) - (1.0 - 0.707_106_77)).abs() < 1e-5);
2604 assert!((DistanceMetric::Dot.distance(&a, &b) + 1.0).abs() < 1e-6);
2606 let zero = [0.0_f32, 0.0];
2608 assert!((DistanceMetric::Cosine.distance(&a, &zero) - 1.0).abs() < 1e-6);
2609 }
2610
2611 #[test]
2612 fn search_metric_changes_ranking() {
2613 let runtime = tokio::runtime::Runtime::new().unwrap();
2614 runtime.block_on(async {
2615 let query = make_embedding2(1.0, 0.0);
2617 let aligned = make_embedding2(10.0, 0.0);
2620 let near = make_embedding2(1.0, 1.0);
2622
2623 let l2_dir = TempDir::new().unwrap();
2625 let mut l2_store = ContextStore::open(&l2_dir.path().to_string_lossy())
2626 .await
2627 .unwrap();
2628 l2_store
2629 .add(&[
2630 text_record_with("aligned", aligned.clone()),
2631 text_record_with("near", near.clone()),
2632 ])
2633 .await
2634 .unwrap();
2635 let l2_results = l2_store.search(&query, Some(2)).await.unwrap();
2636 assert_eq!(l2_results[0].record.id, "near");
2637
2638 let cos_dir = TempDir::new().unwrap();
2640 let cos_opts = ContextStoreOptions {
2641 distance_metric: Some(DistanceMetric::Cosine),
2642 ..Default::default()
2643 };
2644 let mut cos_store =
2645 ContextStore::open_with_options(&cos_dir.path().to_string_lossy(), cos_opts)
2646 .await
2647 .unwrap();
2648 cos_store
2649 .add(&[
2650 text_record_with("aligned", aligned.clone()),
2651 text_record_with("near", near.clone()),
2652 ])
2653 .await
2654 .unwrap();
2655 let cos_results = cos_store.search(&query, Some(2)).await.unwrap();
2656 assert_eq!(cos_results[0].record.id, "aligned");
2657
2658 let dot_dir = TempDir::new().unwrap();
2660 let dot_opts = ContextStoreOptions {
2661 distance_metric: Some(DistanceMetric::Dot),
2662 ..Default::default()
2663 };
2664 let mut dot_store =
2665 ContextStore::open_with_options(&dot_dir.path().to_string_lossy(), dot_opts)
2666 .await
2667 .unwrap();
2668 dot_store
2669 .add(&[
2670 text_record_with("aligned", aligned),
2671 text_record_with("near", near),
2672 ])
2673 .await
2674 .unwrap();
2675 let dot_results = dot_store.search(&query, Some(2)).await.unwrap();
2676 assert_eq!(dot_results[0].record.id, "aligned");
2677 });
2678 }
2679
2680 #[test]
2681 fn distance_metric_persists_across_reopen() {
2682 let runtime = tokio::runtime::Runtime::new().unwrap();
2683 runtime.block_on(async {
2684 let dir = TempDir::new().unwrap();
2685 let uri = dir.path().to_string_lossy().to_string();
2686 let query = make_embedding2(1.0, 0.0);
2687 let aligned = make_embedding2(10.0, 0.0);
2688 let near = make_embedding2(1.0, 1.0);
2689
2690 {
2692 let opts = ContextStoreOptions {
2693 distance_metric: Some(DistanceMetric::Cosine),
2694 ..Default::default()
2695 };
2696 let mut store = ContextStore::open_with_options(&uri, opts).await.unwrap();
2697 store
2698 .add(&[
2699 text_record_with("aligned", aligned.clone()),
2700 text_record_with("near", near.clone()),
2701 ])
2702 .await
2703 .unwrap();
2704 }
2705
2706 let store = ContextStore::open(&uri).await.unwrap();
2709 assert_eq!(store.distance_metric, DistanceMetric::Cosine);
2710 let results = store.search(&query, Some(2)).await.unwrap();
2711 assert_eq!(results[0].record.id, "aligned");
2712 });
2713 }
2714
2715 #[test]
2716 fn distance_metric_mismatch_errors() {
2717 let runtime = tokio::runtime::Runtime::new().unwrap();
2718 runtime.block_on(async {
2719 let dir = TempDir::new().unwrap();
2720 let uri = dir.path().to_string_lossy().to_string();
2721 ContextStore::open_with_options(
2722 &uri,
2723 ContextStoreOptions {
2724 distance_metric: Some(DistanceMetric::Cosine),
2725 ..Default::default()
2726 },
2727 )
2728 .await
2729 .unwrap();
2730
2731 let result = ContextStore::open_with_options(
2732 &uri,
2733 ContextStoreOptions {
2734 distance_metric: Some(DistanceMetric::Dot),
2735 ..Default::default()
2736 },
2737 )
2738 .await;
2739 let err = match result {
2740 Ok(_) => panic!("expected a distance-metric mismatch error"),
2741 Err(err) => err,
2742 };
2743 assert!(
2744 err.to_string().contains("distance metric"),
2745 "unexpected error: {err}"
2746 );
2747 });
2748 }
2749
2750 #[test]
2751 fn distance_metric_from_schema_defaults_l2_when_absent() {
2752 let schema = Schema::new(vec![Field::new("id", DataType::Utf8, false)]);
2754 assert_eq!(
2755 distance_metric_from_schema(&schema).unwrap(),
2756 DistanceMetric::L2
2757 );
2758 }
2759
2760 #[test]
2761 fn retrieve_fuses_text_and_vector_channels() {
2762 let dir = TempDir::new().unwrap();
2763 let uri = dir.path().to_string_lossy().to_string();
2764 let runtime = tokio::runtime::Runtime::new().unwrap();
2765 runtime.block_on(async {
2766 let mut store = ContextStore::open(&uri).await.unwrap();
2767 let mut semantic_near = text_record("semantic-near", 0.0);
2768 semantic_near.text_payload = Some("general rollout risk guidance".to_string());
2769 let mut exact_policy = text_record("exact-policy", 1.0);
2770 exact_policy.text_payload = Some("POLICY-123 blocks service-a rollouts".to_string());
2771
2772 store
2773 .add(&[semantic_near.clone(), exact_policy.clone()])
2774 .await
2775 .unwrap();
2776
2777 let query = make_embedding(0.0);
2778 let results = store
2779 .retrieve_filtered_with_options(
2780 Some("POLICY-123 service-a"),
2781 Some(&query),
2782 Some(2),
2783 None,
2784 LifecycleQueryOptions::default(),
2785 )
2786 .await
2787 .unwrap();
2788
2789 assert_eq!(results.len(), 2);
2790 assert_eq!(results[0].record.id, exact_policy.id);
2791 assert!(results[0].score > results[1].score);
2792 assert!(results[0].vector_distance.is_some());
2793 assert_eq!(results[0].text_score, Some(1.0));
2794 assert_eq!(results[0].matched_channels, ["vector", "text"]);
2795 });
2796 }
2797
2798 #[test]
2799 fn custom_embedding_dimension_round_trips_add_search_and_reopen() {
2800 let dir = TempDir::new().unwrap();
2801 let uri = dir.path().to_string_lossy().to_string();
2802 let runtime = tokio::runtime::Runtime::new().unwrap();
2803 runtime.block_on(async {
2804 let options = ContextStoreOptions {
2805 embedding_dim: Some(3),
2806 ..Default::default()
2807 };
2808 let mut store = ContextStore::open_with_options(&uri, options)
2809 .await
2810 .unwrap();
2811 assert_eq!(store.embedding_dim(), 3);
2812
2813 let mut first = text_record("custom-a", 0.0);
2814 first.embedding = Some(make_embedding_with_dim(3, 0.0));
2815 let mut second = text_record("custom-b", 0.0);
2816 second.embedding = Some(make_embedding_with_dim(3, 1.0));
2817 store.add(&[first.clone(), second.clone()]).await.unwrap();
2818
2819 let query = make_embedding_with_dim(3, 1.0);
2820 let results = store.search(&query, Some(2)).await.unwrap();
2821 assert_eq!(results[0].record.id, second.id);
2822
2823 let reopened = ContextStore::open(&uri).await.unwrap();
2824 assert_eq!(reopened.embedding_dim(), 3);
2825 let results = reopened.search(&query, Some(1)).await.unwrap();
2826 assert_eq!(results[0].record.id, second.id);
2827
2828 let err = reopened
2829 .search(&make_embedding(1.0), None)
2830 .await
2831 .unwrap_err();
2832 assert!(
2833 err.to_string().contains("embedding dimension 3"),
2834 "unexpected error message: {err}"
2835 );
2836 });
2837 }
2838
2839 #[test]
2840 fn existing_default_dimension_dataset_opens_without_options() {
2841 let dir = TempDir::new().unwrap();
2842 let uri = dir.path().to_string_lossy().to_string();
2843 let runtime = tokio::runtime::Runtime::new().unwrap();
2844 runtime.block_on(async {
2845 let mut store = ContextStore::open(&uri).await.unwrap();
2846 assert_eq!(store.embedding_dim(), DEFAULT_EMBEDDING_DIM);
2847 store.add(&[text_record("default-dim", 0.0)]).await.unwrap();
2848 drop(store);
2849
2850 let reopened = ContextStore::open(&uri).await.unwrap();
2851 assert_eq!(reopened.embedding_dim(), DEFAULT_EMBEDDING_DIM);
2852 reopened
2853 .search(&make_embedding(0.0), Some(1))
2854 .await
2855 .unwrap();
2856 });
2857 }
2858
2859 #[test]
2860 fn opening_existing_dataset_rejects_mismatched_requested_dimension() {
2861 let dir = TempDir::new().unwrap();
2862 let uri = dir.path().to_string_lossy().to_string();
2863 let runtime = tokio::runtime::Runtime::new().unwrap();
2864 runtime.block_on(async {
2865 let options = ContextStoreOptions {
2866 embedding_dim: Some(3),
2867 ..Default::default()
2868 };
2869 ContextStore::open_with_options(&uri, options)
2870 .await
2871 .unwrap();
2872
2873 let mismatched = ContextStoreOptions {
2874 embedding_dim: Some(4),
2875 ..Default::default()
2876 };
2877 let err = match ContextStore::open_with_options(&uri, mismatched).await {
2878 Ok(_) => panic!("expected mismatched embedding dimension to fail"),
2879 Err(err) => err,
2880 };
2881 assert!(
2882 err.to_string()
2883 .contains("does not match requested dimension 4"),
2884 "unexpected error message: {err}"
2885 );
2886 });
2887 }
2888
2889 #[test]
2890 fn list_hides_expired_and_retired_records_by_default() {
2891 let dir = TempDir::new().unwrap();
2892 let uri = dir.path().to_string_lossy().to_string();
2893 let runtime = tokio::runtime::Runtime::new().unwrap();
2894 runtime.block_on(async {
2895 let mut store = ContextStore::open(&uri).await.unwrap();
2896 let active = text_record("active", 0.0);
2897 let mut expired = text_record("expired", 0.0);
2898 expired.expires_at = Some(Utc::now() - ChronoDuration::minutes(1));
2899 let mut superseded = text_record("superseded", 0.0);
2900 superseded.lifecycle_status = "superseded".to_string();
2901 superseded.retired_reason = Some("replaced by newer fact".to_string());
2902 superseded.superseded_by_id = Some("active".to_string());
2903
2904 store
2905 .add(&[active.clone(), expired.clone(), superseded.clone()])
2906 .await
2907 .unwrap();
2908
2909 let visible = store.list(None, None).await.unwrap();
2910 assert_eq!(visible.len(), 1);
2911 assert_eq!(visible[0].id, active.id);
2912
2913 let all = store
2914 .list_with_options(None, None, LifecycleQueryOptions::new(true, true))
2915 .await
2916 .unwrap();
2917 assert_eq!(all.len(), 3);
2918 let expired_roundtrip = all.iter().find(|record| record.id == expired.id).unwrap();
2919 assert_eq!(
2920 expired_roundtrip
2921 .expires_at
2922 .map(|value| value.timestamp_micros()),
2923 expired.expires_at.map(|value| value.timestamp_micros())
2924 );
2925 let superseded_roundtrip = all
2926 .iter()
2927 .find(|record| record.id == superseded.id)
2928 .unwrap();
2929 assert_eq!(superseded_roundtrip.lifecycle_status, "superseded");
2930 assert_eq!(
2931 superseded_roundtrip.superseded_by_id.as_deref(),
2932 Some("active")
2933 );
2934 });
2935 }
2936
2937 #[test]
2938 fn list_hides_records_superseded_by_newer_pointer() {
2939 let dir = TempDir::new().unwrap();
2940 let uri = dir.path().to_string_lossy().to_string();
2941 let runtime = tokio::runtime::Runtime::new().unwrap();
2942 runtime.block_on(async {
2943 let mut store = ContextStore::open(&uri).await.unwrap();
2944 let old = text_record("old", 0.0);
2945 let mut replacement = text_record("new", 1.0);
2946 replacement.supersedes_id = Some(old.id.clone());
2947 store
2948 .add(&[old.clone(), replacement.clone()])
2949 .await
2950 .unwrap();
2951
2952 let visible = store.list(None, None).await.unwrap();
2953 assert_eq!(visible.len(), 1);
2954 assert_eq!(visible[0].id, replacement.id);
2955
2956 let history = store
2957 .list_with_options(None, None, LifecycleQueryOptions::new(false, true))
2958 .await
2959 .unwrap();
2960 assert_eq!(history.len(), 2);
2961 assert!(history.iter().any(|record| record.id == old.id));
2962 assert!(history.iter().any(|record| record.id == replacement.id));
2963 });
2964 }
2965
2966 #[test]
2967 fn search_filters_lifecycle_before_ranking() {
2968 let dir = TempDir::new().unwrap();
2969 let uri = dir.path().to_string_lossy().to_string();
2970 let runtime = tokio::runtime::Runtime::new().unwrap();
2971 runtime.block_on(async {
2972 let mut store = ContextStore::open(&uri).await.unwrap();
2973 let active = text_record("active", 1.0);
2974 let mut expired_better_match = text_record("expired", 0.0);
2975 expired_better_match.expires_at = Some(Utc::now() - ChronoDuration::minutes(1));
2976 store
2977 .add(&[active.clone(), expired_better_match.clone()])
2978 .await
2979 .unwrap();
2980
2981 let query = make_embedding(0.0);
2982 let visible = store.search(&query, Some(1)).await.unwrap();
2983 assert_eq!(visible.len(), 1);
2984 assert_eq!(visible[0].record.id, active.id);
2985
2986 let all = store
2987 .search_with_options(&query, Some(1), LifecycleQueryOptions::new(true, false))
2988 .await
2989 .unwrap();
2990 assert_eq!(all.len(), 1);
2991 assert_eq!(all[0].record.id, expired_better_match.id);
2992 });
2993 }
2994
2995 #[test]
2996 fn external_id_roundtrips_and_supports_lookup() {
2997 let dir = TempDir::new().unwrap();
2998 let uri = dir.path().to_string_lossy().to_string();
2999 let runtime = tokio::runtime::Runtime::new().unwrap();
3000 runtime.block_on(async {
3001 let mut store = ContextStore::open(&uri).await.unwrap();
3002 let mut record = text_record("a", 0.0);
3003 record.external_id = Some("doc-123#chunk-1".to_string());
3004 store.add(std::slice::from_ref(&record)).await.unwrap();
3005
3006 let by_external_id = store
3007 .get_by_external_id("doc-123#chunk-1")
3008 .await
3009 .unwrap()
3010 .unwrap();
3011 assert_eq!(by_external_id.id, record.id);
3012 assert_eq!(by_external_id.external_id, record.external_id);
3013
3014 let by_id = store.get_by_id(&record.id).await.unwrap().unwrap();
3015 assert_eq!(by_id.external_id.as_deref(), Some("doc-123#chunk-1"));
3016
3017 let missing = store.get_by_external_id("missing").await.unwrap();
3018 assert!(missing.is_none());
3019 });
3020 }
3021
3022 #[test]
3023 fn upsert_by_external_id_inserts_then_replaces_visible_record() {
3024 let dir = TempDir::new().unwrap();
3025 let uri = dir.path().to_string_lossy().to_string();
3026 let runtime = tokio::runtime::Runtime::new().unwrap();
3027 runtime.block_on(async {
3028 let mut store = ContextStore::open(&uri).await.unwrap();
3029
3030 let mut first = text_record("first", 0.0);
3031 first.external_id = Some("doc-123#chunk-1".to_string());
3032 let inserted = store.upsert_by_external_id(first.clone()).await.unwrap();
3033 assert!(inserted.inserted);
3034 assert_eq!(inserted.replaced_id, None);
3035 assert_eq!(inserted.record.id, first.id);
3036
3037 let mut replacement = text_record("replacement", 1.0);
3038 replacement.external_id = first.external_id.clone();
3039 let replaced = store
3040 .upsert_by_external_id(replacement.clone())
3041 .await
3042 .unwrap();
3043 assert!(!replaced.inserted);
3044 assert_eq!(replaced.replaced_id.as_deref(), Some(first.id.as_str()));
3045 assert_eq!(
3046 replaced.record.supersedes_id.as_deref(),
3047 Some(first.id.as_str())
3048 );
3049
3050 let visible = store.list(None, None).await.unwrap();
3051 assert_eq!(visible.len(), 1);
3052 assert_eq!(visible[0].id, replacement.id);
3053
3054 let by_external_id = store
3055 .get_by_external_id("doc-123#chunk-1")
3056 .await
3057 .unwrap()
3058 .unwrap();
3059 assert_eq!(by_external_id.id, replacement.id);
3060
3061 let history = store
3062 .list_with_options(None, None, LifecycleQueryOptions::new(false, true))
3063 .await
3064 .unwrap();
3065 assert_eq!(history.len(), 2);
3066 assert!(history.iter().any(|record| record.id == first.id));
3067 assert!(history.iter().any(|record| record.id == replacement.id));
3068 });
3069 }
3070
3071 #[test]
3072 fn update_by_external_id_patches_mutable_fields_and_preserves_payload() {
3073 let dir = TempDir::new().unwrap();
3074 let uri = dir.path().to_string_lossy().to_string();
3075 let runtime = tokio::runtime::Runtime::new().unwrap();
3076 runtime.block_on(async {
3077 let mut store = ContextStore::open(&uri).await.unwrap();
3078
3079 let mut record = text_record("stable", 0.0);
3080 record.external_id = Some("doc-123#chunk-1".to_string());
3081 record.metadata = Some(serde_json::json!({"revision": 1}));
3082 store.add(std::slice::from_ref(&record)).await.unwrap();
3083
3084 let patch = RecordPatch {
3085 bot_id: Some("bot-a".to_string()),
3086 session_id: Some("session-a".to_string()),
3087 metadata: Some(serde_json::json!({"revision": 2, "confidence": 0.9})),
3088 relationships: Some(vec![Relationship {
3089 target_id: "doc-123".to_string(),
3090 relation: "derived_from".to_string(),
3091 weight: None,
3092 }]),
3093 ..Default::default()
3094 };
3095 let updated = store
3096 .update_by_external_id("doc-123#chunk-1", patch)
3097 .await
3098 .unwrap()
3099 .unwrap();
3100
3101 assert_eq!(updated.replaced_id, record.id);
3102 assert_ne!(updated.record.id, record.id);
3103 assert_eq!(updated.record.external_id, record.external_id);
3104 assert_eq!(updated.record.text_payload, record.text_payload);
3105 assert_eq!(updated.record.embedding, record.embedding);
3106 assert_eq!(updated.record.bot_id.as_deref(), Some("bot-a"));
3107 assert_eq!(updated.record.session_id.as_deref(), Some("session-a"));
3108 assert_eq!(
3109 updated.record.metadata,
3110 Some(serde_json::json!({"revision": 2, "confidence": 0.9}))
3111 );
3112 assert_eq!(updated.record.relationships.len(), 1);
3113 assert_eq!(
3114 updated.record.supersedes_id.as_deref(),
3115 Some(record.id.as_str())
3116 );
3117
3118 let visible = store
3119 .get_by_external_id("doc-123#chunk-1")
3120 .await
3121 .unwrap()
3122 .unwrap();
3123 assert_eq!(visible.id, updated.record.id);
3124
3125 let history = store
3126 .list_with_options(None, None, LifecycleQueryOptions::new(false, true))
3127 .await
3128 .unwrap();
3129 assert_eq!(history.len(), 2);
3130 assert!(history.iter().any(|item| item.id == record.id));
3131 assert!(history.iter().any(|item| item.id == updated.record.id));
3132 });
3133 }
3134
3135 #[test]
3136 fn deferred_embedding_patch_makes_raw_record_searchable() {
3137 let dir = TempDir::new().unwrap();
3138 let uri = dir.path().to_string_lossy().to_string();
3139 let runtime = tokio::runtime::Runtime::new().unwrap();
3140 runtime.block_on(async {
3141 let mut store = ContextStore::open(&uri).await.unwrap();
3142
3143 let mut by_ext = text_record("raw-ext", 0.0);
3145 by_ext.embedding = None;
3146 by_ext.external_id = Some("doc-1#chunk-1".to_string());
3147 let mut by_id = text_record("raw-id", 0.0);
3148 by_id.embedding = None;
3149 by_id.external_id = None;
3150 store.add(&[by_ext.clone(), by_id.clone()]).await.unwrap();
3151
3152 let query = make_embedding(1.0);
3154 assert!(store.search(&query, Some(10)).await.unwrap().is_empty());
3155
3156 let enriched_ext = store
3158 .update_by_external_id(
3159 "doc-1#chunk-1",
3160 RecordPatch {
3161 embedding: Some(make_embedding(1.0)),
3162 ..Default::default()
3163 },
3164 )
3165 .await
3166 .unwrap()
3167 .unwrap();
3168 assert_eq!(enriched_ext.record.embedding, Some(make_embedding(1.0)));
3169 assert_eq!(enriched_ext.record.text_payload, by_ext.text_payload);
3171
3172 let enriched_id = store
3174 .update_by_id(
3175 &by_id.id,
3176 RecordPatch {
3177 embedding: Some(make_embedding(0.0)),
3178 ..Default::default()
3179 },
3180 )
3181 .await
3182 .unwrap()
3183 .unwrap();
3184 assert_eq!(enriched_id.record.embedding, Some(make_embedding(0.0)));
3185
3186 let results = store.search(&query, Some(10)).await.unwrap();
3188 let ids: Vec<&str> = results.iter().map(|r| r.record.id.as_str()).collect();
3189 assert!(ids.contains(&enriched_ext.record.id.as_str()));
3190 assert!(ids.contains(&enriched_id.record.id.as_str()));
3191 assert_eq!(results[0].record.id, enriched_ext.record.id);
3193 });
3194 }
3195
3196 #[test]
3197 fn relationships_roundtrip_and_support_related_lookup() {
3198 let dir = TempDir::new().unwrap();
3199 let uri = dir.path().to_string_lossy().to_string();
3200 let runtime = tokio::runtime::Runtime::new().unwrap();
3201 runtime.block_on(async {
3202 let mut store = ContextStore::open(&uri).await.unwrap();
3203 let mut related = text_record("related", 0.0);
3204 related.relationships = vec![
3205 Relationship {
3206 target_id: "doc-1#chunk-1".to_string(),
3207 relation: "cites".to_string(),
3208 weight: Some(0.75),
3209 },
3210 Relationship {
3211 target_id: "service-a".to_string(),
3212 relation: "mentions".to_string(),
3213 weight: None,
3214 },
3215 ];
3216 let unrelated = text_record("unrelated", 1.0);
3217 store.add(&[related.clone(), unrelated]).await.unwrap();
3218
3219 let listed = store.list(None, None).await.unwrap();
3220 let roundtrip = listed
3221 .iter()
3222 .find(|record| record.id == related.id)
3223 .unwrap();
3224 assert_eq!(roundtrip.relationships, related.relationships);
3225
3226 let by_target = store
3227 .list_related("doc-1#chunk-1", None, None)
3228 .await
3229 .unwrap();
3230 assert_eq!(by_target.len(), 1);
3231 assert_eq!(by_target[0].id, related.id);
3232
3233 let by_relation = store
3234 .list_related("doc-1#chunk-1", Some("cites"), None)
3235 .await
3236 .unwrap();
3237 assert_eq!(by_relation.len(), 1);
3238 assert_eq!(by_relation[0].id, related.id);
3239
3240 let wrong_relation = store
3241 .list_related("doc-1#chunk-1", Some("mentions"), None)
3242 .await
3243 .unwrap();
3244 assert!(wrong_relation.is_empty());
3245 });
3246 }
3247
3248 #[test]
3249 fn migrate_relationships_column_adds_missing_column() {
3250 let dir = TempDir::new().unwrap();
3251 let uri = dir.path().to_string_lossy().to_string();
3252 let runtime = tokio::runtime::Runtime::new().unwrap();
3253 runtime.block_on(async {
3254 let schema = Arc::new(ContextStore::schema_with_options(
3255 &HashSet::new(),
3256 true,
3257 true,
3258 false,
3259 true,
3260 DEFAULT_EMBEDDING_DIM,
3261 DistanceMetric::default(),
3262 ));
3263 let empty_batch = RecordBatch::new_empty(schema.clone());
3264 let batches = RecordBatchIterator::new(
3265 vec![Ok::<RecordBatch, ArrowError>(empty_batch)].into_iter(),
3266 schema,
3267 );
3268 Dataset::write(
3269 batches,
3270 &uri,
3271 Some(WriteParams {
3272 mode: WriteMode::Create,
3273 ..Default::default()
3274 }),
3275 )
3276 .await
3277 .unwrap();
3278
3279 let mut store = ContextStore::open(&uri).await.unwrap();
3280 assert!(!store.has_relationships_column());
3281
3282 let mut record = text_record("with-relationships", 0.0);
3283 record.relationships.push(Relationship {
3284 target_id: "target".to_string(),
3285 relation: "mentions".to_string(),
3286 weight: None,
3287 });
3288 let err = store.add(std::slice::from_ref(&record)).await.unwrap_err();
3289 assert!(
3290 err.to_string().contains("migrate_relationships_column"),
3291 "unexpected error: {err}"
3292 );
3293
3294 assert!(store.migrate_relationships_column().await.unwrap());
3295 assert!(store.has_relationships_column());
3296 assert!(!store.migrate_relationships_column().await.unwrap());
3297
3298 store.add(std::slice::from_ref(&record)).await.unwrap();
3299 let roundtrip = store.get_by_id(&record.id).await.unwrap().unwrap();
3300 assert_eq!(roundtrip.relationships, record.relationships);
3301 });
3302 }
3303
3304 #[test]
3305 fn add_rejects_duplicate_external_id() {
3306 let dir = TempDir::new().unwrap();
3307 let uri = dir.path().to_string_lossy().to_string();
3308 let runtime = tokio::runtime::Runtime::new().unwrap();
3309 runtime.block_on(async {
3310 let mut store = ContextStore::open(&uri).await.unwrap();
3311 let mut first = text_record("a", 0.0);
3312 first.external_id = Some("doc-123#chunk-1".to_string());
3313 store.add(std::slice::from_ref(&first)).await.unwrap();
3314
3315 let mut duplicate = text_record("b", 0.0);
3316 duplicate.external_id = first.external_id.clone();
3317 let err = store.add(&[duplicate]).await.unwrap_err();
3318 let message = err.to_string();
3319 assert!(
3320 message.contains("external_id") && message.contains("already exists"),
3321 "unexpected error message: {message}"
3322 );
3323 });
3324 }
3325
3326 #[test]
3327 fn add_rejects_reserved_tombstone_content_type() {
3328 let dir = TempDir::new().unwrap();
3329 let uri = dir.path().to_string_lossy().to_string();
3330 let runtime = tokio::runtime::Runtime::new().unwrap();
3331 runtime.block_on(async {
3332 let mut store = ContextStore::open(&uri).await.unwrap();
3333 let mut record = text_record("a", 0.0);
3334 record.content_type = CONTENT_TYPE_TOMBSTONE.to_string();
3335
3336 let err = store.add(&[record]).await.unwrap_err();
3337 let message = err.to_string();
3338 assert!(
3339 message.contains("reserved") && message.contains("tombstone"),
3340 "unexpected error message: {message}"
3341 );
3342 });
3343 }
3344
3345 #[test]
3346 fn delete_by_external_id_hides_record_from_default_reads() {
3347 let dir = TempDir::new().unwrap();
3348 let uri = dir.path().to_string_lossy().to_string();
3349 let runtime = tokio::runtime::Runtime::new().unwrap();
3350 runtime.block_on(async {
3351 let mut store = ContextStore::open(&uri).await.unwrap();
3352 let mut first = text_record("a", 0.0);
3353 first.external_id = Some("doc-123#chunk-1".to_string());
3354 let second = text_record("b", 2.0);
3355 store.add(&[first.clone(), second.clone()]).await.unwrap();
3356
3357 assert!(store
3358 .delete_by_external_id("doc-123#chunk-1")
3359 .await
3360 .unwrap());
3361
3362 assert!(store
3363 .get_by_external_id("doc-123#chunk-1")
3364 .await
3365 .unwrap()
3366 .is_none());
3367 assert!(store.get_by_id(&first.id).await.unwrap().is_none());
3368
3369 let records = store.list(None, None).await.unwrap();
3370 assert_eq!(records.len(), 1);
3371 assert_eq!(records[0].id, second.id);
3372
3373 let query = make_embedding(0.0);
3374 let hits = store.search(&query, Some(10)).await.unwrap();
3375 assert_eq!(hits.len(), 1);
3376 assert_eq!(hits[0].record.id, second.id);
3377 });
3378 }
3379
3380 #[test]
3381 fn delete_by_id_hides_record_from_default_reads() {
3382 let dir = TempDir::new().unwrap();
3383 let uri = dir.path().to_string_lossy().to_string();
3384 let runtime = tokio::runtime::Runtime::new().unwrap();
3385 runtime.block_on(async {
3386 let mut store = ContextStore::open(&uri).await.unwrap();
3387 let mut first = text_record("a", 0.0);
3388 first.external_id = Some("doc-123#chunk-1".to_string());
3389 let second = text_record("b", 2.0);
3390 store.add(&[first.clone(), second.clone()]).await.unwrap();
3391
3392 assert!(store.delete_by_id(&first.id).await.unwrap());
3393
3394 assert!(store.get_by_id(&first.id).await.unwrap().is_none());
3395 assert!(store
3396 .get_by_external_id("doc-123#chunk-1")
3397 .await
3398 .unwrap()
3399 .is_none());
3400
3401 let records = store.list(None, None).await.unwrap();
3402 assert_eq!(records.len(), 1);
3403 assert_eq!(records[0].id, second.id);
3404
3405 let query = make_embedding(0.0);
3406 let hits = store.search(&query, Some(10)).await.unwrap();
3407 assert_eq!(hits.len(), 1);
3408 assert_eq!(hits[0].record.id, second.id);
3409 });
3410 }
3411
3412 #[test]
3413 fn delete_missing_id_is_noop() {
3414 let dir = TempDir::new().unwrap();
3415 let uri = dir.path().to_string_lossy().to_string();
3416 let runtime = tokio::runtime::Runtime::new().unwrap();
3417 runtime.block_on(async {
3418 let mut store = ContextStore::open(&uri).await.unwrap();
3419 assert!(!store.delete_by_id("missing").await.unwrap());
3420 assert!(!store.delete_by_external_id("missing").await.unwrap());
3421 });
3422 }
3423
3424 #[test]
3425 fn external_id_can_be_reused_after_delete() {
3426 let dir = TempDir::new().unwrap();
3427 let uri = dir.path().to_string_lossy().to_string();
3428 let runtime = tokio::runtime::Runtime::new().unwrap();
3429 runtime.block_on(async {
3430 let mut store = ContextStore::open(&uri).await.unwrap();
3431 let mut first = text_record("a", 0.0);
3432 first.external_id = Some("doc-123#chunk-1".to_string());
3433 store.add(std::slice::from_ref(&first)).await.unwrap();
3434 assert!(store
3435 .delete_by_external_id("doc-123#chunk-1")
3436 .await
3437 .unwrap());
3438
3439 let mut replacement = text_record("b", 1.0);
3440 replacement.external_id = first.external_id.clone();
3441 store.add(std::slice::from_ref(&replacement)).await.unwrap();
3442
3443 let by_external_id = store
3444 .get_by_external_id("doc-123#chunk-1")
3445 .await
3446 .unwrap()
3447 .unwrap();
3448 assert_eq!(by_external_id.id, replacement.id);
3449 assert_eq!(store.list(None, None).await.unwrap().len(), 1);
3450 });
3451 }
3452
3453 #[test]
3454 fn test_region_id_derivation_explicit() {
3455 let bot_id = Some("bot-123".to_string());
3456 let session_id = Some("session-456".to_string());
3457
3458 let region_id_1 = ContextStore::derive_region_id(&bot_id, &session_id);
3459 let region_id_2 = ContextStore::derive_region_id(&bot_id, &session_id);
3460
3461 assert_eq!(
3462 region_id_1, region_id_2,
3463 "Region ID should be deterministic for same inputs"
3464 );
3465
3466 let other_session = Some("session-789".to_string());
3467 let region_id_3 = ContextStore::derive_region_id(&bot_id, &other_session);
3468
3469 assert_ne!(
3470 region_id_1, region_id_3,
3471 "Region ID should differ for different inputs"
3472 );
3473
3474 let region_id_none = ContextStore::derive_region_id(&None, &None);
3476 let region_id_none_2 = ContextStore::derive_region_id(&None, &None);
3477 assert_eq!(
3478 region_id_none, region_id_none_2,
3479 "Region ID for None/None should be deterministic"
3480 );
3481 }
3482
3483 #[test]
3484 fn test_add_multiple_regions() {
3485 let dir = TempDir::new().unwrap();
3486 let uri = dir.path().to_string_lossy().to_string();
3487 let runtime = tokio::runtime::Runtime::new().unwrap();
3488
3489 runtime.block_on(async {
3490 let mut store = ContextStore::open(&uri).await.unwrap();
3491
3492 let mut record1 = text_record("r1", 0.0);
3494 record1.bot_id = Some("bot-A".to_string());
3495 record1.session_id = Some("session-1".to_string());
3496
3497 let mut record2 = text_record("r2", 0.0);
3498 record2.bot_id = Some("bot-B".to_string());
3499 record2.session_id = Some("session-2".to_string());
3500
3501 store
3503 .add(&[record1.clone(), record2.clone()])
3504 .await
3505 .unwrap();
3506
3507 let store = ContextStore::open(&uri).await.unwrap();
3509
3510 let results = store.list(None, None).await.unwrap();
3512 assert_eq!(results.len(), 2);
3513
3514 let ids: Vec<String> = results.iter().map(|r| r.id.clone()).collect();
3515 assert!(ids.contains(&"r1".to_string()));
3516 assert!(ids.contains(&"r2".to_string()));
3517 });
3518 }
3519
3520 #[test]
3521 fn test_blob_binary_payload() {
3522 let dir = TempDir::new().unwrap();
3523 let uri = dir.path().to_string_lossy().to_string();
3524 let runtime = tokio::runtime::Runtime::new().unwrap();
3525
3526 runtime.block_on(async {
3527 let options = ContextStoreOptions {
3528 blob_columns: HashSet::from(["binary_payload".to_string()]),
3529 ..Default::default()
3530 };
3531 let mut store = ContextStore::open_with_options(&uri, options)
3532 .await
3533 .unwrap();
3534
3535 let mut record = text_record("blob-bin-1", 0.0);
3536 record.binary_payload = Some(vec![0xDE, 0xAD, 0xBE, 0xEF]);
3537 store.add(std::slice::from_ref(&record)).await.unwrap();
3538
3539 let schema = ContextStore::schema(&store.blob_columns);
3541 let field = schema.field_with_name("binary_payload").unwrap();
3542 assert_eq!(
3543 field.metadata().get("lance-encoding:blob"),
3544 Some(&"true".to_string()),
3545 );
3546 let text_field = schema.field_with_name("text_payload").unwrap();
3548 assert_eq!(text_field.data_type(), &DataType::LargeUtf8);
3549 assert!(text_field.metadata().get("lance-encoding:blob").is_none());
3550 });
3551 }
3552
3553 #[test]
3554 fn test_blob_text_payload() {
3555 let dir = TempDir::new().unwrap();
3556 let uri = dir.path().to_string_lossy().to_string();
3557 let runtime = tokio::runtime::Runtime::new().unwrap();
3558
3559 runtime.block_on(async {
3560 let options = ContextStoreOptions {
3561 blob_columns: HashSet::from(["text_payload".to_string()]),
3562 ..Default::default()
3563 };
3564 let mut store = ContextStore::open_with_options(&uri, options)
3565 .await
3566 .unwrap();
3567
3568 let record = text_record("blob-txt-1", 0.0);
3569 store.add(std::slice::from_ref(&record)).await.unwrap();
3570
3571 let batch = store
3573 .records_to_batch(std::slice::from_ref(&record))
3574 .unwrap();
3575 let batch_schema = batch.schema();
3576 let text_field = batch_schema.field_with_name("text_payload").unwrap();
3577 assert_eq!(
3578 text_field.data_type(),
3579 &DataType::LargeBinary,
3580 "text_payload should be LargeBinary when blob-encoded"
3581 );
3582
3583 let roundtripped = batch_to_records(&batch).unwrap();
3584 assert_eq!(roundtripped.len(), 1);
3585 assert_eq!(
3586 roundtripped[0].text_payload, record.text_payload,
3587 "text payload should survive blob roundtrip"
3588 );
3589 });
3590 }
3591
3592 #[test]
3593 fn test_blob_both_columns() {
3594 let dir = TempDir::new().unwrap();
3595 let uri = dir.path().to_string_lossy().to_string();
3596 let runtime = tokio::runtime::Runtime::new().unwrap();
3597
3598 runtime.block_on(async {
3599 let options = ContextStoreOptions {
3600 blob_columns: HashSet::from([
3601 "text_payload".to_string(),
3602 "binary_payload".to_string(),
3603 ]),
3604 ..Default::default()
3605 };
3606 let mut store = ContextStore::open_with_options(&uri, options)
3607 .await
3608 .unwrap();
3609
3610 let mut record = text_record("blob-both-1", 0.0);
3611 record.binary_payload = Some(b"hello binary".to_vec());
3612 store.add(std::slice::from_ref(&record)).await.unwrap();
3613
3614 let schema = ContextStore::schema(&store.blob_columns);
3616 let text_field = schema.field_with_name("text_payload").unwrap();
3617 let bin_field = schema.field_with_name("binary_payload").unwrap();
3618 assert_eq!(
3619 text_field.metadata().get("lance-encoding:blob"),
3620 Some(&"true".to_string()),
3621 );
3622 assert_eq!(
3623 bin_field.metadata().get("lance-encoding:blob"),
3624 Some(&"true".to_string()),
3625 );
3626
3627 let batch = store
3629 .records_to_batch(std::slice::from_ref(&record))
3630 .unwrap();
3631 let roundtripped = batch_to_records(&batch).unwrap();
3632 assert_eq!(roundtripped.len(), 1);
3633 assert_eq!(roundtripped[0].text_payload, record.text_payload);
3634 assert_eq!(roundtripped[0].binary_payload, record.binary_payload);
3635 });
3636 }
3637
3638 #[test]
3639 fn test_no_blob_default() {
3640 let schema = ContextStore::schema(&HashSet::new());
3642 let text_field = schema.field_with_name("text_payload").unwrap();
3643 let bin_field = schema.field_with_name("binary_payload").unwrap();
3644
3645 assert_eq!(text_field.data_type(), &DataType::LargeUtf8);
3646 assert!(text_field.metadata().get("lance-encoding:blob").is_none());
3647 assert_eq!(bin_field.data_type(), &DataType::LargeBinary);
3648 assert!(bin_field.metadata().get("lance-encoding:blob").is_none());
3649 }
3650
3651 #[test]
3652 fn test_blob_schema_metadata() {
3653 let blob_columns =
3654 HashSet::from(["text_payload".to_string(), "binary_payload".to_string()]);
3655 let schema = ContextStore::schema(&blob_columns);
3656
3657 let text_field = schema.field_with_name("text_payload").unwrap();
3658 assert_eq!(text_field.data_type(), &DataType::LargeBinary);
3659 assert_eq!(
3660 text_field.metadata().get("lance-encoding:blob"),
3661 Some(&"true".to_string()),
3662 );
3663
3664 let bin_field = schema.field_with_name("binary_payload").unwrap();
3665 assert_eq!(bin_field.data_type(), &DataType::LargeBinary);
3666 assert_eq!(
3667 bin_field.metadata().get("lance-encoding:blob"),
3668 Some(&"true".to_string()),
3669 );
3670
3671 let id_field = schema.field_with_name("id").unwrap();
3673 assert!(id_field.metadata().get("lance-encoding:blob").is_none());
3674 }
3675
3676 #[test]
3677 fn test_blob_invalid_column_name() {
3678 let dir = TempDir::new().unwrap();
3679 let uri = dir.path().to_string_lossy().to_string();
3680 let runtime = tokio::runtime::Runtime::new().unwrap();
3681
3682 runtime.block_on(async {
3683 let options = ContextStoreOptions {
3684 blob_columns: HashSet::from(["nonexistent_column".to_string()]),
3685 ..Default::default()
3686 };
3687 let result = ContextStore::open_with_options(&uri, options).await;
3688 assert!(result.is_err(), "should reject invalid blob column names");
3689 let err_msg = result.err().unwrap().to_string();
3690 assert!(
3691 err_msg.contains("invalid blob column"),
3692 "error should mention invalid blob column: {err_msg}"
3693 );
3694 });
3695 }
3696
3697 #[test]
3698 fn test_batch_to_records_autodetects_text_type() {
3699 let runtime = tokio::runtime::Runtime::new().unwrap();
3702 runtime.block_on(async {
3703 let dir1 = TempDir::new().unwrap();
3705 let uri1 = dir1.path().to_string_lossy().to_string();
3706 let store_default = ContextStore::open(&uri1).await.unwrap();
3707 let record = text_record("auto-1", 0.0);
3708 let batch_utf8 = store_default
3709 .records_to_batch(std::slice::from_ref(&record))
3710 .unwrap();
3711 let results_utf8 = batch_to_records(&batch_utf8).unwrap();
3712 assert_eq!(results_utf8[0].text_payload, record.text_payload);
3713
3714 let dir2 = TempDir::new().unwrap();
3716 let uri2 = dir2.path().to_string_lossy().to_string();
3717 let options = ContextStoreOptions {
3718 blob_columns: HashSet::from(["text_payload".to_string()]),
3719 ..Default::default()
3720 };
3721 let store_blob = ContextStore::open_with_options(&uri2, options)
3722 .await
3723 .unwrap();
3724 let batch_binary = store_blob
3725 .records_to_batch(std::slice::from_ref(&record))
3726 .unwrap();
3727 let results_binary = batch_to_records(&batch_binary).unwrap();
3728 assert_eq!(results_binary[0].text_payload, record.text_payload);
3729 });
3730 }
3731
3732 #[test]
3733 fn test_id_index_btree() {
3734 let dir = TempDir::new().unwrap();
3735 let uri = dir.path().to_string_lossy().to_string();
3736 let runtime = tokio::runtime::Runtime::new().unwrap();
3737
3738 runtime.block_on(async {
3739 let options = ContextStoreOptions {
3740 id_index_type: IdIndexType::BTree,
3741 ..Default::default()
3742 };
3743 let mut store = ContextStore::open_with_options(&uri, options)
3744 .await
3745 .unwrap();
3746
3747 let indices = store.dataset.load_indices().await.unwrap();
3749 assert!(
3750 indices.iter().any(|i| i.name == ID_INDEX_NAME),
3751 "btree index should be created on open"
3752 );
3753
3754 for i in 0..5 {
3756 store
3757 .add(&[text_record(&format!("btree-{i}"), i as f32)])
3758 .await
3759 .unwrap();
3760 }
3761 store.compact(None).await.unwrap();
3762
3763 let indices = store.dataset.load_indices().await.unwrap();
3765 assert!(
3766 indices.iter().any(|i| i.name == ID_INDEX_NAME),
3767 "btree index should persist after compaction"
3768 );
3769 });
3770 }
3771
3772 #[test]
3773 fn test_id_index_zonemap() {
3774 let dir = TempDir::new().unwrap();
3775 let uri = dir.path().to_string_lossy().to_string();
3776 let runtime = tokio::runtime::Runtime::new().unwrap();
3777
3778 runtime.block_on(async {
3779 let options = ContextStoreOptions {
3780 id_index_type: IdIndexType::ZoneMap,
3781 ..Default::default()
3782 };
3783 let mut store = ContextStore::open_with_options(&uri, options)
3784 .await
3785 .unwrap();
3786
3787 let indices = store.dataset.load_indices().await.unwrap();
3789 assert!(
3790 indices.iter().any(|i| i.name == ID_INDEX_NAME),
3791 "zonemap index should be created on open"
3792 );
3793
3794 for i in 0..5 {
3795 store
3796 .add(&[text_record(&format!("zm-{i}"), i as f32)])
3797 .await
3798 .unwrap();
3799 }
3800 store.compact(None).await.unwrap();
3801
3802 let indices = store.dataset.load_indices().await.unwrap();
3803 assert!(
3804 indices.iter().any(|i| i.name == ID_INDEX_NAME),
3805 "zonemap index should persist after compaction"
3806 );
3807 });
3808 }
3809
3810 #[test]
3811 fn test_id_index_none_by_default() {
3812 let dir = TempDir::new().unwrap();
3813 let uri = dir.path().to_string_lossy().to_string();
3814 let runtime = tokio::runtime::Runtime::new().unwrap();
3815
3816 runtime.block_on(async {
3817 let mut store = ContextStore::open(&uri).await.unwrap();
3818
3819 store.add(&[text_record("no-idx-1", 0.0)]).await.unwrap();
3820 store.compact(None).await.unwrap();
3821
3822 let indices = store.dataset.load_indices().await.unwrap();
3823 assert!(
3824 !indices.iter().any(|i| i.name == ID_INDEX_NAME),
3825 "no id index should be created when IdIndexType::None"
3826 );
3827 });
3828 }
3829
3830 #[test]
3831 fn test_id_index_idempotent() {
3832 let dir = TempDir::new().unwrap();
3833 let uri = dir.path().to_string_lossy().to_string();
3834 let runtime = tokio::runtime::Runtime::new().unwrap();
3835
3836 runtime.block_on(async {
3837 let options = ContextStoreOptions {
3838 id_index_type: IdIndexType::BTree,
3839 ..Default::default()
3840 };
3841 let mut store = ContextStore::open_with_options(&uri, options)
3842 .await
3843 .unwrap();
3844
3845 for i in 0..5 {
3846 store
3847 .add(&[text_record(&format!("idem-{i}"), i as f32)])
3848 .await
3849 .unwrap();
3850 }
3851
3852 store.create_id_index().await.unwrap();
3854 let v1 = store.version();
3855 store.ensure_id_index().await.unwrap();
3856 let v2 = store.version();
3857 assert_eq!(v1, v2, "ensure_id_index should not recreate existing index");
3858 });
3859 }
3860}