use super::{ParsedQuery, QueryType, SparqlResult};
use crate::error::{Error, Result};
use aingle_graph::{GraphDB, TriplePattern as GraphTriplePattern};
use spargebra::{
algebra::{Expression, GraphPattern},
term::{NamedNodePattern, TermPattern},
Query,
};
use std::collections::HashMap;
pub fn execute_query(graph: &GraphDB, query: &ParsedQuery) -> Result<SparqlResult> {
match query.query_type {
QueryType::Select => execute_select(graph, query),
QueryType::Ask => execute_ask(graph, query),
QueryType::Construct => execute_construct(graph, query),
QueryType::Describe => execute_describe(graph, query),
}
}
fn execute_select(graph: &GraphDB, query: &ParsedQuery) -> Result<SparqlResult> {
match &query.query {
Query::Select { pattern, .. } => {
let mut bindings = Vec::new();
let mut variables = Vec::new();
let results = execute_pattern(graph, pattern)?;
if let Some(first) = results.first() {
variables = first.keys().cloned().collect();
}
for result in results {
let binding: serde_json::Value =
serde_json::to_value(&result).map_err(|e| Error::Internal(e.to_string()))?;
bindings.push(binding);
}
Ok(SparqlResult {
result_type: "bindings".to_string(),
variables: Some(variables),
bindings: Some(bindings),
boolean: None,
triple_count: None,
})
}
_ => Err(Error::Internal("Expected SELECT query".to_string())),
}
}
fn execute_ask(graph: &GraphDB, query: &ParsedQuery) -> Result<SparqlResult> {
match &query.query {
Query::Ask { pattern, .. } => {
let results = execute_pattern(graph, pattern)?;
let exists = !results.is_empty();
Ok(SparqlResult {
result_type: "boolean".to_string(),
variables: None,
bindings: None,
boolean: Some(exists),
triple_count: None,
})
}
_ => Err(Error::Internal("Expected ASK query".to_string())),
}
}
fn execute_construct(graph: &GraphDB, query: &ParsedQuery) -> Result<SparqlResult> {
match &query.query {
Query::Construct { pattern, .. } => {
let results = execute_pattern(graph, pattern)?;
let count = results.len();
Ok(SparqlResult {
result_type: "graph".to_string(),
variables: None,
bindings: Some(
results
.into_iter()
.map(|r| serde_json::to_value(&r).unwrap_or_default())
.collect(),
),
boolean: None,
triple_count: Some(count),
})
}
_ => Err(Error::Internal("Expected CONSTRUCT query".to_string())),
}
}
fn execute_describe(graph: &GraphDB, _query: &ParsedQuery) -> Result<SparqlResult> {
let all_triples = graph.find(GraphTriplePattern::any())?;
Ok(SparqlResult {
result_type: "graph".to_string(),
variables: None,
bindings: Some(
all_triples
.into_iter()
.map(|t| {
serde_json::json!({
"subject": t.subject.to_string(),
"predicate": t.predicate.to_string(),
"object": format!("{}", t.object),
})
})
.collect(),
),
boolean: None,
triple_count: None,
})
}
fn execute_pattern(
graph: &GraphDB,
pattern: &GraphPattern,
) -> Result<Vec<HashMap<String, String>>> {
let mut results = Vec::new();
match pattern {
GraphPattern::Bgp { patterns } => {
if patterns.is_empty() {
let all_triples = graph.find(GraphTriplePattern::any())?;
for triple in all_triples {
let mut binding = HashMap::new();
binding.insert("s".to_string(), triple.subject.to_string());
binding.insert("p".to_string(), triple.predicate.to_string());
binding.insert("o".to_string(), format!("{}", triple.object));
results.push(binding);
}
} else {
for pattern in patterns {
let all_triples = graph.find(GraphTriplePattern::any())?;
for triple in all_triples {
let mut binding = HashMap::new();
let mut matched = true;
match &pattern.subject {
TermPattern::Variable(v) => {
binding.insert(v.as_str().to_string(), triple.subject.to_string());
}
TermPattern::NamedNode(n) => {
if triple.subject.to_string() != format!("<{}>", n.as_str()) {
matched = false;
}
}
_ => {
matched = false;
}
}
if !matched {
continue;
}
match &pattern.predicate {
NamedNodePattern::Variable(v) => {
binding
.insert(v.as_str().to_string(), triple.predicate.to_string());
}
NamedNodePattern::NamedNode(n) => {
let pred_str = triple.predicate.to_string();
let expected = format!("<{}>", n.as_str());
if pred_str != expected {
let local_name = n.as_str().rsplit('/').next().unwrap_or("");
if pred_str != format!("<{}>", local_name) {
matched = false;
}
}
}
}
if !matched {
continue;
}
match &pattern.object {
TermPattern::Variable(v) => {
binding
.insert(v.as_str().to_string(), format!("{}", triple.object));
}
TermPattern::NamedNode(n) => {
if triple.object.to_string() != format!("<{}>", n.as_str()) {
matched = false;
}
}
TermPattern::Literal(lit) => {
if triple.object.to_string() != format!("\"{}\"", lit.value()) {
matched = false;
}
}
_ => {
matched = false;
}
}
if matched {
results.push(binding);
}
}
}
}
}
GraphPattern::Filter { inner, expr } => {
results = execute_pattern(graph, inner)?;
results.retain(|binding| evaluate_filter_expression(expr, binding).unwrap_or(false));
}
GraphPattern::Project { inner, variables } => {
results = execute_pattern(graph, inner)?;
results = results
.into_iter()
.map(|mut binding| {
let projected: HashMap<String, String> = variables
.iter()
.filter_map(|v| {
let var_name = v.as_str().to_string();
binding.remove(&var_name).map(|val| (var_name, val))
})
.collect();
projected
})
.collect();
}
GraphPattern::Join { left, right } => {
let left_results = execute_pattern(graph, left)?;
let right_results = execute_pattern(graph, right)?;
for l in &left_results {
for r in &right_results {
let mut combined = l.clone();
combined.extend(r.clone());
results.push(combined);
}
}
}
GraphPattern::Union { left, right } => {
results.extend(execute_pattern(graph, left)?);
results.extend(execute_pattern(graph, right)?);
}
GraphPattern::LeftJoin { left, right, .. } => {
results = execute_pattern(graph, left)?;
if let Ok(right_results) = execute_pattern(graph, right) {
for r in right_results {
if !results.iter().any(|l| l == &r) {
results.push(r);
}
}
}
}
_ => {
let all_triples = graph.find(GraphTriplePattern::any())?;
for triple in all_triples {
let mut binding = HashMap::new();
binding.insert("s".to_string(), triple.subject.to_string());
binding.insert("p".to_string(), triple.predicate.to_string());
binding.insert("o".to_string(), format!("{}", triple.object));
results.push(binding);
}
}
}
Ok(results)
}
fn evaluate_filter_expression(
expr: &Expression,
binding: &HashMap<String, String>,
) -> Result<bool> {
match expr {
Expression::Equal(left, right) => {
let l = evaluate_term(left, binding)?;
let r = evaluate_term(right, binding)?;
Ok(l == r)
}
Expression::Less(left, right) => {
let l = evaluate_term(left, binding)?;
let r = evaluate_term(right, binding)?;
compare_values(&l, &r, |a, b| a < b)
}
Expression::Greater(left, right) => {
let l = evaluate_term(left, binding)?;
let r = evaluate_term(right, binding)?;
compare_values(&l, &r, |a, b| a > b)
}
Expression::LessOrEqual(left, right) => {
let l = evaluate_term(left, binding)?;
let r = evaluate_term(right, binding)?;
compare_values(&l, &r, |a, b| a <= b)
}
Expression::GreaterOrEqual(left, right) => {
let l = evaluate_term(left, binding)?;
let r = evaluate_term(right, binding)?;
compare_values(&l, &r, |a, b| a >= b)
}
Expression::And(left, right) => Ok(evaluate_filter_expression(left, binding)?
&& evaluate_filter_expression(right, binding)?),
Expression::Or(left, right) => Ok(evaluate_filter_expression(left, binding)?
|| evaluate_filter_expression(right, binding)?),
Expression::Not(inner) => Ok(!evaluate_filter_expression(inner, binding)?),
Expression::Bound(var) => {
let var_name = var.as_str();
Ok(binding.contains_key(var_name))
}
Expression::FunctionCall(func, args) => evaluate_function_call(func, args, binding),
Expression::Exists(_) => {
Ok(true)
}
_ => {
Ok(true)
}
}
}
fn evaluate_term(expr: &Expression, binding: &HashMap<String, String>) -> Result<String> {
match expr {
Expression::Variable(var) => binding
.get(var.as_str())
.cloned()
.ok_or_else(|| Error::UnboundVariable(var.as_str().to_string())),
Expression::Literal(lit) => Ok(lit.value().to_string()),
Expression::NamedNode(node) => Ok(node.as_str().to_string()),
_ => Err(Error::UnsupportedExpression),
}
}
fn compare_values<F>(left: &str, right: &str, cmp: F) -> Result<bool>
where
F: Fn(f64, f64) -> bool,
{
if let (Ok(l), Ok(r)) = (left.parse::<f64>(), right.parse::<f64>()) {
return Ok(cmp(l, r));
}
Ok(false)
}
fn evaluate_function_call(
func: &spargebra::algebra::Function,
args: &[Expression],
binding: &HashMap<String, String>,
) -> Result<bool> {
use spargebra::algebra::Function;
match func {
Function::Regex => {
if args.len() >= 2 {
let text_val = evaluate_term(&args[0], binding)?;
let pattern_val = evaluate_term(&args[1], binding)?;
let flags_val = if args.len() >= 3 {
Some(evaluate_term(&args[2], binding)?)
} else {
None
};
evaluate_regex(&text_val, &pattern_val, flags_val.as_deref())
} else {
Ok(false)
}
}
Function::Str => {
if !args.is_empty() {
let _ = evaluate_term(&args[0], binding)?;
Ok(true)
} else {
Ok(false)
}
}
Function::LangMatches => {
Ok(true)
}
Function::IsIri => {
if !args.is_empty() {
if let Expression::NamedNode(_) = &args[0] {
Ok(true)
} else {
Ok(false)
}
} else {
Ok(false)
}
}
Function::IsBlank => {
if !args.is_empty() {
if let Ok(val) = evaluate_term(&args[0], binding) {
Ok(val.starts_with("_:"))
} else {
Ok(false)
}
} else {
Ok(false)
}
}
Function::IsLiteral => {
if !args.is_empty() {
Ok(matches!(&args[0], Expression::Literal(_)))
} else {
Ok(false)
}
}
_ => {
Ok(true)
}
}
}
const MAX_REGEX_PATTERN_LEN: usize = 256;
const MAX_REGEX_SIZE: usize = 10 * 1024;
fn evaluate_regex(text: &str, pattern: &str, flags: Option<&str>) -> Result<bool> {
if pattern.len() > MAX_REGEX_PATTERN_LEN {
return Err(Error::InvalidRegex(format!(
"Regex pattern exceeds maximum length of {} characters",
MAX_REGEX_PATTERN_LEN
)));
}
let case_insensitive = flags.map(|f| f.contains('i')).unwrap_or(false);
let regex = if case_insensitive {
regex::RegexBuilder::new(pattern)
.case_insensitive(true)
.size_limit(MAX_REGEX_SIZE)
.build()
} else {
regex::RegexBuilder::new(pattern)
.size_limit(MAX_REGEX_SIZE)
.build()
}
.map_err(|e| Error::InvalidRegex(e.to_string()))?;
Ok(regex.is_match(text))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_execute_basic_select() {
let graph = GraphDB::memory().unwrap();
use aingle_graph::{NodeId, Predicate, Triple, Value};
graph
.insert(Triple::new(
NodeId::named("alice"),
Predicate::named("knows"),
Value::Node(NodeId::named("bob")),
))
.unwrap();
let query = super::super::parse_sparql("SELECT ?s ?p ?o WHERE { ?s ?p ?o }").unwrap();
let result = execute_query(&graph, &query).unwrap();
assert_eq!(result.result_type, "bindings");
assert!(result.bindings.is_some());
}
#[test]
fn test_filter_equality() {
use aingle_graph::{NodeId, Predicate, Triple, Value};
let graph = GraphDB::memory().unwrap();
graph
.insert(Triple::new(
NodeId::named("alice"),
Predicate::named("name"),
Value::literal("Alice"),
))
.unwrap();
graph
.insert(Triple::new(
NodeId::named("bob"),
Predicate::named("name"),
Value::literal("Bob"),
))
.unwrap();
let query_str =
r#"SELECT ?s WHERE { ?s <http://example.org/name> ?o . FILTER(?o = "Alice") }"#;
let query = super::super::parse_sparql(query_str).unwrap();
let result = execute_query(&graph, &query).unwrap();
assert_eq!(result.result_type, "bindings");
assert!(result.bindings.is_some());
}
#[test]
fn test_filter_comparison_numeric() {
let mut binding = HashMap::new();
binding.insert("age".to_string(), "25".to_string());
use spargebra::term::{Literal, Variable};
let var_age = Variable::new("age").unwrap();
let expr = Expression::Greater(
Box::new(Expression::Variable(var_age)),
Box::new(Expression::Literal(Literal::new_simple_literal("18"))),
);
let result = evaluate_filter_expression(&expr, &binding).unwrap();
assert!(result); }
#[test]
fn test_filter_regex() {
let mut binding = HashMap::new();
binding.insert("name".to_string(), "John Smith".to_string());
use spargebra::algebra::Function;
use spargebra::term::{Literal, Variable};
let var_name = Variable::new("name").unwrap();
let pattern = Literal::new_simple_literal("^John");
let expr = Expression::FunctionCall(
Function::Regex,
vec![Expression::Variable(var_name), Expression::Literal(pattern)],
);
let result = evaluate_filter_expression(&expr, &binding).unwrap();
assert!(result); }
#[test]
fn test_filter_logical_and() {
let mut binding = HashMap::new();
binding.insert("age".to_string(), "25".to_string());
use spargebra::term::{Literal, Variable};
let var_age = Variable::new("age").unwrap();
let expr = Expression::And(
Box::new(Expression::GreaterOrEqual(
Box::new(Expression::Variable(var_age.clone())),
Box::new(Expression::Literal(Literal::new_simple_literal("18"))),
)),
Box::new(Expression::LessOrEqual(
Box::new(Expression::Variable(var_age)),
Box::new(Expression::Literal(Literal::new_simple_literal("30"))),
)),
);
let result = evaluate_filter_expression(&expr, &binding).unwrap();
assert!(result); }
#[test]
fn test_filter_not_equal() {
let mut binding = HashMap::new();
binding.insert("city".to_string(), "LA".to_string());
use spargebra::term::{Literal, Variable};
let var_city = Variable::new("city").unwrap();
let expr = Expression::Not(Box::new(Expression::Equal(
Box::new(Expression::Variable(var_city)),
Box::new(Expression::Literal(Literal::new_simple_literal("NYC"))),
)));
let result = evaluate_filter_expression(&expr, &binding).unwrap();
assert!(result); }
#[test]
fn test_evaluate_filter_bound() {
let mut binding = HashMap::new();
binding.insert("x".to_string(), "value".to_string());
use spargebra::term::Variable;
let var = Variable::new("x").unwrap();
let expr = Expression::Bound(var.clone());
let result = evaluate_filter_expression(&expr, &binding).unwrap();
assert!(result);
let var_y = Variable::new("y").unwrap();
let expr_y = Expression::Bound(var_y);
let result_y = evaluate_filter_expression(&expr_y, &binding).unwrap();
assert!(!result_y);
}
}