Skip to main content

lance_graph/
lance_vector_search.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Lance Vector Search API for lance-graph
5//!
6//! This module provides a flexible API for vector similarity search that can work with:
7//! - In-memory RecordBatches (brute-force search)
8//! - Lance datasets (ANN search with indices)
9//!
10//! This is distinct from the UDF-based vector search (`vector_distance()`, `vector_similarity()`)
11//! which is integrated into Cypher queries. This API provides explicit two-step search for
12//! GraphRAG workflows where you want to:
13//! 1. Use Cypher for graph traversal and filtering
14//! 2. Use VectorSearch for similarity ranking with Lance ANN indices
15//!
16//! # Example
17//!
18//! ```ignore
19//! use lance_graph::lance_vector_search::VectorSearch;
20//! use lance_graph::ast::DistanceMetric;
21//!
22//! // Step 1: Run Cypher query to get candidates
23//! let candidates = query.execute(datasets, None).await?;
24//!
25//! // Step 2: Rerank by vector similarity
26//! let results = VectorSearch::new("embedding")
27//!     .query_vector(vec![0.1, 0.2, 0.3])
28//!     .metric(DistanceMetric::Cosine)
29//!     .top_k(10)
30//!     .search(&candidates)
31//!     .await?;
32//! ```
33
34use crate::ast::DistanceMetric;
35use crate::datafusion_planner::vector_ops;
36use crate::error::{GraphError, Result};
37use arrow::array::{Array, ArrayRef, Float32Array, UInt32Array};
38use arrow::compute::take;
39use arrow::datatypes::{DataType, Field, Schema};
40use arrow::record_batch::RecordBatch;
41use std::sync::Arc;
42
43/// Builder for vector similarity search operations
44///
45/// Supports both brute-force search on RecordBatches and ANN search on Lance datasets.
46#[derive(Debug, Clone)]
47pub struct VectorSearch {
48    /// Name of the vector column to search
49    column: String,
50    /// Query vector for similarity computation
51    query_vector: Option<Vec<f32>>,
52    /// Distance metric (L2, Cosine, Dot)
53    metric: DistanceMetric,
54    /// Number of results to return
55    top_k: usize,
56    /// Whether to include distance/similarity scores in output
57    include_distance: bool,
58    /// Name for the distance column (default: "_distance")
59    distance_column_name: String,
60}
61
62impl VectorSearch {
63    /// Create a new VectorSearch builder for the specified column
64    ///
65    /// # Arguments
66    /// * `column` - Name of the vector column in the data
67    ///
68    /// # Example
69    /// ```ignore
70    /// let search = VectorSearch::new("embedding");
71    /// ```
72    pub fn new(column: &str) -> Self {
73        Self {
74            column: column.to_string(),
75            query_vector: None,
76            metric: DistanceMetric::L2,
77            top_k: 10,
78            include_distance: true,
79            distance_column_name: "_distance".to_string(),
80        }
81    }
82
83    /// Set the query vector for similarity search
84    ///
85    /// # Arguments
86    /// * `vec` - The query vector (must match dimension of data vectors)
87    pub fn query_vector(mut self, vec: Vec<f32>) -> Self {
88        self.query_vector = Some(vec);
89        self
90    }
91
92    /// Set the distance metric
93    ///
94    /// # Arguments
95    /// * `metric` - Distance metric (L2, Cosine, or Dot)
96    pub fn metric(mut self, metric: DistanceMetric) -> Self {
97        self.metric = metric;
98        self
99    }
100
101    /// Set the number of results to return
102    ///
103    /// # Arguments
104    /// * `k` - Maximum number of results
105    pub fn top_k(mut self, k: usize) -> Self {
106        self.top_k = k;
107        self
108    }
109
110    /// Whether to include distance scores in the output
111    ///
112    /// # Arguments
113    /// * `include` - If true, adds a distance column to results
114    pub fn include_distance(mut self, include: bool) -> Self {
115        self.include_distance = include;
116        self
117    }
118
119    /// Set the name for the distance column
120    ///
121    /// # Arguments
122    /// * `name` - Column name for distance values (default: "_distance")
123    pub fn distance_column_name(mut self, name: &str) -> Self {
124        self.distance_column_name = name.to_string();
125        self
126    }
127
128    // Getters for accessing internal state (used by Python bindings)
129
130    /// Get the column name
131    pub fn column(&self) -> &str {
132        &self.column
133    }
134
135    /// Get the query vector if set
136    pub fn get_query_vector(&self) -> Option<&[f32]> {
137        self.query_vector.as_deref()
138    }
139
140    /// Get the distance metric
141    pub fn get_metric(&self) -> &DistanceMetric {
142        &self.metric
143    }
144
145    /// Get the top_k value
146    pub fn get_top_k(&self) -> usize {
147        self.top_k
148    }
149
150    /// Perform brute-force vector search on a RecordBatch
151    ///
152    /// This method computes distances for all vectors in the batch and returns
153    /// the top-k results sorted by distance (ascending).
154    ///
155    /// # Arguments
156    /// * `data` - RecordBatch containing the vector column
157    ///
158    /// # Returns
159    /// A new RecordBatch with the top-k rows, optionally including a distance column
160    ///
161    /// # Example
162    /// ```ignore
163    /// let results = VectorSearch::new("embedding")
164    ///     .query_vector(query_vec)
165    ///     .metric(DistanceMetric::Cosine)
166    ///     .top_k(10)
167    ///     .search(&candidates)
168    ///     .await?;
169    /// ```
170    pub async fn search(&self, data: &RecordBatch) -> Result<RecordBatch> {
171        let query_vector = self
172            .query_vector
173            .as_ref()
174            .ok_or_else(|| GraphError::ConfigError {
175                message: "Query vector is required for search".to_string(),
176                location: snafu::Location::new(file!(), line!(), column!()),
177            })?;
178
179        // Find the vector column
180        let schema = data.schema();
181        let column_idx = schema
182            .index_of(&self.column)
183            .map_err(|_| GraphError::ConfigError {
184                message: format!("Vector column '{}' not found in data", self.column),
185                location: snafu::Location::new(file!(), line!(), column!()),
186            })?;
187
188        let vector_column = data.column(column_idx);
189
190        // Extract vectors and compute distances
191        let vectors = vector_ops::extract_vectors(vector_column)?;
192        let distances = vector_ops::compute_vector_distances(&vectors, query_vector, &self.metric);
193
194        // Get top-k indices (sorted by distance ascending)
195        let top_k_indices = self.get_top_k_indices(&distances);
196
197        // Build result batch
198        self.build_result_batch(data, &top_k_indices, &distances)
199    }
200
201    /// Perform ANN vector search on a Lance dataset
202    ///
203    /// This method uses Lance's native ANN search via `scan().nearest()`,
204    /// which leverages vector indices (IVF_PQ, IVF_HNSW, etc.) when available.
205    ///
206    /// # Arguments
207    /// * `dataset` - Lance dataset with vector column
208    ///
209    /// # Returns
210    /// A RecordBatch with the top-k nearest neighbors
211    ///
212    /// # Example
213    /// ```ignore
214    /// let dataset = lance::Dataset::open("data.lance").await?;
215    /// let results = VectorSearch::new("embedding")
216    ///     .query_vector(query_vec)
217    ///     .metric(DistanceMetric::L2)
218    ///     .top_k(10)
219    ///     .search_lance(&dataset)
220    ///     .await?;
221    /// ```
222    pub async fn search_lance(&self, dataset: &lance::Dataset) -> Result<RecordBatch> {
223        use arrow::compute::concat_batches;
224        use futures::TryStreamExt;
225
226        let query_vector = self
227            .query_vector
228            .as_ref()
229            .ok_or_else(|| GraphError::ConfigError {
230                message: "Query vector is required for search".to_string(),
231                location: snafu::Location::new(file!(), line!(), column!()),
232            })?;
233
234        // Convert metric to Lance's DistanceType
235        let lance_metric = match self.metric {
236            DistanceMetric::L2 => lance_linalg::distance::DistanceType::L2,
237            DistanceMetric::Cosine => lance_linalg::distance::DistanceType::Cosine,
238            DistanceMetric::Dot => lance_linalg::distance::DistanceType::Dot,
239        };
240
241        // Create query array
242        let query_array = Float32Array::from(query_vector.clone());
243
244        // Build scanner with ANN search
245        let mut scanner = dataset.scan();
246        scanner
247            .nearest(&self.column, &query_array as &dyn Array, self.top_k)
248            .map_err(|e| GraphError::ExecutionError {
249                message: format!("Failed to configure nearest neighbor search: {}", e),
250                location: snafu::Location::new(file!(), line!(), column!()),
251            })?
252            .distance_metric(lance_metric);
253
254        // Execute scan and collect results
255        let stream = scanner
256            .try_into_stream()
257            .await
258            .map_err(|e| GraphError::ExecutionError {
259                message: format!("Failed to create scan stream: {}", e),
260                location: snafu::Location::new(file!(), line!(), column!()),
261            })?;
262
263        let batches: Vec<RecordBatch> =
264            stream
265                .try_collect()
266                .await
267                .map_err(|e| GraphError::ExecutionError {
268                    message: format!("Failed to collect scan results: {}", e),
269                    location: snafu::Location::new(file!(), line!(), column!()),
270                })?;
271
272        if batches.is_empty() {
273            // Return empty batch with dataset schema
274            let lance_schema = dataset.schema();
275            let arrow_schema: Schema = lance_schema.into();
276            return Ok(RecordBatch::new_empty(Arc::new(arrow_schema)));
277        }
278
279        // Concatenate batches
280        let schema = batches[0].schema();
281        concat_batches(&schema, &batches).map_err(|e| GraphError::ExecutionError {
282            message: format!("Failed to concatenate result batches: {}", e),
283            location: snafu::Location::new(file!(), line!(), column!()),
284        })
285    }
286
287    /// Get indices of top-k smallest distances
288    fn get_top_k_indices(&self, distances: &[f32]) -> Vec<u32> {
289        // Create (index, distance) pairs
290        let mut indexed: Vec<(usize, f32)> = distances.iter().cloned().enumerate().collect();
291
292        // Sort by distance (ascending)
293        indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
294
295        // Take top-k indices
296        indexed
297            .into_iter()
298            .take(self.top_k)
299            .map(|(idx, _)| idx as u32)
300            .collect()
301    }
302
303    /// Build result RecordBatch from original data and top-k indices
304    fn build_result_batch(
305        &self,
306        data: &RecordBatch,
307        indices: &[u32],
308        distances: &[f32],
309    ) -> Result<RecordBatch> {
310        let indices_array = UInt32Array::from(indices.to_vec());
311
312        // Take rows from each column
313        let mut columns: Vec<ArrayRef> = Vec::with_capacity(data.num_columns() + 1);
314        let mut fields: Vec<Field> = Vec::with_capacity(data.num_columns() + 1);
315
316        for (i, field) in data.schema().fields().iter().enumerate() {
317            let column = data.column(i);
318            let taken = take(column.as_ref(), &indices_array, None).map_err(|e| {
319                GraphError::ExecutionError {
320                    message: format!("Failed to select rows: {}", e),
321                    location: snafu::Location::new(file!(), line!(), column!()),
322                }
323            })?;
324            columns.push(taken);
325            fields.push(field.as_ref().clone());
326        }
327
328        // Add distance column if requested
329        if self.include_distance {
330            let selected_distances: Vec<f32> =
331                indices.iter().map(|&i| distances[i as usize]).collect();
332            let distance_array = Arc::new(Float32Array::from(selected_distances)) as ArrayRef;
333            columns.push(distance_array);
334            fields.push(Field::new(
335                &self.distance_column_name,
336                DataType::Float32,
337                false,
338            ));
339        }
340
341        let schema = Arc::new(Schema::new(fields));
342        RecordBatch::try_new(schema, columns).map_err(|e| GraphError::ExecutionError {
343            message: format!("Failed to create result batch: {}", e),
344            location: snafu::Location::new(file!(), line!(), column!()),
345        })
346    }
347}
348
349/// Result of a vector search operation with metadata
350#[derive(Debug)]
351pub struct VectorSearchResult {
352    /// The result data
353    pub data: RecordBatch,
354    /// Whether ANN index was used (vs brute-force)
355    pub used_ann_index: bool,
356    /// Number of vectors scanned (for brute-force)
357    pub vectors_scanned: usize,
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363    use arrow::array::{FixedSizeListArray, Int64Array, StringArray};
364    use arrow::datatypes::FieldRef;
365
366    fn create_test_batch() -> RecordBatch {
367        // Create schema with 3D embeddings
368        let schema = Arc::new(Schema::new(vec![
369            Field::new("id", DataType::Int64, false),
370            Field::new("name", DataType::Utf8, false),
371            Field::new(
372                "embedding",
373                DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 3),
374                false,
375            ),
376        ]));
377
378        // Create test data with embeddings
379        let embedding_data = vec![
380            1.0, 0.0, 0.0, // Alice - closest to [1,0,0]
381            0.9, 0.1, 0.0, // Bob - second closest
382            0.0, 1.0, 0.0, // Carol - orthogonal
383            0.0, 0.0, 1.0, // David - orthogonal
384            0.5, 0.5, 0.0, // Eve - medium distance
385        ];
386
387        let field = Arc::new(Field::new("item", DataType::Float32, true)) as FieldRef;
388        let values = Arc::new(Float32Array::from(embedding_data));
389        let embeddings = FixedSizeListArray::try_new(field, 3, values, None).unwrap();
390
391        RecordBatch::try_new(
392            schema,
393            vec![
394                Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])),
395                Arc::new(StringArray::from(vec![
396                    "Alice", "Bob", "Carol", "David", "Eve",
397                ])),
398                Arc::new(embeddings),
399            ],
400        )
401        .unwrap()
402    }
403
404    #[tokio::test]
405    async fn test_vector_search_basic() {
406        let batch = create_test_batch();
407
408        let results = VectorSearch::new("embedding")
409            .query_vector(vec![1.0, 0.0, 0.0])
410            .metric(DistanceMetric::L2)
411            .top_k(3)
412            .search(&batch)
413            .await
414            .unwrap();
415
416        assert_eq!(results.num_rows(), 3);
417
418        // Check that Alice is first (closest to [1,0,0])
419        let names = results
420            .column(1)
421            .as_any()
422            .downcast_ref::<StringArray>()
423            .unwrap();
424        assert_eq!(names.value(0), "Alice");
425        assert_eq!(names.value(1), "Bob");
426    }
427
428    #[tokio::test]
429    async fn test_vector_search_cosine() {
430        let batch = create_test_batch();
431
432        let results = VectorSearch::new("embedding")
433            .query_vector(vec![1.0, 0.0, 0.0])
434            .metric(DistanceMetric::Cosine)
435            .top_k(2)
436            .search(&batch)
437            .await
438            .unwrap();
439
440        assert_eq!(results.num_rows(), 2);
441
442        let names = results
443            .column(1)
444            .as_any()
445            .downcast_ref::<StringArray>()
446            .unwrap();
447        assert_eq!(names.value(0), "Alice");
448    }
449
450    #[tokio::test]
451    async fn test_vector_search_with_distance() {
452        let batch = create_test_batch();
453
454        let results = VectorSearch::new("embedding")
455            .query_vector(vec![1.0, 0.0, 0.0])
456            .metric(DistanceMetric::L2)
457            .top_k(2)
458            .include_distance(true)
459            .search(&batch)
460            .await
461            .unwrap();
462
463        // Should have original columns + distance column
464        assert_eq!(results.num_columns(), 4);
465
466        // Check distance column exists and has correct name
467        let schema = results.schema();
468        assert!(schema.field_with_name("_distance").is_ok());
469
470        // First result should have distance 0 (identical vector)
471        let distances = results
472            .column(3)
473            .as_any()
474            .downcast_ref::<Float32Array>()
475            .unwrap();
476        assert_eq!(distances.value(0), 0.0);
477    }
478
479    #[tokio::test]
480    async fn test_vector_search_without_distance() {
481        let batch = create_test_batch();
482
483        let results = VectorSearch::new("embedding")
484            .query_vector(vec![1.0, 0.0, 0.0])
485            .metric(DistanceMetric::L2)
486            .top_k(2)
487            .include_distance(false)
488            .search(&batch)
489            .await
490            .unwrap();
491
492        // Should have only original columns
493        assert_eq!(results.num_columns(), 3);
494    }
495
496    #[tokio::test]
497    async fn test_vector_search_custom_distance_column() {
498        let batch = create_test_batch();
499
500        let results = VectorSearch::new("embedding")
501            .query_vector(vec![1.0, 0.0, 0.0])
502            .metric(DistanceMetric::L2)
503            .top_k(2)
504            .distance_column_name("similarity_score")
505            .search(&batch)
506            .await
507            .unwrap();
508
509        let schema = results.schema();
510        assert!(schema.field_with_name("similarity_score").is_ok());
511    }
512
513    #[tokio::test]
514    async fn test_vector_search_missing_query() {
515        let batch = create_test_batch();
516
517        let result = VectorSearch::new("embedding")
518            .metric(DistanceMetric::L2)
519            .top_k(2)
520            .search(&batch)
521            .await;
522
523        assert!(result.is_err());
524    }
525
526    #[tokio::test]
527    async fn test_vector_search_missing_column() {
528        let batch = create_test_batch();
529
530        let result = VectorSearch::new("nonexistent")
531            .query_vector(vec![1.0, 0.0, 0.0])
532            .top_k(2)
533            .search(&batch)
534            .await;
535
536        assert!(result.is_err());
537    }
538
539    #[tokio::test]
540    async fn test_vector_search_top_k_larger_than_data() {
541        let batch = create_test_batch();
542
543        let results = VectorSearch::new("embedding")
544            .query_vector(vec![1.0, 0.0, 0.0])
545            .metric(DistanceMetric::L2)
546            .top_k(100) // More than 5 rows
547            .search(&batch)
548            .await
549            .unwrap();
550
551        // Should return all 5 rows
552        assert_eq!(results.num_rows(), 5);
553    }
554}