use pg_query::protobuf::node;
use crate::locks::DmlKind;
use crate::parse::format_relation;
#[cfg(feature = "catalog")]
use {
anyhow::Result,
crate::workload::{QueryFamily, TransactionBaseline, WorkloadProfile, make_label},
};
#[cfg(feature = "catalog")]
pub async fn fetch_workload(
client: &tokio_postgres::Client,
pg_version_num: i32,
) -> Result<WorkloadProfile> {
let (families, stats_reset) = fetch_query_families(client, pg_version_num).await?;
let baseline = fetch_transaction_baseline(client).await?;
let unparseable = families.iter().filter(|f| f.tables.is_empty()).count();
let stats_window_seconds = if pg_version_num >= 140000 {
client
.query_opt(
"SELECT EXTRACT(EPOCH FROM (now() - stats_reset))::float8 FROM pg_stat_statements_info",
&[],
)
.await
.ok()
.flatten()
.and_then(|r| r.get::<_, Option<f64>>(0))
} else {
None
};
Ok(WorkloadProfile {
query_families: families,
transaction_baseline: baseline,
collected_at: chrono_now(),
stats_reset,
stats_window_seconds,
unparseable_queries: unparseable,
})
}
#[cfg(feature = "catalog")]
fn chrono_now() -> String {
let output = std::process::Command::new("date")
.args(["-u", "+%Y-%m-%dT%H:%M:%SZ"])
.output();
match output {
Ok(o) => String::from_utf8_lossy(&o.stdout).trim().to_string(),
Err(_) => "unknown".into(),
}
}
#[cfg(feature = "catalog")]
async fn fetch_query_families(
client: &tokio_postgres::Client,
pg_version_num: i32,
) -> Result<(Vec<QueryFamily>, Option<String>)> {
let has_info_view = pg_version_num >= 140000;
let has_extension: bool = client
.query_one(
"SELECT EXISTS(SELECT 1 FROM pg_extension WHERE extname = 'pg_stat_statements')",
&[],
)
.await
.map(|row| row.get(0))
.unwrap_or(false);
if !has_extension {
return Ok((vec![], None));
}
let (rows, stats_reset) = if has_info_view {
let reset_row = client
.query_opt("SELECT stats_reset::text FROM pg_stat_statements_info", &[])
.await
.ok()
.flatten();
let stats_reset: Option<String> = reset_row.and_then(|r| r.get(0));
let rows = client
.query(
"SELECT
s.queryid,
s.query,
s.calls,
s.mean_exec_time,
s.stddev_exec_time,
s.calls::float8 / GREATEST(EXTRACT(EPOCH FROM (now() - i.stats_reset)), 1) AS calls_per_sec
FROM pg_stat_statements s
CROSS JOIN pg_stat_statements_info i
WHERE s.query NOT LIKE '%pg_stat%'
AND s.dbid = (SELECT oid FROM pg_database WHERE datname = current_database())
ORDER BY s.calls DESC
LIMIT 500",
&[],
)
.await
.map_err(|e| anyhow::anyhow!("pg_stat_statements query failed: {e}"))?;
(rows, stats_reset)
} else {
let rows = client
.query(
"SELECT
s.queryid,
s.query,
s.calls,
s.mean_exec_time,
s.stddev_exec_time,
NULL::float8 AS calls_per_sec
FROM pg_stat_statements s
WHERE s.query NOT LIKE '%pg_stat%'
AND s.dbid = (SELECT oid FROM pg_database WHERE datname = current_database())
ORDER BY s.calls DESC
LIMIT 500",
&[],
)
.await
.map_err(|e| anyhow::anyhow!("pg_stat_statements query failed: {e}"))?;
(rows, None)
};
let mut families = Vec::new();
for row in rows {
let queryid: i64 = row.get(0);
let query: String = row.get(1);
let _calls: i64 = row.get(2);
let mean_exec_time: f64 = row.get(3);
let stddev_exec_time: f64 = row.get(4);
let calls_per_sec: Option<f64> = row.get(5);
let cps = calls_per_sec.unwrap_or(0.0);
let p95 = if stddev_exec_time > 0.0 {
Some(mean_exec_time + 1.645 * stddev_exec_time)
} else {
None
};
let (tables, dml_kind) = extract_tables_and_dml(&query).unwrap_or_default();
let lock_mode = dml_kind.lock_mode();
families.push(QueryFamily {
queryid,
normalised_sql: query.clone(),
label: make_label(&query),
tables,
dml_kind,
lock_mode,
calls_per_sec: cps,
mean_exec_ms: mean_exec_time,
p95_exec_ms: p95,
});
}
Ok((families, stats_reset))
}
#[cfg(feature = "catalog")]
async fn fetch_transaction_baseline(
client: &tokio_postgres::Client,
) -> Result<TransactionBaseline> {
let row = client
.query_one(
"SELECT
count(*) FILTER (WHERE state = 'active'),
count(*) FILTER (WHERE state = 'idle in transaction'),
COALESCE((
SELECT EXTRACT(EPOCH FROM percentile_cont(0.5) WITHIN GROUP (
ORDER BY age(clock_timestamp(), xact_start)
))::float8 * 1000
FROM pg_stat_activity
WHERE pid != pg_backend_pid() AND state = 'active' AND xact_start IS NOT NULL
), 0::float8),
COALESCE((
SELECT EXTRACT(EPOCH FROM percentile_cont(0.95) WITHIN GROUP (
ORDER BY age(clock_timestamp(), xact_start)
))::float8 * 1000
FROM pg_stat_activity
WHERE pid != pg_backend_pid() AND state = 'active' AND xact_start IS NOT NULL
), 0::float8),
COALESCE(
max(EXTRACT(EPOCH FROM age(clock_timestamp(), xact_start))::float8)
FILTER (WHERE state = 'active' AND xact_start IS NOT NULL) * 1000,
0::float8
)
FROM pg_stat_activity
WHERE pid != pg_backend_pid()",
&[],
)
.await
.map_err(|e| anyhow::anyhow!("pg_stat_activity query failed: {e}"))?;
Ok(TransactionBaseline {
active_sessions: row.get(0),
idle_in_transaction: row.get(1),
median_age_ms: row.get(2),
p95_age_ms: row.get(3),
max_age_ms: row.get(4),
})
}
pub fn extract_tables_and_dml(sql: &str) -> Option<(Vec<String>, DmlKind)> {
let parsed = pg_query::parse(sql).ok()?;
let stmt = parsed.protobuf.stmts.first()?;
let wrapper = stmt.stmt.as_ref()?;
let n = wrapper.node.as_ref()?;
match n {
node::Node::SelectStmt(select) => {
let tables = extract_from_clause_tables(&select.from_clause);
let has_locking = !select.locking_clause.is_empty();
let kind = if has_locking {
DmlKind::SelectForUpdate
} else {
DmlKind::Select
};
if tables.is_empty() {
None
} else {
Some((tables, kind))
}
}
node::Node::InsertStmt(insert) => {
let table = insert.relation.as_ref().map(format_relation)?;
Some((vec![table], DmlKind::Insert))
}
node::Node::UpdateStmt(update) => {
let table = update.relation.as_ref().map(format_relation)?;
Some((vec![table], DmlKind::Update))
}
node::Node::DeleteStmt(delete) => {
let table = delete.relation.as_ref().map(format_relation)?;
Some((vec![table], DmlKind::Delete))
}
_ => None,
}
}
fn extract_from_clause_tables(from_clause: &[pg_query::protobuf::Node]) -> Vec<String> {
let mut tables = Vec::new();
for node in from_clause {
extract_range_vars(node, &mut tables);
}
tables.sort();
tables.dedup();
tables
}
fn extract_range_vars(node: &pg_query::protobuf::Node, tables: &mut Vec<String>) {
let Some(ref inner) = node.node else { return };
match inner {
node::Node::RangeVar(rv) => {
tables.push(format_relation(rv));
}
node::Node::JoinExpr(join) => {
if let Some(ref larg) = join.larg {
extract_range_vars(larg, tables);
}
if let Some(ref rarg) = join.rarg {
extract_range_vars(rarg, tables);
}
}
node::Node::RangeSubselect(_) => {}
_ => {}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extracts_select_table() {
let (tables, kind) = extract_tables_and_dml("SELECT * FROM orders WHERE id = $1").unwrap();
assert_eq!(tables, vec!["orders"]);
assert_eq!(kind, DmlKind::Select);
}
#[test]
fn extracts_insert_table() {
let (tables, kind) = extract_tables_and_dml(
"INSERT INTO orders (customer_id, total) VALUES ($1, $2)",
)
.unwrap();
assert_eq!(tables, vec!["orders"]);
assert_eq!(kind, DmlKind::Insert);
}
#[test]
fn extracts_update_table() {
let (tables, kind) =
extract_tables_and_dml("UPDATE orders SET status = $1 WHERE id = $2").unwrap();
assert_eq!(tables, vec!["orders"]);
assert_eq!(kind, DmlKind::Update);
}
#[test]
fn extracts_delete_table() {
let (tables, kind) =
extract_tables_and_dml("DELETE FROM orders WHERE created_at < $1").unwrap();
assert_eq!(tables, vec!["orders"]);
assert_eq!(kind, DmlKind::Delete);
}
#[test]
fn extracts_join_tables() {
let (tables, kind) = extract_tables_and_dml(
"SELECT o.id, c.name FROM orders o JOIN customers c ON o.customer_id = c.id",
)
.unwrap();
assert_eq!(tables, vec!["customers", "orders"]);
assert_eq!(kind, DmlKind::Select);
}
#[test]
fn extracts_schema_qualified_table() {
let (tables, kind) =
extract_tables_and_dml("SELECT * FROM public.orders WHERE id = $1").unwrap();
assert_eq!(tables, vec!["public.orders"]);
assert_eq!(kind, DmlKind::Select);
}
#[test]
fn returns_none_for_utility_statements() {
assert!(extract_tables_and_dml("CREATE TABLE foo (id int)").is_none());
assert!(extract_tables_and_dml("ALTER TABLE foo ADD COLUMN bar int").is_none());
}
#[test]
fn select_for_update_detected() {
let (tables, kind) =
extract_tables_and_dml("SELECT * FROM orders WHERE id = $1 FOR UPDATE").unwrap();
assert_eq!(tables, vec!["orders"]);
assert_eq!(kind, DmlKind::SelectForUpdate);
}
}