Skip to main content

lance_graph/
lance_vector_search.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Lance Vector Search API for lance-graph
5//!
6//! This module provides a flexible API for vector similarity search that can work with:
7//! - In-memory RecordBatches (brute-force search)
8//! - Lance datasets (ANN search with indices)
9//!
10//! This is distinct from the UDF-based vector search (`vector_distance()`, `vector_similarity()`)
11//! which is integrated into Cypher queries. This API provides explicit two-step search for
12//! GraphRAG workflows where you want to:
13//! 1. Use Cypher for graph traversal and filtering
14//! 2. Use VectorSearch for similarity ranking with Lance ANN indices
15//!
16//! # Example
17//!
18//! ```ignore
19//! use lance_graph::lance_vector_search::VectorSearch;
20//! use lance_graph::ast::DistanceMetric;
21//!
22//! // Step 1: Run Cypher query to get candidates
23//! let candidates = query.execute(datasets, None).await?;
24//!
25//! // Step 2: Rerank by vector similarity
26//! let results = VectorSearch::new("embedding")
27//!     .query_vector(vec![0.1, 0.2, 0.3])
28//!     .metric(DistanceMetric::Cosine)
29//!     .top_k(10)
30//!     .search(&candidates)
31//!     .await?;
32//! ```
33
34use crate::ast::DistanceMetric;
35use crate::datafusion_planner::vector_ops;
36use crate::error::{GraphError, Result};
37use arrow::array::{Array, ArrayRef, Float32Array, UInt32Array};
38use arrow::compute::take;
39use arrow::datatypes::{DataType, Field, Schema};
40use arrow::record_batch::RecordBatch;
41use std::sync::Arc;
42
43/// Builder for vector similarity search operations
44///
45/// Supports both brute-force search on RecordBatches and ANN search on Lance datasets.
46#[derive(Debug, Clone)]
47pub struct VectorSearch {
48    /// Name of the vector column to search
49    column: String,
50    /// Query vector for similarity computation
51    query_vector: Option<Vec<f32>>,
52    /// Distance metric (L2, Cosine, Dot)
53    metric: DistanceMetric,
54    /// Number of results to return
55    top_k: usize,
56    /// Whether to include distance/similarity scores in output
57    include_distance: bool,
58    /// Name for the distance column (default: "_distance")
59    distance_column_name: String,
60}
61
62impl VectorSearch {
63    /// Create a new VectorSearch builder for the specified column
64    ///
65    /// # Arguments
66    /// * `column` - Name of the vector column in the data
67    ///
68    /// # Example
69    /// ```ignore
70    /// let search = VectorSearch::new("embedding");
71    /// ```
72    pub fn new(column: &str) -> Self {
73        Self {
74            column: column.to_string(),
75            query_vector: None,
76            metric: DistanceMetric::L2,
77            top_k: 10,
78            include_distance: true,
79            distance_column_name: "_distance".to_string(),
80        }
81    }
82
83    /// Set the query vector for similarity search
84    ///
85    /// # Arguments
86    /// * `vec` - The query vector (must match dimension of data vectors)
87    pub fn query_vector(mut self, vec: Vec<f32>) -> Self {
88        self.query_vector = Some(vec);
89        self
90    }
91
92    /// Set the distance metric
93    ///
94    /// # Arguments
95    /// * `metric` - Distance metric (L2, Cosine, or Dot)
96    pub fn metric(mut self, metric: DistanceMetric) -> Self {
97        self.metric = metric;
98        self
99    }
100
101    /// Set the number of results to return
102    ///
103    /// # Arguments
104    /// * `k` - Maximum number of results
105    pub fn top_k(mut self, k: usize) -> Self {
106        self.top_k = k;
107        self
108    }
109
110    /// Whether to include distance scores in the output
111    ///
112    /// # Arguments
113    /// * `include` - If true, adds a distance column to results
114    pub fn include_distance(mut self, include: bool) -> Self {
115        self.include_distance = include;
116        self
117    }
118
119    /// Set the name for the distance column
120    ///
121    /// # Arguments
122    /// * `name` - Column name for distance values (default: "_distance")
123    pub fn distance_column_name(mut self, name: &str) -> Self {
124        self.distance_column_name = name.to_string();
125        self
126    }
127
128    /// Perform brute-force vector search on a RecordBatch
129    ///
130    /// This method computes distances for all vectors in the batch and returns
131    /// the top-k results sorted by distance (ascending).
132    ///
133    /// # Arguments
134    /// * `data` - RecordBatch containing the vector column
135    ///
136    /// # Returns
137    /// A new RecordBatch with the top-k rows, optionally including a distance column
138    ///
139    /// # Example
140    /// ```ignore
141    /// let results = VectorSearch::new("embedding")
142    ///     .query_vector(query_vec)
143    ///     .metric(DistanceMetric::Cosine)
144    ///     .top_k(10)
145    ///     .search(&candidates)
146    ///     .await?;
147    /// ```
148    pub async fn search(&self, data: &RecordBatch) -> Result<RecordBatch> {
149        let query_vector = self
150            .query_vector
151            .as_ref()
152            .ok_or_else(|| GraphError::ConfigError {
153                message: "Query vector is required for search".to_string(),
154                location: snafu::Location::new(file!(), line!(), column!()),
155            })?;
156
157        // Find the vector column
158        let schema = data.schema();
159        let column_idx = schema
160            .index_of(&self.column)
161            .map_err(|_| GraphError::ConfigError {
162                message: format!("Vector column '{}' not found in data", self.column),
163                location: snafu::Location::new(file!(), line!(), column!()),
164            })?;
165
166        let vector_column = data.column(column_idx);
167
168        // Extract vectors and compute distances
169        let vectors = vector_ops::extract_vectors(vector_column)?;
170        let distances = vector_ops::compute_vector_distances(&vectors, query_vector, &self.metric);
171
172        // Get top-k indices (sorted by distance ascending)
173        let top_k_indices = self.get_top_k_indices(&distances);
174
175        // Build result batch
176        self.build_result_batch(data, &top_k_indices, &distances)
177    }
178
179    /// Perform ANN vector search on a Lance dataset
180    ///
181    /// This method uses Lance's native ANN search via `scan().nearest()`,
182    /// which leverages vector indices (IVF_PQ, IVF_HNSW, etc.) when available.
183    ///
184    /// # Arguments
185    /// * `dataset` - Lance dataset with vector column
186    ///
187    /// # Returns
188    /// A RecordBatch with the top-k nearest neighbors
189    ///
190    /// # Example
191    /// ```ignore
192    /// let dataset = lance::Dataset::open("data.lance").await?;
193    /// let results = VectorSearch::new("embedding")
194    ///     .query_vector(query_vec)
195    ///     .metric(DistanceMetric::L2)
196    ///     .top_k(10)
197    ///     .search_lance(&dataset)
198    ///     .await?;
199    /// ```
200    pub async fn search_lance(&self, dataset: &lance::Dataset) -> Result<RecordBatch> {
201        use arrow::compute::concat_batches;
202        use futures::TryStreamExt;
203
204        let query_vector = self
205            .query_vector
206            .as_ref()
207            .ok_or_else(|| GraphError::ConfigError {
208                message: "Query vector is required for search".to_string(),
209                location: snafu::Location::new(file!(), line!(), column!()),
210            })?;
211
212        // Convert metric to Lance's DistanceType
213        let lance_metric = match self.metric {
214            DistanceMetric::L2 => lance_linalg::distance::DistanceType::L2,
215            DistanceMetric::Cosine => lance_linalg::distance::DistanceType::Cosine,
216            DistanceMetric::Dot => lance_linalg::distance::DistanceType::Dot,
217        };
218
219        // Create query array
220        let query_array = Float32Array::from(query_vector.clone());
221
222        // Build scanner with ANN search
223        let mut scanner = dataset.scan();
224        scanner
225            .nearest(&self.column, &query_array as &dyn Array, self.top_k)
226            .map_err(|e| GraphError::ExecutionError {
227                message: format!("Failed to configure nearest neighbor search: {}", e),
228                location: snafu::Location::new(file!(), line!(), column!()),
229            })?
230            .distance_metric(lance_metric);
231
232        // Execute scan and collect results
233        let stream = scanner
234            .try_into_stream()
235            .await
236            .map_err(|e| GraphError::ExecutionError {
237                message: format!("Failed to create scan stream: {}", e),
238                location: snafu::Location::new(file!(), line!(), column!()),
239            })?;
240
241        let batches: Vec<RecordBatch> =
242            stream
243                .try_collect()
244                .await
245                .map_err(|e| GraphError::ExecutionError {
246                    message: format!("Failed to collect scan results: {}", e),
247                    location: snafu::Location::new(file!(), line!(), column!()),
248                })?;
249
250        if batches.is_empty() {
251            // Return empty batch with dataset schema
252            let lance_schema = dataset.schema();
253            let arrow_schema: Schema = lance_schema.into();
254            return Ok(RecordBatch::new_empty(Arc::new(arrow_schema)));
255        }
256
257        // Concatenate batches
258        let schema = batches[0].schema();
259        concat_batches(&schema, &batches).map_err(|e| GraphError::ExecutionError {
260            message: format!("Failed to concatenate result batches: {}", e),
261            location: snafu::Location::new(file!(), line!(), column!()),
262        })
263    }
264
265    /// Get indices of top-k smallest distances
266    fn get_top_k_indices(&self, distances: &[f32]) -> Vec<u32> {
267        // Create (index, distance) pairs
268        let mut indexed: Vec<(usize, f32)> = distances.iter().cloned().enumerate().collect();
269
270        // Sort by distance (ascending)
271        indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
272
273        // Take top-k indices
274        indexed
275            .into_iter()
276            .take(self.top_k)
277            .map(|(idx, _)| idx as u32)
278            .collect()
279    }
280
281    /// Build result RecordBatch from original data and top-k indices
282    fn build_result_batch(
283        &self,
284        data: &RecordBatch,
285        indices: &[u32],
286        distances: &[f32],
287    ) -> Result<RecordBatch> {
288        let indices_array = UInt32Array::from(indices.to_vec());
289
290        // Take rows from each column
291        let mut columns: Vec<ArrayRef> = Vec::with_capacity(data.num_columns() + 1);
292        let mut fields: Vec<Field> = Vec::with_capacity(data.num_columns() + 1);
293
294        for (i, field) in data.schema().fields().iter().enumerate() {
295            let column = data.column(i);
296            let taken = take(column.as_ref(), &indices_array, None).map_err(|e| {
297                GraphError::ExecutionError {
298                    message: format!("Failed to select rows: {}", e),
299                    location: snafu::Location::new(file!(), line!(), column!()),
300                }
301            })?;
302            columns.push(taken);
303            fields.push(field.as_ref().clone());
304        }
305
306        // Add distance column if requested
307        if self.include_distance {
308            let selected_distances: Vec<f32> =
309                indices.iter().map(|&i| distances[i as usize]).collect();
310            let distance_array = Arc::new(Float32Array::from(selected_distances)) as ArrayRef;
311            columns.push(distance_array);
312            fields.push(Field::new(
313                &self.distance_column_name,
314                DataType::Float32,
315                false,
316            ));
317        }
318
319        let schema = Arc::new(Schema::new(fields));
320        RecordBatch::try_new(schema, columns).map_err(|e| GraphError::ExecutionError {
321            message: format!("Failed to create result batch: {}", e),
322            location: snafu::Location::new(file!(), line!(), column!()),
323        })
324    }
325}
326
327/// Result of a vector search operation with metadata
328#[derive(Debug)]
329pub struct VectorSearchResult {
330    /// The result data
331    pub data: RecordBatch,
332    /// Whether ANN index was used (vs brute-force)
333    pub used_ann_index: bool,
334    /// Number of vectors scanned (for brute-force)
335    pub vectors_scanned: usize,
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341    use arrow::array::{FixedSizeListArray, Int64Array, StringArray};
342    use arrow::datatypes::FieldRef;
343
344    fn create_test_batch() -> RecordBatch {
345        // Create schema with 3D embeddings
346        let schema = Arc::new(Schema::new(vec![
347            Field::new("id", DataType::Int64, false),
348            Field::new("name", DataType::Utf8, false),
349            Field::new(
350                "embedding",
351                DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 3),
352                false,
353            ),
354        ]));
355
356        // Create test data with embeddings
357        let embedding_data = vec![
358            1.0, 0.0, 0.0, // Alice - closest to [1,0,0]
359            0.9, 0.1, 0.0, // Bob - second closest
360            0.0, 1.0, 0.0, // Carol - orthogonal
361            0.0, 0.0, 1.0, // David - orthogonal
362            0.5, 0.5, 0.0, // Eve - medium distance
363        ];
364
365        let field = Arc::new(Field::new("item", DataType::Float32, true)) as FieldRef;
366        let values = Arc::new(Float32Array::from(embedding_data));
367        let embeddings = FixedSizeListArray::try_new(field, 3, values, None).unwrap();
368
369        RecordBatch::try_new(
370            schema,
371            vec![
372                Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])),
373                Arc::new(StringArray::from(vec![
374                    "Alice", "Bob", "Carol", "David", "Eve",
375                ])),
376                Arc::new(embeddings),
377            ],
378        )
379        .unwrap()
380    }
381
382    #[tokio::test]
383    async fn test_vector_search_basic() {
384        let batch = create_test_batch();
385
386        let results = VectorSearch::new("embedding")
387            .query_vector(vec![1.0, 0.0, 0.0])
388            .metric(DistanceMetric::L2)
389            .top_k(3)
390            .search(&batch)
391            .await
392            .unwrap();
393
394        assert_eq!(results.num_rows(), 3);
395
396        // Check that Alice is first (closest to [1,0,0])
397        let names = results
398            .column(1)
399            .as_any()
400            .downcast_ref::<StringArray>()
401            .unwrap();
402        assert_eq!(names.value(0), "Alice");
403        assert_eq!(names.value(1), "Bob");
404    }
405
406    #[tokio::test]
407    async fn test_vector_search_cosine() {
408        let batch = create_test_batch();
409
410        let results = VectorSearch::new("embedding")
411            .query_vector(vec![1.0, 0.0, 0.0])
412            .metric(DistanceMetric::Cosine)
413            .top_k(2)
414            .search(&batch)
415            .await
416            .unwrap();
417
418        assert_eq!(results.num_rows(), 2);
419
420        let names = results
421            .column(1)
422            .as_any()
423            .downcast_ref::<StringArray>()
424            .unwrap();
425        assert_eq!(names.value(0), "Alice");
426    }
427
428    #[tokio::test]
429    async fn test_vector_search_with_distance() {
430        let batch = create_test_batch();
431
432        let results = VectorSearch::new("embedding")
433            .query_vector(vec![1.0, 0.0, 0.0])
434            .metric(DistanceMetric::L2)
435            .top_k(2)
436            .include_distance(true)
437            .search(&batch)
438            .await
439            .unwrap();
440
441        // Should have original columns + distance column
442        assert_eq!(results.num_columns(), 4);
443
444        // Check distance column exists and has correct name
445        let schema = results.schema();
446        assert!(schema.field_with_name("_distance").is_ok());
447
448        // First result should have distance 0 (identical vector)
449        let distances = results
450            .column(3)
451            .as_any()
452            .downcast_ref::<Float32Array>()
453            .unwrap();
454        assert_eq!(distances.value(0), 0.0);
455    }
456
457    #[tokio::test]
458    async fn test_vector_search_without_distance() {
459        let batch = create_test_batch();
460
461        let results = VectorSearch::new("embedding")
462            .query_vector(vec![1.0, 0.0, 0.0])
463            .metric(DistanceMetric::L2)
464            .top_k(2)
465            .include_distance(false)
466            .search(&batch)
467            .await
468            .unwrap();
469
470        // Should have only original columns
471        assert_eq!(results.num_columns(), 3);
472    }
473
474    #[tokio::test]
475    async fn test_vector_search_custom_distance_column() {
476        let batch = create_test_batch();
477
478        let results = VectorSearch::new("embedding")
479            .query_vector(vec![1.0, 0.0, 0.0])
480            .metric(DistanceMetric::L2)
481            .top_k(2)
482            .distance_column_name("similarity_score")
483            .search(&batch)
484            .await
485            .unwrap();
486
487        let schema = results.schema();
488        assert!(schema.field_with_name("similarity_score").is_ok());
489    }
490
491    #[tokio::test]
492    async fn test_vector_search_missing_query() {
493        let batch = create_test_batch();
494
495        let result = VectorSearch::new("embedding")
496            .metric(DistanceMetric::L2)
497            .top_k(2)
498            .search(&batch)
499            .await;
500
501        assert!(result.is_err());
502    }
503
504    #[tokio::test]
505    async fn test_vector_search_missing_column() {
506        let batch = create_test_batch();
507
508        let result = VectorSearch::new("nonexistent")
509            .query_vector(vec![1.0, 0.0, 0.0])
510            .top_k(2)
511            .search(&batch)
512            .await;
513
514        assert!(result.is_err());
515    }
516
517    #[tokio::test]
518    async fn test_vector_search_top_k_larger_than_data() {
519        let batch = create_test_batch();
520
521        let results = VectorSearch::new("embedding")
522            .query_vector(vec![1.0, 0.0, 0.0])
523            .metric(DistanceMetric::L2)
524            .top_k(100) // More than 5 rows
525            .search(&batch)
526            .await
527            .unwrap();
528
529        // Should return all 5 rows
530        assert_eq!(results.num_rows(), 5);
531    }
532}