1use anyhow::{anyhow, Result};
4use serde::{Deserialize, Serialize};
5use sqlx::{PgPool, Row};
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::time::timeout;
10use tracing::{debug, info};
11
12pub struct QueryOptimizer {
14 db_pool: Arc<PgPool>,
15}
16
17impl QueryOptimizer {
18 pub fn new(db_pool: Arc<PgPool>) -> Self {
19 Self { db_pool }
20 }
21
22 pub async fn analyze_query(&self, query: &str) -> Result<QueryAnalysis> {
24 if !self.is_safe_query_for_explain(query) {
26 return Err(anyhow!(
27 "Query contains potentially unsafe statements for EXPLAIN"
28 ));
29 }
30
31 let explain_query = format!("EXPLAIN (ANALYZE, BUFFERS, FORMAT JSON) {query}");
33
34 let row = timeout(
36 Duration::from_secs(30),
37 sqlx::query(&explain_query).fetch_one(self.db_pool.as_ref()),
38 )
39 .await
40 .map_err(|_| anyhow!("Query analysis timed out after 30 seconds"))??;
41
42 let plan_json: serde_json::Value = row.get(0);
43
44 self.parse_explain_output(plan_json)
45 }
46
47 fn is_safe_query_for_explain(&self, query: &str) -> bool {
49 let dangerous_keywords = [
50 "DROP", "DELETE", "TRUNCATE", "ALTER", "CREATE", "GRANT", "REVOKE",
51 ];
52 let upper_query = query.to_uppercase();
53
54 if dangerous_keywords
56 .iter()
57 .any(|&keyword| upper_query.contains(keyword))
58 {
59 return false;
60 }
61
62 if upper_query.contains("INSERT") || upper_query.contains("UPDATE") {
64 return false;
65 }
66
67 true
68 }
69
70 fn parse_explain_output(&self, plan: serde_json::Value) -> Result<QueryAnalysis> {
72 let plan_array = plan
73 .as_array()
74 .ok_or_else(|| anyhow!("Invalid EXPLAIN output format"))?;
75
76 let plan_obj = plan_array
77 .first()
78 .and_then(|p| p.as_object())
79 .ok_or_else(|| anyhow!("Invalid plan structure"))?;
80
81 let execution_time = plan_obj
82 .get("Execution Time")
83 .and_then(|t| t.as_f64())
84 .unwrap_or(0.0);
85
86 let planning_time = plan_obj
87 .get("Planning Time")
88 .and_then(|t| t.as_f64())
89 .unwrap_or(0.0);
90
91 let plan_details = plan_obj
92 .get("Plan")
93 .cloned()
94 .unwrap_or(serde_json::Value::Null);
95
96 let (node_type, rows_scanned, cost) = self.extract_plan_metrics(&plan_details);
98
99 let issues = self.identify_query_issues(&plan_details);
101
102 let recommendations = self.generate_recommendations(&issues);
104
105 Ok(QueryAnalysis {
106 query_type: node_type,
107 execution_time_ms: execution_time,
108 planning_time_ms: planning_time,
109 total_time_ms: execution_time + planning_time,
110 rows_scanned,
111 estimated_cost: cost,
112 issues,
113 recommendations,
114 full_plan: plan_details,
115 })
116 }
117
118 fn extract_plan_metrics(&self, plan: &serde_json::Value) -> (String, u64, f64) {
120 let node_type = plan
121 .get("Node Type")
122 .and_then(|n| n.as_str())
123 .unwrap_or("Unknown")
124 .to_string();
125
126 let rows_scanned = plan
127 .get("Actual Rows")
128 .and_then(|r| r.as_u64())
129 .unwrap_or(0);
130
131 let cost = plan
132 .get("Total Cost")
133 .and_then(|c| c.as_f64())
134 .unwrap_or(0.0);
135
136 (node_type, rows_scanned, cost)
137 }
138
139 fn identify_query_issues(&self, plan: &serde_json::Value) -> Vec<QueryIssue> {
141 let mut issues = Vec::new();
142
143 if let Some(node_type) = plan.get("Node Type").and_then(|n| n.as_str()) {
145 if node_type == "Seq Scan" {
146 if let Some(rows) = plan.get("Actual Rows").and_then(|r| r.as_u64()) {
147 if rows > 1000 {
148 issues.push(QueryIssue {
149 severity: IssueSeverity::High,
150 issue_type: "Sequential Scan".to_string(),
151 description: format!("Sequential scan on {rows} rows"),
152 impact: "High query latency".to_string(),
153 });
154 }
155 }
156 }
157 }
158
159 if let Some(filter) = plan.get("Filter").and_then(|f| f.as_str()) {
161 if !filter.is_empty() {
162 issues.push(QueryIssue {
163 severity: IssueSeverity::Medium,
164 issue_type: "Missing Index".to_string(),
165 description: format!("Filter condition without index: {filter}"),
166 impact: "Increased scan time".to_string(),
167 });
168 }
169 }
170
171 if let Some(node_type) = plan.get("Node Type").and_then(|n| n.as_str()) {
173 if node_type == "Nested Loop" {
174 if let Some(loops) = plan.get("Actual Loops").and_then(|l| l.as_u64()) {
175 if loops > 100 {
176 issues.push(QueryIssue {
177 severity: IssueSeverity::High,
178 issue_type: "Inefficient Join".to_string(),
179 description: format!("Nested loop with {loops} iterations"),
180 impact: "Exponential complexity".to_string(),
181 });
182 }
183 }
184 }
185 }
186
187 issues
188 }
189
190 fn generate_recommendations(&self, issues: &[QueryIssue]) -> Vec<String> {
192 let mut recommendations = Vec::new();
193
194 for issue in issues {
195 match issue.issue_type.as_str() {
196 "Sequential Scan" => {
197 recommendations
198 .push("Consider adding an index on frequently queried columns".to_string());
199 }
200 "Missing Index" => {
201 recommendations.push(
202 "Create an index on the filter columns to improve query performance"
203 .to_string(),
204 );
205 }
206 "Inefficient Join" => {
207 recommendations.push(
208 "Consider using hash join or merge join instead of nested loop".to_string(),
209 );
210 }
211 _ => {}
212 }
213 }
214
215 recommendations
216 }
217
218 pub async fn get_index_recommendations(&self) -> Result<Vec<IndexRecommendation>> {
220 let mut recommendations = Vec::new();
221
222 let missing_indexes_query = r#"
224 SELECT
225 schemaname,
226 tablename,
227 seq_scan,
228 seq_tup_read,
229 idx_scan,
230 idx_tup_fetch
231 FROM pg_stat_user_tables
232 WHERE seq_scan > 0
233 AND seq_tup_read > 100000
234 AND (idx_scan IS NULL OR idx_scan < seq_scan / 10)
235 ORDER BY seq_tup_read DESC
236 LIMIT 10
237 "#;
238
239 let rows = sqlx::query(missing_indexes_query)
240 .fetch_all(self.db_pool.as_ref())
241 .await?;
242
243 for row in rows {
244 let table_name: String = row.get("tablename");
245 let seq_scans: i64 = row.get("seq_scan");
246 let rows_read: i64 = row.get("seq_tup_read");
247
248 recommendations.push(IndexRecommendation {
249 table_name: table_name.clone(),
250 reason: format!(
251 "Table has {seq_scans} sequential scans reading {rows_read} rows total"
252 ),
253 suggested_columns: vec![], estimated_improvement: "50-90% reduction in scan time".to_string(),
255 priority: if rows_read > 1_000_000 {
256 RecommendationPriority::High
257 } else {
258 RecommendationPriority::Medium
259 },
260 });
261 }
262
263 let duplicate_indexes_query = r#"
265 SELECT
266 indexname,
267 tablename,
268 indexdef
269 FROM pg_indexes
270 WHERE schemaname = 'public'
271 ORDER BY tablename, indexname
272 "#;
273
274 let index_rows = sqlx::query(duplicate_indexes_query)
275 .fetch_all(self.db_pool.as_ref())
276 .await?;
277
278 let mut index_map: HashMap<String, Vec<String>> = HashMap::new();
279
280 for row in index_rows {
281 let table: String = row.get("tablename");
282 let index: String = row.get("indexname");
283 let _definition: String = row.get("indexdef");
284
285 index_map.entry(table).or_default().push(index);
286 }
287
288 for (table, indexes) in index_map {
290 if indexes.len() > 5 {
291 recommendations.push(IndexRecommendation {
292 table_name: table,
293 reason: format!(
294 "Table has {} indexes which may slow down writes",
295 indexes.len()
296 ),
297 suggested_columns: vec![],
298 estimated_improvement: "10-20% improvement in write performance".to_string(),
299 priority: RecommendationPriority::Low,
300 });
301 }
302 }
303
304 Ok(recommendations)
305 }
306
307 pub async fn optimize_connection_pool(&self) -> Result<ConnectionPoolRecommendation> {
309 let conn_stats_query = r#"
311 SELECT
312 count(*) as total_connections,
313 count(*) FILTER (WHERE state = 'active') as active_connections,
314 count(*) FILTER (WHERE state = 'idle') as idle_connections,
315 count(*) FILTER (WHERE state = 'idle in transaction') as idle_in_transaction,
316 max(EXTRACT(EPOCH FROM (now() - state_change))) as max_idle_time
317 FROM pg_stat_activity
318 WHERE datname = current_database()
319 "#;
320
321 let row = sqlx::query(conn_stats_query)
322 .fetch_one(self.db_pool.as_ref())
323 .await?;
324
325 let total_connections: i64 = row.get("total_connections");
326 let active_connections: i64 = row.get("active_connections");
327 let idle_connections: i64 = row.get("idle_connections");
328 let idle_in_transaction: i64 = row.get("idle_in_transaction");
329 let max_idle_time: Option<f64> = row.get("max_idle_time");
330
331 let mut recommendations = Vec::new();
333
334 if idle_connections > active_connections * 3 {
335 recommendations.push("Reduce max_idle_connections to save resources".to_string());
336 }
337
338 if idle_in_transaction > 0 {
339 recommendations.push("Investigate and fix idle-in-transaction connections".to_string());
340 }
341
342 if let Some(idle_time) = max_idle_time {
343 if idle_time > 300.0 {
344 recommendations
345 .push("Set connection idle timeout to prevent zombie connections".to_string());
346 }
347 }
348
349 let suggested_pool_size = ((active_connections as f64 * 1.5) as u32).max(10).min(100);
350
351 Ok(ConnectionPoolRecommendation {
352 current_connections: total_connections as u32,
353 active_connections: active_connections as u32,
354 idle_connections: idle_connections as u32,
355 suggested_pool_size,
356 suggested_idle_timeout: Duration::from_secs(300),
357 recommendations,
358 })
359 }
360
361 pub async fn run_full_analysis(&self) -> Result<FullOptimizationReport> {
363 info!("Running full database optimization analysis");
364
365 let slow_queries = self.identify_slow_queries().await?;
367
368 let index_recommendations = self.get_index_recommendations().await?;
370
371 let connection_pool = self.optimize_connection_pool().await?;
373
374 let health_score = self.calculate_health_score(&slow_queries, &index_recommendations);
376
377 let summary = self.generate_summary(&slow_queries, &index_recommendations);
379
380 Ok(FullOptimizationReport {
381 timestamp: chrono::Utc::now(),
382 health_score,
383 slow_queries,
384 index_recommendations,
385 connection_pool,
386 summary,
387 })
388 }
389
390 async fn identify_slow_queries(&self) -> Result<Vec<SlowQuery>> {
392 let slow_queries_query = r#"
394 SELECT
395 calls,
396 total_exec_time,
397 mean_exec_time,
398 stddev_exec_time,
399 query
400 FROM pg_stat_statements
401 WHERE mean_exec_time > 100
402 ORDER BY mean_exec_time DESC
403 LIMIT 10
404 "#;
405
406 match sqlx::query(slow_queries_query)
408 .fetch_all(self.db_pool.as_ref())
409 .await
410 {
411 Ok(rows) => {
412 let mut queries = Vec::new();
413 for row in rows {
414 queries.push(SlowQuery {
415 query: row.get("query"),
416 total_calls: row.get("calls"),
417 mean_time_ms: row.get("mean_exec_time"),
418 total_time_ms: row.get("total_exec_time"),
419 });
420 }
421 Ok(queries)
422 }
423 Err(_) => {
424 debug!("pg_stat_statements not available, skipping slow query analysis");
425 Ok(Vec::new())
426 }
427 }
428 }
429
430 fn calculate_health_score(
431 &self,
432 slow_queries: &[SlowQuery],
433 index_recs: &[IndexRecommendation],
434 ) -> u32 {
435 let mut score = 100u32;
436
437 score = score.saturating_sub((slow_queries.len() * 5) as u32);
439
440 for rec in index_recs {
442 match rec.priority {
443 RecommendationPriority::High => score = score.saturating_sub(10),
444 RecommendationPriority::Medium => score = score.saturating_sub(5),
445 RecommendationPriority::Low => score = score.saturating_sub(2),
446 }
447 }
448
449 score.min(100)
450 }
451
452 fn generate_summary(
453 &self,
454 slow_queries: &[SlowQuery],
455 index_recs: &[IndexRecommendation],
456 ) -> String {
457 format!(
458 "Found {} slow queries and {} index optimization opportunities",
459 slow_queries.len(),
460 index_recs.len()
461 )
462 }
463}
464
465#[derive(Debug, Clone, Serialize, Deserialize)]
467pub struct QueryAnalysis {
468 pub query_type: String,
469 pub execution_time_ms: f64,
470 pub planning_time_ms: f64,
471 pub total_time_ms: f64,
472 pub rows_scanned: u64,
473 pub estimated_cost: f64,
474 pub issues: Vec<QueryIssue>,
475 pub recommendations: Vec<String>,
476 pub full_plan: serde_json::Value,
477}
478
479#[derive(Debug, Clone, Serialize, Deserialize)]
481pub struct QueryIssue {
482 pub severity: IssueSeverity,
483 pub issue_type: String,
484 pub description: String,
485 pub impact: String,
486}
487
488#[derive(Debug, Clone, Serialize, Deserialize)]
490pub enum IssueSeverity {
491 Low,
492 Medium,
493 High,
494}
495
496#[derive(Debug, Clone, Serialize, Deserialize)]
498pub struct IndexRecommendation {
499 pub table_name: String,
500 pub reason: String,
501 pub suggested_columns: Vec<String>,
502 pub estimated_improvement: String,
503 pub priority: RecommendationPriority,
504}
505
506#[derive(Debug, Clone, Serialize, Deserialize)]
508pub enum RecommendationPriority {
509 Low,
510 Medium,
511 High,
512}
513
514#[derive(Debug, Clone, Serialize, Deserialize)]
516pub struct ConnectionPoolRecommendation {
517 pub current_connections: u32,
518 pub active_connections: u32,
519 pub idle_connections: u32,
520 pub suggested_pool_size: u32,
521 pub suggested_idle_timeout: Duration,
522 pub recommendations: Vec<String>,
523}
524
525#[derive(Debug, Clone, Serialize, Deserialize)]
527pub struct SlowQuery {
528 pub query: String,
529 pub total_calls: i64,
530 pub mean_time_ms: f64,
531 pub total_time_ms: f64,
532}
533
534#[derive(Debug, Clone, Serialize, Deserialize)]
536pub struct FullOptimizationReport {
537 pub timestamp: chrono::DateTime<chrono::Utc>,
538 pub health_score: u32,
539 pub slow_queries: Vec<SlowQuery>,
540 pub index_recommendations: Vec<IndexRecommendation>,
541 pub connection_pool: ConnectionPoolRecommendation,
542 pub summary: String,
543}
544
545#[cfg(test)]
546mod tests {
547 use super::*;
548
549 #[test]
550 fn test_health_score_calculation() {
551 let optimizer = QueryOptimizer {
552 db_pool: Arc::new(PgPool::connect_lazy("").unwrap()),
553 };
554
555 let slow_queries = vec![SlowQuery {
556 query: "SELECT * FROM test".to_string(),
557 total_calls: 100,
558 mean_time_ms: 150.0,
559 total_time_ms: 15000.0,
560 }];
561
562 let index_recs = vec![IndexRecommendation {
563 table_name: "test".to_string(),
564 reason: "Missing index".to_string(),
565 suggested_columns: vec![],
566 estimated_improvement: "50%".to_string(),
567 priority: RecommendationPriority::High,
568 }];
569
570 let score = optimizer.calculate_health_score(&slow_queries, &index_recs);
571 assert_eq!(score, 85); }
573}