use chrono::{DateTime, Local};
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tokio::fs::OpenOptions;
use tokio::io::AsyncWriteExt;
use tokio::sync::RwLock;
use tracing::{debug, error, info};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryLogEntry {
pub timestamp: DateTime<Local>,
pub query: String,
pub execution_time_ms: f64,
pub rows_returned: usize,
pub success: bool,
pub error: Option<String>,
pub client_addr: String,
pub protocol: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct QueryLogConfig {
pub enabled: bool,
pub file_path: Option<PathBuf>,
pub max_memory_entries: usize,
pub slow_query_threshold_ms: Option<f64>,
pub include_results: bool,
pub only_errors: bool,
}
impl Default for QueryLogConfig {
fn default() -> Self {
Self {
enabled: false,
file_path: None,
max_memory_entries: 1000,
slow_query_threshold_ms: Some(100.0),
include_results: false,
only_errors: false,
}
}
}
pub struct QueryLogger {
config: QueryLogConfig,
entries: Arc<RwLock<VecDeque<QueryLogEntry>>>,
stats: Arc<RwLock<QueryStats>>,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct QueryStats {
pub total_queries: u64,
pub successful_queries: u64,
pub failed_queries: u64,
pub total_execution_time_ms: f64,
pub avg_execution_time_ms: f64,
pub max_execution_time_ms: f64,
pub min_execution_time_ms: f64,
pub slow_queries: u64,
pub total_rows_returned: u64,
}
impl QueryLogger {
pub fn new(config: QueryLogConfig) -> Self {
Self {
config,
entries: Arc::new(RwLock::new(VecDeque::new())),
stats: Arc::new(RwLock::new(QueryStats::default())),
}
}
#[allow(clippy::too_many_arguments)]
pub async fn log_query(
&self,
query: &str,
execution_time: Duration,
rows_returned: usize,
success: bool,
error: Option<String>,
client_addr: String,
protocol: &str,
) {
if !self.config.enabled {
return;
}
let execution_time_ms = execution_time.as_secs_f64() * 1000.0;
if self.config.only_errors && success {
return;
}
let is_slow = self
.config
.slow_query_threshold_ms
.map(|threshold| execution_time_ms >= threshold)
.unwrap_or(false);
let entry = QueryLogEntry {
timestamp: Local::now(),
query: query.to_string(),
execution_time_ms,
rows_returned,
success,
error: error.clone(),
client_addr,
protocol: protocol.to_string(),
};
{
let mut stats = self.stats.write().await;
stats.total_queries += 1;
if success {
stats.successful_queries += 1;
} else {
stats.failed_queries += 1;
}
stats.total_execution_time_ms += execution_time_ms;
stats.avg_execution_time_ms =
stats.total_execution_time_ms / stats.total_queries as f64;
if execution_time_ms > stats.max_execution_time_ms {
stats.max_execution_time_ms = execution_time_ms;
}
if stats.min_execution_time_ms == 0.0 || execution_time_ms < stats.min_execution_time_ms
{
stats.min_execution_time_ms = execution_time_ms;
}
if is_slow {
stats.slow_queries += 1;
}
stats.total_rows_returned += rows_returned as u64;
}
{
let mut entries = self.entries.write().await;
entries.push_back(entry.clone());
while entries.len() > self.config.max_memory_entries {
entries.pop_front();
}
}
if let Some(ref file_path) = self.config.file_path {
if let Err(e) = self.write_to_file(&entry, file_path).await {
error!("Failed to write query log to file: {}", e);
}
}
if is_slow {
info!(
"🐌 Slow query detected: {} ms - {}",
execution_time_ms,
truncate_query(query, 100)
);
} else if !success {
info!(
"❌ Query failed: {} - {}",
truncate_query(query, 100),
error.unwrap_or_default()
);
} else {
debug!(
"✅ Query executed: {} ms - {}",
execution_time_ms,
truncate_query(query, 100)
);
}
}
async fn write_to_file(&self, entry: &QueryLogEntry, file_path: &PathBuf) -> crate::Result<()> {
let mut file = OpenOptions::new()
.create(true)
.append(true)
.open(file_path)
.await?;
let log_line = if self.config.include_results {
serde_json::to_string(entry).map_err(|e| {
crate::YamlBaseError::Io(std::io::Error::other(format!(
"Failed to serialize log entry: {}",
e
)))
})?
} else {
format!(
"[{}] {} | {:.2}ms | {} rows | {} | {}\n",
entry.timestamp.format("%Y-%m-%d %H:%M:%S"),
if entry.success { "OK" } else { "FAIL" },
entry.execution_time_ms,
entry.rows_returned,
entry.client_addr,
truncate_query(&entry.query, 200)
)
};
file.write_all(log_line.as_bytes()).await?;
file.write_all(b"\n").await?;
Ok(())
}
pub async fn get_recent_queries(&self, limit: usize) -> Vec<QueryLogEntry> {
let entries = self.entries.read().await;
entries.iter().rev().take(limit).cloned().collect()
}
pub async fn get_slow_queries(&self, limit: usize) -> Vec<QueryLogEntry> {
let threshold = self.config.slow_query_threshold_ms.unwrap_or(100.0);
let entries = self.entries.read().await;
let mut slow_queries: Vec<_> = entries
.iter()
.filter(|e| e.execution_time_ms >= threshold)
.cloned()
.collect();
slow_queries.sort_by(|a, b| {
b.execution_time_ms
.partial_cmp(&a.execution_time_ms)
.unwrap_or(std::cmp::Ordering::Equal)
});
slow_queries.into_iter().take(limit).collect()
}
pub async fn get_failed_queries(&self, limit: usize) -> Vec<QueryLogEntry> {
let entries = self.entries.read().await;
entries
.iter()
.filter(|e| !e.success)
.rev()
.take(limit)
.cloned()
.collect()
}
pub async fn get_stats(&self) -> QueryStats {
self.stats.read().await.clone()
}
pub async fn clear(&self) {
self.entries.write().await.clear();
*self.stats.write().await = QueryStats::default();
info!("Query log cleared");
}
pub async fn export_to_json(&self, path: &PathBuf) -> crate::Result<()> {
let entries = self.entries.read().await;
let json = serde_json::to_string_pretty(&*entries).map_err(|e| {
crate::YamlBaseError::Io(std::io::Error::other(format!(
"Failed to serialize queries: {}",
e
)))
})?;
tokio::fs::write(path, json).await?;
info!("Exported {} queries to {}", entries.len(), path.display());
Ok(())
}
pub async fn generate_report(&self) -> String {
let stats = self.stats.read().await;
let entries = self.entries.read().await;
let mut report = String::new();
report.push_str("=== Query Performance Report ===\n\n");
report.push_str("📊 Overall Statistics:\n");
report.push_str(&format!(" Total Queries: {}\n", stats.total_queries));
report.push_str(&format!(
" Successful: {} ({:.1}%)\n",
stats.successful_queries,
(stats.successful_queries as f64 / stats.total_queries as f64) * 100.0
));
report.push_str(&format!(
" Failed: {} ({:.1}%)\n",
stats.failed_queries,
(stats.failed_queries as f64 / stats.total_queries as f64) * 100.0
));
report.push('\n');
report.push_str("⏱️ Performance Metrics:\n");
report.push_str(&format!(
" Average Time: {:.2} ms\n",
stats.avg_execution_time_ms
));
report.push_str(&format!(
" Min Time: {:.2} ms\n",
stats.min_execution_time_ms
));
report.push_str(&format!(
" Max Time: {:.2} ms\n",
stats.max_execution_time_ms
));
report.push_str(&format!(
" Slow Queries: {} ({:.1}%)\n",
stats.slow_queries,
(stats.slow_queries as f64 / stats.total_queries as f64) * 100.0
));
report.push('\n');
report.push_str("📈 Throughput:\n");
report.push_str(&format!(" Total Rows: {}\n", stats.total_rows_returned));
report.push_str(&format!(
" Avg Rows/Query: {:.1}\n",
stats.total_rows_returned as f64 / stats.total_queries as f64
));
if stats.slow_queries > 0 {
report.push_str("\n🐌 Top 5 Slowest Queries:\n");
let mut slow_queries: Vec<_> = entries
.iter()
.filter(|e| {
e.execution_time_ms >= self.config.slow_query_threshold_ms.unwrap_or(100.0)
})
.collect();
slow_queries.sort_by(|a, b| {
b.execution_time_ms
.partial_cmp(&a.execution_time_ms)
.unwrap_or(std::cmp::Ordering::Equal)
});
for (i, entry) in slow_queries.iter().take(5).enumerate() {
report.push_str(&format!(
" {}. {:.2} ms - {}\n",
i + 1,
entry.execution_time_ms,
truncate_query(&entry.query, 60)
));
}
}
let mut query_counts = std::collections::HashMap::new();
for entry in entries.iter() {
*query_counts
.entry(normalize_query(&entry.query))
.or_insert(0) += 1;
}
if !query_counts.is_empty() {
let mut freq_queries: Vec<_> = query_counts.iter().collect();
freq_queries.sort_by(|a, b| b.1.cmp(a.1));
report.push_str("\n🔥 Top 5 Most Frequent Queries:\n");
for (i, (query, count)) in freq_queries.iter().take(5).enumerate() {
report.push_str(&format!(
" {}. {} times - {}\n",
i + 1,
count,
truncate_query(query, 60)
));
}
}
report
}
}
fn truncate_query(query: &str, max_len: usize) -> String {
let normalized = query.replace(['\n', '\t'], " ");
if normalized.len() <= max_len {
normalized
} else {
let truncate_at = max_len.saturating_sub(3);
format!("{}...", &normalized[..truncate_at])
}
}
fn normalize_query(query: &str) -> String {
query
.to_lowercase()
.split_whitespace()
.collect::<Vec<_>>()
.join(" ")
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_query_logging() {
let config = QueryLogConfig {
enabled: true,
..Default::default()
};
let logger = QueryLogger::new(config);
logger
.log_query(
"SELECT * FROM users",
Duration::from_millis(10),
5,
true,
None,
"127.0.0.1:12345".to_string(),
"postgresql",
)
.await;
logger
.log_query(
"SELECT * FROM products WHERE price > 100",
Duration::from_millis(150),
20,
true,
None,
"127.0.0.1:12346".to_string(),
"postgresql",
)
.await;
logger
.log_query(
"SELECT * FROM invalid_table",
Duration::from_millis(5),
0,
false,
Some("Table not found".to_string()),
"127.0.0.1:12347".to_string(),
"postgresql",
)
.await;
let stats = logger.get_stats().await;
assert_eq!(stats.total_queries, 3);
assert_eq!(stats.successful_queries, 2);
assert_eq!(stats.failed_queries, 1);
assert_eq!(stats.slow_queries, 1);
let recent = logger.get_recent_queries(10).await;
assert_eq!(recent.len(), 3);
let slow = logger.get_slow_queries(10).await;
assert_eq!(slow.len(), 1);
assert!(slow[0].query.contains("products"));
let failed = logger.get_failed_queries(10).await;
assert_eq!(failed.len(), 1);
assert!(failed[0].query.contains("invalid_table"));
}
#[test]
fn test_query_truncation() {
let query = "SELECT very_long_column_name_1, very_long_column_name_2, very_long_column_name_3 FROM table";
let truncated = truncate_query(query, 30);
assert_eq!(truncated, "SELECT very_long_column_nam...");
}
#[test]
fn test_query_normalization() {
let query1 = "SELECT *\n FROM\t users WHERE id = 1";
let query2 = "select * from users where id = 1";
assert_eq!(normalize_query(query1), normalize_query(query2));
}
}