use crate::distributed::shard::{EdgeData, GraphShard, NodeData, NodeId, ShardId};
use crate::{GraphError, Result};
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryPlan {
pub query_id: String,
pub query: String,
pub target_shards: Vec<ShardId>,
pub steps: Vec<QueryStep>,
pub estimated_cost: f64,
pub is_distributed: bool,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum QueryStep {
NodeScan {
shard_id: ShardId,
label: Option<String>,
filter: Option<String>,
},
EdgeScan {
shard_id: ShardId,
edge_type: Option<String>,
},
Join {
left_shard: ShardId,
right_shard: ShardId,
join_key: String,
},
Aggregate {
operation: AggregateOp,
group_by: Option<String>,
},
Filter { predicate: String },
Sort { key: String, ascending: bool },
Limit { count: usize },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AggregateOp {
Count,
Sum(String),
Avg(String),
Min(String),
Max(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryResult {
pub query_id: String,
pub nodes: Vec<NodeData>,
pub edges: Vec<EdgeData>,
pub aggregates: HashMap<String, serde_json::Value>,
pub stats: QueryStats,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryStats {
pub execution_time_ms: u64,
pub shards_queried: usize,
pub nodes_scanned: usize,
pub edges_scanned: usize,
pub cached: bool,
}
pub struct ShardCoordinator {
shards: Arc<DashMap<ShardId, Arc<GraphShard>>>,
query_cache: Arc<DashMap<String, QueryResult>>,
transactions: Arc<DashMap<String, Transaction>>,
}
impl ShardCoordinator {
pub fn new() -> Self {
Self {
shards: Arc::new(DashMap::new()),
query_cache: Arc::new(DashMap::new()),
transactions: Arc::new(DashMap::new()),
}
}
pub fn register_shard(&self, shard_id: ShardId, shard: Arc<GraphShard>) {
info!("Registering shard {} with coordinator", shard_id);
self.shards.insert(shard_id, shard);
}
pub fn unregister_shard(&self, shard_id: ShardId) -> Result<()> {
info!("Unregistering shard {}", shard_id);
self.shards
.remove(&shard_id)
.ok_or_else(|| GraphError::ShardError(format!("Shard {} not found", shard_id)))?;
Ok(())
}
pub fn get_shard(&self, shard_id: ShardId) -> Option<Arc<GraphShard>> {
self.shards.get(&shard_id).map(|s| Arc::clone(s.value()))
}
pub fn list_shards(&self) -> Vec<ShardId> {
self.shards.iter().map(|e| *e.key()).collect()
}
pub fn plan_query(&self, query: &str) -> Result<QueryPlan> {
let query_id = Uuid::new_v4().to_string();
let target_shards: Vec<ShardId> = self.list_shards();
let steps = self.parse_query_steps(query)?;
let estimated_cost = self.estimate_cost(&steps, &target_shards);
Ok(QueryPlan {
query_id,
query: query.to_string(),
target_shards,
steps,
estimated_cost,
is_distributed: true,
created_at: Utc::now(),
})
}
fn parse_query_steps(&self, query: &str) -> Result<Vec<QueryStep>> {
let mut steps = Vec::new();
if query.to_lowercase().contains("match") {
for shard_id in self.list_shards() {
steps.push(QueryStep::NodeScan {
shard_id,
label: None,
filter: None,
});
}
}
if query.to_lowercase().contains("count") {
steps.push(QueryStep::Aggregate {
operation: AggregateOp::Count,
group_by: None,
});
}
if let Some(limit_pos) = query.to_lowercase().find("limit") {
if let Some(count_str) = query[limit_pos..].split_whitespace().nth(1) {
if let Ok(count) = count_str.parse::<usize>() {
steps.push(QueryStep::Limit { count });
}
}
}
Ok(steps)
}
fn estimate_cost(&self, steps: &[QueryStep], target_shards: &[ShardId]) -> f64 {
let mut cost = 0.0;
for step in steps {
match step {
QueryStep::NodeScan { .. } => cost += 10.0,
QueryStep::EdgeScan { .. } => cost += 15.0,
QueryStep::Join { .. } => cost += 50.0,
QueryStep::Aggregate { .. } => cost += 20.0,
QueryStep::Filter { .. } => cost += 5.0,
QueryStep::Sort { .. } => cost += 30.0,
QueryStep::Limit { .. } => cost += 1.0,
}
}
cost * target_shards.len() as f64
}
pub async fn execute_query(&self, plan: QueryPlan) -> Result<QueryResult> {
let start = std::time::Instant::now();
info!(
"Executing query {} across {} shards",
plan.query_id,
plan.target_shards.len()
);
if let Some(cached) = self.query_cache.get(&plan.query) {
debug!("Query cache hit for: {}", plan.query);
return Ok(cached.value().clone());
}
let mut nodes = Vec::new();
let mut edges = Vec::new();
let mut aggregates = HashMap::new();
let mut nodes_scanned = 0;
let mut edges_scanned = 0;
for step in &plan.steps {
match step {
QueryStep::NodeScan {
shard_id,
label,
filter,
} => {
if let Some(shard) = self.get_shard(*shard_id) {
let shard_nodes = shard.list_nodes();
nodes_scanned += shard_nodes.len();
let filtered: Vec<_> = if let Some(label_filter) = label {
shard_nodes
.into_iter()
.filter(|n| n.labels.contains(label_filter))
.collect()
} else {
shard_nodes
};
nodes.extend(filtered);
}
}
QueryStep::EdgeScan {
shard_id,
edge_type,
} => {
if let Some(shard) = self.get_shard(*shard_id) {
let shard_edges = shard.list_edges();
edges_scanned += shard_edges.len();
let filtered: Vec<_> = if let Some(type_filter) = edge_type {
shard_edges
.into_iter()
.filter(|e| &e.edge_type == type_filter)
.collect()
} else {
shard_edges
};
edges.extend(filtered);
}
}
QueryStep::Aggregate {
operation,
group_by,
} => {
match operation {
AggregateOp::Count => {
aggregates.insert(
"count".to_string(),
serde_json::Value::Number(nodes.len().into()),
);
}
_ => {
}
}
}
QueryStep::Limit { count } => {
nodes.truncate(*count);
}
_ => {
}
}
}
let execution_time_ms = start.elapsed().as_millis() as u64;
let result = QueryResult {
query_id: plan.query_id.clone(),
nodes,
edges,
aggregates,
stats: QueryStats {
execution_time_ms,
shards_queried: plan.target_shards.len(),
nodes_scanned,
edges_scanned,
cached: false,
},
};
self.query_cache.insert(plan.query.clone(), result.clone());
info!(
"Query {} completed in {}ms",
plan.query_id, execution_time_ms
);
Ok(result)
}
pub fn begin_transaction(&self) -> String {
let tx_id = Uuid::new_v4().to_string();
let transaction = Transaction::new(tx_id.clone());
self.transactions.insert(tx_id.clone(), transaction);
info!("Started transaction: {}", tx_id);
tx_id
}
pub async fn commit_transaction(&self, tx_id: &str) -> Result<()> {
if let Some((_, tx)) = self.transactions.remove(tx_id) {
info!("Committing transaction: {}", tx_id);
Ok(())
} else {
Err(GraphError::CoordinatorError(format!(
"Transaction not found: {}",
tx_id
)))
}
}
pub async fn rollback_transaction(&self, tx_id: &str) -> Result<()> {
if let Some((_, tx)) = self.transactions.remove(tx_id) {
warn!("Rolling back transaction: {}", tx_id);
Ok(())
} else {
Err(GraphError::CoordinatorError(format!(
"Transaction not found: {}",
tx_id
)))
}
}
pub fn clear_cache(&self) {
self.query_cache.clear();
info!("Query cache cleared");
}
}
#[derive(Debug, Clone)]
struct Transaction {
id: String,
shards: HashSet<ShardId>,
state: TransactionState,
created_at: DateTime<Utc>,
}
impl Transaction {
fn new(id: String) -> Self {
Self {
id,
shards: HashSet::new(),
state: TransactionState::Active,
created_at: Utc::now(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum TransactionState {
Active,
Preparing,
Committed,
Aborted,
}
pub struct Coordinator {
shard_coordinator: Arc<ShardCoordinator>,
config: CoordinatorConfig,
}
impl Coordinator {
pub fn new(config: CoordinatorConfig) -> Self {
Self {
shard_coordinator: Arc::new(ShardCoordinator::new()),
config,
}
}
pub fn shard_coordinator(&self) -> Arc<ShardCoordinator> {
Arc::clone(&self.shard_coordinator)
}
pub async fn execute(&self, query: &str) -> Result<QueryResult> {
let plan = self.shard_coordinator.plan_query(query)?;
self.shard_coordinator.execute_query(plan).await
}
pub fn config(&self) -> &CoordinatorConfig {
&self.config
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CoordinatorConfig {
pub enable_cache: bool,
pub cache_ttl_seconds: u64,
pub max_query_time_seconds: u64,
pub enable_optimization: bool,
}
impl Default for CoordinatorConfig {
fn default() -> Self {
Self {
enable_cache: true,
cache_ttl_seconds: 300,
max_query_time_seconds: 60,
enable_optimization: true,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::distributed::shard::ShardMetadata;
use crate::distributed::shard::ShardStrategy;
#[tokio::test]
async fn test_shard_coordinator() {
let coordinator = ShardCoordinator::new();
let metadata = ShardMetadata::new(0, "node-1".to_string(), ShardStrategy::Hash);
let shard = Arc::new(GraphShard::new(metadata));
coordinator.register_shard(0, shard);
assert_eq!(coordinator.list_shards().len(), 1);
assert!(coordinator.get_shard(0).is_some());
}
#[tokio::test]
async fn test_query_planning() {
let coordinator = ShardCoordinator::new();
let metadata = ShardMetadata::new(0, "node-1".to_string(), ShardStrategy::Hash);
let shard = Arc::new(GraphShard::new(metadata));
coordinator.register_shard(0, shard);
let plan = coordinator.plan_query("MATCH (n:Person) RETURN n").unwrap();
assert!(!plan.query_id.is_empty());
assert!(!plan.steps.is_empty());
}
#[tokio::test]
async fn test_transaction() {
let coordinator = ShardCoordinator::new();
let tx_id = coordinator.begin_transaction();
assert!(!tx_id.is_empty());
coordinator.commit_transaction(&tx_id).await.unwrap();
}
}