1use crate::ast::DistanceMetric;
35use crate::datafusion_planner::vector_ops;
36use crate::error::{GraphError, Result};
37use arrow::array::{Array, ArrayRef, Float32Array, UInt32Array};
38use arrow::compute::take;
39use arrow::datatypes::{DataType, Field, Schema};
40use arrow::record_batch::RecordBatch;
41use std::sync::Arc;
42
43#[derive(Debug, Clone)]
47pub struct VectorSearch {
48 column: String,
50 query_vector: Option<Vec<f32>>,
52 metric: DistanceMetric,
54 top_k: usize,
56 include_distance: bool,
58 distance_column_name: String,
60}
61
62impl VectorSearch {
63 pub fn new(column: &str) -> Self {
73 Self {
74 column: column.to_string(),
75 query_vector: None,
76 metric: DistanceMetric::L2,
77 top_k: 10,
78 include_distance: true,
79 distance_column_name: "_distance".to_string(),
80 }
81 }
82
83 pub fn query_vector(mut self, vec: Vec<f32>) -> Self {
88 self.query_vector = Some(vec);
89 self
90 }
91
92 pub fn metric(mut self, metric: DistanceMetric) -> Self {
97 self.metric = metric;
98 self
99 }
100
101 pub fn top_k(mut self, k: usize) -> Self {
106 self.top_k = k;
107 self
108 }
109
110 pub fn include_distance(mut self, include: bool) -> Self {
115 self.include_distance = include;
116 self
117 }
118
119 pub fn distance_column_name(mut self, name: &str) -> Self {
124 self.distance_column_name = name.to_string();
125 self
126 }
127
128 pub async fn search(&self, data: &RecordBatch) -> Result<RecordBatch> {
149 let query_vector = self
150 .query_vector
151 .as_ref()
152 .ok_or_else(|| GraphError::ConfigError {
153 message: "Query vector is required for search".to_string(),
154 location: snafu::Location::new(file!(), line!(), column!()),
155 })?;
156
157 let schema = data.schema();
159 let column_idx = schema
160 .index_of(&self.column)
161 .map_err(|_| GraphError::ConfigError {
162 message: format!("Vector column '{}' not found in data", self.column),
163 location: snafu::Location::new(file!(), line!(), column!()),
164 })?;
165
166 let vector_column = data.column(column_idx);
167
168 let vectors = vector_ops::extract_vectors(vector_column)?;
170 let distances = vector_ops::compute_vector_distances(&vectors, query_vector, &self.metric);
171
172 let top_k_indices = self.get_top_k_indices(&distances);
174
175 self.build_result_batch(data, &top_k_indices, &distances)
177 }
178
179 pub async fn search_lance(&self, dataset: &lance::Dataset) -> Result<RecordBatch> {
201 use arrow::compute::concat_batches;
202 use futures::TryStreamExt;
203
204 let query_vector = self
205 .query_vector
206 .as_ref()
207 .ok_or_else(|| GraphError::ConfigError {
208 message: "Query vector is required for search".to_string(),
209 location: snafu::Location::new(file!(), line!(), column!()),
210 })?;
211
212 let lance_metric = match self.metric {
214 DistanceMetric::L2 => lance_linalg::distance::DistanceType::L2,
215 DistanceMetric::Cosine => lance_linalg::distance::DistanceType::Cosine,
216 DistanceMetric::Dot => lance_linalg::distance::DistanceType::Dot,
217 };
218
219 let query_array = Float32Array::from(query_vector.clone());
221
222 let mut scanner = dataset.scan();
224 scanner
225 .nearest(&self.column, &query_array as &dyn Array, self.top_k)
226 .map_err(|e| GraphError::ExecutionError {
227 message: format!("Failed to configure nearest neighbor search: {}", e),
228 location: snafu::Location::new(file!(), line!(), column!()),
229 })?
230 .distance_metric(lance_metric);
231
232 let stream = scanner
234 .try_into_stream()
235 .await
236 .map_err(|e| GraphError::ExecutionError {
237 message: format!("Failed to create scan stream: {}", e),
238 location: snafu::Location::new(file!(), line!(), column!()),
239 })?;
240
241 let batches: Vec<RecordBatch> =
242 stream
243 .try_collect()
244 .await
245 .map_err(|e| GraphError::ExecutionError {
246 message: format!("Failed to collect scan results: {}", e),
247 location: snafu::Location::new(file!(), line!(), column!()),
248 })?;
249
250 if batches.is_empty() {
251 let lance_schema = dataset.schema();
253 let arrow_schema: Schema = lance_schema.into();
254 return Ok(RecordBatch::new_empty(Arc::new(arrow_schema)));
255 }
256
257 let schema = batches[0].schema();
259 concat_batches(&schema, &batches).map_err(|e| GraphError::ExecutionError {
260 message: format!("Failed to concatenate result batches: {}", e),
261 location: snafu::Location::new(file!(), line!(), column!()),
262 })
263 }
264
265 fn get_top_k_indices(&self, distances: &[f32]) -> Vec<u32> {
267 let mut indexed: Vec<(usize, f32)> = distances.iter().cloned().enumerate().collect();
269
270 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
272
273 indexed
275 .into_iter()
276 .take(self.top_k)
277 .map(|(idx, _)| idx as u32)
278 .collect()
279 }
280
281 fn build_result_batch(
283 &self,
284 data: &RecordBatch,
285 indices: &[u32],
286 distances: &[f32],
287 ) -> Result<RecordBatch> {
288 let indices_array = UInt32Array::from(indices.to_vec());
289
290 let mut columns: Vec<ArrayRef> = Vec::with_capacity(data.num_columns() + 1);
292 let mut fields: Vec<Field> = Vec::with_capacity(data.num_columns() + 1);
293
294 for (i, field) in data.schema().fields().iter().enumerate() {
295 let column = data.column(i);
296 let taken = take(column.as_ref(), &indices_array, None).map_err(|e| {
297 GraphError::ExecutionError {
298 message: format!("Failed to select rows: {}", e),
299 location: snafu::Location::new(file!(), line!(), column!()),
300 }
301 })?;
302 columns.push(taken);
303 fields.push(field.as_ref().clone());
304 }
305
306 if self.include_distance {
308 let selected_distances: Vec<f32> =
309 indices.iter().map(|&i| distances[i as usize]).collect();
310 let distance_array = Arc::new(Float32Array::from(selected_distances)) as ArrayRef;
311 columns.push(distance_array);
312 fields.push(Field::new(
313 &self.distance_column_name,
314 DataType::Float32,
315 false,
316 ));
317 }
318
319 let schema = Arc::new(Schema::new(fields));
320 RecordBatch::try_new(schema, columns).map_err(|e| GraphError::ExecutionError {
321 message: format!("Failed to create result batch: {}", e),
322 location: snafu::Location::new(file!(), line!(), column!()),
323 })
324 }
325}
326
327#[derive(Debug)]
329pub struct VectorSearchResult {
330 pub data: RecordBatch,
332 pub used_ann_index: bool,
334 pub vectors_scanned: usize,
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341 use arrow::array::{FixedSizeListArray, Int64Array, StringArray};
342 use arrow::datatypes::FieldRef;
343
344 fn create_test_batch() -> RecordBatch {
345 let schema = Arc::new(Schema::new(vec![
347 Field::new("id", DataType::Int64, false),
348 Field::new("name", DataType::Utf8, false),
349 Field::new(
350 "embedding",
351 DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 3),
352 false,
353 ),
354 ]));
355
356 let embedding_data = vec![
358 1.0, 0.0, 0.0, 0.9, 0.1, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.5, 0.5, 0.0, ];
364
365 let field = Arc::new(Field::new("item", DataType::Float32, true)) as FieldRef;
366 let values = Arc::new(Float32Array::from(embedding_data));
367 let embeddings = FixedSizeListArray::try_new(field, 3, values, None).unwrap();
368
369 RecordBatch::try_new(
370 schema,
371 vec![
372 Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])),
373 Arc::new(StringArray::from(vec![
374 "Alice", "Bob", "Carol", "David", "Eve",
375 ])),
376 Arc::new(embeddings),
377 ],
378 )
379 .unwrap()
380 }
381
382 #[tokio::test]
383 async fn test_vector_search_basic() {
384 let batch = create_test_batch();
385
386 let results = VectorSearch::new("embedding")
387 .query_vector(vec![1.0, 0.0, 0.0])
388 .metric(DistanceMetric::L2)
389 .top_k(3)
390 .search(&batch)
391 .await
392 .unwrap();
393
394 assert_eq!(results.num_rows(), 3);
395
396 let names = results
398 .column(1)
399 .as_any()
400 .downcast_ref::<StringArray>()
401 .unwrap();
402 assert_eq!(names.value(0), "Alice");
403 assert_eq!(names.value(1), "Bob");
404 }
405
406 #[tokio::test]
407 async fn test_vector_search_cosine() {
408 let batch = create_test_batch();
409
410 let results = VectorSearch::new("embedding")
411 .query_vector(vec![1.0, 0.0, 0.0])
412 .metric(DistanceMetric::Cosine)
413 .top_k(2)
414 .search(&batch)
415 .await
416 .unwrap();
417
418 assert_eq!(results.num_rows(), 2);
419
420 let names = results
421 .column(1)
422 .as_any()
423 .downcast_ref::<StringArray>()
424 .unwrap();
425 assert_eq!(names.value(0), "Alice");
426 }
427
428 #[tokio::test]
429 async fn test_vector_search_with_distance() {
430 let batch = create_test_batch();
431
432 let results = VectorSearch::new("embedding")
433 .query_vector(vec![1.0, 0.0, 0.0])
434 .metric(DistanceMetric::L2)
435 .top_k(2)
436 .include_distance(true)
437 .search(&batch)
438 .await
439 .unwrap();
440
441 assert_eq!(results.num_columns(), 4);
443
444 let schema = results.schema();
446 assert!(schema.field_with_name("_distance").is_ok());
447
448 let distances = results
450 .column(3)
451 .as_any()
452 .downcast_ref::<Float32Array>()
453 .unwrap();
454 assert_eq!(distances.value(0), 0.0);
455 }
456
457 #[tokio::test]
458 async fn test_vector_search_without_distance() {
459 let batch = create_test_batch();
460
461 let results = VectorSearch::new("embedding")
462 .query_vector(vec![1.0, 0.0, 0.0])
463 .metric(DistanceMetric::L2)
464 .top_k(2)
465 .include_distance(false)
466 .search(&batch)
467 .await
468 .unwrap();
469
470 assert_eq!(results.num_columns(), 3);
472 }
473
474 #[tokio::test]
475 async fn test_vector_search_custom_distance_column() {
476 let batch = create_test_batch();
477
478 let results = VectorSearch::new("embedding")
479 .query_vector(vec![1.0, 0.0, 0.0])
480 .metric(DistanceMetric::L2)
481 .top_k(2)
482 .distance_column_name("similarity_score")
483 .search(&batch)
484 .await
485 .unwrap();
486
487 let schema = results.schema();
488 assert!(schema.field_with_name("similarity_score").is_ok());
489 }
490
491 #[tokio::test]
492 async fn test_vector_search_missing_query() {
493 let batch = create_test_batch();
494
495 let result = VectorSearch::new("embedding")
496 .metric(DistanceMetric::L2)
497 .top_k(2)
498 .search(&batch)
499 .await;
500
501 assert!(result.is_err());
502 }
503
504 #[tokio::test]
505 async fn test_vector_search_missing_column() {
506 let batch = create_test_batch();
507
508 let result = VectorSearch::new("nonexistent")
509 .query_vector(vec![1.0, 0.0, 0.0])
510 .top_k(2)
511 .search(&batch)
512 .await;
513
514 assert!(result.is_err());
515 }
516
517 #[tokio::test]
518 async fn test_vector_search_top_k_larger_than_data() {
519 let batch = create_test_batch();
520
521 let results = VectorSearch::new("embedding")
522 .query_vector(vec![1.0, 0.0, 0.0])
523 .metric(DistanceMetric::L2)
524 .top_k(100) .search(&batch)
526 .await
527 .unwrap();
528
529 assert_eq!(results.num_rows(), 5);
531 }
532}