use crate::ast::DistanceMetric;
use crate::datafusion_planner::vector_ops;
use crate::error::{GraphError, Result};
use arrow::array::{Array, ArrayRef, Float32Array, UInt32Array};
use arrow::compute::take;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct VectorSearch {
column: String,
query_vector: Option<Vec<f32>>,
metric: DistanceMetric,
top_k: usize,
include_distance: bool,
distance_column_name: String,
}
impl VectorSearch {
pub fn new(column: &str) -> Self {
Self {
column: column.to_string(),
query_vector: None,
metric: DistanceMetric::L2,
top_k: 10,
include_distance: true,
distance_column_name: "_distance".to_string(),
}
}
pub fn query_vector(mut self, vec: Vec<f32>) -> Self {
self.query_vector = Some(vec);
self
}
pub fn metric(mut self, metric: DistanceMetric) -> Self {
self.metric = metric;
self
}
pub fn top_k(mut self, k: usize) -> Self {
self.top_k = k;
self
}
pub fn include_distance(mut self, include: bool) -> Self {
self.include_distance = include;
self
}
pub fn distance_column_name(mut self, name: &str) -> Self {
self.distance_column_name = name.to_string();
self
}
pub fn column(&self) -> &str {
&self.column
}
pub fn get_query_vector(&self) -> Option<&[f32]> {
self.query_vector.as_deref()
}
pub fn get_metric(&self) -> &DistanceMetric {
&self.metric
}
pub fn get_top_k(&self) -> usize {
self.top_k
}
pub async fn search(&self, data: &RecordBatch) -> Result<RecordBatch> {
let query_vector = self
.query_vector
.as_ref()
.ok_or_else(|| GraphError::ConfigError {
message: "Query vector is required for search".to_string(),
location: snafu::Location::new(file!(), line!(), column!()),
})?;
let schema = data.schema();
let column_idx = schema
.index_of(&self.column)
.map_err(|_| GraphError::ConfigError {
message: format!("Vector column '{}' not found in data", self.column),
location: snafu::Location::new(file!(), line!(), column!()),
})?;
let vector_column = data.column(column_idx);
let vectors = vector_ops::extract_vectors(vector_column)?;
let distances = vector_ops::compute_vector_distances(&vectors, query_vector, &self.metric);
let top_k_indices = self.get_top_k_indices(&distances);
self.build_result_batch(data, &top_k_indices, &distances)
}
pub async fn search_lance(&self, dataset: &lance::Dataset) -> Result<RecordBatch> {
use arrow::compute::concat_batches;
use futures::TryStreamExt;
let query_vector = self
.query_vector
.as_ref()
.ok_or_else(|| GraphError::ConfigError {
message: "Query vector is required for search".to_string(),
location: snafu::Location::new(file!(), line!(), column!()),
})?;
let lance_metric = match self.metric {
DistanceMetric::L2 => lance_linalg::distance::DistanceType::L2,
DistanceMetric::Cosine => lance_linalg::distance::DistanceType::Cosine,
DistanceMetric::Dot => lance_linalg::distance::DistanceType::Dot,
};
let query_array = Float32Array::from(query_vector.clone());
let mut scanner = dataset.scan();
scanner
.nearest(&self.column, &query_array as &dyn Array, self.top_k)
.map_err(|e| GraphError::ExecutionError {
message: format!("Failed to configure nearest neighbor search: {}", e),
location: snafu::Location::new(file!(), line!(), column!()),
})?
.distance_metric(lance_metric);
let stream = scanner
.try_into_stream()
.await
.map_err(|e| GraphError::ExecutionError {
message: format!("Failed to create scan stream: {}", e),
location: snafu::Location::new(file!(), line!(), column!()),
})?;
let batches: Vec<RecordBatch> =
stream
.try_collect()
.await
.map_err(|e| GraphError::ExecutionError {
message: format!("Failed to collect scan results: {}", e),
location: snafu::Location::new(file!(), line!(), column!()),
})?;
if batches.is_empty() {
let lance_schema = dataset.schema();
let arrow_schema: Schema = lance_schema.into();
return Ok(RecordBatch::new_empty(Arc::new(arrow_schema)));
}
let schema = batches[0].schema();
concat_batches(&schema, &batches).map_err(|e| GraphError::ExecutionError {
message: format!("Failed to concatenate result batches: {}", e),
location: snafu::Location::new(file!(), line!(), column!()),
})
}
fn get_top_k_indices(&self, distances: &[f32]) -> Vec<u32> {
let mut indexed: Vec<(usize, f32)> = distances.iter().cloned().enumerate().collect();
indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
indexed
.into_iter()
.take(self.top_k)
.map(|(idx, _)| idx as u32)
.collect()
}
fn build_result_batch(
&self,
data: &RecordBatch,
indices: &[u32],
distances: &[f32],
) -> Result<RecordBatch> {
let indices_array = UInt32Array::from(indices.to_vec());
let mut columns: Vec<ArrayRef> = Vec::with_capacity(data.num_columns() + 1);
let mut fields: Vec<Field> = Vec::with_capacity(data.num_columns() + 1);
for (i, field) in data.schema().fields().iter().enumerate() {
let column = data.column(i);
let taken = take(column.as_ref(), &indices_array, None).map_err(|e| {
GraphError::ExecutionError {
message: format!("Failed to select rows: {}", e),
location: snafu::Location::new(file!(), line!(), column!()),
}
})?;
columns.push(taken);
fields.push(field.as_ref().clone());
}
if self.include_distance {
let selected_distances: Vec<f32> =
indices.iter().map(|&i| distances[i as usize]).collect();
let distance_array = Arc::new(Float32Array::from(selected_distances)) as ArrayRef;
columns.push(distance_array);
fields.push(Field::new(
&self.distance_column_name,
DataType::Float32,
false,
));
}
let schema = Arc::new(Schema::new(fields));
RecordBatch::try_new(schema, columns).map_err(|e| GraphError::ExecutionError {
message: format!("Failed to create result batch: {}", e),
location: snafu::Location::new(file!(), line!(), column!()),
})
}
}
#[derive(Debug)]
pub struct VectorSearchResult {
pub data: RecordBatch,
pub used_ann_index: bool,
pub vectors_scanned: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{FixedSizeListArray, Int64Array, StringArray};
use arrow::datatypes::FieldRef;
fn create_test_batch() -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, false),
Field::new(
"embedding",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 3),
false,
),
]));
let embedding_data = vec![
1.0, 0.0, 0.0, 0.9, 0.1, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.5, 0.5, 0.0, ];
let field = Arc::new(Field::new("item", DataType::Float32, true)) as FieldRef;
let values = Arc::new(Float32Array::from(embedding_data));
let embeddings = FixedSizeListArray::try_new(field, 3, values, None).unwrap();
RecordBatch::try_new(
schema,
vec![
Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])),
Arc::new(StringArray::from(vec![
"Alice", "Bob", "Carol", "David", "Eve",
])),
Arc::new(embeddings),
],
)
.unwrap()
}
#[tokio::test]
async fn test_vector_search_basic() {
let batch = create_test_batch();
let results = VectorSearch::new("embedding")
.query_vector(vec![1.0, 0.0, 0.0])
.metric(DistanceMetric::L2)
.top_k(3)
.search(&batch)
.await
.unwrap();
assert_eq!(results.num_rows(), 3);
let names = results
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(names.value(0), "Alice");
assert_eq!(names.value(1), "Bob");
}
#[tokio::test]
async fn test_vector_search_cosine() {
let batch = create_test_batch();
let results = VectorSearch::new("embedding")
.query_vector(vec![1.0, 0.0, 0.0])
.metric(DistanceMetric::Cosine)
.top_k(2)
.search(&batch)
.await
.unwrap();
assert_eq!(results.num_rows(), 2);
let names = results
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(names.value(0), "Alice");
}
#[tokio::test]
async fn test_vector_search_with_distance() {
let batch = create_test_batch();
let results = VectorSearch::new("embedding")
.query_vector(vec![1.0, 0.0, 0.0])
.metric(DistanceMetric::L2)
.top_k(2)
.include_distance(true)
.search(&batch)
.await
.unwrap();
assert_eq!(results.num_columns(), 4);
let schema = results.schema();
assert!(schema.field_with_name("_distance").is_ok());
let distances = results
.column(3)
.as_any()
.downcast_ref::<Float32Array>()
.unwrap();
assert_eq!(distances.value(0), 0.0);
}
#[tokio::test]
async fn test_vector_search_without_distance() {
let batch = create_test_batch();
let results = VectorSearch::new("embedding")
.query_vector(vec![1.0, 0.0, 0.0])
.metric(DistanceMetric::L2)
.top_k(2)
.include_distance(false)
.search(&batch)
.await
.unwrap();
assert_eq!(results.num_columns(), 3);
}
#[tokio::test]
async fn test_vector_search_custom_distance_column() {
let batch = create_test_batch();
let results = VectorSearch::new("embedding")
.query_vector(vec![1.0, 0.0, 0.0])
.metric(DistanceMetric::L2)
.top_k(2)
.distance_column_name("similarity_score")
.search(&batch)
.await
.unwrap();
let schema = results.schema();
assert!(schema.field_with_name("similarity_score").is_ok());
}
#[tokio::test]
async fn test_vector_search_missing_query() {
let batch = create_test_batch();
let result = VectorSearch::new("embedding")
.metric(DistanceMetric::L2)
.top_k(2)
.search(&batch)
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_vector_search_missing_column() {
let batch = create_test_batch();
let result = VectorSearch::new("nonexistent")
.query_vector(vec![1.0, 0.0, 0.0])
.top_k(2)
.search(&batch)
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_vector_search_top_k_larger_than_data() {
let batch = create_test_batch();
let results = VectorSearch::new("embedding")
.query_vector(vec![1.0, 0.0, 0.0])
.metric(DistanceMetric::L2)
.top_k(100) .search(&batch)
.await
.unwrap();
assert_eq!(results.num_rows(), 5);
}
}