use crate::shard::{ShardId, ShardRouter};
use crate::shard_manager::ShardManager;
use crate::{ClusterError, Result};
use oxirs_core::model::Triple;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, warn};
#[derive(Debug, Clone)]
pub struct QueryPlan {
pub query_id: String,
pub shard_targets: Vec<ShardTarget>,
pub optimization_hints: QueryOptimizationHints,
pub estimated_cost: f64,
}
#[derive(Debug, Clone)]
pub struct ShardTarget {
pub shard_id: ShardId,
pub preferred_node: u64,
pub alternative_nodes: Vec<u64>,
pub selectivity: f64,
}
#[derive(Debug, Clone, Default)]
pub struct QueryOptimizationHints {
pub use_index: bool,
pub parallel_execution: bool,
pub limit: Option<usize>,
pub order_by: Option<String>,
pub enable_cache: bool,
pub timeout_ms: Option<u64>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct RoutingStatistics {
pub total_queries: u64,
pub single_shard_queries: u64,
pub multi_shard_queries: u64,
pub avg_shards_per_query: f64,
pub cache_hit_rate: f64,
pub avg_latency_ms: f64,
}
pub struct QueryRouter {
shard_router: Arc<ShardRouter>,
shard_manager: Arc<ShardManager>,
query_cache: Arc<RwLock<QueryCache>>,
statistics: Arc<RwLock<RoutingStatistics>>,
cost_model: Arc<dyn CostModel>,
}
struct QueryCache {
entries: HashMap<String, CacheEntry>,
max_size: usize,
hits: u64,
misses: u64,
}
struct CacheEntry {
plan: QueryPlan,
#[allow(dead_code)]
timestamp: u64,
access_count: u64,
}
pub trait CostModel: Send + Sync {
fn estimate_shard_cost(&self, shard_id: ShardId, selectivity: f64) -> f64;
fn estimate_network_cost(&self, data_size: usize, hop_count: u32) -> f64;
fn estimate_merge_cost(&self, shard_count: usize, result_size: usize) -> f64;
}
pub struct DefaultCostModel;
impl CostModel for DefaultCostModel {
fn estimate_shard_cost(&self, _shard_id: ShardId, selectivity: f64) -> f64 {
10.0 + (100.0 * selectivity)
}
fn estimate_network_cost(&self, data_size: usize, hop_count: u32) -> f64 {
(hop_count as f64 * 5.0) + (data_size as f64 / 1_000_000.0)
}
fn estimate_merge_cost(&self, shard_count: usize, result_size: usize) -> f64 {
(shard_count as f64).log2() * (result_size as f64 / 1000.0)
}
}
impl QueryRouter {
pub fn new(shard_router: Arc<ShardRouter>, shard_manager: Arc<ShardManager>) -> Self {
Self {
shard_router,
shard_manager,
query_cache: Arc::new(RwLock::new(QueryCache {
entries: HashMap::new(),
max_size: 1000,
hits: 0,
misses: 0,
})),
statistics: Arc::new(RwLock::new(RoutingStatistics::default())),
cost_model: Arc::new(DefaultCostModel),
}
}
pub fn with_cost_model(mut self, cost_model: Arc<dyn CostModel>) -> Self {
self.cost_model = cost_model;
self
}
pub async fn plan_query(
&self,
subject: Option<&str>,
predicate: Option<&str>,
object: Option<&str>,
hints: QueryOptimizationHints,
) -> Result<QueryPlan> {
let query_id = self.generate_query_id(subject, predicate, object);
if hints.enable_cache {
if let Some(plan) = self.check_cache(&query_id).await {
return Ok(plan);
}
}
let shard_ids = self
.shard_router
.route_query_pattern(subject, predicate, object)
.await?;
{
let mut stats = self.statistics.write().await;
stats.total_queries += 1;
if shard_ids.len() == 1 {
stats.single_shard_queries += 1;
} else {
stats.multi_shard_queries += 1;
}
stats.avg_shards_per_query = (stats.avg_shards_per_query
* (stats.total_queries - 1) as f64
+ shard_ids.len() as f64)
/ stats.total_queries as f64;
}
let mut shard_targets = Vec::new();
let mut total_cost = 0.0;
for &shard_id in &shard_ids {
if let Some(metadata) = self.shard_router.get_shard_metadata(shard_id).await {
let selectivity = self.estimate_selectivity(subject, predicate, object);
let target = ShardTarget {
shard_id,
preferred_node: metadata.primary_node,
alternative_nodes: metadata
.node_ids
.iter()
.filter(|&&id| id != metadata.primary_node)
.copied()
.collect(),
selectivity,
};
let shard_cost = self.cost_model.estimate_shard_cost(shard_id, selectivity);
total_cost += shard_cost;
shard_targets.push(target);
}
}
if shard_targets.len() > 1 {
let merge_cost = self.cost_model.estimate_merge_cost(
shard_targets.len(),
1000, );
total_cost += merge_cost;
}
let plan = QueryPlan {
query_id: query_id.clone(),
shard_targets,
optimization_hints: hints,
estimated_cost: total_cost,
};
self.cache_plan(query_id, plan.clone()).await;
Ok(plan)
}
pub async fn plan_federated_query(
&self,
local_pattern: (Option<&str>, Option<&str>, Option<&str>),
remote_endpoints: Vec<String>,
hints: QueryOptimizationHints,
) -> Result<FederatedQueryPlan> {
let local_plan = self
.plan_query(
local_pattern.0,
local_pattern.1,
local_pattern.2,
hints.clone(),
)
.await?;
let mut remote_plans = Vec::new();
for endpoint in remote_endpoints {
remote_plans.push(RemoteQueryPlan {
endpoint: endpoint.clone(),
estimated_latency_ms: 100.0, estimated_result_size: 1000,
});
}
Ok(FederatedQueryPlan {
local_plan,
remote_plans,
merge_strategy: MergeStrategy::Union,
})
}
pub async fn optimize_plan(&self, plan: &mut QueryPlan) -> Result<()> {
for target in &mut plan.shard_targets {
if let Some(metadata) = self.shard_router.get_shard_metadata(target.shard_id).await {
if metadata.state != crate::shard::ShardState::Active {
if let Some(alt) = target.alternative_nodes.first() {
target.preferred_node = *alt;
warn!(
"Shard {} primary offline, using alternative node {}",
target.shard_id, alt
);
}
}
}
}
if plan.shard_targets.len() > 1 && plan.optimization_hints.parallel_execution {
debug!(
"Enabling parallel execution for {} shards",
plan.shard_targets.len()
);
}
Ok(())
}
pub async fn execute_plan(&self, plan: QueryPlan) -> Result<Vec<Triple>> {
let start_time = std::time::Instant::now();
let mut all_results = Vec::new();
if plan.optimization_hints.parallel_execution && plan.shard_targets.len() > 1 {
let mut handles = Vec::new();
for _target in plan.shard_targets {
let _shard_manager = self.shard_manager.clone();
let handle = tokio::spawn(async move {
Vec::<Triple>::new()
});
handles.push(handle);
}
for handle in handles {
let results = handle
.await
.map_err(|e| ClusterError::Runtime(format!("Query execution failed: {e}")))?;
all_results.extend(results);
}
} else {
for _target in plan.shard_targets {
}
}
if let Some(limit) = plan.optimization_hints.limit {
all_results.truncate(limit);
}
let latency_ms = start_time.elapsed().as_millis() as f64;
{
let mut stats = self.statistics.write().await;
stats.avg_latency_ms = (stats.avg_latency_ms * (stats.total_queries - 1) as f64
+ latency_ms)
/ stats.total_queries as f64;
}
Ok(all_results)
}
fn generate_query_id(
&self,
subject: Option<&str>,
predicate: Option<&str>,
object: Option<&str>,
) -> String {
format!("{subject:?}:{predicate:?}:{object:?}")
}
fn estimate_selectivity(
&self,
subject: Option<&str>,
predicate: Option<&str>,
object: Option<&str>,
) -> f64 {
let specificity = [subject.is_some(), predicate.is_some(), object.is_some()]
.iter()
.filter(|&&x| x)
.count();
match specificity {
3 => 0.01, 2 => 0.1, 1 => 0.5, _ => 1.0, }
}
async fn check_cache(&self, query_id: &str) -> Option<QueryPlan> {
let mut cache = self.query_cache.write().await;
if let Some(entry) = cache.entries.get_mut(query_id) {
entry.access_count += 1;
let plan = entry.plan.clone();
cache.hits += 1;
let hit_rate = cache.hits as f64 / (cache.hits + cache.misses) as f64;
self.statistics.write().await.cache_hit_rate = hit_rate;
Some(plan)
} else {
cache.misses += 1;
None
}
}
async fn cache_plan(&self, query_id: String, plan: QueryPlan) {
let mut cache = self.query_cache.write().await;
if cache.entries.len() >= cache.max_size {
if let Some((evict_id, _)) = cache
.entries
.iter()
.min_by_key(|(_, entry)| entry.access_count)
{
let evict_id = evict_id.clone();
cache.entries.remove(&evict_id);
}
}
cache.entries.insert(
query_id,
CacheEntry {
plan,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs(),
access_count: 1,
},
);
}
pub async fn get_statistics(&self) -> RoutingStatistics {
self.statistics.read().await.clone()
}
}
#[derive(Debug, Clone)]
pub struct FederatedQueryPlan {
pub local_plan: QueryPlan,
pub remote_plans: Vec<RemoteQueryPlan>,
pub merge_strategy: MergeStrategy,
}
#[derive(Debug, Clone)]
pub struct RemoteQueryPlan {
pub endpoint: String,
pub estimated_latency_ms: f64,
pub estimated_result_size: usize,
}
#[derive(Debug, Clone)]
pub enum MergeStrategy {
Union,
Intersection,
Join(Vec<String>),
}
#[cfg(test)]
mod tests {
use super::*;
use crate::network::{NetworkConfig, NetworkService};
use crate::shard::ShardingStrategy;
use crate::shard_manager::ShardManagerConfig;
use crate::storage::mock::MockStorageBackend;
#[tokio::test]
async fn test_query_planning() {
let strategy = ShardingStrategy::Hash { num_shards: 4 };
let shard_router = Arc::new(ShardRouter::new(strategy));
shard_router.init_shards(4, 3).await.unwrap();
let storage = Arc::new(MockStorageBackend::new());
let network = Arc::new(NetworkService::new(1, NetworkConfig::default()));
let shard_manager = Arc::new(ShardManager::new(
1,
shard_router.clone(),
ShardManagerConfig::default(),
storage,
network,
));
let query_router = QueryRouter::new(shard_router, shard_manager);
let hints = QueryOptimizationHints {
parallel_execution: true,
enable_cache: true,
..Default::default()
};
let plan = query_router
.plan_query(Some("http://example.org/subject"), None, None, hints)
.await
.unwrap();
assert_eq!(plan.shard_targets.len(), 1);
assert!(plan.estimated_cost > 0.0);
}
#[test]
fn test_selectivity_estimation() {
let query_router = QueryRouter::new(
Arc::new(ShardRouter::new(ShardingStrategy::Hash { num_shards: 1 })),
Arc::new(ShardManager::new(
1,
Arc::new(ShardRouter::new(ShardingStrategy::Hash { num_shards: 1 })),
ShardManagerConfig::default(),
Arc::new(MockStorageBackend::new()),
Arc::new(NetworkService::new(1, NetworkConfig::default())),
)),
);
assert_eq!(
query_router.estimate_selectivity(Some("s"), Some("p"), Some("o")),
0.01
);
assert_eq!(
query_router.estimate_selectivity(Some("s"), Some("p"), None),
0.1
);
assert_eq!(
query_router.estimate_selectivity(Some("s"), None, None),
0.5
);
assert_eq!(query_router.estimate_selectivity(None, None, None), 1.0);
}
#[test]
fn test_cost_model() {
let cost_model = DefaultCostModel;
let shard_cost = cost_model.estimate_shard_cost(0, 0.5);
assert!(shard_cost > 0.0);
let network_cost = cost_model.estimate_network_cost(1_000_000, 2);
assert!(network_cost > 0.0);
let merge_cost = cost_model.estimate_merge_cost(4, 1000);
assert!(merge_cost > 0.0);
}
}