1use std::collections::HashMap;
7use std::sync::Arc;
8
9use crate::storage::engine::distance::{distance, DistanceMetric};
10use crate::storage::engine::hnsw::{HnswConfig, HnswIndex};
11use crate::storage::engine::unified_index::UnifiedIndex;
12use crate::storage::engine::vector_metadata::{MetadataFilter, MetadataValue};
13use crate::storage::engine::vector_store::VectorStore;
14use crate::storage::query::ast::{QueryExpr, VectorQuery, VectorSource};
15use crate::storage::query::sql_lowering::effective_vector_filter;
16use crate::storage::query::unified::{
17 ExecutionError, QueryStats, UnifiedRecord, UnifiedResult, VectorSearchResult,
18};
19use crate::storage::schema::Value;
20
21pub struct VectorExecutor {
23 vector_store: Arc<VectorStore>,
25 unified_index: Option<Arc<UnifiedIndex>>,
27}
28
29impl VectorExecutor {
30 pub fn new(vector_store: Arc<VectorStore>) -> Self {
32 Self {
33 vector_store,
34 unified_index: None,
35 }
36 }
37
38 pub fn with_unified_index(mut self, index: Arc<UnifiedIndex>) -> Self {
40 self.unified_index = Some(index);
41 self
42 }
43
44 pub fn execute(&self, query: &VectorQuery) -> Result<UnifiedResult, ExecutionError> {
46 let start = std::time::Instant::now();
47
48 let query_vector = self.resolve_vector_source(&query.query_vector)?;
50
51 let collection = self.vector_store.get(&query.collection).ok_or_else(|| {
53 ExecutionError::new(format!("Vector collection not found: {}", query.collection))
54 })?;
55
56 let search_results = collection.search_with_filter(
58 &query_vector,
59 query.k,
60 effective_vector_filter(query).as_ref(),
61 );
62
63 let mut result = UnifiedResult::with_columns(vec![
65 "id".to_string(),
66 "distance".to_string(),
67 "collection".to_string(),
68 ]);
69
70 if query.include_vectors {
71 result.columns.push("vector".to_string());
72 }
73 if query.include_metadata {
74 result.columns.push("metadata".to_string());
75 }
76
77 for sr in search_results {
79 if let Some(threshold) = query.threshold {
81 if sr.distance > threshold {
82 continue;
83 }
84 }
85
86 let mut record = UnifiedRecord::new();
87
88 let mut vsr = VectorSearchResult::new(sr.id, &query.collection, sr.distance);
90
91 if query.include_vectors {
93 if let Some(vec_data) = sr.vector {
94 vsr = vsr.with_vector(vec_data);
95 }
96 }
97
98 if query.include_metadata {
100 if let Some(ref meta_entry) = sr.metadata {
101 let mut meta_map: HashMap<String, Value> = HashMap::new();
103 for (k, v) in &meta_entry.strings {
104 meta_map.insert(k.clone(), Value::text(v.clone()));
105 }
106 for (k, v) in &meta_entry.integers {
107 meta_map.insert(k.clone(), Value::Integer(*v));
108 }
109 for (k, v) in &meta_entry.floats {
110 meta_map.insert(k.clone(), Value::Float(*v));
111 }
112 for (k, v) in &meta_entry.bools {
113 meta_map.insert(k.clone(), Value::Boolean(*v));
114 }
115 vsr = vsr.with_metadata(meta_map);
116 }
117 }
118
119 if let Some(ref unified) = self.unified_index {
121 if let Some(node_id) = unified.get_vector_node(&query.collection, sr.id) {
123 vsr = vsr.with_linked_node(node_id);
124 }
125
126 if let Some(row_key) = unified.get_vector_row(&query.collection, sr.id) {
128 vsr = vsr.with_linked_row(&row_key.table, row_key.row_id);
129 }
130 }
131
132 record.set_arc(Arc::from("id"), Value::Integer(sr.id as i64));
134 record.set_arc(Arc::from("distance"), Value::Float(sr.distance as f64));
135 record.set_arc(
136 Arc::from("collection"),
137 Value::text(query.collection.clone()),
138 );
139
140 record.vector_results.push(vsr);
141 result.push(record);
142 }
143
144 result.stats = QueryStats {
146 nodes_scanned: 0,
147 edges_scanned: 0,
148 rows_scanned: result.len() as u64,
149 exec_time_us: start.elapsed().as_micros() as u64,
150 };
151
152 Ok(result)
153 }
154
155 fn resolve_vector_source(&self, source: &VectorSource) -> Result<Vec<f32>, ExecutionError> {
157 match source {
158 VectorSource::Literal(vec) => Ok(vec.clone()),
159
160 VectorSource::Text(text) => {
161 Err(ExecutionError::new(format!(
164 "Text embedding not yet implemented. Provide a literal vector or use an embedding service for: '{}'",
165 text
166 )))
167 }
168
169 VectorSource::Reference {
170 collection,
171 vector_id,
172 } => {
173 if let Some(coll) = self.vector_store.get(collection) {
174 coll.get(*vector_id).cloned().ok_or_else(|| {
175 ExecutionError::new(format!(
176 "Reference vector not found: {}:{}",
177 collection, vector_id
178 ))
179 })
180 } else {
181 Err(ExecutionError::new(format!(
182 "Vector collection not found: {}",
183 collection
184 )))
185 }
186 }
187
188 VectorSource::Subquery(expr) => self.resolve_subquery_vector(expr.as_ref()),
189 }
190 }
191
192 fn resolve_subquery_vector(&self, expr: &QueryExpr) -> Result<Vec<f32>, ExecutionError> {
193 match expr {
194 QueryExpr::Vector(query) => {
195 let result = self.execute(query)?;
196 let (collection, vector_id) =
197 vector_subquery_reference(&result.records, &query.collection)?;
198 self.resolve_vector_source(&VectorSource::Reference {
199 collection,
200 vector_id,
201 })
202 }
203 other => Err(ExecutionError::new(format!(
204 "Vector subqueries currently support only nested VECTOR SEARCH expressions, got {}",
205 query_expr_name(other)
206 ))),
207 }
208 }
209}
210
211fn metadata_value_to_value(mv: MetadataValue) -> Value {
213 match mv {
214 MetadataValue::String(s) => Value::text(s),
215 MetadataValue::Integer(i) => Value::Integer(i),
216 MetadataValue::Float(f) => Value::Float(f),
217 MetadataValue::Bool(b) => Value::Boolean(b),
218 MetadataValue::Null => Value::Null,
219 }
220}
221
222pub struct InMemoryVectorExecutor {
228 vectors: HashMap<(String, u64), Vec<f32>>,
230 metadata: HashMap<(String, u64), HashMap<String, MetadataValue>>,
232 indexes: HashMap<String, HnswIndex>,
234 unified_index: Option<Arc<UnifiedIndex>>,
236}
237
238impl InMemoryVectorExecutor {
239 pub fn new() -> Self {
241 Self {
242 vectors: HashMap::new(),
243 metadata: HashMap::new(),
244 indexes: HashMap::new(),
245 unified_index: None,
246 }
247 }
248
249 pub fn with_unified_index(mut self, index: Arc<UnifiedIndex>) -> Self {
251 self.unified_index = Some(index);
252 self
253 }
254
255 pub fn add_vector(
257 &mut self,
258 collection: &str,
259 id: u64,
260 vector: Vec<f32>,
261 meta: Option<HashMap<String, MetadataValue>>,
262 ) {
263 let dim = vector.len();
264
265 self.vectors
267 .insert((collection.to_string(), id), vector.clone());
268
269 if let Some(m) = meta {
271 self.metadata.insert((collection.to_string(), id), m);
272 }
273
274 let index = self
276 .indexes
277 .entry(collection.to_string())
278 .or_insert_with(|| {
279 let config = HnswConfig {
280 m: 16,
281 m_max0: 32,
282 ef_construction: 200,
283 ef_search: 50,
284 ml: 1.0 / (16.0_f64).ln(),
285 metric: DistanceMetric::L2,
286 };
287 HnswIndex::new(dim, config)
288 });
289
290 index.insert_with_id(id, vector.clone());
291 }
292
293 pub fn execute(&self, query: &VectorQuery) -> Result<UnifiedResult, ExecutionError> {
295 let start = std::time::Instant::now();
296
297 let query_vector = match &query.query_vector {
299 VectorSource::Literal(v) => v.clone(),
300 VectorSource::Reference {
301 collection,
302 vector_id,
303 } => self
304 .vectors
305 .get(&(collection.clone(), *vector_id))
306 .cloned()
307 .ok_or_else(|| ExecutionError::new("Reference vector not found"))?,
308 VectorSource::Text(t) => {
309 return Err(ExecutionError::new(format!(
310 "Text embedding not implemented: '{}'",
311 t
312 )));
313 }
314 VectorSource::Subquery(expr) => self.resolve_subquery_vector(expr.as_ref())?,
315 };
316
317 let metric = query.metric.unwrap_or(DistanceMetric::L2);
318
319 let mut result = UnifiedResult::with_columns(vec![
321 "id".to_string(),
322 "distance".to_string(),
323 "collection".to_string(),
324 ]);
325
326 let search_results: Vec<(u64, f32)> =
328 if let Some(index) = self.indexes.get(&query.collection) {
329 let mut results: Vec<_> = index
331 .search(&query_vector, query.k)
332 .into_iter()
333 .map(|r| (r.id, r.distance))
334 .collect();
335 results.sort_by(|a, b| {
336 match a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) {
337 std::cmp::Ordering::Equal => a.0.cmp(&b.0),
338 ordering => ordering,
339 }
340 });
341 results
342 } else {
343 self.brute_force_search(&query.collection, &query_vector, query.k, metric)
345 };
346
347 for (vector_id, dist) in search_results {
348 if let Some(threshold) = query.threshold {
350 if dist > threshold {
351 continue;
352 }
353 }
354
355 if let Some(ref filter) = query.filter {
357 let key = (query.collection.clone(), vector_id);
358 if let Some(meta) = self.metadata.get(&key) {
359 if !evaluate_filter(filter, meta) {
360 continue;
361 }
362 } else {
363 continue; }
365 }
366
367 let mut record = UnifiedRecord::new();
368 let mut vsr = VectorSearchResult::new(vector_id, &query.collection, dist);
369
370 if query.include_vectors {
371 if let Some(vec) = self.vectors.get(&(query.collection.clone(), vector_id)) {
372 vsr = vsr.with_vector(vec.clone());
373 }
374 }
375
376 if query.include_metadata {
377 if let Some(meta) = self.metadata.get(&(query.collection.clone(), vector_id)) {
378 let meta_map: HashMap<String, Value> = meta
379 .iter()
380 .map(|(k, v)| (k.clone(), metadata_value_to_value(v.clone())))
381 .collect();
382 vsr = vsr.with_metadata(meta_map);
383 }
384 }
385
386 if let Some(ref unified) = self.unified_index {
388 if let Some(node_id) = unified.get_vector_node(&query.collection, vector_id) {
389 vsr = vsr.with_linked_node(node_id);
390 }
391
392 if let Some(row_key) = unified.get_vector_row(&query.collection, vector_id) {
393 vsr = vsr.with_linked_row(&row_key.table, row_key.row_id);
394 }
395 }
396
397 record.set_arc(Arc::from("id"), Value::Integer(vector_id as i64));
398 record.set_arc(Arc::from("distance"), Value::Float(dist as f64));
399 record.set_arc(
400 Arc::from("collection"),
401 Value::text(query.collection.clone()),
402 );
403 record.vector_results.push(vsr);
404 result.push(record);
405 }
406
407 result.stats = QueryStats {
408 nodes_scanned: 0,
409 edges_scanned: 0,
410 rows_scanned: self.vectors.len() as u64,
411 exec_time_us: start.elapsed().as_micros() as u64,
412 };
413
414 Ok(result)
415 }
416
417 fn resolve_subquery_vector(&self, expr: &QueryExpr) -> Result<Vec<f32>, ExecutionError> {
418 match expr {
419 QueryExpr::Vector(query) => {
420 let result = self.execute(query)?;
421 let (collection, vector_id) =
422 vector_subquery_reference(&result.records, &query.collection)?;
423 self.vectors
424 .get(&(collection, vector_id))
425 .cloned()
426 .ok_or_else(|| ExecutionError::new("Subquery reference vector not found"))
427 }
428 other => Err(ExecutionError::new(format!(
429 "Vector subqueries currently support only nested VECTOR SEARCH expressions, got {}",
430 query_expr_name(other)
431 ))),
432 }
433 }
434
435 fn brute_force_search(
437 &self,
438 collection: &str,
439 query: &[f32],
440 k: usize,
441 metric: DistanceMetric,
442 ) -> Vec<(u64, f32)> {
443 let mut results: Vec<(u64, f32)> = self
444 .vectors
445 .iter()
446 .filter(|((c, _), _)| c == collection)
447 .map(|((_, id), vec)| {
448 let dist = distance(query, vec, metric);
449 (*id, dist)
450 })
451 .collect();
452
453 results.sort_by(
454 |a, b| match a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) {
455 std::cmp::Ordering::Equal => a.0.cmp(&b.0),
456 ordering => ordering,
457 },
458 );
459 results.truncate(k);
460 results
461 }
462}
463
464impl Default for InMemoryVectorExecutor {
465 fn default() -> Self {
466 Self::new()
467 }
468}
469
470fn evaluate_filter(filter: &MetadataFilter, metadata: &HashMap<String, MetadataValue>) -> bool {
472 match filter {
473 MetadataFilter::Eq(field, value) => metadata
474 .get(field)
475 .map(|candidate| candidate.matches_eq(value))
476 .unwrap_or(false),
477 MetadataFilter::Ne(field, value) => metadata
478 .get(field)
479 .map(|candidate| !candidate.matches_eq(value))
480 .unwrap_or(true),
481 MetadataFilter::Lt(field, value) => metadata
482 .get(field)
483 .and_then(|candidate| candidate.compare(value))
484 .map(|ord| ord == std::cmp::Ordering::Less)
485 .unwrap_or(false),
486 MetadataFilter::Lte(field, value) => metadata
487 .get(field)
488 .and_then(|candidate| candidate.compare(value))
489 .map(|ord| ord != std::cmp::Ordering::Greater)
490 .unwrap_or(false),
491 MetadataFilter::Gt(field, value) => metadata
492 .get(field)
493 .and_then(|candidate| candidate.compare(value))
494 .map(|ord| ord == std::cmp::Ordering::Greater)
495 .unwrap_or(false),
496 MetadataFilter::Gte(field, value) => metadata
497 .get(field)
498 .and_then(|candidate| candidate.compare(value))
499 .map(|ord| ord != std::cmp::Ordering::Less)
500 .unwrap_or(false),
501 MetadataFilter::In(field, values) => metadata
502 .get(field)
503 .map(|candidate| values.iter().any(|value| candidate.matches_eq(value)))
504 .unwrap_or(false),
505 MetadataFilter::NotIn(field, values) => metadata
506 .get(field)
507 .map(|candidate| !values.iter().any(|value| candidate.matches_eq(value)))
508 .unwrap_or(true),
509 MetadataFilter::Contains(field, substring) => {
510 if let Some(MetadataValue::String(s)) = metadata.get(field) {
511 s.contains(substring)
512 } else {
513 false
514 }
515 }
516 MetadataFilter::And(filters) => filters.iter().all(|f| evaluate_filter(f, metadata)),
517 MetadataFilter::Or(filters) => filters.iter().any(|f| evaluate_filter(f, metadata)),
518 MetadataFilter::Not(inner) => !evaluate_filter(inner, metadata),
519 MetadataFilter::StartsWith(field, prefix) => {
520 if let Some(MetadataValue::String(s)) = metadata.get(field) {
521 s.starts_with(prefix)
522 } else {
523 false
524 }
525 }
526 MetadataFilter::EndsWith(field, suffix) => {
527 if let Some(MetadataValue::String(s)) = metadata.get(field) {
528 s.ends_with(suffix)
529 } else {
530 false
531 }
532 }
533 MetadataFilter::Exists(field) => metadata.contains_key(field),
534 MetadataFilter::NotExists(field) => !metadata.contains_key(field),
535 }
536}
537
538fn vector_subquery_reference(
539 records: &[UnifiedRecord],
540 default_collection: &str,
541) -> Result<(String, u64), ExecutionError> {
542 let record = records
543 .first()
544 .ok_or_else(|| ExecutionError::new("Vector subquery returned no rows"))?;
545
546 let collection: String = match record.get("collection") {
547 Some(Value::Text(collection)) => collection.to_string(),
548 _ => default_collection.to_string(),
549 };
550
551 let vector_id = match record.get("id") {
552 Some(Value::Integer(id)) if *id >= 0 => *id as u64,
553 Some(Value::UnsignedInteger(id)) => *id,
554 other => {
555 return Err(ExecutionError::new(format!(
556 "Vector subquery must expose an integer id column, got {other:?}"
557 )));
558 }
559 };
560
561 Ok((collection, vector_id))
562}
563
564fn query_expr_name(expr: &QueryExpr) -> &'static str {
565 match expr {
566 QueryExpr::Table(_) => "table",
567 QueryExpr::Graph(_) => "graph",
568 QueryExpr::Join(_) => "join",
569 QueryExpr::Path(_) => "path",
570 QueryExpr::Vector(_) => "vector",
571 QueryExpr::Hybrid(_) => "hybrid",
572 QueryExpr::Insert(_) => "insert",
573 QueryExpr::Update(_) => "update",
574 QueryExpr::Delete(_) => "delete",
575 QueryExpr::CreateTable(_) => "create_table",
576 QueryExpr::DropTable(_) => "drop_table",
577 QueryExpr::DropGraph(_) => "drop_graph",
578 QueryExpr::DropVector(_) => "drop_vector",
579 QueryExpr::DropDocument(_) => "drop_document",
580 QueryExpr::DropKv(_) => "drop_kv",
581 QueryExpr::DropCollection(_) => "drop_collection",
582 QueryExpr::Truncate(_) => "truncate",
583 QueryExpr::AlterTable(_) => "alter_table",
584 QueryExpr::GraphCommand(_) => "graph_command",
585 QueryExpr::SearchCommand(_) => "search_command",
586 QueryExpr::Ask(_) => "ask",
587 QueryExpr::CreateIndex(_) => "create_index",
588 QueryExpr::DropIndex(_) => "drop_index",
589 QueryExpr::ProbabilisticCommand(_) => "probabilistic_command",
590 QueryExpr::CreateTimeSeries(_) => "create_timeseries",
591 QueryExpr::DropTimeSeries(_) => "drop_timeseries",
592 QueryExpr::CreateQueue(_) => "create_queue",
593 QueryExpr::AlterQueue(_) => "alter_queue",
594 QueryExpr::DropQueue(_) => "drop_queue",
595 QueryExpr::QueueSelect(_) => "queue_select",
596 QueryExpr::QueueCommand(_) => "queue_command",
597 QueryExpr::KvCommand(_) => "kv_command",
598 QueryExpr::ConfigCommand(_) => "config_command",
599 QueryExpr::CreateTree(_) => "create_tree",
600 QueryExpr::DropTree(_) => "drop_tree",
601 QueryExpr::TreeCommand(_) => "tree_command",
602 QueryExpr::SetConfig { .. } => "set_config",
603 QueryExpr::ShowConfig { .. } => "show_config",
604 QueryExpr::SetSecret { .. } => "set_secret",
605 QueryExpr::DeleteSecret { .. } => "delete_secret",
606 QueryExpr::ShowSecrets { .. } => "show_secrets",
607 QueryExpr::SetTenant(_) => "set_tenant",
608 QueryExpr::ShowTenant => "show_tenant",
609 QueryExpr::ExplainAlter(_) => "explain_alter",
610 QueryExpr::TransactionControl(_) => "transaction_control",
611 QueryExpr::MaintenanceCommand(_) => "maintenance_command",
612 QueryExpr::CreateSchema(_) => "create_schema",
613 QueryExpr::DropSchema(_) => "drop_schema",
614 QueryExpr::CreateSequence(_) => "create_sequence",
615 QueryExpr::DropSequence(_) => "drop_sequence",
616 QueryExpr::CopyFrom(_) => "copy_from",
617 QueryExpr::CreateView(_) => "create_view",
618 QueryExpr::DropView(_) => "drop_view",
619 QueryExpr::RefreshMaterializedView(_) => "refresh_materialized_view",
620 QueryExpr::CreatePolicy(_) => "create_policy",
621 QueryExpr::DropPolicy(_) => "drop_policy",
622 QueryExpr::CreateServer(_) => "create_server",
623 QueryExpr::DropServer(_) => "drop_server",
624 QueryExpr::CreateForeignTable(_) => "create_foreign_table",
625 QueryExpr::DropForeignTable(_) => "drop_foreign_table",
626 QueryExpr::Grant(_) => "grant",
627 QueryExpr::Revoke(_) => "revoke",
628 QueryExpr::AlterUser(_) => "alter_user",
629 QueryExpr::CreateIamPolicy { .. } => "create_iam_policy",
630 QueryExpr::DropIamPolicy { .. } => "drop_iam_policy",
631 QueryExpr::AttachPolicy { .. } => "attach_policy",
632 QueryExpr::DetachPolicy { .. } => "detach_policy",
633 QueryExpr::ShowPolicies { .. } => "show_policies",
634 QueryExpr::ShowEffectivePermissions { .. } => "show_effective_permissions",
635 QueryExpr::SimulatePolicy { .. } => "simulate_policy",
636 QueryExpr::CreateMigration(_) => "create_migration",
637 QueryExpr::ApplyMigration(_) => "apply_migration",
638 QueryExpr::RollbackMigration(_) => "rollback_migration",
639 QueryExpr::ExplainMigration(_) => "explain_migration",
640 QueryExpr::EventsBackfill(_) => "events_backfill",
641 QueryExpr::EventsBackfillStatus { .. } => "events_backfill_status",
642 }
643}
644
645#[cfg(test)]
650mod tests {
651 use super::*;
652
653 #[test]
654 fn test_in_memory_vector_search() {
655 let mut executor = InMemoryVectorExecutor::new();
656
657 executor.add_vector("test", 1, vec![1.0, 0.0, 0.0], None);
659 executor.add_vector("test", 2, vec![0.0, 1.0, 0.0], None);
660 executor.add_vector("test", 3, vec![0.0, 0.0, 1.0], None);
661 executor.add_vector("test", 4, vec![0.9, 0.1, 0.0], None);
662
663 let query = VectorQuery {
664 alias: None,
665 collection: "test".to_string(),
666 query_vector: VectorSource::Literal(vec![1.0, 0.0, 0.0]),
667 k: 2,
668 filter: None,
669 metric: Some(DistanceMetric::L2),
670 include_vectors: false,
671 include_metadata: false,
672 threshold: None,
673 };
674
675 let result = executor.execute(&query).unwrap();
676 assert_eq!(result.len(), 2);
677
678 let first = &result.records[0];
680 assert_eq!(first.get("id"), Some(&Value::Integer(1)));
681 }
682
683 #[test]
684 fn test_vector_search_with_metadata_filter() {
685 let mut executor = InMemoryVectorExecutor::new();
686
687 let mut meta1 = HashMap::new();
688 meta1.insert("type".to_string(), MetadataValue::String("cve".to_string()));
689 meta1.insert("severity".to_string(), MetadataValue::Integer(9));
690
691 let mut meta2 = HashMap::new();
692 meta2.insert("type".to_string(), MetadataValue::String("cve".to_string()));
693 meta2.insert("severity".to_string(), MetadataValue::Integer(5));
694
695 let mut meta3 = HashMap::new();
696 meta3.insert(
697 "type".to_string(),
698 MetadataValue::String("advisory".to_string()),
699 );
700 meta3.insert("severity".to_string(), MetadataValue::Integer(8));
701
702 executor.add_vector("vulns", 1, vec![1.0, 0.0], Some(meta1));
703 executor.add_vector("vulns", 2, vec![0.9, 0.1], Some(meta2));
704 executor.add_vector("vulns", 3, vec![0.8, 0.2], Some(meta3));
705
706 let query = VectorQuery {
708 alias: None,
709 collection: "vulns".to_string(),
710 query_vector: VectorSource::Literal(vec![1.0, 0.0]),
711 k: 10,
712 filter: Some(MetadataFilter::And(vec![
713 MetadataFilter::Eq("type".to_string(), MetadataValue::String("cve".to_string())),
714 MetadataFilter::Gte("severity".to_string(), MetadataValue::Integer(7)),
715 ])),
716 metric: Some(DistanceMetric::L2),
717 include_vectors: false,
718 include_metadata: true,
719 threshold: None,
720 };
721
722 let result = executor.execute(&query).unwrap();
723
724 assert_eq!(result.len(), 1);
726 assert_eq!(result.records[0].get("id"), Some(&Value::Integer(1)));
727 }
728
729 #[test]
730 fn test_vector_search_with_threshold() {
731 let mut executor = InMemoryVectorExecutor::new();
732
733 executor.add_vector("test", 1, vec![1.0, 0.0], None);
734 executor.add_vector("test", 2, vec![0.0, 1.0], None); let query = VectorQuery {
737 alias: None,
738 collection: "test".to_string(),
739 query_vector: VectorSource::Literal(vec![1.0, 0.0]),
740 k: 10,
741 filter: None,
742 metric: Some(DistanceMetric::L2),
743 include_vectors: false,
744 include_metadata: false,
745 threshold: Some(0.5), };
747
748 let result = executor.execute(&query).unwrap();
749
750 assert_eq!(result.len(), 1);
752 }
753
754 #[test]
755 fn test_vector_search_include_vectors() {
756 let mut executor = InMemoryVectorExecutor::new();
757
758 executor.add_vector("test", 1, vec![1.0, 2.0, 3.0], None);
759
760 let query = VectorQuery {
761 alias: None,
762 collection: "test".to_string(),
763 query_vector: VectorSource::Literal(vec![1.0, 2.0, 3.0]),
764 k: 1,
765 filter: None,
766 metric: Some(DistanceMetric::L2),
767 include_vectors: true,
768 include_metadata: false,
769 threshold: None,
770 };
771
772 let result = executor.execute(&query).unwrap();
773 assert_eq!(result.len(), 1);
774
775 let vsr = &result.records[0].vector_results[0];
776 assert!(vsr.vector.is_some());
777 assert_eq!(vsr.vector.as_ref().unwrap(), &vec![1.0, 2.0, 3.0]);
778 }
779
780 #[test]
781 fn test_vector_executor_reference_source() {
782 let mut store = VectorStore::new();
783 let collection = store.create_collection("refs", 2);
784 let ref_id = collection.insert(vec![1.0, 0.0], None).unwrap();
785 collection.insert(vec![0.0, 1.0], None).unwrap();
786
787 let executor = VectorExecutor::new(Arc::new(store));
788 let query = VectorQuery {
789 alias: None,
790 collection: "refs".to_string(),
791 query_vector: VectorSource::Reference {
792 collection: "refs".to_string(),
793 vector_id: ref_id,
794 },
795 k: 1,
796 filter: None,
797 metric: Some(DistanceMetric::L2),
798 include_vectors: false,
799 include_metadata: false,
800 threshold: None,
801 };
802
803 let result = executor.execute(&query).unwrap();
804 assert_eq!(result.len(), 1);
805 assert_eq!(result.records[0].get("id"), Some(&Value::Integer(0)));
806 }
807
808 #[test]
809 fn test_vector_executor_subquery_source() {
810 let mut store = VectorStore::new();
811 let collection = store.create_collection("refs", 2);
812 collection.insert(vec![1.0, 0.0], None).unwrap();
813 collection.insert(vec![0.0, 1.0], None).unwrap();
814
815 let executor = VectorExecutor::new(Arc::new(store));
816 let inner = VectorQuery {
817 alias: None,
818 collection: "refs".to_string(),
819 query_vector: VectorSource::Literal(vec![1.0, 0.0]),
820 k: 1,
821 filter: None,
822 metric: Some(DistanceMetric::L2),
823 include_vectors: false,
824 include_metadata: false,
825 threshold: None,
826 };
827 let query = VectorQuery {
828 alias: None,
829 collection: "refs".to_string(),
830 query_vector: VectorSource::Subquery(Box::new(QueryExpr::Vector(inner))),
831 k: 1,
832 filter: None,
833 metric: Some(DistanceMetric::L2),
834 include_vectors: false,
835 include_metadata: false,
836 threshold: None,
837 };
838
839 let result = executor.execute(&query).unwrap();
840 assert_eq!(result.len(), 1);
841 assert_eq!(result.records[0].get("id"), Some(&Value::Integer(0)));
842 }
843}