use async_graphql::{Context, Result, Object, SimpleObject, Enum, InputObject};
use std::sync::Arc;
use tokio::sync::RwLock;
use std::pin::Pin;
use std::future::Future;
use serde::{Serialize, Deserialize};
use std::collections::HashMap;
use chrono::{DateTime, Utc};
use crate::graphql_subscriptions::UserContext;
#[derive(SimpleObject, Serialize, Deserialize, Clone, Debug)]
pub struct QueryPlan {
pub query_id: String,
pub original_query: String,
pub optimized_query: Option<String>,
pub execution_plan: ExecutionPlan,
pub estimated_cost: f64,
pub estimated_rows: u64,
pub optimization_suggestions: Vec<OptimizationSuggestion>,
pub created_at: DateTime<Utc>,
}
#[derive(SimpleObject, Serialize, Deserialize, Clone, Debug)]
pub struct ExecutionPlan {
pub steps: Vec<ExecutionStep>,
pub parallelizable: bool,
pub estimated_duration_ms: u64,
}
#[derive(SimpleObject, Serialize, Deserialize, Clone, Debug)]
pub struct ExecutionStep {
pub step_type: StepType,
pub description: String,
pub estimated_cost: f64,
pub estimated_rows: u64,
pub dependencies: Vec<usize>,
}
#[derive(Enum, Serialize, Deserialize, Clone, Copy, Debug, Eq, PartialEq)]
pub enum StepType {
TableScan,
IndexScan,
Filter,
Sort,
Aggregate,
Join,
Limit,
Projection,
}
#[derive(SimpleObject, Serialize, Deserialize, Clone, Debug)]
pub struct OptimizationSuggestion {
pub suggestion_type: SuggestionType,
pub description: String,
pub impact: ImpactLevel,
pub estimated_improvement: f64,
}
#[derive(Enum, Serialize, Deserialize, Clone, Copy, Debug, Eq, PartialEq)]
pub enum SuggestionType {
AddIndex,
RewriteQuery,
UseLimit,
AvoidSelectStar,
OptimizeJoin,
CacheResult,
}
#[derive(Enum, Serialize, Deserialize, Clone, Copy, Debug, Eq, PartialEq)]
pub enum ImpactLevel {
Low,
Medium,
High,
Critical,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct QueryMetrics {
pub query_id: String,
pub execution_time_ms: u64,
pub rows_returned: u64,
pub cache_hit: bool,
pub timestamp: DateTime<Utc>,
pub error: Option<String>,
}
#[derive(SimpleObject, Serialize, Deserialize, Clone, Debug)]
pub struct SlowQuery {
pub query: String,
pub execution_time_secs: u64,
pub timestamp: DateTime<Utc>,
pub frequency: u32,
pub optimization_applied: bool,
}
#[derive(InputObject, Serialize, Deserialize, Clone, Debug)]
pub struct BatchQuery {
pub query: String,
pub parameters: Option<Vec<serde_json::Value>>,
pub cache_ttl: Option<i32>,
}
#[derive(SimpleObject, Serialize, Deserialize, Clone, Debug, Default)]
pub struct QueryResult {
pub rows: Vec<serde_json::Value>,
pub affected_rows: u64,
pub execution_time_ms: u64,
pub cached: bool,
pub query_id: String,
}
#[derive(SimpleObject, Serialize, Deserialize, Clone, Debug)]
pub struct QueryStatistics {
pub queries_executed: u64,
pub cache_hit_rate: f64,
pub avg_execution_time_ms: f64,
pub slow_queries: Vec<SlowQuery>,
pub most_frequent_queries: Vec<QueryFrequency>,
}
#[derive(SimpleObject, Serialize, Deserialize, Clone, Debug)]
pub struct QueryFrequency {
pub query_pattern: String,
pub count: u32,
pub avg_execution_time_ms: f64,
}
pub trait Cache: Send + Sync {
fn get<'a>(&'a self, key: &'a str) -> Pin<Box<dyn Future<Output = Option<QueryResult>> + Send + 'a>>;
fn set<'a>(&'a self, key: &'a str, value: &'a QueryResult, ttl: std::time::Duration) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
fn invalidate<'a>(&'a self, key: &'a str) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
fn clear<'a>(&'a self) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
fn stats<'a>(&'a self) -> Pin<Box<dyn Future<Output = CacheStats> + Send + 'a>>;
}
#[derive(SimpleObject, Serialize, Deserialize, Clone, Debug)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub size: usize,
pub max_size: usize,
}
pub trait ConnectionPool: Send + Sync {
fn execute<'a>(&'a self, query: &'a str, params: &'a [serde_json::Value]) -> Pin<Box<dyn Future<Output = Result<QueryResult>> + Send + 'a>>;
fn execute_batch<'a>(&'a self, queries: &'a [BatchQuery]) -> Pin<Box<dyn Future<Output = Result<Vec<QueryResult>>> + Send + 'a>>;
fn get_stats<'a>(&'a self) -> Pin<Box<dyn Future<Output = ConnectionPoolStats> + Send + 'a>>;
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct ConnectionPoolStats {
pub active_connections: u32,
pub idle_connections: u32,
pub total_connections: u32,
pub max_connections: u32,
}
pub trait QueryPlanner: Send + Sync {
fn optimize<'a>(&'a self, query: &'a str, params: &'a Option<Vec<serde_json::Value>>) -> Pin<Box<dyn Future<Output = Result<QueryPlan>> + Send + 'a>>;
fn explain<'a>(&'a self, query: &'a str) -> Pin<Box<dyn Future<Output = Result<ExecutionPlan>> + Send + 'a>>;
fn validate<'a>(&'a self, query: &'a str) -> Pin<Box<dyn Future<Output = Result<Vec<ValidationError>>> + Send + 'a>>;
}
#[derive(SimpleObject, Serialize, Deserialize, Clone, Debug)]
pub struct ValidationError {
pub message: String,
pub severity: ValidationSeverity,
pub line: Option<u32>,
pub column: Option<u32>,
}
#[derive(Enum, Serialize, Deserialize, Clone, Copy, Debug, Eq, PartialEq)]
pub enum ValidationSeverity {
Error,
Warning,
Info,
}
pub struct OptimizedQueryExecutor {
cache: Arc<dyn Cache>,
query_planner: Arc<dyn QueryPlanner>,
connection_pool: Arc<dyn ConnectionPool>,
metrics: Arc<RwLock<QueryMetricsCollector>>,
}
#[derive(Default)]
struct QueryMetricsCollector {
queries_executed: u64,
cache_hits: u64,
cache_misses: u64,
avg_execution_time: std::time::Duration,
slow_queries: Vec<SlowQuery>,
query_frequencies: HashMap<String, QueryFrequency>,
}
impl QueryMetricsCollector {
fn update_execution_metrics(&mut self, execution_time: std::time::Duration, query: &str) {
self.queries_executed += 1;
let total_time = self.avg_execution_time * (self.queries_executed - 1) as u32 + execution_time;
self.avg_execution_time = total_time / self.queries_executed as u32;
if execution_time > std::time::Duration::from_secs(1) {
self.slow_queries.push(SlowQuery {
query: query.to_string(),
execution_time,
timestamp: Utc::now(),
frequency: 1,
optimization_applied: false,
});
if self.slow_queries.len() > 100 {
self.slow_queries.remove(0);
}
}
let query_pattern = self.normalize_query(query);
let frequency = self.query_frequencies.entry(query_pattern).or_insert(QueryFrequency {
query_pattern: query.clone(),
count: 0,
avg_execution_time_ms: execution_time.as_millis() as f64,
});
frequency.count += 1;
let total_time = frequency.avg_execution_time_ms * (frequency.count - 1) as f64 + execution_time.as_millis() as f64;
frequency.avg_execution_time_ms = total_time / frequency.count as f64;
}
fn normalize_query(&self, query: &str) -> String {
query
.chars()
.map(|c| if c.is_whitespace() { ' ' } else { c })
.collect::<String>()
.split_whitespace()
.collect::<Vec<&str>>()
.join(" ")
.to_lowercase()
}
fn update_cache_hit_metrics(&mut self) {
self.cache_hits += 1;
}
fn update_cache_miss_metrics(&mut self) {
self.cache_misses += 1;
}
fn get_cache_hit_rate(&self) -> f64 {
let total_requests = self.cache_hits + self.cache_misses;
if total_requests == 0 {
0.0
} else {
self.cache_hits as f64 / total_requests as f64
}
}
}
#[Object]
impl OptimizedQueryExecutor {
async fn execute_query(
&self,
ctx: &Context<'_>,
#[graphql(desc = "SQL query to execute")] query: String,
#[graphql(desc = "Query parameters")] parameters: Option<Vec<serde_json::Value>>,
#[graphql(desc = "Cache TTL in seconds")] cache_ttl: Option<i32>
) -> Result<QueryResult> {
let user_context = ctx.data::<UserContext>()
.map_err(|_| async_graphql::Error::new("User context not available"))?;
self::validate_query_permissions(user_context, &query)?;
let cache_key = self::generate_cache_key(&query, ¶meters);
if let Some(cached_result) = self.cache.get(&cache_key).await {
self::update_cache_hit_metrics(&self.metrics).await;
return Ok(cached_result);
}
self::update_cache_miss_metrics(&self.metrics).await;
let optimized_plan = Pin::from(self.query_planner.optimize(&query, ¶meters)).await?;
let start_time = std::time::Instant::now();
let mut result = Pin::from(self.connection_pool.execute(&query, ¶meters.as_ref().unwrap_or(&vec![]))).await?;
let execution_time = start_time.elapsed();
self::update_execution_metrics(&self.metrics, execution_time, &query).await;
result.query_id = optimized_plan.query_id.clone();
result.execution_time_ms = execution_time.as_millis() as u64;
let ttl = std::time::Duration::from_secs(cache_ttl.unwrap_or(300) as u64);
self.cache.set(&cache_key, &result, ttl).await;
Ok(result)
}
async fn execute_batch_queries(
&self,
ctx: &Context<'_>,
#[graphql(desc = "List of queries to execute")] queries: Vec<BatchQuery>
) -> Result<Vec<QueryResult>> {
let user_context = ctx.data::<UserContext>()
.map_err(|_| async_graphql::Error::new("User context not available"))?;
for batch_query in &queries {
self::validate_query_permissions(user_context, &batch_query.query)?;
}
let mut read_queries = Vec::new();
let mut write_queries = Vec::new();
for (i, batch_query) in queries.into_iter().enumerate() {
if self::is_read_query(&batch_query.query) {
read_queries.push((i, batch_query));
} else {
write_queries.push((i, batch_query));
}
}
let read_futures: Vec<_> = read_queries.into_iter()
.map(|(index, batch_query)| {
let executor = self.clone();
async move {
let result = executor.execute_single_query(&batch_query.query, &batch_query.parameters).await?;
Ok((index, result))
}
})
.collect();
let read_results = futures::future::try_join_all(read_futures).await?;
let mut write_results = Vec::new();
for (index, batch_query) in write_queries {
let result = self.execute_single_query(&batch_query.query, &batch_query.parameters).await?;
write_results.push((index, result));
}
let mut combined_results = vec![QueryResult::default(); read_results.len() + write_results.len()];
for (index, result) in read_results {
combined_results[index] = result;
}
for (index, result) in write_results {
combined_results[index] = result;
}
Ok(combined_results)
}
async fn query_statistics(&self, ctx: &Context<'_>) -> Result<QueryStatistics> {
let user_context = ctx.data::<UserContext>()
.map_err(|_| async_graphql::Error::new("User context not available"))?;
if !user_context.has_permission("query.statistics") {
return Err(async_graphql::Error::new("Insufficient permissions for query statistics"));
}
let metrics = self.metrics.read().await;
Ok(QueryStatistics {
queries_executed: metrics.queries_executed,
cache_hit_rate: metrics.get_cache_hit_rate(),
avg_execution_time_ms: metrics.avg_execution_time.as_millis() as f64,
slow_queries: metrics.slow_queries.clone(),
most_frequent_queries: metrics.query_frequencies.values().cloned().collect(),
})
}
async fn optimize_query(
&self,
ctx: &Context<'_>,
#[graphql(desc = "SQL query to optimize")] query: String,
#[graphql(desc = "Query parameters")] parameters: Option<Vec<serde_json::Value>>
) -> Result<QueryPlan> {
let user_context = ctx.data::<UserContext>()
.map_err(|_| async_graphql::Error::new("User context not available"))?;
if !user_context.has_permission("query.optimize") {
return Err(async_graphql::Error::new("Insufficient permissions for query optimization"));
}
let plan = Pin::from(self.query_planner.optimize(&query, ¶meters)).await?;
Ok(plan)
}
async fn explain_query(
&self,
ctx: &Context<'_>,
#[graphql(desc = "SQL query to explain")] query: String
) -> Result<ExecutionPlan> {
let user_context = ctx.data::<UserContext>()
.map_err(|_| async_graphql::Error::new("User context not available"))?;
if !user_context.has_permission("query.explain") {
return Err(async_graphql::Error::new("Insufficient permissions for query explanation"));
}
let plan = Pin::from(self.query_planner.explain(&query)).await?;
Ok(plan)
}
async fn validate_query(
&self,
ctx: &Context<'_>,
#[graphql(desc = "SQL query to validate")] query: String
) -> Result<Vec<ValidationError>> {
let user_context = ctx.data::<UserContext>()
.map_err(|_| async_graphql::Error::new("User context not available"))?;
if !user_context.has_permission("query.validate") {
return Err(async_graphql::Error::new("Insufficient permissions for query validation"));
}
let errors = Pin::from(self.query_planner.validate(&query)).await?;
Ok(errors)
}
async fn clear_cache(&self, ctx: &Context<'_>) -> Result<bool> {
let user_context = ctx.data::<UserContext>()
.map_err(|_| async_graphql::Error::new("User context not available"))?;
if !user_context.has_permission("cache.clear") {
return Err(async_graphql::Error::new("Insufficient permissions to clear cache"));
}
self.cache.clear().await;
Ok(true)
}
async fn cache_statistics(&self, ctx: &Context<'_>) -> Result<CacheStats> {
let user_context = ctx.data::<UserContext>()
.map_err(|_| async_graphql::Error::new("User context not available"))?;
if !user_context.has_permission("cache.stats") {
return Err(async_graphql::Error::new("Insufficient permissions for cache statistics"));
}
let stats = Pin::from(self.cache.stats()).await;
Ok(stats)
}
let parsed_query = self::parse_query(query)?;
for table in &parsed_query.tables {
if !user_context.has_table_permission(table, &parsed_query.operation) {
return Err(async_graphql::Error::new(format!("Access denied to table: {}", table)));
}
}
if parsed_query.is_admin_operation && !user_context.has_permission("admin.query") {
return Err(async_graphql::Error::new("Admin privileges required for this query"));
}
Ok(())
}
fn parse_query(&self, query: &str) -> Result<ParsedQuery> {
let query_lower = query.to_lowercase();
let operation = if query_lower.contains("select") {
"SELECT".to_string()
} else if query_lower.contains("insert") {
"INSERT".to_string()
} else if query_lower.contains("update") {
"UPDATE".to_string()
} else if query_lower.contains("delete") {
"DELETE".to_string()
} else {
return Err(async_graphql::Error::new("Unsupported query type"));
};
let mut tables = Vec::new();
if query_lower.contains("from") {
let words: Vec<&str> = query_lower.split_whitespace().collect();
for (i, word) in words.iter().enumerate() {
if *word == "from" && i + 1 < words.len() {
tables.push(words[i + 1].to_string());
}
}
}
let is_admin_operation = query_lower.contains("drop") ||
query_lower.contains("create") ||
query_lower.contains("alter") ||
query_lower.contains("truncate");
Ok(ParsedQuery {
operation,
tables,
is_admin_operation,
})
}
fn generate_cache_key(&self, query: &str, parameters: &Option<Vec<serde_json::Value>>) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
query.hash(&mut hasher);
if let Some(params) = parameters {
for param in params {
param.to_string().hash(&mut hasher);
}
}
format!("query_{:x}", hasher.finish())
}
async fn execute_single_query(&self, query: &str, parameters: &Option<Vec<serde_json::Value>>) -> Result<QueryResult> {
let start_time = std::time::Instant::now();
let result = Pin::from(self.connection_pool.execute(query, ¶meters.as_ref().unwrap_or(&vec![]))).await?;
let execution_time = start_time.elapsed();
let mut final_result = result;
final_result.execution_time_ms = execution_time.as_millis() as u64;
Ok(final_result)
}
fn is_read_query(&self, query: &str) -> bool {
let query_lower = query.to_lowercase();
query_lower.contains("select") || query_lower.contains("show") || query_lower.contains("describe")
}
async fn update_cache_hit_metrics(&self, metrics: &Arc<RwLock<QueryMetricsCollector>>) {
let mut metrics = metrics.write().await;
metrics.update_cache_hit_metrics();
}
async fn update_cache_miss_metrics(&self, metrics: &Arc<RwLock<QueryMetricsCollector>>) {
let mut metrics = metrics.write().await;
metrics.update_cache_miss_metrics();
}
async fn update_execution_metrics(&self, metrics: &Arc<RwLock<QueryMetricsCollector>>, execution_time: std::time::Duration, query: &str) {
let mut metrics = metrics.write().await;
metrics.update_execution_metrics(execution_time, query);
}
}
impl Clone for OptimizedQueryExecutor {
fn clone(&self) -> Self {
Self {
cache: self.cache.clone(),
query_planner: self.query_planner.clone(),
connection_pool: self.connection_pool.clone(),
metrics: self.metrics.clone(),
}
}
}
#[derive(Debug, Clone)]
struct ParsedQuery {
operation: String,
tables: Vec<String>,
is_admin_operation: bool,
}
impl UserContext {
pub fn has_table_permission(&self, table: &str, operation: &str) -> bool {
let permission = format!("table:{}:{}", operation.to_lowercase(), table);
self.has_permission(&permission) ||
self.has_permission(&format!("table:*:{}", table)) ||
self.has_permission(&format!("table:{}:*", operation.to_lowercase())) ||
self.has_permission("table:*:*") ||
self.has_role("admin")
}
}
async fn update_cache_hit_metrics(metrics: &Arc<RwLock<QueryMetricsCollector>>) {
let mut metrics = metrics.write().await;
metrics.cache_hits += 1;
}
async fn update_cache_miss_metrics(metrics: &Arc<RwLock<QueryMetricsCollector>>) {
let mut metrics = metrics.write().await;
metrics.cache_misses += 1;
}
async fn update_execution_metrics(metrics: &Arc<RwLock<QueryMetricsCollector>>, duration: std::time::Duration, query: &str) {
let mut metrics = metrics.write().await;
metrics.queries_executed += 1;
metrics.avg_execution_time += duration;
if duration.as_millis() > 1000 {
metrics.slow_queries.push(SlowQuery {
query: query.to_string(),
execution_time_secs: duration.as_secs(),
timestamp: Utc::now(),
frequency: 1,
optimization_applied: false,
});
}
}
fn parse_query(query: &str) -> Result<ParsedQuery> {
let query_lower = query.to_lowercase();
let tables = if query_lower.contains("from") {
query_lower.split("from").nth(1)
.and_then(|s| s.split_whitespace().next())
.map(|s| vec![s.to_string()])
.unwrap_or_default()
} else {
vec![]
};
Ok(ParsedQuery {
original: query.to_string(),
tables,
is_select: query_lower.starts_with("select"),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_query_parsing() {
let executor = create_test_executor();
let result = executor.parse_query("SELECT * FROM users WHERE id = 1").unwrap();
assert_eq!(result.operation, "SELECT");
assert!(result.tables.contains(&"users".to_string()));
assert!(!result.is_admin_operation);
}
#[tokio::test]
async fn test_cache_key_generation() {
let executor = create_test_executor();
let key1 = executor.generate_cache_key("SELECT * FROM users", &None);
let key2 = executor.generate_cache_key("SELECT * FROM users", &None);
assert_eq!(key1, key2);
let key3 = executor.generate_cache_key("SELECT * FROM users", &Some(vec![serde_json::json!(1)]));
assert_ne!(key1, key3);
}
#[tokio::test]
async fn test_query_classification() {
let executor = create_test_executor();
assert!(executor.is_read_query("SELECT * FROM users"));
assert!(executor.is_read_query("SHOW TABLES"));
assert!(!executor.is_read_query("INSERT INTO users VALUES (1)"));
assert!(!executor.is_read_query("UPDATE users SET name = 'test'"));
assert!(!executor.is_read_query("DELETE FROM users"));
}
fn create_test_executor() -> OptimizedQueryExecutor {
unimplemented!("Test executor creation not implemented")
}
}