1use crate::ast::DistanceMetric;
9use crate::error::{GraphError, Result};
10use arrow::array::{Array, ArrayRef, FixedSizeListArray, Float32Array, ListArray};
11
12pub fn extract_vectors(array: &ArrayRef) -> Result<Vec<Vec<f32>>> {
18 if let Some(list_array) = array.as_any().downcast_ref::<FixedSizeListArray>() {
20 let mut vectors = Vec::with_capacity(list_array.len());
21 for i in 0..list_array.len() {
22 if list_array.is_null(i) {
23 return Err(GraphError::ExecutionError {
24 message: "Null vector in FixedSizeListArray".to_string(),
25 location: snafu::Location::new(file!(), line!(), column!()),
26 });
27 }
28 let value_array = list_array.value(i);
29 let float_array = value_array
30 .as_any()
31 .downcast_ref::<Float32Array>()
32 .ok_or_else(|| GraphError::ExecutionError {
33 message: "Expected Float32Array in vector".to_string(),
34 location: snafu::Location::new(file!(), line!(), column!()),
35 })?;
36
37 let vec: Vec<f32> = (0..float_array.len())
38 .map(|j| float_array.value(j))
39 .collect();
40 vectors.push(vec);
41 }
42 return Ok(vectors);
43 }
44
45 if let Some(list_array) = array.as_any().downcast_ref::<ListArray>() {
47 let mut vectors = Vec::with_capacity(list_array.len());
48 for i in 0..list_array.len() {
49 if list_array.is_null(i) {
50 return Err(GraphError::ExecutionError {
51 message: "Null vector in ListArray".to_string(),
52 location: snafu::Location::new(file!(), line!(), column!()),
53 });
54 }
55 let value_array = list_array.value(i);
56 let float_array = value_array
57 .as_any()
58 .downcast_ref::<Float32Array>()
59 .ok_or_else(|| GraphError::ExecutionError {
60 message: "Expected Float32Array in vector".to_string(),
61 location: snafu::Location::new(file!(), line!(), column!()),
62 })?;
63
64 let vec: Vec<f32> = (0..float_array.len())
65 .map(|j| float_array.value(j))
66 .collect();
67 vectors.push(vec);
68 }
69 return Ok(vectors);
70 }
71
72 Err(GraphError::ExecutionError {
73 message: "Expected ListArray or FixedSizeListArray for vector column".to_string(),
74 location: snafu::Location::new(file!(), line!(), column!()),
75 })
76}
77
78pub fn extract_single_vector_from_scalar(
81 scalar: &datafusion::scalar::ScalarValue,
82) -> Result<Vec<f32>> {
83 let array = scalar.to_array().map_err(|e| GraphError::ExecutionError {
85 message: format!("Failed to convert scalar to array: {}", e),
86 location: snafu::Location::new(file!(), line!(), column!()),
87 })?;
88
89 let list_array = array
90 .as_any()
91 .downcast_ref::<FixedSizeListArray>()
92 .ok_or_else(|| GraphError::ExecutionError {
93 message: "Expected FixedSizeListArray for vector scalar".to_string(),
94 location: snafu::Location::new(file!(), line!(), column!()),
95 })?;
96
97 if list_array.is_empty() {
98 return Err(GraphError::ExecutionError {
99 message: "Empty vector array".to_string(),
100 location: snafu::Location::new(file!(), line!(), column!()),
101 });
102 }
103
104 let value_array = list_array.value(0);
105 let float_array = value_array
106 .as_any()
107 .downcast_ref::<Float32Array>()
108 .ok_or_else(|| GraphError::ExecutionError {
109 message: "Expected Float32Array in vector".to_string(),
110 location: snafu::Location::new(file!(), line!(), column!()),
111 })?;
112
113 Ok((0..float_array.len())
114 .map(|j| float_array.value(j))
115 .collect())
116}
117
118pub fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
120 if a.len() != b.len() {
121 return f32::MAX;
123 }
124
125 a.iter()
126 .zip(b.iter())
127 .map(|(x, y)| (x - y).powi(2))
128 .sum::<f32>()
129 .sqrt()
130}
131
132pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
135 if a.len() != b.len() {
136 return 2.0;
138 }
139
140 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
141 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
142 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
143
144 if norm_a == 0.0 || norm_b == 0.0 {
145 return 2.0; }
147
148 let similarity = dot / (norm_a * norm_b);
149 1.0 - similarity
150}
151
152pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
155 if a.len() != b.len() {
156 return -1.0;
158 }
159
160 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
161 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
162 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
163
164 if norm_a == 0.0 || norm_b == 0.0 {
165 return -1.0; }
167
168 dot / (norm_a * norm_b)
169}
170
171pub fn dot_product_distance(a: &[f32], b: &[f32]) -> f32 {
174 if a.len() != b.len() {
175 return f32::MAX;
177 }
178
179 -a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<f32>()
180}
181
182pub fn dot_product_similarity(a: &[f32], b: &[f32]) -> f32 {
184 if a.len() != b.len() {
185 return f32::MIN;
187 }
188
189 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<f32>()
190}
191
192pub fn compute_vector_distances(
194 vectors: &[Vec<f32>],
195 query_vector: &[f32],
196 metric: &DistanceMetric,
197) -> Vec<f32> {
198 vectors
199 .iter()
200 .map(|v| match metric {
201 DistanceMetric::L2 => l2_distance(v, query_vector),
202 DistanceMetric::Cosine => cosine_distance(v, query_vector),
203 DistanceMetric::Dot => dot_product_distance(v, query_vector),
204 })
205 .collect()
206}
207
208pub fn compute_vector_similarities(
210 vectors: &[Vec<f32>],
211 query_vector: &[f32],
212 metric: &DistanceMetric,
213) -> Vec<f32> {
214 vectors
215 .iter()
216 .map(|v| match metric {
217 DistanceMetric::L2 => {
218 let dist = l2_distance(v, query_vector);
220 if dist == 0.0 {
221 1.0 } else {
223 1.0 / (1.0 + dist) }
225 }
226 DistanceMetric::Cosine => cosine_similarity(v, query_vector),
227 DistanceMetric::Dot => dot_product_similarity(v, query_vector),
228 })
229 .collect()
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use std::sync::Arc;
236
237 #[test]
238 fn test_l2_distance() {
239 let a = vec![1.0, 0.0, 0.0];
240 let b = vec![0.0, 1.0, 0.0];
241 let dist = l2_distance(&a, &b);
242 assert!((dist - 1.414).abs() < 0.01); }
244
245 #[test]
246 fn test_l2_distance_identical() {
247 let a = vec![1.0, 2.0, 3.0];
248 let b = vec![1.0, 2.0, 3.0];
249 let dist = l2_distance(&a, &b);
250 assert_eq!(dist, 0.0);
251 }
252
253 #[test]
254 fn test_cosine_distance() {
255 let a = vec![1.0, 0.0, 0.0];
256 let b = vec![1.0, 0.0, 0.0];
257 let dist = cosine_distance(&a, &b);
258 assert_eq!(dist, 0.0); }
260
261 #[test]
262 fn test_cosine_distance_orthogonal() {
263 let a = vec![1.0, 0.0, 0.0];
264 let b = vec![0.0, 1.0, 0.0];
265 let dist = cosine_distance(&a, &b);
266 assert_eq!(dist, 1.0); }
268
269 #[test]
270 fn test_cosine_similarity() {
271 let a = vec![1.0, 0.0, 0.0];
272 let b = vec![1.0, 0.0, 0.0];
273 let sim = cosine_similarity(&a, &b);
274 assert_eq!(sim, 1.0); }
276
277 #[test]
278 fn test_dot_product() {
279 let a = vec![1.0, 2.0, 3.0];
280 let b = vec![4.0, 5.0, 6.0];
281 let sim = dot_product_similarity(&a, &b);
282 assert_eq!(sim, 32.0); }
284
285 #[test]
286 fn test_dimension_mismatch() {
287 let a = vec![1.0, 2.0];
288 let b = vec![1.0, 2.0, 3.0];
289
290 let dist = l2_distance(&a, &b);
291 assert_eq!(dist, f32::MAX);
292
293 let dist = cosine_distance(&a, &b);
294 assert_eq!(dist, 2.0);
295
296 let dist = dot_product_distance(&a, &b);
297 assert_eq!(dist, f32::MAX);
298
299 let sim = dot_product_similarity(&a, &b);
300 assert_eq!(sim, f32::MIN);
301 }
302
303 #[test]
304 fn test_extract_single_vector_from_scalar() {
305 use arrow::array::FixedSizeListArray;
306 use arrow::datatypes::{DataType, Field};
307 use datafusion::scalar::ScalarValue;
308
309 let field = Arc::new(Field::new("item", DataType::Float32, true));
311 let values = Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0]));
312 let list_array = FixedSizeListArray::try_new(field.clone(), 3, values, None).unwrap();
313
314 let scalar = ScalarValue::try_from_array(&list_array, 0).unwrap();
316
317 let result = extract_single_vector_from_scalar(&scalar);
319 assert!(result.is_ok());
320
321 let vec = result.unwrap();
322 assert_eq!(vec.len(), 3);
323 assert_eq!(vec[0], 1.0);
324 assert_eq!(vec[1], 2.0);
325 assert_eq!(vec[2], 3.0);
326 }
327
328 #[test]
329 fn test_extract_single_vector_from_scalar_different_dimensions() {
330 use arrow::array::FixedSizeListArray;
331 use arrow::datatypes::{DataType, Field};
332 use datafusion::scalar::ScalarValue;
333
334 let field = Arc::new(Field::new("item", DataType::Float32, true));
336 let values = Arc::new(Float32Array::from(vec![0.1, 0.2, 0.3, 0.4, 0.5]));
337 let list_array = FixedSizeListArray::try_new(field.clone(), 5, values, None).unwrap();
338
339 let scalar = ScalarValue::try_from_array(&list_array, 0).unwrap();
340 let result = extract_single_vector_from_scalar(&scalar);
341 assert!(result.is_ok());
342
343 let vec = result.unwrap();
344 assert_eq!(vec.len(), 5);
345 assert!((vec[0] - 0.1).abs() < 0.001);
346 assert!((vec[4] - 0.5).abs() < 0.001);
347 }
348
349 #[test]
350 fn test_compute_vector_distances_broadcast() {
351 let data_vectors = vec![
353 vec![1.0, 0.0, 0.0],
354 vec![0.0, 1.0, 0.0],
355 vec![0.0, 0.0, 1.0],
356 ];
357 let query_vector = vec![1.0, 0.0, 0.0];
358
359 let distances = compute_vector_distances(&data_vectors, &query_vector, &DistanceMetric::L2);
360
361 assert_eq!(distances.len(), 3);
362 assert_eq!(distances[0], 0.0); assert!((distances[1] - 1.414).abs() < 0.01); assert!((distances[2] - 1.414).abs() < 0.01); }
366
367 #[test]
368 fn test_compute_vector_similarities_broadcast() {
369 let data_vectors = vec![
371 vec![1.0, 0.0, 0.0],
372 vec![0.0, 1.0, 0.0],
373 vec![0.5, 0.5, 0.0], ];
375 let query_vector = vec![1.0, 0.0, 0.0];
376
377 let similarities =
378 compute_vector_similarities(&data_vectors, &query_vector, &DistanceMetric::Cosine);
379
380 assert_eq!(similarities.len(), 3);
381 assert_eq!(similarities[0], 1.0); assert_eq!(similarities[1], 0.0); assert!((similarities[2] - 0.707).abs() < 0.01); }
385
386 #[test]
387 fn test_extract_vectors_from_fixed_size_list() {
388 use arrow::datatypes::{DataType, Field};
389
390 let field = Arc::new(Field::new("item", DataType::Float32, true));
392 let values = Arc::new(Float32Array::from(vec![
393 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, ]));
397 let list_array = FixedSizeListArray::try_new(field, 3, values, None).unwrap();
398 let array_ref: ArrayRef = Arc::new(list_array);
399
400 let result = extract_vectors(&array_ref);
401 assert!(result.is_ok());
402
403 let vectors = result.unwrap();
404 assert_eq!(vectors.len(), 3);
405 assert_eq!(vectors[0], vec![1.0, 0.0, 0.0]);
406 assert_eq!(vectors[1], vec![0.0, 1.0, 0.0]);
407 assert_eq!(vectors[2], vec![0.0, 0.0, 1.0]);
408 }
409
410 #[test]
411 fn test_extract_vectors_from_list_array() {
412 use arrow::array::ListBuilder;
413
414 let values_builder = Float32Array::builder(9);
416 let mut list_builder = ListBuilder::new(values_builder);
417
418 list_builder.values().append_value(1.0);
420 list_builder.values().append_value(0.0);
421 list_builder.values().append_value(0.0);
422 list_builder.append(true);
423
424 list_builder.values().append_value(0.0);
426 list_builder.values().append_value(1.0);
427 list_builder.values().append_value(0.0);
428 list_builder.append(true);
429
430 list_builder.values().append_value(0.5);
432 list_builder.values().append_value(0.5);
433 list_builder.values().append_value(0.0);
434 list_builder.append(true);
435
436 let list_array = list_builder.finish();
437 let array_ref: ArrayRef = Arc::new(list_array);
438
439 let result = extract_vectors(&array_ref);
440 assert!(result.is_ok());
441
442 let vectors = result.unwrap();
443 assert_eq!(vectors.len(), 3);
444 assert_eq!(vectors[0], vec![1.0, 0.0, 0.0]);
445 assert_eq!(vectors[1], vec![0.0, 1.0, 0.0]);
446 assert_eq!(vectors[2], vec![0.5, 0.5, 0.0]);
447 }
448
449 #[test]
450 fn test_extract_vectors_rejects_invalid_type() {
451 let float_array = Float32Array::from(vec![1.0, 2.0, 3.0]);
453 let array_ref: ArrayRef = Arc::new(float_array);
454
455 let result = extract_vectors(&array_ref);
456 assert!(result.is_err());
457 assert!(result
458 .unwrap_err()
459 .to_string()
460 .contains("Expected ListArray or FixedSizeListArray"));
461 }
462
463 #[test]
464 fn test_extract_vectors_rejects_null_in_fixed_size_list() {
465 use arrow::datatypes::{DataType, Field};
466
467 let field = Arc::new(Field::new("item", DataType::Float32, true));
469 let values = Arc::new(Float32Array::from(vec![
470 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, ]));
474 let null_buffer = arrow::buffer::NullBuffer::from(vec![true, false, true]);
475 let list_array = FixedSizeListArray::try_new(field, 3, values, Some(null_buffer)).unwrap();
476 let array_ref: ArrayRef = Arc::new(list_array);
477
478 let result = extract_vectors(&array_ref);
479 assert!(result.is_err());
480 assert!(result
481 .unwrap_err()
482 .to_string()
483 .contains("Null vector in FixedSizeListArray"));
484 }
485}