Skip to main content

lance_context_core/
store.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use arrow_array::builder::{
6    FixedSizeListBuilder, Float32Builder, Int32Builder, LargeBinaryBuilder, LargeStringBuilder,
7    StringBuilder, StringDictionaryBuilder, StructBuilder, TimestampMicrosecondBuilder,
8};
9use arrow_array::types::Int8Type;
10use arrow_array::{
11    Array, ArrayRef, DictionaryArray, FixedSizeListArray, Float32Array, Int32Array,
12    LargeBinaryArray, LargeStringArray, RecordBatch, RecordBatchIterator, StringArray, StructArray,
13    TimestampMicrosecondArray,
14};
15use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, TimeUnit};
16use chrono::{DateTime, Timelike, Utc};
17use futures::TryStreamExt;
18use lance::dataset::optimize::{compact_files, CompactionMetrics, CompactionOptions};
19use lance::dataset::{builder::DatasetBuilder, Dataset, WriteMode, WriteParams};
20use lance::io::ObjectStoreParams;
21use lance::{Error as LanceError, Result as LanceResult};
22use tokio::sync::Mutex;
23use tokio::task::JoinHandle;
24use tracing::{error, info, warn};
25
26use crate::record::{ContextRecord, SearchResult, StateMetadata};
27
28/// Embedding length used for the semantic index column.
29const DEFAULT_EMBEDDING_DIM: i32 = 1536;
30const DEFAULT_SEARCH_LIMIT: usize = 10;
31
32/// Configuration for background compaction.
33#[derive(Debug, Clone)]
34pub struct CompactionConfig {
35    /// Whether background compaction is enabled.
36    pub enabled: bool,
37    /// Minimum number of fragments to trigger compaction.
38    pub min_fragments: usize,
39    /// Target rows per fragment after compaction.
40    pub target_rows_per_fragment: usize,
41    /// Maximum rows per row group.
42    pub max_rows_per_group: usize,
43    /// Whether to materialize (remove) deleted rows during compaction.
44    pub materialize_deletions: bool,
45    /// Deletion threshold (0.0-1.0) to trigger materialization.
46    pub materialize_deletions_threshold: f32,
47    /// Number of threads for compaction (None = auto).
48    pub num_threads: Option<usize>,
49    /// Interval in seconds between compaction checks.
50    pub check_interval_secs: u64,
51    /// Quiet hours during which compaction is skipped [(start_hour, end_hour)].
52    pub quiet_hours: Vec<(u8, u8)>,
53}
54
55impl Default for CompactionConfig {
56    fn default() -> Self {
57        Self {
58            enabled: false,
59            min_fragments: 5,
60            target_rows_per_fragment: 1_000_000,
61            max_rows_per_group: 1024,
62            materialize_deletions: true,
63            materialize_deletions_threshold: 0.1,
64            num_threads: None,
65            check_interval_secs: 300,
66            quiet_hours: vec![],
67        }
68    }
69}
70
71/// Statistics about compaction status and history.
72#[derive(Debug, Clone)]
73pub struct CompactionStats {
74    /// Current number of fragments in the dataset.
75    pub total_fragments: usize,
76    /// Whether a compaction is currently in progress.
77    pub is_compacting: bool,
78    /// Timestamp of the last successful compaction.
79    pub last_compaction: Option<DateTime<Utc>>,
80    /// Error message from the last failed compaction.
81    pub last_error: Option<String>,
82    /// Total number of successful compactions performed.
83    pub total_compactions: u64,
84}
85
86/// Internal state for tracking background compaction.
87struct CompactionState {
88    background_task: Option<JoinHandle<()>>,
89    is_compacting: bool,
90    last_compaction: Option<DateTime<Utc>>,
91    last_error: Option<String>,
92    total_compactions: u64,
93}
94
95/// Persistent Lance-backed context store.
96#[derive(Clone)]
97pub struct ContextStore {
98    dataset: Dataset,
99    compaction_state: Arc<Mutex<CompactionState>>,
100    pub compaction_config: CompactionConfig,
101}
102
103/// Additional configuration when opening a [`ContextStore`].
104#[derive(Debug, Clone, Default)]
105pub struct ContextStoreOptions {
106    pub storage_options: Option<HashMap<String, String>>,
107    pub compaction: CompactionConfig,
108}
109
110impl ContextStoreOptions {
111    #[must_use]
112    pub fn storage_options(&self) -> Option<HashMap<String, String>> {
113        self.storage_options.clone()
114    }
115}
116
117impl ContextStore {
118    /// Open an existing context dataset or create a new one with the project schema.
119    pub async fn open(uri: &str) -> LanceResult<Self> {
120        Self::open_with_options(uri, ContextStoreOptions::default()).await
121    }
122
123    /// Open a dataset with explicit object store configuration (e.g. S3 credentials).
124    pub async fn open_with_options(uri: &str, options: ContextStoreOptions) -> LanceResult<Self> {
125        let storage_options = options.storage_options();
126        let dataset = match Self::load_with_options(uri, storage_options.clone()).await {
127            Ok(dataset) => dataset,
128            Err(LanceError::DatasetNotFound { .. }) => {
129                Self::create_with_options(uri, storage_options).await?
130            }
131            Err(err) => return Err(err),
132        };
133
134        let mut store = Self {
135            dataset,
136            compaction_state: Arc::new(Mutex::new(CompactionState {
137                background_task: None,
138                is_compacting: false,
139                last_compaction: None,
140                last_error: None,
141                total_compactions: 0,
142            })),
143            compaction_config: options.compaction,
144        };
145
146        // Start background compaction if enabled
147        store.start_background_compaction().await?;
148
149        Ok(store)
150    }
151
152    /// Append context records to the store and return the new dataset version.
153    pub async fn add(&mut self, entries: &[ContextRecord]) -> LanceResult<u64> {
154        if entries.is_empty() {
155            return Ok(self.dataset.manifest.version);
156        }
157
158        let batch = Self::records_to_batch(entries)?;
159        let schema = batch.schema();
160        let reader = RecordBatchIterator::new(
161            vec![Ok::<RecordBatch, ArrowError>(batch)].into_iter(),
162            schema,
163        );
164        self.dataset.append(reader, None).await?;
165
166        Ok(self.dataset.manifest.version)
167    }
168
169    /// Current dataset version.
170    pub fn version(&self) -> u64 {
171        self.dataset.manifest.version
172    }
173
174    /// Checkout a specific dataset version.
175    pub async fn checkout(&mut self, version_id: u64) -> LanceResult<()> {
176        let dataset = self.dataset.checkout_version(version_id).await?;
177        self.dataset = dataset;
178        Ok(())
179    }
180
181    /// List all records in the dataset.
182    pub async fn list(
183        &self,
184        limit: Option<usize>,
185        offset: Option<usize>,
186    ) -> LanceResult<Vec<ContextRecord>> {
187        let mut scanner = self.dataset.scan();
188        if let Some(limit) = limit {
189            scanner.limit(Some(limit as i64), offset.map(|o| o as i64))?;
190        } else if let Some(offset) = offset {
191            scanner.limit(None, Some(offset as i64))?;
192        }
193
194        let mut stream = scanner.try_into_stream().await?;
195        let mut results = Vec::new();
196        while let Some(batch) = stream.try_next().await? {
197            results.extend(batch_to_records(&batch)?);
198        }
199        Ok(results)
200    }
201
202    /// Perform a nearest-neighbor search over stored embeddings.
203    pub async fn search(
204        &self,
205        query: &[f32],
206        limit: Option<usize>,
207    ) -> LanceResult<Vec<SearchResult>> {
208        if query.len() != DEFAULT_EMBEDDING_DIM as usize {
209            return Err(ArrowError::InvalidArgumentError(format!(
210                "query length {} does not match embedding dimension {}",
211                query.len(),
212                DEFAULT_EMBEDDING_DIM
213            ))
214            .into());
215        }
216
217        let top_k = limit.unwrap_or(DEFAULT_SEARCH_LIMIT);
218        if top_k == 0 {
219            return Ok(Vec::new());
220        }
221
222        let query_array = Float32Array::from(query.to_vec());
223
224        let mut scanner = self.dataset.scan();
225        scanner.nearest("embedding", &query_array, top_k)?;
226        scanner.limit(Some(top_k as i64), Some(0))?;
227
228        let mut stream = scanner.try_into_stream().await?;
229        let mut results = Vec::new();
230        while let Some(batch) = stream.try_next().await? {
231            results.extend(batch_to_search_results(&batch)?);
232        }
233        Ok(results)
234    }
235
236    /// Manually trigger compaction to merge small fragments.
237    pub async fn compact(
238        &mut self,
239        options: Option<CompactionConfig>,
240    ) -> LanceResult<CompactionMetrics> {
241        let config = options.unwrap_or_else(|| self.compaction_config.clone());
242
243        info!(
244            "Starting compaction: {} fragments",
245            self.dataset.count_fragments()
246        );
247        let start = std::time::Instant::now();
248
249        // Mark as compacting
250        {
251            let mut state = self.compaction_state.lock().await;
252            if state.is_compacting {
253                warn!("Compaction already in progress, skipping");
254                return Err(LanceError::from(ArrowError::InvalidArgumentError(
255                    "Compaction already in progress".to_string(),
256                )));
257            }
258            state.is_compacting = true;
259        }
260
261        // Build Lance CompactionOptions
262        let lance_options = CompactionOptions {
263            target_rows_per_fragment: config.target_rows_per_fragment,
264            max_rows_per_group: config.max_rows_per_group,
265            materialize_deletions: config.materialize_deletions,
266            materialize_deletions_threshold: config.materialize_deletions_threshold,
267            num_threads: config.num_threads,
268            ..Default::default()
269        };
270
271        // Run compaction
272        let result = compact_files(&mut self.dataset, lance_options, None).await;
273
274        // Update state
275        let mut state = self.compaction_state.lock().await;
276        state.is_compacting = false;
277
278        match result {
279            Ok(metrics) => {
280                state.last_compaction = Some(Utc::now());
281                state.total_compactions += 1;
282                state.last_error = None;
283
284                info!(
285                    "Compaction completed in {:?}: removed {} fragments ({}files), added {} fragments ({} files)",
286                    start.elapsed(),
287                    metrics.fragments_removed,
288                    metrics.files_removed,
289                    metrics.fragments_added,
290                    metrics.files_added
291                );
292
293                // Reload dataset to see new version
294                self.dataset = Dataset::open(self.dataset.uri()).await?;
295
296                Ok(metrics)
297            }
298            Err(e) => {
299                error!("Compaction failed: {}", e);
300                state.last_error = Some(e.to_string());
301                Err(e)
302            }
303        }
304    }
305
306    /// Check if compaction should run based on configuration thresholds.
307    pub async fn should_compact(&self) -> LanceResult<bool> {
308        let fragment_count = self.dataset.count_fragments();
309
310        if fragment_count < self.compaction_config.min_fragments {
311            return Ok(false);
312        }
313
314        // Check quiet hours
315        if !self.compaction_config.quiet_hours.is_empty() {
316            let now = Utc::now();
317            let current_hour = now.hour() as u8;
318
319            for (start, end) in &self.compaction_config.quiet_hours {
320                if current_hour >= *start && current_hour < *end {
321                    info!("Skipping compaction during quiet hours ({}-{})", start, end);
322                    return Ok(false);
323                }
324            }
325        }
326
327        Ok(true)
328    }
329
330    /// Get current compaction statistics.
331    pub async fn compaction_stats(&self) -> LanceResult<CompactionStats> {
332        let state = self.compaction_state.lock().await;
333
334        Ok(CompactionStats {
335            total_fragments: self.dataset.count_fragments(),
336            is_compacting: state.is_compacting,
337            last_compaction: state.last_compaction,
338            last_error: state.last_error.clone(),
339            total_compactions: state.total_compactions,
340        })
341    }
342
343    /// Start background compaction task if enabled.
344    async fn start_background_compaction(&mut self) -> LanceResult<()> {
345        if !self.compaction_config.enabled {
346            return Ok(());
347        }
348
349        let mut state = self.compaction_state.lock().await;
350        if state.background_task.is_some() {
351            warn!("Background compaction already running");
352            return Ok(());
353        }
354
355        info!(
356            "Starting background compaction (interval: {}s, min fragments: {})",
357            self.compaction_config.check_interval_secs, self.compaction_config.min_fragments
358        );
359
360        let mut store_clone = self.clone();
361        let interval_secs = self.compaction_config.check_interval_secs;
362
363        let task = tokio::spawn(async move {
364            let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
365
366            loop {
367                interval.tick().await;
368
369                match store_clone.should_compact().await {
370                    Ok(true) => {
371                        info!("Background compaction triggered");
372                        if let Err(e) = store_clone.compact(None).await {
373                            error!("Background compaction failed: {}", e);
374                        }
375                    }
376                    Ok(false) => {
377                        // Not needed or in quiet hours
378                    }
379                    Err(e) => {
380                        error!("Error checking compaction need: {}", e);
381                    }
382                }
383            }
384        });
385
386        state.background_task = Some(task);
387        Ok(())
388    }
389
390    /// Stop background compaction task.
391    pub async fn stop_background_compaction(&mut self) -> LanceResult<()> {
392        let mut state = self.compaction_state.lock().await;
393
394        if let Some(task) = state.background_task.take() {
395            info!("Stopping background compaction");
396            task.abort();
397        }
398
399        Ok(())
400    }
401
402    /// Lance schema for the context store.
403    pub fn schema() -> Schema {
404        Schema::new(vec![
405            Field::new("id", DataType::Utf8, false),
406            Field::new("run_id", DataType::Utf8, false),
407            Field::new("bot_id", DataType::Utf8, true),
408            Field::new("session_id", DataType::Utf8, true),
409            Field::new(
410                "created_at",
411                DataType::Timestamp(TimeUnit::Microsecond, None),
412                false,
413            ),
414            Field::new(
415                "role",
416                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
417                false,
418            ),
419            Field::new(
420                "state_metadata",
421                DataType::Struct(
422                    vec![
423                        Field::new("step", DataType::Int32, true),
424                        Field::new("active_plan_id", DataType::Utf8, true),
425                        Field::new("tokens_used", DataType::Int32, true),
426                        Field::new("custom", DataType::Utf8, true),
427                    ]
428                    .into(),
429                ),
430                true,
431            ),
432            Field::new("content_type", DataType::Utf8, false),
433            Field::new("text_payload", DataType::LargeUtf8, true),
434            Field::new("binary_payload", DataType::LargeBinary, true),
435            Field::new(
436                "embedding",
437                DataType::FixedSizeList(
438                    Arc::new(Field::new("item", DataType::Float32, true)),
439                    DEFAULT_EMBEDDING_DIM,
440                ),
441                true,
442            ),
443        ])
444    }
445
446    async fn load_with_options(
447        uri: &str,
448        storage_options: Option<HashMap<String, String>>,
449    ) -> LanceResult<Dataset> {
450        if let Some(options) = storage_options {
451            DatasetBuilder::from_uri(uri)
452                .with_storage_options(options)
453                .load()
454                .await
455        } else {
456            Dataset::open(uri).await
457        }
458    }
459
460    async fn create_with_options(
461        uri: &str,
462        storage_options: Option<HashMap<String, String>>,
463    ) -> LanceResult<Dataset> {
464        let schema = Arc::new(Self::schema());
465        let empty_batch = RecordBatch::new_empty(schema.clone());
466        let batches = RecordBatchIterator::new(
467            vec![Ok::<RecordBatch, ArrowError>(empty_batch)].into_iter(),
468            schema.clone(),
469        );
470
471        let mut params = WriteParams {
472            mode: WriteMode::Create,
473            ..Default::default()
474        };
475
476        if let Some(options) = storage_options {
477            let store_params = ObjectStoreParams {
478                storage_options: Some(options),
479                ..Default::default()
480            };
481            params.store_params = Some(store_params);
482        }
483
484        Dataset::write(batches, uri, Some(params)).await
485    }
486
487    fn records_to_batch(entries: &[ContextRecord]) -> LanceResult<RecordBatch> {
488        let mut id_builder = StringBuilder::new();
489        let mut run_id_builder = StringBuilder::new();
490        let mut bot_id_builder = StringBuilder::new();
491        let mut session_id_builder = StringBuilder::new();
492        let mut created_at_builder = TimestampMicrosecondBuilder::with_capacity(entries.len());
493        let mut role_builder = StringDictionaryBuilder::<Int8Type>::new();
494        let mut content_type_builder = StringBuilder::new();
495        let mut text_builder = LargeStringBuilder::new();
496        let mut binary_builder = LargeBinaryBuilder::new();
497
498        let state_fields: Vec<FieldRef> = vec![
499            Arc::new(Field::new("step", DataType::Int32, true)),
500            Arc::new(Field::new("active_plan_id", DataType::Utf8, true)),
501            Arc::new(Field::new("tokens_used", DataType::Int32, true)),
502            Arc::new(Field::new("custom", DataType::Utf8, true)),
503        ];
504        let mut state_builder = StructBuilder::new(
505            state_fields,
506            vec![
507                Box::new(Int32Builder::new()),
508                Box::new(StringBuilder::new()),
509                Box::new(Int32Builder::new()),
510                Box::new(StringBuilder::new()),
511            ],
512        );
513
514        let mut embedding_builder =
515            FixedSizeListBuilder::new(Float32Builder::new(), DEFAULT_EMBEDDING_DIM);
516
517        for entry in entries {
518            id_builder.append_value(&entry.id);
519            run_id_builder.append_value(&entry.run_id);
520            bot_id_builder.append_option(entry.bot_id.as_deref());
521            session_id_builder.append_option(entry.session_id.as_deref());
522            created_at_builder.append_value(entry.created_at.timestamp_micros());
523            role_builder.append(&entry.role)?;
524            content_type_builder.append_value(&entry.content_type);
525
526            match &entry.text_payload {
527                Some(value) => text_builder.append_value(value),
528                None => text_builder.append_null(),
529            }
530
531            match &entry.binary_payload {
532                Some(value) => binary_builder.append_value(value),
533                None => binary_builder.append_null(),
534            }
535
536            if let Some(metadata) = &entry.state_metadata {
537                state_builder
538                    .field_builder::<Int32Builder>(0)
539                    .unwrap()
540                    .append_option(metadata.step);
541                state_builder
542                    .field_builder::<StringBuilder>(1)
543                    .unwrap()
544                    .append_option(metadata.active_plan_id.as_deref());
545                state_builder
546                    .field_builder::<Int32Builder>(2)
547                    .unwrap()
548                    .append_option(metadata.tokens_used);
549                state_builder
550                    .field_builder::<StringBuilder>(3)
551                    .unwrap()
552                    .append_option(metadata.custom.as_deref());
553                state_builder.append(true);
554            } else {
555                state_builder
556                    .field_builder::<Int32Builder>(0)
557                    .unwrap()
558                    .append_null();
559                state_builder
560                    .field_builder::<StringBuilder>(1)
561                    .unwrap()
562                    .append_null();
563                state_builder
564                    .field_builder::<Int32Builder>(2)
565                    .unwrap()
566                    .append_null();
567                state_builder
568                    .field_builder::<StringBuilder>(3)
569                    .unwrap()
570                    .append_null();
571                state_builder.append(false);
572            }
573
574            if let Some(embedding) = &entry.embedding {
575                if embedding.len() != DEFAULT_EMBEDDING_DIM as usize {
576                    return Err(ArrowError::InvalidArgumentError(format!(
577                        "embedding length {} does not match expected dimension {}",
578                        embedding.len(),
579                        DEFAULT_EMBEDDING_DIM
580                    ))
581                    .into());
582                }
583                {
584                    let values_builder = embedding_builder.values();
585                    for value in embedding {
586                        values_builder.append_value(*value);
587                    }
588                }
589                embedding_builder.append(true);
590            } else {
591                // FixedSizeListBuilder requires padding values for null slots.
592                let values_builder = embedding_builder.values();
593                for _ in 0..DEFAULT_EMBEDDING_DIM {
594                    values_builder.append_null();
595                }
596                embedding_builder.append(false);
597            }
598        }
599
600        let id_array: ArrayRef = Arc::new(id_builder.finish());
601        let run_id_array: ArrayRef = Arc::new(run_id_builder.finish());
602        let bot_id_array: ArrayRef = Arc::new(bot_id_builder.finish());
603        let session_id_array: ArrayRef = Arc::new(session_id_builder.finish());
604        let created_at_array: ArrayRef = Arc::new(created_at_builder.finish());
605        let role_array: ArrayRef = Arc::new(role_builder.finish());
606        let content_type_array: ArrayRef = Arc::new(content_type_builder.finish());
607        let text_array: ArrayRef = Arc::new(text_builder.finish());
608        let binary_array: ArrayRef = Arc::new(binary_builder.finish());
609        let state_array: ArrayRef = Arc::new(state_builder.finish());
610        let embedding_array: ArrayRef = Arc::new(embedding_builder.finish());
611
612        let schema = Arc::new(Self::schema());
613        let batch = RecordBatch::try_new(
614            schema,
615            vec![
616                id_array,
617                run_id_array,
618                bot_id_array,
619                session_id_array,
620                created_at_array,
621                role_array,
622                state_array,
623                content_type_array,
624                text_array,
625                binary_array,
626                embedding_array,
627            ],
628        )?;
629
630        Ok(batch)
631    }
632}
633
634impl Drop for ContextStore {
635    fn drop(&mut self) {
636        // Best-effort cleanup of background task
637        if let Ok(mut state) = self.compaction_state.try_lock() {
638            if let Some(task) = state.background_task.take() {
639                task.abort();
640            }
641        }
642    }
643}
644
645fn batch_to_search_results(batch: &RecordBatch) -> LanceResult<Vec<SearchResult>> {
646    let records = batch_to_records(batch)?;
647
648    let distance_column = batch.column_by_name("_distance").ok_or_else(|| {
649        LanceError::from(ArrowError::InvalidArgumentError(
650            "search results missing _distance column".to_string(),
651        ))
652    })?;
653    let distance_array = distance_column
654        .as_ref()
655        .as_any()
656        .downcast_ref::<Float32Array>()
657        .ok_or_else(|| {
658            LanceError::from(ArrowError::InvalidArgumentError(
659                "_distance column has unexpected data type".to_string(),
660            ))
661        })?;
662
663    Ok(records
664        .into_iter()
665        .enumerate()
666        .map(|(i, record)| SearchResult {
667            record,
668            distance: distance_array.value(i),
669        })
670        .collect())
671}
672
673/// Convert a record batch to context records.
674fn batch_to_records(batch: &RecordBatch) -> LanceResult<Vec<ContextRecord>> {
675    let id_array = column_as::<StringArray>(batch, "id")?;
676    let run_id_array = column_as::<StringArray>(batch, "run_id")?;
677    let bot_id_array = column_as_optional::<StringArray>(batch, "bot_id");
678    let session_id_array = column_as_optional::<StringArray>(batch, "session_id");
679    let created_at_array = column_as::<TimestampMicrosecondArray>(batch, "created_at")?;
680    let role_array = column_as::<DictionaryArray<Int8Type>>(batch, "role")?;
681    let state_array = column_as::<StructArray>(batch, "state_metadata")?;
682    let content_type_array = column_as::<StringArray>(batch, "content_type")?;
683    let text_array = column_as::<LargeStringArray>(batch, "text_payload")?;
684    let binary_array = column_as::<LargeBinaryArray>(batch, "binary_payload")?;
685    let embedding_array = column_as::<FixedSizeListArray>(batch, "embedding")?;
686
687    let step_array = state_array
688        .column(0)
689        .as_ref()
690        .as_any()
691        .downcast_ref::<Int32Array>()
692        .ok_or_else(|| {
693            LanceError::from(ArrowError::InvalidArgumentError(
694                "step column has unexpected data type".to_string(),
695            ))
696        })?;
697    let active_plan_array = state_array
698        .column(1)
699        .as_ref()
700        .as_any()
701        .downcast_ref::<StringArray>()
702        .ok_or_else(|| {
703            LanceError::from(ArrowError::InvalidArgumentError(
704                "active_plan_id column has unexpected data type".to_string(),
705            ))
706        })?;
707    let tokens_used_array = state_array
708        .column(2)
709        .as_ref()
710        .as_any()
711        .downcast_ref::<Int32Array>()
712        .ok_or_else(|| {
713            LanceError::from(ArrowError::InvalidArgumentError(
714                "tokens_used column has unexpected data type".to_string(),
715            ))
716        })?;
717    let custom_array = state_array
718        .column(3)
719        .as_ref()
720        .as_any()
721        .downcast_ref::<StringArray>()
722        .ok_or_else(|| {
723            LanceError::from(ArrowError::InvalidArgumentError(
724                "custom column has unexpected data type".to_string(),
725            ))
726        })?;
727
728    let mut results = Vec::with_capacity(batch.num_rows());
729    for row in 0..batch.num_rows() {
730        let created_at =
731            DateTime::from_timestamp_micros(created_at_array.value(row)).ok_or_else(|| {
732                LanceError::from(ArrowError::InvalidArgumentError(format!(
733                    "invalid timestamp value {}",
734                    created_at_array.value(row)
735                )))
736            })?;
737
738        let state_metadata = if state_array.is_null(row) {
739            None
740        } else {
741            Some(StateMetadata {
742                step: if step_array.is_null(row) {
743                    None
744                } else {
745                    Some(step_array.value(row))
746                },
747                active_plan_id: if active_plan_array.is_null(row) {
748                    None
749                } else {
750                    Some(active_plan_array.value(row).to_string())
751                },
752                tokens_used: if tokens_used_array.is_null(row) {
753                    None
754                } else {
755                    Some(tokens_used_array.value(row))
756                },
757                custom: if custom_array.is_null(row) {
758                    None
759                } else {
760                    Some(custom_array.value(row).to_string())
761                },
762            })
763        };
764
765        let text_payload = if text_array.is_null(row) {
766            None
767        } else {
768            Some(text_array.value(row).to_string())
769        };
770
771        let binary_payload = if binary_array.is_null(row) {
772            None
773        } else {
774            Some(binary_array.value(row).to_vec())
775        };
776
777        let embedding = if embedding_array.is_null(row) {
778            None
779        } else {
780            Some(embedding_from_list(embedding_array, row)?)
781        };
782
783        let role = if role_array.is_null(row) {
784            return Err(LanceError::from(ArrowError::InvalidArgumentError(
785                "role column contains null values".to_string(),
786            )));
787        } else {
788            let role_values = role_array
789                .values()
790                .as_any()
791                .downcast_ref::<StringArray>()
792                .ok_or_else(|| {
793                    LanceError::from(ArrowError::InvalidArgumentError(
794                        "role dictionary values are not strings".to_string(),
795                    ))
796                })?;
797            let key = role_array.keys().value(row) as usize;
798            role_values.value(key).to_string()
799        };
800
801        let bot_id = bot_id_array.and_then(|arr| {
802            if arr.is_null(row) {
803                None
804            } else {
805                Some(arr.value(row).to_string())
806            }
807        });
808
809        let session_id = session_id_array.and_then(|arr| {
810            if arr.is_null(row) {
811                None
812            } else {
813                Some(arr.value(row).to_string())
814            }
815        });
816
817        results.push(ContextRecord {
818            id: id_array.value(row).to_string(),
819            run_id: run_id_array.value(row).to_string(),
820            bot_id,
821            session_id,
822            created_at,
823            role,
824            state_metadata,
825            content_type: content_type_array.value(row).to_string(),
826            text_payload,
827            binary_payload,
828            embedding,
829        });
830    }
831
832    Ok(results)
833}
834
835fn embedding_from_list(list: &FixedSizeListArray, row: usize) -> LanceResult<Vec<f32>> {
836    let values = list.value(row);
837    let float_array = values
838        .as_ref()
839        .as_any()
840        .downcast_ref::<Float32Array>()
841        .ok_or_else(|| {
842            LanceError::from(ArrowError::InvalidArgumentError(
843                "embedding column does not contain float32 values".to_string(),
844            ))
845        })?;
846    let mut embedding = Vec::with_capacity(float_array.len());
847    for idx in 0..float_array.len() {
848        embedding.push(float_array.value(idx));
849    }
850    Ok(embedding)
851}
852
853fn column_as<'a, A>(batch: &'a RecordBatch, name: &str) -> LanceResult<&'a A>
854where
855    A: Array + 'static,
856{
857    let column = batch.column_by_name(name).ok_or_else(|| {
858        LanceError::from(ArrowError::InvalidArgumentError(format!(
859            "column '{name}' not found"
860        )))
861    })?;
862    column.as_ref().as_any().downcast_ref::<A>().ok_or_else(|| {
863        LanceError::from(ArrowError::InvalidArgumentError(format!(
864            "column '{name}' has unexpected data type"
865        )))
866    })
867}
868
869fn column_as_optional<'a, A>(batch: &'a RecordBatch, name: &str) -> Option<&'a A>
870where
871    A: Array + 'static,
872{
873    batch
874        .column_by_name(name)
875        .and_then(|col| col.as_ref().as_any().downcast_ref::<A>())
876}
877
878#[cfg(test)]
879mod tests {
880    use super::*;
881    use crate::serde::CONTENT_TYPE_TEXT;
882    use chrono::Utc;
883    use tempfile::TempDir;
884
885    fn make_embedding(pivot: f32) -> Vec<f32> {
886        let mut values = vec![0.0; DEFAULT_EMBEDDING_DIM as usize];
887        if !values.is_empty() {
888            values[0] = pivot;
889        }
890        values
891    }
892
893    fn text_record(id: &str, embedding_pivot: f32) -> ContextRecord {
894        ContextRecord {
895            id: id.to_string(),
896            run_id: format!("run-{id}"),
897            bot_id: None,
898            session_id: None,
899            created_at: Utc::now(),
900            role: "user".to_string(),
901            state_metadata: Some(StateMetadata {
902                step: Some(1),
903                active_plan_id: Some("plan".to_string()),
904                tokens_used: Some(10),
905                custom: None,
906            }),
907            content_type: CONTENT_TYPE_TEXT.to_string(),
908            text_payload: Some(format!("payload-{id}")),
909            binary_payload: None,
910            embedding: Some(make_embedding(embedding_pivot)),
911        }
912    }
913
914    #[test]
915    fn search_orders_by_distance() {
916        let dir = TempDir::new().unwrap();
917        let uri = dir.path().to_string_lossy().to_string();
918        let runtime = tokio::runtime::Runtime::new().unwrap();
919        runtime.block_on(async {
920            let mut store = ContextStore::open(&uri).await.unwrap();
921            let first = text_record("a", 0.0);
922            let second = text_record("b", 1.0);
923            store.add(&[first.clone(), second.clone()]).await.unwrap();
924
925            let query = make_embedding(1.0);
926            let results = store.search(&query, Some(2)).await.unwrap();
927
928            assert_eq!(results.len(), 2);
929            assert_eq!(results[0].record.id, second.id);
930            assert!(
931                results[0].distance <= results[1].distance,
932                "results not ordered by distance: {:?}",
933                results
934            );
935        });
936    }
937
938    #[test]
939    fn search_validates_query_length() {
940        let dir = TempDir::new().unwrap();
941        let uri = dir.path().to_string_lossy().to_string();
942        let runtime = tokio::runtime::Runtime::new().unwrap();
943        runtime.block_on(async {
944            let store = ContextStore::open(&uri).await.unwrap();
945            let err = store.search(&[0.0_f32], None).await.unwrap_err();
946            let message = err.to_string();
947            assert!(
948                message.contains("embedding dimension"),
949                "unexpected error message: {message}"
950            );
951        });
952    }
953}