1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use arrow_array::builder::{
6 FixedSizeListBuilder, Float32Builder, Int32Builder, LargeBinaryBuilder, LargeStringBuilder,
7 StringBuilder, StringDictionaryBuilder, StructBuilder, TimestampMicrosecondBuilder,
8};
9use arrow_array::types::Int8Type;
10use arrow_array::{
11 Array, ArrayRef, DictionaryArray, FixedSizeListArray, Float32Array, Int32Array,
12 LargeBinaryArray, LargeStringArray, RecordBatch, RecordBatchIterator, StringArray, StructArray,
13 TimestampMicrosecondArray,
14};
15use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, TimeUnit};
16use chrono::{DateTime, Timelike, Utc};
17use futures::TryStreamExt;
18use lance::dataset::optimize::{compact_files, CompactionMetrics, CompactionOptions};
19use lance::dataset::{builder::DatasetBuilder, Dataset, WriteMode, WriteParams};
20use lance::io::ObjectStoreParams;
21use lance::{Error as LanceError, Result as LanceResult};
22use tokio::sync::Mutex;
23use tokio::task::JoinHandle;
24use tracing::{error, info, warn};
25
26use crate::record::{ContextRecord, SearchResult, StateMetadata};
27
28const DEFAULT_EMBEDDING_DIM: i32 = 1536;
30const DEFAULT_SEARCH_LIMIT: usize = 10;
31
32#[derive(Debug, Clone)]
34pub struct CompactionConfig {
35 pub enabled: bool,
37 pub min_fragments: usize,
39 pub target_rows_per_fragment: usize,
41 pub max_rows_per_group: usize,
43 pub materialize_deletions: bool,
45 pub materialize_deletions_threshold: f32,
47 pub num_threads: Option<usize>,
49 pub check_interval_secs: u64,
51 pub quiet_hours: Vec<(u8, u8)>,
53}
54
55impl Default for CompactionConfig {
56 fn default() -> Self {
57 Self {
58 enabled: false,
59 min_fragments: 5,
60 target_rows_per_fragment: 1_000_000,
61 max_rows_per_group: 1024,
62 materialize_deletions: true,
63 materialize_deletions_threshold: 0.1,
64 num_threads: None,
65 check_interval_secs: 300,
66 quiet_hours: vec![],
67 }
68 }
69}
70
71#[derive(Debug, Clone)]
73pub struct CompactionStats {
74 pub total_fragments: usize,
76 pub is_compacting: bool,
78 pub last_compaction: Option<DateTime<Utc>>,
80 pub last_error: Option<String>,
82 pub total_compactions: u64,
84}
85
86struct CompactionState {
88 background_task: Option<JoinHandle<()>>,
89 is_compacting: bool,
90 last_compaction: Option<DateTime<Utc>>,
91 last_error: Option<String>,
92 total_compactions: u64,
93}
94
95#[derive(Clone)]
97pub struct ContextStore {
98 dataset: Dataset,
99 compaction_state: Arc<Mutex<CompactionState>>,
100 pub compaction_config: CompactionConfig,
101}
102
103#[derive(Debug, Clone, Default)]
105pub struct ContextStoreOptions {
106 pub storage_options: Option<HashMap<String, String>>,
107 pub compaction: CompactionConfig,
108}
109
110impl ContextStoreOptions {
111 #[must_use]
112 pub fn storage_options(&self) -> Option<HashMap<String, String>> {
113 self.storage_options.clone()
114 }
115}
116
117impl ContextStore {
118 pub async fn open(uri: &str) -> LanceResult<Self> {
120 Self::open_with_options(uri, ContextStoreOptions::default()).await
121 }
122
123 pub async fn open_with_options(uri: &str, options: ContextStoreOptions) -> LanceResult<Self> {
125 let storage_options = options.storage_options();
126 let dataset = match Self::load_with_options(uri, storage_options.clone()).await {
127 Ok(dataset) => dataset,
128 Err(LanceError::DatasetNotFound { .. }) => {
129 Self::create_with_options(uri, storage_options).await?
130 }
131 Err(err) => return Err(err),
132 };
133
134 let mut store = Self {
135 dataset,
136 compaction_state: Arc::new(Mutex::new(CompactionState {
137 background_task: None,
138 is_compacting: false,
139 last_compaction: None,
140 last_error: None,
141 total_compactions: 0,
142 })),
143 compaction_config: options.compaction,
144 };
145
146 store.start_background_compaction().await?;
148
149 Ok(store)
150 }
151
152 pub async fn add(&mut self, entries: &[ContextRecord]) -> LanceResult<u64> {
154 if entries.is_empty() {
155 return Ok(self.dataset.manifest.version);
156 }
157
158 let batch = Self::records_to_batch(entries)?;
159 let schema = batch.schema();
160 let reader = RecordBatchIterator::new(
161 vec![Ok::<RecordBatch, ArrowError>(batch)].into_iter(),
162 schema,
163 );
164 self.dataset.append(reader, None).await?;
165
166 Ok(self.dataset.manifest.version)
167 }
168
169 pub fn version(&self) -> u64 {
171 self.dataset.manifest.version
172 }
173
174 pub async fn checkout(&mut self, version_id: u64) -> LanceResult<()> {
176 let dataset = self.dataset.checkout_version(version_id).await?;
177 self.dataset = dataset;
178 Ok(())
179 }
180
181 pub async fn list(
183 &self,
184 limit: Option<usize>,
185 offset: Option<usize>,
186 ) -> LanceResult<Vec<ContextRecord>> {
187 let mut scanner = self.dataset.scan();
188 if let Some(limit) = limit {
189 scanner.limit(Some(limit as i64), offset.map(|o| o as i64))?;
190 } else if let Some(offset) = offset {
191 scanner.limit(None, Some(offset as i64))?;
192 }
193
194 let mut stream = scanner.try_into_stream().await?;
195 let mut results = Vec::new();
196 while let Some(batch) = stream.try_next().await? {
197 results.extend(batch_to_records(&batch)?);
198 }
199 Ok(results)
200 }
201
202 pub async fn search(
204 &self,
205 query: &[f32],
206 limit: Option<usize>,
207 ) -> LanceResult<Vec<SearchResult>> {
208 if query.len() != DEFAULT_EMBEDDING_DIM as usize {
209 return Err(ArrowError::InvalidArgumentError(format!(
210 "query length {} does not match embedding dimension {}",
211 query.len(),
212 DEFAULT_EMBEDDING_DIM
213 ))
214 .into());
215 }
216
217 let top_k = limit.unwrap_or(DEFAULT_SEARCH_LIMIT);
218 if top_k == 0 {
219 return Ok(Vec::new());
220 }
221
222 let query_array = Float32Array::from(query.to_vec());
223
224 let mut scanner = self.dataset.scan();
225 scanner.nearest("embedding", &query_array, top_k)?;
226 scanner.limit(Some(top_k as i64), Some(0))?;
227
228 let mut stream = scanner.try_into_stream().await?;
229 let mut results = Vec::new();
230 while let Some(batch) = stream.try_next().await? {
231 results.extend(batch_to_search_results(&batch)?);
232 }
233 Ok(results)
234 }
235
236 pub async fn compact(
238 &mut self,
239 options: Option<CompactionConfig>,
240 ) -> LanceResult<CompactionMetrics> {
241 let config = options.unwrap_or_else(|| self.compaction_config.clone());
242
243 info!(
244 "Starting compaction: {} fragments",
245 self.dataset.count_fragments()
246 );
247 let start = std::time::Instant::now();
248
249 {
251 let mut state = self.compaction_state.lock().await;
252 if state.is_compacting {
253 warn!("Compaction already in progress, skipping");
254 return Err(LanceError::from(ArrowError::InvalidArgumentError(
255 "Compaction already in progress".to_string(),
256 )));
257 }
258 state.is_compacting = true;
259 }
260
261 let lance_options = CompactionOptions {
263 target_rows_per_fragment: config.target_rows_per_fragment,
264 max_rows_per_group: config.max_rows_per_group,
265 materialize_deletions: config.materialize_deletions,
266 materialize_deletions_threshold: config.materialize_deletions_threshold,
267 num_threads: config.num_threads,
268 ..Default::default()
269 };
270
271 let result = compact_files(&mut self.dataset, lance_options, None).await;
273
274 let mut state = self.compaction_state.lock().await;
276 state.is_compacting = false;
277
278 match result {
279 Ok(metrics) => {
280 state.last_compaction = Some(Utc::now());
281 state.total_compactions += 1;
282 state.last_error = None;
283
284 info!(
285 "Compaction completed in {:?}: removed {} fragments ({}files), added {} fragments ({} files)",
286 start.elapsed(),
287 metrics.fragments_removed,
288 metrics.files_removed,
289 metrics.fragments_added,
290 metrics.files_added
291 );
292
293 self.dataset = Dataset::open(self.dataset.uri()).await?;
295
296 Ok(metrics)
297 }
298 Err(e) => {
299 error!("Compaction failed: {}", e);
300 state.last_error = Some(e.to_string());
301 Err(e)
302 }
303 }
304 }
305
306 pub async fn should_compact(&self) -> LanceResult<bool> {
308 let fragment_count = self.dataset.count_fragments();
309
310 if fragment_count < self.compaction_config.min_fragments {
311 return Ok(false);
312 }
313
314 if !self.compaction_config.quiet_hours.is_empty() {
316 let now = Utc::now();
317 let current_hour = now.hour() as u8;
318
319 for (start, end) in &self.compaction_config.quiet_hours {
320 if current_hour >= *start && current_hour < *end {
321 info!("Skipping compaction during quiet hours ({}-{})", start, end);
322 return Ok(false);
323 }
324 }
325 }
326
327 Ok(true)
328 }
329
330 pub async fn compaction_stats(&self) -> LanceResult<CompactionStats> {
332 let state = self.compaction_state.lock().await;
333
334 Ok(CompactionStats {
335 total_fragments: self.dataset.count_fragments(),
336 is_compacting: state.is_compacting,
337 last_compaction: state.last_compaction,
338 last_error: state.last_error.clone(),
339 total_compactions: state.total_compactions,
340 })
341 }
342
343 async fn start_background_compaction(&mut self) -> LanceResult<()> {
345 if !self.compaction_config.enabled {
346 return Ok(());
347 }
348
349 let mut state = self.compaction_state.lock().await;
350 if state.background_task.is_some() {
351 warn!("Background compaction already running");
352 return Ok(());
353 }
354
355 info!(
356 "Starting background compaction (interval: {}s, min fragments: {})",
357 self.compaction_config.check_interval_secs, self.compaction_config.min_fragments
358 );
359
360 let mut store_clone = self.clone();
361 let interval_secs = self.compaction_config.check_interval_secs;
362
363 let task = tokio::spawn(async move {
364 let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
365
366 loop {
367 interval.tick().await;
368
369 match store_clone.should_compact().await {
370 Ok(true) => {
371 info!("Background compaction triggered");
372 if let Err(e) = store_clone.compact(None).await {
373 error!("Background compaction failed: {}", e);
374 }
375 }
376 Ok(false) => {
377 }
379 Err(e) => {
380 error!("Error checking compaction need: {}", e);
381 }
382 }
383 }
384 });
385
386 state.background_task = Some(task);
387 Ok(())
388 }
389
390 pub async fn stop_background_compaction(&mut self) -> LanceResult<()> {
392 let mut state = self.compaction_state.lock().await;
393
394 if let Some(task) = state.background_task.take() {
395 info!("Stopping background compaction");
396 task.abort();
397 }
398
399 Ok(())
400 }
401
402 pub fn schema() -> Schema {
404 Schema::new(vec![
405 Field::new("id", DataType::Utf8, false),
406 Field::new("run_id", DataType::Utf8, false),
407 Field::new("bot_id", DataType::Utf8, true),
408 Field::new("session_id", DataType::Utf8, true),
409 Field::new(
410 "created_at",
411 DataType::Timestamp(TimeUnit::Microsecond, None),
412 false,
413 ),
414 Field::new(
415 "role",
416 DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
417 false,
418 ),
419 Field::new(
420 "state_metadata",
421 DataType::Struct(
422 vec![
423 Field::new("step", DataType::Int32, true),
424 Field::new("active_plan_id", DataType::Utf8, true),
425 Field::new("tokens_used", DataType::Int32, true),
426 Field::new("custom", DataType::Utf8, true),
427 ]
428 .into(),
429 ),
430 true,
431 ),
432 Field::new("content_type", DataType::Utf8, false),
433 Field::new("text_payload", DataType::LargeUtf8, true),
434 Field::new("binary_payload", DataType::LargeBinary, true),
435 Field::new(
436 "embedding",
437 DataType::FixedSizeList(
438 Arc::new(Field::new("item", DataType::Float32, true)),
439 DEFAULT_EMBEDDING_DIM,
440 ),
441 true,
442 ),
443 ])
444 }
445
446 async fn load_with_options(
447 uri: &str,
448 storage_options: Option<HashMap<String, String>>,
449 ) -> LanceResult<Dataset> {
450 if let Some(options) = storage_options {
451 DatasetBuilder::from_uri(uri)
452 .with_storage_options(options)
453 .load()
454 .await
455 } else {
456 Dataset::open(uri).await
457 }
458 }
459
460 async fn create_with_options(
461 uri: &str,
462 storage_options: Option<HashMap<String, String>>,
463 ) -> LanceResult<Dataset> {
464 let schema = Arc::new(Self::schema());
465 let empty_batch = RecordBatch::new_empty(schema.clone());
466 let batches = RecordBatchIterator::new(
467 vec![Ok::<RecordBatch, ArrowError>(empty_batch)].into_iter(),
468 schema.clone(),
469 );
470
471 let mut params = WriteParams {
472 mode: WriteMode::Create,
473 ..Default::default()
474 };
475
476 if let Some(options) = storage_options {
477 let store_params = ObjectStoreParams {
478 storage_options: Some(options),
479 ..Default::default()
480 };
481 params.store_params = Some(store_params);
482 }
483
484 Dataset::write(batches, uri, Some(params)).await
485 }
486
487 fn records_to_batch(entries: &[ContextRecord]) -> LanceResult<RecordBatch> {
488 let mut id_builder = StringBuilder::new();
489 let mut run_id_builder = StringBuilder::new();
490 let mut bot_id_builder = StringBuilder::new();
491 let mut session_id_builder = StringBuilder::new();
492 let mut created_at_builder = TimestampMicrosecondBuilder::with_capacity(entries.len());
493 let mut role_builder = StringDictionaryBuilder::<Int8Type>::new();
494 let mut content_type_builder = StringBuilder::new();
495 let mut text_builder = LargeStringBuilder::new();
496 let mut binary_builder = LargeBinaryBuilder::new();
497
498 let state_fields: Vec<FieldRef> = vec![
499 Arc::new(Field::new("step", DataType::Int32, true)),
500 Arc::new(Field::new("active_plan_id", DataType::Utf8, true)),
501 Arc::new(Field::new("tokens_used", DataType::Int32, true)),
502 Arc::new(Field::new("custom", DataType::Utf8, true)),
503 ];
504 let mut state_builder = StructBuilder::new(
505 state_fields,
506 vec![
507 Box::new(Int32Builder::new()),
508 Box::new(StringBuilder::new()),
509 Box::new(Int32Builder::new()),
510 Box::new(StringBuilder::new()),
511 ],
512 );
513
514 let mut embedding_builder =
515 FixedSizeListBuilder::new(Float32Builder::new(), DEFAULT_EMBEDDING_DIM);
516
517 for entry in entries {
518 id_builder.append_value(&entry.id);
519 run_id_builder.append_value(&entry.run_id);
520 bot_id_builder.append_option(entry.bot_id.as_deref());
521 session_id_builder.append_option(entry.session_id.as_deref());
522 created_at_builder.append_value(entry.created_at.timestamp_micros());
523 role_builder.append(&entry.role)?;
524 content_type_builder.append_value(&entry.content_type);
525
526 match &entry.text_payload {
527 Some(value) => text_builder.append_value(value),
528 None => text_builder.append_null(),
529 }
530
531 match &entry.binary_payload {
532 Some(value) => binary_builder.append_value(value),
533 None => binary_builder.append_null(),
534 }
535
536 if let Some(metadata) = &entry.state_metadata {
537 state_builder
538 .field_builder::<Int32Builder>(0)
539 .unwrap()
540 .append_option(metadata.step);
541 state_builder
542 .field_builder::<StringBuilder>(1)
543 .unwrap()
544 .append_option(metadata.active_plan_id.as_deref());
545 state_builder
546 .field_builder::<Int32Builder>(2)
547 .unwrap()
548 .append_option(metadata.tokens_used);
549 state_builder
550 .field_builder::<StringBuilder>(3)
551 .unwrap()
552 .append_option(metadata.custom.as_deref());
553 state_builder.append(true);
554 } else {
555 state_builder
556 .field_builder::<Int32Builder>(0)
557 .unwrap()
558 .append_null();
559 state_builder
560 .field_builder::<StringBuilder>(1)
561 .unwrap()
562 .append_null();
563 state_builder
564 .field_builder::<Int32Builder>(2)
565 .unwrap()
566 .append_null();
567 state_builder
568 .field_builder::<StringBuilder>(3)
569 .unwrap()
570 .append_null();
571 state_builder.append(false);
572 }
573
574 if let Some(embedding) = &entry.embedding {
575 if embedding.len() != DEFAULT_EMBEDDING_DIM as usize {
576 return Err(ArrowError::InvalidArgumentError(format!(
577 "embedding length {} does not match expected dimension {}",
578 embedding.len(),
579 DEFAULT_EMBEDDING_DIM
580 ))
581 .into());
582 }
583 {
584 let values_builder = embedding_builder.values();
585 for value in embedding {
586 values_builder.append_value(*value);
587 }
588 }
589 embedding_builder.append(true);
590 } else {
591 let values_builder = embedding_builder.values();
593 for _ in 0..DEFAULT_EMBEDDING_DIM {
594 values_builder.append_null();
595 }
596 embedding_builder.append(false);
597 }
598 }
599
600 let id_array: ArrayRef = Arc::new(id_builder.finish());
601 let run_id_array: ArrayRef = Arc::new(run_id_builder.finish());
602 let bot_id_array: ArrayRef = Arc::new(bot_id_builder.finish());
603 let session_id_array: ArrayRef = Arc::new(session_id_builder.finish());
604 let created_at_array: ArrayRef = Arc::new(created_at_builder.finish());
605 let role_array: ArrayRef = Arc::new(role_builder.finish());
606 let content_type_array: ArrayRef = Arc::new(content_type_builder.finish());
607 let text_array: ArrayRef = Arc::new(text_builder.finish());
608 let binary_array: ArrayRef = Arc::new(binary_builder.finish());
609 let state_array: ArrayRef = Arc::new(state_builder.finish());
610 let embedding_array: ArrayRef = Arc::new(embedding_builder.finish());
611
612 let schema = Arc::new(Self::schema());
613 let batch = RecordBatch::try_new(
614 schema,
615 vec![
616 id_array,
617 run_id_array,
618 bot_id_array,
619 session_id_array,
620 created_at_array,
621 role_array,
622 state_array,
623 content_type_array,
624 text_array,
625 binary_array,
626 embedding_array,
627 ],
628 )?;
629
630 Ok(batch)
631 }
632}
633
634impl Drop for ContextStore {
635 fn drop(&mut self) {
636 if let Ok(mut state) = self.compaction_state.try_lock() {
638 if let Some(task) = state.background_task.take() {
639 task.abort();
640 }
641 }
642 }
643}
644
645fn batch_to_search_results(batch: &RecordBatch) -> LanceResult<Vec<SearchResult>> {
646 let records = batch_to_records(batch)?;
647
648 let distance_column = batch.column_by_name("_distance").ok_or_else(|| {
649 LanceError::from(ArrowError::InvalidArgumentError(
650 "search results missing _distance column".to_string(),
651 ))
652 })?;
653 let distance_array = distance_column
654 .as_ref()
655 .as_any()
656 .downcast_ref::<Float32Array>()
657 .ok_or_else(|| {
658 LanceError::from(ArrowError::InvalidArgumentError(
659 "_distance column has unexpected data type".to_string(),
660 ))
661 })?;
662
663 Ok(records
664 .into_iter()
665 .enumerate()
666 .map(|(i, record)| SearchResult {
667 record,
668 distance: distance_array.value(i),
669 })
670 .collect())
671}
672
673fn batch_to_records(batch: &RecordBatch) -> LanceResult<Vec<ContextRecord>> {
675 let id_array = column_as::<StringArray>(batch, "id")?;
676 let run_id_array = column_as::<StringArray>(batch, "run_id")?;
677 let bot_id_array = column_as_optional::<StringArray>(batch, "bot_id");
678 let session_id_array = column_as_optional::<StringArray>(batch, "session_id");
679 let created_at_array = column_as::<TimestampMicrosecondArray>(batch, "created_at")?;
680 let role_array = column_as::<DictionaryArray<Int8Type>>(batch, "role")?;
681 let state_array = column_as::<StructArray>(batch, "state_metadata")?;
682 let content_type_array = column_as::<StringArray>(batch, "content_type")?;
683 let text_array = column_as::<LargeStringArray>(batch, "text_payload")?;
684 let binary_array = column_as::<LargeBinaryArray>(batch, "binary_payload")?;
685 let embedding_array = column_as::<FixedSizeListArray>(batch, "embedding")?;
686
687 let step_array = state_array
688 .column(0)
689 .as_ref()
690 .as_any()
691 .downcast_ref::<Int32Array>()
692 .ok_or_else(|| {
693 LanceError::from(ArrowError::InvalidArgumentError(
694 "step column has unexpected data type".to_string(),
695 ))
696 })?;
697 let active_plan_array = state_array
698 .column(1)
699 .as_ref()
700 .as_any()
701 .downcast_ref::<StringArray>()
702 .ok_or_else(|| {
703 LanceError::from(ArrowError::InvalidArgumentError(
704 "active_plan_id column has unexpected data type".to_string(),
705 ))
706 })?;
707 let tokens_used_array = state_array
708 .column(2)
709 .as_ref()
710 .as_any()
711 .downcast_ref::<Int32Array>()
712 .ok_or_else(|| {
713 LanceError::from(ArrowError::InvalidArgumentError(
714 "tokens_used column has unexpected data type".to_string(),
715 ))
716 })?;
717 let custom_array = state_array
718 .column(3)
719 .as_ref()
720 .as_any()
721 .downcast_ref::<StringArray>()
722 .ok_or_else(|| {
723 LanceError::from(ArrowError::InvalidArgumentError(
724 "custom column has unexpected data type".to_string(),
725 ))
726 })?;
727
728 let mut results = Vec::with_capacity(batch.num_rows());
729 for row in 0..batch.num_rows() {
730 let created_at =
731 DateTime::from_timestamp_micros(created_at_array.value(row)).ok_or_else(|| {
732 LanceError::from(ArrowError::InvalidArgumentError(format!(
733 "invalid timestamp value {}",
734 created_at_array.value(row)
735 )))
736 })?;
737
738 let state_metadata = if state_array.is_null(row) {
739 None
740 } else {
741 Some(StateMetadata {
742 step: if step_array.is_null(row) {
743 None
744 } else {
745 Some(step_array.value(row))
746 },
747 active_plan_id: if active_plan_array.is_null(row) {
748 None
749 } else {
750 Some(active_plan_array.value(row).to_string())
751 },
752 tokens_used: if tokens_used_array.is_null(row) {
753 None
754 } else {
755 Some(tokens_used_array.value(row))
756 },
757 custom: if custom_array.is_null(row) {
758 None
759 } else {
760 Some(custom_array.value(row).to_string())
761 },
762 })
763 };
764
765 let text_payload = if text_array.is_null(row) {
766 None
767 } else {
768 Some(text_array.value(row).to_string())
769 };
770
771 let binary_payload = if binary_array.is_null(row) {
772 None
773 } else {
774 Some(binary_array.value(row).to_vec())
775 };
776
777 let embedding = if embedding_array.is_null(row) {
778 None
779 } else {
780 Some(embedding_from_list(embedding_array, row)?)
781 };
782
783 let role = if role_array.is_null(row) {
784 return Err(LanceError::from(ArrowError::InvalidArgumentError(
785 "role column contains null values".to_string(),
786 )));
787 } else {
788 let role_values = role_array
789 .values()
790 .as_any()
791 .downcast_ref::<StringArray>()
792 .ok_or_else(|| {
793 LanceError::from(ArrowError::InvalidArgumentError(
794 "role dictionary values are not strings".to_string(),
795 ))
796 })?;
797 let key = role_array.keys().value(row) as usize;
798 role_values.value(key).to_string()
799 };
800
801 let bot_id = bot_id_array.and_then(|arr| {
802 if arr.is_null(row) {
803 None
804 } else {
805 Some(arr.value(row).to_string())
806 }
807 });
808
809 let session_id = session_id_array.and_then(|arr| {
810 if arr.is_null(row) {
811 None
812 } else {
813 Some(arr.value(row).to_string())
814 }
815 });
816
817 results.push(ContextRecord {
818 id: id_array.value(row).to_string(),
819 run_id: run_id_array.value(row).to_string(),
820 bot_id,
821 session_id,
822 created_at,
823 role,
824 state_metadata,
825 content_type: content_type_array.value(row).to_string(),
826 text_payload,
827 binary_payload,
828 embedding,
829 });
830 }
831
832 Ok(results)
833}
834
835fn embedding_from_list(list: &FixedSizeListArray, row: usize) -> LanceResult<Vec<f32>> {
836 let values = list.value(row);
837 let float_array = values
838 .as_ref()
839 .as_any()
840 .downcast_ref::<Float32Array>()
841 .ok_or_else(|| {
842 LanceError::from(ArrowError::InvalidArgumentError(
843 "embedding column does not contain float32 values".to_string(),
844 ))
845 })?;
846 let mut embedding = Vec::with_capacity(float_array.len());
847 for idx in 0..float_array.len() {
848 embedding.push(float_array.value(idx));
849 }
850 Ok(embedding)
851}
852
853fn column_as<'a, A>(batch: &'a RecordBatch, name: &str) -> LanceResult<&'a A>
854where
855 A: Array + 'static,
856{
857 let column = batch.column_by_name(name).ok_or_else(|| {
858 LanceError::from(ArrowError::InvalidArgumentError(format!(
859 "column '{name}' not found"
860 )))
861 })?;
862 column.as_ref().as_any().downcast_ref::<A>().ok_or_else(|| {
863 LanceError::from(ArrowError::InvalidArgumentError(format!(
864 "column '{name}' has unexpected data type"
865 )))
866 })
867}
868
869fn column_as_optional<'a, A>(batch: &'a RecordBatch, name: &str) -> Option<&'a A>
870where
871 A: Array + 'static,
872{
873 batch
874 .column_by_name(name)
875 .and_then(|col| col.as_ref().as_any().downcast_ref::<A>())
876}
877
878#[cfg(test)]
879mod tests {
880 use super::*;
881 use crate::serde::CONTENT_TYPE_TEXT;
882 use chrono::Utc;
883 use tempfile::TempDir;
884
885 fn make_embedding(pivot: f32) -> Vec<f32> {
886 let mut values = vec![0.0; DEFAULT_EMBEDDING_DIM as usize];
887 if !values.is_empty() {
888 values[0] = pivot;
889 }
890 values
891 }
892
893 fn text_record(id: &str, embedding_pivot: f32) -> ContextRecord {
894 ContextRecord {
895 id: id.to_string(),
896 run_id: format!("run-{id}"),
897 bot_id: None,
898 session_id: None,
899 created_at: Utc::now(),
900 role: "user".to_string(),
901 state_metadata: Some(StateMetadata {
902 step: Some(1),
903 active_plan_id: Some("plan".to_string()),
904 tokens_used: Some(10),
905 custom: None,
906 }),
907 content_type: CONTENT_TYPE_TEXT.to_string(),
908 text_payload: Some(format!("payload-{id}")),
909 binary_payload: None,
910 embedding: Some(make_embedding(embedding_pivot)),
911 }
912 }
913
914 #[test]
915 fn search_orders_by_distance() {
916 let dir = TempDir::new().unwrap();
917 let uri = dir.path().to_string_lossy().to_string();
918 let runtime = tokio::runtime::Runtime::new().unwrap();
919 runtime.block_on(async {
920 let mut store = ContextStore::open(&uri).await.unwrap();
921 let first = text_record("a", 0.0);
922 let second = text_record("b", 1.0);
923 store.add(&[first.clone(), second.clone()]).await.unwrap();
924
925 let query = make_embedding(1.0);
926 let results = store.search(&query, Some(2)).await.unwrap();
927
928 assert_eq!(results.len(), 2);
929 assert_eq!(results[0].record.id, second.id);
930 assert!(
931 results[0].distance <= results[1].distance,
932 "results not ordered by distance: {:?}",
933 results
934 );
935 });
936 }
937
938 #[test]
939 fn search_validates_query_length() {
940 let dir = TempDir::new().unwrap();
941 let uri = dir.path().to_string_lossy().to_string();
942 let runtime = tokio::runtime::Runtime::new().unwrap();
943 runtime.block_on(async {
944 let store = ContextStore::open(&uri).await.unwrap();
945 let err = store.search(&[0.0_f32], None).await.unwrap_err();
946 let message = err.to_string();
947 assert!(
948 message.contains("embedding dimension"),
949 "unexpected error message: {message}"
950 );
951 });
952 }
953}