Skip to main content

reddb_server/storage/query/rag/
unified_adapter.rs

1//! Unified Store Adapter for RAG Engine
2//!
3//! Bridges the unified RedDB store with the existing RAG retrieval infrastructure,
4//! enabling queries that seamlessly combine tables, graphs, and vectors.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use crate::storage::query::unified::ExecutionError;
10use crate::storage::schema::Value;
11use crate::storage::{CrossRef, EntityData, EntityId, EntityKind, RefType, Store, UnifiedEntity};
12
13use super::context::{ChunkSource, ContextChunk, RetrievalContext};
14use super::RagConfig;
15
16/// Result from a unified multi-modal query
17#[derive(Debug, Clone)]
18pub struct UnifiedQueryResult {
19    /// Matched entities (rows, nodes, edges, vectors)
20    pub entities: Vec<MatchedEntity>,
21    /// Query statistics
22    pub stats: UnifiedQueryStats,
23}
24
25impl UnifiedQueryResult {
26    pub fn new() -> Self {
27        Self {
28            entities: Vec::new(),
29            stats: UnifiedQueryStats::default(),
30        }
31    }
32
33    pub fn push(&mut self, entity: MatchedEntity) {
34        self.entities.push(entity);
35    }
36
37    pub fn len(&self) -> usize {
38        self.entities.len()
39    }
40
41    pub fn is_empty(&self) -> bool {
42        self.entities.is_empty()
43    }
44}
45
46impl Default for UnifiedQueryResult {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52/// A matched entity with relevance score and source information
53#[derive(Debug, Clone)]
54pub struct MatchedEntity {
55    /// The entity itself
56    pub entity: UnifiedEntity,
57    /// Relevance score (0.0 - 1.0)
58    pub score: f32,
59    /// Source of the match
60    pub source: MatchSource,
61    /// Cross-references followed to reach this entity
62    pub via_refs: Vec<CrossRef>,
63}
64
65impl MatchedEntity {
66    pub fn new(entity: UnifiedEntity, score: f32, source: MatchSource) -> Self {
67        Self {
68            entity,
69            score,
70            source,
71            via_refs: Vec::new(),
72        }
73    }
74
75    pub fn with_refs(mut self, refs: Vec<CrossRef>) -> Self {
76        self.via_refs = refs;
77        self
78    }
79}
80
81/// Source of a match in unified query
82#[derive(Debug, Clone, Copy, PartialEq, Eq)]
83pub enum MatchSource {
84    /// Direct vector similarity search
85    VectorSimilarity,
86    /// Graph pattern match
87    GraphPattern,
88    /// Table filter match
89    TableFilter,
90    /// Cross-reference expansion
91    CrossReference,
92    /// Hybrid scoring
93    Hybrid,
94}
95
96/// Statistics for unified query execution
97#[derive(Debug, Clone, Default)]
98pub struct UnifiedQueryStats {
99    /// Number of vector comparisons
100    pub vector_comparisons: usize,
101    /// Number of graph patterns checked
102    pub graph_patterns_checked: usize,
103    /// Number of table rows scanned
104    pub table_rows_scanned: usize,
105    /// Number of cross-refs followed
106    pub cross_refs_followed: usize,
107    /// Execution time in microseconds
108    pub execution_time_us: u64,
109}
110
111/// Adapter that connects the store to RAG queries
112pub struct UnifiedStoreAdapter {
113    /// The store
114    store: Arc<Store>,
115}
116
117impl UnifiedStoreAdapter {
118    /// Create a new adapter for the given store
119    pub fn new(store: Arc<Store>) -> Self {
120        Self { store }
121    }
122
123    /// Search vectors across all collections
124    pub fn vector_search(
125        &self,
126        query_vector: &[f32],
127        collections: Option<&[&str]>,
128        k: usize,
129        _metadata_filter: Option<MetadataQuery>,
130    ) -> Result<UnifiedQueryResult, ExecutionError> {
131        let start = std::time::Instant::now();
132        let mut result = UnifiedQueryResult::new();
133
134        // Get all collections if not specified
135        let collection_names: Vec<String> = if let Some(cols) = collections {
136            cols.iter().map(|s| s.to_string()).collect()
137        } else {
138            self.store.list_collections()
139        };
140
141        // Search each collection using query_all
142        for col_name in &collection_names {
143            let manager = match self.store.get_collection(col_name) {
144                Some(m) => m,
145                None => continue,
146            };
147
148            // Use query_all to scan entities
149            let entities = manager.query_all(|_| true);
150            for entity in entities {
151                // Check if it's a vector entity
152                if let EntityData::Vector(ref vec_data) = entity.data {
153                    let similarity = cosine_similarity(query_vector, &vec_data.dense);
154                    if similarity > 0.0 {
155                        result.push(MatchedEntity::new(
156                            entity.clone(),
157                            similarity,
158                            MatchSource::VectorSimilarity,
159                        ));
160                        result.stats.vector_comparisons += 1;
161                    }
162                }
163
164                // Also check embeddings in any entity type
165                for slot in entity.embeddings() {
166                    let similarity = cosine_similarity(query_vector, &slot.vector);
167                    if similarity > 0.5 {
168                        result.push(MatchedEntity::new(
169                            entity.clone(),
170                            similarity,
171                            MatchSource::VectorSimilarity,
172                        ));
173                        result.stats.vector_comparisons += 1;
174                    }
175                }
176            }
177        }
178
179        // Sort by score and take top k
180        result.entities.sort_by(|a, b| {
181            b.score
182                .partial_cmp(&a.score)
183                .unwrap_or(std::cmp::Ordering::Equal)
184                .then_with(|| a.entity.id.cmp(&b.entity.id))
185        });
186        result.entities.truncate(k);
187
188        result.stats.execution_time_us = start.elapsed().as_micros() as u64;
189        Ok(result)
190    }
191
192    /// Find entities by cross-reference type
193    pub fn find_by_cross_ref(
194        &self,
195        source_id: EntityId,
196        ref_type: RefType,
197        max_depth: u32,
198    ) -> Result<UnifiedQueryResult, ExecutionError> {
199        let start = std::time::Instant::now();
200        let mut result = UnifiedQueryResult::new();
201        let mut visited = std::collections::HashSet::new();
202        let mut frontier = vec![(source_id, 0u32, vec![])];
203
204        while let Some((current_id, depth, path)) = frontier.pop() {
205            if depth > max_depth || visited.contains(&current_id) {
206                continue;
207            }
208            visited.insert(current_id);
209
210            // Find the entity
211            if let Some((col_name, entity)) = self.store.get_any(current_id) {
212                // Add to results if not the source
213                if current_id != source_id {
214                    let matched = MatchedEntity::new(
215                        entity.clone(),
216                        1.0 - (depth as f32 * 0.2),
217                        MatchSource::CrossReference,
218                    )
219                    .with_refs(path.clone());
220                    result.push(matched);
221                }
222
223                // Expand cross-refs of matching type
224                for (target_id, link_type, target_collection) in
225                    self.store.get_refs_from(current_id)
226                {
227                    if link_type == ref_type || matches!(ref_type, RefType::RelatedTo) {
228                        let mut new_path = path.clone();
229                        new_path.push(CrossRef::new(
230                            current_id,
231                            target_id,
232                            target_collection,
233                            link_type,
234                        ));
235                        frontier.push((target_id, depth + 1, new_path));
236                    }
237                }
238
239                result.stats.cross_refs_followed += 1;
240            }
241        }
242
243        result.stats.execution_time_us = start.elapsed().as_micros() as u64;
244        Ok(result)
245    }
246
247    /// Execute a multi-modal query combining vector, graph, and table filters
248    pub fn multi_modal_query(
249        &self,
250        query: MultiModalQuery,
251    ) -> Result<UnifiedQueryResult, ExecutionError> {
252        let start = std::time::Instant::now();
253        let mut result = UnifiedQueryResult::new();
254
255        // 1. Vector search if query vector provided
256        let mut vector_results = HashMap::new();
257        if let Some(ref qvec) = query.query_vector {
258            let vec_result = self.vector_search(
259                qvec,
260                query.collections.as_deref(),
261                query.vector_k.unwrap_or(10),
262                query.metadata_filter.clone(),
263            )?;
264            for m in vec_result.entities {
265                vector_results.insert(m.entity.id, m.score);
266            }
267        }
268
269        // 2. Pattern matching for graph entities
270        let mut graph_matches = std::collections::HashSet::new();
271        if let Some(ref pattern) = query.graph_pattern {
272            self.match_graph_pattern(pattern, &mut graph_matches)?;
273        }
274
275        // 3. Scan all collections and score entities
276        for col_name in &self.store.list_collections() {
277            if let Some(cols) = &query.collections {
278                if !cols.contains(&col_name.as_str()) {
279                    continue;
280                }
281            }
282
283            let manager = match self.store.get_collection(col_name) {
284                Some(m) => m,
285                None => continue,
286            };
287
288            // Use query_all to get entities
289            let entities = manager.query_all(|_| true);
290            for entity in entities {
291                let mut score = 0.0f32;
292                let mut sources = vec![];
293
294                // Vector similarity score
295                if let Some(&vec_score) = vector_results.get(&entity.id) {
296                    score += vec_score * query.vector_weight.unwrap_or(0.5);
297                    sources.push(MatchSource::VectorSimilarity);
298                }
299
300                // Graph pattern match
301                if graph_matches.contains(&entity.id) {
302                    score += 0.8 * query.graph_weight.unwrap_or(0.3);
303                    sources.push(MatchSource::GraphPattern);
304                }
305
306                // Metadata filter match - check entity properties
307                if let Some(ref filter) = query.metadata_filter {
308                    if self.matches_metadata(&entity, filter) {
309                        score += 0.5 * query.table_weight.unwrap_or(0.2);
310                        sources.push(MatchSource::TableFilter);
311                    }
312                }
313
314                // Add if score is above threshold
315                if score >= query.min_score.unwrap_or(0.1) {
316                    let source = if sources.len() > 1 {
317                        MatchSource::Hybrid
318                    } else {
319                        sources.first().copied().unwrap_or(MatchSource::Hybrid)
320                    };
321
322                    result.push(MatchedEntity::new(entity, score, source));
323                }
324            }
325        }
326
327        // Sort by score
328        result.entities.sort_by(|a, b| {
329            b.score
330                .partial_cmp(&a.score)
331                .unwrap_or(std::cmp::Ordering::Equal)
332                .then_with(|| a.entity.id.cmp(&b.entity.id))
333        });
334
335        // Apply limit
336        if let Some(limit) = query.limit {
337            result.entities.truncate(limit);
338        }
339
340        result.stats.execution_time_us = start.elapsed().as_micros() as u64;
341        Ok(result)
342    }
343
344    /// Expand context around an entity by following cross-refs
345    pub fn expand_entity_context(
346        &self,
347        entity_id: EntityId,
348        config: &RagConfig,
349    ) -> Result<RetrievalContext, ExecutionError> {
350        let mut context = RetrievalContext::new(format!("expand:{}", entity_id.0));
351
352        // Find the entity first
353        let (collection, entity) = self
354            .store
355            .get_any(entity_id)
356            .ok_or_else(|| ExecutionError::new(format!("Entity {} not found", entity_id.0)))?;
357
358        // Add the entity itself as a chunk
359        context.add_chunk(entity_to_chunk(&entity, &collection, 1.0));
360
361        // Follow cross-refs up to configured depth
362        let refs_result =
363            self.find_by_cross_ref(entity_id, RefType::RelatedTo, config.graph_depth)?;
364        for matched in refs_result.entities {
365            context.add_chunk(entity_to_chunk(&matched.entity, "cross_ref", matched.score));
366        }
367
368        // If the entity has embeddings, find similar vectors
369        if !entity.embeddings().is_empty() && config.expand_cross_refs {
370            let primary_vec = &entity.embeddings()[0].vector;
371            let similar = self.vector_search(primary_vec, None, 5, None)?;
372            for matched in similar.entities {
373                if matched.entity.id != entity_id {
374                    context.add_chunk(entity_to_chunk(
375                        &matched.entity,
376                        "similar",
377                        matched.score * 0.8,
378                    ));
379                }
380            }
381        }
382
383        Ok(context)
384    }
385
386    /// Check if an entity matches metadata filter by checking properties
387    fn matches_metadata(&self, entity: &UnifiedEntity, filter: &MetadataQuery) -> bool {
388        // Extract properties from entity data
389        let properties: HashMap<String, Value> = match &entity.data {
390            EntityData::Node(node) => node.properties.clone(),
391            EntityData::Edge(edge) => edge.properties.clone(),
392            EntityData::Row(row) => row.named.clone().unwrap_or_default(),
393            EntityData::Vector(_) => HashMap::new(),
394            EntityData::TimeSeries(_) => HashMap::new(),
395            EntityData::QueueMessage(_) => HashMap::new(),
396        };
397
398        for (key, expected) in &filter.conditions {
399            let prop_val = properties.get(key);
400            let matches = match (prop_val, expected) {
401                (Some(Value::Text(s)), QueryCondition::Equals(QueryValue::String(exp))) => {
402                    &**s == exp.as_str()
403                }
404                (Some(Value::Integer(i)), QueryCondition::Equals(QueryValue::Int(exp))) => {
405                    *i == *exp
406                }
407                (Some(Value::Float(f)), QueryCondition::Equals(QueryValue::Float(exp))) => {
408                    *f == *exp
409                }
410                (Some(Value::Boolean(b)), QueryCondition::Equals(QueryValue::Bool(exp))) => {
411                    *b == *exp
412                }
413                (Some(Value::Integer(i)), QueryCondition::GreaterThan(QueryValue::Int(n))) => {
414                    *i > *n
415                }
416                (Some(Value::Float(f)), QueryCondition::GreaterThan(QueryValue::Float(n))) => {
417                    *f > *n
418                }
419                (Some(Value::Integer(i)), QueryCondition::LessThan(QueryValue::Int(n))) => *i < *n,
420                (Some(Value::Float(f)), QueryCondition::LessThan(QueryValue::Float(n))) => *f < *n,
421                (Some(Value::Text(s)), QueryCondition::Contains(substr)) => {
422                    s.contains(substr.as_str())
423                }
424                _ => false,
425            };
426            if !matches {
427                return false;
428            }
429        }
430        true
431    }
432
433    /// Match graph pattern against entities
434    fn match_graph_pattern(
435        &self,
436        pattern: &GraphQueryPattern,
437        matches: &mut std::collections::HashSet<EntityId>,
438    ) -> Result<(), ExecutionError> {
439        for col_name in &self.store.list_collections() {
440            let manager = match self.store.get_collection(col_name) {
441                Some(m) => m,
442                None => continue,
443            };
444
445            let entities = manager.query_all(|_| true);
446            for entity in entities {
447                let is_match = match (&entity.kind, &pattern.node_pattern) {
448                    (EntityKind::GraphNode(ref node), Some(pat)) => {
449                        let label_match = pat.label.as_ref().is_none_or(|l| &node.label == l);
450                        let type_match =
451                            pat.node_type.as_ref().is_none_or(|t| &node.node_type == t);
452                        label_match && type_match
453                    }
454                    (EntityKind::GraphEdge(ref edge), Some(pat)) => {
455                        pat.label.as_ref() == Some(&edge.label)
456                    }
457                    (_, None) => true,
458                    _ => false,
459                };
460
461                if is_match {
462                    matches.insert(entity.id);
463                }
464            }
465        }
466
467        Ok(())
468    }
469}
470
471/// Multi-modal query specification
472#[derive(Debug, Clone, Default)]
473pub struct MultiModalQuery {
474    /// Query vector for similarity search
475    pub query_vector: Option<Vec<f32>>,
476    /// Collections to search (None = all)
477    pub collections: Option<Vec<&'static str>>,
478    /// Number of vectors to retrieve
479    pub vector_k: Option<usize>,
480    /// Graph pattern to match
481    pub graph_pattern: Option<GraphQueryPattern>,
482    /// Metadata filter conditions
483    pub metadata_filter: Option<MetadataQuery>,
484    /// Weight for vector similarity (0.0-1.0)
485    pub vector_weight: Option<f32>,
486    /// Weight for graph pattern match (0.0-1.0)
487    pub graph_weight: Option<f32>,
488    /// Weight for table/metadata filter (0.0-1.0)
489    pub table_weight: Option<f32>,
490    /// Minimum combined score
491    pub min_score: Option<f32>,
492    /// Maximum results to return
493    pub limit: Option<usize>,
494}
495
496impl MultiModalQuery {
497    pub fn new() -> Self {
498        Self::default()
499    }
500
501    pub fn with_vector(mut self, vector: Vec<f32>, k: usize) -> Self {
502        self.query_vector = Some(vector);
503        self.vector_k = Some(k);
504        self
505    }
506
507    pub fn with_graph_pattern(mut self, pattern: GraphQueryPattern) -> Self {
508        self.graph_pattern = Some(pattern);
509        self
510    }
511
512    pub fn with_metadata(mut self, filter: MetadataQuery) -> Self {
513        self.metadata_filter = Some(filter);
514        self
515    }
516
517    pub fn with_weights(mut self, vector: f32, graph: f32, table: f32) -> Self {
518        self.vector_weight = Some(vector);
519        self.graph_weight = Some(graph);
520        self.table_weight = Some(table);
521        self
522    }
523
524    pub fn with_limit(mut self, limit: usize) -> Self {
525        self.limit = Some(limit);
526        self
527    }
528}
529
530/// Graph pattern for matching
531#[derive(Debug, Clone, Default)]
532pub struct GraphQueryPattern {
533    /// Node pattern (label, type)
534    pub node_pattern: Option<NodePattern>,
535    /// Edge patterns to match
536    pub edge_patterns: Vec<EdgePatternSpec>,
537}
538
539/// Node pattern
540#[derive(Debug, Clone)]
541pub struct NodePattern {
542    pub label: Option<String>,
543    pub node_type: Option<String>,
544}
545
546/// Edge pattern
547#[derive(Debug, Clone)]
548pub struct EdgePatternSpec {
549    pub label: Option<String>,
550    pub direction: EdgeDirection,
551}
552
553#[derive(Debug, Clone, Copy)]
554pub enum EdgeDirection {
555    Outgoing,
556    Incoming,
557    Any,
558}
559
560/// Metadata query filter
561#[derive(Debug, Clone, Default)]
562pub struct MetadataQuery {
563    pub conditions: HashMap<String, QueryCondition>,
564}
565
566impl MetadataQuery {
567    pub fn new() -> Self {
568        Self::default()
569    }
570
571    pub fn eq(mut self, key: impl Into<String>, value: impl Into<QueryValue>) -> Self {
572        self.conditions
573            .insert(key.into(), QueryCondition::Equals(value.into()));
574        self
575    }
576
577    pub fn gt(mut self, key: impl Into<String>, value: impl Into<QueryValue>) -> Self {
578        self.conditions
579            .insert(key.into(), QueryCondition::GreaterThan(value.into()));
580        self
581    }
582
583    pub fn lt(mut self, key: impl Into<String>, value: impl Into<QueryValue>) -> Self {
584        self.conditions
585            .insert(key.into(), QueryCondition::LessThan(value.into()));
586        self
587    }
588
589    pub fn contains(mut self, key: impl Into<String>, substr: impl Into<String>) -> Self {
590        self.conditions
591            .insert(key.into(), QueryCondition::Contains(substr.into()));
592        self
593    }
594}
595
596#[derive(Debug, Clone)]
597pub enum QueryCondition {
598    Equals(QueryValue),
599    GreaterThan(QueryValue),
600    LessThan(QueryValue),
601    Contains(String),
602}
603
604#[derive(Debug, Clone)]
605pub enum QueryValue {
606    Int(i64),
607    Float(f64),
608    String(String),
609    Bool(bool),
610}
611
612impl From<i64> for QueryValue {
613    fn from(v: i64) -> Self {
614        QueryValue::Int(v)
615    }
616}
617
618impl From<f64> for QueryValue {
619    fn from(v: f64) -> Self {
620        QueryValue::Float(v)
621    }
622}
623
624impl From<&str> for QueryValue {
625    fn from(v: &str) -> Self {
626        QueryValue::String(v.to_string())
627    }
628}
629
630impl From<String> for QueryValue {
631    fn from(v: String) -> Self {
632        QueryValue::String(v)
633    }
634}
635
636impl From<bool> for QueryValue {
637    fn from(v: bool) -> Self {
638        QueryValue::Bool(v)
639    }
640}
641
642// ============================================================================
643// Helper Functions
644// ============================================================================
645
646/// Calculate cosine similarity between two vectors
647fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
648    if a.len() != b.len() || a.is_empty() {
649        return 0.0;
650    }
651
652    let mut dot = 0.0f32;
653    let mut norm_a = 0.0f32;
654    let mut norm_b = 0.0f32;
655
656    for i in 0..a.len() {
657        dot += a[i] * b[i];
658        norm_a += a[i] * a[i];
659        norm_b += b[i] * b[i];
660    }
661
662    let denom = norm_a.sqrt() * norm_b.sqrt();
663    if denom > 0.0 {
664        dot / denom
665    } else {
666        0.0
667    }
668}
669
670/// Convert an entity to a context chunk
671fn entity_to_chunk(entity: &UnifiedEntity, collection: &str, score: f32) -> ContextChunk {
672    let content = match &entity.data {
673        EntityData::Row(row) => {
674            let fields: Vec<String> = row
675                .columns
676                .iter()
677                .enumerate()
678                .map(|(i, v)| format!("col{}: {:?}", i, v))
679                .collect();
680            fields.join(", ")
681        }
682        EntityData::Node(node) => {
683            let props: Vec<String> = node
684                .properties
685                .iter()
686                .map(|(k, v)| format!("{}: {:?}", k, v))
687                .collect();
688            format!("Node: {}", props.join(", "))
689        }
690        EntityData::Edge(edge) => {
691            format!("Edge: weight={}", edge.weight)
692        }
693        EntityData::Vector(vec) => {
694            format!(
695                "Vector: dim={}, sparse={}",
696                vec.dense.len(),
697                vec.sparse.is_some()
698            )
699        }
700        EntityData::TimeSeries(ts) => {
701            format!("TimeSeries: metric={}, value={}", ts.metric, ts.value)
702        }
703        EntityData::QueueMessage(msg) => {
704            format!(
705                "QueueMessage: attempts={}, acked={}",
706                msg.attempts, msg.acked
707            )
708        }
709    };
710
711    let (source, entity_type) = match &entity.kind {
712        EntityKind::TableRow { table, .. } => (
713            ChunkSource::Table(table.to_string()),
714            Some(super::EntityType::Unknown), // Generic table row
715        ),
716        EntityKind::GraphNode(ref node) => (
717            ChunkSource::Graph,
718            // Try to map node_type to EntityType
719            Some(match node.node_type.to_lowercase().as_str() {
720                "host" => super::EntityType::Host,
721                "service" => super::EntityType::Service,
722                "port" => super::EntityType::Port,
723                "vulnerability" | "vuln" => super::EntityType::Vulnerability,
724                "credential" | "cred" => super::EntityType::Credential,
725                "user" => super::EntityType::User,
726                "certificate" | "cert" => super::EntityType::Certificate,
727                "domain" => super::EntityType::Domain,
728                "network" => super::EntityType::Network,
729                "technology" | "tech" => super::EntityType::Technology,
730                "endpoint" => super::EntityType::Endpoint,
731                _ => super::EntityType::Unknown,
732            }),
733        ),
734        EntityKind::GraphEdge(_) => (
735            ChunkSource::Graph,
736            Some(super::EntityType::Unknown), // Edges don't have a direct type mapping
737        ),
738        EntityKind::Vector { collection: col } => (
739            ChunkSource::Vector(col.clone()),
740            Some(super::EntityType::Unknown), // Vectors don't have a direct type mapping
741        ),
742        EntityKind::TimeSeriesPoint(ref ts) => (
743            ChunkSource::Table(ts.series.clone()),
744            Some(super::EntityType::Unknown),
745        ),
746        EntityKind::QueueMessage { queue, .. } => (
747            ChunkSource::Table(queue.clone()),
748            Some(super::EntityType::Unknown),
749        ),
750    };
751
752    ContextChunk {
753        content,
754        source,
755        relevance: score,
756        entity_type,
757        entity_id: Some(entity.id.0.to_string()),
758        metadata: HashMap::new(),
759        vector_distance: Some(1.0 - score), // Convert similarity to distance
760        graph_depth: None,
761    }
762}
763
764#[cfg(test)]
765mod tests {
766    use super::*;
767
768    #[test]
769    fn test_cosine_similarity() {
770        let a = vec![1.0, 0.0, 0.0];
771        let b = vec![1.0, 0.0, 0.0];
772        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
773
774        let c = vec![0.0, 1.0, 0.0];
775        assert!(cosine_similarity(&a, &c).abs() < 0.001);
776
777        let d = vec![1.0, 1.0, 0.0];
778        let sim = cosine_similarity(&a, &d);
779        assert!(sim > 0.7 && sim < 0.72);
780    }
781
782    #[test]
783    fn test_metadata_query_builder() {
784        let query = MetadataQuery::new()
785            .eq("type", "host")
786            .gt("score", 0.5f64)
787            .contains("name", "server");
788
789        assert_eq!(query.conditions.len(), 3);
790    }
791
792    #[test]
793    fn test_multi_modal_query_builder() {
794        let query = MultiModalQuery::new()
795            .with_vector(vec![1.0, 0.0, 0.0], 10)
796            .with_weights(0.6, 0.3, 0.1)
797            .with_limit(20);
798
799        assert!(query.query_vector.is_some());
800        assert_eq!(query.vector_k, Some(10));
801        assert_eq!(query.limit, Some(20));
802    }
803}