use std::collections::HashMap;
use std::sync::Mutex;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct ExplainConfig {
pub mode: ExplainMode,
pub depth_threshold: usize,
pub max_total_cost: f64,
pub max_plan_rows: u64,
pub cache_ttl: Duration,
}
impl Default for ExplainConfig {
fn default() -> Self {
Self {
mode: ExplainMode::Precheck,
depth_threshold: 3,
max_total_cost: 100_000.0,
max_plan_rows: 1_000_000,
cache_ttl: Duration::from_secs(300),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ExplainMode {
Off,
Precheck,
Enforce,
}
#[derive(Debug, Clone)]
pub struct ExplainEstimate {
pub total_cost: f64,
pub plan_rows: u64,
}
struct CachedEstimate {
estimate: ExplainEstimate,
cached_at: Instant,
plan_rows: u64,
}
pub struct ExplainCache {
entries: Mutex<HashMap<u64, CachedEstimate>>,
ttl: Duration,
max_entries: usize,
}
impl ExplainCache {
pub fn new(ttl: Duration) -> Self {
Self {
entries: Mutex::new(HashMap::new()),
ttl,
max_entries: 10_000,
}
}
pub fn get(&self, shape_hash: u64, current_reltuples: Option<u64>) -> Option<ExplainEstimate> {
let entries = self.entries.lock().ok()?;
let entry = entries.get(&shape_hash)?;
if entry.cached_at.elapsed() < self.ttl {
if let Some(current) = current_reltuples
&& entry.plan_rows > 0
{
let cached = entry.plan_rows as f64;
let drift = ((current as f64) - cached).abs() / cached;
let abs_delta = (current as i64 - entry.plan_rows as i64).unsigned_abs();
if drift > 0.5 && abs_delta > 10_000 {
return None; }
}
Some(entry.estimate.clone())
} else {
None
}
}
pub fn insert(&self, shape_hash: u64, estimate: ExplainEstimate) {
if let Ok(mut entries) = self.entries.lock() {
if entries.len() >= self.max_entries / 2 {
let ttl = self.ttl;
entries.retain(|_, v| v.cached_at.elapsed() < ttl);
}
if entries.len() >= self.max_entries {
return;
}
entries.insert(
shape_hash,
CachedEstimate {
plan_rows: estimate.plan_rows,
estimate,
cached_at: Instant::now(),
},
);
}
}
pub fn len(&self) -> usize {
self.entries.lock().map(|e| e.len()).unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
pub fn parse_explain_json(json_str: &str) -> Option<ExplainEstimate> {
let total_cost = extract_json_number(json_str, "Total Cost")?;
let plan_rows = extract_json_number(json_str, "Plan Rows")? as u64;
Some(ExplainEstimate {
total_cost,
plan_rows,
})
}
fn extract_json_number(json: &str, key: &str) -> Option<f64> {
let pattern = format!("\"{}\":", key);
let start = json.find(&pattern)?;
let after_key = &json[start + pattern.len()..];
let trimmed = after_key.trim_start();
let end = trimmed.find(|c: char| {
!c.is_ascii_digit() && c != '.' && c != '-' && c != 'e' && c != 'E' && c != '+'
})?;
let num_str = &trimmed[..end];
num_str.parse::<f64>().ok()
}
#[derive(Debug)]
pub enum ExplainDecision {
Allow,
Reject {
total_cost: f64,
plan_rows: u64,
max_cost: f64,
max_rows: u64,
},
Skipped,
}
impl ExplainDecision {
pub fn is_rejected(&self) -> bool {
matches!(self, ExplainDecision::Reject { .. })
}
pub fn rejection_message(&self) -> Option<String> {
match self {
ExplainDecision::Reject {
total_cost,
plan_rows,
max_cost,
max_rows,
} => Some(format!(
"Query rejected: estimated cost {:.0} exceeds limit {:.0}, \
or estimated rows {} exceeds limit {}. \
Try narrowing your filters, reducing ?expand depth, or using pagination.",
total_cost, max_cost, plan_rows, max_rows
)),
_ => None,
}
}
pub fn rejection_detail(&self) -> Option<ExplainRejectionDetail> {
match self {
ExplainDecision::Reject {
total_cost,
plan_rows,
max_cost,
max_rows,
} => Some(ExplainRejectionDetail {
estimated_cost: *total_cost,
cost_limit: *max_cost,
estimated_rows: *plan_rows,
row_limit: *max_rows,
suggestions: vec![
"Add WHERE clauses to narrow the result set".to_string(),
"Reduce ?expand depth (deep JOINs multiply cost)".to_string(),
"Use ?limit and ?offset for pagination".to_string(),
"Add indexes on frequently filtered columns".to_string(),
],
}),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct ExplainRejectionDetail {
pub estimated_cost: f64,
pub cost_limit: f64,
pub estimated_rows: u64,
pub row_limit: u64,
pub suggestions: Vec<String>,
}
pub fn check_estimate(estimate: &ExplainEstimate, config: &ExplainConfig) -> ExplainDecision {
if estimate.total_cost > config.max_total_cost || estimate.plan_rows > config.max_plan_rows {
ExplainDecision::Reject {
total_cost: estimate.total_cost,
plan_rows: estimate.plan_rows,
max_cost: config.max_total_cost,
max_rows: config.max_plan_rows,
}
} else {
ExplainDecision::Allow
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_explain_json() {
let json = r#"[{"Plan": {"Node Type": "Seq Scan", "Total Cost": 1234.56, "Plan Rows": 5000, "Plan Width": 100}}]"#;
let est = parse_explain_json(json).unwrap();
assert!((est.total_cost - 1234.56).abs() < 0.01);
assert_eq!(est.plan_rows, 5000);
}
#[test]
fn test_parse_explain_json_nested_join() {
let json = r#"[{"Plan": {"Node Type": "Hash Join", "Total Cost": 250000.0, "Plan Rows": 2000000, "Plan Width": 200}}]"#;
let est = parse_explain_json(json).unwrap();
assert!((est.total_cost - 250000.0).abs() < 0.01);
assert_eq!(est.plan_rows, 2_000_000);
}
#[test]
fn test_parse_explain_json_invalid() {
assert!(parse_explain_json("not json").is_none());
assert!(parse_explain_json("{}").is_none());
assert!(parse_explain_json("[]").is_none());
}
#[test]
fn test_check_estimate_allow() {
let config = ExplainConfig::default();
let est = ExplainEstimate {
total_cost: 100.0,
plan_rows: 500,
};
let decision = check_estimate(&est, &config);
assert!(!decision.is_rejected());
}
#[test]
fn test_check_estimate_reject_cost() {
let config = ExplainConfig::default();
let est = ExplainEstimate {
total_cost: 200_000.0,
plan_rows: 500,
};
let decision = check_estimate(&est, &config);
assert!(decision.is_rejected());
assert!(decision.rejection_message().unwrap().contains("200000"));
}
#[test]
fn test_check_estimate_reject_rows() {
let config = ExplainConfig::default();
let est = ExplainEstimate {
total_cost: 50.0,
plan_rows: 5_000_000,
};
let decision = check_estimate(&est, &config);
assert!(decision.is_rejected());
}
#[test]
fn test_cache_basic() {
let cache = ExplainCache::new(Duration::from_secs(60));
assert!(cache.is_empty());
cache.insert(
42,
ExplainEstimate {
total_cost: 100.0,
plan_rows: 50,
},
);
assert_eq!(cache.len(), 1);
let cached = cache.get(42, None).unwrap();
assert!((cached.total_cost - 100.0).abs() < 0.01);
assert_eq!(cached.plan_rows, 50);
assert!(cache.get(99, None).is_none());
}
#[test]
fn test_cache_expiry() {
let cache = ExplainCache::new(Duration::from_millis(1));
cache.insert(
1,
ExplainEstimate {
total_cost: 100.0,
plan_rows: 50,
},
);
std::thread::sleep(Duration::from_millis(5));
assert!(cache.get(1, None).is_none());
}
#[test]
fn test_cache_drift_invalidation() {
let cache = ExplainCache::new(Duration::from_secs(60));
cache.insert(
1,
ExplainEstimate {
total_cost: 50.0,
plan_rows: 1000,
},
);
assert!(cache.get(1, None).is_some());
assert!(cache.get(1, Some(1000)).is_some());
assert!(
cache.get(1, Some(1600)).is_some(),
"small table should not thrash"
);
assert!(
cache.get(1, Some(400)).is_some(),
"small shrinkage should not thrash"
);
cache.insert(
3,
ExplainEstimate {
total_cost: 500.0,
plan_rows: 50_000,
},
);
assert!(
cache.get(3, Some(85_000)).is_none(),
"large drift should invalidate"
);
assert!(
cache.get(3, Some(70_000)).is_some(),
"moderate drift should not invalidate"
);
assert!(
cache.get(3, Some(20_000)).is_none(),
"large shrinkage should invalidate"
);
cache.insert(
2,
ExplainEstimate {
total_cost: 10.0,
plan_rows: 0,
},
);
assert!(cache.get(2, Some(999_999)).is_some());
}
#[test]
fn test_explain_mode_default() {
let config = ExplainConfig::default();
assert_eq!(config.mode, ExplainMode::Precheck);
assert_eq!(config.depth_threshold, 3);
assert!((config.max_total_cost - 100_000.0).abs() < 0.01);
assert_eq!(config.max_plan_rows, 1_000_000);
}
}