use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(default)]
pub struct QueryCostConfig {
pub max_cost_per_query: u64,
pub base_cost_per_field: u64,
pub field_cost_multipliers: HashMap<String, u64>,
pub user_cost_budget: u64,
pub budget_window: Duration,
pub track_expensive_queries: bool,
pub expensive_percentile: f64,
pub adaptive_costs: bool,
pub high_load_multiplier: f64,
}
impl Default for QueryCostConfig {
fn default() -> Self {
Self {
max_cost_per_query: 1000,
base_cost_per_field: 1,
field_cost_multipliers: HashMap::new(),
user_cost_budget: 10_000,
budget_window: Duration::from_secs(60),
track_expensive_queries: true,
expensive_percentile: 0.95,
adaptive_costs: true,
high_load_multiplier: 2.0,
}
}
}
pub struct QueryCostAnalyzer {
config: QueryCostConfig,
user_budgets: Arc<RwLock<HashMap<String, UserBudget>>>,
query_costs: Arc<RwLock<VecDeque<u64>>>, current_load_factor: Arc<RwLock<f64>>,
}
#[derive(Debug, Clone)]
struct UserBudget {
spent: u64,
window_start: Instant,
}
impl QueryCostAnalyzer {
pub fn new(config: QueryCostConfig) -> Self {
Self {
config,
user_budgets: Arc::new(RwLock::new(HashMap::new())),
query_costs: Arc::new(RwLock::new(VecDeque::with_capacity(10_000))),
current_load_factor: Arc::new(RwLock::new(1.0)),
}
}
pub async fn calculate_query_cost(&self, query: &str) -> Result<QueryCostResult, String> {
let start = Instant::now();
let field_count = self.count_fields(query);
let complexity = self.calculate_complexity(query);
let mut total_cost = field_count as u64 * self.config.base_cost_per_field;
for (field_pattern, multiplier) in &self.config.field_cost_multipliers {
if query.contains(field_pattern) {
total_cost += total_cost * multiplier / 100;
}
}
if self.config.adaptive_costs {
let load_factor = *self.current_load_factor.read().await;
total_cost = (total_cost as f64 * load_factor) as u64;
}
if total_cost > self.config.max_cost_per_query {
return Err(format!(
"Query cost {} exceeds maximum allowed cost {}",
total_cost, self.config.max_cost_per_query
));
}
if self.config.track_expensive_queries {
let mut costs = self.query_costs.write().await;
const MAX_COST_HISTORY: usize = 10_000;
if costs.len() >= MAX_COST_HISTORY {
costs.pop_front();
}
costs.push_back(total_cost);
}
Ok(QueryCostResult {
total_cost,
field_count,
complexity,
calculation_time: start.elapsed(),
})
}
pub async fn check_user_budget(&self, user_id: &str, query_cost: u64) -> Result<(), String> {
let mut budgets = self.user_budgets.write().await;
let now = Instant::now();
const MAX_USER_BUDGET_ENTRIES: usize = 50_000;
if budgets.len() >= MAX_USER_BUDGET_ENTRIES && !budgets.contains_key(user_id) {
budgets.retain(|_, b| now.duration_since(b.window_start) <= self.config.budget_window);
if budgets.len() >= MAX_USER_BUDGET_ENTRIES {
return Err(format!(
"Query cost budget tracking capacity ({}) exceeded; try again later",
MAX_USER_BUDGET_ENTRIES
));
}
}
let budget = budgets.entry(user_id.to_string()).or_insert(UserBudget {
spent: 0,
window_start: now,
});
if now.duration_since(budget.window_start) > self.config.budget_window {
budget.spent = 0;
budget.window_start = now;
}
if budget.spent + query_cost > self.config.user_cost_budget {
return Err(format!(
"User {} exceeded query cost budget ({}/{} in last {:?})",
user_id, budget.spent, self.config.user_cost_budget, self.config.budget_window
));
}
budget.spent += query_cost;
Ok(())
}
pub async fn update_load_factor(&self, cpu_usage: f64, memory_usage: f64) {
let load = (cpu_usage + memory_usage) / 2.0;
let factor = if load > 0.8 {
self.config.high_load_multiplier
} else if load > 0.6 {
1.5
} else {
1.0
};
*self.current_load_factor.write().await = factor;
}
pub async fn get_expensive_threshold(&self) -> u64 {
let costs = self.query_costs.read().await;
if costs.is_empty() {
return self.config.max_cost_per_query;
}
let mut sorted: Vec<u64> = costs.iter().copied().collect();
sorted.sort_unstable();
let index = ((sorted.len() as f64 * self.config.expensive_percentile) as usize)
.min(sorted.len() - 1);
sorted[index]
}
pub async fn get_analytics(&self) -> QueryCostAnalytics {
let costs = self.query_costs.read().await;
if costs.is_empty() {
return QueryCostAnalytics::default();
}
let mut sorted: Vec<u64> = costs.iter().copied().collect();
sorted.sort_unstable();
let len = sorted.len();
let sum: u64 = sorted.iter().sum();
QueryCostAnalytics {
total_queries: len,
average_cost: sum / len as u64,
median_cost: sorted[len / 2],
p95_cost: sorted[((len as f64 * 0.95) as usize).min(len - 1)],
p99_cost: sorted[((len as f64 * 0.99) as usize).min(len - 1)],
max_cost: *sorted.last().unwrap(),
min_cost: *sorted.first().unwrap(),
}
}
fn count_fields(&self, query: &str) -> usize {
query
.lines()
.map(|l| l.trim())
.filter(|line| !line.starts_with('#') && !line.is_empty())
.filter(|line| {
!line.starts_with('}')
&& !line.starts_with("query")
&& !line.starts_with("mutation")
&& !line.starts_with("subscription")
})
.filter(|line| *line != "{")
.count()
}
fn calculate_complexity(&self, query: &str) -> usize {
let depth = query.matches('{').count();
let breadth = self.count_fields(query);
depth * breadth
}
pub async fn cleanup_expired_budgets(&self) {
let mut budgets = self.user_budgets.write().await;
let now = Instant::now();
budgets.retain(|_, budget| {
now.duration_since(budget.window_start) <= self.config.budget_window * 2
});
}
}
#[derive(Debug, Clone)]
pub struct QueryCostResult {
pub total_cost: u64,
pub field_count: usize,
pub complexity: usize,
pub calculation_time: Duration,
}
#[derive(Debug, Clone, Default)]
pub struct QueryCostAnalytics {
pub total_queries: usize,
pub average_cost: u64,
pub median_cost: u64,
pub p95_cost: u64,
pub p99_cost: u64,
pub max_cost: u64,
pub min_cost: u64,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_basic_query_cost() {
let config = QueryCostConfig {
max_cost_per_query: 100,
base_cost_per_field: 10,
..Default::default()
};
let analyzer = QueryCostAnalyzer::new(config);
let query = r#"
query {
user {
id
name
email
}
}
"#;
let result = analyzer.calculate_query_cost(query).await.unwrap();
assert!(result.total_cost > 0);
assert_eq!(result.field_count, 4); }
#[tokio::test]
async fn test_expensive_query_rejection() {
let config = QueryCostConfig {
max_cost_per_query: 10,
base_cost_per_field: 10,
..Default::default()
};
let analyzer = QueryCostAnalyzer::new(config);
let query = r#"
query {
users {
id
name
posts {
id
title
}
}
}
"#;
let result = analyzer.calculate_query_cost(query).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_field_multipliers() {
let mut multipliers = HashMap::new();
multipliers.insert("expensiveField".to_string(), 100);
let config = QueryCostConfig {
base_cost_per_field: 10,
field_cost_multipliers: multipliers,
..Default::default()
};
let analyzer = QueryCostAnalyzer::new(config);
let query_normal = "query {\n normalField\n}";
let cost_normal = analyzer
.calculate_query_cost(query_normal)
.await
.unwrap()
.total_cost;
let query_expensive = "query {\n expensiveField\n}";
let cost_expensive = analyzer
.calculate_query_cost(query_expensive)
.await
.unwrap()
.total_cost;
assert!(cost_expensive > cost_normal);
assert_eq!(cost_normal, 10);
assert_eq!(cost_expensive, 20);
}
#[tokio::test]
async fn test_adaptive_costs() {
let config = QueryCostConfig {
base_cost_per_field: 10,
adaptive_costs: true,
high_load_multiplier: 2.0,
..Default::default()
};
let analyzer = QueryCostAnalyzer::new(config);
let query = "query {\n field\n}";
analyzer.update_load_factor(0.1, 0.1).await;
let cost_low = analyzer
.calculate_query_cost(query)
.await
.unwrap()
.total_cost;
analyzer.update_load_factor(0.9, 0.9).await;
let cost_high = analyzer
.calculate_query_cost(query)
.await
.unwrap()
.total_cost;
assert_eq!(cost_low, 10);
assert_eq!(cost_high, 20); }
#[tokio::test]
async fn test_user_budget_enforcement() {
let config = QueryCostConfig {
user_cost_budget: 100,
budget_window: Duration::from_secs(60),
..Default::default()
};
let analyzer = QueryCostAnalyzer::new(config);
assert!(analyzer.check_user_budget("user1", 50).await.is_ok());
assert!(analyzer.check_user_budget("user1", 50).await.is_ok());
assert!(analyzer.check_user_budget("user1", 10).await.is_err());
}
#[tokio::test]
async fn test_user_budget_expiration() {
let config = QueryCostConfig {
user_cost_budget: 100,
budget_window: Duration::from_millis(50), ..Default::default()
};
let analyzer = QueryCostAnalyzer::new(config);
analyzer.check_user_budget("user1", 100).await.unwrap();
assert!(analyzer.check_user_budget("user1", 1).await.is_err());
tokio::time::sleep(Duration::from_millis(60)).await;
assert!(analyzer.check_user_budget("user1", 1).await.is_ok());
}
#[tokio::test]
async fn test_analytics() {
let analyzer = QueryCostAnalyzer::new(QueryCostConfig::default());
for cost in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100] {
let mut costs = analyzer.query_costs.write().await;
costs.push_back(cost);
}
let analytics = analyzer.get_analytics().await;
assert_eq!(analytics.total_queries, 10);
assert_eq!(analytics.average_cost, 55);
assert_eq!(analytics.median_cost, 60);
let threshold = analyzer.get_expensive_threshold().await;
assert_eq!(threshold, 100);
}
#[tokio::test]
async fn test_cleanup_expired_budgets() {
let config = QueryCostConfig {
budget_window: Duration::from_millis(10), ..Default::default()
};
let analyzer = QueryCostAnalyzer::new(config);
analyzer.check_user_budget("userA", 10).await.unwrap();
tokio::time::sleep(Duration::from_millis(25)).await;
analyzer.cleanup_expired_budgets().await;
let budgets = analyzer.user_budgets.read().await;
assert!(budgets.is_empty());
}
}