use hickory_resolver::TokioResolver;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, sync::Arc, time::Duration};
use tokio::{
sync::{Notify, RwLock},
time::interval,
};
use url::Url;
use crate::{
error::{FusekiError, FusekiResult},
federation::{
FederationConfig, ServiceCapabilities, ServiceEndpoint, ServiceHealth, ServiceMetadata,
},
};
type Result<T> = FusekiResult<T>;
#[derive(Debug, Clone)]
pub enum DiscoveryMethod {
Static(Vec<ServiceRegistration>),
Dns { domain: String },
Consul { endpoint: Url },
Kubernetes { namespace: String },
ServiceDescription,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServiceRegistration {
pub id: String,
pub url: Url,
pub metadata: ServiceMetadata,
}
pub struct ServiceDiscovery {
config: FederationConfig,
endpoints: Arc<RwLock<HashMap<String, ServiceEndpoint>>>,
discovery_methods: Vec<DiscoveryMethod>,
http_client: Client,
shutdown: Arc<Notify>,
}
impl ServiceDiscovery {
pub fn new(
config: FederationConfig,
endpoints: Arc<RwLock<HashMap<String, ServiceEndpoint>>>,
) -> Self {
Self {
config,
endpoints,
discovery_methods: Vec::new(),
http_client: Client::builder()
.timeout(Duration::from_secs(10))
.build()
.expect("HTTP client build should succeed"),
shutdown: Arc::new(Notify::new()),
}
}
pub fn add_method(&mut self, method: DiscoveryMethod) {
self.discovery_methods.push(method);
}
pub async fn start(&self) -> Result<()> {
let shutdown = self.shutdown.clone();
let config = self.config.clone();
let endpoints = self.endpoints.clone();
let methods = self.discovery_methods.clone();
let client = self.http_client.clone();
tokio::spawn(async move {
let mut interval = interval(config.discovery_interval);
loop {
tokio::select! {
_ = interval.tick() => {
if let Err(e) = Self::discover_services(&methods, &endpoints, &client).await {
tracing::error!("Service discovery error: {}", e);
}
}
_ = shutdown.notified() => {
tracing::info!("Service discovery shutting down");
break;
}
}
}
});
Self::discover_services(&self.discovery_methods, &self.endpoints, &self.http_client)
.await?;
Ok(())
}
pub async fn stop(&self) -> Result<()> {
self.shutdown.notify_one();
Ok(())
}
async fn discover_services(
methods: &[DiscoveryMethod],
endpoints: &Arc<RwLock<HashMap<String, ServiceEndpoint>>>,
client: &Client,
) -> Result<()> {
for method in methods {
match method {
DiscoveryMethod::Static(registrations) => {
Self::discover_static(registrations, endpoints).await?;
}
DiscoveryMethod::ServiceDescription => {
Self::discover_via_service_description(endpoints, client).await?;
}
DiscoveryMethod::Dns { domain } => {
Self::discover_via_dns(domain, endpoints, client).await?;
}
DiscoveryMethod::Consul { endpoint } => {
Self::discover_via_consul(endpoint, endpoints, client).await?;
}
DiscoveryMethod::Kubernetes { namespace } => {
Self::discover_via_kubernetes(namespace, endpoints, client).await?;
}
}
}
Ok(())
}
async fn discover_static(
registrations: &[ServiceRegistration],
endpoints: &Arc<RwLock<HashMap<String, ServiceEndpoint>>>,
) -> Result<()> {
let mut eps = endpoints.write().await;
for reg in registrations {
let endpoint = ServiceEndpoint {
url: reg.url.clone(),
metadata: reg.metadata.clone(),
health: ServiceHealth::Unknown,
capabilities: ServiceCapabilities::default(),
};
eps.insert(reg.id.clone(), endpoint);
}
Ok(())
}
async fn discover_via_service_description(
endpoints: &Arc<RwLock<HashMap<String, ServiceEndpoint>>>,
client: &Client,
) -> Result<()> {
let eps = endpoints.read().await;
let service_urls: Vec<_> = eps.values().map(|ep| ep.url.clone()).collect();
drop(eps);
for url in service_urls {
if let Ok(capabilities) = Self::fetch_service_description(&url, client).await {
let mut eps = endpoints.write().await;
if let Some(ep) = eps.values_mut().find(|ep| ep.url == url) {
ep.capabilities = capabilities;
}
}
}
Ok(())
}
async fn fetch_service_description(
base_url: &Url,
client: &Client,
) -> Result<ServiceCapabilities> {
let query = r#"
PREFIX sd: <http://www.w3.org/ns/sparql-service-description#>
PREFIX void: <http://rdfs.org/ns/void#>
SELECT ?feature ?triples WHERE {
?service a sd:Service ;
sd:supportedLanguage ?feature .
OPTIONAL {
?service sd:defaultDataset/void:triples ?triples
}
}
"#;
let response = client
.get(base_url.as_str())
.query(&[("query", query)])
.header("Accept", "application/sparql-results+json")
.send()
.await
.map_err(|e| FusekiError::Internal {
message: format!("Failed to fetch service description: {e}"),
})?;
if !response.status().is_success() {
return Err(FusekiError::Internal {
message: format!("Service description query failed: {}", response.status()),
});
}
let json_text = response.text().await.map_err(|e| FusekiError::Internal {
message: format!("Failed to read response body: {e}"),
})?;
let capabilities = Self::parse_sparql_json_results(&json_text)?;
Ok(capabilities)
}
fn parse_sparql_json_results(json_text: &str) -> Result<ServiceCapabilities> {
let json: serde_json::Value =
serde_json::from_str(json_text).map_err(|e| FusekiError::Internal {
message: format!("Failed to parse SPARQL JSON results: {e}"),
})?;
let mut sparql_features = Vec::new();
let mut triple_count = None;
if let Some(bindings) = json["results"]["bindings"].as_array() {
for binding in bindings {
if let Some(feature_value) = binding["feature"]["value"].as_str() {
let feature_name = Self::parse_feature_uri(feature_value);
if !sparql_features.contains(&feature_name) {
sparql_features.push(feature_name);
}
}
if let Some(triples_value) = binding["triples"]["value"].as_str() {
if let Ok(count) = triples_value.parse::<usize>() {
triple_count = Some(count);
}
}
}
}
if sparql_features.is_empty() {
sparql_features = vec![
"SPARQL 1.1 Query".to_string(),
"SPARQL 1.1 Update".to_string(),
];
}
Ok(ServiceCapabilities {
sparql_features,
result_formats: vec![
"application/sparql-results+json".to_string(),
"application/sparql-results+xml".to_string(),
"text/turtle".to_string(),
"application/rdf+xml".to_string(),
],
dataset_size: triple_count.map(|c| c as u64),
avg_response_time: None, max_result_size: None, })
}
fn parse_feature_uri(uri: &str) -> String {
match uri {
"http://www.w3.org/ns/sparql-service-description#SPARQL11Query" => {
"SPARQL 1.1 Query".to_string()
}
"http://www.w3.org/ns/sparql-service-description#SPARQL11Update" => {
"SPARQL 1.1 Update".to_string()
}
"http://www.w3.org/ns/sparql-service-description#SPARQL10Query" => {
"SPARQL 1.0 Query".to_string()
}
"http://www.w3.org/ns/sparql-service-description#DereferencesURIs" => {
"Dereferences URIs".to_string()
}
"http://www.w3.org/ns/sparql-service-description#UnionDefaultGraph" => {
"Union Default Graph".to_string()
}
"http://www.w3.org/ns/sparql-service-description#BasicFederatedQuery" => {
"Basic Federated Query".to_string()
}
_ => {
uri.rsplit('/').next().unwrap_or(uri).to_string()
}
}
}
async fn discover_via_dns(
domain: &str,
endpoints: &Arc<RwLock<HashMap<String, ServiceEndpoint>>>,
client: &Client,
) -> Result<()> {
tracing::info!(
"Starting DNS-based service discovery for domain: {}",
domain
);
let resolver = TokioResolver::builder_tokio()
.map_err(|e| FusekiError::Internal {
message: format!("Failed to create DNS resolver: {e}"),
})?
.build();
let srv_query = format!("_sparql._tcp.{domain}");
match resolver.srv_lookup(&srv_query).await {
Ok(lookup) => {
let mut eps = endpoints.write().await;
let mut discovered_count = 0;
for record in lookup.iter() {
let target = record.target().to_string();
let port = record.port();
let service_url = match Url::parse(&format!(
"http://{}:{}/sparql",
target.trim_end_matches('.'),
port
)) {
Ok(url) => url,
Err(e) => {
tracing::warn!("Invalid URL for SRV record {}:{}: {}", target, port, e);
continue;
}
};
let service_id = format!("dns-{}:{}", target.trim_end_matches('.'), port);
match Self::check_service_health(&service_url, client).await {
Ok(health) => {
let endpoint = ServiceEndpoint {
url: service_url,
metadata: ServiceMetadata {
name: format!(
"SPARQL Service at {}:{}",
target.trim_end_matches('.'),
port
),
description: Some(format!(
"Discovered via DNS SRV record for {domain}"
)),
version: None,
contact: None,
location: None,
tags: vec![],
},
health,
capabilities: ServiceCapabilities::default(),
};
eps.insert(service_id.clone(), endpoint);
discovered_count += 1;
tracing::info!(
"Discovered SPARQL service: {} at {}:{}",
service_id,
target,
port
);
}
Err(e) => {
tracing::warn!(
"Service at {}:{} is not reachable: {}",
target,
port,
e
);
}
}
}
tracing::info!(
"DNS discovery completed: {} services discovered for domain {}",
discovered_count,
domain
);
}
Err(e) => {
tracing::warn!("No SRV records found for {}: {}", srv_query, e);
Self::discover_via_fallback_ports(domain, endpoints, client).await?;
}
}
Ok(())
}
async fn discover_via_consul(
consul_endpoint: &Url,
endpoints: &Arc<RwLock<HashMap<String, ServiceEndpoint>>>,
client: &Client,
) -> Result<()> {
tracing::info!(
"Starting Consul-based service discovery from: {}",
consul_endpoint
);
let consul_url = format!("{consul_endpoint}/v1/health/service/sparql?passing=true");
let response = client
.get(&consul_url)
.header("Accept", "application/json")
.timeout(Duration::from_secs(10))
.send()
.await
.map_err(|e| FusekiError::Internal {
message: format!("Failed to query Consul: {e}"),
})?;
if !response.status().is_success() {
return Err(FusekiError::Internal {
message: format!("Consul query failed with status: {}", response.status()),
});
}
let consul_services: serde_json::Value =
response.json().await.map_err(|e| FusekiError::Internal {
message: format!("Failed to parse Consul response: {e}"),
})?;
let mut eps = endpoints.write().await;
let mut discovered_count = 0;
if let Some(services) = consul_services.as_array() {
for service in services {
if let (Some(service_obj), Some(checks)) = (
service.get("Service").and_then(|s| s.as_object()),
service.get("Checks").and_then(|c| c.as_array()),
) {
let service_name = service_obj
.get("Service")
.and_then(|s| s.as_str())
.unwrap_or("unknown");
let service_id = service_obj
.get("ID")
.and_then(|id| id.as_str())
.unwrap_or(service_name);
let address = service_obj
.get("Address")
.and_then(|a| a.as_str())
.unwrap_or("localhost");
let port = service_obj
.get("Port")
.and_then(|p| p.as_u64())
.unwrap_or(8080) as u16;
let all_passing = checks.iter().all(|check| {
check
.get("Status")
.and_then(|s| s.as_str())
.map(|status| status == "passing")
.unwrap_or(false)
});
if !all_passing {
tracing::debug!("Skipping unhealthy Consul service: {}", service_id);
continue;
}
let tags = service_obj
.get("Tags")
.and_then(|t| t.as_array())
.map(|tags| tags.iter().filter_map(|t| t.as_str()).collect::<Vec<_>>())
.unwrap_or_default();
let sparql_path = tags
.iter()
.find(|tag| tag.starts_with("sparql-path="))
.map(|tag| tag.strip_prefix("sparql-path=").unwrap_or("/sparql"))
.unwrap_or("/sparql");
let service_url =
match Url::parse(&format!("http://{address}:{port}{sparql_path}")) {
Ok(url) => url,
Err(e) => {
tracing::warn!(
"Invalid URL for Consul service {}:{}:{}: {}",
service_id,
address,
port,
e
);
continue;
}
};
match Self::check_service_health(&service_url, client).await {
Ok(health) => {
let consul_service_id = format!("consul-{service_id}");
let endpoint = ServiceEndpoint {
url: service_url.clone(),
metadata: ServiceMetadata {
name: format!("Consul Service: {service_name}"),
description: Some(format!(
"Discovered via Consul from {consul_endpoint}"
)),
version: tags
.iter()
.find(|tag| tag.starts_with("version="))
.map(|tag| {
tag.strip_prefix("version=")
.unwrap_or("unknown")
.to_string()
}),
contact: None,
location: None,
tags: tags.iter().map(|s| s.to_string()).collect(),
},
health,
capabilities: ServiceCapabilities::default(),
};
eps.insert(consul_service_id.clone(), endpoint);
discovered_count += 1;
tracing::info!(
"Discovered SPARQL service via Consul: {} at {}",
consul_service_id,
service_url
);
}
Err(e) => {
tracing::warn!(
"Consul service {} at {} failed health check: {}",
service_id,
service_url,
e
);
}
}
}
}
}
tracing::info!(
"Consul discovery completed: {} services discovered from {}",
discovered_count,
consul_endpoint
);
Ok(())
}
async fn discover_via_kubernetes(
namespace: &str,
_endpoints: &Arc<RwLock<HashMap<String, ServiceEndpoint>>>,
_client: &Client,
) -> Result<()> {
tracing::debug!(
"Kubernetes discovery not yet implemented for namespace: {}",
namespace
);
Ok(())
}
async fn check_service_health(url: &Url, client: &Client) -> Result<ServiceHealth> {
let health_check_query = "ASK { ?s ?p ?o }";
let response = client
.get(url.as_str())
.query(&[("query", health_check_query)])
.header("Accept", "application/sparql-results+json")
.timeout(Duration::from_secs(5))
.send()
.await;
match response {
Ok(resp) if resp.status().is_success() => {
tracing::debug!("Service at {} is healthy", url);
Ok(ServiceHealth::Healthy)
}
Ok(resp) => {
tracing::warn!("Service at {} returned status: {}", url, resp.status());
Ok(ServiceHealth::Unhealthy)
}
Err(e) => {
tracing::warn!("Failed to reach service at {}: {}", url, e);
Err(FusekiError::Internal {
message: format!("Service health check failed: {e}"),
})
}
}
}
async fn discover_via_fallback_ports(
domain: &str,
endpoints: &Arc<RwLock<HashMap<String, ServiceEndpoint>>>,
client: &Client,
) -> Result<()> {
let common_ports = [8080, 3030, 8000, 80, 443];
let common_paths = ["/sparql", "/query", "/sparql/query"];
let mut eps = endpoints.write().await;
let mut discovered_count = 0;
for port in &common_ports {
for path in &common_paths {
let scheme = if *port == 443 { "https" } else { "http" };
let service_url = match Url::parse(&format!("{scheme}://{domain}:{port}{path}")) {
Ok(url) => url,
Err(e) => {
tracing::debug!(
"Invalid fallback URL for {}:{}{}: {}",
domain,
port,
path,
e
);
continue;
}
};
match Self::check_service_health(&service_url, client).await {
Ok(health) => {
let service_id = format!("fallback-{domain}:{port}{path}");
let endpoint = ServiceEndpoint {
url: service_url.clone(),
metadata: ServiceMetadata {
name: format!("SPARQL Service at {domain}:{port}{path}"),
description: Some(
"Discovered via fallback port scanning".to_string(),
),
version: None,
contact: None,
location: None,
tags: vec![],
},
health,
capabilities: ServiceCapabilities::default(),
};
eps.insert(service_id.clone(), endpoint);
discovered_count += 1;
tracing::info!(
"Discovered SPARQL service via fallback: {} at {}",
service_id,
service_url
);
break;
}
Err(_) => {
tracing::debug!("No SPARQL service found at {}:{}{}", domain, port, path);
}
}
}
}
tracing::info!(
"Fallback discovery completed: {} services discovered for domain {}",
discovered_count,
domain
);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_static_discovery() {
let endpoints = Arc::new(RwLock::new(HashMap::new()));
let config = FederationConfig::default();
let mut discovery = ServiceDiscovery::new(config, endpoints.clone());
let registrations = vec![ServiceRegistration {
id: "test-service".to_string(),
url: Url::parse("http://example.com/sparql").unwrap(),
metadata: ServiceMetadata {
name: "Test Service".to_string(),
..Default::default()
},
}];
discovery.add_method(DiscoveryMethod::Static(registrations));
discovery.start().await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
let eps = endpoints.read().await;
assert_eq!(eps.len(), 1);
assert!(eps.contains_key("test-service"));
discovery.stop().await.unwrap();
}
}