use anyhow::{Context, Result};
use futures_util::future;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::time::{Duration, Instant};
pub struct DatasetFederation {
endpoints: Vec<SparqlEndpoint>,
join_optimizer: JoinOptimizer,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SparqlEndpoint {
pub id: String,
pub url: String,
pub auth_header: Option<String>,
pub timeout_secs: u64,
pub max_concurrent_queries: usize,
pub supported_features: HashSet<String>,
pub statistics: Option<EndpointStatistics>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EndpointStatistics {
#[serde(with = "duration_serde")]
pub avg_response_time: Duration,
pub triple_count: Option<u64>,
pub indexes: Vec<String>,
#[serde(with = "instant_serde")]
pub last_updated: Instant,
}
impl Default for EndpointStatistics {
fn default() -> Self {
Self {
avg_response_time: Duration::from_millis(100),
triple_count: None,
indexes: Vec::new(),
last_updated: Instant::now(),
}
}
}
mod duration_serde {
use serde::{Deserialize, Deserializer, Serializer};
use std::time::Duration;
pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_u64(duration.as_millis() as u64)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
let millis = u64::deserialize(deserializer)?;
Ok(Duration::from_millis(millis))
}
}
mod instant_serde {
use serde::{Deserializer, Serializer};
use std::time::{Instant, SystemTime, UNIX_EPOCH};
pub fn serialize<S>(_instant: &Instant, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let now = SystemTime::now();
let epoch_duration = now.duration_since(UNIX_EPOCH).unwrap_or_default();
let instant_timestamp = epoch_duration.as_secs();
serializer.serialize_u64(instant_timestamp)
}
pub fn deserialize<'de, D>(_deserializer: D) -> Result<Instant, D::Error>
where
D: Deserializer<'de>,
{
Ok(Instant::now())
}
}
pub struct JoinOptimizer {
#[allow(dead_code)]
cost_model: CostModel,
}
#[derive(Debug, Clone)]
pub struct CostModel {
#[allow(dead_code)]
network_latency_ms: f64,
#[allow(dead_code)]
transfer_cost_per_mb: f64,
#[allow(dead_code)]
local_processing_cost: f64,
}
#[derive(Debug, Clone)]
pub struct FederatedStep {
pub endpoint_id: String,
pub sparql_query: String,
pub expected_result_size: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct JoinPattern {
pub left_step: usize,
pub right_step: usize,
pub join_variables: Vec<String>,
}
impl DatasetFederation {
pub fn new() -> Self {
Self {
endpoints: Vec::new(),
join_optimizer: JoinOptimizer::new(),
}
}
pub fn add_endpoint(&mut self, endpoint: SparqlEndpoint) {
self.endpoints.push(endpoint);
}
pub async fn federate_sparql_query(&self, query: &str) -> Result<serde_json::Value> {
let query_plan = self.plan_federated_query(query).await?;
self.execute_federated_plan(&query_plan).await
}
async fn plan_federated_query(&self, query: &str) -> Result<Vec<FederatedStep>> {
let mut steps = Vec::new();
for endpoint in &self.endpoints {
if self.endpoint_can_contribute(endpoint, query).await? {
let adapted_query = self.adapt_query_for_endpoint(query, endpoint)?;
steps.push(FederatedStep {
endpoint_id: endpoint.id.clone(),
sparql_query: adapted_query,
expected_result_size: None,
});
}
}
self.join_optimizer.optimize_joins(&mut steps);
Ok(steps)
}
async fn endpoint_can_contribute(
&self,
endpoint: &SparqlEndpoint,
query: &str,
) -> Result<bool> {
if query.contains("FILTER") && !endpoint.supported_features.contains("filters") {
return Ok(false);
}
if query.contains("GROUP BY") && !endpoint.supported_features.contains("aggregation") {
return Ok(false);
}
Ok(true)
}
fn adapt_query_for_endpoint(&self, query: &str, endpoint: &SparqlEndpoint) -> Result<String> {
let mut adapted = query.to_string();
if !adapted.contains("SERVICE") {
adapted = format!(
"SELECT * WHERE {{ SERVICE <{}> {{ {} }} }}",
endpoint.url, adapted
);
}
Ok(adapted)
}
async fn execute_federated_plan(&self, plan: &[FederatedStep]) -> Result<serde_json::Value> {
let futures: Vec<_> = plan
.iter()
.map(|step| self.execute_federated_step(step))
.collect();
let step_results = future::try_join_all(futures).await?;
self.merge_federated_results(&step_results)
}
async fn execute_federated_step(&self, step: &FederatedStep) -> Result<serde_json::Value> {
let endpoint = self
.endpoints
.iter()
.find(|ep| ep.id == step.endpoint_id)
.ok_or_else(|| anyhow::anyhow!("Endpoint not found: {}", step.endpoint_id))?;
let client = reqwest::Client::new();
let mut request = client
.post(&endpoint.url)
.header("Content-Type", "application/sparql-query")
.body(step.sparql_query.clone());
if let Some(auth) = &endpoint.auth_header {
request = request.header("Authorization", auth);
}
let response = request
.timeout(Duration::from_secs(endpoint.timeout_secs))
.send()
.await
.context("Failed to execute SPARQL query")?;
if !response.status().is_success() {
return Err(anyhow::anyhow!(
"SPARQL query failed with status: {}",
response.status()
));
}
let result = response
.json()
.await
.context("Failed to parse SPARQL response")?;
Ok(result)
}
fn merge_federated_results(&self, results: &[serde_json::Value]) -> Result<serde_json::Value> {
let mut merged_bindings = Vec::new();
for result in results {
if let Some(result_obj) = result.as_object() {
if let Some(results_obj) = result_obj.get("results") {
if let Some(bindings) = results_obj.get("bindings").and_then(|b| b.as_array()) {
merged_bindings.extend(bindings.iter().cloned());
}
}
}
}
Ok(serde_json::json!({
"head": {
"vars": []
},
"results": {
"bindings": merged_bindings
}
}))
}
pub async fn update_endpoint_statistics(&mut self, endpoint_id: &str) -> Result<()> {
let endpoint_idx = self
.endpoints
.iter()
.position(|ep| ep.id == endpoint_id)
.ok_or_else(|| anyhow::anyhow!("Endpoint not found: {}", endpoint_id))?;
let stats = self
.assess_endpoint_capabilities(&self.endpoints[endpoint_idx])
.await?;
self.endpoints[endpoint_idx].statistics = Some(stats);
Ok(())
}
async fn assess_endpoint_capabilities(
&self,
endpoint: &SparqlEndpoint,
) -> Result<EndpointStatistics> {
let start_time = Instant::now();
let test_query = "SELECT (COUNT(*) as ?count) WHERE { ?s ?p ?o }";
let client = reqwest::Client::new();
let mut request = client
.post(&endpoint.url)
.header("Content-Type", "application/sparql-query")
.body(test_query);
if let Some(auth) = &endpoint.auth_header {
request = request.header("Authorization", auth);
}
let response = request
.timeout(Duration::from_secs(endpoint.timeout_secs))
.send()
.await
.context("Failed to assess endpoint capabilities")?;
let response_time = start_time.elapsed();
let mut triple_count = None;
if response.status().is_success() {
if let Ok(result) = response.json::<serde_json::Value>().await {
if let Some(bindings) = result.pointer("/results/bindings/0/count/value") {
if let Some(count_str) = bindings.as_str() {
triple_count = count_str.parse().ok();
}
}
}
}
Ok(EndpointStatistics {
avg_response_time: response_time,
triple_count,
indexes: vec!["spo".to_string()], last_updated: Instant::now(),
})
}
}
impl Default for JoinOptimizer {
fn default() -> Self {
Self::new()
}
}
impl JoinOptimizer {
pub fn new() -> Self {
Self {
cost_model: CostModel {
network_latency_ms: 50.0,
transfer_cost_per_mb: 1.0,
local_processing_cost: 0.1,
},
}
}
pub fn optimize_joins(&self, steps: &mut [FederatedStep]) {
steps.sort_by(|a, b| {
match (a.expected_result_size, b.expected_result_size) {
(Some(size_a), Some(size_b)) => size_a.cmp(&size_b),
(Some(_), None) => std::cmp::Ordering::Less,
(None, Some(_)) => std::cmp::Ordering::Greater,
(None, None) => std::cmp::Ordering::Equal,
}
});
}
pub fn identify_join_patterns(&self, steps: &[FederatedStep]) -> Vec<JoinPattern> {
let mut patterns = Vec::new();
for (i, step_a) in steps.iter().enumerate() {
for (j, step_b) in steps.iter().enumerate().skip(i + 1) {
let common_vars =
self.find_common_variables(&step_a.sparql_query, &step_b.sparql_query);
if !common_vars.is_empty() {
patterns.push(JoinPattern {
left_step: i,
right_step: j,
join_variables: common_vars,
});
}
}
}
patterns
}
fn find_common_variables(&self, query_a: &str, query_b: &str) -> Vec<String> {
let var_regex =
regex::Regex::new(r"\?(\w+)").expect("parse should succeed for valid input");
let vars_a: HashSet<String> = var_regex
.captures_iter(query_a)
.map(|cap| cap[1].to_string())
.collect();
let vars_b: HashSet<String> = var_regex
.captures_iter(query_b)
.map(|cap| cap[1].to_string())
.collect();
vars_a.intersection(&vars_b).cloned().collect()
}
}
impl Default for DatasetFederation {
fn default() -> Self {
Self::new()
}
}