1use std::collections::HashMap;
7use std::sync::Arc;
8
9use crate::storage::query::unified::ExecutionError;
10use crate::storage::schema::Value;
11use crate::storage::{CrossRef, EntityData, EntityId, EntityKind, RefType, Store, UnifiedEntity};
12
13use super::context::{ChunkSource, ContextChunk, RetrievalContext};
14use super::RagConfig;
15
16#[derive(Debug, Clone)]
18pub struct UnifiedQueryResult {
19 pub entities: Vec<MatchedEntity>,
21 pub stats: UnifiedQueryStats,
23}
24
25impl UnifiedQueryResult {
26 pub fn new() -> Self {
27 Self {
28 entities: Vec::new(),
29 stats: UnifiedQueryStats::default(),
30 }
31 }
32
33 pub fn push(&mut self, entity: MatchedEntity) {
34 self.entities.push(entity);
35 }
36
37 pub fn len(&self) -> usize {
38 self.entities.len()
39 }
40
41 pub fn is_empty(&self) -> bool {
42 self.entities.is_empty()
43 }
44}
45
46impl Default for UnifiedQueryResult {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52#[derive(Debug, Clone)]
54pub struct MatchedEntity {
55 pub entity: UnifiedEntity,
57 pub score: f32,
59 pub source: MatchSource,
61 pub via_refs: Vec<CrossRef>,
63}
64
65impl MatchedEntity {
66 pub fn new(entity: UnifiedEntity, score: f32, source: MatchSource) -> Self {
67 Self {
68 entity,
69 score,
70 source,
71 via_refs: Vec::new(),
72 }
73 }
74
75 pub fn with_refs(mut self, refs: Vec<CrossRef>) -> Self {
76 self.via_refs = refs;
77 self
78 }
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
83pub enum MatchSource {
84 VectorSimilarity,
86 GraphPattern,
88 TableFilter,
90 CrossReference,
92 Hybrid,
94}
95
96#[derive(Debug, Clone, Default)]
98pub struct UnifiedQueryStats {
99 pub vector_comparisons: usize,
101 pub graph_patterns_checked: usize,
103 pub table_rows_scanned: usize,
105 pub cross_refs_followed: usize,
107 pub execution_time_us: u64,
109}
110
111pub struct UnifiedStoreAdapter {
113 store: Arc<Store>,
115}
116
117impl UnifiedStoreAdapter {
118 pub fn new(store: Arc<Store>) -> Self {
120 Self { store }
121 }
122
123 pub fn vector_search(
125 &self,
126 query_vector: &[f32],
127 collections: Option<&[&str]>,
128 k: usize,
129 _metadata_filter: Option<MetadataQuery>,
130 ) -> Result<UnifiedQueryResult, ExecutionError> {
131 let start = std::time::Instant::now();
132 let mut result = UnifiedQueryResult::new();
133
134 let collection_names: Vec<String> = if let Some(cols) = collections {
136 cols.iter().map(|s| s.to_string()).collect()
137 } else {
138 self.store.list_collections()
139 };
140
141 for col_name in &collection_names {
143 let manager = match self.store.get_collection(col_name) {
144 Some(m) => m,
145 None => continue,
146 };
147
148 let entities = manager.query_all(|_| true);
150 for entity in entities {
151 if let EntityData::Vector(ref vec_data) = entity.data {
153 let similarity = cosine_similarity(query_vector, &vec_data.dense);
154 if similarity > 0.0 {
155 result.push(MatchedEntity::new(
156 entity.clone(),
157 similarity,
158 MatchSource::VectorSimilarity,
159 ));
160 result.stats.vector_comparisons += 1;
161 }
162 }
163
164 for slot in entity.embeddings() {
166 let similarity = cosine_similarity(query_vector, &slot.vector);
167 if similarity > 0.5 {
168 result.push(MatchedEntity::new(
169 entity.clone(),
170 similarity,
171 MatchSource::VectorSimilarity,
172 ));
173 result.stats.vector_comparisons += 1;
174 }
175 }
176 }
177 }
178
179 result.entities.sort_by(|a, b| {
181 b.score
182 .partial_cmp(&a.score)
183 .unwrap_or(std::cmp::Ordering::Equal)
184 .then_with(|| a.entity.id.cmp(&b.entity.id))
185 });
186 result.entities.truncate(k);
187
188 result.stats.execution_time_us = start.elapsed().as_micros() as u64;
189 Ok(result)
190 }
191
192 pub fn find_by_cross_ref(
194 &self,
195 source_id: EntityId,
196 ref_type: RefType,
197 max_depth: u32,
198 ) -> Result<UnifiedQueryResult, ExecutionError> {
199 let start = std::time::Instant::now();
200 let mut result = UnifiedQueryResult::new();
201 let mut visited = std::collections::HashSet::new();
202 let mut frontier = vec![(source_id, 0u32, vec![])];
203
204 while let Some((current_id, depth, path)) = frontier.pop() {
205 if depth > max_depth || visited.contains(¤t_id) {
206 continue;
207 }
208 visited.insert(current_id);
209
210 if let Some((col_name, entity)) = self.store.get_any(current_id) {
212 if current_id != source_id {
214 let matched = MatchedEntity::new(
215 entity.clone(),
216 1.0 - (depth as f32 * 0.2),
217 MatchSource::CrossReference,
218 )
219 .with_refs(path.clone());
220 result.push(matched);
221 }
222
223 for (target_id, link_type, target_collection) in
225 self.store.get_refs_from(current_id)
226 {
227 if link_type == ref_type || matches!(ref_type, RefType::RelatedTo) {
228 let mut new_path = path.clone();
229 new_path.push(CrossRef::new(
230 current_id,
231 target_id,
232 target_collection,
233 link_type,
234 ));
235 frontier.push((target_id, depth + 1, new_path));
236 }
237 }
238
239 result.stats.cross_refs_followed += 1;
240 }
241 }
242
243 result.stats.execution_time_us = start.elapsed().as_micros() as u64;
244 Ok(result)
245 }
246
247 pub fn multi_modal_query(
249 &self,
250 query: MultiModalQuery,
251 ) -> Result<UnifiedQueryResult, ExecutionError> {
252 let start = std::time::Instant::now();
253 let mut result = UnifiedQueryResult::new();
254
255 let mut vector_results = HashMap::new();
257 if let Some(ref qvec) = query.query_vector {
258 let vec_result = self.vector_search(
259 qvec,
260 query.collections.as_deref(),
261 query.vector_k.unwrap_or(10),
262 query.metadata_filter.clone(),
263 )?;
264 for m in vec_result.entities {
265 vector_results.insert(m.entity.id, m.score);
266 }
267 }
268
269 let mut graph_matches = std::collections::HashSet::new();
271 if let Some(ref pattern) = query.graph_pattern {
272 self.match_graph_pattern(pattern, &mut graph_matches)?;
273 }
274
275 for col_name in &self.store.list_collections() {
277 if let Some(cols) = &query.collections {
278 if !cols.contains(&col_name.as_str()) {
279 continue;
280 }
281 }
282
283 let manager = match self.store.get_collection(col_name) {
284 Some(m) => m,
285 None => continue,
286 };
287
288 let entities = manager.query_all(|_| true);
290 for entity in entities {
291 let mut score = 0.0f32;
292 let mut sources = vec![];
293
294 if let Some(&vec_score) = vector_results.get(&entity.id) {
296 score += vec_score * query.vector_weight.unwrap_or(0.5);
297 sources.push(MatchSource::VectorSimilarity);
298 }
299
300 if graph_matches.contains(&entity.id) {
302 score += 0.8 * query.graph_weight.unwrap_or(0.3);
303 sources.push(MatchSource::GraphPattern);
304 }
305
306 if let Some(ref filter) = query.metadata_filter {
308 if self.matches_metadata(&entity, filter) {
309 score += 0.5 * query.table_weight.unwrap_or(0.2);
310 sources.push(MatchSource::TableFilter);
311 }
312 }
313
314 if score >= query.min_score.unwrap_or(0.1) {
316 let source = if sources.len() > 1 {
317 MatchSource::Hybrid
318 } else {
319 sources.first().copied().unwrap_or(MatchSource::Hybrid)
320 };
321
322 result.push(MatchedEntity::new(entity, score, source));
323 }
324 }
325 }
326
327 result.entities.sort_by(|a, b| {
329 b.score
330 .partial_cmp(&a.score)
331 .unwrap_or(std::cmp::Ordering::Equal)
332 .then_with(|| a.entity.id.cmp(&b.entity.id))
333 });
334
335 if let Some(limit) = query.limit {
337 result.entities.truncate(limit);
338 }
339
340 result.stats.execution_time_us = start.elapsed().as_micros() as u64;
341 Ok(result)
342 }
343
344 pub fn expand_entity_context(
346 &self,
347 entity_id: EntityId,
348 config: &RagConfig,
349 ) -> Result<RetrievalContext, ExecutionError> {
350 let mut context = RetrievalContext::new(format!("expand:{}", entity_id.0));
351
352 let (collection, entity) = self
354 .store
355 .get_any(entity_id)
356 .ok_or_else(|| ExecutionError::new(format!("Entity {} not found", entity_id.0)))?;
357
358 context.add_chunk(entity_to_chunk(&entity, &collection, 1.0));
360
361 let refs_result =
363 self.find_by_cross_ref(entity_id, RefType::RelatedTo, config.graph_depth)?;
364 for matched in refs_result.entities {
365 context.add_chunk(entity_to_chunk(&matched.entity, "cross_ref", matched.score));
366 }
367
368 if !entity.embeddings().is_empty() && config.expand_cross_refs {
370 let primary_vec = &entity.embeddings()[0].vector;
371 let similar = self.vector_search(primary_vec, None, 5, None)?;
372 for matched in similar.entities {
373 if matched.entity.id != entity_id {
374 context.add_chunk(entity_to_chunk(
375 &matched.entity,
376 "similar",
377 matched.score * 0.8,
378 ));
379 }
380 }
381 }
382
383 Ok(context)
384 }
385
386 fn matches_metadata(&self, entity: &UnifiedEntity, filter: &MetadataQuery) -> bool {
388 let properties: HashMap<String, Value> = match &entity.data {
390 EntityData::Node(node) => node.properties.clone(),
391 EntityData::Edge(edge) => edge.properties.clone(),
392 EntityData::Row(row) => row.named.clone().unwrap_or_default(),
393 EntityData::Vector(_) => HashMap::new(),
394 EntityData::TimeSeries(_) => HashMap::new(),
395 EntityData::QueueMessage(_) => HashMap::new(),
396 };
397
398 for (key, expected) in &filter.conditions {
399 let prop_val = properties.get(key);
400 let matches = match (prop_val, expected) {
401 (Some(Value::Text(s)), QueryCondition::Equals(QueryValue::String(exp))) => {
402 &**s == exp.as_str()
403 }
404 (Some(Value::Integer(i)), QueryCondition::Equals(QueryValue::Int(exp))) => {
405 *i == *exp
406 }
407 (Some(Value::Float(f)), QueryCondition::Equals(QueryValue::Float(exp))) => {
408 *f == *exp
409 }
410 (Some(Value::Boolean(b)), QueryCondition::Equals(QueryValue::Bool(exp))) => {
411 *b == *exp
412 }
413 (Some(Value::Integer(i)), QueryCondition::GreaterThan(QueryValue::Int(n))) => {
414 *i > *n
415 }
416 (Some(Value::Float(f)), QueryCondition::GreaterThan(QueryValue::Float(n))) => {
417 *f > *n
418 }
419 (Some(Value::Integer(i)), QueryCondition::LessThan(QueryValue::Int(n))) => *i < *n,
420 (Some(Value::Float(f)), QueryCondition::LessThan(QueryValue::Float(n))) => *f < *n,
421 (Some(Value::Text(s)), QueryCondition::Contains(substr)) => {
422 s.contains(substr.as_str())
423 }
424 _ => false,
425 };
426 if !matches {
427 return false;
428 }
429 }
430 true
431 }
432
433 fn match_graph_pattern(
435 &self,
436 pattern: &GraphQueryPattern,
437 matches: &mut std::collections::HashSet<EntityId>,
438 ) -> Result<(), ExecutionError> {
439 for col_name in &self.store.list_collections() {
440 let manager = match self.store.get_collection(col_name) {
441 Some(m) => m,
442 None => continue,
443 };
444
445 let entities = manager.query_all(|_| true);
446 for entity in entities {
447 let is_match = match (&entity.kind, &pattern.node_pattern) {
448 (EntityKind::GraphNode(ref node), Some(pat)) => {
449 let label_match = pat.label.as_ref().is_none_or(|l| &node.label == l);
450 let type_match =
451 pat.node_type.as_ref().is_none_or(|t| &node.node_type == t);
452 label_match && type_match
453 }
454 (EntityKind::GraphEdge(ref edge), Some(pat)) => {
455 pat.label.as_ref() == Some(&edge.label)
456 }
457 (_, None) => true,
458 _ => false,
459 };
460
461 if is_match {
462 matches.insert(entity.id);
463 }
464 }
465 }
466
467 Ok(())
468 }
469}
470
471#[derive(Debug, Clone, Default)]
473pub struct MultiModalQuery {
474 pub query_vector: Option<Vec<f32>>,
476 pub collections: Option<Vec<&'static str>>,
478 pub vector_k: Option<usize>,
480 pub graph_pattern: Option<GraphQueryPattern>,
482 pub metadata_filter: Option<MetadataQuery>,
484 pub vector_weight: Option<f32>,
486 pub graph_weight: Option<f32>,
488 pub table_weight: Option<f32>,
490 pub min_score: Option<f32>,
492 pub limit: Option<usize>,
494}
495
496impl MultiModalQuery {
497 pub fn new() -> Self {
498 Self::default()
499 }
500
501 pub fn with_vector(mut self, vector: Vec<f32>, k: usize) -> Self {
502 self.query_vector = Some(vector);
503 self.vector_k = Some(k);
504 self
505 }
506
507 pub fn with_graph_pattern(mut self, pattern: GraphQueryPattern) -> Self {
508 self.graph_pattern = Some(pattern);
509 self
510 }
511
512 pub fn with_metadata(mut self, filter: MetadataQuery) -> Self {
513 self.metadata_filter = Some(filter);
514 self
515 }
516
517 pub fn with_weights(mut self, vector: f32, graph: f32, table: f32) -> Self {
518 self.vector_weight = Some(vector);
519 self.graph_weight = Some(graph);
520 self.table_weight = Some(table);
521 self
522 }
523
524 pub fn with_limit(mut self, limit: usize) -> Self {
525 self.limit = Some(limit);
526 self
527 }
528}
529
530#[derive(Debug, Clone, Default)]
532pub struct GraphQueryPattern {
533 pub node_pattern: Option<NodePattern>,
535 pub edge_patterns: Vec<EdgePatternSpec>,
537}
538
539#[derive(Debug, Clone)]
541pub struct NodePattern {
542 pub label: Option<String>,
543 pub node_type: Option<String>,
544}
545
546#[derive(Debug, Clone)]
548pub struct EdgePatternSpec {
549 pub label: Option<String>,
550 pub direction: EdgeDirection,
551}
552
553#[derive(Debug, Clone, Copy)]
554pub enum EdgeDirection {
555 Outgoing,
556 Incoming,
557 Any,
558}
559
560#[derive(Debug, Clone, Default)]
562pub struct MetadataQuery {
563 pub conditions: HashMap<String, QueryCondition>,
564}
565
566impl MetadataQuery {
567 pub fn new() -> Self {
568 Self::default()
569 }
570
571 pub fn eq(mut self, key: impl Into<String>, value: impl Into<QueryValue>) -> Self {
572 self.conditions
573 .insert(key.into(), QueryCondition::Equals(value.into()));
574 self
575 }
576
577 pub fn gt(mut self, key: impl Into<String>, value: impl Into<QueryValue>) -> Self {
578 self.conditions
579 .insert(key.into(), QueryCondition::GreaterThan(value.into()));
580 self
581 }
582
583 pub fn lt(mut self, key: impl Into<String>, value: impl Into<QueryValue>) -> Self {
584 self.conditions
585 .insert(key.into(), QueryCondition::LessThan(value.into()));
586 self
587 }
588
589 pub fn contains(mut self, key: impl Into<String>, substr: impl Into<String>) -> Self {
590 self.conditions
591 .insert(key.into(), QueryCondition::Contains(substr.into()));
592 self
593 }
594}
595
596#[derive(Debug, Clone)]
597pub enum QueryCondition {
598 Equals(QueryValue),
599 GreaterThan(QueryValue),
600 LessThan(QueryValue),
601 Contains(String),
602}
603
604#[derive(Debug, Clone)]
605pub enum QueryValue {
606 Int(i64),
607 Float(f64),
608 String(String),
609 Bool(bool),
610}
611
612impl From<i64> for QueryValue {
613 fn from(v: i64) -> Self {
614 QueryValue::Int(v)
615 }
616}
617
618impl From<f64> for QueryValue {
619 fn from(v: f64) -> Self {
620 QueryValue::Float(v)
621 }
622}
623
624impl From<&str> for QueryValue {
625 fn from(v: &str) -> Self {
626 QueryValue::String(v.to_string())
627 }
628}
629
630impl From<String> for QueryValue {
631 fn from(v: String) -> Self {
632 QueryValue::String(v)
633 }
634}
635
636impl From<bool> for QueryValue {
637 fn from(v: bool) -> Self {
638 QueryValue::Bool(v)
639 }
640}
641
642fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
648 if a.len() != b.len() || a.is_empty() {
649 return 0.0;
650 }
651
652 let mut dot = 0.0f32;
653 let mut norm_a = 0.0f32;
654 let mut norm_b = 0.0f32;
655
656 for i in 0..a.len() {
657 dot += a[i] * b[i];
658 norm_a += a[i] * a[i];
659 norm_b += b[i] * b[i];
660 }
661
662 let denom = norm_a.sqrt() * norm_b.sqrt();
663 if denom > 0.0 {
664 dot / denom
665 } else {
666 0.0
667 }
668}
669
670fn entity_to_chunk(entity: &UnifiedEntity, collection: &str, score: f32) -> ContextChunk {
672 let content = match &entity.data {
673 EntityData::Row(row) => {
674 let fields: Vec<String> = row
675 .columns
676 .iter()
677 .enumerate()
678 .map(|(i, v)| format!("col{}: {:?}", i, v))
679 .collect();
680 fields.join(", ")
681 }
682 EntityData::Node(node) => {
683 let props: Vec<String> = node
684 .properties
685 .iter()
686 .map(|(k, v)| format!("{}: {:?}", k, v))
687 .collect();
688 format!("Node: {}", props.join(", "))
689 }
690 EntityData::Edge(edge) => {
691 format!("Edge: weight={}", edge.weight)
692 }
693 EntityData::Vector(vec) => {
694 format!(
695 "Vector: dim={}, sparse={}",
696 vec.dense.len(),
697 vec.sparse.is_some()
698 )
699 }
700 EntityData::TimeSeries(ts) => {
701 format!("TimeSeries: metric={}, value={}", ts.metric, ts.value)
702 }
703 EntityData::QueueMessage(msg) => {
704 format!(
705 "QueueMessage: attempts={}, acked={}",
706 msg.attempts, msg.acked
707 )
708 }
709 };
710
711 let (source, entity_type) = match &entity.kind {
712 EntityKind::TableRow { table, .. } => (
713 ChunkSource::Table(table.to_string()),
714 Some(super::EntityType::Unknown), ),
716 EntityKind::GraphNode(ref node) => (
717 ChunkSource::Graph,
718 Some(match node.node_type.to_lowercase().as_str() {
720 "host" => super::EntityType::Host,
721 "service" => super::EntityType::Service,
722 "port" => super::EntityType::Port,
723 "vulnerability" | "vuln" => super::EntityType::Vulnerability,
724 "credential" | "cred" => super::EntityType::Credential,
725 "user" => super::EntityType::User,
726 "certificate" | "cert" => super::EntityType::Certificate,
727 "domain" => super::EntityType::Domain,
728 "network" => super::EntityType::Network,
729 "technology" | "tech" => super::EntityType::Technology,
730 "endpoint" => super::EntityType::Endpoint,
731 _ => super::EntityType::Unknown,
732 }),
733 ),
734 EntityKind::GraphEdge(_) => (
735 ChunkSource::Graph,
736 Some(super::EntityType::Unknown), ),
738 EntityKind::Vector { collection: col } => (
739 ChunkSource::Vector(col.clone()),
740 Some(super::EntityType::Unknown), ),
742 EntityKind::TimeSeriesPoint(ref ts) => (
743 ChunkSource::Table(ts.series.clone()),
744 Some(super::EntityType::Unknown),
745 ),
746 EntityKind::QueueMessage { queue, .. } => (
747 ChunkSource::Table(queue.clone()),
748 Some(super::EntityType::Unknown),
749 ),
750 };
751
752 ContextChunk {
753 content,
754 source,
755 relevance: score,
756 entity_type,
757 entity_id: Some(entity.id.0.to_string()),
758 metadata: HashMap::new(),
759 vector_distance: Some(1.0 - score), graph_depth: None,
761 }
762}
763
764#[cfg(test)]
765mod tests {
766 use super::*;
767
768 #[test]
769 fn test_cosine_similarity() {
770 let a = vec![1.0, 0.0, 0.0];
771 let b = vec![1.0, 0.0, 0.0];
772 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
773
774 let c = vec![0.0, 1.0, 0.0];
775 assert!(cosine_similarity(&a, &c).abs() < 0.001);
776
777 let d = vec![1.0, 1.0, 0.0];
778 let sim = cosine_similarity(&a, &d);
779 assert!(sim > 0.7 && sim < 0.72);
780 }
781
782 #[test]
783 fn test_metadata_query_builder() {
784 let query = MetadataQuery::new()
785 .eq("type", "host")
786 .gt("score", 0.5f64)
787 .contains("name", "server");
788
789 assert_eq!(query.conditions.len(), 3);
790 }
791
792 #[test]
793 fn test_multi_modal_query_builder() {
794 let query = MultiModalQuery::new()
795 .with_vector(vec![1.0, 0.0, 0.0], 10)
796 .with_weights(0.6, 0.3, 0.1)
797 .with_limit(20);
798
799 assert!(query.query_vector.is_some());
800 assert_eq!(query.vector_k, Some(10));
801 assert_eq!(query.limit, Some(20));
802 }
803}