use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, error, info, warn};
use crate::raft::OxirsNodeId;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DistributedQueryPlan {
pub query_id: String,
pub original_sparql: String,
pub subqueries: Vec<SubqueryPlan>,
pub join_operations: Vec<JoinOperation>,
pub aggregation_plan: Option<AggregationPlan>,
pub estimated_cost: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SubqueryPlan {
pub subquery_id: String,
pub target_node: OxirsNodeId,
pub sparql_fragment: String,
pub variables: Vec<String>,
pub estimated_rows: u64,
pub estimated_latency_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JoinOperation {
pub left_subquery: String,
pub right_subquery: String,
pub join_variables: Vec<String>,
pub join_type: JoinType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum JoinType {
Inner,
Left,
Optional,
Union,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AggregationPlan {
pub group_by: Vec<String>,
pub aggregates: Vec<AggregateFunction>,
pub having_conditions: Vec<String>,
pub order_by: Vec<OrderByClause>,
pub limit: Option<u64>,
pub offset: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AggregateFunction {
pub function: String,
pub variable: String,
pub alias: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OrderByClause {
pub variable: String,
pub ascending: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryStats {
pub query_id: String,
pub execution_time_ms: u64,
pub nodes_involved: u32,
pub total_intermediate_results: u64,
pub final_result_count: u64,
pub network_transfer_bytes: u64,
pub cache_hits: u32,
pub cache_misses: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct ResultBinding {
pub variables: BTreeMap<String, String>,
}
impl Default for ResultBinding {
fn default() -> Self {
Self::new()
}
}
impl ResultBinding {
pub fn new() -> Self {
Self {
variables: BTreeMap::new(),
}
}
pub fn add_binding(&mut self, variable: String, value: String) {
self.variables.insert(variable, value);
}
pub fn get(&self, variable: &str) -> Option<&String> {
self.variables.get(variable)
}
pub fn merge(&self, other: &ResultBinding) -> Option<ResultBinding> {
let mut merged = self.clone();
for (var, val) in &other.variables {
if let Some(existing) = merged.variables.get(var) {
if existing != val {
return None; }
} else {
merged.variables.insert(var.clone(), val.clone());
}
}
Some(merged)
}
}
#[derive(Debug)]
pub struct DistributedQueryExecutor {
#[allow(dead_code)]
node_id: OxirsNodeId,
cluster_nodes: Arc<RwLock<HashSet<OxirsNodeId>>>,
query_cache: Arc<RwLock<HashMap<String, Vec<ResultBinding>>>>,
statistics: Arc<RwLock<HashMap<String, QueryStats>>>,
}
impl DistributedQueryExecutor {
pub fn new(node_id: OxirsNodeId) -> Self {
Self {
node_id,
cluster_nodes: Arc::new(RwLock::new(HashSet::new())),
query_cache: Arc::new(RwLock::new(HashMap::new())),
statistics: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn add_node(&self, node_id: OxirsNodeId) {
let mut nodes = self.cluster_nodes.write().await;
nodes.insert(node_id);
info!("Added node {} to distributed query executor", node_id);
}
pub async fn remove_node(&self, node_id: OxirsNodeId) {
let mut nodes = self.cluster_nodes.write().await;
nodes.remove(&node_id);
info!("Removed node {} from distributed query executor", node_id);
}
pub async fn execute_query(&self, sparql: &str) -> Result<Vec<ResultBinding>> {
let query_id = uuid::Uuid::new_v4().to_string();
let start_time = std::time::Instant::now();
info!("Executing distributed query {}: {}", query_id, sparql);
if let Some(cached_results) = self.check_cache(sparql).await {
info!("Cache hit for query {}", query_id);
return Ok(cached_results);
}
let plan = self.create_execution_plan(&query_id, sparql).await?;
let subquery_results = self.execute_subqueries(&plan).await?;
let final_results = self.combine_results(&plan, subquery_results).await?;
self.cache_results(sparql, &final_results).await;
let execution_time = start_time.elapsed().as_millis() as u64;
self.record_statistics(&query_id, &plan, &final_results, execution_time)
.await;
info!(
"Completed distributed query {} in {}ms, {} results",
query_id,
execution_time,
final_results.len()
);
Ok(final_results)
}
async fn create_execution_plan(
&self,
query_id: &str,
sparql: &str,
) -> Result<DistributedQueryPlan> {
let parsed = self.parse_sparql(sparql)?;
let data_distribution = self.analyze_data_distribution().await?;
let subqueries = self.create_subqueries(&parsed, &data_distribution).await?;
let join_operations = self.plan_joins(&subqueries)?;
let aggregation_plan = self.create_aggregation_plan(&parsed)?;
let estimated_cost = self.estimate_cost(&subqueries, &join_operations).await;
Ok(DistributedQueryPlan {
query_id: query_id.to_string(),
original_sparql: sparql.to_string(),
subqueries,
join_operations,
aggregation_plan,
estimated_cost,
})
}
fn parse_sparql(&self, sparql: &str) -> Result<ParsedQuery> {
let mut variables = Vec::new();
let mut triple_patterns = Vec::new();
let filters = Vec::new();
if let Some(select_part) = sparql.split("WHERE").next() {
if select_part.contains("SELECT") {
let vars_part = select_part.replace("SELECT", "").trim().to_string();
if vars_part != "*" {
variables = vars_part
.split_whitespace()
.filter(|v| v.starts_with('?'))
.map(|v| v.to_string())
.collect();
}
}
}
if let Some(where_part) = sparql.split("WHERE").nth(1) {
let clean_where = where_part.replace(['{', '}'], "");
for line in clean_where.lines() {
let line = line.trim();
if !line.is_empty() && line.contains(' ') {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() >= 3 {
triple_patterns.push(TriplePattern {
subject: parts[0].to_string(),
predicate: parts[1].to_string(),
object: parts[2].replace('.', ""),
});
}
}
}
}
Ok(ParsedQuery {
variables,
triple_patterns,
filters,
limit: None,
offset: None,
order_by: Vec::new(),
})
}
async fn analyze_data_distribution(&self) -> Result<DataDistribution> {
let nodes = self.cluster_nodes.read().await;
let mut distribution = DataDistribution {
node_triple_counts: HashMap::new(),
predicate_distribution: HashMap::new(),
subject_distribution: HashMap::new(),
};
for &node_id in nodes.iter() {
distribution.node_triple_counts.insert(node_id, 10000); }
Ok(distribution)
}
async fn create_subqueries(
&self,
parsed: &ParsedQuery,
_distribution: &DataDistribution,
) -> Result<Vec<SubqueryPlan>> {
let mut subqueries = Vec::new();
let nodes: Vec<_> = self.cluster_nodes.read().await.iter().cloned().collect();
if nodes.is_empty() {
return Err(anyhow::anyhow!("No nodes available for query execution"));
}
for (i, triple_pattern) in parsed.triple_patterns.iter().enumerate() {
let target_node = nodes[i % nodes.len()];
let sparql_fragment = format!(
"SELECT {} WHERE {{ {} {} {} }}",
parsed.variables.join(" "),
triple_pattern.subject,
triple_pattern.predicate,
triple_pattern.object
);
subqueries.push(SubqueryPlan {
subquery_id: format!("subquery_{i}"),
target_node,
sparql_fragment,
variables: parsed.variables.clone(),
estimated_rows: 1000, estimated_latency_ms: 50,
});
}
Ok(subqueries)
}
fn plan_joins(&self, subqueries: &[SubqueryPlan]) -> Result<Vec<JoinOperation>> {
let mut joins = Vec::new();
for i in 0..subqueries.len().saturating_sub(1) {
let left = &subqueries[i];
let right = &subqueries[i + 1];
let common_vars: Vec<String> = left
.variables
.iter()
.filter(|v| right.variables.contains(v))
.cloned()
.collect();
if !common_vars.is_empty() {
joins.push(JoinOperation {
left_subquery: left.subquery_id.clone(),
right_subquery: right.subquery_id.clone(),
join_variables: common_vars,
join_type: JoinType::Inner,
});
}
}
Ok(joins)
}
fn create_aggregation_plan(&self, parsed: &ParsedQuery) -> Result<Option<AggregationPlan>> {
if parsed.order_by.is_empty() && parsed.limit.is_none() {
return Ok(None);
}
Ok(Some(AggregationPlan {
group_by: Vec::new(),
aggregates: Vec::new(),
having_conditions: Vec::new(),
order_by: parsed.order_by.clone(),
limit: parsed.limit,
offset: parsed.offset,
}))
}
async fn estimate_cost(&self, subqueries: &[SubqueryPlan], joins: &[JoinOperation]) -> f64 {
let mut total_cost = 0.0;
for subquery in subqueries {
total_cost += subquery.estimated_rows as f64 * 0.001; total_cost += subquery.estimated_latency_ms as f64 * 0.01; }
for _join in joins {
total_cost += 10.0; }
total_cost
}
async fn execute_subqueries(
&self,
plan: &DistributedQueryPlan,
) -> Result<HashMap<String, Vec<ResultBinding>>> {
let mut results = HashMap::new();
let mut handles = Vec::new();
for subquery in &plan.subqueries {
let subquery_clone = subquery.clone();
let handle =
tokio::spawn(async move { Self::execute_single_subquery(subquery_clone).await });
handles.push((subquery.subquery_id.clone(), handle));
}
for (subquery_id, handle) in handles {
match handle.await {
Ok(Ok(subquery_results)) => {
results.insert(subquery_id, subquery_results);
}
Ok(Err(e)) => {
error!("Subquery {} failed: {}", subquery_id, e);
return Err(e);
}
Err(e) => {
error!("Subquery {} task failed: {}", subquery_id, e);
return Err(anyhow::anyhow!("Task execution failed: {}", e));
}
}
}
Ok(results)
}
async fn execute_single_subquery(subquery: SubqueryPlan) -> Result<Vec<ResultBinding>> {
debug!(
"Executing subquery {} on node {}",
subquery.subquery_id, subquery.target_node
);
let client = reqwest::Client::builder()
.timeout(tokio::time::Duration::from_millis(
subquery.estimated_latency_ms * 3,
))
.build()
.map_err(|e| anyhow::anyhow!("Failed to create HTTP client: {}", e))?;
let endpoint_url = format!("http://node-{}/sparql", subquery.target_node);
let response = client
.post(&endpoint_url)
.header("Content-Type", "application/sparql-query")
.header("Accept", "application/sparql-results+json")
.body(subquery.sparql_fragment.clone())
.send()
.await;
match response {
Ok(resp) if resp.status().is_success() => {
let json: serde_json::Value = resp
.json()
.await
.map_err(|e| anyhow::anyhow!("Failed to parse JSON response: {}", e))?;
Self::parse_sparql_json_results(json)
}
Ok(resp) => {
warn!(
"Node {} returned error status {}: falling back to simulation",
subquery.target_node,
resp.status()
);
Self::simulate_subquery_execution(subquery).await
}
Err(e) => {
warn!(
"Failed to reach node {}: {} - falling back to simulation",
subquery.target_node, e
);
Self::simulate_subquery_execution(subquery).await
}
}
}
fn parse_sparql_json_results(json: serde_json::Value) -> Result<Vec<ResultBinding>> {
let bindings_array = json
.get("results")
.and_then(|r| r.get("bindings"))
.and_then(|b| b.as_array())
.ok_or_else(|| anyhow::anyhow!("Invalid SPARQL JSON results format"))?;
let mut results = Vec::new();
for binding_obj in bindings_array {
if let Some(binding_map) = binding_obj.as_object() {
let mut result_binding = ResultBinding::new();
for (var_name, var_value) in binding_map {
if let Some(value_obj) = var_value.as_object() {
let value = value_obj
.get("value")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
result_binding.add_binding(format!("?{var_name}"), value);
}
}
results.push(result_binding);
}
}
Ok(results)
}
async fn simulate_subquery_execution(subquery: SubqueryPlan) -> Result<Vec<ResultBinding>> {
tokio::time::sleep(tokio::time::Duration::from_millis(
subquery.estimated_latency_ms,
))
.await;
let mut results = Vec::new();
let result_count = std::cmp::min(subquery.estimated_rows, 100);
for i in 0..result_count {
let mut binding = ResultBinding::new();
for var in &subquery.variables {
let value = match var.as_str() {
"?s" | "?subject" => format!("http://example.org/resource_{i}"),
"?p" | "?predicate" => format!("http://example.org/property_{}", i % 10),
"?o" | "?object" => format!("\"Object value {i}\""),
"?name" => format!("\"Name {i}\""),
"?type" => "http://example.org/Type".to_string(),
_ => format!("value_{}_{}", subquery.target_node, i),
};
binding.add_binding(var.clone(), value);
}
results.push(binding);
}
Ok(results)
}
async fn combine_results(
&self,
plan: &DistributedQueryPlan,
subquery_results: HashMap<String, Vec<ResultBinding>>,
) -> Result<Vec<ResultBinding>> {
let mut current_results = Vec::new();
if let Some(first_subquery) = plan.subqueries.first() {
if let Some(first_results) = subquery_results.get(&first_subquery.subquery_id) {
current_results = first_results.clone();
}
}
for join in &plan.join_operations {
if let Some(right_results) = subquery_results.get(&join.right_subquery) {
current_results = self
.execute_join(¤t_results, right_results, join)
.await?;
}
}
if let Some(agg_plan) = &plan.aggregation_plan {
current_results = self.apply_aggregation(current_results, agg_plan).await?;
}
Ok(current_results)
}
async fn execute_join(
&self,
left_results: &[ResultBinding],
right_results: &[ResultBinding],
join: &JoinOperation,
) -> Result<Vec<ResultBinding>> {
let mut joined_results = Vec::new();
match join.join_type {
JoinType::Inner => {
for left_binding in left_results {
for right_binding in right_results {
if self.bindings_compatible(
left_binding,
right_binding,
&join.join_variables,
) {
if let Some(merged) = left_binding.merge(right_binding) {
joined_results.push(merged);
}
}
}
}
}
JoinType::Left => {
for left_binding in left_results {
let mut found_match = false;
for right_binding in right_results {
if self.bindings_compatible(
left_binding,
right_binding,
&join.join_variables,
) {
if let Some(merged) = left_binding.merge(right_binding) {
joined_results.push(merged);
found_match = true;
}
}
}
if !found_match {
joined_results.push(left_binding.clone());
}
}
}
JoinType::Optional => {
joined_results = Box::pin(self.execute_join(
left_results,
right_results,
&JoinOperation {
left_subquery: join.left_subquery.clone(),
right_subquery: join.right_subquery.clone(),
join_variables: join.join_variables.clone(),
join_type: JoinType::Left,
},
))
.await?;
}
JoinType::Union => {
joined_results.extend_from_slice(left_results);
joined_results.extend_from_slice(right_results);
joined_results.sort_by(|a, b| format!("{a:?}").cmp(&format!("{b:?}")));
joined_results.dedup();
}
}
Ok(joined_results)
}
fn bindings_compatible(
&self,
left: &ResultBinding,
right: &ResultBinding,
join_variables: &[String],
) -> bool {
for var in join_variables {
if let (Some(left_val), Some(right_val)) = (left.get(var), right.get(var)) {
if left_val != right_val {
return false;
}
}
}
true
}
async fn apply_aggregation(
&self,
mut results: Vec<ResultBinding>,
agg_plan: &AggregationPlan,
) -> Result<Vec<ResultBinding>> {
if !agg_plan.order_by.is_empty() {
results.sort_by(|a, b| {
for order_clause in &agg_plan.order_by {
let empty_string = String::new();
let a_val = a.get(&order_clause.variable).unwrap_or(&empty_string);
let b_val = b.get(&order_clause.variable).unwrap_or(&empty_string);
let cmp = if order_clause.ascending {
a_val.cmp(b_val)
} else {
b_val.cmp(a_val)
};
if cmp != std::cmp::Ordering::Equal {
return cmp;
}
}
std::cmp::Ordering::Equal
});
}
if let Some(offset) = agg_plan.offset {
if offset < results.len() as u64 {
results = results.into_iter().skip(offset as usize).collect();
} else {
results.clear();
}
}
if let Some(limit) = agg_plan.limit {
results.truncate(limit as usize);
}
Ok(results)
}
async fn check_cache(&self, sparql: &str) -> Option<Vec<ResultBinding>> {
let cache = self.query_cache.read().await;
cache.get(sparql).cloned()
}
async fn cache_results(&self, sparql: &str, results: &[ResultBinding]) {
let mut cache = self.query_cache.write().await;
cache.insert(sparql.to_string(), results.to_vec());
if cache.len() > 1000 {
let keys_to_remove: Vec<_> = cache.keys().take(100).cloned().collect();
for key in keys_to_remove {
cache.remove(&key);
}
}
}
async fn record_statistics(
&self,
query_id: &str,
plan: &DistributedQueryPlan,
results: &[ResultBinding],
execution_time_ms: u64,
) {
let stats = QueryStats {
query_id: query_id.to_string(),
execution_time_ms,
nodes_involved: plan.subqueries.len() as u32,
total_intermediate_results: plan.subqueries.iter().map(|s| s.estimated_rows).sum(),
final_result_count: results.len() as u64,
network_transfer_bytes: 0, cache_hits: 0,
cache_misses: 1,
};
let mut statistics = self.statistics.write().await;
statistics.insert(query_id.to_string(), stats);
if statistics.len() > 10000 {
let keys_to_remove: Vec<_> = statistics.keys().take(1000).cloned().collect();
for key in keys_to_remove {
statistics.remove(&key);
}
}
}
pub async fn get_statistics(&self) -> HashMap<String, QueryStats> {
self.statistics.read().await.clone()
}
pub async fn clear_cache(&self) {
let mut cache = self.query_cache.write().await;
cache.clear();
info!("Query cache cleared");
}
}
#[derive(Debug, Clone)]
struct ParsedQuery {
variables: Vec<String>,
triple_patterns: Vec<TriplePattern>,
#[allow(dead_code)]
filters: Vec<String>,
limit: Option<u64>,
offset: Option<u64>,
order_by: Vec<OrderByClause>,
}
#[derive(Debug, Clone)]
struct TriplePattern {
subject: String,
predicate: String,
object: String,
}
#[derive(Debug, Clone)]
struct DataDistribution {
node_triple_counts: HashMap<OxirsNodeId, u64>,
#[allow(dead_code)]
predicate_distribution: HashMap<String, Vec<OxirsNodeId>>,
#[allow(dead_code)]
subject_distribution: HashMap<String, Vec<OxirsNodeId>>,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_distributed_query_executor_creation() {
let executor = DistributedQueryExecutor::new(1);
executor.add_node(2).await;
executor.add_node(3).await;
let nodes = executor.cluster_nodes.read().await;
assert_eq!(nodes.len(), 2);
assert!(nodes.contains(&2));
assert!(nodes.contains(&3));
}
#[tokio::test]
async fn test_result_binding() {
let mut binding1 = ResultBinding::new();
binding1.add_binding("?x".to_string(), "value1".to_string());
let mut binding2 = ResultBinding::new();
binding2.add_binding("?y".to_string(), "value2".to_string());
let merged = binding1.merge(&binding2).unwrap();
assert_eq!(merged.get("?x"), Some(&"value1".to_string()));
assert_eq!(merged.get("?y"), Some(&"value2".to_string()));
}
#[tokio::test]
async fn test_result_binding_conflict() {
let mut binding1 = ResultBinding::new();
binding1.add_binding("?x".to_string(), "value1".to_string());
let mut binding2 = ResultBinding::new();
binding2.add_binding("?x".to_string(), "value2".to_string());
let merged = binding1.merge(&binding2);
assert!(merged.is_none()); }
#[test]
fn test_sparql_parsing() {
let executor = DistributedQueryExecutor::new(1);
let sparql = "SELECT ?x ?y WHERE { ?x <predicate> ?y }";
let parsed = executor.parse_sparql(sparql).unwrap();
assert_eq!(parsed.variables, vec!["?x", "?y"]);
assert_eq!(parsed.triple_patterns.len(), 1);
}
#[tokio::test]
async fn test_query_execution() {
let executor = DistributedQueryExecutor::new(1);
executor.add_node(2).await;
executor.add_node(3).await;
let sparql = "SELECT ?x WHERE { ?x <type> <Person> }";
let results = executor.execute_query(sparql).await.unwrap();
assert!(!results.is_empty());
}
#[tokio::test]
async fn test_query_caching() {
let executor = DistributedQueryExecutor::new(1);
executor.add_node(2).await;
let sparql = "SELECT ?x WHERE { ?x <type> <Person> }";
let results1 = executor.execute_query(sparql).await.unwrap();
let results2 = executor.execute_query(sparql).await.unwrap();
assert_eq!(results1.len(), results2.len());
}
}