1use std::collections::{HashMap, HashSet};
17use std::sync::Arc;
18
19use crate::storage::engine::graph_store::GraphStore;
20use crate::storage::query::ast::CompareOp;
21use crate::storage::query::modes::sparql::{
22 SparqlFilter, SparqlParser, SparqlQuery, SparqlTerm, TriplePattern,
23};
24use crate::storage::query::unified::{
25 ExecutionError, MatchedEdge, MatchedNode, QueryStats, UnifiedRecord, UnifiedResult,
26};
27use crate::storage::schema::Value;
28
29#[derive(Debug, Clone, Default)]
31pub struct Binding {
32 values: HashMap<String, BoundValue>,
34 parent: Option<Box<Binding>>,
36}
37
38#[derive(Debug, Clone, PartialEq)]
40pub enum BoundValue {
41 Node(String),
43 Edge(String, String, String),
45 Literal(String),
47 Integer(i64),
49 Float(f64),
51 Boolean(bool),
53}
54
55impl BoundValue {
56 pub fn as_node_id(&self) -> Option<&str> {
58 match self {
59 Self::Node(id) => Some(id),
60 _ => None,
61 }
62 }
63
64 pub fn to_string_value(&self) -> String {
66 match self {
67 Self::Node(id) => id.clone(),
68 Self::Edge(from, etype, to) => format!("{}--{}-->{}", from, etype, to),
69 Self::Literal(s) => s.clone(),
70 Self::Integer(i) => i.to_string(),
71 Self::Float(f) => f.to_string(),
72 Self::Boolean(b) => b.to_string(),
73 }
74 }
75}
76
77impl Binding {
78 pub fn new() -> Self {
80 Self::default()
81 }
82
83 pub fn with_parent(parent: Binding) -> Self {
85 Self {
86 values: HashMap::new(),
87 parent: Some(Box::new(parent)),
88 }
89 }
90
91 pub fn bind(&mut self, var: &str, value: BoundValue) {
93 let var_name = var.strip_prefix('?').unwrap_or(var);
95 self.values.insert(var_name.to_string(), value);
96 }
97
98 pub fn get(&self, var: &str) -> Option<&BoundValue> {
100 let var_name = var.strip_prefix('?').unwrap_or(var);
101 self.values
102 .get(var_name)
103 .or_else(|| self.parent.as_ref().and_then(|p| p.get(var_name)))
104 }
105
106 pub fn contains(&self, var: &str) -> bool {
108 self.get(var).is_some()
109 }
110
111 pub fn merge(&self, other: &Binding) -> Option<Binding> {
113 let mut result = self.clone();
114 for (var, value) in &other.values {
115 if let Some(existing) = result.get(var) {
116 if existing != value {
118 return None; }
120 } else {
121 result.bind(var, value.clone());
122 }
123 }
124 Some(result)
125 }
126
127 pub fn vars(&self) -> Vec<String> {
129 let mut vars: HashSet<_> = self.values.keys().cloned().collect();
130 if let Some(ref parent) = self.parent {
131 for v in parent.vars() {
132 vars.insert(v);
133 }
134 }
135 vars.into_iter().collect()
136 }
137}
138
139pub struct SparqlExecutor {
141 graph: Arc<GraphStore>,
142}
143
144impl SparqlExecutor {
145 pub fn new(graph: Arc<GraphStore>) -> Self {
147 Self { graph }
148 }
149
150 pub fn execute(&self, query: &str) -> Result<UnifiedResult, ExecutionError> {
152 let parsed = SparqlParser::parse(query).map_err(|e| ExecutionError::new(e.to_string()))?;
153 self.execute_query(&parsed)
154 }
155
156 pub fn execute_query(&self, query: &SparqlQuery) -> Result<UnifiedResult, ExecutionError> {
158 let mut stats = QueryStats::default();
159
160 let initial = vec![Binding::new()];
162
163 let mut bindings = self.execute_patterns(&query.where_patterns, initial, &mut stats)?;
165
166 for filter in &query.filters {
168 bindings = self.apply_filter(bindings, filter)?;
169 }
170
171 for optional in &query.optionals {
173 bindings = self.execute_optional(bindings, optional, &mut stats)?;
174 }
175
176 if let Some(limit) = query.limit {
178 bindings.truncate(limit as usize);
179 }
180
181 self.project_results(&query.select, bindings, stats)
183 }
184
185 fn execute_patterns(
187 &self,
188 patterns: &[TriplePattern],
189 bindings: Vec<Binding>,
190 stats: &mut QueryStats,
191 ) -> Result<Vec<Binding>, ExecutionError> {
192 let mut current = bindings;
193
194 for pattern in patterns {
195 current = self.match_pattern(pattern, current, stats)?;
196 if current.is_empty() {
197 break;
198 }
199 }
200
201 Ok(current)
202 }
203
204 fn match_pattern(
206 &self,
207 pattern: &TriplePattern,
208 bindings: Vec<Binding>,
209 stats: &mut QueryStats,
210 ) -> Result<Vec<Binding>, ExecutionError> {
211 let mut results = Vec::new();
212
213 for binding in bindings {
214 let subjects = self.resolve_term(&pattern.subject, &binding, stats);
216
217 for subject in subjects {
218 let subject_id = match &subject {
220 BoundValue::Node(id) => id.clone(),
221 BoundValue::Literal(s) => s.clone(),
222 _ => continue,
223 };
224
225 for (edge_type, target, _weight) in self.graph.outgoing_edges(&subject_id) {
227 stats.edges_scanned += 1;
228
229 if !self.predicate_matches(&pattern.predicate, edge_type.as_str(), &binding) {
231 continue;
232 }
233
234 let object_value = self.resolve_object(&pattern.object, &binding, &target);
236 if object_value.is_none() {
237 continue;
238 }
239
240 let mut new_binding = binding.clone();
242
243 if let SparqlTerm::Variable(var) = &pattern.subject {
245 new_binding.bind(var, subject.clone());
246 }
247
248 if let SparqlTerm::Variable(var) = &pattern.predicate {
250 new_binding.bind(var, BoundValue::Literal(format!("{:?}", edge_type)));
251 }
252
253 if let SparqlTerm::Variable(var) = &pattern.object {
255 if let Some(obj) = object_value {
256 new_binding.bind(var, obj);
257 }
258 }
259
260 results.push(new_binding);
261 }
262
263 if self.is_type_predicate(&pattern.predicate) {
265 if let Some(node) = self.graph.get_node(&subject_id) {
266 stats.nodes_scanned += 1;
267 let node_type_str = format!("{:?}", node.node_type);
268
269 if self.object_matches_type(&pattern.object, &node_type_str, &binding) {
270 let mut new_binding = binding.clone();
271
272 if let SparqlTerm::Variable(var) = &pattern.subject {
273 new_binding.bind(var, BoundValue::Node(subject_id.clone()));
274 }
275 if let SparqlTerm::Variable(var) = &pattern.object {
276 new_binding.bind(var, BoundValue::Literal(node_type_str));
277 }
278
279 results.push(new_binding);
280 }
281 }
282 }
283 }
284 }
285
286 Ok(results)
287 }
288
289 fn resolve_term(
291 &self,
292 term: &SparqlTerm,
293 binding: &Binding,
294 stats: &mut QueryStats,
295 ) -> Vec<BoundValue> {
296 match term {
297 SparqlTerm::Variable(var) => {
298 if let Some(value) = binding.get(var) {
300 return vec![value.clone()];
301 }
302 self.graph
304 .iter_nodes()
305 .map(|n| {
306 stats.nodes_scanned += 1;
307 BoundValue::Node(n.id.clone())
308 })
309 .collect()
310 }
311 SparqlTerm::PrefixedName(prefix, local) => {
312 let id = if prefix.is_empty() {
313 local.clone()
314 } else {
315 format!("{}:{}", prefix, local)
316 };
317 vec![BoundValue::Node(id)]
318 }
319 SparqlTerm::Iri(iri) => {
320 let id = iri
322 .rsplit('/')
323 .next()
324 .or_else(|| iri.rsplit('#').next())
325 .unwrap_or(iri);
326 vec![BoundValue::Node(id.to_string())]
327 }
328 SparqlTerm::Literal(lit) => {
329 vec![BoundValue::Literal(lit.clone())]
330 }
331 SparqlTerm::TypedLiteral(lit, _datatype) => {
332 vec![BoundValue::Literal(lit.clone())]
333 }
334 SparqlTerm::Number(n) => {
335 vec![BoundValue::Float(*n)]
336 }
337 SparqlTerm::Boolean(b) => {
338 vec![BoundValue::Boolean(*b)]
339 }
340 SparqlTerm::A => {
341 vec![BoundValue::Literal("rdf:type".to_string())]
342 }
343 }
344 }
345
346 fn predicate_matches(
348 &self,
349 predicate: &SparqlTerm,
350 edge_label: &str,
351 binding: &Binding,
352 ) -> bool {
353 match predicate {
354 SparqlTerm::Variable(var) => {
355 if let Some(bound) = binding.get(var) {
356 let bound_str = bound.to_string_value().to_lowercase();
357 let edge_str = edge_label.to_lowercase();
358 return bound_str == edge_str || edge_str.contains(&bound_str);
359 }
360 true }
362 SparqlTerm::PrefixedName(_, local) => {
363 let pred_clean = local.to_lowercase();
364 let edge_str = edge_label.to_lowercase();
365 edge_str == pred_clean
366 || edge_str.contains(&pred_clean)
367 || self.predicate_alias_matches(&pred_clean, edge_label)
368 }
369 SparqlTerm::Iri(iri) => {
370 let local = iri
371 .rsplit('/')
372 .next()
373 .or_else(|| iri.rsplit('#').next())
374 .unwrap_or(iri);
375 let pred_clean = local.to_lowercase();
376 let edge_str = edge_label.to_lowercase();
377 edge_str == pred_clean
378 || edge_str.contains(&pred_clean)
379 || self.predicate_alias_matches(&pred_clean, edge_label)
380 }
381 SparqlTerm::A => false, _ => false,
383 }
384 }
385
386 fn predicate_alias_matches(&self, predicate: &str, edge_label: &str) -> bool {
390 matches!(
391 (predicate, edge_label),
392 ("hasservice" | "has_service" | "service", "has_service")
393 | ("connectsto" | "connects_to" | "connects", "connects_to")
394 | ("hasuser" | "has_user", "has_user")
395 | ("usestech" | "uses_tech" | "uses", "uses_tech")
396 | ("authaccess" | "auth_access", "auth_access")
397 | ("hasendpoint" | "has_endpoint", "has_endpoint")
398 | (
399 "hascert" | "has_cert" | "hascertificate" | "has_certificate",
400 "has_cert"
401 )
402 | ("contains" | "has_subdomain" | "hassubdomain", "contains")
403 | (
404 "affectedby"
405 | "affected_by"
406 | "hasvulnerability"
407 | "has_vuln"
408 | "vulnerable_to",
409 "affected_by"
410 )
411 | (
412 "relatedto" | "related_to" | "memberof" | "member_of",
413 "related_to"
414 )
415 )
416 }
417
418 fn is_type_predicate(&self, predicate: &SparqlTerm) -> bool {
420 match predicate {
421 SparqlTerm::A => true,
422 SparqlTerm::PrefixedName(_prefix, local) => {
423 local == "type" }
425 SparqlTerm::Iri(iri) => iri.ends_with("type") || iri.ends_with("#type"),
426 _ => false,
427 }
428 }
429
430 fn object_matches_type(&self, object: &SparqlTerm, node_type: &str, binding: &Binding) -> bool {
432 match object {
433 SparqlTerm::Variable(var) => {
434 if let Some(bound) = binding.get(var) {
435 bound.to_string_value().to_lowercase() == node_type.to_lowercase()
436 } else {
437 true }
439 }
440 SparqlTerm::PrefixedName(_, local) => {
441 node_type.to_lowercase() == local.to_lowercase()
442 || node_type.to_lowercase().contains(&local.to_lowercase())
443 }
444 SparqlTerm::Iri(iri) => {
445 let local = iri
446 .rsplit('/')
447 .next()
448 .or_else(|| iri.rsplit('#').next())
449 .unwrap_or(iri);
450 node_type.to_lowercase() == local.to_lowercase()
451 || node_type.to_lowercase().contains(&local.to_lowercase())
452 }
453 SparqlTerm::Literal(lit) => {
454 node_type.to_lowercase() == lit.to_lowercase()
455 || node_type.to_lowercase().contains(&lit.to_lowercase())
456 }
457 _ => false,
458 }
459 }
460
461 fn resolve_object(
463 &self,
464 object: &SparqlTerm,
465 binding: &Binding,
466 target: &str,
467 ) -> Option<BoundValue> {
468 match object {
469 SparqlTerm::Variable(var) => {
470 if let Some(bound) = binding.get(var) {
471 if bound.as_node_id() == Some(target) {
473 return Some(bound.clone());
474 }
475 return None;
476 }
477 Some(BoundValue::Node(target.to_string()))
479 }
480 SparqlTerm::PrefixedName(_, local) => {
481 if target == local || target.ends_with(local) || target.contains(local) {
482 Some(BoundValue::Node(target.to_string()))
483 } else {
484 None
485 }
486 }
487 SparqlTerm::Iri(iri) => {
488 let id = iri
489 .rsplit('/')
490 .next()
491 .or_else(|| iri.rsplit('#').next())
492 .unwrap_or(iri);
493 if target == id || target.ends_with(id) || target.contains(id) {
494 Some(BoundValue::Node(target.to_string()))
495 } else {
496 None
497 }
498 }
499 SparqlTerm::Literal(_) | SparqlTerm::TypedLiteral(_, _) => {
500 None
502 }
503 _ => None,
504 }
505 }
506
507 fn apply_filter(
509 &self,
510 bindings: Vec<Binding>,
511 filter: &SparqlFilter,
512 ) -> Result<Vec<Binding>, ExecutionError> {
513 Ok(bindings
514 .into_iter()
515 .filter(|b| self.evaluate_filter(filter, b))
516 .collect())
517 }
518
519 fn evaluate_filter(&self, filter: &SparqlFilter, binding: &Binding) -> bool {
521 match filter {
522 SparqlFilter::Compare(var, op, term) => {
523 if let Some(bound) = binding.get(var) {
524 let bound_str = bound.to_string_value();
525 let term_str = self.term_to_string(term);
526
527 match op {
528 CompareOp::Eq => bound_str.to_lowercase() == term_str.to_lowercase(),
529 CompareOp::Ne => bound_str.to_lowercase() != term_str.to_lowercase(),
530 CompareOp::Lt => self.compare_numeric(&bound_str, &term_str, |a, b| a < b),
531 CompareOp::Le => self.compare_numeric(&bound_str, &term_str, |a, b| a <= b),
532 CompareOp::Gt => self.compare_numeric(&bound_str, &term_str, |a, b| a > b),
533 CompareOp::Ge => self.compare_numeric(&bound_str, &term_str, |a, b| a >= b),
534 }
535 } else {
536 false
537 }
538 }
539 SparqlFilter::Regex(var, pattern, _flags) => {
540 if let Some(value) = binding.get(var) {
541 let s = value.to_string_value();
542 s.contains(pattern) } else {
544 false
545 }
546 }
547 SparqlFilter::Bound(var) => binding.contains(var),
548 SparqlFilter::NotBound(var) => !binding.contains(var),
549 SparqlFilter::IsIri(var) => binding
550 .get(var)
551 .map(|v| matches!(v, BoundValue::Node(_)))
552 .unwrap_or(false),
553 SparqlFilter::IsLiteral(var) => binding
554 .get(var)
555 .map(|v| !matches!(v, BoundValue::Node(_)))
556 .unwrap_or(false),
557 SparqlFilter::Contains(var, substring) => {
558 if let Some(value) = binding.get(var) {
559 value.to_string_value().contains(substring)
560 } else {
561 false
562 }
563 }
564 SparqlFilter::StrStarts(var, prefix) => {
565 if let Some(value) = binding.get(var) {
566 value.to_string_value().starts_with(prefix)
567 } else {
568 false
569 }
570 }
571 SparqlFilter::StrEnds(var, suffix) => {
572 if let Some(value) = binding.get(var) {
573 value.to_string_value().ends_with(suffix)
574 } else {
575 false
576 }
577 }
578 SparqlFilter::And(left, right) => {
579 self.evaluate_filter(left, binding) && self.evaluate_filter(right, binding)
580 }
581 SparqlFilter::Or(left, right) => {
582 self.evaluate_filter(left, binding) || self.evaluate_filter(right, binding)
583 }
584 SparqlFilter::Not(inner) => !self.evaluate_filter(inner, binding),
585 }
586 }
587
588 fn term_to_string(&self, term: &SparqlTerm) -> String {
590 match term {
591 SparqlTerm::Variable(v) => format!("?{}", v),
592 SparqlTerm::PrefixedName(p, l) => {
593 if p.is_empty() {
594 l.clone()
595 } else {
596 format!("{}:{}", p, l)
597 }
598 }
599 SparqlTerm::Iri(iri) => iri.clone(),
600 SparqlTerm::Literal(lit) => lit.clone(),
601 SparqlTerm::TypedLiteral(lit, _) => lit.clone(),
602 SparqlTerm::Number(n) => n.to_string(),
603 SparqlTerm::Boolean(b) => b.to_string(),
604 SparqlTerm::A => "rdf:type".to_string(),
605 }
606 }
607
608 fn compare_numeric<F>(&self, a: &str, b: &str, f: F) -> bool
610 where
611 F: Fn(f64, f64) -> bool,
612 {
613 let a_num: f64 = a.parse().unwrap_or(0.0);
614 let b_num: f64 = b.parse().unwrap_or(0.0);
615 f(a_num, b_num)
616 }
617
618 fn execute_optional(
620 &self,
621 bindings: Vec<Binding>,
622 optional_patterns: &[TriplePattern],
623 stats: &mut QueryStats,
624 ) -> Result<Vec<Binding>, ExecutionError> {
625 let mut results = Vec::new();
626
627 for binding in bindings {
628 let optional_matches =
630 self.execute_patterns(optional_patterns, vec![binding.clone()], stats)?;
631
632 if optional_matches.is_empty() {
633 results.push(binding);
635 } else {
636 results.extend(optional_matches);
638 }
639 }
640
641 Ok(results)
642 }
643
644 fn project_results(
646 &self,
647 select: &[String],
648 bindings: Vec<Binding>,
649 stats: QueryStats,
650 ) -> Result<UnifiedResult, ExecutionError> {
651 let mut result = UnifiedResult::empty();
652 result.stats = stats;
653
654 let columns: Vec<String> = if select.is_empty() || select.iter().any(|s| s == "*") {
656 if let Some(first) = bindings.first() {
658 first.vars()
659 } else {
660 Vec::new()
661 }
662 } else {
663 select
664 .iter()
665 .map(|s| s.strip_prefix('?').unwrap_or(s).to_string())
666 .collect()
667 };
668 result.columns = columns.clone();
669
670 for binding in bindings {
672 let mut record = UnifiedRecord::new();
673
674 for col in &columns {
675 if let Some(value) = binding.get(col) {
676 match value {
677 BoundValue::Node(id) => {
678 if let Some(node) = self.graph.get_node(id) {
680 record.set_node(col, MatchedNode::from_stored(&node));
681 }
682 record.set(col, Value::text(id.clone()));
683 }
684 BoundValue::Edge(from, etype, to) => {
685 record.set_edge(
686 col,
687 MatchedEdge::from_tuple(from, etype.clone(), to, 1.0),
688 );
689 record.set(col, Value::text(format!("{}->{}({})", from, to, etype)));
690 }
691 BoundValue::Literal(s) => {
692 record.set(col, Value::text(s.clone()));
693 }
694 BoundValue::Integer(i) => {
695 record.set(col, Value::Integer(*i));
696 }
697 BoundValue::Float(f) => {
698 record.set(col, Value::Float(*f));
699 }
700 BoundValue::Boolean(b) => {
701 record.set(col, Value::Boolean(*b));
702 }
703 }
704 }
705 }
706
707 result.push(record);
708 }
709
710 Ok(result)
711 }
712}
713
714#[cfg(test)]
715mod tests {
716 use super::*;
717 use crate::storage::query::test_support::service_graph_with_user;
718
719 fn create_test_graph() -> Arc<GraphStore> {
720 service_graph_with_user()
721 }
722
723 #[test]
724 fn test_simple_pattern() {
725 let graph = create_test_graph();
726 let executor = SparqlExecutor::new(graph);
727
728 let result = executor
729 .execute("SELECT ?s WHERE { ?s :hasService ?o }")
730 .unwrap();
731 assert!(!result.is_empty());
732 }
733
734 #[test]
735 fn test_type_pattern() {
736 let graph = create_test_graph();
737 let executor = SparqlExecutor::new(graph);
738
739 let result = executor.execute("SELECT ?h WHERE { ?h a :Host }").unwrap();
740 assert_eq!(result.records.len(), 2); }
742
743 #[test]
744 fn test_binding() {
745 let mut binding = Binding::new();
746 binding.bind("?x", BoundValue::Node("test".to_string()));
747
748 assert!(binding.contains("?x"));
749 assert!(binding.contains("x")); assert_eq!(binding.get("x").unwrap().as_node_id(), Some("test"));
751 }
752
753 #[test]
754 fn test_optional() {
755 let graph = create_test_graph();
756 let executor = SparqlExecutor::new(graph);
757
758 let result = executor
759 .execute("SELECT ?h ?u WHERE { ?h a :Host } OPTIONAL { ?h :hasUser ?u }")
760 .unwrap();
761 assert_eq!(result.records.len(), 2);
763 }
764}