1use crate::ast::DistanceMetric;
35use crate::datafusion_planner::vector_ops;
36use crate::error::{GraphError, Result};
37use arrow::array::{Array, ArrayRef, Float32Array, UInt32Array};
38use arrow::compute::take;
39use arrow::datatypes::{DataType, Field, Schema};
40use arrow::record_batch::RecordBatch;
41use std::sync::Arc;
42
43#[derive(Debug, Clone)]
47pub struct VectorSearch {
48 column: String,
50 query_vector: Option<Vec<f32>>,
52 metric: DistanceMetric,
54 top_k: usize,
56 include_distance: bool,
58 distance_column_name: String,
60}
61
62impl VectorSearch {
63 pub fn new(column: &str) -> Self {
73 Self {
74 column: column.to_string(),
75 query_vector: None,
76 metric: DistanceMetric::L2,
77 top_k: 10,
78 include_distance: true,
79 distance_column_name: "_distance".to_string(),
80 }
81 }
82
83 pub fn query_vector(mut self, vec: Vec<f32>) -> Self {
88 self.query_vector = Some(vec);
89 self
90 }
91
92 pub fn metric(mut self, metric: DistanceMetric) -> Self {
97 self.metric = metric;
98 self
99 }
100
101 pub fn top_k(mut self, k: usize) -> Self {
106 self.top_k = k;
107 self
108 }
109
110 pub fn include_distance(mut self, include: bool) -> Self {
115 self.include_distance = include;
116 self
117 }
118
119 pub fn distance_column_name(mut self, name: &str) -> Self {
124 self.distance_column_name = name.to_string();
125 self
126 }
127
128 pub fn column(&self) -> &str {
132 &self.column
133 }
134
135 pub fn get_query_vector(&self) -> Option<&[f32]> {
137 self.query_vector.as_deref()
138 }
139
140 pub fn get_metric(&self) -> &DistanceMetric {
142 &self.metric
143 }
144
145 pub fn get_top_k(&self) -> usize {
147 self.top_k
148 }
149
150 pub async fn search(&self, data: &RecordBatch) -> Result<RecordBatch> {
171 let query_vector = self
172 .query_vector
173 .as_ref()
174 .ok_or_else(|| GraphError::ConfigError {
175 message: "Query vector is required for search".to_string(),
176 location: snafu::Location::new(file!(), line!(), column!()),
177 })?;
178
179 let schema = data.schema();
181 let column_idx = schema
182 .index_of(&self.column)
183 .map_err(|_| GraphError::ConfigError {
184 message: format!("Vector column '{}' not found in data", self.column),
185 location: snafu::Location::new(file!(), line!(), column!()),
186 })?;
187
188 let vector_column = data.column(column_idx);
189
190 let vectors = vector_ops::extract_vectors(vector_column)?;
192 let distances = vector_ops::compute_vector_distances(&vectors, query_vector, &self.metric);
193
194 let top_k_indices = self.get_top_k_indices(&distances);
196
197 self.build_result_batch(data, &top_k_indices, &distances)
199 }
200
201 pub async fn search_lance(&self, dataset: &lance::Dataset) -> Result<RecordBatch> {
223 use arrow::compute::concat_batches;
224 use futures::TryStreamExt;
225
226 let query_vector = self
227 .query_vector
228 .as_ref()
229 .ok_or_else(|| GraphError::ConfigError {
230 message: "Query vector is required for search".to_string(),
231 location: snafu::Location::new(file!(), line!(), column!()),
232 })?;
233
234 let lance_metric = match self.metric {
236 DistanceMetric::L2 => lance_linalg::distance::DistanceType::L2,
237 DistanceMetric::Cosine => lance_linalg::distance::DistanceType::Cosine,
238 DistanceMetric::Dot => lance_linalg::distance::DistanceType::Dot,
239 };
240
241 let query_array = Float32Array::from(query_vector.clone());
243
244 let mut scanner = dataset.scan();
246 scanner
247 .nearest(&self.column, &query_array as &dyn Array, self.top_k)
248 .map_err(|e| GraphError::ExecutionError {
249 message: format!("Failed to configure nearest neighbor search: {}", e),
250 location: snafu::Location::new(file!(), line!(), column!()),
251 })?
252 .distance_metric(lance_metric);
253
254 let stream = scanner
256 .try_into_stream()
257 .await
258 .map_err(|e| GraphError::ExecutionError {
259 message: format!("Failed to create scan stream: {}", e),
260 location: snafu::Location::new(file!(), line!(), column!()),
261 })?;
262
263 let batches: Vec<RecordBatch> =
264 stream
265 .try_collect()
266 .await
267 .map_err(|e| GraphError::ExecutionError {
268 message: format!("Failed to collect scan results: {}", e),
269 location: snafu::Location::new(file!(), line!(), column!()),
270 })?;
271
272 if batches.is_empty() {
273 let lance_schema = dataset.schema();
275 let arrow_schema: Schema = lance_schema.into();
276 return Ok(RecordBatch::new_empty(Arc::new(arrow_schema)));
277 }
278
279 let schema = batches[0].schema();
281 concat_batches(&schema, &batches).map_err(|e| GraphError::ExecutionError {
282 message: format!("Failed to concatenate result batches: {}", e),
283 location: snafu::Location::new(file!(), line!(), column!()),
284 })
285 }
286
287 fn get_top_k_indices(&self, distances: &[f32]) -> Vec<u32> {
289 let mut indexed: Vec<(usize, f32)> = distances.iter().cloned().enumerate().collect();
291
292 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
294
295 indexed
297 .into_iter()
298 .take(self.top_k)
299 .map(|(idx, _)| idx as u32)
300 .collect()
301 }
302
303 fn build_result_batch(
305 &self,
306 data: &RecordBatch,
307 indices: &[u32],
308 distances: &[f32],
309 ) -> Result<RecordBatch> {
310 let indices_array = UInt32Array::from(indices.to_vec());
311
312 let mut columns: Vec<ArrayRef> = Vec::with_capacity(data.num_columns() + 1);
314 let mut fields: Vec<Field> = Vec::with_capacity(data.num_columns() + 1);
315
316 for (i, field) in data.schema().fields().iter().enumerate() {
317 let column = data.column(i);
318 let taken = take(column.as_ref(), &indices_array, None).map_err(|e| {
319 GraphError::ExecutionError {
320 message: format!("Failed to select rows: {}", e),
321 location: snafu::Location::new(file!(), line!(), column!()),
322 }
323 })?;
324 columns.push(taken);
325 fields.push(field.as_ref().clone());
326 }
327
328 if self.include_distance {
330 let selected_distances: Vec<f32> =
331 indices.iter().map(|&i| distances[i as usize]).collect();
332 let distance_array = Arc::new(Float32Array::from(selected_distances)) as ArrayRef;
333 columns.push(distance_array);
334 fields.push(Field::new(
335 &self.distance_column_name,
336 DataType::Float32,
337 false,
338 ));
339 }
340
341 let schema = Arc::new(Schema::new(fields));
342 RecordBatch::try_new(schema, columns).map_err(|e| GraphError::ExecutionError {
343 message: format!("Failed to create result batch: {}", e),
344 location: snafu::Location::new(file!(), line!(), column!()),
345 })
346 }
347}
348
349#[derive(Debug)]
351pub struct VectorSearchResult {
352 pub data: RecordBatch,
354 pub used_ann_index: bool,
356 pub vectors_scanned: usize,
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363 use arrow::array::{FixedSizeListArray, Int64Array, StringArray};
364 use arrow::datatypes::FieldRef;
365
366 fn create_test_batch() -> RecordBatch {
367 let schema = Arc::new(Schema::new(vec![
369 Field::new("id", DataType::Int64, false),
370 Field::new("name", DataType::Utf8, false),
371 Field::new(
372 "embedding",
373 DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 3),
374 false,
375 ),
376 ]));
377
378 let embedding_data = vec![
380 1.0, 0.0, 0.0, 0.9, 0.1, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.5, 0.5, 0.0, ];
386
387 let field = Arc::new(Field::new("item", DataType::Float32, true)) as FieldRef;
388 let values = Arc::new(Float32Array::from(embedding_data));
389 let embeddings = FixedSizeListArray::try_new(field, 3, values, None).unwrap();
390
391 RecordBatch::try_new(
392 schema,
393 vec![
394 Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])),
395 Arc::new(StringArray::from(vec![
396 "Alice", "Bob", "Carol", "David", "Eve",
397 ])),
398 Arc::new(embeddings),
399 ],
400 )
401 .unwrap()
402 }
403
404 #[tokio::test]
405 async fn test_vector_search_basic() {
406 let batch = create_test_batch();
407
408 let results = VectorSearch::new("embedding")
409 .query_vector(vec![1.0, 0.0, 0.0])
410 .metric(DistanceMetric::L2)
411 .top_k(3)
412 .search(&batch)
413 .await
414 .unwrap();
415
416 assert_eq!(results.num_rows(), 3);
417
418 let names = results
420 .column(1)
421 .as_any()
422 .downcast_ref::<StringArray>()
423 .unwrap();
424 assert_eq!(names.value(0), "Alice");
425 assert_eq!(names.value(1), "Bob");
426 }
427
428 #[tokio::test]
429 async fn test_vector_search_cosine() {
430 let batch = create_test_batch();
431
432 let results = VectorSearch::new("embedding")
433 .query_vector(vec![1.0, 0.0, 0.0])
434 .metric(DistanceMetric::Cosine)
435 .top_k(2)
436 .search(&batch)
437 .await
438 .unwrap();
439
440 assert_eq!(results.num_rows(), 2);
441
442 let names = results
443 .column(1)
444 .as_any()
445 .downcast_ref::<StringArray>()
446 .unwrap();
447 assert_eq!(names.value(0), "Alice");
448 }
449
450 #[tokio::test]
451 async fn test_vector_search_with_distance() {
452 let batch = create_test_batch();
453
454 let results = VectorSearch::new("embedding")
455 .query_vector(vec![1.0, 0.0, 0.0])
456 .metric(DistanceMetric::L2)
457 .top_k(2)
458 .include_distance(true)
459 .search(&batch)
460 .await
461 .unwrap();
462
463 assert_eq!(results.num_columns(), 4);
465
466 let schema = results.schema();
468 assert!(schema.field_with_name("_distance").is_ok());
469
470 let distances = results
472 .column(3)
473 .as_any()
474 .downcast_ref::<Float32Array>()
475 .unwrap();
476 assert_eq!(distances.value(0), 0.0);
477 }
478
479 #[tokio::test]
480 async fn test_vector_search_without_distance() {
481 let batch = create_test_batch();
482
483 let results = VectorSearch::new("embedding")
484 .query_vector(vec![1.0, 0.0, 0.0])
485 .metric(DistanceMetric::L2)
486 .top_k(2)
487 .include_distance(false)
488 .search(&batch)
489 .await
490 .unwrap();
491
492 assert_eq!(results.num_columns(), 3);
494 }
495
496 #[tokio::test]
497 async fn test_vector_search_custom_distance_column() {
498 let batch = create_test_batch();
499
500 let results = VectorSearch::new("embedding")
501 .query_vector(vec![1.0, 0.0, 0.0])
502 .metric(DistanceMetric::L2)
503 .top_k(2)
504 .distance_column_name("similarity_score")
505 .search(&batch)
506 .await
507 .unwrap();
508
509 let schema = results.schema();
510 assert!(schema.field_with_name("similarity_score").is_ok());
511 }
512
513 #[tokio::test]
514 async fn test_vector_search_missing_query() {
515 let batch = create_test_batch();
516
517 let result = VectorSearch::new("embedding")
518 .metric(DistanceMetric::L2)
519 .top_k(2)
520 .search(&batch)
521 .await;
522
523 assert!(result.is_err());
524 }
525
526 #[tokio::test]
527 async fn test_vector_search_missing_column() {
528 let batch = create_test_batch();
529
530 let result = VectorSearch::new("nonexistent")
531 .query_vector(vec![1.0, 0.0, 0.0])
532 .top_k(2)
533 .search(&batch)
534 .await;
535
536 assert!(result.is_err());
537 }
538
539 #[tokio::test]
540 async fn test_vector_search_top_k_larger_than_data() {
541 let batch = create_test_batch();
542
543 let results = VectorSearch::new("embedding")
544 .query_vector(vec![1.0, 0.0, 0.0])
545 .metric(DistanceMetric::L2)
546 .top_k(100) .search(&batch)
548 .await
549 .unwrap();
550
551 assert_eq!(results.num_rows(), 5);
553 }
554}