Skip to main content

qail_pg/driver/
explain.rs

1//! EXPLAIN-based query cost estimation for pre-check rejection.
2//!
3//! Provides runtime cost-based rejection of queries that would be too
4//! expensive, using PostgreSQL's `EXPLAIN (FORMAT JSON)` output.
5//!
6//! # Modes
7//! - **Off**: No EXPLAIN pre-check
8//! - **Precheck**: Run EXPLAIN on cache-miss for queries with expand depth ≥ threshold
9//! - **Enforce**: Always run EXPLAIN and enforce cost thresholds
10//!
11//! # Caching
12//! EXPLAIN results are cached by `AST_shape_hash + rls_signature` with configurable TTL.
13//! This avoids repeated EXPLAIN calls for the same query shape.
14
15use std::collections::HashMap;
16use std::sync::Mutex;
17use std::time::{Duration, Instant};
18
19/// Configuration for EXPLAIN pre-check behavior.
20#[derive(Debug, Clone)]
21pub struct ExplainConfig {
22    /// Operating mode for EXPLAIN pre-check.
23    pub mode: ExplainMode,
24
25    /// Run EXPLAIN for queries with expand_depth >= this value.
26    /// Default: 3 (queries joining 3+ tables get pre-checked).
27    pub depth_threshold: usize,
28
29    /// Reject if PostgreSQL's estimated total cost exceeds this.
30    /// Default: 100,000 (unitless PostgreSQL planner cost).
31    pub max_total_cost: f64,
32
33    /// Reject if PostgreSQL estimates more rows than this.
34    /// Default: 1,000,000 rows.
35    pub max_plan_rows: u64,
36
37    /// TTL for cached EXPLAIN results.
38    /// Default: 5 minutes.
39    pub cache_ttl: Duration,
40}
41
42impl Default for ExplainConfig {
43    fn default() -> Self {
44        Self {
45            mode: ExplainMode::Precheck,
46            depth_threshold: 3,
47            max_total_cost: 100_000.0,
48            max_plan_rows: 1_000_000,
49            cache_ttl: Duration::from_secs(300),
50        }
51    }
52}
53
54/// Operating mode for EXPLAIN pre-check.
55#[derive(Debug, Clone, PartialEq, Eq)]
56pub enum ExplainMode {
57    /// No EXPLAIN pre-check — fastest, no protection.
58    Off,
59    /// Run EXPLAIN on cache-miss for queries above depth threshold.
60    /// Recommended default for production.
61    Precheck,
62    /// Always run EXPLAIN and enforce — strictest, slight latency cost.
63    /// Recommended for staging or high-security tenants.
64    Enforce,
65}
66
67/// Result of an EXPLAIN pre-check.
68#[derive(Debug, Clone)]
69pub struct ExplainEstimate {
70    /// PostgreSQL's estimated total cost (arbitrary units).
71    pub total_cost: f64,
72    /// PostgreSQL's estimated number of rows returned.
73    pub plan_rows: u64,
74}
75
76/// Cached EXPLAIN result with TTL and row-estimate snapshot.
77struct CachedEstimate {
78    estimate: ExplainEstimate,
79    cached_at: Instant,
80    /// Row estimate snapshot at cache time, for drift detection.
81    plan_rows: u64,
82}
83
84/// In-memory cache for EXPLAIN estimates, keyed by AST shape hash.
85pub struct ExplainCache {
86    entries: Mutex<HashMap<u64, CachedEstimate>>,
87    ttl: Duration,
88    /// Maximum number of cached entries to prevent OOM from shape explosion
89    max_entries: usize,
90}
91
92impl ExplainCache {
93    /// Create a new EXPLAIN cache with the given TTL.
94    pub fn new(ttl: Duration) -> Self {
95        Self {
96            entries: Mutex::new(HashMap::new()),
97            ttl,
98            max_entries: 10_000,
99        }
100    }
101
102    /// Get a cached estimate if it exists, hasn't expired, and row-estimate
103    /// hasn't drifted beyond 50%.
104    ///
105    /// `current_reltuples` is the current `pg_class.reltuples` for the primary
106    /// table. If provided and the cached plan_rows have drifted >50% from
107    /// the current estimate, the entry is considered stale (data skew).
108    pub fn get(&self, shape_hash: u64, current_reltuples: Option<u64>) -> Option<ExplainEstimate> {
109        let entries = self.entries.lock().ok()?;
110        let entry = entries.get(&shape_hash)?;
111        if entry.cached_at.elapsed() < self.ttl {
112            // Row-estimate drift check: invalidate if BOTH conditions met:
113            // 1. Relative change > 50% (data skew)
114            // 2. Absolute delta > 10,000 rows (prevents small table thrash)
115            if let Some(current) = current_reltuples
116                && entry.plan_rows > 0
117            {
118                let cached = entry.plan_rows as f64;
119                let drift = ((current as f64) - cached).abs() / cached;
120                let abs_delta = (current as i64 - entry.plan_rows as i64).unsigned_abs();
121                if drift > 0.5 && abs_delta > 10_000 {
122                    return None; // Stale — significant data skew detected
123                }
124            }
125            Some(entry.estimate.clone())
126        } else {
127            None
128        }
129    }
130
131    /// Store an estimate in the cache.
132    pub fn insert(&self, shape_hash: u64, estimate: ExplainEstimate) {
133        if let Ok(mut entries) = self.entries.lock() {
134            // Evict expired entries when approaching capacity
135            if entries.len() >= self.max_entries / 2 {
136                let ttl = self.ttl;
137                entries.retain(|_, v| v.cached_at.elapsed() < ttl);
138            }
139            // Hard cap: if still at capacity after eviction, skip insert
140            if entries.len() >= self.max_entries {
141                return;
142            }
143            entries.insert(shape_hash, CachedEstimate {
144                plan_rows: estimate.plan_rows,
145                estimate,
146                cached_at: Instant::now(),
147            });
148        }
149    }
150
151    /// Number of cached entries (for metrics).
152    pub fn len(&self) -> usize {
153        self.entries.lock().map(|e| e.len()).unwrap_or(0)
154    }
155
156    /// Whether the cache is empty.
157    pub fn is_empty(&self) -> bool {
158        self.len() == 0
159    }
160}
161
162/// Parse `EXPLAIN (FORMAT JSON)` output to extract cost estimates.
163///
164/// Uses lightweight string parsing to avoid adding serde_json as a
165/// dependency to the pg driver crate. The EXPLAIN JSON format is stable:
166/// ```json
167/// [{"Plan": {"Total Cost": 1234.56, "Plan Rows": 5000, ...}}]
168/// ```
169pub fn parse_explain_json(json_str: &str) -> Option<ExplainEstimate> {
170    let total_cost = extract_json_number(json_str, "Total Cost")?;
171    let plan_rows = extract_json_number(json_str, "Plan Rows")? as u64;
172
173    Some(ExplainEstimate {
174        total_cost,
175        plan_rows,
176    })
177}
178
179/// Extract a numeric value after `"key":` from a JSON string.
180fn extract_json_number(json: &str, key: &str) -> Option<f64> {
181    let pattern = format!("\"{}\":", key);
182    let start = json.find(&pattern)?;
183    let after_key = &json[start + pattern.len()..];
184
185    // Skip whitespace
186    let trimmed = after_key.trim_start();
187
188    // Parse the number (may be integer or float)
189    let end = trimmed.find(|c: char| !c.is_ascii_digit() && c != '.' && c != '-' && c != 'e' && c != 'E' && c != '+')?;
190    let num_str = &trimmed[..end];
191    num_str.parse::<f64>().ok()
192}
193
194/// Decision from the EXPLAIN pre-check.
195#[derive(Debug)]
196pub enum ExplainDecision {
197    /// Query is allowed to proceed.
198    Allow,
199    /// Query is rejected with an explanation.
200    Reject {
201        total_cost: f64,
202        plan_rows: u64,
203        max_cost: f64,
204        max_rows: u64,
205    },
206    /// EXPLAIN was skipped (mode=Off or below depth threshold).
207    Skipped,
208}
209
210impl ExplainDecision {
211    /// Returns true if the query should be rejected.
212    pub fn is_rejected(&self) -> bool {
213        matches!(self, ExplainDecision::Reject { .. })
214    }
215
216    /// Human-readable rejection message for API responses.
217    pub fn rejection_message(&self) -> Option<String> {
218        match self {
219            ExplainDecision::Reject { total_cost, plan_rows, max_cost, max_rows } => {
220                Some(format!(
221                    "Query rejected: estimated cost {:.0} exceeds limit {:.0}, \
222                     or estimated rows {} exceeds limit {}. \
223                     Try narrowing your filters, reducing ?expand depth, or using pagination.",
224                    total_cost, max_cost, plan_rows, max_rows
225                ))
226            }
227            _ => None,
228        }
229    }
230
231    /// Machine-readable rejection detail for structured API error responses.
232    ///
233    /// Returns `None` for `Allow` and `Skipped` decisions.
234    /// Client SDKs can use this to programmatically react to cost rejections.
235    pub fn rejection_detail(&self) -> Option<ExplainRejectionDetail> {
236        match self {
237            ExplainDecision::Reject { total_cost, plan_rows, max_cost, max_rows } => {
238                Some(ExplainRejectionDetail {
239                    estimated_cost: *total_cost,
240                    cost_limit: *max_cost,
241                    estimated_rows: *plan_rows,
242                    row_limit: *max_rows,
243                    suggestions: vec![
244                        "Add WHERE clauses to narrow the result set".to_string(),
245                        "Reduce ?expand depth (deep JOINs multiply cost)".to_string(),
246                        "Use ?limit and ?offset for pagination".to_string(),
247                        "Add indexes on frequently filtered columns".to_string(),
248                    ],
249                })
250            }
251            _ => None,
252        }
253    }
254}
255
256/// Structured rejection detail for EXPLAIN cost guard violations.
257#[derive(Debug, Clone)]
258pub struct ExplainRejectionDetail {
259    pub estimated_cost: f64,
260    pub cost_limit: f64,
261    pub estimated_rows: u64,
262    pub row_limit: u64,
263    pub suggestions: Vec<String>,
264}
265
266/// Check an EXPLAIN estimate against configured thresholds.
267pub fn check_estimate(estimate: &ExplainEstimate, config: &ExplainConfig) -> ExplainDecision {
268    if estimate.total_cost > config.max_total_cost || estimate.plan_rows > config.max_plan_rows {
269        ExplainDecision::Reject {
270            total_cost: estimate.total_cost,
271            plan_rows: estimate.plan_rows,
272            max_cost: config.max_total_cost,
273            max_rows: config.max_plan_rows,
274        }
275    } else {
276        ExplainDecision::Allow
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn test_parse_explain_json() {
286        let json = r#"[{"Plan": {"Node Type": "Seq Scan", "Total Cost": 1234.56, "Plan Rows": 5000, "Plan Width": 100}}]"#;
287        let est = parse_explain_json(json).unwrap();
288        assert!((est.total_cost - 1234.56).abs() < 0.01);
289        assert_eq!(est.plan_rows, 5000);
290    }
291
292    #[test]
293    fn test_parse_explain_json_nested_join() {
294        let json = r#"[{"Plan": {"Node Type": "Hash Join", "Total Cost": 250000.0, "Plan Rows": 2000000, "Plan Width": 200}}]"#;
295        let est = parse_explain_json(json).unwrap();
296        assert!((est.total_cost - 250000.0).abs() < 0.01);
297        assert_eq!(est.plan_rows, 2_000_000);
298    }
299
300    #[test]
301    fn test_parse_explain_json_invalid() {
302        assert!(parse_explain_json("not json").is_none());
303        assert!(parse_explain_json("{}").is_none());
304        assert!(parse_explain_json("[]").is_none());
305    }
306
307    #[test]
308    fn test_check_estimate_allow() {
309        let config = ExplainConfig::default();
310        let est = ExplainEstimate { total_cost: 100.0, plan_rows: 500 };
311        let decision = check_estimate(&est, &config);
312        assert!(!decision.is_rejected());
313    }
314
315    #[test]
316    fn test_check_estimate_reject_cost() {
317        let config = ExplainConfig::default();
318        let est = ExplainEstimate { total_cost: 200_000.0, plan_rows: 500 };
319        let decision = check_estimate(&est, &config);
320        assert!(decision.is_rejected());
321        assert!(decision.rejection_message().unwrap().contains("200000"));
322    }
323
324    #[test]
325    fn test_check_estimate_reject_rows() {
326        let config = ExplainConfig::default();
327        let est = ExplainEstimate { total_cost: 50.0, plan_rows: 5_000_000 };
328        let decision = check_estimate(&est, &config);
329        assert!(decision.is_rejected());
330    }
331
332    #[test]
333    fn test_cache_basic() {
334        let cache = ExplainCache::new(Duration::from_secs(60));
335        assert!(cache.is_empty());
336
337        cache.insert(42, ExplainEstimate { total_cost: 100.0, plan_rows: 50 });
338        assert_eq!(cache.len(), 1);
339
340        let cached = cache.get(42, None).unwrap();
341        assert!((cached.total_cost - 100.0).abs() < 0.01);
342        assert_eq!(cached.plan_rows, 50);
343
344        // Miss for unknown key
345        assert!(cache.get(99, None).is_none());
346    }
347
348    #[test]
349    fn test_cache_expiry() {
350        let cache = ExplainCache::new(Duration::from_millis(1));
351        cache.insert(1, ExplainEstimate { total_cost: 100.0, plan_rows: 50 });
352
353        // Wait for expiry
354        std::thread::sleep(Duration::from_millis(5));
355        assert!(cache.get(1, None).is_none());
356    }
357
358    #[test]
359    fn test_cache_drift_invalidation() {
360        let cache = ExplainCache::new(Duration::from_secs(60));
361
362        // ── Small dataset: relative drift alone should NOT invalidate ──
363        cache.insert(1, ExplainEstimate { total_cost: 50.0, plan_rows: 1000 });
364
365        // No reltuples — pure TTL, should hit
366        assert!(cache.get(1, None).is_some());
367
368        // Same estimate — no drift, should hit
369        assert!(cache.get(1, Some(1000)).is_some());
370
371        // 60% relative drift but only 600 absolute — below 10k floor, should STILL hit
372        assert!(cache.get(1, Some(1600)).is_some(), "small table should not thrash");
373
374        // 60% shrinkage but only 600 absolute — should STILL hit
375        assert!(cache.get(1, Some(400)).is_some(), "small shrinkage should not thrash");
376
377        // ── Large dataset: BOTH relative AND absolute thresholds exceeded ──
378        cache.insert(3, ExplainEstimate { total_cost: 500.0, plan_rows: 50_000 });
379
380        // 70% drift + 35k absolute (both above threshold) — should miss
381        assert!(cache.get(3, Some(85_000)).is_none(), "large drift should invalidate");
382
383        // 40% drift + 20k absolute (relative below 50%) — should STILL hit
384        assert!(cache.get(3, Some(70_000)).is_some(), "moderate drift should not invalidate");
385
386        // 60% shrinkage + 30k absolute (both above threshold) — should miss
387        assert!(cache.get(3, Some(20_000)).is_none(), "large shrinkage should invalidate");
388
389        // Edge: plan_rows = 0 in cache — skip drift check entirely
390        cache.insert(2, ExplainEstimate { total_cost: 10.0, plan_rows: 0 });
391        assert!(cache.get(2, Some(999_999)).is_some());
392    }
393
394    #[test]
395    fn test_explain_mode_default() {
396        let config = ExplainConfig::default();
397        assert_eq!(config.mode, ExplainMode::Precheck);
398        assert_eq!(config.depth_threshold, 3);
399        assert!((config.max_total_cost - 100_000.0).abs() < 0.01);
400        assert_eq!(config.max_plan_rows, 1_000_000);
401    }
402}