use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProfilerConfig {
pub enabled: bool,
pub max_depth: usize,
pub min_duration_us: u64,
pub include_memory: bool,
}
impl Default for ProfilerConfig {
fn default() -> Self {
Self {
enabled: true,
max_depth: 10,
min_duration_us: 10,
include_memory: false,
}
}
}
pub struct QueryProfiler {
config: ProfilerConfig,
queries: Vec<QueryProfile>,
current_query: Option<QueryProfile>,
stage_stack: Vec<(String, Instant)>,
}
impl QueryProfiler {
pub fn new(config: ProfilerConfig) -> Self {
Self {
config,
queries: Vec::new(),
current_query: None,
stage_stack: Vec::new(),
}
}
pub fn start_query(&mut self, name: impl Into<String>) {
if !self.config.enabled {
return;
}
self.current_query = Some(QueryProfile {
name: name.into(),
start_time: Instant::now(),
end_time: None,
stages: Vec::new(),
total_duration_ms: 0.0,
});
self.stage_stack.clear();
}
pub fn start_stage(&mut self, name: impl Into<String>) {
if !self.config.enabled || self.current_query.is_none() {
return;
}
if self.stage_stack.len() >= self.config.max_depth {
return;
}
self.stage_stack.push((name.into(), Instant::now()));
}
pub fn end_stage(&mut self) {
if !self.config.enabled || self.current_query.is_none() {
return;
}
if let Some((name, start_time)) = self.stage_stack.pop() {
let duration = start_time.elapsed();
let duration_us = duration.as_micros() as u64;
if duration_us < self.config.min_duration_us {
return;
}
let depth = self.stage_stack.len();
let stage = ProfileStage {
name,
depth,
duration_ms: duration.as_secs_f64() * 1000.0,
start_offset_ms: 0.0, };
if let Some(query) = &mut self.current_query {
query.stages.push(stage);
}
}
}
pub fn end_query(&mut self) {
if !self.config.enabled {
return;
}
if let Some(mut query) = self.current_query.take() {
if let Some(end_time) = query.end_time {
query.total_duration_ms =
end_time.duration_since(query.start_time).as_secs_f64() * 1000.0;
} else {
let end_time = Instant::now();
query.total_duration_ms =
end_time.duration_since(query.start_time).as_secs_f64() * 1000.0;
query.end_time = Some(end_time);
}
self.queries.push(query);
}
self.stage_stack.clear();
}
pub fn queries(&self) -> &[QueryProfile] {
&self.queries
}
pub fn find_bottlenecks(&self, threshold: f64) -> Vec<&ProfileStage> {
let mut bottlenecks = Vec::new();
for query in &self.queries {
let threshold_ms = query.total_duration_ms * threshold;
for stage in &query.stages {
if stage.duration_ms >= threshold_ms {
bottlenecks.push(stage);
}
}
}
bottlenecks.sort_by(|a, b| b.duration_ms.partial_cmp(&a.duration_ms).unwrap());
bottlenecks
}
pub fn generate_flame_graph(&self) -> Result<String> {
if self.queries.is_empty() {
return Err(anyhow!("No queries to profile"));
}
let mut svg = String::new();
svg.push_str(r#"<?xml version="1.0" standalone="no"?>"#);
svg.push_str("\n");
svg.push_str(r#"<svg width="1200" height="600" xmlns="http://www.w3.org/2000/svg">"#);
svg.push_str("\n");
svg.push_str(r#"<text x="600" y="30" text-anchor="middle" font-size="20" font-family="Arial">Query Flame Graph</text>"#);
svg.push_str("\n");
let mut y = 60;
for (i, query) in self.queries.iter().enumerate() {
svg.push_str(&format!(
r#"<text x="10" y="{}" font-size="14" font-family="Arial">{}: {:.2}ms</text>"#,
y, query.name, query.total_duration_ms
));
svg.push_str("\n");
y += 20;
for stage in &query.stages {
let width = (stage.duration_ms / query.total_duration_ms * 1000.0) as i32;
let x = 100 + (stage.depth * 50);
svg.push_str(&format!(
r#"<rect x="{}" y="{}" width="{}" height="15" fill="rgb(200, {}, 100)" stroke="black"/>"#,
x,
y,
width.max(1),
(255 - stage.depth * 30).min(255)
));
svg.push_str("\n");
svg.push_str(&format!(
r#"<text x="{}" y="{}" font-size="10" font-family="Arial">{} ({:.2}ms)</text>"#,
x + 5,
y + 12,
stage.name,
stage.duration_ms
));
svg.push_str("\n");
y += 20;
}
y += 30;
if i >= 10 {
break; }
}
svg.push_str("</svg>");
Ok(svg)
}
pub fn export_json(&self) -> Result<String> {
Ok(serde_json::to_string_pretty(&self.queries)?)
}
pub fn summary(&self) -> ProfileSummary {
let total_queries = self.queries.len();
let total_time_ms: f64 = self.queries.iter().map(|q| q.total_duration_ms).sum();
let avg_time_ms = if total_queries > 0 {
total_time_ms / total_queries as f64
} else {
0.0
};
let max_time_ms = self
.queries
.iter()
.map(|q| q.total_duration_ms)
.fold(0.0f64, f64::max);
let min_time_ms = self
.queries
.iter()
.map(|q| q.total_duration_ms)
.fold(f64::MAX, f64::min);
ProfileSummary {
total_queries,
total_time_ms,
avg_time_ms,
min_time_ms: if min_time_ms == f64::MAX {
0.0
} else {
min_time_ms
},
max_time_ms,
}
}
pub fn clear(&mut self) {
self.queries.clear();
self.current_query = None;
self.stage_stack.clear();
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryProfile {
pub name: String,
#[serde(skip, default = "Instant::now")]
start_time: Instant,
#[serde(skip)]
end_time: Option<Instant>,
pub total_duration_ms: f64,
pub stages: Vec<ProfileStage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProfileStage {
pub name: String,
pub depth: usize,
pub duration_ms: f64,
pub start_offset_ms: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProfileSummary {
pub total_queries: usize,
pub total_time_ms: f64,
pub avg_time_ms: f64,
pub min_time_ms: f64,
pub max_time_ms: f64,
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_profiler_config_default() {
let config = ProfilerConfig::default();
assert!(config.enabled);
assert_eq!(config.max_depth, 10);
}
#[test]
fn test_basic_profiling() {
let mut profiler = QueryProfiler::new(ProfilerConfig::default());
profiler.start_query("test_query");
thread::sleep(Duration::from_millis(10));
profiler.end_query();
assert_eq!(profiler.queries().len(), 1);
assert!(profiler.queries()[0].total_duration_ms >= 10.0);
}
#[test]
fn test_stage_profiling() {
let mut profiler = QueryProfiler::new(ProfilerConfig::default());
profiler.start_query("test_query");
profiler.start_stage("stage1");
thread::sleep(Duration::from_millis(5));
profiler.end_stage();
profiler.start_stage("stage2");
thread::sleep(Duration::from_millis(5));
profiler.end_stage();
profiler.end_query();
assert_eq!(profiler.queries()[0].stages.len(), 2);
assert!(profiler.queries()[0].stages[0].duration_ms >= 5.0);
}
#[test]
fn test_nested_stages() {
let mut profiler = QueryProfiler::new(ProfilerConfig::default());
profiler.start_query("nested_query");
profiler.start_stage("outer");
profiler.start_stage("inner");
thread::sleep(Duration::from_millis(5));
profiler.end_stage(); profiler.end_stage(); profiler.end_query();
let query = &profiler.queries()[0];
assert_eq!(query.stages.len(), 2);
assert_eq!(query.stages[0].depth, 1); assert_eq!(query.stages[1].depth, 0); }
#[test]
fn test_disabled_profiler() {
let mut config = ProfilerConfig::default();
config.enabled = false;
let mut profiler = QueryProfiler::new(config);
profiler.start_query("disabled_query");
profiler.start_stage("stage1");
profiler.end_stage();
profiler.end_query();
assert_eq!(profiler.queries().len(), 0);
}
#[test]
fn test_find_bottlenecks() {
let mut profiler = QueryProfiler::new(ProfilerConfig::default());
profiler.start_query("bottleneck_query");
profiler.start_stage("fast_stage");
thread::sleep(Duration::from_millis(1));
profiler.end_stage();
profiler.start_stage("slow_stage");
thread::sleep(Duration::from_millis(50));
profiler.end_stage();
profiler.end_query();
let bottlenecks = profiler.find_bottlenecks(0.5); assert!(!bottlenecks.is_empty());
assert_eq!(bottlenecks[0].name, "slow_stage");
}
#[test]
fn test_summary_statistics() {
let mut profiler = QueryProfiler::new(ProfilerConfig::default());
for i in 0..3 {
profiler.start_query(format!("query_{}", i));
thread::sleep(Duration::from_millis(10));
profiler.end_query();
}
let summary = profiler.summary();
assert_eq!(summary.total_queries, 3);
assert!(summary.avg_time_ms >= 10.0);
assert!(summary.max_time_ms >= summary.min_time_ms);
}
#[test]
fn test_flame_graph_generation() {
let mut profiler = QueryProfiler::new(ProfilerConfig::default());
profiler.start_query("flame_query");
profiler.start_stage("stage1");
thread::sleep(Duration::from_millis(5));
profiler.end_stage();
profiler.end_query();
let svg = profiler.generate_flame_graph();
assert!(svg.is_ok());
assert!(svg.unwrap().contains("<svg"));
}
#[test]
fn test_json_export() {
let mut profiler = QueryProfiler::new(ProfilerConfig::default());
profiler.start_query("json_query");
profiler.start_stage("stage1");
thread::sleep(Duration::from_millis(5));
profiler.end_stage();
profiler.end_query();
let json = profiler.export_json();
assert!(json.is_ok());
assert!(json.unwrap().contains("json_query"));
}
#[test]
fn test_clear() {
let mut profiler = QueryProfiler::new(ProfilerConfig::default());
profiler.start_query("test_query");
profiler.end_query();
assert_eq!(profiler.queries().len(), 1);
profiler.clear();
assert_eq!(profiler.queries().len(), 0);
}
#[test]
fn test_min_duration_filter() {
let mut config = ProfilerConfig::default();
config.min_duration_us = 1_000_000;
let mut profiler = QueryProfiler::new(config);
profiler.start_query("filter_query");
profiler.start_stage("fast_stage");
thread::sleep(Duration::from_micros(100)); profiler.end_stage();
profiler.end_query();
assert_eq!(profiler.queries()[0].stages.len(), 0);
}
#[test]
fn test_max_depth_limit() {
let mut config = ProfilerConfig::default();
config.max_depth = 2;
let mut profiler = QueryProfiler::new(config);
profiler.start_query("depth_query");
profiler.start_stage("level1");
profiler.start_stage("level2");
profiler.start_stage("level3"); thread::sleep(Duration::from_millis(5));
profiler.end_stage();
profiler.end_stage();
profiler.end_stage();
profiler.end_query();
assert_eq!(profiler.queries()[0].stages.len(), 2);
}
#[test]
fn test_empty_profiler_flame_graph() {
let profiler = QueryProfiler::new(ProfilerConfig::default());
let result = profiler.generate_flame_graph();
assert!(result.is_err());
}
}