use crate::functions::Expression;
use crate::model::StarTerm;
use crate::query::{BasicGraphPattern, Binding, QueryExecutor};
use crate::StarResult;
use std::collections::{HashMap, HashSet};
use tracing::{debug, span, Level};
#[derive(Debug, Clone)]
pub enum GraphPattern {
BasicPattern(BasicGraphPattern),
Optional(Box<GraphPattern>),
Union(Box<GraphPattern>, Box<GraphPattern>),
Graph {
graph: TermOrVariable,
pattern: Box<GraphPattern>,
},
Minus(Box<GraphPattern>, Box<GraphPattern>),
Group(Vec<GraphPattern>),
}
#[derive(Debug, Clone)]
pub enum TermOrVariable {
Term(StarTerm),
Variable(String),
}
#[derive(Debug, Clone, Default)]
pub struct SolutionModifier {
pub order_by: Vec<OrderCondition>,
pub distinct: bool,
pub limit: Option<usize>,
pub offset: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct OrderCondition {
pub expression: Expression,
pub ascending: bool,
}
#[derive(Debug, Clone)]
pub struct BindClause {
pub expression: Expression,
pub variable: String,
}
#[derive(Debug, Clone)]
pub struct ValuesClause {
pub variables: Vec<String>,
pub data: Vec<Vec<Option<StarTerm>>>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum AggregateFunction {
Count,
Sum,
Avg,
Min,
Max,
GroupConcat { separator: String },
Sample,
}
#[derive(Debug, Clone)]
pub struct Aggregation {
pub function: AggregateFunction,
pub expression: Expression,
pub as_variable: String,
pub distinct: bool,
}
#[derive(Debug, Clone)]
pub struct GroupByClause {
pub expressions: Vec<Expression>,
pub aggregations: Vec<Aggregation>,
pub having: Vec<Expression>,
}
#[derive(Debug, Clone)]
pub struct EnhancedQuery {
pub pattern: GraphPattern,
pub bind_clauses: Vec<BindClause>,
pub values: Option<ValuesClause>,
pub group_by: Option<GroupByClause>,
pub modifiers: SolutionModifier,
}
pub struct EnhancedSparqlExecutor {
base_executor: QueryExecutor,
}
impl EnhancedSparqlExecutor {
pub fn new(base_executor: QueryExecutor) -> Self {
Self { base_executor }
}
fn binding_to_map(binding: &Binding) -> HashMap<String, StarTerm> {
binding
.variables()
.into_iter()
.filter_map(|var| binding.get(var).map(|term| (var.clone(), term.clone())))
.collect()
}
pub fn execute(&mut self, query: &EnhancedQuery) -> StarResult<Vec<Binding>> {
let span = span!(Level::INFO, "execute_enhanced_query");
let _enter = span.enter();
let mut bindings = self.execute_graph_pattern(&query.pattern)?;
if let Some(ref values) = query.values {
bindings = self.apply_values(bindings, values)?;
}
for bind in &query.bind_clauses {
bindings = self.apply_bind(bindings, bind)?;
}
if let Some(ref group_by) = query.group_by {
bindings = self.apply_group_by(bindings, group_by)?;
}
bindings = self.apply_modifiers(bindings, &query.modifiers)?;
debug!("Enhanced query produced {} final bindings", bindings.len());
Ok(bindings)
}
pub fn execute_graph_pattern(&mut self, pattern: &GraphPattern) -> StarResult<Vec<Binding>> {
match pattern {
GraphPattern::BasicPattern(bgp) => self.base_executor.execute_bgp(bgp),
GraphPattern::Optional(inner) => {
let base_bindings = vec![Binding::new()];
let optional_bindings = self.execute_graph_pattern(inner)?;
if optional_bindings.is_empty() {
Ok(base_bindings)
} else {
Ok(optional_bindings)
}
}
GraphPattern::Union(left, right) => {
let mut left_bindings = self.execute_graph_pattern(left)?;
let right_bindings = self.execute_graph_pattern(right)?;
left_bindings.extend(right_bindings);
Ok(left_bindings)
}
GraphPattern::Graph { pattern, .. } => {
self.execute_graph_pattern(pattern)
}
GraphPattern::Minus(left, right) => {
let left_bindings = self.execute_graph_pattern(left)?;
let right_bindings = self.execute_graph_pattern(right)?;
let right_set: HashSet<_> =
right_bindings.iter().map(|b| format!("{:?}", b)).collect();
Ok(left_bindings
.into_iter()
.filter(|b| !right_set.contains(&format!("{:?}", b)))
.collect())
}
GraphPattern::Group(patterns) => {
let mut current_bindings = vec![Binding::new()];
for pattern in patterns {
let pattern_bindings = self.execute_graph_pattern(pattern)?;
current_bindings = self.join_bindings(current_bindings, pattern_bindings)?;
}
Ok(current_bindings)
}
}
}
fn join_bindings(&self, left: Vec<Binding>, right: Vec<Binding>) -> StarResult<Vec<Binding>> {
let mut result = Vec::new();
for left_binding in &left {
for right_binding in &right {
if let Some(merged) = left_binding.merge(right_binding) {
result.push(merged);
}
}
}
Ok(result)
}
fn apply_values(
&self,
bindings: Vec<Binding>,
values: &ValuesClause,
) -> StarResult<Vec<Binding>> {
let mut result = Vec::new();
for row in &values.data {
let mut value_binding = Binding::new();
for (i, var) in values.variables.iter().enumerate() {
if let Some(Some(term)) = row.get(i) {
value_binding.bind(var, term.clone());
}
}
for binding in &bindings {
if let Some(merged) = binding.merge(&value_binding) {
result.push(merged);
}
}
}
Ok(result)
}
fn apply_bind(&self, bindings: Vec<Binding>, bind: &BindClause) -> StarResult<Vec<Binding>> {
let mut result = Vec::new();
for binding in bindings {
let mut new_binding = binding.clone();
let binding_map = Self::binding_to_map(&binding);
if let Ok(term) =
crate::functions::ExpressionEvaluator::evaluate(&bind.expression, &binding_map)
{
new_binding.bind(&bind.variable, term);
result.push(new_binding);
} else {
continue;
}
}
Ok(result)
}
fn apply_group_by(
&self,
bindings: Vec<Binding>,
group_by: &GroupByClause,
) -> StarResult<Vec<Binding>> {
let mut groups: HashMap<String, Vec<Binding>> = HashMap::new();
for binding in bindings {
let group_key = self.compute_group_key(&binding, &group_by.expressions)?;
groups.entry(group_key).or_default().push(binding);
}
let mut result = Vec::new();
for (_, group_bindings) in groups {
if group_bindings.is_empty() {
continue;
}
let mut aggregated_binding = group_bindings[0].clone();
for agg in &group_by.aggregations {
let agg_result = self.apply_aggregation(&group_bindings, agg)?;
aggregated_binding.bind(&agg.as_variable, agg_result);
}
let passes_having = group_by.having.iter().all(|having_expr| {
let binding_map = Self::binding_to_map(&aggregated_binding);
if let Ok(result) =
crate::functions::ExpressionEvaluator::evaluate(having_expr, &binding_map)
{
self.is_truthy(&result)
} else {
false
}
});
if passes_having {
result.push(aggregated_binding);
}
}
Ok(result)
}
fn compute_group_key(
&self,
binding: &Binding,
expressions: &[Expression],
) -> StarResult<String> {
let mut key = String::new();
for expr in expressions {
let binding_map = Self::binding_to_map(binding);
if let Ok(term) = crate::functions::ExpressionEvaluator::evaluate(expr, &binding_map) {
key.push_str(&format!("{:?}", term));
key.push('|');
}
}
Ok(key)
}
fn apply_aggregation(&self, bindings: &[Binding], agg: &Aggregation) -> StarResult<StarTerm> {
match agg.function {
AggregateFunction::Count => {
let count = if agg.distinct {
self.count_distinct(bindings, &agg.expression)?
} else {
bindings.len()
};
StarTerm::literal(&count.to_string())
}
AggregateFunction::Sum => self.sum_aggregation(bindings, &agg.expression),
AggregateFunction::Avg => {
let sum = self.sum_numeric(bindings, &agg.expression)?;
let count = bindings.len() as f64;
let avg = sum / count;
StarTerm::literal(&avg.to_string())
}
AggregateFunction::Min => self.min_aggregation(bindings, &agg.expression),
AggregateFunction::Max => self.max_aggregation(bindings, &agg.expression),
AggregateFunction::GroupConcat { ref separator } => {
self.group_concat(bindings, &agg.expression, separator)
}
AggregateFunction::Sample => {
if let Some(binding) = bindings.first() {
let binding_map = Self::binding_to_map(binding);
crate::functions::ExpressionEvaluator::evaluate(&agg.expression, &binding_map)
} else {
StarTerm::literal("")
}
}
}
}
fn count_distinct(&self, bindings: &[Binding], expr: &Expression) -> StarResult<usize> {
let mut seen = HashSet::new();
for binding in bindings {
let binding_map = Self::binding_to_map(binding);
if let Ok(term) = crate::functions::ExpressionEvaluator::evaluate(expr, &binding_map) {
seen.insert(format!("{:?}", term));
}
}
Ok(seen.len())
}
fn sum_numeric(&self, bindings: &[Binding], expr: &Expression) -> StarResult<f64> {
let mut sum = 0.0;
for binding in bindings {
let binding_map = Self::binding_to_map(binding);
if let Ok(term) = crate::functions::ExpressionEvaluator::evaluate(expr, &binding_map) {
if let Some(literal) = term.as_literal() {
if let Ok(num) = literal.value.parse::<f64>() {
sum += num;
}
}
}
}
Ok(sum)
}
fn sum_aggregation(&self, bindings: &[Binding], expr: &Expression) -> StarResult<StarTerm> {
let sum = self.sum_numeric(bindings, expr)?;
StarTerm::literal(&sum.to_string())
}
fn min_aggregation(&self, bindings: &[Binding], expr: &Expression) -> StarResult<StarTerm> {
let mut min_val: Option<f64> = None;
for binding in bindings {
let binding_map = Self::binding_to_map(binding);
if let Ok(term) = crate::functions::ExpressionEvaluator::evaluate(expr, &binding_map) {
if let Some(literal) = term.as_literal() {
if let Ok(num) = literal.value.parse::<f64>() {
min_val = Some(min_val.map_or(num, |m| m.min(num)));
}
}
}
}
if let Some(min) = min_val {
StarTerm::literal(&min.to_string())
} else {
StarTerm::literal("")
}
}
fn max_aggregation(&self, bindings: &[Binding], expr: &Expression) -> StarResult<StarTerm> {
let mut max_val: Option<f64> = None;
for binding in bindings {
let binding_map = Self::binding_to_map(binding);
if let Ok(term) = crate::functions::ExpressionEvaluator::evaluate(expr, &binding_map) {
if let Some(literal) = term.as_literal() {
if let Ok(num) = literal.value.parse::<f64>() {
max_val = Some(max_val.map_or(num, |m| m.max(num)));
}
}
}
}
if let Some(max) = max_val {
StarTerm::literal(&max.to_string())
} else {
StarTerm::literal("")
}
}
fn group_concat(
&self,
bindings: &[Binding],
expr: &Expression,
separator: &str,
) -> StarResult<StarTerm> {
let mut values = Vec::new();
for binding in bindings {
let binding_map = Self::binding_to_map(binding);
if let Ok(term) = crate::functions::ExpressionEvaluator::evaluate(expr, &binding_map) {
values.push(format!("{}", term));
}
}
StarTerm::literal(&values.join(separator))
}
fn apply_modifiers(
&self,
mut bindings: Vec<Binding>,
modifiers: &SolutionModifier,
) -> StarResult<Vec<Binding>> {
if modifiers.distinct {
let mut seen = HashSet::new();
bindings.retain(|binding| {
let key = format!("{:?}", binding);
seen.insert(key.clone())
});
}
if !modifiers.order_by.is_empty() {
bindings.sort_by(|a, b| {
for order_cond in &modifiers.order_by {
let a_map = Self::binding_to_map(a);
let b_map = Self::binding_to_map(b);
let a_val = crate::functions::ExpressionEvaluator::evaluate(
&order_cond.expression,
&a_map,
)
.ok();
let b_val = crate::functions::ExpressionEvaluator::evaluate(
&order_cond.expression,
&b_map,
)
.ok();
let cmp = match (a_val, b_val) {
(Some(a_term), Some(b_term)) => self.compare_terms(&a_term, &b_term),
(Some(_), None) => std::cmp::Ordering::Less,
(None, Some(_)) => std::cmp::Ordering::Greater,
(None, None) => std::cmp::Ordering::Equal,
};
if cmp != std::cmp::Ordering::Equal {
return if order_cond.ascending {
cmp
} else {
cmp.reverse()
};
}
}
std::cmp::Ordering::Equal
});
}
let start = modifiers.offset.unwrap_or(0);
let end = modifiers.limit.map(|l| start + l).unwrap_or(bindings.len());
Ok(bindings.into_iter().skip(start).take(end - start).collect())
}
fn compare_terms(&self, a: &StarTerm, b: &StarTerm) -> std::cmp::Ordering {
format!("{:?}", a).cmp(&format!("{:?}", b))
}
fn is_truthy(&self, term: &StarTerm) -> bool {
if let Some(literal) = term.as_literal() {
if let Some(datatype) = &literal.datatype {
if datatype.iri == "http://www.w3.org/2001/XMLSchema#boolean" {
return literal.value == "true";
}
}
!literal.value.is_empty() && literal.value != "false" && literal.value != "0"
} else {
false
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::StarTriple;
use crate::query::{QueryExecutor, TermPattern, TriplePattern};
use crate::StarStore;
#[test]
fn test_optional_pattern() {
let store = StarStore::new();
let triple1 = StarTriple::new(
StarTerm::iri("http://example.org/alice").unwrap(),
StarTerm::iri("http://example.org/name").unwrap(),
StarTerm::literal("Alice").unwrap(),
);
store.insert(&triple1).unwrap();
let base_executor = QueryExecutor::new(store);
let mut executor = EnhancedSparqlExecutor::new(base_executor);
let mut bgp = BasicGraphPattern::new();
bgp.add_pattern(TriplePattern::new(
TermPattern::Variable("x".to_string()),
TermPattern::Term(StarTerm::iri("http://example.org/name").unwrap()),
TermPattern::Variable("name".to_string()),
));
let optional_bgp = GraphPattern::BasicPattern(bgp);
let bindings = executor.execute_graph_pattern(&optional_bgp).unwrap();
assert_eq!(bindings.len(), 1);
}
#[test]
fn test_solution_modifiers() {
let store = StarStore::new();
for i in 0..10 {
let triple = StarTriple::new(
StarTerm::iri(&format!("http://example.org/person{i}")).unwrap(),
StarTerm::iri("http://example.org/age").unwrap(),
StarTerm::literal(&format!("{}", 20 + i)).unwrap(),
);
store.insert(&triple).unwrap();
}
let base_executor = QueryExecutor::new(store);
let executor = EnhancedSparqlExecutor::new(base_executor);
let mut bindings = Vec::new();
for i in 0..10 {
let mut binding = Binding::new();
binding.bind("age", StarTerm::literal(&format!("{}", 20 + i)).unwrap());
bindings.push(binding);
}
let modifiers = SolutionModifier {
limit: Some(5),
offset: Some(2),
distinct: false,
order_by: vec![],
};
let result = executor.apply_modifiers(bindings, &modifiers).unwrap();
assert_eq!(result.len(), 5);
}
#[test]
fn test_aggregations() {
let store = StarStore::new();
let base_executor = QueryExecutor::new(store);
let executor = EnhancedSparqlExecutor::new(base_executor);
let mut bindings = Vec::new();
for i in 1..=5 {
let mut binding = Binding::new();
binding.bind("value", StarTerm::literal(&i.to_string()).unwrap());
bindings.push(binding);
}
let count_agg = Aggregation {
function: AggregateFunction::Count,
expression: Expression::var("value"),
as_variable: "count".to_string(),
distinct: false,
};
let count_result = executor.apply_aggregation(&bindings, &count_agg).unwrap();
assert_eq!(count_result, StarTerm::literal("5").unwrap());
let sum_agg = Aggregation {
function: AggregateFunction::Sum,
expression: Expression::var("value"),
as_variable: "sum".to_string(),
distinct: false,
};
let sum_result = executor.apply_aggregation(&bindings, &sum_agg).unwrap();
if let Some(literal) = sum_result.as_literal() {
let sum: f64 = literal.value.parse().unwrap();
assert!((sum - 15.0).abs() < 0.1); }
}
}