Skip to main content

lance_graph/datafusion_planner/
vector_ops.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Vector Operations
5//!
6//! Helpers for vector similarity search and distance computation
7
8use crate::ast::DistanceMetric;
9use crate::error::{GraphError, Result};
10use arrow::array::{Array, ArrayRef, FixedSizeListArray, Float32Array, ListArray};
11
12/// Extract vectors from Arrow ListArray or FixedSizeListArray
13///
14/// Accepts both types for user convenience:
15/// - FixedSizeListArray: from Lance datasets or explicit construction
16/// - ListArray: from natural table construction with nested lists
17pub fn extract_vectors(array: &ArrayRef) -> Result<Vec<Vec<f32>>> {
18    // Try FixedSizeListArray first (more common in Lance)
19    if let Some(list_array) = array.as_any().downcast_ref::<FixedSizeListArray>() {
20        let mut vectors = Vec::with_capacity(list_array.len());
21        for i in 0..list_array.len() {
22            if list_array.is_null(i) {
23                return Err(GraphError::ExecutionError {
24                    message: "Null vector in FixedSizeListArray".to_string(),
25                    location: snafu::Location::new(file!(), line!(), column!()),
26                });
27            }
28            let value_array = list_array.value(i);
29            let float_array = value_array
30                .as_any()
31                .downcast_ref::<Float32Array>()
32                .ok_or_else(|| GraphError::ExecutionError {
33                    message: "Expected Float32Array in vector".to_string(),
34                    location: snafu::Location::new(file!(), line!(), column!()),
35                })?;
36
37            let vec: Vec<f32> = (0..float_array.len())
38                .map(|j| float_array.value(j))
39                .collect();
40            vectors.push(vec);
41        }
42        return Ok(vectors);
43    }
44
45    // Try ListArray (from nested list construction)
46    if let Some(list_array) = array.as_any().downcast_ref::<ListArray>() {
47        let mut vectors = Vec::with_capacity(list_array.len());
48        for i in 0..list_array.len() {
49            if list_array.is_null(i) {
50                return Err(GraphError::ExecutionError {
51                    message: "Null vector in ListArray".to_string(),
52                    location: snafu::Location::new(file!(), line!(), column!()),
53                });
54            }
55            let value_array = list_array.value(i);
56            let float_array = value_array
57                .as_any()
58                .downcast_ref::<Float32Array>()
59                .ok_or_else(|| GraphError::ExecutionError {
60                    message: "Expected Float32Array in vector".to_string(),
61                    location: snafu::Location::new(file!(), line!(), column!()),
62                })?;
63
64            let vec: Vec<f32> = (0..float_array.len())
65                .map(|j| float_array.value(j))
66                .collect();
67            vectors.push(vec);
68        }
69        return Ok(vectors);
70    }
71
72    Err(GraphError::ExecutionError {
73        message: "Expected ListArray or FixedSizeListArray for vector column".to_string(),
74        location: snafu::Location::new(file!(), line!(), column!()),
75    })
76}
77
78/// Extract a single vector from a ScalarValue
79/// This avoids allocating a full array when we just need one vector
80pub fn extract_single_vector_from_scalar(
81    scalar: &datafusion::scalar::ScalarValue,
82) -> Result<Vec<f32>> {
83    // Convert scalar to a single-element array, then extract
84    let array = scalar.to_array().map_err(|e| GraphError::ExecutionError {
85        message: format!("Failed to convert scalar to array: {}", e),
86        location: snafu::Location::new(file!(), line!(), column!()),
87    })?;
88
89    let list_array = array
90        .as_any()
91        .downcast_ref::<FixedSizeListArray>()
92        .ok_or_else(|| GraphError::ExecutionError {
93            message: "Expected FixedSizeListArray for vector scalar".to_string(),
94            location: snafu::Location::new(file!(), line!(), column!()),
95        })?;
96
97    if list_array.is_empty() {
98        return Err(GraphError::ExecutionError {
99            message: "Empty vector array".to_string(),
100            location: snafu::Location::new(file!(), line!(), column!()),
101        });
102    }
103
104    let value_array = list_array.value(0);
105    let float_array = value_array
106        .as_any()
107        .downcast_ref::<Float32Array>()
108        .ok_or_else(|| GraphError::ExecutionError {
109            message: "Expected Float32Array in vector".to_string(),
110            location: snafu::Location::new(file!(), line!(), column!()),
111        })?;
112
113    Ok((0..float_array.len())
114        .map(|j| float_array.value(j))
115        .collect())
116}
117
118/// Compute L2 (Euclidean) distance between two vectors
119pub fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
120    if a.len() != b.len() {
121        // Dimension mismatch - return max distance
122        return f32::MAX;
123    }
124
125    a.iter()
126        .zip(b.iter())
127        .map(|(x, y)| (x - y).powi(2))
128        .sum::<f32>()
129        .sqrt()
130}
131
132/// Compute cosine distance (1 - cosine_similarity) between two vectors
133/// Returns a value in [0, 2] where 0 means identical and 2 means opposite
134pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
135    if a.len() != b.len() {
136        // Dimension mismatch - return max distance
137        return 2.0;
138    }
139
140    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
141    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
142    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
143
144    if norm_a == 0.0 || norm_b == 0.0 {
145        return 2.0; // Maximum distance for zero vectors
146    }
147
148    let similarity = dot / (norm_a * norm_b);
149    1.0 - similarity
150}
151
152/// Compute cosine similarity (for vector_similarity function)
153/// Returns a value in [-1, 1] where 1 means identical and -1 means opposite
154pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
155    if a.len() != b.len() {
156        // Dimension mismatch - return minimum similarity
157        return -1.0;
158    }
159
160    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
161    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
162    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
163
164    if norm_a == 0.0 || norm_b == 0.0 {
165        return -1.0; // Minimum similarity for zero vectors
166    }
167
168    dot / (norm_a * norm_b)
169}
170
171/// Compute dot product between two vectors
172/// For similarity search, we return the negative (so lower is better for sorting)
173pub fn dot_product_distance(a: &[f32], b: &[f32]) -> f32 {
174    if a.len() != b.len() {
175        // Dimension mismatch - return worst distance to exclude from results
176        return f32::MAX;
177    }
178
179    -a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<f32>()
180}
181
182/// Compute dot product similarity (for vector_similarity function)
183pub fn dot_product_similarity(a: &[f32], b: &[f32]) -> f32 {
184    if a.len() != b.len() {
185        // Dimension mismatch - return worst similarity to exclude from results
186        return f32::MIN;
187    }
188
189    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<f32>()
190}
191
192/// Compute vector distance for an array of vectors against a single query vector
193pub fn compute_vector_distances(
194    vectors: &[Vec<f32>],
195    query_vector: &[f32],
196    metric: &DistanceMetric,
197) -> Vec<f32> {
198    vectors
199        .iter()
200        .map(|v| match metric {
201            DistanceMetric::L2 => l2_distance(v, query_vector),
202            DistanceMetric::Cosine => cosine_distance(v, query_vector),
203            DistanceMetric::Dot => dot_product_distance(v, query_vector),
204        })
205        .collect()
206}
207
208/// Compute vector similarities for an array of vectors against a single query vector
209pub fn compute_vector_similarities(
210    vectors: &[Vec<f32>],
211    query_vector: &[f32],
212    metric: &DistanceMetric,
213) -> Vec<f32> {
214    vectors
215        .iter()
216        .map(|v| match metric {
217            DistanceMetric::L2 => {
218                // For L2, convert distance to similarity (inverse)
219                let dist = l2_distance(v, query_vector);
220                if dist == 0.0 {
221                    1.0 // Perfect match
222                } else {
223                    1.0 / (1.0 + dist) // Similarity decreases as distance increases
224                }
225            }
226            DistanceMetric::Cosine => cosine_similarity(v, query_vector),
227            DistanceMetric::Dot => dot_product_similarity(v, query_vector),
228        })
229        .collect()
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use std::sync::Arc;
236
237    #[test]
238    fn test_l2_distance() {
239        let a = vec![1.0, 0.0, 0.0];
240        let b = vec![0.0, 1.0, 0.0];
241        let dist = l2_distance(&a, &b);
242        assert!((dist - 1.414).abs() < 0.01); // sqrt(2)
243    }
244
245    #[test]
246    fn test_l2_distance_identical() {
247        let a = vec![1.0, 2.0, 3.0];
248        let b = vec![1.0, 2.0, 3.0];
249        let dist = l2_distance(&a, &b);
250        assert_eq!(dist, 0.0);
251    }
252
253    #[test]
254    fn test_cosine_distance() {
255        let a = vec![1.0, 0.0, 0.0];
256        let b = vec![1.0, 0.0, 0.0];
257        let dist = cosine_distance(&a, &b);
258        assert_eq!(dist, 0.0); // Identical vectors
259    }
260
261    #[test]
262    fn test_cosine_distance_orthogonal() {
263        let a = vec![1.0, 0.0, 0.0];
264        let b = vec![0.0, 1.0, 0.0];
265        let dist = cosine_distance(&a, &b);
266        assert_eq!(dist, 1.0); // Orthogonal vectors
267    }
268
269    #[test]
270    fn test_cosine_similarity() {
271        let a = vec![1.0, 0.0, 0.0];
272        let b = vec![1.0, 0.0, 0.0];
273        let sim = cosine_similarity(&a, &b);
274        assert_eq!(sim, 1.0); // Identical
275    }
276
277    #[test]
278    fn test_dot_product() {
279        let a = vec![1.0, 2.0, 3.0];
280        let b = vec![4.0, 5.0, 6.0];
281        let sim = dot_product_similarity(&a, &b);
282        assert_eq!(sim, 32.0); // 1*4 + 2*5 + 3*6 = 32
283    }
284
285    #[test]
286    fn test_dimension_mismatch() {
287        let a = vec![1.0, 2.0];
288        let b = vec![1.0, 2.0, 3.0];
289
290        let dist = l2_distance(&a, &b);
291        assert_eq!(dist, f32::MAX);
292
293        let dist = cosine_distance(&a, &b);
294        assert_eq!(dist, 2.0);
295
296        let dist = dot_product_distance(&a, &b);
297        assert_eq!(dist, f32::MAX);
298
299        let sim = dot_product_similarity(&a, &b);
300        assert_eq!(sim, f32::MIN);
301    }
302
303    #[test]
304    fn test_extract_single_vector_from_scalar() {
305        use arrow::array::FixedSizeListArray;
306        use arrow::datatypes::{DataType, Field};
307        use datafusion::scalar::ScalarValue;
308
309        // Create a FixedSizeList scalar value with a 3D vector [1.0, 2.0, 3.0]
310        let field = Arc::new(Field::new("item", DataType::Float32, true));
311        let values = Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0]));
312        let list_array = FixedSizeListArray::try_new(field.clone(), 3, values, None).unwrap();
313
314        // Create a scalar from the first element
315        let scalar = ScalarValue::try_from_array(&list_array, 0).unwrap();
316
317        // Extract the vector
318        let result = extract_single_vector_from_scalar(&scalar);
319        assert!(result.is_ok());
320
321        let vec = result.unwrap();
322        assert_eq!(vec.len(), 3);
323        assert_eq!(vec[0], 1.0);
324        assert_eq!(vec[1], 2.0);
325        assert_eq!(vec[2], 3.0);
326    }
327
328    #[test]
329    fn test_extract_single_vector_from_scalar_different_dimensions() {
330        use arrow::array::FixedSizeListArray;
331        use arrow::datatypes::{DataType, Field};
332        use datafusion::scalar::ScalarValue;
333
334        // Create a 5D vector [0.1, 0.2, 0.3, 0.4, 0.5]
335        let field = Arc::new(Field::new("item", DataType::Float32, true));
336        let values = Arc::new(Float32Array::from(vec![0.1, 0.2, 0.3, 0.4, 0.5]));
337        let list_array = FixedSizeListArray::try_new(field.clone(), 5, values, None).unwrap();
338
339        let scalar = ScalarValue::try_from_array(&list_array, 0).unwrap();
340        let result = extract_single_vector_from_scalar(&scalar);
341        assert!(result.is_ok());
342
343        let vec = result.unwrap();
344        assert_eq!(vec.len(), 5);
345        assert!((vec[0] - 0.1).abs() < 0.001);
346        assert!((vec[4] - 0.5).abs() < 0.001);
347    }
348
349    #[test]
350    fn test_compute_vector_distances_broadcast() {
351        // Test that compute_vector_distances properly broadcasts a single query vector
352        let data_vectors = vec![
353            vec![1.0, 0.0, 0.0],
354            vec![0.0, 1.0, 0.0],
355            vec![0.0, 0.0, 1.0],
356        ];
357        let query_vector = vec![1.0, 0.0, 0.0];
358
359        let distances = compute_vector_distances(&data_vectors, &query_vector, &DistanceMetric::L2);
360
361        assert_eq!(distances.len(), 3);
362        assert_eq!(distances[0], 0.0); // Same as query
363        assert!((distances[1] - 1.414).abs() < 0.01); // Orthogonal
364        assert!((distances[2] - 1.414).abs() < 0.01); // Orthogonal
365    }
366
367    #[test]
368    fn test_compute_vector_similarities_broadcast() {
369        // Test that compute_vector_similarities properly broadcasts a single query vector
370        let data_vectors = vec![
371            vec![1.0, 0.0, 0.0],
372            vec![0.0, 1.0, 0.0],
373            vec![0.5, 0.5, 0.0], // 45 degrees from x-axis
374        ];
375        let query_vector = vec![1.0, 0.0, 0.0];
376
377        let similarities =
378            compute_vector_similarities(&data_vectors, &query_vector, &DistanceMetric::Cosine);
379
380        assert_eq!(similarities.len(), 3);
381        assert_eq!(similarities[0], 1.0); // Same as query
382        assert_eq!(similarities[1], 0.0); // Orthogonal
383        assert!((similarities[2] - 0.707).abs() < 0.01); // cos(45°) ≈ 0.707
384    }
385
386    #[test]
387    fn test_extract_vectors_from_fixed_size_list() {
388        use arrow::datatypes::{DataType, Field};
389
390        // Create FixedSizeListArray with 3D vectors
391        let field = Arc::new(Field::new("item", DataType::Float32, true));
392        let values = Arc::new(Float32Array::from(vec![
393            1.0, 0.0, 0.0, // Vector 1
394            0.0, 1.0, 0.0, // Vector 2
395            0.0, 0.0, 1.0, // Vector 3
396        ]));
397        let list_array = FixedSizeListArray::try_new(field, 3, values, None).unwrap();
398        let array_ref: ArrayRef = Arc::new(list_array);
399
400        let result = extract_vectors(&array_ref);
401        assert!(result.is_ok());
402
403        let vectors = result.unwrap();
404        assert_eq!(vectors.len(), 3);
405        assert_eq!(vectors[0], vec![1.0, 0.0, 0.0]);
406        assert_eq!(vectors[1], vec![0.0, 1.0, 0.0]);
407        assert_eq!(vectors[2], vec![0.0, 0.0, 1.0]);
408    }
409
410    #[test]
411    fn test_extract_vectors_from_list_array() {
412        use arrow::array::ListBuilder;
413
414        // Create ListArray with variable-length vectors (though we use same length)
415        let values_builder = Float32Array::builder(9);
416        let mut list_builder = ListBuilder::new(values_builder);
417
418        // Add first vector [1.0, 0.0, 0.0]
419        list_builder.values().append_value(1.0);
420        list_builder.values().append_value(0.0);
421        list_builder.values().append_value(0.0);
422        list_builder.append(true);
423
424        // Add second vector [0.0, 1.0, 0.0]
425        list_builder.values().append_value(0.0);
426        list_builder.values().append_value(1.0);
427        list_builder.values().append_value(0.0);
428        list_builder.append(true);
429
430        // Add third vector [0.5, 0.5, 0.0]
431        list_builder.values().append_value(0.5);
432        list_builder.values().append_value(0.5);
433        list_builder.values().append_value(0.0);
434        list_builder.append(true);
435
436        let list_array = list_builder.finish();
437        let array_ref: ArrayRef = Arc::new(list_array);
438
439        let result = extract_vectors(&array_ref);
440        assert!(result.is_ok());
441
442        let vectors = result.unwrap();
443        assert_eq!(vectors.len(), 3);
444        assert_eq!(vectors[0], vec![1.0, 0.0, 0.0]);
445        assert_eq!(vectors[1], vec![0.0, 1.0, 0.0]);
446        assert_eq!(vectors[2], vec![0.5, 0.5, 0.0]);
447    }
448
449    #[test]
450    fn test_extract_vectors_rejects_invalid_type() {
451        // Test that extract_vectors rejects non-list arrays
452        let float_array = Float32Array::from(vec![1.0, 2.0, 3.0]);
453        let array_ref: ArrayRef = Arc::new(float_array);
454
455        let result = extract_vectors(&array_ref);
456        assert!(result.is_err());
457        assert!(result
458            .unwrap_err()
459            .to_string()
460            .contains("Expected ListArray or FixedSizeListArray"));
461    }
462
463    #[test]
464    fn test_extract_vectors_rejects_null_in_fixed_size_list() {
465        use arrow::datatypes::{DataType, Field};
466
467        // Create FixedSizeListArray with a null vector
468        let field = Arc::new(Field::new("item", DataType::Float32, true));
469        let values = Arc::new(Float32Array::from(vec![
470            1.0, 0.0, 0.0, // Vector 1
471            0.0, 1.0, 0.0, // Vector 2 (will be null)
472            0.0, 0.0, 1.0, // Vector 3
473        ]));
474        let null_buffer = arrow::buffer::NullBuffer::from(vec![true, false, true]);
475        let list_array = FixedSizeListArray::try_new(field, 3, values, Some(null_buffer)).unwrap();
476        let array_ref: ArrayRef = Arc::new(list_array);
477
478        let result = extract_vectors(&array_ref);
479        assert!(result.is_err());
480        assert!(result
481            .unwrap_err()
482            .to_string()
483            .contains("Null vector in FixedSizeListArray"));
484    }
485}