use pgrx::prelude::*;
use super::state::{DAG_STATE, TRAJECTORY_BUFFER, TrajectoryEntry};
use super::guc;
pub fn planner_hook(
parse: *mut pg_sys::Query,
query_string: *const std::os::raw::c_char,
cursor_options: std::os::raw::c_int,
bound_params: *mut pg_sys::ParamListInfoData,
) -> *mut pg_sys::PlannedStmt {
let result = unsafe {
pg_sys::standard_planner(parse, query_string, cursor_options, bound_params)
};
if !guc::is_enabled() {
return result;
}
result
}
pub fn executor_start_hook(
query_desc: *mut pg_sys::QueryDesc,
eflags: std::os::raw::c_int,
) {
if !guc::is_enabled() {
return;
}
let query_hash = compute_query_hash(query_desc);
TRAJECTORY_BUFFER.insert(query_hash, TrajectoryEntry {
query_hash,
start_time: std::time::Instant::now(),
dag_structure: None,
});
}
pub fn executor_end_hook(query_desc: *mut pg_sys::QueryDesc) {
if !guc::is_enabled() {
return;
}
let query_hash = compute_query_hash(query_desc);
if let Some((_, entry)) = TRAJECTORY_BUFFER.remove(&query_hash) {
let execution_time = entry.start_time.elapsed();
DAG_STATE.increment_trajectories();
}
}
fn compute_query_hash(query_desc: *mut pg_sys::QueryDesc) -> u64 {
use std::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
0u64.hash(&mut hasher);
hasher.finish()
}