Skip to main content

reddb_server/storage/query/executors/
vector.rs

1//! Vector Query Executor
2//!
3//! Executes VECTOR SEARCH queries using HNSW approximate nearest neighbor search.
4//! Supports metadata filtering, multiple distance metrics, and cross-references.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use crate::storage::engine::distance::{distance, DistanceMetric};
10use crate::storage::engine::hnsw::{HnswConfig, HnswIndex};
11use crate::storage::engine::unified_index::UnifiedIndex;
12use crate::storage::engine::vector_metadata::{MetadataFilter, MetadataValue};
13use crate::storage::engine::vector_store::VectorStore;
14use crate::storage::query::ast::{QueryExpr, VectorQuery, VectorSource};
15use crate::storage::query::sql_lowering::effective_vector_filter;
16use crate::storage::query::unified::{
17    ExecutionError, QueryStats, UnifiedRecord, UnifiedResult, VectorSearchResult,
18};
19use crate::storage::schema::Value;
20
21/// Vector query executor using HNSW index
22pub struct VectorExecutor {
23    /// Vector store for segment management
24    vector_store: Arc<VectorStore>,
25    /// Cross-reference index for linking vectors to nodes/rows
26    unified_index: Option<Arc<UnifiedIndex>>,
27}
28
29impl VectorExecutor {
30    /// Create a new vector executor
31    pub fn new(vector_store: Arc<VectorStore>) -> Self {
32        Self {
33            vector_store,
34            unified_index: None,
35        }
36    }
37
38    /// Create with cross-reference support
39    pub fn with_unified_index(mut self, index: Arc<UnifiedIndex>) -> Self {
40        self.unified_index = Some(index);
41        self
42    }
43
44    /// Execute a vector search query
45    pub fn execute(&self, query: &VectorQuery) -> Result<UnifiedResult, ExecutionError> {
46        let start = std::time::Instant::now();
47
48        // Resolve the query vector
49        let query_vector = self.resolve_vector_source(&query.query_vector)?;
50
51        // Get the collection
52        let collection = self.vector_store.get(&query.collection).ok_or_else(|| {
53            ExecutionError::new(format!("Vector collection not found: {}", query.collection))
54        })?;
55
56        // Search the vector store with filter
57        let search_results = collection.search_with_filter(
58            &query_vector,
59            query.k,
60            effective_vector_filter(query).as_ref(),
61        );
62
63        // Build result
64        let mut result = UnifiedResult::with_columns(vec![
65            "id".to_string(),
66            "distance".to_string(),
67            "collection".to_string(),
68        ]);
69
70        if query.include_vectors {
71            result.columns.push("vector".to_string());
72        }
73        if query.include_metadata {
74            result.columns.push("metadata".to_string());
75        }
76
77        // Convert search results to unified records
78        for sr in search_results {
79            // Apply threshold filter if specified
80            if let Some(threshold) = query.threshold {
81                if sr.distance > threshold {
82                    continue;
83                }
84            }
85
86            let mut record = UnifiedRecord::new();
87
88            // Build vector search result
89            let mut vsr = VectorSearchResult::new(sr.id, &query.collection, sr.distance);
90
91            // Include vector data if requested and available
92            if query.include_vectors {
93                if let Some(vec_data) = sr.vector {
94                    vsr = vsr.with_vector(vec_data);
95                }
96            }
97
98            // Include metadata if requested and available
99            if query.include_metadata {
100                if let Some(ref meta_entry) = sr.metadata {
101                    // Convert MetadataEntry to HashMap<String, Value>
102                    let mut meta_map: HashMap<String, Value> = HashMap::new();
103                    for (k, v) in &meta_entry.strings {
104                        meta_map.insert(k.clone(), Value::text(v.clone()));
105                    }
106                    for (k, v) in &meta_entry.integers {
107                        meta_map.insert(k.clone(), Value::Integer(*v));
108                    }
109                    for (k, v) in &meta_entry.floats {
110                        meta_map.insert(k.clone(), Value::Float(*v));
111                    }
112                    for (k, v) in &meta_entry.bools {
113                        meta_map.insert(k.clone(), Value::Boolean(*v));
114                    }
115                    vsr = vsr.with_metadata(meta_map);
116                }
117            }
118
119            // Add cross-references if available
120            if let Some(ref unified) = self.unified_index {
121                // Check for linked node
122                if let Some(node_id) = unified.get_vector_node(&query.collection, sr.id) {
123                    vsr = vsr.with_linked_node(node_id);
124                }
125
126                // Check for linked row
127                if let Some(row_key) = unified.get_vector_row(&query.collection, sr.id) {
128                    vsr = vsr.with_linked_row(&row_key.table, row_key.row_id);
129                }
130            }
131
132            // Add basic values to record
133            record.set_arc(Arc::from("id"), Value::Integer(sr.id as i64));
134            record.set_arc(Arc::from("distance"), Value::Float(sr.distance as f64));
135            record.set_arc(
136                Arc::from("collection"),
137                Value::text(query.collection.clone()),
138            );
139
140            record.vector_results.push(vsr);
141            result.push(record);
142        }
143
144        // Update stats
145        result.stats = QueryStats {
146            nodes_scanned: 0,
147            edges_scanned: 0,
148            rows_scanned: result.len() as u64,
149            exec_time_us: start.elapsed().as_micros() as u64,
150        };
151
152        Ok(result)
153    }
154
155    /// Resolve vector source to actual vector data
156    fn resolve_vector_source(&self, source: &VectorSource) -> Result<Vec<f32>, ExecutionError> {
157        match source {
158            VectorSource::Literal(vec) => Ok(vec.clone()),
159
160            VectorSource::Text(text) => {
161                // Text embedding would require an embedding model
162                // For now, return an error indicating this needs external embedding
163                Err(ExecutionError::new(format!(
164                    "Text embedding not yet implemented. Provide a literal vector or use an embedding service for: '{}'",
165                    text
166                )))
167            }
168
169            VectorSource::Reference {
170                collection,
171                vector_id,
172            } => {
173                if let Some(coll) = self.vector_store.get(collection) {
174                    coll.get(*vector_id).cloned().ok_or_else(|| {
175                        ExecutionError::new(format!(
176                            "Reference vector not found: {}:{}",
177                            collection, vector_id
178                        ))
179                    })
180                } else {
181                    Err(ExecutionError::new(format!(
182                        "Vector collection not found: {}",
183                        collection
184                    )))
185                }
186            }
187
188            VectorSource::Subquery(expr) => self.resolve_subquery_vector(expr.as_ref()),
189        }
190    }
191
192    fn resolve_subquery_vector(&self, expr: &QueryExpr) -> Result<Vec<f32>, ExecutionError> {
193        match expr {
194            QueryExpr::Vector(query) => {
195                let result = self.execute(query)?;
196                let (collection, vector_id) =
197                    vector_subquery_reference(&result.records, &query.collection)?;
198                self.resolve_vector_source(&VectorSource::Reference {
199                    collection,
200                    vector_id,
201                })
202            }
203            other => Err(ExecutionError::new(format!(
204                "Vector subqueries currently support only nested VECTOR SEARCH expressions, got {}",
205                query_expr_name(other)
206            ))),
207        }
208    }
209}
210
211/// Convert MetadataValue to Value for unified results
212fn metadata_value_to_value(mv: MetadataValue) -> Value {
213    match mv {
214        MetadataValue::String(s) => Value::text(s),
215        MetadataValue::Integer(i) => Value::Integer(i),
216        MetadataValue::Float(f) => Value::Float(f),
217        MetadataValue::Bool(b) => Value::Boolean(b),
218        MetadataValue::Null => Value::Null,
219    }
220}
221
222// ============================================================================
223// In-Memory Executor for Testing
224// ============================================================================
225
226/// Simple in-memory vector executor for testing without full VectorStore
227pub struct InMemoryVectorExecutor {
228    /// Vectors indexed by (collection, id)
229    vectors: HashMap<(String, u64), Vec<f32>>,
230    /// Metadata indexed by (collection, id)
231    metadata: HashMap<(String, u64), HashMap<String, MetadataValue>>,
232    /// HNSW indexes by collection
233    indexes: HashMap<String, HnswIndex>,
234    /// Cross-reference index
235    unified_index: Option<Arc<UnifiedIndex>>,
236}
237
238impl InMemoryVectorExecutor {
239    /// Create a new in-memory executor
240    pub fn new() -> Self {
241        Self {
242            vectors: HashMap::new(),
243            metadata: HashMap::new(),
244            indexes: HashMap::new(),
245            unified_index: None,
246        }
247    }
248
249    /// Add cross-reference support
250    pub fn with_unified_index(mut self, index: Arc<UnifiedIndex>) -> Self {
251        self.unified_index = Some(index);
252        self
253    }
254
255    /// Add a vector to a collection
256    pub fn add_vector(
257        &mut self,
258        collection: &str,
259        id: u64,
260        vector: Vec<f32>,
261        meta: Option<HashMap<String, MetadataValue>>,
262    ) {
263        let dim = vector.len();
264
265        // Store vector
266        self.vectors
267            .insert((collection.to_string(), id), vector.clone());
268
269        // Store metadata
270        if let Some(m) = meta {
271            self.metadata.insert((collection.to_string(), id), m);
272        }
273
274        // Add to HNSW index
275        let index = self
276            .indexes
277            .entry(collection.to_string())
278            .or_insert_with(|| {
279                let config = HnswConfig {
280                    m: 16,
281                    m_max0: 32,
282                    ef_construction: 200,
283                    ef_search: 50,
284                    ml: 1.0 / (16.0_f64).ln(),
285                    metric: DistanceMetric::L2,
286                };
287                HnswIndex::new(dim, config)
288            });
289
290        index.insert_with_id(id, vector.clone());
291    }
292
293    /// Execute a vector query
294    pub fn execute(&self, query: &VectorQuery) -> Result<UnifiedResult, ExecutionError> {
295        let start = std::time::Instant::now();
296
297        // Resolve query vector
298        let query_vector = match &query.query_vector {
299            VectorSource::Literal(v) => v.clone(),
300            VectorSource::Reference {
301                collection,
302                vector_id,
303            } => self
304                .vectors
305                .get(&(collection.clone(), *vector_id))
306                .cloned()
307                .ok_or_else(|| ExecutionError::new("Reference vector not found"))?,
308            VectorSource::Text(t) => {
309                return Err(ExecutionError::new(format!(
310                    "Text embedding not implemented: '{}'",
311                    t
312                )));
313            }
314            VectorSource::Subquery(expr) => self.resolve_subquery_vector(expr.as_ref())?,
315        };
316
317        let metric = query.metric.unwrap_or(DistanceMetric::L2);
318
319        // Get or create result
320        let mut result = UnifiedResult::with_columns(vec![
321            "id".to_string(),
322            "distance".to_string(),
323            "collection".to_string(),
324        ]);
325
326        // Search using HNSW if available, otherwise brute force
327        let search_results: Vec<(u64, f32)> =
328            if let Some(index) = self.indexes.get(&query.collection) {
329                // HNSW search returns DistanceResult with id and distance
330                let mut results: Vec<_> = index
331                    .search(&query_vector, query.k)
332                    .into_iter()
333                    .map(|r| (r.id, r.distance))
334                    .collect();
335                results.sort_by(|a, b| {
336                    match a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) {
337                        std::cmp::Ordering::Equal => a.0.cmp(&b.0),
338                        ordering => ordering,
339                    }
340                });
341                results
342            } else {
343                // Brute force search
344                self.brute_force_search(&query.collection, &query_vector, query.k, metric)
345            };
346
347        for (vector_id, dist) in search_results {
348            // Apply threshold
349            if let Some(threshold) = query.threshold {
350                if dist > threshold {
351                    continue;
352                }
353            }
354
355            // Apply metadata filter
356            if let Some(ref filter) = query.filter {
357                let key = (query.collection.clone(), vector_id);
358                if let Some(meta) = self.metadata.get(&key) {
359                    if !evaluate_filter(filter, meta) {
360                        continue;
361                    }
362                } else {
363                    continue; // No metadata, filter fails
364                }
365            }
366
367            let mut record = UnifiedRecord::new();
368            let mut vsr = VectorSearchResult::new(vector_id, &query.collection, dist);
369
370            if query.include_vectors {
371                if let Some(vec) = self.vectors.get(&(query.collection.clone(), vector_id)) {
372                    vsr = vsr.with_vector(vec.clone());
373                }
374            }
375
376            if query.include_metadata {
377                if let Some(meta) = self.metadata.get(&(query.collection.clone(), vector_id)) {
378                    let meta_map: HashMap<String, Value> = meta
379                        .iter()
380                        .map(|(k, v)| (k.clone(), metadata_value_to_value(v.clone())))
381                        .collect();
382                    vsr = vsr.with_metadata(meta_map);
383                }
384            }
385
386            // Add cross-references
387            if let Some(ref unified) = self.unified_index {
388                if let Some(node_id) = unified.get_vector_node(&query.collection, vector_id) {
389                    vsr = vsr.with_linked_node(node_id);
390                }
391
392                if let Some(row_key) = unified.get_vector_row(&query.collection, vector_id) {
393                    vsr = vsr.with_linked_row(&row_key.table, row_key.row_id);
394                }
395            }
396
397            record.set_arc(Arc::from("id"), Value::Integer(vector_id as i64));
398            record.set_arc(Arc::from("distance"), Value::Float(dist as f64));
399            record.set_arc(
400                Arc::from("collection"),
401                Value::text(query.collection.clone()),
402            );
403            record.vector_results.push(vsr);
404            result.push(record);
405        }
406
407        result.stats = QueryStats {
408            nodes_scanned: 0,
409            edges_scanned: 0,
410            rows_scanned: self.vectors.len() as u64,
411            exec_time_us: start.elapsed().as_micros() as u64,
412        };
413
414        Ok(result)
415    }
416
417    fn resolve_subquery_vector(&self, expr: &QueryExpr) -> Result<Vec<f32>, ExecutionError> {
418        match expr {
419            QueryExpr::Vector(query) => {
420                let result = self.execute(query)?;
421                let (collection, vector_id) =
422                    vector_subquery_reference(&result.records, &query.collection)?;
423                self.vectors
424                    .get(&(collection, vector_id))
425                    .cloned()
426                    .ok_or_else(|| ExecutionError::new("Subquery reference vector not found"))
427            }
428            other => Err(ExecutionError::new(format!(
429                "Vector subqueries currently support only nested VECTOR SEARCH expressions, got {}",
430                query_expr_name(other)
431            ))),
432        }
433    }
434
435    /// Brute force search when no index is available
436    fn brute_force_search(
437        &self,
438        collection: &str,
439        query: &[f32],
440        k: usize,
441        metric: DistanceMetric,
442    ) -> Vec<(u64, f32)> {
443        let mut results: Vec<(u64, f32)> = self
444            .vectors
445            .iter()
446            .filter(|((c, _), _)| c == collection)
447            .map(|((_, id), vec)| {
448                let dist = distance(query, vec, metric);
449                (*id, dist)
450            })
451            .collect();
452
453        results.sort_by(
454            |a, b| match a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) {
455                std::cmp::Ordering::Equal => a.0.cmp(&b.0),
456                ordering => ordering,
457            },
458        );
459        results.truncate(k);
460        results
461    }
462}
463
464impl Default for InMemoryVectorExecutor {
465    fn default() -> Self {
466        Self::new()
467    }
468}
469
470/// Evaluate a metadata filter against metadata values
471fn evaluate_filter(filter: &MetadataFilter, metadata: &HashMap<String, MetadataValue>) -> bool {
472    match filter {
473        MetadataFilter::Eq(field, value) => metadata
474            .get(field)
475            .map(|candidate| candidate.matches_eq(value))
476            .unwrap_or(false),
477        MetadataFilter::Ne(field, value) => metadata
478            .get(field)
479            .map(|candidate| !candidate.matches_eq(value))
480            .unwrap_or(true),
481        MetadataFilter::Lt(field, value) => metadata
482            .get(field)
483            .and_then(|candidate| candidate.compare(value))
484            .map(|ord| ord == std::cmp::Ordering::Less)
485            .unwrap_or(false),
486        MetadataFilter::Lte(field, value) => metadata
487            .get(field)
488            .and_then(|candidate| candidate.compare(value))
489            .map(|ord| ord != std::cmp::Ordering::Greater)
490            .unwrap_or(false),
491        MetadataFilter::Gt(field, value) => metadata
492            .get(field)
493            .and_then(|candidate| candidate.compare(value))
494            .map(|ord| ord == std::cmp::Ordering::Greater)
495            .unwrap_or(false),
496        MetadataFilter::Gte(field, value) => metadata
497            .get(field)
498            .and_then(|candidate| candidate.compare(value))
499            .map(|ord| ord != std::cmp::Ordering::Less)
500            .unwrap_or(false),
501        MetadataFilter::In(field, values) => metadata
502            .get(field)
503            .map(|candidate| values.iter().any(|value| candidate.matches_eq(value)))
504            .unwrap_or(false),
505        MetadataFilter::NotIn(field, values) => metadata
506            .get(field)
507            .map(|candidate| !values.iter().any(|value| candidate.matches_eq(value)))
508            .unwrap_or(true),
509        MetadataFilter::Contains(field, substring) => {
510            if let Some(MetadataValue::String(s)) = metadata.get(field) {
511                s.contains(substring)
512            } else {
513                false
514            }
515        }
516        MetadataFilter::And(filters) => filters.iter().all(|f| evaluate_filter(f, metadata)),
517        MetadataFilter::Or(filters) => filters.iter().any(|f| evaluate_filter(f, metadata)),
518        MetadataFilter::Not(inner) => !evaluate_filter(inner, metadata),
519        MetadataFilter::StartsWith(field, prefix) => {
520            if let Some(MetadataValue::String(s)) = metadata.get(field) {
521                s.starts_with(prefix)
522            } else {
523                false
524            }
525        }
526        MetadataFilter::EndsWith(field, suffix) => {
527            if let Some(MetadataValue::String(s)) = metadata.get(field) {
528                s.ends_with(suffix)
529            } else {
530                false
531            }
532        }
533        MetadataFilter::Exists(field) => metadata.contains_key(field),
534        MetadataFilter::NotExists(field) => !metadata.contains_key(field),
535    }
536}
537
538fn vector_subquery_reference(
539    records: &[UnifiedRecord],
540    default_collection: &str,
541) -> Result<(String, u64), ExecutionError> {
542    let record = records
543        .first()
544        .ok_or_else(|| ExecutionError::new("Vector subquery returned no rows"))?;
545
546    let collection: String = match record.get("collection") {
547        Some(Value::Text(collection)) => collection.to_string(),
548        _ => default_collection.to_string(),
549    };
550
551    let vector_id = match record.get("id") {
552        Some(Value::Integer(id)) if *id >= 0 => *id as u64,
553        Some(Value::UnsignedInteger(id)) => *id,
554        other => {
555            return Err(ExecutionError::new(format!(
556                "Vector subquery must expose an integer id column, got {other:?}"
557            )));
558        }
559    };
560
561    Ok((collection, vector_id))
562}
563
564fn query_expr_name(expr: &QueryExpr) -> &'static str {
565    match expr {
566        QueryExpr::Table(_) => "table",
567        QueryExpr::Graph(_) => "graph",
568        QueryExpr::Join(_) => "join",
569        QueryExpr::Path(_) => "path",
570        QueryExpr::Vector(_) => "vector",
571        QueryExpr::Hybrid(_) => "hybrid",
572        QueryExpr::Insert(_) => "insert",
573        QueryExpr::Update(_) => "update",
574        QueryExpr::Delete(_) => "delete",
575        QueryExpr::CreateTable(_) => "create_table",
576        QueryExpr::DropTable(_) => "drop_table",
577        QueryExpr::DropGraph(_) => "drop_graph",
578        QueryExpr::DropVector(_) => "drop_vector",
579        QueryExpr::DropDocument(_) => "drop_document",
580        QueryExpr::DropKv(_) => "drop_kv",
581        QueryExpr::DropCollection(_) => "drop_collection",
582        QueryExpr::Truncate(_) => "truncate",
583        QueryExpr::AlterTable(_) => "alter_table",
584        QueryExpr::GraphCommand(_) => "graph_command",
585        QueryExpr::SearchCommand(_) => "search_command",
586        QueryExpr::Ask(_) => "ask",
587        QueryExpr::CreateIndex(_) => "create_index",
588        QueryExpr::DropIndex(_) => "drop_index",
589        QueryExpr::ProbabilisticCommand(_) => "probabilistic_command",
590        QueryExpr::CreateTimeSeries(_) => "create_timeseries",
591        QueryExpr::DropTimeSeries(_) => "drop_timeseries",
592        QueryExpr::CreateQueue(_) => "create_queue",
593        QueryExpr::AlterQueue(_) => "alter_queue",
594        QueryExpr::DropQueue(_) => "drop_queue",
595        QueryExpr::QueueSelect(_) => "queue_select",
596        QueryExpr::QueueCommand(_) => "queue_command",
597        QueryExpr::KvCommand(_) => "kv_command",
598        QueryExpr::ConfigCommand(_) => "config_command",
599        QueryExpr::CreateTree(_) => "create_tree",
600        QueryExpr::DropTree(_) => "drop_tree",
601        QueryExpr::TreeCommand(_) => "tree_command",
602        QueryExpr::SetConfig { .. } => "set_config",
603        QueryExpr::ShowConfig { .. } => "show_config",
604        QueryExpr::SetSecret { .. } => "set_secret",
605        QueryExpr::DeleteSecret { .. } => "delete_secret",
606        QueryExpr::ShowSecrets { .. } => "show_secrets",
607        QueryExpr::SetTenant(_) => "set_tenant",
608        QueryExpr::ShowTenant => "show_tenant",
609        QueryExpr::ExplainAlter(_) => "explain_alter",
610        QueryExpr::TransactionControl(_) => "transaction_control",
611        QueryExpr::MaintenanceCommand(_) => "maintenance_command",
612        QueryExpr::CreateSchema(_) => "create_schema",
613        QueryExpr::DropSchema(_) => "drop_schema",
614        QueryExpr::CreateSequence(_) => "create_sequence",
615        QueryExpr::DropSequence(_) => "drop_sequence",
616        QueryExpr::CopyFrom(_) => "copy_from",
617        QueryExpr::CreateView(_) => "create_view",
618        QueryExpr::DropView(_) => "drop_view",
619        QueryExpr::RefreshMaterializedView(_) => "refresh_materialized_view",
620        QueryExpr::CreatePolicy(_) => "create_policy",
621        QueryExpr::DropPolicy(_) => "drop_policy",
622        QueryExpr::CreateServer(_) => "create_server",
623        QueryExpr::DropServer(_) => "drop_server",
624        QueryExpr::CreateForeignTable(_) => "create_foreign_table",
625        QueryExpr::DropForeignTable(_) => "drop_foreign_table",
626        QueryExpr::Grant(_) => "grant",
627        QueryExpr::Revoke(_) => "revoke",
628        QueryExpr::AlterUser(_) => "alter_user",
629        QueryExpr::CreateIamPolicy { .. } => "create_iam_policy",
630        QueryExpr::DropIamPolicy { .. } => "drop_iam_policy",
631        QueryExpr::AttachPolicy { .. } => "attach_policy",
632        QueryExpr::DetachPolicy { .. } => "detach_policy",
633        QueryExpr::ShowPolicies { .. } => "show_policies",
634        QueryExpr::ShowEffectivePermissions { .. } => "show_effective_permissions",
635        QueryExpr::SimulatePolicy { .. } => "simulate_policy",
636        QueryExpr::CreateMigration(_) => "create_migration",
637        QueryExpr::ApplyMigration(_) => "apply_migration",
638        QueryExpr::RollbackMigration(_) => "rollback_migration",
639        QueryExpr::ExplainMigration(_) => "explain_migration",
640        QueryExpr::EventsBackfill(_) => "events_backfill",
641        QueryExpr::EventsBackfillStatus { .. } => "events_backfill_status",
642    }
643}
644
645// ============================================================================
646// Tests
647// ============================================================================
648
649#[cfg(test)]
650mod tests {
651    use super::*;
652
653    #[test]
654    fn test_in_memory_vector_search() {
655        let mut executor = InMemoryVectorExecutor::new();
656
657        // Add some vectors
658        executor.add_vector("test", 1, vec![1.0, 0.0, 0.0], None);
659        executor.add_vector("test", 2, vec![0.0, 1.0, 0.0], None);
660        executor.add_vector("test", 3, vec![0.0, 0.0, 1.0], None);
661        executor.add_vector("test", 4, vec![0.9, 0.1, 0.0], None);
662
663        let query = VectorQuery {
664            alias: None,
665            collection: "test".to_string(),
666            query_vector: VectorSource::Literal(vec![1.0, 0.0, 0.0]),
667            k: 2,
668            filter: None,
669            metric: Some(DistanceMetric::L2),
670            include_vectors: false,
671            include_metadata: false,
672            threshold: None,
673        };
674
675        let result = executor.execute(&query).unwrap();
676        assert_eq!(result.len(), 2);
677
678        // First result should be vector 1 (exact match)
679        let first = &result.records[0];
680        assert_eq!(first.get("id"), Some(&Value::Integer(1)));
681    }
682
683    #[test]
684    fn test_vector_search_with_metadata_filter() {
685        let mut executor = InMemoryVectorExecutor::new();
686
687        let mut meta1 = HashMap::new();
688        meta1.insert("type".to_string(), MetadataValue::String("cve".to_string()));
689        meta1.insert("severity".to_string(), MetadataValue::Integer(9));
690
691        let mut meta2 = HashMap::new();
692        meta2.insert("type".to_string(), MetadataValue::String("cve".to_string()));
693        meta2.insert("severity".to_string(), MetadataValue::Integer(5));
694
695        let mut meta3 = HashMap::new();
696        meta3.insert(
697            "type".to_string(),
698            MetadataValue::String("advisory".to_string()),
699        );
700        meta3.insert("severity".to_string(), MetadataValue::Integer(8));
701
702        executor.add_vector("vulns", 1, vec![1.0, 0.0], Some(meta1));
703        executor.add_vector("vulns", 2, vec![0.9, 0.1], Some(meta2));
704        executor.add_vector("vulns", 3, vec![0.8, 0.2], Some(meta3));
705
706        // Search with filter: type = 'cve' AND severity >= 7
707        let query = VectorQuery {
708            alias: None,
709            collection: "vulns".to_string(),
710            query_vector: VectorSource::Literal(vec![1.0, 0.0]),
711            k: 10,
712            filter: Some(MetadataFilter::And(vec![
713                MetadataFilter::Eq("type".to_string(), MetadataValue::String("cve".to_string())),
714                MetadataFilter::Gte("severity".to_string(), MetadataValue::Integer(7)),
715            ])),
716            metric: Some(DistanceMetric::L2),
717            include_vectors: false,
718            include_metadata: true,
719            threshold: None,
720        };
721
722        let result = executor.execute(&query).unwrap();
723
724        // Only vector 1 matches (type=cve, severity=9)
725        assert_eq!(result.len(), 1);
726        assert_eq!(result.records[0].get("id"), Some(&Value::Integer(1)));
727    }
728
729    #[test]
730    fn test_vector_search_with_threshold() {
731        let mut executor = InMemoryVectorExecutor::new();
732
733        executor.add_vector("test", 1, vec![1.0, 0.0], None);
734        executor.add_vector("test", 2, vec![0.0, 1.0], None); // Far from query
735
736        let query = VectorQuery {
737            alias: None,
738            collection: "test".to_string(),
739            query_vector: VectorSource::Literal(vec![1.0, 0.0]),
740            k: 10,
741            filter: None,
742            metric: Some(DistanceMetric::L2),
743            include_vectors: false,
744            include_metadata: false,
745            threshold: Some(0.5), // Only include close matches
746        };
747
748        let result = executor.execute(&query).unwrap();
749
750        // Only vector 1 is within threshold
751        assert_eq!(result.len(), 1);
752    }
753
754    #[test]
755    fn test_vector_search_include_vectors() {
756        let mut executor = InMemoryVectorExecutor::new();
757
758        executor.add_vector("test", 1, vec![1.0, 2.0, 3.0], None);
759
760        let query = VectorQuery {
761            alias: None,
762            collection: "test".to_string(),
763            query_vector: VectorSource::Literal(vec![1.0, 2.0, 3.0]),
764            k: 1,
765            filter: None,
766            metric: Some(DistanceMetric::L2),
767            include_vectors: true,
768            include_metadata: false,
769            threshold: None,
770        };
771
772        let result = executor.execute(&query).unwrap();
773        assert_eq!(result.len(), 1);
774
775        let vsr = &result.records[0].vector_results[0];
776        assert!(vsr.vector.is_some());
777        assert_eq!(vsr.vector.as_ref().unwrap(), &vec![1.0, 2.0, 3.0]);
778    }
779
780    #[test]
781    fn test_vector_executor_reference_source() {
782        let mut store = VectorStore::new();
783        let collection = store.create_collection("refs", 2);
784        let ref_id = collection.insert(vec![1.0, 0.0], None).unwrap();
785        collection.insert(vec![0.0, 1.0], None).unwrap();
786
787        let executor = VectorExecutor::new(Arc::new(store));
788        let query = VectorQuery {
789            alias: None,
790            collection: "refs".to_string(),
791            query_vector: VectorSource::Reference {
792                collection: "refs".to_string(),
793                vector_id: ref_id,
794            },
795            k: 1,
796            filter: None,
797            metric: Some(DistanceMetric::L2),
798            include_vectors: false,
799            include_metadata: false,
800            threshold: None,
801        };
802
803        let result = executor.execute(&query).unwrap();
804        assert_eq!(result.len(), 1);
805        assert_eq!(result.records[0].get("id"), Some(&Value::Integer(0)));
806    }
807
808    #[test]
809    fn test_vector_executor_subquery_source() {
810        let mut store = VectorStore::new();
811        let collection = store.create_collection("refs", 2);
812        collection.insert(vec![1.0, 0.0], None).unwrap();
813        collection.insert(vec![0.0, 1.0], None).unwrap();
814
815        let executor = VectorExecutor::new(Arc::new(store));
816        let inner = VectorQuery {
817            alias: None,
818            collection: "refs".to_string(),
819            query_vector: VectorSource::Literal(vec![1.0, 0.0]),
820            k: 1,
821            filter: None,
822            metric: Some(DistanceMetric::L2),
823            include_vectors: false,
824            include_metadata: false,
825            threshold: None,
826        };
827        let query = VectorQuery {
828            alias: None,
829            collection: "refs".to_string(),
830            query_vector: VectorSource::Subquery(Box::new(QueryExpr::Vector(inner))),
831            k: 1,
832            filter: None,
833            metric: Some(DistanceMetric::L2),
834            include_vectors: false,
835            include_metadata: false,
836            threshold: None,
837        };
838
839        let result = executor.execute(&query).unwrap();
840        assert_eq!(result.len(), 1);
841        assert_eq!(result.records[0].get("id"), Some(&Value::Integer(0)));
842    }
843}